@@ -10342,20 +10342,27 @@ static void ggml_compute_forward_get_rows_q(
10342
10342
return;
10343
10343
}
10344
10344
10345
- const int nc = src0->ne[0];
10346
- const int nr = ggml_nelements(src1);
10345
+ GGML_TENSOR_BINARY_OP_LOCALS
10346
+
10347
+ const int64_t nc = ne00;
10348
+ const int64_t nr = ggml_nelements(src1);
10349
+
10347
10350
const enum ggml_type type = src0->type;
10348
10351
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
10349
10352
10350
- assert( dst->ne[0] == nc);
10353
+ assert(ne0 == nc);
10354
+ assert(ne02 == ne11);
10355
+ assert(nb00 == ggml_type_size(type));
10351
10356
assert(ggml_nrows(dst) == nr);
10352
- assert(src0->nb[0] == ggml_type_size(type));
10353
10357
10354
- for (int i = 0; i < nr; ++i) {
10355
- const int r = ((int32_t *) src1->data)[i];
10358
+ // TODO: multi-thread
10359
+ for (int64_t i = 0; i < nr; ++i) {
10360
+ const int64_t r = ((int32_t *) src1->data)[i];
10361
+
10362
+ const int64_t i02 = i/ne10;
10356
10363
10357
10364
dequantize_row_q(
10358
- (const void *) ((char *) src0->data + r*src0->nb[1] ),
10365
+ (const void *) ((char *) src0->data + i02*nb02 + r*nb01 ),
10359
10366
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
10360
10367
}
10361
10368
}
@@ -10371,19 +10378,25 @@ static void ggml_compute_forward_get_rows_f16(
10371
10378
return;
10372
10379
}
10373
10380
10374
- const int nc = src0->ne[0];
10375
- const int nr = ggml_nelements(src1);
10381
+ GGML_TENSOR_BINARY_OP_LOCALS
10382
+
10383
+ const int64_t nc = ne00;
10384
+ const int64_t nr = ggml_nelements(src1);
10376
10385
10377
- assert( dst->ne[0] == nc);
10386
+ assert(ne0 == nc);
10387
+ assert(ne02 == ne11);
10388
+ assert(nb00 == sizeof(ggml_fp16_t));
10378
10389
assert(ggml_nrows(dst) == nr);
10379
- assert(src0->nb[0] == sizeof(ggml_fp16_t));
10380
10390
10381
- for (int i = 0; i < nr; ++i) {
10382
- const int r = ((int32_t *) src1->data)[i];
10391
+ // TODO: multi-thread
10392
+ for (int64_t i = 0; i < nr; ++i) {
10393
+ const int64_t r = ((int32_t *) src1->data)[i];
10394
+
10395
+ const int64_t i02 = i/ne10;
10383
10396
10384
10397
for (int j = 0; j < nc; ++j) {
10385
- ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1] ))[j];
10386
- ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
10398
+ ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01 ))[j];
10399
+ ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
10387
10400
}
10388
10401
}
10389
10402
}
@@ -10399,19 +10412,25 @@ static void ggml_compute_forward_get_rows_f32(
10399
10412
return;
10400
10413
}
10401
10414
10402
- const int nc = src0->ne[0];
10403
- const int nr = ggml_nelements(src1);
10415
+ GGML_TENSOR_BINARY_OP_LOCALS
10416
+
10417
+ const int64_t nc = ne00;
10418
+ const int64_t nr = ggml_nelements(src1);
10404
10419
10405
- assert( dst->ne[0] == nc);
10420
+ assert(ne0 == nc);
10421
+ assert(ne02 == ne11);
10422
+ assert(nb00 == sizeof(float));
10406
10423
assert(ggml_nrows(dst) == nr);
10407
- assert(src0->nb[0] == sizeof(float));
10408
10424
10409
- for (int i = 0; i < nr; ++i) {
10410
- const int r = ((int32_t *) src1->data)[i];
10425
+ // TODO: multi-thread
10426
+ for (int64_t i = 0; i < nr; ++i) {
10427
+ const int64_t r = ((int32_t *) src1->data)[i];
10428
+
10429
+ const int64_t i02 = i/ne10;
10411
10430
10412
10431
ggml_vec_cpy_f32(nc,
10413
10432
(float *) ((char *) dst->data + i*dst->nb[1]),
10414
- (float *) ((char *) src0->data + r*src0->nb[1] ));
10433
+ (float *) ((char *) src0->data + i02*nb02 + r*nb01 ));
10415
10434
}
10416
10435
}
10417
10436
0 commit comments