Skip to content

Commit af81618

Browse files
committedJan 3, 2023
sqlite: add basic implementation
1 parent 0337ea3 commit af81618

20 files changed

+1062
-259
lines changed
 

Diff for: ‎.github/workflows/tests.yml

+2-4
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,13 @@ concurrency:
1515

1616
jobs:
1717
tests:
18-
timeout-minutes: 10
18+
timeout-minutes: 20
1919
runs-on: ${{ matrix.os }}
2020
strategy:
2121
fail-fast: false
2222
matrix:
23-
os: [ubuntu-20.04, windows-latest, macos-latest]
23+
os: [ubuntu-22.04, windows-latest, macos-latest]
2424
pyv: ['3.8', '3.9', '3.10', '3.11']
25-
include:
26-
- {os: ubuntu-latest, pyv: 'pypy3.8'}
2725

2826
steps:
2927
- name: Check out the repository

Diff for: ‎.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,6 @@ dmypy.json
136136

137137
# Cython debug symbols
138138
cython_debug/
139+
140+
# vim
141+
*.swp

Diff for: ‎.pre-commit-config.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,9 @@ repos:
5353
- id: bandit
5454
args: [-c, pyproject.toml]
5555
additional_dependencies: ["toml"]
56+
# NOTE: temporarily skipped
57+
# - repo: https://github.com/sqlfluff/sqlfluff
58+
# rev: 1.4.2
59+
# hooks:
60+
# - id: sqlfluff-fix
61+
# args: [--FIX-EVEN-UNPARSABLE, --force]

Diff for: ‎MANIFEST.in

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
global-include *.sql

Diff for: ‎README.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
SQLTrie
2-
=======
2+
========
33

44
|PyPI| |Status| |Python Version| |License|
55

Diff for: ‎pyproject.toml

+33
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,24 @@ show_error_codes = true
5656
show_error_context = true
5757
show_traceback = true
5858
pretty = true
59+
ignore_missing_imports = true
5960
check_untyped_defs = false
6061
# Warnings
6162
warn_no_return = true
6263
warn_redundant_casts = true
6364
warn_unreachable = true
6465
files = ["src", "tests"]
6566

67+
[tool.pylint.master]
68+
load-plugins = ["pylint_pytest"]
69+
6670
[tool.pylint.message_control]
6771
enable = ["c-extension-no-member", "no-else-return"]
72+
disable = [
73+
"fixme",
74+
"missing-function-docstring", "missing-module-docstring",
75+
"missing-class-docstring",
76+
]
6877

6978
[tool.pylint.variables]
7079
dummy-variables-rgx = "_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_"
@@ -76,3 +85,27 @@ ignore-words-list = " "
7685
[tool.bandit]
7786
exclude_dirs = ["tests"]
7887
skips = ["B101"]
88+
89+
[tool.sqlfluff.core]
90+
dialect = "sqlite"
91+
exclude_rules = "L031"
92+
93+
[tool.sqlfluff.rules]
94+
tab_space_size = 4
95+
max_line_length = 80
96+
indent_unit = "space"
97+
allow_scalar = true
98+
single_table_references = "consistent"
99+
unquoted_identifiers_policy = "all"
100+
101+
[tool.sqlfluff.rules.L010]
102+
capitalisation_policy = "upper"
103+
104+
[tool.sqlfluff.rules.L029]
105+
# these are not reserved in sqlite,
106+
# see https://www.sqlite.org/lang_keywords.html
107+
ignore_words = ["name", "value", "depth"]
108+
109+
[tool.sqlfluff.rules.L063]
110+
# Data Types
111+
extended_capitalisation_policy = "upper"

Diff for: ‎setup.cfg

+14-3
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,18 @@ long_description = file: README.rst
55
long_description_content_type = text/x-rst
66
license = Apache-2.0
77
license_file = LICENSE
8-
url = https://github.com/efiop/sqltrie
8+
url = https://github.com/iterative/sqltrie
99
platforms=any
10-
authors = Ruslan Kuprieiev
11-
maintainer_email = ruslan@iterative.ai
10+
authors = DVC team
11+
maintainer_email = support@dvc.org
12+
keywords =
13+
sqlite
14+
sqlite3
15+
sql
16+
trie
17+
prefix tree
18+
data-science
19+
diskcache
1220
classifiers =
1321
Programming Language :: Python :: 3
1422
Programming Language :: Python :: 3.8
@@ -23,16 +31,19 @@ zip_safe = False
2331
package_dir=
2432
=src
2533
packages = find:
34+
include_package_data = True
2635
install_requires=
2736

2837
[options.extras_require]
2938
tests =
3039
pytest==7.2.0
40+
pytest-benchmark
3141
pytest-sugar==0.9.5
3242
pytest-cov==3.0.0
3343
pytest-mock==3.8.2
3444
pylint==2.15.0
3545
mypy==0.971
46+
pygtrie
3647
dev =
3748
%(tests)s
3849

Diff for: ‎src/sqltrie/.trie.py.swp

-12 KB
Binary file not shown.

