|
1 |
| -import threading |
2 |
| -from getpass import getpass |
3 |
| -from typing import TYPE_CHECKING, BinaryIO, Union |
4 |
| - |
5 |
| -from dvc_objects.fs.base import AnyFSPath, FileSystem |
6 |
| -from dvc_objects.fs.callbacks import DEFAULT_CALLBACK, Callback |
7 |
| -from dvc_objects.fs.errors import ConfigError |
8 |
| -from funcy import cached_property, memoize, wrap_with |
9 |
| - |
10 |
| -if TYPE_CHECKING: |
11 |
| - from ssl import SSLContext |
12 |
| - |
13 |
| - |
14 |
| -@wrap_with(threading.Lock()) |
15 |
| -@memoize |
16 |
| -def ask_password(host, user): |
17 |
| - return getpass(f"Enter a password for host '{host}' user '{user}':\n") |
18 |
| - |
19 |
| - |
20 |
| -def make_context( |
21 |
| - ssl_verify: Union[bool, str, None] |
22 |
| -) -> Union["SSLContext", bool, None]: |
23 |
| - if isinstance(ssl_verify, bool) or ssl_verify is None: |
24 |
| - return ssl_verify |
25 |
| - |
26 |
| - # If this is a path, then we will create an |
27 |
| - # SSL context for it, and load the given certificate. |
28 |
| - import ssl |
29 |
| - |
30 |
| - context = ssl.create_default_context() |
31 |
| - context.load_verify_locations(ssl_verify) |
32 |
| - return context |
33 |
| - |
34 |
| - |
35 |
| -# pylint: disable=abstract-method |
36 |
| -class HTTPFileSystem(FileSystem): |
37 |
| - protocol = "http" |
38 |
| - PARAM_CHECKSUM = "checksum" |
39 |
| - REQUIRES = {"aiohttp": "aiohttp", "aiohttp-retry": "aiohttp_retry"} |
40 |
| - CAN_TRAVERSE = False |
41 |
| - |
42 |
| - SESSION_RETRIES = 5 |
43 |
| - SESSION_BACKOFF_FACTOR = 0.1 |
44 |
| - REQUEST_TIMEOUT = 60 |
45 |
| - |
46 |
| - def __init__(self, fs=None, timeout=REQUEST_TIMEOUT, **kwargs): |
47 |
| - super().__init__(fs, **kwargs) |
48 |
| - self.upload_method = kwargs.get("method", "POST") |
49 |
| - |
50 |
| - client_kwargs = self.fs_args.setdefault("client_kwargs", {}) |
51 |
| - client_kwargs.update( |
52 |
| - { |
53 |
| - "ssl_verify": kwargs.get("ssl_verify"), |
54 |
| - "read_timeout": kwargs.get("read_timeout", timeout), |
55 |
| - "connect_timeout": kwargs.get("connect_timeout", timeout), |
56 |
| - "trust_env": True, # Allow reading proxy configs from the env |
57 |
| - } |
58 |
| - ) |
59 |
| - |
60 |
| - def _prepare_credentials(self, **config): |
61 |
| - import aiohttp |
62 |
| - |
63 |
| - auth_method = config.get("auth") |
64 |
| - if not auth_method: |
65 |
| - return {} |
66 |
| - |
67 |
| - user = config.get("user") |
68 |
| - password = config.get("password") |
69 |
| - |
70 |
| - if password is None and config.get("ask_password"): |
71 |
| - password = ask_password(config.get("url"), user or "custom") |
72 |
| - |
73 |
| - client_kwargs = {} |
74 |
| - if auth_method == "basic": |
75 |
| - if user is None or password is None: |
76 |
| - raise ConfigError( |
77 |
| - "HTTP 'basic' authentication require both " |
78 |
| - "'user' and 'password'" |
79 |
| - ) |
80 |
| - client_kwargs["auth"] = aiohttp.BasicAuth(user, password) |
81 |
| - elif auth_method == "custom": |
82 |
| - custom_auth_header = config.get("custom_auth_header") |
83 |
| - if custom_auth_header is None or password is None: |
84 |
| - raise ConfigError( |
85 |
| - "HTTP 'custom' authentication require both " |
86 |
| - "'custom_auth_header' and 'password'" |
87 |
| - ) |
88 |
| - client_kwargs["headers"] = {custom_auth_header: password} |
89 |
| - else: |
90 |
| - raise NotImplementedError( |
91 |
| - f"Auth method {auth_method!r} is not supported." |
92 |
| - ) |
93 |
| - return {"client_kwargs": client_kwargs} |
94 |
| - |
95 |
| - async def get_client(self, **kwargs): |
96 |
| - import aiohttp |
97 |
| - from aiohttp_retry import ExponentialRetry |
98 |
| - |
99 |
| - from .retry import ReadOnlyRetryClient |
100 |
| - |
101 |
| - kwargs["retry_options"] = ExponentialRetry( |
102 |
| - attempts=self.SESSION_RETRIES, |
103 |
| - factor=self.SESSION_BACKOFF_FACTOR, |
104 |
| - max_timeout=self.REQUEST_TIMEOUT, |
105 |
| - exceptions={aiohttp.ClientError}, |
106 |
| - ) |
107 |
| - |
108 |
| - # The default total timeout for an aiohttp request is 300 seconds |
109 |
| - # which is too low for DVC's interactions when dealing with large |
110 |
| - # data blobs. We remove the total timeout, and only limit the time |
111 |
| - # that is spent when connecting to the remote server and waiting |
112 |
| - # for new data portions. |
113 |
| - connect_timeout = kwargs.pop("connect_timeout") |
114 |
| - kwargs["timeout"] = aiohttp.ClientTimeout( |
115 |
| - total=None, |
116 |
| - connect=connect_timeout, |
117 |
| - sock_connect=connect_timeout, |
118 |
| - sock_read=kwargs.pop("read_timeout"), |
119 |
| - ) |
120 |
| - |
121 |
| - kwargs["connector"] = aiohttp.TCPConnector( |
122 |
| - # Force cleanup of closed SSL transports. |
123 |
| - # See https://github.com/iterative/dvc/issues/7414 |
124 |
| - enable_cleanup_closed=True, |
125 |
| - ssl=make_context(kwargs.pop("ssl_verify", None)), |
126 |
| - ) |
127 |
| - |
128 |
| - return ReadOnlyRetryClient(**kwargs) |
129 |
| - |
130 |
| - @cached_property |
131 |
| - def fs(self): |
132 |
| - from fsspec.implementations.http import ( |
133 |
| - HTTPFileSystem as _HTTPFileSystem, |
134 |
| - ) |
135 |
| - |
136 |
| - return _HTTPFileSystem(get_client=self.get_client, **self.fs_args) |
137 |
| - |
138 |
| - def unstrip_protocol(self, path: str) -> str: |
139 |
| - return path |
140 |
| - |
141 |
| - def put_file( |
142 |
| - self, |
143 |
| - from_file: Union[AnyFSPath, BinaryIO], |
144 |
| - to_info: AnyFSPath, |
145 |
| - callback: Callback = DEFAULT_CALLBACK, |
146 |
| - size: int = None, |
147 |
| - **kwargs, |
148 |
| - ) -> None: |
149 |
| - kwargs.setdefault("method", self.upload_method) |
150 |
| - super().put_file( |
151 |
| - from_file, to_info, callback=callback, size=size, **kwargs |
152 |
| - ) |
153 |
| - |
154 |
| - # pylint: disable=arguments-differ |
155 |
| - |
156 |
| - def find(self, *args, **kwargs): |
157 |
| - raise NotImplementedError |
158 |
| - |
159 |
| - def isdir(self, *args, **kwargs): |
160 |
| - return False |
161 |
| - |
162 |
| - def ls(self, *args, **kwargs): |
163 |
| - raise NotImplementedError |
164 |
| - |
165 |
| - def walk(self, *args, **kwargs): |
166 |
| - raise NotImplementedError |
167 |
| - |
168 |
| - # pylint: enable=arguments-differ |
169 |
| - |
170 |
| - |
171 |
| -class HTTPSFileSystem(HTTPFileSystem): # pylint:disable=abstract-method |
172 |
| - protocol = "https" |
| 1 | +import threading |
| 2 | +from getpass import getpass |
| 3 | +from typing import TYPE_CHECKING, BinaryIO, Union |
| 4 | + |
| 5 | +from dvc_objects.fs.base import AnyFSPath, FileSystem |
| 6 | +from dvc_objects.fs.callbacks import DEFAULT_CALLBACK, Callback |
| 7 | +from dvc_objects.fs.errors import ConfigError |
| 8 | +from funcy import cached_property, memoize, wrap_with |
| 9 | + |
| 10 | +if TYPE_CHECKING: |
| 11 | + from ssl import SSLContext |
| 12 | + |
| 13 | + |
| 14 | +@wrap_with(threading.Lock()) |
| 15 | +@memoize |
| 16 | +def ask_password(host, user): |
| 17 | + return getpass(f"Enter a password for host '{host}' user '{user}':\n") |
| 18 | + |
| 19 | + |
| 20 | +def make_context( |
| 21 | + ssl_verify: Union[bool, str, None] |
| 22 | +) -> Union["SSLContext", bool, None]: |
| 23 | + if isinstance(ssl_verify, bool) or ssl_verify is None: |
| 24 | + return ssl_verify |
| 25 | + |
| 26 | + # If this is a path, then we will create an |
| 27 | + # SSL context for it, and load the given certificate. |
| 28 | + import ssl |
| 29 | + |
| 30 | + context = ssl.create_default_context() |
| 31 | + context.load_verify_locations(ssl_verify) |
| 32 | + return context |
| 33 | + |
| 34 | + |
| 35 | +# pylint: disable=abstract-method |
| 36 | +class HTTPFileSystem(FileSystem): |
| 37 | + protocol = "http" |
| 38 | + PARAM_CHECKSUM = "checksum" |
| 39 | + REQUIRES = {"aiohttp": "aiohttp", "aiohttp-retry": "aiohttp_retry"} |
| 40 | + CAN_TRAVERSE = False |
| 41 | + |
| 42 | + SESSION_RETRIES = 5 |
| 43 | + SESSION_BACKOFF_FACTOR = 0.1 |
| 44 | + REQUEST_TIMEOUT = 60 |
| 45 | + |
| 46 | + def __init__(self, fs=None, timeout=REQUEST_TIMEOUT, **kwargs): |
| 47 | + super().__init__(fs, **kwargs) |
| 48 | + self.upload_method = kwargs.get("method", "POST") |
| 49 | + |
| 50 | + client_kwargs = self.fs_args.setdefault("client_kwargs", {}) |
| 51 | + client_kwargs.update( |
| 52 | + { |
| 53 | + "ssl_verify": kwargs.get("ssl_verify"), |
| 54 | + "read_timeout": kwargs.get("read_timeout", timeout), |
| 55 | + "connect_timeout": kwargs.get("connect_timeout", timeout), |
| 56 | + "trust_env": True, # Allow reading proxy configs from the env |
| 57 | + } |
| 58 | + ) |
| 59 | + |
| 60 | + def _prepare_credentials(self, **config): |
| 61 | + import aiohttp |
| 62 | + |
| 63 | + auth_method = config.get("auth") |
| 64 | + if not auth_method: |
| 65 | + return {} |
| 66 | + |
| 67 | + user = config.get("user") |
| 68 | + password = config.get("password") |
| 69 | + |
| 70 | + if password is None and config.get("ask_password"): |
| 71 | + password = ask_password(config.get("url"), user or "custom") |
| 72 | + |
| 73 | + client_kwargs = {} |
| 74 | + if auth_method == "basic": |
| 75 | + if user is None or password is None: |
| 76 | + raise ConfigError( |
| 77 | + "HTTP 'basic' authentication require both " |
| 78 | + "'user' and 'password'" |
| 79 | + ) |
| 80 | + client_kwargs["auth"] = aiohttp.BasicAuth(user, password) |
| 81 | + elif auth_method == "custom": |
| 82 | + custom_auth_header = config.get("custom_auth_header") |
| 83 | + if custom_auth_header is None or password is None: |
| 84 | + raise ConfigError( |
| 85 | + "HTTP 'custom' authentication require both " |
| 86 | + "'custom_auth_header' and 'password'" |
| 87 | + ) |
| 88 | + client_kwargs["headers"] = {custom_auth_header: password} |
| 89 | + else: |
| 90 | + raise NotImplementedError( |
| 91 | + f"Auth method {auth_method!r} is not supported." |
| 92 | + ) |
| 93 | + return {"client_kwargs": client_kwargs} |
| 94 | + |
| 95 | + async def get_client(self, **kwargs): |
| 96 | + import aiohttp |
| 97 | + from aiohttp_retry import ExponentialRetry |
| 98 | + |
| 99 | + from .retry import ReadOnlyRetryClient |
| 100 | + |
| 101 | + kwargs["retry_options"] = ExponentialRetry( |
| 102 | + attempts=self.SESSION_RETRIES, |
| 103 | + factor=self.SESSION_BACKOFF_FACTOR, |
| 104 | + max_timeout=self.REQUEST_TIMEOUT, |
| 105 | + exceptions={aiohttp.ClientError}, |
| 106 | + ) |
| 107 | + |
| 108 | + # The default total timeout for an aiohttp request is 300 seconds |
| 109 | + # which is too low for DVC's interactions when dealing with large |
| 110 | + # data blobs. We remove the total timeout, and only limit the time |
| 111 | + # that is spent when connecting to the remote server and waiting |
| 112 | + # for new data portions. |
| 113 | + connect_timeout = kwargs.pop("connect_timeout") |
| 114 | + kwargs["timeout"] = aiohttp.ClientTimeout( |
| 115 | + total=None, |
| 116 | + connect=connect_timeout, |
| 117 | + sock_connect=connect_timeout, |
| 118 | + sock_read=kwargs.pop("read_timeout"), |
| 119 | + ) |
| 120 | + |
| 121 | + kwargs["connector"] = aiohttp.TCPConnector( |
| 122 | + # Force cleanup of closed SSL transports. |
| 123 | + # See https://github.com/iterative/dvc/issues/7414 |
| 124 | + enable_cleanup_closed=True, |
| 125 | + ssl=make_context(kwargs.pop("ssl_verify", None)), |
| 126 | + ) |
| 127 | + |
| 128 | + return ReadOnlyRetryClient(**kwargs) |
| 129 | + |
| 130 | + @cached_property |
| 131 | + def fs(self): |
| 132 | + from fsspec.implementations.http import ( |
| 133 | + HTTPFileSystem as _HTTPFileSystem, |
| 134 | + ) |
| 135 | + |
| 136 | + return _HTTPFileSystem(get_client=self.get_client, **self.fs_args) |
| 137 | + |
| 138 | + def unstrip_protocol(self, path: str) -> str: |
| 139 | + return path |
| 140 | + |
| 141 | + def put_file( |
| 142 | + self, |
| 143 | + from_file: Union[AnyFSPath, BinaryIO], |
| 144 | + to_info: AnyFSPath, |
| 145 | + callback: Callback = DEFAULT_CALLBACK, |
| 146 | + size: int = None, |
| 147 | + **kwargs, |
| 148 | + ) -> None: |
| 149 | + kwargs.setdefault("method", self.upload_method) |
| 150 | + super().put_file( |
| 151 | + from_file, to_info, callback=callback, size=size, **kwargs |
| 152 | + ) |
| 153 | + |
| 154 | + # pylint: disable=arguments-differ |
| 155 | + |
| 156 | + def find(self, *args, **kwargs): |
| 157 | + raise NotImplementedError |
| 158 | + |
| 159 | + def isdir(self, *args, **kwargs): |
| 160 | + return False |
| 161 | + |
| 162 | + def ls(self, *args, **kwargs): |
| 163 | + raise NotImplementedError |
| 164 | + |
| 165 | + def walk(self, *args, **kwargs): |
| 166 | + raise NotImplementedError |
| 167 | + |
| 168 | + # pylint: enable=arguments-differ |
| 169 | + |
| 170 | + |
| 171 | +class HTTPSFileSystem(HTTPFileSystem): # pylint:disable=abstract-method |
| 172 | + protocol = "https" |
0 commit comments