Skip to content

Commit a315023

Browse files
committed
Greatly simplify the template saving logic
1 parent 61921ee commit a315023

33 files changed

+551
-1004
lines changed

superduper/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from .components.plugin import Plugin
3636
from .components.streamlit import Streamlit
3737
from .components.table import Table
38-
from .components.template import QueryTemplate, Template
38+
from .components.template import Template
3939
from .components.vector_index import VectorIndex
4040

4141
REQUIRES = [
@@ -65,7 +65,6 @@
6565
'Table',
6666
'Application',
6767
'Template',
68-
'QueryTemplate',
6968
'Application',
7069
'Component',
7170
'pickle_serializer',

superduper/backends/base/backends.py

+61-17
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
import typing as t
23
from abc import ABC, abstractmethod
34

@@ -6,6 +7,64 @@
67
from superduper.components.component import Component
78

89

10+
class Bookkeeping(ABC):
11+
def __init__(self):
12+
self.component_uuid_mapping = defaultdict(set)
13+
self.uuid_component_mapping = {}
14+
self.tool_uuid_mapping = defaultdict(set)
15+
self.uuid_tool_mapping = {}
16+
self.tools = {}
17+
18+
def build_tool(self, component: 'Component'):
19+
pass
20+
21+
def get_tool(self, uuid: str):
22+
tool_id = self.uuid_tool_mapping[uuid]
23+
return self.tools[tool_id]
24+
25+
def put_component(self, component: 'Component', **kwargs):
26+
tool = self.build_tool(component)
27+
tool.db = self.db
28+
self.component_uuid_mapping[(component.component, component.identifier)].add(component.uuid)
29+
self.uuid_component_mapping[component.uuid] = (component.component, component.identifier)
30+
self.uuid_tool_mapping[component.uuid] = tool.identifier
31+
self.tool_uuid_mapping[tool.identifier].add(component.uuid)
32+
self.tools[tool.identifier] = tool
33+
tool.initialize(**kwargs)
34+
35+
def drop_component(self, component: str, identifier: str):
36+
uuids = self.component_uuid_mapping[(component, identifier)]
37+
tool_ids = []
38+
for uuid in uuids:
39+
del self.uuid_component_mapping[uuid]
40+
tool_id = self.uuid_tool_mapping[uuid]
41+
tool_ids.append(tool_id)
42+
del self.uuid_tool_mapping[uuid]
43+
self.tool_uuid_mapping[tool_id].remove(uuid)
44+
if not self.tool_uuid_mapping[tool_id]:
45+
self.tools.drop()
46+
del self.tools[tool_id]
47+
del self.component_uuid_mapping[(component, identifier)]
48+
49+
def drop(self):
50+
for tool in self.tools.values():
51+
tool.drop()
52+
self.component_uuid_mapping = defaultdict(set)
53+
self.uuid_component_mapping = {}
54+
self.tool_uuid_mapping = defaultdict(set)
55+
self.uuid_tool_mapping = {}
56+
self.tools = {}
57+
58+
def list_components(self):
59+
return list(self.component_uuid_mapping.keys())
60+
61+
def list_tools(self):
62+
return list(self.tools.keys())
63+
64+
def list_uuids(self):
65+
return list(self.uuid_component_mapping.keys())
66+
67+
968
class BaseBackend(ABC):
1069
"""Base backend class for cluster client."""
1170

@@ -34,28 +93,13 @@ def initialize(self):
3493
"""To be called on program start."""
3594
pass
3695

37-
def put_component(self, component: 'Component', **kwargs):
96+
@abstractmethod
97+
def put_component(self, component: 'Component'):
3898
"""Add a component to the deployment.
3999
40100
:param component: ``Component`` to put.
41101
:param kwargs: kwargs dictionary.
42102
"""
43-
# This is to make sure that we only have 1 version
44-
# of each component implemented at any given time
45-
# TODO: get identifier in string component argument.
46-
identifier = ''
47-
if isinstance(component, str):
48-
uuid = component
49-
else:
50-
uuid = component.uuid
51-
identifier = component.identifier
52-
53-
if uuid in self.list_uuids():
54-
return
55-
if identifier in self.list_components():
56-
del self[component.identifier]
57-
58-
self._put(component, **kwargs)
59103

60104
@abstractmethod
61105
def drop_component(self, component: str, identifier: str):

superduper/backends/base/scheduler.py

-57
Original file line numberDiff line numberDiff line change
@@ -19,59 +19,6 @@
1919
BATCH_SIZE = 100
2020

2121

22-
def _chunked_list(lst, batch_size=BATCH_SIZE):
23-
if len(lst) <= batch_size:
24-
return [lst]
25-
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]
26-
27-
28-
class BaseQueueConsumer(ABC):
29-
"""
30-
Base class for handling consumer process.
31-
32-
This class is an implementation of message broker between
33-
producers (superduper db client) and consumers i.e listeners.
34-
35-
:param uri: Uri to connect.
36-
:param queue_name: Queue to consume.
37-
:param callback: Callback for consumed messages.
38-
"""
39-
40-
def __init__(
41-
self,
42-
queue_name: str = '',
43-
callback: t.Optional[t.Callable] = None,
44-
):
45-
self.callback = callback
46-
self.queue_name = queue_name
47-
self.futures: t.DefaultDict = defaultdict(lambda: {})
48-
49-
@abstractmethod
50-
def start_consuming(self):
51-
"""Abstract method to start consuming messages."""
52-
pass
53-
54-
@abstractmethod
55-
def close_connection(self):
56-
"""Abstract method to close connection."""
57-
pass
58-
59-
def consume(self, *args, **kwargs):
60-
"""Start consuming messages from queue.
61-
62-
:param args: positional arguments
63-
:param kwargs: keyword arguments
64-
"""
65-
logging.info(f"Started consuming on queue: {self.queue_name}")
66-
try:
67-
self.start_consuming()
68-
except KeyboardInterrupt:
69-
logging.info("KeyboardInterrupt: Stopping consumer...")
70-
finally:
71-
self.close_connection()
72-
logging.info(f"Stopped consuming on queue: {self.queue_name}")
73-
74-
7522
class BaseScheduler(BaseBackend):
7623
"""
7724
Base class for handling publisher and consumer process.
@@ -82,10 +29,6 @@ class BaseScheduler(BaseBackend):
8229
:param uri: Uri to connect.
8330
"""
8431

