wrds-download

TUI/CLI tool for browsing and downloading WRDS data
Log | Files | Refs | README

commit 07747b009035be9ccb836492c7961c0937d1889f
parent b5e24ecc5c20f8f1561b695cbad912b10b84a651
Author: Erik Loualiche <[email protected]>
Date:   Fri, 20 Feb 2026 15:00:39 -0600

Merge pull request #5 from LouLouLibs/feat/improvements-roadmap

Add progress feedback, info command, dry-run, CI, and SQL quoting
Diffstat:
A.github/workflows/ci.yml | 27+++++++++++++++++++++++++++
M.github/workflows/release.yml | 4++++
MREADME.md | 65++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----
Mcmd/download.go | 111++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------
Acmd/download_test.go | 117+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Acmd/info.go | 138+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Minternal/db/meta.go | 6++++--
Ainternal/db/meta_test.go | 26++++++++++++++++++++++++++
Minternal/export/export.go | 21++++++++++++++++-----
Ainternal/export/export_test.go | 81+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Minternal/tui/app.go | 63+++++++++++++++++++++++++++++++++++++++++++++++++++++++++------
Minternal/tui/dlform.go | 14+++++++++++++-
12 files changed, 642 insertions(+), 31 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml @@ -0,0 +1,27 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + workflow_call: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: "1.25" + + - name: Vet + run: go vet ./... + + - name: Build + run: go build ./... + + - name: Test + run: go test ./... diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml @@ -7,7 +7,11 @@ on: workflow_dispatch: jobs: + ci: + uses: ./.github/workflows/ci.yml + build: + needs: [ci] runs-on: ubuntu-latest strategy: matrix: diff --git a/README.md b/README.md @@ -6,8 +6,11 @@ A terminal tool for browsing and downloading data from the [WRDS](https://wrds-w - **TUI** — browse schemas and tables, inspect column metadata, trigger downloads without leaving the terminal - **CLI** — scriptable `download` command with structured flags or raw SQL +- **`info` command** — inspect table metadata (columns, types, row count) from the command line or scripts - **Parquet output** — streams rows via pgx and writes Parquet with ZSTD compression using parquet-go (pure Go) - **CSV output** — streams rows to CSV via encoding/csv +- **Progress feedback** — live row count during large exports (CLI and TUI) +- **Dry-run mode** — preview the query, row count, and first 5 rows before committing to a download - **Login flow** — interactive login screen with Duo 2FA support; saved credentials for one-press reconnect - **Database switching** — browse and switch between WRDS databases from within the TUI - **Standard auth** — reads from `PG*` environment variables, `~/.config/wrds-dl/credentials`, or `~/.pgpass` @@ -113,11 +116,14 @@ Press `d` on a selected table to open the download form: |---|---| | SELECT columns | Comma-separated column names, or `*` for all | | WHERE clause | SQL filter without the `WHERE` keyword | +| LIMIT rows | Maximum number of rows to download (leave empty for no limit) | | Output path | File path; defaults to `./schema_table.parquet` | | Format | `parquet` or `csv` | Navigate with `tab`/`shift+tab`, confirm with `enter` on the last field. +During download, the spinner shows a live row count updated every 10,000 rows. + ## CLI ### Structured download @@ -159,7 +165,21 @@ wrds-dl download \ Format is inferred from the output file extension (`.parquet` or `.csv`). Override with `--format`. -### All flags +### Dry run + +Preview what a download will do before committing: + +```sh +wrds-dl download \ + --schema crsp \ + --table dsf \ + --where "date = '2020-01-02'" \ + --dry-run +``` + +This prints the SQL query, the row count, and the first 5 rows as a table. No `--out` flag is required for dry runs. + +### All download flags | Flag | Description | |---|---| @@ -168,9 +188,38 @@ Format is inferred from the output file extension (`.parquet` or `.csv`). Overri | `-c`, `--columns` | Columns to select (comma-separated, default `*`) | | `--where` | SQL `WHERE` clause, without the keyword | | `--query` | Full SQL query — overrides `--schema`, `--table`, `--where`, `--columns` | -| `--out` | Output file path (required) | +| `--out` | Output file path (required unless `--dry-run`) | | `--format` | `parquet` or `csv` (inferred from extension if omitted) | | `--limit` | Row limit, useful for testing (default: no limit) | +| `--dry-run` | Preview query, row count, and first 5 rows without downloading | + +### Table info + +Inspect table metadata without downloading data: + +```sh +wrds-dl info --schema crsp --table dsf +``` + +Output: + +``` +crsp.dsf + Daily Stock File + ~245302893 rows, 47 GB + +NAME TYPE NULLABLE DESCRIPTION +cusip character varying(8) YES CUSIP - HISTORICAL +permno double precision YES PERMNO +permco double precision YES PERMCO +... +``` + +For machine-readable output (useful in scripts and coding assistants): + +```sh +wrds-dl info --schema crsp --table dsf --json +``` ## How it works @@ -180,8 +229,12 @@ Format is inferred from the output file extension (`.parquet` or `.csv`). Overri - **Parquet**: rows are batched (10,000 per row group) and written with ZSTD compression via [parquet-go](https://github.com/parquet-go/parquet-go). String columns use PLAIN encoding for broad compatibility (R, Python, Julia). - **CSV**: rows are streamed directly to disk via Go's `encoding/csv`. +Progress is reported every 10,000 rows — printed to stderr on the CLI and shown in the TUI spinner overlay. + PostgreSQL types are mapped to Parquet types: `bool` → BOOLEAN, `int2/int4` → INT32, `int8` → INT64, `float4` → FLOAT, `float8` → DOUBLE, `date` → DATE, `timestamp/timestamptz` → TIMESTAMP (microseconds), `numeric` → STRING, `text/varchar/char` → STRING. +Schema and table names are quoted as PostgreSQL identifiers to prevent SQL injection. Column names from `--columns` are individually quoted. + ## Project structure ``` @@ -190,7 +243,8 @@ wrds-download/ ├── cmd/ │ ├── root.go # cobra root command │ ├── tui.go # `wrds-dl tui` — launches interactive browser -│ └── download.go # `wrds-dl download` — CLI download command +│ ├── download.go # `wrds-dl download` — CLI download with --dry-run +│ └── info.go # `wrds-dl info` — table metadata inspection ├── internal/ │ ├── db/ │ │ ├── client.go # pgx pool, DSN construction, connection management @@ -200,12 +254,13 @@ wrds-download/ │ ├── tui/ │ │ ├── app.go # root Bubble Tea model, Update/View, pane navigation │ │ ├── loginform.go # login dialog with saved-credentials support -│ │ ├── dlform.go # download dialog (columns, where, output, format) +│ │ ├── dlform.go # download dialog (columns, where, limit, output, format) │ │ └── styles.go # lipgloss styles and colors │ └── config/ │ └── config.go # credentials file read/write (~/.config/wrds-dl/) └── .github/workflows/ - └── release.yml # CI: cross-compile 4 targets, attach to GitHub Release + ├── ci.yml # CI: go vet, build, and test on push/PR + └── release.yml # Release: cross-compile 4 targets, attach to GitHub Release ``` ## Dependencies diff --git a/cmd/download.go b/cmd/download.go @@ -1,24 +1,30 @@ package cmd import ( + "context" "fmt" "os" "strings" + "text/tabwriter" + "time" + "github.com/jackc/pgx/v5" "github.com/louloulibs/wrds-download/internal/config" + "github.com/louloulibs/wrds-download/internal/db" "github.com/louloulibs/wrds-download/internal/export" "github.com/spf13/cobra" ) var ( - dlSchema string - dlTable string + dlSchema string + dlTable string dlColumns string - dlWhere string - dlQuery string - dlOut string - dlFormat string - dlLimit int + dlWhere string + dlQuery string + dlOut string + dlFormat string + dlLimit int + dlDryRun bool ) var downloadCmd = &cobra.Command{ @@ -46,8 +52,7 @@ func init() { f.StringVar(&dlOut, "out", "", "Output file path (required)") f.StringVar(&dlFormat, "format", "", "Output format: parquet or csv (inferred from extension if omitted)") f.IntVar(&dlLimit, "limit", 0, "Limit number of rows (0 = no limit)") - - _ = downloadCmd.MarkFlagRequired("out") + f.BoolVar(&dlDryRun, "dry-run", false, "Preview the query, row count, and first 5 rows without downloading") } func runDownload(cmd *cobra.Command, args []string) error { @@ -58,11 +63,24 @@ func runDownload(cmd *cobra.Command, args []string) error { return err } + if dlDryRun { + return runDryRun(query) + } + + if dlOut == "" { + return fmt.Errorf("required flag \"out\" not set") + } + format := resolveFormat(dlOut, dlFormat) fmt.Fprintf(os.Stderr, "Exporting to %s (%s)...\n", dlOut, format) - opts := export.Options{Format: format} + opts := export.Options{ + Format: format, + ProgressFunc: func(rows int) { + fmt.Fprintf(os.Stderr, "Exported %d rows...\n", rows) + }, + } if err := export.Export(query, dlOut, opts); err != nil { return fmt.Errorf("export failed: %w", err) } @@ -71,6 +89,70 @@ func runDownload(cmd *cobra.Command, args []string) error { return nil } +func runDryRun(query string) error { + dsn, err := db.DSNFromEnv() + if err != nil { + return fmt.Errorf("dsn: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, dsn) + if err != nil { + return fmt.Errorf("connect: %w", err) + } + defer conn.Close(ctx) + + fmt.Fprintln(os.Stdout, "Query:") + fmt.Fprintln(os.Stdout, " ", query) + fmt.Fprintln(os.Stdout) + + // Row count + countQuery := fmt.Sprintf("SELECT count(*) FROM (%s) sub", query) + var count int64 + if err := conn.QueryRow(ctx, countQuery).Scan(&count); err != nil { + return fmt.Errorf("count query: %w", err) + } + fmt.Fprintf(os.Stdout, "Row count: %d\n\n", count) + + // Preview first 5 rows + previewQuery := fmt.Sprintf("SELECT * FROM (%s) sub LIMIT 5", query) + rows, err := conn.Query(ctx, previewQuery) + if err != nil { + return fmt.Errorf("preview query: %w", err) + } + defer rows.Close() + + fds := rows.FieldDescriptions() + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + + // Header + headers := make([]string, len(fds)) + for i, fd := range fds { + headers[i] = fd.Name + } + fmt.Fprintln(w, strings.Join(headers, "\t")) + + // Rows + for rows.Next() { + vals, err := rows.Values() + if err != nil { + return fmt.Errorf("scan row: %w", err) + } + cells := make([]string, len(vals)) + for i, v := range vals { + if v == nil { + cells[i] = "NULL" + } else { + cells[i] = fmt.Sprintf("%v", v) + } + } + fmt.Fprintln(w, strings.Join(cells, "\t")) + } + return w.Flush() +} + func buildQuery() (string, error) { if dlQuery != "" { return dlQuery, nil @@ -81,9 +163,14 @@ func buildQuery() (string, error) { sel := "*" if dlColumns != "" && dlColumns != "*" { - sel = dlColumns + parts := strings.Split(dlColumns, ",") + quoted := make([]string, len(parts)) + for i, p := range parts { + quoted[i] = db.QuoteIdent(strings.TrimSpace(p)) + } + sel = strings.Join(quoted, ", ") } - q := fmt.Sprintf("SELECT %s FROM %s.%s", sel, dlSchema, dlTable) + q := fmt.Sprintf("SELECT %s FROM %s.%s", sel, db.QuoteIdent(dlSchema), db.QuoteIdent(dlTable)) if dlWhere != "" { q += " WHERE " + dlWhere diff --git a/cmd/download_test.go b/cmd/download_test.go @@ -0,0 +1,117 @@ +package cmd + +import "testing" + +func TestBuildQuery(t *testing.T) { + tests := []struct { + name string + setup func() + want string + wantErr bool + }{ + { + name: "raw query passthrough", + setup: func() { + dlQuery = "SELECT * FROM crsp.dsf LIMIT 10" + dlSchema = "" + dlTable = "" + }, + want: "SELECT * FROM crsp.dsf LIMIT 10", + }, + { + name: "schema and table", + setup: func() { + dlQuery = "" + dlSchema = "crsp" + dlTable = "dsf" + dlColumns = "*" + dlWhere = "" + dlLimit = 0 + }, + want: `SELECT * FROM "crsp"."dsf"`, + }, + { + name: "with columns", + setup: func() { + dlQuery = "" + dlSchema = "comp" + dlTable = "funda" + dlColumns = "gvkey,datadate,sale" + dlWhere = "" + dlLimit = 0 + }, + want: `SELECT "gvkey", "datadate", "sale" FROM "comp"."funda"`, + }, + { + name: "with where and limit", + setup: func() { + dlQuery = "" + dlSchema = "crsp" + dlTable = "dsf" + dlColumns = "*" + dlWhere = "date >= '2020-01-01'" + dlLimit = 1000 + }, + want: `SELECT * FROM "crsp"."dsf" WHERE date >= '2020-01-01' LIMIT 1000`, + }, + { + name: "missing schema and table", + setup: func() { + dlQuery = "" + dlSchema = "" + dlTable = "" + }, + wantErr: true, + }, + { + name: "column with spaces trimmed", + setup: func() { + dlQuery = "" + dlSchema = "crsp" + dlTable = "dsf" + dlColumns = " permno , date , prc " + dlWhere = "" + dlLimit = 0 + }, + want: `SELECT "permno", "date", "prc" FROM "crsp"."dsf"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup() + got, err := buildQuery() + if (err != nil) != tt.wantErr { + t.Fatalf("buildQuery() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr && got != tt.want { + t.Errorf("buildQuery() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestResolveFormat(t *testing.T) { + tests := []struct { + path string + flag string + want string + }{ + {"out.parquet", "", "parquet"}, + {"out.csv", "", "csv"}, + {"out.CSV", "", "csv"}, + {"out.parquet", "csv", "csv"}, + {"out.csv", "parquet", "parquet"}, + {"out.txt", "", "parquet"}, + {"out", "CSV", "csv"}, + } + + for _, tt := range tests { + t.Run(tt.path+"_"+tt.flag, func(t *testing.T) { + got := resolveFormat(tt.path, tt.flag) + if got != tt.want { + t.Errorf("resolveFormat(%q, %q) = %q, want %q", tt.path, tt.flag, got, tt.want) + } + }) + } +} diff --git a/cmd/info.go b/cmd/info.go @@ -0,0 +1,138 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "text/tabwriter" + "time" + + "github.com/louloulibs/wrds-download/internal/config" + "github.com/louloulibs/wrds-download/internal/db" + "github.com/spf13/cobra" +) + +var ( + infoSchema string + infoTable string + infoJSON bool +) + +var infoCmd = &cobra.Command{ + Use: "info", + Short: "Show table metadata (columns, types, row count)", + Long: `Display metadata for a WRDS table: comment, estimated row count, +size, and column details (name, type, nullable, description). + +Examples: + wrds-dl info --schema crsp --table dsf + wrds-dl info --schema comp --table funda --json`, + RunE: runInfo, +} + +func init() { + rootCmd.AddCommand(infoCmd) + + f := infoCmd.Flags() + f.StringVar(&infoSchema, "schema", "", "Schema name (required)") + f.StringVar(&infoTable, "table", "", "Table name (required)") + f.BoolVar(&infoJSON, "json", false, "Output as JSON") + + _ = infoCmd.MarkFlagRequired("schema") + _ = infoCmd.MarkFlagRequired("table") +} + +func runInfo(cmd *cobra.Command, args []string) error { + config.ApplyCredentials() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + client, err := db.New(ctx) + if err != nil { + return fmt.Errorf("connect: %w", err) + } + defer client.Close() + + meta, err := client.TableMeta(ctx, infoSchema, infoTable) + if err != nil { + return fmt.Errorf("table meta: %w", err) + } + + if infoJSON { + return printInfoJSON(meta) + } + return printInfoTable(meta) +} + +type jsonColumn struct { + Name string `json:"name"` + Type string `json:"type"` + Nullable bool `json:"nullable"` + Description string `json:"description,omitempty"` +} + +type jsonInfo struct { + Schema string `json:"schema"` + Table string `json:"table"` + Comment string `json:"comment,omitempty"` + RowCount int64 `json:"row_count"` + Size string `json:"size,omitempty"` + Columns []jsonColumn `json:"columns"` +} + +func printInfoJSON(meta *db.TableMeta) error { + info := jsonInfo{ + Schema: meta.Schema, + Table: meta.Table, + Comment: meta.Comment, + RowCount: meta.RowCount, + Size: meta.Size, + Columns: make([]jsonColumn, len(meta.Columns)), + } + for i, c := range meta.Columns { + info.Columns[i] = jsonColumn{ + Name: c.Name, + Type: c.DataType, + Nullable: c.Nullable, + Description: c.Description, + } + } + + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + return enc.Encode(info) +} + +func printInfoTable(meta *db.TableMeta) error { + fmt.Fprintf(os.Stdout, "%s.%s\n", meta.Schema, meta.Table) + if meta.Comment != "" { + fmt.Fprintf(os.Stdout, " %s\n", meta.Comment) + } + if meta.RowCount > 0 || meta.Size != "" { + parts := "" + if meta.RowCount > 0 { + parts += fmt.Sprintf("~%d rows", meta.RowCount) + } + if meta.Size != "" { + if parts != "" { + parts += ", " + } + parts += meta.Size + } + fmt.Fprintf(os.Stdout, " %s\n", parts) + } + fmt.Fprintln(os.Stdout) + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "NAME\tTYPE\tNULLABLE\tDESCRIPTION") + for _, c := range meta.Columns { + nullable := "NO" + if c.Nullable { + nullable = "YES" + } + fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", c.Name, c.DataType, nullable, c.Description) + } + return w.Flush() +} diff --git a/internal/db/meta.go b/internal/db/meta.go @@ -175,7 +175,7 @@ func (c *Client) Preview(ctx context.Context, schema, table string, limit int) ( limit = 50 } - qualified := fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table)) + qualified := fmt.Sprintf("%s.%s", QuoteIdent(schema), QuoteIdent(table)) // Estimated count via pg stats (fast). var total int64 @@ -220,6 +220,8 @@ func (c *Client) Preview(ctx context.Context, schema, table string, limit int) ( return &result, rows.Err() } -func quoteIdent(s string) string { +// QuoteIdent quotes a PostgreSQL identifier (schema, table, column name) +// to prevent SQL injection. +func QuoteIdent(s string) string { return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } diff --git a/internal/db/meta_test.go b/internal/db/meta_test.go @@ -0,0 +1,26 @@ +package db + +import "testing" + +func TestQuoteIdent(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"crsp", `"crsp"`}, + {"dsf", `"dsf"`}, + {`my"table`, `"my""table"`}, + {"", `""`}, + {"with space", `"with space"`}, + {`double""quote`, `"double""""quote"`}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := QuoteIdent(tt.input) + if got != tt.want { + t.Errorf("QuoteIdent(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/export/export.go b/internal/export/export.go @@ -20,7 +20,8 @@ import ( // Options controls the export behaviour. type Options struct { - Format string // "parquet" or "csv" + Format string // "parquet" or "csv" + ProgressFunc func(rows int) // called periodically with total rows exported so far } const rowGroupSize = 10_000 @@ -57,14 +58,14 @@ func Export(query, outPath string, opts Options) error { switch format { case "csv": - return writeCSV(rows, outPath) + return writeCSV(rows, outPath, opts.ProgressFunc) default: - return writeParquet(rows, outPath) + return writeParquet(rows, outPath, opts.ProgressFunc) } } // writeCSV streams rows into a CSV file with a header row. -func writeCSV(rows pgx.Rows, outPath string) error { +func writeCSV(rows pgx.Rows, outPath string, progressFn func(int)) error { f, err := os.Create(outPath) if err != nil { return fmt.Errorf("create csv: %w", err) @@ -84,6 +85,7 @@ func writeCSV(rows pgx.Rows, outPath string) error { } record := make([]string, len(fds)) + total := 0 for rows.Next() { vals, err := rows.Values() if err != nil { @@ -95,6 +97,10 @@ func writeCSV(rows pgx.Rows, outPath string) error { if err := w.Write(record); err != nil { return fmt.Errorf("write row: %w", err) } + total++ + if progressFn != nil && total%rowGroupSize == 0 { + progressFn(total) + } } if err := rows.Err(); err != nil { return fmt.Errorf("rows: %w", err) @@ -105,7 +111,7 @@ func writeCSV(rows pgx.Rows, outPath string) error { } // writeParquet streams rows into a Parquet file using parquet-go. -func writeParquet(rows pgx.Rows, outPath string) error { +func writeParquet(rows pgx.Rows, outPath string, progressFn func(int)) error { fds := rows.FieldDescriptions() schema, colTypes := buildParquetSchema(fds) @@ -123,6 +129,7 @@ func writeParquet(rows pgx.Rows, outPath string) error { ) buf := make([]map[string]any, 0, rowGroupSize) + total := 0 for rows.Next() { vals, err := rows.Values() @@ -140,7 +147,11 @@ func writeParquet(rows pgx.Rows, outPath string) error { if _, err := writer.Write(buf); err != nil { return fmt.Errorf("write row group: %w", err) } + total += len(buf) buf = buf[:0] + if progressFn != nil { + progressFn(total) + } } } if err := rows.Err(); err != nil { diff --git a/internal/export/export_test.go b/internal/export/export_test.go @@ -0,0 +1,81 @@ +package export + +import ( + "math/big" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" +) + +func TestFormatValue(t *testing.T) { + tests := []struct { + name string + v any + want string + }{ + {"nil", nil, ""}, + {"string", "hello", "hello"}, + {"bytes", []byte("data"), "data"}, + {"int", 42, "42"}, + {"float", 3.14, "3.14"}, + {"bool", true, "true"}, + {"date only", time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), "2020-01-02"}, + {"datetime", time.Date(2020, 1, 2, 15, 4, 5, 0, time.UTC), "2020-01-02T15:04:05Z"}, + {"numeric valid", pgtype.Numeric{Int: big.NewInt(12345), Exp: -2, Valid: true}, "123.45"}, + {"numeric invalid", pgtype.Numeric{Valid: false}, ""}, + {"numeric NaN", pgtype.Numeric{Valid: true, NaN: true}, "NaN"}, + {"numeric Infinity", pgtype.Numeric{Valid: true, InfinityModifier: pgtype.Infinity}, "Infinity"}, + {"numeric -Infinity", pgtype.Numeric{Valid: true, InfinityModifier: pgtype.NegativeInfinity}, "-Infinity"}, + {"numeric zero exp", pgtype.Numeric{Int: big.NewInt(42), Exp: 0, Valid: true}, "42"}, + {"numeric positive exp", pgtype.Numeric{Int: big.NewInt(5), Exp: 3, Valid: true}, "5000"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatValue(tt.v) + if got != tt.want { + t.Errorf("formatValue(%v) = %q, want %q", tt.v, got, tt.want) + } + }) + } +} + +func TestConvertValue(t *testing.T) { + epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) + + tests := []struct { + name string + v any + ct colType + want any + }{ + {"nil", nil, colString, nil}, + {"bool true", true, colBool, true}, + {"bool false", false, colBool, false}, + {"int16 to int32", int16(42), colInt32, int32(42)}, + {"int32 passthrough", int32(100), colInt32, int32(100)}, + {"int64 to int32", int64(200), colInt32, int32(200)}, + {"int64 passthrough", int64(999), colInt64, int64(999)}, + {"int32 to int64", int32(50), colInt64, int64(50)}, + {"int16 to int64", int16(10), colInt64, int64(10)}, + {"float32 passthrough", float32(1.5), colFloat32, float32(1.5)}, + {"float64 to float32", float64(2.5), colFloat32, float32(2.5)}, + {"float64 passthrough", float64(3.14), colFloat64, float64(3.14)}, + {"float32 to float64", float32(1.5), colFloat64, float64(1.5)}, + {"date", time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), colDate, + int32(time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC).Sub(epoch).Hours() / 24)}, + {"timestamp", time.Date(2020, 6, 15, 12, 30, 0, 0, time.UTC), colTimestamp, + time.Date(2020, 6, 15, 12, 30, 0, 0, time.UTC).Sub(epoch).Microseconds()}, + {"string passthrough", "hello", colString, "hello"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := convertValue(tt.v, tt.ct) + if got != tt.want { + t.Errorf("convertValue(%v, %v) = %v (%T), want %v (%T)", tt.v, tt.ct, got, got, tt.want, tt.want) + } + }) + } +} diff --git a/internal/tui/app.go b/internal/tui/app.go @@ -45,6 +45,7 @@ type tablesLoadedMsg struct{ tables []db.Table } type metaLoadedMsg struct{ meta *db.TableMeta } type errMsg struct{ err error } type downloadDoneMsg struct{ path string } +type downloadProgressMsg struct{ rows int } type tickMsg time.Time type loginSuccessMsg struct{ client *db.Client } type loginFailMsg struct{ err error } @@ -89,6 +90,8 @@ type App struct { currentDatabase string selectedSchema string selectedTable string + downloadRows int + dlProgressCh chan int } func newPreviewFilter() textinput.Model { @@ -228,21 +231,52 @@ func (a *App) loadMeta(schema, tbl string) tea.Cmd { } func (a *App) startDownload(msg DlSubmitMsg) tea.Cmd { - return func() tea.Msg { + progressCh := make(chan int, 1) + a.dlProgressCh = progressCh + + download := func() tea.Msg { sel := "*" if msg.Columns != "" && msg.Columns != "*" { - sel = msg.Columns + parts := strings.Split(msg.Columns, ",") + quoted := make([]string, len(parts)) + for i, p := range parts { + quoted[i] = db.QuoteIdent(strings.TrimSpace(p)) + } + sel = strings.Join(quoted, ", ") } - query := fmt.Sprintf("SELECT %s FROM %s.%s", sel, msg.Schema, msg.Table) + query := fmt.Sprintf("SELECT %s FROM %s.%s", sel, db.QuoteIdent(msg.Schema), db.QuoteIdent(msg.Table)) if msg.Where != "" { query += " WHERE " + msg.Where } - err := export.Export(query, msg.Out, export.Options{Format: msg.Format}) + if msg.Limit > 0 { + query += fmt.Sprintf(" LIMIT %d", msg.Limit) + } + opts := export.Options{ + Format: msg.Format, + ProgressFunc: func(rows int) { + select { + case progressCh <- rows: + default: + } + }, + } + err := export.Export(query, msg.Out, opts) + close(progressCh) if err != nil { return errMsg{err} } return downloadDoneMsg{msg.Out} } + + listenProgress := func() tea.Msg { + rows, ok := <-progressCh + if !ok { + return nil + } + return downloadProgressMsg{rows} + } + + return tea.Batch(download, listenProgress) } func (a *App) attemptLogin(msg LoginSubmitMsg) tea.Cmd { @@ -391,9 +425,22 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.state = stateBrowse return a, nil + case downloadProgressMsg: + a.downloadRows = msg.rows + // Keep listening for more progress updates. + ch := a.dlProgressCh + return a, func() tea.Msg { + rows, ok := <-ch + if !ok { + return nil + } + return downloadProgressMsg{rows} + } + case downloadDoneMsg: a.statusOK = fmt.Sprintf("Saved: %s", msg.path) a.state = stateDone + a.downloadRows = 0 return a, nil case DlCancelMsg: @@ -404,6 +451,7 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.state = stateDownloading a.statusErr = "" a.statusOK = "" + a.downloadRows = 0 return a, tea.Batch(a.startDownload(msg), a.spinner.Tick) case list.FilterMatchesMsg: @@ -648,8 +696,11 @@ func (a *App) View() string { return overlayCenter(full, overlay, a.width, a.height) } if a.state == stateDownloading { - msg := a.spinner.View() + " Downloading…" - return overlayCenter(full, stylePanelFocused.Padding(1, 3).Render(msg), a.width, a.height) + dlMsg := a.spinner.View() + " Downloading…" + if a.downloadRows > 0 { + dlMsg = a.spinner.View() + fmt.Sprintf(" Downloading… %s rows exported", formatCount(int64(a.downloadRows))) + } + return overlayCenter(full, stylePanelFocused.Padding(1, 3).Render(dlMsg), a.width, a.height) } if a.state == stateDone { msg := styleSuccess.Render("✓ ") + a.statusOK + "\n\n" + styleStatusBar.Render("[esc] dismiss") diff --git a/internal/tui/dlform.go b/internal/tui/dlform.go @@ -2,6 +2,7 @@ package tui import ( "fmt" + "strconv" "strings" "github.com/charmbracelet/bubbles/textinput" @@ -14,6 +15,7 @@ type dlFormField int const ( fieldSelect dlFormField = iota fieldWhere + fieldLimit fieldOut fieldFormat fieldCount @@ -34,6 +36,7 @@ type DlSubmitMsg struct { Table string Columns string Where string + Limit int Out string Format string } @@ -61,6 +64,10 @@ func newDlForm(schema, table string, colNames []string) DlForm { f.inputs[fieldWhere].Placeholder = "e.g. date >= '2020-01-01'" f.inputs[fieldWhere].CharLimit = 512 + f.inputs[fieldLimit] = textinput.New() + f.inputs[fieldLimit].Placeholder = "no limit" + f.inputs[fieldLimit].CharLimit = 12 + f.inputs[fieldOut] = textinput.New() f.inputs[fieldOut].Placeholder = fmt.Sprintf("./%s_%s.parquet", schema, table) f.inputs[fieldOut].CharLimit = 256 @@ -101,12 +108,17 @@ func (f DlForm) Update(msg tea.Msg) (DlForm, tea.Cmd) { if columns == "" { columns = "*" } + var limit int + if v := strings.TrimSpace(f.inputs[fieldLimit].Value()); v != "" { + limit, _ = strconv.Atoi(v) + } return f, func() tea.Msg { return DlSubmitMsg{ Schema: f.schema, Table: f.table, Columns: columns, Where: f.inputs[fieldWhere].Value(), + Limit: limit, Out: out, Format: format, } @@ -135,7 +147,7 @@ func (f DlForm) View(width int) string { title := stylePanelHeader.Render(fmt.Sprintf("Download %s.%s", f.schema, f.table)) sb.WriteString(title + "\n\n") - labels := []string{"SELECT columns", "WHERE clause", "Output path", "Format (parquet/csv)"} + labels := []string{"SELECT columns", "WHERE clause", "LIMIT rows", "Output path", "Format (parquet/csv)"} for i, label := range labels { style := lipgloss.NewStyle().Foreground(colorMuted) if dlFormField(i) == f.focused {