Skip to content

Commit 70db7e7

Browse files
authored
Open source 0725 patch (#42)
1 parent aa95bb7 commit 70db7e7

29 files changed

+74
-1620
lines changed

README.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ _**An Industrial-Level Framework for Easy-of-Use**_
2020

2121
## Latest News
2222

23+
- [2024-7-25] veScale's [pipeline parallelism](https://github.com/volcengine/veScale/blob/main/vescale/pipe/README.md) open sourced with API, graph parser, stage abstraction, schedules and execution runtime along with [nD distributed timeline](https://github.com/volcengine/veScale/blob/main/vescale/ndtimeline/README.md).
24+
2325
- [2024-5-31] veScale's [fast checkpointing system](https://github.com/volcengine/veScale/blob/main/vescale/checkpoint/README.md) open sourced with automatic checkpoint resharding, caching, load-balancing, fast copying, deduplicating, and asynchronous io.
2426

2527
- [2024-5-21] veScale's examples ([Mixtral](https://github.com/volcengine/veScale/tree/main/examples/mixtral_4D_training), [LLama2](https://github.com/volcengine/veScale/tree/main/examples/llama2_4D_finetune), and [nanoGPT](https://github.com/volcengine/veScale/tree/main/examples/nanogpt_4D_finetune)) open sourced with bit-wise correctness of training loss curves.
@@ -32,7 +34,11 @@ _**An Industrial-Level Framework for Easy-of-Use**_
3234

3335
_**veScale**_ is still in its early phase. We are refactoring our internal LLM training system components to meet open source standard. The tentative timeline is as follows:
3436

35-
- by end of July, CUDA event monitor, pipeline parallelism and supporting components for large-scale training
37+
- High-level [nD parallel api](https://github.com/volcengine/veScale/issues/39) for extreme ease of use
38+
39+
- Power-user plan api for easy customization of nD parallel training
40+
41+
- End-to-end vescale/examples with 5D parallel training (TP, SP, DP, ZeRO, PP)
3642

3743
## Table of Content ([web view](https://volcengine.github.io/veScaleWeb/))
3844

examples/open_llama_4D_benchmark/download_open_llama_ckpt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

examples/open_llama_4D_benchmark/llama_mfu_calculator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

examples/open_llama_4D_benchmark/run_open_llama_w_vescale.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

examples/open_llama_4D_benchmark/sharding_plan.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

test/checkpoint/open_llama/test_open_llama_dp_reshard.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

test/checkpoint/open_llama/test_open_llama_load_save.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

test/checkpoint/open_llama/test_open_llama_tp_reshard.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

test/parallel/pipeline/api/test_pipe_single_stage_ops.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch.testing._internal.common_utils import run_tests
2222
from vescale.devicemesh_api import VESCALE_DEVICE_MESH
2323
from vescale.plan import PipelineScheduleType, PipelineParallelPlan, ModeType, PipelineSplitMethodType
24-
from vescale.pipe.pipe_stage import PipeModule, construct_stage_modules
24+
from vescale.pipe.pipe_stage import construct_pipeline_stage
2525
from vescale.engine import PipeEngine
2626
from common_dtensor import DTensorTestBase, with_comms
2727
from torch.optim import SGD
@@ -132,9 +132,7 @@ def test_stage_forward(self):
132132
def _run_no_pp_model(self):
133133
os.environ["model_name"] = "golden"
134134
model = EightMLP().to("cuda:0")
135-
optimizer = torch.optim.SGD(
136-
model.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False
137-
)
135+
optimizer = SGD(model.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False)
138136
torch.manual_seed(9999)
139137
batch = [torch.ones(microbatch_size, 128, 32, dtype=torch.float32).to("cuda:0") for _ in range(factor)]
140138
for mb in batch:
@@ -166,13 +164,6 @@ def _run_stage_forward(self):
166164
mesh_dim_names=("PP", "DP", "TP"),
167165
)
168166

169-
stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
170-
model,
171-
config,
172-
VESCALE_DEVICE_MESH,
173-
update_split_points=True,
174-
)
175-
176167
optimizer_fn_kwargs = {
177168
"lr": 0.01,
178169
"momentum": 0,
@@ -183,9 +174,16 @@ def _run_stage_forward(self):
183174
"foreach": None,
184175
"differentiable": False,
185176
}
186-
_parameters = list(stage_modules[0].parameters()) + list(stage_modules[1].parameters())
187-
optimizer = SGD(_parameters, **optimizer_fn_kwargs)
188-
pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, config)
177+
178+
pipe_module = construct_pipeline_stage(
179+
model,
180+
config,
181+
VESCALE_DEVICE_MESH,
182+
lr_scheduler=None,
183+
update_split_points=True,
184+
)
185+
optimizer = SGD(pipe_module.parameters(), **optimizer_fn_kwargs)
186+
pipe_module.doptimizer = optimizer
189187

