Skip to content

Commit 761658e

Browse files
williamFalconakarnachev
authored and
akarnachev
committed
quick patch __code__ (Lightning-AI#1352)
* quick patch * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix
1 parent b8e094a commit 761658e

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

pytorch_lightning/trainer/model_hooks.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,17 @@ def is_overriden(self, method_name: str, model: LightningModule = None) -> bool:
2020
# in case of calling deprecated method
2121
return False
2222

23-
# when code pointers are different, it was overriden
24-
is_overriden = getattr(model, method_name).__code__ is not getattr(super_object, method_name).__code__
23+
instance_attr = getattr(model, method_name)
24+
super_attr = getattr(super_object, method_name)
25+
26+
# when code pointers are different, it was implemented
27+
if hasattr(instance_attr, 'patch_loader_code'):
28+
# cannot pickle __code__ so cannot verify if PatchDataloader
29+
# exists which shows dataloader methods have been overwritten.
30+
# so, we hack it by using the string representation
31+
is_overriden = instance_attr.patch_loader_code != str(super_attr.__code__)
32+
else:
33+
is_overriden = instance_attr.__code__ is not super_attr.__code__
2534
return is_overriden
2635

2736
def has_arg(self, f_name, arg_name):

pytorch_lightning/trainer/trainer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -971,8 +971,10 @@ class _PatchDataLoader(object):
971971
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
972972
self.dataloader = dataloader
973973

974-
# Assign __code__, needed for checking if method has been overriden
975-
self.__code__ = self.__call__.__code__
974+
# cannot pickle __code__ so cannot verify if PatchDataloader
975+
# exists which shows dataloader methods have been overwritten.
976+
# so, we hack it by using the string representation
977+
self.patch_loader_code = str(self.__call__.__code__)
976978

977979
def __call__(self) -> Union[List[DataLoader], DataLoader]:
978980
return self.dataloader

0 commit comments

Comments
 (0)