Skip to content

Commit cde46d2

Browse files
committed
move BLAS to a separate backend
1 parent bde7cd3 commit cde46d2

7 files changed

+356
-219
lines changed

CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,9 @@ if (LLAMA_BLAS)
381381
add_compile_definitions(GGML_BLAS_USE_MKL)
382382
endif()
383383

384+
set(GGML_HEADERS_BLAS ggml-blas.h)
385+
set(GGML_SOURCES_BLAS ggml-blas.c)
386+
384387
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES})
385388
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS})
386389
else()
@@ -1268,6 +1271,7 @@ add_library(ggml OBJECT
12681271
${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE}
12691272
${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN}
12701273
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
1274+
${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS}
12711275
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
12721276
)
12731277

Makefile

+17-4
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ ifndef LLAMA_NO_ACCELERATE
408408
MK_CPPFLAGS += -DACCELERATE_NEW_LAPACK
409409
MK_CPPFLAGS += -DACCELERATE_LAPACK_ILP64
410410
MK_LDFLAGS += -framework Accelerate
411+
OBJS += ggml-blas.o
411412
endif
412413
endif # LLAMA_NO_ACCELERATE
413414

@@ -421,23 +422,35 @@ ifdef LLAMA_OPENBLAS
421422
MK_CPPFLAGS += -DGGML_USE_OPENBLAS $(shell pkg-config --cflags-only-I openblas)
422423
MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas)
423424
MK_LDFLAGS += $(shell pkg-config --libs openblas)
425+
OBJS += ggml-blas.o
424426
endif # LLAMA_OPENBLAS
425427

426-
ifndef LLAMA_NO_LLAMAFILE
427-
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
428-
OBJS += sgemm.o
429-
endif
428+
ifdef LLAMA_OPENBLAS64
429+
MK_CPPFLAGS += -DGGML_USE_OPENBLAS $(shell pkg-config --cflags-only-I openblas64)
430+
MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas64)
431+
MK_LDFLAGS += $(shell pkg-config --libs openblas64)
432+
OBJS += ggml-blas.o
433+
endif # LLAMA_OPENBLAS64
430434

431435
ifdef LLAMA_BLIS
432436
MK_CPPFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/blis -I/usr/include/blis
433437
MK_LDFLAGS += -lblis -L/usr/local/lib
438+
OBJS += ggml-blas.o
434439
endif # LLAMA_BLIS
435440

441+
ifndef LLAMA_NO_LLAMAFILE
442+
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
443+
OBJS += sgemm.o
444+
endif
445+
436446
ifdef LLAMA_RPC
437447
MK_CPPFLAGS += -DGGML_USE_RPC
438448
OBJS += ggml-rpc.o
439449
endif # LLAMA_RPC
440450

451+
ggml-blas.o: ggml-blas.c ggml-blas.h
452+
$(CC) $(CFLAGS) -c $< -o $@
453+
441454
ifdef LLAMA_CUBLAS
442455
# LLAMA_CUBLAS is deprecated and will be removed in the future
443456
LLAMA_CUDA := 1

ggml-backend.c

+61-25
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,9 @@ GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_
640640
}
641641

642642
GGML_CALL static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
643-
return ggml_backend_is_cpu(backend);
643+
// HACK
644+
static ggml_guid blas_guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
645+
return ggml_backend_is_cpu(backend) || ggml_guid_matches(backend->guid, &blas_guid);
644646

645647
GGML_UNUSED(buft);
646648
}
@@ -1097,15 +1099,16 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen
10971099
return -1;
10981100
}
10991101

1100-
static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor) {
1102+
static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) {
11011103
ggml_backend_buffer_t buffer = tensor->buffer;
11021104
if (buffer == NULL) {
11031105
return -1;
11041106
}
11051107

1106-
// find highest prio backend that supports the buffer type
1108+
// find highest prio backend that supports the buffer type and the op
11071109
for (int i = 0; i < sched->n_backends; i++) {
1108-
if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
1110+
if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i]) &&
1111+
ggml_backend_supports_op(sched->backends[i], op)) {
11091112
return i;
11101113
}
11111114
}
@@ -1126,20 +1129,25 @@ static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED
11261129
#define GET_CAUSE(node) ""
11271130
#endif
11281131

1132+
//#define DEBUG_PASS1
1133+
//#define DEBUG_PASS2
1134+
//#define DEBUG_PASS3
1135+
//#define DEBUG_PASS4
1136+
11291137
// returns the backend that should be used for the node based on the current locations
11301138
static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) {
11311139
// TODO: use supports_op to check if the backend supports the op
11321140

11331141
// assign pre-allocated nodes to their backend
1134-
int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor);
1142+
int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor);
11351143
if (cur_backend_id != -1) {
11361144
SET_CAUSE(tensor, "1.dst");
11371145
return cur_backend_id;
11381146
}
11391147

