Skip to content

Commit 572c388

Browse files
authored
Merge pull request #96 from bytedance/add_amd_backend
support AMD backend for micro_perf.
2 parents 3b3ea9d + 76cfd5f commit 572c388

File tree

6 files changed

+380
-1
lines changed

6 files changed

+380
-1
lines changed

Diff for: byte_micro_perf/backends/AMD/backend_amd.py

+299
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# Copyright 2023 ByteDance and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import logging
17+
import math
18+
import os
19+
from datetime import timedelta
20+
from typing import Any, Dict, List
21+
22+
import torch
23+
import torch.distributed as dist
24+
import torch.distributed.distributed_c10d as dist_c10d
25+
26+
from backends.backend import Backend
27+
from backends.module_store import *
28+
from backends.utils import get_dtype_bytes
29+
30+
from .custom_ops import GPUGemmOp, GPUBatchGemmOp, GPUGroupGemmOp
31+
32+
33+
logging.basicConfig(level=logging.INFO)
34+
log = logging.getLogger("PerfEngine")
35+
36+
37+
class BackendAMD(Backend):
38+
39+
def get_device_count(self):
40+
return torch.cuda.device_count()
41+
42+
def set_device(self, device_index):
43+
torch.cuda.set_device(device_index)
44+
45+
def get_device(self):
46+
return torch.cuda.current_device()
47+
48+
def all_gather_object(self, obj):
49+
gather_object_list = [None for _ in range(self.world_size)]
50+
dist.all_gather_object(
51+
object_list=gather_object_list,
52+
obj=obj,
53+
group=self.group
54+
)
55+
return gather_object_list
56+
57+
58+
def get_device_name(self):
59+
return torch.cuda.get_device_name(0)
60+
61+
def get_backend_properties(self):
62+
self.memory_limit = int(
63+
torch.cuda.get_device_properties(0).total_memory / (1024**3)
64+
)
65+
66+
if self.vendor_path is not None and os.path.exists(self.vendor_path) and (self.vendor_path).endswith(".json"):
67+
with open(self.vendor_path, "r") as f:
68+
self.hw_info_dict = json.load(f)
69+
# if the vendor path does not exist, please set this param manaually
70+
self.bandwidth_limit = self.hw_info_dict["内存参数"]["内存"]["内存带宽(GB/s)"]
71+
else:
72+
log.warning(
73+
"Vendor_path: [ {} ] was not found or not a full path points to json, please check your path!!! Otherwise, please set the hardware info manaually.".format(
74+
self.vendor_path
75+
)
76+
)
77+
78+
79+
# device/host ops
80+
def host2device(self):
81+
self.op = Host2DeviceOp(torch.device("cuda"))
82+
83+
def device2host(self):
84+
self.op = Device2HostOp()
85+
86+
87+
# communication ops
88+
def allreduce(self):
89+
self.op = AllReduceOp(self.group)
90+
91+
def allgather(self):
92+
self.op = AllGatherOp(self.group)
93+
94+
def reducescatter(self):
95+
self.op = ReduceScatterOp(self.group)
96+
97+
def alltoall(self):
98+
self.op = AllToAllOp(self.group)
99+
100+
def broadcast(self):
101+
self.op = BroadcastOp(self.group)
102+
103+
def p2p(self):
104+
self.op = P2POp(self.group, self.ranks, self.rank)
105+
106+
# compute ops
107+
# unary ops
108+
def sin(self):
109+
self.op = SinOp()
110+
111+
def cos(self):
112+
self.op = CosOp()
113+
114+
def exp(self):
115+
self.op = ExpOp()
116+
117+
def exponential(self):
118+
self.op = ExponentialOp()
119+
120+
def silu(self):
121+
self.op = SiluOp()
122+
123+
def gelu(self):
124+
self.op = GeluOp()
125+
126+
def swiglu(self):
127+
self.op = SwiGLUOp()
128+
129+
def cast(self):
130+
self.op = CastOp()
131+
132+
133+
# binary ops
134+
def add(self):
135+
self.op = AddOp()
136+
137+
def mul(self):
138+
self.op = MulOp()
139+
140+
def sub(self):
141+
self.op = SubOp()
142+
143+
def div(self):
144+
self.op = DivOp()
145+
146+
147+
# reduce ops
148+
def layernorm(self):
149+
self.op = LayerNormOp()
150+
151+
def softmax(self):
152+
self.op = SoftmaxOp()
153+
154+
def reduce_sum(self):
155+
self.op = ReduceSumOp()
156+
157+
def reduce_min(self):
158+
self.op = ReduceMinOp()
159+
160+
def reduce_max(self):
161+
self.op = ReduceMaxOp()
162+
163+
164+
# index ops
165+
def index_add(self):
166+
self.op = IndexAddOp()
167+
168+
def sort(self):
169+
self.op = SortOp()
170+
171+
def unique(self):
172+
self.op = UniqueOp()
173+
174+
def scatter(self):
175+
self.op = ScatterOp()
176+
177+
def gather(self):
178+
self.op = GatherOp()
179+
180+
# gemm ops
181+
def gemm(self):
182+
self.op = GPUGemmOp()
183+
184+
def gemv(self):
185+
self.op = GPUGemmOp()
186+
187+
def batch_gemm(self):
188+
self.op = GPUBatchGemmOp()
189+
190+
def group_gemm(self):
191+
self.op = GPUGroupGemmOp()
192+
193+
194+
195+
# create input tensors
196+
def build_tensor(self, input_shapes, dtype):
197+
torch.cuda.empty_cache()
198+
torch_dtype = getattr(torch, dtype)
199+
200+
# compute size of input and output tensors
201+
if hasattr(self.op, "compute_size"):
202+
bytes_per_cnt = self.op.compute_size(input_shapes, dtype)
203+
# default: input_tensors_size == output_tensor_size, all tensors have same dtype
204+
else:
205+
dtype_size = get_dtype_bytes(dtype)
206+
element_num = 2 * sum([math.prod(shape) for shape in input_shapes])
207+
bytes_per_cnt = dtype_size * element_num
208+
209+
# compute max avail tensors for compute
210+
avail_bytes = (self.memory_limit - 4) * 1024**3
211+
avail_cnts = avail_bytes // bytes_per_cnt
212+
max_data_cnt = min(self.iterations, avail_cnts)
213+
214+
# create input tensors for each op
215+
input_tensors_list = []
216+
for _ in range(max_data_cnt):
217+
# create input tensors
218+
if hasattr(self.op, "custom_create_tensors"):
219+
input_tensors = self.op.custom_create_tensors(input_shapes, torch_dtype, "cuda")
220+
input_tensors_list.append(input_tensors)
221+
# default: all input tensors have same dtype
222+
else:
223+
if torch_dtype in [torch.int8, torch.int32]:
224+
input_tensors = [
225+
torch.randint(-3, 3, size=shape, dtype=torch_dtype, device="cuda")
226+
for shape in input_shapes
227+
]
228+
else:
229+
input_tensors = [
230+
torch.randn(shape, dtype=torch_dtype, device="cuda")
231+
for shape in input_shapes
232+
]
233+
input_tensors_list.append(input_tensors)
234+
if hasattr(self.op, "process_inputs"):
235+
input_tensors_list = [
236+
self.op.process_inputs(*(input_tensor))
237+
for input_tensor in input_tensors_list
238+
]
239+
return input_tensors_list, max_data_cnt, bytes_per_cnt
240+
241+
242+
243+
def _run_operation(self, operation, inputs):
244+
result = operation(*inputs)
245+
return result
246+
247+
def device_synchronize(self):
248+
torch.cuda.synchronize()
249+
return True
250+
251+
def initialize_ccl(self, rank, world_size):
252+
"""
253+
initialize distributed process groups and relevant ENVs
254+
"""
255+
# check device_count
256+
device_count = torch.cuda.device_count()
257+
if world_size > device_count:
258+
world_size = device_count
259+
if rank >= world_size:
260+
return False
261+
262+
# set envs
263+
os.environ["MASTER_ADDR"] = "127.0.0.1"
264+
os.environ["MASTER_PORT"] = "49373"
265+
os.environ["LOCAL_RANK"] = str(rank)
266+
os.environ["RANK"] = str(rank)
267+
os.environ["WORLD_SIZE"] = str(world_size)
268+
269+
torch.cuda.set_device(rank)
270+
271+
# Call the init process
272+
timeout_seconds = int(os.environ.get("MEGATRON_NCCL_TIMEOUT_SECOND", 30))
273+
torch.distributed.init_process_group(
274+
backend="nccl",
275+
world_size=world_size,
276+
rank=rank,
277+
store=None,
278+
timeout=timedelta(seconds=timeout_seconds),
279+
)
280+
self.setup_2d_group()
281+
log.warning("DIST: rank {}, world_size {}".format(rank, world_size))
282+
return True
283+
284+
def setup_2d_group(self):
285+
self.rank = dist.get_rank()
286+
torch.cuda.set_device(self.rank)
287+
origin_store_based_barrier = dist_c10d._store_based_barrier
288+
dist_c10d._store_based_barrier = lambda *a, **kw: None
289+
self.world_size = dist.get_world_size()
290+
self.ranks = range(0, self.world_size)
291+
group = dist.new_group(self.ranks)
292+
if self.rank in self.ranks:
293+
self.group = group
294+
dist_c10d._store_based_barrier = origin_store_based_barrier
295+
# wait for all ranks finish group initializing
296+
torch.distributed.barrier()
297+
298+
def destroy_process_group(self):
299+
dist.destroy_process_group()

