@@ -204,7 +204,7 @@ def __init__(
204
204
205
205
self .__set_db (db )
206
206
if all ([self .__db , name ]):
207
- self .__set_graph (name , default_node_type , edge_type_func )
207
+ self .__set_graph (name , overwrite_graph , default_node_type , edge_type_func )
208
208
self .__set_edge_collections_attributes (edge_collections_attributes )
209
209
210
210
# NOTE: Need to revisit these...
@@ -232,23 +232,6 @@ def __init__(
232
232
self ._set_factory_methods (read_parallelism , read_batch_size )
233
233
self .__set_arangodb_backend_config ()
234
234
235
- if overwrite_graph :
236
- logger .info ("Overwriting graph..." )
237
-
238
- properties = self .adb_graph .properties ()
239
- self .db .delete_graph (name , drop_collections = True )
240
- self .db .create_graph (
241
- name = name ,
242
- edge_definitions = properties ["edge_definitions" ],
243
- orphan_collections = properties ["orphan_collections" ],
244
- smart = properties .get ("smart" ),
245
- disjoint = properties .get ("disjoint" ),
246
- smart_field = properties .get ("smart_field" ),
247
- shard_count = properties .get ("shard_count" ),
248
- replication_factor = properties .get ("replication_factor" ),
249
- write_concern = properties .get ("write_concern" ),
250
- )
251
-
252
235
if isinstance (incoming_graph_data , nx .Graph ):
253
236
self ._load_nx_graph (incoming_graph_data , write_batch_size , write_async )
254
237
self ._loaded_incoming_graph_data = True
@@ -367,13 +350,33 @@ def __set_db(self, db: Any = None) -> None:
367
350
def __set_graph (
368
351
self ,
369
352
name : Any ,
353
+ overwrite_graph : bool ,
370
354
default_node_type : str | None = None ,
371
355
edge_type_func : Callable [[str , str ], str ] | None = None ,
372
356
) -> None :
373
357
if not isinstance (name , str ):
374
358
raise TypeError ("**name** must be a string" )
375
359
376
- if self .db .has_graph (name ):
360
+ graph_exists = self .db .has_graph (name )
361
+
362
+ if graph_exists and overwrite_graph :
363
+ logger .info (f"Overwriting graph '{ name } '" )
364
+
365
+ properties = self .db .graph (name ).properties ()
366
+ self .db .delete_graph (name , drop_collections = True )
367
+ self .db .create_graph (
368
+ name = name ,
369
+ edge_definitions = properties ["edge_definitions" ],
370
+ orphan_collections = properties ["orphan_collections" ],
371
+ smart = properties .get ("smart" ),
372
+ disjoint = properties .get ("disjoint" ),
373
+ smart_field = properties .get ("smart_field" ),
374
+ shard_count = properties .get ("shard_count" ),
375
+ replication_factor = properties .get ("replication_factor" ),
376
+ write_concern = properties .get ("write_concern" ),
377
+ )
378
+
379
+ if graph_exists :
377
380
logger .info (f"Graph '{ name } ' exists." )
378
381
379
382
if edge_type_func is not None :
@@ -613,9 +616,14 @@ def chat(
613
616
if llm is None :
614
617
llm = ChatOpenAI (temperature = 0 , model_name = "gpt-4" )
615
618
619
+ graph = ArangoGraph (
620
+ self .db ,
621
+ # graph_name=self.name # not yet supported
622
+ )
623
+
616
624
chain = ArangoGraphQAChain .from_llm (
617
625
llm = llm ,
618
- graph = ArangoGraph ( self . db ) ,
626
+ graph = graph ,
619
627
verbose = verbose ,
620
628
)
621
629
0 commit comments