diff --git a/dvc_http/__init__.py b/dvc_http/__init__.py index e68c07c..3f05464 100644 --- a/dvc_http/__init__.py +++ b/dvc_http/__init__.py @@ -43,59 +43,56 @@ class HTTPFileSystem(FileSystem): SESSION_BACKOFF_FACTOR = 0.1 REQUEST_TIMEOUT = 60 + def __init__(self, fs=None, timeout=REQUEST_TIMEOUT, **kwargs): + super().__init__(fs, **kwargs) + self.upload_method = kwargs.get("method", "POST") + + client_kwargs = self.fs_args.setdefault("client_kwargs", {}) + client_kwargs.update( + { + "ssl_verify": kwargs.get("ssl_verify"), + "read_timeout": kwargs.get("read_timeout", timeout), + "connect_timeout": kwargs.get("connect_timeout", timeout), + "trust_env": True, # Allow reading proxy configs from the env + } + ) + def _prepare_credentials(self, **config): import aiohttp - credentials = {} - client_kwargs = credentials.setdefault("client_kwargs", {}) - - if config.get("auth"): - user = config.get("user") - password = config.get("password") - custom_auth_header = config.get("custom_auth_header") - - if password is None and config.get("ask_password"): - password = ask_password(config.get("url"), user or "custom") - - auth_method = config["auth"] - if auth_method == "basic": - if user is None or password is None: - raise ConfigError( - "HTTP 'basic' authentication require both " - "'user' and 'password'" - ) - - client_kwargs["auth"] = aiohttp.BasicAuth(user, password) - elif auth_method == "custom": - if custom_auth_header is None or password is None: - raise ConfigError( - "HTTP 'custom' authentication require both " - "'custom_auth_header' and 'password'" - ) - credentials["headers"] = {custom_auth_header: password} - else: - raise NotImplementedError( - f"Auth method {auth_method!r} is not supported." - ) - - if "ssl_verify" in config: - client_kwargs["ssl_verify"] = config["ssl_verify"] + auth_method = config.get("auth") + if not auth_method: + return {} - for timeout in ("connect_timeout", "read_timeout"): - if timeout in config: - client_kwargs[timeout] = config.get(timeout) + user = config.get("user") + password = config.get("password") - # Allow reading proxy configurations from the environment. - client_kwargs["trust_env"] = True + if password is None and config.get("ask_password"): + password = ask_password(config.get("url"), user or "custom") - credentials["get_client"] = self.get_client - self.upload_method = config.get("method", "POST") - return credentials - - async def get_client( - self, - **kwargs, - ): + client_kwargs = {} + if auth_method == "basic": + if user is None or password is None: + raise ConfigError( + "HTTP 'basic' authentication require both " + "'user' and 'password'" + ) + client_kwargs["auth"] = aiohttp.BasicAuth(user, password) + elif auth_method == "custom": + custom_auth_header = config.get("custom_auth_header") + if custom_auth_header is None or password is None: + raise ConfigError( + "HTTP 'custom' authentication require both " + "'custom_auth_header' and 'password'" + ) + client_kwargs["headers"] = {custom_auth_header: password} + else: + raise NotImplementedError( + f"Auth method {auth_method!r} is not supported." + ) + return {"client_kwargs": client_kwargs} + + async def get_client(self, **kwargs): import aiohttp from aiohttp_retry import ExponentialRetry @@ -113,12 +110,12 @@ async def get_client( # data blobs. We remove the total timeout, and only limit the time # that is spent when connecting to the remote server and waiting # for new data portions. - connect_timeout = kwargs.pop("connect_timeout", self.REQUEST_TIMEOUT) + connect_timeout = kwargs.pop("connect_timeout") kwargs["timeout"] = aiohttp.ClientTimeout( total=None, connect=connect_timeout, sock_connect=connect_timeout, - sock_read=kwargs.pop("read_timeout", self.REQUEST_TIMEOUT), + sock_read=kwargs.pop("read_timeout"), ) kwargs["connector"] = aiohttp.TCPConnector( @@ -136,7 +133,7 @@ def fs(self): HTTPFileSystem as _HTTPFileSystem, ) - return _HTTPFileSystem(**self.fs_args) + return _HTTPFileSystem(get_client=self.get_client, **self.fs_args) def unstrip_protocol(self, path: str) -> str: return path diff --git a/dvc_http/tests/test_config.py b/dvc_http/tests/test_config.py index a75fefd..86cd848 100644 --- a/dvc_http/tests/test_config.py +++ b/dvc_http/tests/test_config.py @@ -50,7 +50,7 @@ def test_custom_auth_method(): fs = HTTPFileSystem(**config) - headers = fs.fs_args["headers"] + headers = fs.fs_args["client_kwargs"]["headers"] assert header in headers assert headers[header] == password