1
1
# TODO move to base
2
2
import dataclasses as dc
3
3
import json
4
+ import re
4
5
import typing as t
5
6
from functools import cached_property
6
7
7
- from superduper import CFG
8
+ from superduper import CFG , logging
8
9
from superduper .base .datatype import BaseDataType
9
10
from superduper .base .encoding import EncodeContext
10
11
from superduper .misc .special_dicts import dict_to_ascii_table
@@ -40,14 +41,44 @@ def __add__(self, other: 'Schema'):
40
41
new_fields .update (other .fields )
41
42
return Schema (fields = new_fields )
42
43
43
- @cached_property
44
+ @property
44
45
def trivial (self ):
45
46
"""Determine if the schema contains only trivial fields."""
46
47
return not any ([isinstance (v , BaseDataType ) for v in self .fields .values ()])
47
48
48
49
def __repr__ (self ):
49
50
return dict_to_ascii_table (self .fields )
50
51
52
+ @staticmethod
53
+ def handle_references (item , builds ):
54
+ if '?(' not in str (item ):
55
+ return item
56
+
57
+ if isinstance (item , str ):
58
+ instances = re .findall (r'\?\((.*?)\)' , item )
59
+
60
+ for k in instances :
61
+ name = k .split ('.' )[0 ]
62
+ attr = k .split ('.' )[- 1 ]
63
+
64
+ if name not in builds :
65
+ logging .warn (f'Could not find reference { name } from reference in { item } in builds' )
66
+ return item
67
+
68
+ to_replace = getattr (builds [name ], attr )
69
+ item = item .replace (f'?({ k } )' , str (to_replace ))
70
+
71
+ return item
72
+ elif isinstance (item , list ):
73
+ return [Schema .handle_references (i , builds ) for i in item ]
74
+ elif isinstance (item , dict ):
75
+ return {
76
+ Schema .handle_references (k , builds ): Schema .handle_references (v , builds )
77
+ for k , v in item .items ()
78
+ }
79
+ else :
80
+ return item
81
+
51
82
def decode_data (
52
83
self , data : dict [str , t .Any ], builds : t .Dict , db
53
84
) -> dict [str , t .Any ]:
@@ -62,9 +93,17 @@ def decode_data(
62
93
63
94
decoded = {}
64
95
96
+ # reorder the component so that references go first
97
+ is_ref = lambda x : isinstance (x , str ) and x .startswith ('?' )
98
+ data_is_ref = {k : v for k , v in data .items () if is_ref (v )}
99
+ data_not_ref = {k : v for k , v in data .items () if not is_ref (v )}
100
+ data = {** data_is_ref , ** data_not_ref }
101
+
65
102
for k , value in data .items ():
66
103
field = self .fields .get (k )
67
104
105
+ value = self .handle_references (value , builds )
106
+
68
107
if not isinstance (field , BaseDataType ) or value is None :
69
108
decoded [k ] = value
70
109
continue
@@ -133,6 +172,9 @@ def encode_data(self, out, context: t.Optional[EncodeContext] = None, **kwargs):
133
172
134
173
return result
135
174
175
+ def __getitem__ (self , item : str ):
176
+ return self .fields [item ]
177
+
136
178
137
179
def get_schema (db , schema : t .Union [Schema , str ]) -> t .Optional [Schema ]:
138
180
"""Handle schema caching and loading.
0 commit comments