392 lines
8.5 KiB
Go
392 lines
8.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/oschwald/maxminddb-golang"
|
|
)
|
|
|
|
const (
|
|
defaultMMDBPath = "/initial_data/GeoLite2-City.mmdb"
|
|
defaultSchema = "geoip"
|
|
defaultLoaderTimeout = 30 * time.Minute
|
|
)
|
|
|
|
type cityRecord struct {
|
|
City struct {
|
|
GeoNameID uint `maxminddb:"geoname_id"`
|
|
Names map[string]string `maxminddb:"names"`
|
|
} `maxminddb:"city"`
|
|
Country struct {
|
|
IsoCode string `maxminddb:"iso_code"`
|
|
Names map[string]string `maxminddb:"names"`
|
|
} `maxminddb:"country"`
|
|
Subdivisions []struct {
|
|
IsoCode string `maxminddb:"iso_code"`
|
|
Names map[string]string `maxminddb:"names"`
|
|
} `maxminddb:"subdivisions"`
|
|
Location struct {
|
|
Latitude float64 `maxminddb:"latitude"`
|
|
Longitude float64 `maxminddb:"longitude"`
|
|
TimeZone string `maxminddb:"time_zone"`
|
|
} `maxminddb:"location"`
|
|
}
|
|
|
|
type cityRow struct {
|
|
network string
|
|
geonameID int
|
|
country string
|
|
countryISO string
|
|
region string
|
|
regionISO string
|
|
city string
|
|
latitude float64
|
|
longitude float64
|
|
timeZone string
|
|
}
|
|
|
|
func main() {
|
|
dbURL := os.Getenv("DATABASE_URL")
|
|
if dbURL == "" {
|
|
log.Fatal("DATABASE_URL is required")
|
|
}
|
|
|
|
mmdbPath := env("GEOIP_DB_PATH", defaultMMDBPath)
|
|
timeout := envDuration("GEOIP_LOADER_TIMEOUT", defaultLoaderTimeout)
|
|
skipIfSame := envBool("GEOIP_LOADER_SKIP_IF_SAME_HASH", true)
|
|
force := envBool("GEOIP_LOADER_FORCE", false)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
log.Printf("starting mmdb load from %s", mmdbPath)
|
|
|
|
hash, err := fileSHA256(mmdbPath)
|
|
if err != nil {
|
|
log.Fatalf("failed to hash mmdb: %v", err)
|
|
}
|
|
|
|
conn, err := pgx.Connect(ctx, dbURL)
|
|
if err != nil {
|
|
log.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
defer conn.Close(context.Background())
|
|
|
|
if err := ensureSchema(ctx, conn); err != nil {
|
|
log.Fatalf("failed to ensure schema: %v", err)
|
|
}
|
|
|
|
existingHash, err := currentHash(ctx, conn)
|
|
if err != nil {
|
|
log.Fatalf("failed to read metadata: %v", err)
|
|
}
|
|
if skipIfSame && !force && existingHash == hash {
|
|
log.Printf("mmdb hash unchanged (%s), skipping load", hash)
|
|
return
|
|
}
|
|
|
|
rowSource, err := newNetworkSource(mmdbPath)
|
|
if err != nil {
|
|
log.Fatalf("failed to open mmdb: %v", err)
|
|
}
|
|
defer rowSource.Close()
|
|
|
|
if err := loadNetworks(ctx, conn, rowSource); err != nil {
|
|
log.Fatalf("failed to load networks: %v", err)
|
|
}
|
|
|
|
if err := upsertHash(ctx, conn, hash); err != nil {
|
|
log.Fatalf("failed to update metadata: %v", err)
|
|
}
|
|
|
|
log.Printf("loaded mmdb into Postgres (%d rows), hash=%s", rowSource.Rows(), hash)
|
|
}
|
|
|
|
func env(key, fallback string) string {
|
|
if val := os.Getenv(key); val != "" {
|
|
return val
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func envBool(key string, fallback bool) bool {
|
|
val := os.Getenv(key)
|
|
if val == "" {
|
|
return fallback
|
|
}
|
|
parsed, err := strconv.ParseBool(val)
|
|
if err != nil {
|
|
return fallback
|
|
}
|
|
return parsed
|
|
}
|
|
|
|
func envDuration(key string, fallback time.Duration) time.Duration {
|
|
val := os.Getenv(key)
|
|
if val == "" {
|
|
return fallback
|
|
}
|
|
d, err := time.ParseDuration(val)
|
|
if err != nil {
|
|
return fallback
|
|
}
|
|
return d
|
|
}
|
|
|
|
func fileSHA256(path string) (string, error) {
|
|
f, err := os.Open(path)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer f.Close()
|
|
|
|
h := sha256.New()
|
|
if _, err := io.Copy(h, f); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(h.Sum(nil)), nil
|
|
}
|
|
|
|
func ensureSchema(ctx context.Context, conn *pgx.Conn) error {
|
|
ddl := fmt.Sprintf(`
|
|
CREATE SCHEMA IF NOT EXISTS %s;
|
|
|
|
CREATE TABLE IF NOT EXISTS %s.geoip_metadata (
|
|
key text PRIMARY KEY,
|
|
value text NOT NULL,
|
|
updated_at timestamptz NOT NULL DEFAULT now()
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS %s.city_lookup (
|
|
network cidr PRIMARY KEY,
|
|
geoname_id integer,
|
|
country text,
|
|
country_iso_code text,
|
|
region text,
|
|
region_iso_code text,
|
|
city text,
|
|
latitude double precision,
|
|
longitude double precision,
|
|
time_zone text
|
|
);
|
|
`, defaultSchema, defaultSchema, defaultSchema)
|
|
|
|
_, err := conn.Exec(ctx, ddl)
|
|
return err
|
|
}
|
|
|
|
func currentHash(ctx context.Context, conn *pgx.Conn) (string, error) {
|
|
var hash sql.NullString
|
|
err := conn.QueryRow(ctx, `SELECT value FROM geoip.geoip_metadata WHERE key = 'mmdb_sha256'`).Scan(&hash)
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return "", nil
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return hash.String, nil
|
|
}
|
|
|
|
func upsertHash(ctx context.Context, conn *pgx.Conn, hash string) error {
|
|
_, err := conn.Exec(ctx, `
|
|
INSERT INTO geoip.geoip_metadata(key, value, updated_at)
|
|
VALUES ('mmdb_sha256', $1, now())
|
|
ON CONFLICT (key) DO UPDATE
|
|
SET value = EXCLUDED.value,
|
|
updated_at = EXCLUDED.updated_at;
|
|
`, hash)
|
|
return err
|
|
}
|
|
|
|
type networkSource struct {
|
|
reader *maxminddb.Reader
|
|
iter *maxminddb.Networks
|
|
err error
|
|
row cityRow
|
|
count int
|
|
}
|
|
|
|
func newNetworkSource(path string) (*networkSource, error) {
|
|
reader, err := maxminddb.Open(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &networkSource{
|
|
reader: reader,
|
|
iter: reader.Networks(),
|
|
}, nil
|
|
}
|
|
|
|
func (s *networkSource) Close() {
|
|
if s.reader != nil {
|
|
_ = s.reader.Close()
|
|
}
|
|
}
|
|
|
|
func (s *networkSource) Rows() int {
|
|
return s.count
|
|
}
|
|
|
|
func (s *networkSource) Next() bool {
|
|
if !s.iter.Next() {
|
|
s.err = s.iter.Err()
|
|
return false
|
|
}
|
|
|
|
var rec cityRecord
|
|
network, err := s.iter.Network(&rec)
|
|
if err != nil {
|
|
s.err = err
|
|
return false
|
|
}
|
|
|
|
s.row = cityRow{
|
|
network: network.String(),
|
|
geonameID: int(rec.City.GeoNameID),
|
|
country: rec.Country.Names["en"],
|
|
countryISO: rec.Country.IsoCode,
|
|
region: firstName(rec.Subdivisions),
|
|
regionISO: firstISO(rec.Subdivisions),
|
|
city: rec.City.Names["en"],
|
|
latitude: rec.Location.Latitude,
|
|
longitude: rec.Location.Longitude,
|
|
timeZone: rec.Location.TimeZone,
|
|
}
|
|
s.count++
|
|
if s.count%500000 == 0 {
|
|
log.Printf("loader progress: %d rows processed", s.count)
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (s *networkSource) Values() ([]any, error) {
|
|
return []any{
|
|
s.row.network,
|
|
s.row.geonameID,
|
|
s.row.country,
|
|
s.row.countryISO,
|
|
s.row.region,
|
|
s.row.regionISO,
|
|
s.row.city,
|
|
s.row.latitude,
|
|
s.row.longitude,
|
|
s.row.timeZone,
|
|
}, nil
|
|
}
|
|
|
|
func (s *networkSource) Err() error {
|
|
if s.err != nil {
|
|
return s.err
|
|
}
|
|
return s.iter.Err()
|
|
}
|
|
|
|
func firstName(subdivisions []struct {
|
|
IsoCode string `maxminddb:"iso_code"`
|
|
Names map[string]string `maxminddb:"names"`
|
|
}) string {
|
|
if len(subdivisions) == 0 {
|
|
return ""
|
|
}
|
|
return subdivisions[0].Names["en"]
|
|
}
|
|
|
|
func firstISO(subdivisions []struct {
|
|
IsoCode string `maxminddb:"iso_code"`
|
|
Names map[string]string `maxminddb:"names"`
|
|
}) string {
|
|
if len(subdivisions) == 0 {
|
|
return ""
|
|
}
|
|
return subdivisions[0].IsoCode
|
|
}
|
|
|
|
func loadNetworks(ctx context.Context, conn *pgx.Conn, src *networkSource) error {
|
|
tx, err := conn.Begin(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
_ = tx.Rollback(ctx)
|
|
}()
|
|
|
|
_, err = tx.Exec(ctx, `DROP TABLE IF EXISTS geoip.city_lookup_new; CREATE TABLE geoip.city_lookup_new (LIKE geoip.city_lookup INCLUDING ALL);`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
columns := []string{
|
|
"network",
|
|
"geoname_id",
|
|
"country",
|
|
"country_iso_code",
|
|
"region",
|
|
"region_iso_code",
|
|
"city",
|
|
"latitude",
|
|
"longitude",
|
|
"time_zone",
|
|
}
|
|
|
|
log.Printf("loader copy: starting bulk copy")
|
|
copied, err := tx.CopyFrom(ctx, pgx.Identifier{defaultSchema, "city_lookup_new"}, columns, src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Printf("loader copy: finished bulk copy (rows=%d)", copied)
|
|
|
|
if _, err := tx.Exec(ctx, `
|
|
ALTER TABLE IF EXISTS geoip.city_lookup RENAME TO city_lookup_old;
|
|
ALTER TABLE geoip.city_lookup_new RENAME TO city_lookup;
|
|
DROP TABLE IF EXISTS geoip.city_lookup_old;
|
|
`); err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := tx.Exec(ctx, `
|
|
CREATE INDEX IF NOT EXISTS city_lookup_network_gist ON geoip.city_lookup USING gist (network inet_ops);
|
|
CREATE INDEX IF NOT EXISTS city_lookup_geoname_id_idx ON geoip.city_lookup (geoname_id);
|
|
`); err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := tx.Exec(ctx, `
|
|
CREATE OR REPLACE FUNCTION geoip.lookup_city(ip inet)
|
|
RETURNS TABLE (
|
|
ip inet,
|
|
country text,
|
|
region text,
|
|
city text,
|
|
latitude double precision,
|
|
longitude double precision
|
|
) LANGUAGE sql STABLE AS $$
|
|
SELECT
|
|
$1::inet AS ip,
|
|
c.country,
|
|
c.region,
|
|
c.city,
|
|
c.latitude,
|
|
c.longitude
|
|
FROM geoip.city_lookup c
|
|
WHERE c.network >>= $1
|
|
ORDER BY masklen(c.network) DESC
|
|
LIMIT 1;
|
|
$$;
|
|
`); err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit(ctx)
|
|
}
|