Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.7.5 Cherry Picks #1967

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import logging
from typing import List, Sequence, Tuple
from typing import Optional, List, Sequence, Tuple, Text

from pydid import (
BaseDIDDocument as ResolvedDocument,
Expand Down Expand Up @@ -223,7 +223,9 @@ async def remove_keys_for_did(self, did: str):
storage: BaseStorage = session.inject(BaseStorage)
await storage.delete_all_records(self.RECORD_TYPE_DID_KEY, {"did": did})

async def resolve_invitation(self, did: str):
async def resolve_invitation(
self, did: str, service_accept: Optional[Sequence[Text]] = None
):
"""
Resolve invitation with the DID Resolver.

Expand All @@ -237,7 +239,7 @@ async def resolve_invitation(self, did: str):

resolver = self._profile.inject(DIDResolver)
try:
doc_dict: dict = await resolver.resolve(self._profile, did)
doc_dict: dict = await resolver.resolve(self._profile, did, service_accept)
doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True)
except ResolverError as error:
raise BaseConnectionManagerError(
Expand Down
15 changes: 8 additions & 7 deletions aries_cloudagent/core/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,8 @@ async def start(self) -> None:
async def stop(self, timeout=1.0):
"""Stop the agent."""
# notify protcols that we are shutting down
await self.root_profile.notify(SHUTDOWN_EVENT_TOPIC, {})
if self.root_profile:
await self.root_profile.notify(SHUTDOWN_EVENT_TOPIC, {})

shutdown = TaskQueue()
if self.dispatcher:
Expand All @@ -485,13 +486,13 @@ async def stop(self, timeout=1.0):
if self.outbound_transport_manager:
shutdown.run(self.outbound_transport_manager.stop())

# close multitenant profiles
multitenant_mgr = self.context.inject_or(BaseMultitenantManager)
if multitenant_mgr:
for profile in multitenant_mgr.open_profiles:
shutdown.run(profile.close())

if self.root_profile:
# close multitenant profiles
multitenant_mgr = self.context.inject_or(BaseMultitenantManager)
if multitenant_mgr:
for profile in multitenant_mgr.open_profiles:
shutdown.run(profile.close())

shutdown.run(self.root_profile.close())

await shutdown.complete(timeout)
Expand Down
72 changes: 67 additions & 5 deletions aries_cloudagent/core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import warnings

from typing import Callable, Coroutine, Union
from typing import Callable, Coroutine, Optional, Union, Tuple
import weakref

from aiohttp.web import HTTPException
Expand All @@ -36,6 +36,13 @@

from .error import ProtocolMinorVersionNotSupported
from .protocol_registry import ProtocolRegistry
from .util import (
get_version_from_message_type,
validate_get_response_version,
# WARNING_DEGRADED_FEATURES,
# WARNING_VERSION_MISMATCH,
# WARNING_VERSION_NOT_SUPPORTED,
)

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -133,16 +140,22 @@ async def handle_message(
inbound_message: The inbound message instance
send_outbound: Async function to send outbound messages

# Raises:
# MessageParseError: If the message type version is not supported

Returns:
The response from the handler

"""
r_time = get_timer()

error_result = None
version_warning = None
message = None
try:
message = await self.make_message(inbound_message.payload)
(message, warning) = await self.make_message(
profile, inbound_message.payload
)
except ProblemReportParseError:
pass # avoid problem report recursion
except MessageParseError as e:
Expand All @@ -155,6 +168,47 @@ async def handle_message(
)
if inbound_message.receipt.thread_id:
error_result.assign_thread_id(inbound_message.receipt.thread_id)
# if warning:
# warning_message_type = inbound_message.payload.get("@type")
# if warning == WARNING_DEGRADED_FEATURES:
# LOGGER.error(
# f"Sending {WARNING_DEGRADED_FEATURES} problem report, "
# "message type received with a minor version at or higher"
# " than protocol minimum supported and current minor version "
# f"for message_type {warning_message_type}"
# )
# version_warning = ProblemReport(
# description={
# "en": (
# "message type received with a minor version at or "
# "higher than protocol minimum supported and current"
# f" minor version for message_type {warning_message_type}"
# ),
# "code": WARNING_DEGRADED_FEATURES,
# }
# )
# elif warning == WARNING_VERSION_MISMATCH:
# LOGGER.error(
# f"Sending {WARNING_VERSION_MISMATCH} problem report, message "
# "type received with a minor version higher than current minor "
# f"version for message_type {warning_message_type}"
# )
# version_warning = ProblemReport(
# description={
# "en": (
# "message type received with a minor version higher"
# " than current minor version for message_type"
# f" {warning_message_type}"
# ),
# "code": WARNING_VERSION_MISMATCH,
# }
# )
# elif warning == WARNING_VERSION_NOT_SUPPORTED:
# raise MessageParseError(
# f"Message type version not supported for {warning_message_type}"
# )
# if version_warning and inbound_message.receipt.thread_id:
# version_warning.assign_thread_id(inbound_message.receipt.thread_id)

trace_event(
self.profile.settings,
Expand Down Expand Up @@ -199,6 +253,8 @@ async def handle_message(

if error_result:
await responder.send_reply(error_result)
elif version_warning:
await responder.send_reply(version_warning)
elif context.message:
context.injector.bind_instance(BaseResponder, responder)

Expand All @@ -215,7 +271,9 @@ async def handle_message(
perf_counter=r_time,
)

async def make_message(self, parsed_msg: dict) -> BaseMessage:
async def make_message(
self, profile: Profile, parsed_msg: dict
) -> Tuple[BaseMessage, Optional[str]]:
"""
Deserialize a message dict into the appropriate message instance.

Expand All @@ -224,6 +282,7 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage:

Args:
parsed_msg: The parsed message
profile: Profile

Returns:
An instance of the corresponding message class for this message
Expand All @@ -237,6 +296,7 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage:
if not isinstance(parsed_msg, dict):
raise MessageParseError("Expected a JSON object")
message_type = parsed_msg.get("@type")
message_type_rec_version = get_version_from_message_type(message_type)

if not message_type:
raise MessageParseError("Message does not contain '@type' parameter")
Expand All @@ -256,8 +316,10 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage:
if "/problem-report" in message_type:
raise ProblemReportParseError("Error parsing problem report message")
raise MessageParseError(f"Error deserializing message: {e}") from e

return instance
_, warning = await validate_get_response_version(
profile, message_type_rec_version, message_cls
)
return (instance, warning)

async def complete(self, timeout: float = 0.1):
"""Wait for pending tasks to complete."""
Expand Down
105 changes: 91 additions & 14 deletions aries_cloudagent/core/protocol_registry.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Handle registration and publication of supported protocols."""

import logging
import re

from typing import Mapping, Sequence

from ..config.injection_context import InjectionContext
from ..utils.classloader import ClassLoader

from .error import ProtocolMinorVersionNotSupported
from .error import ProtocolMinorVersionNotSupported, ProtocolDefinitionValidationError

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,6 +75,79 @@ def parse_type_string(self, message_type):
"minor_version": int(version_string_tokens[1]),
}

def create_msg_types_for_minor_version(self, typesets, version_definition):
"""
Return mapping of message type to module path for minor versions.

Args:
typesets: Mappings of message types to register
version_definition: Optional version definition dict

Returns:
Typesets mapping

"""
updated_typeset = {}
curr_minor_version = version_definition["current_minor_version"]
min_minor_version = version_definition["minimum_minor_version"]
major_version = version_definition["major_version"]
if curr_minor_version >= min_minor_version:
for version_index in range(min_minor_version, curr_minor_version + 1):
to_check = f"{str(major_version)}.{str(version_index)}"
updated_typeset.update(
self._get_updated_typeset_dict(typesets, to_check, updated_typeset)
)
else:
raise ProtocolDefinitionValidationError(
"min_minor_version is greater than curr_minor_version for the"
f" following typeset: {str(typesets)}"
)
return (updated_typeset,)

def _get_updated_typeset_dict(self, typesets, to_check, updated_typeset) -> dict:
for typeset in typesets:
for msg_type_string, module_path in typeset.items():
updated_msg_type_string = re.sub(
r"(\d+\.)?(\*|\d+)", to_check, msg_type_string
)
updated_typeset[updated_msg_type_string] = module_path
return updated_typeset

def _message_type_check_for_minor_verssion(self, version_definition) -> bool:
if not version_definition:
return False
curr_minor_version = version_definition["current_minor_version"]
min_minor_version = version_definition["minimum_minor_version"]
return bool(curr_minor_version >= 1 and curr_minor_version >= min_minor_version)

def _create_and_register_updated_typesets(self, typesets, version_definition):
updated_typesets = self.create_msg_types_for_minor_version(
typesets, version_definition
)
update_flag = False
for typeset in updated_typesets:
if typeset:
self._typemap.update(typeset)
update_flag = True
if update_flag:
return updated_typesets
else:
return None

def _update_version_map(self, message_type_string, module_path, version_definition):
parsed_type_string = self.parse_type_string(message_type_string)

if version_definition["major_version"] not in self._versionmap:
self._versionmap[version_definition["major_version"]] = []

self._versionmap[version_definition["major_version"]].append(
{
"parsed_type_string": parsed_type_string,
"version_definition": version_definition,
"message_module": module_path,
}
)

def register_message_types(self, *typesets, version_definition=None):
"""
Add new supported message types.
Expand All @@ -85,24 +159,27 @@ def register_message_types(self, *typesets, version_definition=None):
"""

# Maintain support for versionless protocol modules
for typeset in typesets:
self._typemap.update(typeset)
updated_typesets = None
minor_versions_supported = self._message_type_check_for_minor_verssion(
version_definition
)
if not minor_versions_supported:
for typeset in typesets:
self._typemap.update(typeset)

# Track versioned modules for version routing
if version_definition:
# create updated typesets for minor versions and register them
if minor_versions_supported:
updated_typesets = self._create_and_register_updated_typesets(
typesets, version_definition
)
if updated_typesets:
typesets = updated_typesets
for typeset in typesets:
for message_type_string, module_path in typeset.items():
parsed_type_string = self.parse_type_string(message_type_string)

if version_definition["major_version"] not in self._versionmap:
self._versionmap[version_definition["major_version"]] = []

self._versionmap[version_definition["major_version"]].append(
{
"parsed_type_string": parsed_type_string,
"version_definition": version_definition,
"message_module": module_path,
}
self._update_version_map(
message_type_string, module_path, version_definition
)

def register_controllers(self, *controller_sets, version_definition=None):
Expand Down
Loading