Skip to content

Commit 4257d45

Browse files
Sihui Hanfacebook-github-bot
Sihui Han
authored andcommitted
: basic tbe input dump framework (pytorch#3593)
Summary: Plugin capability to dump TBE input and no-ops in OSS Differential Revision: D68446857
1 parent fde11cd commit 4257d45

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

+33
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@
4949
generate_vbe_metadata,
5050
is_torchdynamo_compiling,
5151
)
52+
from fbgemm_gpu.tbe_input_multiplexer import (
53+
TBEInfo,
54+
TBEInputInfo,
55+
TBEInputMultiplexer,
56+
TBEInputMultiplexerConfig,
57+
)
5258

5359
from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
5460

@@ -647,6 +653,7 @@ def __init__( # noqa C901
647653
global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
648654
uvm_host_mapped: bool = False,
649655
extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None,
656+
tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None,
650657
) -> None:
651658
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
652659

@@ -820,6 +827,23 @@ def __init__( # noqa C901
820827
self.feature_table_map: List[int] = (
821828
feature_table_map if feature_table_map is not None else list(range(T_))
822829
)
830+
831+
self.tbe_input_multiplexer: Optional[TBEInputMultiplexer] = (
832+
tbe_input_multiplexer_config.create_tbe_input_multiplexer(
833+
tbe_info=TBEInfo(
834+
table_names=(
835+
table_names
836+
if table_names
837+
else [f"table-{i}" for i in range(len(embedding_specs))]
838+
),
839+
table_heights=rows,
840+
tbe_uuid=self.uuid,
841+
feature_table_map=self.feature_table_map,
842+
)
843+
)
844+
if tbe_input_multiplexer_config is not None
845+
else None
846+
)
823847
T = len(self.feature_table_map)
824848
assert T_ <= T
825849
table_has_feature = [False] * T_
@@ -1789,6 +1813,15 @@ def forward( # noqa: C901
17891813
self._report_io_size_count("fwd_input", indices)
17901814
self._report_tbe_mem_usage()
17911815

1816+
if self.tbe_input_multiplexer is not None:
1817+
tbe_input_multiplexer: TBEInputMultiplexer = self.tbe_input_multiplexer
1818+
if tbe_input_multiplexer.should_run(self.step):
1819+
tbe_input_multiplexer.run(
1820+
tbe_input_info=TBEInputInfo(
1821+
indices, offsets, batch_size_per_feature_per_rank
1822+
)
1823+
)
1824+
17921825
if len(self.timesteps_prefetched) == 0:
17931826
# In forward, we don't enable multi-pass prefetch as we want the process
17941827
# to be as fast as possible and memory usage doesn't matter (will be recycled
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import abc
9+
10+
from dataclasses import dataclass
11+
from typing import List, Optional
12+
13+
from torch import Tensor
14+
15+
16+
@dataclass(frozen=True)
17+
class TBEInfo:
18+
"""
19+
contains selective TBE info used for multiplexing. For more info, check https://fburl.com/code/ljnd6j65
20+
21+
Args:
22+
table_names: table names within the tbe
23+
table_heights: table heights (hashsize)
24+
tbe_uuid: a unique identifier for the TBE
25+
feature_table_map: feature to table map
26+
"""
27+
28+
table_names: List[str]
29+
table_heights: List[int]
30+
tbe_uuid: str
31+
feature_table_map: List[int]
32+
33+
34+
@dataclass(frozen=True)
35+
class TBEInputInfo:
36+
"""
37+
indices: A 1D-tensor that contains indices to be looked up
38+
from all embedding table.
39+
offsets: A 1D-tensor that conatins offsets of indices.
40+
batch_size_per_feature_per_rank: An optional 2D-tensor that contains batch sizes for every rank and
41+
every feature. this is needed to support VBE.
42+
"""
43+
44+
indices: Tensor
45+
offsets: Tensor
46+
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None
47+
48+
49+
class TBEInputMultiplexer(abc.ABC):
50+
"""
51+
Interface for multiplex TBE input data out, actual implementation may store the data to files
52+
"""
53+
54+
@abc.abstractmethod
55+
def should_run(self, step: int) -> bool:
56+
"""
57+
To check if should run at this step
58+
Args:
59+
step: the current step
60+
Returns:
61+
True if should run, otherwise False
62+
"""
63+
pass
64+
65+
@abc.abstractmethod
66+
def run(
67+
self,
68+
tbe_input_info: TBEInputInfo,
69+
) -> None:
70+
"""
71+
To run the tbe input multiplex, and this is called for every batch that needs to be dumped
72+
Args:
73+
tbe_input_info: tbe input info that contains all the necessary info for further processing
74+
"""
75+
pass
76+
77+
78+
@dataclass(frozen=True)
79+
class TBEInputMultiplexerConfig:
80+
"""
81+
Configuration for TBEInputMultiplexer
82+
"""
83+
84+
# first batch to start run, -1 means no run
85+
start_batch: int = -1
86+
# total batch to multiplex
87+
total_batch: int = 0
88+
89+
def create_tbe_input_multiplexer(
90+
self,
91+
tbe_info: TBEInfo,
92+
) -> Optional[TBEInputMultiplexer]:
93+
assert (
94+
self.start_batch == -1
95+
), "Cannot specify monitor_start_batch without an actual implementation."
96+
return None

0 commit comments

Comments
 (0)