Skip to content

Commit 2cbcba8

Browse files
committed
metal : add more general support for ggml_get_rows + tests
1 parent 9064b1c commit 2cbcba8

File tree

4 files changed

+78
-25
lines changed

4 files changed

+78
-25
lines changed

ggml-metal.m

+9-7
Original file line numberDiff line numberDiff line change
@@ -805,8 +805,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
805805
case GGML_OP_NONE:
806806
case GGML_OP_RESHAPE:
807807
case GGML_OP_VIEW:
808-
case GGML_OP_TRANSPOSE:
809808
case GGML_OP_PERMUTE:
809+
case GGML_OP_TRANSPOSE:
810+
case GGML_OP_GET_ROWS:
810811
case GGML_OP_CONCAT:
811812
case GGML_OP_ADD:
812813
case GGML_OP_MUL:
@@ -828,7 +829,6 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
828829
case GGML_OP_MUL_MAT_ID:
829830
return true;
830831
case GGML_OP_DIAG_MASK_INF:
831-
case GGML_OP_GET_ROWS:
832832
{
833833
return op->ne[0] % 4 == 0;
834834
}
@@ -1568,16 +1568,18 @@ void ggml_metal_graph_compute(
15681568
default: GGML_ASSERT(false && "not implemented");
15691569
}
15701570

1571-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1572-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1573-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1571+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1572+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1573+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
15741574
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
15751575
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1576-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1576+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1577+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1578+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7];
15771579

15781580
const int64_t n = ggml_nelements(src1);
15791581

1580-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1582+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
15811583
} break;
15821584
case GGML_OP_RMS_NORM:
15831585
{

ggml-metal.metal

+56-6
Original file line numberDiff line numberDiff line change
@@ -3223,21 +3223,69 @@ kernel void kernel_get_rows(
32233223
device float * dst,
32243224
constant int64_t & ne00,
32253225
constant uint64_t & nb01,
3226+
constant uint64_t & nb02,
3227+
constant int64_t & ne10,
32263228
constant uint64_t & nb1,
32273229
uint tgpig[[threadgroup_position_in_grid]],
32283230
uint tiitg[[thread_index_in_threadgroup]],
3229-
uint tptg[[threads_per_threadgroup]]) {
3230-
const int i = tgpig;
3231-
const int r = ((device int32_t *) src1)[i];
3231+
uint tptg [[threads_per_threadgroup]]) {
3232+
const int64_t i = tgpig;
3233+
const int64_t r = ((device int32_t *) src1)[i];
32323234

3233-
for (int ind = tiitg; ind < ne00/16; ind += tptg) {
3235+
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) {
32343236
float4x4 temp;
32353237
dequantize_func(
32363238
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
32373239
*(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
32383240
}
32393241
}
32403242

3243+
kernel void kernel_get_rows_f32(
3244+
device const void * src0,
3245+
device const int * src1,
3246+
device float * dst,
3247+
constant int64_t & ne00,
3248+
constant uint64_t & nb01,
3249+
constant uint64_t & nb02,
3250+
constant int64_t & ne10,
3251+
constant uint64_t & nb1,
3252+
uint tgpig[[threadgroup_position_in_grid]],
3253+
uint tiitg[[thread_index_in_threadgroup]],
3254+
uint tptg [[threads_per_threadgroup]]) {
3255+
const int64_t i = tgpig;
3256+
const int64_t r = ((device int32_t *) src1)[i];
3257+
3258+
const int64_t i02 = i/ne10;
3259+
3260+
for (int ind = tiitg; ind < ne00; ind += tptg) {
3261+
((device float *) ((device char *) dst + i*nb1))[ind] =
3262+
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3263+
}
3264+
}
3265+
3266+
kernel void kernel_get_rows_f16(
3267+
device const void * src0,
3268+
device const int * src1,
3269+
device float * dst,
3270+
constant int64_t & ne00,
3271+
constant uint64_t & nb01,
3272+
constant uint64_t & nb02,
3273+
constant int64_t & ne10,
3274+
constant uint64_t & nb1,
3275+
uint tgpig[[threadgroup_position_in_grid]],
3276+
uint tiitg[[thread_index_in_threadgroup]],
3277+
uint tptg [[threads_per_threadgroup]]) {
3278+
const int64_t i = tgpig;
3279+
const int64_t r = ((device int32_t *) src1)[i];
3280+
3281+
const int64_t i02 = i/ne10;
3282+
3283+
for (int ind = tiitg; ind < ne00; ind += tptg) {
3284+
((device float *) ((device char *) dst + i*nb1))[ind] =
3285+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3286+
}
3287+
}
3288+
32413289
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
32423290
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
32433291
#define BLOCK_SIZE_K 32
@@ -3490,11 +3538,13 @@ typedef void (get_rows_t)(
34903538
device float * dst,
34913539
constant int64_t & ne00,
34923540
constant uint64_t & nb01,
3541+
constant uint64_t & nb02,
3542+
constant int64_t & ne10,
34933543
constant uint64_t & nb1,
34943544
uint, uint, uint);
34953545

3496-
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
3497-
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
3546+
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
3547+
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
34983548
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
34993549
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
35003550
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;

ggml.c

+3-3
Original file line numberDiff line numberDiff line change
@@ -10363,7 +10363,7 @@ static void ggml_compute_forward_get_rows_q(
1036310363

1036410364
dequantize_row_q(
1036510365
(const void *) ((char *) src0->data + i02*nb02 + r*nb01),
10366-
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
10366+
(float *) ((char *) dst->data + i*nb1), nc);
1036710367
}
1036810368
}
1036910369

@@ -10396,7 +10396,7 @@ static void ggml_compute_forward_get_rows_f16(
1039610396

1039710397
for (int j = 0; j < nc; ++j) {
1039810398
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
10399-
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
10399+
((float *) ((char *) dst->data + i*nb1))[j] = GGML_FP16_TO_FP32(v);
1040010400
}
1040110401
}
1040210402
}
@@ -10429,7 +10429,7 @@ static void ggml_compute_forward_get_rows_f32(
1042910429
const int64_t i02 = i/ne10;
1043010430

1043110431
ggml_vec_cpy_f32(nc,
10432-
(float *) ((char *) dst->data + i*dst->nb[1]),
10432+
(float *) ((char *) dst->data + i*nb1),
1043310433
(float *) ((char *) src0->data + i02*nb02 + r*nb01));
1043410434
}
1043510435
}

