Skip to content

Commit 315964a

Browse files
author
Andy Ferris
committed
Make strides into a generic trait
Returns `nothing` for non-strided arrays, otherwise gives the give strides in memory. Useful as an extensible trait in generic contexts, and simpler to overload for cases of "wrapped" arrays where "stridedness" can be deferred to the parent rather than a complex (and inextensible) method signature.
1 parent a0b7a76 commit 315964a

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

base/abstractarray.jl

+12-3
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,9 @@ last(a) = a[end]
311311
"""
312312
strides(A)
313313
314-
Return a tuple of the memory strides in each dimension.
314+
Return a tuple of the memory strides in each dimension, for an `AbstractArray` with a
315+
strided memory layout. For arrays with a non-strided layout (such as sparse arrays), return
316+
`nothing`.
315317
316318
# Examples
317319
```jldoctest
@@ -321,7 +323,7 @@ julia> strides(A)
321323
(1, 3, 12)
322324
```
323325
"""
324-
function strides end
326+
strides(::AbstractArray) = nothing
325327

326328
"""
327329
stride(A, k::Integer)
@@ -339,7 +341,14 @@ julia> stride(A,3)
339341
12
340342
```
341343
"""
342-
stride(A::AbstractArray, k::Integer) = strides(A)[k]
344+
function stride(A::AbstractArray, k::Integer)
345+
str = strides(A)
346+
if str === nothing
347+
return nothing
348+
else
349+
return str[k]
350+
end
351+
end
343352

344353
@inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...)
345354
size_to_strides(s, d) = (s,)

base/permuteddimsarray.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ Base.pointer(A::PermutedDimsArray, i::Integer) = throw(ArgumentError("pointer(A,
6060

6161
function Base.strides(A::PermutedDimsArray{T,N,perm}) where {T,N,perm}
6262
s = strides(parent(A))
63-
ntuple(d->s[perm[d]], Val(N))
63+
if s === nothing
64+
return nothing
65+
else
66+
return ntuple(d->s[perm[d]], Val(N))
67+
end
6468
end
6569

6670
@inline function Base.getindex(A::PermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}) where {T,N,perm,iperm}

stdlib/LinearAlgebra/src/adjtrans.jl

+19
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,25 @@ vec(v::TransposeAbsVec) = parent(v)
155155
cmp(A::AdjOrTransAbsVec, B::AdjOrTransAbsVec) = cmp(parent(A), parent(B))
156156
isless(A::AdjOrTransAbsVec, B::AdjOrTransAbsVec) = isless(parent(A), parent(B))
157157

158+
# provide strides, but only for eltypes that are directly stored in memory (i.e. unaffected
159+
# by recursive `adjoint` and `transpose`, being `Real` and `Number` respectively)
160+
function Base.strides(a::Union{Adjoint{<:Real, <:AbstractVector}, Transpose{<:Number, <:AbstractVector}})
161+
str = strides(a.parent)
162+
if str === nothing
163+
return nothing
164+
else
165+
return (1, str[1])
166+
end
167+
end
168+
function Base.strides(a::Union{Adjoint{<:Real, <:AbstractMatrix}, Transpose{<:Number, <:AbstractMatrix}})
169+
str = strides(a.parent)
170+
if str === nothing
171+
return nothing
172+
else
173+
return (str[2], str[1])
174+
end
175+
end
176+
158177
### concatenation
159178
# preserve Adjoint/Transpose wrapper around vectors
160179
# to retain the associated semantics post-concatenation

0 commit comments

Comments
 (0)