Skip to content

Commit 010493b

Browse files
q10facebook-github-bot
authored andcommitted
Optimzed backward pass for ROCm devices (pt 2) (pytorch#3511)
Summary: X-link: facebookresearch/FBGEMM#594 - Break up D66310520 (pytorch#3367) into backend and frontend diffs. This is the frontend diff, and followup to D66986498 Differential Revision: D67407935
1 parent a75d8fe commit 010493b

6 files changed

+19
-5
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function(
172172
c10::SymInt /* max_B = -1 */,
173173
c10::SymInt /* max_B_feature_rank = -1 */,
174174
c10::SymInt /* vbe_output_size = -1 */,
175-
bool /* mixed_D = true */) {
175+
bool /* mixed_D = false */) {
176176
return SplitLookupFunction_Dense_Op::apply(
177177
host_weights,
178178
weights_offsets,
@@ -191,15 +191,15 @@ Tensor split_embedding_codegen_lookup_dense_function(
191191
// Deprecated for fb namespace! Please use fbgemm namespace instead!
192192
TORCH_LIBRARY_FRAGMENT(fb, m) {
193193
m.def(
194-
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=True) -> Tensor");
194+
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor");
195195
DISPATCH_TO_CPU(
196196
"dense_embedding_codegen_lookup_function",
197197
split_embedding_codegen_lookup_dense_function);
198198
}
199199

200200
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
201201
m.def(
202-
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=True) -> Tensor");
202+
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor");
203203
DISPATCH_TO_CPU(
204204
"dense_embedding_codegen_lookup_function",
205205
split_embedding_codegen_lookup_dense_function);

fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
10831083
{%- else %}
10841084
const c10::SymInt vbe_output_size = -1,
10851085
{%- endif %}
1086-
const bool mixed_D = true
1086+
const bool mixed_D = false
10871087
) {
10881088
// TODO: refactor into macro
10891089
{%- if has_gpu_support %}
@@ -1200,7 +1200,7 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) {
12001200
" Tensor[]? ssd_tensors=None,"
12011201
{%- endif %}
12021202
" float gwd_lower_bound=0, "
1203-
" bool mixed_D=True"
1203+
" bool mixed_D=False"
12041204
") -> Tensor",
12051205
{PT2_COMPLIANT_TAG});
12061206

fbgemm_gpu/codegen/training/python/lookup_args.template

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class CommonArgs(NamedTuple):
4949
{%- if ssd %}
5050
ssd_tensors: Dict[str, torch.Tensor]
5151
{%- endif %}
52+
mixed_D: bool
5253

5354

5455
class OptimizerArgs(NamedTuple):

fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template

+1
Original file line numberDiff line numberDiff line change
@@ -409,5 +409,6 @@ def invoke(
409409
use_homogeneous_placements=common_args.use_homogeneous_placements,
410410
apply_global_weight_decay=apply_global_weight_decay,
411411
gwd_lower_bound=gwd_lower_bound,
412+
mixed_D=common_args.mixed_D,
412413
)
413414
{%- endif %}

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

+11
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,7 @@ def __init__( # noqa C901
744744
not mixed_D
745745
), "OptimType.NONE does not support mixed embedding dimension"
746746

747+
self.mixed_D: bool = mixed_D
747748
if device is None:
748749
self.current_device: torch.device = (
749750
torch.device("cpu")
@@ -1806,6 +1807,7 @@ def forward( # noqa: C901
18061807
is_experimental=self.is_experimental,
18071808
use_uniq_cache_locations_bwd=self.use_uniq_cache_locations_bwd,
18081809
use_homogeneous_placements=self.use_homogeneous_placements,
1810+
mixed_D=self.mixed_D,
18091811
)
18101812

18111813
if self.optimizer == OptimType.NONE:
@@ -3581,6 +3583,14 @@ def __init__(
35813583
)
35823584
assert self.D_offsets.numel() == T + 1
35833585

3586+
mixed_D = False
3587+
D = dims[0]
3588+
for d in dims:
3589+
if d != D:
3590+
mixed_D = True
3591+
break
3592+
self.mixed_D: bool = mixed_D
3593+
35843594
# Required for VBE
35853595
self.register_buffer(
35863596
"feature_dims",
@@ -3694,6 +3704,7 @@ def forward(
36943704
max_B=vbe_metadata.max_B,
36953705
max_B_feature_rank=vbe_metadata.max_B_feature_rank,
36963706
vbe_output_size=vbe_metadata.output_size,
3707+
mixed_D=self.mixed_D,
36973708
)
36983709

36993710
@torch.jit.export

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

+1
Original file line numberDiff line numberDiff line change
@@ -1615,6 +1615,7 @@ def forward(
16151615
},
16161616
# pyre-fixme[6]: Expected `lookup_args_ssd.VBEMetadata` but got `lookup_args.VBEMetadata`
16171617
vbe_metadata=vbe_metadata,
1618+
mixed_D=False,
16181619
)
16191620

16201621
self.timesteps_prefetched.pop(0)

0 commit comments

Comments
 (0)