85-
def __init__(self):
86-
super().__init__()
87-
self.queue: t.Dict = defaultdict(lambda: [])
88-
8932
@abstractmethod
9033
def publish(self, events: t.List[Event]):
9134
"""

superduper/backends/base/vector_search.py

+68-21
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,35 @@
66
import numpy
77
import numpy.typing
88

9-
from superduper.backends.base.backends import BaseBackend
9+
from superduper.backends.base.backends import BaseBackend, Bookkeeping
1010

1111
if t.TYPE_CHECKING:
1212
from superduper.base.datalayer import Datalayer
13-
from superduper.components.vector_index import VectorIndex
13+
from superduper.components.vector_index import VectorIndex, VectorItem
1414

1515

16-
class VectorSearchBackend(BaseBackend):
16+
class VectorSearchBackend(Bookkeeping, BaseBackend):
1717
"""Base vector-search backend."""
1818

1919
def __init__(self):
20-
self._cache = {}
20+
Bookkeeping.__init__(self)
21+
BaseBackend.__init__(self)
2122

22-
@abstractmethod
23-
def __getitem__(self, identifier):
24-
pass
25-
26-
def add(self, identifier, vectors):
23+
def add(self, uuid: str, vectors: t.List['VectorItem']):
2724
"""Add vectors to a vector-index.
2825
2926
:param identifier: Identifier of index.
3027
:param vectors: Vectors.
3128
"""
32-
self.get(identifier).add(vectors)
29+
self.get_tool(uuid).add(vectors)
3330

34-
def delete(self, identifier, ids):
31+
def delete(self, uuid, ids):
3532
"""Delete ids from index.
3633
3734
:param identifier: Identifier of index.
3835
:param ids: Ids to delete.
3936
"""
40-
self.get(identifier).delete(ids)
37+
self.get_tool(uuid).delete(ids)
4138

4239
@property
4340
def db(self) -> 'Datalayer':
@@ -52,13 +49,48 @@ def db(self, value: 'Datalayer'):
5249
"""
5350
self._db = value
5451

52+
@abstractmethod
53+
def find_nearest_from_array(
54+
self,
55+
h: numpy.typing.ArrayLike,
56+
vector_index: str,
57+
n: int = 100,
58+
within_ids: t.Sequence[str] = (),
59+
) -> t.Tuple[t.List[str], t.List[float]]:
60+
"""
61+
Find the nearest vectors to the given vector.
62+
63+
:param h: vector
64+
:param n: number of nearest vectors to return
65+
:param within_ids: list of ids to search within
66+
"""
67+
68+
@abstractmethod
69+
def find_nearest_from_id(
70+
self,
71+
id: str,
72+
vector_index: str,
73+
n: int = 100,
74+
within_ids: t.Sequence[str] = (),
75+
) -> t.Tuple[t.List[str], t.List[float]]:
76+
"""
77+
Find the nearest vectors to the given vector.
78+
79+
:param id: id of the vector to search with
80+
:param n: number of nearest vectors to return
81+
:param within_ids: list of ids to search within
82+
"""
83+
5584

