Skip to content

Commit 2730b5a

Browse files
committed
Fix #2644 in release-0.1
1 parent a703335 commit 2730b5a

File tree

2 files changed

+69
-31
lines changed

2 files changed

+69
-31
lines changed

base/subarray.jl

+47-31
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ function sub(A::SubArray, i::RangeIndex...)
7575
j += 1
7676
end
7777
end
78-
sub(A.parent, tuple(newindexes...))
78+
ni = tuple(newindexes...)
79+
SubArray{eltype(A),L,typeof(A.parent),typeof(ni)}(A.parent, ni)
7980
end
8081

8182
function slice{T,N}(A::AbstractArray{T,N}, i::NTuple{N,RangeIndex})
@@ -188,33 +189,63 @@ function ref{T,S<:Integer}(s::SubArray{T,1}, I::AbstractVector{S})
188189
ref(s.parent, t)
189190
end
190191

192+
function translate_indexes(s::SubArray, I::Union(Real,AbstractArray)...)
193+
I = indices(I)
194+
nds = ndims(s)
195+
n = length(I)
196+
if n > nds
197+
throw(BoundsError())
198+
end
199+
ndp = ndims(s.parent) - (nds-n)
200+
newindexes = Array(Any, ndp)
201+
sp = strides(s.parent)
202+
j = 1
203+
for i = 1:ndp
204+
t = s.indexes[i]
205+
if s.strides[j] == sp[i]
206+
#TODO: don't generate the dense vector indexes if they can be ranges
207+
if j==n && n < nds
208+
newindexes[i] = translate_linear_indexes(s, j, I[j])
209+
else
210+
newindexes[i] = isa(t, Int) ? t : t[I[j]]
211+
end
212+
j += 1
213+
else
214+
newindexes[i] = t
215+
end
216+
end
217+
newindexes
218+
end
219+
191220
# translate a linear index vector I for dim n to a linear index vector for
192221
# the parent array
193222
function translate_linear_indexes(s, n, I)
194223
idx = Array(Int, length(I))
195224
ssztail = size(s)[n:]
196-
psztail = size(s.parent)[n:]
225+
pdims = parentdims(s)
226+
psztail = size(s.parent)[pdims[n:]]
197227
for j=1:length(I)
198228
su = ind2sub(ssztail,I[j])
199-
idx[j] = sub2ind(psztail, [ s.indexes[n+k-1][su[k]] for k=1:length(su) ]...)
229+
idx[j] = sub2ind(psztail, [ s.indexes[pdims[n+k-1]][su[k]] for k=1:length(su) ]...)
200230
end
201231
idx
202232
end
203233

204-
function ref(s::SubArray, I::Union(Real,AbstractArray)...)
205-
I = indices(I)
206-
ndp = ndims(s.parent)
207-
n = length(I)
208-
newindexes = Array(Any, n)
209-
for i = 1:n
210-
t = s.indexes[i]
211-
#TODO: don't generate the dense vector indexes if they can be ranges
212-
if i==n && n < ndp
213-
newindexes[i] = translate_linear_indexes(s, i, I[i])
214-
else
215-
newindexes[i] = isa(t, Int) ? t : t[I[i]]
234+
function parentdims(s::SubArray)
235+
dimindex = Array(Int, ndims(s))
236+
sp = strides(s.parent)
237+
j = 1
238+
for i = 1:ndims(s.parent)
239+
if sp[i] == s.strides[j]
240+
dimindex[j] = i
241+
j += 1
216242
end
217243
end
244+
dimindex
245+
end
246+
247+
function ref(s::SubArray, I::Union(Real,AbstractArray)...)
248+
newindexes = translate_indexes(s, I...)
218249

219250
rs = ref_shape(I...)
220251
result = ref(s.parent, newindexes...)
@@ -270,22 +301,7 @@ function assign{T,S<:Integer}(s::SubArray{T,1}, v, I::AbstractVector{S})
270301
end
271302

272303
function assign(s::SubArray, v, I::Union(Real,AbstractArray)...)
273-
I = indices(I)
274-
j = 1 #the jth dimension in subarray
275-
ndp = ndims(s.parent)
276-
n = length(I)
277-
newindexes = cell(n)
278-
for i = 1:n
279-
t = s.indexes[i]
280-
#TODO: don't generate the dense vector indexes if they can be ranges
281-
if i==n && n < ndp
282-
newindexes[i] = translate_linear_indexes(s, i, I[i])
283-
else
284-
newindexes[i] = isa(t, Int) ? t : t[I[j]]
285-
end
286-
j += 1
287-
end
288-
304+
newindexes = translate_indexes(s, I...)
289305
assign(s.parent, v, newindexes...)
290306
end
291307

test/arrayops.jl

+22
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,28 @@ b = [4, 6, 2, -7, 1]
9090
ind = findin(a, b)
9191
@test ind == [3,4]
9292

93+
# sub
94+
A = reshape(1:120, 3, 5, 8)
95+
sA = sub(A, 2, 1:5, 1:8)
96+
@test size(sA) == (1, 5, 8)
97+
@test_fails sA[2, 1:8]
98+
@test sA[1, 2, 1:8][:] == 5:15:120
99+
sA[2:5:end] = -1
100+
@test all(sA[2:5:end] .== -1)
101+
@test all(A[5:15:120] .== -1)
102+
103+
# slice
104+
A = reshape(1:120, 3, 5, 8)
105+
sA = slice(A, 2, 1:5, 1:8)
106+
@test size(sA) == (5, 8)
107+
@test sA[2, 1:8][:] == 5:15:120
108+
@test sA[:,1] == 2:3:14
109+
@test sA[2:5:end] == 5:15:120
110+
sA[2:5:end] = -1
111+
@test all(sA[2:5:end] .== -1)
112+
@test all(A[5:15:120] .== -1)
113+
114+
93115
# get
94116
let
95117
A = reshape(1:24, 3, 8)

0 commit comments

Comments
 (0)