From b80c14cf85fcb8ee765018c050490f7190acbb3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Fri, 6 Jan 2023 13:25:56 +0545 Subject: [PATCH] tests: use threaded http server --- dvc_http/tests/fixtures.py | 6 +++--- dvc_http/tests/httpd.py | 26 ++++++-------------------- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/dvc_http/tests/fixtures.py b/dvc_http/tests/fixtures.py index f4c9089..e663322 100644 --- a/dvc_http/tests/fixtures.py +++ b/dvc_http/tests/fixtures.py @@ -4,14 +4,14 @@ import pytest from .cloud import HTTP -from .httpd import StaticFileServer +from .httpd import static_file_server @pytest.fixture(scope="session") def http_server(tmp_path_factory): directory = os.fspath(tmp_path_factory.mktemp("http")) - with StaticFileServer(directory=directory) as httpd: - yield httpd + with static_file_server(directory) as server: + yield server @pytest.fixture diff --git a/dvc_http/tests/httpd.py b/dvc_http/tests/httpd.py index 3a657f8..ad2b426 100644 --- a/dvc_http/tests/httpd.py +++ b/dvc_http/tests/httpd.py @@ -2,8 +2,9 @@ import os import threading from contextlib import contextmanager +from functools import partial from http import HTTPStatus -from http.server import HTTPServer +from http.server import ThreadingHTTPServer from RangeHTTPServer import RangeRequestHandler @@ -67,22 +68,7 @@ def run_server_on_thread(server): server.server_close() -class StaticFileServer: - _lock = threading.Lock() - - def __init__(self, directory): - from functools import partial - - addr = ("localhost", 0) - req = partial(TestRequestHandler, directory=directory) - server = HTTPServer(addr, req) - self.runner = run_server_on_thread(server) - - # pylint: disable=no-member - def __enter__(self): - self._lock.acquire() - return self.runner.__enter__() - - def __exit__(self, *args): - self.runner.__exit__(*args) - self._lock.release() +def static_file_server(directory): + req = partial(TestRequestHandler, directory=directory) + server = ThreadingHTTPServer(("localhost", 0), req) + return run_server_on_thread(server)