Skip to content

Commit 748e52b

Browse files
Patched docs for torch_compile_tutorial (#2936)
* Patched docs for torch_compile_tutorial --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent e07d43b commit 748e52b

File tree

1 file changed

+100
-2
lines changed

1 file changed

+100
-2
lines changed

intermediate_source/torch_compile_tutorial.py

+100-2
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,21 @@ def foo(x, y):
7373

7474
######################################################################
7575
# Alternatively, we can decorate the function.
76+
t1 = torch.randn(10, 10)
77+
t2 = torch.randn(10, 10)
7678

7779
@torch.compile
7880
def opt_foo2(x, y):
7981
a = torch.sin(x)
8082
b = torch.cos(y)
8183
return a + b
82-
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
84+
print(opt_foo2(t1, t2))
8385

8486
######################################################################
8587
# We can also optimize ``torch.nn.Module`` instances.
8688

89+
t = torch.randn(10, 100)
90+
8791
class MyModule(torch.nn.Module):
8892
def __init__(self):
8993
super().__init__()
@@ -94,7 +98,101 @@ def forward(self, x):
9498

9599
mod = MyModule()
96100
opt_mod = torch.compile(mod)
97-
print(opt_mod(torch.randn(10, 100)))
101+
print(opt_mod(t))
102+
103+
######################################################################
104+
# torch.compile and Nested Calls
105+
# ------------------------------
106+
# Nested function calls within the decorated function will also be compiled.
107+
108+
def nested_function(x):
109+
return torch.sin(x)
110+
111+
@torch.compile
112+
def outer_function(x, y):
113+
a = nested_function(x)
114+
b = torch.cos(y)
115+
return a + b
116+
117+
print(outer_function(t1, t2))
118+
119+
######################################################################
120+
# In the same fashion, when compiling a module all sub-modules and methods
121+
# within it, that are not in a skip list, are also compiled.
122+
123+
class OuterModule(torch.nn.Module):
124+
def __init__(self):
125+
super().__init__()
126+
self.inner_module = MyModule()
127+
self.outer_lin = torch.nn.Linear(10, 2)
128+
129+
def forward(self, x):
130+
x = self.inner_module(x)
131+
return torch.nn.functional.relu(self.outer_lin(x))
132+
133+
outer_mod = OuterModule()
134+
opt_outer_mod = torch.compile(outer_mod)
135+
print(opt_outer_mod(t))
136+
137+
######################################################################
138+
# We can also disable some functions from being compiled by using
139+
# ``torch.compiler.disable``. Suppose you want to disable the tracing on just
140+
# the ``complex_function`` function, but want to continue the tracing back in
141+
# ``complex_conjugate``. In this case, you can use
142+
# ``torch.compiler.disable(recursive=False)`` option. Otherwise, the default is
143+
# ``recursive=True``.
144+
145+
def complex_conjugate(z):
146+
return torch.conj(z)
147+
148+
@torch.compiler.disable(recursive=False)
149+
def complex_function(real, imag):
150+
# Assuming this function cause problems in the compilation
151+
z = torch.complex(real, imag)
152+
return complex_conjugate(z)
153+
154+
def outer_function():
155+
real = torch.tensor([2, 3], dtype=torch.float32)
156+
imag = torch.tensor([4, 5], dtype=torch.float32)
157+
z = complex_function(real, imag)
158+
return torch.abs(z)
159+
160+
# Try to compile the outer_function
161+
try:
162+
opt_outer_function = torch.compile(outer_function)
163+
print(opt_outer_function())
164+
except Exception as e:
165+
print("Compilation of outer_function failed:", e)
166+
167+
######################################################################
168+
# Best Practices and Recommendations
169+
# ----------------------------------
170+
#
171+
# Behavior of ``torch.compile`` with Nested Modules and Function Calls
172+
#
173+
# When you use ``torch.compile``, the compiler will try to recursively compile
174+
# every function call inside the target function or module inside the target
175+
# function or module that is not in a skip list (such as built-ins, some functions in
176+
# the torch.* namespace).
177+
#
178+
# **Best Practices:**
179+
#
180+
# 1. **Top-Level Compilation:** One approach is to compile at the highest level
181+
# possible (i.e., when the top-level module is initialized/called) and
182+
# selectively disable compilation when encountering excessive graph breaks or
183+
# errors. If there are still many compile issues, compile individual
184+
# subcomponents instead.
185+
#
186+
# 2. **Modular Testing:** Test individual functions and modules with ``torch.compile``
187+
# before integrating them into larger models to isolate potential issues.
188+
#
189+
# 3. **Disable Compilation Selectively:** If certain functions or sub-modules
190+
# cannot be handled by `torch.compile`, use the `torch.compiler.disable` context
191+
# managers to recursively exclude them from compilation.
192+
#
193+
# 4. **Compile Leaf Functions First:** In complex models with multiple nested
194+
# functions and modules, start by compiling the leaf functions or modules first.
195+
# For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__.
98196

99197
######################################################################
100198
# Demonstrating Speedups

0 commit comments

Comments
 (0)