Skip to content

Commit aa95bb7

Browse files
authored
PP API and nD Distributed Timeline Profiling (#41)
1 parent c4afc72 commit aa95bb7

File tree

98 files changed

+18591
-23
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+18591
-23
lines changed

docs/pictures/ndtimeline_arch.jpg

129 KB
Loading

docs/pictures/ndtimeline_trace.png

238 KB
Loading

docs/pictures/pp.png

82.5 KB
Loading

examples/open_llama_4D_benchmark/download_open_llama_ckpt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

examples/open_llama_4D_benchmark/llama_mfu_calculator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

examples/open_llama_4D_benchmark/run_open_llama_w_vescale.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

examples/open_llama_4D_benchmark/sharding_plan.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pytest
66
tqdm
77
optree
88
accelerate
9-
transformers==4.37.2
9+
transformers==4.40.2
1010
flash_attn
11+
matplotlib
1112
mmh3

test/checkpoint/nano_gpt/test_nano_gpt_load_save.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ def init_method(self):
101101
@skip_unless_torch_gpu
102102
@with_comms
103103
def test_load(self):
104-
ddp_gpt, dist_optimizer, _ = build_gpt_model_optimizer_and_dataset(
105-
self.init_method, dp_size=2, tp_size=2
106-
)
104+
ddp_gpt, dist_optimizer, _ = build_gpt_model_optimizer_and_dataset(self.init_method, dp_size=2, tp_size=2)
107105

108106
# Load the model and optimizer after first data
109107

test/checkpoint/open_llama/test_open_llama_dp_reshard.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

test/checkpoint/open_llama/test_open_llama_load_save.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

test/checkpoint/open_llama/test_open_llama_tp_reshard.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
################################################################################
22
#
3-
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
3+
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at

test/model/open_llama/test_attention.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def test_attention(self):
5656
input.retain_grad()
5757
non_parallel_attention, _ = get_model()
5858
non_parallel_attention = non_parallel_attention.cuda()
59-
golden_outputs = non_parallel_attention(input)
59+
dummy_position_ids = torch.randint(low=0, high=s, size=(bsz, s)).cuda()
60+
golden_outputs = non_parallel_attention(input, position_ids=dummy_position_ids)
6061
golden_loss = golden_outputs[0].mean()
6162
golden_loss.backward()
6263

@@ -84,8 +85,9 @@ def test_attention(self):
8485
d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)])
8586
d_input.requires_grad_()
8687
d_input.retain_grad()
88+
d_position_id = distribute_tensor(dummy_position_ids.detach(), device_mesh, [Replicate()])
8789

88-
vescale_outputs = vescale_attention(d_input)
90+
vescale_outputs = vescale_attention(d_input, position_ids=d_position_id)
8991
vescale_outputs[0] = vescale_outputs[0].redistribute(placements=[Replicate()] * device_mesh.ndim)
9092
vescale_loss = vescale_outputs[0].mean()
9193

test/model/open_llama/test_decoder_layer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def test_decoder(self):
5656
input.retain_grad()
5757
non_parallel_decoder, _ = get_model()
5858
non_parallel_decoder = non_parallel_decoder.cuda()
59-
golden_outputs = non_parallel_decoder(input)
59+
dummy_position_id = torch.randint(low=0, high=s, size=(bsz, s)).cuda()
60+
golden_outputs = non_parallel_decoder(input, position_ids=dummy_position_id)
6061
golden_loss = golden_outputs[0].mean()
6162
golden_loss.backward()
6263

@@ -95,8 +96,9 @@ def test_decoder(self):
9596
d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)])
9697
d_input.requires_grad_()
9798
d_input.retain_grad()
99+
d_position_id = distribute_tensor(dummy_position_id.detach(), device_mesh, [Replicate()])
98100

99-
vescale_outputs = vescale_decoder(d_input)
101+
vescale_outputs = vescale_decoder(d_input, position_ids=d_position_id)
100102
vescale_outputs[0] = vescale_outputs[0].redistribute(placements=[Replicate()] * device_mesh.ndim)
101103
vescale_loss = vescale_outputs[0].mean()
102104

test/ndtimeline/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# make pylint happy
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
################################################################################
2+
#
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
################################################################################
17+
18+
import os
19+
from vescale.ndtimeline.world_info import WorldInfo
20+
from vescale.ndtimeline.handlers import LocalRawNDHandler
21+
from vescale.ndtimeline.variables import LOCAL_LOGGING_PATH
22+
23+
24+
def test_basic_usage():
25+
h = LocalRawNDHandler(run_id=0, chunk_sz=10, backup_cnt=3)
26+
file_name = "timeline_run0_raw.log"
27+
h("test_metric", 1.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
28+
assert os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name))
29+
for _ in range(4):
30+
h("test_metric", 1.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
31+
h("test_metric2", 2.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
32+
assert os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".2"))
33+
assert not os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".4"))
34+
os.remove(os.path.join(LOCAL_LOGGING_PATH, file_name))
35+
for i in range(1, 4):
36+
os.remove(os.path.join(LOCAL_LOGGING_PATH, file_name + "." + str(i)))
37+
assert not os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".2"))

