Skip to content

Commit c7c0972

Browse files
janselpytorchmergebot
authored andcommitted
Move TorchDynamo into PyTorch core (pytorch#86461)
Context: pytorch/torchdynamo#1588 This PR moves [TorchDynamo](https://github.com/pytorch/torchdynamo) and TorchInductor into PyTorch core. - `torchdynamo` becomes `torch._dynamo` - `torchinductor` becomes `torch._inductor` This PR was generated by running `copy_to_core.sh` in pytorch/torchdynamo#1538 Pull Request resolved: pytorch#86461 Approved by: https://github.com/voznesenskym
1 parent 97abc21 commit c7c0972

File tree

308 files changed

+85171
-14
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

308 files changed

+85171
-14
lines changed

.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ popd
3636
=======
3737
:: Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
3838

39-
pip install "ninja==1.10.0.post1" future "hypothesis==5.35.1" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pytest-xdist pytest-shard pytest-rerunfailures "xdoctest==1.0.2" "pygments==2.12.0" "opt-einsum>=3.3"
39+
pip install "ninja==1.10.0.post1" future "hypothesis==5.35.1" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pytest-xdist pytest-shard pytest-rerunfailures sympy "xdoctest==1.0.2" "pygments==2.12.0" "opt-einsum>=3.3"
4040
if errorlevel 1 exit /b
4141
if not errorlevel 0 exit /b
4242

benchmarks/dynamo/README.md

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Torchdynamo Benchmarks
2+
3+
## What We Benchmark
4+
TorchDynamo provides a benchmark harness that takes care of uniformly benchmarking different models. It interleaves runs of eager and dynamo to avoid machine noise/variability issues, and reports results based on medians along with P-values.
5+
6+
The runner integrates with models from TorchBenchmark, HuggingFace and TIMM suites and covers both training and inference.
7+
8+
The infrastructure allows us to specify a loss function. For torchbench models, we use .sum().backward() call in place of the native loss function. For TIMM models, we use a CrossEntropy loss. And HF models contain a loss function inside the model itself, so we don't need any special loss computation handling.
9+
10+
Training benchmarks approximate training by running the model forward, computing loss and then running backward. We entirely skip the optimizer step today.
11+
12+
Inference benchmarks and Training benchmarks measure correctness by comparing dynamo and eager model outputs given fixed inputs and seeds.
13+
14+
## Setup
15+
16+
### Machine
17+
We run benchmarks on AWS machines (p4d.24xlarge) using 8xNVidia A100 40GB cards. We suggest using Cuda 11.6 for consistency.
18+
19+
### Benchmarks
20+
Make sure to carefully follow the [torchbench installation](https://github.com/pytorch/benchmark#installation) instructions, taking care to build the auxiliary libraries (torchvision, torchtext) from a matching version to your pytorch version.
21+
22+
For HF and TIMM models, the scripts already install the transformers and timm package respectively on the first run.
23+
24+
## Runbook
25+
26+
### Basic Usage
27+
There are a lot of flags in the benchmark runner, and it can be confusing to know which settings to use or what machine to run it on. In order to support apples-to-apples comparison, we have provided the following 'standard' settings in `runner.py`. This script is a wrapper over the common benchmarking infrastructure and simplifies the flags. We will continually update `runner.py` with the latest and most relevant compilers for training and inference. It also provides some graph utilities to visualize and compare results. Some of the example commands are
28+
29+
**Inference Commands**
30+
* Inference compilers on torchbench models - `python benchmarks/runner.py --suites=torchbench --inference --dtypes=float16`
31+
32+
**Training Commands**
33+
* Training compilers on TIMM models - `python benchmarks/runner.py --suites=timm_models --training --dtypes=float32 --output-dir=timm_logs`
34+
* AOTAutograd Training compiler on TIMM models - `python benchmarks/runner.py --suites=timm_models --training --dtypes=float32 --compilers=aot_nvfuser --output-dir=timm_logs`
35+
36+
Running runner.py generates a file named `run.sh`. This file contains the actual commands that invoke the common benchmarking infrastructure with the appropriate flags. Which brings us to the advanced usage.
37+
38+
### Advanced Usage
39+
40+
One could directly call `torchbench.py`, `huggingface.py` or `timm_models.py` with the necessary flags. There are a lot of flags in the benchmarks runner. Some of the examples are as follows. These are subject to change.
41+
42+
**Inference Commands**
43+
* TorchScript NVFuser Inference - `python benchmarks/torchbench.py -dcuda -n100 --speedup-ts`
44+
* TorchInductor CUDA Graphs Inference - `python benchmarks/torchbench.py -dcuda --inductor-settings --float32 -n50 --inductor`
45+
46+
**Training Commands**
47+
* Torchscript (with TorchDynamo capture) NVFuser Training - `python benchmarks/torchbench.py --float32 -dcuda --training --nvfuser --speedup-dynamo-ts --use-eval-mode`
48+
* AOTAutograd Torchscript NVFuser Training - `python benchmarks/torchbench.py --float32 -dcuda --training --nvfuser --accuracy-aot-ts-mincut --use-eval-mode`
49+
50+
Above commands are for torchbench models. You can simply replace `torchbench.py` with `huggingface.py` for HF models, and `timm_model.py` for TIMM models.

benchmarks/dynamo/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)