Diff for: ‎src/sqltrie/__init__.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
1-
"""SQLTrie."""
2-
3-
from .trie import AbstractTrie, ShortKeyError
4-
from .sqlite import SQLiteTrie
5-
1+
from .serialized import ( # noqa: F401, pylint: disable=unused-import
2+
JSONTrie,
3+
SerializedTrie,
4+
)
5+
from .sqlite import SQLiteTrie # noqa: F401, pylint: disable=unused-import
6+
from .trie import ( # noqa: F401, pylint: disable=unused-import
7+
ADD,
8+
DELETE,
9+
MODIFY,
10+
RENAME,
11+
UNCHANGED,
12+
AbstractTrie,
13+
Change,
14+
ShortKeyError,
15+
TrieKey,
16+
TrieNode,
17+
)

Diff for: ‎src/sqltrie/serialized.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import json
2+
from abc import abstractmethod
3+
from typing import Any, Optional
4+
5+
from .trie import AbstractTrie, Iterator, TrieKey
6+
7+
8+
class SerializedTrie(AbstractTrie):
9+
@property
10+
@abstractmethod
11+
def _trie(self):
12+
pass
13+
14+
def close(self):
15+
self._trie.close()
16+
17+
def commit(self):
18+
self._trie.commit()
19+
20+
def rollback(self):
21+
self._trie.rollback()
22+
23+
@abstractmethod
24+
def _load(self, key: TrieKey, value: Optional[bytes]) -> Optional[Any]:
25+
pass
26+
27+
@abstractmethod
28+
def _dump(self, key: TrieKey, value: Optional[Any]) -> Optional[bytes]:
29+
pass
30+
31+
def __setitem__(self, key, value):
32+
self._trie[key] = self._dump(key, value)
33+
34+
def __getitem__(self, key):
35+
raw = self._trie[key]
36+
return self._load(key, raw)
37+
38+
def __delitem__(self, key):
39+
del self._trie[key]
40+
41+
def __len__(self):
42+
return len(self._trie)
43+
44+
def view(self, key: Optional[TrieKey] = None) -> "SerializedTrie":
45+
if not key:
46+
return self
47+
48+
raw_trie = self._trie.view(key)
49+
trie = type(self)()
50+
# pylint: disable-next=protected-access
51+
trie._trie = raw_trie # type: ignore
52+
return trie
53+
54+
def items(self, *args, **kwargs):
55+
yield from (
56+
(key, self._load(key, raw))
57+
for key, raw in self._trie.items(*args, **kwargs)
58+
)
59+
60+
def ls(self, key, with_values=False):
61+
entries = self._trie.ls(key, with_values=with_values)
62+
if with_values:
63+
yield from (
64+
(ekey, self._load(ekey, evalue)) for ekey, evalue in entries
65+
)
66+
else:
67+
yield from entries
68+
69+
def traverse(self, node_factory, prefix=None):
70+
def _node_factory_wrapper(path_conv, path, children, value):
71+
return node_factory(
72+
path_conv, path, children, self._load(path, value)
73+
)
74+
75+
return self._trie.traverse(_node_factory_wrapper, prefix=prefix)
76+
77+
def diff(self, *args, **kwargs):
78+
yield from self._trie.diff(*args, **kwargs)
79+
80+
def has_node(self, key):
81+
return self._trie.has_node(key)
82+
83+
def shortest_prefix(self, key):
84+
sprefix = self._trie.shortest_prefix(key)
85+
if sprefix is None:
86+
return None
87+
88+
skey, raw = sprefix
89+
return key, self._load(skey, raw)
90+
91+
def prefixes(self, key):
92+
for prefix, raw in self._trie.prefixes(key):
93+
yield (prefix, self._load(prefix, raw))
94+
95+
def longest_prefix(self, key):
96+
lprefix = self._trie.longest_prefix(key)
97+
if lprefix is None:
98+
return None
99+
100+
lkey, raw = lprefix
101+
return lkey, self._load(lkey, raw)
102+
103+
def __iter__(self) -> Iterator[TrieKey]:
104+
yield from self._trie
105+
106+
107+
class JSONTrie(SerializedTrie): # pylint: disable=abstract-method
108+
def _load(self, key: TrieKey, value: Optional[bytes]) -> Optional[Any]:
109+
if value is None:
110+
return None
111+
return json.loads(value.decode("utf-8"))
112+
113+
def _dump(self, key: TrieKey, value: Optional[Any]) -> Optional[bytes]:
114+
if value is None:
115+
return None
116+
return json.dumps(value).encode("utf-8")

Diff for: ‎src/sqltrie/sqlite.py

-159
This file was deleted.

Diff for: ‎src/sqltrie/sqlite/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .sqlite import SQLiteTrie # noqa: F401, pylint: disable=unused-import