5685
class VectorSearcherInterface(ABC):
5786
"""Interface for vector searchers.
5887
5988
# noqa
6089
"""
6190

91+
def __init__(self, identifier: str):
92+
self.identifier = identifier
93+
6294
@abstractmethod
6395
def add(self, items: t.Sequence['VectorItem']) -> None:
6496
"""
@@ -75,31 +107,31 @@ def delete(self, ids: t.Sequence[str]) -> None:
75107
"""
76108

77109
@abstractmethod
78-
def find_nearest_from_id(
110+
def find_nearest_from_array(
79111
self,
80-
_id,
112+
h: numpy.typing.ArrayLike,
81113
n: int = 100,
82114
within_ids: t.Sequence[str] = (),
83115
) -> t.Tuple[t.List[str], t.List[float]]:
84116
"""
85-
Find the nearest vectors to the vector with the given id.
117+
Find the nearest vectors to the given vector.
86118
87-
:param _id: id of the vector
119+
:param h: vector
88120
:param n: number of nearest vectors to return
89121
:param within_ids: list of ids to search within
90122
"""
91123

92124
@abstractmethod
93-
def find_nearest_from_array(
125+
def find_nearest_from_id(
94126
self,
95-
h: numpy.typing.ArrayLike,
127+
id: str,
96128
n: int = 100,
97129
within_ids: t.Sequence[str] = (),
98130
) -> t.Tuple[t.List[str], t.List[float]]:
99131
"""
100132
Find the nearest vectors to the given vector.
101133
102-
:param h: vector
134+
:param id: id of the vector to search with
103135
:param n: number of nearest vectors to return
104136
:param within_ids: list of ids to search within
105137
"""
@@ -111,6 +143,13 @@ def post_create(self):
111143
to perform a task after all vectors have been added
112144
"""
113145

146+
def post_create(self):
147+
"""Post create method.
148+
149+
This method is used for searchers which requires
150+
to perform a task after all vectors have been added
151+
"""
152+
114153

115154
class BaseVectorSearcher(VectorSearcherInterface):
116155
"""Base class for vector searchers.
@@ -122,10 +161,18 @@ class BaseVectorSearcher(VectorSearcherInterface):
122161

123162
native_service: t.ClassVar[bool] = True
124163

164+
@property
165+
def db(self) -> 'Datalayer':
166+
return self._db
167+
168+
@db.setter
169+
def db(self, value: 'Datalayer'):
170+
self._db = value
171+
125172
@abstractmethod
126173
def __init__(
127174
self,
128-
uuid: str,
175+
identifier: str,
129176
dimensions: int,
130177
measure: str,
131178
):
@@ -137,7 +184,7 @@ def from_component(cls, index: 'VectorIndex'):
137184
138185
:param vi: ``VectorIndex`` instance
139186
"""
140-
return cls(uuid=index.uuid, dimensions=index.dimensions, measure=index.measure)
187+
return cls(identifier=index.uuid, dimensions=index.dimensions, measure=index.measure)
141188

142189
@abstractmethod
143190
def initialize(self, db):

superduper/backends/local/cache.py

-4
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ def __delitem__(self, item):
2020
del self._cache[item]
2121

2222
def __setitem__(self, key, value):
23-
if isinstance(value, str):
24-
import pdb
25-
26-
pdb.set_trace()
2723
self._cache[key] = copy.deepcopy(value)
2824

2925
def __getitem__(self, item):

superduper/backends/local/cdc.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def list_uuids(self):
3232
"""List UUIDs of components."""
3333
return list(self._trigger_uuid_mapping.values())
3434

35-
def _put(self, item):
36-
assert isinstance(item, CDC)
37-
self.triggers.add((item.component, item.identifier))
35+
def put_component(self, component):
36+
assert isinstance(component, CDC)
37+
self.triggers.add((component.component, component.identifier))
3838

3939
def drop_component(self, component, identifier):
4040
c = self.db.load(component=component, identifier=identifier)

0 commit comments

Comments
 (0)