diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f390076..745b9a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: [--settings-path, pyproject.toml] diff --git a/pyproject.toml b/pyproject.toml index 6122a80..dffca27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ python = "^3.7" # Python lack of functionalities from future versions importlib-metadata = { version = "*", python = "<3.8" } -neptune-client = ">=0.10.0" +neptune-client = ">=0.16.16" numpy = "<1.24.0" fvcore = "<0.1.5.post20221220" diff --git a/src/neptune_detectron2/impl/__init__.py b/src/neptune_detectron2/impl/__init__.py index 3c4d43a..68da0b0 100644 --- a/src/neptune_detectron2/impl/__init__.py +++ b/src/neptune_detectron2/impl/__init__.py @@ -18,6 +18,8 @@ The class is used for automatic metadata logging to Neptune, during training and validation of detectron2 models.""" +from __future__ import annotations + __all__ = [ "__version__", "NeptuneHook", @@ -43,8 +45,10 @@ import detectron2 from detectron2.checkpoint import Checkpointer from detectron2.engine import hooks +from neptune.new.handler import Handler from neptune.new.metadata_containers import Run from neptune.new.types import File +from neptune.new.utils import stringify_unsupported from torch.nn import Module from neptune_detectron2.impl.version import __version__ @@ -56,7 +60,7 @@ class NeptuneHook(hooks.HookBase): """Hook implementation that sends the logs to Neptune. Args: - run: Pass a Neptune run object if you want to continue logging to an existing run. + run: Pass a Neptune run or namespace handler object if you want to continue logging to an existing run. Learn more about resuming runs in the docs: https://docs.neptune.ai/logging/to_existing_object base_namespace: In the Neptune run, the root namespace that will contain all the logged metadata. smoothing_window_size: How often NeptuneHook should log metrics (and checkpoints, if @@ -90,7 +94,7 @@ class NeptuneHook(hooks.HookBase): def __init__( self, *, - run: Optional[Run] = None, + run: Optional[Run | Handler] = None, base_namespace: str = "training", smoothing_window_size: int = 20, log_model: bool = False, @@ -103,15 +107,17 @@ def __init__( self._verify_window_size() - self._run = neptune.init_run(**kwargs) if not isinstance(run, Run) else run + self._run = neptune.init_run(**kwargs) if not run else run - verify_type("run", self._run, Run) + verify_type("run", self._run, (Run, Handler)) if base_namespace.endswith("/"): - self._base_namespace = base_namespace[:-1] + base_namespace = base_namespace[:-1] self.base_handler = self._run[base_namespace] + self._root_object = self._run.get_root_object() if isinstance(self._run, Handler) else self._run + def _verify_window_size(self) -> None: if self._window_size <= 0: raise ValueError(f"Update freq should be greater than 0. Got {self._window_size}.") @@ -119,11 +125,11 @@ def _verify_window_size(self) -> None: raise TypeError(f"Smoothing window size should be of type int. Got {type(self._window_size)} instead.") def _log_integration_version(self) -> None: - self.base_handler[INTEGRATION_VERSION_KEY] = detectron2.__version__ + self._root_object[INTEGRATION_VERSION_KEY] = detectron2.__version__ def _log_config(self) -> None: if hasattr(self.trainer, "cfg") and isinstance(self.trainer.cfg, dict): - self.base_handler["config"] = self.trainer.cfg + self.base_handler["config"] = stringify_unsupported(self.trainer.cfg) def _log_model(self) -> None: if hasattr(self.trainer, "model") and isinstance(self.trainer.model, Module): @@ -148,7 +154,7 @@ def _log_checkpoint(self, final: bool = False) -> None: def _log_metrics(self) -> None: storage = detectron2.utils.events.get_event_storage() for k, (v, _) in storage.latest_with_smoothing_hint(self._window_size).items(): - self.base_handler[f"metrics/{k}"].log(v) + self.base_handler[f"metrics/{k}"].append(v) def _can_save_checkpoint(self) -> bool: return hasattr(self.trainer, "checkpointer") and isinstance(self.trainer.checkpointer, Checkpointer) @@ -177,5 +183,5 @@ def after_train(self) -> None: if self.log_model: self._log_checkpoint(final=True) - self._run.sync() - self._run.stop() + self._root_object.sync() + self._root_object.stop()