Skip to content

Commit 95ee305

Browse files
q10facebook-github-bot
authored andcommitted
Migrate TBE EEG Python code to OSS (pytorch#3774)
Summary: X-link: facebookresearch/FBGEMM#854 Pull Request resolved: pytorch#3774 - Migrate TBE EEG Python code to OSS Reviewed By: shintaro-iwasaki Differential Revision: D70536525
1 parent 428e671 commit 95ee305

8 files changed

+1033
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-unsafe
9+
10+
from .bench_config import ( # noqa F401
11+
TBEBenchmarkingConfig,
12+
TBEBenchmarkingConfigLoader,
13+
)
14+
from .config import TBEDataConfig # noqa F401
15+
from .config_loader import TBEDataConfigLoader # noqa F401
16+
from .config_param_models import BatchParams, IndicesParams, PoolingParams # noqa F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import dataclasses
11+
import json
12+
from typing import Any, Dict, Optional
13+
14+
import click
15+
16+
17+
@dataclasses.dataclass(frozen=True)
18+
class TBEBenchmarkingConfig:
19+
# Number of iterations
20+
iterations: int
21+
# Number of input TBE batches to generate for testing
22+
num_requests: int
23+
# Number of warmup iterations to run before making measurements
24+
warmup_iterations: int
25+
# Amount of memory to use for flushing the GPU cache after each iteration
26+
flush_gpu_cache_size_mb: int
27+
# If set, trace will be exported to the path specified in trace_url
28+
export_trace: bool
29+
# The path for exporting the trace
30+
trace_url: Optional[str]
31+
32+
@classmethod
33+
# pyre-ignore [3]
34+
def from_dict(cls, data: Dict[str, Any]):
35+
return cls(**data)
36+
37+
@classmethod
38+
# pyre-ignore [3]
39+
def from_json(cls, data: str):
40+
return cls.from_dict(json.loads(data))
41+
42+
def dict(self) -> Dict[str, Any]:
43+
return dataclasses.asdict(self)
44+
45+
def json(self, format: bool = False) -> str:
46+
return json.dumps(self.dict(), indent=(2 if format else -1), sort_keys=True)
47+
48+
# pyre-ignore [3]
49+
def validate(self):
50+
assert self.iterations > 0, "iterations must be positive"
51+
assert self.num_requests > 0, "num_requests must be positive"
52+
assert self.warmup_iterations >= 0, "warmup_iterations must be non-negative"
53+
assert (
54+
self.flush_gpu_cache_size_mb >= 0
55+
), "flush_gpu_cache_size_mb must be non-negative"
56+
return self
57+
58+
59+
class TBEBenchmarkingConfigLoader:
60+
@classmethod
61+
# pyre-ignore [2]
62+
def options(cls, func) -> click.Command:
63+
options = [
64+
click.option(
65+
"--bench-iterations",
66+
type=int,
67+
default=100,
68+
help="Number of benchmark iterations to run",
69+
),
70+
click.option(
71+
"--bench-num-requests",
72+
type=int,
73+
default=-1,
74+
help="Number of input batches to generate. If the value is smaller than the number of benchmark iterations, input batches will be re-used",
75+
),
76+
click.option(
77+
"--bench-warmup-iterations",
78+
type=int,
79+
default=0,
80+
help="Number of warmup iterations to run before making measurements",
81+
),
82+
click.option(
83+
"--bench-flush-gpu-cache-size",
84+
type=int,
85+
default=0,
86+
help="Amount of memory to use for flushing the GPU cache after each iteration (MB)",
87+
),
88+
click.option(
89+
"--bench-export-trace",
90+
is_flag=True,
91+
default=False,
92+
help="If set, a trace will be exported",
93+
),
94+
click.option(
95+
"--bench-trace-url",
96+
type=str,
97+
required=False,
98+
default="{emb_op_type}_tbe_{phase}_trace_{ospid}.json",
99+
help="The path for exporting the trace",
100+
),
101+
]
102+
103+
for option in reversed(options):
104+
func = option(func)
105+
return func
106+
107+
@classmethod
108+
def load(cls, context: click.Context) -> TBEBenchmarkingConfig:
109+
params = context.params
110+
111+
iterations = params["bench_iterations"]
112+
num_requests = params["bench_num_requests"]
113+
warmup_iterations = params["bench_warmup_iterations"]
114+
flush_gpu_cache_size = params["bench_flush_gpu_cache_size"]
115+
export_trace = params["bench_export_trace"]
116+
trace_url = params["bench_trace_url"]
117+
118+
# Default the number of TBE requests to number of iterations specified
119+
num_requests = iterations if num_requests == -1 else num_requests
120+
121+
return TBEBenchmarkingConfig(
122+
iterations,
123+
num_requests,
124+
warmup_iterations,
125+
flush_gpu_cache_size,
126+
export_trace,
127+
trace_url,
128+
).validate()

0 commit comments

Comments
 (0)