tests/test-backend-ops.cpp

+10-9
Original file line numberDiff line numberDiff line change
@@ -488,17 +488,18 @@ struct test_get_rows : public test_case {
488488
const int n; // cols
489489
const int m; // rows
490490
const int r; // rows to get
491+
const int b; // batch size
491492

492493
std::string vars() override {
493494
return VARS_TO_STR4(type, n, m, r);
494495
}
495496

496-
test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3)
497-
: type(type), n(n), m(m), r(r) {}
497+
test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1)
498+
: type(type), n(n), m(m), r(r), b(b) {}
498499

499500
ggml_tensor * build_graph(ggml_context * ctx) override {
500-
ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m);
501-
ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r);
501+
ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
502+
ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
502503
ggml_tensor * out = ggml_get_rows(ctx, in, rows);
503504
return out;
504505
}
@@ -507,11 +508,11 @@ struct test_get_rows : public test_case {
507508
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
508509
if (t->type == GGML_TYPE_I32) {
509510
// rows
510-
std::vector<int> data(r);
511-
for (int i = 0; i < r; i++) {
511+
std::vector<int> data(r*b);
512+
for (int i = 0; i < r*b; i++) {
512513
data[i] = rand() % m;
513514
}
514-
ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int));
515+
ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
515516
} else {
516517
init_tensor_uniform(t);
517518
}
@@ -1125,8 +1126,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
11251126
}
11261127

11271128
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
1128-
test_cases.emplace_back(new test_get_rows(type, 10, 5, 3));
1129-
test_cases.emplace_back(new test_get_rows(type, 16, 5, 3));
1129+
test_cases.emplace_back(new test_get_rows(type, 10, 5, 3, 7));
1130+
test_cases.emplace_back(new test_get_rows(type, 16, 5, 3, 7));
11301131
}
11311132

11321133
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));

0 commit comments

Comments
 (0)