9
9
from dialog_lib .embeddings .generate import generate_embeddings
10
10
from dialog .llm .embeddings import EMBEDDINGS_LLM
11
11
from dialog_lib .db .models import CompanyContent
12
- from dialog .db import get_session
12
+ from dialog .db import session_scope
13
13
from dialog .settings import Settings
14
14
15
15
import logging
21
21
22
22
logger = logging .getLogger ("make_embeddings" )
23
23
24
- session = next (get_session ())
25
24
NECESSARY_COLS = ["category" , "subcategory" , "question" , "content" ]
26
25
PK_METADATA_COLS = ["category" , "subcategory" , "question" ]
27
26
@@ -33,7 +32,7 @@ def _get_csv_cols(path: str) -> List[str]:
33
32
return reader .fieldnames
34
33
35
34
36
- def retrieve_docs_from_vectordb () -> List [Document ]:
35
+ def retrieve_docs_from_vectordb (session ) -> List [Document ]:
37
36
"""Retrieve all documents from the vector store."""
38
37
company_contents : List [CompanyContent ] = session .query (CompanyContent ).all ()
39
38
return [
@@ -97,7 +96,7 @@ def load_csv_with_metadata(
97
96
98
97
99
98
def load_csv_and_generate_embeddings (
100
- path , cleardb = False , embed_columns : Optional [list [str ]] = None
99
+ path , session , cleardb = False , embed_columns : Optional [list [str ]] = None
101
100
):
102
101
"""
103
102
Load the knowledge base CSV, get their embeddings and store them into the vector store.
@@ -121,7 +120,7 @@ def load_csv_and_generate_embeddings(
121
120
session .commit ()
122
121
123
122
# Get existing docs
124
- docs_in_db : List [Document ] = retrieve_docs_from_vectordb ()
123
+ docs_in_db : List [Document ] = retrieve_docs_from_vectordb (session )
125
124
logging .info (f"Existing docs: { len (docs_in_db )} " )
126
125
existing_pks : List [str ] = [
127
126
get_document_pk (doc , PK_METADATA_COLS ) for doc in docs_in_db
@@ -160,6 +159,7 @@ def load_csv_and_generate_embeddings(
160
159
parser .add_argument ("--embed-columns" , default = "content" )
161
160
args = parser .parse_args ()
162
161
163
- load_csv_and_generate_embeddings (
164
- args .path , args .cleardb , args .embed_columns .split ("," )
165
- )
162
+ with session_scope () as session :
163
+ load_csv_and_generate_embeddings (
164
+ args .path , session , args .cleardb , args .embed_columns .split ("," )
165
+ )
0 commit comments