Skip to content

Commit 36eed0c

Browse files
authored
stablelm : StableLM support (#3586)
* Add support for stablelm-3b-4e1t * Supports GPU offloading of (n-1) layers
1 parent b46d12f commit 36eed0c

File tree

6 files changed

+322
-12
lines changed

6 files changed

+322
-12
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ as the main playground for developing new features for the [ggml](https://github
9393
- [X] [Persimmon 8B](https://github.com/ggerganov/llama.cpp/pull/3410)
9494
- [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417)
9595
- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553)
96+
- [X] [StableLM-3b-4e1t](https://github.com/ggerganov/llama.cpp/pull/3586)
9697

9798

9899
**Bindings:**

convert-hf-to-gguf.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ def load_hparams(dir_model):
150150

151151
@staticmethod
152152
def from_model_architecture(model_architecture):
153-
if model_architecture == "StableLMEpochForCausalLM":
154-
return StableLMModel
155153
if model_architecture == "GPTNeoXForCausalLM":
156154
return GPTNeoXModel
157155
if model_architecture == "BloomForCausalLM":
@@ -168,6 +166,8 @@ def from_model_architecture(model_architecture):
168166
return RefactModel
169167
if model_architecture == "PersimmonForCausalLM":
170168
return PersimmonModel
169+
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
170+
return StableLMModel
171171
return Model
172172

173173
def _is_model_safetensors(self) -> bool:
@@ -201,6 +201,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
201201
return gguf.MODEL_ARCH.REFACT
202202
if arch == "PersimmonForCausalLM":
203203
return gguf.MODEL_ARCH.PERSIMMON
204+
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
205+
return gguf.MODEL_ARCH.STABLELM
204206

205207
raise NotImplementedError(f'Architecture "{arch}" not supported!')
206208

@@ -294,15 +296,6 @@ def _set_vocab_sentencepiece(self):
294296
special_vocab.add_to_gguf(self.gguf_writer)
295297

296298

297-
class StableLMModel(Model):
298-
def set_gguf_parameters(self):
299-
super().set_gguf_parameters()
300-
self.gguf_writer.add_rope_dimension_count(
301-
int(self.hparams["rope_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
302-
)
303-
self.gguf_writer.add_layer_norm_eps(1e-5)
304-
305-
306299
class GPTNeoXModel(Model):
307300
def set_gguf_parameters(self):
308301
block_count = self.hparams["num_hidden_layers"]
@@ -824,6 +817,21 @@ def write_tensors(self):
824817
self.gguf_writer.add_tensor(new_name, data)
825818

826819

820+
class StableLMModel(Model):
821+
def set_gguf_parameters(self):
822+
hparams = self.hparams
823+
block_count = hparams["num_hidden_layers"]
824+
825+
self.gguf_writer.add_name(dir_model.name)
826+
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
827+
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
828+
self.gguf_writer.add_block_count(block_count)
829+
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
830+
self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"]*(hparams["hidden_size"] // hparams["num_attention_heads"])))
831+
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
832+
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
833+
self.gguf_writer.add_layer_norm_eps(1e-5)
834+
827835
###### CONVERSION LOGIC ######
828836

829837
def parse_args() -> argparse.Namespace:

gguf-py/gguf/constants.py

+17
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class MODEL_ARCH(IntEnum):
9090
REFACT = auto()
9191
BERT = auto()
9292
BLOOM = auto()
93+
STABLELM = auto()
9394

9495

9596
class MODEL_TENSOR(IntEnum):
@@ -129,6 +130,7 @@ class MODEL_TENSOR(IntEnum):
129130
MODEL_ARCH.REFACT: "refact",
130131
MODEL_ARCH.BERT: "bert",
131132
MODEL_ARCH.BLOOM: "bloom",
133+
MODEL_ARCH.STABLELM: "stablelm",
132134
}
133135

134136
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -299,6 +301,21 @@ class MODEL_TENSOR(IntEnum):
299301
MODEL_TENSOR.FFN_DOWN,
300302
MODEL_TENSOR.FFN_UP,
301303
],
304+
MODEL_ARCH.STABLELM: [
305+
MODEL_TENSOR.TOKEN_EMBD,
306+
MODEL_TENSOR.OUTPUT_NORM,
307+
MODEL_TENSOR.OUTPUT,
308+
MODEL_TENSOR.ROPE_FREQS,
309+
MODEL_TENSOR.ATTN_NORM,
310+
MODEL_TENSOR.ATTN_Q,
311+
MODEL_TENSOR.ATTN_K,
312+
MODEL_TENSOR.ATTN_V,
313+
MODEL_TENSOR.ATTN_OUT,
314+
MODEL_TENSOR.FFN_NORM,
315+
MODEL_TENSOR.FFN_GATE,
316+
MODEL_TENSOR.FFN_DOWN,
317+
MODEL_TENSOR.FFN_UP,
318+
],
302319
MODEL_ARCH.GPT2: [
303320
# TODO
304321
],

0 commit comments

Comments
 (0)