190188
engine = PipeEngine(
191189
pipe_module,

test/parallel/pipeline/api/test_schedule_engine.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from common_dtensor import DTensorTestBase, with_comms
2121
from torch.testing._internal.common_utils import run_tests
2222
from vescale.devicemesh_api import VESCALE_DEVICE_MESH
23-
from vescale.pipe.pipe_stage import PipeModule, construct_stage_modules
23+
from vescale.pipe.pipe_stage import construct_pipeline_stage
2424
from vescale.pipe._schedules.instruction_base import StageDeps
2525
from vescale.pipe.pipe_emmiter import ScheduleEngine
2626
from vescale.plan.spec import PipelineScheduleType, ModeType, PipelineSplitMethodType
@@ -79,13 +79,6 @@ def test_simple_1f1b(self):
7979
schedule_type=PipelineScheduleType.SIMPLE_1F1B,
8080
)
8181

82-
stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
83-
model,
84-
config,
85-
VESCALE_DEVICE_MESH,
86-
update_split_points=True,
87-
)
88-
8982
optimizer_fn_kwargs = {
9083
"lr": 0.01,
9184
"momentum": 0,
@@ -96,9 +89,15 @@ def test_simple_1f1b(self):
9689
"foreach": None,
9790
"differentiable": False,
9891
}
99-
_parameters = list(stage_modules[0].parameters())
100-
optimizer = SGD(_parameters, **optimizer_fn_kwargs)
101-
pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, config)
92+
pipe_module = construct_pipeline_stage(
93+
model,
94+
config,
95+
VESCALE_DEVICE_MESH,
96+
lr_scheduler=None,
97+
update_split_points=True,
98+
)
99+
optimizer = SGD(pipe_module.parameters(), **optimizer_fn_kwargs)
100+
pipe_module.doptimizer = optimizer
102101

103102
dep = pipe_module.stage_deps
104103
device_mesh_list = VESCALE_DEVICE_MESH.get_global_tensor_parallel_meshes()

test/parallel/pipeline/api/test_simple_api.py

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import torch
2121
import torch.nn as nn
2222
from torch.testing._internal.common_utils import run_tests
23-
from vescale.debug.pdb import ForkedPdb
2423
from vescale.optim.base_optimizer import BasicOptimizer
2524
from vescale.pipe.pipe_stage import construct_pipeline_stage
2625
from vescale.devicemesh_api import VESCALE_DEVICE_MESH

test/parallel/pipeline/e2e/test_pp_accuracy_alignment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _run_engine_with_1f1b(self, fixed_size=True):
225225
pipe_config,
226226
)
227227

228-
engine.forward_backward(batch)
228+
engine(batch)
229229
optimizer = engine.get_optimizer
230230
optimizer.step()
231231