Diff for: ‎src/sqltrie/sqlite/diff.sql

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
DROP TABLE IF EXISTS temp_old_items;
2+
DROP TABLE IF EXISTS temp_new_items;
3+
DROP TABLE IF EXISTS temp_diff;
4+
DROP INDEX IF EXISTS temp_old_items_path_idx;
5+
DROP INDEX IF EXISTS temp_new_items_path_idx;
6+
7+
CREATE TEMP TABLE temp_old_items (
8+
id INTEGER PRIMARY KEY,
9+
pid INTEGER,
10+
name TEXT,
11+
path TEXT,
12+
value BLOB,
13+
UNIQUE(pid, name),
14+
UNIQUE(id, pid),
15+
CHECK(id != pid)
16+
);
17+
CREATE INDEX IF NOT EXISTS temp_old_items_path_idx ON temp_old_items (path);
18+
19+
INSERT INTO temp_old_items
20+
WITH RECURSIVE old_items (id, pid, name, path, value) AS (
21+
SELECT
22+
nodes.id,
23+
nodes.pid,
24+
nodes.name,
25+
nodes.name,
26+
nodes.value
27+
FROM nodes WHERE nodes.pid == {old_root}
28+
29+
UNION ALL
30+
31+
SELECT
32+
nodes.id,
33+
nodes.pid,
34+
nodes.name,
35+
old_items.path || '/' || nodes.name,
36+
nodes.value
37+
FROM nodes, old_items WHERE old_items.id == nodes.pid
38+
)
39+
40+
SELECT * FROM old_items;
41+
42+
CREATE TEMP TABLE temp_new_items (
43+
id INTEGER PRIMARY KEY,
44+
pid INTEGER,
45+
name TEXT,
46+
path TEXT,
47+
value BLOB,
48+
UNIQUE(pid, name),
49+
UNIQUE(id, pid),
50+
CHECK(id != pid)
51+
);
52+
CREATE INDEX IF NOT EXISTS temp_new_items_path_idx ON temp_new_items (path);
53+
54+
INSERT INTO temp_new_items
55+
WITH RECURSIVE new_items (id, pid, name, path, value) AS (
56+
SELECT
57+
nodes.id,
58+
nodes.pid,
59+
nodes.name,
60+
nodes.name,
61+
nodes.value
62+
FROM nodes WHERE nodes.pid == {new_root}
63+
64+
UNION ALL
65+
66+
SELECT
67+
nodes.id,
68+
nodes.pid,
69+
nodes.name,
70+
new_items.path || '/' || nodes.name,
71+
nodes.value
72+
FROM nodes, new_items WHERE new_items.id == nodes.pid
73+
)
74+
75+
SELECT * FROM new_items;
76+
77+
CREATE TEMP TABLE temp_diff AS
78+
WITH RECURSIVE diff (
79+
old_id,
80+
old_pid,
81+
old_name,
82+
old_path,
83+
old_value,
84+
new_id,
85+
new_pid,
86+
new_name,
87+
new_path,
88+
new_value
89+
) AS (
90+
/* FULL OUTER JOIN is not supported, so we have to use two LEFT JOINs :( */
91+
SELECT
92+
old.id,
93+
old.pid,
94+
old.name,
95+
old.path,
96+
old.value,
97+
new.id,
98+
new.pid,
99+
new.name,
100+
new.path,
101+
new.value
102+
FROM
103+
temp_old_items AS old
104+
LEFT JOIN
105+
temp_new_items AS new
106+
ON old.path == new.path
107+
108+
UNION
109+
110+
SELECT
111+
old.id,
112+
old.pid,
113+
old.name,
114+
old.path,
115+
old.value,
116+
new.id,
117+
new.pid,
118+
new.name,
119+
new.path,
120+
new.value
121+
FROM
122+
temp_new_items AS new
123+
LEFT JOIN
124+
temp_old_items AS old
125+
ON old.path == new.path
126+
)
127+
128+
SELECT
129+
(
130+
CASE WHEN old_id IS NULL THEN 'add' ELSE (
131+
CASE WHEN new_id IS NULL THEN 'delete' ELSE (
132+
CASE
133+
WHEN old_value != new_value THEN 'modify' ELSE 'unchanged'
134+
END
135+
) END
136+
) END
137+
) AS type,
138+
*
139+
FROM diff
140+
WHERE (
141+
{with_unchanged}
142+
OR type != 'unchanged'
143+
);

Diff for: ‎src/sqltrie/sqlite/init.sql

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
CREATE TABLE IF NOT EXISTS nodes (
2+
id INTEGER PRIMARY KEY AUTOINCREMENT,
3+
pid INTEGER,
4+
name TEXT,
5+
has_value BOOLEAN,
6+
value BLOB,
7+
UNIQUE(pid, name),
8+
UNIQUE(id, pid),
9+
CHECK(id != pid)
10+
);
11+
CREATE INDEX IF NOT EXISTS nodes_pid_idx ON nodes (pid);
12+
INSERT OR IGNORE INTO nodes (id, pid, name, has_value, value)
13+
VALUES (1, NULL, "", FALSE, NULL);

Diff for: ‎src/sqltrie/sqlite/items.sql

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
DROP TABLE IF EXISTS temp_items;
2+
3+
CREATE TEMP TABLE temp_items AS
4+
WITH RECURSIVE children (
5+
id, pid, name, path, has_value, value, found_value
6+
) AS (
7+
SELECT
8+
nodes.id,
9+
nodes.pid,
10+
nodes.name,
11+
nodes.name,
12+
nodes.has_value,
13+
nodes.value,
14+
nodes.has_value
15+
FROM nodes WHERE nodes.pid == {root}
16+
17+
UNION ALL
18+
19+
SELECT
20+
nodes.id,
21+
nodes.pid,
22+
nodes.name,
23+
children.path || '/' || nodes.name,
24+
nodes.has_value,
25+
nodes.value,
26+
children.found_value OR nodes.has_value
27+
FROM nodes, children
28+
WHERE children.id == nodes.pid AND (NOT {shallow} OR NOT children.found_value OR nodes.has_value)
29+
)
30+
31+
SELECT
32+
id,
33+
pid,
34+
name,
35+
path,
36+
has_value,
37+
value
38+
FROM children WHERE has_value;

Diff for: ‎src/sqltrie/sqlite/sqlite.py

