Skip to content

Commit bc34c9f

Browse files
committed
Greatly simplify the template saving logic
1 parent 61921ee commit bc34c9f

23 files changed

+266
-745
lines changed

superduper/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from .components.plugin import Plugin
3636
from .components.streamlit import Streamlit
3737
from .components.table import Table
38-
from .components.template import QueryTemplate, Template
38+
from .components.template import Template
3939
from .components.vector_index import VectorIndex
4040

4141
REQUIRES = [
@@ -65,7 +65,6 @@
6565
'Table',
6666
'Application',
6767
'Template',
68-
'QueryTemplate',
6968
'Application',
7069
'Component',
7170
'pickle_serializer',

superduper/base/base.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def cls_encode(cls, item: 'Base', builds, blobs, files, leaves_to_keep=()):
201201
)
202202

203203
@staticmethod
204-
def decode(r, db: t.Optional['Datalayer'] = None):
204+
def decode(r):
205205
"""Decode a dictionary component into a `Component` instance.
206206
207207
:param r: Object to be decoded.
@@ -211,9 +211,8 @@ def decode(r, db: t.Optional['Datalayer'] = None):
211211

212212
if '_path' in r:
213213
from superduper.misc.importing import import_object
214-
215214
cls = import_object(r['_path'])
216-
r = Document.decode(r, schema=cls.class_schema, db=db)
215+
r = Document.decode(r, schema=cls.class_schema)
217216
return cls.from_dict(r, db=None)
218217

219218
def encode(

superduper/base/datatype.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,35 @@ def hash(cls, item):
298298
return hash_item([x.hash for x in item])
299299

300300

301+
class FDict(BaseDataType):
302+
dtype: t.ClassVar[str] = 'json'
303+
304+
def encode_data(self, item, context):
305+
"""Encode the given item into a bytes-like object or reference.
306+
307+
:param item: The object/instance to encode.
308+
:param context: A context object containing caches.
309+
"""
310+
assert isinstance(item, dict)
311+
return {
312+
k: File().encode_data(v, context)
313+
for k, v in item.items()
314+
}
315+
316+
def decode_data(self, item, builds, db):
317+
"""Decode the item from `bytes`.
318+
319+
:param item: The item to decode.
320+
:param builds: The builds.
321+
:param db: The Datalayer.
322+
"""
323+
return {k: File().decode_data(v, builds, db) for k, v in item.items()}
324+
325+
@classmethod
326+
def hash(cls, item):
327+
return hash_item({k: File.hash(v) for k, v in item.items()})
328+
329+
301330
@dc.dataclass(kw_only=True)
302331
class BaseVector(BaseDataType):
303332
"""Base class for vector.
@@ -453,8 +482,11 @@ def hash_indescript(item):
453482
"""
454483
if inspect.isfunction(item):
455484
module = item.__module__
456-
body = f'{module}\n{inspect.getsource(item)}'
457-
return hashlib.sha256(body.encode()).hexdigest()
485+
try:
486+
body = f'{module}\n{inspect.getsource(item)}'
487+
return hashlib.sha256(body.encode()).hexdigest()
488+
except OSError:
489+
return hashlib.sha256(str(item).encode()).hexdigest()
458490
if inspect.isclass(item):
459491
module = item.__module__
460492
body = f'{module}\n{inspect.getsource(item)}'
@@ -601,6 +633,8 @@ def encode_data(self, item, context):
601633
:param item: The object/instance to encode.
602634
:param context: A context object containing caches.
603635
"""
636+
if isinstance(item, FileItem):
637+
return item.reference
604638
assert os.path.exists(item)
605639
file = FileItem(identifier=self.hash(item), path=item)
606640
context.files[file.identifier] = file.path
@@ -690,6 +724,7 @@ class _DatatypeLookup:
690724
File(),
691725
LeafType(),
692726
ComponentType(),
727+
FDict(),
693728
SDict(),
694729
SList(),
695730
FieldType('str'),

