diff --git a/smartsim/_core/control/interval.py b/smartsim/_core/control/interval.py new file mode 100644 index 000000000..e35b1c694 --- /dev/null +++ b/smartsim/_core/control/interval.py @@ -0,0 +1,112 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import time +import typing as t + +Seconds = t.NewType("Seconds", float) + + +class SynchronousTimeInterval: + """A utility class to represent and synchronously block the execution of a + thread for an interval of time. + """ + + def __init__(self, delta: float | None) -> None: + """Initialize a new `SynchronousTimeInterval` interval + + :param delta: The difference in time the interval represents in + seconds. If `None`, the interval will represent an infinite amount + of time. + :raises ValueError: The `delta` is negative + """ + if delta is not None and delta < 0: + raise ValueError("Timeout value cannot be less than 0") + if delta is None: + delta = float("inf") + self._delta = Seconds(delta) + """The amount of time, in seconds, the interval spans.""" + self._start = time.perf_counter() + """The time of the creation of the interval""" + + @property + def delta(self) -> Seconds: + """The difference in time the interval represents + + :returns: The difference in time the interval represents + """ + return self._delta + + @property + def elapsed(self) -> Seconds: + """The amount of time that has passed since the interval was created + + :returns: The amount of time that has passed since the interval was + created + """ + return Seconds(time.perf_counter() - self._start) + + @property + def remaining(self) -> Seconds: + """The amount of time remaining in the interval + + :returns: The amount of time remaining in the interval + """ + return Seconds(max(self.delta - self.elapsed, 0)) + + @property + def expired(self) -> bool: + """The amount of time remaining in interval + + :returns: The amount of time left in the interval + """ + return self.remaining <= 0 + + @property + def infinite(self) -> bool: + """Return true if the timeout interval is infinitely long + + :returns: `True` if the delta is infinite, `False` otherwise + """ + return self.remaining == float("inf") + + def new_interval(self) -> SynchronousTimeInterval: + """Make a new timeout with the same interval + + :returns: The new time interval + """ + return type(self)(self.delta) + + def block(self) -> None: + """Block the thread until the timeout completes + + :raises RuntimeError: The thread would be blocked forever + """ + if self.remaining == float("inf"): + raise RuntimeError("Cannot block thread forever") + time.sleep(self.remaining) diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index 56eaa98d3..1133358a6 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -31,6 +31,7 @@ import base64 import collections.abc +import functools import itertools import os import signal @@ -40,12 +41,15 @@ import uuid import warnings from datetime import datetime -from functools import lru_cache from shutil import which if t.TYPE_CHECKING: from types import FrameType + from typing_extensions import TypeVarTuple, Unpack + + _Ts = TypeVarTuple("_Ts") + _T = t.TypeVar("_T") _HashableT = t.TypeVar("_HashableT", bound=t.Hashable) @@ -97,7 +101,7 @@ def create_lockfile_name() -> str: return f"smartsim-{lock_suffix}.lock" -@lru_cache(maxsize=20, typed=False) +@functools.lru_cache(maxsize=20, typed=False) def check_dev_log_level() -> bool: lvl = os.environ.get("SMARTSIM_LOG_LEVEL", "") return lvl == "developer" @@ -454,6 +458,43 @@ def group_by( return dict(groups) +def pack_params( + fn: t.Callable[[Unpack[_Ts]], _T] +) -> t.Callable[[tuple[Unpack[_Ts]]], _T]: + r"""Take a function that takes an unspecified number of positional arguments + and turn it into a function that takes one argument of type `tuple` of + unspecified length. The main use case is largely just for iterating over an + iterable where arguments are "pre-zipped" into tuples. E.g. + + .. highlight:: python + .. code-block:: python + + def pretty_print_dict(d): + fmt_pair = lambda key, value: f"{repr(key)}: {repr(value)}," + body = "\n".join(map(pack_params(fmt_pair), d.items())) + # ^^^^^^^^^^^^^^^^^^^^^ + print(f"{{\n{textwrap.indent(body, ' ')}\n}}") + + pretty_print_dict({"spam": "eggs", "foo": "bar", "hello": "world"}) + # prints: + # { + # 'spam': 'eggs', + # 'foo': 'bar', + # 'hello': 'world', + # } + + :param fn: A callable that takes many positional parameters. + :returns: A callable that takes a single positional parameter of type tuple + of with the same shape as the original callable parameter list. + """ + + @functools.wraps(fn) + def packed(args: tuple[Unpack[_Ts]]) -> _T: + return fn(*args) + + return packed + + @t.final class SignalInterceptionStack(collections.abc.Collection[_TSignalHandlerFn]): """Registers a stack of callables to be called when a signal is diff --git a/smartsim/experiment.py b/smartsim/experiment.py index ea7cccc3d..8701f62ce 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -24,8 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# pylint: disable=too-many-lines - from __future__ import annotations import datetime @@ -39,9 +37,11 @@ from smartsim._core import dispatch from smartsim._core.config import CONFIG +from smartsim._core.control import interval as _interval from smartsim._core.control.launch_history import LaunchHistory as _LaunchHistory +from smartsim._core.utils import helpers as _helpers from smartsim.error import errors -from smartsim.status import InvalidJobStatus, JobStatus +from smartsim.status import TERMINAL_STATUSES, InvalidJobStatus, JobStatus from ._core import Generator, Manifest, previewrenderer from .entity import TelemetryConfiguration @@ -254,6 +254,84 @@ def get_status( stats = (stats_map.get(i, InvalidJobStatus.NEVER_STARTED) for i in ids) return tuple(stats) + def wait( + self, *ids: LaunchedJobID, timeout: float | None = None, verbose: bool = True + ) -> None: + """Block execution until all of the provided launched jobs, represented + by an ID, have entered a terminal status. + + :param ids: The ids of the launched jobs to wait for. + :param timeout: The max time to wait for all of the launched jobs to end. + :param verbose: Whether found statuses should be displayed in the console. + :raises ValueError: No IDs were provided. + """ + if not ids: + raise ValueError("No job ids to wait on provided") + self._poll_for_statuses( + ids, TERMINAL_STATUSES, timeout=timeout, verbose=verbose + ) + + def _poll_for_statuses( + self, + ids: t.Sequence[LaunchedJobID], + statuses: t.Collection[JobStatus], + timeout: float | None = None, + interval: float = 5.0, + verbose: bool = True, + ) -> dict[LaunchedJobID, JobStatus | InvalidJobStatus]: + """Poll the experiment's launchers for the statuses of the launched + jobs with the provided ids, until the status of the changes to one of + the provided statuses. + + :param ids: The ids of the launched jobs to wait for. + :param statuses: A collection of statuses to poll for. + :param timeout: The minimum amount of time to spend polling all jobs to + reach one of the supplied statuses. If not supplied or `None`, the + experiment will poll indefinitely. + :param interval: The minimum time between polling launchers. + :param verbose: Whether or not to log polled states to the console. + :raises ValueError: The interval between polling launchers is infinite + :raises TimeoutError: The polling interval was exceeded. + :returns: A mapping of ids to the status they entered that ended + polling. + """ + terminal = frozenset(itertools.chain(statuses, InvalidJobStatus)) + log = logger.info if verbose else lambda *_, **__: None + method_timeout = _interval.SynchronousTimeInterval(timeout) + iter_timeout = _interval.SynchronousTimeInterval(interval) + final: dict[LaunchedJobID, JobStatus | InvalidJobStatus] = {} + + def is_finished( + id_: LaunchedJobID, status: JobStatus | InvalidJobStatus + ) -> bool: + job_title = f"Job({id_}): " + if done := status in terminal: + log(f"{job_title}Finished with status '{status.value}'") + else: + log(f"{job_title}Running with status '{status.value}'") + return done + + if iter_timeout.infinite: + raise ValueError("Polling interval cannot be infinite") + while ids and not method_timeout.expired: + iter_timeout = iter_timeout.new_interval() + stats = zip(ids, self.get_status(*ids)) + is_done = _helpers.group_by(_helpers.pack_params(is_finished), stats) + final |= dict(is_done.get(True, ())) + ids = tuple(id_ for id_, _ in is_done.get(False, ())) + if ids: + ( + iter_timeout + if iter_timeout.remaining < method_timeout.remaining + else method_timeout + ).block() + if ids: + raise TimeoutError( + f"Job ID(s) {', '.join(map(str, ids))} failed to reach " + "terminal status before timeout" + ) + return final + @_contextualize def _generate( self, generator: Generator, job: Job, job_index: int diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 39f2b9b11..8dfda1012 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -27,14 +27,18 @@ from __future__ import annotations import dataclasses +import io import itertools import random +import re +import time import typing as t import uuid import pytest from smartsim._core import dispatch +from smartsim._core.control.interval import SynchronousTimeInterval from smartsim._core.control.launch_history import LaunchHistory from smartsim._core.utils.launcher import LauncherProtocol, create_job_id from smartsim.entity import entity @@ -316,7 +320,7 @@ def get_status(self, *ids: LaunchedJobID): @pytest.fixture -def make_populated_experment(monkeypatch, experiment): +def make_populated_experiment(monkeypatch, experiment): def impl(num_active_launchers): new_launchers = (GetStatusLauncher() for _ in range(num_active_launchers)) id_to_launcher = { @@ -330,8 +334,8 @@ def impl(num_active_launchers): yield impl -def test_experiment_can_get_statuses(make_populated_experment): - exp = make_populated_experment(num_active_launchers=1) +def test_experiment_can_get_statuses(make_populated_experiment): + exp = make_populated_experiment(num_active_launchers=1) (launcher,) = exp._launch_history.iter_past_launchers() ids = tuple(launcher.known_ids) recieved_stats = exp.get_status(*ids) @@ -346,9 +350,9 @@ def test_experiment_can_get_statuses(make_populated_experment): [pytest.param(i, id=f"{i} launcher(s)") for i in (2, 3, 5, 10, 20, 100)], ) def test_experiment_can_get_statuses_from_many_launchers( - make_populated_experment, num_launchers + make_populated_experiment, num_launchers ): - exp = make_populated_experment(num_active_launchers=num_launchers) + exp = make_populated_experiment(num_active_launchers=num_launchers) launcher_and_rand_ids = ( (launcher, random.choice(tuple(launcher.id_to_status))) for launcher in exp._launch_history.iter_past_launchers() @@ -363,9 +367,9 @@ def test_experiment_can_get_statuses_from_many_launchers( def test_get_status_returns_not_started_for_unrecognized_ids( - monkeypatch, make_populated_experment + monkeypatch, make_populated_experiment ): - exp = make_populated_experment(num_active_launchers=1) + exp = make_populated_experiment(num_active_launchers=1) brand_new_id = create_job_id() ((launcher, (id_not_known_by_exp, *rest)),) = ( exp._launch_history.group_by_launcher().items() @@ -378,7 +382,7 @@ def test_get_status_returns_not_started_for_unrecognized_ids( def test_get_status_de_dups_ids_passed_to_launchers( - monkeypatch, make_populated_experment + monkeypatch, make_populated_experiment ): def track_calls(fn): calls = [] @@ -389,7 +393,7 @@ def impl(*a, **kw): return calls, impl - exp = make_populated_experment(num_active_launchers=1) + exp = make_populated_experiment(num_active_launchers=1) ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() calls, tracked_get_status = track_calls(launcher.get_status) monkeypatch.setattr(launcher, "get_status", tracked_get_status) @@ -399,3 +403,131 @@ def impl(*a, **kw): assert len(calls) == 1, "Launcher's `get_status` was called more than once" (call,) = calls assert call == ((id_,), {}), "IDs were not de-duplicated" + + +def test_wait_handles_empty_call_args(experiment): + """An exception is raised when there are no jobs to complete""" + with pytest.raises(ValueError, match="No job ids"): + experiment.wait() + + +def test_wait_does_not_block_unknown_id(experiment): + """If an experiment does not recognize a job id, it should not wait for its + completion + """ + now = time.perf_counter() + experiment.wait(create_job_id()) + assert time.perf_counter() - now < 1 + + +def test_wait_calls_prefered_impl(make_populated_experiment, monkeypatch): + """Make wait is calling the expected method for checking job statuses. + Right now we only have the "polling" impl, but in future this might change + to an event based system. + """ + exp = make_populated_experiment(1) + ((_, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + was_called = False + + def mocked_impl(*args, **kwargs): + nonlocal was_called + was_called = True + + monkeypatch.setattr(exp, "_poll_for_statuses", mocked_impl) + exp.wait(id_) + assert was_called + + +@pytest.mark.parametrize( + "num_polls", + [ + pytest.param(i, id=f"Poll for status {i} times") + for i in (1, 5, 10, 20, 100, 1_000) + ], +) +@pytest.mark.parametrize("verbose", [True, False]) +def test_poll_status_blocks_until_job_is_completed( + monkeypatch, make_populated_experiment, num_polls, verbose +): + """Make sure that the polling based implementation blocks the calling + thread. Use varying number of polls to simulate varying lengths of job time + for a job to complete. + + Additionally check to make sure that the expected log messages are present + """ + exp = make_populated_experiment(1) + ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + (current_status,) = launcher.get_status(id_).values() + different_statuses = set(JobStatus) - {current_status} + (new_status, *_) = different_statuses + mock_log = io.StringIO() + + @dataclasses.dataclass + class ChangeStatusAfterNPolls: + n: int + from_: JobStatus + to: JobStatus + num_calls: int = dataclasses.field(default=0, init=False) + + def __call__(self, *args, **kwargs): + self.num_calls += 1 + ret_status = self.to if self.num_calls >= self.n else self.from_ + return (ret_status,) + + mock_get_status = ChangeStatusAfterNPolls(num_polls, current_status, new_status) + monkeypatch.setattr(exp, "get_status", mock_get_status) + monkeypatch.setattr( + "smartsim.experiment.logger.info", lambda s: mock_log.write(f"{s}\n") + ) + final_statuses = exp._poll_for_statuses( + [id_], different_statuses, timeout=10, interval=0, verbose=verbose + ) + assert final_statuses == {id_: new_status} + + expected_log = io.StringIO() + expected_log.writelines( + f"Job({id_}): Running with status '{current_status.value}'\n" + for _ in range(num_polls - 1) + ) + expected_log.write(f"Job({id_}): Finished with status '{new_status.value}'\n") + assert mock_get_status.num_calls == num_polls + assert mock_log.getvalue() == (expected_log.getvalue() if verbose else "") + + +def test_poll_status_raises_when_called_with_infinite_iter_wait( + make_populated_experiment, +): + """Cannot wait forever between polls. That will just block the thread after + the first poll + """ + exp = make_populated_experiment(1) + ((_, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + with pytest.raises(ValueError, match="Polling interval cannot be infinite"): + exp._poll_for_statuses( + [id_], + [], + timeout=10, + interval=float("inf"), + ) + + +def test_poll_for_status_raises_if_ids_not_found_within_timeout( + make_populated_experiment, +): + """If there is a timeout, a timeout error should be raised when it is exceeded""" + exp = make_populated_experiment(1) + ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + (current_status,) = launcher.get_status(id_).values() + different_statuses = set(JobStatus) - {current_status} + with pytest.raises( + TimeoutError, + match=re.escape( + f"Job ID(s) {id_} failed to reach terminal status before timeout" + ), + ): + exp._poll_for_statuses( + [id_], + different_statuses, + timeout=1, + interval=0, + ) diff --git a/tests/test_intervals.py b/tests/test_intervals.py new file mode 100644 index 000000000..1b865867f --- /dev/null +++ b/tests/test_intervals.py @@ -0,0 +1,87 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import contextlib +import operator +import time + +import pytest + +from smartsim._core.control.interval import SynchronousTimeInterval + +pytestmark = pytest.mark.group_a + + +@pytest.mark.parametrize( + "timeout", [pytest.param(i, id=f"{i} second(s)") for i in range(10)] +) +def test_sync_timeout_finite(timeout, monkeypatch): + """Test that the sync timeout intervals are correctly calculated""" + monkeypatch.setattr(time, "perf_counter", lambda *_, **__: 0) + t = SynchronousTimeInterval(timeout) + assert t.delta == timeout + assert t.elapsed == 0 + assert t.remaining == timeout + assert (operator.not_ if timeout > 0 else bool)(t.expired) + assert not t.infinite + future = timeout + 2 + monkeypatch.setattr(time, "perf_counter", lambda *_, **__: future) + assert t.elapsed == future + assert t.remaining == 0 + assert t.expired + assert not t.infinite + new_t = t.new_interval() + assert new_t.delta == timeout + assert new_t.elapsed == 0 + assert new_t.remaining == timeout + assert (operator.not_ if timeout > 0 else bool)(new_t.expired) + assert not new_t.infinite + + +def test_sync_timeout_can_block_thread(): + """Test that the sync timeout can block the calling thread""" + timeout = 1 + now = time.perf_counter() + SynchronousTimeInterval(timeout).block() + later = time.perf_counter() + assert abs(later - now - timeout) <= 0.25 + + +def test_sync_timeout_infinte(): + """Passing in `None` to a sync timeout creates a timeout with an infinite + delta time + """ + t = SynchronousTimeInterval(None) + assert t.remaining == float("inf") + assert t.infinite + with pytest.raises(RuntimeError, match="block thread forever"): + t.block() + + +def test_sync_timeout_raises_on_invalid_value(monkeypatch): + """Cannot make a sync time interval with a negative time delta""" + with pytest.raises(ValueError): + SynchronousTimeInterval(-1)