+321
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
import sqlite3
2+
import threading
3+
from pathlib import Path
4+
from typing import Iterator, Optional, Union
5+
from uuid import uuid4
6+
7+
from ..trie import (
8+
AbstractTrie,
9+
Change,
10+
ShortKeyError,
11+
TrieKey,
12+
TrieNode,
13+
TrieStep,
14+
)
15+
16+
# NOTE: seems like "named" doesn't work without changing this global var,
17+
# so unfortunately we have to stick with qmark.
18+
assert sqlite3.paramstyle == "qmark"
19+
20+
scripts = Path(__file__).parent
21+
22+
ROOT_KEY = ()
23+
ROOT_ID = 1
24+
ROOT_NAME = "/"
25+
26+
INIT_SQL = (scripts / "init.sql").read_text()
27+
28+
STEPS_SQL = (scripts / "steps.sql").read_text()
29+
STEPS_TABLE = "temp_steps"
30+
31+
ITEMS_SQL = (scripts / "items.sql").read_text()
32+
ITEMS_TABLE = "temp_items"
33+
34+
DIFF_SQL = (scripts / "diff.sql").read_text()
35+
DIFF_TABLE = "temp_diff"
36+
37+
DEFAULT_DB_FMT = "file:sqlitetrie_{id}?mode=memory&cache=shared"
38+
39+
40+
class SQLiteTrie(AbstractTrie):
41+
def __init__(self, *args, **kwargs):
42+
self._root_key = ROOT_KEY
43+
self._root_id = ROOT_ID
44+
self._path = DEFAULT_DB_FMT.format(id=uuid4())
45+
self._local = threading.local()
46+
self._ids = {}
47+
super().__init__(*args, **kwargs)
48+
49+
@classmethod
50+
def open(cls, path):
51+
trie = cls()
52+
trie._path = path
53+
return trie
54+
55+
def close(self):
56+
conn = getattr(self._local, "conn", None)
57+
if conn is None:
58+
return
59+
60+
conn.close()
61+
62+
try:
63+
delattr(self._local, "conn")
64+
except AttributeError:
65+
pass
66+
67+
def commit(self):
68+
self._conn.commit()
69+
70+
def rollback(self):
71+
self._conn.rollback()
72+
73+
@property
74+
def _conn(self): # pylint: disable=method-hidden
75+
conn = getattr(self._local, "conn", None)
76+
if conn is None:
77+
conn = self._local.conn = sqlite3.connect(self._path)
78+
conn.row_factory = sqlite3.Row
79+
conn.executescript(INIT_SQL)
80+
81+
return conn
82+
83+
def _create_node(self, key):
84+
try:
85+
return self._ids[key]
86+
except KeyError:
87+
pass
88+
89+
rows = self._traverse(key)
90+
if rows:
91+
longest_prefix = tuple(rows[-1]["path"].split("/"))
92+
pid = rows[-1]["id"]
93+
else:
94+
longest_prefix = ()
95+
pid = self._root_id
96+
self._ids[longest_prefix] = pid
97+
98+
node_key = longest_prefix
99+
for name in key[len(longest_prefix) :]:
100+
node_key = (*node_key, name)
101+
row = self._conn.execute(
102+
"""
103+
INSERT OR IGNORE
104+
INTO nodes (pid, name)
105+
VALUES (?, ?)
106+
RETURNING id
107+
""",
108+
(pid, name),
109+
).fetchone()
110+
nid = row["id"]
111+
self._ids[node_key] = nid
112+
pid = nid
113+
114+
return pid
115+
116+
def _traverse(self, key):
117+
self._conn.executescript(
118+
STEPS_SQL.format(path="/".join(key), root=self._root_id)
119+
)
120+
121+
return self._conn.execute( # nosec
122+
f"SELECT * FROM {STEPS_TABLE}"
123+
).fetchall()
124+
125+
def _get_node(self, key):
126+
if not key:
127+
return {
128+
"id": self._root_id,
129+
"pid": None,
130+
"name": None,
131+
"value": None,
132+
}
133+
134+
rows = list(self._traverse(key))
135+
if len(rows) != len(key):
136+
raise KeyError(key)
137+
138+
return rows[-1]
139+
140+
def _get_children(self, key, limit=None):
141+
node = self._get_node(key)
142+
143+
limit_sql = ""
144+
if limit:
145+
limit_sql = f"LIMIT {limit}"
146+
147+
return self._conn.execute( # nosec
148+
f"""
149+
SELECT * FROM nodes WHERE nodes.pid == ? {limit_sql}
150+
""",
151+
(node["id"],),
152+
).fetchall()
153+
154+
def _delete_node(self, key):
155+
node = self._get_node(key)
156+
del self._ids[key]
157+
self._conn.execute(
158+
"""
159+
DELETE FROM nodes WHERE id = ?
160+
""",
161+
(node["id"],),
162+
)
163+
164+
def __setitem__(self, key, value):
165+
pid = self._create_node(key[:-1])
166+
self._conn.execute(
167+
"""
168+
INSERT INTO
169+
nodes (pid, name, has_value, value)
170+
VALUES (?1, ?2, True, ?3)
171+
ON CONFLICT (pid, name) DO UPDATE SET value=?3
172+
""",
173+
(
174+
pid,
175+
key[-1],
176+
value,
177+
),
178+
)
179+
180+
def __iter__(self):
181+
yield from (key for key, _ in self.items())
182+
183+
def __getitem__(self, key):
184+
row = self._get_node(key)
185+
has_value = row["has_value"]
186+
if not has_value:
187+
raise ShortKeyError(key)
188+
return row["value"]
189+
190+
def __delitem__(self, key):
191+
node = self._get_node(key)
192+
self._conn.execute(
193+
"""
194+
UPDATE nodes SET has_value = False, value = NULL WHERE id == ?
195+
""",
196+
(node["id"],),
197+
)
198+
199+
def __len__(self):
200+
self._conn.executescript(
201+
ITEMS_SQL.format(root=self._root_id, shallow=False)
202+
)
203+
return self._conn.execute( # nosec
204+
f"""
205+
SELECT COUNT(*) AS count FROM {ITEMS_TABLE}
206+
"""
207+
).fetchone()["count"]
208+
209+
def prefixes(self, key: TrieKey) -> Iterator[TrieStep]:
210+
for row in self._traverse(key):
211+
if not row["has_value"]:
212+
continue
213+
214+
yield (
215+
tuple(row["path"].split("/")), # type: ignore
216+
row["value"],
217+
)
218+
219+
def shortest_prefix(self, key: TrieKey) -> Optional[TrieStep]:
220+
return next(self.prefixes(key), None)
221+
222+
def longest_prefix(self, key) -> Optional[TrieStep]:
223+
ret = None
224+
for step in self.prefixes(key):
225+
ret = step
226+
return ret
227+
228+
def view( # type: ignore
229+
self,
230+
key: Optional[TrieKey] = None,
231+
) -> "SQLiteTrie":
232+
if not key:
233+
return self
234+
235+
self.commit()
236+
node = self._get_node(key)
237+
238+
trie = SQLiteTrie()
239+
trie._path = self._path # pylint: disable=protected-access
240+
trie._root_key = key # pylint: disable=protected-access
241+
trie._root_id = node["id"] # pylint: disable=protected-access
242+
return trie
243+
244+
def items(self, prefix=None, shallow=False):
245+
if prefix:
246+
pid = self._get_node(prefix)["id"]
247+
else:
248+
prefix = ()
249+
pid = self._root_id
250+
251+
self._conn.executescript(ITEMS_SQL.format(root=pid, shallow=shallow))
252+
rows = self._conn.execute(f"SELECT * FROM {ITEMS_TABLE}") # nosec
253+
254+
yield from (
255+
((*prefix, *row["path"].split("/")), row["value"]) for row in rows
256+
)
257+
258+
def clear(self):
259+
self._conn.execute("DELETE FROM nodes")
260+
261+
def has_node(self, key: TrieKey) -> bool:
262+
try:
263+
value = self[key]
264+
return value is not None
265+
except KeyError:
266+
return False
267+
268+
def ls(
269+
self, key: TrieKey, with_values: Optional[bool] = False
270+
) -> Iterator[Union[TrieKey, TrieNode]]:
271+
if with_values:
272+
yield from ( # type: ignore
273+
((*key, row["name"]), row["value"])
274+
for row in self._get_children(key)
275+
)
276+
else:
277+
yield from ( # type: ignore
278+
(*key, row["name"]) for row in self._get_children(key)
279+
)
280+
281+
def traverse(self, node_factory, prefix=None):
282+
key = prefix or ()
283+
row = self._get_node(prefix)
284+
value = row["value"]
285+
286+
children_keys = (
287+
(*key, row["name"]) for row in self._get_children(key)
288+
)
289+
children = (
290+
self.traverse(node_factory, child) for child in children_keys
291+
)
292+
293+
return node_factory(None, key, children, value)
294+
295+
def diff(self, old, new, with_unchanged=False):
296+
old_id = self._get_node(old)["id"]
297+
new_id = self._get_node(new)["id"]
298+
299+
self._conn.executescript(
300+
DIFF_SQL.format(
301+
old_root=old_id,
302+
new_root=new_id,
303+
with_unchanged=with_unchanged,
304+
)
305+
)
306+
307+
rows = self._conn.execute(f"SELECT * FROM {DIFF_TABLE}") # nosec
308+
yield from (
309+
Change(
310+
row["type"],
311+
TrieNode(
312+
tuple(row["old_path"].split("/")),
313+
row["old_value"],
314+
),
315+
TrieNode(
316+
tuple(row["new_path"].split("/")),
317+
row["new_value"],
318+
),
319+
)
320+
for row in rows
321+
)

Diff for: ‎src/sqltrie/sqlite/steps.sql

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
DROP TABLE IF EXISTS temp_split;
2+
DROP TABLE IF EXISTS temp_steps;
3+
4+
CREATE TEMP TABLE temp_split AS
5+
WITH RECURSIVE
6+
path (path) AS (
7+
VALUES('{path}')
8+
),
9+
10+
split (depth, name, rpath) AS (
11+
SELECT
12+
1,
13+
(
14+
CASE WHEN instr(path, '/') == 0 THEN path ELSE substr(path, 0, instr(path, '/')) END
15+
),
16+
(
17+
CASE WHEN instr(path, '/') == 0 THEN '' ELSE substr(path, instr(path, '/') + 1) END
18+
)
19+
FROM path
20+
21+
UNION ALL
22+
23+
SELECT
24+
split.depth + 1,
25+
(
26+
CASE WHEN instr(split.rpath, '/') == 0 THEN split.rpath ELSE substr(split.rpath, 0, instr(split.rpath, '/')) END
27+
),
28+
(
29+
CASE WHEN instr(split.rpath, '/') == 0 THEN '' ELSE substr(split.rpath, instr(split.rpath, '/') + 1) END
30+
)
31+
FROM split WHERE split.rpath != ''
32+
)
33+
34+
SELECT
35+
depth,
36+
name
37+
FROM split;
38+
39+
CREATE TEMP TABLE temp_steps AS
40+
WITH RECURSIVE
41+
steps (id, pid, name, path, has_value, value, depth) AS (
42+
SELECT
43+
nodes.id,
44+
nodes.pid,
45+
nodes.name,
46+
nodes.name,
47+
nodes.has_value,
48+
nodes.value,
49+
temp_split.depth
50+
FROM nodes, temp_split
51+
WHERE
52+
temp_split.depth == 1 AND nodes.pid == {root} AND nodes.name == temp_split.name
53+
54+
UNION ALL
55+
56+
SELECT
57+
nodes.id,
58+
nodes.pid,
59+
nodes.name,
60+
steps.path || '/' || nodes.name,
61+
nodes.has_value,
62+
nodes.value,
63+
steps.depth + 1
64+
FROM nodes, steps, temp_split
65+
WHERE
66+
nodes.pid == steps.id
67+
AND temp_split.depth == steps.depth + 1
68+
AND temp_split.name == nodes.name
69+
)
70+
71+
SELECT
72+
id,
73+
pid,
74+
name,
75+
path,
76+
has_value,
77+
value
78+
FROM steps;

