Skip to content

Commit 83fc5cf

Browse files
authoredSep 17, 2023
SqliteTempBackend: Add support for reading from and writing to archives (#5658)
To this end, the `bulk_insert` and `bulk_update` are implemented. The archive creation and import functionality currently requires that the repository of the storage backend uses a SHA256 hash for the keys of the objects. This is not the case for the `SandboxRepositoryBackend` that the `SqliteTempBackend` uses. Therefore, the `SandboxRepositoryBackend` is subclassed to `SandboxShaRepositoryBackend` which replaces the UUID key of its parent and uses a SHA256 instead.
1 parent 7a3f108 commit 83fc5cf

File tree

3 files changed

+178
-12
lines changed

3 files changed

+178
-12
lines changed
 

‎aiida/storage/sqlite_temp/backend.py

+129-10
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,19 @@
1010
"""Definition of the ``SqliteTempBackend`` backend."""
1111
from __future__ import annotations
1212

13-
from contextlib import contextmanager
13+
from contextlib import contextmanager, nullcontext
14+
import functools
1415
from functools import cached_property
15-
from typing import Any, Iterator, Sequence
16+
import hashlib
17+
import os
18+
from pathlib import Path
19+
import shutil
20+
from typing import Any, BinaryIO, Iterator, Sequence
1621

1722
from sqlalchemy.orm import Session
1823

19-
from aiida.common.exceptions import ClosedStorage
20-
from aiida.manage import Profile, get_config_option
24+
from aiida.common.exceptions import ClosedStorage, IntegrityError
25+
from aiida.manage import Profile
2126
from aiida.orm.entities import EntityTypes
2227
from aiida.orm.implementation import BackendEntity, StorageBackend
2328
from aiida.repository.backend.sandbox import SandboxRepositoryBackend
@@ -42,7 +47,8 @@ def create_profile(
4247
name: str = 'temp',
4348
default_user_email='user@email.com',
4449
options: dict | None = None,
45-
debug: bool = False
50+
debug: bool = False,
51+
repo_path: str | Path | None = None,
4652
) -> Profile:
4753
"""Create a new profile instance for this backend, from the path to the zip file."""
4854
return Profile(
@@ -52,6 +58,7 @@ def create_profile(
5258
'backend': 'core.sqlite_temp',
5359
'config': {
5460
'debug': debug,
61+
'repo_path': repo_path,
5562
}
5663
},
5764
'process_control': {
@@ -81,7 +88,7 @@ def migrate(cls, profile: Profile):
8188
def __init__(self, profile: Profile):
8289
super().__init__(profile)
8390
self._session: Session | None = None
84-
self._repo: SandboxRepositoryBackend | None = None
91+
self._repo: SandboxShaRepositoryBackend | None = None
8592
self._globals: dict[str, tuple[Any, str | None]] = {}
8693
self._closed = False
8794
self.get_session() # load the database on initialization
@@ -124,12 +131,13 @@ def get_session(self) -> Session:
124131
self._session.commit()
125132
return self._session
126133

127-
def get_repository(self) -> SandboxRepositoryBackend:
134+
def get_repository(self) -> SandboxShaRepositoryBackend:
128135
if self._closed:
129136
raise ClosedStorage(str(self))
130137
if self._repo is None:
131138
# to-do this does not seem to be removing the folder on garbage collection?
132-
self._repo = SandboxRepositoryBackend(filepath=get_config_option('storage.sandbox') or None)
139+
repo_path = self.profile.storage_config.get('repo_path')
140+
self._repo = SandboxShaRepositoryBackend(filepath=Path(repo_path) if repo_path else None)
133141
return self._repo
134142

135143
@property
@@ -199,11 +207,122 @@ def get_info(self, detailed: bool = False) -> dict:
199207
# results['repository'] = self.get_repository().get_info(detailed)
200208
return results
201209

210+
@staticmethod
211+
@functools.lru_cache(maxsize=18)
212+
def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool):
213+
"""Return the Sqlalchemy mapper and fields corresponding to the given entity.
214+
215+
:param with_pk: if True, the fields returned will include the primary key
216+
"""
217+
from sqlalchemy import inspect
218+
219+
from aiida.storage.sqlite_zip.models import (
220+
DbAuthInfo,
221+
DbComment,
222+
DbComputer,
223+
DbGroup,
224+
DbGroupNodes,
225+
DbLink,
226+
DbLog,
227+
DbNode,
228+
DbUser,
229+
)
230+
231+
model = {
232+
EntityTypes.AUTHINFO: DbAuthInfo,
233+
EntityTypes.COMMENT: DbComment,
234+
EntityTypes.COMPUTER: DbComputer,
235+
EntityTypes.GROUP: DbGroup,
236+
EntityTypes.LOG: DbLog,
237+
EntityTypes.NODE: DbNode,
238+
EntityTypes.USER: DbUser,
239+
EntityTypes.LINK: DbLink,
240+
EntityTypes.GROUP_NODE: DbGroupNodes,
241+
}[entity_type]
242+
mapper = inspect(model).mapper
243+
keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key}
244+
return mapper, keys
245+
202246
def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults: bool = False) -> list[int]:
203-
raise NotImplementedError
247+
mapper, keys = self._get_mapper_from_entity(entity_type, False)
248+
if not rows:
249+
return []
250+
if entity_type in (EntityTypes.COMPUTER, EntityTypes.LOG, EntityTypes.AUTHINFO):
251+
for row in rows:
252+
row['_metadata'] = row.pop('metadata')
253+
if allow_defaults:
254+
for row in rows:
255+
if not keys.issuperset(row):
256+
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
257+
else:
258+
for row in rows:
259+
if set(row) != keys:
260+
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}')
261+
session = self.get_session()
262+
with (nullcontext() if self.in_transaction else self.transaction()):
263+
session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True)
264+
return [row['id'] for row in rows]
204265

