Skip to content

Commit e665b55

Browse files
authored
Merge pull request ggml-org#523 from shouyiwang/tensor_split
Update tensor_split to match llama.cpp's change
2 parents d3bf7db + 426dbfe commit e665b55

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

llama_cpp/llama.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,12 @@ def __init__(
273273
self.params.low_vram = low_vram
274274

275275
self.tensor_split = tensor_split
276-
self._c_tensor_split = None
276+
self._p_tensor_split = None
277277

278278
if self.tensor_split is not None:
279-
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
280-
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
281-
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
282-
self.params.tensor_split = self._c_tensor_split
279+
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
280+
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd
281+
self.params.tensor_split = self._p_tensor_split
283282

284283
self.params.rope_freq_base = rope_freq_base
285284
self.params.rope_freq_scale = rope_freq_scale

0 commit comments

Comments
 (0)