-
Notifications
You must be signed in to change notification settings - Fork 249
Llama3.1 with torchtune #1123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama3.1 with torchtune #1123
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1123
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1599c2b with merge base 964d437 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR makes torchchat support multi-modality model definition and constructions. To show our power in multi-modality area, we integrate flamingo component into our system. Note that this is only for bare-minimum support for model definition. Please check openai_api_multimodal branch for e2e, and #1123 (comment) for better structure and llama3.1 support
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR itself looks find to me with some minor nits.
torchchat/cli/builder.py
Outdated
@@ -35,6 +35,14 @@ | |||
from torchchat.utils.measure_time import measure_time | |||
from torchchat.utils.quantize import quantize_model | |||
|
|||
# bypass the import issue before torchao is ready on macos | |||
try: | |||
from torchtune.training import set_default_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused import?
@@ -11,6 +11,7 @@ | |||
from enum import Enum | |||
from pathlib import Path | |||
from typing import Callable, Dict, Optional, Union | |||
from abc import ABC, abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ABC is unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be one of the Model's parents. Fixed it.
torchchat/model.py
Outdated
if isinstance(self.config.transformer_args[name], dict): | ||
modules[name] = module_class(**self.config.transformer_args[name]) | ||
else: | ||
modules[name] = module_class(self.config.transformer_args[name]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(self.config.transformer_args[name], dict): | |
modules[name] = module_class(**self.config.transformer_args[name]) | |
else: | |
modules[name] = module_class(self.config.transformer_args[name]) | |
if isinstance(config_args := self.config.transformer_args[name], dict): | |
modules[name] = module_class(**config_args) | |
else: | |
modules[name] = module_class(config_args) |
torchchat/model.py
Outdated
|
||
|
||
class FlamingoModel(Model): | ||
def forward(self, tokens: Tensor, encoder_input: Optional[Dict[str, Tensor]] = None, encoder_mask: Optional[Tensor] = None) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lint long line
This PR aims to add torchtune llama3.1 support while keep the original torchchat llama3.1 for reference.
To play with it:
The command for original model should be the same:
If you want to play with torchtune 3.1 model, consider using: