@@ -33,7 +33,7 @@ class WandbLogger(LightningLoggerBase):
33
33
"""
34
34
35
35
def __init__ (self , name = None , save_dir = None , offline = False , id = None , anonymous = False ,
36
- version = None , project = None , tags = None , experiment = None ):
36
+ version = None , project = None , tags = None , experiment = None , entity = None ):
37
37
super ().__init__ ()
38
38
self ._name = name
39
39
self ._save_dir = save_dir
@@ -43,6 +43,7 @@ def __init__(self, name=None, save_dir=None, offline=False, id=None, anonymous=F
43
43
self ._project = project
44
44
self ._experiment = experiment
45
45
self ._offline = offline
46
+ self ._entity = entity
46
47
47
48
def __getstate__ (self ):
48
49
state = self .__dict__ .copy ()
@@ -68,7 +69,7 @@ def experiment(self):
68
69
os .environ ["WANDB_MODE" ] = "dryrun"
69
70
self ._experiment = wandb .init (
70
71
name = self ._name , dir = self ._save_dir , project = self ._project , anonymous = self ._anonymous ,
71
- id = self ._id , resume = "allow" , tags = self ._tags )
72
+ id = self ._id , resume = "allow" , tags = self ._tags , entity = self . _entity )
72
73
return self ._experiment
73
74
74
75
def watch (self , model , log = "gradients" , log_freq = 100 ):
0 commit comments