Skip to content

Commit a96e739

Browse files
refactor: Update Live class to handle pathlib.Path object for dvcyaml argument.
1 parent 6888462 commit a96e739

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

Diff for: src/dvclive/live.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
resume: bool = False,
8383
report: Literal["md", "notebook", "html", None] = None,
8484
save_dvc_exp: bool = True,
85-
dvcyaml: Optional[str] = "dvc.yaml",
85+
dvcyaml: Optional[str, Path] = "dvc.yaml",
8686
cache_images: bool = False,
8787
exp_name: Optional[str] = None,
8888
exp_message: Optional[str] = None,
@@ -104,11 +104,11 @@ def __init__(
104104
part of `Live.end()`. Defaults to `True`. If you are using DVCLive
105105
inside a DVC Pipeline and running with `dvc exp run`, the option will be
106106
ignored.
107-
dvcyaml (str | None): where to write dvc.yaml file, which adds DVC
107+
dvcyaml (str | Path | None): where to write dvc.yaml file, which adds DVC
108108
configuration for metrics, plots, and parameters as part of
109109
`Live.next_step()` and `Live.end()`. If `None`, no dvc.yaml file is
110110
written. Defaults to `"dvc.yaml"`. See `Live.make_dvcyaml()`.
111-
If a string like `"subdir/dvc.yaml"`, DVCLive will write the
111+
If a string or Path like `"subdir/dvc.yaml"`, DVCLive will write the
112112
configuration to that path (file must be named "dvc.yaml").
113113
If `False`, DVCLive will not write to "dvc.yaml" (useful if you are
114114
tracking DVCLive metrics, plots, and parameters independently and
@@ -265,11 +265,16 @@ def _init_dvc(self): # noqa: C901
265265
self._include_untracked.append(self.dir)
266266

267267
def _init_dvc_file(self) -> str:
268-
if isinstance(self._dvcyaml, str):
269-
if os.path.basename(self._dvcyaml) == "dvc.yaml":
270-
return self._dvcyaml
271-
raise InvalidDvcyamlError
272-
return "dvc.yaml"
268+
if self._dvcyaml is None or isinstance(self._dvcyaml, bool):
269+
return "dvc.yaml"
270+
if isinstance(self._dvcyaml, Path):
271+
self._dvcyaml = str(self._dvcyaml)
272+
if (
273+
isinstance(self._dvcyaml, str)
274+
and os.path.basename(self._dvcyaml) == "dvc.yaml"
275+
):
276+
return self._dvcyaml
277+
raise InvalidDvcyamlError
273278

274279
def _init_dvc_pipeline(self):
275280
if os.getenv(env.DVC_EXP_BASELINE_REV, None):
@@ -334,6 +339,8 @@ def _init_test(self):
334339
"""
335340
with tempfile.TemporaryDirectory() as dirpath:
336341
self._dir = os.path.join(dirpath, self._dir)
342+
if isinstance(self._dvcyaml, Path):
343+
self._dvcyaml = str(self._dvcyaml)
337344
if isinstance(self._dvcyaml, str):
338345
self._dvc_file = os.path.join(dirpath, self._dvcyaml)
339346
self._save_dvc_exp = False

Diff for: tests/test_make_dvcyaml.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
from PIL import Image
5+
from pathlib import Path
56

67
from dvclive import Live
78
from dvclive.dvc import make_dvcyaml
@@ -423,7 +424,7 @@ def test_warn_on_dvcyaml_output_overlap(tmp_dir, mocker, mocked_dvc_repo, dvcyam
423424

424425
@pytest.mark.parametrize(
425426
"dvcyaml",
426-
[True, False, "dvc.yaml"],
427+
[True, False, "dvc.yaml", Path("dvc.yaml")],
427428
)
428429
def test_make_dvcyaml(tmp_dir, mocked_dvc_repo, dvcyaml):
429430
dvclive = Live("logs", dvcyaml=dvcyaml)

0 commit comments

Comments
 (0)