|
1 | 1 | import argparse
|
2 | 2 | import os
|
3 | 3 | import shutil
|
| 4 | +from typing import Any, ClassVar |
4 | 5 |
|
5 | 6 | from private_gpt.paths import local_data_path
|
6 | 7 | from private_gpt.settings.settings import settings
|
7 | 8 |
|
8 | 9 |
|
9 |
| -def wipe() -> None: |
10 |
| - WIPE_MAP = { |
11 |
| - "simple": wipe_simple, # node store |
12 |
| - "chroma": wipe_chroma, # vector store |
13 |
| - "postgres": wipe_postgres, # node, index and vector store |
14 |
| - } |
15 |
| - for dbtype in ("nodestore", "vectorstore"): |
16 |
| - database = getattr(settings(), dbtype).database |
17 |
| - func = WIPE_MAP.get(database) |
18 |
| - if func: |
19 |
| - func(dbtype) |
20 |
| - else: |
21 |
| - print(f"Unable to wipe database '{database}' for '{dbtype}'") |
22 |
| - |
23 |
| - |
24 | 10 | def wipe_file(file: str) -> None:
|
25 | 11 | if os.path.isfile(file):
|
26 | 12 | os.remove(file)
|
@@ -50,62 +36,149 @@ def wipe_tree(path: str) -> None:
|
50 | 36 | continue
|
51 | 37 |
|
52 | 38 |
|
53 |
| -def wipe_simple(dbtype: str) -> None: |
54 |
| - assert dbtype == "nodestore" |
55 |
| - from llama_index.core.storage.docstore.types import ( |
56 |
| - DEFAULT_PERSIST_FNAME as DOCSTORE, |
57 |
| - ) |
58 |
| - from llama_index.core.storage.index_store.types import ( |
59 |
| - DEFAULT_PERSIST_FNAME as INDEXSTORE, |
60 |
| - ) |
| 39 | +class Postgres: |
| 40 | + tables: ClassVar[dict[str, list[str]]] = { |
| 41 | + "nodestore": ["data_docstore", "data_indexstore"], |
| 42 | + "vectorstore": ["data_embeddings"], |
| 43 | + } |
| 44 | + |
| 45 | + def __init__(self) -> None: |
| 46 | + try: |
| 47 | + import psycopg2 |
| 48 | + except ModuleNotFoundError: |
| 49 | + raise ModuleNotFoundError("Postgres dependencies not found") from None |
61 | 50 |
|
62 |
| - for store in (DOCSTORE, INDEXSTORE): |
63 |
| - wipe_file(str((local_data_path / store).absolute())) |
| 51 | + connection = settings().postgres.model_dump(exclude_none=True) |
| 52 | + self.schema = connection.pop("schema_name") |
| 53 | + self.conn = psycopg2.connect(**connection) |
64 | 54 |
|
| 55 | + def wipe(self, storetype: str) -> None: |
| 56 | + cur = self.conn.cursor() |
| 57 | + try: |
| 58 | + for table in self.tables[storetype]: |
| 59 | + sql = f"DROP TABLE IF EXISTS {self.schema}.{table}" |
| 60 | + cur.execute(sql) |
| 61 | + print(f"Table {self.schema}.{table} dropped.") |
| 62 | + self.conn.commit() |
| 63 | + finally: |
| 64 | + cur.close() |
65 | 65 |
|
66 |
| -def wipe_postgres(dbtype: str) -> None: |
67 |
| - try: |
68 |
| - import psycopg2 |
69 |
| - except ImportError as e: |
70 |
| - raise ImportError("Postgres dependencies not found") from e |
| 66 | + def stats(self, store_type: str) -> None: |
| 67 | + template = "SELECT '{table}', COUNT(*), pg_size_pretty(pg_total_relation_size('{table}')) FROM {table}" |
| 68 | + sql = " UNION ALL ".join( |
| 69 | + template.format(table=tbl) for tbl in self.tables[store_type] |
| 70 | + ) |
| 71 | + |
| 72 | + cur = self.conn.cursor() |
| 73 | + try: |
| 74 | + print(f"Storage for Postgres {store_type}.") |
| 75 | + print("{:<15} | {:>15} | {:>9}".format("Table", "Rows", "Size")) |
| 76 | + print("-" * 45) # Print a line separator |
71 | 77 |
|
72 |
| - cur = conn = None |
73 |
| - try: |
74 |
| - tables = { |
75 |
| - "nodestore": ["data_docstore", "data_indexstore"], |
76 |
| - "vectorstore": ["data_embeddings"], |
77 |
| - }[dbtype] |
78 |
| - connection = settings().postgres.model_dump(exclude_none=True) |
79 |
| - schema = connection.pop("schema_name") |
80 |
| - conn = psycopg2.connect(**connection) |
81 |
| - cur = conn.cursor() |
82 |
| - for table in tables: |
83 |
| - sql = f"DROP TABLE IF EXISTS {schema}.{table}" |
84 | 78 | cur.execute(sql)
|
85 |
| - print(f"Table {schema}.{table} dropped.") |
86 |
| - conn.commit() |
87 |
| - except psycopg2.Error as e: |
88 |
| - print("Error:", e) |
89 |
| - finally: |
90 |
| - if cur: |
| 79 | + for row in cur.fetchall(): |
| 80 | + formatted_row_count = f"{row[1]:,}" |
| 81 | + print(f"{row[0]:<15} | {formatted_row_count:>15} | {row[2]:>9}") |
| 82 | + |
| 83 | + print() |
| 84 | + finally: |
91 | 85 | cur.close()
|
92 |
| - if conn: |
93 |
| - conn.close() |
94 | 86 |
|
| 87 | + def __del__(self): |
| 88 | + if hasattr(self, "conn") and self.conn: |
| 89 | + self.conn.close() |
95 | 90 |
|
96 |
| -def wipe_chroma(dbtype: str): |
97 |
| - assert dbtype == "vectorstore" |
98 |
| - wipe_tree(str((local_data_path / "chroma_db").absolute())) |
99 | 91 |
|
| 92 | +class Simple: |
| 93 | + def wipe(self, store_type: str) -> None: |
| 94 | + assert store_type == "nodestore" |
| 95 | + from llama_index.core.storage.docstore.types import ( |
| 96 | + DEFAULT_PERSIST_FNAME as DOCSTORE, |
| 97 | + ) |
| 98 | + from llama_index.core.storage.index_store.types import ( |
| 99 | + DEFAULT_PERSIST_FNAME as INDEXSTORE, |
| 100 | + ) |
100 | 101 |
|
101 |
| -if __name__ == "__main__": |
102 |
| - commands = { |
103 |
| - "wipe": wipe, |
| 102 | + for store in (DOCSTORE, INDEXSTORE): |
| 103 | + wipe_file(str((local_data_path / store).absolute())) |
| 104 | + |
| 105 | + |
| 106 | +class Chroma: |
| 107 | + def wipe(self, store_type: str) -> None: |
| 108 | + assert store_type == "vectorstore" |
| 109 | + wipe_tree(str((local_data_path / "chroma_db").absolute())) |
| 110 | + |
| 111 | + |
| 112 | +class Qdrant: |
| 113 | + COLLECTION = ( |
| 114 | + "make_this_parameterizable_per_api_call" # ?! see vector_store_component.py |
| 115 | + ) |
| 116 | + |
| 117 | + def __init__(self) -> None: |
| 118 | + try: |
| 119 | + from qdrant_client import QdrantClient # type: ignore |
| 120 | + except ImportError: |
| 121 | + raise ImportError("Qdrant dependencies not found") from None |
| 122 | + self.client = QdrantClient(**settings().qdrant.model_dump(exclude_none=True)) |
| 123 | + |
| 124 | + def wipe(self, store_type: str) -> None: |
| 125 | + assert store_type == "vectorstore" |
| 126 | + try: |
| 127 | + self.client.delete_collection(self.COLLECTION) |
| 128 | + print("Collection dropped successfully.") |
| 129 | + except Exception as e: |
| 130 | + print("Error dropping collection:", e) |
| 131 | + |
| 132 | + def stats(self, store_type: str) -> None: |
| 133 | + print(f"Storage for Qdrant {store_type}.") |
| 134 | + try: |
| 135 | + collection_data = self.client.get_collection(self.COLLECTION) |
| 136 | + if collection_data: |
| 137 | + # Collection Info |
| 138 | + # https://qdrant.tech/documentation/concepts/collections/ |
| 139 | + print(f"\tPoints: {collection_data.points_count:,}") |
| 140 | + print(f"\tVectors: {collection_data.vectors_count:,}") |
| 141 | + print(f"\tIndex Vectors: {collection_data.indexed_vectors_count:,}") |
| 142 | + return |
| 143 | + except ValueError: |
| 144 | + pass |
| 145 | + print("\t- Qdrant collection not found or empty") |
| 146 | + |
| 147 | + |
| 148 | +class Command: |
| 149 | + DB_HANDLERS: ClassVar[dict[str, Any]] = { |
| 150 | + "simple": Simple, # node store |
| 151 | + "chroma": Chroma, # vector store |
| 152 | + "postgres": Postgres, # node, index and vector store |
| 153 | + "qdrant": Qdrant, # vector store |
104 | 154 | }
|
105 | 155 |
|
| 156 | + def for_each_store(self, cmd: str): |
| 157 | + for store_type in ("nodestore", "vectorstore"): |
| 158 | + database = getattr(settings(), store_type).database |
| 159 | + handler_class = self.DB_HANDLERS.get(database) |
| 160 | + if handler_class is None: |
| 161 | + print(f"No handler found for database '{database}'") |
| 162 | + continue |
| 163 | + handler_instance = handler_class() # Instantiate the class |
| 164 | + # If the DB can handle this cmd dispatch it. |
| 165 | + if hasattr(handler_instance, cmd) and callable( |
| 166 | + func := getattr(handler_instance, cmd) |
| 167 | + ): |
| 168 | + func(store_type) |
| 169 | + else: |
| 170 | + print( |
| 171 | + f"Unable to execute command '{cmd}' on '{store_type}' in database '{database}'" |
| 172 | + ) |
| 173 | + |
| 174 | + def execute(self, cmd: str) -> None: |
| 175 | + if cmd in ("wipe", "stats"): |
| 176 | + self.for_each_store(cmd) |
| 177 | + |
| 178 | + |
| 179 | +if __name__ == "__main__": |
106 | 180 | parser = argparse.ArgumentParser()
|
107 |
| - parser.add_argument( |
108 |
| - "mode", help="select a mode to run", choices=list(commands.keys()) |
109 |
| - ) |
| 181 | + parser.add_argument("mode", help="select a mode to run", choices=["wipe", "stats"]) |
110 | 182 | args = parser.parse_args()
|
111 |
| - commands[args.mode.lower()]() |
| 183 | + |
| 184 | + Command().execute(args.mode.lower()) |
0 commit comments