commit fdb6e793588099732b112541b18e1018ca08f65c
parent b3b6222eb5461b6e738b6e3c7bcd2a0348ce496f
Author: Erik Loualiche <[email protected]>
Date: Wed, 25 Feb 2026 09:16:04 -0600
optimize tlag/tlead for Month/Year with Int64 pre-computation dispatch
Pre-compute Date values as Int64 for expensive Month/Year arithmetic,
keeping cheap Int/Day cases on the direct scan path. ~10-15% faster
for Month/Year with no regression on Int/Day.
Co-Authored-By: Claude Opus 4.6 <[email protected]>
Diffstat:
| M | src/TimeShift.jl | | | 126 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------- |
| A | test/bench_timeshift.jl | | | 161 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
2 files changed, 253 insertions(+), 34 deletions(-)
diff --git a/src/TimeShift.jl b/src/TimeShift.jl
@@ -25,6 +25,86 @@ function _validate_tshift_args(x, t_vec; n=nothing, checksorted=true, verbose=fa
return n, N
end
+
+
+# Linear scan on native types (integers — already fast)
+function _scan_lag!(x_shift, x, t_vec, n, N)
+ j = 0
+ @inbounds for i in 1:N
+ target = t_vec[i] - n
+ while j < N && t_vec[j + 1] <= target
+ j += 1
+ end
+ if j > 0 && t_vec[j] == target
+ x_shift[i] = x[j]
+ end
+ end
+ return x_shift
+end
+
+function _scan_lead!(x_shift, x, t_vec, n, N)
+ j = 0
+ @inbounds for i in 1:N
+ target = t_vec[i] + n
+ if target > t_vec[N]
+ break
+ end
+ while j < N && t_vec[j + 1] < target
+ j += 1
+ end
+ if j + 1 <= N && t_vec[j + 1] == target
+ x_shift[i] = x[j + 1]
+ end
+ end
+ return x_shift
+end
+
+
+# Pre-computed Int64 scan for Date types.
+# Date arithmetic (especially Month/Year) is expensive; converting to Int64
+# first keeps the hot scan loop in pure integer comparisons.
+function _scan_lag_int64!(x_shift, x, t_vec, n, N)
+ int_times = Vector{Int64}(undef, N)
+ int_targets = Vector{Int64}(undef, N)
+ @inbounds for i in 1:N
+ int_times[i] = Dates.value(t_vec[i])
+ int_targets[i] = Dates.value(t_vec[i] - n)
+ end
+ j = 0
+ @inbounds for i in 1:N
+ target = int_targets[i]
+ while j < N && int_times[j + 1] <= target
+ j += 1
+ end
+ if j > 0 && int_times[j] == target
+ x_shift[i] = x[j]
+ end
+ end
+ return x_shift
+end
+
+function _scan_lead_int64!(x_shift, x, t_vec, n, N)
+ int_times = Vector{Int64}(undef, N)
+ int_targets = Vector{Int64}(undef, N)
+ @inbounds for i in 1:N
+ int_times[i] = Dates.value(t_vec[i])
+ int_targets[i] = Dates.value(t_vec[i] + n)
+ end
+ j = 0
+ @inbounds for i in 1:N
+ target = int_targets[i]
+ if target > int_times[N]
+ break
+ end
+ while j < N && int_times[j + 1] < target
+ j += 1
+ end
+ if j + 1 <= N && int_times[j + 1] == target
+ x_shift[i] = x[j + 1]
+ end
+ end
+ return x_shift
+end
# --------------------------------------------------------------------------------------------------
@@ -51,7 +131,6 @@ backward in time by a specified amount `n`.
# Notes
- Time vectors must be strictly sorted (ascending order)
- The time gap `n` must be positive
-- Uses linear scan to match time points
- For `Date` types, no type checking is performed on `n`
- Elements at the beginning will be `missing` if they don't have values from `n` time units ago
- See PanelShift.jl for original implementation
@@ -83,22 +162,15 @@ function tlag(x, t_vec;
n, N = _validate_tshift_args(x, t_vec; n=n, checksorted=checksorted, verbose=verbose)
x_shift = Array{Union{Missing, eltype(x)}}(missing, N)
- _linear_scan!(x_shift, x, t_vec, n, N)
-
- return x_shift
-end
-function _linear_scan!(x_shift, x, t_vec, n, N)
- j = 0
- @inbounds for i in 1:N
- lagt = t_vec[i] - n
- while j < N && t_vec[j + 1] <= lagt
- j += 1
- end
- if j > 0 && t_vec[j] == lagt
- x_shift[i] = x[j]
- end
+ # Month/Year arithmetic is expensive; pre-compute Int64 targets for those.
+ # Day and integer arithmetic is cheap; scan directly.
+ if n isa Dates.OtherPeriod
+ _scan_lag_int64!(x_shift, x, t_vec, n, N)
+ else
+ _scan_lag!(x_shift, x, t_vec, n, N)
end
+
return x_shift
end
# --------------------------------------------------------------------------------------------------
@@ -127,7 +199,6 @@ forward in time by a specified amount `n`.
# Notes
- Time vectors must be strictly sorted (ascending order)
- The time gap `n` must be positive
-- Uses linear scan to match time points
- For `Date` types, no type checking is performed on `n`
- Elements at the end will be `missing` if they don't have values from `n` time units in the future
- See PanelShift.jl for original implementation
@@ -159,26 +230,13 @@ function tlead(x, t_vec;
n, N = _validate_tshift_args(x, t_vec; n=n, checksorted=checksorted, verbose=verbose)
x_shift = Array{Union{Missing, eltype(x)}}(missing, N)
- _linear_scan_lead!(x_shift, x, t_vec, n, N)
- return x_shift
-end
-
-function _linear_scan_lead!(x_shift, x, t_vec, n, N)
- j = 0
-
- @inbounds for i in 1:N
- leadt = t_vec[i] + n
- if leadt > t_vec[N]
- break
- end
- while j < N && t_vec[j + 1] < leadt
- j += 1
- end
- if j + 1 <= N && t_vec[j + 1] == leadt
- x_shift[i] = x[j + 1]
- end
+ if n isa Dates.OtherPeriod
+ _scan_lead_int64!(x_shift, x, t_vec, n, N)
+ else
+ _scan_lead!(x_shift, x, t_vec, n, N)
end
+
return x_shift
end
# --------------------------------------------------------------------------------------------------
diff --git a/test/bench_timeshift.jl b/test/bench_timeshift.jl
@@ -0,0 +1,161 @@
+#!/usr/bin/env julia
+# Benchmark: tlag/tlead performance comparison
+#
+# Compares three approaches:
+# 1. Old linear scan (Date arithmetic interleaved with comparisons)
+# 2. New: pre-compute Int64 targets, scan in pure Int64
+# 3. Dict-based O(1) lookup
+#
+# Run with: julia --project test/bench_timeshift.jl
+
+using BazerData
+using Dates
+using Random
+using Statistics
+
+# --- Old linear scan (Date objects in hot loop) ---
+function tlag_oldscan(x, t_vec, n)
+ N = length(t_vec)
+ x_shift = Array{Union{Missing, eltype(x)}}(missing, N)
+ j = 0
+ @inbounds for i in 1:N
+ lagt = t_vec[i] - n
+ while j < N && t_vec[j + 1] <= lagt
+ j += 1
+ end
+ if j > 0 && t_vec[j] == lagt
+ x_shift[i] = x[j]
+ end
+ end
+ return x_shift
+end
+
+function tlead_oldscan(x, t_vec, n)
+ N = length(t_vec)
+ x_shift = Array{Union{Missing, eltype(x)}}(missing, N)
+ j = 0
+ @inbounds for i in 1:N
+ leadt = t_vec[i] + n
+ if leadt > t_vec[N]; break; end
+ while j < N && t_vec[j + 1] < leadt
+ j += 1
+ end
+ if j + 1 <= N && t_vec[j + 1] == leadt
+ x_shift[i] = x[j + 1]
+ end
+ end
+ return x_shift
+end
+
+# --- Dict-based lookup ---
+function tlag_dict(x, t_vec, n)
+ N = length(t_vec)
+ x_shift = Array{Union{Missing, eltype(x)}}(missing, N)
+ lookup = Dict{eltype(t_vec), Int}()
+ sizehint!(lookup, N)
+ @inbounds for i in 1:N; lookup[t_vec[i]] = i; end
+ @inbounds for i in 1:N
+ idx = get(lookup, t_vec[i] - n, 0)
+ if idx > 0; x_shift[i] = x[idx]; end
+ end
+ return x_shift
+end
+
+
+# --- Benchmark harness ---
+function bench(f; warmup=3, trials=15)
+ for _ in 1:warmup; f(); end
+ GC.gc()
+ times = Float64[]
+ for _ in 1:trials
+ t0 = time_ns()
+ f()
+ push!(times, (time_ns() - t0) / 1e6) # ms
+ end
+ return (median=median(times), min=minimum(times))
+end
+
+function report(label, old, new; dict=nothing)
+ speedup = old.median / new.median
+ color = speedup >= 1.0 ? "\033[32m" : "\033[31m"
+ reset = "\033[0m"
+ line = " $(rpad(label, 28)) old=$(rpad(round(old.median, digits=2), 8))ms " *
+ "new=$(rpad(round(new.median, digits=2), 8))ms " *
+ "$(color)$(round(speedup, digits=2))x$(reset)"
+ if dict !== nothing
+ ds = old.median / dict.median
+ dc = ds >= 1.0 ? "\033[32m" : "\033[31m"
+ line *= " dict=$(rpad(round(dict.median, digits=2), 8))ms $(dc)$(round(ds, digits=2))x$(reset)"
+ end
+ println(line)
+end
+
+
+# --- Generate test data ---
+function make_daily_dates(n; gap_prob=0.1, seed=42)
+ Random.seed!(seed)
+ dates = Vector{Date}(undef, n)
+ d = Date(2000, 1, 1)
+ for i in 1:n
+ dates[i] = d
+ d += Day(rand() < gap_prob ? rand(2:5) : 1)
+ end
+ return dates
+end
+
+function make_integers(n; gap_prob=0.1, seed=42)
+ Random.seed!(seed)
+ ts = Vector{Int}(undef, n)
+ t = 1
+ for i in 1:n
+ ts[i] = t
+ t += rand() < gap_prob ? rand(2:5) : 1
+ end
+ return ts
+end
+
+
+# --- Run benchmarks ---
+println("\n" * "="^80)
+println(" TimeShift Benchmark")
+println(" old = linear scan on Date objects")
+println(" new = pre-compute Int64 targets, scan in Int64")
+println(" dict = Dict{T,Int} lookup")
+println("="^80)
+
+for N in [100_000, 1_000_000]
+ println("\n--- N = $(N ÷ 1000)K elements ---")
+
+ dates = make_daily_dates(N)
+ ints = make_integers(N)
+ x_f = rand(N)
+ x_i = rand(1:1000, N)
+
+ println("\n tlag:")
+ for (lbl, t, x, n) in [
+ ("Int, n=1", ints, x_i, 1),
+ ("Int, n=365", ints, x_i, 365),
+ ("Date, n=Day(1)", dates, x_f, Day(1)),
+ ("Date, n=Month(1)", dates, x_f, Month(1)),
+ ("Date, n=Year(1)", dates, x_f, Year(1)),
+ ]
+ old = bench(() -> tlag_oldscan(x, t, n))
+ new = bench(() -> tlag(x, t; n=n, checksorted=false))
+ dict = bench(() -> tlag_dict(x, t, n))
+ report(lbl, old, new; dict=dict)
+ end
+
+ println("\n tlead:")
+ for (lbl, t, x, n) in [
+ ("Int, n=1", ints, x_i, 1),
+ ("Date, n=Day(1)", dates, x_f, Day(1)),
+ ("Date, n=Month(1)", dates, x_f, Month(1)),
+ ("Date, n=Year(1)", dates, x_f, Year(1)),
+ ]
+ old = bench(() -> tlead_oldscan(x, t, n))
+ new = bench(() -> tlead(x, t; n=n, checksorted=false))
+ report(lbl, old, new)
+ end
+end
+
+println("\n" * "="^80)