superduper/base/schema.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# TODO move to base
22
import dataclasses as dc
33
import json
4+
import re
45
import typing as t
56
from functools import cached_property
67

7-
from superduper import CFG
8+
from superduper import CFG, logging
89
from superduper.base.datatype import BaseDataType
910
from superduper.base.encoding import EncodeContext
1011
from superduper.misc.special_dicts import dict_to_ascii_table
@@ -40,14 +41,44 @@ def __add__(self, other: 'Schema'):
4041
new_fields.update(other.fields)
4142
return Schema(fields=new_fields)
4243

43-
@cached_property
44+
@property
4445
def trivial(self):
4546
"""Determine if the schema contains only trivial fields."""
4647
return not any([isinstance(v, BaseDataType) for v in self.fields.values()])
4748

4849
def __repr__(self):
4950
return dict_to_ascii_table(self.fields)
5051

52+
@staticmethod
53+
def handle_references(item, builds):
54+
if '?(' not in str(item):
55+
return item
56+
57+
if isinstance(item, str):
58+
instances = re.findall(r'\?\((.*?)\)', item)
59+
60+
for k in instances:
61+
name = k.split('.')[0]
62+
attr = k.split('.')[-1]
63+
64+
if name not in builds:
65+
logging.warn(f'Could not find reference {name} from reference in {item} in builds')
66+
return item
67+
68+
to_replace = getattr(builds[name], attr)
69+
item = item.replace(f'?({k})', str(to_replace))
70+
71+
return item
72+
elif isinstance(item, list):
73+
return [Schema.handle_references(i, builds) for i in item]
74+
elif isinstance(item, dict):
75+
return {
76+
Schema.handle_references(k, builds): Schema.handle_references(v, builds)
77+
for k, v in item.items()
78+
}
79+
else:
80+
return item
81+
5182
def decode_data(
5283
self, data: dict[str, t.Any], builds: t.Dict, db
5384
) -> dict[str, t.Any]:
@@ -62,9 +93,17 @@ def decode_data(
6293

6394
decoded = {}
6495

96+
# reorder the component so that references go first
97+
is_ref = lambda x: isinstance(x, str) and x.startswith('?')
98+
data_is_ref = {k: v for k, v in data.items() if is_ref(v)}
99+
data_not_ref = {k: v for k, v in data.items() if not is_ref(v)}
100+
data = {**data_is_ref, **data_not_ref}
101+
65102
for k, value in data.items():
66103
field = self.fields.get(k)
67104

105+
value = self.handle_references(value, builds)
106+
68107
if not isinstance(field, BaseDataType) or value is None:
69108
decoded[k] = value
70109
continue
@@ -133,6 +172,9 @@ def encode_data(self, out, context: t.Optional[EncodeContext] = None, **kwargs):
133172

134173
return result
135174

175+
def __getitem__(self, item: str):
176+
return self.fields[item]
177+
136178

137179
def get_schema(db, schema: t.Union[Schema, str]) -> t.Optional[Schema]:
138180
"""Handle schema caching and loading.

superduper/components/component.py

+9-49
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def _build_info_from_path(path: str):
5656
files = {}
5757
for file_id in os.listdir(os.path.join(path, "files")):
5858
sub_paths = os.listdir(os.path.join(path, "files", file_id))
59-
assert len(sub_paths) == 1, f"Multiple files found in {file_id}"
60-
file_name = sub_paths[0]
59+
# assert len(sub_paths) == 1, f"Multiple files found in {file_id}"
60+
file_name = next(x for x in sub_paths if not x.startswith(".") or x.startswith("_"))
6161
files[file_id] = os.path.join(path, "files", file_id, file_name)
6262
config_object[KEY_FILES] = files
6363

@@ -484,19 +484,16 @@ def dependencies(self):
484484
"""Get dependencies on the component."""
485485
return ()
486486

487-
# TODO why both methods?
488487
def init(self):
489488
"""Method to help initiate component field dependencies."""
490-
self.unpack()
491489

492-
def unpack(self):
493-
"""Method to unpack the component.
494-
495-
This method is used to initialize all the fields of the component and leaf
496-
"""
490+
def mro(item):
491+
objects = item.__class__.__mro__
492+
return [f'{o.__module__}.{o.__name__}' for o in objects]
497493

498494
def _init(item):
499-
if isinstance(item, Component):
495+
496+
if 'superduper.components.component.Component' in mro(item):
500497
item.init()
501498
return item
502499

@@ -552,7 +549,7 @@ def declare_component(self, cluster):
552549
pass
553550

554551
@staticmethod
555-
def read(path: str, db: t.Optional['Datalayer'] = None):
552+
def read(path: str):
556553
"""
557554
Read a `Component` instance from a directory created with `.export`.
558555
@@ -566,45 +563,8 @@ def read(path: str, db: t.Optional['Datalayer'] = None):
566563
|_files/*
567564
```
568565
"""
569-
was_zipped = False
570-
if path.endswith('.zip'):
571-
was_zipped = True
572-
import shutil
573-
574-
shutil.unpack_archive(path)
575-
path = path.replace('.zip', '')
576-
577566
config_object = _build_info_from_path(path=path)
578-
579-
from superduper import Document
580-
581-
if db is not None:
582-
for blob in os.listdir(path + '/' + 'blobs'):
583-
with open(path + '/blobs/' + blob, 'rb') as f:
584-
data = f.read()
585-
db.artifact_store.put_bytes(data, blob)
586-
587-
out = Document.decode(config_object, db=db).unpack()
588-
else:
589-
from superduper.base.artifacts import FileSystemArtifactStore
590-
591-
artifact_store = FileSystemArtifactStore(
592-
conn=path,
593-
name='tmp_artifact_store',
594-
files='files',
595-
blobs='blobs',
596-
)
597-
db = namedtuple('tmp_db', field_names=('artifact_store',))(
598-
artifact_store=artifact_store
599-
)
600-
cls = import_object(config_object['_path'])
601-
out = Document.decode(
602-
config_object, schema=cls.class_schema, db=db
603-
).unpack()
604-
out = cls.from_dict(out, db=db)
605-
if was_zipped:
606-
shutil.rmtree(path)
607-
return out
567+
return Component.decode(config_object)
608568

609569
def export(
610570
self,

superduper/components/plugin.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ class Plugin(Component):
2323

2424
def postinit(self):
2525
"""Post initialization method."""
26-
if isinstance(self.path, FileItem):
27-
self._prepare_plugin()
28-
else:
29-
path_name = os.path.basename(self.path.rstrip("/"))
30-
self.identifier = self.identifier or f"plugin-{path_name}".replace(".", "_")
26+
self.init()
27+
self._prepare_plugin()
28+
29+
path_name = os.path.basename(self.path.rstrip("/"))
30+
self.identifier = self.identifier or f"plugin-{path_name}".replace(".", "_")
3131
self._install()
3232
super().postinit()
3333

@@ -91,19 +91,22 @@ def _pip_install(self, requirement_path):
9191

9292
def _prepare_plugin(self):
9393
plugin_name_tag = f"{self.identifier}"
94-
assert isinstance(self.path, FileItem)
94+
if isinstance(self.path, FileItem):
95+
self.path = self.path.unpack()
96+
9597
cache_path = os.path.expanduser(self.cache_path)
9698
uuid_path = os.path.join(cache_path, self.uuid)
99+
97100
# Check if plugin is already in cache
98101
if os.path.exists(uuid_path):
99102
names = os.listdir(uuid_path)
100103
names = [name for name in names if name != "__pycache__"]
101104
assert len(names) == 1, f"Multiple plugins found in {uuid_path}"
102105
self.path = os.path.join(uuid_path, names[0])
106+
sys.path.append(uuid_path)
103107
return
104108

105109
logging.info(f"Preparing plugin {plugin_name_tag}")
106-
self.path = self.path.unpack()
107110
assert os.path.exists(
108111
self.path
109112
), f"Plugin {plugin_name_tag} not found at {self.path}"

0 commit comments

Comments
 (0)