Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update integration after recent deprecations #6

Merged
merged 6 commits into from
Jan 30, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
update integration after recent deprecations
AleksanderWWW committed Jan 26, 2023
commit 9990cb3f1adcfd62daa65c438f4b7d290cfa58a5
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ python = "^3.7"
# Python lack of functionalities from future versions
importlib-metadata = { version = "*", python = "<3.8" }

neptune-client = ">=0.10.0"
neptune-client = ">=0.16.16"
numpy = "<1.24.0"
fvcore = "<0.1.5.post20221220"

21 changes: 14 additions & 7 deletions src/neptune_detectron2/impl/__init__.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,8 @@
The class is used for automatic metadata logging to Neptune, during training and validation of detectron2 models."""

from __future__ import annotations

__all__ = [
"__version__",
"NeptuneHook",
@@ -43,8 +45,10 @@
import detectron2
from detectron2.checkpoint import Checkpointer
from detectron2.engine import hooks
from neptune.new.handler import Handler
from neptune.new.metadata_containers import Run
from neptune.new.types import File
from neptune.new.utils import stringify_unsupported
from torch.nn import Module

from neptune_detectron2.impl.version import __version__
@@ -56,7 +60,7 @@ class NeptuneHook(hooks.HookBase):
"""Hook implementation that sends the logs to Neptune.
Args:
run: Pass a Neptune run object if you want to continue logging to an existing run.
run: Pass a Neptune run or handler object if you want to continue logging to an existing run.
Learn more about resuming runs in the docs: https://docs.neptune.ai/logging/to_existing_object
base_namespace: In the Neptune run, the root namespace that will contain all the logged metadata.
smoothing_window_size: How often NeptuneHook should log metrics (and checkpoints, if
@@ -90,7 +94,7 @@ class NeptuneHook(hooks.HookBase):
def __init__(
self,
*,
run: Optional[Run] = None,
run: Optional[Run | Handler] = None,
base_namespace: str = "training",
smoothing_window_size: int = 20,
log_model: bool = False,
@@ -103,14 +107,14 @@ def __init__(

self._verify_window_size()

self._run = neptune.init_run(**kwargs) if not isinstance(run, Run) else run
self._run = neptune.init_run(**kwargs) if not run else run

verify_type("run", self._run, Run)
verify_type("run", self._run, (Run, Handler))

if base_namespace.endswith("/"):
self._base_namespace = base_namespace[:-1]

self.base_handler = self._run[base_namespace]
self.base_handler = self._run[self._base_namespace]

def _verify_window_size(self) -> None:
if self._window_size <= 0:
@@ -123,7 +127,7 @@ def _log_integration_version(self) -> None:

def _log_config(self) -> None:
if hasattr(self.trainer, "cfg") and isinstance(self.trainer.cfg, dict):
self.base_handler["config"] = self.trainer.cfg
self.base_handler["config"] = stringify_unsupported(self.trainer.cfg)

def _log_model(self) -> None:
if hasattr(self.trainer, "model") and isinstance(self.trainer.model, Module):
@@ -148,7 +152,7 @@ def _log_checkpoint(self, final: bool = False) -> None:
def _log_metrics(self) -> None:
storage = detectron2.utils.events.get_event_storage()
for k, (v, _) in storage.latest_with_smoothing_hint(self._window_size).items():
self.base_handler[f"metrics/{k}"].log(v)
self.base_handler[f"metrics/{k}"].append(v)

def _can_save_checkpoint(self) -> bool:
return hasattr(self.trainer, "checkpointer") and isinstance(self.trainer.checkpointer, Checkpointer)
@@ -177,5 +181,8 @@ def after_train(self) -> None:
if self.log_model:
self._log_checkpoint(final=True)

if isinstance(self._run, Handler):
self._run = self._run.get_root_object()

self._run.sync()
self._run.stop()