Skip to content

Llama3.1 with torchtune #1123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 53 commits into from
Sep 11, 2024
Merged

Llama3.1 with torchtune #1123

merged 53 commits into from
Sep 11, 2024

Conversation

Gasoonjia
Copy link
Contributor

@Gasoonjia Gasoonjia commented Sep 9, 2024

This PR aims to add torchtune llama3.1 support while keep the original torchchat llama3.1 for reference.
To play with it:

The command for original model should be the same:

python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy and his bear"

If you want to play with torchtune 3.1 model, consider using:

python3 torchchat.py generate llama3.1-tune --prompt "write me a story about a boy and his bear"

Copy link

pytorch-bot bot commented Sep 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1123

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1599c2b with merge base 964d437 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 9, 2024
Gasoonjia added a commit that referenced this pull request Sep 11, 2024
This PR makes torchchat support multi-modality model definition and constructions. To show our power in multi-modality area, we integrate flamingo component into our system.
Note that this is only for bare-minimum support for model definition. Please check openai_api_multimodal branch for e2e, and #1123 (comment) for better structure and llama3.1 support
Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR itself looks find to me with some minor nits.

@@ -35,6 +35,14 @@
from torchchat.utils.measure_time import measure_time
from torchchat.utils.quantize import quantize_model

# bypass the import issue before torchao is ready on macos
try:
from torchtune.training import set_default_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import?

@@ -11,6 +11,7 @@
from enum import Enum
from pathlib import Path
from typing import Callable, Dict, Optional, Union
from abc import ABC, abstractmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ABC is unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be one of the Model's parents. Fixed it.

Comment on lines 304 to 307
if isinstance(self.config.transformer_args[name], dict):
modules[name] = module_class(**self.config.transformer_args[name])
else:
modules[name] = module_class(self.config.transformer_args[name])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(self.config.transformer_args[name], dict):
modules[name] = module_class(**self.config.transformer_args[name])
else:
modules[name] = module_class(self.config.transformer_args[name])
if isinstance(config_args := self.config.transformer_args[name], dict):
modules[name] = module_class(**config_args)
else:
modules[name] = module_class(config_args)



class FlamingoModel(Model):
def forward(self, tokens: Tensor, encoder_input: Optional[Dict[str, Tensor]] = None, encoder_mask: Optional[Tensor] = None) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lint long line

@Gasoonjia Gasoonjia merged commit e2049f4 into main Sep 11, 2024
51 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants