From 7c633a61c3d895060e1e7410fccf067bd47284f7 Mon Sep 17 00:00:00 2001 From: Matt Drozt Date: Mon, 12 Aug 2024 18:29:56 -0500 Subject: [PATCH 1/5] Poll Based Waiting for Job Completion `Experiment` was given a `wait` method that takes a collection of Launched Job IDs and will wait until the launch reaches a terminal state by either completing or erroring out. Implements a polling based solution. --- smartsim/_core/control/interval.py | 115 ++++++++++++++++++++++++++++ smartsim/_core/utils/helpers.py | 45 ++++++++++- smartsim/experiment.py | 82 +++++++++++++++++++- tests/test_experiment.py | 118 +++++++++++++++++++++++++++++ tests/test_intervals.py | 103 +++++++++++++++++++++++++ 5 files changed, 458 insertions(+), 5 deletions(-) create mode 100644 smartsim/_core/control/interval.py create mode 100644 tests/test_intervals.py diff --git a/smartsim/_core/control/interval.py b/smartsim/_core/control/interval.py new file mode 100644 index 000000000..e2153b85c --- /dev/null +++ b/smartsim/_core/control/interval.py @@ -0,0 +1,115 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import time +import typing as t + +Seconds = t.NewType("Seconds", float) + + +class SynchronousTimeInterval: + """A utility class to represent and synchronously block the execution of a + thread for an interval of time. + """ + + def __init__(self, delta: float | None, strict: bool = True) -> None: + """Initialize a new `SynchronousTimeInterval` interval + + :param delta: The difference in time the interval represents. If + `None`, the interval will represent an infinite amount of time in + seconds. + :param strict: Wether or not to raise in the case where negative time + delta is provided. If `False`, the `SynchronousTimeInterval` and a negative `delta` + is provided, the interval will be infinite. + :raises ValueError: The `delta` is negative and `strict` is `True` + """ + if delta is not None and delta < 0 and strict: + raise ValueError("Timeout value cannot be less than 0") + if delta is None or delta < 0: + delta = float("inf") + self._delta = Seconds(delta) + """The amount of time, in seconds the interval spans.""" + self._start = time.perf_counter() + """The time of the creation of the interval""" + + @property + def delta(self) -> Seconds: + """The difference in time the interval represents + + :returns: The difference in time the interval represents + """ + return self._delta + + @property + def elapsed(self) -> Seconds: + """The amount of time that has passed since the interval was created + + :returns: The amount of time that has passed since the interval was + created + """ + return Seconds(time.perf_counter() - self._start) + + @property + def remaining(self) -> Seconds: + """The amount of time remaining in the interval + + :returns: The amount of time remaining in the interval + """ + return Seconds(max(self.delta - self.elapsed, 0)) + + @property + def expired(self) -> bool: + """The amount of time remaining in interval + + :returns: The amount of time left in the interval + """ + return self.remaining <= 0 + + @property + def infinite(self) -> bool: + """Is the timeout interval infinitely long. + + :returns: `True` if the delta is infinite, `False` otherwise + """ + return self.remaining == float("inf") + + def new_interval(self) -> SynchronousTimeInterval: + """Make a new timeout with the same interval + + :returns: The new timeout + """ + return type(self)(self.delta) + + def wait(self) -> None: + """Block the thread until the timeout completes + + :raises RuntimeError: The thread would be blocked forever + """ + if self.remaining == float("inf"): + raise RuntimeError("Cannot block thread forever") + time.sleep(self.remaining) diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index 62d176259..ec0320db3 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -31,13 +31,13 @@ import base64 import collections.abc +import functools import os import signal import subprocess import typing as t import uuid from datetime import datetime -from functools import lru_cache from pathlib import Path from shutil import which @@ -46,6 +46,10 @@ if t.TYPE_CHECKING: from types import FrameType + from typing_extensions import TypeVarTuple, Unpack + + _Ts = TypeVarTuple("_Ts") + _T = t.TypeVar("_T") _HashableT = t.TypeVar("_HashableT", bound=t.Hashable) @@ -97,7 +101,7 @@ def create_lockfile_name() -> str: return f"smartsim-{lock_suffix}.lock" -@lru_cache(maxsize=20, typed=False) +@functools.lru_cache(maxsize=20, typed=False) def check_dev_log_level() -> bool: lvl = os.environ.get("SMARTSIM_LOG_LEVEL", "") return lvl == "developer" @@ -486,6 +490,43 @@ def group_by( return dict(groups) +def pack_params( + fn: t.Callable[[Unpack[_Ts]], _T] +) -> t.Callable[[tuple[Unpack[_Ts]]], _T]: + r"""Take a function that takes an unspecified number of positional arguments + and turn it into a function that takes one argument of type `tuple` of + unspecified length. The main use case is largely just for iterating over an + iterable where arguments are "pre-zipped" into tuples. E.g. + + .. highlight:: python + .. code-block:: python + + def pretty_print_dict(d): + fmt_pair = lambda key, value: f"{repr(key)}: {repr(value)}," + body = "\n".join(map(pack_params(fmt_pair), d.items())) + # ^^^^^^^^^^^^^^^^^^^^^ + print(f"{{\n{textwrap.indent(body, ' ')}\n}}") + + pretty_print_dict({"spam": "eggs", "foo": "bar", "hello": "world"}) + # prints: + # { + # 'spam': 'eggs', + # 'foo': 'bar', + # 'hello': 'world', + # } + + :param fn: A callable that takes many positional parameters. + :returns: A callable that takes a single positional parameter of type tuple + of with the same shape as the original callable parameter list. + """ + + @functools.wraps(fn) + def packed(args: tuple[Unpack[_Ts]]) -> _T: + return fn(*args) + + return packed + + @t.final class SignalInterceptionStack(collections.abc.Collection[_TSignalHandlerFn]): """Registers a stack of callables to be called when a signal is diff --git a/smartsim/experiment.py b/smartsim/experiment.py index 55ccea7b5..a3b9d0832 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -24,8 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# pylint: disable=too-many-lines - from __future__ import annotations import collections @@ -42,9 +40,11 @@ from smartsim._core import dispatch from smartsim._core.config import CONFIG +from smartsim._core.control import interval as _interval from smartsim._core.control.launch_history import LaunchHistory as _LaunchHistory +from smartsim._core.utils import helpers as _helpers from smartsim.error import errors -from smartsim.status import InvalidJobStatus, JobStatus +from smartsim.status import TERMINAL_STATUSES, InvalidJobStatus, JobStatus from ._core import Controller, Generator, Manifest, previewrenderer from .database import FeatureStore @@ -276,6 +276,82 @@ def get_status( stats = (stats_map.get(i, InvalidJobStatus.NEVER_STARTED) for i in ids) return tuple(stats) + def wait( + self, *ids: LaunchedJobID, timeout: float | None = None, verbose: bool = True + ) -> None: + """Block execution until all of the provided launched jobs, represented + by an ID, have entered a terminal status. + + :param ids: The ids of the launched ids to wait for. + :param timeout: The max time to wait for all of the launched jobs to end. + :param verbose: Whether found statuses should be displayed in the console. + :raises ValueError: No IDs were provided. + """ + if not ids: + raise ValueError("No job ids to wait on provided") + self._poll_for_statuses( + ids, + TERMINAL_STATUSES, + timeout=_interval.SynchronousTimeInterval(timeout), + verbose=verbose, + ) + + def _poll_for_statuses( + self, + ids: t.Sequence[LaunchedJobID], + statuses: t.Collection[JobStatus], + timeout: _interval.SynchronousTimeInterval | None = None, + interval: _interval.SynchronousTimeInterval | None = None, + verbose: bool = True, + ) -> dict[LaunchedJobID, JobStatus | InvalidJobStatus]: + """Poll launchers until status until all jobs represented by a + collections of ids have changed state to one of the provided statuses. + + :param ids: IDs of launches to poll for status. + :param statuses: A collection of statuses to poll for. + :param timeout: The minimum amount of time to spend polling all jobs to + reach one of the supplied statuses. If not supplied or `None`, the + experiment will poll indefinitely. + :param interval: The minimum time between polling launchers. + :param verbose: Whether or not to log polled states the console. + :raises ValueError: The interval between polling launchers is infinite + :raises TimeoutError: The polling interval was exceeded. + :returns: A mapping of ids to the status they entered that ended + polling. + """ + terminal = frozenset(itertools.chain(statuses, InvalidJobStatus)) + log = logger.info if verbose else lambda *_, **__: None + method_timeout = timeout or _interval.SynchronousTimeInterval(None) + iter_timeout = interval or _interval.SynchronousTimeInterval(5.0) + final: dict[LaunchedJobID, JobStatus | InvalidJobStatus] = {} + + def is_finished( + id_: LaunchedJobID, status: JobStatus | InvalidJobStatus + ) -> bool: + job_title = f"Job({id_}): " + if done := status in terminal: + log(f"{job_title}Finished with status '{status.value}'") + else: + log(f"{job_title}Running with status '{status.value}'") + return done + + if iter_timeout.infinite: + raise ValueError("Polling interval cannot be infinite") + while ids and not method_timeout.expired: + iter_timeout = iter_timeout.new_interval() + stats = zip(ids, self.get_status(*ids)) + is_done = _helpers.group_by(_helpers.pack_params(is_finished), stats) + final |= dict(is_done.get(True, ())) + ids = tuple(id_ for id_, _ in is_done.get(False, ())) + if ids: + iter_timeout.wait() + if ids: + raise TimeoutError( + f"Job ID(s) {', '.join(map(str, ids))} failed to reach " + "terminal status before timeout" + ) + return final + @_contextualize def _generate(self, generator: Generator, job: Job, job_index: int) -> pathlib.Path: """Generate the directory structure and files for a ``Job`` diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 8671bfedb..f619b1d3a 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -27,15 +27,19 @@ from __future__ import annotations import dataclasses +import io import itertools import random +import re import tempfile +import time import typing as t import uuid import pytest from smartsim._core import dispatch +from smartsim._core.control.interval import SynchronousTimeInterval from smartsim._core.control.launch_history import LaunchHistory from smartsim.entity import _mock, entity from smartsim.experiment import Experiment @@ -389,3 +393,117 @@ def impl(*a, **kw): assert len(calls) == 1, "Launcher's `get_status` was called more than once" (call,) = calls assert call == ((id_,), {}), "IDs were not de-duplicated" + + +def test_wait_handles_empty_call_args(experiment): + with pytest.raises(ValueError, match="No job ids"): + experiment.wait() + + +def test_wait_does_not_block_unknown_id(experiment): + now = time.perf_counter() + experiment.wait(dispatch.create_job_id()) + assert time.perf_counter() - now < 1 + + +def test_wait_calls_prefered_impl(make_populated_experment, monkeypatch): + exp = make_populated_experment(1) + ((_, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + was_called = False + + def mocked_impl(*args, **kwargs): + nonlocal was_called + was_called = True + + monkeypatch.setattr(exp, "_poll_for_statuses", mocked_impl) + exp.wait(id_) + assert was_called + + +@pytest.mark.parametrize( + "num_polls", + [ + pytest.param(i, id=f"Poll for status {i} times") + for i in (1, 5, 10, 20, 100, 1_000) + ], +) +@pytest.mark.parametrize("verbose", [True, False]) +def test_poll_status_blocks_until_job_is_completed( + monkeypatch, make_populated_experment, num_polls, verbose +): + exp = make_populated_experment(1) + ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + (current_status,) = launcher.get_status(id_).values() + different_statuses = set(JobStatus) - {current_status} + (new_status, *_) = different_statuses + mock_log = io.StringIO() + + @dataclasses.dataclass + class ChangeStatusAfterNPolls: + n: int + from_: JobStatus + to: JobStatus + num_calls: int = dataclasses.field(default=0, init=False) + + def __call__(self, *args, **kwargs): + self.num_calls += 1 + ret_status = self.to if self.num_calls >= self.n else self.from_ + return (ret_status,) + + mock_get_status = ChangeStatusAfterNPolls(num_polls, current_status, new_status) + monkeypatch.setattr(exp, "get_status", mock_get_status) + monkeypatch.setattr( + "smartsim.experiment.logger.info", lambda s: mock_log.write(f"{s}\n") + ) + final_statuses = exp._poll_for_statuses( + [id_], + different_statuses, + timeout=SynchronousTimeInterval(10), + interval=SynchronousTimeInterval(0), + verbose=verbose, + ) + assert final_statuses == {id_: new_status} + + expected_log = io.StringIO() + expected_log.writelines( + f"Job({id_}): Running with status '{current_status.value}'\n" + for _ in range(num_polls - 1) + ) + expected_log.write(f"Job({id_}): Finished with status '{new_status.value}'\n") + assert mock_get_status.num_calls == num_polls + assert mock_log.getvalue() == (expected_log.getvalue() if verbose else "") + + +def test_poll_status_raises_when_called_with_infinite_iter_wait( + make_populated_experment, +): + exp = make_populated_experment(1) + ((_, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + with pytest.raises(ValueError, match="Polling interval cannot be infinite"): + exp._poll_for_statuses( + [id_], + [], + timeout=SynchronousTimeInterval(10), + interval=SynchronousTimeInterval(None), + ) + + +def test_poll_for_status_raises_if_ids_not_found_within_timeout( + make_populated_experment, +): + exp = make_populated_experment(1) + ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() + (current_status,) = launcher.get_status(id_).values() + different_statuses = set(JobStatus) - {current_status} + with pytest.raises( + TimeoutError, + match=re.escape( + f"Job ID(s) {id_} failed to reach terminal status before timeout" + ), + ): + exp._poll_for_statuses( + [id_], + different_statuses, + timeout=SynchronousTimeInterval(1), + interval=SynchronousTimeInterval(0), + ) diff --git a/tests/test_intervals.py b/tests/test_intervals.py new file mode 100644 index 000000000..db12eac38 --- /dev/null +++ b/tests/test_intervals.py @@ -0,0 +1,103 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import contextlib +import operator +import time + +import pytest + +from smartsim._core.control.interval import SynchronousTimeInterval + +pytestmark = pytest.mark.group_a + + +@pytest.mark.parametrize( + "timeout", [pytest.param(i, id=f"{i} second(s)") for i in range(10)] +) +def test_sync_timeout_finite(timeout, monkeypatch): + monkeypatch.setattr(time, "perf_counter", lambda *_, **__: 0) + t = SynchronousTimeInterval(timeout) + assert t.delta == timeout + assert t.elapsed == 0 + assert t.remaining == timeout + assert (operator.not_ if timeout > 0 else bool)(t.expired) + assert not t.infinite + future = timeout + 2 + monkeypatch.setattr(time, "perf_counter", lambda *_, **__: future) + assert t.elapsed == future + assert t.remaining == 0 + assert t.expired + assert not t.infinite + new_t = t.new_interval() + assert new_t.delta == timeout + assert new_t.elapsed == 0 + assert new_t.remaining == timeout + assert (operator.not_ if timeout > 0 else bool)(new_t.expired) + assert not new_t.infinite + + +def test_sync_timeout_can_block_thread(): + timeout = 1 + now = time.perf_counter() + SynchronousTimeInterval(timeout).wait() + later = time.perf_counter() + assert abs(later - now - timeout) <= 0.25 + + +@pytest.mark.parametrize( + "timeout", + [ + pytest.param(-1, id="Negative timeout"), + pytest.param(None, id="Nullish timeout"), + ], +) +def test_sync_timeout_infinte(timeout): + t = SynchronousTimeInterval(timeout, strict=False) + assert t.remaining == float("inf") + assert t.infinite + with pytest.raises(RuntimeError, match="block thread forever"): + t.wait() + + +@pytest.mark.parametrize( + "kwargs, ctx", + [ + pytest.param({}, pytest.raises(ValueError), id=f"Default"), + pytest.param({"strict": True}, pytest.raises(ValueError), id=f"Strict"), + pytest.param({"strict": False}, contextlib.nullcontext(), id=f"Unstrict"), + ], +) +def test_sync_timeout_raises_on_invalid_value(monkeypatch, kwargs, ctx): + with ctx: + t = SynchronousTimeInterval(-1, **kwargs) + now = time.perf_counter() + monkeypatch.setattr( + time, "perf_counter", lambda *_, **__: now + 365 * 24 * 60 * 60 + ) + assert t._delta == float("inf") + assert t.expired == False + assert t.remaining == float("inf") From 45ef45b712d8d93786fdbb21e837e9b92cd7f6f9 Mon Sep 17 00:00:00 2001 From: Matt Drozt Date: Wed, 21 Aug 2024 14:49:23 -0700 Subject: [PATCH 2/5] Address reviewer feedback --- smartsim/_core/control/interval.py | 11 ++++------- smartsim/experiment.py | 15 ++++++++++----- tests/test_experiment.py | 18 +++++++++++++++++ tests/test_intervals.py | 31 +++++++++++------------------- 4 files changed, 43 insertions(+), 32 deletions(-) diff --git a/smartsim/_core/control/interval.py b/smartsim/_core/control/interval.py index e2153b85c..de9794bc3 100644 --- a/smartsim/_core/control/interval.py +++ b/smartsim/_core/control/interval.py @@ -37,23 +37,20 @@ class SynchronousTimeInterval: thread for an interval of time. """ - def __init__(self, delta: float | None, strict: bool = True) -> None: + def __init__(self, delta: float | None) -> None: """Initialize a new `SynchronousTimeInterval` interval :param delta: The difference in time the interval represents. If `None`, the interval will represent an infinite amount of time in seconds. - :param strict: Wether or not to raise in the case where negative time - delta is provided. If `False`, the `SynchronousTimeInterval` and a negative `delta` - is provided, the interval will be infinite. :raises ValueError: The `delta` is negative and `strict` is `True` """ - if delta is not None and delta < 0 and strict: + if delta is not None and delta < 0: raise ValueError("Timeout value cannot be less than 0") if delta is None or delta < 0: delta = float("inf") self._delta = Seconds(delta) - """The amount of time, in seconds the interval spans.""" + """The amount of time, in seconds, the interval spans.""" self._start = time.perf_counter() """The time of the creation of the interval""" @@ -101,7 +98,7 @@ def infinite(self) -> bool: def new_interval(self) -> SynchronousTimeInterval: """Make a new timeout with the same interval - :returns: The new timeout + :returns: The new time interval """ return type(self)(self.delta) diff --git a/smartsim/experiment.py b/smartsim/experiment.py index a3b9d0832..edf9587b3 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -282,7 +282,7 @@ def wait( """Block execution until all of the provided launched jobs, represented by an ID, have entered a terminal status. - :param ids: The ids of the launched ids to wait for. + :param ids: The ids of the launched jobs to wait for. :param timeout: The max time to wait for all of the launched jobs to end. :param verbose: Whether found statuses should be displayed in the console. :raises ValueError: No IDs were provided. @@ -304,10 +304,11 @@ def _poll_for_statuses( interval: _interval.SynchronousTimeInterval | None = None, verbose: bool = True, ) -> dict[LaunchedJobID, JobStatus | InvalidJobStatus]: - """Poll launchers until status until all jobs represented by a - collections of ids have changed state to one of the provided statuses. + """Poll the experiment's launchers for the statuses of the launched + jobs with the provided ids, until the status of the changes to one of + the provided statuses. - :param ids: IDs of launches to poll for status. + :param ids: The ids of the launched jobs to wait for. :param statuses: A collection of statuses to poll for. :param timeout: The minimum amount of time to spend polling all jobs to reach one of the supplied statuses. If not supplied or `None`, the @@ -344,7 +345,11 @@ def is_finished( final |= dict(is_done.get(True, ())) ids = tuple(id_ for id_, _ in is_done.get(False, ())) if ids: - iter_timeout.wait() + ( + iter_timeout + if iter_timeout.remaining < method_timeout.remaining + else method_timeout + ).wait() if ids: raise TimeoutError( f"Job ID(s) {', '.join(map(str, ids))} failed to reach " diff --git a/tests/test_experiment.py b/tests/test_experiment.py index f619b1d3a..46318614f 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -396,17 +396,25 @@ def impl(*a, **kw): def test_wait_handles_empty_call_args(experiment): + """An exception is raised when asked to wait for the completion of no jobs""" with pytest.raises(ValueError, match="No job ids"): experiment.wait() def test_wait_does_not_block_unknown_id(experiment): + """If an experiment does not recognize a job id, it should not wait for its + completion + """ now = time.perf_counter() experiment.wait(dispatch.create_job_id()) assert time.perf_counter() - now < 1 def test_wait_calls_prefered_impl(make_populated_experment, monkeypatch): + """Make wait is calling the expected method for checking job statuses. + Right now we only have the "polling" impl, but in future this might change + to an event based system. + """ exp = make_populated_experment(1) ((_, (id_, *_)),) = exp._launch_history.group_by_launcher().items() was_called = False @@ -431,6 +439,12 @@ def mocked_impl(*args, **kwargs): def test_poll_status_blocks_until_job_is_completed( monkeypatch, make_populated_experment, num_polls, verbose ): + """Make sure that the polling based impl blocks the calling thread. Use + varying number of polls to simulate varying lengths of job time for a job + to complete. + + Additionally check to make sure that the expected log messages are present + """ exp = make_populated_experment(1) ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() (current_status,) = launcher.get_status(id_).values() @@ -477,6 +491,9 @@ def __call__(self, *args, **kwargs): def test_poll_status_raises_when_called_with_infinite_iter_wait( make_populated_experment, ): + """Cannot wait forever between polls. That will just block the thread after + the first poll + """ exp = make_populated_experment(1) ((_, (id_, *_)),) = exp._launch_history.group_by_launcher().items() with pytest.raises(ValueError, match="Polling interval cannot be infinite"): @@ -491,6 +508,7 @@ def test_poll_status_raises_when_called_with_infinite_iter_wait( def test_poll_for_status_raises_if_ids_not_found_within_timeout( make_populated_experment, ): + """If there is a timeout, a timeout error should be raised when it is exceeded""" exp = make_populated_experment(1) ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() (current_status,) = launcher.get_status(id_).values() diff --git a/tests/test_intervals.py b/tests/test_intervals.py index db12eac38..cdceabfab 100644 --- a/tests/test_intervals.py +++ b/tests/test_intervals.py @@ -39,6 +39,7 @@ "timeout", [pytest.param(i, id=f"{i} second(s)") for i in range(10)] ) def test_sync_timeout_finite(timeout, monkeypatch): + """Test the sync timeout intervals are correctly calcualted""" monkeypatch.setattr(time, "perf_counter", lambda *_, **__: 0) t = SynchronousTimeInterval(timeout) assert t.delta == timeout @@ -61,6 +62,7 @@ def test_sync_timeout_finite(timeout, monkeypatch): def test_sync_timeout_can_block_thread(): + """Test the sync timeout can block the calling thread""" timeout = 1 now = time.perf_counter() SynchronousTimeInterval(timeout).wait() @@ -68,32 +70,21 @@ def test_sync_timeout_can_block_thread(): assert abs(later - now - timeout) <= 0.25 -@pytest.mark.parametrize( - "timeout", - [ - pytest.param(-1, id="Negative timeout"), - pytest.param(None, id="Nullish timeout"), - ], -) -def test_sync_timeout_infinte(timeout): - t = SynchronousTimeInterval(timeout, strict=False) +def test_sync_timeout_infinte(): + """Passing in `None` to a sync timeout creates a timeout with an infinite + delta time + """ + t = SynchronousTimeInterval(None) assert t.remaining == float("inf") assert t.infinite with pytest.raises(RuntimeError, match="block thread forever"): t.wait() -@pytest.mark.parametrize( - "kwargs, ctx", - [ - pytest.param({}, pytest.raises(ValueError), id=f"Default"), - pytest.param({"strict": True}, pytest.raises(ValueError), id=f"Strict"), - pytest.param({"strict": False}, contextlib.nullcontext(), id=f"Unstrict"), - ], -) -def test_sync_timeout_raises_on_invalid_value(monkeypatch, kwargs, ctx): - with ctx: - t = SynchronousTimeInterval(-1, **kwargs) +def test_sync_timeout_raises_on_invalid_value(monkeypatch): + """Cannot make a sync time interval with a negative time delta""" + with pytest.raises(ValueError): + t = SynchronousTimeInterval(-1) now = time.perf_counter() monkeypatch.setattr( time, "perf_counter", lambda *_, **__: now + 365 * 24 * 60 * 60 From cc12ebb9956e4c2a75ec600e61f71afbb940a5cb Mon Sep 17 00:00:00 2001 From: Matt Drozt Date: Thu, 22 Aug 2024 11:31:38 -0700 Subject: [PATCH 3/5] Seriously, how is English this my native language? --- smartsim/_core/control/interval.py | 2 +- smartsim/experiment.py | 2 +- tests/test_experiment.py | 42 +++++++++++++++--------------- tests/test_intervals.py | 4 +-- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/smartsim/_core/control/interval.py b/smartsim/_core/control/interval.py index de9794bc3..cdcfd47cb 100644 --- a/smartsim/_core/control/interval.py +++ b/smartsim/_core/control/interval.py @@ -89,7 +89,7 @@ def expired(self) -> bool: @property def infinite(self) -> bool: - """Is the timeout interval infinitely long. + """Return true if the timeout interval is infinitely long :returns: `True` if the delta is infinite, `False` otherwise """ diff --git a/smartsim/experiment.py b/smartsim/experiment.py index edf9587b3..2d812c633 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -314,7 +314,7 @@ def _poll_for_statuses( reach one of the supplied statuses. If not supplied or `None`, the experiment will poll indefinitely. :param interval: The minimum time between polling launchers. - :param verbose: Whether or not to log polled states the console. + :param verbose: Whether or not to log polled states to the console. :raises ValueError: The interval between polling launchers is infinite :raises TimeoutError: The polling interval was exceeded. :returns: A mapping of ids to the status they entered that ended diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 46318614f..cedd57b88 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -310,7 +310,7 @@ def get_status(self, *ids: LaunchedJobID): @pytest.fixture -def make_populated_experment(monkeypatch, experiment): +def make_populated_experiment(monkeypatch, experiment): def impl(num_active_launchers): new_launchers = (GetStatusLauncher() for _ in range(num_active_launchers)) id_to_launcher = { @@ -324,8 +324,8 @@ def impl(num_active_launchers): yield impl -def test_experiment_can_get_statuses(make_populated_experment): - exp = make_populated_experment(num_active_launchers=1) +def test_experiment_can_get_statuses(make_populated_experiment): + exp = make_populated_experiment(num_active_launchers=1) (launcher,) = exp._launch_history.iter_past_launchers() ids = tuple(launcher.known_ids) recieved_stats = exp.get_status(*ids) @@ -340,9 +340,9 @@ def test_experiment_can_get_statuses(make_populated_experment): [pytest.param(i, id=f"{i} launcher(s)") for i in (2, 3, 5, 10, 20, 100)], ) def test_experiment_can_get_statuses_from_many_launchers( - make_populated_experment, num_launchers + make_populated_experiment, num_launchers ): - exp = make_populated_experment(num_active_launchers=num_launchers) + exp = make_populated_experiment(num_active_launchers=num_launchers) launcher_and_rand_ids = ( (launcher, random.choice(tuple(launcher.id_to_status))) for launcher in exp._launch_history.iter_past_launchers() @@ -357,9 +357,9 @@ def test_experiment_can_get_statuses_from_many_launchers( def test_get_status_returns_not_started_for_unrecognized_ids( - monkeypatch, make_populated_experment + monkeypatch, make_populated_experiment ): - exp = make_populated_experment(num_active_launchers=1) + exp = make_populated_experiment(num_active_launchers=1) brand_new_id = dispatch.create_job_id() ((launcher, (id_not_known_by_exp, *rest)),) = ( exp._launch_history.group_by_launcher().items() @@ -372,7 +372,7 @@ def test_get_status_returns_not_started_for_unrecognized_ids( def test_get_status_de_dups_ids_passed_to_launchers( - monkeypatch, make_populated_experment + monkeypatch, make_populated_experiment ): def track_calls(fn): calls = [] @@ -383,7 +383,7 @@ def impl(*a, **kw): return calls, impl - exp = make_populated_experment(num_active_launchers=1) + exp = make_populated_experiment(num_active_launchers=1) ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() calls, tracked_get_status = track_calls(launcher.get_status) monkeypatch.setattr(launcher, "get_status", tracked_get_status) @@ -396,7 +396,7 @@ def impl(*a, **kw): def test_wait_handles_empty_call_args(experiment): - """An exception is raised when asked to wait for the completion of no jobs""" + """An exception is raised when there are no jobs to complete""" with pytest.raises(ValueError, match="No job ids"): experiment.wait() @@ -410,12 +410,12 @@ def test_wait_does_not_block_unknown_id(experiment): assert time.perf_counter() - now < 1 -def test_wait_calls_prefered_impl(make_populated_experment, monkeypatch): +def test_wait_calls_prefered_impl(make_populated_experiment, monkeypatch): """Make wait is calling the expected method for checking job statuses. Right now we only have the "polling" impl, but in future this might change to an event based system. """ - exp = make_populated_experment(1) + exp = make_populated_experiment(1) ((_, (id_, *_)),) = exp._launch_history.group_by_launcher().items() was_called = False @@ -437,15 +437,15 @@ def mocked_impl(*args, **kwargs): ) @pytest.mark.parametrize("verbose", [True, False]) def test_poll_status_blocks_until_job_is_completed( - monkeypatch, make_populated_experment, num_polls, verbose + monkeypatch, make_populated_experiment, num_polls, verbose ): - """Make sure that the polling based impl blocks the calling thread. Use - varying number of polls to simulate varying lengths of job time for a job - to complete. + """Make sure that the polling based implementation blocks the calling + thread. Use varying number of polls to simulate varying lengths of job time + for a job to complete. Additionally check to make sure that the expected log messages are present """ - exp = make_populated_experment(1) + exp = make_populated_experiment(1) ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() (current_status,) = launcher.get_status(id_).values() different_statuses = set(JobStatus) - {current_status} @@ -489,12 +489,12 @@ def __call__(self, *args, **kwargs): def test_poll_status_raises_when_called_with_infinite_iter_wait( - make_populated_experment, + make_populated_experiment, ): """Cannot wait forever between polls. That will just block the thread after the first poll """ - exp = make_populated_experment(1) + exp = make_populated_experiment(1) ((_, (id_, *_)),) = exp._launch_history.group_by_launcher().items() with pytest.raises(ValueError, match="Polling interval cannot be infinite"): exp._poll_for_statuses( @@ -506,10 +506,10 @@ def test_poll_status_raises_when_called_with_infinite_iter_wait( def test_poll_for_status_raises_if_ids_not_found_within_timeout( - make_populated_experment, + make_populated_experiment, ): """If there is a timeout, a timeout error should be raised when it is exceeded""" - exp = make_populated_experment(1) + exp = make_populated_experiment(1) ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() (current_status,) = launcher.get_status(id_).values() different_statuses = set(JobStatus) - {current_status} diff --git a/tests/test_intervals.py b/tests/test_intervals.py index cdceabfab..a98bfa22e 100644 --- a/tests/test_intervals.py +++ b/tests/test_intervals.py @@ -39,7 +39,7 @@ "timeout", [pytest.param(i, id=f"{i} second(s)") for i in range(10)] ) def test_sync_timeout_finite(timeout, monkeypatch): - """Test the sync timeout intervals are correctly calcualted""" + """Test that the sync timeout intervals are correctly calculated""" monkeypatch.setattr(time, "perf_counter", lambda *_, **__: 0) t = SynchronousTimeInterval(timeout) assert t.delta == timeout @@ -62,7 +62,7 @@ def test_sync_timeout_finite(timeout, monkeypatch): def test_sync_timeout_can_block_thread(): - """Test the sync timeout can block the calling thread""" + """Test that the sync timeout can block the calling thread""" timeout = 1 now = time.perf_counter() SynchronousTimeInterval(timeout).wait() From c6245657544617bcc2f27e8414e59b80842d3acf Mon Sep 17 00:00:00 2001 From: Matt Drozt Date: Wed, 28 Aug 2024 09:48:09 -0700 Subject: [PATCH 4/5] Deobfuscate API --- smartsim/_core/control/interval.py | 2 +- smartsim/experiment.py | 15 ++++++--------- tests/test_experiment.py | 14 +++++--------- tests/test_intervals.py | 13 +++---------- 4 files changed, 15 insertions(+), 29 deletions(-) diff --git a/smartsim/_core/control/interval.py b/smartsim/_core/control/interval.py index cdcfd47cb..c4e31c5d2 100644 --- a/smartsim/_core/control/interval.py +++ b/smartsim/_core/control/interval.py @@ -102,7 +102,7 @@ def new_interval(self) -> SynchronousTimeInterval: """ return type(self)(self.delta) - def wait(self) -> None: + def block(self) -> None: """Block the thread until the timeout completes :raises RuntimeError: The thread would be blocked forever diff --git a/smartsim/experiment.py b/smartsim/experiment.py index 3fa5d12b3..fcdade5c9 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -290,18 +290,15 @@ def wait( if not ids: raise ValueError("No job ids to wait on provided") self._poll_for_statuses( - ids, - TERMINAL_STATUSES, - timeout=_interval.SynchronousTimeInterval(timeout), - verbose=verbose, + ids, TERMINAL_STATUSES, timeout=timeout, verbose=verbose ) def _poll_for_statuses( self, ids: t.Sequence[LaunchedJobID], statuses: t.Collection[JobStatus], - timeout: _interval.SynchronousTimeInterval | None = None, - interval: _interval.SynchronousTimeInterval | None = None, + timeout: float | None = None, + interval: float = 5.0, verbose: bool = True, ) -> dict[LaunchedJobID, JobStatus | InvalidJobStatus]: """Poll the experiment's launchers for the statuses of the launched @@ -322,8 +319,8 @@ def _poll_for_statuses( """ terminal = frozenset(itertools.chain(statuses, InvalidJobStatus)) log = logger.info if verbose else lambda *_, **__: None - method_timeout = timeout or _interval.SynchronousTimeInterval(None) - iter_timeout = interval or _interval.SynchronousTimeInterval(5.0) + method_timeout = _interval.SynchronousTimeInterval(timeout) + iter_timeout = _interval.SynchronousTimeInterval(interval) final: dict[LaunchedJobID, JobStatus | InvalidJobStatus] = {} def is_finished( @@ -349,7 +346,7 @@ def is_finished( iter_timeout if iter_timeout.remaining < method_timeout.remaining else method_timeout - ).wait() + ).block() if ids: raise TimeoutError( f"Job ID(s) {', '.join(map(str, ids))} failed to reach " diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 795adf542..29e2626cc 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -471,11 +471,7 @@ def __call__(self, *args, **kwargs): "smartsim.experiment.logger.info", lambda s: mock_log.write(f"{s}\n") ) final_statuses = exp._poll_for_statuses( - [id_], - different_statuses, - timeout=SynchronousTimeInterval(10), - interval=SynchronousTimeInterval(0), - verbose=verbose, + [id_], different_statuses, timeout=10, interval=0, verbose=verbose ) assert final_statuses == {id_: new_status} @@ -501,8 +497,8 @@ def test_poll_status_raises_when_called_with_infinite_iter_wait( exp._poll_for_statuses( [id_], [], - timeout=SynchronousTimeInterval(10), - interval=SynchronousTimeInterval(None), + timeout=10, + interval=float("inf"), ) @@ -523,6 +519,6 @@ def test_poll_for_status_raises_if_ids_not_found_within_timeout( exp._poll_for_statuses( [id_], different_statuses, - timeout=SynchronousTimeInterval(1), - interval=SynchronousTimeInterval(0), + timeout=1, + interval=0, ) diff --git a/tests/test_intervals.py b/tests/test_intervals.py index a98bfa22e..1b865867f 100644 --- a/tests/test_intervals.py +++ b/tests/test_intervals.py @@ -65,7 +65,7 @@ def test_sync_timeout_can_block_thread(): """Test that the sync timeout can block the calling thread""" timeout = 1 now = time.perf_counter() - SynchronousTimeInterval(timeout).wait() + SynchronousTimeInterval(timeout).block() later = time.perf_counter() assert abs(later - now - timeout) <= 0.25 @@ -78,17 +78,10 @@ def test_sync_timeout_infinte(): assert t.remaining == float("inf") assert t.infinite with pytest.raises(RuntimeError, match="block thread forever"): - t.wait() + t.block() def test_sync_timeout_raises_on_invalid_value(monkeypatch): """Cannot make a sync time interval with a negative time delta""" with pytest.raises(ValueError): - t = SynchronousTimeInterval(-1) - now = time.perf_counter() - monkeypatch.setattr( - time, "perf_counter", lambda *_, **__: now + 365 * 24 * 60 * 60 - ) - assert t._delta == float("inf") - assert t.expired == False - assert t.remaining == float("inf") + SynchronousTimeInterval(-1) From 98b745f167e823c6d5f2bdfaf73e860eba8db5e2 Mon Sep 17 00:00:00 2001 From: Matt Drozt Date: Wed, 28 Aug 2024 13:45:54 -0700 Subject: [PATCH 5/5] Remove redundant check --- smartsim/_core/control/interval.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/smartsim/_core/control/interval.py b/smartsim/_core/control/interval.py index c4e31c5d2..e35b1c694 100644 --- a/smartsim/_core/control/interval.py +++ b/smartsim/_core/control/interval.py @@ -40,14 +40,14 @@ class SynchronousTimeInterval: def __init__(self, delta: float | None) -> None: """Initialize a new `SynchronousTimeInterval` interval - :param delta: The difference in time the interval represents. If - `None`, the interval will represent an infinite amount of time in - seconds. - :raises ValueError: The `delta` is negative and `strict` is `True` + :param delta: The difference in time the interval represents in + seconds. If `None`, the interval will represent an infinite amount + of time. + :raises ValueError: The `delta` is negative """ if delta is not None and delta < 0: raise ValueError("Timeout value cannot be less than 0") - if delta is None or delta < 0: + if delta is None: delta = float("inf") self._delta = Seconds(delta) """The amount of time, in seconds, the interval spans."""