@@ -640,7 +640,9 @@ GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_
640
640
}
641
641
642
642
GGML_CALL static bool ggml_backend_cpu_buffer_type_supports_backend (ggml_backend_buffer_type_t buft , ggml_backend_t backend ) {
643
- return ggml_backend_is_cpu (backend );
643
+ // HACK
644
+ static ggml_guid blas_guid = { 0x12 , 0xa8 , 0xae , 0xf4 , 0xc0 , 0x1e , 0x61 , 0x97 , 0x8f , 0xeb , 0x33 , 0x04 , 0xa1 , 0x33 , 0x51 , 0x2d };
645
+ return ggml_backend_is_cpu (backend ) || ggml_guid_matches (backend -> guid , & blas_guid );
644
646
645
647
GGML_UNUSED (buft );
646
648
}
@@ -1097,15 +1099,16 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen
1097
1099
return -1 ;
1098
1100
}
1099
1101
1100
- static int ggml_backend_sched_backend_from_buffer (ggml_backend_sched_t sched , const struct ggml_tensor * tensor ) {
1102
+ static int ggml_backend_sched_backend_from_buffer (ggml_backend_sched_t sched , const struct ggml_tensor * tensor , const struct ggml_tensor * op ) {
1101
1103
ggml_backend_buffer_t buffer = tensor -> buffer ;
1102
1104
if (buffer == NULL ) {
1103
1105
return -1 ;
1104
1106
}
1105
1107
1106
- // find highest prio backend that supports the buffer type
1108
+ // find highest prio backend that supports the buffer type and the op
1107
1109
for (int i = 0 ; i < sched -> n_backends ; i ++ ) {
1108
- if (ggml_backend_buft_supports_backend (buffer -> buft , sched -> backends [i ])) {
1110
+ if (ggml_backend_buft_supports_backend (buffer -> buft , sched -> backends [i ]) &&
1111
+ ggml_backend_supports_op (sched -> backends [i ], op )) {
1109
1112
return i ;
1110
1113
}
1111
1114
}
@@ -1126,20 +1129,25 @@ static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED
1126
1129
#define GET_CAUSE (node ) ""
1127
1130
#endif
1128
1131
1132
+ //#define DEBUG_PASS1
1133
+ //#define DEBUG_PASS2
1134
+ //#define DEBUG_PASS3
1135
+ //#define DEBUG_PASS4
1136
+
1129
1137
// returns the backend that should be used for the node based on the current locations
1130
1138
static int ggml_backend_sched_backend_id_from_cur (ggml_backend_sched_t sched , struct ggml_tensor * tensor ) {
1131
1139
// TODO: use supports_op to check if the backend supports the op
1132
1140
1133
1141
// assign pre-allocated nodes to their backend
1134
- int cur_backend_id = ggml_backend_sched_backend_from_buffer (sched , tensor );
1142
+ int cur_backend_id = ggml_backend_sched_backend_from_buffer (sched , tensor , tensor );
1135
1143
if (cur_backend_id != -1 ) {
1136
1144
SET_CAUSE (tensor , "1.dst" );
1137
1145
return cur_backend_id ;
1138
1146
}
1139
1147
1140
1148
// view_src
1141
1149
if (tensor -> view_src != NULL ) {
1142
- cur_backend_id = ggml_backend_sched_backend_from_buffer (sched , tensor -> view_src );
1150
+ cur_backend_id = ggml_backend_sched_backend_from_buffer (sched , tensor -> view_src , tensor );
1143
1151
if (cur_backend_id != -1 ) {
1144
1152
SET_CAUSE (tensor , "1.vsrc" );
1145
1153
return cur_backend_id ;
@@ -1161,7 +1169,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
1161
1169
continue ;
1162
1170
}
1163
1171
if (src -> buffer != NULL && src -> buffer -> usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS ) {
1164
- int src_backend_id = ggml_backend_sched_backend_from_buffer (sched , src );
1172
+ int src_backend_id = ggml_backend_sched_backend_from_buffer (sched , src , tensor );
1165
1173
// check if a backend with higher prio wants to offload the op
1166
1174
if (src_backend_id == sched -> n_backends - 1 ) {
1167
1175
for (int b = 0 ; b < src_backend_id ; b ++ ) {
@@ -1223,10 +1231,30 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
1223
1231
}
1224
1232
}
1225
1233
1226
- //#define DEBUG_PASS1
1227
- //#define DEBUG_PASS2
1228
- //#define DEBUG_PASS3
1229
- //#define DEBUG_PASS4
1234
+ static int set_if_supports (ggml_backend_sched_t sched , struct ggml_tensor * node , int cur_backend_id , int * node_backend_id ) {
1235
+ if (ggml_backend_supports_op (sched -> backends [cur_backend_id ], node )) {
1236
+ * node_backend_id = cur_backend_id ;
1237
+ SET_CAUSE (node , "2.2" );
1238
+ } else {
1239
+ for (int b = 0 ; b < sched -> n_backends ; b ++ ) {
1240
+ if (b == cur_backend_id ) {
1241
+ continue ;
1242
+ }
1243
+ if (ggml_backend_supports_op (sched -> backends [b ], node )) {
1244
+ * node_backend_id = b ;
1245
+ cur_backend_id = b ;
1246
+ SET_CAUSE (node , "2.2" );
1247
+ break ;
1248
+ }
1249
+ }
1250
+ }
1251
+ return cur_backend_id ;
1252
+ }
1253
+
1254
+ static bool buffer_supported (ggml_backend_sched_t sched , const struct ggml_tensor * t , int cur_backend_id ) {
1255
+ ggml_backend_buffer_t buf = t -> view_src ? t -> view_src -> buffer : t -> buffer ;
1256
+ return buf != NULL && ggml_backend_buft_supports_backend (buf -> buft , sched -> backends [cur_backend_id ]);
1257
+ }
1230
1258
1231
1259
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
1232
1260
static void ggml_backend_sched_split_graph (ggml_backend_sched_t sched , struct ggml_cgraph * graph ) {
@@ -1306,9 +1334,13 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
1306
1334
} else {
1307
1335
cur_backend_id = * node_backend_id ;
1308
1336
}
1309
- } else {
1310
- * node_backend_id = cur_backend_id ;
1311
- SET_CAUSE (node , "2.2" );
1337
+ } else if (cur_backend_id != -1 ) {
1338
+ // FIXME: clean this
1339
+ cur_backend_id = set_if_supports (sched , node , cur_backend_id , node_backend_id );
1340
+ if (cur_backend_id == sched -> n_backends - 1 ) {
1341
+ // skip cpu (lowest prio backend)
1342
+ cur_backend_id = -1 ;
1343
+ }
1312
1344
}
1313
1345
}
1314
1346
}
@@ -1328,9 +1360,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
1328
1360
} else {
1329
1361
cur_backend_id = * node_backend_id ;
1330
1362
}
1331
- } else {
1332
- * node_backend_id = cur_backend_id ;
1333
- SET_CAUSE (node , "2.1" );
1363
+ } else if (cur_backend_id != -1 ) {
1364
+ cur_backend_id = set_if_supports (sched , node , cur_backend_id , node_backend_id );
1365
+ if (cur_backend_id == sched -> n_backends - 1 ) {
1366
+ // skip cpu (lowest prio backend)
1367
+ cur_backend_id = -1 ;
1368
+ }
1334
1369
}
1335
1370
}
1336
1371
}
@@ -1345,9 +1380,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
1345
1380
int * node_backend_id = & tensor_backend_id (node );
1346
1381
if (* node_backend_id != -1 ) {
1347
1382
cur_backend_id = * node_backend_id ;
1348
- } else {
1349
- * node_backend_id = cur_backend_id ;
1350
- SET_CAUSE (node , "2.4" );
1383
+ } else if (cur_backend_id != -1 ) {
1384
+ cur_backend_id = set_if_supports (sched , node , cur_backend_id , node_backend_id );
1351
1385
}
1352
1386
}
1353
1387
}
@@ -1362,9 +1396,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
1362
1396
int * node_backend_id = & tensor_backend_id (node );
1363
1397
if (* node_backend_id != -1 ) {
1364
1398
cur_backend_id = * node_backend_id ;
1365
- } else {
1366
- * node_backend_id = cur_backend_id ;
1367
- SET_CAUSE (node , "2.3" );
1399
+ } else if (cur_backend_id != -1 ) {
1400
+ cur_backend_id = set_if_supports (sched , node , cur_backend_id , node_backend_id );
1368
1401
}
1369
1402
}
1370
1403
}
@@ -1448,10 +1481,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
1448
1481
}
1449
1482
}
1450
1483
// check if the split has too many inputs
1484
+ // FIXME: count the number of inputs instead of only checking when full
1451
1485
if (split -> n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS ) {
1452
1486
const size_t id = hash_id (src );
1453
1487
int src_backend_id = sched -> tensor_backend_id [id ];
1454
- if (src_backend_id != cur_backend_id && sched -> tensor_copies [hash_id (src )][cur_backend_id ][0 ] == NULL ) {
1488
+ bool supported = buffer_supported (sched , src , cur_backend_id );
1489
+ if (src_backend_id != cur_backend_id && sched -> tensor_copies [hash_id (src )][cur_backend_id ][0 ] == NULL && !supported ) {
1455
1490
//printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name);
1456
1491
need_new_split = true;
1457
1492
break ;
@@ -1511,7 +1546,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
1511
1546
}
1512
1547
}
1513
1548
1514
- if (src_backend_id != node_backend_id ) {
1549
+ bool supported = buffer_supported (sched , src , cur_backend_id );
1550
+ if (src_backend_id != cur_backend_id && !supported ) {
1515
1551
// create a copy of the input in the split's backend
1516
1552
const size_t id = hash_id (src );
1517
1553
if (sched -> tensor_copies [id ][cur_backend_id ][0 ] == NULL ) {
0 commit comments