Skip to content

Commit 2eca8a9

Browse files
quick patch __code__ (#1352)
* quick patch * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix
1 parent 1576ad9 commit 2eca8a9

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

pytorch_lightning/trainer/model_hooks.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,17 @@ def is_overriden(self, method_name: str, model: LightningModule = None) -> bool:
2020
# in case of calling deprecated method
2121
return False
2222

23-
# when code pointers are different, it was overriden
24-
is_overriden = getattr(model, method_name).__code__ is not getattr(super_object, method_name).__code__
23+
instance_attr = getattr(model, method_name)
24+
super_attr = getattr(super_object, method_name)
25+
26+
# when code pointers are different, it was implemented
27+
if hasattr(instance_attr, 'patch_loader_code'):
28+
# cannot pickle __code__ so cannot verify if PatchDataloader
29+
# exists which shows dataloader methods have been overwritten.
30+
# so, we hack it by using the string representation
31+
is_overriden = instance_attr.patch_loader_code != str(super_attr.__code__)
32+
else:
33+
is_overriden = instance_attr.__code__ is not super_attr.__code__
2534
return is_overriden
2635

2736
def has_arg(self, f_name, arg_name):

pytorch_lightning/trainer/trainer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -970,8 +970,10 @@ class _PatchDataLoader(object):
970970
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
971971
self.dataloader = dataloader
972972

973-
# Assign __code__, needed for checking if method has been overriden
974-
self.__code__ = self.__call__.__code__
973+
# cannot pickle __code__ so cannot verify if PatchDataloader
974+
# exists which shows dataloader methods have been overwritten.
975+
# so, we hack it by using the string representation
976+
self.patch_loader_code = str(self.__call__.__code__)
975977

976978
def __call__(self) -> Union[List[DataLoader], DataLoader]:
977979
return self.dataloader

0 commit comments

Comments
 (0)