@@ -519,48 +519,56 @@ void ggml_metal_graph_compute(
519
519
520
520
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
521
521
} break ;
522
- case GGML_OP_SILU:
523
- {
524
- if (encoder == nil ) {
525
- encoder = [command_buffer computeCommandEncoder ];
526
- }
527
-
528
- [encoder setComputePipelineState: ctx->pipeline_silu];
529
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
530
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
531
-
532
- const int64_t n = ggml_nelements (dst);
533
-
534
- [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
535
- } break ;
536
- case GGML_OP_RELU:
537
- {
538
- if (encoder == nil ) {
539
- encoder = [command_buffer computeCommandEncoder ];
540
- }
541
-
542
- [encoder setComputePipelineState: ctx->pipeline_relu];
543
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
544
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
545
-
546
- const int64_t n = ggml_nelements (dst);
547
-
548
- [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
522
+ case GGML_OP_UNARY:
523
+ switch (ggml_get_unary_op (gf->nodes [i])) {
524
+ case GGML_UNARY_OP_SILU:
525
+ {
526
+ if (encoder == nil ) {
527
+ encoder = [command_buffer computeCommandEncoder ];
528
+ }
529
+
530
+ [encoder setComputePipelineState: ctx->pipeline_silu];
531
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
532
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
533
+
534
+ const int64_t n = ggml_nelements (dst);
535
+
536
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
537
+ } break ;
538
+ case GGML_UNARY_OP_RELU:
539
+ {
540
+ if (encoder == nil ) {
541
+ encoder = [command_buffer computeCommandEncoder ];
542
+ }
543
+
544
+ [encoder setComputePipelineState: ctx->pipeline_relu];
545
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
546
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
547
+
548
+ const int64_t n = ggml_nelements (dst);
549
+
550
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
551
+ } break ;
552
+ case GGML_UNARY_OP_GELU:
553
+ {
554
+ if (encoder == nil ) {
555
+ encoder = [command_buffer computeCommandEncoder ];
556
+ }
557
+
558
+ [encoder setComputePipelineState: ctx->pipeline_gelu];
559
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
560
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
561
+
562
+ const int64_t n = ggml_nelements (dst);
563
+
564
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
565
+ } break ;
566
+ default :
567
+ {
568
+ fprintf (stderr, " %s : node %3d , op = %8s not implemented\n " , __func__, i, ggml_op_name (dst->op ));
569
+ GGML_ASSERT (false );
570
+ }
549
571
} break ;
550
- case GGML_OP_GELU:
551
- {
552
- if (encoder == nil ) {
553
- encoder = [command_buffer computeCommandEncoder ];
554
- }
555
-
556
- [encoder setComputePipelineState: ctx->pipeline_gelu];
557
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
558
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
559
-
560
- const int64_t n = ggml_nelements (dst);
561
-
562
- [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
563
- } break ;
564
572
case GGML_OP_SOFT_MAX:
565
573
{
566
574
if (encoder == nil ) {
@@ -979,8 +987,10 @@ void ggml_metal_graph_compute(
979
987
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
980
988
} break ;
981
989
default :
982
- fprintf (stderr, " %s : node %3d , op = %8s not implemented\n " , __func__, i, ggml_op_name (dst->op ));
983
- GGML_ASSERT (false );
990
+ {
991
+ fprintf (stderr, " %s : node %3d , op = %8s not implemented\n " , __func__, i, ggml_op_name (dst->op ));
992
+ GGML_ASSERT (false );
993
+ }
984
994
}
985
995
}
986
996
0 commit comments