Diff for: ‎src/sqltrie/trie.py

+83-76
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,117 @@
1-
from collections.abc import MutableMapping
21
from abc import abstractmethod
2+
from collections.abc import MutableMapping
3+
from typing import Iterator, NamedTuple, Optional, Tuple, Union
4+
5+
from attrs import define
36

47

58
class ShortKeyError(KeyError):
69
"""Raised when given key is a prefix of an existing longer key
710
but does not have a value associated with itself."""
811

912

10-
class AbstractTrie(MutableMapping):
11-
def __init__(self, *args, **kwargs):
12-
self.update(*args, **kwargs)
13-
14-
def enable_sorting(self, enable=True):
15-
raise NotImplementedError
16-
17-
def clear(self):
18-
raise NotImplementedError
13+
TrieKey = Union[Tuple[()], Tuple[str]]
14+
TrieStep = Tuple[Optional[TrieKey], Optional[bytes]]
1915

20-
def update(self, *args, **kwargs): # pylint: disable=arguments-differ
21-
raise NotImplementedError
2216

23-
def merge(self, other, overwrite=False):
24-
raise NotImplementedError
17+
class TrieNode(NamedTuple):
18+
key: TrieKey
19+
value: Optional[bytes]
2520

26-
def copy(self, __make_copy=lambda x: x):
27-
raise NotImplementedError
2821

29-
def __copy__(self):
30-
return self.copy()
22+
ADD = "add"
23+
MODIFY = "modify"
24+
RENAME = "rename"
25+
DELETE = "delete"
26+
UNCHANGED = "unchanged"
3127

32-
def __deepcopy__(self, memo):
33-
return self.copy(lambda x: _copy.deepcopy(x, memo))
3428

35-
@classmethod
36-
def fromkeys(cls, keys, value=None):
37-
raise NotImplementedError
38-
39-
def __iter__(self):
40-
return self.iterkeys()
41-
42-
def iteritems(self, prefix=None, shallow=False):
43-
raise NotImplementedError
44-
45-
def iterkeys(self, prefix=None, shallow=False):
46-
raise NotImplementedError
29+
@define(frozen=True, hash=True, order=True)
30+
class Change:
31+
typ: str
32+
old: Optional[TrieNode]
33+
new: Optional[TrieNode]
4734

48-
def itervalues(self, prefix=None, shallow=False):
49-
raise NotImplementedError
35+
@property
36+
def key(self) -> TrieKey:
37+
if self.typ == RENAME:
38+
raise ValueError
5039

51-
def items(self, prefix=None, shallow=False):
52-
return list(self.iteritems(prefix=prefix, shallow=shallow))
40+
if self.typ == ADD:
41+
entry = self.new
42+
else:
43+
entry = self.old
5344

54-
def keys(self, prefix=None, shallow=False):
55-
return list(self.iterkeys(prefix=prefix, shallow=shallow))
45+
assert entry
46+
assert entry.key
47+
return entry.key
5648

57-
def values(self, prefix=None, shallow=False):
58-
return list(self.itervalues(prefix=prefix, shallow=shallow))
49+
def __bool__(self) -> bool:
50+
return self.typ != UNCHANGED
5951

60-
def __len__(self):
61-
raise NotImplementedError
6252

63-
def __bool__(self):
64-
raise NotImplementedError
65-
66-
__nonzero__ = __bool__
67-
__hash__ = None
68-
69-
def has_node(self, key):
70-
raise NotImplementedError
71-
72-
def has_key(self, key):
73-
return bool(self.has_node(key) & self.HAS_VALUE)
53+
class AbstractTrie(MutableMapping):
54+
def __init__(self, *args, **kwargs):
55+
self.update(*args, **kwargs)
7456

75-
def has_subtrie(self, key):
76-
return bool(self.has_node(key) & self.HAS_SUBTRIE)
57+
@classmethod
58+
@abstractmethod
59+
def open(cls, path: str) -> "AbstractTrie":
60+
pass
7761

