-
Notifications
You must be signed in to change notification settings - Fork 346
Fast dataset resume #1082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fast dataset resume #1082
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making resuming dataset checkpoint efficient! I had some inline comments. Please see if they make sense to you.
Regarding
On the other hand, for iterable-style datasets, calling
.skip
aftersplit_dataset_by_node
splits the number of elements to skip across the ranks (e.g. calling.skip(10)
aftersplit_dataset_by_node(<rank>, 2)
effectively skips 5 (10 // 2 = 5
) elements on each rank), which isn'r what we want/expect, so removing.skip
was justified there.
I'm a bit confused by your comment. It sounds that the behavior of skip
for IterableDataset
is oblivious of whether it has gone through split_dataset_by_node
or not, which is not intuitive?
assert torch.equal(input_ids["input"], expected_input_ids["input"]) | ||
assert torch.equal(labels, expected_labels) | ||
for dataset_name in ["c4_test", "c4_test_streaming"]: | ||
dataset_name = "c4_test" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should remove
torchtitan/datasets/hf_datasets.py
Outdated
"c4_test_streaming": DatasetConfig( | ||
path="tests/assets/c4_test", | ||
loader=lambda path: load_dataset(path, split="train", streaming=True), | ||
text_processor=_process_c4_text, | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is only used in unit test, can we only register it in the test file?
torchtitan/datasets/hf_datasets.py
Outdated
@@ -97,15 +102,21 @@ def __init__( | |||
# Variables for checkpointing | |||
self._sample_idx = 0 | |||
self._all_tokens: list[int] = [] | |||
self._data_state_dict_loaded = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for considering backward compatibility. At this moment, I think we should just move with the right and efficient behavior, without worrying too much about it.
} | ||
|
||
if isinstance(self._data, IterableDataset): | ||
_state_dict["data"] = self._data.state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supported in datasets 2.21.0
, or newer version? If so, we need to update requirements.txt
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also let's add some comment on the efficiency of this solution with reference to https://huggingface.co/docs/datasets/v3.5.0/en/stream#save-a-dataset-checkpoint-and-resume-iteration
torchtitan/datasets/hf_datasets.py
Outdated
@@ -138,8 +149,23 @@ def load_state_dict(self, state_dict): | |||
self._sample_idx = state_dict["sample_idx"] | |||
self._all_tokens = state_dict["token_buffer"] | |||
|
|||
if isinstance(self._data, IterableDataset): | |||
if "data" in state_dict: # backward compatibility |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By not worrying about BC breaking, let's assert the existence of "data"
torchtitan/datasets/hf_datasets.py
Outdated
elif not self._data_state_dict_loaded: # backward compatibility | ||
it = iter(self._data) | ||
for _ in range(self._sample_idx): | ||
next(it) | ||
return it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove, for simplicity
torchtitan/datasets/hf_datasets.py
Outdated
@@ -138,8 +149,23 @@ def load_state_dict(self, state_dict): | |||
self._sample_idx = state_dict["sample_idx"] | |||
self._all_tokens = state_dict["token_buffer"] | |||
|
|||
if isinstance(self._data, IterableDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I'm curious is -- why don't map-style datasets provide the same API to make API consistent? It could be as simple as what we do here using self._data.skip(self._sample_idx)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dunno, I'm no longer involved in the datasets
project, so I don't know exactly, but I assume the reasoning is that it's not too hard to resume map-style datasets manually, as explained in huggingface/datasets#5454 (comment)
torchtitan/datasets/hf_datasets.py
Outdated
def state_dict(self): | ||
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} | ||
_state_dict = { | ||
"token_buffer": self._all_tokens, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you do torchtitan a favor and rename self._all_tokens
to self._token_buffer
, to make things consistent?
torchtitan/datasets/hf_datasets.py
Outdated
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} | ||
_state_dict = { | ||
"token_buffer": self._all_tokens, | ||
"sample_idx": self._sample_idx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we save this in state dict only when it's map-style dataset?
Can you we have an accuracy verification for this PR? I believe llama3 8B can reproduce the loss issue if the dataset doesn't resume correctly. |
Sorry for the delay, I addressed the comments and made the test much more robust.
It's hard to explain because the
I don't have access to A100 / H100 GPUs right now, so it would be great if someone else could do the run. I improved the test significantly (e.g. now it re-loops the test datasets), so I'm not sure if this is really needed, though. |
Can you also fix the linter error and integration test error? I will try if I can verify with llama3. |
Unfortunately, one more bug needs to be fixed, this time directly in EDIT: |
This PR makes resuming dataset iteration from a checkpoint fast again.
This performance regression comes from #838. In that PR,
.skip
is removed for both map-style and iterable-style datasets for correctness reasons. However,.skip
works as expected for map-style datasets, so the change can be reverted for that case. On the other hand, for iterable-style datasets, calling.skip
aftersplit_dataset_by_node
splits the number of elements to skip across the ranks (e.g. calling.skip(10)
aftersplit_dataset_by_node(<rank>, 2)
effectively skips 5 (10 // 2 = 5
) elements on each rank), which isn'r what we want/expect, so removing.skip
was justified there. Still, we can make the whole thing much faster using thestate_dict
API for iterable-style datasets, which avoids re-iterating past shards/files when resuming.