Skip to content

Commit 2048753

Browse files
committed
Poll Based Waiting for Job Completion
`Experiment` was given a `wait` method that takes a collection of Launched Job IDs and will wait until the launch reaches a terminal state by either completing or erroring out. Implements a polling based solution.
1 parent ddde9c5 commit 2048753

File tree

5 files changed

+456
-5
lines changed

5 files changed

+456
-5
lines changed

Diff for: smartsim/_core/control/interval.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# BSD 2-Clause License
2+
#
3+
# Copyright (c) 2021-2024, Hewlett Packard Enterprise
4+
# All rights reserved.
5+
#
6+
# Redistribution and use in source and binary forms, with or without
7+
# modification, are permitted provided that the following conditions are met:
8+
#
9+
# 1. Redistributions of source code must retain the above copyright notice, this
10+
# list of conditions and the following disclaimer.
11+
#
12+
# 2. Redistributions in binary form must reproduce the above copyright notice,
13+
# this list of conditions and the following disclaimer in the documentation
14+
# and/or other materials provided with the distribution.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
from __future__ import annotations
28+
29+
import time
30+
import typing as t
31+
32+
Seconds = t.NewType("Seconds", float)
33+
34+
35+
class SynchronousTimeInterval:
36+
"""A utility class to represent and synchronously block the execution of a
37+
thread for an interval of time.
38+
"""
39+
40+
def __init__(self, delta: float | None, strict: bool = True) -> None:
41+
"""Initialize a new `SynchronousTimeInterval` interval
42+
43+
:param delta: The difference in time the interval represents. If
44+
`None`, the interval will represent an infinite amount of time in
45+
seconds.
46+
:param strict: Wether or not to raise in the case where negative time
47+
delta is provided. If `False`, the `SynchronousTimeInterval` and a negative `delta`
48+
is provided, the interval will be infinite.
49+
:raises ValueError: The `delta` is negative and `strict` is `True`
50+
"""
51+
if delta is not None and delta < 0 and strict:
52+
raise ValueError("Timeout value cannot be less than 0")
53+
if delta is None or delta < 0:
54+
delta = float("inf")
55+
self._delta = Seconds(delta)
56+
"""The amount of time, in seconds the interval spans."""
57+
self._start = time.perf_counter()
58+
"""The time of the creation of the interval"""
59+
60+
@property
61+
def delta(self) -> Seconds:
62+
"""The difference in time the interval represents
63+
64+
:returns: The difference in time the interval represents
65+
"""
66+
return self._delta
67+
68+
@property
69+
def elapsed(self) -> Seconds:
70+
"""The amount of time that has passed since the interval was created
71+
72+
:returns: The amount of time that has passed since the interval was
73+
created
74+
"""
75+
return Seconds(time.perf_counter() - self._start)
76+
77+
@property
78+
def remaining(self) -> Seconds:
79+
"""The amount of time remaining in the interval
80+
81+
:returns: The amount of time remaining in the interval
82+
"""
83+
return Seconds(max(self.delta - self.elapsed, 0))
84+
85+
@property
86+
def expired(self) -> bool:
87+
"""The amount of time remaining in interval
88+
89+
:returns: The amount of time left in the interval
90+
"""
91+
return self.remaining <= 0
92+
93+
@property
94+
def infinite(self) -> bool:
95+
"""Is the timeout interval infinitely long.
96+
97+
:returns: `True` if the delta is infinite, `False` otherwise
98+
"""
99+
return self.remaining == float("inf")
100+
101+
def new_interval(self) -> SynchronousTimeInterval:
102+
"""Make a new timeout with the same interval
103+
104+
:returns: The new timeout
105+
"""
106+
return type(self)(self.delta)
107+
108+
def wait(self) -> None:
109+
"""Block the thread until the timeout completes
110+
111+
:raises RuntimeError: The thread would be blocked forever
112+
"""
113+
if self.remaining == float("inf"):
114+
raise RuntimeError("Cannot block thread forever")
115+
time.sleep(self.remaining)

