Skip to content

Commit e49ec67

Browse files
authored
Merge pull request #68 from Tokazama/master
Non-mutating versions of pop, popfirst, etc. (#66)
2 parents 85e93de + bc3447d commit e49ec67

File tree

4 files changed

+169
-62
lines changed

4 files changed

+169
-62
lines changed

src/ArrayInterface.jl

+98-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Requires
44
using LinearAlgebra
55
using SparseArrays
66

7-
using Base: OneTo
7+
using Base: OneTo, @propagate_inbounds
88

99
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
1010
parameterless_type(x) = parameterless_type(typeof(x))
@@ -543,6 +543,103 @@ function restructure(x::Array,y)
543543
reshape(convert(Array,y),size(x)...)
544544
end
545545

546+
"""
547+
insert(collection, index, item)
548+
549+
Return a new instance of `collection` with `item` inserted into at the given `index`.
550+
"""
551+
Base.@propagate_inbounds function insert(collection, index, item)
552+
@boundscheck checkbounds(collection, index)
553+
ret = similar(collection, length(collection) + 1)
554+
@inbounds for i in firstindex(ret):(index - 1)
555+
ret[i] = collection[i]
556+
end
557+
@inbounds ret[index] = item
558+
@inbounds for i in (index + 1):lastindex(ret)
559+
ret[i] = collection[i - 1]
560+
end
561+
return ret
562+
end
563+
564+
function insert(x::Tuple, index::Integer, item)
565+
@boundscheck if !checkindex(Bool, static_first(x):static_last(x), index)
566+
throw(BoundsError(x, index))
567+
end
568+
return unsafe_insert(x, Int(index), item)
569+
end
570+
571+
@inline function unsafe_insert(x::Tuple, i::Int, item)
572+
if i === 1
573+
return (item, x...)
574+
else
575+
return (first(x), unsafe_insert(Base.tail(x), i - 1, item)...)
576+
end
577+
end
578+
579+
"""
580+
deleteat(collection, index)
581+
582+
Return a new instance of `collection` with the item at the given `index` removed.
583+
"""
584+
@propagate_inbounds function deleteat(collection::AbstractVector, index)
585+
@boundscheck if !checkindex(Bool, eachindex(collection), index)
586+
throw(BoundsError(collection, index))
587+
end
588+
return unsafe_deleteat(collection, index)
589+
end
590+
@propagate_inbounds function deleteat(collection::Tuple, index)
591+
@boundscheck if !checkindex(Bool, static_first(collection):static_last(collection), index)
592+
throw(BoundsError(collection, index))
593+
end
594+
return unsafe_deleteat(collection, index)
595+
end
596+
597+
function unsafe_deleteat(src::AbstractVector, index::Integer)
598+
dst = similar(src, length(src) - 1)
599+
@inbounds for i in indices(dst)
600+
if i < index
601+
dst[i] = src[i]
602+
else
603+
dst[i] = src[i + 1]
604+
end
605+
end
606+
return dst
607+
end
608+
609+
@inline function unsafe_deleteat(src::AbstractVector, inds::AbstractVector)
610+
dst = similar(src, length(src) - length(inds))
611+
dst_index = firstindex(dst)
612+
@inbounds for src_index in indices(src)
613+
if !in(src_index, inds)
614+
dst[dst_index] = src[src_index]
615+
dst_index += one(dst_index)
616+
end
617+
end
618+
return dst
619+
end
620+
621+
@inline function unsafe_deleteat(src::Tuple, inds::AbstractVector)
622+
dst = Vector{eltype(src)}(undef, length(src) - length(inds))
623+
dst_index = firstindex(dst)
624+
@inbounds for src_index in OneTo(length(src))
625+
if !in(src_index, inds)
626+
dst[dst_index] = src[src_index]
627+
dst_index += one(dst_index)
628+
end
629+
end
630+
return Tuple(dst)
631+
end
632+
633+
@inline function unsafe_deleteat(x::Tuple, i::Integer)
634+
if i === one(i)
635+
return Base.tail(x)
636+
elseif i == length(x)
637+
return Base.front(x)
638+
else
639+
return (first(x), unsafe_deleteat(Base.tail(x), i - one(i))...)
640+
end
641+
end
642+
546643
function __init__()
547644

548645
@require SuiteSparse="4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin

src/ranges.jl

+21-40
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,6 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)
4242

