17
17
18
18
namespace fbgemm_gpu {
19
19
20
+ template <typename T, typename ... Ts>
21
+ constexpr inline bool is_one_of_v = (std::is_same_v<T, Ts> || ...);
22
+
20
23
// //////////////////////////////////////////////////////////////////////////////
21
24
// Quantized Load and Store
22
25
// //////////////////////////////////////////////////////////////////////////////
@@ -37,32 +40,19 @@ DEVICE_INLINE void quantize_store(
37
40
template <typename dst_t , typename src_t >
38
41
DEVICE_INLINE Vec4T<dst_t > dequantize_load (
39
42
const src_t * value,
40
- const float2 /* unused */ ) {
41
- return Vec4T<dst_t >(value);
42
- }
43
-
44
- template <>
45
- DEVICE_INLINE Vec4T<float > dequantize_load (
46
- const uint8_t * value,
47
- const float2 qparams) {
48
- Vec4T<float > out;
49
- out.acc .x = value[0 ] * qparams.x + qparams.y ;
50
- out.acc .y = value[1 ] * qparams.x + qparams.y ;
51
- out.acc .z = value[2 ] * qparams.x + qparams.y ;
52
- out.acc .w = value[3 ] * qparams.x + qparams.y ;
53
- return out;
54
- }
43
+ [[maybe_unused]] const float2 qparams) {
44
+ if constexpr (
45
+ std::is_same_v<src_t , uint8_t > && is_one_of_v<dst_t , float , at::Half>) {
46
+ Vec4T<dst_t > out;
47
+ out.acc .x = value[0 ] * qparams.x + qparams.y ;
48
+ out.acc .y = value[1 ] * qparams.x + qparams.y ;
49
+ out.acc .z = value[2 ] * qparams.x + qparams.y ;
50
+ out.acc .w = value[3 ] * qparams.x + qparams.y ;
51
+ return out;
55
52
56
- template <>
57
- DEVICE_INLINE Vec4T<at::Half> dequantize_load (
58
- const uint8_t * value,
59
- const float2 qparams) {
60
- Vec4T<at::Half> out;
61
- out.acc .x = value[0 ] * qparams.x + qparams.y ;
62
- out.acc .y = value[1 ] * qparams.x + qparams.y ;
63
- out.acc .z = value[2 ] * qparams.x + qparams.y ;
64
- out.acc .w = value[3 ] * qparams.x + qparams.y ;
65
- return out;
53
+ } else {
54
+ return Vec4T<dst_t >(value);
55
+ }
66
56
}
67
57
68
58
template <typename emb_t >
@@ -74,12 +64,6 @@ DEVICE_INLINE float2 load_qparams_from_row(emb_t* qparam_ptr) {
74
64
return qparams;
75
65
}
76
66
77
- template <typename emb_t >
78
- DEVICE_INLINE void store_qparams_to_row (emb_t * ptr, float2 qparams) {
79
- CUDA_KERNEL_ASSERT (false ); // Only int8 embeddding should call this
80
- }
81
-
82
- template <>
83
67
DEVICE_INLINE void store_qparams_to_row (uint8_t * ptr, float2 qparams) {
84
68
auto ptr_as_uint = reinterpret_cast <uintptr_t >(ptr);
85
69
if (ptr_as_uint % 8 == 0 ) {
@@ -112,12 +96,24 @@ DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) {
112
96
113
97
// //////////////////////////////////////////////////////////////////////////////
114
98
// Weight Row
99
+ //
100
+ // This is a memory accessor around a row of dim_ number of embedding weights.
101
+ // It provides for loading and storing of 4 elements at a time (Vec4T<dst_t>)
102
+ // from and to the embedding table or cache. It also provides for quantization
103
+ // and de-quantization of the data. The cache row pointer is optional, and if
104
+ // not provided, then the embedding table is assumed to be the source of truth.
105
+ //
106
+ // Template parameters:
107
+ // emb_t : The type of the embedding table (e.g. uint8_t, float, at::Half)
108
+ // cache_t : The type of the cache
109
+ // dst_t : The type of the registers
115
110
// //////////////////////////////////////////////////////////////////////////////
116
111
117
112
template <typename emb_t , typename cache_t , typename dst_t >
118
113
// TODO: pass in dimension info and calculate qparams for rowwise integer
119
114
// quantization
120
- struct WeightRow {
115
+ class WeightRow {
116
+ public:
121
117
// Constructor for no stochastic rounding
122
118
DEVICE_INLINE WeightRow (emb_t * row, cache_t * cache_row, int dim)
123
119
: row_(row),
@@ -144,65 +140,54 @@ struct WeightRow {
144
140
}
145
141
}
146
142
147
- emb_t * row_;
148
- cache_t * cache_row_;
149
- int dim_;
150
- StochasticRoundingRNGState stoc_rounding_state_;
151
- StochasticRoundingRNGState* stoc_rounding_state_ptr_;
143
+ // ////////////////////////////////////////////////////////////////////////////
144
+ // Load 4 elements from the table row at element offset d into a register
145
+ // variable (Vec4T<dst_t>)
146
+ //
147
+ // If the cache row pointer is valid, then data will be read from the cache
148
+ // instead of embedding table.
149
+ // ////////////////////////////////////////////////////////////////////////////
152
150
153
- // Load from cache if resident; else load from embedding
154
151
DEVICE_INLINE Vec4T<dst_t > load (const int32_t d, const float2 qparams) const {
152
+ // Load from the cache if resident; else load from the embedding table.
153
+ //
154
+ // Note: This method assumes that dst_t is of higher precision than cache_t
155
+ // and emb_t
155
156
if (cache_row_) {
156
157
return dequantize_load<dst_t , cache_t >(cache_row_ + d, qparams);
157
158
} else {
158
159
return dequantize_load<dst_t , emb_t >(row_ + d, qparams);
159
160
}
160
161
}
161
162
162
- // Write back weight (high precision) to cache if resident; else write to
163
- // embedding assume dst_t is higher precision than cache_t and emb_t
163
+ // ////////////////////////////////////////////////////////////////////////////
164
+ // Store regster variable of 4 elements (Vec4T<dst_t>) back into the table
165
+ // into the table row at element offset d
166
+ //
167
+ // If the cache row pointer is valid, then data will be written to the cache
168
+ // instead of embedding table.
169
+ // ////////////////////////////////////////////////////////////////////////////
170
+
164
171
DEVICE_INLINE void
165
172
store (const Vec4T<dst_t >& v, const int32_t d, const float2 qparams) {
173
+ // Write back weight (high precision) to cache if resident; else write to
174
+ // embedding table.
175
+ //
176
+ // Note: This method assumes that dst_t is of higher precision than cache_t
177
+ // and emb_t
166
178
if (cache_row_) {
167
179
quantize_store (cache_row_ + d, v, stoc_rounding_state_ptr_, qparams);
168
180
} else {
169
181
quantize_store (row_ + d, v, stoc_rounding_state_ptr_, qparams);
170
182
}
171
183
}
172
184
173
- // Copy vector from src_vec to dst_vec (both are float)
174
- DEVICE_INLINE void same_type_vector_copy (
175
- float * dst_vec,
176
- const float * src_vec) {
177
- *reinterpret_cast <float4 *>(dst_vec) =
178
- *reinterpret_cast <const float4 *>(src_vec);
179
- }
180
-
181
- // Copy vector from src_vec to dst_vec (both are at::Half)
182
- DEVICE_INLINE void same_type_vector_copy (
183
- at::Half* dst_vec,
184
- const at::Half* src_vec) {
185
- *reinterpret_cast <float2 *>(dst_vec) =
186
- *reinterpret_cast <const float2 *>(src_vec);
187
- }
188
-
189
- // Evict cached row into embedding row (high prec -> low prec)
190
- DEVICE_INLINE void evict_cache (const int32_t d, const float2 qparams) {
191
- if constexpr (std::is_same_v<emb_t , cache_t >) {
192
- // No conversion required when emb_t and cache_t are the same type
193
- same_type_vector_copy (
194
- reinterpret_cast <cache_t *>(row_ + d),
195
- reinterpret_cast <const cache_t *>(cache_row_ + d));
196
- } else {
197
- // Does 2-step conversion: cache_t -> FP32 -> weight_t
198
- const auto cache_slice = load (d, qparams);
199
- quantize_store (row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams);
200
- }
201
- }
202
-
203
- DEVICE_INLINE void store_qparams (const float2 qparams) {
204
- store_qparams_to_row (row_ + dim_, qparams);
205
- }
185
+ // ////////////////////////////////////////////////////////////////////////////
186
+ // Fetch the quantization parameters of the table row
187
+ //
188
+ // Qparams are fetched from the end of the row in the embedding table, not the
189
+ // cache.
190
+ // ////////////////////////////////////////////////////////////////////////////
206
191
207
192
DEVICE_INLINE float2 load_qparams () const {
208
193
if constexpr (std::is_same_v<emb_t , uint8_t >) {
@@ -212,32 +197,77 @@ struct WeightRow {
212
197
}
213
198
}
214
199
215
- DEVICE_INLINE void warp_copy_to_cache (
216
- cache_t * dst_row,
217
- const uint32_t dim_length,
200
+ // ////////////////////////////////////////////////////////////////////////////
201
+ // Update the quantization parameters of the table row
202
+ //
203
+ // Qparams are stored at the end of the row in the embedding table, not the
204
+ // cache.
205
+ // ////////////////////////////////////////////////////////////////////////////
206
+
207
+ template <typename T = emb_t >
208
+ DEVICE_INLINE auto store_qparams (const float2 qparams) const
209
+ -> std::enable_if_t<std::is_same_v<T, uint8_t>, void> {
210
+ store_qparams_to_row (row_ + dim_, qparams);
211
+ }
212
+
213
+ // ////////////////////////////////////////////////////////////////////////////
214
+ // Load the row from the embedding table into the cache
215
+ //
216
+ // De-quantization will be applied if the embedding table type is uint8_t (low
217
+ // prec -> high prec).
218
+ // ////////////////////////////////////////////////////////////////////////////
219
+
220
+ DEVICE_INLINE void warp_cache_load (
221
+ // cache_t* dst_row,
218
222
const uint32_t num_lanes,
219
223
const uint32_t lane_id) {
220
224
if constexpr (std::is_same_v<emb_t , cache_t >) {
221
- // No conversion required when emb_t and cache_t are the same type
222
- for (int32_t d = lane_id * 4 ; d < dim_length; d += num_lanes * 4 ) {
225
+ // If the embedding table and cache types are the same, then simply copy
226
+ // data from cache to embedding table.
227
+ for (auto d = lane_id * 4 ; d < dim_; d += num_lanes * 4 ) {
223
228
same_type_vector_copy (
224
- dst_row + d, reinterpret_cast <const cache_t *>(row_ + d));
229
+ cache_row_ + d, reinterpret_cast <const cache_t *>(row_ + d));
225
230
}
226
231
} else {
227
232
// Load quantization params from embedding row
228
233
const auto qparams = load_qparams ();
229
234
230
235
// Copy over for each warp-sized slice of Vec4's
231
236
// Does 2-step conversion: weight_t -> FP32 -> cache_t
232
- for (int32_t d = lane_id * 4 ; d < dim_length ; d += num_lanes * 4 ) {
237
+ for (auto d = lane_id * 4 ; d < dim_ ; d += num_lanes * 4 ) {
233
238
const auto slice = load (d, qparams);
234
- quantize_store (dst_row + d, slice, stoc_rounding_state_ptr_, qparams);
239
+ quantize_store (
240
+ cache_row_ + d, slice, stoc_rounding_state_ptr_, qparams);
235
241
}
236
242
}
237
243
}
238
244
239
- DEVICE_INLINE void warp_evict_cache (
240
- const uint32_t dim_length,
245
+ // ////////////////////////////////////////////////////////////////////////////
246
+ // Copy the row from the embedding table into the cache
247
+ // ////////////////////////////////////////////////////////////////////////////
248
+
249
+ DEVICE_INLINE void evict_cache (const uint32_t d, const float2 qparams) {
250
+ if constexpr (std::is_same_v<emb_t , cache_t >) {
251
+ // If the embedding table and cache types are the same, then simply copy
252
+ // data from cache to embedding table.
253
+ same_type_vector_copy (
254
+ reinterpret_cast <emb_t *>(row_ + d),
255
+ reinterpret_cast <const cache_t *>(cache_row_ + d));
256
+ } else {
257
+ // Else, do 2-step conversion: cache_t -> FP32 (register) -> weight_t
258
+ const auto cache_slice = load (d, qparams);
259
+ quantize_store (row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams);
260
+ }
261
+ }
262
+
263
+ // ////////////////////////////////////////////////////////////////////////////
264
+ // Evict the row from the cache and into the embedding table.
265
+ //
266
+ // Quantization will be applied if the embedding table type is uint8_t (high
267
+ // prec -> low prec).
268
+ // ////////////////////////////////////////////////////////////////////////////
269
+
270
+ DEVICE_INLINE void warp_cache_evict (
241
271
const uint32_t num_lanes,
242
272
const uint32_t lane_id) {
243
273
float2 qparams;
@@ -248,7 +278,7 @@ struct WeightRow {
248
278
std::numeric_limits<at::acc_type<cache_t , true >>::lowest ();
249
279
250
280
// Compute the qparams from the cache row (not embedding row) weights
251
- for (auto d = lane_id; d * 4 < dim_length ; d += num_lanes) {
281
+ for (auto d = lane_id; d * 4 < dim_ ; d += num_lanes) {
252
282
const auto cache_slice = load (d * 4 , qparams); // qparams not used
253
283
local_max = max (local_max, cache_slice.vmax ());
254
284
local_min = min (local_min, cache_slice.vmin ());
@@ -263,41 +293,84 @@ struct WeightRow {
263
293
}
264
294
}
265
295
266
- for (auto d = lane_id * 4 ; d < dim_length ; d += num_lanes * 4 ) {
296
+ for (auto d = lane_id * 4 ; d < dim_ ; d += num_lanes * 4 ) {
267
297
// Evict the slice into the embedding row
268
298
evict_cache (d, qparams);
269
299
}
270
300
}
301
+
302
+ protected:
303
+ emb_t * const row_;
304
+ cache_t * const cache_row_;
305
+ int32_t const dim_;
306
+ StochasticRoundingRNGState stoc_rounding_state_;
307
+ StochasticRoundingRNGState* stoc_rounding_state_ptr_;
308
+
309
+ // ////////////////////////////////////////////////////////////////////////////
310
+ // Copy 4 elements (float or at::Half) from src_vec to dst_vec
311
+ //
312
+ // Reinterpret cast to float4* or float2* for mass copy
313
+ // ////////////////////////////////////////////////////////////////////////////
314
+
315
+ template <
316
+ typename T,
317
+ typename = std::enable_if_t <is_one_of_v<T, float , at::Half>>>
318
+ DEVICE_INLINE void same_type_vector_copy (T* dst_vec, const T* src_vec) {
319
+ // Copy vector from src_vec to dst_vec (both are float)
320
+ using ptr_t = std::conditional_t <std::is_same_v<T, float >, float4 , float2 >;
321
+ *reinterpret_cast <ptr_t *>(dst_vec) =
322
+ *reinterpret_cast <const ptr_t *>(src_vec);
323
+ }
271
324
};
272
325
273
326
// //////////////////////////////////////////////////////////////////////////////
274
327
// Weight Row Accessor
275
328
//
276
- // This is a basic memory accessor around a row of dim_ number of embedding
277
- // weights of type row_t, and provides for loading 4 elements at a time into
278
- // Vec4T<dst_t> with de-quantization support. Unlike WeightRow, this accessor
279
- // is for reading only, and does not take into account embedding vs cache table,
280
- // etc.
329
+ // This is a lightweight memory accessor around a row of dim_ number of
330
+ // embedding weights of type row_t (can be HBM or UVM), and provides for loading
331
+ // 4 elements at a time into Vec4T<dst_t> with de-quantization support. Unlike
332
+ // the heavyweight WeightRow class, this accessor is for reading values only,
333
+ // and does not handle embedding vs cache tables, etc.
334
+ //
335
+ // Template parameters:
336
+ // row_t : The type of the table row (e.g. uint8_t, float, at::Half)
337
+ // dst_t : The type of the registers
281
338
// //////////////////////////////////////////////////////////////////////////////
282
339
283
340
template <typename row_t , typename dst_t >
284
- struct WeightRowAccessor {
285
- const row_t * row_;
341
+ class WeightRowAccessor {
342
+ // The pointer to the row of weights in the table
343
+ const row_t * const row_;
344
+
345
+ // The number of elements per table row.
346
+ //
347
+ // This is NOT necessarily equivalent to the row stride D_emb, as there may be
348
+ // quantization parameters and optimizer states packed into the back of the
349
+ // row.
350
+ //
351
+ // dim_ is presumed to be a multiple of 4, since it loads data into Vec4T for
352
+ // max register occupancy.
286
353
const int32_t dim_;
287
- const float2 qparams_;
288
354
355
+ // [OPTIONAL] The quantization parameters for the row. If the row type is not
356
+ // uint8_t, i.e. not quantized, then it is set to (0.0f, 0.0f).
357
+ float2 qparams_ = make_float2(0 .0f , 0 .0f );
358
+
359
+ public:
289
360
DEVICE_INLINE
290
361
WeightRowAccessor (const row_t * const row, const int32_t dim)
291
- : row_(row), dim_(dim), qparams_(qparams()) {}
292
-
293
- DEVICE_INLINE auto qparams () const {
362
+ : row_(row), dim_(dim) {
294
363
if constexpr (std::is_same_v<row_t , uint8_t >) {
295
- return load_qparams_from_row<row_t >(row_ + dim_);
296
- } else {
297
- return make_float2 (0 .0f , 0 .0f );
364
+ qparams_ = qparams ();
298
365
}
299
366
}
300
367
368
+ template <typename T = row_t >
369
+ DEVICE_INLINE auto qparams () const
370
+ -> std::enable_if_t<std::is_same_v<T, uint8_t>, float2> {
371
+ return load_qparams_from_row<row_t >(row_ + dim_);
372
+ }
373
+
301
374
DEVICE_INLINE Vec4T<dst_t > load (const int32_t d) const {
302
375
return dequantize_load<dst_t , row_t >(row_ + d, qparams_);
303
376
}
0 commit comments