@@ -73,17 +73,21 @@ def foo(x, y):
73
73
74
74
######################################################################
75
75
# Alternatively, we can decorate the function.
76
+ t1 = torch .randn (10 , 10 )
77
+ t2 = torch .randn (10 , 10 )
76
78
77
79
@torch .compile
78
80
def opt_foo2 (x , y ):
79
81
a = torch .sin (x )
80
82
b = torch .cos (y )
81
83
return a + b
82
- print (opt_foo2 (torch . randn ( 10 , 10 ), torch . randn ( 10 , 10 ) ))
84
+ print (opt_foo2 (t1 , t2 ))
83
85
84
86
######################################################################
85
87
# We can also optimize ``torch.nn.Module`` instances.
86
88
89
+ t = torch .randn (10 , 100 )
90
+
87
91
class MyModule (torch .nn .Module ):
88
92
def __init__ (self ):
89
93
super ().__init__ ()
@@ -94,7 +98,101 @@ def forward(self, x):
94
98
95
99
mod = MyModule ()
96
100
opt_mod = torch .compile (mod )
97
- print (opt_mod (torch .randn (10 , 100 )))
101
+ print (opt_mod (t ))
102
+
103
+ ######################################################################
104
+ # torch.compile and Nested Calls
105
+ # ------------------------------
106
+ # Nested function calls within the decorated function will also be compiled.
107
+
108
+ def nested_function (x ):
109
+ return torch .sin (x )
110
+
111
+ @torch .compile
112
+ def outer_function (x , y ):
113
+ a = nested_function (x )
114
+ b = torch .cos (y )
115
+ return a + b
116
+
117
+ print (outer_function (t1 , t2 ))
118
+
119
+ ######################################################################
120
+ # In the same fashion, when compiling a module all sub-modules and methods
121
+ # within it, that are not in a skip list, are also compiled.
122
+
123
+ class OuterModule (torch .nn .Module ):
124
+ def __init__ (self ):
125
+ super ().__init__ ()
126
+ self .inner_module = MyModule ()
127
+ self .outer_lin = torch .nn .Linear (10 , 2 )
128
+
129
+ def forward (self , x ):
130
+ x = self .inner_module (x )
131
+ return torch .nn .functional .relu (self .outer_lin (x ))
132
+
133
+ outer_mod = OuterModule ()
134
+ opt_outer_mod = torch .compile (outer_mod )
135
+ print (opt_outer_mod (t ))
136
+
137
+ ######################################################################
138
+ # We can also disable some functions from being compiled by using
139
+ # ``torch.compiler.disable``. Suppose you want to disable the tracing on just
140
+ # the ``complex_function`` function, but want to continue the tracing back in
141
+ # ``complex_conjugate``. In this case, you can use
142
+ # ``torch.compiler.disable(recursive=False)`` option. Otherwise, the default is
143
+ # ``recursive=True``.
144
+
145
+ def complex_conjugate (z ):
146
+ return torch .conj (z )
147
+
148
+ @torch .compiler .disable (recursive = False )
149
+ def complex_function (real , imag ):
150
+ # Assuming this function cause problems in the compilation
151
+ z = torch .complex (real , imag )
152
+ return complex_conjugate (z )
153
+
154
+ def outer_function ():
155
+ real = torch .tensor ([2 , 3 ], dtype = torch .float32 )
156
+ imag = torch .tensor ([4 , 5 ], dtype = torch .float32 )
157
+ z = complex_function (real , imag )
158
+ return torch .abs (z )
159
+
160
+ # Try to compile the outer_function
161
+ try :
162
+ opt_outer_function = torch .compile (outer_function )
163
+ print (opt_outer_function ())
164
+ except Exception as e :
165
+ print ("Compilation of outer_function failed:" , e )
166
+
167
+ ######################################################################
168
+ # Best Practices and Recommendations
169
+ # ----------------------------------
170
+ #
171
+ # Behavior of ``torch.compile`` with Nested Modules and Function Calls
172
+ #
173
+ # When you use ``torch.compile``, the compiler will try to recursively compile
174
+ # every function call inside the target function or module inside the target
175
+ # function or module that is not in a skip list (such as built-ins, some functions in
176
+ # the torch.* namespace).
177
+ #
178
+ # **Best Practices:**
179
+ #
180
+ # 1. **Top-Level Compilation:** One approach is to compile at the highest level
181
+ # possible (i.e., when the top-level module is initialized/called) and
182
+ # selectively disable compilation when encountering excessive graph breaks or
183
+ # errors. If there are still many compile issues, compile individual
184
+ # subcomponents instead.
185
+ #
186
+ # 2. **Modular Testing:** Test individual functions and modules with ``torch.compile``
187
+ # before integrating them into larger models to isolate potential issues.
188
+ #
189
+ # 3. **Disable Compilation Selectively:** If certain functions or sub-modules
190
+ # cannot be handled by `torch.compile`, use the `torch.compiler.disable` context
191
+ # managers to recursively exclude them from compilation.
192
+ #
193
+ # 4. **Compile Leaf Functions First:** In complex models with multiple nested
194
+ # functions and modules, start by compiling the leaf functions or modules first.
195
+ # For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__.
98
196
99
197
######################################################################
100
198
# Demonstrating Speedups
0 commit comments