|
8 | 8 | # pyre-strict
|
9 | 9 |
|
10 | 10 | import unittest
|
11 |
| -from unittest.mock import MagicMock, patch |
12 | 11 |
|
13 |
| -import torch |
14 |
| -from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType |
15 |
| -from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( |
16 |
| - EmbeddingLocation, |
17 |
| - PoolingMode, |
18 |
| -) |
| 12 | +import hypothesis.strategies as st |
19 | 13 |
|
| 14 | +import torch |
| 15 | +from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation |
20 | 16 | from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
|
21 | 17 | ComputeDevice,
|
| 18 | + get_available_compute_device, |
22 | 19 | SplitTableBatchedEmbeddingBagsCodegen,
|
23 | 20 | )
|
24 | 21 | from fbgemm_gpu.tbe.bench import (
|
|
29 | 26 | )
|
30 | 27 | from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
|
31 | 28 | from fbgemm_gpu.tbe.utils import get_device
|
| 29 | +from hypothesis import given, settings |
32 | 30 |
|
33 | 31 |
|
34 | 32 | class TestTBEBenchmarkParamsReporter(unittest.TestCase):
|
35 |
| - @patch("fbgemm_gpu.utils.FileStore") # Mock FileStore |
| 33 | + # pyre-ignore[56] |
| 34 | + @given( |
| 35 | + T=st.integers(1, 10), |
| 36 | + E=st.integers(100, 10000), |
| 37 | + D=st.sampled_from([32, 64, 128, 256]), |
| 38 | + L=st.integers(1, 10), |
| 39 | + B=st.integers(20, 100), |
| 40 | + ) |
| 41 | + @settings(max_examples=1, deadline=None) |
36 | 42 | def test_report_stats(
|
37 | 43 | self,
|
38 |
| - mock_filestore: MagicMock, # Mock FileStore |
| 44 | + T: int, |
| 45 | + E: int, |
| 46 | + D: int, |
| 47 | + L: int, |
| 48 | + B: int, |
39 | 49 | ) -> None:
|
| 50 | + """Test that the reporter can extract a valid JSON configuration from the embedding operation and requests.""" |
40 | 51 |
|
| 52 | + # Generate a TBEDataConfig |
41 | 53 | tbeconfig = TBEDataConfig(
|
42 |
| - T=2, |
43 |
| - E=1024, |
44 |
| - D=32, |
45 |
| - mixed_dim=True, |
| 54 | + T=T, |
| 55 | + E=E, |
| 56 | + D=D, |
| 57 | + mixed_dim=False, |
46 | 58 | weighted=False,
|
47 |
| - batch_params=BatchParams(B=512), |
| 59 | + batch_params=BatchParams(B=B), |
48 | 60 | indices_params=IndicesParams(
|
49 | 61 | heavy_hitters=torch.tensor([]),
|
50 | 62 | zipf_q=0.1,
|
51 | 63 | zipf_s=0.1,
|
52 | 64 | index_dtype=torch.int64,
|
53 | 65 | offset_dtype=torch.int64,
|
54 | 66 | ),
|
55 |
| - pooling_params=PoolingParams(L=2), |
56 |
| - use_cpu=True, |
| 67 | + pooling_params=PoolingParams(L=L), |
| 68 | + use_cpu=get_available_compute_device() == ComputeDevice.CPU, |
57 | 69 | )
|
58 | 70 |
|
59 |
| - embedding_location = EmbeddingLocation.HOST |
| 71 | + embedding_location = ( |
| 72 | + EmbeddingLocation.DEVICE |
| 73 | + if torch.cuda.is_available() |
| 74 | + else EmbeddingLocation.HOST |
| 75 | + ) |
60 | 76 |
|
| 77 | + # Generate the embedding dimension list |
61 | 78 | _, Ds = tbeconfig.generate_embedding_dims()
|
| 79 | + |
| 80 | + # Generate the embedding operation |
62 | 81 | embedding_op = SplitTableBatchedEmbeddingBagsCodegen(
|
63 | 82 | [
|
64 | 83 | (
|
65 | 84 | tbeconfig.E,
|
66 | 85 | D,
|
67 | 86 | embedding_location,
|
68 |
| - ComputeDevice.CPU, |
| 87 | + ComputeDevice.CUDA if get_device() == "cuda" else ComputeDevice.CPU, |
69 | 88 | )
|
70 | 89 | for D in Ds
|
71 | 90 | ],
|
72 |
| - optimizer=OptimType.EXACT_ROWWISE_ADAGRAD, |
73 |
| - learning_rate=0.01, |
74 |
| - weights_precision=SparseType.FP32, |
75 |
| - pooling_mode=PoolingMode.SUM, |
76 |
| - output_dtype=SparseType.FP32, |
77 | 91 | )
|
78 | 92 |
|
79 | 93 | embedding_op = embedding_op.to(get_device())
|
80 | 94 |
|
81 |
| - requests = tbeconfig.generate_requests(1) |
82 |
| - |
83 | 95 | # Initialize the reporter
|
84 | 96 | reporter = TBEBenchmarkParamsReporter(report_interval=1)
|
85 |
| - # Set the mock filestore as the reporter's filestore |
86 |
| - reporter.filestore = mock_filestore |
87 | 97 |
|
88 |
| - request = requests[0] |
| 98 | + # Generate indices and offsets |
| 99 | + request = tbeconfig.generate_requests(1)[0] |
| 100 | + |
89 | 101 | # Call the report_stats method
|
90 | 102 | extracted_config = reporter.extract_params(
|
91 | 103 | embedding_op=embedding_op,
|
92 | 104 | indices=request.indices,
|
93 | 105 | offsets=request.offsets,
|
94 | 106 | )
|
95 | 107 |
|
96 |
| - reporter.report_stats( |
97 |
| - embedding_op=embedding_op, |
98 |
| - indices=request.indices, |
99 |
| - offsets=request.offsets, |
100 |
| - ) |
101 |
| - |
102 |
| - # TODO: This is not working because need more details in initial config |
103 |
| - # Assert that the reconstructed configuration matches the original |
104 |
| - # assert ( |
105 |
| - # extracted_config == tbeconfig |
106 |
| - # ), "Extracted configuration does not match the original TBEDataConfig" |
107 |
| - |
108 |
| - # Check if the write method was called on the FileStore |
109 | 108 | assert (
|
110 |
| - reporter.filestore.write.assert_called_once |
111 |
| - ), "FileStore.write() was not called" |
| 109 | + extracted_config.T == tbeconfig.T |
| 110 | + and extracted_config.E == tbeconfig.E |
| 111 | + and extracted_config.D == tbeconfig.D |
| 112 | + and extracted_config.pooling_params.L == tbeconfig.pooling_params.L |
| 113 | + and extracted_config.batch_params.B == tbeconfig.batch_params.B |
| 114 | + ), "Extracted config does not match the original TBEDataConfig" |
| 115 | + # Attempt to reconstruct TBEDataConfig from extracted_json_config |
0 commit comments