@@ -156,30 +156,31 @@ def remat(f):
156
156
pass, the forward computation is recomputed as needed.
157
157
158
158
Example:
159
- ```python
160
- from keras import Model
161
- class CustomRematLayer(layers.Layer):
162
- def __init__(self, **kwargs):
163
- super().__init__(**kwargs)
164
- self.remat_function = remat(self.intermediate_function)
165
-
166
- def intermediate_function(self, x):
167
- for _ in range(2):
168
- x = x + x * 0.1 # Simple scaled transformation
169
- return x
170
-
171
- def call(self, inputs):
172
- return self.remat_function(inputs)
173
-
174
- # Define a simple model using the custom layer
175
- inputs = layers.Input(shape=(4,))
176
- x = layers.Dense(4, activation="relu")(inputs)
177
- x = CustomRematLayer()(x) # Custom layer with rematerialization
178
- outputs = layers.Dense(1)(x)
179
-
180
- # Create and compile the model
181
- model = Model(inputs=inputs, outputs=outputs)
182
- model.compile(optimizer="sgd", loss="mse")
183
- ```
159
+
160
+ ```python
161
+ from keras import Model
162
+ class CustomRematLayer(layers.Layer):
163
+ def __init__(self, **kwargs):
164
+ super().__init__(**kwargs)
165
+ self.remat_function = remat(self.intermediate_function)
166
+
167
+ def intermediate_function(self, x):
168
+ for _ in range(2):
169
+ x = x + x * 0.1 # Simple scaled transformation
170
+ return x
171
+
172
+ def call(self, inputs):
173
+ return self.remat_function(inputs)
174
+
175
+ # Define a simple model using the custom layer
176
+ inputs = layers.Input(shape=(4,))
177
+ x = layers.Dense(4, activation="relu")(inputs)
178
+ x = CustomRematLayer()(x) # Custom layer with rematerialization
179
+ outputs = layers.Dense(1)(x)
180
+
181
+ # Create and compile the model
182
+ model = Model(inputs=inputs, outputs=outputs)
183
+ model.compile(optimizer="sgd", loss="mse")
184
+ ```
184
185
"""
185
186
return backend .core .remat (f )
0 commit comments