78-
def __getitem__(self, key_or_slice):
79-
raise NotImplementedError
62+
@abstractmethod
63+
def close(self) -> None:
64+
pass
8065

81-
def __setitem__(self, key_or_slice, value):
82-
raise NotImplementedError
66+
@abstractmethod
67+
def commit(self) -> None:
68+
pass
8369

84-
def __delitem__(self, key_or_slice):
85-
raise NotImplementedError
70+
@abstractmethod
71+
def rollback(self) -> None:
72+
pass
8673

87-
def setdefault(self, key, default=None):
88-
raise NotImplementedError
74+
@abstractmethod
75+
def items( # type: ignore
76+
self, prefix: Optional[TrieKey] = None, shallow: Optional[bool] = False
77+
) -> Iterator[Tuple[TrieKey, bytes]]:
78+
pass
8979

90-
def pop(self, key, default=None):
91-
raise NotImplementedError
80+
@abstractmethod
81+
def view(self, key: Optional[TrieKey] = None) -> "AbstractTrie":
82+
pass
9283

93-
def popitem(self):
94-
raise NotImplementedError
84+
@abstractmethod
85+
def has_node(self, key: TrieKey) -> bool:
86+
pass
9587

96-
def walk_towards(self, key):
97-
raise NotImplementedError
88+
@abstractmethod
89+
def prefixes(self, key: TrieKey) -> Iterator[TrieStep]:
90+
pass
9891

99-
def prefixes(self, key):
100-
raise NotImplementedError
92+
@abstractmethod
93+
def shortest_prefix(self, key: TrieKey) -> Optional[TrieStep]:
94+
pass
10195

102-
def shortest_prefix(self, key):
103-
raise NotImplementedError
96+
@abstractmethod
97+
def longest_prefix(self, key: TrieKey) -> Optional[TrieStep]:
98+
pass
10499

105-
def longest_prefix(self, key):
106-
raise NotImplementedError
100+
@abstractmethod
101+
# pylint: disable-next=invalid-name
102+
def ls(
103+
self, key: TrieKey, with_values: bool = False
104+
) -> Iterator[Union[TrieKey, TrieNode]]:
105+
pass
107106

108-
def traverse(self, node_factory, prefix=None):
107+
@abstractmethod
108+
def traverse(
109+
self, node_factory, prefix: Optional[TrieKey]
110+
) -> Iterator[Tuple[TrieKey, bytes]]:
109111
pass
110112

113+
@abstractmethod
114+
def diff(
115+
self, old: TrieKey, new: TrieKey, with_unchanged: bool = False
116+
) -> Iterator[Change]:
117+
pass

Diff for: ‎tests/benchmarks/test_sqltrie.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import pytest
2+
from pygtrie import Trie as _GTrie
3+
4+
from sqltrie import ADD, DELETE, MODIFY, UNCHANGED, Change, SQLiteTrie
5+
6+
NFILES = 10000
7+
NSUBDIRS = 3
8+
9+
10+
class GTrie(_GTrie):
11+
def ls(self, root_key):
12+
def node_factory(_, key, children, *args):
13+
if key == root_key:
14+
return children
15+
else:
16+
return key
17+
18+
return self.traverse(node_factory, prefix=root_key)
19+
20+
def diff(self, old, new):
21+
# FIXME this is not the most optimal implementation
22+
old_keys = {key for key, _ in self.iteritems(old or ())}
23+
new_keys = {key for key, _ in self.iteritems(new or ())}
24+
25+
for key in old_keys | new_keys:
26+
old_entry = self.get(key)
27+
new_entry = self.get(key)
28+
29+
typ = UNCHANGED
30+
if old_entry and not new_entry:
31+
typ = DELETE
32+
elif not old_entry and new_entry:
33+
typ = ADD
34+
elif old_entry != new_entry:
35+
typ = MODIFY
36+
else:
37+
continue
38+
39+
yield Change(typ, old_entry, new_entry)
40+
41+
42+
@pytest.fixture(scope="session")
43+
def items():
44+
ret = {}
45+
46+
files = {str(idx): bytes(idx) for idx in range(NFILES)}
47+
for subdir in range(NSUBDIRS):
48+
ret[str(subdir)] = files.copy()
49+
50+
return ret
51+
52+
53+
@pytest.mark.parametrize("cls", [SQLiteTrie, GTrie])
54+
def test_set(benchmark, items, cls):
55+
def _set():
56+
trie = cls()
57+
58+
for subdir in range(NSUBDIRS):
59+
for idx in range(NFILES):
60+
trie[(str(subdir), str(idx))] = bytes(idx)
61+
62+
benchmark(_set)
63+
64+
65+
@pytest.mark.parametrize("cls", [SQLiteTrie, GTrie])
66+
def test_items(benchmark, items, cls):
67+
trie = cls()
68+
69+
for subdir in range(NSUBDIRS):
70+
for idx in range(NFILES):
71+
trie[(str(subdir), str(idx))] = bytes(idx)
72+
73+
def _items():
74+
list(trie.items())
75+
76+
benchmark(_items)
77+
78+
79+
@pytest.mark.parametrize("cls", [SQLiteTrie, GTrie])
80+
def test_ls(benchmark, items, cls):
81+
trie = cls()
82+
83+
for subdir in range(NSUBDIRS):
84+
for idx in range(NFILES):
85+
trie[(str(subdir), str(idx))] = bytes(idx)
86+
87+
def _ls():
88+
list(trie.ls(("1",)))
89+
90+
benchmark(_ls)
91+
92+
93+
@pytest.mark.parametrize("cls", [SQLiteTrie, GTrie])
94+
def test_diff(benchmark, items, cls):
95+
trie = cls()
96+
97+
for subdir in range(NSUBDIRS):
98+
for idx in range(NFILES):
99+
trie[(str(subdir), str(idx))] = bytes(idx)
100+
101+
def _diff():
102+
list(trie.diff(None, None))
103+
104+
benchmark(_diff)

