Skip to content

Commit a28fdde

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e3f9de3 commit a28fdde

File tree

4 files changed

+77
-29
lines changed

4 files changed

+77
-29
lines changed

Diff for: colossalai/checkpoint_io/distributed_checkpoint_utils.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def dist_model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool =
5858
destination[extra_state_key] = extra_state
5959
return destination
6060

61+
6162
def load_state_dict_into_dist_model(
6263
model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False
6364
):
@@ -86,11 +87,12 @@ def load_state_dict_into_dist_model(
8687
extra_state.copy_(state_dict[extra_state_key])
8788
return destination
8889

90+
8991
def create_model_metadata(
9092
model: nn.Module,
9193
prefix: str = "",
92-
tp_size = None,
93-
tp_rank = None,
94+
tp_size=None,
95+
tp_rank=None,
9496
):
9597
param_origin_shape = model.param_origin_shape
9698
model = model.unwrap()
@@ -110,11 +112,12 @@ def create_model_metadata(
110112
partition_size = param.shape[tp_partition_dim]
111113
model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * tp_rank
112114
if tp_rank == tp_size - 1:
113-
model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[
114-
tp_partition_dim
115-
] - (partition_size * (tp_size - 1))
115+
model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - (
116+
partition_size * (tp_size - 1)
117+
)
116118
return model_metadata
117119

120+
118121
def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_size=None):
119122
metadata_dicts = {
120123
"checkpoint_version": "1.0",
@@ -133,6 +136,7 @@ def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_siz
133136
with open(metadata_file, "w") as json_file:
134137
json.dump(metadata_dicts, json_file, indent=4)
135138

139+
136140
def load_metadata(checkpoint: str):
137141
metadata_dict = {}
138142
for filename in os.listdir(checkpoint):
@@ -197,6 +201,7 @@ def find_covering_shards(shards, target_offsets, target_lengths):
197201
assert total_lengths == global_shape
198202
return covering_shards
199203

204+
200205
def extract_weight_from_shard_partial(shard, target_offsets, target_lengths):
201206
"""
202207
Extract the target range of weights from shard data, supporting partial overlap.
@@ -233,6 +238,7 @@ def extract_weight_from_shard_partial(shard, target_offsets, target_lengths):
233238
target_weight = weight[tuple(slices)]
234239
return target_weight, target_slices
235240

241+
236242
def assemble_tensor_from_shards_partial(shards, target_offsets, target_lengths, dtype):
237243
target_tensor = torch.zeros(target_lengths, dtype=dtype)
238244

@@ -310,7 +316,13 @@ def dist_model_sharder(
310316

311317

312318
def save_dist_unshard_model(
313-
model: ModelWrapper, model_metadata: Dict, checkpoint: str, use_safetensors: bool, use_async: bool = False, dist_id = 0, pinned_state_dicts = None
319+
model: ModelWrapper,
320+
model_metadata: Dict,
321+
checkpoint: str,
322+
use_safetensors: bool,
323+
use_async: bool = False,
324+
dist_id=0,
325+
pinned_state_dicts=None,
314326
):
315327
"""
316328
Save model state dict to a single file with given checkpointing path.
@@ -426,7 +438,7 @@ def save_dist_sharded_model(
426438
use_safetensors: bool = False,
427439
use_async: bool = False,
428440
dist_id: int = 0,
429-
pinned_state_dicts = None,
441+
pinned_state_dicts=None,
430442
) -> None:
431443
"""
432444
Save sharded model checkpoint under the given checkpointing path.
@@ -463,9 +475,7 @@ def save_dist_sharded_model(
463475
pinned_state_dicts = pinned_state_dicts[id(model)]
464476
else:
465477
pinned_state_dicts = None
466-
state_dict_shard = dist_model_sharder(
467-
model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts
468-
)
478+
state_dict_shard = dist_model_sharder(model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts)
469479
weights_name, _ = get_model_base_filenames(prefix, use_safetensors)
470480
index_file = CheckpointIndexFile(checkpoint)
471481

Diff for: colossalai/checkpoint_io/general_checkpoint_io.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -309,4 +309,4 @@ def load_sharded_model(
309309
)
310310

311311
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
312-
raise NotImplementedError
312+
raise NotImplementedError

Diff for: colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

+44-16
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
2525
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
2626

27+
from .distributed_checkpoint_utils import (
28+
create_model_metadata,
29+
is_pytorch_model_meta_dist_file,
30+
load_dist_model,
31+
save_dist_sharded_model,
32+
save_dist_unshard_model,
33+
)
2734
from .general_checkpoint_io import GeneralCheckpointIO
2835
from .index_file import CheckpointIndexFile
2936
from .utils import (
@@ -47,14 +54,6 @@
4754
sharded_optimizer_loading_epilogue,
4855
)
4956

50-
from .distributed_checkpoint_utils import (
51-
save_dist_sharded_model,
52-
save_dist_unshard_model,
53-
load_dist_model,
54-
is_pytorch_model_meta_dist_file,
55-
create_model_metadata
56-
)
57-
5857
try:
5958
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
6059
except ImportError:
@@ -244,10 +243,20 @@ def save_sharded_model(
244243
return
245244
dist_id = self.tp_size * self.pp_rank + self.tp_rank
246245
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
247-
async_writers = save_dist_sharded_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts)
246+
async_writers = save_dist_sharded_model(
247+
model=model,
248+
model_metadata=model_metadata,
249+
checkpoint=checkpoint,
250+
prefix=prefix,
251+
size_per_shard=size_per_shard,
252+
use_safetensors=use_safetensors,
253+
use_async=use_async,
254+
dist_id=dist_id,
255+
pinned_state_dicts=self.pinned_state_dicts,
256+
)
248257
self.async_writers.extend(async_writers)
249258
return
250-
259+
251260
model = model.unwrap()
252261

253262
if os.path.isfile(checkpoint):
@@ -396,9 +405,15 @@ def load_sharded_model(
396405
if is_pytorch_model_meta_dist_file(checkpoint_index_file):
397406
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
398407
checkpoint = checkpoint_index_file.parent
399-
load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads)
408+
load_dist_model(
409+
model=model,
410+
model_metadata=model_metadata,
411+
checkpoint=checkpoint,
412+
low_cpu_mem_mode=low_cpu_mem_mode,
413+
num_threads=num_threads,
414+
)
400415
return
401-
416+
402417
model_before_wrapping = model # backup for model before wrapping
403418
model = model.unwrap()
404419

@@ -794,11 +809,19 @@ def save_unsharded_model(
794809
if self.dp_rank != 0 and self.sp_rank != 0:
795810
return
796811
dist_id = self.tp_size * self.pp_rank + self.tp_rank
797-
writer= save_dist_unshard_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts)
812+
writer = save_dist_unshard_model(
813+
model=model,
814+
model_metadata=model_metadata,
815+
checkpoint=checkpoint,
816+
use_safetensors=use_safetensors,
817+
use_async=use_async,
818+
dist_id=dist_id,
819+
pinned_state_dicts=self.pinned_state_dicts,
820+
)
798821
if writer is not None:
799822
self.async_writers.append(writer)
800823
return
801-
824+
802825
model = model.unwrap()
803826
if self.dp_rank != 0:
804827
return
@@ -871,7 +894,13 @@ def load_unsharded_model(
871894
for filename in os.listdir(checkpoint):
872895
if is_pytorch_model_meta_dist_file(filename):
873896
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
874-
load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads)
897+
load_dist_model(
898+
model=model,
899+
model_metadata=model_metadata,
900+
checkpoint=checkpoint,
901+
low_cpu_mem_mode=low_cpu_mem_mode,
902+
num_threads=num_threads,
903+
)
875904
return
876905

877906
strict = False
@@ -1103,7 +1132,6 @@ def gather_from_sharded_optimizer_state(
11031132
dist.all_gather(gather_tensor, v, group=dp_group)
11041133
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
11051134

1106-
11071135
# Then gather TP shards.
11081136
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
11091137
if partition_dim is not None:

Diff for: tests/test_checkpoint_io/test_dist_checkpointio.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,12 @@ def _preprocess_data(data):
7979
model_ckpt_path_0 = f"{tempdir}/model_0"
8080

8181
booster_0.save_model(
82-
model_0, model_ckpt_path_0, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async
82+
model_0,
83+
model_ckpt_path_0,
84+
shard=shard,
85+
gather_dtensor=True,
86+
size_per_shard=size_per_shard,
87+
use_async=use_async,
8388
)
8489
booster_0.checkpoint_io._sync_d2h()
8590
booster_0.checkpoint_io._sync_io()
@@ -96,7 +101,12 @@ def _preprocess_data(data):
96101

97102
model_ckpt_path_1 = f"{tempdir}/model_1"
98103
booster_1.save_model(
99-
model_1, model_ckpt_path_1, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async
104+
model_1,
105+
model_ckpt_path_1,
106+
shard=shard,
107+
gather_dtensor=True,
108+
size_per_shard=size_per_shard,
109+
use_async=use_async,
100110
)
101111
booster_1.checkpoint_io._sync_d2h()
102112
booster_1.checkpoint_io._sync_io()

0 commit comments

Comments
 (0)