Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: iterative/dvc-task
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 0.2.0
Choose a base ref
...
head repository: iterative/dvc-task
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 0.2.1
Choose a head ref
  • 2 commits
  • 3 files changed
  • 1 contributor

Commits on Apr 27, 2023

  1. Copy the full SHA
    e07e276 View commit details
  2. Copy the full SHA
    f23893b View commit details
Showing with 49 additions and 23 deletions.
  1. +25 −4 src/dvc_task/app/filesystem.py
  2. +23 −16 src/dvc_task/worker/temporary.py
  3. +1 −3 tests/worker/test_temporary.py
29 changes: 25 additions & 4 deletions src/dvc_task/app/filesystem.py
Original file line number Diff line number Diff line change
@@ -244,10 +244,10 @@ def _delete_expired(
cache: Dict[str, str],
include_tickets: bool = False,
):
assert isinstance(msg.properties, dict)
properties = cast(Dict[str, Any], msg.properties)
delivery_info: Dict[str, str] = properties.get("delivery_info", {})
if queues:
assert isinstance(msg.properties, dict)
properties = cast(Dict[str, Any], msg.properties)
delivery_info: Dict[str, str] = properties.get("delivery_info", {})
routing_key = delivery_info.get("routing_key")
if routing_key and routing_key in queues:
return
@@ -256,7 +256,10 @@ def _delete_expired(
ticket = msg.headers.get("ticket")
if include_tickets and ticket or (expires is not None and expires <= now):
assert msg.delivery_tag
self._delete_msg(msg.delivery_tag, [], cache)
try:
self._delete_msg(msg.delivery_tag, [], cache)
except ValueError:
pass

queues = set(exclude) if exclude else set()
now = datetime.now().timestamp()
@@ -270,3 +273,21 @@ def _delete_expired(
def clean(self):
"""Clean extraneous celery messages from this FSApp."""
self._gc(exclude=[self.conf.task_default_queue])
self._clean_pidbox(f"reply.{self.conf.task_default_queue}.pidbox")

def _clean_pidbox(self, exchange: str):
"""Clean pidbox replies for the specified exchange."""

def _delete_replies(msg: Message, exchange: str, cache: Dict[str, str]):
assert isinstance(msg.properties, dict)
properties = cast(Dict[str, Any], msg.properties)
delivery_info: Dict[str, str] = properties.get("delivery_info", {})
if delivery_info.get("exchange", "") == exchange:
assert msg.delivery_tag
try:
self._delete_msg(msg.delivery_tag, [], cache)
except ValueError:
pass

for msg in self._iter_data_folder():
_delete_replies(msg, exchange, self._queued_msg_path_cache)
39 changes: 23 additions & 16 deletions src/dvc_task/worker/temporary.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import os
import threading
import time
from typing import Any, List, Mapping
from typing import Any, Dict, List, Mapping, Optional

from celery import Celery
from celery.utils.nodenames import default_nodename
@@ -36,6 +36,15 @@ def __init__( # pylint: disable=too-many-arguments
self.timeout = timeout
self.config = kwargs

def ping(self, name: str, timeout: float = 1.0) -> Optional[List[Dict[str, Any]]]:
"""Ping the specified worker."""
return self._ping(destination=[default_nodename(name)], timeout=timeout)

def _ping(
self, *, destination: Optional[List[str]] = None, timeout: float = 1.0
) -> Optional[List[Dict[str, Any]]]:
return self.app.control.ping(destination=destination, timeout=timeout)

def start(self, name: str, fsapp_clean: bool = False) -> None:
"""Start the worker if it does not already exist.
@@ -50,19 +59,22 @@ def start(self, name: str, fsapp_clean: bool = False) -> None:
# see https://github.com/celery/billiard/issues/247
os.environ["FORKED_BY_MULTIPROCESSING"] = "1"

if not self.app.control.ping(destination=[name]):
if not self.ping(name):
monitor = threading.Thread(
target=self.monitor,
daemon=True,
args=(name,),
kwargs={"fsapp_clean": fsapp_clean},
)
monitor.start()
config = dict(self.config)
config["hostname"] = name
argv = ["worker"]
argv.extend(self._parse_config(config))
self.app.worker_main(argv=argv)
if fsapp_clean and isinstance(self.app, FSApp): # type: ignore[unreachable]
logger.info("cleaning up FSApp broker.")
self.app.clean()
logger.info("done")

@staticmethod
def _parse_config(config: Mapping[str, Any]) -> List[str]:
@@ -85,13 +97,9 @@ def _parse_config(config: Mapping[str, Any]) -> List[str]:
argv.append("-E")
return argv

def monitor(self, name: str, fsapp_clean: bool = False) -> None:
def monitor(self, name: str) -> None:
"""Monitor the worker and stop it when the queue is empty."""
logger.debug("monitor: waiting for worker to start")
nodename = default_nodename(name)
while not self.app.control.ping(destination=[nodename]):
# wait for worker to start
time.sleep(1)

def _tasksets(nodes):
for taskset in (
@@ -105,17 +113,16 @@ def _tasksets(nodes):
if isinstance(self.app, FSApp):
yield from self.app.iter_queued()

logger.info("monitor: watching celery worker '%s'", nodename)
while self.app.control.ping(destination=[nodename]):
logger.debug("monitor: watching celery worker '%s'", nodename)
while True:
time.sleep(self.timeout)
nodes = self.app.control.inspect( # type: ignore[call-arg]
destination=[nodename]
destination=[nodename],
limit=1,
)
if nodes is None or not any(tasks for tasks in _tasksets(nodes)):
logger.info("monitor: shutting down due to empty queue.")
self.app.control.shutdown(destination=[nodename])
break
if fsapp_clean and isinstance(self.app, FSApp):
logger.info("monitor: cleanup FSApp broker.")
self.app.clean()
logger.info("monitor: done")
logger.debug("monitor: sending shutdown to '%s'.", nodename)
self.app.control.shutdown()
logger.debug("monitor: done")
4 changes: 1 addition & 3 deletions tests/worker/test_temporary.py
Original file line number Diff line number Diff line change
@@ -24,9 +24,7 @@ def test_start(celery_app: Celery, mocker: MockerFixture):
assert kwargs["pool"] == TaskPool
assert kwargs["concurrency"] == 1
assert kwargs["prefetch_multiplier"] == 1
thread.assert_called_once_with(
target=worker.monitor, daemon=True, args=(name,), kwargs={"fsapp_clean": False}
)
thread.assert_called_once_with(target=worker.monitor, daemon=True, args=(name,))


@pytest.mark.flaky(