Skip to content

Commit a12f4ed

Browse files
Sihui Hanfacebook-github-bot
Sihui Han
authored andcommitted
: basic tbe input dump framework (pytorch#3593)
Summary: Plugin capability to dump TBE input and no-ops in OSS Reviewed By: damianr99 Differential Revision: D68446857
1 parent b858408 commit a12f4ed

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
2727
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
28-
2928
from fbgemm_gpu.config import FeatureGate, FeatureGateName
3029
from fbgemm_gpu.runtime_monitor import (
3130
AsyncSeriesTimer,
@@ -49,6 +48,7 @@
4948
generate_vbe_metadata,
5049
is_torchdynamo_compiling,
5150
)
51+
from fbgemm_gpu.tbe_input_dump import TBEInputDump, TBEInputDumpConfig
5252

5353
from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
5454

@@ -647,6 +647,7 @@ def __init__( # noqa C901
647647
global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
648648
uvm_host_mapped: bool = False,
649649
extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None,
650+
tbe_input_dump_config: Optional[TBEInputDumpConfig] = None,
650651
) -> None:
651652
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
652653

@@ -820,6 +821,21 @@ def __init__( # noqa C901
820821
self.feature_table_map: List[int] = (
821822
feature_table_map if feature_table_map is not None else list(range(T_))
822823
)
824+
825+
self.tbe_input_dump: Optional[TBEInputDump] = (
826+
tbe_input_dump_config.create_tbe_input_dump(
827+
table_names=(
828+
table_names
829+
if table_names
830+
else [f"table-{i}" for i in range(len(embedding_specs))]
831+
),
832+
table_heights=rows,
833+
tbe_uuid=self.uuid,
834+
feature_table_map=self.feature_table_map,
835+
)
836+
if tbe_input_dump_config is not None
837+
else None
838+
)
823839
T = len(self.feature_table_map)
824840
assert T_ <= T
825841
table_has_feature = [False] * T_
@@ -1789,6 +1805,11 @@ def forward( # noqa: C901
17891805
self._report_io_size_count("fwd_input", indices)
17901806
self._report_tbe_mem_usage()
17911807

1808+
if self.tbe_input_dump is not None:
1809+
tbe_input_dump: TBEInputDump = self.tbe_input_dump
1810+
if tbe_input_dump.should_dump(self.step):
1811+
tbe_input_dump.run(indices, offsets, batch_size_per_feature_per_rank)
1812+
17921813
if len(self.timesteps_prefetched) == 0:
17931814
# In forward, we don't enable multi-pass prefetch as we want the process
17941815
# to be as fast as possible and memory usage doesn't matter (will be recycled
+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import abc
9+
10+
from dataclasses import dataclass
11+
from typing import List, Optional
12+
13+
from torch import Tensor
14+
15+
16+
class TBEInputDump(abc.ABC):
17+
"""
18+
Interface for dump TBE input data out, actual implementation may store the data to files
19+
"""
20+
21+
@abc.abstractmethod
22+
def should_dump(self, step: int) -> bool:
23+
"""
24+
To check if the dump should be triggered at this step
25+
Args:
26+
step: the current step
27+
Returns:
28+
True if the dump should be triggered, otherwise False
29+
"""
30+
pass
31+
32+
@abc.abstractmethod
33+
def run(
34+
self,
35+
indices: Tensor,
36+
offsets: Tensor,
37+
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
38+
) -> None:
39+
"""
40+
To run the tbe input dump, and this is called for every batch that needs to be dumped
41+
Args:
42+
indices: A 1D-tensor that contains indices to be looked up
43+
from all embedding table.
44+
offsets: A 1D-tensor that conatins offsets of indices.
45+
batch_size_per_feature_per_rank: An optional 2D-tensor that contains batch sizes for every rank and
46+
every feature. this is needed to support VBE.
47+
"""
48+
pass
49+
50+
51+
@dataclass(frozen=True)
52+
class TBEInputDumpConfig:
53+
"""
54+
Configuration for TBEInputDump
55+
"""
56+
57+
# first batch to start dump, -1 means no dump
58+
monitored_batch_start: int = -1
59+
# total batch to dump
60+
monitored_total_batch: int = 0
61+
62+
def create_tbe_input_dump(
63+
self,
64+
table_names: List[str],
65+
table_heights: List[int],
66+
tbe_uuid: str,
67+
feature_table_map: List[int],
68+
) -> Optional[TBEInputDump]:
69+
assert (
70+
self.monitored_batch_start == -1
71+
), "Cannot specify monitored_batch_start without an actual implementation of tbe dump"
72+
return None

0 commit comments

Comments
 (0)