Skip to content

Commit 0413a3b

Browse files
committed
Greatly simplify the template saving logic
1 parent 61921ee commit 0413a3b

30 files changed

+523
-886
lines changed

superduper/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from .components.plugin import Plugin
3636
from .components.streamlit import Streamlit
3737
from .components.table import Table
38-
from .components.template import QueryTemplate, Template
38+
from .components.template import Template
3939
from .components.vector_index import VectorIndex
4040

4141
REQUIRES = [
@@ -65,7 +65,6 @@
6565
'Table',
6666
'Application',
6767
'Template',
68-
'QueryTemplate',
6968
'Application',
7069
'Component',
7170
'pickle_serializer',

superduper/backends/base/backends.py

+62-17
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
import typing as t
23
from abc import ABC, abstractmethod
34

@@ -6,6 +7,65 @@
67
from superduper.components.component import Component
78

89

10+
class Bookkeeping(ABC):
11+
def __init__(self):
12+
self.component_uuid_mapping = defaultdict(set)
13+
self.uuid_component_mapping = {}
14+
self.tool_uuid_mapping = defaultdict(set)
15+
self.uuid_tool_mapping = {}
16+
self.tools = {}
17+
18+
@abstractmethod
19+
def build_tool(self, component: str, uuid: str):
20+
pass
21+
22+
def get_tool(self, uuid: str):
23+
tool_id = self.uuid_tool_mapping[uuid]
24+
return self.tools[tool_id]
25+
26+
def put_component(self, component: 'Component'):
27+
tool = self.build_tool(component)
28+
tool.db = self.db
29+
self.component_uuid_mapping[component.component].add(component.uuid)
30+
self.uuid_component_mapping[component.uuid] = (component.component, component.identifier)
31+
self.uuid_tool_mapping[component.uuid] = tool.identifier
32+
self.tool_uuid_mapping[tool.identifier].add(component.uuid)
33+
self.tools[tool.identifier] = tool
34+
tool.initialize()
35+
36+
def drop_component(self, component: str, identifier: str):
37+
uuids = self.component_uuid_mapping[(component, identifier)]
38+
tool_ids = []
39+
for uuid in uuids:
40+
del self.uuid_component_mapping[uuid]
41+
tool_id = self.uuid_tool_mapping[uuid]
42+
tool_ids.append(tool_id)
43+
del self.uuid_tool_mapping[uuid]
44+
self.tool_uuid_mapping[tool_id].remove(uuid)
45+
if not self.tool_uuid_mapping[tool_id]:
46+
self.tools.drop()
47+
del self.tools[tool_id]
48+
del self.component_uuid_mapping[(component, identifier)]
49+
50+
def drop(self):
51+
for tool in self.tools.values():
52+
tool.drop()
53+
self.component_uuid_mapping = defaultdict(set)
54+
self.uuid_component_mapping = {}
55+
self.tool_uuid_mapping = defaultdict(set)
56+
self.uuid_tool_mapping = {}
57+
self.tools = {}
58+
59+
def list_components(self):
60+
return list(self.component_uuid_mapping.keys())
61+
62+
def list_tools(self):
63+
return list(self.tools.keys())
64+
65+
def list_uuids(self):
66+
return list(self.uuid_component_mapping.keys())
67+
68+
969
class BaseBackend(ABC):
1070
"""Base backend class for cluster client."""
1171

@@ -34,28 +94,13 @@ def initialize(self):
3494
"""To be called on program start."""
3595
pass
3696

37-
def put_component(self, component: 'Component', **kwargs):
97+
@abstractmethod
98+
def put_component(self, component: 'Component'):
3899
"""Add a component to the deployment.
39100
40101
:param component: ``Component`` to put.
41102
:param kwargs: kwargs dictionary.
42103
"""
43-
# This is to make sure that we only have 1 version
44-
# of each component implemented at any given time
45-
# TODO: get identifier in string component argument.
46-
identifier = ''
47-
if isinstance(component, str):
48-
uuid = component
49-
else:
50-
uuid = component.uuid
51-
identifier = component.identifier
52-
53-
if uuid in self.list_uuids():
54-
return
55-
if identifier in self.list_components():
56-
del self[component.identifier]
57-
58-
self._put(component, **kwargs)
59104

60105
@abstractmethod
61106
def drop_component(self, component: str, identifier: str):

superduper/backends/base/vector_search.py

+68-21
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,35 @@
66
import numpy
77
import numpy.typing
88

9-
from superduper.backends.base.backends import BaseBackend
9+
from superduper.backends.base.backends import BaseBackend, Bookkeeping
1010

1111
if t.TYPE_CHECKING:
1212
from superduper.base.datalayer import Datalayer
13-
from superduper.components.vector_index import VectorIndex
13+
from superduper.components.vector_index import VectorIndex, VectorItem
1414

1515

16-
class VectorSearchBackend(BaseBackend):
16+
class VectorSearchBackend(Bookkeeping, BaseBackend):
1717
"""Base vector-search backend."""
1818

1919
def __init__(self):
20-
self._cache = {}
20+
Bookkeeping.__init__(self)
21+
BaseBackend.__init__(self)
2122

22-
@abstractmethod
23-
def __getitem__(self, identifier):
24-
pass
25-
26-
def add(self, identifier, vectors):
23+
def add(self, uuid: str, vectors: t.List['VectorItem']):
2724
"""Add vectors to a vector-index.
2825
2926
:param identifier: Identifier of index.
3027
:param vectors: Vectors.
3128
"""
32-
self.get(identifier).add(vectors)
29+
self.get_tool(uuid).add(vectors)
3330

34-
def delete(self, identifier, ids):
31+
def delete(self, uuid, ids):
3532
"""Delete ids from index.
3633
3734
:param identifier: Identifier of index.
3835
:param ids: Ids to delete.
3936
"""
40-
self.get(identifier).delete(ids)
37+
self.get_tool(uuid).delete(ids)
4138

4239
@property
4340
def db(self) -> 'Datalayer':
@@ -52,13 +49,48 @@ def db(self, value: 'Datalayer'):
5249
"""
5350
self._db = value
5451

52+
@abstractmethod
53+
def find_nearest_from_array(
54+
self,
55+
h: numpy.typing.ArrayLike,
56+
vector_index: str,
57+
n: int = 100,
58+
within_ids: t.Sequence[str] = (),
59+
) -> t.Tuple[t.List[str], t.List[float]]:
60+
"""
61+
Find the nearest vectors to the given vector.
62+
63+
:param h: vector
64+
:param n: number of nearest vectors to return
65+
:param within_ids: list of ids to search within
66+
"""
67+
68+
@abstractmethod
69+
def find_nearest_from_id(
70+
self,
71+
id: str,
72+
vector_index: str,
73+
n: int = 100,
74+
within_ids: t.Sequence[str] = (),
75+
) -> t.Tuple[t.List[str], t.List[float]]:
76+
"""
77+
Find the nearest vectors to the given vector.
78+
79+
:param id: id of the vector to search with
80+
:param n: number of nearest vectors to return
81+
:param within_ids: list of ids to search within
82+
"""
83+
5584

5685
class VectorSearcherInterface(ABC):
5786
"""Interface for vector searchers.
5887
5988
# noqa
6089
"""
6190

91+
def __init__(self, identifier: str):
92+
self.identifier = identifier
93+
6294
@abstractmethod
6395
def add(self, items: t.Sequence['VectorItem']) -> None:
6496
"""
@@ -75,31 +107,31 @@ def delete(self, ids: t.Sequence[str]) -> None:
75107
"""
76108

77109
@abstractmethod
78-
def find_nearest_from_id(
110+
def find_nearest_from_array(
79111
self,
80-
_id,
112+
h: numpy.typing.ArrayLike,
81113
n: int = 100,
82114
within_ids: t.Sequence[str] = (),
83115
) -> t.Tuple[t.List[str], t.List[float]]:
84116
"""
85-
Find the nearest vectors to the vector with the given id.
117+
Find the nearest vectors to the given vector.
86118
87-
:param _id: id of the vector
119+
:param h: vector
88120
:param n: number of nearest vectors to return
89121
:param within_ids: list of ids to search within
90122
"""
91123

92124
@abstractmethod
93-
def find_nearest_from_array(
125+
def find_nearest_from_id(
94126
self,
95-
h: numpy.typing.ArrayLike,
127+
id: str,
96128
n: int = 100,
97129
within_ids: t.Sequence[str] = (),
98130
) -> t.Tuple[t.List[str], t.List[float]]:
99131
"""
100132
Find the nearest vectors to the given vector.
101133
102-
:param h: vector
134+
:param id: id of the vector to search with
103135
:param n: number of nearest vectors to return
104136
:param within_ids: list of ids to search within
105137
"""
@@ -111,6 +143,13 @@ def post_create(self):
111143
to perform a task after all vectors have been added
112144
"""
113145

146+
def post_create(self):
147+
"""Post create method.
148+
149+
This method is used for searchers which requires
150+
to perform a task after all vectors have been added
151+
"""
152+
114153

115154
class BaseVectorSearcher(VectorSearcherInterface):
116155
"""Base class for vector searchers.
@@ -122,10 +161,18 @@ class BaseVectorSearcher(VectorSearcherInterface):
122161

123162
native_service: t.ClassVar[bool] = True
124163

164+
@property
165+
def db(self) -> 'Datalayer':
166+
return self._db
167+
168+
@db.setter
169+
def db(self, value: 'Datalayer'):
170+
self._db = value
171+
125172
@abstractmethod
126173
def __init__(
127174
self,
128-
uuid: str,
175+
identifier: str,
129176
dimensions: int,
130177
measure: str,
131178
):
@@ -137,7 +184,7 @@ def from_component(cls, index: 'VectorIndex'):
137184
138185
:param vi: ``VectorIndex`` instance
139186
"""
140-
return cls(uuid=index.uuid, dimensions=index.dimensions, measure=index.measure)
187+
return cls(identifier=index.uuid, dimensions=index.dimensions, measure=index.measure)
141188

142189
@abstractmethod
143190
def initialize(self, db):

superduper/backends/local/cdc.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def list_uuids(self):
3232
"""List UUIDs of components."""
3333
return list(self._trigger_uuid_mapping.values())
3434

35-
def _put(self, item):
36-
assert isinstance(item, CDC)
37-
self.triggers.add((item.component, item.identifier))
35+
def put_component(self, component):
36+
assert isinstance(component, CDC)
37+
self.triggers.add((component.component, component.identifier))
3838

3939
def drop_component(self, component, identifier):
4040
c = self.db.load(component=component, identifier=identifier)

0 commit comments

Comments
 (0)