Skip to content

Commit 0f5715b

Browse files
author
Dave Berenbaum
authoredAug 5, 2024··
support lists for log_plot y val (#837)
1 parent c8c32d0 commit 0f5715b

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed
 

‎src/dvclive/live.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def log_plot(
562562
name: str,
563563
datapoints: Union[pd.DataFrame, np.ndarray, List[Dict]],
564564
x: str,
565-
y: str,
565+
y: Union[str, list[str]],
566566
template: Optional[str] = "linear",
567567
title: Optional[str] = None,
568568
x_label: Optional[str] = None,
@@ -579,7 +579,8 @@ def log_plot(
579579
datapoints (pd.DataFrame | np.ndarray | List[Dict]): Pandas DataFrame, Numpy
580580
Array or List of dictionaries containing the data for the plot.
581581
x (str): name of the key (present in the dictionaries) to use as the x axis.
582-
y (str): name of the key (present in the dictionaries) to use the y axis.
582+
y (str | list[str]): name of the key or keys (present in the
583+
dictionaries) to use the y axis.
583584
template (str): name of the `DVC plots template` to use. Defaults to
584585
`"linear"`.
585586
title (str): title to be displayed. Defaults to

‎src/dvclive/plots/custom.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Optional
2+
from typing import Optional, Union
33

44
from dvclive.serialize import dump_json
55

@@ -15,7 +15,7 @@ def __init__(
1515
name: str,
1616
output_folder: str,
1717
x: str,
18-
y: str,
18+
y: Union[str, list[str]],
1919
template: Optional[str],
2020
title: Optional[str] = None,
2121
x_label: Optional[str] = None,

‎tests/plots/test_custom.py

+27
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,30 @@ def test_log_custom_plot(tmp_dir):
2929
"x_label": "x_label",
3030
"y_label": "y_label",
3131
}
32+
33+
34+
def test_log_custom_plot_multi_y(tmp_dir):
35+
live = Live()
36+
out = tmp_dir / live.plots_dir / CustomPlot.subfolder
37+
38+
datapoints = [{"x": 1, "y1": 2, "y2": 3}, {"x": 4, "y1": 5, "y2": 6}]
39+
live.log_plot(
40+
"custom_linear",
41+
datapoints,
42+
x="x",
43+
y=["y1", "y2"],
44+
template="linear",
45+
title="custom_title",
46+
x_label="x_label",
47+
y_label="y_label",
48+
)
49+
50+
assert json.loads((out / "custom_linear.json").read_text()) == datapoints
51+
assert live._plots["custom_linear"].plot_config == {
52+
"template": "linear",
53+
"title": "custom_title",
54+
"x": "x",
55+
"y": ["y1", "y2"],
56+
"x_label": "x_label",
57+
"y_label": "y_label",
58+
}

0 commit comments

Comments
 (0)
Please sign in to comment.