Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: CI
on:
pull_request:
push:
branches: [master]
branches: [main]
tags: ['*']
workflow_dispatch:
jobs:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
..DS_Store
.DS_Store
Manifest.toml
benchmark/tune.json
settings.local.json
24 changes: 11 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,37 +1,35 @@
name = "FixedEffects"
uuid = "c8885935-8500-56a7-9867-7708b20db0eb"
version = "3.0.0"

version = "3.1.0"

[deps]
GroupedArrays = "6407cd72-fade-4a84-8a1e-56e431fc1533"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
GroupedArrays = "6407cd72-fade-4a84-8a1e-56e431fc1533"

[compat]
PrecompileTools = "1"
StatsBase = "0.33, 0.34"
GroupedArrays = "0.3"
julia = "1.10"
[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[extensions]
CUDAExt = "CUDA"
MetalExt = "Metal"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
[compat]
GroupedArrays = "0.3"
PrecompileTools = "1"
StatsBase = "0.33, 0.34"
julia = "1.10"

[extras]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CategoricalArrays", "CUDA", "Metal", "Pkg", "PooledArrays", "Test"]

169 changes: 169 additions & 0 deletions benchmark/bench_gather.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
using Random, BenchmarkTools, Base.Threads
println("Julia ", VERSION, " — ", nthreads(), " threads")
Random.seed!(1234)

##############################################################################
# Current serial gather (baseline)
##############################################################################
function gather_serial!(fecoef, refs, α, y, cache)
@fastmath @inbounds @simd for i in eachindex(y)
fecoef[refs[i]] += α * y[i] * cache[i]
end
end

##############################################################################
# Approach 1: CSC-style transposed gather
# Precompute a CSC structure: for each group k, store obs indices.
# Each group k can be processed independently → trivially parallel, zero conflicts.
##############################################################################
struct CSCIndex
offsets::Vector{Int}
indices::Vector{Int}
end

function build_csc(refs::AbstractVector{<:Integer}, n::Int)
N = length(refs)
counts = zeros(Int, n)
@inbounds for i in 1:N
counts[refs[i]] += 1
end
offsets = Vector{Int}(undef, n + 1)
offsets[1] = 1
@inbounds for k in 1:n
offsets[k+1] = offsets[k] + counts[k]
end
indices = Vector{Int}(undef, N)
fill!(counts, 0)
@inbounds for i in 1:N
k = refs[i]
counts[k] += 1
indices[offsets[k] + counts[k] - 1] = i
end
return CSCIndex(offsets, indices)
end

function gather_csc_parallel!(fecoef::AbstractVector{T}, csc::CSCIndex, α, y, cache) where T
offsets, indices = csc.offsets, csc.indices
n = length(fecoef)
Threads.@threads for k in 1:n
s = zero(T)
@fastmath @inbounds for j in offsets[k]:(offsets[k+1]-1)
i = indices[j]
s += y[i] * cache[i]
end
@inbounds fecoef[k] += α * s
end
end

function gather_csc_serial!(fecoef::AbstractVector{T}, csc::CSCIndex, α, y, cache) where T
offsets, indices = csc.offsets, csc.indices
n = length(fecoef)
for k in 1:n
s = zero(T)
@fastmath @inbounds for j in offsets[k]:(offsets[k+1]-1)
i = indices[j]
s += y[i] * cache[i]
end
@inbounds fecoef[k] += α * s
end
end

##############################################################################
# Approach 2: Per-thread accumulators with manual chunking (@spawn)
##############################################################################
struct PerThreadBuffers{T}
buffers::Vector{Vector{T}}
end
PerThreadBuffers{T}(n::Int, nt::Int) where T = PerThreadBuffers([zeros(T, n) for _ in 1:nt])

function gather_perthread!(fecoef::AbstractVector{T}, refs, α, y, cache, ptb::PerThreadBuffers{T}) where T
nt = length(ptb.buffers)
N = length(y)
for buf in ptb.buffers
fill!(buf, zero(T))
end
chunk = cld(N, nt)
@sync for t in 1:nt
Threads.@spawn begin
buf = ptb.buffers[t]
lo = (t-1)*chunk + 1
hi = min(t*chunk, N)
@fastmath @inbounds for i in lo:hi
buf[refs[i]] += y[i] * cache[i]
end
end
end
@inbounds for buf in ptb.buffers
@simd for k in eachindex(fecoef)
fecoef[k] += α * buf[k]
end
end
end

##############################################################################
# Benchmarks
##############################################################################
function run_bench(label, N, n_groups)
println("\n", "="^60)
println("$label: N=$N, n_groups=$n_groups (avg group size=$(N÷n_groups))")
println("="^60)

refs = rand(1:n_groups, N)
y = rand(N)
cache = rand(N)
α = 1.0
nt = nthreads()

csc = build_csc(refs, n_groups)
ptb = PerThreadBuffers{Float64}(n_groups, nt)

# Verify correctness
out_ref = zeros(n_groups)
gather_serial!(out_ref, refs, α, y, cache)

for (name, fn!) in [
("CSC parallel", (out) -> gather_csc_parallel!(out, csc, α, y, cache)),
("CSC serial", (out) -> gather_csc_serial!(out, csc, α, y, cache)),
("Per-thread chunked", (out) -> gather_perthread!(out, refs, α, y, cache, ptb)),
]
out_test = zeros(n_groups)
fn!(out_test)
if !isapprox(out_test, out_ref, rtol=1e-10)
println(" WARNING $name: INCORRECT (max diff = $(maximum(abs.(out_test .- out_ref))))")
end
end

print(" Serial (baseline): ")
b0 = @benchmark gather_serial!(out, $refs, $α, $y, $cache) setup=(out=zeros($n_groups)) evals=1 samples=30
show(stdout, MIME("text/plain"), b0); println()

print(" CSC serial: ")
b1 = @benchmark gather_csc_serial!(out, $csc, $α, $y, $cache) setup=(out=zeros($n_groups)) evals=1 samples=30
show(stdout, MIME("text/plain"), b1); println()

print(" CSC parallel: ")
b2 = @benchmark gather_csc_parallel!(out, $csc, $α, $y, $cache) setup=(out=zeros($n_groups)) evals=1 samples=30
show(stdout, MIME("text/plain"), b2); println()

print(" Per-thread chunked: ")
b3 = @benchmark gather_perthread!(out, $refs, $α, $y, $cache, $ptb) setup=(out=zeros($n_groups)) evals=1 samples=30
show(stdout, MIME("text/plain"), b3); println()

t0 = median(b0).time
println("\n Speedups vs serial baseline:")
println(" CSC serial: $(round(t0/median(b1).time, digits=2))x")
println(" CSC parallel: $(round(t0/median(b2).time, digits=2))x")
println(" Per-thread chunked: $(round(t0/median(b3).time, digits=2))x")
end

# Scenario 1: Few large groups (like year FE)
run_bench("Few large groups", 10_000_000, 100)

# Scenario 2: Many medium groups
run_bench("Many medium groups", 10_000_000, 100_000)

# Scenario 3: Many small groups (worker FE)
run_bench("Many small groups (worker FE)", 800_000, 400_000)

# Scenario 4: Moderate groups (firm FE)
run_bench("Moderate groups (firm FE)", 800_000, 50_000)
82 changes: 82 additions & 0 deletions benchmark/run.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
using FixedEffects, Random
try using CUDA catch end
try using Metal catch end
Random.seed!(1234)

##############################################################################
# Setup
##############################################################################

# Simple problem: N=10M, two FEs (100k × 100 groups)
N = 10_000_000
K = 100
id1 = rand(1:div(N, K), N)
id2 = rand(1:K, N)
fes_simple = [FixedEffect(id1), FixedEffect(id2)]
x_simple = rand(N)

# Hard problem: N=800k, worker-firm (400k × 50k)
N = 800_000
M = 400_000
O = 50_000
Random.seed!(1234)
pid = rand(1:M, N)
fid = [rand(max(1, div(x, 8)-10):min(O, div(x, 8)+10)) for x in pid]
x_hard = rand(N)
fes_hard = [FixedEffect(pid), FixedEffect(fid)]
y = rand(N)
fes_interact = [FixedEffect(pid), FixedEffect(pid; interaction = y)]

##############################################################################
# CPU
##############################################################################

println("Simple (N=10M, 100k×100), first run:") # ~3 s
@time solve_residuals!(deepcopy(x_simple), fes_simple)
println("Simple (N=10M, 100k×100), second run:") # ~0.5 s
@time solve_residuals!(deepcopy(x_simple), fes_simple)

println("Hard (N=800k, 400k×50k), Float32, first run:") # ~3 s
@time solve_residuals!(deepcopy(x_hard), fes_hard; double_precision = false)
println("Hard (N=800k, 400k×50k), Float32, second run:") # ~2.5 s
@time solve_residuals!(deepcopy(x_hard), fes_hard; double_precision = false)

println("Hard (N=800k, 400k×50k), maxiter=300, first run:") # ~2.5 s
@time solve_residuals!(deepcopy(x_hard), fes_hard; maxiter = 300)
println("Hard (N=800k, 400k×50k), maxiter=300, second run:") # ~2.5 s
@time solve_residuals!(deepcopy(x_hard), fes_hard; maxiter = 300)

println("Hard (N=800k, interacted), first run:") # ~3.5 s
@time solve_residuals!(deepcopy(x_hard), fes_interact; maxiter = 300)
println("Hard (N=800k, interacted), second run:") # ~3 s
@time solve_residuals!(deepcopy(x_hard), fes_interact; maxiter = 300)

##############################################################################
# CUDA
##############################################################################
if isdefined(Main, :CUDA) && CUDA.functional()
println("Simple (N=10M, 100k×100), CUDA, first run:")
@time solve_residuals!(deepcopy(x_simple), fes_simple; method = :CUDA)
println("Simple (N=10M, 100k×100), CUDA, second run:")
@time solve_residuals!(deepcopy(x_simple), fes_simple; method = :CUDA)

println("Hard (N=800k, 400k×50k), CUDA, first run:")
@time solve_residuals!(deepcopy(x_hard), fes_hard; method = :CUDA)
println("Hard (N=800k, 400k×50k), CUDA, second run:")
@time solve_residuals!(deepcopy(x_hard), fes_hard; method = :CUDA)
end

##############################################################################
# Metal
##############################################################################
if isdefined(Main, :Metal) && Metal.functional()
println("Simple (N=10M, 100k×100), Metal, first run:") # ~18 s
@time solve_residuals!(Float32.(deepcopy(x_simple)), fes_simple; method = :Metal, double_precision = false)
println("Simple (N=10M, 100k×100), Metal, second run:") # ~1.5 s
@time solve_residuals!(Float32.(deepcopy(x_simple)), fes_simple; method = :Metal, double_precision = false)

println("Hard (N=800k, 400k×50k), Metal, first run:") # ~3.3 s
@time solve_residuals!(Float32.(deepcopy(x_hard)), fes_hard; method = :Metal, double_precision = false, maxiter = 300)
println("Hard (N=800k, 400k×50k), Metal, second run:") # ~1.6 s
@time solve_residuals!(Float32.(deepcopy(x_hard)), fes_hard; method = :Metal, double_precision = false, maxiter = 300)
end
30 changes: 0 additions & 30 deletions benchmarks/benchmark.jl

This file was deleted.

50 changes: 0 additions & 50 deletions benchmarks/benchmark_CUDA.jl

This file was deleted.

Loading
Loading