Skip to content

Commit 09cf743

Browse files
authored
split credentials and config handling, set custom auth header to the client (#37)
1 parent f2e536e commit 09cf743

File tree

2 files changed

+48
-51
lines changed

2 files changed

+48
-51
lines changed

dvc_http/__init__.py

+47-50
Original file line numberDiff line numberDiff line change
@@ -43,59 +43,56 @@ class HTTPFileSystem(FileSystem):
4343
SESSION_BACKOFF_FACTOR = 0.1
4444
REQUEST_TIMEOUT = 60
4545

46+
def __init__(self, fs=None, timeout=REQUEST_TIMEOUT, **kwargs):
47+
super().__init__(fs, **kwargs)
48+
self.upload_method = kwargs.get("method", "POST")
49+
50+
client_kwargs = self.fs_args.setdefault("client_kwargs", {})
51+
client_kwargs.update(
52+
{
53+
"ssl_verify": kwargs.get("ssl_verify"),
54+
"read_timeout": kwargs.get("read_timeout", timeout),
55+
"connect_timeout": kwargs.get("connect_timeout", timeout),
56+
"trust_env": True, # Allow reading proxy configs from the env
57+
}
58+
)
59+
4660
def _prepare_credentials(self, **config):
4761
import aiohttp
4862

49-
credentials = {}
50-
client_kwargs = credentials.setdefault("client_kwargs", {})
51-
52-
if config.get("auth"):
53-
user = config.get("user")
54-
password = config.get("password")
55-
custom_auth_header = config.get("custom_auth_header")
56-
57-
if password is None and config.get("ask_password"):
58-
password = ask_password(config.get("url"), user or "custom")
59-
60-
auth_method = config["auth"]
61-
if auth_method == "basic":
62-
if user is None or password is None:
63-
raise ConfigError(
64-
"HTTP 'basic' authentication require both "
65-
"'user' and 'password'"
66-
)
67-
68-
client_kwargs["auth"] = aiohttp.BasicAuth(user, password)
69-
elif auth_method == "custom":
70-
if custom_auth_header is None or password is None:
71-
raise ConfigError(
72-
"HTTP 'custom' authentication require both "
73-
"'custom_auth_header' and 'password'"
74-
)
75-
credentials["headers"] = {custom_auth_header: password}
76-
else:
77-
raise NotImplementedError(
78-
f"Auth method {auth_method!r} is not supported."
79-
)
80-
81-
if "ssl_verify" in config:
82-
client_kwargs["ssl_verify"] = config["ssl_verify"]
63+
auth_method = config.get("auth")
64+
if not auth_method:
65+
return {}
8366

84-
for timeout in ("connect_timeout", "read_timeout"):
85-
if timeout in config:
86-
client_kwargs[timeout] = config.get(timeout)
67+
user = config.get("user")
68+
password = config.get("password")
8769

88-
# Allow reading proxy configurations from the environment.
89-
client_kwargs["trust_env"] = True
70+
if password is None and config.get("ask_password"):
71+
password = ask_password(config.get("url"), user or "custom")
9072

91-
credentials["get_client"] = self.get_client
92-
self.upload_method = config.get("method", "POST")
93-
return credentials
94-
95-
async def get_client(
96-
self,
97-
**kwargs,
98-
):
73+
client_kwargs = {}
74+
if auth_method == "basic":
75+
if user is None or password is None:
76+
raise ConfigError(
77+
"HTTP 'basic' authentication require both "
78+
"'user' and 'password'"
79+
)
80+
client_kwargs["auth"] = aiohttp.BasicAuth(user, password)
81+
elif auth_method == "custom":
82+
custom_auth_header = config.get("custom_auth_header")
83+
if custom_auth_header is None or password is None:
84+
raise ConfigError(
85+
"HTTP 'custom' authentication require both "
86+
"'custom_auth_header' and 'password'"
87+
)
88+
client_kwargs["headers"] = {custom_auth_header: password}
89+
else:
90+
raise NotImplementedError(
91+
f"Auth method {auth_method!r} is not supported."
92+
)
93+
return {"client_kwargs": client_kwargs}
94+
95+
async def get_client(self, **kwargs):
9996
import aiohttp
10097
from aiohttp_retry import ExponentialRetry
10198

@@ -113,12 +110,12 @@ async def get_client(
113110
# data blobs. We remove the total timeout, and only limit the time
114111
# that is spent when connecting to the remote server and waiting
115112
# for new data portions.
116-
connect_timeout = kwargs.pop("connect_timeout", self.REQUEST_TIMEOUT)
113+
connect_timeout = kwargs.pop("connect_timeout")
117114
kwargs["timeout"] = aiohttp.ClientTimeout(
118115
total=None,
119116
connect=connect_timeout,
120117
sock_connect=connect_timeout,
121-
sock_read=kwargs.pop("read_timeout", self.REQUEST_TIMEOUT),
118+
sock_read=kwargs.pop("read_timeout"),
122119
)
123120

124121
kwargs["connector"] = aiohttp.TCPConnector(
@@ -136,7 +133,7 @@ def fs(self):
136133
HTTPFileSystem as _HTTPFileSystem,
137134
)
138135

139-
return _HTTPFileSystem(**self.fs_args)
136+
return _HTTPFileSystem(get_client=self.get_client, **self.fs_args)
140137

141138
def unstrip_protocol(self, path: str) -> str:
142139
return path

dvc_http/tests/test_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_custom_auth_method():
5050

5151
fs = HTTPFileSystem(**config)
5252

53-
headers = fs.fs_args["headers"]
53+
headers = fs.fs_args["client_kwargs"]["headers"]
5454
assert header in headers
5555
assert headers[header] == password
5656

0 commit comments

Comments
 (0)