Diff for: ‎tests/test_sqltrie.py

+88-11
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,101 @@
11
"""Tests for `sqltrie` package."""
2+
import os
3+
24
import pytest
35

4-
from sqltrie import SQLiteTrie, ShortKeyError
6+
from sqltrie import UNCHANGED, Change, ShortKeyError, SQLiteTrie, TrieNode
7+
58

69
def test_trie():
710
trie = SQLiteTrie()
811

9-
trie[("foo",)] = "foo-value"
10-
trie[("foo", "bar", "baz")] = "baz-value"
12+
trie[("foo",)] = b"foo-value"
13+
trie[("foo", "bar", "baz")] = b"baz-value"
1114

1215
assert len(trie) == 2
13-
assert trie[("foo",)] == "foo-value"
14-
assert trie[("foo", "bar")] == None
15-
assert trie[("foo", "bar", "baz")] == "baz-value"
16+
assert trie[("foo",)] == b"foo-value"
17+
with pytest.raises(ShortKeyError):
18+
trie[("foo", "bar")] # pylint: disable=pointless-statement
19+
assert trie[("foo", "bar", "baz")] == b"baz-value"
1620

1721
del trie[("foo",)]
1822
assert len(trie) == 1
19-
# FIXME the next two should raise ShortKeyError
20-
assert trie[("foo",)] == None
21-
assert trie[("foo", "bar")] == None
22-
assert trie[("foo", "bar", "baz")] == "baz-value"
23+
assert trie[("foo", "bar", "baz")] == b"baz-value"
24+
25+
with pytest.raises(ShortKeyError):
26+
trie[("foo",)] # pylint: disable=pointless-statement
27+
28+
with pytest.raises(ShortKeyError):
29+
trie[("foo", "bar")] # pylint: disable=pointless-statement
30+
31+
with pytest.raises(KeyError):
32+
trie[("non-existent",)] # pylint: disable=pointless-statement
33+
34+
with pytest.raises(KeyError):
35+
trie[("foo", "non-existent")] # pylint: disable=pointless-statement
36+
37+
assert trie.longest_prefix(()) is None
38+
assert trie.longest_prefix(("non-existent",)) is None
39+
assert trie.longest_prefix(("foo",)) is None
40+
assert trie.longest_prefix(("foo", "non-existent")) is None
41+
assert trie.longest_prefix(("foo", "bar", "baz", "qux")) == (
42+
("foo", "bar", "baz"),
43+
b"baz-value",
44+
)
45+
46+
assert set(trie.items()) == {
47+
(("foo", "bar", "baz"), b"baz-value"),
48+
}
49+
assert set(trie.items(shallow=True)) == {
50+
(("foo", "bar", "baz"), b"baz-value"),
51+
}
52+
assert set(trie.items(("foo",))) == {
53+
(("foo", "bar", "baz"), b"baz-value"),
54+
}
55+
assert set(trie.items(("foo", "bar"))) == {
56+
(("foo", "bar", "baz"), b"baz-value"),
57+
}
58+
assert set(trie.items(("foo", "bar", "baz"))) == set()
59+
60+
assert set(trie.view(("foo",)).items()) == {
61+
(("bar", "baz"), b"baz-value"),
62+
}
63+
assert set(trie.view(("foo", "bar", "baz")).items()) == set()
64+
65+
assert list(trie.ls(())) == [("foo",)]
66+
assert list(trie.ls(("foo",))) == [("foo", "bar")]
67+
assert list(trie.ls(("foo", "bar"))) == [("foo", "bar", "baz")]
68+
69+
assert not list(trie.diff(("foo",), ("foo",)))
70+
assert list(trie.diff(("foo",), ("foo",), with_unchanged=True)) == [
71+
Change(
72+
typ=UNCHANGED,
73+
old=TrieNode(key=("bar",), value=None),
74+
new=TrieNode(key=("bar",), value=None),
75+
),
76+
Change(
77+
typ=UNCHANGED,
78+
old=TrieNode(key=("bar", "baz"), value=b"baz-value"),
79+
new=TrieNode(key=("bar", "baz"), value=b"baz-value"),
80+
),
81+
]
82+
83+
84+
def test_open(tmp_path):
85+
path = os.fspath(tmp_path / "db")
86+
trie = SQLiteTrie.open(path)
87+
88+
assert len(trie) == 0
89+
90+
trie[("foo",)] = b"foo-value"
91+
trie[("foo", "bar", "baz")] = b"baz-value"
92+
93+
trie.commit()
94+
trie.close()
95+
96+
trie = SQLiteTrie.open(path)
97+
98+
assert len(trie) == 2
2399

24-
assert set(trie.iteritems()) == set()
100+
assert trie[("foo",)] == b"foo-value"
101+
assert trie[("foo", "bar", "baz")] == b"baz-value"

0 commit comments

Comments
 (0)
Please sign in to comment.