1
+ # Copyright 2023 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import math
18
+ import os
19
+ from datetime import timedelta
20
+ from typing import Any , Dict , List
21
+
22
+ import torch
23
+ import torch .distributed as dist
24
+ import torch .distributed .distributed_c10d as dist_c10d
25
+
26
+ from backends .backend import Backend
27
+ from backends .module_store import *
28
+ from backends .utils import get_dtype_bytes
29
+
30
+ from .custom_ops import GPUGemmOp , GPUBatchGemmOp , GPUGroupGemmOp
31
+
32
+
33
+ logging .basicConfig (level = logging .INFO )
34
+ log = logging .getLogger ("PerfEngine" )
35
+
36
+
37
+ class BackendAMD (Backend ):
38
+
39
+ def get_device_count (self ):
40
+ return torch .cuda .device_count ()
41
+
42
+ def set_device (self , device_index ):
43
+ torch .cuda .set_device (device_index )
44
+
45
+ def get_device (self ):
46
+ return torch .cuda .current_device ()
47
+
48
+ def all_gather_object (self , obj ):
49
+ gather_object_list = [None for _ in range (self .world_size )]
50
+ dist .all_gather_object (
51
+ object_list = gather_object_list ,
52
+ obj = obj ,
53
+ group = self .group
54
+ )
55
+ return gather_object_list
56
+
57
+
58
+ def get_device_name (self ):
59
+ return torch .cuda .get_device_name (0 )
60
+
61
+ def get_backend_properties (self ):
62
+ self .memory_limit = int (
63
+ torch .cuda .get_device_properties (0 ).total_memory / (1024 ** 3 )
64
+ )
65
+
66
+ if self .vendor_path is not None and os .path .exists (self .vendor_path ) and (self .vendor_path ).endswith (".json" ):
67
+ with open (self .vendor_path , "r" ) as f :
68
+ self .hw_info_dict = json .load (f )
69
+ # if the vendor path does not exist, please set this param manaually
70
+ self .bandwidth_limit = self .hw_info_dict ["内存参数" ]["内存" ]["内存带宽(GB/s)" ]
71
+ else :
72
+ log .warning (
73
+ "Vendor_path: [ {} ] was not found or not a full path points to json, please check your path!!! Otherwise, please set the hardware info manaually." .format (
74
+ self .vendor_path
75
+ )
76
+ )
77
+
78
+
79
+ # device/host ops
80
+ def host2device (self ):
81
+ self .op = Host2DeviceOp (torch .device ("cuda" ))
82
+
83
+ def device2host (self ):
84
+ self .op = Device2HostOp ()
85
+
86
+
87
+ # communication ops
88
+ def allreduce (self ):
89
+ self .op = AllReduceOp (self .group )
90
+
91
+ def allgather (self ):
92
+ self .op = AllGatherOp (self .group )
93
+
94
+ def reducescatter (self ):
95
+ self .op = ReduceScatterOp (self .group )
96
+
97
+ def alltoall (self ):
98
+ self .op = AllToAllOp (self .group )
99
+
100
+ def broadcast (self ):
101
+ self .op = BroadcastOp (self .group )
102
+
103
+ def p2p (self ):
104
+ self .op = P2POp (self .group , self .ranks , self .rank )
105
+
106
+ # compute ops
107
+ # unary ops
108
+ def sin (self ):
109
+ self .op = SinOp ()
110
+
111
+ def cos (self ):
112
+ self .op = CosOp ()
113
+
114
+ def exp (self ):
115
+ self .op = ExpOp ()
116
+
117
+ def exponential (self ):
118
+ self .op = ExponentialOp ()
119
+
120
+ def silu (self ):
121
+ self .op = SiluOp ()
122
+
123
+ def gelu (self ):
124
+ self .op = GeluOp ()
125
+
126
+ def swiglu (self ):
127
+ self .op = SwiGLUOp ()
128
+
129
+ def cast (self ):
130
+ self .op = CastOp ()
131
+
132
+
133
+ # binary ops
134
+ def add (self ):
135
+ self .op = AddOp ()
136
+
137
+ def mul (self ):
138
+ self .op = MulOp ()
139
+
140
+ def sub (self ):
141
+ self .op = SubOp ()
142
+
143
+ def div (self ):
144
+ self .op = DivOp ()
145
+
146
+
147
+ # reduce ops
148
+ def layernorm (self ):
149
+ self .op = LayerNormOp ()
150
+
151
+ def softmax (self ):
152
+ self .op = SoftmaxOp ()
153
+
154
+ def reduce_sum (self ):
155
+ self .op = ReduceSumOp ()
156
+
157
+ def reduce_min (self ):
158
+ self .op = ReduceMinOp ()
159
+
160
+ def reduce_max (self ):
161
+ self .op = ReduceMaxOp ()
162
+
163
+
164
+ # index ops
165
+ def index_add (self ):
166
+ self .op = IndexAddOp ()
167
+
168
+ def sort (self ):
169
+ self .op = SortOp ()
170
+
171
+ def unique (self ):
172
+ self .op = UniqueOp ()
173
+
174
+ def scatter (self ):
175
+ self .op = ScatterOp ()
176
+
177
+ def gather (self ):
178
+ self .op = GatherOp ()
179
+
180
+ # gemm ops
181
+ def gemm (self ):
182
+ self .op = GPUGemmOp ()
183
+
184
+ def gemv (self ):
185
+ self .op = GPUGemmOp ()
186
+
187
+ def batch_gemm (self ):
188
+ self .op = GPUBatchGemmOp ()
189
+
190
+ def group_gemm (self ):
191
+ self .op = GPUGroupGemmOp ()
192
+
193
+
194
+
195
+ # create input tensors
196
+ def build_tensor (self , input_shapes , dtype ):
197
+ torch .cuda .empty_cache ()
198
+ torch_dtype = getattr (torch , dtype )
199
+
200
+ # compute size of input and output tensors
201
+ if hasattr (self .op , "compute_size" ):
202
+ bytes_per_cnt = self .op .compute_size (input_shapes , dtype )
203
+ # default: input_tensors_size == output_tensor_size, all tensors have same dtype
204
+ else :
205
+ dtype_size = get_dtype_bytes (dtype )
206
+ element_num = 2 * sum ([math .prod (shape ) for shape in input_shapes ])
207
+ bytes_per_cnt = dtype_size * element_num
208
+
209
+ # compute max avail tensors for compute
210
+ avail_bytes = (self .memory_limit - 4 ) * 1024 ** 3
211
+ avail_cnts = avail_bytes // bytes_per_cnt
212
+ max_data_cnt = min (self .iterations , avail_cnts )
213
+
214
+ # create input tensors for each op
215
+ input_tensors_list = []
216
+ for _ in range (max_data_cnt ):
217
+ # create input tensors
218
+ if hasattr (self .op , "custom_create_tensors" ):
219
+ input_tensors = self .op .custom_create_tensors (input_shapes , torch_dtype , "cuda" )
220
+ input_tensors_list .append (input_tensors )
221
+ # default: all input tensors have same dtype
222
+ else :
223
+ if torch_dtype in [torch .int8 , torch .int32 ]:
224
+ input_tensors = [
225
+ torch .randint (- 3 , 3 , size = shape , dtype = torch_dtype , device = "cuda" )
226
+ for shape in input_shapes
227
+ ]
228
+ else :
229
+ input_tensors = [
230
+ torch .randn (shape , dtype = torch_dtype , device = "cuda" )
231
+ for shape in input_shapes
232
+ ]
233
+ input_tensors_list .append (input_tensors )
234
+ if hasattr (self .op , "process_inputs" ):
235
+ input_tensors_list = [
236
+ self .op .process_inputs (* (input_tensor ))
237
+ for input_tensor in input_tensors_list
238
+ ]
239
+ return input_tensors_list , max_data_cnt , bytes_per_cnt
240
+
241
+
242
+
243
+ def _run_operation (self , operation , inputs ):
244
+ result = operation (* inputs )
245
+ return result
246
+
247
+ def device_synchronize (self ):
248
+ torch .cuda .synchronize ()
249
+ return True
250
+
251
+ def initialize_ccl (self , rank , world_size ):
252
+ """
253
+ initialize distributed process groups and relevant ENVs
254
+ """
255
+ # check device_count
256
+ device_count = torch .cuda .device_count ()
257
+ if world_size > device_count :
258
+ world_size = device_count
259
+ if rank >= world_size :
260
+ return False
261
+
262
+ # set envs
263
+ os .environ ["MASTER_ADDR" ] = "127.0.0.1"
264
+ os .environ ["MASTER_PORT" ] = "49373"
265
+ os .environ ["LOCAL_RANK" ] = str (rank )
266
+ os .environ ["RANK" ] = str (rank )
267
+ os .environ ["WORLD_SIZE" ] = str (world_size )
268
+
269
+ torch .cuda .set_device (rank )
270
+
271
+ # Call the init process
272
+ timeout_seconds = int (os .environ .get ("MEGATRON_NCCL_TIMEOUT_SECOND" , 30 ))
273
+ torch .distributed .init_process_group (
274
+ backend = "nccl" ,
275
+ world_size = world_size ,
276
+ rank = rank ,
277
+ store = None ,
278
+ timeout = timedelta (seconds = timeout_seconds ),
279
+ )
280
+ self .setup_2d_group ()
281
+ log .warning ("DIST: rank {}, world_size {}" .format (rank , world_size ))
282
+ return True
283
+
284
+ def setup_2d_group (self ):
285
+ self .rank = dist .get_rank ()
286
+ torch .cuda .set_device (self .rank )
287
+ origin_store_based_barrier = dist_c10d ._store_based_barrier
288
+ dist_c10d ._store_based_barrier = lambda * a , ** kw : None
289
+ self .world_size = dist .get_world_size ()
290
+ self .ranks = range (0 , self .world_size )
291
+ group = dist .new_group (self .ranks )
292
+ if self .rank in self .ranks :
293
+ self .group = group
294
+ dist_c10d ._store_based_barrier = origin_store_based_barrier
295
+ # wait for all ranks finish group initializing
296
+ torch .distributed .barrier ()
297
+
298
+ def destroy_process_group (self ):
299
+ dist .destroy_process_group ()
0 commit comments