wrds-download

TUI/CLI tool for browsing and downloading WRDS data
Log | Files | Refs | README

commit 50acd5dd300ab27ade59be50b463ce63fe3c525c
parent 916fb139cc5c14a3d2f513fff60c7678e6afaaaf
Author: Erik Loualiche <[email protected]>
Date:   Fri, 20 Feb 2026 09:13:52 -0600

Merge pull request #1 from eloualiche/fix/preview-panel

Replace data preview with metadata column catalog
Diffstat:
Mcmd/download.go | 3+++
Mcmd/tui.go | 24+++++++++++++++++++++---
Ainternal/config/config.go | 79+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Minternal/db/client.go | 58+++++++++++++++++++++++++++++++++++++++++++++++++++-------
Minternal/db/meta.go | 63+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Minternal/tui/app.go | 613+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------
Ainternal/tui/loginform.go | 192+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
7 files changed, 935 insertions(+), 97 deletions(-)

diff --git a/cmd/download.go b/cmd/download.go @@ -5,6 +5,7 @@ import ( "os" "strings" + "github.com/eloualiche/wrds-download/internal/config" "github.com/eloualiche/wrds-download/internal/export" "github.com/spf13/cobra" ) @@ -47,6 +48,8 @@ func init() { } func runDownload(cmd *cobra.Command, args []string) error { + config.ApplyCredentials() + query, err := buildQuery() if err != nil { return err diff --git a/cmd/tui.go b/cmd/tui.go @@ -5,9 +5,10 @@ import ( "fmt" "os" + tea "github.com/charmbracelet/bubbletea" + "github.com/eloualiche/wrds-download/internal/config" "github.com/eloualiche/wrds-download/internal/db" "github.com/eloualiche/wrds-download/internal/tui" - tea "github.com/charmbracelet/bubbletea" "github.com/spf13/cobra" ) @@ -22,18 +23,35 @@ func init() { } func runTUI(cmd *cobra.Command, args []string) error { + config.ApplyCredentials() + ctx := context.Background() client, err := db.New(ctx) if err != nil { - return fmt.Errorf("connect to WRDS: %w", err) + // Launch TUI in login mode instead of crashing + m := tui.NewAppNoClient() + p := tea.NewProgram(m, tea.WithAltScreen()) + final, err := p.Run() + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + if a, ok := final.(*tui.App); ok && a.Err() != "" { + fmt.Fprintln(os.Stderr, a.Err()) + } + return nil } defer client.Close() m := tui.NewApp(client) p := tea.NewProgram(m, tea.WithAltScreen()) - if _, err := p.Run(); err != nil { + final, err := p.Run() + if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } + if a, ok := final.(*tui.App); ok && a.Err() != "" { + fmt.Fprintln(os.Stderr, a.Err()) + } return nil } diff --git a/internal/config/config.go b/internal/config/config.go @@ -0,0 +1,79 @@ +package config + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" +) + +// CredentialsPath returns the path to the credentials file, +// respecting $XDG_CONFIG_HOME with fallback to ~/.config/wrds-dl/credentials. +func CredentialsPath() string { + base := os.Getenv("XDG_CONFIG_HOME") + if base == "" { + home, _ := os.UserHomeDir() + base = filepath.Join(home, ".config") + } + return filepath.Join(base, "wrds-dl", "credentials") +} + +// LoadCredentials reads PGUSER, PGPASSWORD, and PGDATABASE from the credentials file. +func LoadCredentials() (user, password, database string, err error) { + f, err := os.Open(CredentialsPath()) + if err != nil { + return "", "", "", err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + key, val, ok := strings.Cut(line, "=") + if !ok { + continue + } + switch strings.TrimSpace(key) { + case "PGUSER": + user = strings.TrimSpace(val) + case "PGPASSWORD": + password = strings.TrimSpace(val) + case "PGDATABASE": + database = strings.TrimSpace(val) + } + } + return user, password, database, scanner.Err() +} + +// SaveCredentials writes PGUSER, PGPASSWORD, and PGDATABASE to the credentials file. +// It creates the parent directories with 0700 and the file with 0600 permissions. +func SaveCredentials(user, password, database string) error { + path := CredentialsPath() + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { + return fmt.Errorf("create config dir: %w", err) + } + content := fmt.Sprintf("PGUSER=%s\nPGPASSWORD=%s\nPGDATABASE=%s\n", user, password, database) + return os.WriteFile(path, []byte(content), 0600) +} + +// ApplyCredentials loads credentials from the config file and sets +// environment variables for any values not already set. +func ApplyCredentials() { + user, password, database, err := LoadCredentials() + if err != nil { + return // file doesn't exist or unreadable — silently skip + } + if os.Getenv("PGUSER") == "" && user != "" { + os.Setenv("PGUSER", user) + } + if os.Getenv("PGPASSWORD") == "" && password != "" { + os.Setenv("PGPASSWORD", password) + } + if os.Getenv("PGDATABASE") == "" && database != "" { + os.Setenv("PGDATABASE", database) + } +} diff --git a/internal/db/client.go b/internal/db/client.go @@ -2,6 +2,7 @@ package db import ( "context" + "errors" "fmt" "os" "strconv" @@ -9,25 +10,35 @@ import ( "github.com/jackc/pgx/v5/pgxpool" ) +// ErrNoUser is returned when PGUSER is not set. +var ErrNoUser = errors.New("PGUSER not set") + // Client wraps a pgx connection pool. type Client struct { Pool *pgxpool.Pool } // DSNFromEnv builds a PostgreSQL DSN from standard PG environment variables. -func DSNFromEnv() string { +// Returns ("", ErrNoUser) if PGUSER is empty. +func DSNFromEnv() (string, error) { host := getenv("PGHOST", "wrds-pgdata.wharton.upenn.edu") port := getenv("PGPORT", "9737") user := getenv("PGUSER", "") password := getenv("PGPASSWORD", "") - database := getenv("PGDATABASE", user) // WRDS default db = username + database := getenv("PGDATABASE", "wrds") + + if user == "" { + return "", ErrNoUser + } + dsn := fmt.Sprintf("host=%s port=%s user=%s sslmode=require", host, port, user) if password != "" { - return fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=require", - host, port, user, password, database) + dsn += fmt.Sprintf(" password=%s", password) + } + if database != "" { + dsn += fmt.Sprintf(" dbname=%s", database) } - return fmt.Sprintf("host=%s port=%s user=%s dbname=%s sslmode=require", - host, port, user, database) + return dsn, nil } // PortFromEnv returns the port as an integer (for DuckDB attach). @@ -49,7 +60,10 @@ func getenv(key, fallback string) string { // New creates and pings a pgx pool using DSNFromEnv. func New(ctx context.Context) (*Client, error) { - dsn := DSNFromEnv() + dsn, err := DSNFromEnv() + if err != nil { + return nil, err + } pool, err := pgxpool.New(ctx, dsn) if err != nil { return nil, fmt.Errorf("pgxpool.New: %w", err) @@ -61,6 +75,36 @@ func New(ctx context.Context) (*Client, error) { return &Client{Pool: pool}, nil } +// NewWithCredentials sets PGUSER/PGPASSWORD/PGDATABASE env vars then creates and pings a pool. +func NewWithCredentials(ctx context.Context, user, password, database string) (*Client, error) { + os.Setenv("PGUSER", user) + os.Setenv("PGPASSWORD", password) + if database != "" { + os.Setenv("PGDATABASE", database) + } + return New(ctx) +} + +// Databases returns the list of connectable databases. +func (c *Client) Databases(ctx context.Context) ([]string, error) { + rows, err := c.Pool.Query(ctx, + "SELECT datname FROM pg_database WHERE datallowconn = true ORDER BY datname") + if err != nil { + return nil, fmt.Errorf("databases query: %w", err) + } + defer rows.Close() + + var dbs []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + dbs = append(dbs, name) + } + return dbs, rows.Err() +} + // Close releases the pool. func (c *Client) Close() { c.Pool.Close() diff --git a/internal/db/meta.go b/internal/db/meta.go @@ -23,6 +23,24 @@ type Column struct { DataType string } +// ColumnMeta holds catalog metadata about a single column. +type ColumnMeta struct { + Name string + DataType string + Nullable bool + Description string // from pg_description (WRDS variable label) +} + +// TableMeta holds catalog metadata for a table (no data scan required). +type TableMeta struct { + Schema string + Table string + Comment string // table-level comment from pg_description + RowCount int64 // estimated from pg_class.reltuples + Size string // human-readable, from pg_size_pretty + Columns []ColumnMeta +} + // PreviewResult holds sample rows and row count for a table. type PreviewResult struct { Columns []string @@ -106,6 +124,51 @@ func (c *Client) Columns(ctx context.Context, schema, table string) ([]Column, e return cols, rows.Err() } +// TableMeta fetches catalog metadata for a table: column info with +// descriptions, estimated row count, and table size. All queries hit +// pg_catalog only — no table data is scanned. +func (c *Client) TableMeta(ctx context.Context, schema, table string) (*TableMeta, error) { + meta := &TableMeta{Schema: schema, Table: table} + + // Table-level stats (best effort — some may require permissions). + _ = c.Pool.QueryRow(ctx, ` + SELECT c.reltuples::bigint, + COALESCE(pg_size_pretty(pg_total_relation_size(c.oid)), ''), + COALESCE(obj_description(c.oid), '') + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = $1 AND c.relname = $2 + `, schema, table).Scan(&meta.RowCount, &meta.Size, &meta.Comment) + + // Column metadata with descriptions from pg_description. + rows, err := c.Pool.Query(ctx, ` + SELECT a.attname, + pg_catalog.format_type(a.atttypid, a.atttypmod), + NOT a.attnotnull, + COALESCE(d.description, '') + FROM pg_attribute a + JOIN pg_class c ON a.attrelid = c.oid + JOIN pg_namespace n ON c.relnamespace = n.oid + LEFT JOIN pg_description d ON d.objoid = c.oid AND d.objsubid = a.attnum + WHERE n.nspname = $1 AND c.relname = $2 + AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum + `, schema, table) + if err != nil { + return nil, fmt.Errorf("table meta: %w", err) + } + defer rows.Close() + + for rows.Next() { + var col ColumnMeta + if err := rows.Scan(&col.Name, &col.DataType, &col.Nullable, &col.Description); err != nil { + return nil, err + } + meta.Columns = append(meta.Columns, col) + } + return meta, rows.Err() +} + // Preview fetches the first `limit` rows and an estimated row count. func (c *Client) Preview(ctx context.Context, schema, table string, limit int) (*PreviewResult, error) { if limit <= 0 { diff --git a/internal/tui/app.go b/internal/tui/app.go @@ -3,14 +3,16 @@ package tui import ( "context" "fmt" + "os" "strings" "time" "github.com/charmbracelet/bubbles/list" "github.com/charmbracelet/bubbles/spinner" - "github.com/charmbracelet/bubbles/table" + "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/eloualiche/wrds-download/internal/config" "github.com/eloualiche/wrds-download/internal/db" "github.com/eloualiche/wrds-download/internal/export" ) @@ -28,7 +30,9 @@ const ( type appState int const ( - stateBrowse appState = iota + stateLogin appState = iota + stateBrowse + stateDatabaseSelect stateDownloadForm stateDownloading stateDone @@ -38,10 +42,15 @@ const ( type schemasLoadedMsg struct{ schemas []db.Schema } type tablesLoadedMsg struct{ tables []db.Table } -type previewLoadedMsg struct{ result *db.PreviewResult } +type metaLoadedMsg struct{ meta *db.TableMeta } type errMsg struct{ err error } type downloadDoneMsg struct{ path string } type tickMsg time.Time +type loginSuccessMsg struct{ client *db.Client } +type loginFailMsg struct{ err error } +type databasesLoadedMsg struct{ databases []string } +type databaseSwitchedMsg struct{ client *db.Client } +type databaseSwitchFailMsg struct{ err error } func errCmd(err error) tea.Cmd { return func() tea.Msg { return errMsg{err} } @@ -62,50 +71,124 @@ type App struct { focus pane state appState - schemaList list.Model - tableList list.Model - previewTbl table.Model - previewInfo string // "~2.1M rows" etc. - - dlForm DlForm - spinner spinner.Model - statusOK string + schemaList list.Model + tableList list.Model + previewMeta *db.TableMeta + previewScroll int + previewFilter textinput.Model + previewFiltering bool + + loginForm LoginForm + loginErr string + dlForm DlForm + dbList list.Model + spinner spinner.Model + statusOK string statusErr string - selectedSchema string - selectedTable string + currentDatabase string + selectedSchema string + selectedTable string +} + +func newPreviewFilter() textinput.Model { + pf := textinput.New() + pf.Prompt = "/ " + pf.Placeholder = "filter columns…" + pf.CharLimit = 64 + return pf } // NewApp constructs the root model. func NewApp(client *db.Client) *App { del := list.NewDefaultDelegate() del.ShowDescription = false + del.SetSpacing(0) schemaList := list.New(nil, del, 0, 0) schemaList.Title = "Schemas" schemaList.SetShowStatusBar(false) schemaList.SetFilteringEnabled(true) + schemaList.DisableQuitKeybindings() + schemaList.Styles.TitleBar = schemaList.Styles.TitleBar.Padding(0, 0, 0, 2) tableList := list.New(nil, del, 0, 0) tableList.Title = "Tables" tableList.SetShowStatusBar(false) tableList.SetFilteringEnabled(true) + tableList.DisableQuitKeybindings() + tableList.Styles.TitleBar = tableList.Styles.TitleBar.Padding(0, 0, 0, 2) + + dbList := list.New(nil, del, 0, 0) + dbList.Title = "Databases" + dbList.SetShowStatusBar(false) + dbList.SetFilteringEnabled(true) + dbList.DisableQuitKeybindings() + dbList.Styles.TitleBar = dbList.Styles.TitleBar.Padding(0, 0, 0, 2) sp := spinner.New() sp.Spinner = spinner.Dot return &App{ - client: client, - schemaList: schemaList, - tableList: tableList, - spinner: sp, - focus: paneSchema, - state: stateBrowse, + client: client, + currentDatabase: os.Getenv("PGDATABASE"), + schemaList: schemaList, + tableList: tableList, + dbList: dbList, + spinner: sp, + previewFilter: newPreviewFilter(), + focus: paneSchema, + state: stateBrowse, } } -// Init loads schemas on startup. +// NewAppNoClient creates an App in login state (no DB connection yet). +func NewAppNoClient() *App { + del := list.NewDefaultDelegate() + del.ShowDescription = false + del.SetSpacing(0) + + schemaList := list.New(nil, del, 0, 0) + schemaList.Title = "Schemas" + schemaList.SetShowStatusBar(false) + schemaList.SetFilteringEnabled(true) + schemaList.DisableQuitKeybindings() + schemaList.Styles.TitleBar = schemaList.Styles.TitleBar.Padding(0, 0, 0, 2) + + tableList := list.New(nil, del, 0, 0) + tableList.Title = "Tables" + tableList.SetShowStatusBar(false) + tableList.SetFilteringEnabled(true) + tableList.DisableQuitKeybindings() + tableList.Styles.TitleBar = tableList.Styles.TitleBar.Padding(0, 0, 0, 2) + + dbList := list.New(nil, del, 0, 0) + dbList.Title = "Databases" + dbList.SetShowStatusBar(false) + dbList.SetFilteringEnabled(true) + dbList.DisableQuitKeybindings() + dbList.Styles.TitleBar = dbList.Styles.TitleBar.Padding(0, 0, 0, 2) + + sp := spinner.New() + sp.Spinner = spinner.Dot + + return &App{ + schemaList: schemaList, + tableList: tableList, + dbList: dbList, + spinner: sp, + previewFilter: newPreviewFilter(), + focus: paneSchema, + state: stateLogin, + loginForm: newLoginForm(), + } +} + +// Init loads schemas on startup, or starts login form blink if in login state. func (a *App) Init() tea.Cmd { + if a.state == stateLogin { + return textinput.Blink + } return tea.Batch( a.loadSchemas(), a.spinner.Tick, @@ -132,13 +215,15 @@ func (a *App) loadTables(schema string) tea.Cmd { } } -func (a *App) loadPreview(schema, tbl string) tea.Cmd { +func (a *App) loadMeta(schema, tbl string) tea.Cmd { return func() tea.Msg { - result, err := a.client.Preview(context.Background(), schema, tbl, 50) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + meta, err := a.client.TableMeta(ctx, schema, tbl) if err != nil { return errMsg{err} } - return previewLoadedMsg{result} + return metaLoadedMsg{meta} } } @@ -158,6 +243,60 @@ func (a *App) startDownload(msg DlSubmitMsg) tea.Cmd { } } +func (a *App) attemptLogin(msg LoginSubmitMsg) tea.Cmd { + return func() tea.Msg { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + client, err := db.NewWithCredentials(ctx, msg.User, msg.Password, msg.Database) + if err != nil { + return loginFailMsg{err} + } + if msg.Save { + _ = config.SaveCredentials(msg.User, msg.Password, msg.Database) + } + return loginSuccessMsg{client} + } +} + +func (a *App) loadDatabases() tea.Cmd { + return func() tea.Msg { + dbs, err := a.client.Databases(context.Background()) + if err != nil { + return errMsg{err} + } + return databasesLoadedMsg{dbs} + } +} + +func (a *App) switchDatabase(name string) tea.Cmd { + return func() tea.Msg { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + user := os.Getenv("PGUSER") + password := os.Getenv("PGPASSWORD") + client, err := db.NewWithCredentials(ctx, user, password, name) + if err != nil { + return databaseSwitchFailMsg{err} + } + return databaseSwitchedMsg{client} + } +} + +// friendlyError extracts a short, readable message from verbose pgx errors. +func friendlyError(err error) string { + s := err.Error() + // pgx errors look like: "ping: failed to connect to `host=... user=...`: <reason>" + // Extract just the reason after the last colon-space following the backtick-quoted section. + if idx := strings.LastIndex(s, "`: "); idx != -1 { + return s[idx+3:] + } + // Fall back to stripping common prefixes. + for _, prefix := range []string{"ping: ", "pgxpool.New: "} { + s = strings.TrimPrefix(s, prefix) + } + return s +} + // Update handles all messages. func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { @@ -187,40 +326,66 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { items[i] = item{t.Name} } a.tableList.SetItems(items) - a.previewTbl = table.Model{} // clear preview - a.previewInfo = "" + a.previewMeta = nil + a.previewScroll = 0 + a.previewFilter.SetValue("") + a.previewFiltering = false return a, nil - case previewLoadedMsg: - r := msg.result - cols := make([]table.Column, len(r.Columns)) - for i, c := range r.Columns { - w := maxWidth(c, r.Rows, i, 20) - cols[i] = table.Column{Title: c, Width: w} - } - rows := make([]table.Row, len(r.Rows)) - for i, row := range r.Rows { - rows[i] = table.Row(row) - } - t := table.New( - table.WithColumns(cols), - table.WithRows(rows), - table.WithFocused(false), - table.WithHeight(a.previewHeight()-4), - ) - ts := table.DefaultStyles() - ts.Header = ts.Header.BorderStyle(lipgloss.NormalBorder()).BorderForeground(colorMuted).BorderBottom(true).Bold(true) - ts.Selected = ts.Selected.Foreground(colorPrimary).Bold(false) - t.SetStyles(ts) - - a.previewTbl = t - if r.Total > 0 { - a.previewInfo = fmt.Sprintf("~%s rows", formatCount(r.Total)) + case metaLoadedMsg: + a.previewMeta = msg.meta + a.previewScroll = 0 + a.previewFilter.SetValue("") + a.previewFiltering = false + return a, nil + + case LoginSubmitMsg: + a.loginErr = "" + return a, a.attemptLogin(msg) + + case LoginCancelMsg: + return a, tea.Quit + + case loginSuccessMsg: + a.client = msg.client + a.currentDatabase = os.Getenv("PGDATABASE") + a.state = stateBrowse + return a, tea.Batch(a.loadSchemas(), a.spinner.Tick) + + case loginFailMsg: + a.loginErr = friendlyError(msg.err) + a.state = stateLogin + return a, nil + + case databasesLoadedMsg: + items := make([]list.Item, len(msg.databases)) + for i, d := range msg.databases { + items[i] = item{d} } + a.dbList.SetItems(items) + a.state = stateDatabaseSelect + return a, nil + + case databaseSwitchedMsg: + a.client.Close() + a.client = msg.client + a.currentDatabase = os.Getenv("PGDATABASE") + a.selectedSchema = "" + a.selectedTable = "" + a.previewMeta = nil + a.previewScroll = 0 + a.previewFilter.SetValue("") + a.tableList.SetItems(nil) + a.state = stateBrowse + return a, a.loadSchemas() + + case databaseSwitchFailMsg: + a.statusErr = friendlyError(msg.err) + a.state = stateBrowse return a, nil case errMsg: - a.statusErr = msg.err.Error() + a.statusErr = friendlyError(msg.err) a.state = stateBrowse return a, nil @@ -239,26 +404,97 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.statusOK = "" return a, tea.Batch(a.startDownload(msg), a.spinner.Tick) + case list.FilterMatchesMsg: + // Route async filter results back to the list that initiated filtering. + var cmd tea.Cmd + switch { + case a.schemaList.FilterState() == list.Filtering: + a.schemaList, cmd = a.schemaList.Update(msg) + case a.tableList.FilterState() == list.Filtering: + a.tableList, cmd = a.tableList.Update(msg) + case a.dbList.FilterState() == list.Filtering: + a.dbList, cmd = a.dbList.Update(msg) + } + return a, cmd + case tea.KeyMsg: + if a.state == stateLogin { + var cmd tea.Cmd + a.loginForm, cmd = a.loginForm.Update(msg) + return a, cmd + } if a.state == stateDownloadForm { var cmd tea.Cmd a.dlForm, cmd = a.dlForm.Update(msg) return a, cmd } + if a.state == stateDatabaseSelect { + if a.dbList.FilterState() == list.Filtering { + var cmd tea.Cmd + a.dbList, cmd = a.dbList.Update(msg) + return a, cmd + } + switch msg.String() { + case "esc": + a.state = stateBrowse + return a, nil + case "enter": + if sel := selectedItemTitle(a.dbList); sel != "" { + a.state = stateDownloading + return a, tea.Batch(a.switchDatabase(sel), a.spinner.Tick) + } + } + var cmd tea.Cmd + a.dbList, cmd = a.dbList.Update(msg) + return a, cmd + } + + // Preview column filter: intercept all keys when active. + if a.focus == panePreview && a.previewFiltering { + switch msg.String() { + case "esc": + a.previewFiltering = false + a.previewFilter.SetValue("") + a.previewFilter.Blur() + return a, nil + case "enter": + a.previewFiltering = false + a.previewFilter.Blur() + return a, nil + } + var cmd tea.Cmd + a.previewFilter, cmd = a.previewFilter.Update(msg) + a.previewScroll = 0 + return a, cmd + } switch msg.String() { case "q", "ctrl+c": + if a.focusedListFiltering() { + break // let list handle it + } return a, tea.Quit case "tab": + if a.focusedListFiltering() { + break + } + a.statusErr = "" a.focus = (a.focus + 1) % 3 return a, nil case "shift+tab": + if a.focusedListFiltering() { + break + } + a.statusErr = "" a.focus = (a.focus + 2) % 3 return a, nil - case "enter": + case "right", "l": + if a.focusedListFiltering() { + break + } switch a.focus { case paneSchema: if sel := selectedItemTitle(a.schemaList); sel != "" { @@ -270,19 +506,44 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case paneTable: if sel := selectedItemTitle(a.tableList); sel != "" { a.selectedTable = sel + a.previewMeta = nil + a.previewScroll = 0 + a.previewFilter.SetValue("") a.focus = panePreview - return a, a.loadPreview(a.selectedSchema, sel) + return a, a.loadMeta(a.selectedSchema, sel) } } + return a, nil + + case "left", "h": + if a.focusedListFiltering() { + break + } + if a.focus > paneSchema { + a.focus-- + } + return a, nil case "d": + if a.focusedListFiltering() { + break + } if a.selectedSchema != "" && a.selectedTable != "" { a.dlForm = newDlForm(a.selectedSchema, a.selectedTable) a.state = stateDownloadForm return a, nil } + case "b": + if a.focusedListFiltering() { + break + } + return a, a.loadDatabases() + case "esc": + if a.focusedListFiltering() { + break // let list cancel filter + } if a.state == stateDone { a.state = stateBrowse a.statusOK = "" @@ -290,7 +551,7 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, nil } - // Delegate keyboard events to the focused list. + // All other keys (including enter, /, letters) go to the focused list/pane. var cmd tea.Cmd switch a.focus { case paneSchema: @@ -298,11 +559,32 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case paneTable: a.tableList, cmd = a.tableList.Update(msg) case panePreview: - a.previewTbl, cmd = a.previewTbl.Update(msg) + switch msg.String() { + case "/": + a.previewFiltering = true + a.previewFilter.Focus() + cmd = textinput.Blink + case "j", "down": + cols := a.filteredColumns() + if a.previewScroll < len(cols)-1 { + a.previewScroll++ + } + case "k", "up": + if a.previewScroll > 0 { + a.previewScroll-- + } + } } return a, cmd } + // Forward cursor blink messages to the active text input. + if a.previewFiltering { + var cmd tea.Cmd + a.previewFilter, cmd = a.previewFilter.Update(msg) + return a, cmd + } + // Forward spinner ticks when downloading. if a.state == stateDownloading { var cmd tea.Cmd @@ -319,7 +601,15 @@ func (a *App) View() string { return "Loading…" } - header := styleTitle.Render(" WRDS") + styleStatusBar.Render(" Wharton Research Data Services") + if a.state == stateLogin { + return a.loginView() + } + + dbLabel := "" + if a.currentDatabase != "" { + dbLabel = " db:" + a.currentDatabase + } + header := styleTitle.Render(" WRDS") + styleStatusBar.Render(" Wharton Research Data Services"+dbLabel) footer := a.footerView() // Content area height. @@ -327,13 +617,24 @@ func (a *App) View() string { schemaPanelW, tablePanelW, previewPanelW := a.panelWidths() - schemaPanel := a.renderListPanel(a.schemaList, "Schemas", paneSchema, schemaPanelW, contentH) - tablePanel := a.renderListPanel(a.tableList, fmt.Sprintf("Tables (%s)", a.selectedSchema), paneTable, tablePanelW, contentH) + schemaPanel := a.renderListPanel(a.schemaList, "Schemas", paneSchema, schemaPanelW, contentH, 1) + tablePanel := a.renderListPanel(a.tableList, fmt.Sprintf("Tables (%s)", a.selectedSchema), paneTable, tablePanelW, contentH, 1) previewPanel := a.renderPreviewPanel(previewPanelW, contentH) body := lipgloss.JoinHorizontal(lipgloss.Top, schemaPanel, tablePanel, previewPanel) full := lipgloss.JoinVertical(lipgloss.Left, header, body, footer) + if a.state == stateDatabaseSelect { + a.dbList.SetSize(40, a.height/2) + content := a.dbList.View() + hint := styleStatusBar.Render("[enter] switch [esc] cancel [/] filter") + box := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorFocus). + Padding(1, 2). + Render(content + "\n" + hint) + return overlayCenter(full, box, a.width, a.height) + } if a.state == stateDownloadForm { overlay := a.dlForm.View(a.width) return overlayCenter(full, overlay, a.width, a.height) @@ -350,23 +651,41 @@ func (a *App) View() string { return full } +func (a *App) loginView() string { + var sb strings.Builder + sb.WriteString(styleTitle.Render(" WRDS") + styleStatusBar.Render(" Wharton Research Data Services") + "\n\n") + sb.WriteString(a.loginForm.View(a.width, a.loginErr)) + return lipgloss.Place(a.width, a.height, lipgloss.Center, lipgloss.Center, sb.String()) +} + func (a *App) footerView() string { - keys := "[tab] switch pane [enter] select [d] download [/] filter [q] quit" - status := "" + keys := "[tab] pane [→/l] select [←/h] back [d] download [b] databases [/] filter [q] quit" + footer := styleStatusBar.Render(keys) if a.statusErr != "" { - status = " " + styleError.Render("Error: "+a.statusErr) + errText := a.statusErr + maxLen := a.width - 12 + if maxLen > 0 && len(errText) > maxLen { + errText = errText[:maxLen-1] + "…" + } + errBar := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#FFFFFF")). + Background(colorError). + Width(a.width). + Padding(0, 1). + Render("Error: " + errText) + footer = errBar + "\n" + footer } - return styleStatusBar.Render(keys + status) + return footer } -func (a *App) renderListPanel(l list.Model, title string, p pane, w, h int) string { - l.SetSize(w-4, h-2) +func (a *App) renderListPanel(l list.Model, title string, p pane, w, h, mr int) string { + l.SetSize(w-2, h-2) content := l.View() style := stylePanelBlurred if a.focus == p { style = stylePanelFocused } - return style.Width(w - 2).Height(h).Render(content) + return style.Width(w - 2).Height(h).MarginRight(mr).Render(content) } func (a *App) renderPreviewPanel(w, h int) string { @@ -377,16 +696,103 @@ func (a *App) renderPreviewPanel(w, h int) string { } sb.WriteString(stylePanelHeader.Render(label) + "\n") - if len(a.previewTbl.Columns()) > 0 { - a.previewTbl.SetHeight(h - 4) - sb.WriteString(a.previewTbl.View()) - if a.previewInfo != "" { - sb.WriteString("\n" + styleRowCount.Render(a.previewInfo)) + contentW := w - 4 // panel border + internal padding + + if a.previewMeta != nil { + meta := a.previewMeta + + // Stats line: "~245.3M rows · 1.2 GB" + var stats []string + if meta.RowCount > 0 { + stats = append(stats, "~"+formatCount(meta.RowCount)+" rows") + } + if meta.Size != "" { + stats = append(stats, meta.Size) + } + if len(stats) > 0 { + sb.WriteString(styleRowCount.Render(strings.Join(stats, " · ")) + "\n") + } + if meta.Comment != "" { + sb.WriteString(styleStatusBar.Render(meta.Comment) + "\n") } + + // Filter bar + if a.previewFiltering { + sb.WriteString(a.previewFilter.View() + "\n") + } else if a.previewFilter.Value() != "" { + sb.WriteString(styleStatusBar.Render("/ "+a.previewFilter.Value()) + "\n") + } + + cols := a.filteredColumns() + + if len(cols) > 0 { + // Calculate column widths from data. + nameW, typeW := len("Column"), len("Type") + for _, c := range cols { + if len(c.Name) > nameW { + nameW = len(c.Name) + } + if len(c.DataType) > typeW { + typeW = len(c.DataType) + } + } + if nameW > 22 { + nameW = 22 + } + if typeW > 20 { + typeW = 20 + } + descW := contentW - nameW - typeW - 4 // 2-char gaps + if descW < 8 { + descW = 8 + } + + // Column header + hdr := fmt.Sprintf("%-*s %-*s %-*s", nameW, "Column", typeW, "Type", descW, "Description") + sb.WriteString(styleCellHeader.Render(truncStr(hdr, contentW)) + "\n") + sb.WriteString(lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Repeat("─", contentW)) + "\n") + + // How many rows fit? + usedLines := lipgloss.Height(sb.String()) + footerLines := 1 + availRows := h - usedLines - footerLines - 2 + if availRows < 1 { + availRows = 1 + } + + start := a.previewScroll + end := start + availRows + if end > len(cols) { + end = len(cols) + } + + for i := start; i < end; i++ { + c := cols[i] + line := fmt.Sprintf("%-*s %-*s %s", + nameW, truncStr(c.Name, nameW), + typeW, truncStr(c.DataType, typeW), + truncStr(c.Description, descW)) + style := styleCellNormal + if i%2 == 0 { + style = style.Foreground(lipgloss.Color("#D1D5DB")) + } + sb.WriteString(style.Render(line) + "\n") + } + } + + // Column count footer + total := len(meta.Columns) + shown := len(cols) + countStr := fmt.Sprintf("%d columns", total) + if shown < total { + countStr = fmt.Sprintf("%d/%d columns", shown, total) + } + sb.WriteString(styleRowCount.Render(countStr)) + } else if a.selectedTable != "" { sb.WriteString(styleStatusBar.Render("Loading…")) } else { - sb.WriteString(styleStatusBar.Render("Select a table to preview rows")) + sb.WriteString(styleStatusBar.Render("Select a table to preview")) } style := stylePanelBlurred @@ -397,20 +803,55 @@ func (a *App) renderPreviewPanel(w, h int) string { } func (a *App) panelWidths() (int, int, int) { - schema := 22 - table := 28 - preview := a.width - schema - table + schema := 24 + tbl := 30 + margins := 2 // MarginRight(1) on schema + table panels + preview := a.width - schema - tbl - margins if preview < 30 { preview = 30 } - return schema, table, preview + return schema, tbl, preview } -func (a *App) previewHeight() int { - return a.height - 4 +func (a *App) resizePanels() {} + +// focusedListFiltering returns true if the currently focused list is in filter mode. +func (a *App) focusedListFiltering() bool { + switch a.focus { + case paneSchema: + return a.schemaList.FilterState() == list.Filtering + case paneTable: + return a.tableList.FilterState() == list.Filtering + } + return false } -func (a *App) resizePanels() {} +// filteredColumns returns the columns matching the current filter text. +func (a *App) filteredColumns() []db.ColumnMeta { + if a.previewMeta == nil { + return nil + } + filter := strings.ToLower(a.previewFilter.Value()) + if filter == "" { + return a.previewMeta.Columns + } + var out []db.ColumnMeta + for _, col := range a.previewMeta.Columns { + if strings.Contains(strings.ToLower(col.Name), filter) || + strings.Contains(strings.ToLower(col.Description), filter) { + out = append(out, col) + } + } + return out +} + +// Err returns the last error message (login or status), if any. +func (a *App) Err() string { + if a.loginErr != "" { + return a.loginErr + } + return a.statusErr +} // -- helpers -- @@ -421,17 +862,15 @@ func selectedItemTitle(l list.Model) string { return "" } -func maxWidth(header string, rows [][]string, col, max int) int { - w := len(header) - for _, row := range rows { - if col < len(row) && len(row[col]) > w { - w = len(row[col]) - } +func truncStr(s string, max int) string { + r := []rune(s) + if len(r) <= max { + return s } - if w > max { - return max + if max <= 1 { + return "…" } - return w + 2 + return string(r[:max-1]) + "…" } func formatCount(n int64) string { diff --git a/internal/tui/loginform.go b/internal/tui/loginform.go @@ -0,0 +1,192 @@ +package tui + +import ( + "strings" + + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +type loginField int + +const ( + loginFieldUser loginField = iota + loginFieldPassword + loginFieldDatabase + loginFieldSave + loginFieldCount +) + +const loginTextInputs = 3 // number of text input fields (before the save toggle) + +// LoginForm is the login dialog overlay shown when credentials are missing. +type LoginForm struct { + inputs [loginTextInputs]textinput.Model + save bool + focused loginField +} + +// LoginSubmitMsg is sent when the user confirms the login form. +type LoginSubmitMsg struct { + User string + Password string + Database string + Save bool +} + +// LoginCancelMsg is sent when the user cancels the login form. +type LoginCancelMsg struct{} + +func newLoginForm() LoginForm { + f := LoginForm{} + + f.inputs[loginFieldUser] = textinput.New() + f.inputs[loginFieldUser].Placeholder = "WRDS username" + f.inputs[loginFieldUser].CharLimit = 128 + + f.inputs[loginFieldPassword] = textinput.New() + f.inputs[loginFieldPassword].Placeholder = "WRDS password" + f.inputs[loginFieldPassword].CharLimit = 128 + f.inputs[loginFieldPassword].EchoMode = textinput.EchoPassword + f.inputs[loginFieldPassword].EchoCharacter = '*' + + f.inputs[loginFieldDatabase] = textinput.New() + f.inputs[loginFieldDatabase].Placeholder = "wrds" + f.inputs[loginFieldDatabase].CharLimit = 128 + f.inputs[loginFieldDatabase].SetValue("wrds") + + f.save = true + f.inputs[loginFieldUser].Focus() + return f +} + +func (f LoginForm) Update(msg tea.Msg) (LoginForm, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "esc": + return f, func() tea.Msg { return LoginCancelMsg{} } + + case "enter": + if f.focused == loginFieldSave { + // Submit + user := strings.TrimSpace(f.inputs[loginFieldUser].Value()) + pw := f.inputs[loginFieldPassword].Value() + if user == "" || pw == "" { + return f, nil + } + database := strings.TrimSpace(f.inputs[loginFieldDatabase].Value()) + if database == "" { + database = "wrds" + } + return f, func() tea.Msg { + return LoginSubmitMsg{User: user, Password: pw, Database: database, Save: f.save} + } + } + // Advance to next field + if int(f.focused) < loginTextInputs { + f.inputs[f.focused].Blur() + } + f.focused++ + if int(f.focused) < loginTextInputs { + f.inputs[f.focused].Focus() + return f, textinput.Blink + } + return f, nil + + case "tab", "down": + if int(f.focused) < loginTextInputs { + f.inputs[f.focused].Blur() + } + f.focused = loginField((int(f.focused) + 1) % int(loginFieldCount)) + if int(f.focused) < loginTextInputs { + f.inputs[f.focused].Focus() + return f, textinput.Blink + } + return f, nil + + case "shift+tab", "up": + if int(f.focused) < loginTextInputs { + f.inputs[f.focused].Blur() + } + f.focused = loginField((int(f.focused) + int(loginFieldCount) - 1) % int(loginFieldCount)) + if int(f.focused) < loginTextInputs { + f.inputs[f.focused].Focus() + return f, textinput.Blink + } + return f, nil + + case " ": + if f.focused == loginFieldSave { + f.save = !f.save + return f, nil + } + } + } + + // Forward to focused text input + if int(f.focused) < loginTextInputs { + var cmd tea.Cmd + f.inputs[f.focused], cmd = f.inputs[f.focused].Update(msg) + return f, cmd + } + return f, nil +} + +func (f LoginForm) View(width int, errMsg string) string { + var sb strings.Builder + + title := stylePanelHeader.Render("WRDS Login") + sb.WriteString(title + "\n\n") + + labels := []string{"Username", "Password", "Database"} + for i, label := range labels { + style := lipgloss.NewStyle().Foreground(colorMuted) + if loginField(i) == f.focused { + style = lipgloss.NewStyle().Foreground(colorFocus) + } + sb.WriteString(style.Render(label+" ") + "\n") + sb.WriteString(f.inputs[i].View() + "\n\n") + } + + // Save toggle + check := "[ ]" + if f.save { + check = "[x]" + } + saveStyle := lipgloss.NewStyle().Foreground(colorMuted) + if f.focused == loginFieldSave { + saveStyle = lipgloss.NewStyle().Foreground(colorFocus) + } + sb.WriteString(saveStyle.Render(check+" Save to ~/.config/wrds-dl/credentials") + "\n\n") + + if errMsg != "" { + maxLen := 52 + if len(errMsg) > maxLen { + errMsg = errMsg[:maxLen-1] + "…" + } + sb.WriteString(styleError.Render("Error: "+errMsg) + "\n\n") + } + + hint := styleStatusBar.Render("[tab] next field [enter] submit [esc] quit") + sb.WriteString(hint) + + content := sb.String() + boxWidth := 60 + if boxWidth > width-4 { + boxWidth = width - 4 + } + if boxWidth < 40 { + boxWidth = 40 + } + + box := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorFocus). + Padding(1, 2). + Width(boxWidth). + Render(content) + + return lipgloss.Place(width, 24, lipgloss.Center, lipgloss.Center, box) +}