Skip to content

Commit 7994159

Browse files
q10facebook-github-bot
authored andcommitted
Clean up WeightRow in preparation for optimizer state offloading (#4021)
Summary: Pull Request resolved: #4021 X-link: facebookresearch/FBGEMM#1109 - Clean up `WeightRow` implementation in preparation for optimizer state offloading - Add documentation for the class Differential Revision: D73473546
1 parent 0f27e6d commit 7994159

File tree

4 files changed

+180
-122
lines changed

4 files changed

+180
-122
lines changed

fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh

+172-99
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
namespace fbgemm_gpu {
1919

20+
template <typename T, typename... Ts>
21+
constexpr inline bool is_one_of_v = (std::is_same_v<T, Ts> || ...);
22+
2023
////////////////////////////////////////////////////////////////////////////////
2124
// Quantized Load and Store
2225
////////////////////////////////////////////////////////////////////////////////
@@ -37,32 +40,19 @@ DEVICE_INLINE void quantize_store(
3740
template <typename dst_t, typename src_t>
3841
DEVICE_INLINE Vec4T<dst_t> dequantize_load(
3942
const src_t* value,
40-
const float2 /* unused */) {
41-
return Vec4T<dst_t>(value);
42-
}
43-
44-
template <>
45-
DEVICE_INLINE Vec4T<float> dequantize_load(
46-
const uint8_t* value,
47-
const float2 qparams) {
48-
Vec4T<float> out;
49-
out.acc.x = value[0] * qparams.x + qparams.y;
50-
out.acc.y = value[1] * qparams.x + qparams.y;
51-
out.acc.z = value[2] * qparams.x + qparams.y;
52-
out.acc.w = value[3] * qparams.x + qparams.y;
53-
return out;
54-
}
43+
[[maybe_unused]] const float2 qparams) {
44+
if constexpr (
45+
std::is_same_v<src_t, uint8_t> && is_one_of_v<dst_t, float, at::Half>) {
46+
Vec4T<dst_t> out;
47+
out.acc.x = value[0] * qparams.x + qparams.y;
48+
out.acc.y = value[1] * qparams.x + qparams.y;
49+
out.acc.z = value[2] * qparams.x + qparams.y;
50+
out.acc.w = value[3] * qparams.x + qparams.y;
51+
return out;
5552

56-
template <>
57-
DEVICE_INLINE Vec4T<at::Half> dequantize_load(
58-
const uint8_t* value,
59-
const float2 qparams) {
60-
Vec4T<at::Half> out;
61-
out.acc.x = value[0] * qparams.x + qparams.y;
62-
out.acc.y = value[1] * qparams.x + qparams.y;
63-
out.acc.z = value[2] * qparams.x + qparams.y;
64-
out.acc.w = value[3] * qparams.x + qparams.y;
65-
return out;
53+
} else {
54+
return Vec4T<dst_t>(value);
55+
}
6656
}
6757

6858
template <typename emb_t>
@@ -74,12 +64,6 @@ DEVICE_INLINE float2 load_qparams_from_row(emb_t* qparam_ptr) {
7464
return qparams;
7565
}
7666

77-
template <typename emb_t>
78-
DEVICE_INLINE void store_qparams_to_row(emb_t* ptr, float2 qparams) {
79-
CUDA_KERNEL_ASSERT(false); // Only int8 embeddding should call this
80-
}
81-
82-
template <>
8367
DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) {
8468
auto ptr_as_uint = reinterpret_cast<uintptr_t>(ptr);
8569
if (ptr_as_uint % 8 == 0) {
@@ -112,12 +96,24 @@ DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) {
11296

11397
////////////////////////////////////////////////////////////////////////////////
11498
// Weight Row
99+
//
100+
// This is a memory accessor around a row of dim_ number of embedding weights.
101+
// It provides for loading and storing of 4 elements at a time (Vec4T<dst_t>)
102+
// from and to the embedding table or cache. It also provides for quantization
103+
// and de-quantization of the data. The cache row pointer is optional, and if
104+
// not provided, then the embedding table is assumed to be the source of truth.
105+
//
106+
// Template parameters:
107+
// emb_t : The type of the embedding table (e.g. uint8_t, float, at::Half)
108+
// cache_t : The type of the cache
109+
// dst_t : The type of the registers
115110
////////////////////////////////////////////////////////////////////////////////
116111

117112
template <typename emb_t, typename cache_t, typename dst_t>
118113
// TODO: pass in dimension info and calculate qparams for rowwise integer
119114
// quantization
120-
struct WeightRow {
115+
class WeightRow {
116+
public:
121117
// Constructor for no stochastic rounding
122118
DEVICE_INLINE WeightRow(emb_t* row, cache_t* cache_row, int dim)
123119
: row_(row),
@@ -144,65 +140,54 @@ struct WeightRow {
144140
}
145141
}
146142

147-
emb_t* row_;
148-
cache_t* cache_row_;
149-
int dim_;
150-
StochasticRoundingRNGState stoc_rounding_state_;
151-
StochasticRoundingRNGState* stoc_rounding_state_ptr_;
143+
//////////////////////////////////////////////////////////////////////////////
144+
// Load 4 elements from the table row at element offset d into a register
145+
// variable (Vec4T<dst_t>)
146+
//
147+
// If the cache row pointer is valid, then data will be read from the cache
148+
// instead of embedding table.
149+
//////////////////////////////////////////////////////////////////////////////
152150

153-
// Load from cache if resident; else load from embedding
154151
DEVICE_INLINE Vec4T<dst_t> load(const int32_t d, const float2 qparams) const {
152+
// Load from the cache if resident; else load from the embedding table.
153+
//
154+
// Note: This method assumes that dst_t is of higher precision than cache_t
155+
// and emb_t
155156
if (cache_row_) {
156157
return dequantize_load<dst_t, cache_t>(cache_row_ + d, qparams);
157158
} else {
158159
return dequantize_load<dst_t, emb_t>(row_ + d, qparams);
159160
}
160161
}
161162

162-
// Write back weight (high precision) to cache if resident; else write to
163-
// embedding assume dst_t is higher precision than cache_t and emb_t
163+
//////////////////////////////////////////////////////////////////////////////
164+
// Store regster variable of 4 elements (Vec4T<dst_t>) back into the table
165+
// into the table row at element offset d
166+
//
167+
// If the cache row pointer is valid, then data will be written to the cache
168+
// instead of embedding table.
169+
//////////////////////////////////////////////////////////////////////////////
170+
164171
DEVICE_INLINE void
165172
store(const Vec4T<dst_t>& v, const int32_t d, const float2 qparams) {
173+
// Write back weight (high precision) to cache if resident; else write to
174+
// embedding table.
175+
//
176+
// Note: This method assumes that dst_t is of higher precision than cache_t
177+
// and emb_t
166178
if (cache_row_) {
167179
quantize_store(cache_row_ + d, v, stoc_rounding_state_ptr_, qparams);
168180
} else {
169181
quantize_store(row_ + d, v, stoc_rounding_state_ptr_, qparams);
170182
}
171183
}
172184

173-
// Copy vector from src_vec to dst_vec (both are float)
174-
DEVICE_INLINE void same_type_vector_copy(
175-
float* dst_vec,
176-
const float* src_vec) {
177-
*reinterpret_cast<float4*>(dst_vec) =
178-
*reinterpret_cast<const float4*>(src_vec);
179-
}
180-
181-
// Copy vector from src_vec to dst_vec (both are at::Half)
182-
DEVICE_INLINE void same_type_vector_copy(
183-
at::Half* dst_vec,
184-
const at::Half* src_vec) {
185-
*reinterpret_cast<float2*>(dst_vec) =
186-
*reinterpret_cast<const float2*>(src_vec);
187-
}
188-
189-
// Evict cached row into embedding row (high prec -> low prec)
190-
DEVICE_INLINE void evict_cache(const int32_t d, const float2 qparams) {
191-
if constexpr (std::is_same_v<emb_t, cache_t>) {
192-
// No conversion required when emb_t and cache_t are the same type
193-
same_type_vector_copy(
194-
reinterpret_cast<cache_t*>(row_ + d),
195-
reinterpret_cast<const cache_t*>(cache_row_ + d));
196-
} else {
197-
// Does 2-step conversion: cache_t -> FP32 -> weight_t
198-
const auto cache_slice = load(d, qparams);
199-
quantize_store(row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams);
200-
}
201-
}
202-
203-
DEVICE_INLINE void store_qparams(const float2 qparams) {
204-
store_qparams_to_row(row_ + dim_, qparams);
205-
}
185+
//////////////////////////////////////////////////////////////////////////////
186+
// Fetch the quantization parameters of the table row
187+
//
188+
// Qparams are fetched from the end of the row in the embedding table, not the
189+
// cache.
190+
//////////////////////////////////////////////////////////////////////////////
206191

207192
DEVICE_INLINE float2 load_qparams() const {
208193
if constexpr (std::is_same_v<emb_t, uint8_t>) {
@@ -212,32 +197,77 @@ struct WeightRow {
212197
}
213198
}
214199

215-
DEVICE_INLINE void warp_copy_to_cache(
216-
cache_t* dst_row,
217-
const uint32_t dim_length,
200+
//////////////////////////////////////////////////////////////////////////////
201+
// Update the quantization parameters of the table row
202+
//
203+
// Qparams are stored at the end of the row in the embedding table, not the
204+
// cache.
205+
//////////////////////////////////////////////////////////////////////////////
206+
207+
template <typename T = emb_t>
208+
DEVICE_INLINE auto store_qparams(const float2 qparams) const
209+
-> std::enable_if_t<std::is_same_v<T, uint8_t>, void> {
210+
store_qparams_to_row(row_ + dim_, qparams);
211+
}
212+
213+
//////////////////////////////////////////////////////////////////////////////
214+
// Load the row from the embedding table into the cache
215+
//
216+
// De-quantization will be applied if the embedding table type is uint8_t (low
217+
// prec -> high prec).
218+
//////////////////////////////////////////////////////////////////////////////
219+
220+
DEVICE_INLINE void warp_cache_load(
221+
// cache_t* dst_row,
218222
const uint32_t num_lanes,
219223
const uint32_t lane_id) {
220224
if constexpr (std::is_same_v<emb_t, cache_t>) {
221-
// No conversion required when emb_t and cache_t are the same type
222-
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
225+
// If the embedding table and cache types are the same, then simply copy
226+
// data from cache to embedding table.
227+
for (auto d = lane_id * 4; d < dim_; d += num_lanes * 4) {
223228
same_type_vector_copy(
224-
dst_row + d, reinterpret_cast<const cache_t*>(row_ + d));
229+
cache_row_ + d, reinterpret_cast<const cache_t*>(row_ + d));
225230
}
226231
} else {
227232
// Load quantization params from embedding row
228233
const auto qparams = load_qparams();
229234

230235
// Copy over for each warp-sized slice of Vec4's
231236
// Does 2-step conversion: weight_t -> FP32 -> cache_t
232-
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
237+
for (auto d = lane_id * 4; d < dim_; d += num_lanes * 4) {
233238
const auto slice = load(d, qparams);
234-
quantize_store(dst_row + d, slice, stoc_rounding_state_ptr_, qparams);
239+
quantize_store(
240+
cache_row_ + d, slice, stoc_rounding_state_ptr_, qparams);
235241
}
236242
}
237243
}
238244

239-
DEVICE_INLINE void warp_evict_cache(
240-
const uint32_t dim_length,
245+
//////////////////////////////////////////////////////////////////////////////
246+
// Copy the row from the embedding table into the cache
247+
//////////////////////////////////////////////////////////////////////////////
248+
249+
DEVICE_INLINE void evict_cache(const uint32_t d, const float2 qparams) {
250+
if constexpr (std::is_same_v<emb_t, cache_t>) {
251+
// If the embedding table and cache types are the same, then simply copy
252+
// data from cache to embedding table.
253+
same_type_vector_copy(
254+
reinterpret_cast<emb_t*>(row_ + d),
255+
reinterpret_cast<const cache_t*>(cache_row_ + d));
256+
} else {
257+
// Else, do 2-step conversion: cache_t -> FP32 (register) -> weight_t
258+
const auto cache_slice = load(d, qparams);
259+
quantize_store(row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams);
260+
}
261+
}
262+
263+
//////////////////////////////////////////////////////////////////////////////
264+
// Evict the row from the cache and into the embedding table.
265+
//
266+
// Quantization will be applied if the embedding table type is uint8_t (high
267+
// prec -> low prec).
268+
//////////////////////////////////////////////////////////////////////////////
269+
270+
DEVICE_INLINE void warp_cache_evict(
241271
const uint32_t num_lanes,
242272
const uint32_t lane_id) {
243273
float2 qparams;
@@ -248,7 +278,7 @@ struct WeightRow {
248278
std::numeric_limits<at::acc_type<cache_t, true>>::lowest();
249279

250280
// Compute the qparams from the cache row (not embedding row) weights
251-
for (auto d = lane_id; d * 4 < dim_length; d += num_lanes) {
281+
for (auto d = lane_id; d * 4 < dim_; d += num_lanes) {
252282
const auto cache_slice = load(d * 4, qparams); // qparams not used
253283
local_max = max(local_max, cache_slice.vmax());
254284
local_min = min(local_min, cache_slice.vmin());
@@ -263,41 +293,84 @@ struct WeightRow {
263293
}
264294
}
265295

266-
for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
296+
for (auto d = lane_id * 4; d < dim_; d += num_lanes * 4) {
267297
// Evict the slice into the embedding row
268298
evict_cache(d, qparams);
269299
}
270300
}
301+
302+
protected:
303+
emb_t* const row_;
304+
cache_t* const cache_row_;
305+
int32_t const dim_;
306+
StochasticRoundingRNGState stoc_rounding_state_;
307+
StochasticRoundingRNGState* stoc_rounding_state_ptr_;
308+
309+
//////////////////////////////////////////////////////////////////////////////
310+
// Copy 4 elements (float or at::Half) from src_vec to dst_vec
311+
//
312+
// Reinterpret cast to float4* or float2* for mass copy
313+
//////////////////////////////////////////////////////////////////////////////
314+
315+
template <
316+
typename T,
317+
typename = std::enable_if_t<is_one_of_v<T, float, at::Half>>>
318+
DEVICE_INLINE void same_type_vector_copy(T* dst_vec, const T* src_vec) {
319+
// Copy vector from src_vec to dst_vec (both are float)
320+
using ptr_t = std::conditional_t<std::is_same_v<T, float>, float4, float2>;
321+
*reinterpret_cast<ptr_t*>(dst_vec) =
322+
*reinterpret_cast<const ptr_t*>(src_vec);
323+
}
271324
};
272325

273326
////////////////////////////////////////////////////////////////////////////////
274327
// Weight Row Accessor
275328
//
276-
// This is a basic memory accessor around a row of dim_ number of embedding
277-
// weights of type row_t, and provides for loading 4 elements at a time into
278-
// Vec4T<dst_t> with de-quantization support. Unlike WeightRow, this accessor
279-
// is for reading only, and does not take into account embedding vs cache table,
280-
// etc.
329+
// This is a lightweight memory accessor around a row of dim_ number of
330+
// embedding weights of type row_t (can be HBM or UVM), and provides for loading
331+
// 4 elements at a time into Vec4T<dst_t> with de-quantization support. Unlike
332+
// the heavyweight WeightRow class, this accessor is for reading values only,
333+
// and does not handle embedding vs cache tables, etc.
334+
//
335+
// Template parameters:
336+
// row_t : The type of the table row (e.g. uint8_t, float, at::Half)
337+
// dst_t : The type of the registers
281338
////////////////////////////////////////////////////////////////////////////////
282339

283340
template <typename row_t, typename dst_t>
284-
struct WeightRowAccessor {
285-
const row_t* row_;
341+
class WeightRowAccessor {
342+
// The pointer to the row of weights in the table
343+
const row_t* const row_;
344+
345+
// The number of elements per table row.
346+
//
347+
// This is NOT necessarily equivalent to the row stride D_emb, as there may be
348+
// quantization parameters and optimizer states packed into the back of the
349+
// row.
350+
//
351+
// dim_ is presumed to be a multiple of 4, since it loads data into Vec4T for
352+
// max register occupancy.
286353
const int32_t dim_;
287-
const float2 qparams_;
288354

355+
// [OPTIONAL] The quantization parameters for the row. If the row type is not
356+
// uint8_t, i.e. not quantized, then it is set to (0.0f, 0.0f).
357+
float2 qparams_ = make_float2(0.0f, 0.0f);
358+
359+
public:
289360
DEVICE_INLINE
290361
WeightRowAccessor(const row_t* const row, const int32_t dim)
291-
: row_(row), dim_(dim), qparams_(qparams()) {}
292-
293-
DEVICE_INLINE auto qparams() const {
362+
: row_(row), dim_(dim) {
294363
if constexpr (std::is_same_v<row_t, uint8_t>) {
295-
return load_qparams_from_row<row_t>(row_ + dim_);
296-
} else {
297-
return make_float2(0.0f, 0.0f);
364+
qparams_ = qparams();
298365
}
299366
}
300367

368+
template <typename T = row_t>
369+
DEVICE_INLINE auto qparams() const
370+
-> std::enable_if_t<std::is_same_v<T, uint8_t>, float2> {
371+
return load_qparams_from_row<row_t>(row_ + dim_);
372+
}
373+
301374
DEVICE_INLINE Vec4T<dst_t> load(const int32_t d) const {
302375
return dequantize_load<dst_t, row_t>(row_ + d, qparams_);
303376
}

0 commit comments

Comments
 (0)