|
25 | 25 |
|
26 | 26 | # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
|
27 | 27 | import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
|
28 |
| - |
29 | 28 | from fbgemm_gpu.config import FeatureGate, FeatureGateName
|
30 | 29 | from fbgemm_gpu.runtime_monitor import (
|
31 | 30 | AsyncSeriesTimer,
|
|
49 | 48 | generate_vbe_metadata,
|
50 | 49 | is_torchdynamo_compiling,
|
51 | 50 | )
|
| 51 | +from fbgemm_gpu.tbe_input_dump import TBEInputDump, TBEInputDumpConfig |
52 | 52 |
|
53 | 53 | from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
|
54 | 54 |
|
@@ -647,6 +647,7 @@ def __init__( # noqa C901
|
647 | 647 | global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
|
648 | 648 | uvm_host_mapped: bool = False,
|
649 | 649 | extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None,
|
| 650 | + tbe_input_dump_config: Optional[TBEInputDumpConfig] = None, |
650 | 651 | ) -> None:
|
651 | 652 | super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
|
652 | 653 |
|
@@ -820,6 +821,21 @@ def __init__( # noqa C901
|
820 | 821 | self.feature_table_map: List[int] = (
|
821 | 822 | feature_table_map if feature_table_map is not None else list(range(T_))
|
822 | 823 | )
|
| 824 | + |
| 825 | + self.tbe_input_dump: Optional[TBEInputDump] = ( |
| 826 | + tbe_input_dump_config.create_tbe_input_dump( |
| 827 | + table_names=( |
| 828 | + table_names |
| 829 | + if table_names |
| 830 | + else [f"table-{i}" for i in range(len(embedding_specs))] |
| 831 | + ), |
| 832 | + table_heights=rows, |
| 833 | + tbe_uuid=self.uuid, |
| 834 | + feature_table_map=self.feature_table_map, |
| 835 | + ) |
| 836 | + if tbe_input_dump_config is not None |
| 837 | + else None |
| 838 | + ) |
823 | 839 | T = len(self.feature_table_map)
|
824 | 840 | assert T_ <= T
|
825 | 841 | table_has_feature = [False] * T_
|
@@ -1789,6 +1805,11 @@ def forward( # noqa: C901
|
1789 | 1805 | self._report_io_size_count("fwd_input", indices)
|
1790 | 1806 | self._report_tbe_mem_usage()
|
1791 | 1807 |
|
| 1808 | + if self.tbe_input_dump is not None: |
| 1809 | + tbe_input_dump: TBEInputDump = self.tbe_input_dump |
| 1810 | + if tbe_input_dump.should_dump(self.step): |
| 1811 | + tbe_input_dump.run(indices, offsets, batch_size_per_feature_per_rank) |
| 1812 | + |
1792 | 1813 | if len(self.timesteps_prefetched) == 0:
|
1793 | 1814 | # In forward, we don't enable multi-pass prefetch as we want the process
|
1794 | 1815 | # to be as fast as possible and memory usage doesn't matter (will be recycled
|
|
0 commit comments