@@ -181,7 +181,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
181
181
182
182
static void norm_f32_sycl (const float * x, float * dst, const int ncols,
183
183
const int nrows, const float eps,
184
- queue_ptr stream) {
184
+ queue_ptr stream, int device ) {
185
185
GGML_ASSERT (ncols % WARP_SIZE == 0 );
186
186
if (ncols < 1024 ) {
187
187
const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE);
@@ -197,7 +197,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
197
197
});
198
198
}
199
199
else {
200
- const int work_group_size = get_work_group_size (stream-> get_device ()) ;
200
+ const int work_group_size = ggml_sycl_info (). max_work_group_sizes [device] ;
201
201
const sycl::range<3 > block_dims (1 , 1 , work_group_size);
202
202
/*
203
203
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
@@ -222,7 +222,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
222
222
223
223
static void group_norm_f32_sycl (const float * x, float * dst,
224
224
const int num_groups, const int group_size,
225
- const int ne_elements, queue_ptr stream) {
225
+ const int ne_elements, queue_ptr stream, int device ) {
226
226
static const float eps = 1e-6f ;
227
227
if (group_size < 1024 ) {
228
228
const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE);
@@ -240,7 +240,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
240
240
});
241
241
}
242
242
else {
243
- const int work_group_size = get_work_group_size (stream-> get_device ()) ;
243
+ const int work_group_size = ggml_sycl_info (). max_work_group_sizes [device] ;
244
244
const sycl::range<3 > block_dims (1 , 1 , work_group_size);
245
245
/*
246
246
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
@@ -269,7 +269,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
269
269
270
270
static void rms_norm_f32_sycl (const float * x, float * dst, const int ncols,
271
271
const int nrows, const float eps,
272
- queue_ptr stream) {
272
+ queue_ptr stream, int device ) {
273
273
GGML_ASSERT (ncols % WARP_SIZE == 0 );
274
274
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
275
275
if (ncols < 1024 ) {
@@ -286,7 +286,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
286
286
});
287
287
}
288
288
else {
289
- const int work_group_size = get_work_group_size (stream-> get_device ()) ;
289
+ const int work_group_size = ggml_sycl_info (). max_work_group_sizes [device] ;
290
290
const sycl::range<3 > block_dims (1 , 1 , work_group_size);
291
291
/*
292
292
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
@@ -322,7 +322,7 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
322
322
float eps;
323
323
memcpy (&eps, dst->op_params , sizeof (float ));
324
324
325
- norm_f32_sycl (src0_dd, dst_dd, ne00, nrows, eps, main_stream);
325
+ norm_f32_sycl (src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx. device );
326
326
327
327
(void )src1;
328
328
(void )dst;
@@ -340,7 +340,7 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
340
340
341
341
int num_groups = dst->op_params [0 ];
342
342
int group_size = src0->ne [0 ] * src0->ne [1 ] * ((src0->ne [2 ] + num_groups - 1 ) / num_groups);
343
- group_norm_f32_sycl (src0_dd, dst_dd, num_groups, group_size, src0->ne [0 ] * src0->ne [1 ] * src0->ne [2 ], main_stream);
343
+ group_norm_f32_sycl (src0_dd, dst_dd, num_groups, group_size, src0->ne [0 ] * src0->ne [1 ] * src0->ne [2 ], main_stream, ctx. device );
344
344
345
345
(void )src1;
346
346
(void )dst;
@@ -362,7 +362,7 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
362
362
float eps;
363
363
memcpy (&eps, dst->op_params , sizeof (float ));
364
364
365
- rms_norm_f32_sycl (src0_dd, dst_dd, ne00, nrows, eps, main_stream);
365
+ rms_norm_f32_sycl (src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx. device );
366
366
367
367
(void )src1;
368
368
(void )dst;
0 commit comments