205266
def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None:
206-
raise NotImplementedError
267+
mapper, keys = self._get_mapper_from_entity(entity_type, True)
268+
if not rows:
269+
return None
270+
for row in rows:
271+
if 'id' not in row:
272+
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}")
273+
if not keys.issuperset(row):
274+
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
275+
session = self.get_session()
276+
with (nullcontext() if self.in_transaction else self.transaction()):
277+
session.bulk_update_mappings(mapper, rows)
207278

208279
def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]):
209280
raise NotImplementedError
281+
282+
283+
class SandboxShaRepositoryBackend(SandboxRepositoryBackend):
284+
"""A sandbox repository backend that uses the sha256 of the file as the key.
285+
286+
This allows for compatibility with the archive format (i.e. `SqliteZipBackend`).
287+
Which allows for temporary profiles to be exported and imported.
288+
"""
289+
290+
@property
291+
def key_format(self) -> str | None:
292+
return 'sha256'
293+
294+
def get_object_hash(self, key: str) -> str:
295+
return key
296+
297+
def _put_object_from_filelike(self, handle: BinaryIO) -> str:
298+
"""Store the byte contents of a file in the repository.
299+
300+
:param handle: filelike object with the byte content to be stored.
301+
:return: the generated fully qualified identifier for the object within the repository.
302+
:raises TypeError: if the handle is not a byte stream.
303+
"""
304+
# we first compute the hash of the file contents
305+
hsh = hashlib.sha256()
306+
position = handle.tell()
307+
while True:
308+
buf = handle.read(1024 * 1024)
309+
if not buf:
310+
break
311+
hsh.update(buf)
312+
key = hsh.hexdigest()
313+
314+
filepath = os.path.join(self.sandbox.abspath, key)
315+
if not os.path.exists(filepath):
316+
# if a file with this hash does not already exist
317+
# then we reset the file pointer and copy the contents
318+
handle.seek(position)
319+
with open(filepath, 'wb') as target:
320+
shutil.copyfileobj(handle, target)
321+
322+
return key
323+
324+
def get_info(self, detailed: bool = False, **kwargs) -> dict:
325+
return {'objects': {'count': len(list(self.list_objects()))}}
326+
327+
def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None:
328+
pass

‎aiida/tools/archive/imports.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def _transform(row):
778778
def _transform(row):
779779
# to-do this is probably not the most efficient way to do this
780780
uuid, new_mtime, new_comment = row
781-
cmt = orm.Comment.collection.get(uuid=uuid)
781+
cmt = orm.comments.CommentCollection(orm.Comment, backend).get(uuid=uuid)
782782
if cmt.mtime < new_mtime:
783783
cmt.set_mtime(new_mtime)
784784
cmt.set_content(new_comment)
@@ -1086,7 +1086,7 @@ def _make_import_group(
10861086
break
10871087
else:
10881088
raise ImportUniquenessError(f'New import Group has existing label {label!r} and re-labelling failed')
1089-
dummy_orm = orm.ImportGroup(label)
1089+
dummy_orm = orm.ImportGroup(label, backend=backend_to)
10901090
row = {
10911091
'label': label,
10921092
'description': 'Group generated by archive import',

‎tests/storage/sqlite/test_archive.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# -*- coding: utf-8 -*-
2+
"""Test export and import of AiiDA archives to/from a temporary profile."""
3+
from pathlib import Path
4+
5+
from aiida import orm
6+
from aiida.storage.sqlite_temp import SqliteTempBackend
7+
from aiida.tools.archive import create_archive, import_archive
8+
9+
10+
def test_basic(tmp_path):
11+
"""Test the creation of an archive and re-import."""
12+
filename = Path(tmp_path / 'export.aiida')
13+
14+
# generate a temporary backend
15+
profile1 = SqliteTempBackend.create_profile(repo_path=str(tmp_path / 'repo1'))
16+
backend1 = SqliteTempBackend(profile1)
17+
18+
# add simple node
19+
dict_data = {'key1': 'value1'}
20+
node = orm.Dict(dict_data, backend=backend1).store()
21+
# add a comment to the node
22+
node.base.comments.add('test comment', backend1.default_user)
23+
# add node with repository data
24+
path = Path(tmp_path / 'test.txt')
25+
text_data = 'test'
26+
path.write_text(text_data, encoding='utf-8')
27+
orm.SinglefileData(str(path), backend=backend1).store()
28+
29+
# export to archive
30+
create_archive(None, backend=backend1, filename=filename)
31+
32+
# create a new temporary backend and import
33+
profile2 = SqliteTempBackend.create_profile(repo_path=str(tmp_path / 'repo2'))
34+
backend2 = SqliteTempBackend(profile2)
35+
import_archive(filename, backend=backend2)
36+
37+
# check that the nodes are there
38+
assert orm.QueryBuilder(backend=backend2).append(orm.Data).count() == 2
39+
40+
# check that we can retrieve the attributes and comment data
41+
node = orm.QueryBuilder(backend=backend2).append(orm.Dict).first(flat=True)
42+
assert node.get_dict() == dict_data
43+
assert len(node.base.comments.all()) == 1
44+
45+
# check that we can retrieve the repository data
46+
node = orm.QueryBuilder(backend=backend2).append(orm.SinglefileData).first(flat=True)
47+
assert node.get_content() == text_data

0 commit comments

Comments
 (0)
Please sign in to comment.