Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for sort in collection.find function #359

Merged
merged 7 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,6 @@ arango/version.py

# test results
*_results.txt

# devcontainers
.devcontainer
11 changes: 10 additions & 1 deletion arango/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@
from arango.typings import Fields, Headers, Json, Jsons, Params
from arango.utils import (
build_filter_conditions,
build_sort_expression,
get_batches,
get_doc_id,
is_none_or_bool,
is_none_or_int,
is_none_or_str,
validate_sort_parameters,
)


Expand Down Expand Up @@ -718,6 +720,7 @@ def all(
:return: Document cursor.
:rtype: arango.cursor.Cursor
:raise arango.exceptions.DocumentGetError: If retrieval fails.
:raise arango.exceptions.SortValidationError: If sort parameters are invalid.
"""
assert is_none_or_int(skip), "skip must be a non-negative int"
assert is_none_or_int(limit), "limit must be a non-negative int"
Expand Down Expand Up @@ -753,6 +756,7 @@ def find(
skip: Optional[int] = None,
limit: Optional[int] = None,
allow_dirty_read: bool = False,
sort: Optional[Jsons] = None,
) -> Result[Cursor]:
"""Return all documents that match the given filters.

Expand All @@ -764,23 +768,28 @@ def find(
:type limit: int | None
:param allow_dirty_read: Allow reads from followers in a cluster.
:type allow_dirty_read: bool
:param sort: Document sort parameters
:type sort: Jsons | None
:return: Document cursor.
:rtype: arango.cursor.Cursor
:raise arango.exceptions.DocumentGetError: If retrieval fails.
:raise arango.exceptions.SortValidationError: If sort parameters are invalid.
"""
assert isinstance(filters, dict), "filters must be a dict"
assert is_none_or_int(skip), "skip must be a non-negative int"
assert is_none_or_int(limit), "limit must be a non-negative int"
if sort:
validate_sort_parameters(sort)

skip_val = skip if skip is not None else 0
limit_val = limit if limit is not None else "null"
query = f"""
FOR doc IN @@collection
{build_filter_conditions(filters)}
LIMIT {skip_val}, {limit_val}
{build_sort_expression(sort)}
RETURN doc
"""

bind_vars = {"@collection": self.name}

request = Request(
Expand Down
7 changes: 7 additions & 0 deletions arango/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,3 +1074,10 @@ class JWTRefreshError(ArangoClientError):

class JWTExpiredError(ArangoClientError):
"""JWT token has expired."""


###################################
# Parameter Validation Exceptions #
###################################
class SortValidationError(ArangoClientError):
"""Invalid sort parameters."""
43 changes: 41 additions & 2 deletions arango/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from contextlib import contextmanager
from typing import Any, Iterator, Sequence, Union

from arango.exceptions import DocumentParseError
from arango.typings import Json
from arango.exceptions import DocumentParseError, SortValidationError
from arango.typings import Json, Jsons


@contextmanager
Expand Down Expand Up @@ -126,3 +126,42 @@ def build_filter_conditions(filters: Json) -> str:
conditions.append(f"doc.{field} == {json.dumps(v)}")

return "FILTER " + " AND ".join(conditions)


def validate_sort_parameters(sort: Sequence[Json]) -> bool:
"""Validate sort parameters for an AQL query.

:param sort: Document sort parameters.
:type sort: Sequence[Json]
:return: Validation success.
:rtype: bool
:raise arango.exceptions.SortValidationError: If sort parameters are invalid.
"""
assert isinstance(sort, Sequence)
for param in sort:
if "sort_by" not in param or "sort_order" not in param:
raise SortValidationError(
"Each sort parameter must have 'sort_by' and 'sort_order'."
)
if param["sort_order"].upper() not in ["ASC", "DESC"]:
raise SortValidationError("'sort_order' must be either 'ASC' or 'DESC'")
return True


def build_sort_expression(sort: Jsons | None) -> str:
"""Build a sort condition for an AQL query.

:param sort: Document sort parameters.
:type sort: Jsons | None
:return: The complete AQL sort condition.
:rtype: str
"""
if not sort:
return ""

sort_chunks = []
for sort_param in sort:
chunk = f"doc.{sort_param['sort_by']} {sort_param['sort_order']}"
sort_chunks.append(chunk)

return "SORT " + ", ".join(sort_chunks)
6 changes: 6 additions & 0 deletions docs/document.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ Standard documents are managed via collection API wrapper:
assert student['GPA'] == 3.6
assert student['last'] == 'Kim'

# Retrieve one or more matching documents, sorted by a field.
for student in students.find({'first': 'John'}, sort=[{'sort_by': 'GPA', 'sort_order': 'DESC'}]):
assert student['_key'] == 'john'
assert student['GPA'] == 3.6
assert student['last'] == 'Kim'

# Retrieve a document by key.
students.get('john')

Expand Down
20 changes: 20 additions & 0 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,26 @@ def test_document_find(col, bad_col, docs):
# Set up test documents
col.import_bulk(docs)

# Test find with sort expression (single field)
found = list(col.find({}, sort=[{"sort_by": "text", "sort_order": "ASC"}]))
assert len(found) == 6
assert found[0]["text"] == "bar"
assert found[-1]["text"] == "foo"

# Test find with sort expression (multiple fields)
found = list(
col.find(
{},
sort=[
{"sort_by": "text", "sort_order": "ASC"},
{"sort_by": "val", "sort_order": "DESC"},
],
)
)
assert len(found) == 6
assert found[0]["val"] == 6
assert found[-1]["val"] == 1

# Test find (single match) with default options
found = list(col.find({"val": 2}))
assert len(found) == 1
Expand Down
Loading