This guide provides a step-by-step process for adding a new model in PyTorch Backend.
Before you begin, ensure you have the following:
- A working installation of TensorRT-LLM. Follow these instructions.
Suppose you want to support a new model named MyModel
. If the model is already supported in HuggingFace's transformers, you should bring the PyTorch modeling code and reuse HuggingFace's configuration class. For example, our tensorrt_llm/_torch/models/modeling_llama.py
was adapted from HuggingFace's modeling_llama.py; in the modeling code, we reuse the configuration class:
from transformers import LlamaConfig
If the model is not registered in HuggingFace's transformers, you need to define the configuration class in your configuration_mymodel.py
following HuggingFace's configuration_llama.py:
from transformers.configuration_utils import PretrainedConfig
class MyConfig(PretrainedConfig):
def __init__(self, ...):
...
Remove any unnecessary code (e.g., training-specific code), and then rewrite some PyTorch modules. For a typical Transformer decoder model, you need to implement your modeling_mymodel.py
like this:
from typing import Optional
import torch
from torch import nn
from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import DecoderModel, DecoderModelForCausalLM
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
from configuration_mymodel import MyConfig
class MyAttention(Attention):
def __init__(self, model_config: ModelConfig[MyConfig], layer_idx: Optional[int] = None):
# Use model_config to initialize the Attention module
super().__init__(...)
class MyDecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig[MyConfig], layer_idx: int):
super().__init__()
# Use model_config to initialize the submodules
self.input_layernorm = ...
self.self_attn = MyAttention(model_config, layer_idx)
self.post_attention_layernorm = ...
self.mlp = ...
def forward(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, **kwargs):
# Define the forward computation of a single decoder layer
...
class MyModel(DecoderModel):
def __init__(self, model_config: ModelConfig[MyConfig]):
super().__init__(model_config)
# Use model_config to initialize the submodules
self.embed_tokens = ...
self.layers = nn.ModuleList([
MyDecoderLayer(model_config, layer_idx) for layer_idx in range(model_config.pretrained_config.num_hidden_layers)
])
def forward(self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None):
# Define the forward computation of the model
...
class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):
def __init__(self, model_config: ModelConfig[MyConfig]):
super().__init__(MyModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
Note that MyAttention
inherits from our Attention
module (in tensorrt_llm/_torch/modules/attention.py
), so that the attention computation is compatible with our PyTorch runtime. Related to this, module inputs should also be adapted:
- The
attn_metadata
stores the metadata from the batched input and KV cache for the attention backend. It is created by and passed from the runtime, and model developers need to ensure thatattn_metadata
is correctly passed to the attention module. - The input tensors (i.e.,
input_ids
,position_ids
,hidden_states
) are in the packed mode. The first dimension corresponds to the number of tokens in a batch.
Additionally, MyDecoderLayer
, MyModel
, and MyModelForCausalLM
are subclasses of DecoderLayer
, DecoderModel
, and DecoderModelForCausalLM
respectively. The base classes define interfaces and provide a generic scaffolding to define model layers, load weights, etc.
Optionally, you may replace the native PyTorch modules with our implementations to enable features or achieve higher performance:
Linear
(intensorrt_llm/_torch/modules/linear.py
): Enables tensor parallelism and quantization.Embedding
(intensorrt_llm/_torch/modules/embedding.py
): Enables tensor parallelism for embedding.RotaryEmbedding
(intensorrt_llm/_torch/modules/rotary_embedding.py
): Enables performant rotary embedding.RMSNorm
(intensorrt_llm/_torch/modules/rms_norm.py
): Enables performant RMS norm.
For a concrete reference, check out tensorrt_llm/_torch/models/modeling_llama.py
.
The base class DecoderModelForCausalLM
provides a load_weights
method that loads the weights from the checkpoint file and assigns them to the corresponding layers in the model. However, if the default method does not work for MyModelForCausalLM
, you need to implement your own load_weights
:
class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):
def load_weights(self, weights: dict):
# Define the weight loading logic
...
For example, Huggingface's LLaMA model uses three linear layers for Q/K/V projections, resulting in three weight tensors in the checkpoint:
>>> weights
{
...,
"model.layers.0.self_attn.q_proj.weight": torch.Tensor([hidden_size, hidden_size]),
"model.layers.0.self_attn.k_proj.weight": torch.Tensor([hidden_size, hidden_size]),
"model.layers.0.self_attn.v_proj.weight": torch.Tensor([hidden_size, hidden_size]),
...,
}
However, our LLaMA model fuses the three layers into one linear layer:
>>> llama.model.layers[0].self_attn.qkv_proj.weight.data
torch.Tensor([hidden_size * 3, hidden_size])
Hence, load_weights
needs to collect the three weight tensors from the original checkpoint, concatenate them, and assign them to the fused linear layer. Considering tensor parallelism and quantization, the process would be more complicated. We recommend calling the predefined module-level load_weights
(e.g., Linear
and Embedding
) when implementing your model-level load_weights
method.
Overall, load_weights
should handle any discrepancy between MyModelForCausalLM
and the weights loaded from the checkpoint, so that MyModelForCausalLM
can perform forward computation equivalent to the original model.
The new model needs to be registered so that it can be recognized by the PyTorch runtime. The registration can be done simply by adding the register_auto_model
decorator for MyModelForCausalLM
:
from tensorrt_llm._torch.models.modeling_utils import register_auto_model
@register_auto_model("MyModelForCausalLM")
class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):
def __init__(self, model_config: ModelConfig[MyConfig]):
...
To add the new model to core models, modeling_mymodel.py
(and potentially configuration_mymodel.py
) should be placed in tensorrt_llm/_torch/models
. Then, you need to import the modeling code in tensorrt_llm/_torch/models/__init__.py
:
from .modeling_mymodel import MyModelForCausalLM
__all__ = [
...,
"MyModelForCausalLM",
]
Alternatively, you can register the new model as an out-of-tree model, so that you can use the new model without touching the TensorRT-LLM codebase. To do so, place modeling_mymodel.py
(and potentially configuration_mymodel.py
) in your working directory, and import the modeling code in your script:
from tensorrt_llm._torch import LLM
import modeling_mymodel
def main():
llm = LLM(...)
if __name__ == '__main__':
main()
We provide an out-of-tree modeling example in examples/pytorch/out_of_tree_example
. The model is implemented in modeling_opt.py
and you can run the example by:
python examples/pytorch/out_of_tree_example/main.py