Skip to content

Commit d7b00dd

Browse files
committedDec 28, 2013
implement new array assignment shape matching rule. fixes #4048, fixes #4383
this rule ignores singleton dimensions, and allows the last dimension of one side to match all trailing dimensions of the other.
1 parent 851b2c8 commit d7b00dd

File tree

2 files changed

+73
-24
lines changed

2 files changed

+73
-24
lines changed
 

‎base/array.jl

+1-14
Original file line numberDiff line numberDiff line change
@@ -593,20 +593,7 @@ function setindex!(A::Array, x, I::Union(Real,AbstractArray)...)
593593
assign_cache = Dict()
594594
end
595595
X = x
596-
nel = 1
597-
for idx in I
598-
nel *= length(idx)
599-
end
600-
if length(X) != nel
601-
throw(DimensionMismatch(""))
602-
end
603-
if ndims(X) > 1
604-
for i = 1:length(I)
605-
if size(X,i) != length(I[i])
606-
throw(DimensionMismatch(""))
607-
end
608-
end
609-
end
596+
setindex_shape_check(X, I...)
610597
gen_array_index_map(assign_cache, storeind -> quote
611598
A[$storeind] = X[refind]
612599
refind += 1

‎base/operators.jl

+72-10
Original file line numberDiff line numberDiff line change
@@ -206,21 +206,83 @@ index_shape(I::Real...) = ()
206206
index_shape(i, I...) = tuple(length(i), index_shape(I...)...)
207207

208208
# check for valid sizes in A[I...] = X where X <: AbstractArray
209+
# we want to allow dimensions that are equal up to permutation, but only
210+
# for permutations that leave array elements in the same linear order.
211+
# those are the permutations that preserve the order of the non-singleton
212+
# dimensions.
209213
function setindex_shape_check(X::AbstractArray, I...)
214+
li = length(I)
215+
ii = 1
210216
nel = 1
211-
for idx in I
212-
nel *= length(idx)
213-
end
214-
if length(X) != nel
215-
error("dimensions must match")
216-
end
217-
if ndims(X) > 1
218-
for i = 1:length(I)
219-
if size(X,i) != length(I[i])
220-
error("dimensions must match")
217+
xi = 1
218+
ndx = ndims(X)
219+
match = true
220+
while ii < li
221+
lii = length(I[ii])::Int
222+
ii += 1
223+
if lii != 1
224+
nel *= lii
225+
local lxi
226+
while true
227+
lxi = size(X,xi)
228+
xi += 1
229+
if lxi != 1 || xi > ndx
230+
break
231+
end
232+
end
233+
if xi > ndx
234+
trailing = lii
235+
while ii <= li
236+
lii = length(I[ii])::Int
237+
trailing *= lii
238+
ii += 1
239+
end
240+
# X's last dimension can match all the trailing indexes
241+
if lxi == trailing && match
242+
return
243+
else
244+
throw(DimensionMismatch(""))
245+
end
246+
else
247+
if lxi != lii
248+
match = false
249+
end
221250
end
222251
end
223252
end
253+
254+
# last index can match X's trailing dimensions
255+
lii = length(I[ii])::Int
256+
nel *= lii
257+
if lii != trailingsize(X,xi)
258+
match = false
259+
end
260+
261+
if !(match && length(X)==nel)
262+
throw(DimensionMismatch(""))
263+
end
264+
end
265+
266+
setindex_shape_check(X::AbstractArray) = (length(X)==1 || throw(DimensionMismatch("")))
267+
268+
setindex_shape_check(X::AbstractArray, i) =
269+
(length(X)==length(i) || throw(DimensionMismatch("")))
270+
271+
setindex_shape_check{T}(X::AbstractArray{T,1}, i) =
272+
(length(X)==length(i) || throw(DimensionMismatch("")))
273+
274+
setindex_shape_check{T}(X::AbstractArray{T,1}, i, j) =
275+
(length(X)==length(i)*length(j) || throw(DimensionMismatch("")))
276+
277+
function setindex_shape_check{T}(X::AbstractArray{T,2}, i, j)
278+
li, lj = length(i), length(j)
279+
if length(X) != li*lj
280+
throw(DimensionMismatch(""))
281+
end
282+
sx1 = size(X,1)
283+
if !(li == 1 || li == sx1 || sx1 == 1)
284+
throw(DimensionMismatch(""))
285+
end
224286
end
225287

226288
# convert to integer index

0 commit comments

Comments
 (0)
Please sign in to comment.