@@ -206,21 +206,83 @@ index_shape(I::Real...) = ()
206
206
index_shape (i, I... ) = tuple (length (i), index_shape (I... )... )
207
207
208
208
# check for valid sizes in A[I...] = X where X <: AbstractArray
209
+ # we want to allow dimensions that are equal up to permutation, but only
210
+ # for permutations that leave array elements in the same linear order.
211
+ # those are the permutations that preserve the order of the non-singleton
212
+ # dimensions.
209
213
function setindex_shape_check (X:: AbstractArray , I... )
214
+ li = length (I)
215
+ ii = 1
210
216
nel = 1
211
- for idx in I
212
- nel *= length (idx)
213
- end
214
- if length (X) != nel
215
- error (" dimensions must match" )
216
- end
217
- if ndims (X) > 1
218
- for i = 1 : length (I)
219
- if size (X,i) != length (I[i])
220
- error (" dimensions must match" )
217
+ xi = 1
218
+ ndx = ndims (X)
219
+ match = true
220
+ while ii < li
221
+ lii = length (I[ii]):: Int
222
+ ii += 1
223
+ if lii != 1
224
+ nel *= lii
225
+ local lxi
226
+ while true
227
+ lxi = size (X,xi)
228
+ xi += 1
229
+ if lxi != 1 || xi > ndx
230
+ break
231
+ end
232
+ end
233
+ if xi > ndx
234
+ trailing = lii
235
+ while ii <= li
236
+ lii = length (I[ii]):: Int
237
+ trailing *= lii
238
+ ii += 1
239
+ end
240
+ # X's last dimension can match all the trailing indexes
241
+ if lxi == trailing && match
242
+ return
243
+ else
244
+ throw (DimensionMismatch (" " ))
245
+ end
246
+ else
247
+ if lxi != lii
248
+ match = false
249
+ end
221
250
end
222
251
end
223
252
end
253
+
254
+ # last index can match X's trailing dimensions
255
+ lii = length (I[ii]):: Int
256
+ nel *= lii
257
+ if lii != trailingsize (X,xi)
258
+ match = false
259
+ end
260
+
261
+ if ! (match && length (X)== nel)
262
+ throw (DimensionMismatch (" " ))
263
+ end
264
+ end
265
+
266
+ setindex_shape_check (X:: AbstractArray ) = (length (X)== 1 || throw (DimensionMismatch (" " )))
267
+
268
+ setindex_shape_check (X:: AbstractArray , i) =
269
+ (length (X)== length (i) || throw (DimensionMismatch (" " )))
270
+
271
+ setindex_shape_check {T} (X:: AbstractArray{T,1} , i) =
272
+ (length (X)== length (i) || throw (DimensionMismatch (" " )))
273
+
274
+ setindex_shape_check {T} (X:: AbstractArray{T,1} , i, j) =
275
+ (length (X)== length (i)* length (j) || throw (DimensionMismatch (" " )))
276
+
277
+ function setindex_shape_check {T} (X:: AbstractArray{T,2} , i, j)
278
+ li, lj = length (i), length (j)
279
+ if length (X) != li* lj
280
+ throw (DimensionMismatch (" " ))
281
+ end
282
+ sx1 = size (X,1 )
283
+ if ! (li == 1 || li == sx1 || sx1 == 1 )
284
+ throw (DimensionMismatch (" " ))
285
+ end
224
286
end
225
287
226
288
# convert to integer index
0 commit comments