|
24 | 24 | from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
25 | 25 | from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
|
26 | 26 |
|
| 27 | +from .distributed_checkpoint_utils import ( |
| 28 | + create_model_metadata, |
| 29 | + is_pytorch_model_meta_dist_file, |
| 30 | + load_dist_model, |
| 31 | + save_dist_sharded_model, |
| 32 | + save_dist_unshard_model, |
| 33 | +) |
27 | 34 | from .general_checkpoint_io import GeneralCheckpointIO
|
28 | 35 | from .index_file import CheckpointIndexFile
|
29 | 36 | from .utils import (
|
|
47 | 54 | sharded_optimizer_loading_epilogue,
|
48 | 55 | )
|
49 | 56 |
|
50 |
| -from .distributed_checkpoint_utils import ( |
51 |
| - save_dist_sharded_model, |
52 |
| - save_dist_unshard_model, |
53 |
| - load_dist_model, |
54 |
| - is_pytorch_model_meta_dist_file, |
55 |
| - create_model_metadata |
56 |
| -) |
57 |
| - |
58 | 57 | try:
|
59 | 58 | from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
60 | 59 | except ImportError:
|
@@ -244,10 +243,20 @@ def save_sharded_model(
|
244 | 243 | return
|
245 | 244 | dist_id = self.tp_size * self.pp_rank + self.tp_rank
|
246 | 245 | model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
|
247 |
| - async_writers = save_dist_sharded_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) |
| 246 | + async_writers = save_dist_sharded_model( |
| 247 | + model=model, |
| 248 | + model_metadata=model_metadata, |
| 249 | + checkpoint=checkpoint, |
| 250 | + prefix=prefix, |
| 251 | + size_per_shard=size_per_shard, |
| 252 | + use_safetensors=use_safetensors, |
| 253 | + use_async=use_async, |
| 254 | + dist_id=dist_id, |
| 255 | + pinned_state_dicts=self.pinned_state_dicts, |
| 256 | + ) |
248 | 257 | self.async_writers.extend(async_writers)
|
249 | 258 | return
|
250 |
| - |
| 259 | + |
251 | 260 | model = model.unwrap()
|
252 | 261 |
|
253 | 262 | if os.path.isfile(checkpoint):
|
@@ -396,9 +405,15 @@ def load_sharded_model(
|
396 | 405 | if is_pytorch_model_meta_dist_file(checkpoint_index_file):
|
397 | 406 | model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
|
398 | 407 | checkpoint = checkpoint_index_file.parent
|
399 |
| - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) |
| 408 | + load_dist_model( |
| 409 | + model=model, |
| 410 | + model_metadata=model_metadata, |
| 411 | + checkpoint=checkpoint, |
| 412 | + low_cpu_mem_mode=low_cpu_mem_mode, |
| 413 | + num_threads=num_threads, |
| 414 | + ) |
400 | 415 | return
|
401 |
| - |
| 416 | + |
402 | 417 | model_before_wrapping = model # backup for model before wrapping
|
403 | 418 | model = model.unwrap()
|
404 | 419 |
|
@@ -794,11 +809,19 @@ def save_unsharded_model(
|
794 | 809 | if self.dp_rank != 0 and self.sp_rank != 0:
|
795 | 810 | return
|
796 | 811 | dist_id = self.tp_size * self.pp_rank + self.tp_rank
|
797 |
| - writer= save_dist_unshard_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) |
| 812 | + writer = save_dist_unshard_model( |
| 813 | + model=model, |
| 814 | + model_metadata=model_metadata, |
| 815 | + checkpoint=checkpoint, |
| 816 | + use_safetensors=use_safetensors, |
| 817 | + use_async=use_async, |
| 818 | + dist_id=dist_id, |
| 819 | + pinned_state_dicts=self.pinned_state_dicts, |
| 820 | + ) |
798 | 821 | if writer is not None:
|
799 | 822 | self.async_writers.append(writer)
|
800 | 823 | return
|
801 |
| - |
| 824 | + |
802 | 825 | model = model.unwrap()
|
803 | 826 | if self.dp_rank != 0:
|
804 | 827 | return
|
@@ -871,7 +894,13 @@ def load_unsharded_model(
|
871 | 894 | for filename in os.listdir(checkpoint):
|
872 | 895 | if is_pytorch_model_meta_dist_file(filename):
|
873 | 896 | model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
|
874 |
| - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) |
| 897 | + load_dist_model( |
| 898 | + model=model, |
| 899 | + model_metadata=model_metadata, |
| 900 | + checkpoint=checkpoint, |
| 901 | + low_cpu_mem_mode=low_cpu_mem_mode, |
| 902 | + num_threads=num_threads, |
| 903 | + ) |
875 | 904 | return
|
876 | 905 |
|
877 | 906 | strict = False
|
@@ -1103,7 +1132,6 @@ def gather_from_sharded_optimizer_state(
|
1103 | 1132 | dist.all_gather(gather_tensor, v, group=dp_group)
|
1104 | 1133 | v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
1105 | 1134 |
|
1106 |
| - |
1107 | 1135 | # Then gather TP shards.
|
1108 | 1136 | partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
|
1109 | 1137 | if partition_dim is not None:
|
|
0 commit comments