Diff for: byte_micro_perf/backends/AMD/custom_ops.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from typing import List
2+
3+
import torch
4+
5+
from backends.module_store import GemmOp, BatchGemmOp, GroupGemmOp
6+
7+
8+
# gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16
9+
# gemm(cutlass) int8 --> int32
10+
class GPUGemmOp(GemmOp):
11+
def __init__(self):
12+
super().__init__()
13+
14+
def forward(
15+
self,
16+
input_tensor_a : torch.Tensor,
17+
input_tensor_b : torch.Tensor
18+
):
19+
compute_dtype = input_tensor_a.dtype
20+
if compute_dtype == torch.int8:
21+
output_tensor = input_tensor_a
22+
else:
23+
output_tensor = torch.mm(input_tensor_a, input_tensor_b)
24+
return output_tensor
25+
26+
27+
# batch_gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16
28+
# batch_gemm(cutlass) int8 --> int32
29+
class GPUBatchGemmOp(BatchGemmOp):
30+
def __init__(self):
31+
super().__init__()
32+
33+
def forward(
34+
self,
35+
input_tensor_a : torch.Tensor,
36+
input_tensor_b : torch.Tensor
37+
):
38+
compute_dtype = input_tensor_a.dtype
39+
40+
output_tensor = None
41+
if compute_dtype == torch.int8:
42+
output_tensor = input_tensor_a
43+
else:
44+
output_tensor = torch.bmm(input_tensor_a, input_tensor_b)
45+
return output_tensor
46+
47+
48+
# group_gemm(pytorch) float32/float16/bfloat16 --> float32/float16/bfloat16
49+
# group_gemm(cutlass) int8 --> int32
50+
class GPUGroupGemmOp(GroupGemmOp):
51+
def __init__(self):
52+
super().__init__()
53+
54+
def forward(self,
55+
a_list : List[torch.Tensor],
56+
b_list : List[torch.Tensor]
57+
):
58+
compute_dtype = a_list[0].dtype
59+
if compute_dtype == torch.int8:
60+
output_tensors = a_list
61+
else:
62+
output_tensors = [a @ b for a, b in zip(a_list, b_list)]
63+
return output_tensors

Diff for: byte_micro_perf/backends/AMD/requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
-i https://download.pytorch.org/whl/rocm6.1
2+
torch

Diff for: byte_micro_perf/backends/GPU/backend_gpu.py

+3
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,6 @@ def setup_2d_group(self):
302302

303303
# wait for all ranks finish group initializing
304304
torch.distributed.barrier()
305+
306+
def destroy_process_group(self):
307+
dist.destroy_process_group()

Diff for: byte_micro_perf/backends/backend.py

+5
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ def initialize_ccl(self, rank, world_size):
8888
def setup_2d_group(self):
8989
pass
9090

91+
@abstractmethod
92+
def destroy_process_group(self):
93+
pass
94+
95+
9196

9297
# communication ops
9398
def host2device(self):

0 commit comments

Comments
 (0)