vescale/engine/pipe.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
module: PipeModule,
3737
global_mesh: VeDeviceMesh,
3838
loss_fn: Callable,
39-
config: PipelineParallelPlan,
39+
plan: PipelineParallelPlan,
4040
):
4141
"""
4242
Training engine for pipeline parallelism and multi-dimensional
@@ -46,8 +46,8 @@ def __init__(
4646
training, and optimizer synchronization.
4747
"""
4848
self.module = module
49-
self.virtual_chunks_per_stage = config.virtual_chunks
50-
self.engine_config = config
49+
self.virtual_chunks_per_stage = plan.virtual_chunks
50+
self.engine_plan = plan
5151
self.optimizer = self.module.get_optimizer
5252
self.lr_scheduler = self.module.get_lr_scheduler
5353
self.global_mesh = global_mesh
@@ -59,16 +59,16 @@ def __init__(
5959
except: # noqa: E722
6060
self.loss_fn = loss_fn
6161
self.schedule_engine = None
62-
self.reuse_comm_shape = self.engine_config.reuse_p2p_tensor_shape
62+
self.reuse_comm_shape = self.engine_plan.reuse_p2p_tensor_shape
6363
if self.reuse_comm_shape:
6464
os.environ["REUSE_COMM_SHAPE"] = "1"
6565
if (
66-
self.engine_config.schedule_type == PipelineScheduleType.INTERLEAVED_1F1B
66+
self.engine_plan.schedule_type == PipelineScheduleType.INTERLEAVED_1F1B
6767
and self.virtual_chunks_per_stage == 1
6868
):
6969
print("[warning]: #virtual pipeline chunks is 1. Falling back to simple 1F1B schedule.")
70-
self.engine_config.schedule_type = PipelineScheduleType.SIMPLE_1F1B
71-
self.schedule_type = self.engine_config.schedule_type
70+
self.engine_plan.schedule_type = PipelineScheduleType.SIMPLE_1F1B
71+
self.schedule_type = self.engine_plan.schedule_type
7272

7373
def build_schedule(self, minibatches, data_shape=None):
7474
"""
@@ -105,7 +105,7 @@ def _locate_tp_mesh(_rank):
105105
)
106106
num_minibatches = self._align_num_batches(first_stage_rank, len(minibatches))
107107
# TODO: insert shape inference
108-
batch_p2p_comm = self.engine_config.batch_p2p_comm
108+
batch_p2p_comm = self.engine_plan.batch_p2p_comm
109109
# if on interleaved 1f1b schedule, set batch_p2p_comm to False to execute p2p communication
110110
schedule_type = self.schedule_type
111111
if schedule_type in [PipelineScheduleType.INTERLEAVED_1F1B, PipelineScheduleType.ZERO_BUBBLE]:
@@ -123,16 +123,16 @@ def _locate_tp_mesh(_rank):
123123
data_iterator=data_iterator,
124124
stage_id=self.global_mesh.get_pipeline_parallel_rank(),
125125
shape=data_shape,
126-
dtype=self.engine_config.p2p_tensor_dtype,
126+
dtype=self.engine_plan.p2p_tensor_dtype,
127127
num_chunks=self.virtual_chunks_per_stage,
128128
input_shapes=None,
129129
input_shapes_unpad=None,
130130
# send_dtypes_map=self.module.recv_dtypes_dict,
131-
overlap_p2p_comm=self.engine_config.overlap_p2p_comm,
131+
overlap_p2p_comm=self.engine_plan.overlap_p2p_comm,
132132
batch_p2p_comm=batch_p2p_comm,
133133
loss_fn=self.loss_fn,
134134
global_mesh=self.global_mesh,
135-
forward_only=self.engine_config.forward_only,
135+
forward_only=self.engine_plan.forward_only,
136136
)
137137

138138
def forward_backward(
@@ -211,7 +211,7 @@ def parameters(self, including_frozen=False):
211211
def sync_shared_params(self, group_id: int = 0, share_params=True) -> None:
212212
"""
213213
Synchronize gradients and weights among groups of specified units, dictated by
214-
"partition_units" in PipelineConfig. Typically, this function is used for
214+
"partition_units" in PipelineParallelPlan. Typically, this function is used for
215215
synchronizing gradients and weights of embeddings layers in Transformer-based
216216
architecture.
217217
Args:

vescale/model/base_gpt/__init__.py

-5
This file was deleted.

0 commit comments

Comments
 (0)