10
10
"""Definition of the ``SqliteTempBackend`` backend."""
11
11
from __future__ import annotations
12
12
13
- from contextlib import contextmanager
13
+ from contextlib import contextmanager , nullcontext
14
+ import functools
14
15
from functools import cached_property
15
- from typing import Any , Iterator , Sequence
16
+ import hashlib
17
+ import os
18
+ from pathlib import Path
19
+ import shutil
20
+ from typing import Any , BinaryIO , Iterator , Sequence
16
21
17
22
from sqlalchemy .orm import Session
18
23
19
- from aiida .common .exceptions import ClosedStorage
20
- from aiida .manage import Profile , get_config_option
24
+ from aiida .common .exceptions import ClosedStorage , IntegrityError
25
+ from aiida .manage import Profile
21
26
from aiida .orm .entities import EntityTypes
22
27
from aiida .orm .implementation import BackendEntity , StorageBackend
23
28
from aiida .repository .backend .sandbox import SandboxRepositoryBackend
@@ -42,7 +47,8 @@ def create_profile(
42
47
name : str = 'temp' ,
43
48
default_user_email = 'user@email.com' ,
44
49
options : dict | None = None ,
45
- debug : bool = False
50
+ debug : bool = False ,
51
+ repo_path : str | Path | None = None ,
46
52
) -> Profile :
47
53
"""Create a new profile instance for this backend, from the path to the zip file."""
48
54
return Profile (
@@ -52,6 +58,7 @@ def create_profile(
52
58
'backend' : 'core.sqlite_temp' ,
53
59
'config' : {
54
60
'debug' : debug ,
61
+ 'repo_path' : repo_path ,
55
62
}
56
63
},
57
64
'process_control' : {
@@ -81,7 +88,7 @@ def migrate(cls, profile: Profile):
81
88
def __init__ (self , profile : Profile ):
82
89
super ().__init__ (profile )
83
90
self ._session : Session | None = None
84
- self ._repo : SandboxRepositoryBackend | None = None
91
+ self ._repo : SandboxShaRepositoryBackend | None = None
85
92
self ._globals : dict [str , tuple [Any , str | None ]] = {}
86
93
self ._closed = False
87
94
self .get_session () # load the database on initialization
@@ -124,12 +131,13 @@ def get_session(self) -> Session:
124
131
self ._session .commit ()
125
132
return self ._session
126
133
127
- def get_repository (self ) -> SandboxRepositoryBackend :
134
+ def get_repository (self ) -> SandboxShaRepositoryBackend :
128
135
if self ._closed :
129
136
raise ClosedStorage (str (self ))
130
137
if self ._repo is None :
131
138
# to-do this does not seem to be removing the folder on garbage collection?
132
- self ._repo = SandboxRepositoryBackend (filepath = get_config_option ('storage.sandbox' ) or None )
139
+ repo_path = self .profile .storage_config .get ('repo_path' )
140
+ self ._repo = SandboxShaRepositoryBackend (filepath = Path (repo_path ) if repo_path else None )
133
141
return self ._repo
134
142
135
143
@property
@@ -199,11 +207,122 @@ def get_info(self, detailed: bool = False) -> dict:
199
207
# results['repository'] = self.get_repository().get_info(detailed)
200
208
return results
201
209
210
+ @staticmethod
211
+ @functools .lru_cache (maxsize = 18 )
212
+ def _get_mapper_from_entity (entity_type : EntityTypes , with_pk : bool ):
213
+ """Return the Sqlalchemy mapper and fields corresponding to the given entity.
214
+
215
+ :param with_pk: if True, the fields returned will include the primary key
216
+ """
217
+ from sqlalchemy import inspect
218
+
219
+ from aiida .storage .sqlite_zip .models import (
220
+ DbAuthInfo ,
221
+ DbComment ,
222
+ DbComputer ,
223
+ DbGroup ,
224
+ DbGroupNodes ,
225
+ DbLink ,
226
+ DbLog ,
227
+ DbNode ,
228
+ DbUser ,
229
+ )
230
+
231
+ model = {
232
+ EntityTypes .AUTHINFO : DbAuthInfo ,
233
+ EntityTypes .COMMENT : DbComment ,
234
+ EntityTypes .COMPUTER : DbComputer ,
235
+ EntityTypes .GROUP : DbGroup ,
236
+ EntityTypes .LOG : DbLog ,
237
+ EntityTypes .NODE : DbNode ,
238
+ EntityTypes .USER : DbUser ,
239
+ EntityTypes .LINK : DbLink ,
240
+ EntityTypes .GROUP_NODE : DbGroupNodes ,
241
+ }[entity_type ]
242
+ mapper = inspect (model ).mapper
243
+ keys = {key for key , col in mapper .c .items () if with_pk or col not in mapper .primary_key }
244
+ return mapper , keys
245
+
202
246
def bulk_insert (self , entity_type : EntityTypes , rows : list [dict ], allow_defaults : bool = False ) -> list [int ]:
203
- raise NotImplementedError
247
+ mapper , keys = self ._get_mapper_from_entity (entity_type , False )
248
+ if not rows :
249
+ return []
250
+ if entity_type in (EntityTypes .COMPUTER , EntityTypes .LOG , EntityTypes .AUTHINFO ):
251
+ for row in rows :
252
+ row ['_metadata' ] = row .pop ('metadata' )
253
+ if allow_defaults :
254
+ for row in rows :
255
+ if not keys .issuperset (row ):
256
+ raise IntegrityError (f'Incorrect fields given for { entity_type } : { set (row )} not subset of { keys } ' )
257
+ else :
258
+ for row in rows :
259
+ if set (row ) != keys :
260
+ raise IntegrityError (f'Incorrect fields given for { entity_type } : { set (row )} != { keys } ' )
261
+ session = self .get_session ()
262
+ with (nullcontext () if self .in_transaction else self .transaction ()):
263
+ session .bulk_insert_mappings (mapper , rows , render_nulls = True , return_defaults = True )
264
+ return [row ['id' ] for row in rows ]
204
265
205
266
def bulk_update (self , entity_type : EntityTypes , rows : list [dict ]) -> None :
206
- raise NotImplementedError
267
+ mapper , keys = self ._get_mapper_from_entity (entity_type , True )
268
+ if not rows :
269
+ return None
270
+ for row in rows :
271
+ if 'id' not in row :
272
+ raise IntegrityError (f"'id' field not given for { entity_type } : { set (row )} " )
273
+ if not keys .issuperset (row ):
274
+ raise IntegrityError (f'Incorrect fields given for { entity_type } : { set (row )} not subset of { keys } ' )
275
+ session = self .get_session ()
276
+ with (nullcontext () if self .in_transaction else self .transaction ()):
277
+ session .bulk_update_mappings (mapper , rows )
207
278
208
279
def delete_nodes_and_connections (self , pks_to_delete : Sequence [int ]):
209
280
raise NotImplementedError
281
+
282
+
283
+ class SandboxShaRepositoryBackend (SandboxRepositoryBackend ):
284
+ """A sandbox repository backend that uses the sha256 of the file as the key.
285
+
286
+ This allows for compatibility with the archive format (i.e. `SqliteZipBackend`).
287
+ Which allows for temporary profiles to be exported and imported.
288
+ """
289
+
290
+ @property
291
+ def key_format (self ) -> str | None :
292
+ return 'sha256'
293
+
294
+ def get_object_hash (self , key : str ) -> str :
295
+ return key
296
+
297
+ def _put_object_from_filelike (self , handle : BinaryIO ) -> str :
298
+ """Store the byte contents of a file in the repository.
299
+
300
+ :param handle: filelike object with the byte content to be stored.
301
+ :return: the generated fully qualified identifier for the object within the repository.
302
+ :raises TypeError: if the handle is not a byte stream.
303
+ """
304
+ # we first compute the hash of the file contents
305
+ hsh = hashlib .sha256 ()
306
+ position = handle .tell ()
307
+ while True :
308
+ buf = handle .read (1024 * 1024 )
309
+ if not buf :
310
+ break
311
+ hsh .update (buf )
312
+ key = hsh .hexdigest ()
313
+
314
+ filepath = os .path .join (self .sandbox .abspath , key )
315
+ if not os .path .exists (filepath ):
316
+ # if a file with this hash does not already exist
317
+ # then we reset the file pointer and copy the contents
318
+ handle .seek (position )
319
+ with open (filepath , 'wb' ) as target :
320
+ shutil .copyfileobj (handle , target )
321
+
322
+ return key
323
+
324
+ def get_info (self , detailed : bool = False , ** kwargs ) -> dict :
325
+ return {'objects' : {'count' : len (list (self .list_objects ()))}}
326
+
327
+ def maintain (self , dry_run : bool = False , live : bool = True , ** kwargs ) -> None :
328
+ pass
0 commit comments