export.go (7700B)
1 package export 2 3 import ( 4 "context" 5 "encoding/csv" 6 "fmt" 7 "math/big" 8 "os" 9 "strings" 10 "time" 11 12 "github.com/jackc/pgx/v5" 13 "github.com/jackc/pgx/v5/pgconn" 14 "github.com/jackc/pgx/v5/pgtype" 15 "github.com/parquet-go/parquet-go" 16 "github.com/parquet-go/parquet-go/compress/zstd" 17 18 "github.com/louloulibs/wrds-download/internal/db" 19 ) 20 21 // Options controls the export behaviour. 22 type Options struct { 23 Format string // "parquet" or "csv" 24 ProgressFunc func(rows int) // called periodically with total rows exported so far 25 } 26 27 const rowGroupSize = 10_000 28 29 // Export runs query against the WRDS PostgreSQL instance and writes output to outPath. 30 // Format is determined by opts.Format (default: parquet). 31 func Export(query, outPath string, opts Options) error { 32 format := strings.ToLower(opts.Format) 33 if format == "" { 34 if strings.HasSuffix(strings.ToLower(outPath), ".csv") { 35 format = "csv" 36 } else { 37 format = "parquet" 38 } 39 } 40 41 dsn, err := db.DSNFromEnv() 42 if err != nil { 43 return fmt.Errorf("dsn: %w", err) 44 } 45 46 ctx := context.Background() 47 conn, err := pgx.Connect(ctx, dsn) 48 if err != nil { 49 return fmt.Errorf("connect: %w", err) 50 } 51 defer conn.Close(ctx) 52 53 rows, err := conn.Query(ctx, query) 54 if err != nil { 55 return fmt.Errorf("query: %w", err) 56 } 57 defer rows.Close() 58 59 switch format { 60 case "csv": 61 return writeCSV(rows, outPath, opts.ProgressFunc) 62 default: 63 return writeParquet(rows, outPath, opts.ProgressFunc) 64 } 65 } 66 67 // writeCSV streams rows into a CSV file with a header row. 68 func writeCSV(rows pgx.Rows, outPath string, progressFn func(int)) error { 69 f, err := os.Create(outPath) 70 if err != nil { 71 return fmt.Errorf("create csv: %w", err) 72 } 73 defer f.Close() 74 75 w := csv.NewWriter(f) 76 defer w.Flush() 77 78 fds := rows.FieldDescriptions() 79 header := make([]string, len(fds)) 80 for i, fd := range fds { 81 header[i] = fd.Name 82 } 83 if err := w.Write(header); err != nil { 84 return fmt.Errorf("write header: %w", err) 85 } 86 87 record := make([]string, len(fds)) 88 total := 0 89 for rows.Next() { 90 vals, err := rows.Values() 91 if err != nil { 92 return fmt.Errorf("scan row: %w", err) 93 } 94 for i, v := range vals { 95 record[i] = formatValue(v) 96 } 97 if err := w.Write(record); err != nil { 98 return fmt.Errorf("write row: %w", err) 99 } 100 total++ 101 if progressFn != nil && total%rowGroupSize == 0 { 102 progressFn(total) 103 } 104 } 105 if err := rows.Err(); err != nil { 106 return fmt.Errorf("rows: %w", err) 107 } 108 109 w.Flush() 110 return w.Error() 111 } 112 113 // writeParquet streams rows into a Parquet file using parquet-go. 114 func writeParquet(rows pgx.Rows, outPath string, progressFn func(int)) error { 115 fds := rows.FieldDescriptions() 116 117 schema, colTypes := buildParquetSchema(fds) 118 119 f, err := os.Create(outPath) 120 if err != nil { 121 return fmt.Errorf("create parquet: %w", err) 122 } 123 defer f.Close() 124 125 writer := parquet.NewGenericWriter[map[string]any](f, 126 schema, 127 parquet.Compression(&zstd.Codec{}), 128 parquet.DefaultEncodingFor(parquet.ByteArray, &parquet.Plain), 129 ) 130 131 buf := make([]map[string]any, 0, rowGroupSize) 132 total := 0 133 134 for rows.Next() { 135 vals, err := rows.Values() 136 if err != nil { 137 return fmt.Errorf("scan row: %w", err) 138 } 139 140 row := make(map[string]any, len(fds)) 141 for i, v := range vals { 142 row[fds[i].Name] = convertValue(v, colTypes[i]) 143 } 144 buf = append(buf, row) 145 146 if len(buf) >= rowGroupSize { 147 if _, err := writer.Write(buf); err != nil { 148 return fmt.Errorf("write row group: %w", err) 149 } 150 total += len(buf) 151 buf = buf[:0] 152 if progressFn != nil { 153 progressFn(total) 154 } 155 } 156 } 157 if err := rows.Err(); err != nil { 158 return fmt.Errorf("rows: %w", err) 159 } 160 161 // Flush remaining rows. 162 if len(buf) > 0 { 163 if _, err := writer.Write(buf); err != nil { 164 return fmt.Errorf("write final rows: %w", err) 165 } 166 } 167 168 return writer.Close() 169 } 170 171 // colType tags how we convert PG values for Parquet. 172 type colType int 173 174 const ( 175 colString colType = iota 176 colBool 177 colInt32 178 colInt64 179 colFloat32 180 colFloat64 181 colDate // days since epoch → int32 182 colTimestamp // microseconds since epoch → int64 183 ) 184 185 // buildParquetSchema maps PG field descriptors to a parquet schema. 186 func buildParquetSchema(fds []pgconn.FieldDescription) (*parquet.Schema, []colType) { 187 cols := make([]colType, len(fds)) 188 group := make(parquet.Group, len(fds)) 189 190 for i, fd := range fds { 191 var node parquet.Node 192 193 switch fd.DataTypeOID { 194 case 16: // bool 195 cols[i] = colBool 196 node = parquet.Optional(parquet.Leaf(parquet.BooleanType)) 197 case 21: // int2 198 cols[i] = colInt32 199 node = parquet.Optional(parquet.Leaf(parquet.Int32Type)) 200 case 23: // int4 201 cols[i] = colInt32 202 node = parquet.Optional(parquet.Leaf(parquet.Int32Type)) 203 case 20: // int8 204 cols[i] = colInt64 205 node = parquet.Optional(parquet.Leaf(parquet.Int64Type)) 206 case 700: // float4 207 cols[i] = colFloat32 208 node = parquet.Optional(parquet.Leaf(parquet.FloatType)) 209 case 701: // float8 210 cols[i] = colFloat64 211 node = parquet.Optional(parquet.Leaf(parquet.DoubleType)) 212 case 1082: // date 213 cols[i] = colDate 214 node = parquet.Optional(parquet.Date()) 215 case 1114, 1184: // timestamp, timestamptz 216 cols[i] = colTimestamp 217 node = parquet.Optional(parquet.Timestamp(parquet.Microsecond)) 218 default: 219 // text (25), varchar (1043), char (18, 1042), numeric (1700), etc. 220 cols[i] = colString 221 node = parquet.Optional(parquet.String()) 222 } 223 224 group[fd.Name] = node 225 } 226 227 return parquet.NewSchema("wrds", group), cols 228 } 229 230 var epoch = time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) 231 232 // convertValue converts a pgx-scanned value to the appropriate Go type for parquet-go. 233 func convertValue(v any, ct colType) any { 234 if v == nil { 235 return nil 236 } 237 238 switch ct { 239 case colBool: 240 if b, ok := v.(bool); ok { 241 return b 242 } 243 case colInt32: 244 switch n := v.(type) { 245 case int16: 246 return int32(n) 247 case int32: 248 return n 249 case int64: 250 return int32(n) 251 } 252 case colInt64: 253 switch n := v.(type) { 254 case int64: 255 return n 256 case int32: 257 return int64(n) 258 case int16: 259 return int64(n) 260 } 261 case colFloat32: 262 if f, ok := v.(float32); ok { 263 return f 264 } 265 if f, ok := v.(float64); ok { 266 return float32(f) 267 } 268 case colFloat64: 269 if f, ok := v.(float64); ok { 270 return f 271 } 272 if f, ok := v.(float32); ok { 273 return float64(f) 274 } 275 case colDate: 276 if t, ok := v.(time.Time); ok { 277 days := int32(t.Sub(epoch).Hours() / 24) 278 return days 279 } 280 case colTimestamp: 281 if t, ok := v.(time.Time); ok { 282 return t.Sub(epoch).Microseconds() 283 } 284 case colString: 285 return formatValue(v) 286 } 287 288 // Fallback: stringify. 289 return formatValue(v) 290 } 291 292 // formatValue converts any value to its string representation. 293 func formatValue(v any) string { 294 if v == nil { 295 return "" 296 } 297 switch val := v.(type) { 298 case string: 299 return val 300 case []byte: 301 return string(val) 302 case time.Time: 303 if val.Hour() == 0 && val.Minute() == 0 && val.Second() == 0 && val.Nanosecond() == 0 { 304 return val.Format("2006-01-02") 305 } 306 return val.Format(time.RFC3339) 307 case pgtype.Numeric: 308 if !val.Valid { 309 return "" 310 } 311 if val.NaN { 312 return "NaN" 313 } 314 if val.InfinityModifier == pgtype.Infinity { 315 return "Infinity" 316 } 317 if val.InfinityModifier == pgtype.NegativeInfinity { 318 return "-Infinity" 319 } 320 // Convert to big.Float for string representation. 321 bi := val.Int 322 if bi == nil { 323 bi = new(big.Int) 324 } 325 bf := new(big.Float).SetInt(bi) 326 if val.Exp < 0 { 327 divisor := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(-val.Exp)), nil)) 328 bf.Quo(bf, divisor) 329 } else if val.Exp > 0 { 330 multiplier := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(val.Exp)), nil)) 331 bf.Mul(bf, multiplier) 332 } 333 return bf.Text('f', -1) 334 default: 335 return fmt.Sprintf("%v", val) 336 } 337 }