client.go (2633B)
1 package db 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "os" 8 9 "github.com/jackc/pgx/v5/pgxpool" 10 ) 11 12 // ErrNoUser is returned when PGUSER is not set. 13 var ErrNoUser = errors.New("PGUSER not set") 14 15 // Client wraps a pgx connection pool. 16 type Client struct { 17 Pool *pgxpool.Pool 18 } 19 20 // DSNFromEnv builds a PostgreSQL DSN from standard PG environment variables. 21 // Returns ("", ErrNoUser) if PGUSER is empty. 22 func DSNFromEnv() (string, error) { 23 host := getenv("PGHOST", "wrds-pgdata.wharton.upenn.edu") 24 port := getenv("PGPORT", "9737") 25 user := getenv("PGUSER", "") 26 password := getenv("PGPASSWORD", "") 27 database := getenv("PGDATABASE", "wrds") 28 29 if user == "" { 30 return "", ErrNoUser 31 } 32 33 dsn := fmt.Sprintf("host=%s port=%s user=%s sslmode=require", host, port, user) 34 if password != "" { 35 dsn += fmt.Sprintf(" password=%s", password) 36 } 37 if database != "" { 38 dsn += fmt.Sprintf(" dbname=%s", database) 39 } 40 return dsn, nil 41 } 42 43 func getenv(key, fallback string) string { 44 if v := os.Getenv(key); v != "" { 45 return v 46 } 47 return fallback 48 } 49 50 // New creates and pings a pgx pool using DSNFromEnv. 51 // The pool is limited to a single connection to avoid triggering 52 // multiple authentication prompts (e.g. Duo 2FA on WRDS). 53 func New(ctx context.Context) (*Client, error) { 54 dsn, err := DSNFromEnv() 55 if err != nil { 56 return nil, err 57 } 58 cfg, err := pgxpool.ParseConfig(dsn) 59 if err != nil { 60 return nil, fmt.Errorf("parse dsn: %w", err) 61 } 62 cfg.MaxConns = 1 63 cfg.MinConns = 0 64 pool, err := pgxpool.NewWithConfig(ctx, cfg) 65 if err != nil { 66 return nil, fmt.Errorf("pgxpool.New: %w", err) 67 } 68 if err := pool.Ping(ctx); err != nil { 69 pool.Close() 70 return nil, fmt.Errorf("ping: %w", err) 71 } 72 return &Client{Pool: pool}, nil 73 } 74 75 // NewWithCredentials sets PGUSER/PGPASSWORD/PGDATABASE env vars then creates and pings a pool. 76 func NewWithCredentials(ctx context.Context, user, password, database string) (*Client, error) { 77 os.Setenv("PGUSER", user) 78 os.Setenv("PGPASSWORD", password) 79 if database != "" { 80 os.Setenv("PGDATABASE", database) 81 } 82 return New(ctx) 83 } 84 85 // Databases returns the list of connectable databases. 86 func (c *Client) Databases(ctx context.Context) ([]string, error) { 87 rows, err := c.Pool.Query(ctx, 88 "SELECT datname FROM pg_database WHERE datallowconn = true ORDER BY datname") 89 if err != nil { 90 return nil, fmt.Errorf("databases query: %w", err) 91 } 92 defer rows.Close() 93 94 var dbs []string 95 for rows.Next() { 96 var name string 97 if err := rows.Scan(&name); err != nil { 98 return nil, err 99 } 100 dbs = append(dbs, name) 101 } 102 return dbs, rows.Err() 103 } 104 105 // Close releases the pool. 106 func (c *Client) Close() { 107 c.Pool.Close() 108 }