diff --git a/cmd/user_program_import/main.go b/cmd/user_program_import/main.go index d5d9d79..d972be6 100644 --- a/cmd/user_program_import/main.go +++ b/cmd/user_program_import/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "log" "os" "time" @@ -12,18 +13,18 @@ import ( ) const ( - defaultCSVPath = "./initial_data/user_program_info_init_20251208.csv" - defaultUpdateDir = "./update_data" + defaultCSVPath = "/initial_data/user_program_info_init_20251208.csv" + defaultUpdateDir = "/update_data" defaultTimeout = 10 * time.Minute defaultSchema = "public" - defaultLogDir = "./log" + defaultLogDir = "/log" targetTableName = "user_program_info_replica" ) func main() { - dbURL := os.Getenv("DATABASE_URL") - if dbURL == "" { - log.Fatal("DATABASE_URL is required") + dbURL, err := databaseURL() + if err != nil { + log.Fatalf("database config: %v", err) } csvPath := env("USER_PROGRAM_INFO_CSV", defaultCSVPath) @@ -57,3 +58,18 @@ func env(key, fallback string) string { } return fallback } + +func databaseURL() (string, error) { + if url := os.Getenv("DATABASE_URL"); url != "" { + return url, nil + } + user := os.Getenv("POSTGRES_USER") + pass := os.Getenv("POSTGRES_PASSWORD") + host := env("POSTGRES_HOST", "localhost") + port := env("POSTGRES_PORT", "5432") + db := os.Getenv("POSTGRES_DB") + if user == "" || pass == "" || db == "" { + return "", fmt.Errorf("DATABASE_URL or POSTGRES_{USER,PASSWORD,DB} is required") + } + return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", user, pass, host, port, db), nil +} diff --git a/cmd/user_program_sync/main.go b/cmd/user_program_sync/main.go index 984d4d9..ca00344 100644 --- a/cmd/user_program_sync/main.go +++ b/cmd/user_program_sync/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "log" "os" "time" @@ -14,9 +15,9 @@ const defaultTimeout = 30 * time.Minute func main() { logger := log.New(os.Stdout, "[sync] ", log.LstdFlags) - dbURL := os.Getenv("DATABASE_URL") - if dbURL == "" { - logger.Fatal("DATABASE_URL is required") + dbURL, err := databaseURL() + if err != nil { + logger.Fatalf("database config: %v", err) } mysqlCfg, err := userprogram.NewMySQLConfigFromEnv() @@ -43,3 +44,24 @@ func main() { logger.Fatalf("sync failed: %v", err) } } + +func databaseURL() (string, error) { + if url := os.Getenv("DATABASE_URL"); url != "" { + return url, nil + } + user := os.Getenv("POSTGRES_USER") + pass := os.Getenv("POSTGRES_PASSWORD") + host := os.Getenv("POSTGRES_HOST") + if host == "" { + host = "localhost" + } + port := os.Getenv("POSTGRES_PORT") + if port == "" { + port = "5432" + } + db := os.Getenv("POSTGRES_DB") + if user == "" || pass == "" || db == "" { + return "", fmt.Errorf("DATABASE_URL or POSTGRES_{USER,PASSWORD,DB} is required") + } + return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", user, pass, host, port, db), nil +} diff --git a/internal/importer/user_program_info.go b/internal/importer/user_program_info.go index 14cc32e..9c1137f 100644 --- a/internal/importer/user_program_info.go +++ b/internal/importer/user_program_info.go @@ -177,9 +177,9 @@ func copyAndUpsertCSV(ctx context.Context, conn *pgx.Conn, path, schema, table s _ = tx.Rollback(ctx) }() - tempTable := fmt.Sprintf("%s_import_tmp", table) + tempTable := fmt.Sprintf("%s_import_tmp_%d", table, time.Now().UnixNano()) - if _, err := tx.Exec(ctx, fmt.Sprintf(`CREATE TEMP TABLE %s (LIKE %s INCLUDING ALL);`, quoteIdent(tempTable), pgx.Identifier{schema, table}.Sanitize())); err != nil { + if _, err := tx.Exec(ctx, fmt.Sprintf(`CREATE TEMP TABLE %s (LIKE %s INCLUDING ALL) ON COMMIT DROP;`, quoteIdent(tempTable), pgx.Identifier{schema, table}.Sanitize())); err != nil { return res, err }