BazerData.jl

Data manipulation utilities for Julia
Log | Files | Refs | README | LICENSE

StataUtils.jl (17849B)


      1 # ------------------------------------------------------------------------------------------
      2 
      3 # StataUtils.jl
      4 
      5 # Collection of functions that replicate some stata utilities
      6 # ------------------------------------------------------------------------------------------
      7 
      8 
      9 
     10 # ------------------------------------------------------------------------------------------
     11 # List of exported functions
     12 # tabulate
     13 # xtile
     14 # ------------------------------------------------------------------------------------------
     15 
     16 
     17 # ------------------------------------------------------------------------------------------
     18 """
     19     tabulate(df::AbstractDataFrame, cols::Union{Symbol, Array{Symbol}};
     20         reorder_cols=true, out::Symbol=:stdout)
     21 
     22 Frequency tabulation inspired by Stata's `tabulate` command.
     23 Forked from TexTables.jl and inspired by https://github.com/matthieugomez/statar
     24 
     25 # Arguments
     26 - `df::AbstractDataFrame`: Input DataFrame to analyze
     27 - `cols::Union{Symbol, Vector{Symbol}}`: Single column name or vector of column names to tabulate
     28 - `group_type::Union{Symbol, Vector{Symbol}}=:value`: Specifies how to group each column:
     29     - `:value`: Group by the actual values in the column
     30     - `:type`: Group by the type of values in the column
     31     - `Vector{Symbol}`: Vector combining `:value` and `:type` for different columns
     32 - `reorder_cols::Bool=true`  Whether to sort the output by sortable columns
     33 - `format_tbl::Symbol=:long` How to present the results long or wide (stata twoway)
     34 - `format_stat::Symbol=:freq`  Which statistics to present for format :freq or :pct
     35 - `skip_stat::Union{Nothing, Symbol, Vector{Symbol}}=nothing`  do not print out all statistics (only for string)
     36 - `out::Symbol=:stdout`  Output format:
     37     - `:stdout`  Print formatted table to standard output (returns nothing)
     38     - `:df`  Return the result as a DataFrame
     39     - `:string` Return the formatted table as a string
     40 
     41 # Returns
     42 - `Nothing` if `out=:stdout`
     43 - `DataFrame` if `out=:df`
     44 - `String` if `out=:string`
     45 
     46 # Output Format
     47 The resulting table contains the following columns:
     48 - Specified grouping columns (from `cols`)
     49 - `freq`: Frequency count
     50 - `pct`: Percentage of total
     51 - `cum`: Cumulative percentage
     52 
     53 # Examples
     54 See the README for more examples
     55 ```julia
     56 # Simple frequency table for one column
     57 tabulate(df, :country)
     58 
     59 ## Group by value type
     60 tabulate(df, :age, group_type=:type)
     61 
     62 # Multiple columns with mixed grouping
     63 tabulate(df, [:country, :age], group_type=[:value, :type])
     64 
     65 # Return as DataFrame instead of printing
     66 result_df = tabulate(df, :country, out=:df)
     67 ```
     68 
     69 """
     70 function tabulate(
     71     df::AbstractDataFrame, cols::Union{Symbol, Vector{Symbol}};
     72     group_type::Union{Symbol, Vector{Symbol}}=:value,
     73     reorder_cols::Bool=true,
     74     format_tbl::Symbol=:long,
     75     format_stat::Symbol=:freq,
     76     skip_stat::Union{Nothing, Symbol, Vector{Symbol}}=nothing,
     77     out::Symbol=:stdout)
     78 
     79     N_COLS = cols isa Symbol ? 1 : length(cols)
     80 
     81     if !(format_tbl ∈ [:long, :wide])
     82         if N_COLS == 1
     83             @warn "Converting format_tbl to :long"
     84             format_tbl = :long
     85         else
     86             error("Table format_tbl must be :long or :wide")
     87         end
     88     end
     89 
     90     if isempty(df)
     91         @warn "Input Dataframe is empty ..."
     92         return nothing
     93     end
     94 
     95     df_out, new_cols = _tabulate_compute(df, cols, group_type, reorder_cols)
     96 
     97     if format_tbl == :long
     98         return _tabulate_render_long(df_out, new_cols, N_COLS, out, skip_stat)
     99     else  # :wide
    100         return _tabulate_render_wide(df_out, new_cols, N_COLS, format_stat, out)
    101     end
    102 end
    103 
    104 
    105 # ----- Computation: groupby, combine, sort, pct/cum transforms
    106 function _tabulate_compute(df, cols, group_type, reorder_cols)
    107     group_type_error_msg = """
    108         \ngroup_type input must specify either ':value' or ':type' for columns;
    109         options are :value, :type, or a vector combining the two;
    110         see help for more information
    111         """
    112     if group_type == :value
    113         df_out = combine(groupby(df, cols), nrow => :freq, proprow =>:pct)
    114         new_cols = cols
    115     elseif group_type == :type
    116         name_type_cols = Symbol.(cols, "_typeof")
    117         df_out = transform(df, cols .=> ByRow(typeof) .=> name_type_cols) |>
    118             (d -> combine(groupby(d, name_type_cols), nrow => :freq, proprow =>:pct))
    119         new_cols = name_type_cols
    120     elseif group_type isa Vector{Symbol}
    121         !all(s -> s in [:value, :type], group_type) && error(group_type_error_msg)
    122         (size(group_type, 1) != size(cols, 1)) &&
    123             error("group_type and cols must be the same size; see help for more information")
    124         type_cols = cols[group_type .== :type]
    125         name_type_cols = Symbol.(type_cols, "_typeof")
    126         group_cols = [cols[group_type .== :value]; name_type_cols]
    127         df_out = transform(df, type_cols .=> ByRow(typeof) .=> name_type_cols) |>
    128             (d -> combine(groupby(d, group_cols), nrow => :freq, proprow =>:pct))
    129         new_cols = group_cols
    130     else
    131         error(group_type_error_msg)
    132     end
    133     # resort columns based on the original order
    134     new_cols = sort(new_cols isa Symbol ? [new_cols] : new_cols,
    135         by= x -> findfirst(==(replace(string(x), r"_typeof$" => "")), string.(cols)) )
    136 
    137     if reorder_cols
    138         cols_sortable = [
    139             name
    140             for (name, col) in pairs(eachcol(select(df_out, new_cols)))
    141             if eltype(col) |> t -> hasmethod(isless, Tuple{t,t})
    142         ]
    143         if !isempty(cols_sortable)
    144             sort!(df_out, cols_sortable)  # order before we build cumulative
    145         end
    146     end
    147     transform!(df_out, :pct => cumsum => :cum, :freq => ByRow(Int) => :freq)
    148     transform!(df_out,
    149         :pct => (x -> x .* 100),
    150         :cum => (x -> Int.(round.(x .* 100, digits=0))), renamecols=false)
    151 
    152     return df_out, new_cols
    153 end
    154 
    155 
    156 # ----- Long format rendering
    157 function _tabulate_render_long(df_out, new_cols, N_COLS, out, skip_stat)
    158     transform!(df_out, :freq => (x->text_histogram(x, width=24)) => :freq_hist)
    159 
    160     # highlighter with gradient for the freq/pct/cum columns (rest is cyan)
    161     col_highlighters = Tuple(vcat(
    162         map(i -> Highlighter((data, row, col) -> col == i, crayon"cyan bold"), 1:N_COLS),
    163         hl_custom_gradient(cols=(N_COLS+1), colorscheme=:Oranges_9, scale=maximum(df_out.freq)),
    164         hl_custom_gradient(cols=(N_COLS+2), colorscheme=:Greens_9,  scale=ceil(Int, maximum(df_out.pct))),
    165         hl_custom_gradient(cols=(N_COLS+3), colorscheme=:Greens_9, scale=100),
    166     ))
    167 
    168     # when skip_stat is provided and output is string, filter columns
    169     if out == :string && !isnothing(skip_stat)
    170         all_stats = [:freq, :pct, :cum, :freq_hist]
    171         skip_list = skip_stat isa Vector ? skip_stat : [skip_stat]
    172         col_stat = setdiff(all_stats, skip_list)
    173         N_COL_STAT = length(col_stat)
    174 
    175         stat_headers = Dict(:freq=>"Freq.", :pct=>"Percent", :cum=>"Cum", :freq_hist=>"Hist.")
    176         stat_formats = Dict(:freq=>"%d", :pct=>"%.1f", :cum=>"%d", :freq_hist=>"%s")
    177         stat_colorschemes = Dict(
    178             :freq => (:Oranges_9, maximum(df_out.freq)),
    179             :pct  => (:Greens_9, ceil(Int, maximum(df_out.pct))),
    180             :cum  => (:Greens_9, 100),
    181         )
    182 
    183         header = vcat(string.(new_cols),
    184             [stat_headers[k] for k in col_stat])
    185         formatters = Tuple(vcat(
    186             [ft_printf("%s", i) for i in 1:N_COLS],
    187             [ft_printf(stat_formats[k], N_COLS + i) for (i, k) in enumerate(col_stat)]
    188         ))
    189         # rebuild highlighters for the filtered column layout
    190         filtered_highlighters = Tuple(vcat(
    191             map(i -> Highlighter((data, row, col) -> col == i, crayon"cyan bold"), 1:N_COLS),
    192             [haskey(stat_colorschemes, k) ?
    193                 hl_custom_gradient(cols=N_COLS+i, colorscheme=stat_colorschemes[k][1], scale=stat_colorschemes[k][2]) :
    194                 Highlighter((data, row, col) -> col == N_COLS+i, crayon"white")
    195              for (i, k) in enumerate(col_stat)]
    196         ))
    197         alignment = vcat(repeat([:l], N_COLS), repeat([:c], N_COL_STAT))
    198         cell_alignment = reduce(push!,
    199             map(i -> (i,1)=>:l, 1:N_COLS+N_COL_STAT-1),
    200             init=Dict{Tuple{Int64, Int64}, Symbol}())
    201 
    202         df_render = select(df_out, new_cols, col_stat)
    203         return _render_pretty_table(df_render, out;
    204             hlines=[1], vlines=[N_COLS],
    205             alignment=alignment, cell_alignment=cell_alignment,
    206             header=header, formatters=formatters, highlighters=filtered_highlighters)
    207     end
    208 
    209     # default: all stat columns
    210     header = [string.(new_cols); "Freq."; "Percent"; "Cum"; "Hist."]
    211     formatters = Tuple(vcat(
    212         [ft_printf("%s", i) for i in 1:N_COLS],
    213         [ft_printf("%d", N_COLS+1), ft_printf("%.1f", N_COLS+2),
    214          ft_printf("%d", N_COLS+3), ft_printf("%s", N_COLS+4)]
    215     ))
    216     alignment = vcat(repeat([:l], N_COLS), :c, :c, :c, :c)
    217     cell_alignment = reduce(push!,
    218         map(i -> (i,1)=>:l, 1:N_COLS+3),
    219         init=Dict{Tuple{Int64, Int64}, Symbol}())
    220 
    221     return _render_pretty_table(df_out, out;
    222         hlines=[1], vlines=[N_COLS],
    223         alignment=alignment, cell_alignment=cell_alignment,
    224         header=header, formatters=formatters, highlighters=col_highlighters)
    225 end
    226 
    227 
    228 # ----- Wide format rendering
    229 function _tabulate_render_wide(df_out, new_cols, N_COLS, format_stat, out)
    230     format_stat ∈ (:freq, :pct) || error("format_stat must be :freq or :pct, got :$format_stat")
    231     df_out = unstack(df_out,
    232         new_cols[1:(N_COLS-1)], new_cols[N_COLS], format_stat,
    233         allowmissing=true)
    234 
    235     N_GROUP_COLS = N_COLS - 1
    236     N_VAR_COLS   = size(df_out, 2) - N_GROUP_COLS
    237 
    238     if format_stat == :freq
    239 
    240         # frequency: add row and column totals
    241         total_row_des = "Total by $(string(new_cols[N_COLS]))"
    242         total_col_des = join(vcat("Total by ", join(string.(new_cols[1:(N_COLS-1)]), ", ")))
    243 
    244         sum_cols = sum.(skipmissing.(eachcol(df_out[:, range(1+N_GROUP_COLS; length=N_VAR_COLS)])))
    245         row_vector = vcat([total_row_des], repeat(["-"], max(0, N_GROUP_COLS-1)), sum_cols)
    246         df_out = vcat(df_out,
    247             DataFrame(permutedims(row_vector)[:, end+1-size(df_out,2):end], names(df_out)))
    248         sum_rows = sum.(skipmissing.(eachrow(df_out[:, range(1+N_GROUP_COLS; length=N_VAR_COLS)])))
    249         col_vector = rename(DataFrame(total = sum_rows), "total" => total_col_des)
    250         df_out = hcat(df_out, col_vector)
    251         rename!(df_out, [i => "-"^i for i in 1:N_GROUP_COLS])
    252 
    253         col_highlighters = Tuple(vcat(
    254             map(i -> Highlighter((data, row, col) -> col == i, crayon"cyan bold"), 1:N_GROUP_COLS),
    255             [ hl_custom_gradient(cols=i, colorscheme=:Greens_9,
    256                     scale = ceil(Int, maximum(skipmissing(df_out[1:end-1, i]))))
    257               for i in  range(1+N_GROUP_COLS; length=N_VAR_COLS) ],
    258             Highlighter((data, row, col) -> col == size(df_out, 2), crayon"green")
    259         ))
    260 
    261         formatters = Tuple(vcat(
    262             [ ft_printf("%s", i) for i in 1:N_GROUP_COLS ],
    263             [ ft_printf("%d", j) for j in range(1+N_GROUP_COLS; length=N_VAR_COLS) ],
    264             [ ft_printf("%d", 1+N_GROUP_COLS+N_VAR_COLS) ]
    265         ))
    266 
    267         hlines = [1, size(df_out, 1)]
    268         vlines = [N_GROUP_COLS, N_GROUP_COLS+N_VAR_COLS]
    269         alignment = vcat(repeat([:l], N_GROUP_COLS), repeat([:c], N_VAR_COLS), [:l])
    270 
    271     elseif format_stat == :pct
    272 
    273         col_highlighters = Tuple(vcat(
    274             map(i -> Highlighter((data, row, col) -> col == i, crayon"cyan bold"), 1:N_GROUP_COLS),
    275             [ hl_custom_gradient(cols=i, colorscheme=:Greens_9,
    276                     scale = ceil(Int, maximum(skipmissing(df_out[:, i]))) )
    277               for i in  range(1+N_GROUP_COLS; length=N_VAR_COLS) ],
    278         ))
    279 
    280         formatters = Tuple(vcat(
    281             [ ft_printf("%s", i) for i in 1:N_GROUP_COLS ],
    282             [ ft_printf("%.1f", j) for j in range(1+N_GROUP_COLS; length=N_VAR_COLS) ]
    283         ))
    284 
    285         hlines = [1]
    286         vlines = [0, N_GROUP_COLS, N_GROUP_COLS+N_VAR_COLS]
    287         alignment = vcat(repeat([:l], N_GROUP_COLS), repeat([:c], N_VAR_COLS))
    288 
    289     end
    290 
    291     cell_alignment = reduce(push!,
    292         map(i -> (i,1)=>:l, 1:N_GROUP_COLS),
    293         init=Dict{Tuple{Int64, Int64}, Symbol}())
    294 
    295     return _render_pretty_table(df_out, out;
    296         hlines=hlines, vlines=vlines,
    297         alignment=alignment, cell_alignment=cell_alignment,
    298         formatters=formatters, highlighters=col_highlighters,
    299         show_subheader=false)
    300 end
    301 
    302 
    303 # ----- Unified pretty_table output handler (stdout / df / string)
    304 function _render_pretty_table(df, out::Symbol; show_subheader=true, pt_kwargs...)
    305     common = (
    306         border_crayon = crayon"bold yellow",
    307         header_crayon = crayon"bold light_green",
    308         show_header = true,
    309         show_subheader = show_subheader,
    310     )
    311 
    312     if out ∈ [:stdout, :df]
    313         pretty_table(df; common..., vcrop_mode=:middle, pt_kwargs...)
    314         return out == :stdout ? nothing : df
    315     else  # :string
    316         return pretty_table(String, df; common..., crop=:none, pt_kwargs...)
    317     end
    318 end
    319 # --------------------------------------------------------------------------------------------------
    320 
    321 
    322 # --------------------------------------------------------------------------------------------------
    323 function hl_custom_gradient(;
    324     cols::Int=0,
    325     colorscheme::Symbol=:Oranges_9,
    326     scale::Int=1)
    327 
    328     Highlighter(
    329     (data, i, j) -> j == cols,
    330     (h, data, i, j) -> begin
    331         if ismissing(data[i, j])
    332             return Crayon(foreground=(128, 128, 128))  # Use a default color for missing values
    333         end
    334         color = get(colorschemes[colorscheme], data[i, j], (0, scale))
    335         return Crayon(foreground=(round(Int, color.r * 255),
    336                                   round(Int, color.g * 255),
    337                                   round(Int, color.b * 255)))
    338     end
    339 )
    340 
    341 end
    342 # --------------------------------------------------------------------------------------------------
    343 
    344 
    345 # --------------------------------------------------------------------------------------------------
    346 # From https://github.com/mbauman/Sparklines.jl/blob/master/src/Sparklines.jl
    347 # Unicode characters:
    348 # █ (Full block, U+2588)
    349 # ⣿ (Full Braille block, U+28FF)
    350 # ▓ (Dark shade, U+2593)
    351 # ▒ (Medium shade, U+2592)
    352 # ░ (Light shade, U+2591)
    353 # ◼ (Small black square, U+25FC)
    354 
    355 function text_histogram(frequencies; width=12)
    356     blocks = [" ", "▏", "▎", "▍", "▌", "▋", "▊", "▉", "█"]
    357     max_freq = maximum(frequencies)
    358     max_freq == 0 && return fill(" " ^ width, length(frequencies))
    359     scale = (width * 8 - 1) / max_freq  # Subtract 1 to ensure we don't exceed width
    360 
    361     function bar(f)
    362         units = round(Int, f * scale)
    363         full_blocks = div(units, 8)
    364         remainder = units % 8
    365         rpad(repeat("█", full_blocks) * blocks[remainder + 1], width)
    366     end
    367     bar.(frequencies)
    368 end
    369 # --------------------------------------------------------------------------------------------------
    370 
    371 
    372 
    373 # --------------------------------------------------------------------------------------------------
    374 
    375 """
    376     xtile(data::Vector{T}, n_quantiles::Integer,
    377                  weights::Union{Vector{Float64}, Nothing}=nothing)::Vector{Int} where T <: Real
    378 
    379 Create quantile groups using Julia's built-in weighted quantile functionality.
    380 
    381 # Arguments
    382 - `data`: Values to group
    383 - `n_quantiles`: Number of groups
    384 - `weights`: Optional weights of weight type (StatasBase)
    385 
    386 # Examples
    387 ```julia
    388 sales = rand(10_000);
    389 a = xtile(sales, 10);
    390 b = xtile(sales, 10, weights=Weights(repeat([1], length(sales))) );
    391 @assert a == b
    392 ```
    393 """
    394 function xtile(
    395     data::AbstractVector{T},
    396     n_quantiles::Integer;
    397     weights::Union{Weights{<:Real}, Nothing} = nothing
    398 )::Vector{Int} where T <: Real
    399 
    400         N = length(data)
    401         n_quantiles < 1 && error("n_quantiles must be >= 1")
    402         n_quantiles > N && (@warn "More quantiles than data")
    403 
    404         probs = range(0, 1, length=n_quantiles + 1)[2:end]
    405         if weights === nothing
    406             weights = UnitWeights{T}(N)
    407         end
    408         cuts = quantile(collect(data), weights, probs)
    409 
    410     return searchsortedlast.(Ref(cuts), data)
    411 end
    412 
    413 # String version: use lexicographic rank, then delegate to numeric xtile
    414 function xtile(
    415     data::AbstractVector{T},
    416     n_quantiles::Integer;
    417     weights::Union{Weights{<:Real}, Nothing} = nothing
    418 )::Vector{Int} where T <: AbstractString
    419 
    420     sorted_cats = sort(unique(data))
    421     rank_map = Dict(cat => i for (i, cat) in enumerate(sorted_cats))
    422     ranks = [rank_map[d] for d in data]
    423 
    424     return xtile(ranks, n_quantiles; weights=weights)
    425 end
    426 
    427 # Dealing with missing and Numbers
    428 function xtile(
    429     data::AbstractVector{T},
    430     n_quantiles::Integer;
    431     weights::Union{Weights{<:Real}, Nothing} = nothing
    432 )::Vector{Union{Int, Missing}} where {T <: Union{Missing, AbstractString, Number}}
    433 
    434     # Determine the non-missing type
    435     non_missing_type = Base.nonmissingtype(T)
    436 
    437     # Identify valid (non-missing) data
    438     data_notmissing_idx = findall(!ismissing, data)
    439 
    440     if isempty(data_notmissing_idx)  # If all values are missing, return all missing
    441         return fill(missing, length(data))
    442     end
    443 
    444     # Use @view to avoid unnecessary allocations but convert explicitly to non-missing type
    445     valid_data = convert(Vector{non_missing_type}, @view data[data_notmissing_idx])
    446     valid_weights = weights === nothing ? nothing : Weights(@view weights[data_notmissing_idx])
    447 
    448     # Compute quantile groups on valid data
    449     valid_result = xtile(valid_data, n_quantiles; weights=valid_weights)
    450 
    451     # Allocate result array with correct type
    452     result = Vector{Union{Int, Missing}}(missing, length(data))
    453     result[data_notmissing_idx] .= valid_result  # Assign computed quantile groups
    454 
    455     return result
    456 end
    457 # --------------------------------------------------------------------------------------------------
    458