11401148
// view_src
11411149
if (tensor->view_src != NULL) {
1142-
cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src);
1150+
cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src, tensor);
11431151
if (cur_backend_id != -1) {
11441152
SET_CAUSE(tensor, "1.vsrc");
11451153
return cur_backend_id;
@@ -1161,7 +1169,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
11611169
continue;
11621170
}
11631171
if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
1164-
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src);
1172+
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
11651173
// check if a backend with higher prio wants to offload the op
11661174
if (src_backend_id == sched->n_backends - 1) {
11671175
for (int b = 0; b < src_backend_id; b++) {
@@ -1223,10 +1231,30 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
12231231
}
12241232
}
12251233

1226-
//#define DEBUG_PASS1
1227-
//#define DEBUG_PASS2
1228-
//#define DEBUG_PASS3
1229-
//#define DEBUG_PASS4
1234+
static int set_if_supports(ggml_backend_sched_t sched, struct ggml_tensor * node, int cur_backend_id, int * node_backend_id) {
1235+
if (ggml_backend_supports_op(sched->backends[cur_backend_id], node)) {
1236+
*node_backend_id = cur_backend_id;
1237+
SET_CAUSE(node, "2.2");
1238+
} else {
1239+
for (int b = 0; b < sched->n_backends; b++) {
1240+
if (b == cur_backend_id) {
1241+
continue;
1242+
}
1243+
if (ggml_backend_supports_op(sched->backends[b], node)) {
1244+
*node_backend_id = b;
1245+
cur_backend_id = b;
1246+
SET_CAUSE(node, "2.2");
1247+
break;
1248+
}
1249+
}
1250+
}
1251+
return cur_backend_id;
1252+
}
1253+
1254+
static bool buffer_supported(ggml_backend_sched_t sched, const struct ggml_tensor * t, int cur_backend_id) {
1255+
ggml_backend_buffer_t buf = t->view_src ? t->view_src->buffer : t->buffer;
1256+
return buf != NULL && ggml_backend_buft_supports_backend(buf->buft, sched->backends[cur_backend_id]);
1257+
}
12301258

12311259
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
12321260
static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
@@ -1306,9 +1334,13 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
13061334
} else {
13071335
cur_backend_id = *node_backend_id;
13081336
}
1309-
} else {
1310-
*node_backend_id = cur_backend_id;
1311-
SET_CAUSE(node, "2.2");
1337+
} else if (cur_backend_id != -1) {
1338+
// FIXME: clean this
1339+
cur_backend_id = set_if_supports(sched, node, cur_backend_id, node_backend_id);
1340+
if (cur_backend_id == sched->n_backends - 1) {
1341+
// skip cpu (lowest prio backend)
1342+
cur_backend_id = -1;
1343+
}
13121344
}
13131345
}
13141346
}
@@ -1328,9 +1360,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
13281360
} else {
13291361
cur_backend_id = *node_backend_id;
13301362
}
1331-
} else {
1332-
*node_backend_id = cur_backend_id;
1333-
SET_CAUSE(node, "2.1");
1363+
} else if (cur_backend_id != -1) {
1364+
cur_backend_id = set_if_supports(sched, node, cur_backend_id, node_backend_id);
1365+
if (cur_backend_id == sched->n_backends - 1) {
1366+
// skip cpu (lowest prio backend)
1367+
cur_backend_id = -1;
1368+
}
13341369
}
13351370
}
13361371
}
@@ -1345,9 +1380,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
13451380
int * node_backend_id = &tensor_backend_id(node);
13461381
if (*node_backend_id != -1) {
13471382
cur_backend_id = *node_backend_id;
1348-
} else {
1349-
*node_backend_id = cur_backend_id;
1350-
SET_CAUSE(node, "2.4");
1383+
} else if (cur_backend_id != -1) {
1384+
cur_backend_id = set_if_supports(sched, node, cur_backend_id, node_backend_id);
13511385
}
13521386
}
13531387
}
@@ -1362,9 +1396,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
13621396
int * node_backend_id = &tensor_backend_id(node);
13631397
if (*node_backend_id != -1) {
13641398
cur_backend_id = *node_backend_id;
1365-
} else {
1366-
*node_backend_id = cur_backend_id;
1367-
SET_CAUSE(node, "2.3");
1399+
} else if (cur_backend_id != -1) {
1400+
cur_backend_id = set_if_supports(sched, node, cur_backend_id, node_backend_id);
13681401
}
13691402
}
13701403
}
@@ -1448,10 +1481,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
14481481
}
14491482
}
14501483
// check if the split has too many inputs
1484+
// FIXME: count the number of inputs instead of only checking when full
14511485
if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) {
14521486
const size_t id = hash_id(src);
14531487
int src_backend_id = sched->tensor_backend_id[id];
1454-
if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL) {
1488+
bool supported = buffer_supported(sched, src, cur_backend_id);
1489+
if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL && !supported) {
14551490
//printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name);
14561491
need_new_split = true;
14571492
break;
@@ -1511,7 +1546,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
15111546
}
15121547
}
15131548

1514-
if (src_backend_id != node_backend_id) {
1549+
bool supported = buffer_supported(sched, src, cur_backend_id);
1550+
if (src_backend_id != cur_backend_id && !supported) {
15151551
// create a copy of the input in the split's backend
15161552
const size_t id = hash_id(src);
15171553
if (sched->tensor_copies[id][cur_backend_id][0] == NULL) {

0 commit comments

Comments
 (0)