Skip to content

Commit 9064b1c

Browse files
committed
ggml : fix ggml_get_rows to take into account ne02 / ne11
1 parent ee8fb39 commit 9064b1c

File tree

1 file changed

+41
-22
lines changed

1 file changed

+41
-22
lines changed

ggml.c

+41-22
Original file line numberDiff line numberDiff line change
@@ -10342,20 +10342,27 @@ static void ggml_compute_forward_get_rows_q(
1034210342
return;
1034310343
}
1034410344

10345-
const int nc = src0->ne[0];
10346-
const int nr = ggml_nelements(src1);
10345+
GGML_TENSOR_BINARY_OP_LOCALS
10346+
10347+
const int64_t nc = ne00;
10348+
const int64_t nr = ggml_nelements(src1);
10349+
1034710350
const enum ggml_type type = src0->type;
1034810351
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
1034910352

10350-
assert( dst->ne[0] == nc);
10353+
assert(ne0 == nc);
10354+
assert(ne02 == ne11);
10355+
assert(nb00 == ggml_type_size(type));
1035110356
assert(ggml_nrows(dst) == nr);
10352-
assert(src0->nb[0] == ggml_type_size(type));
1035310357

10354-
for (int i = 0; i < nr; ++i) {
10355-
const int r = ((int32_t *) src1->data)[i];
10358+
// TODO: multi-thread
10359+
for (int64_t i = 0; i < nr; ++i) {
10360+
const int64_t r = ((int32_t *) src1->data)[i];
10361+
10362+
const int64_t i02 = i/ne10;
1035610363

1035710364
dequantize_row_q(
10358-
(const void *) ((char *) src0->data + r*src0->nb[1]),
10365+
(const void *) ((char *) src0->data + i02*nb02 + r*nb01),
1035910366
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
1036010367
}
1036110368
}
@@ -10371,19 +10378,25 @@ static void ggml_compute_forward_get_rows_f16(
1037110378
return;
1037210379
}
1037310380

10374-
const int nc = src0->ne[0];
10375-
const int nr = ggml_nelements(src1);
10381+
GGML_TENSOR_BINARY_OP_LOCALS
10382+
10383+
const int64_t nc = ne00;
10384+
const int64_t nr = ggml_nelements(src1);
1037610385

10377-
assert( dst->ne[0] == nc);
10386+
assert(ne0 == nc);
10387+
assert(ne02 == ne11);
10388+
assert(nb00 == sizeof(ggml_fp16_t));
1037810389
assert(ggml_nrows(dst) == nr);
10379-
assert(src0->nb[0] == sizeof(ggml_fp16_t));
1038010390

10381-
for (int i = 0; i < nr; ++i) {
10382-
const int r = ((int32_t *) src1->data)[i];
10391+
// TODO: multi-thread
10392+
for (int64_t i = 0; i < nr; ++i) {
10393+
const int64_t r = ((int32_t *) src1->data)[i];
10394+
10395+
const int64_t i02 = i/ne10;
1038310396

1038410397
for (int j = 0; j < nc; ++j) {
10385-
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
10386-
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
10398+
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
10399+
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
1038710400
}
1038810401
}
1038910402
}
@@ -10399,19 +10412,25 @@ static void ggml_compute_forward_get_rows_f32(
1039910412
return;
1040010413
}
1040110414

10402-
const int nc = src0->ne[0];
10403-
const int nr = ggml_nelements(src1);
10415+
GGML_TENSOR_BINARY_OP_LOCALS
10416+
10417+
const int64_t nc = ne00;
10418+
const int64_t nr = ggml_nelements(src1);
1040410419

10405-
assert( dst->ne[0] == nc);
10420+
assert(ne0 == nc);
10421+
assert(ne02 == ne11);
10422+
assert(nb00 == sizeof(float));
1040610423
assert(ggml_nrows(dst) == nr);
10407-
assert(src0->nb[0] == sizeof(float));
1040810424

10409-
for (int i = 0; i < nr; ++i) {
10410-
const int r = ((int32_t *) src1->data)[i];
10425+
// TODO: multi-thread
10426+
for (int64_t i = 0; i < nr; ++i) {
10427+
const int64_t r = ((int32_t *) src1->data)[i];
10428+
10429+
const int64_t i02 = i/ne10;
1041110430

1041210431
ggml_vec_cpy_f32(nc,
1041310432
(float *) ((char *) dst->data + i*dst->nb[1]),
10414-
(float *) ((char *) src0->data + r*src0->nb[1]));
10433+
(float *) ((char *) src0->data + i02*nb02 + r*nb01));
1041510434
}
1041610435
}
1041710436

0 commit comments

Comments
 (0)