cli.py (5991B)
1 """Click CLI — download and info subcommands matching the Go wrds-dl interface.""" 2 3 from __future__ import annotations 4 5 import json 6 7 import click 8 import psycopg 9 10 from wrds_dl.config import apply_credentials 11 from wrds_dl.db import build_query, connect, dsn_from_env, table_meta 12 13 14 @click.group() 15 def cli() -> None: 16 """Download data from the WRDS PostgreSQL database to Parquet or CSV.""" 17 18 19 @cli.command() 20 @click.option("--schema", default="", help="Schema name (e.g. crsp)") 21 @click.option("--table", default="", help="Table name (e.g. dsf)") 22 @click.option("-c", "--columns", default="*", help="Columns to select (comma-separated, default *)") 23 @click.option("--where", "where_clause", default="", help="SQL WHERE clause (without the WHERE keyword)") 24 @click.option("--query", default="", help="Full SQL query (overrides --schema/--table/--where)") 25 @click.option("--out", default="", help="Output file path (required unless --dry-run)") 26 @click.option("--format", "fmt", default="", help="Output format: parquet or csv (inferred from extension)") 27 @click.option("--limit", default=0, type=int, help="Limit number of rows (0 = no limit)") 28 @click.option("--dry-run", is_flag=True, help="Preview query, row count, and first 5 rows") 29 def download( 30 schema: str, 31 table: str, 32 columns: str, 33 where_clause: str, 34 query: str, 35 out: str, 36 fmt: str, 37 limit: int, 38 dry_run: bool, 39 ) -> None: 40 """Download WRDS data to Parquet or CSV.""" 41 apply_credentials() 42 43 # Build query. 44 if query: 45 sql = query 46 elif schema and table: 47 sql = build_query(schema, table, columns, where_clause, limit) 48 else: 49 raise click.UsageError("Either --query or both --schema and --table must be specified") 50 51 if dry_run: 52 _run_dry_run(sql) 53 return 54 55 if not out: 56 raise click.UsageError('Required option "--out" not provided') 57 58 # Resolve format. 59 resolved_fmt = fmt.lower() if fmt else ("csv" if out.lower().endswith(".csv") else "parquet") 60 61 click.echo(f"Exporting to {out} ({resolved_fmt})...", err=True) 62 63 from wrds_dl.export import export_data 64 65 def progress(rows: int) -> None: 66 click.echo(f"Exported {rows} rows...", err=True) 67 68 export_data(sql, out, resolved_fmt, progress) 69 click.echo(f"Done: {out}", err=True) 70 71 72 def _run_dry_run(sql: str) -> None: 73 """Print query, row count, and first 5 rows.""" 74 conn = psycopg.connect(dsn_from_env()) 75 try: 76 with conn.cursor() as cur: 77 click.echo("Query:") 78 click.echo(f" {sql}") 79 click.echo() 80 81 # Row count. 82 cur.execute(f"SELECT count(*) FROM ({sql}) sub") 83 row = cur.fetchone() 84 count = row[0] if row else 0 85 click.echo(f"Row count: {count}") 86 click.echo() 87 88 # Preview first 5 rows. 89 cur.execute(f"SELECT * FROM ({sql}) sub LIMIT 5") 90 if cur.description is None: 91 return 92 93 col_names = [desc.name for desc in cur.description] 94 rows = cur.fetchall() 95 96 # Calculate column widths. 97 widths = [len(name) for name in col_names] 98 str_rows = [] 99 for row in rows: 100 cells = [str(v) if v is not None else "NULL" for v in row] 101 str_rows.append(cells) 102 for i, cell in enumerate(cells): 103 widths[i] = max(widths[i], len(cell)) 104 105 # Print header and rows. 106 header = " ".join(name.ljust(widths[i]) for i, name in enumerate(col_names)) 107 click.echo(header) 108 for cells in str_rows: 109 click.echo(" ".join(cell.ljust(widths[i]) for i, cell in enumerate(cells))) 110 finally: 111 conn.close() 112 113 114 @cli.command() 115 @click.option("--schema", required=True, help="Schema name (required)") 116 @click.option("--table", required=True, help="Table name (required)") 117 @click.option("--json", "as_json", is_flag=True, help="Output as JSON") 118 def info(schema: str, table: str, as_json: bool) -> None: 119 """Show table metadata (columns, types, row count).""" 120 apply_credentials() 121 122 conn = connect() 123 try: 124 meta = table_meta(conn, schema, table) 125 finally: 126 conn.close() 127 128 if as_json: 129 _print_info_json(meta) 130 else: 131 _print_info_table(meta) 132 133 134 def _print_info_json(meta) -> None: 135 data = { 136 "schema": meta.schema, 137 "table": meta.table, 138 "comment": meta.comment or None, 139 "row_count": meta.row_count, 140 "size": meta.size or None, 141 "columns": [ 142 { 143 "name": c.name, 144 "type": c.data_type, 145 "nullable": c.nullable, 146 **({"description": c.description} if c.description else {}), 147 } 148 for c in meta.columns 149 ], 150 } 151 # Match Go: omit null keys 152 data = {k: v for k, v in data.items() if v is not None} 153 click.echo(json.dumps(data, indent=2)) 154 155 156 def _print_info_table(meta) -> None: 157 click.echo(f"{meta.schema}.{meta.table}") 158 if meta.comment: 159 click.echo(f" {meta.comment}") 160 161 parts = [] 162 if meta.row_count > 0: 163 parts.append(f"~{meta.row_count} rows") 164 if meta.size: 165 parts.append(meta.size) 166 if parts: 167 click.echo(f" {', '.join(parts)}") 168 169 click.echo() 170 171 # Column table with tab-aligned output. 172 widths = [4, 4, 8, 11] # NAME, TYPE, NULLABLE, DESCRIPTION minimums 173 rows = [] 174 for c in meta.columns: 175 nullable = "YES" if c.nullable else "NO" 176 row = [c.name, c.data_type, nullable, c.description] 177 rows.append(row) 178 for i, cell in enumerate(row): 179 widths[i] = max(widths[i], len(cell)) 180 181 header = " ".join( 182 label.ljust(widths[i]) 183 for i, label in enumerate(["NAME", "TYPE", "NULLABLE", "DESCRIPTION"]) 184 ) 185 click.echo(header) 186 for row in rows: 187 click.echo(" ".join(cell.ljust(widths[i]) for i, cell in enumerate(row)))