Skip to content

[RFC] Pass the original input to all PP stages #1130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

[RFC] Pass the original input to all PP stages #1130

wants to merge 2 commits into from

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Apr 22, 2025

We need the original tokens to generate the document masks/block causal masks. Since TorchTitan currently let all ranks perform data loading, there will be no performance regressions.

This is required to support document masking attention with PP.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 22, 2025
Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on PP side looks good. Does the current llama3 and 4 not support document masking? How come model changes are needed?

@fegin
Copy link
Contributor Author

fegin commented Apr 22, 2025

@H-Huang llama3 doesn't have this. And document masking + PP is one missing feature for llama4.

fegin added 2 commits April 24, 2025 09:54
We need the original tokens to generate the document masks/block causal masks. Since TorchTitan currently let all ranks perform data loading, there will be no performance regressions.
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some nit comments.

CI is broken by compile + SAC. Please make sure this change works before merge.

If pipeline parallelism is enabled, this will be the input token indices
for the ranks on the first pipeline stage. This will be the activation of the
previous pipeline stage if the current rank is not on the first stage.
input_batch (torch.Tensor): The input batch read from the dataloader.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment that -- this field is needed for non-first PP stages to obtain proper document masks

@@ -351,7 +355,7 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
# Non-PP forward / backward
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
pred = model_parts[0](inputs)
pred = model_parts[0](inputs, input_batch=inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non PP branch looks a bit strange -- I slightly prefer the alternative way of making input_batch optional, and let init_attention_mask use tokens if input_batch is None. The idea is that non PP users see less universal usage of input_batch

@tianyu-l tianyu-l mentioned this pull request Apr 27, 2025
12 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants