Skip to content

Commit e67ac8f

Browse files
authored
Add checks to deserialization. (#20751)
In particular for functional models.
1 parent 3d20616 commit e67ac8f

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

keras/src/models/functional.py

+6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from keras.src.ops.function import make_node_key
2020
from keras.src.ops.node import KerasHistory
2121
from keras.src.ops.node import Node
22+
from keras.src.ops.operation import Operation
2223
from keras.src.saving import serialization_lib
2324
from keras.src.utils import tracking
2425

@@ -523,6 +524,11 @@ def process_layer(layer_data):
523524
layer = serialization_lib.deserialize_keras_object(
524525
layer_data, custom_objects=custom_objects
525526
)
527+
if not isinstance(layer, Operation):
528+
raise ValueError(
529+
"Unexpected object from deserialization, expected a layer or "
530+
f"operation, got a {type(layer)}"
531+
)
526532
created_layers[layer_name] = layer
527533

528534
# Gather layer inputs.

keras/src/saving/serialization_lib.py

+12-16
Original file line numberDiff line numberDiff line change
@@ -783,22 +783,18 @@ def _retrieve_class_or_fn(
783783

784784
# Otherwise, attempt to retrieve the class object given the `module`
785785
# and `class_name`. Import the module, find the class.
786-
try:
787-
mod = importlib.import_module(module)
788-
except ModuleNotFoundError:
789-
raise TypeError(
790-
f"Could not deserialize {obj_type} '{name}' because "
791-
f"its parent module {module} cannot be imported. "
792-
f"Full object config: {full_config}"
793-
)
794-
obj = vars(mod).get(name, None)
795-
796-
# Special case for keras.metrics.metrics
797-
if obj is None and registered_name is not None:
798-
obj = vars(mod).get(registered_name, None)
799-
800-
if obj is not None:
801-
return obj
786+
if module == "keras.src" or module.startswith("keras.src."):
787+
try:
788+
mod = importlib.import_module(module)
789+
obj = vars(mod).get(name, None)
790+
if obj is not None:
791+
return obj
792+
except ModuleNotFoundError:
793+
raise TypeError(
794+
f"Could not deserialize {obj_type} '{name}' because "
795+
f"its parent module {module} cannot be imported. "
796+
f"Full object config: {full_config}"
797+
)
802798

803799
raise TypeError(
804800
f"Could not locate {obj_type} '{name}'. "

0 commit comments

Comments
 (0)