test_integration.py (4739B)
1 """Integration test: download a small CRSP MSF sample and verify output. 2 3 Requires WRDS credentials (PGUSER/PGPASSWORD or ~/.config/wrds-dl/credentials). 4 Skipped automatically when credentials are unavailable. 5 6 If the Go wrds-dl binary is found, downloads the same data with both 7 implementations and asserts their content hashes match. 8 """ 9 10 from __future__ import annotations 11 12 import hashlib 13 import os 14 import subprocess 15 import tempfile 16 from pathlib import Path 17 18 import pyarrow.parquet as pq 19 import pytest 20 21 from wrds_dl.config import load_credentials 22 23 # A narrow, deterministic query: 10 rows from crsp.msf for Jan 2020. 24 QUERY = ( 25 "SELECT permno, date, prc, ret, shrout " 26 "FROM crsp.msf " 27 "WHERE date = '2020-01-31' " 28 "ORDER BY permno " 29 "LIMIT 10" 30 ) 31 32 REPO_ROOT = Path(__file__).resolve().parents[2] 33 GO_BINARY = REPO_ROOT / "wrds-dl" # pre-built binary at repo root 34 35 36 def _has_credentials() -> bool: 37 if os.environ.get("PGUSER"): 38 return True 39 user, pw, _ = load_credentials() 40 return bool(user and pw) 41 42 43 pytestmark = pytest.mark.skipif( 44 not _has_credentials(), 45 reason="WRDS credentials not available", 46 ) 47 48 49 def _content_hash(parquet_path: str) -> str: 50 """Read a parquet file, sort deterministically, and return a SHA-256 of the content. 51 52 Converts all values to their repr() for a canonical representation 53 that is independent of the parquet writer (parquet-go vs pyarrow). 54 """ 55 table = pq.read_table(parquet_path) 56 # Normalize column order alphabetically. 57 col_names = sorted(table.column_names) 58 table = table.select(col_names) 59 # Sort rows by all columns. 60 sort_keys = [(col, "ascending") for col in col_names] 61 table = table.sort_by(sort_keys) 62 # Hash a canonical string representation of every cell. 63 h = hashlib.sha256() 64 h.update(",".join(col_names).encode()) 65 for i in range(table.num_rows): 66 for col_name in col_names: 67 val = table.column(col_name)[i].as_py() 68 h.update(repr(val).encode()) 69 h.update(b"|") 70 h.update(b"\n") 71 return h.hexdigest() 72 73 74 def test_python_download_parquet(): 75 """Download a small sample with the Python CLI and verify the parquet output.""" 76 with tempfile.TemporaryDirectory() as tmpdir: 77 out = os.path.join(tmpdir, "test_py.parquet") 78 79 from click.testing import CliRunner 80 from wrds_dl.cli import cli 81 82 runner = CliRunner() 83 result = runner.invoke(cli, ["download", "--query", QUERY, "--out", out]) 84 assert result.exit_code == 0, f"Python download failed: {result.output}" 85 86 # Verify parquet file. 87 table = pq.read_table(out) 88 assert table.num_rows == 10 89 assert set(table.column_names) == {"permno", "date", "prc", "ret", "shrout"} 90 91 py_hash = _content_hash(out) 92 assert len(py_hash) == 64 # valid sha256 93 94 95 @pytest.mark.skipif( 96 not GO_BINARY.is_file(), 97 reason=f"Go binary not found at {GO_BINARY}", 98 ) 99 def test_go_python_parity(): 100 """Download the same data with Go and Python, assert content hashes match.""" 101 with tempfile.TemporaryDirectory() as tmpdir: 102 py_out = os.path.join(tmpdir, "py.parquet") 103 go_out = os.path.join(tmpdir, "go.parquet") 104 105 # Python download. 106 from click.testing import CliRunner 107 from wrds_dl.cli import cli 108 109 runner = CliRunner() 110 result = runner.invoke(cli, ["download", "--query", QUERY, "--out", py_out]) 111 assert result.exit_code == 0, f"Python download failed: {result.output}" 112 113 # Go download. 114 env = os.environ.copy() 115 proc = subprocess.run( 116 [str(GO_BINARY), "download", "--query", QUERY, "--out", go_out], 117 capture_output=True, 118 text=True, 119 env=env, 120 timeout=60, 121 ) 122 assert proc.returncode == 0, f"Go download failed: {proc.stderr}" 123 124 # Compare content hashes. 125 py_hash = _content_hash(py_out) 126 go_hash = _content_hash(go_out) 127 128 # Read both tables for diagnostics on failure. 129 py_table = pq.read_table(py_out) 130 go_table = pq.read_table(go_out) 131 132 assert py_table.num_rows == go_table.num_rows, ( 133 f"Row count mismatch: Python={py_table.num_rows}, Go={go_table.num_rows}" 134 ) 135 assert set(py_table.column_names) == set(go_table.column_names), ( 136 f"Column mismatch: Python={py_table.column_names}, Go={go_table.column_names}" 137 ) 138 assert py_hash == go_hash, ( 139 f"Content hash mismatch:\n" 140 f" Python: {py_hash}\n" 141 f" Go: {go_hash}\n" 142 f" Python schema:\n{py_table.schema}\n" 143 f" Go schema:\n{go_table.schema}\n" 144 )