Skip to content

Commit e94750e

Browse files
authored
ES-1867 | force_one_shard_attribute_value param (#314)
1 parent 5e93203 commit e94750e

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

arango/aql.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def execute(
276276
fill_block_cache: Optional[bool] = None,
277277
allow_dirty_read: bool = False,
278278
allow_retry: bool = False,
279+
force_one_shard_attribute_value: Optional[str] = None,
279280
) -> Result[Cursor]:
280281
"""Execute the query and return the result cursor.
281282
@@ -373,6 +374,16 @@ def execute(
373374
:param allow_retry: Make it possible to retry fetching the latest batch
374375
from a cursor.
375376
:type allow_retry: bool
377+
:param force_one_shard_attribute_value: (Enterprise Only) Explicitly set
378+
a shard key value that will be used during query snippet distribution
379+
to limit the query to a specific server in the cluster. This query option
380+
can be used in complex queries in case the query optimizer cannot
381+
automatically detect that the query can be limited to only a single
382+
server (e.g. in a disjoint smart graph case). If the option is set
383+
incorrectly, i.e. to a wrong shard key value, then the query may be
384+
shipped to a wrong DB server and may not return results
385+
(i.e. empty result set). Use at your own risk.
386+
:param force_one_shard_attribute_value: str | None
376387
:return: Result cursor.
377388
:rtype: arango.cursor.Cursor
378389
:raise arango.exceptions.AQLQueryExecuteError: If execute fails.
@@ -418,10 +429,10 @@ def execute(
418429
options["skipInaccessibleCollections"] = skip_inaccessible_cols
419430
if max_runtime is not None:
420431
options["maxRuntime"] = max_runtime
421-
422-
# New in 3.11
423432
if allow_retry is not None:
424433
options["allowRetry"] = allow_retry
434+
if force_one_shard_attribute_value is not None:
435+
options["forceOneShardAttributeValue"] = force_one_shard_attribute_value
425436

426437
if options:
427438
data["options"] = options

tests/test_aql.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
AQLQueryTrackingSetError,
1818
AQLQueryValidateError,
1919
)
20-
from tests.helpers import assert_raises, extract
20+
from tests.helpers import assert_raises, extract, generate_col_name
2121

2222

2323
def test_aql_attributes(db, username):
@@ -246,6 +246,36 @@ def test_aql_query_management(db_version, db, bad_db, col, docs):
246246
assert err.value.error_code in {11, 1228}
247247

248248

249+
def test_aql_query_force_one_shard_attribute_value(db, db_version, enterprise, cluster):
250+
if db_version < version.parse("3.10") or not enterprise or not cluster:
251+
return
252+
253+
name = generate_col_name()
254+
col = db.create_collection(name, shard_fields=["foo"], shard_count=3)
255+
256+
doc = {"foo": "bar"}
257+
col.insert(doc)
258+
259+
cursor = db.aql.execute(
260+
"FOR d IN @@c RETURN d",
261+
bind_vars={"@c": name},
262+
force_one_shard_attribute_value="bar",
263+
)
264+
265+
results = [doc for doc in cursor]
266+
assert len(results) == 1
267+
assert results[0]["foo"] == "bar"
268+
269+
cursor = db.aql.execute(
270+
"FOR d IN @@c RETURN d",
271+
bind_vars={"@c": name},
272+
force_one_shard_attribute_value="ooo",
273+
)
274+
275+
results = [doc for doc in cursor]
276+
assert len(results) == 0
277+
278+
249279
def test_aql_function_management(db, bad_db):
250280
fn_group = "functions::temperature"
251281
fn_name_1 = "functions::temperature::celsius_to_fahrenheit"

0 commit comments

Comments
 (0)