4343
# add methods to support ArrayInterface
4444

45-
_get(x) = x
46-
_get(::Static{V}) where {V} = V
47-
_get(::Type{Static{V}}) where {V} = V
48-
_convert(::Type{T}, x) where {T} = convert(T, x)
49-
_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V))
50-
5145
"""
5246
OptionallyStaticUnitRange{T<:Integer}(start, stop) <: OrdinalRange{T,T}
5347
@@ -57,28 +51,23 @@ at compile time. An `OptionallyStaticUnitRange` is intended to be constructed in
5751
from other valid indices. Therefore, users should not expect the same checks are used
5852
to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
5953
"""
60-
struct OptionallyStaticUnitRange{T <: Integer, F <: Integer, L <: Integer} <: AbstractUnitRange{T}
54+
struct OptionallyStaticUnitRange{F <: Integer, L <: Integer} <: AbstractUnitRange{Int}
6155
start::F
6256
stop::L
6357

64-
function OptionallyStaticUnitRange{T}(start, stop) where {T<:Real}
65-
if _get(start) isa T
66-
if _get(stop) isa T
67-
return new{T,typeof(start),typeof(stop)}(start, stop)
58+
function OptionallyStaticUnitRange(start, stop)
59+
if eltype(start) <: Int
60+
if eltype(stop) <: Int
61+
return new{typeof(start),typeof(stop)}(start, stop)
6862
else
69-
return OptionallyStaticUnitRange{T}(start, _convert(T, stop))
63+
return OptionallyStaticUnitRange(start, Int(stop))
7064
end
7165
else
72-
return OptionallyStaticUnitRange{T}(_convert(T, start), stop)
66+
return OptionallyStaticUnitRange(Int(start), stop)
7367
end
7468
end
7569

76-
function OptionallyStaticUnitRange(start, stop)
77-
T = promote_type(typeof(_get(start)), typeof(_get(stop)))
78-
return OptionallyStaticUnitRange{T}(start, stop)
79-
end
80-
81-
function OptionallyStaticUnitRange(x::AbstractRange)
70+
function OptionallyStaticUnitRange(x::AbstractRange)
8271
if step(x) == 1
8372
fst = static_first(x)
8473
lst = static_last(x)
@@ -94,12 +83,12 @@ Base.:(:)(::Static{L}, U::Integer) where {L} = OptionallyStaticUnitRange(Static(
9483
Base.:(:)(::Static{L}, ::Static{U}) where {L,U} = OptionallyStaticUnitRange(Static(L), Static(U))
9584

9685
Base.first(r::OptionallyStaticUnitRange) = r.start
97-
Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T)
86+
Base.step(::OptionallyStaticUnitRange) = Static(1)
9887
Base.last(r::OptionallyStaticUnitRange) = r.stop
9988

100-
known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Static{F}}}) where {F} = F
101-
known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T)
102-
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Static{L}}}) where {L} = L
89+
known_first(::Type{<:OptionallyStaticUnitRange{Static{F}}}) where {F} = F
90+
known_step(::Type{<:OptionallyStaticUnitRange}) = 1
91+
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,Static{L}}}) where {L} = L
10392

10493
function Base.isempty(r::OptionallyStaticUnitRange)
10594
if known_first(r) === oneunit(eltype(r))
@@ -112,10 +101,8 @@ end
112101
unsafe_isempty_one_to(lst) = lst <= zero(lst)
113102
unsafe_isempty_unit_range(fst, lst) = fst > lst
114103

115-
unsafe_isempty_unit_range(fst::T, lst::T) where {T} = Integer(lst - fst + one(T))
116-
117-
unsafe_length_one_to(lst::T) where {T<:Int} = T(lst)
118-
unsafe_length_one_to(lst::T) where {T} = Integer(lst - zero(lst))
104+
unsafe_length_one_to(lst::Int) = lst
105+
unsafe_length_one_to(::Static{L}) where {L} = lst
119106

120107
Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
121108
if known_first(r) === oneunit(r)
@@ -144,15 +131,15 @@ end
144131
@inline _try_static(::Static{M}, ::Static{N}) where {M, N} = @assert false "Unequal Indices: Static{$M}() != Static{$N}()"
145132
function _try_static(::Static{N}, x) where {N}
146133
@assert N == x "Unequal Indices: Static{$N}() != x == $x"
147-
Static{N}()
134+
return Static{N}()
148135
end
149136
function _try_static(x, ::Static{N}) where {N}
150137
@assert N == x "Unequal Indices: x == $x != Static{$N}()"
151-
Static{N}()
138+
return Static{N}()
152139
end
153140
function _try_static(x, y)
154141
@assert x == y "Unequal Indicess: x == $x != $y == y"
155-
x
142+
return x
156143
end
157144

