The iternz
API
This returns an iterator over the structural non-zero elements of the array (elements that aren't zero due to the structure not zero elements) i.e.
all(iternz(x)) do (v, k...)
x[k...] == v
end
The big idea is to abstract away all of the speciall loops needed to iterate over sparse containers. These include special Linear Algebra matrices like Diagonal
, and UpperTriangular
or SparseMatrixCSC
. Furethemore it's possible to use this recursively i.e. An iteration over a Diagonal{SparseVector}
will skip the zero elements (if they are not stored) of the SparseVector.
For an example let's take the sum of the elements in a matrix such that (i + j) % 7 == 0
. The most general way of writing it is
using BenchmarkTools, SparseArrays
const n = 10_000
const A = sprandn(n, n, max(1000, 0.1 * n*n) / n / n);
function general(x::AbstractMatrix)
s = zero(eltype(x))
@inbounds for j in axes(x, 2),
i in axes(x, 1)
if (i + j) % 7 == 0
s += x[i, j]
end
end
return s
end
@benchmark general($A)
BenchmarkTools.Trial: 11 samples with 1 evaluation.
Range (min … max): 461.048 ms … 468.458 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 463.496 ms ┊ GC (median): 0.00%
Time (mean ± σ): 463.928 ms ± 2.200 ms ┊ GC (mean ± σ): 0.00% ± 0.00%
█
▇▁▇▁▁▁▁▁▁▁▇▁▁▁▇▇▁▁▁▁▇▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
461 ms Histogram: frequency by time 468 ms <
Memory estimate: 0 bytes, allocs estimate: 0.
Now this is pretty bad, we can improve the performance by using the sparse structure of the problem
using SparseArrays: getcolptr, nonzeros, rowvals
function sparse_only(x::SparseMatrixCSC)
s = zero(eltype(x))
@inbounds for j in axes(x, 2),
ind in getcolptr(x)[j]:getcolptr(x)[j + 1] - 1
i = rowvals(x)[ind]
if (i + j) % 7 == 0
s += nonzeros(x)[ind]
end
end
return s
end
sparse_only (generic function with 1 method)
We can test for correctness
sparse_only(A) == general(A)
true
and benchmark the function
@benchmark sparse_only($A)
BenchmarkTools.Trial: 301 samples with 1 evaluation.
Range (min … max): 15.765 ms … 19.260 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 16.511 ms ┊ GC (median): 0.00%
Time (mean ± σ): 16.646 ms ± 566.611 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▁▃▅▄█▃▄ ▃ ▁▄ ▁
▃▁▃▃▄█████████▇██▆▅▇██▅▆▅█▇▅▃▆▅▃▆▅▆▄▃▄▁▁▃▁▄▃▅▃▁▃▁▁▃▃▁▁▁▁▁▃▁▃ ▄
15.8 ms Histogram: frequency by time 18.6 ms <
Memory estimate: 0 bytes, allocs estimate: 0.
we can see that while writing the function requires understanding how CSC matrices are stored, the code is 600x faster. The thing is that this pattern gets repeated everywhere so we might try and abstract it away. My proposition is the iternz api.
using SparseExtra
function iternz_only(x::AbstractMatrix)
s = zero(eltype(x))
for (v, i, j) in iternz(x)
if (i + j) % 7 == 0
s += v
end
end
return s
end
iternz_only(A) == general(A)
true
@benchmark sparse_only($A)
BenchmarkTools.Trial: 266 samples with 1 evaluation.
Range (min … max): 15.932 ms … 32.230 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 17.630 ms ┊ GC (median): 0.00%
Time (mean ± σ): 18.806 ms ± 2.971 ms ┊ GC (mean ± σ): 0.00% ± 0.00%
▄█▅▂█▁▃
▅█████████▄▅▅▆▅▆▆▄▄▄▄▅▃▁▃▄▃▄▃▁▅▄▁▃▃▃▁▁▃▁▃▃▃▁▄▃▃▁▁▃▃▃▁▃▃▁▁▃▃ ▃
15.9 ms Histogram: frequency by time 28.3 ms <
Memory estimate: 0 bytes, allocs estimate: 0.
The speed is the same as the specialized version but there is no @inbounds
, no need for ugly loops etc. As a bonus point it works on all of the specialized matrices
using LinearAlgebra
all(iternz_only(i(A)) ≈ general(i(A)) for i in [Transpose, UpperTriangular, LowerTriangular, Diagonal, Symmetric]) # symmetric changes the order of exection.
true
Since these interfaces are written using the iternz interface themselves, the codes generalize to the cases where these special matrices are combined, removing the need to do these tedious specialization.
For instance the 3 argument dot can be written as
function iternz_dot(x::AbstractVector, A::AbstractMatrix, y::AbstractVector)
(length(x), length(y)) == size(A) || throw(ArgumentError("bad shape"))
acc = zero(promote_type(eltype(x), eltype(A), eltype(y)))
@inbounds for (v, i, j) in iternz(A)
acc += x[i] * v * y[j]
end
acc
end
const (x, y) = randn(n), randn(n);
const SA = Symmetric(A);
Correctness tests
dot(x, A, y) ≈ iternz_dot(x, A, y) && dot(x, SA, y) ≈ iternz_dot(x, SA, y)
true
Benchmarks
@benchmark dot($x, $A, $y)
BenchmarkTools.Trial: 500 samples with 1 evaluation.
Range (min … max): 9.153 ms … 16.797 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 9.794 ms ┊ GC (median): 0.00%
Time (mean ± σ): 9.998 ms ± 744.517 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▃ ▂▄█▃ ▁
██████▇██████▆█▅▄▄▅▄▄▄▄▄▄▅▄▄▄▃▄▃▄▄▃▃▃▁▃▁▁▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂ ▄
9.15 ms Histogram: frequency by time 12.9 ms <
Memory estimate: 0 bytes, allocs estimate: 0.
@benchmark iternz_dot($x, $A, $y)
BenchmarkTools.Trial: 426 samples with 1 evaluation.
Range (min … max): 10.760 ms … 16.682 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 11.519 ms ┊ GC (median): 0.00%
Time (mean ± σ): 11.727 ms ± 741.895 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▅█▄▁▂▂
▅▇▇▆█▇▇████████▆█▄▆██▆▄▆▄▆▅▆▃▄▄▃▄▂▄▃▃▂▂▁▃▁▁▁▂▃▁▁▃▁▃▁▁▁▃▃▁▁▁▂ ▄
10.8 ms Histogram: frequency by time 14.5 ms <
Memory estimate: 0 bytes, allocs estimate: 0.
@benchmark dot($x, $SA, $y)
BenchmarkTools.Trial: 9 samples with 1 evaluation.
Range (min … max): 607.881 ms … 623.882 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 609.728 ms ┊ GC (median): 0.00%
Time (mean ± σ): 611.193 ms ± 4.900 ms ┊ GC (mean ± σ): 0.00% ± 0.00%
▁ █ ▁ ▁█ ▁ ▁
█▁▁█▁█▁██▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
608 ms Histogram: frequency by time 624 ms <
Memory estimate: 0 bytes, allocs estimate: 0.
@benchmark iternz_dot($x, $SA, $y)
BenchmarkTools.Trial: 440 samples with 1 evaluation.
Range (min … max): 10.576 ms … 17.101 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 11.173 ms ┊ GC (median): 0.00%
Time (mean ± σ): 11.371 ms ± 740.788 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▂▃ █▆▃▄▃
▆███████████▇▆▇▇▅▆▆▆▅▄▅▅▃▄▃▃▃▃▃▂▁▂▁▁▂▂▃▁▃▁▁▂▂▁▂▁▁▁▂▁▁▁▃▁▁▁▁▂ ▃
10.6 ms Histogram: frequency by time 14.4 ms <
Memory estimate: 0 bytes, allocs estimate: 0.
API:
The Api is pretty simple, the iternz(A)
should return an iteratable such that
all(A[ind...] == v for (v, ind...) in iternz(A))
If the matrix is a container for a different type, the inner iteration should be done via iternz. This repo provides the IterateNZ
container whose sole pupose is to hold the array to overload Base.iterate
. Additionally matrices have the skip_col
and skip_row_to
functions defined. The idea that if meaningful, this should return a state such that iterating on that state will give the first element of the next column or in the case of skip_row_to(cont, state, i)
, iterate should return (i, j)
where j is the current column.
TODO
- test with non-one based indexing
This page was generated using Literate.jl.