Skip to content

Commit dde78a7

Browse files
gchalumpfacebook-github-bot
authored andcommitted
Report TBE data configuration with EEG-based indices estimation (#4018)
Summary: X-link: facebookresearch/FBGEMM#1106 Pull Request resolved: #4018 - Separate a new method in the TBEBenchmarkParamsReporter class that extracts the TBE data configuration parameters from the SplitTableBatchedEmbeddingBagsCodegen object and returns them as TBEDataConfig. - Add unit test to verify extracted TBEDataConfig. Differential Revision: D73450767
1 parent 6f3e870 commit dde78a7

File tree

2 files changed

+53
-44
lines changed

2 files changed

+53
-44
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
) -> None:
4646
self.report_interval = report_interval
4747
self.report_once = report_once
48+
self.has_reported = False
4849

4950
default_bucket = "/tmp" if open_source else "tlparse_reports"
5051
bucket = (
@@ -149,7 +150,9 @@ def report_stats(
149150
per_sample_weights (Optional[Tensor]): Input per
150151
sample weights
151152
"""
152-
if embedding_op.iter.item() % self.report_interval == 0:
153+
if embedding_op.iter.item() % self.report_interval == 0 and (
154+
not self.report_once or (self.report_once and not self.has_reported)
155+
):
153156
# Extract TBE config
154157
config = self.extract_params(
155158
embedding_op, indices, offsets, per_sample_weights
@@ -160,3 +163,5 @@ def report_stats(
160163
f"tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter.item()}.json",
161164
io.BytesIO(config.json(format=True).encode()),
162165
)
166+
167+
self.has_reported = True

fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py

+47-43
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,14 @@
88
# pyre-strict
99

1010
import unittest
11-
from unittest.mock import MagicMock, patch
1211

13-
import torch
14-
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
15-
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
16-
EmbeddingLocation,
17-
PoolingMode,
18-
)
12+
import hypothesis.strategies as st
1913

14+
import torch
15+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation
2016
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
2117
ComputeDevice,
18+
get_available_compute_device,
2219
SplitTableBatchedEmbeddingBagsCodegen,
2320
)
2421
from fbgemm_gpu.tbe.bench import (
@@ -29,83 +26,90 @@
2926
)
3027
from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
3128
from fbgemm_gpu.tbe.utils import get_device
29+
from hypothesis import given, settings
3230

3331

3432
class TestTBEBenchmarkParamsReporter(unittest.TestCase):
35-
@patch("fbgemm_gpu.utils.FileStore") # Mock FileStore
33+
# pyre-ignore[56]
34+
@given(
35+
T=st.integers(1, 10),
36+
E=st.integers(100, 10000),
37+
D=st.sampled_from([32, 64, 128, 256]),
38+
L=st.integers(1, 10),
39+
B=st.integers(20, 100),
40+
)
41+
@settings(max_examples=1, deadline=None)
3642
def test_report_stats(
3743
self,
38-
mock_filestore: MagicMock, # Mock FileStore
44+
T: int,
45+
E: int,
46+
D: int,
47+
L: int,
48+
B: int,
3949
) -> None:
50+
"""Test that the reporter can extract a valid JSON configuration from the embedding operation and requests."""
4051

52+
# Generate a TBEDataConfig
4153
tbeconfig = TBEDataConfig(
42-
T=2,
43-
E=1024,
44-
D=32,
45-
mixed_dim=True,
54+
T=T,
55+
E=E,
56+
D=D,
57+
mixed_dim=False,
4658
weighted=False,
47-
batch_params=BatchParams(B=512),
59+
batch_params=BatchParams(B=B),
4860
indices_params=IndicesParams(
4961
heavy_hitters=torch.tensor([]),
5062
zipf_q=0.1,
5163
zipf_s=0.1,
5264
index_dtype=torch.int64,
5365
offset_dtype=torch.int64,
5466
),
55-
pooling_params=PoolingParams(L=2),
56-
use_cpu=True,
67+
pooling_params=PoolingParams(L=L),
68+
use_cpu=get_available_compute_device() == ComputeDevice.CPU,
5769
)
5870

59-
embedding_location = EmbeddingLocation.HOST
71+
embedding_location = (
72+
EmbeddingLocation.DEVICE
73+
if torch.cuda.is_available()
74+
else EmbeddingLocation.HOST
75+
)
6076

77+
# Generate the embedding dimension list
6178
_, Ds = tbeconfig.generate_embedding_dims()
79+
80+
# Generate the embedding operation
6281
embedding_op = SplitTableBatchedEmbeddingBagsCodegen(
6382
[
6483
(
6584
tbeconfig.E,
6685
D,
6786
embedding_location,
68-
ComputeDevice.CPU,
87+
ComputeDevice.CUDA if get_device() == "cuda" else ComputeDevice.CPU,
6988
)
7089
for D in Ds
7190
],
72-
optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
73-
learning_rate=0.01,
74-
weights_precision=SparseType.FP32,
75-
pooling_mode=PoolingMode.SUM,
76-
output_dtype=SparseType.FP32,
7791
)
7892

7993
embedding_op = embedding_op.to(get_device())
8094

81-
requests = tbeconfig.generate_requests(1)
82-
8395
# Initialize the reporter
8496
reporter = TBEBenchmarkParamsReporter(report_interval=1)
85-
# Set the mock filestore as the reporter's filestore
86-
reporter.filestore = mock_filestore
8797

88-
request = requests[0]
98+
# Generate indices and offsets
99+
request = tbeconfig.generate_requests(1)[0]
100+
89101
# Call the report_stats method
90102
extracted_config = reporter.extract_params(
91103
embedding_op=embedding_op,
92104
indices=request.indices,
93105
offsets=request.offsets,
94106
)
95107

96-
reporter.report_stats(
97-
embedding_op=embedding_op,
98-
indices=request.indices,
99-
offsets=request.offsets,
100-
)
101-
102-
# TODO: This is not working because need more details in initial config
103-
# Assert that the reconstructed configuration matches the original
104-
# assert (
105-
# extracted_config == tbeconfig
106-
# ), "Extracted configuration does not match the original TBEDataConfig"
107-
108-
# Check if the write method was called on the FileStore
109108
assert (
110-
reporter.filestore.write.assert_called_once
111-
), "FileStore.write() was not called"
109+
extracted_config.T == tbeconfig.T
110+
and extracted_config.E == tbeconfig.E
111+
and extracted_config.D == tbeconfig.D
112+
and extracted_config.pooling_params.L == tbeconfig.pooling_params.L
113+
and extracted_config.batch_params.B == tbeconfig.batch_params.B
114+
), "Extracted config does not match the original TBEDataConfig"
115+
# Attempt to reconstruct TBEDataConfig from extracted_json_config

0 commit comments

Comments
 (0)