package main import ( "context" "crypto/sha256" "database/sql" "encoding/hex" "errors" "fmt" "io" "log" "os" "strconv" "time" "github.com/jackc/pgx/v5" "github.com/oschwald/maxminddb-golang" ) const ( defaultMMDBPath = "/data/GeoLite2-City.mmdb" defaultSchema = "geoip" defaultLoaderTimeout = 30 * time.Minute ) type cityRecord struct { City struct { GeoNameID uint `maxminddb:"geoname_id"` Names map[string]string `maxminddb:"names"` } `maxminddb:"city"` Country struct { IsoCode string `maxminddb:"iso_code"` Names map[string]string `maxminddb:"names"` } `maxminddb:"country"` Subdivisions []struct { IsoCode string `maxminddb:"iso_code"` Names map[string]string `maxminddb:"names"` } `maxminddb:"subdivisions"` Location struct { Latitude float64 `maxminddb:"latitude"` Longitude float64 `maxminddb:"longitude"` TimeZone string `maxminddb:"time_zone"` } `maxminddb:"location"` } type cityRow struct { network string geonameID int country string countryISO string region string regionISO string city string latitude float64 longitude float64 timeZone string } func main() { dbURL := os.Getenv("DATABASE_URL") if dbURL == "" { log.Fatal("DATABASE_URL is required") } mmdbPath := env("GEOIP_DB_PATH", defaultMMDBPath) timeout := envDuration("GEOIP_LOADER_TIMEOUT", defaultLoaderTimeout) skipIfSame := envBool("GEOIP_LOADER_SKIP_IF_SAME_HASH", true) force := envBool("GEOIP_LOADER_FORCE", false) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() log.Printf("starting mmdb load from %s", mmdbPath) hash, err := fileSHA256(mmdbPath) if err != nil { log.Fatalf("failed to hash mmdb: %v", err) } conn, err := pgx.Connect(ctx, dbURL) if err != nil { log.Fatalf("failed to connect database: %v", err) } defer conn.Close(context.Background()) if err := ensureSchema(ctx, conn); err != nil { log.Fatalf("failed to ensure schema: %v", err) } existingHash, err := currentHash(ctx, conn) if err != nil { log.Fatalf("failed to read metadata: %v", err) } if skipIfSame && !force && existingHash == hash { log.Printf("mmdb hash unchanged (%s), skipping load", hash) return } rowSource, err := newNetworkSource(mmdbPath) if err != nil { log.Fatalf("failed to open mmdb: %v", err) } defer rowSource.Close() if err := loadNetworks(ctx, conn, rowSource); err != nil { log.Fatalf("failed to load networks: %v", err) } if err := upsertHash(ctx, conn, hash); err != nil { log.Fatalf("failed to update metadata: %v", err) } log.Printf("loaded mmdb into Postgres (%d rows), hash=%s", rowSource.Rows(), hash) } func env(key, fallback string) string { if val := os.Getenv(key); val != "" { return val } return fallback } func envBool(key string, fallback bool) bool { val := os.Getenv(key) if val == "" { return fallback } parsed, err := strconv.ParseBool(val) if err != nil { return fallback } return parsed } func envDuration(key string, fallback time.Duration) time.Duration { val := os.Getenv(key) if val == "" { return fallback } d, err := time.ParseDuration(val) if err != nil { return fallback } return d } func fileSHA256(path string) (string, error) { f, err := os.Open(path) if err != nil { return "", err } defer f.Close() h := sha256.New() if _, err := io.Copy(h, f); err != nil { return "", err } return hex.EncodeToString(h.Sum(nil)), nil } func ensureSchema(ctx context.Context, conn *pgx.Conn) error { ddl := fmt.Sprintf(` CREATE SCHEMA IF NOT EXISTS %s; CREATE TABLE IF NOT EXISTS %s.geoip_metadata ( key text PRIMARY KEY, value text NOT NULL, updated_at timestamptz NOT NULL DEFAULT now() ); CREATE TABLE IF NOT EXISTS %s.city_lookup ( network cidr PRIMARY KEY, geoname_id integer, country text, country_iso_code text, region text, region_iso_code text, city text, latitude double precision, longitude double precision, time_zone text ); `, defaultSchema, defaultSchema, defaultSchema) _, err := conn.Exec(ctx, ddl) return err } func currentHash(ctx context.Context, conn *pgx.Conn) (string, error) { var hash sql.NullString err := conn.QueryRow(ctx, `SELECT value FROM geoip.geoip_metadata WHERE key = 'mmdb_sha256'`).Scan(&hash) if errors.Is(err, pgx.ErrNoRows) { return "", nil } if err != nil { return "", err } return hash.String, nil } func upsertHash(ctx context.Context, conn *pgx.Conn, hash string) error { _, err := conn.Exec(ctx, ` INSERT INTO geoip.geoip_metadata(key, value, updated_at) VALUES ('mmdb_sha256', $1, now()) ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value, updated_at = EXCLUDED.updated_at; `, hash) return err } type networkSource struct { reader *maxminddb.Reader iter *maxminddb.Networks err error row cityRow count int } func newNetworkSource(path string) (*networkSource, error) { reader, err := maxminddb.Open(path) if err != nil { return nil, err } return &networkSource{ reader: reader, iter: reader.Networks(), }, nil } func (s *networkSource) Close() { if s.reader != nil { _ = s.reader.Close() } } func (s *networkSource) Rows() int { return s.count } func (s *networkSource) Next() bool { if !s.iter.Next() { s.err = s.iter.Err() return false } var rec cityRecord network, err := s.iter.Network(&rec) if err != nil { s.err = err return false } s.row = cityRow{ network: network.String(), geonameID: int(rec.City.GeoNameID), country: rec.Country.Names["en"], countryISO: rec.Country.IsoCode, region: firstName(rec.Subdivisions), regionISO: firstISO(rec.Subdivisions), city: rec.City.Names["en"], latitude: rec.Location.Latitude, longitude: rec.Location.Longitude, timeZone: rec.Location.TimeZone, } s.count++ if s.count%500000 == 0 { log.Printf("loader progress: %d rows processed", s.count) } return true } func (s *networkSource) Values() ([]any, error) { return []any{ s.row.network, s.row.geonameID, s.row.country, s.row.countryISO, s.row.region, s.row.regionISO, s.row.city, s.row.latitude, s.row.longitude, s.row.timeZone, }, nil } func (s *networkSource) Err() error { if s.err != nil { return s.err } return s.iter.Err() } func firstName(subdivisions []struct { IsoCode string `maxminddb:"iso_code"` Names map[string]string `maxminddb:"names"` }) string { if len(subdivisions) == 0 { return "" } return subdivisions[0].Names["en"] } func firstISO(subdivisions []struct { IsoCode string `maxminddb:"iso_code"` Names map[string]string `maxminddb:"names"` }) string { if len(subdivisions) == 0 { return "" } return subdivisions[0].IsoCode } func loadNetworks(ctx context.Context, conn *pgx.Conn, src *networkSource) error { tx, err := conn.Begin(ctx) if err != nil { return err } defer func() { _ = tx.Rollback(ctx) }() _, err = tx.Exec(ctx, `DROP TABLE IF EXISTS geoip.city_lookup_new; CREATE TABLE geoip.city_lookup_new (LIKE geoip.city_lookup INCLUDING ALL);`) if err != nil { return err } columns := []string{ "network", "geoname_id", "country", "country_iso_code", "region", "region_iso_code", "city", "latitude", "longitude", "time_zone", } log.Printf("loader copy: starting bulk copy") copied, err := tx.CopyFrom(ctx, pgx.Identifier{defaultSchema, "city_lookup_new"}, columns, src) if err != nil { return err } log.Printf("loader copy: finished bulk copy (rows=%d)", copied) if _, err := tx.Exec(ctx, ` ALTER TABLE IF EXISTS geoip.city_lookup RENAME TO city_lookup_old; ALTER TABLE geoip.city_lookup_new RENAME TO city_lookup; DROP TABLE IF EXISTS geoip.city_lookup_old; `); err != nil { return err } if _, err := tx.Exec(ctx, ` CREATE INDEX IF NOT EXISTS city_lookup_network_gist ON geoip.city_lookup USING gist (network inet_ops); CREATE INDEX IF NOT EXISTS city_lookup_geoname_id_idx ON geoip.city_lookup (geoname_id); `); err != nil { return err } if _, err := tx.Exec(ctx, ` CREATE OR REPLACE FUNCTION geoip.lookup_city(ip inet) RETURNS TABLE ( ip inet, country text, region text, city text, latitude double precision, longitude double precision ) LANGUAGE sql STABLE AS $$ SELECT $1::inet AS ip, c.country, c.region, c.city, c.latitude, c.longitude FROM geoip.city_lookup c WHERE c.network >>= $1 ORDER BY masklen(c.network) DESC LIMIT 1; $$; `); err != nil { return err } return tx.Commit(ctx) }