Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comet fix #481

Merged
merged 8 commits into from
Nov 12, 2019
1 change: 1 addition & 0 deletions .run_local_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
rm -rf _ckpt_*
rm -rf tests/save_dir*
rm -rf tests/mlruns_*
rm -rf tests/cometruns*
rm -rf tests/tests/*
rm -rf lightning_logs
coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules
Expand Down
17 changes: 16 additions & 1 deletion docs/Trainer/Logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,27 @@ def any_lightning_module_function_or_hook(...):

Log using [comet](https://www.comet.ml)

Comet logger can be used in either online or offline mode.
To log in online mode, CometLogger requries an API key:
```{.python}
from pytorch_lightning.logging import CometLogger
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger(
api_key=os.environ["COMET_KEY"],
workspace=os.environ["COMET_WORKSPACE"],
workspace=os.environ["COMET_WORKSPACE"], # Optional
project_name="default_project", # Optional
rest_api_key=os.environ["COMET_REST_KEY"], # Optional
experiment_name="default" # Optional
)
trainer = Trainer(logger=comet_logger)
```
To log in offline mode, CometLogger requires a path to a local directory:
```{.python}
from pytorch_lightning.logging import CometLogger
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger(
save_dir=".",
workspace=os.environ["COMET_WORKSPACE"], # Optional
project_name="default_project", # Optional
rest_api_key=os.environ["COMET_REST_KEY"], # Optional
experiment_name="default" # Optional
Expand Down
60 changes: 43 additions & 17 deletions pytorch_lightning/logging/comet_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

try:
from comet_ml import Experiment as CometExperiment
from comet_ml import OfflineExperiment as CometOfflineExperiment
from comet_ml.papi import API
except ImportError:
raise ImportError('Missing comet_ml package.')
Expand All @@ -14,12 +15,14 @@


class CometLogger(LightningLoggerBase):
def __init__(self, api_key, workspace, rest_api_key=None, project_name=None, experiment_name=None, **kwargs):
def __init__(self, api_key=None, save_dir=None, workspace=None,
rest_api_key=None, project_name=None, experiment_name=None, **kwargs):
"""
Initialize a Comet.ml logger
Initialize a Comet.ml logger. Requires either an API Key (online mode) or a local directory path (offline mode)

:param str api_key: API key, found on Comet.ml
:param str workspace: Name of workspace for this user
:param str api_key: Required in online mode. API key, found on Comet.ml
:param str save_dir: Required in offline mode. The path for the directory to save local comet logs
:param str workspace: Optional. Name of workspace for this user
:param str project_name: Optional. Send your experiment to a specific project.
Otherwise will be sent to Uncategorized Experiments.
If project name does not already exists Comet.ml will create a new project.
Expand All @@ -30,10 +33,25 @@ def __init__(self, api_key, workspace, rest_api_key=None, project_name=None, exp
super().__init__()
self._experiment = None

self.api_key = api_key
# Determine online or offline mode based on which arguments were passed to CometLogger
if save_dir is not None and api_key is not None:
# If arguments are passed for both save_dir and api_key, preference is given to online mode
self.mode = "online"
self.api_key = api_key
elif api_key is not None:
self.mode = "online"
self.api_key = api_key
elif save_dir is not None:
self.mode = "offline"
self.save_dir = save_dir
else:
# If neither api_key nor save_dir are passed as arguments, raise an exception
raise Exception("CometLogger requires either api_key or save_dir during initialization.")

logger.info(f"CometLogger will be initialized in {self.mode} mode")

self.workspace = workspace
self.project_name = project_name

self._kwargs = kwargs

if rest_api_key is not None:
Expand All @@ -46,7 +64,7 @@ def __init__(self, api_key, workspace, rest_api_key=None, project_name=None, exp

if experiment_name:
try:
self._set_experiment_name(experiment_name)
self.name = experiment_name
except TypeError as e:
logger.exception("Failed to set experiment name for comet.ml logger")

Expand All @@ -55,12 +73,20 @@ def experiment(self):
if self._experiment is not None:
return self._experiment

self._experiment = CometExperiment(
api_key=self.api_key,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)
if self.mode == "online":
self._experiment = CometExperiment(
api_key=self.api_key,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)
else:
self._experiment = CometOfflineExperiment(
offline_directory=self.save_dir,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)

return self._experiment

Expand All @@ -81,14 +107,14 @@ def log_metrics(self, metrics, step_num=None):
def finalize(self, status):
self.experiment.end()

@rank_zero_only
def _set_experiment_name(self, experiment_name):
self.experiment.set_name(experiment_name)

@property
def name(self):
return self.experiment.project_name

@name.setter
def name(self, value):
self.experiment.set_name(value)

@property
def version(self):
if self.project_name and self.rest_api_key:
Expand Down
23 changes: 16 additions & 7 deletions tests/test_y_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,12 @@ def test_comet_logger():
hparams = testing_utils.get_hparams()
model = LightningTestModel(hparams)

# API key for dummy Comet.ml account
root_dir = os.path.dirname(os.path.realpath(__file__))
comet_dir = os.path.join(root_dir, "cometruns")

# We test CometLogger in offline mode with local saves
logger = CometLogger(
api_key="KnmgASRHHyxWXOpwUfgrAFz8C",
save_dir=comet_dir,
project_name="general",
workspace="dummy-test",
)
Expand All @@ -170,10 +173,12 @@ def test_comet_logger():
print('result finished')
assert result == 1, "Training failed"

testing_utils.clear_save_dir()


def test_comet_pickle():
"""
verify that pickling trainer with mlflow logger works
verify that pickling trainer with comet logger works
"""
reset_seed()

Expand All @@ -185,11 +190,14 @@ def test_comet_pickle():
hparams = testing_utils.get_hparams()
model = LightningTestModel(hparams)

# API key for dummy Comet.ml account
root_dir = os.path.dirname(os.path.realpath(__file__))
comet_dir = os.path.join(root_dir, "cometruns")

# We test CometLogger in offline mode with local saves
logger = CometLogger(
api_key="KnmgASRHHyxWXOpwUfgrAFz8C",
save_dir=comet_dir,
project_name="general",
workspace="dummy-test"
workspace="dummy-test",
)

trainer_options = dict(
Expand All @@ -202,9 +210,10 @@ def test_comet_pickle():
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})

testing_utils.clear_save_dir()

def test_custom_logger(tmpdir):

def test_custom_logger(tmpdir):
class CustomLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
Expand Down