commit 07747b009035be9ccb836492c7961c0937d1889f
parent b5e24ecc5c20f8f1561b695cbad912b10b84a651
Author: Erik Loualiche <[email protected]>
Date: Fri, 20 Feb 2026 15:00:39 -0600
Merge pull request #5 from LouLouLibs/feat/improvements-roadmap
Add progress feedback, info command, dry-run, CI, and SQL quoting
Diffstat:
12 files changed, 642 insertions(+), 31 deletions(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
@@ -0,0 +1,27 @@
+name: CI
+
+on:
+ push:
+ branches: [main]
+ pull_request:
+ branches: [main]
+ workflow_call:
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+
+ - uses: actions/setup-go@v5
+ with:
+ go-version: "1.25"
+
+ - name: Vet
+ run: go vet ./...
+
+ - name: Build
+ run: go build ./...
+
+ - name: Test
+ run: go test ./...
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
@@ -7,7 +7,11 @@ on:
workflow_dispatch:
jobs:
+ ci:
+ uses: ./.github/workflows/ci.yml
+
build:
+ needs: [ci]
runs-on: ubuntu-latest
strategy:
matrix:
diff --git a/README.md b/README.md
@@ -6,8 +6,11 @@ A terminal tool for browsing and downloading data from the [WRDS](https://wrds-w
- **TUI** — browse schemas and tables, inspect column metadata, trigger downloads without leaving the terminal
- **CLI** — scriptable `download` command with structured flags or raw SQL
+- **`info` command** — inspect table metadata (columns, types, row count) from the command line or scripts
- **Parquet output** — streams rows via pgx and writes Parquet with ZSTD compression using parquet-go (pure Go)
- **CSV output** — streams rows to CSV via encoding/csv
+- **Progress feedback** — live row count during large exports (CLI and TUI)
+- **Dry-run mode** — preview the query, row count, and first 5 rows before committing to a download
- **Login flow** — interactive login screen with Duo 2FA support; saved credentials for one-press reconnect
- **Database switching** — browse and switch between WRDS databases from within the TUI
- **Standard auth** — reads from `PG*` environment variables, `~/.config/wrds-dl/credentials`, or `~/.pgpass`
@@ -113,11 +116,14 @@ Press `d` on a selected table to open the download form:
|---|---|
| SELECT columns | Comma-separated column names, or `*` for all |
| WHERE clause | SQL filter without the `WHERE` keyword |
+| LIMIT rows | Maximum number of rows to download (leave empty for no limit) |
| Output path | File path; defaults to `./schema_table.parquet` |
| Format | `parquet` or `csv` |
Navigate with `tab`/`shift+tab`, confirm with `enter` on the last field.
+During download, the spinner shows a live row count updated every 10,000 rows.
+
## CLI
### Structured download
@@ -159,7 +165,21 @@ wrds-dl download \
Format is inferred from the output file extension (`.parquet` or `.csv`). Override with `--format`.
-### All flags
+### Dry run
+
+Preview what a download will do before committing:
+
+```sh
+wrds-dl download \
+ --schema crsp \
+ --table dsf \
+ --where "date = '2020-01-02'" \
+ --dry-run
+```
+
+This prints the SQL query, the row count, and the first 5 rows as a table. No `--out` flag is required for dry runs.
+
+### All download flags
| Flag | Description |
|---|---|
@@ -168,9 +188,38 @@ Format is inferred from the output file extension (`.parquet` or `.csv`). Overri
| `-c`, `--columns` | Columns to select (comma-separated, default `*`) |
| `--where` | SQL `WHERE` clause, without the keyword |
| `--query` | Full SQL query — overrides `--schema`, `--table`, `--where`, `--columns` |
-| `--out` | Output file path (required) |
+| `--out` | Output file path (required unless `--dry-run`) |
| `--format` | `parquet` or `csv` (inferred from extension if omitted) |
| `--limit` | Row limit, useful for testing (default: no limit) |
+| `--dry-run` | Preview query, row count, and first 5 rows without downloading |
+
+### Table info
+
+Inspect table metadata without downloading data:
+
+```sh
+wrds-dl info --schema crsp --table dsf
+```
+
+Output:
+
+```
+crsp.dsf
+ Daily Stock File
+ ~245302893 rows, 47 GB
+
+NAME TYPE NULLABLE DESCRIPTION
+cusip character varying(8) YES CUSIP - HISTORICAL
+permno double precision YES PERMNO
+permco double precision YES PERMCO
+...
+```
+
+For machine-readable output (useful in scripts and coding assistants):
+
+```sh
+wrds-dl info --schema crsp --table dsf --json
+```
## How it works
@@ -180,8 +229,12 @@ Format is inferred from the output file extension (`.parquet` or `.csv`). Overri
- **Parquet**: rows are batched (10,000 per row group) and written with ZSTD compression via [parquet-go](https://github.com/parquet-go/parquet-go). String columns use PLAIN encoding for broad compatibility (R, Python, Julia).
- **CSV**: rows are streamed directly to disk via Go's `encoding/csv`.
+Progress is reported every 10,000 rows — printed to stderr on the CLI and shown in the TUI spinner overlay.
+
PostgreSQL types are mapped to Parquet types: `bool` → BOOLEAN, `int2/int4` → INT32, `int8` → INT64, `float4` → FLOAT, `float8` → DOUBLE, `date` → DATE, `timestamp/timestamptz` → TIMESTAMP (microseconds), `numeric` → STRING, `text/varchar/char` → STRING.
+Schema and table names are quoted as PostgreSQL identifiers to prevent SQL injection. Column names from `--columns` are individually quoted.
+
## Project structure
```
@@ -190,7 +243,8 @@ wrds-download/
├── cmd/
│ ├── root.go # cobra root command
│ ├── tui.go # `wrds-dl tui` — launches interactive browser
-│ └── download.go # `wrds-dl download` — CLI download command
+│ ├── download.go # `wrds-dl download` — CLI download with --dry-run
+│ └── info.go # `wrds-dl info` — table metadata inspection
├── internal/
│ ├── db/
│ │ ├── client.go # pgx pool, DSN construction, connection management
@@ -200,12 +254,13 @@ wrds-download/
│ ├── tui/
│ │ ├── app.go # root Bubble Tea model, Update/View, pane navigation
│ │ ├── loginform.go # login dialog with saved-credentials support
-│ │ ├── dlform.go # download dialog (columns, where, output, format)
+│ │ ├── dlform.go # download dialog (columns, where, limit, output, format)
│ │ └── styles.go # lipgloss styles and colors
│ └── config/
│ └── config.go # credentials file read/write (~/.config/wrds-dl/)
└── .github/workflows/
- └── release.yml # CI: cross-compile 4 targets, attach to GitHub Release
+ ├── ci.yml # CI: go vet, build, and test on push/PR
+ └── release.yml # Release: cross-compile 4 targets, attach to GitHub Release
```
## Dependencies
diff --git a/cmd/download.go b/cmd/download.go
@@ -1,24 +1,30 @@
package cmd
import (
+ "context"
"fmt"
"os"
"strings"
+ "text/tabwriter"
+ "time"
+ "github.com/jackc/pgx/v5"
"github.com/louloulibs/wrds-download/internal/config"
+ "github.com/louloulibs/wrds-download/internal/db"
"github.com/louloulibs/wrds-download/internal/export"
"github.com/spf13/cobra"
)
var (
- dlSchema string
- dlTable string
+ dlSchema string
+ dlTable string
dlColumns string
- dlWhere string
- dlQuery string
- dlOut string
- dlFormat string
- dlLimit int
+ dlWhere string
+ dlQuery string
+ dlOut string
+ dlFormat string
+ dlLimit int
+ dlDryRun bool
)
var downloadCmd = &cobra.Command{
@@ -46,8 +52,7 @@ func init() {
f.StringVar(&dlOut, "out", "", "Output file path (required)")
f.StringVar(&dlFormat, "format", "", "Output format: parquet or csv (inferred from extension if omitted)")
f.IntVar(&dlLimit, "limit", 0, "Limit number of rows (0 = no limit)")
-
- _ = downloadCmd.MarkFlagRequired("out")
+ f.BoolVar(&dlDryRun, "dry-run", false, "Preview the query, row count, and first 5 rows without downloading")
}
func runDownload(cmd *cobra.Command, args []string) error {
@@ -58,11 +63,24 @@ func runDownload(cmd *cobra.Command, args []string) error {
return err
}
+ if dlDryRun {
+ return runDryRun(query)
+ }
+
+ if dlOut == "" {
+ return fmt.Errorf("required flag \"out\" not set")
+ }
+
format := resolveFormat(dlOut, dlFormat)
fmt.Fprintf(os.Stderr, "Exporting to %s (%s)...\n", dlOut, format)
- opts := export.Options{Format: format}
+ opts := export.Options{
+ Format: format,
+ ProgressFunc: func(rows int) {
+ fmt.Fprintf(os.Stderr, "Exported %d rows...\n", rows)
+ },
+ }
if err := export.Export(query, dlOut, opts); err != nil {
return fmt.Errorf("export failed: %w", err)
}
@@ -71,6 +89,70 @@ func runDownload(cmd *cobra.Command, args []string) error {
return nil
}
+func runDryRun(query string) error {
+ dsn, err := db.DSNFromEnv()
+ if err != nil {
+ return fmt.Errorf("dsn: %w", err)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ conn, err := pgx.Connect(ctx, dsn)
+ if err != nil {
+ return fmt.Errorf("connect: %w", err)
+ }
+ defer conn.Close(ctx)
+
+ fmt.Fprintln(os.Stdout, "Query:")
+ fmt.Fprintln(os.Stdout, " ", query)
+ fmt.Fprintln(os.Stdout)
+
+ // Row count
+ countQuery := fmt.Sprintf("SELECT count(*) FROM (%s) sub", query)
+ var count int64
+ if err := conn.QueryRow(ctx, countQuery).Scan(&count); err != nil {
+ return fmt.Errorf("count query: %w", err)
+ }
+ fmt.Fprintf(os.Stdout, "Row count: %d\n\n", count)
+
+ // Preview first 5 rows
+ previewQuery := fmt.Sprintf("SELECT * FROM (%s) sub LIMIT 5", query)
+ rows, err := conn.Query(ctx, previewQuery)
+ if err != nil {
+ return fmt.Errorf("preview query: %w", err)
+ }
+ defer rows.Close()
+
+ fds := rows.FieldDescriptions()
+ w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
+
+ // Header
+ headers := make([]string, len(fds))
+ for i, fd := range fds {
+ headers[i] = fd.Name
+ }
+ fmt.Fprintln(w, strings.Join(headers, "\t"))
+
+ // Rows
+ for rows.Next() {
+ vals, err := rows.Values()
+ if err != nil {
+ return fmt.Errorf("scan row: %w", err)
+ }
+ cells := make([]string, len(vals))
+ for i, v := range vals {
+ if v == nil {
+ cells[i] = "NULL"
+ } else {
+ cells[i] = fmt.Sprintf("%v", v)
+ }
+ }
+ fmt.Fprintln(w, strings.Join(cells, "\t"))
+ }
+ return w.Flush()
+}
+
func buildQuery() (string, error) {
if dlQuery != "" {
return dlQuery, nil
@@ -81,9 +163,14 @@ func buildQuery() (string, error) {
sel := "*"
if dlColumns != "" && dlColumns != "*" {
- sel = dlColumns
+ parts := strings.Split(dlColumns, ",")
+ quoted := make([]string, len(parts))
+ for i, p := range parts {
+ quoted[i] = db.QuoteIdent(strings.TrimSpace(p))
+ }
+ sel = strings.Join(quoted, ", ")
}
- q := fmt.Sprintf("SELECT %s FROM %s.%s", sel, dlSchema, dlTable)
+ q := fmt.Sprintf("SELECT %s FROM %s.%s", sel, db.QuoteIdent(dlSchema), db.QuoteIdent(dlTable))
if dlWhere != "" {
q += " WHERE " + dlWhere
diff --git a/cmd/download_test.go b/cmd/download_test.go
@@ -0,0 +1,117 @@
+package cmd
+
+import "testing"
+
+func TestBuildQuery(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func()
+ want string
+ wantErr bool
+ }{
+ {
+ name: "raw query passthrough",
+ setup: func() {
+ dlQuery = "SELECT * FROM crsp.dsf LIMIT 10"
+ dlSchema = ""
+ dlTable = ""
+ },
+ want: "SELECT * FROM crsp.dsf LIMIT 10",
+ },
+ {
+ name: "schema and table",
+ setup: func() {
+ dlQuery = ""
+ dlSchema = "crsp"
+ dlTable = "dsf"
+ dlColumns = "*"
+ dlWhere = ""
+ dlLimit = 0
+ },
+ want: `SELECT * FROM "crsp"."dsf"`,
+ },
+ {
+ name: "with columns",
+ setup: func() {
+ dlQuery = ""
+ dlSchema = "comp"
+ dlTable = "funda"
+ dlColumns = "gvkey,datadate,sale"
+ dlWhere = ""
+ dlLimit = 0
+ },
+ want: `SELECT "gvkey", "datadate", "sale" FROM "comp"."funda"`,
+ },
+ {
+ name: "with where and limit",
+ setup: func() {
+ dlQuery = ""
+ dlSchema = "crsp"
+ dlTable = "dsf"
+ dlColumns = "*"
+ dlWhere = "date >= '2020-01-01'"
+ dlLimit = 1000
+ },
+ want: `SELECT * FROM "crsp"."dsf" WHERE date >= '2020-01-01' LIMIT 1000`,
+ },
+ {
+ name: "missing schema and table",
+ setup: func() {
+ dlQuery = ""
+ dlSchema = ""
+ dlTable = ""
+ },
+ wantErr: true,
+ },
+ {
+ name: "column with spaces trimmed",
+ setup: func() {
+ dlQuery = ""
+ dlSchema = "crsp"
+ dlTable = "dsf"
+ dlColumns = " permno , date , prc "
+ dlWhere = ""
+ dlLimit = 0
+ },
+ want: `SELECT "permno", "date", "prc" FROM "crsp"."dsf"`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tt.setup()
+ got, err := buildQuery()
+ if (err != nil) != tt.wantErr {
+ t.Fatalf("buildQuery() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ if !tt.wantErr && got != tt.want {
+ t.Errorf("buildQuery() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestResolveFormat(t *testing.T) {
+ tests := []struct {
+ path string
+ flag string
+ want string
+ }{
+ {"out.parquet", "", "parquet"},
+ {"out.csv", "", "csv"},
+ {"out.CSV", "", "csv"},
+ {"out.parquet", "csv", "csv"},
+ {"out.csv", "parquet", "parquet"},
+ {"out.txt", "", "parquet"},
+ {"out", "CSV", "csv"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.path+"_"+tt.flag, func(t *testing.T) {
+ got := resolveFormat(tt.path, tt.flag)
+ if got != tt.want {
+ t.Errorf("resolveFormat(%q, %q) = %q, want %q", tt.path, tt.flag, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/cmd/info.go b/cmd/info.go
@@ -0,0 +1,138 @@
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "text/tabwriter"
+ "time"
+
+ "github.com/louloulibs/wrds-download/internal/config"
+ "github.com/louloulibs/wrds-download/internal/db"
+ "github.com/spf13/cobra"
+)
+
+var (
+ infoSchema string
+ infoTable string
+ infoJSON bool
+)
+
+var infoCmd = &cobra.Command{
+ Use: "info",
+ Short: "Show table metadata (columns, types, row count)",
+ Long: `Display metadata for a WRDS table: comment, estimated row count,
+size, and column details (name, type, nullable, description).
+
+Examples:
+ wrds-dl info --schema crsp --table dsf
+ wrds-dl info --schema comp --table funda --json`,
+ RunE: runInfo,
+}
+
+func init() {
+ rootCmd.AddCommand(infoCmd)
+
+ f := infoCmd.Flags()
+ f.StringVar(&infoSchema, "schema", "", "Schema name (required)")
+ f.StringVar(&infoTable, "table", "", "Table name (required)")
+ f.BoolVar(&infoJSON, "json", false, "Output as JSON")
+
+ _ = infoCmd.MarkFlagRequired("schema")
+ _ = infoCmd.MarkFlagRequired("table")
+}
+
+func runInfo(cmd *cobra.Command, args []string) error {
+ config.ApplyCredentials()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ client, err := db.New(ctx)
+ if err != nil {
+ return fmt.Errorf("connect: %w", err)
+ }
+ defer client.Close()
+
+ meta, err := client.TableMeta(ctx, infoSchema, infoTable)
+ if err != nil {
+ return fmt.Errorf("table meta: %w", err)
+ }
+
+ if infoJSON {
+ return printInfoJSON(meta)
+ }
+ return printInfoTable(meta)
+}
+
+type jsonColumn struct {
+ Name string `json:"name"`
+ Type string `json:"type"`
+ Nullable bool `json:"nullable"`
+ Description string `json:"description,omitempty"`
+}
+
+type jsonInfo struct {
+ Schema string `json:"schema"`
+ Table string `json:"table"`
+ Comment string `json:"comment,omitempty"`
+ RowCount int64 `json:"row_count"`
+ Size string `json:"size,omitempty"`
+ Columns []jsonColumn `json:"columns"`
+}
+
+func printInfoJSON(meta *db.TableMeta) error {
+ info := jsonInfo{
+ Schema: meta.Schema,
+ Table: meta.Table,
+ Comment: meta.Comment,
+ RowCount: meta.RowCount,
+ Size: meta.Size,
+ Columns: make([]jsonColumn, len(meta.Columns)),
+ }
+ for i, c := range meta.Columns {
+ info.Columns[i] = jsonColumn{
+ Name: c.Name,
+ Type: c.DataType,
+ Nullable: c.Nullable,
+ Description: c.Description,
+ }
+ }
+
+ enc := json.NewEncoder(os.Stdout)
+ enc.SetIndent("", " ")
+ return enc.Encode(info)
+}
+
+func printInfoTable(meta *db.TableMeta) error {
+ fmt.Fprintf(os.Stdout, "%s.%s\n", meta.Schema, meta.Table)
+ if meta.Comment != "" {
+ fmt.Fprintf(os.Stdout, " %s\n", meta.Comment)
+ }
+ if meta.RowCount > 0 || meta.Size != "" {
+ parts := ""
+ if meta.RowCount > 0 {
+ parts += fmt.Sprintf("~%d rows", meta.RowCount)
+ }
+ if meta.Size != "" {
+ if parts != "" {
+ parts += ", "
+ }
+ parts += meta.Size
+ }
+ fmt.Fprintf(os.Stdout, " %s\n", parts)
+ }
+ fmt.Fprintln(os.Stdout)
+
+ w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
+ fmt.Fprintln(w, "NAME\tTYPE\tNULLABLE\tDESCRIPTION")
+ for _, c := range meta.Columns {
+ nullable := "NO"
+ if c.Nullable {
+ nullable = "YES"
+ }
+ fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", c.Name, c.DataType, nullable, c.Description)
+ }
+ return w.Flush()
+}
diff --git a/internal/db/meta.go b/internal/db/meta.go
@@ -175,7 +175,7 @@ func (c *Client) Preview(ctx context.Context, schema, table string, limit int) (
limit = 50
}
- qualified := fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
+ qualified := fmt.Sprintf("%s.%s", QuoteIdent(schema), QuoteIdent(table))
// Estimated count via pg stats (fast).
var total int64
@@ -220,6 +220,8 @@ func (c *Client) Preview(ctx context.Context, schema, table string, limit int) (
return &result, rows.Err()
}
-func quoteIdent(s string) string {
+// QuoteIdent quotes a PostgreSQL identifier (schema, table, column name)
+// to prevent SQL injection.
+func QuoteIdent(s string) string {
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
}
diff --git a/internal/db/meta_test.go b/internal/db/meta_test.go
@@ -0,0 +1,26 @@
+package db
+
+import "testing"
+
+func TestQuoteIdent(t *testing.T) {
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"crsp", `"crsp"`},
+ {"dsf", `"dsf"`},
+ {`my"table`, `"my""table"`},
+ {"", `""`},
+ {"with space", `"with space"`},
+ {`double""quote`, `"double""""quote"`},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ got := QuoteIdent(tt.input)
+ if got != tt.want {
+ t.Errorf("QuoteIdent(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/internal/export/export.go b/internal/export/export.go
@@ -20,7 +20,8 @@ import (
// Options controls the export behaviour.
type Options struct {
- Format string // "parquet" or "csv"
+ Format string // "parquet" or "csv"
+ ProgressFunc func(rows int) // called periodically with total rows exported so far
}
const rowGroupSize = 10_000
@@ -57,14 +58,14 @@ func Export(query, outPath string, opts Options) error {
switch format {
case "csv":
- return writeCSV(rows, outPath)
+ return writeCSV(rows, outPath, opts.ProgressFunc)
default:
- return writeParquet(rows, outPath)
+ return writeParquet(rows, outPath, opts.ProgressFunc)
}
}
// writeCSV streams rows into a CSV file with a header row.
-func writeCSV(rows pgx.Rows, outPath string) error {
+func writeCSV(rows pgx.Rows, outPath string, progressFn func(int)) error {
f, err := os.Create(outPath)
if err != nil {
return fmt.Errorf("create csv: %w", err)
@@ -84,6 +85,7 @@ func writeCSV(rows pgx.Rows, outPath string) error {
}
record := make([]string, len(fds))
+ total := 0
for rows.Next() {
vals, err := rows.Values()
if err != nil {
@@ -95,6 +97,10 @@ func writeCSV(rows pgx.Rows, outPath string) error {
if err := w.Write(record); err != nil {
return fmt.Errorf("write row: %w", err)
}
+ total++
+ if progressFn != nil && total%rowGroupSize == 0 {
+ progressFn(total)
+ }
}
if err := rows.Err(); err != nil {
return fmt.Errorf("rows: %w", err)
@@ -105,7 +111,7 @@ func writeCSV(rows pgx.Rows, outPath string) error {
}
// writeParquet streams rows into a Parquet file using parquet-go.
-func writeParquet(rows pgx.Rows, outPath string) error {
+func writeParquet(rows pgx.Rows, outPath string, progressFn func(int)) error {
fds := rows.FieldDescriptions()
schema, colTypes := buildParquetSchema(fds)
@@ -123,6 +129,7 @@ func writeParquet(rows pgx.Rows, outPath string) error {
)
buf := make([]map[string]any, 0, rowGroupSize)
+ total := 0
for rows.Next() {
vals, err := rows.Values()
@@ -140,7 +147,11 @@ func writeParquet(rows pgx.Rows, outPath string) error {
if _, err := writer.Write(buf); err != nil {
return fmt.Errorf("write row group: %w", err)
}
+ total += len(buf)
buf = buf[:0]
+ if progressFn != nil {
+ progressFn(total)
+ }
}
}
if err := rows.Err(); err != nil {
diff --git a/internal/export/export_test.go b/internal/export/export_test.go
@@ -0,0 +1,81 @@
+package export
+
+import (
+ "math/big"
+ "testing"
+ "time"
+
+ "github.com/jackc/pgx/v5/pgtype"
+)
+
+func TestFormatValue(t *testing.T) {
+ tests := []struct {
+ name string
+ v any
+ want string
+ }{
+ {"nil", nil, ""},
+ {"string", "hello", "hello"},
+ {"bytes", []byte("data"), "data"},
+ {"int", 42, "42"},
+ {"float", 3.14, "3.14"},
+ {"bool", true, "true"},
+ {"date only", time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), "2020-01-02"},
+ {"datetime", time.Date(2020, 1, 2, 15, 4, 5, 0, time.UTC), "2020-01-02T15:04:05Z"},
+ {"numeric valid", pgtype.Numeric{Int: big.NewInt(12345), Exp: -2, Valid: true}, "123.45"},
+ {"numeric invalid", pgtype.Numeric{Valid: false}, ""},
+ {"numeric NaN", pgtype.Numeric{Valid: true, NaN: true}, "NaN"},
+ {"numeric Infinity", pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}, "Infinity"},
+ {"numeric -Infinity", pgtype.Numeric{Valid: true, InfinityModifier: pgtype.NegativeInfinity}, "-Infinity"},
+ {"numeric zero exp", pgtype.Numeric{Int: big.NewInt(42), Exp: 0, Valid: true}, "42"},
+ {"numeric positive exp", pgtype.Numeric{Int: big.NewInt(5), Exp: 3, Valid: true}, "5000"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := formatValue(tt.v)
+ if got != tt.want {
+ t.Errorf("formatValue(%v) = %q, want %q", tt.v, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestConvertValue(t *testing.T) {
+ epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)
+
+ tests := []struct {
+ name string
+ v any
+ ct colType
+ want any
+ }{
+ {"nil", nil, colString, nil},
+ {"bool true", true, colBool, true},
+ {"bool false", false, colBool, false},
+ {"int16 to int32", int16(42), colInt32, int32(42)},
+ {"int32 passthrough", int32(100), colInt32, int32(100)},
+ {"int64 to int32", int64(200), colInt32, int32(200)},
+ {"int64 passthrough", int64(999), colInt64, int64(999)},
+ {"int32 to int64", int32(50), colInt64, int64(50)},
+ {"int16 to int64", int16(10), colInt64, int64(10)},
+ {"float32 passthrough", float32(1.5), colFloat32, float32(1.5)},
+ {"float64 to float32", float64(2.5), colFloat32, float32(2.5)},
+ {"float64 passthrough", float64(3.14), colFloat64, float64(3.14)},
+ {"float32 to float64", float32(1.5), colFloat64, float64(1.5)},
+ {"date", time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), colDate,
+ int32(time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC).Sub(epoch).Hours() / 24)},
+ {"timestamp", time.Date(2020, 6, 15, 12, 30, 0, 0, time.UTC), colTimestamp,
+ time.Date(2020, 6, 15, 12, 30, 0, 0, time.UTC).Sub(epoch).Microseconds()},
+ {"string passthrough", "hello", colString, "hello"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := convertValue(tt.v, tt.ct)
+ if got != tt.want {
+ t.Errorf("convertValue(%v, %v) = %v (%T), want %v (%T)", tt.v, tt.ct, got, got, tt.want, tt.want)
+ }
+ })
+ }
+}
diff --git a/internal/tui/app.go b/internal/tui/app.go
@@ -45,6 +45,7 @@ type tablesLoadedMsg struct{ tables []db.Table }
type metaLoadedMsg struct{ meta *db.TableMeta }
type errMsg struct{ err error }
type downloadDoneMsg struct{ path string }
+type downloadProgressMsg struct{ rows int }
type tickMsg time.Time
type loginSuccessMsg struct{ client *db.Client }
type loginFailMsg struct{ err error }
@@ -89,6 +90,8 @@ type App struct {
currentDatabase string
selectedSchema string
selectedTable string
+ downloadRows int
+ dlProgressCh chan int
}
func newPreviewFilter() textinput.Model {
@@ -228,21 +231,52 @@ func (a *App) loadMeta(schema, tbl string) tea.Cmd {
}
func (a *App) startDownload(msg DlSubmitMsg) tea.Cmd {
- return func() tea.Msg {
+ progressCh := make(chan int, 1)
+ a.dlProgressCh = progressCh
+
+ download := func() tea.Msg {
sel := "*"
if msg.Columns != "" && msg.Columns != "*" {
- sel = msg.Columns
+ parts := strings.Split(msg.Columns, ",")
+ quoted := make([]string, len(parts))
+ for i, p := range parts {
+ quoted[i] = db.QuoteIdent(strings.TrimSpace(p))
+ }
+ sel = strings.Join(quoted, ", ")
}
- query := fmt.Sprintf("SELECT %s FROM %s.%s", sel, msg.Schema, msg.Table)
+ query := fmt.Sprintf("SELECT %s FROM %s.%s", sel, db.QuoteIdent(msg.Schema), db.QuoteIdent(msg.Table))
if msg.Where != "" {
query += " WHERE " + msg.Where
}
- err := export.Export(query, msg.Out, export.Options{Format: msg.Format})
+ if msg.Limit > 0 {
+ query += fmt.Sprintf(" LIMIT %d", msg.Limit)
+ }
+ opts := export.Options{
+ Format: msg.Format,
+ ProgressFunc: func(rows int) {
+ select {
+ case progressCh <- rows:
+ default:
+ }
+ },
+ }
+ err := export.Export(query, msg.Out, opts)
+ close(progressCh)
if err != nil {
return errMsg{err}
}
return downloadDoneMsg{msg.Out}
}
+
+ listenProgress := func() tea.Msg {
+ rows, ok := <-progressCh
+ if !ok {
+ return nil
+ }
+ return downloadProgressMsg{rows}
+ }
+
+ return tea.Batch(download, listenProgress)
}
func (a *App) attemptLogin(msg LoginSubmitMsg) tea.Cmd {
@@ -391,9 +425,22 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.state = stateBrowse
return a, nil
+ case downloadProgressMsg:
+ a.downloadRows = msg.rows
+ // Keep listening for more progress updates.
+ ch := a.dlProgressCh
+ return a, func() tea.Msg {
+ rows, ok := <-ch
+ if !ok {
+ return nil
+ }
+ return downloadProgressMsg{rows}
+ }
+
case downloadDoneMsg:
a.statusOK = fmt.Sprintf("Saved: %s", msg.path)
a.state = stateDone
+ a.downloadRows = 0
return a, nil
case DlCancelMsg:
@@ -404,6 +451,7 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.state = stateDownloading
a.statusErr = ""
a.statusOK = ""
+ a.downloadRows = 0
return a, tea.Batch(a.startDownload(msg), a.spinner.Tick)
case list.FilterMatchesMsg:
@@ -648,8 +696,11 @@ func (a *App) View() string {
return overlayCenter(full, overlay, a.width, a.height)
}
if a.state == stateDownloading {
- msg := a.spinner.View() + " Downloading…"
- return overlayCenter(full, stylePanelFocused.Padding(1, 3).Render(msg), a.width, a.height)
+ dlMsg := a.spinner.View() + " Downloading…"
+ if a.downloadRows > 0 {
+ dlMsg = a.spinner.View() + fmt.Sprintf(" Downloading… %s rows exported", formatCount(int64(a.downloadRows)))
+ }
+ return overlayCenter(full, stylePanelFocused.Padding(1, 3).Render(dlMsg), a.width, a.height)
}
if a.state == stateDone {
msg := styleSuccess.Render("✓ ") + a.statusOK + "\n\n" + styleStatusBar.Render("[esc] dismiss")
diff --git a/internal/tui/dlform.go b/internal/tui/dlform.go
@@ -2,6 +2,7 @@ package tui
import (
"fmt"
+ "strconv"
"strings"
"github.com/charmbracelet/bubbles/textinput"
@@ -14,6 +15,7 @@ type dlFormField int
const (
fieldSelect dlFormField = iota
fieldWhere
+ fieldLimit
fieldOut
fieldFormat
fieldCount
@@ -34,6 +36,7 @@ type DlSubmitMsg struct {
Table string
Columns string
Where string
+ Limit int
Out string
Format string
}
@@ -61,6 +64,10 @@ func newDlForm(schema, table string, colNames []string) DlForm {
f.inputs[fieldWhere].Placeholder = "e.g. date >= '2020-01-01'"
f.inputs[fieldWhere].CharLimit = 512
+ f.inputs[fieldLimit] = textinput.New()
+ f.inputs[fieldLimit].Placeholder = "no limit"
+ f.inputs[fieldLimit].CharLimit = 12
+
f.inputs[fieldOut] = textinput.New()
f.inputs[fieldOut].Placeholder = fmt.Sprintf("./%s_%s.parquet", schema, table)
f.inputs[fieldOut].CharLimit = 256
@@ -101,12 +108,17 @@ func (f DlForm) Update(msg tea.Msg) (DlForm, tea.Cmd) {
if columns == "" {
columns = "*"
}
+ var limit int
+ if v := strings.TrimSpace(f.inputs[fieldLimit].Value()); v != "" {
+ limit, _ = strconv.Atoi(v)
+ }
return f, func() tea.Msg {
return DlSubmitMsg{
Schema: f.schema,
Table: f.table,
Columns: columns,
Where: f.inputs[fieldWhere].Value(),
+ Limit: limit,
Out: out,
Format: format,
}
@@ -135,7 +147,7 @@ func (f DlForm) View(width int) string {
title := stylePanelHeader.Render(fmt.Sprintf("Download %s.%s", f.schema, f.table))
sb.WriteString(title + "\n\n")
- labels := []string{"SELECT columns", "WHERE clause", "Output path", "Format (parquet/csv)"}
+ labels := []string{"SELECT columns", "WHERE clause", "LIMIT rows", "Output path", "Format (parquet/csv)"}
for i, label := range labels {
style := lipgloss.NewStyle().Foreground(colorMuted)
if dlFormField(i) == f.focused {