Skip to content

Commit 81a90af

Browse files
make SparseArrays a weak dependency (#134)
1 parent 04e5d89 commit 81a90af

File tree

3 files changed

+113
-90
lines changed

3 files changed

+113
-90
lines changed

Project.toml

+8-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,16 @@ julia = "1.9"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1414

15+
[weakdeps]
16+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
17+
18+
[extensions]
19+
SparseArraysExt = ["SparseArrays"]
20+
1521
[extras]
1622
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
23+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1724
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1825

1926
[targets]
20-
test = ["Random", "Test"]
27+
test = ["Random", "SparseArrays", "Test"]

ext/SparseArraysExt.jl

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
module SparseArraysExt
2+
3+
##### SparseArrays optimizations #####
4+
5+
using Base: require_one_based_indexing
6+
using LinearAlgebra
7+
using SparseArrays
8+
using Statistics
9+
using Statistics: centralize_sumabs2, unscaled_covzm
10+
11+
# extended functions
12+
import Statistics: cov, centralize_sumabs2!
13+
14+
function cov(X::SparseMatrixCSC; dims::Int=1, corrected::Bool=true)
15+
vardim = dims
16+
a, b = size(X)
17+
n, p = vardim == 1 ? (a, b) : (b, a)
18+
19+
# The covariance can be decomposed into two terms
20+
# 1/(n - 1) ∑ (x_i - x̄)*(x_i - x̄)' = 1/(n - 1) (∑ x_i*x_i' - n*x̄*x̄')
21+
# which can be evaluated via a sparse matrix-matrix product
22+
23+
# Compute ∑ x_i*x_i' = X'X using sparse matrix-matrix product
24+
out = Matrix(unscaled_covzm(X, vardim))
25+
26+
# Compute x̄
27+
x̄ᵀ = mean(X, dims=vardim)
28+
29+
# Subtract n*x̄*x̄' from X'X
30+
@inbounds for j in 1:p, i in 1:p
31+
out[i,j] -= x̄ᵀ[i] * x̄ᵀ[j]' * n
32+
end
33+
34+
# scale with the sample size n or the corrected sample size n - 1
35+
return rmul!(out, inv(n - corrected))
36+
end
37+
38+
# This is the function that does the reduction underlying var/std
39+
function centralize_sumabs2!(R::AbstractArray{S}, A::SparseMatrixCSC{Tv,Ti}, means::AbstractArray) where {S,Tv,Ti}
40+
require_one_based_indexing(R, A, means)
41+
lsiz = Base.check_reducedims(R,A)
42+
for i in 1:max(ndims(R), ndims(means))
43+
if axes(means, i) != axes(R, i)
44+
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
45+
end
46+
end
47+
isempty(R) || fill!(R, zero(S))
48+
isempty(A) && return R
49+
50+
rowval = rowvals(A)
51+
nzval = nonzeros(A)
52+
m = size(A, 1)
53+
n = size(A, 2)
54+
55+
if size(R, 1) == size(R, 2) == 1
56+
# Reduction along both columns and rows
57+
R[1, 1] = centralize_sumabs2(A, means[1])
58+
elseif size(R, 1) == 1
59+
# Reduction along rows
60+
@inbounds for col = 1:n
61+
mu = means[col]
62+
r = convert(S, (m - length(nzrange(A, col)))*abs2(mu))
63+
@simd for j = nzrange(A, col)
64+
r += abs2(nzval[j] - mu)
65+
end
66+
R[1, col] = r
67+
end
68+
elseif size(R, 2) == 1
69+
# Reduction along columns
70+
rownz = fill(convert(Ti, n), m)
71+
@inbounds for col = 1:n
72+
@simd for j = nzrange(A, col)
73+
row = rowval[j]
74+
R[row, 1] += abs2(nzval[j] - means[row])
75+
rownz[row] -= 1
76+
end
77+
end
78+
for i = 1:m
79+
R[i, 1] += rownz[i]*abs2(means[i])
80+
end
81+
else
82+
# Reduction along a dimension > 2
83+
@inbounds for col = 1:n
84+
lastrow = 0
85+
@simd for j = nzrange(A, col)
86+
row = rowval[j]
87+
for i = lastrow+1:row-1
88+
R[i, col] = abs2(means[i, col])
89+
end
90+
R[row, col] = abs2(nzval[j] - means[row, col])
91+
lastrow = row
92+
end
93+
for i = lastrow+1:m
94+
R[i, col] = abs2(means[i, col])
95+
end
96+
end
97+
end
98+
return R
99+
end
100+
101+
end # module

src/Statistics.jl

+4-89
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Standard library module for basic statistics functionality.
77
"""
88
module Statistics
99

10-
using LinearAlgebra, SparseArrays
10+
using LinearAlgebra
1111

1212
using Base: has_offset_axes, require_one_based_indexing
1313

@@ -1095,94 +1095,9 @@ quantile(itr, p; sorted::Bool=false, alpha::Real=1.0, beta::Real=alpha) =
10951095
quantile(v::AbstractVector, p; sorted::Bool=false, alpha::Real=1.0, beta::Real=alpha) =
10961096
quantile!(sorted ? v : Base.copymutable(v), p; sorted=sorted, alpha=alpha, beta=beta)
10971097

1098-
1099-
##### SparseArrays optimizations #####
1100-
1101-
function cov(X::SparseMatrixCSC; dims::Int=1, corrected::Bool=true)
1102-
vardim = dims
1103-
a, b = size(X)
1104-
n, p = vardim == 1 ? (a, b) : (b, a)
1105-
1106-
# The covariance can be decomposed into two terms
1107-
# 1/(n - 1) ∑ (x_i - x̄)*(x_i - x̄)' = 1/(n - 1) (∑ x_i*x_i' - n*x̄*x̄')
1108-
# which can be evaluated via a sparse matrix-matrix product
1109-
1110-
# Compute ∑ x_i*x_i' = X'X using sparse matrix-matrix product
1111-
out = Matrix(unscaled_covzm(X, vardim))
1112-
1113-
# Compute x̄
1114-
x̄ᵀ = mean(X, dims=vardim)
1115-
1116-
# Subtract n*x̄*x̄' from X'X
1117-
@inbounds for j in 1:p, i in 1:p
1118-
out[i,j] -= x̄ᵀ[i] * x̄ᵀ[j]' * n
1119-
end
1120-
1121-
# scale with the sample size n or the corrected sample size n - 1
1122-
return rmul!(out, inv(n - corrected))
1123-
end
1124-
1125-
# This is the function that does the reduction underlying var/std
1126-
function centralize_sumabs2!(R::AbstractArray{S}, A::SparseMatrixCSC{Tv,Ti}, means::AbstractArray) where {S,Tv,Ti}
1127-
require_one_based_indexing(R, A, means)
1128-
lsiz = Base.check_reducedims(R,A)
1129-
for i in 1:max(ndims(R), ndims(means))
1130-
if axes(means, i) != axes(R, i)
1131-
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
1132-
end
1133-
end
1134-
isempty(R) || fill!(R, zero(S))
1135-
isempty(A) && return R
1136-
1137-
rowval = rowvals(A)
1138-
nzval = nonzeros(A)
1139-
m = size(A, 1)
1140-
n = size(A, 2)
1141-
1142-
if size(R, 1) == size(R, 2) == 1
1143-
# Reduction along both columns and rows
1144-
R[1, 1] = centralize_sumabs2(A, means[1])
1145-
elseif size(R, 1) == 1
1146-
# Reduction along rows
1147-
@inbounds for col = 1:n
1148-
mu = means[col]
1149-
r = convert(S, (m - length(nzrange(A, col)))*abs2(mu))
1150-
@simd for j = nzrange(A, col)
1151-
r += abs2(nzval[j] - mu)
1152-
end
1153-
R[1, col] = r
1154-
end
1155-
elseif size(R, 2) == 1
1156-
# Reduction along columns
1157-
rownz = fill(convert(Ti, n), m)
1158-
@inbounds for col = 1:n
1159-
@simd for j = nzrange(A, col)
1160-
row = rowval[j]
1161-
R[row, 1] += abs2(nzval[j] - means[row])
1162-
rownz[row] -= 1
1163-
end
1164-
end
1165-
for i = 1:m
1166-
R[i, 1] += rownz[i]*abs2(means[i])
1167-
end
1168-
else
1169-
# Reduction along a dimension > 2
1170-
@inbounds for col = 1:n
1171-
lastrow = 0
1172-
@simd for j = nzrange(A, col)
1173-
row = rowval[j]
1174-
for i = lastrow+1:row-1
1175-
R[i, col] = abs2(means[i, col])
1176-
end
1177-
R[row, col] = abs2(nzval[j] - means[row, col])
1178-
lastrow = row
1179-
end
1180-
for i = lastrow+1:m
1181-
R[i, col] = abs2(means[i, col])
1182-
end
1183-
end
1184-
end
1185-
return R
1098+
# If package extensions are not supported in this Julia version
1099+
if !isdefined(Base, :get_extension)
1100+
include("../ext/SparseArraysExt.jl")
11861101
end
11871102

11881103
end # module

0 commit comments

Comments
 (0)