test/ndtimeline/test_metric_level.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
################################################################################
2+
#
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
################################################################################
17+
18+
from vescale.ndtimeline import NDMetricLevel
19+
20+
21+
def test_cmp_level():
22+
assert NDMetricLevel.FRAMEWORK_DEBUG >= NDMetricLevel.INFO
23+
assert NDMetricLevel.USER_DEBUG >= NDMetricLevel.INFO
24+
assert NDMetricLevel.USER_DEBUG > NDMetricLevel.INFO
25+
assert NDMetricLevel.USER_INFO < NDMetricLevel.INFO
26+
assert NDMetricLevel.USER_INFO <= NDMetricLevel.INFO
27+
assert NDMetricLevel.INFO < NDMetricLevel.DEBUG
28+
assert NDMetricLevel.TRACE <= NDMetricLevel.TRACE
29+
assert NDMetricLevel.TRACE >= NDMetricLevel.TRACE
30+
assert NDMetricLevel.TRACE == NDMetricLevel.TRACE
+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
################################################################################
2+
#
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
################################################################################
17+
18+
import pytest
19+
from vescale.ndtimeline.world_info import WorldInfo
20+
from vescale.ndtimeline.handlers import ParserNDHandler
21+
from vescale.ndtimeline.exceptions import NDHandlerError
22+
23+
24+
def test_normal_input_with_tags():
25+
metric_name = "test_metric"
26+
recent_elapsed_raw_parts = [1.0, 3.2, 1.4]
27+
elapsed = sum(recent_elapsed_raw_parts)
28+
recent_since_start_raw_parts = [1710332816.6118143, 1710332833.2222, 1710332846.1313]
29+
single_tag = {"is_test": True}
30+
tags = [single_tag] * (len(recent_elapsed_raw_parts) - 1) + [{"is_test": False}]
31+
step_range = range(0, 1)
32+
world_info = WorldInfo(0, 0)
33+
callback = ParserNDHandler()
34+
records = callback(
35+
metric_name, elapsed, recent_elapsed_raw_parts, recent_since_start_raw_parts, tags, step_range, world_info, {}
36+
)
37+
assert len(records) == 1
38+
assert records[0].step == 0
39+
40+
41+
def test_normal_invalid_input():
42+
metric_name = "test_metric"
43+
recent_elapsed_raw_parts = [1.0, 3.2, 1.4]
44+
elapsed = sum(recent_elapsed_raw_parts)
45+
recent_since_start_raw_parts = [1710332816.6118143, 1710332846.1313]
46+
single_tag = {"is_test": True}
47+
tags = [single_tag] * (len(recent_elapsed_raw_parts) - 1) + [{"is_test": False}]
48+
step_range = range(0, 1)
49+
world_info = WorldInfo(0, 0)
50+
callback = ParserNDHandler()
51+
with pytest.raises(NDHandlerError):
52+
callback(
53+
metric_name,
54+
elapsed,
55+
recent_elapsed_raw_parts,
56+
recent_since_start_raw_parts,
57+
tags,
58+
step_range,
59+
world_info,
60+
{},
61+
)
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
################################################################################
2+
#
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
################################################################################
17+
18+
import torch
19+
import torch.nn as nn
20+
import os
21+
22+
23+
class MLP(nn.Module):
24+
def __init__(self, features_in, feature_middle, features_out, value):
25+
super().__init__()
26+
self.value = value
27+
self.counter = 0
28+
self.fc1 = nn.Linear(1024, 1024, bias=False)
29+
self.fc1.weight.data.fill_(value)
30+
self.fc2 = nn.Linear(1024, 1024, bias=False)
31+
self.fc2.weight.data.fill_(value * 2)
32+
self.gelu = nn.GELU()
33+
34+
def forward(self, x):
35+
t = self.fc1(x)
36+
t = self.gelu(t)
37+
t = self.fc2(t)
38+
torch.save(t, f"{os.environ['model_name']}_mlp{self.value}_fwd{self.counter}_out_tensor.pt")
39+
self.counter += 1
40+
return t
41+
42+
43+
class FourMLP(nn.Module):
44+
def __init__(self, hidden):
45+
super().__init__()
46+
self.mlp1 = MLP(hidden * 1, hidden * 2, hidden * 3, 0)
47+
self.mlp2 = MLP(hidden * 3, hidden * 4, hidden * 5, 1)
48+
self.mlp3 = MLP(hidden * 5, hidden * 6, hidden * 7, 2)
49+
self.mlp4 = MLP(hidden * 7, hidden * 8, hidden * 9, 3)
50+
self.sequence = nn.Sequential(self.mlp1, self.mlp2, self.mlp3, self.mlp4)
51+
52+
def forward(self, x):
53+
return self.sequence(x)

0 commit comments

Comments
 (0)