Skip to content

Fast dataset resume #1082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Conversation

mariosasko
Copy link

@mariosasko mariosasko commented Apr 9, 2025

This PR makes resuming dataset iteration from a checkpoint fast again.

This performance regression comes from #838. In that PR, .skip is removed for both map-style and iterable-style datasets for correctness reasons. However, .skip works as expected for map-style datasets, so the change can be reverted for that case. On the other hand, for iterable-style datasets, calling .skip after split_dataset_by_node splits the number of elements to skip across the ranks (e.g. calling .skip(10) after split_dataset_by_node(<rank>, 2) effectively skips 5 (10 // 2 = 5) elements on each rank), which isn'r what we want/expect, so removing .skip was justified there. Still, we can make the whole thing much faster using the state_dict API for iterable-style datasets, which avoids re-iterating past shards/files when resuming.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 9, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making resuming dataset checkpoint efficient! I had some inline comments. Please see if they make sense to you.

Regarding

On the other hand, for iterable-style datasets, calling .skip after split_dataset_by_node splits the number of elements to skip across the ranks (e.g. calling .skip(10) after split_dataset_by_node(<rank>, 2) effectively skips 5 (10 // 2 = 5) elements on each rank), which isn'r what we want/expect, so removing .skip was justified there.

I'm a bit confused by your comment. It sounds that the behavior of skip for IterableDataset is oblivious of whether it has gone through split_dataset_by_node or not, which is not intuitive?

assert torch.equal(input_ids["input"], expected_input_ids["input"])
assert torch.equal(labels, expected_labels)
for dataset_name in ["c4_test", "c4_test_streaming"]:
dataset_name = "c4_test"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should remove

Comment on lines 52 to 56
"c4_test_streaming": DatasetConfig(
path="tests/assets/c4_test",
loader=lambda path: load_dataset(path, split="train", streaming=True),
text_processor=_process_c4_text,
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is only used in unit test, can we only register it in the test file?

@@ -97,15 +102,21 @@ def __init__(
# Variables for checkpointing
self._sample_idx = 0
self._all_tokens: list[int] = []
self._data_state_dict_loaded = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for considering backward compatibility. At this moment, I think we should just move with the right and efficient behavior, without worrying too much about it.

}

if isinstance(self._data, IterableDataset):
_state_dict["data"] = self._data.state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supported in datasets 2.21.0, or newer version? If so, we need to update requirements.txt.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also let's add some comment on the efficiency of this solution with reference to https://huggingface.co/docs/datasets/v3.5.0/en/stream#save-a-dataset-checkpoint-and-resume-iteration

@@ -138,8 +149,23 @@ def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]

if isinstance(self._data, IterableDataset):
if "data" in state_dict: # backward compatibility
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By not worrying about BC breaking, let's assert the existence of "data"

Comment on lines 113 to 117
elif not self._data_state_dict_loaded: # backward compatibility
it = iter(self._data)
for _ in range(self._sample_idx):
next(it)
return it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove, for simplicity

@@ -138,8 +149,23 @@ def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]

if isinstance(self._data, IterableDataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I'm curious is -- why don't map-style datasets provide the same API to make API consistent? It could be as simple as what we do here using self._data.skip(self._sample_idx).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dunno, I'm no longer involved in the datasets project, so I don't know exactly, but I assume the reasoning is that it's not too hard to resume map-style datasets manually, as explained in huggingface/datasets#5454 (comment)

def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
_state_dict = {
"token_buffer": self._all_tokens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you do torchtitan a favor and rename self._all_tokens to self._token_buffer, to make things consistent?

return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
_state_dict = {
"token_buffer": self._all_tokens,
"sample_idx": self._sample_idx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we save this in state dict only when it's map-style dataset?

@fegin
Copy link
Contributor

fegin commented Apr 14, 2025

Can you we have an accuracy verification for this PR? I believe llama3 8B can reproduce the loss issue if the dataset doesn't resume correctly.

@mariosasko
Copy link
Author

mariosasko commented Apr 22, 2025

Sorry for the delay, I addressed the comments and made the test much more robust.

I'm a bit confused by your comment. It sounds that the behavior of skip for IterableDataset is oblivious of whether it has gone through split_dataset_by_node or not, which is not intuitive?

It's hard to explain because the datasets logic seems buggy (e.g. this test in datasets should fail with correct parentheses). It should be easier to understand with a toy example. What surprised me is that changing the number of the data shards results in a completely different behaviour.

Can you we have an accuracy verification for this PR? I believe llama3 8B can reproduce the loss issue if the dataset doesn't resume correctly.

I don't have access to A100 / H100 GPUs right now, so it would be great if someone else could do the run. I improved the test significantly (e.g. now it re-loops the test datasets), so I'm not sure if this is really needed, though.

@fegin
Copy link
Contributor

fegin commented Apr 24, 2025

Can you also fix the linter error and integration test error? I will try if I can verify with llama3.

@mariosasko
Copy link
Author

mariosasko commented Apr 27, 2025

Unfortunately, one more bug needs to be fixed, this time directly in datasets ...

EDIT:
Reported in datasets: huggingface/datasets#7538

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants