Skip to content

Commit 2688bfc

Browse files
committedMar 3, 2025·
Fix docstring
1 parent ff427e5 commit 2688bfc

File tree

1 file changed

+26
-25
lines changed

1 file changed

+26
-25
lines changed
 

‎keras/src/backend/common/remat.py

+26-25
Original file line numberDiff line numberDiff line change
@@ -156,30 +156,31 @@ def remat(f):
156156
pass, the forward computation is recomputed as needed.
157157
158158
Example:
159-
```python
160-
from keras import Model
161-
class CustomRematLayer(layers.Layer):
162-
def __init__(self, **kwargs):
163-
super().__init__(**kwargs)
164-
self.remat_function = remat(self.intermediate_function)
165-
166-
def intermediate_function(self, x):
167-
for _ in range(2):
168-
x = x + x * 0.1 # Simple scaled transformation
169-
return x
170-
171-
def call(self, inputs):
172-
return self.remat_function(inputs)
173-
174-
# Define a simple model using the custom layer
175-
inputs = layers.Input(shape=(4,))
176-
x = layers.Dense(4, activation="relu")(inputs)
177-
x = CustomRematLayer()(x) # Custom layer with rematerialization
178-
outputs = layers.Dense(1)(x)
179-
180-
# Create and compile the model
181-
model = Model(inputs=inputs, outputs=outputs)
182-
model.compile(optimizer="sgd", loss="mse")
183-
```
159+
160+
```python
161+
from keras import Model
162+
class CustomRematLayer(layers.Layer):
163+
def __init__(self, **kwargs):
164+
super().__init__(**kwargs)
165+
self.remat_function = remat(self.intermediate_function)
166+
167+
def intermediate_function(self, x):
168+
for _ in range(2):
169+
x = x + x * 0.1 # Simple scaled transformation
170+
return x
171+
172+
def call(self, inputs):
173+
return self.remat_function(inputs)
174+
175+
# Define a simple model using the custom layer
176+
inputs = layers.Input(shape=(4,))
177+
x = layers.Dense(4, activation="relu")(inputs)
178+
x = CustomRematLayer()(x) # Custom layer with rematerialization
179+
outputs = layers.Dense(1)(x)
180+
181+
# Create and compile the model
182+
model = Model(inputs=inputs, outputs=outputs)
183+
model.compile(optimizer="sgd", loss="mse")
184+
```
184185
"""
185186
return backend.core.remat(f)

0 commit comments

Comments
 (0)
Please sign in to comment.