Files
geoip-rest/cmd/loader/main.go
2025-12-09 19:29:34 +09:00

392 lines
8.5 KiB
Go

package main
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"io"
"log"
"os"
"strconv"
"time"
"github.com/jackc/pgx/v5"
"github.com/oschwald/maxminddb-golang"
)
const (
defaultMMDBPath = "/initial_data/GeoLite2-City.mmdb"
defaultSchema = "geoip"
defaultLoaderTimeout = 30 * time.Minute
)
type cityRecord struct {
City struct {
GeoNameID uint `maxminddb:"geoname_id"`
Names map[string]string `maxminddb:"names"`
} `maxminddb:"city"`
Country struct {
IsoCode string `maxminddb:"iso_code"`
Names map[string]string `maxminddb:"names"`
} `maxminddb:"country"`
Subdivisions []struct {
IsoCode string `maxminddb:"iso_code"`
Names map[string]string `maxminddb:"names"`
} `maxminddb:"subdivisions"`
Location struct {
Latitude float64 `maxminddb:"latitude"`
Longitude float64 `maxminddb:"longitude"`
TimeZone string `maxminddb:"time_zone"`
} `maxminddb:"location"`
}
type cityRow struct {
network string
geonameID int
country string
countryISO string
region string
regionISO string
city string
latitude float64
longitude float64
timeZone string
}
func main() {
dbURL := os.Getenv("DATABASE_URL")
if dbURL == "" {
log.Fatal("DATABASE_URL is required")
}
mmdbPath := env("GEOIP_DB_PATH", defaultMMDBPath)
timeout := envDuration("GEOIP_LOADER_TIMEOUT", defaultLoaderTimeout)
skipIfSame := envBool("GEOIP_LOADER_SKIP_IF_SAME_HASH", true)
force := envBool("GEOIP_LOADER_FORCE", false)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
log.Printf("starting mmdb load from %s", mmdbPath)
hash, err := fileSHA256(mmdbPath)
if err != nil {
log.Fatalf("failed to hash mmdb: %v", err)
}
conn, err := pgx.Connect(ctx, dbURL)
if err != nil {
log.Fatalf("failed to connect database: %v", err)
}
defer conn.Close(context.Background())
if err := ensureSchema(ctx, conn); err != nil {
log.Fatalf("failed to ensure schema: %v", err)
}
existingHash, err := currentHash(ctx, conn)
if err != nil {
log.Fatalf("failed to read metadata: %v", err)
}
if skipIfSame && !force && existingHash == hash {
log.Printf("mmdb hash unchanged (%s), skipping load", hash)
return
}
rowSource, err := newNetworkSource(mmdbPath)
if err != nil {
log.Fatalf("failed to open mmdb: %v", err)
}
defer rowSource.Close()
if err := loadNetworks(ctx, conn, rowSource); err != nil {
log.Fatalf("failed to load networks: %v", err)
}
if err := upsertHash(ctx, conn, hash); err != nil {
log.Fatalf("failed to update metadata: %v", err)
}
log.Printf("loaded mmdb into Postgres (%d rows), hash=%s", rowSource.Rows(), hash)
}
func env(key, fallback string) string {
if val := os.Getenv(key); val != "" {
return val
}
return fallback
}
func envBool(key string, fallback bool) bool {
val := os.Getenv(key)
if val == "" {
return fallback
}
parsed, err := strconv.ParseBool(val)
if err != nil {
return fallback
}
return parsed
}
func envDuration(key string, fallback time.Duration) time.Duration {
val := os.Getenv(key)
if val == "" {
return fallback
}
d, err := time.ParseDuration(val)
if err != nil {
return fallback
}
return d
}
func fileSHA256(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
func ensureSchema(ctx context.Context, conn *pgx.Conn) error {
ddl := fmt.Sprintf(`
CREATE SCHEMA IF NOT EXISTS %s;
CREATE TABLE IF NOT EXISTS %s.geoip_metadata (
key text PRIMARY KEY,
value text NOT NULL,
updated_at timestamptz NOT NULL DEFAULT now()
);
CREATE TABLE IF NOT EXISTS %s.city_lookup (
network cidr PRIMARY KEY,
geoname_id integer,
country text,
country_iso_code text,
region text,
region_iso_code text,
city text,
latitude double precision,
longitude double precision,
time_zone text
);
`, defaultSchema, defaultSchema, defaultSchema)
_, err := conn.Exec(ctx, ddl)
return err
}
func currentHash(ctx context.Context, conn *pgx.Conn) (string, error) {
var hash sql.NullString
err := conn.QueryRow(ctx, `SELECT value FROM geoip.geoip_metadata WHERE key = 'mmdb_sha256'`).Scan(&hash)
if errors.Is(err, pgx.ErrNoRows) {
return "", nil
}
if err != nil {
return "", err
}
return hash.String, nil
}
func upsertHash(ctx context.Context, conn *pgx.Conn, hash string) error {
_, err := conn.Exec(ctx, `
INSERT INTO geoip.geoip_metadata(key, value, updated_at)
VALUES ('mmdb_sha256', $1, now())
ON CONFLICT (key) DO UPDATE
SET value = EXCLUDED.value,
updated_at = EXCLUDED.updated_at;
`, hash)
return err
}
type networkSource struct {
reader *maxminddb.Reader
iter *maxminddb.Networks
err error
row cityRow
count int
}
func newNetworkSource(path string) (*networkSource, error) {
reader, err := maxminddb.Open(path)
if err != nil {
return nil, err
}
return &networkSource{
reader: reader,
iter: reader.Networks(),
}, nil
}
func (s *networkSource) Close() {
if s.reader != nil {
_ = s.reader.Close()
}
}
func (s *networkSource) Rows() int {
return s.count
}
func (s *networkSource) Next() bool {
if !s.iter.Next() {
s.err = s.iter.Err()
return false
}
var rec cityRecord
network, err := s.iter.Network(&rec)
if err != nil {
s.err = err
return false
}
s.row = cityRow{
network: network.String(),
geonameID: int(rec.City.GeoNameID),
country: rec.Country.Names["en"],
countryISO: rec.Country.IsoCode,
region: firstName(rec.Subdivisions),
regionISO: firstISO(rec.Subdivisions),
city: rec.City.Names["en"],
latitude: rec.Location.Latitude,
longitude: rec.Location.Longitude,
timeZone: rec.Location.TimeZone,
}
s.count++
if s.count%500000 == 0 {
log.Printf("loader progress: %d rows processed", s.count)
}
return true
}
func (s *networkSource) Values() ([]any, error) {
return []any{
s.row.network,
s.row.geonameID,
s.row.country,
s.row.countryISO,
s.row.region,
s.row.regionISO,
s.row.city,
s.row.latitude,
s.row.longitude,
s.row.timeZone,
}, nil
}
func (s *networkSource) Err() error {
if s.err != nil {
return s.err
}
return s.iter.Err()
}
func firstName(subdivisions []struct {
IsoCode string `maxminddb:"iso_code"`
Names map[string]string `maxminddb:"names"`
}) string {
if len(subdivisions) == 0 {
return ""
}
return subdivisions[0].Names["en"]
}
func firstISO(subdivisions []struct {
IsoCode string `maxminddb:"iso_code"`
Names map[string]string `maxminddb:"names"`
}) string {
if len(subdivisions) == 0 {
return ""
}
return subdivisions[0].IsoCode
}
func loadNetworks(ctx context.Context, conn *pgx.Conn, src *networkSource) error {
tx, err := conn.Begin(ctx)
if err != nil {
return err
}
defer func() {
_ = tx.Rollback(ctx)
}()
_, err = tx.Exec(ctx, `DROP TABLE IF EXISTS geoip.city_lookup_new; CREATE TABLE geoip.city_lookup_new (LIKE geoip.city_lookup INCLUDING ALL);`)
if err != nil {
return err
}
columns := []string{
"network",
"geoname_id",
"country",
"country_iso_code",
"region",
"region_iso_code",
"city",
"latitude",
"longitude",
"time_zone",
}
log.Printf("loader copy: starting bulk copy")
copied, err := tx.CopyFrom(ctx, pgx.Identifier{defaultSchema, "city_lookup_new"}, columns, src)
if err != nil {
return err
}
log.Printf("loader copy: finished bulk copy (rows=%d)", copied)
if _, err := tx.Exec(ctx, `
ALTER TABLE IF EXISTS geoip.city_lookup RENAME TO city_lookup_old;
ALTER TABLE geoip.city_lookup_new RENAME TO city_lookup;
DROP TABLE IF EXISTS geoip.city_lookup_old;
`); err != nil {
return err
}
if _, err := tx.Exec(ctx, `
CREATE INDEX IF NOT EXISTS city_lookup_network_gist ON geoip.city_lookup USING gist (network inet_ops);
CREATE INDEX IF NOT EXISTS city_lookup_geoname_id_idx ON geoip.city_lookup (geoname_id);
`); err != nil {
return err
}
if _, err := tx.Exec(ctx, `
CREATE OR REPLACE FUNCTION geoip.lookup_city(ip inet)
RETURNS TABLE (
ip inet,
country text,
region text,
city text,
latitude double precision,
longitude double precision
) LANGUAGE sql STABLE AS $$
SELECT
$1::inet AS ip,
c.country,
c.region,
c.city,
c.latitude,
c.longitude
FROM geoip.city_lookup c
WHERE c.network >>= $1
ORDER BY masklen(c.network) DESC
LIMIT 1;
$$;
`); err != nil {
return err
}
return tx.Commit(ctx)
}