Files
geoip-rest/internal/userprogram/dumper.go
2025-12-10 11:53:48 +09:00

251 lines
5.4 KiB
Go

package userprogram
import (
"context"
"database/sql"
"encoding/csv"
"fmt"
"os"
"path/filepath"
"strconv"
"time"
"github.com/go-sql-driver/mysql"
)
type Dumper struct {
cfg MySQLConfig
updateDir string
db *sql.DB
}
func NewDumper(cfg MySQLConfig, updateDir string) (*Dumper, error) {
if updateDir == "" {
updateDir = DefaultUpdateDir
}
if err := os.MkdirAll(updateDir, 0o755); err != nil {
return nil, err
}
dsn := (&mysql.Config{
User: cfg.User,
Passwd: cfg.Password,
Net: "tcp",
Addr: netAddr(cfg.Host, cfg.Port),
DBName: cfg.Database,
Params: map[string]string{"parseTime": "true", "loc": "UTC", "charset": "utf8mb4"},
AllowNativePasswords: true,
}).FormatDSN()
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, fmt.Errorf("open mysql: %w", err)
}
db.SetMaxOpenConns(5)
db.SetMaxIdleConns(2)
db.SetConnMaxIdleTime(5 * time.Minute)
if _, err := db.Exec("SET time_zone = '+00:00'"); err != nil {
_ = db.Close()
return nil, fmt.Errorf("set timezone: %w", err)
}
return &Dumper{
cfg: cfg,
updateDir: updateDir,
db: db,
}, nil
}
func (d *Dumper) Close() error {
if d.db == nil {
return nil
}
return d.db.Close()
}
// MaxIDUntil returns the maximum id with created_at up to and including cutoff (KST).
func (d *Dumper) MaxIDUntil(ctx context.Context, cutoff time.Time) (int64, error) {
query := fmt.Sprintf(`SELECT COALESCE(MAX(id), 0) FROM %s WHERE DATE(CONVERT_TZ(created_at, '+00:00', '+09:00')) <= ?`, d.cfg.Table)
var maxID sql.NullInt64
if err := d.db.QueryRowContext(ctx, query, cutoff.In(kst()).Format("2006-01-02")).Scan(&maxID); err != nil {
return 0, err
}
if !maxID.Valid {
return 0, nil
}
return maxID.Int64, nil
}
// CountUpToID returns count(*) where id <= maxID in source.
func (d *Dumper) CountUpToID(ctx context.Context, maxID int64) (int64, error) {
query := fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE id <= ?`, d.cfg.Table)
var count sql.NullInt64
if err := d.db.QueryRowContext(ctx, query, maxID).Scan(&count); err != nil {
return 0, err
}
if !count.Valid {
return 0, nil
}
return count.Int64, nil
}
// DumpRange exports rows with id in (startID, endID] to a CSV file.
func (d *Dumper) DumpRange(ctx context.Context, startID, endID int64, label time.Time) (string, error) {
if endID <= startID {
return "", nil
}
query := fmt.Sprintf(`
SELECT
id,
product_name,
login_id,
user_employee_id,
login_version,
login_public_ip,
login_local_ip,
user_company,
user_department,
user_position,
user_login_time,
created_at,
user_family_flag
FROM %s
WHERE id > ? AND id <= ?
ORDER BY id;`, d.cfg.Table)
rows, err := d.db.QueryContext(ctx, query, startID, endID)
if err != nil {
return "", fmt.Errorf("query: %w", err)
}
defer rows.Close()
filename := fmt.Sprintf("user_program_info_%s.csv", label.In(kst()).Format(defaultTargetRange))
outPath := filepath.Join(d.updateDir, filename)
tmpPath := outPath + ".tmp"
f, err := os.Create(tmpPath)
if err != nil {
return "", err
}
defer f.Close()
writer := csv.NewWriter(f)
defer writer.Flush()
header := []string{
"id",
"product_name",
"login_id",
"user_employee_id",
"login_version",
"login_public_ip",
"login_local_ip",
"user_company",
"user_department",
"user_position",
"user_login_time",
"created_at",
"user_family_flag",
}
if err := writer.Write(header); err != nil {
return "", err
}
for rows.Next() {
record, err := scanRow(rows)
if err != nil {
return "", err
}
if err := writer.Write(record); err != nil {
return "", err
}
}
if err := rows.Err(); err != nil {
return "", err
}
writer.Flush()
if err := writer.Error(); err != nil {
return "", err
}
if err := os.Rename(tmpPath, outPath); err != nil {
return "", err
}
return outPath, nil
}
func scanRow(rows *sql.Rows) ([]string, error) {
var (
id sql.NullInt64
productName sql.NullString
loginID sql.NullString
employeeID sql.NullString
loginVersion sql.NullString
loginPublicIP sql.NullString
loginLocalIP sql.NullString
userCompany sql.NullString
userDepartment sql.NullString
userPosition sql.NullString
userLoginTime sql.NullTime
createdAt sql.NullTime
userFamilyFlag sql.NullString
)
if err := rows.Scan(
&id,
&productName,
&loginID,
&employeeID,
&loginVersion,
&loginPublicIP,
&loginLocalIP,
&userCompany,
&userDepartment,
&userPosition,
&userLoginTime,
&createdAt,
&userFamilyFlag,
); err != nil {
return nil, err
}
if !id.Valid {
return nil, fmt.Errorf("row missing id")
}
return []string{
strconv.FormatInt(id.Int64, 10),
nullToString(productName),
nullToString(loginID),
nullToString(employeeID),
nullToString(loginVersion),
nullToString(loginPublicIP),
nullToString(loginLocalIP),
nullToString(userCompany),
nullToString(userDepartment),
nullToString(userPosition),
formatTimestamp(userLoginTime),
formatTimestamp(createdAt),
nullToString(userFamilyFlag),
}, nil
}
func nullToString(v sql.NullString) string {
if v.Valid {
return v.String
}
return ""
}
func netAddr(host string, port int) string {
return fmt.Sprintf("%s:%d", host, port)
}
func formatTimestamp(t sql.NullTime) string {
if !t.Valid {
return ""
}
return t.Time.In(kst()).Format("2006-01-02 15:04:05.000")
}