158145
###
@@ -172,24 +159,19 @@ end
172159
end
173160
end
174161

175-
function Base.length(r::OptionallyStaticUnitRange{T}) where {T}
162+
function Base.length(r::OptionallyStaticUnitRange)
176163
if isempty(r)
177-
return zero(T)
164+
return 0
178165
else
179-
if known_one(r) === one(T)
166+
if known_first(r) === 0
180167
return unsafe_length_one_to(last(r))
181168
else
182169
return unsafe_length_unit_range(first(r), last(r))
183170
end
184171
end
185172
end
186173

187-
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{Int,Int64,Int128}}
188-
return Base.checked_add(Base.checked_sub(lst, fst), one(T))
189-
end
190-
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{UInt,UInt64,UInt128}}
191-
return Base.checked_add(lst - fst, one(T))
192-
end
174+
unsafe_length_unit_range(start::Integer, stop::Integer) = Int(start - stop + 1)
193175

194176
"""
195177
indices(x[, d])
@@ -231,4 +213,3 @@ end
231213
lst = _try_static(static_last(x), static_last(y))
232214
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
233215
end
234-

src/static.jl

+28-21
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ Use `Static(N)` instead of `Val(N)` when you want it to behave like a number.
66
struct Static{N} <: Integer
77
Static{N}() where {N} = new{N::Int}()
88
end
9+
10+
const Zero = Static{0}
11+
const One = Static{1}
12+
913
Base.@pure Static(N::Int) = Static{N}()
1014
Static(N::Integer) = Static(convert(Int, N))
1115
Static(::Static{N}) where {N} = Static{N}()
@@ -33,41 +37,44 @@ end
3337
Base.promote_rule(::Type{<:Static}, ::Type{<:Static}) = Int
3438
Base.:(%)(::Static{N}, ::Type{Integer}) where {N} = N
3539

36-
Base.iszero(::Static{0}) = true
40+
Base.eltype(::Type{T}) where {T<:Static} = Int
41+
Base.iszero(::Zero) = true
3742
Base.iszero(::Static) = false
38-
Base.isone(::Static{1}) = true
43+
Base.isone(::One) = true
3944
Base.isone(::Static) = false
45+
Base.zero(::Type{T}) where {T<:Static} = Zero()
46+
Base.one(::Type{T}) where {T<:Static} = One()
4047

4148
for T = [:Real, :Rational, :Integer]
4249
@eval begin
43-
@inline Base.:(+)(i::$T, ::Static{0}) = i
50+
@inline Base.:(+)(i::$T, ::Zero) = i
4451
@inline Base.:(+)(i::$T, ::Static{M}) where {M} = i + M
45-
@inline Base.:(+)(::Static{0}, i::$T) = i
52+
@inline Base.:(+)(::Zero, i::$T) = i
4653
@inline Base.:(+)(::Static{M}, i::$T) where {M} = M + i
47-
@inline Base.:(-)(i::$T, ::Static{0}) = i
54+
@inline Base.:(-)(i::$T, ::Zero) = i
4855
@inline Base.:(-)(i::$T, ::Static{M}) where {M} = i - M
49-
@inline Base.:(*)(i::$T, ::Static{0}) = Static{0}()
50-
@inline Base.:(*)(i::$T, ::Static{1}) = i
56+
@inline Base.:(*)(i::$T, ::Zero) = Zero()
57+
@inline Base.:(*)(i::$T, ::One) = i
5158
@inline Base.:(*)(i::$T, ::Static{M}) where {M} = i * M
52-
@inline Base.:(*)(::Static{0}, i::$T) = Static{0}()
53-
@inline Base.:(*)(::Static{1}, i::$T) = i
59+
@inline Base.:(*)(::Zero, i::$T) = Zero()
60+
@inline Base.:(*)(::One, i::$T) = i
5461
@inline Base.:(*)(::Static{M}, i::$T) where {M} = M * i
5562
end
5663
end
57-
@inline Base.:(+)(::Static{0}, ::Static{0}) = Static{0}()
58-
@inline Base.:(+)(::Static{0}, ::Static{M}) where {M} = Static{M}()
59-
@inline Base.:(+)(::Static{M}, ::Static{0}) where {M} = Static{M}()
64+
@inline Base.:(+)(::Zero, ::Zero) = Zero()
65+
@inline Base.:(+)(::Zero, ::Static{M}) where {M} = Static{M}()
66+
@inline Base.:(+)(::Static{M}, ::Zero) where {M} = Static{M}()
6067

61-
@inline Base.:(-)(::Static{M}, ::Static{0}) where {M} = Static{M}()
68+
@inline Base.:(-)(::Static{M}, ::Zero) where {M} = Static{M}()
6269

63-
@inline Base.:(*)(::Static{0}, ::Static{0}) = Static{0}()
64-
@inline Base.:(*)(::Static{1}, ::Static{0}) = Static{0}()
65-
@inline Base.:(*)(::Static{0}, ::Static{1}) = Static{0}()
66-
@inline Base.:(*)(::Static{1}, ::Static{1}) = Static{1}()
67-
@inline Base.:(*)(::Static{M}, ::Static{0}) where {M} = Static{0}()
68-
@inline Base.:(*)(::Static{0}, ::Static{M}) where {M} = Static{0}()
69-
@inline Base.:(*)(::Static{M}, ::Static{1}) where {M} = Static{M}()
70-
@inline Base.:(*)(::Static{1}, ::Static{M}) where {M} = Static{M}()
70+
@inline Base.:(*)(::Zero, ::Zero) = Zero()
71+
@inline Base.:(*)(::One, ::Zero) = Zero()
72+
@inline Base.:(*)(::Zero, ::One) = Zero()
73+
@inline Base.:(*)(::One, ::One) = One()
74+
@inline Base.:(*)(::Static{M}, ::Zero) where {M} = Zero()
75+
@inline Base.:(*)(::Zero, ::Static{M}) where {M} = Zero()
76+
@inline Base.:(*)(::Static{M}, ::One) where {M} = Static{M}()
77+
@inline Base.:(*)(::One, ::Static{M}) where {M} = Static{M}()
7178
for f [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :()]
7279
@eval @generated Base.$f(::Static{M}, ::Static{N}) where {M,N} = Expr(:call, Expr(:curly, :Static, $f(M, N)))
7380
end

test/runtests.jl

+22
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ end
252252
@testset "Static" begin
253253
@test iszero(Static(0))
254254
@test !iszero(Static(1))
255+
@test @inferred(one(Static)) === Static(1)
256+
@test @inferred(zero(Static)) === Static(0)
257+
@test eltype(one(Static)) <: Int
255258
# test for ambiguities and correctness
256259
for i [Static(0), Static(1), Static(2), 3]
257260
for j [Static(0), Static(1), Static(2), 3]
@@ -271,3 +274,22 @@ end
271274
end
272275
end
273276

277+
@testset "insert/deleteat" begin
278+
@test @inferred(ArrayInterface.insert([1,2,3], 2, -2)) == [1, -2, 2, 3]
279+
@test @inferred(ArrayInterface.deleteat([1, 2, 3], 2)) == [1, 3]
280+
281+
@test @inferred(ArrayInterface.deleteat([1, 2, 3], [1, 2])) == [3]
282+
@test @inferred(ArrayInterface.deleteat([1, 2, 3], [1, 3])) == [2]
283+
@test @inferred(ArrayInterface.deleteat([1, 2, 3], [2, 3])) == [1]
284+
285+
286+
@test @inferred(ArrayInterface.insert((1,2,3), 1, -2)) == (-2, 1, 2, 3)
287+
@test @inferred(ArrayInterface.insert((1,2,3), 2, -2)) == (1, -2, 2, 3)
288+
@test @inferred(ArrayInterface.insert((1,2,3), 3, -2)) == (1, 2, -2, 3)
289+
290+
@test @inferred(ArrayInterface.deleteat((1, 2, 3), 1)) == (2, 3)
291+
@test @inferred(ArrayInterface.deleteat((1, 2, 3), 2)) == (1, 3)
292+
@test @inferred(ArrayInterface.deleteat((1, 2, 3), 3)) == (1, 2)
293+
@test ArrayInterface.deleteat((1, 2, 3), [1, 2]) == (3,)
294+
end
295+

0 commit comments

Comments
 (0)