Skip to content

Commit 2064ff5

Browse files
authored
cleanup: overwrite_graph and chat (#69)
* cleanup: `overwrite_graph` * fix: typo * fix: `chat` * temp: disable `chat`
1 parent bd47753 commit 2064ff5

File tree

1 file changed

+28
-20
lines changed

1 file changed

+28
-20
lines changed

nx_arangodb/classes/graph.py

+28-20
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(
204204

205205
self.__set_db(db)
206206
if all([self.__db, name]):
207-
self.__set_graph(name, default_node_type, edge_type_func)
207+
self.__set_graph(name, overwrite_graph, default_node_type, edge_type_func)
208208
self.__set_edge_collections_attributes(edge_collections_attributes)
209209

210210
# NOTE: Need to revisit these...
@@ -232,23 +232,6 @@ def __init__(
232232
self._set_factory_methods(read_parallelism, read_batch_size)
233233
self.__set_arangodb_backend_config()
234234

235-
if overwrite_graph:
236-
logger.info("Overwriting graph...")
237-
238-
properties = self.adb_graph.properties()
239-
self.db.delete_graph(name, drop_collections=True)
240-
self.db.create_graph(
241-
name=name,
242-
edge_definitions=properties["edge_definitions"],
243-
orphan_collections=properties["orphan_collections"],
244-
smart=properties.get("smart"),
245-
disjoint=properties.get("disjoint"),
246-
smart_field=properties.get("smart_field"),
247-
shard_count=properties.get("shard_count"),
248-
replication_factor=properties.get("replication_factor"),
249-
write_concern=properties.get("write_concern"),
250-
)
251-
252235
if isinstance(incoming_graph_data, nx.Graph):
253236
self._load_nx_graph(incoming_graph_data, write_batch_size, write_async)
254237
self._loaded_incoming_graph_data = True
@@ -367,13 +350,33 @@ def __set_db(self, db: Any = None) -> None:
367350
def __set_graph(
368351
self,
369352
name: Any,
353+
overwrite_graph: bool,
370354
default_node_type: str | None = None,
371355
edge_type_func: Callable[[str, str], str] | None = None,
372356
) -> None:
373357
if not isinstance(name, str):
374358
raise TypeError("**name** must be a string")
375359

376-
if self.db.has_graph(name):
360+
graph_exists = self.db.has_graph(name)
361+
362+
if graph_exists and overwrite_graph:
363+
logger.info(f"Overwriting graph '{name}'")
364+
365+
properties = self.db.graph(name).properties()
366+
self.db.delete_graph(name, drop_collections=True)
367+
self.db.create_graph(
368+
name=name,
369+
edge_definitions=properties["edge_definitions"],
370+
orphan_collections=properties["orphan_collections"],
371+
smart=properties.get("smart"),
372+
disjoint=properties.get("disjoint"),
373+
smart_field=properties.get("smart_field"),
374+
shard_count=properties.get("shard_count"),
375+
replication_factor=properties.get("replication_factor"),
376+
write_concern=properties.get("write_concern"),
377+
)
378+
379+
if graph_exists:
377380
logger.info(f"Graph '{name}' exists.")
378381

379382
if edge_type_func is not None:
@@ -613,9 +616,14 @@ def chat(
613616
if llm is None:
614617
llm = ChatOpenAI(temperature=0, model_name="gpt-4")
615618

619+
graph = ArangoGraph(
620+
self.db,
621+
# graph_name=self.name # not yet supported
622+
)
623+
616624
chain = ArangoGraphQAChain.from_llm(
617625
llm=llm,
618-
graph=ArangoGraph(self.db),
626+
graph=graph,
619627
verbose=verbose,
620628
)
621629

0 commit comments

Comments
 (0)