Diff for: smartsim/_core/utils/helpers.py

+43-2
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@
3131

3232
import base64
3333
import collections.abc
34+
import functools
3435
import os
3536
import signal
3637
import subprocess
3738
import typing as t
3839
import uuid
3940
from datetime import datetime
40-
from functools import lru_cache
4141
from pathlib import Path
4242
from shutil import which
4343

@@ -46,6 +46,10 @@
4646
if t.TYPE_CHECKING:
4747
from types import FrameType
4848

49+
from typing_extensions import TypeVarTuple, Unpack
50+
51+
_Ts = TypeVarTuple("_Ts")
52+
4953

5054
_T = t.TypeVar("_T")
5155
_HashableT = t.TypeVar("_HashableT", bound=t.Hashable)
@@ -97,7 +101,7 @@ def create_lockfile_name() -> str:
97101
return f"smartsim-{lock_suffix}.lock"
98102

99103

100-
@lru_cache(maxsize=20, typed=False)
104+
@functools.lru_cache(maxsize=20, typed=False)
101105
def check_dev_log_level() -> bool:
102106
lvl = os.environ.get("SMARTSIM_LOG_LEVEL", "")
103107
return lvl == "developer"
@@ -486,6 +490,43 @@ def group_by(
486490
return dict(groups)
487491

488492

493+
def pack_params(
494+
fn: t.Callable[[Unpack[_Ts]], _T]
495+
) -> t.Callable[[tuple[Unpack[_Ts]]], _T]:
496+
r"""Take a function that takes an unspecified number of positional arguments
497+
and turn it into a function that takes one argument of type `tuple` of
498+
unspecified length. The main use case is largely just for iterating over an
499+
iterable where arguments are "pre-zipped" into tuples. E.g.
500+
501+
.. highlight:: python
502+
.. code-block:: python
503+
504+
def pretty_print_dict(d):
505+
fmt_pair = lambda key, value: f"{repr(key)}: {repr(value)},"
506+
body = "\n".join(map(pack_params(fmt_pair), d.items()))
507+
# ^^^^^^^^^^^^^^^^^^^^^
508+
print(f"{{\n{textwrap.indent(body, ' ')}\n}}")
509+
510+
pretty_print_dict({"spam": "eggs", "foo": "bar", "hello": "world"})
511+
# prints:
512+
# {
513+
# 'spam': 'eggs',
514+
# 'foo': 'bar',
515+
# 'hello': 'world',
516+
# }
517+
518+
:param fn: A callable that takes many positional parameters.
519+
:returns: A callable that takes a single positional parameter of type tuple
520+
of with the same shape as the original callable parameter list.
521+
"""
522+
523+
@functools.wraps(fn)
524+
def packed(args: tuple[Unpack[_Ts]]) -> _T:
525+
return fn(*args)
526+
527+
return packed
528+
529+
489530
@t.final
490531
class SignalInterceptionStack(collections.abc.Collection[_TSignalHandlerFn]):
491532
"""Registers a stack of callables to be called when a signal is

Diff for: smartsim/experiment.py

+79-3
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
# pylint: disable=too-many-lines
28-
2927
from __future__ import annotations
3028

3129
import collections
@@ -42,9 +40,11 @@
4240

4341
from smartsim._core import dispatch
4442
from smartsim._core.config import CONFIG
43+
from smartsim._core.control import interval as _interval
4544
from smartsim._core.control.launch_history import LaunchHistory as _LaunchHistory
45+
from smartsim._core.utils import helpers as _helpers
4646
from smartsim.error import errors
47-
from smartsim.status import InvalidJobStatus, JobStatus
47+
from smartsim.status import TERMINAL_STATUSES, InvalidJobStatus, JobStatus
4848

4949
from ._core import Controller, Generator, Manifest, previewrenderer
5050
from .database import FeatureStore
@@ -276,6 +276,82 @@ def get_status(
276276
stats = (stats_map.get(i, InvalidJobStatus.NEVER_STARTED) for i in ids)
277277
return tuple(stats)
278278

279+
def wait(
280+
self, *ids: LaunchedJobID, timeout: float | None = None, verbose: bool = True
281+
) -> None:
282+
"""Block execution until all of the provided launched jobs, represented
283+
by an ID, have entered a terminal status.
284+
285+
:param ids: The ids of the launched ids to wait for.
286+
:param timeout: The max time to wait for all of the launched jobs to end.
287+
:param verbose: Whether found statuses should be displayed in the console.
288+
:raises ValueError: No IDs were provided.
289+
"""
290+
if not ids:
291+
raise ValueError("No job ids to wait on provided")
292+
self._poll_for_statuses(
293+
ids,
294+
TERMINAL_STATUSES,
295+
timeout=_interval.SynchronousTimeInterval(timeout),
296+
verbose=verbose,
297+
)
298+
299+
def _poll_for_statuses(
300+
self,
301+
ids: t.Sequence[LaunchedJobID],
302+
statuses: t.Collection[JobStatus],
303+
timeout: _interval.SynchronousTimeInterval | None = None,
304+
interval: _interval.SynchronousTimeInterval | None = None,
305+
verbose: bool = True,
306+
) -> dict[LaunchedJobID, JobStatus | InvalidJobStatus]:
307+
"""Poll launchers until status until all jobs represented by a
308+
collections of ids have changed state to one of the provided statuses.
309+
310+
:param ids: IDs of launches to poll for status.
311+
:param statuses: A collection of statuses to poll for.
312+
:param timeout: The minimum amount of time to spend polling all jobs to
313+
reach one of the supplied statuses. If not supplied or `None`, the
314+
experiment will poll indefinitely.
315+
:param interval: The minimum time between polling launchers.
316+
:param verbose: Whether or not to log polled states the console.
317+
:raises ValueError: The interval between polling launchers is infinite
318+
:raises TimeoutError: The polling interval was exceeded.
319+
:returns: A mapping of ids to the status they entered that ended
320+
polling.
321+
"""
322+
terminal = frozenset(itertools.chain(statuses, InvalidJobStatus))
323+
log = logger.info if verbose else lambda *_, **__: None
324+
method_timeout = timeout or _interval.SynchronousTimeInterval(None)
325+
iter_timeout = interval or _interval.SynchronousTimeInterval(5.0)
326+
final: dict[LaunchedJobID, JobStatus | InvalidJobStatus] = {}
327+
328+
def is_finished(
329+
id_: LaunchedJobID, status: JobStatus | InvalidJobStatus
330+
) -> bool:
331+
job_title = f"Job({id_}): "
332+
if done := status in terminal:
333+
log(f"{job_title}Finished with status '{status.value}'")
334+
else:
335+
log(f"{job_title}Running with status '{status.value}'")
336+
return done
337+
338+
if iter_timeout.infinite:
339+
raise ValueError("Polling interval cannot be infinite")
340+
while ids and not method_timeout.expired:
341+
iter_timeout = iter_timeout.new_interval()
342+
stats = zip(ids, self.get_status(*ids))
343+
is_done = _helpers.group_by(_helpers.pack_params(is_finished), stats)
344+
final |= dict(is_done.get(True, ()))
345+
ids = tuple(id_ for id_, _ in is_done.get(False, ()))
346+
if ids:
347+
iter_timeout.wait()
348+
if ids:
349+
raise TimeoutError(
350+
f"Job ID(s) {', '.join(map(str, ids))} failed to reach "
351+
"terminal status before timeout"
352+
)
353+
return final
354+
279355
@_contextualize
280356
def _generate(self, generator: Generator, job: Job, job_index: int) -> pathlib.Path:
281357
"""Generate the directory structure and files for a ``Job``

0 commit comments

Comments
 (0)