Commit e67ac8f 1 parent 3d20616 commit e67ac8f Copy full SHA for e67ac8f
File tree 2 files changed +18
-16
lines changed
2 files changed +18
-16
lines changed Original file line number Diff line number Diff line change 19
19
from keras .src .ops .function import make_node_key
20
20
from keras .src .ops .node import KerasHistory
21
21
from keras .src .ops .node import Node
22
+ from keras .src .ops .operation import Operation
22
23
from keras .src .saving import serialization_lib
23
24
from keras .src .utils import tracking
24
25
@@ -523,6 +524,11 @@ def process_layer(layer_data):
523
524
layer = serialization_lib .deserialize_keras_object (
524
525
layer_data , custom_objects = custom_objects
525
526
)
527
+ if not isinstance (layer , Operation ):
528
+ raise ValueError (
529
+ "Unexpected object from deserialization, expected a layer or "
530
+ f"operation, got a { type (layer )} "
531
+ )
526
532
created_layers [layer_name ] = layer
527
533
528
534
# Gather layer inputs.
Original file line number Diff line number Diff line change @@ -783,22 +783,18 @@ def _retrieve_class_or_fn(
783
783
784
784
# Otherwise, attempt to retrieve the class object given the `module`
785
785
# and `class_name`. Import the module, find the class.
786
- try :
787
- mod = importlib .import_module (module )
788
- except ModuleNotFoundError :
789
- raise TypeError (
790
- f"Could not deserialize { obj_type } '{ name } ' because "
791
- f"its parent module { module } cannot be imported. "
792
- f"Full object config: { full_config } "
793
- )
794
- obj = vars (mod ).get (name , None )
795
-
796
- # Special case for keras.metrics.metrics
797
- if obj is None and registered_name is not None :
798
- obj = vars (mod ).get (registered_name , None )
799
-
800
- if obj is not None :
801
- return obj
786
+ if module == "keras.src" or module .startswith ("keras.src." ):
787
+ try :
788
+ mod = importlib .import_module (module )
789
+ obj = vars (mod ).get (name , None )
790
+ if obj is not None :
791
+ return obj
792
+ except ModuleNotFoundError :
793
+ raise TypeError (
794
+ f"Could not deserialize { obj_type } '{ name } ' because "
795
+ f"its parent module { module } cannot be imported. "
796
+ f"Full object config: { full_config } "
797
+ )
802
798
803
799
raise TypeError (
804
800
f"Could not locate { obj_type } '{ name } '. "
You can’t perform that action at this time.
0 commit comments