Skip to content

Commit 228e9a8

Browse files
author
Dave Berenbaum
authoredApr 19, 2024··
post to studio in thread to avoid blocking (#814)
* post to studio in thread to avoid blocking * queue for studio data posts * fix test_post_to_studio_if_done_skipped * catch and warn in src/dvclive/studio.py:post_to_studio
1 parent edb5ee3 commit 228e9a8

File tree

3 files changed

+77
-30
lines changed

3 files changed

+77
-30
lines changed
 

‎src/dvclive/live.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import math
77
import os
88
import shutil
9+
import queue
910
import tempfile
11+
import threading
1012

1113
from pathlib import Path, PurePath
1214
from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING, Literal
@@ -171,6 +173,7 @@ def __init__(
171173
self._studio_events_to_skip: Set[str] = set()
172174
self._dvc_studio_config: Dict[str, Any] = {}
173175
self._num_points_sent_to_studio: Dict[str, int] = {}
176+
self._studio_queue = None
174177
self._init_studio()
175178

176179
self._system_monitor: Optional[_SystemMonitor] = None # Monitoring thread
@@ -296,7 +299,7 @@ def _init_studio(self):
296299
self._studio_events_to_skip.add("start")
297300
self._studio_events_to_skip.add("done")
298301
else:
299-
self.post_to_studio("start")
302+
post_to_studio(self, "start")
300303

301304
def _init_report(self):
302305
if self._report_mode not in {None, "html", "notebook", "md"}:
@@ -428,7 +431,7 @@ def sync(self):
428431

429432
self.make_report()
430433

431-
self.post_to_studio("data")
434+
self.post_data_to_studio()
432435

433436
def next_step(self):
434437
"""
@@ -880,9 +883,19 @@ def make_dvcyaml(self):
880883
"""
881884
make_dvcyaml(self)
882885

883-
@catch_and_warn(DvcException, logger)
884-
def post_to_studio(self, event: Literal["start", "data", "done"]):
885-
post_to_studio(self, event)
886+
def post_data_to_studio(self):
887+
if not self._studio_queue:
888+
self._studio_queue = queue.Queue()
889+
890+
def worker():
891+
while True:
892+
item = self._studio_queue.get()
893+
post_to_studio(item, "data")
894+
self._studio_queue.task_done()
895+
896+
threading.Thread(target=worker, daemon=True).start()
897+
898+
self._studio_queue.put(self)
886899

887900
def end(self):
888901
"""
@@ -926,7 +939,7 @@ def end(self):
926939
self.save_dvc_exp()
927940

928941
# Mark experiment as done
929-
self.post_to_studio("done")
942+
post_to_studio(self, "done")
930943

931944
cleanup_dvclive_step_completed()
932945

‎src/dvclive/studio.py

+4
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
from pathlib import PureWindowsPath
88
from typing import TYPE_CHECKING, Literal, Mapping
99

10+
from dvc.exceptions import DvcException
1011
from dvc_studio_client.config import get_studio_config
1112
from dvc_studio_client.post_live_metrics import post_live_metrics
1213

14+
from .utils import catch_and_warn
15+
1316
if TYPE_CHECKING:
1417
from dvclive.live import Live
1518
from dvclive.serialize import load_yaml
@@ -96,6 +99,7 @@ def increment_num_points_sent_to_studio(live, plots):
9699
return live
97100

98101

102+
@catch_and_warn(DvcException, logger)
99103
def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa: C901
100104
if event in live._studio_events_to_skip:
101105
return

‎tests/test_post_to_studio.py

+54-24
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dvclive import Live
1010
from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT
1111
from dvclive.plots import Image, Metric
12-
from dvclive.studio import _adapt_image, get_dvc_studio_config
12+
from dvclive.studio import _adapt_image, get_dvc_studio_config, post_to_studio
1313

1414

1515
def get_studio_call(event_type, exp_name, **kwargs):
@@ -46,7 +46,9 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
4646
)
4747

4848
live.log_metric("foo", 1)
49-
live.next_step()
49+
live.step = 0
50+
live.make_summary()
51+
post_to_studio(live, "data")
5052

5153
mocked_post.assert_called_with(
5254
"https://0.0.0.0/api/live",
@@ -58,8 +60,10 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
5860
),
5961
)
6062

63+
live.step += 1
6164
live.log_metric("foo", 2)
62-
live.next_step()
65+
live.make_summary()
66+
post_to_studio(live, "data")
6367

6468
mocked_post.assert_called_with(
6569
"https://0.0.0.0/api/live",
@@ -72,7 +76,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
7276
)
7377

7478
mocked_post.reset_mock()
75-
live.end()
79+
live.save_dvc_exp()
80+
post_to_studio(live, "done")
7681

7782
mocked_post.assert_called_with(
7883
"https://0.0.0.0/api/live",
@@ -118,11 +123,15 @@ def test_post_to_studio_failed_data_request(
118123
error_response.status_code = 400
119124
mocker.patch("requests.post", return_value=error_response)
120125
live.log_metric("foo", 1)
121-
live.next_step()
126+
live.step = 0
127+
live.make_summary()
128+
post_to_studio(live, "data")
122129

123130
mocked_post = mocker.patch("requests.post", return_value=valid_response)
131+
live.step += 1
124132
live.log_metric("foo", 2)
125-
live.next_step()
133+
live.make_summary()
134+
post_to_studio(live, "data")
126135
mocked_post.assert_called_with(
127136
"https://0.0.0.0/api/live",
128137
**get_studio_call(
@@ -154,6 +163,7 @@ def test_post_to_studio_failed_start_request(
154163
live.next_step()
155164

156165
assert mocked_post.call_count == 1
166+
assert live._studio_events_to_skip == {"start", "data", "done"}
157167

158168

159169
def test_post_to_studio_done_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post):
@@ -210,7 +220,9 @@ def test_post_to_studio_dvc_studio_config(
210220

211221
with Live() as live:
212222
live.log_metric("foo", 1)
213-
live.next_step()
223+
live.step = 0
224+
live.make_summary()
225+
post_to_studio(live, "data")
214226

215227
assert mocked_post.call_args.kwargs["headers"]["Authorization"] == "token token"
216228

@@ -231,7 +243,9 @@ def test_post_to_studio_skip_if_no_token(
231243

232244
with Live() as live:
233245
live.log_metric("foo", 1)
234-
live.next_step()
246+
live.step = 0
247+
live.make_summary()
248+
post_to_studio(live, "data")
235249

236250
assert mocked_post.call_count == 0
237251

@@ -241,7 +255,8 @@ def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_po
241255

242256
live = Live()
243257
live.log_metric("eval/loss", 1)
244-
live.next_step()
258+
live.make_summary()
259+
post_to_studio(live, "data")
245260

246261
plots_path = Path(live.plots_dir)
247262
loss_path = (plots_path / Metric.subfolder / "eval/loss.tsv").as_posix()
@@ -269,7 +284,9 @@ def test_post_to_studio_inside_dvc_exp(
269284

270285
with Live() as live:
271286
live.log_metric("foo", 1)
272-
live.next_step()
287+
live.step = 0
288+
live.make_summary()
289+
post_to_studio(live, "data")
273290

274291
call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
275292
assert "start" not in call_types
@@ -287,7 +304,8 @@ def test_post_to_studio_inside_subdir(
287304

288305
live = Live()
289306
live.log_metric("foo", 1)
290-
live.next_step()
307+
live.make_summary()
308+
post_to_studio(live, "data")
291309

292310
foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix()
293311

@@ -317,7 +335,8 @@ def test_post_to_studio_inside_subdir_dvc_exp(
317335

318336
live = Live()
319337
live.log_metric("foo", 1)
320-
live.next_step()
338+
live.make_summary()
339+
post_to_studio(live, "data")
321340

322341
foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix()
323342

@@ -370,7 +389,9 @@ def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post):
370389

371390
live = Live()
372391
live.log_image("foo.png", ImagePIL.new("RGB", (10, 10), (0, 0, 0)))
373-
live.next_step()
392+
live.step = 0
393+
live.make_summary()
394+
post_to_studio(live, "data")
374395

375396
foo_path = (Path(live.plots_dir) / Image.subfolder / "foo.png").as_posix()
376397

@@ -409,11 +430,13 @@ def test_post_to_studio_name(tmp_dir, mocked_dvc_repo, mocked_studio_post):
409430

410431

411432
def test_post_to_studio_if_done_skipped(tmp_dir, mocked_dvc_repo, mocked_studio_post):
412-
live = Live()
413-
live._studio_events_to_skip.add("start")
414-
live._studio_events_to_skip.add("done")
415-
live.log_metric("foo", 1)
416-
live.end()
433+
with Live() as live:
434+
live._studio_events_to_skip.add("start")
435+
live._studio_events_to_skip.add("done")
436+
live.log_metric("foo", 1)
437+
live.step = 0
438+
live.make_summary()
439+
post_to_studio(live, "data")
417440

418441
mocked_post, _ = mocked_studio_post
419442
call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
@@ -439,8 +462,9 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):
439462
)
440463

441464
live.log_metric("foo", 1)
465+
live.make_summary()
466+
post_to_studio(live, "data")
442467

443-
live.next_step()
444468
mocked_post.assert_called_with(
445469
"https://0.0.0.0/api/live",
446470
**get_studio_call(
@@ -452,9 +476,11 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):
452476
),
453477
)
454478

479+
live.step += 1
455480
live.log_metric("foo", 2)
481+
live.make_summary()
482+
post_to_studio(live, "data")
456483

457-
live.next_step()
458484
mocked_post.assert_called_with(
459485
"https://0.0.0.0/api/live",
460486
**get_studio_call(
@@ -466,7 +492,7 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):
466492
),
467493
)
468494

469-
live.end()
495+
post_to_studio(live, "done")
470496
mocked_post.assert_called_with(
471497
"https://0.0.0.0/api/live",
472498
**get_studio_call("done", baseline_sha="0" * 40, exp_name=live._exp_name),
@@ -485,7 +511,9 @@ def test_post_to_studio_skip_if_no_repo_url(
485511

486512
with Live() as live:
487513
live.log_metric("foo", 1)
488-
live.next_step()
514+
live.step = 0
515+
live.make_summary()
516+
post_to_studio(live, "data")
489517

490518
assert mocked_post.call_count == 0
491519

@@ -503,7 +531,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post
503531
live.step = 0
504532
live.log_metric("foo", 1)
505533
live.log_metric("bar", 0.1)
506-
live.sync()
534+
live.make_summary()
535+
post_to_studio(live, "data")
507536

508537
mocked_post.assert_called_with(
509538
"https://0.0.0.0/api/live",
@@ -521,7 +550,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post
521550
live.log_metric("foo", 2)
522551
live.log_metric("foo", 3)
523552
live.log_metric("bar", 0.2)
524-
live.sync()
553+
live.make_summary()
554+
post_to_studio(live, "data")
525555

526556
mocked_post.assert_called_with(
527557
"https://0.0.0.0/api/live",

0 commit comments

Comments
 (0)
Please sign in to comment.