diff --git a/aries_cloudagent/connections/base_manager.py b/aries_cloudagent/connections/base_manager.py index bd5e281694..7bb4ed0025 100644 --- a/aries_cloudagent/connections/base_manager.py +++ b/aries_cloudagent/connections/base_manager.py @@ -5,7 +5,7 @@ """ import logging -from typing import List, Sequence, Tuple +from typing import Optional, List, Sequence, Tuple, Text from pydid import ( BaseDIDDocument as ResolvedDocument, @@ -223,7 +223,9 @@ async def remove_keys_for_did(self, did: str): storage: BaseStorage = session.inject(BaseStorage) await storage.delete_all_records(self.RECORD_TYPE_DID_KEY, {"did": did}) - async def resolve_invitation(self, did: str): + async def resolve_invitation( + self, did: str, service_accept: Optional[Sequence[Text]] = None + ): """ Resolve invitation with the DID Resolver. @@ -237,7 +239,7 @@ async def resolve_invitation(self, did: str): resolver = self._profile.inject(DIDResolver) try: - doc_dict: dict = await resolver.resolve(self._profile, did) + doc_dict: dict = await resolver.resolve(self._profile, did, service_accept) doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True) except ResolverError as error: raise BaseConnectionManagerError( diff --git a/aries_cloudagent/core/conductor.py b/aries_cloudagent/core/conductor.py index 377e04b6f0..f3dcd94c58 100644 --- a/aries_cloudagent/core/conductor.py +++ b/aries_cloudagent/core/conductor.py @@ -473,7 +473,8 @@ async def start(self) -> None: async def stop(self, timeout=1.0): """Stop the agent.""" # notify protcols that we are shutting down - await self.root_profile.notify(SHUTDOWN_EVENT_TOPIC, {}) + if self.root_profile: + await self.root_profile.notify(SHUTDOWN_EVENT_TOPIC, {}) shutdown = TaskQueue() if self.dispatcher: @@ -485,13 +486,13 @@ async def stop(self, timeout=1.0): if self.outbound_transport_manager: shutdown.run(self.outbound_transport_manager.stop()) - # close multitenant profiles - multitenant_mgr = self.context.inject_or(BaseMultitenantManager) - if multitenant_mgr: - for profile in multitenant_mgr.open_profiles: - shutdown.run(profile.close()) - if self.root_profile: + # close multitenant profiles + multitenant_mgr = self.context.inject_or(BaseMultitenantManager) + if multitenant_mgr: + for profile in multitenant_mgr.open_profiles: + shutdown.run(profile.close()) + shutdown.run(self.root_profile.close()) await shutdown.complete(timeout) diff --git a/aries_cloudagent/core/dispatcher.py b/aries_cloudagent/core/dispatcher.py index e3f37b45ac..2193dd20b9 100644 --- a/aries_cloudagent/core/dispatcher.py +++ b/aries_cloudagent/core/dispatcher.py @@ -10,7 +10,7 @@ import os import warnings -from typing import Callable, Coroutine, Union +from typing import Callable, Coroutine, Optional, Union, Tuple import weakref from aiohttp.web import HTTPException @@ -36,6 +36,13 @@ from .error import ProtocolMinorVersionNotSupported from .protocol_registry import ProtocolRegistry +from .util import ( + get_version_from_message_type, + validate_get_response_version, + # WARNING_DEGRADED_FEATURES, + # WARNING_VERSION_MISMATCH, + # WARNING_VERSION_NOT_SUPPORTED, +) LOGGER = logging.getLogger(__name__) @@ -133,6 +140,9 @@ async def handle_message( inbound_message: The inbound message instance send_outbound: Async function to send outbound messages + # Raises: + # MessageParseError: If the message type version is not supported + Returns: The response from the handler @@ -140,9 +150,12 @@ async def handle_message( r_time = get_timer() error_result = None + version_warning = None message = None try: - message = await self.make_message(inbound_message.payload) + (message, warning) = await self.make_message( + profile, inbound_message.payload + ) except ProblemReportParseError: pass # avoid problem report recursion except MessageParseError as e: @@ -155,6 +168,47 @@ async def handle_message( ) if inbound_message.receipt.thread_id: error_result.assign_thread_id(inbound_message.receipt.thread_id) + # if warning: + # warning_message_type = inbound_message.payload.get("@type") + # if warning == WARNING_DEGRADED_FEATURES: + # LOGGER.error( + # f"Sending {WARNING_DEGRADED_FEATURES} problem report, " + # "message type received with a minor version at or higher" + # " than protocol minimum supported and current minor version " + # f"for message_type {warning_message_type}" + # ) + # version_warning = ProblemReport( + # description={ + # "en": ( + # "message type received with a minor version at or " + # "higher than protocol minimum supported and current" + # f" minor version for message_type {warning_message_type}" + # ), + # "code": WARNING_DEGRADED_FEATURES, + # } + # ) + # elif warning == WARNING_VERSION_MISMATCH: + # LOGGER.error( + # f"Sending {WARNING_VERSION_MISMATCH} problem report, message " + # "type received with a minor version higher than current minor " + # f"version for message_type {warning_message_type}" + # ) + # version_warning = ProblemReport( + # description={ + # "en": ( + # "message type received with a minor version higher" + # " than current minor version for message_type" + # f" {warning_message_type}" + # ), + # "code": WARNING_VERSION_MISMATCH, + # } + # ) + # elif warning == WARNING_VERSION_NOT_SUPPORTED: + # raise MessageParseError( + # f"Message type version not supported for {warning_message_type}" + # ) + # if version_warning and inbound_message.receipt.thread_id: + # version_warning.assign_thread_id(inbound_message.receipt.thread_id) trace_event( self.profile.settings, @@ -199,6 +253,8 @@ async def handle_message( if error_result: await responder.send_reply(error_result) + elif version_warning: + await responder.send_reply(version_warning) elif context.message: context.injector.bind_instance(BaseResponder, responder) @@ -215,7 +271,9 @@ async def handle_message( perf_counter=r_time, ) - async def make_message(self, parsed_msg: dict) -> BaseMessage: + async def make_message( + self, profile: Profile, parsed_msg: dict + ) -> Tuple[BaseMessage, Optional[str]]: """ Deserialize a message dict into the appropriate message instance. @@ -224,6 +282,7 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage: Args: parsed_msg: The parsed message + profile: Profile Returns: An instance of the corresponding message class for this message @@ -237,6 +296,7 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage: if not isinstance(parsed_msg, dict): raise MessageParseError("Expected a JSON object") message_type = parsed_msg.get("@type") + message_type_rec_version = get_version_from_message_type(message_type) if not message_type: raise MessageParseError("Message does not contain '@type' parameter") @@ -256,8 +316,10 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage: if "/problem-report" in message_type: raise ProblemReportParseError("Error parsing problem report message") raise MessageParseError(f"Error deserializing message: {e}") from e - - return instance + _, warning = await validate_get_response_version( + profile, message_type_rec_version, message_cls + ) + return (instance, warning) async def complete(self, timeout: float = 0.1): """Wait for pending tasks to complete.""" diff --git a/aries_cloudagent/core/protocol_registry.py b/aries_cloudagent/core/protocol_registry.py index 805c35efa7..90175f9109 100644 --- a/aries_cloudagent/core/protocol_registry.py +++ b/aries_cloudagent/core/protocol_registry.py @@ -1,13 +1,14 @@ """Handle registration and publication of supported protocols.""" import logging +import re from typing import Mapping, Sequence from ..config.injection_context import InjectionContext from ..utils.classloader import ClassLoader -from .error import ProtocolMinorVersionNotSupported +from .error import ProtocolMinorVersionNotSupported, ProtocolDefinitionValidationError LOGGER = logging.getLogger(__name__) @@ -74,6 +75,79 @@ def parse_type_string(self, message_type): "minor_version": int(version_string_tokens[1]), } + def create_msg_types_for_minor_version(self, typesets, version_definition): + """ + Return mapping of message type to module path for minor versions. + + Args: + typesets: Mappings of message types to register + version_definition: Optional version definition dict + + Returns: + Typesets mapping + + """ + updated_typeset = {} + curr_minor_version = version_definition["current_minor_version"] + min_minor_version = version_definition["minimum_minor_version"] + major_version = version_definition["major_version"] + if curr_minor_version >= min_minor_version: + for version_index in range(min_minor_version, curr_minor_version + 1): + to_check = f"{str(major_version)}.{str(version_index)}" + updated_typeset.update( + self._get_updated_typeset_dict(typesets, to_check, updated_typeset) + ) + else: + raise ProtocolDefinitionValidationError( + "min_minor_version is greater than curr_minor_version for the" + f" following typeset: {str(typesets)}" + ) + return (updated_typeset,) + + def _get_updated_typeset_dict(self, typesets, to_check, updated_typeset) -> dict: + for typeset in typesets: + for msg_type_string, module_path in typeset.items(): + updated_msg_type_string = re.sub( + r"(\d+\.)?(\*|\d+)", to_check, msg_type_string + ) + updated_typeset[updated_msg_type_string] = module_path + return updated_typeset + + def _message_type_check_for_minor_verssion(self, version_definition) -> bool: + if not version_definition: + return False + curr_minor_version = version_definition["current_minor_version"] + min_minor_version = version_definition["minimum_minor_version"] + return bool(curr_minor_version >= 1 and curr_minor_version >= min_minor_version) + + def _create_and_register_updated_typesets(self, typesets, version_definition): + updated_typesets = self.create_msg_types_for_minor_version( + typesets, version_definition + ) + update_flag = False + for typeset in updated_typesets: + if typeset: + self._typemap.update(typeset) + update_flag = True + if update_flag: + return updated_typesets + else: + return None + + def _update_version_map(self, message_type_string, module_path, version_definition): + parsed_type_string = self.parse_type_string(message_type_string) + + if version_definition["major_version"] not in self._versionmap: + self._versionmap[version_definition["major_version"]] = [] + + self._versionmap[version_definition["major_version"]].append( + { + "parsed_type_string": parsed_type_string, + "version_definition": version_definition, + "message_module": module_path, + } + ) + def register_message_types(self, *typesets, version_definition=None): """ Add new supported message types. @@ -85,24 +159,27 @@ def register_message_types(self, *typesets, version_definition=None): """ # Maintain support for versionless protocol modules - for typeset in typesets: - self._typemap.update(typeset) + updated_typesets = None + minor_versions_supported = self._message_type_check_for_minor_verssion( + version_definition + ) + if not minor_versions_supported: + for typeset in typesets: + self._typemap.update(typeset) # Track versioned modules for version routing if version_definition: + # create updated typesets for minor versions and register them + if minor_versions_supported: + updated_typesets = self._create_and_register_updated_typesets( + typesets, version_definition + ) + if updated_typesets: + typesets = updated_typesets for typeset in typesets: for message_type_string, module_path in typeset.items(): - parsed_type_string = self.parse_type_string(message_type_string) - - if version_definition["major_version"] not in self._versionmap: - self._versionmap[version_definition["major_version"]] = [] - - self._versionmap[version_definition["major_version"]].append( - { - "parsed_type_string": parsed_type_string, - "version_definition": version_definition, - "message_module": module_path, - } + self._update_version_map( + message_type_string, module_path, version_definition ) def register_controllers(self, *controller_sets, version_definition=None): diff --git a/aries_cloudagent/core/tests/test_dispatcher.py b/aries_cloudagent/core/tests/test_dispatcher.py index b9f00564b1..66b20fbac3 100644 --- a/aries_cloudagent/core/tests/test_dispatcher.py +++ b/aries_cloudagent/core/tests/test_dispatcher.py @@ -1,8 +1,9 @@ import json +from async_case import IsolatedAsyncioTestCase +import mock as async_mock import pytest -from asynctest import TestCase as AsyncTestCase, mock as async_mock from marshmallow import EXCLUDE from ...config.injection_context import InjectionContext @@ -87,7 +88,7 @@ async def handle(self, context, responder): pass -class TestDispatcher(AsyncTestCase): +class TestDispatcher(IsolatedAsyncioTestCase): async def test_dispatch(self): profile = make_profile() registry = profile.inject(ProtocolRegistry) @@ -108,9 +109,17 @@ async def test_dispatch(self): StubAgentMessageHandler, "handle", autospec=True ) as handler_mock, async_mock.patch.object( test_module, "ConnectionManager", autospec=True - ) as conn_mgr_mock: + ) as conn_mgr_mock, async_mock.patch.object( + test_module, + "get_version_from_message_type", + async_mock.AsyncMock(return_value="1.1"), + ), async_mock.patch.object( + test_module, + "validate_get_response_version", + async_mock.AsyncMock(return_value=("1.1", None)), + ): conn_mgr_mock.return_value = async_mock.MagicMock( - find_inbound_connection=async_mock.CoroutineMock( + find_inbound_connection=async_mock.AsyncMock( return_value=async_mock.MagicMock(connection_id="dummy") ) ) @@ -149,7 +158,15 @@ async def test_dispatch_versioned_message(self): with async_mock.patch.object( StubAgentMessageHandler, "handle", autospec=True - ) as handler_mock: + ) as handler_mock, async_mock.patch.object( + test_module, + "get_version_from_message_type", + async_mock.AsyncMock(return_value="1.1"), + ), async_mock.patch.object( + test_module, + "validate_get_response_version", + async_mock.AsyncMock(return_value=("1.1", None)), + ): await dispatcher.queue_message( dispatcher.profile, make_inbound(message), rcv.send ) @@ -262,7 +279,15 @@ async def test_dispatch_versioned_message_handle_greater_succeeds(self): with async_mock.patch.object( StubAgentMessageHandler, "handle", autospec=True - ) as handler_mock: + ) as handler_mock, async_mock.patch.object( + test_module, + "get_version_from_message_type", + async_mock.AsyncMock(return_value="1.1"), + ), async_mock.patch.object( + test_module, + "validate_get_response_version", + async_mock.AsyncMock(return_value=("1.1", None)), + ): await dispatcher.queue_message( dispatcher.profile, make_inbound(message), rcv.send ) @@ -314,17 +339,22 @@ async def test_bad_message_dispatch_parse_x(self): await dispatcher.setup() rcv = Receiver() bad_messages = ["not even a dict", {"bad": "message"}] - for bad in bad_messages: - await dispatcher.queue_message( - dispatcher.profile, make_inbound(bad), rcv.send - ) - await dispatcher.task_queue - assert rcv.messages and isinstance(rcv.messages[0][1], OutboundMessage) - payload = json.loads(rcv.messages[0][1].payload) - assert payload["@type"] == DIDCommPrefix.qualify_current( - ProblemReport.Meta.message_type - ) - rcv.messages.clear() + with async_mock.patch.object( + test_module, "get_version_from_message_type", async_mock.AsyncMock() + ), async_mock.patch.object( + test_module, "validate_get_response_version", async_mock.AsyncMock() + ): + for bad in bad_messages: + await dispatcher.queue_message( + dispatcher.profile, make_inbound(bad), rcv.send + ) + await dispatcher.task_queue + assert rcv.messages and isinstance(rcv.messages[0][1], OutboundMessage) + payload = json.loads(rcv.messages[0][1].payload) + assert payload["@type"] == DIDCommPrefix.qualify_current( + ProblemReport.Meta.message_type + ) + rcv.messages.clear() async def test_bad_message_dispatch_problem_report_x(self): profile = make_profile() @@ -383,7 +413,7 @@ async def test_create_send_outbound(self): message = StubAgentMessage() responder = test_module.DispatcherResponder(context, message, None) outbound_message = await responder.create_outbound(message) - with async_mock.patch.object(responder, "_send", async_mock.CoroutineMock()): + with async_mock.patch.object(responder, "_send", async_mock.AsyncMock()): await responder.send_outbound(outbound_message) async def test_create_send_webhook(self): @@ -400,7 +430,7 @@ async def test_create_enc_outbound(self): message = b"abc123xyz7890000" responder = test_module.DispatcherResponder(context, message, None) with async_mock.patch.object( - responder, "send_outbound", async_mock.CoroutineMock() + responder, "send_outbound", async_mock.AsyncMock() ) as mock_send_outbound: await responder.send(message) assert mock_send_outbound.called_once() @@ -421,3 +451,91 @@ def _smaller_scope(): with self.assertRaises(RuntimeError): await responder.send_webhook("test", {}) + + # async def test_dispatch_version_with_degraded_features(self): + # profile = make_profile() + # registry = profile.inject(ProtocolRegistry) + # registry.register_message_types( + # { + # pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage + # for pfx in DIDCommPrefix + # } + # ) + # dispatcher = test_module.Dispatcher(profile) + # await dispatcher.setup() + # rcv = Receiver() + # message = { + # "@type": DIDCommPrefix.qualify_current(StubAgentMessage.Meta.message_type) + # } + + # with async_mock.patch.object( + # test_module, + # "get_version_from_message_type", + # async_mock.AsyncMock(return_value="1.1"), + # ), async_mock.patch.object( + # test_module, + # "validate_get_response_version", + # async_mock.AsyncMock(return_value=("1.1", "fields-ignored-due-to-version-mismatch")), + # ): + # await dispatcher.queue_message( + # dispatcher.profile, make_inbound(message), rcv.send + # ) + + # async def test_dispatch_fields_ignored_due_to_version_mismatch(self): + # profile = make_profile() + # registry = profile.inject(ProtocolRegistry) + # registry.register_message_types( + # { + # pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage + # for pfx in DIDCommPrefix + # } + # ) + # dispatcher = test_module.Dispatcher(profile) + # await dispatcher.setup() + # rcv = Receiver() + # message = { + # "@type": DIDCommPrefix.qualify_current(StubAgentMessage.Meta.message_type) + # } + + # with async_mock.patch.object( + # test_module, + # "get_version_from_message_type", + # async_mock.AsyncMock(return_value="1.1"), + # ), async_mock.patch.object( + # test_module, + # "validate_get_response_version", + # async_mock.AsyncMock(return_value=("1.1", "version-with-degraded-features")), + # ): + # await dispatcher.queue_message( + # dispatcher.profile, make_inbound(message), rcv.send + # ) + + # async def test_dispatch_version_not_supported(self): + # profile = make_profile() + # registry = profile.inject(ProtocolRegistry) + # registry.register_message_types( + # { + # pfx.qualify(StubAgentMessage.Meta.message_type): StubAgentMessage + # for pfx in DIDCommPrefix + # } + # ) + # dispatcher = test_module.Dispatcher(profile) + # await dispatcher.setup() + # rcv = Receiver() + # message = { + # "@type": DIDCommPrefix.qualify_current(StubAgentMessage.Meta.message_type) + # } + + # with async_mock.patch.object( + # test_module, + # "get_version_from_message_type", + # async_mock.AsyncMock(return_value="1.1"), + # ), async_mock.patch.object( + # test_module, + # "validate_get_response_version", + # async_mock.AsyncMock(return_value=("1.1", "version-not-supported")), + # ): + # with self.assertRaises(test_module.MessageParseError): + # await dispatcher.queue_message( + # dispatcher.profile, make_inbound(message), rcv.send + # ) diff --git a/aries_cloudagent/core/tests/test_protocol_registry.py b/aries_cloudagent/core/tests/test_protocol_registry.py index 5c43668d8b..623ad1d808 100644 --- a/aries_cloudagent/core/tests/test_protocol_registry.py +++ b/aries_cloudagent/core/tests/test_protocol_registry.py @@ -44,6 +44,162 @@ def test_message_type_query(self): matches = self.registry.protocols_matching_query(q) assert matches == () + def test_create_msg_types_for_minor_version(self): + MSG_PATH = "aries_cloudagent.protocols.introduction.v0_1.messages" + test_typesets = ( + { + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-forward-invitation": f"{MSG_PATH}.forward_invitation.ForwardInvitation", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-invitation": f"{MSG_PATH}.invitation.Invitation", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-invitation-request": f"{MSG_PATH}.invitation_request.InvitationRequest", + "https://didcom.org/introduction-service/1.0/fake-forward-invitation": f"{MSG_PATH}.forward_invitation.ForwardInvitation", + "https://didcom.org/introduction-service/1.0/fake-invitation": f"{MSG_PATH}.invitation.Invitation", + "https://didcom.org/introduction-service/1.0/fake-invitation-request": f"{MSG_PATH}.invitation_request.InvitationRequest", + }, + ) + test_version_def = { + "current_minor_version": 0, + "major_version": 1, + "minimum_minor_version": 0, + "path": "v0_1", + } + updated_typesets = self.registry.create_msg_types_for_minor_version( + test_typesets, test_version_def + ) + updated_typeset = updated_typesets[0] + assert ( + "https://didcom.org/introduction-service/1.0/fake-forward-invitation" + in updated_typeset + ) + assert ( + "https://didcom.org/introduction-service/1.0/fake-invitation" + in updated_typeset + ) + assert ( + "https://didcom.org/introduction-service/1.0/fake-invitation-request" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/1.0/fake-forward-invitation" + in updated_typeset + ) + + def test_introduction_create_msg_types_for_minor_version(self): + MSG_PATH = "aries_cloudagent.protocols.introduction.v0_1.messages" + test_typesets = ( + { + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/invitation-request": f"{MSG_PATH}.invitation_request.InvitationRequest", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/invitation": f"{MSG_PATH}.invitation.Invitation", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/forward-invitation": f"{MSG_PATH}.invitation_messages.forward_invitation.ForwardInvitation", + "https://didcom.org/introduction-service/0.1/invitation-request": f"{MSG_PATH}.invitation_request.InvitationRequest", + "https://didcom.org/introduction-service/0.1/invitation": f"{MSG_PATH}.invitation.Invitation", + "https://didcom.org/introduction-service/0.1/forward-invitation": f"{MSG_PATH}.forward_invitation.ForwardInvitation", + }, + ) + test_version_def = { + "current_minor_version": 1, + "major_version": 0, + "minimum_minor_version": 1, + "path": "v0_1", + } + updated_typesets = self.registry.create_msg_types_for_minor_version( + test_typesets, test_version_def + ) + updated_typeset = updated_typesets[0] + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/invitation-request" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/invitation" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/introduction-service/0.1/forward-invitation" + in updated_typeset + ) + assert ( + "https://didcom.org/introduction-service/0.1/invitation-request" + in updated_typeset + ) + assert ( + "https://didcom.org/introduction-service/0.1/invitation" in updated_typeset + ) + assert ( + "https://didcom.org/introduction-service/0.1/forward-invitation" + in updated_typeset + ) + + def test_oob_create_msg_types_for_minor_version(self): + MSG_PATH = "aries_cloudagent.protocols.out_of_band.v1_0.messages" + test_typesets = ( + { + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/invitation": f"{MSG_PATH}.invitation.Invitation", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/handshake-reuse": f"{MSG_PATH}.reuse.HandshakeReuse", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/handshake-reuse-accepted": f"{MSG_PATH}.reuse_accept.HandshakeReuseAccept", + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/problem_report": f"{MSG_PATH}.problem_report.OOBProblemReport", + "https://didcom.org/out-of-band/1.1/invitation": f"{MSG_PATH}.invitation.Invitation", + "https://didcom.org/out-of-band/1.1/handshake-reuse": f"{MSG_PATH}.reuse.HandshakeReuse", + "https://didcom.org/out-of-band/1.1/handshake-reuse-accepted": f"{MSG_PATH}.reuse_accept.HandshakeReuseAccept", + "https://didcom.org/out-of-band/1.1/problem_report": f"{MSG_PATH}.problem_report.OOBProblemReport", + }, + ) + test_version_def = { + "current_minor_version": 1, + "major_version": 1, + "minimum_minor_version": 0, + "path": "v0_1", + } + updated_typesets = self.registry.create_msg_types_for_minor_version( + test_typesets, test_version_def + ) + updated_typeset = updated_typesets[0] + assert "https://didcom.org/out-of-band/1.0/invitation" in updated_typeset + assert "https://didcom.org/out-of-band/1.0/handshake-reuse" in updated_typeset + assert ( + "https://didcom.org/out-of-band/1.0/handshake-reuse-accepted" + in updated_typeset + ) + assert "https://didcom.org/out-of-band/1.0/problem_report" in updated_typeset + assert "https://didcom.org/out-of-band/1.1/invitation" in updated_typeset + assert "https://didcom.org/out-of-band/1.1/handshake-reuse" in updated_typeset + assert ( + "https://didcom.org/out-of-band/1.1/handshake-reuse-accepted" + in updated_typeset + ) + assert "https://didcom.org/out-of-band/1.1/problem_report" in updated_typeset + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.0/invitation" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.0/handshake-reuse" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.0/handshake-reuse-accepted" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.0/problem_report" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/invitation" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/handshake-reuse" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/handshake-reuse-accepted" + in updated_typeset + ) + assert ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/out-of-band/1.1/problem_report" + in updated_typeset + ) + async def test_disclosed(self): self.registry.register_message_types( {self.test_message_type: self.test_message_handler} diff --git a/aries_cloudagent/core/tests/test_util.py b/aries_cloudagent/core/tests/test_util.py new file mode 100644 index 0000000000..c25f944d74 --- /dev/null +++ b/aries_cloudagent/core/tests/test_util.py @@ -0,0 +1,82 @@ +from async_case import IsolatedAsyncioTestCase + +from ...cache.base import BaseCache +from ...cache.in_memory import InMemoryCache +from ...core.in_memory import InMemoryProfile +from ...core.profile import Profile +from ...protocols.didcomm_prefix import DIDCommPrefix +from ...protocols.introduction.v0_1.messages.invitation import Invitation +from ...protocols.out_of_band.v1_0.messages.reuse import HandshakeReuse + +from .. import util as test_module + + +def make_profile() -> Profile: + profile = InMemoryProfile.test_profile() + profile.context.injector.bind_instance(BaseCache, InMemoryCache()) + return profile + + +class TestUtils(IsolatedAsyncioTestCase): + async def test_validate_get_response_version(self): + profile = make_profile() + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "1.1", HandshakeReuse + ) + assert resp_version == "1.1" + assert not warning + + # cached + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "1.1", HandshakeReuse + ) + assert resp_version == "1.1" + assert not warning + + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "1.0", HandshakeReuse + ) + assert resp_version == "1.0" + assert warning == test_module.WARNING_DEGRADED_FEATURES + + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "1.2", HandshakeReuse + ) + assert resp_version == "1.1" + assert warning == test_module.WARNING_VERSION_MISMATCH + + with self.assertRaises(test_module.ProtocolMinorVersionNotSupported): + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "0.0", Invitation + ) + + with self.assertRaises(Exception): + (resp_version, warning) = await test_module.validate_get_response_version( + profile, "1.0", Invitation + ) + + def test_get_version_from_message_type(self): + assert ( + test_module.get_version_from_message_type( + DIDCommPrefix.qualify_current("out-of-band/1.1/handshake-reuse") + ) + == "1.1" + ) + + def test_get_version_from_message(self): + assert test_module.get_version_from_message(HandshakeReuse()) == "1.1" + + async def test_get_proto_default_version_from_msg_class(self): + profile = make_profile() + assert ( + await test_module.get_proto_default_version_from_msg_class( + profile, HandshakeReuse + ) + ) == "1.1" + + def test_get_proto_default_version(self): + assert ( + test_module.get_proto_default_version( + "aries_cloudagent.protocols.out_of_band.definition" + ) + ) == "1.1" diff --git a/aries_cloudagent/core/util.py b/aries_cloudagent/core/util.py index 791f80c95d..58b7713e8f 100644 --- a/aries_cloudagent/core/util.py +++ b/aries_cloudagent/core/util.py @@ -1,10 +1,161 @@ """Core utilities and constants.""" +import inspect +import os import re +from typing import Optional, Tuple + +from ..cache.base import BaseCache +from ..core.profile import Profile +from ..messaging.agent_message import AgentMessage +from ..utils.classloader import ClassLoader + +from .error import ProtocolMinorVersionNotSupported, ProtocolDefinitionValidationError CORE_EVENT_PREFIX = "acapy::core::" STARTUP_EVENT_TOPIC = CORE_EVENT_PREFIX + "startup" STARTUP_EVENT_PATTERN = re.compile(f"^{STARTUP_EVENT_TOPIC}?$") SHUTDOWN_EVENT_TOPIC = CORE_EVENT_PREFIX + "shutdown" SHUTDOWN_EVENT_PATTERN = re.compile(f"^{SHUTDOWN_EVENT_TOPIC}?$") +WARNING_DEGRADED_FEATURES = "version-with-degraded-features" +WARNING_VERSION_MISMATCH = "fields-ignored-due-to-version-mismatch" +WARNING_VERSION_NOT_SUPPORTED = "version-not-supported" + + +async def validate_get_response_version( + profile: Profile, rec_version: str, msg_class: type +) -> Tuple[str, Optional[str]]: + """ + Return a tuple with version to respond with and warnings. + + Process received version and protocol version definition, + returns the tuple. + + Args: + profile: Profile + rec_version: received version from message + msg_class: type + + Returns: + Tuple with response version and any warnings + + """ + resp_version = rec_version + warning = None + version_string_tokens = rec_version.split(".") + rec_major_version = int(version_string_tokens[0]) + rec_minor_version = int(version_string_tokens[1]) + version_definition = await get_version_def_from_msg_class( + profile, msg_class, rec_major_version + ) + proto_major_version = int(version_definition["major_version"]) + proto_curr_minor_version = int(version_definition["current_minor_version"]) + proto_min_minor_version = int(version_definition["minimum_minor_version"]) + if rec_minor_version < proto_min_minor_version: + warning = WARNING_VERSION_NOT_SUPPORTED + elif ( + rec_minor_version >= proto_min_minor_version + and rec_minor_version < proto_curr_minor_version + ): + warning = WARNING_DEGRADED_FEATURES + elif rec_minor_version > proto_curr_minor_version: + warning = WARNING_VERSION_MISMATCH + if proto_major_version == rec_major_version: + if ( + proto_min_minor_version <= rec_minor_version + and proto_curr_minor_version >= rec_minor_version + ): + resp_version = f"{str(proto_major_version)}.{str(rec_minor_version)}" + elif rec_minor_version > proto_curr_minor_version: + resp_version = f"{str(proto_major_version)}.{str(proto_curr_minor_version)}" + elif rec_minor_version < proto_min_minor_version: + raise ProtocolMinorVersionNotSupported( + "Minimum supported minor version is " + + f"{proto_min_minor_version}." + + f" Received {rec_minor_version}." + ) + else: + raise ProtocolMinorVersionNotSupported( + f"Supported major version {proto_major_version}" + " is not same as received major version" + f" {rec_major_version}." + ) + return (resp_version, warning) + + +def get_version_from_message_type(msg_type: str) -> str: + """Return version from provided message_type.""" + return (re.search(r"(\d+\.)?(\*|\d+)", msg_type)).group() + + +def get_version_from_message(msg: AgentMessage) -> str: + """Return version from provided AgentMessage.""" + msg_type = msg._type + return get_version_from_message_type(msg_type) + + +async def get_proto_default_version_from_msg_class( + profile: Profile, msg_class: type, major_version: int = 1 +) -> str: + """Return default protocol version from version_definition.""" + version_definition = await get_version_def_from_msg_class( + profile, msg_class, major_version + ) + return _get_default_version_from_version_def(version_definition) + + +def get_proto_default_version(def_path: str, major_version: int = 1) -> str: + """Return default protocol version from version_definition.""" + version_definition = _get_version_def_from_path(def_path, major_version) + return _get_default_version_from_version_def(version_definition) + + +def _get_path_from_msg_class(msg_class: type) -> str: + path = os.path.normpath(inspect.getfile(msg_class)) + split_str = os.getenv("ACAPY_HOME") or "aries_cloudagent" + path = split_str + path.rsplit(split_str, 1)[1] + version = (re.search(r"v(\d+\_)?(\*|\d+)", path)).group() + path = path.split(version, 1)[0] + return (path.replace("/", ".")) + "definition" + + +def _get_version_def_from_path(definition_path: str, major_version: int = 1): + version_definition = None + definition = ClassLoader.load_module(definition_path) + for protocol_version in definition.versions: + if major_version == protocol_version["major_version"]: + version_definition = protocol_version + break + return version_definition + + +def _get_default_version_from_version_def(version_definition) -> str: + default_major_version = version_definition["major_version"] + default_minor_version = version_definition["current_minor_version"] + return f"{default_major_version}.{default_minor_version}" + + +async def get_version_def_from_msg_class( + profile: Profile, msg_class: type, major_version: int = 1 +): + """Return version_definition of a protocol from msg_class.""" + cache = profile.inject_or(BaseCache) + version_definition = None + if cache: + version_definition = await cache.get( + f"version_definition::{str(msg_class).lower()}" + ) + if version_definition: + return version_definition + definition_path = _get_path_from_msg_class(msg_class) + version_definition = _get_version_def_from_path(definition_path, major_version) + if not version_definition: + raise ProtocolDefinitionValidationError( + f"Unable to load protocol version_definition for {str(msg_class)}" + ) + if cache: + await cache.set( + f"version_definition::{str(msg_class).lower()}", version_definition + ) + return version_definition diff --git a/aries_cloudagent/messaging/agent_message.py b/aries_cloudagent/messaging/agent_message.py index ea39f14712..2d17f1d35a 100644 --- a/aries_cloudagent/messaging/agent_message.py +++ b/aries_cloudagent/messaging/agent_message.py @@ -1,9 +1,11 @@ """Agent message base class and schema.""" -from collections import OrderedDict -from typing import Mapping, Union import uuid +from collections import OrderedDict +from re import sub +from typing import Mapping, Optional, Union, Text + from marshmallow import ( EXCLUDE, fields, @@ -53,7 +55,13 @@ class Meta: schema_class = None message_type = None - def __init__(self, _id: str = None, _decorators: BaseDecoratorSet = None): + def __init__( + self, + _id: str = None, + _type: Optional[Text] = None, + _version: Optional[Text] = None, + _decorators: BaseDecoratorSet = None, + ): """ Initialize base agent message object. @@ -81,6 +89,12 @@ def __init__(self, _id: str = None, _decorators: BaseDecoratorSet = None): self.__class__.__name__ ) ) + if _type: + self._message_type = _type + elif _version: + self._message_type = self.get_updated_msg_type(_version) + else: + self._message_type = self.Meta.message_type # Not required for now # if not self.Meta.handler_class: # raise TypeError( @@ -118,7 +132,12 @@ def _type(self) -> str: Current DIDComm prefix, slash, message type defined on `Meta.message_type` """ - return DIDCommPrefix.qualify_current(self.Meta.message_type) + return DIDCommPrefix.qualify_current(self._message_type) + + @_type.setter + def _type(self, msg_type: str): + """Set the message type identifier.""" + self._message_type = msg_type @property def _id(self) -> str: @@ -146,6 +165,10 @@ def _decorators(self, value: BaseDecoratorSet): """Fetch the message's decorator set.""" self._message_decorators = value + def get_updated_msg_type(self, version: str) -> str: + """Update version to Meta.message_type.""" + return sub(r"(\d+\.)?(\*|\d+)", version, self.Meta.message_type) + def get_signature(self, field_name: str) -> SignatureDecorator: """ Get the signature for a named field. diff --git a/aries_cloudagent/messaging/models/base.py b/aries_cloudagent/messaging/models/base.py index fd00c7d68d..8eb0601a6c 100644 --- a/aries_cloudagent/messaging/models/base.py +++ b/aries_cloudagent/messaging/models/base.py @@ -319,7 +319,14 @@ def make_model(self, data: dict, **kwargs): A model instance """ - return self.Model(**data) + try: + cls_inst = self.Model(**data) + except TypeError as err: + if "_type" in str(err) and "_type" in data: + data["msg_type"] = data["_type"] + del data["_type"] + cls_inst = self.Model(**data) + return cls_inst @post_dump def remove_skipped_values(self, data, **kwargs): diff --git a/aries_cloudagent/protocols/out_of_band/definition.py b/aries_cloudagent/protocols/out_of_band/definition.py index 62bddef6f5..13c1f8a8ef 100644 --- a/aries_cloudagent/protocols/out_of_band/definition.py +++ b/aries_cloudagent/protocols/out_of_band/definition.py @@ -4,7 +4,7 @@ { "major_version": 1, "minimum_minor_version": 0, - "current_minor_version": 0, + "current_minor_version": 1, "path": "v1_0", } ] diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py index 0bbfc668bc..79bdc1a352 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py @@ -3,12 +3,13 @@ import asyncio import logging import re -from typing import Mapping, Optional, Sequence, Union +from typing import Mapping, Optional, Sequence, Union, Text import uuid from ....messaging.decorators.service_decorator import ServiceDecorator from ....core.event_bus import EventBus +from ....core.util import get_version_from_message from ....connections.base_manager import BaseConnectionManager from ....connections.models.conn_record import ConnRecord from ....connections.util import mediation_record_if_id @@ -39,6 +40,7 @@ from .models.invitation import InvitationRecord from .models.oob_record import OobRecord from .messages.service import Service +from .message_types import DEFAULT_VERSION LOGGER = logging.getLogger(__name__) REUSE_WEBHOOK_TOPIC = "acapy::webhook::connection_reuse" @@ -89,6 +91,8 @@ async def create_invitation( attachments: Sequence[Mapping] = None, metadata: dict = None, mediation_id: str = None, + service_accept: Optional[Sequence[Text]] = None, + protocol_version: Optional[Text] = None, ) -> InvitationRecord: """ Generate new connection invitation. @@ -107,6 +111,9 @@ async def create_invitation( multi_use: set to True to create an invitation for multiple-use connection alias: optional alias to apply to connection for later use attachments: list of dicts in form of {"id": ..., "type": ...} + service_accept: Optional list of mime types in the order of preference of + the sender that the receiver can use in responding to the message + protocol_version: OOB protocol version [1.0, 1.1] Returns: Invitation record @@ -130,7 +137,7 @@ async def create_invitation( multitenant_mgr = self.profile.inject_or(BaseMultitenantManager) wallet_id = self.profile.settings.get("wallet.id") - accept = bool( + auto_accept = bool( auto_accept or ( auto_accept is None @@ -235,6 +242,8 @@ async def create_invitation( handshake_protocols=handshake_protocols, requests_attach=message_attachments, services=[f"did:sov:{public_did.did}"], + accept=service_accept if protocol_version != "1.0" else None, + version=protocol_version or DEFAULT_VERSION, ) keylist_updates = await mediation_mgr.add_key( public_did.verkey, keylist_updates @@ -258,7 +267,7 @@ async def create_invitation( their_role=ConnRecord.Role.REQUESTER.rfc23, state=ConnRecord.State.INVITATION.rfc23, accept=ConnRecord.ACCEPT_AUTO - if accept + if auto_accept else ConnRecord.ACCEPT_MANUAL, alias=alias, connection_protocol=connection_protocol, @@ -291,7 +300,9 @@ async def create_invitation( await multitenant_mgr.add_key(wallet_id, connection_key.verkey) # Initializing InvitationMessage here to include # invitation_msg_id in webhook poyload - invi_msg = InvitationMessage(_id=invitation_message_id) + invi_msg = InvitationMessage( + _id=invitation_message_id, version=protocol_version or DEFAULT_VERSION + ) if handshake_protocols: invitation_mode = ( @@ -305,7 +316,7 @@ async def create_invitation( their_role=ConnRecord.Role.REQUESTER.rfc23, state=ConnRecord.State.INVITATION.rfc23, accept=ConnRecord.ACCEPT_AUTO - if accept + if auto_accept else ConnRecord.ACCEPT_MANUAL, invitation_mode=invitation_mode, alias=alias, @@ -371,6 +382,7 @@ async def create_invitation( invi_msg.label = my_label or self.profile.settings.get("default_label") invi_msg.handshake_protocols = handshake_protocols invi_msg.requests_attach = message_attachments + invi_msg.accept = service_accept if protocol_version != "1.0" else None invi_msg.services = [ ServiceMessage( _id="#inline", @@ -457,6 +469,9 @@ async def receive_invitation( # Get the single service item oob_service_item = invitation.services[0] + # service_accept + service_accept = invitation.accept + # Get the DID public did, if any public_did = None if isinstance(oob_service_item, str): @@ -488,7 +503,9 @@ async def receive_invitation( # Try to reuse the connection. If not accepted sets the conn_rec to None if conn_rec and not invitation.requests_attach: - oob_record = await self._handle_hanshake_reuse(oob_record, conn_rec) + oob_record = await self._handle_hanshake_reuse( + oob_record, conn_rec, get_version_from_message(invitation) + ) LOGGER.warning( f"Connection reuse request finished with state {oob_record.state}" @@ -509,6 +526,7 @@ async def receive_invitation( alias=alias, auto_accept=auto_accept, mediation_id=mediation_id, + service_accept=service_accept, ) LOGGER.debug( f"Performed handshake with connection {oob_record.connection_id}" @@ -716,10 +734,12 @@ async def _wait_for_state() -> ConnRecord: return None async def _handle_hanshake_reuse( - self, oob_record: OobRecord, conn_record: ConnRecord + self, oob_record: OobRecord, conn_record: ConnRecord, version: str ) -> OobRecord: # Send handshake reuse - oob_record = await self._create_handshake_reuse_message(oob_record, conn_record) + oob_record = await self._create_handshake_reuse_message( + oob_record, conn_record, version + ) # Wait for the reuse accepted message oob_record = await self._wait_for_reuse_response(oob_record.oob_id) @@ -761,6 +781,7 @@ async def _perform_handshake( alias: Optional[str] = None, auto_accept: Optional[bool] = None, mediation_id: Optional[str] = None, + service_accept: Optional[Sequence[Text]] = None, ) -> OobRecord: invitation = oob_record.invitation @@ -788,7 +809,8 @@ async def _perform_handshake( # or something else that includes the key type. We now assume # ED25519 keys endpoint, recipient_keys, routing_keys = await self.resolve_invitation( - service + service, + service_accept=service_accept, ) service = ServiceMessage.deserialize( { @@ -866,6 +888,7 @@ async def _create_handshake_reuse_message( self, oob_record: OobRecord, conn_record: ConnRecord, + version: str, ) -> OobRecord: """ Create and Send a Handshake Reuse message under RFC 0434. @@ -882,7 +905,7 @@ async def _create_handshake_reuse_message( """ try: - reuse_msg = HandshakeReuse() + reuse_msg = HandshakeReuse(version=version) reuse_msg.assign_thread_id(thid=reuse_msg._id, pthid=oob_record.invi_msg_id) connection_targets = await self.fetch_connection_targets( @@ -949,7 +972,9 @@ async def receive_reuse_message( invi_msg_id = reuse_msg._thread.pthid reuse_msg_id = reuse_msg._thread_id - reuse_accept_msg = HandshakeReuseAccept() + reuse_accept_msg = HandshakeReuseAccept( + version=get_version_from_message(reuse_msg) + ) reuse_accept_msg.assign_thread_id(thid=reuse_msg_id, pthid=invi_msg_id) connection_targets = await self.fetch_connection_targets(connection=conn_rec) diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py b/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py index d8fb709e09..e8fcd09a94 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/message_types.py @@ -1,5 +1,7 @@ """Message and inner object type identifiers for Out of Band messages.""" +from ....core.util import get_proto_default_version + from ...didcomm_prefix import DIDCommPrefix SPEC_URI = ( @@ -7,11 +9,17 @@ "2da7fc4ee043effa3a9960150e7ba8c9a4628b68/features/0434-outofband" ) +# Default Version +DEFAULT_VERSION = get_proto_default_version( + "aries_cloudagent.protocols.out_of_band.definition", 1 +) + # Message types -INVITATION = "out-of-band/1.0/invitation" -MESSAGE_REUSE = "out-of-band/1.0/handshake-reuse" -MESSAGE_REUSE_ACCEPT = "out-of-band/1.0/handshake-reuse-accepted" -PROBLEM_REPORT = "out-of-band/1.0/problem_report" +INVITATION = f"out-of-band/{DEFAULT_VERSION}/invitation" +MESSAGE_REUSE = f"out-of-band/{DEFAULT_VERSION}/handshake-reuse" +MESSAGE_REUSE_ACCEPT = f"out-of-band/{DEFAULT_VERSION}/handshake-reuse-accepted" +PROBLEM_REPORT = f"out-of-band/{DEFAULT_VERSION}/problem_report" + PROTOCOL_PACKAGE = "aries_cloudagent.protocols.out_of_band.v1_0" diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py index 04a2cfa3fa..14b6eaaff6 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py @@ -3,7 +3,7 @@ from collections import namedtuple from enum import Enum from re import sub -from typing import Sequence, Text, Union +from typing import Optional, Sequence, Text, Union from urllib.parse import parse_qs, urljoin, urlparse from marshmallow import ( @@ -26,7 +26,7 @@ from ....didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDX_PROTO from ....connections.v1_0.message_types import ARIES_PROTOCOL as CONN_PROTO -from ..message_types import INVITATION +from ..message_types import INVITATION, DEFAULT_VERSION from .service import Service @@ -123,6 +123,9 @@ def __init__( handshake_protocols: Sequence[Text] = None, requests_attach: Sequence[AttachDecorator] = None, services: Sequence[Union[Service, Text]] = None, + accept: Optional[Sequence[Text]] = None, + version: str = DEFAULT_VERSION, + msg_type: Optional[Text] = None, **kwargs, ): """ @@ -133,13 +136,14 @@ def __init__( """ # super().__init__(_id=_id, **kwargs) - super().__init__(**kwargs) + super().__init__(_type=msg_type, _version=version, **kwargs) self.label = label self.handshake_protocols = ( list(handshake_protocols) if handshake_protocols else [] ) self.requests_attach = list(requests_attach) if requests_attach else [] self.services = services + self.accept = accept @classmethod def wrap_message(cls, message: dict) -> AttachDecorator: @@ -197,6 +201,12 @@ class Meta: model_class = InvitationMessage unknown = EXCLUDE + _type = fields.Str( + data_key="@type", + required=False, + description="Message type", + example="https://didcomm.org/my-family/1.0/my-message-type", + ) label = fields.Str(required=False, description="Optional label", example="Bob") handshake_protocols = fields.List( fields.Str( @@ -208,6 +218,12 @@ class Meta: ), required=False, ) + accept = fields.List( + fields.Str(), + example=["didcomm/aip1", "didcomm/aip2;env=rfc19"], + description=("List of mime type in order of preference"), + required=False, + ) requests_attach = fields.Nested( AttachDecoratorSchema, required=False, diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/problem_report.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/problem_report.py index f6ddb3bf86..fc2e01039e 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/problem_report.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/problem_report.py @@ -3,9 +3,11 @@ import logging from enum import Enum +from typing import Optional, Text from marshmallow import ( EXCLUDE, + fields, pre_dump, validates_schema, ValidationError, @@ -13,7 +15,7 @@ from ....problem_report.v1_0.message import ProblemReport, ProblemReportSchema -from ..message_types import PROBLEM_REPORT, PROTOCOL_PACKAGE +from ..message_types import PROBLEM_REPORT, PROTOCOL_PACKAGE, DEFAULT_VERSION HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers" @@ -40,9 +42,15 @@ class Meta: message_type = PROBLEM_REPORT schema_class = "OOBProblemReportSchema" - def __init__(self, *args, **kwargs): + def __init__( + self, + version: str = DEFAULT_VERSION, + msg_type: Optional[Text] = None, + *args, + **kwargs, + ): """Initialize a ProblemReport message instance.""" - super().__init__(*args, **kwargs) + super().__init__(_type=msg_type, _version=version, *args, **kwargs) class OOBProblemReportSchema(ProblemReportSchema): @@ -54,6 +62,13 @@ class Meta: model_class = OOBProblemReport unknown = EXCLUDE + _type = fields.Str( + data_key="@type", + required=False, + description="Message type", + example="https://didcomm.org/my-family/1.0/my-message-type", + ) + @pre_dump def check_thread_deco(self, obj, **kwargs): """Thread decorator, and its thid and pthid, are mandatory.""" diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse.py index df40511e80..d53f5aeddb 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse.py @@ -1,10 +1,11 @@ """Represents a Handshake Reuse message under RFC 0434.""" -from marshmallow import EXCLUDE, pre_dump, ValidationError +from marshmallow import EXCLUDE, fields, pre_dump, ValidationError +from typing import Optional, Text from .....messaging.agent_message import AgentMessage, AgentMessageSchema -from ..message_types import MESSAGE_REUSE, PROTOCOL_PACKAGE +from ..message_types import MESSAGE_REUSE, PROTOCOL_PACKAGE, DEFAULT_VERSION HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers.reuse_handler.HandshakeReuseMessageHandler" @@ -23,10 +24,12 @@ class Meta: def __init__( self, + version: str = DEFAULT_VERSION, + msg_type: Optional[Text] = None, **kwargs, ): """Initialize Handshake Reuse message object.""" - super().__init__(**kwargs) + super().__init__(_type=msg_type, _version=version, **kwargs) class HandshakeReuseSchema(AgentMessageSchema): @@ -38,6 +41,13 @@ class Meta: model_class = HandshakeReuse unknown = EXCLUDE + _type = fields.Str( + data_key="@type", + required=False, + description="Message type", + example="https://didcomm.org/my-family/1.0/my-message-type", + ) + @pre_dump def check_thread_deco(self, obj, **kwargs): """Thread decorator, and its thid and pthid, are mandatory.""" diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse_accept.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse_accept.py index d519ab0a2b..0b0e21a58e 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse_accept.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/reuse_accept.py @@ -1,10 +1,11 @@ """Represents a Handshake Reuse Accept message under RFC 0434.""" -from marshmallow import EXCLUDE, pre_dump, ValidationError +from marshmallow import EXCLUDE, fields, pre_dump, ValidationError +from typing import Optional, Text from .....messaging.agent_message import AgentMessage, AgentMessageSchema -from ..message_types import MESSAGE_REUSE_ACCEPT, PROTOCOL_PACKAGE +from ..message_types import MESSAGE_REUSE_ACCEPT, PROTOCOL_PACKAGE, DEFAULT_VERSION HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers" @@ -24,10 +25,12 @@ class Meta: def __init__( self, + version: str = DEFAULT_VERSION, + msg_type: Optional[Text] = None, **kwargs, ): """Initialize Handshake Reuse Accept object.""" - super().__init__(**kwargs) + super().__init__(_type=msg_type, _version=version, **kwargs) class HandshakeReuseAcceptSchema(AgentMessageSchema): @@ -39,6 +42,13 @@ class Meta: model_class = HandshakeReuseAccept unknown = EXCLUDE + _type = fields.Str( + data_key="@type", + required=False, + description="Message type", + example="https://didcomm.org/my-family/1.0/my-message-type", + ) + @pre_dump def check_thread_deco(self, obj, **kwargs): """Thread decorator, and its thid and pthid, are mandatory.""" diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py index 5340dd66dc..d00da45914 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py @@ -9,6 +9,7 @@ from .....connections.v1_0.message_types import ARIES_PROTOCOL as CONN_PROTO from .....didcomm_prefix import DIDCommPrefix from .....didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDX_PROTO +from .....didexchange.v1_0.messages.request import DIDXRequest from ...message_types import INVITATION @@ -44,14 +45,14 @@ def test_properties(self): class TestInvitationMessage(TestCase): def test_init(self): """Test initialization message.""" - invi = InvitationMessage( + invi_msg = InvitationMessage( comment="Hello", label="A label", handshake_protocols=[DIDCommPrefix.qualify_current(DIDX_PROTO)], services=[TEST_DID], ) - assert invi.services == [TEST_DID] - assert invi._type == DIDCommPrefix.qualify_current(INVITATION) + assert invi_msg.services == [TEST_DID] + assert "out-of-band/1.1/invitation" in invi_msg._type service = Service(_id="#inline", _type=DID_COMM, did=TEST_DID) invi_msg = InvitationMessage( @@ -59,9 +60,10 @@ def test_init(self): label="A label", handshake_protocols=[DIDCommPrefix.qualify_current(DIDX_PROTO)], services=[service], + version="1.0", ) assert invi_msg.services == [service] - assert invi_msg._type == DIDCommPrefix.qualify_current(INVITATION) + assert "out-of-band/1.0/invitation" in invi_msg._type def test_wrap_serde(self): """Test conversion of aries message to attachment decorator.""" @@ -144,3 +146,15 @@ def test_invalid_invi_wrong_type_services(self): invi_schema = InvitationMessageSchema() with pytest.raises(test_module.ValidationError): invi_schema.validate_fields(obj_x) + + def test_assign_msg_type_version_to_model_inst(self): + test_msg = InvitationMessage() + assert "1.1" in test_msg._type + assert "1.1" in InvitationMessage.Meta.message_type + test_msg = InvitationMessage(version="1.2") + assert "1.2" in test_msg._type + assert "1.1" in InvitationMessage.Meta.message_type + test_req = DIDXRequest() + assert "1.0" in test_req._type + assert "1.2" in test_msg._type + assert "1.1" in InvitationMessage.Meta.message_type diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_problem_report.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_problem_report.py index 0b605b179f..c594ad146b 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_problem_report.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_problem_report.py @@ -73,3 +73,11 @@ def test_validate_and_logger(self): self._caplog.set_level(logging.WARNING) OOBProblemReportSchema().validate_fields(data) assert "Unexpected error code received" in self._caplog.text + + def test_assign_msg_type_version_to_model_inst(self): + test_msg = OOBProblemReport() + assert "1.1" in test_msg._type + assert "1.1" in OOBProblemReport.Meta.message_type + test_msg = OOBProblemReport(version="1.2") + assert "1.2" in test_msg._type + assert "1.1" in OOBProblemReport.Meta.message_type diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_reuse.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_reuse.py index 5cd79750d0..837698e718 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_reuse.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_reuse.py @@ -35,3 +35,11 @@ def test_pre_dump_x(self): """Exercise pre-dump serialization requirements.""" with pytest.raises(BaseModelError): data = self.reuse_msg.serialize() + + def test_assign_msg_type_version_to_model_inst(self): + test_msg = HandshakeReuse() + assert "1.1" in test_msg._type + assert "1.1" in HandshakeReuse.Meta.message_type + test_msg = HandshakeReuse(version="1.2") + assert "1.2" in test_msg._type + assert "1.1" in HandshakeReuse.Meta.message_type diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_reuse_accept.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_reuse_accept.py index 3feca5439a..556493c618 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_reuse_accept.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_reuse_accept.py @@ -6,6 +6,8 @@ from ......messaging.models.base import BaseModelError +from .....didcomm_prefix import DIDCommPrefix + from ..reuse_accept import HandshakeReuseAccept, HandshakeReuseAcceptSchema @@ -31,7 +33,25 @@ def test_make_model(self): model_instance = HandshakeReuseAccept.deserialize(data) assert isinstance(model_instance, HandshakeReuseAccept) + def test_make_model_backward_comp(self): + """Make reuse-accept model.""" + self.reuse_accept_msg.assign_thread_id(thid="test_thid", pthid="test_pthid") + data = self.reuse_accept_msg.serialize() + data["@type"] = DIDCommPrefix.qualify_current( + "out-of-band/1.0/handshake-reuse-accepted" + ) + model_instance = HandshakeReuseAccept.deserialize(data) + assert isinstance(model_instance, HandshakeReuseAccept) + def test_pre_dump_x(self): """Exercise pre-dump serialization requirements.""" with pytest.raises(BaseModelError): data = self.reuse_accept_msg.serialize() + + def test_assign_msg_type_version_to_model_inst(self): + test_msg = HandshakeReuseAccept() + assert "1.1" in test_msg._type + assert "1.1" in HandshakeReuseAccept.Meta.message_type + test_msg = HandshakeReuseAccept(version="1.2") + assert "1.2" in test_msg._type + assert "1.1" in HandshakeReuseAccept.Meta.message_type diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/routes.py b/aries_cloudagent/protocols/out_of_band/v1_0/routes.py index 3f7384c164..3b9b489500 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/routes.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/routes.py @@ -75,6 +75,15 @@ class AttachmentDefSchema(OpenAPISchema): ), required=False, ) + accept = fields.List( + fields.Str(), + description=( + "List of mime type in order of preference that should be" + " use in responding to the message" + ), + example=["didcomm/aip1", "didcomm/aip2;env=rfc19"], + required=False, + ) use_public_did = fields.Boolean( default=False, description="Whether to use public DID in invitation", @@ -92,6 +101,11 @@ class AttachmentDefSchema(OpenAPISchema): required=False, example="Invitation to Barry", ) + protocol_version = fields.Str( + description="OOB protocol version", + required=False, + example="1.1", + ) alias = fields.Str( description="Alias for connection", required=False, @@ -151,11 +165,13 @@ async def invitation_create(request: web.BaseRequest): body = await request.json() if request.body_exists else {} attachments = body.get("attachments") handshake_protocols = body.get("handshake_protocols", []) + service_accept = body.get("accept") use_public_did = body.get("use_public_did", False) metadata = body.get("metadata") my_label = body.get("my_label") alias = body.get("alias") mediation_id = body.get("mediation_id") + protocol_version = body.get("protocol_version") multi_use = json.loads(request.query.get("multi_use", "false")) auto_accept = json.loads(request.query.get("auto_accept", "null")) @@ -175,6 +191,8 @@ async def invitation_create(request: web.BaseRequest): metadata=metadata, alias=alias, mediation_id=mediation_id, + service_accept=service_accept, + protocol_version=protocol_version, ) except (StorageNotFoundError, ValidationError, OutOfBandManagerError) as e: raise web.HTTPBadRequest(reason=e.roll_up) diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py index f7ff19144a..a745265b6b 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py @@ -13,6 +13,7 @@ from .....connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service from .....core.event_bus import EventBus from .....core.in_memory import InMemoryProfile +from .....core.util import get_version_from_message from .....core.oob_processor import OobMessageProcessor from .....did.did_key import DIDKey from .....messaging.decorators.attach_decorator import AttachDecorator @@ -101,6 +102,7 @@ class TestConfig: service_endpoint=test_endpoint, ) NOW_8601 = datetime.utcnow().replace(tzinfo=timezone.utc).isoformat(" ", "seconds") + TEST_INVI_MESSAGE_TYPE = "out-of-band/1.1/invitation" NOW_EPOCH = str_to_epoch(NOW_8601) CD_ID = "GMm4vMw8LLrLJjp81kRRLp:3:CL:12:tag" INDY_PROOF_REQ = json.loads( @@ -379,7 +381,7 @@ async def test_create_invitation_handshake_succeeds(self): ) assert invi_rec.invitation._type == DIDCommPrefix.qualify_current( - INVITATION + self.TEST_INVI_MESSAGE_TYPE ) assert not invi_rec.invitation.requests_attach assert ( @@ -414,9 +416,19 @@ async def test_create_invitation_mediation_overwrites_routing_and_endpoint(self) ) assert isinstance(invite, InvitationRecord) assert invite.invitation._type == DIDCommPrefix.qualify_current( - INVITATION + self.TEST_INVI_MESSAGE_TYPE ) assert invite.invitation.label == "test123" + assert ( + DIDKey.from_did( + invite.invitation.services[0].routing_keys[0] + ).public_key_b58 + == self.test_mediator_routing_keys[0] + ) + assert ( + invite.invitation.services[0].service_endpoint + == self.test_mediator_endpoint + ) mock_get_default_mediator.assert_not_called() async def test_create_invitation_multitenant_local(self): @@ -778,11 +790,12 @@ async def test_create_invitation_peer_did(self): public=False, hs_protos=[test_module.HSProto.RFC23], multi_use=False, + service_accept=["didcomm/aip1", "didcomm/aip2;env=rfc19"], ) assert invi_rec._invitation.ser[ "@type" - ] == DIDCommPrefix.qualify_current(INVITATION) + ] == DIDCommPrefix.qualify_current(self.TEST_INVI_MESSAGE_TYPE) assert not invi_rec._invitation.ser.get("requests~attach") assert invi_rec.invitation.label == "That guy" assert ( @@ -880,7 +893,7 @@ async def test_create_handshake_reuse_msg(self): ) oob_record = await self.manager._create_handshake_reuse_message( - oob_record, self.test_conn_rec + oob_record, self.test_conn_rec, get_version_from_message(invitation) ) _, kwargs = self.responder.send.call_args @@ -889,7 +902,9 @@ async def test_create_handshake_reuse_msg(self): assert oob_record.state == OobRecord.STATE_AWAIT_RESPONSE # Assert responder has been called with the reuse message - assert reuse_message._type == DIDCommPrefix.qualify_current(MESSAGE_REUSE) + assert reuse_message._type == DIDCommPrefix.qualify_current( + "out-of-band/1.1/handshake-reuse" + ) assert oob_record.reuse_msg_id == reuse_message._id async def test_create_handshake_reuse_msg_catch_exception(self): @@ -902,7 +917,7 @@ async def test_create_handshake_reuse_msg_catch_exception(self): oob_mgr_fetch_conn.side_effect = StorageNotFoundError() with self.assertRaises(OutOfBandManagerError) as context: await self.manager._create_handshake_reuse_message( - async_mock.MagicMock(), self.test_conn_rec + async_mock.MagicMock(), self.test_conn_rec, "1.0" ) assert "Error on creating and sending a handshake reuse message" in str( context.exception @@ -973,7 +988,7 @@ async def test_receive_reuse_message_existing_found_multi_use(self): recipient_did_public=False, ) - reuse_msg = HandshakeReuse() + reuse_msg = HandshakeReuse(version="1.0") reuse_msg.assign_thread_id(thid="the-thread-id", pthid="the-pthid") self.test_conn_rec.invitation_msg_id = "test_123" @@ -1425,7 +1440,9 @@ async def test_receive_invitation_handshake_reuse(self): ) perform_handshake.assert_not_called() - handle_handshake_reuse.assert_called_once_with(ANY, test_exist_conn) + handle_handshake_reuse.assert_called_once_with( + ANY, test_exist_conn, get_version_from_message(oob_invitation) + ) assert result.state == OobRecord.STATE_ACCEPTED @@ -1486,12 +1503,15 @@ async def test_receive_invitation_handshake_reuse_failed(self): mediation_id="mediation_id", ) - handle_handshake_reuse.assert_called_once_with(ANY, test_exist_conn) + handle_handshake_reuse.assert_called_once_with( + ANY, test_exist_conn, get_version_from_message(oob_invitation) + ) perform_handshake.assert_called_once_with( oob_record=ANY, alias="alias", auto_accept=True, mediation_id="mediation_id", + service_accept=None, ) assert mock_oob.state == OobRecord.STATE_DONE diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py index d7aaffabf8..cf0349c34a 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py @@ -57,6 +57,49 @@ async def test_invitation_create(self): metadata=body["metadata"], alias=None, mediation_id=None, + service_accept=None, + protocol_version=None, + ) + mock_json_response.assert_called_once_with({"abc": "123"}) + + async def test_invitation_create_with_accept(self): + self.request.query = { + "multi_use": "true", + "auto_accept": "true", + } + body = { + "attachments": async_mock.MagicMock(), + "handshake_protocols": [test_module.HSProto.RFC23.name], + "accept": ["didcomm/aip1", "didcomm/aip2;env=rfc19"], + "use_public_did": True, + "metadata": {"hello": "world"}, + } + self.request.json = async_mock.CoroutineMock(return_value=body) + + with async_mock.patch.object( + test_module, "OutOfBandManager", autospec=True + ) as mock_oob_mgr, async_mock.patch.object( + test_module.web, "json_response", async_mock.Mock() + ) as mock_json_response: + mock_oob_mgr.return_value.create_invitation = async_mock.CoroutineMock( + return_value=async_mock.MagicMock( + serialize=async_mock.MagicMock(return_value={"abc": "123"}) + ) + ) + + await test_module.invitation_create(self.request) + mock_oob_mgr.return_value.create_invitation.assert_called_once_with( + my_label=None, + auto_accept=True, + public=True, + multi_use=True, + hs_protos=[test_module.HSProto.RFC23], + attachments=body["attachments"], + metadata=body["metadata"], + alias=None, + mediation_id=None, + service_accept=["didcomm/aip1", "didcomm/aip2;env=rfc19"], + protocol_version=None, ) mock_json_response.assert_called_once_with({"abc": "123"}) diff --git a/aries_cloudagent/resolver/base.py b/aries_cloudagent/resolver/base.py index 994bfe94b2..df31cf4dd3 100644 --- a/aries_cloudagent/resolver/base.py +++ b/aries_cloudagent/resolver/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import NamedTuple, Pattern, Sequence, Union +from typing import Optional, NamedTuple, Pattern, Sequence, Union, Text from pydid import DID @@ -132,7 +132,12 @@ async def supports(self, profile: Profile, did: str) -> bool: return bool(supported_did_regex.match(did)) - async def resolve(self, profile: Profile, did: Union[str, DID]) -> dict: + async def resolve( + self, + profile: Profile, + did: Union[str, DID], + service_accept: Optional[Sequence[Text]] = None, + ) -> dict: """Resolve a DID using this resolver.""" if isinstance(did, DID): did = str(did) @@ -143,8 +148,13 @@ async def resolve(self, profile: Profile, did: Union[str, DID]) -> dict: f"{self.__class__.__name__} does not support DID method for: {did}" ) - return await self._resolve(profile, did) + return await self._resolve(profile, did, service_accept) @abstractmethod - async def _resolve(self, profile: Profile, did: str) -> dict: + async def _resolve( + self, + profile: Profile, + did: str, + service_accept: Optional[Sequence[Text]] = None, + ) -> dict: """Resolve a DID using this resolver.""" diff --git a/aries_cloudagent/resolver/default/indy.py b/aries_cloudagent/resolver/default/indy.py index d1088121ba..ccf6ef4a81 100644 --- a/aries_cloudagent/resolver/default/indy.py +++ b/aries_cloudagent/resolver/default/indy.py @@ -3,10 +3,11 @@ Resolution is performed using the IndyLedger class. """ -from typing import Pattern +import logging +from typing import Optional, Pattern, Sequence, Text from pydid import DID, DIDDocumentBuilder -from pydid.verification_method import Ed25519VerificationKey2018 +from pydid.verification_method import Ed25519VerificationKey2018, VerificationMethod from ...config.injection_context import InjectionContext from ...core.profile import Profile @@ -21,6 +22,8 @@ from ..base import BaseDIDResolver, DIDNotFound, ResolverError, ResolverType +LOGGER = logging.getLogger(__name__) + class NoIndyLedger(ResolverError): """Raised when there is no Indy ledger instance configured.""" @@ -29,6 +32,9 @@ class NoIndyLedger(ResolverError): class IndyDIDResolver(BaseDIDResolver): """Indy DID Resolver.""" + SERVICE_TYPE_DID_COMMUNICATION = "did-communication" + SERVICE_TYPE_DIDCOMM = "DIDComm" + SERVICE_TYPE_ENDPOINT = "endpoint" AGENT_SERVICE_TYPE = "did-communication" def __init__(self): @@ -43,7 +49,96 @@ def supported_did_regex(self) -> Pattern: """Return supported_did_regex of Indy DID Resolver.""" return IndyDID.PATTERN - async def _resolve(self, profile: Profile, did: str) -> dict: + def process_endpoint_types(self, types): + """Process endpoint types. + + Returns expected types, subset of expected types, + or default types. + """ + expected_types = ["endpoint", "did-communication", "DIDComm"] + default_types = ["endpoint", "did-communication"] + if len(types) <= 0: + return default_types + for type in types: + if type not in expected_types: + return default_types + return types + + def add_services( + self, + builder: DIDDocumentBuilder, + endpoints: Optional[dict], + recipient_key: VerificationMethod = None, + service_accept: Optional[Sequence[Text]] = None, + ): + """Add services.""" + if not endpoints: + return + + endpoint = endpoints.get("endpoint") + routing_keys = endpoints.get("routingKeys", []) + types = endpoints.get("types", [self.SERVICE_TYPE_DID_COMMUNICATION]) + + other_endpoints = { + key: endpoints[key] + for key in ("profile", "linked_domains") + if key in endpoints + } + + if endpoint: + processed_types = self.process_endpoint_types(types) + + if self.SERVICE_TYPE_ENDPOINT in processed_types: + builder.service.add( + ident="endpoint", + service_endpoint=endpoint, + type_=self.SERVICE_TYPE_ENDPOINT, + ) + + if self.SERVICE_TYPE_DID_COMMUNICATION in processed_types: + builder.service.add( + ident="did-communication", + type_=self.SERVICE_TYPE_DID_COMMUNICATION, + service_endpoint=endpoint, + priority=1, + routing_keys=routing_keys, + recipient_keys=[recipient_key.id], + accept=( + service_accept if service_accept else ["didcomm/aip2;env=rfc19"] + ), + ) + + if self.SERVICE_TYPE_DIDCOMM in types: + builder.service.add( + ident="#didcomm-1", + type_=self.SERVICE_TYPE_DIDCOMM, + service_endpoint=endpoint, + recipient_keys=[recipient_key.id], + routing_keys=routing_keys, + # CHECKME + # accept=(service_accept if service_accept else ["didcomm/v2"]), + accept=["didcomm/v2"], + ) + builder.context.append(self.CONTEXT_DIDCOMM_V2) + else: + LOGGER.warning( + "No endpoint for DID although endpoint attrib was resolvable" + ) + + if other_endpoints: + for type_, endpoint in other_endpoints.items(): + builder.service.add( + ident=type_, + type_=EndpointType.get(type_).w3c, + service_endpoint=endpoint, + ) + + async def _resolve( + self, + profile: Profile, + did: str, + service_accept: Optional[Sequence[Text]] = None, + ) -> dict: """Resolve an indy DID.""" multitenant_mgr = profile.inject_or(BaseMultitenantManager) if multitenant_mgr: @@ -73,24 +168,7 @@ async def _resolve(self, profile: Profile, did: str) -> dict: ) builder.authentication.reference(vmethod.id) builder.assertion_method.reference(vmethod.id) - if endpoints: - for type_, endpoint in endpoints.items(): - if type_ == EndpointType.ENDPOINT.indy: - builder.service.add_didcomm( - ident=self.AGENT_SERVICE_TYPE, - type_=self.AGENT_SERVICE_TYPE, - service_endpoint=endpoint, - priority=1, - recipient_keys=[vmethod], - routing_keys=[], - ) - else: - # Accept all service types for now - builder.service.add( - ident=type_, - type_=type_, - service_endpoint=endpoint, - ) + self.add_services(builder, endpoints, vmethod, service_accept) result = builder.build() return result.serialize() diff --git a/aries_cloudagent/resolver/default/key.py b/aries_cloudagent/resolver/default/key.py index 2f9f10edf9..0217156f81 100644 --- a/aries_cloudagent/resolver/default/key.py +++ b/aries_cloudagent/resolver/default/key.py @@ -3,7 +3,7 @@ Resolution is performed using the IndyLedger class. """ -from typing import Pattern +from typing import Optional, Pattern, Sequence, Text from ...did.did_key import DIDKey from ...config.injection_context import InjectionContext @@ -28,7 +28,12 @@ def supported_did_regex(self) -> Pattern: """Return supported_did_regex of Key DID Resolver.""" return DIDKeyType.PATTERN - async def _resolve(self, profile: Profile, did: str) -> dict: + async def _resolve( + self, + profile: Profile, + did: str, + service_accept: Optional[Sequence[Text]] = None, + ) -> dict: """Resolve a Key DID.""" try: did_key = DIDKey.from_did(did) diff --git a/aries_cloudagent/resolver/default/tests/test_indy.py b/aries_cloudagent/resolver/default/tests/test_indy.py index 7781dac13f..34464bfa7f 100644 --- a/aries_cloudagent/resolver/default/tests/test_indy.py +++ b/aries_cloudagent/resolver/default/tests/test_indy.py @@ -33,8 +33,11 @@ def resolver(): def ledger(): """Ledger fixture.""" ledger = async_mock.MagicMock(spec=BaseLedger) - ledger.get_endpoint_for_did = async_mock.CoroutineMock( - return_value="https://github.com/" + ledger.get_all_endpoints_for_did = async_mock.CoroutineMock( + return_value={ + "endpoint": "https://github.com/", + "profile": "https://example.com/profile", + } ) ledger.get_key_for_did = async_mock.CoroutineMock(return_value="key") yield ledger @@ -67,6 +70,15 @@ async def test_resolve(self, profile: Profile, resolver: IndyDIDResolver): """Test resolve method.""" assert await resolver.resolve(profile, TEST_DID0) + @pytest.mark.asyncio + async def test_resolve_with_accept( + self, profile: Profile, resolver: IndyDIDResolver + ): + """Test resolve method.""" + assert await resolver.resolve( + profile, TEST_DID0, ["didcomm/aip1", "didcomm/aip2;env=rfc19"] + ) + @pytest.mark.asyncio async def test_resolve_multitenant( self, profile: Profile, resolver: IndyDIDResolver, ledger: BaseLedger diff --git a/aries_cloudagent/resolver/default/universal.py b/aries_cloudagent/resolver/default/universal.py new file mode 100644 index 0000000000..85ca9e2dba --- /dev/null +++ b/aries_cloudagent/resolver/default/universal.py @@ -0,0 +1,107 @@ +"""HTTP Universal DID Resolver.""" + +import logging +import re +from typing import Iterable, Optional, Pattern, Sequence, Union, Text + +import aiohttp + +from ...config.injection_context import InjectionContext +from ...core.profile import Profile +from ..base import BaseDIDResolver, DIDNotFound, ResolverError, ResolverType + +LOGGER = logging.getLogger(__name__) +DEFAULT_ENDPOINT = "https://dev.uniresolver.io" + + +async def _fetch_resolver_props(endpoint: str) -> dict: + """Retrieve universal resolver properties.""" + async with aiohttp.ClientSession() as session: + async with session.get(f"{endpoint}/1.0/properties/") as resp: + if resp.status >= 200 and resp.status < 400: + return await resp.json() + raise ResolverError( + "Failed to retrieve resolver properties: " + await resp.text() + ) + + +async def _get_supported_did_regex(endpoint: str) -> Pattern: + props = await _fetch_resolver_props(endpoint) + return _compile_supported_did_regex( + driver["http"]["pattern"] for driver in props.values() + ) + + +def _compile_supported_did_regex(patterns: Iterable[Union[str, Pattern]]): + """Create regex from list of regex.""" + return re.compile( + "(?:" + + "|".join( + [ + pattern.pattern if isinstance(pattern, Pattern) else pattern + for pattern in patterns + ] + ) + + ")" + ) + + +class UniversalResolver(BaseDIDResolver): + """Universal DID Resolver with HTTP bindings.""" + + def __init__( + self, + *, + endpoint: Optional[str] = None, + supported_did_regex: Optional[Pattern] = None, + ): + """Initialize UniversalResolver.""" + super().__init__(ResolverType.NON_NATIVE) + self._endpoint = endpoint + self._supported_did_regex = supported_did_regex + + async def setup(self, context: InjectionContext): + """Preform setup, populate supported method list, configuration.""" + endpoint = context.settings.get_str("resolver.universal") + if endpoint == "DEFAULT" or not endpoint: + endpoint = DEFAULT_ENDPOINT + + supported = context.settings.get("resolver.universal.supported") + if supported is None: + supported_did_regex = await _get_supported_did_regex(endpoint) + else: + supported_did_regex = _compile_supported_did_regex(supported) + + self._endpoint = endpoint + self._supported_did_regex = supported_did_regex + + @property + def supported_did_regex(self) -> Pattern: + """Return supported methods regex.""" + if not self._supported_did_regex: + raise ResolverError("Resolver has not been set up") + + return self._supported_did_regex + + async def _resolve( + self, + _profile: Profile, + did: str, + service_accept: Optional[Sequence[Text]] = None, + ) -> dict: + """Resolve DID through remote universal resolver.""" + + async with aiohttp.ClientSession() as session: + async with session.get(f"{self._endpoint}/1.0/identifiers/{did}") as resp: + if resp.status == 200: + doc = await resp.json() + did_doc = doc["didDocument"] + LOGGER.info("Retrieved doc: %s", did_doc) + return did_doc + if resp.status == 404: + raise DIDNotFound(f"{did} not found by {self.__class__.__name__}") + + text = await resp.text() + raise ResolverError( + f"Unexecpted status from universal resolver ({resp.status}): {text}" + ) diff --git a/aries_cloudagent/resolver/default/web.py b/aries_cloudagent/resolver/default/web.py index 30aa5ce440..df2d99e6cc 100644 --- a/aries_cloudagent/resolver/default/web.py +++ b/aries_cloudagent/resolver/default/web.py @@ -2,7 +2,7 @@ import urllib.parse -from typing import Pattern +from typing import Optional, Pattern, Sequence, Text import aiohttp @@ -57,7 +57,12 @@ def __transform_to_url(self, did): return "https://" + url + "/did.json" - async def _resolve(self, profile: Profile, did: str) -> dict: + async def _resolve( + self, + profile: Profile, + did: str, + service_accept: Optional[Sequence[Text]] = None, + ) -> dict: """Resolve did:web DIDs.""" url = self.__transform_to_url(did) diff --git a/aries_cloudagent/resolver/did_resolver.py b/aries_cloudagent/resolver/did_resolver.py index f57ec47fd9..6865e7e2fd 100644 --- a/aries_cloudagent/resolver/did_resolver.py +++ b/aries_cloudagent/resolver/did_resolver.py @@ -8,7 +8,7 @@ from datetime import datetime from itertools import chain import logging -from typing import Sequence, Tuple, Type, TypeVar, Union +from typing import Optional, Sequence, Tuple, Text, Type, TypeVar, Union from pydid import DID, DIDError, DIDUrl, Resource, NonconformantDocument from pydid.doc.doc import IDNotFoundError @@ -38,7 +38,10 @@ def __init__(self, registry: DIDResolverRegistry): self.did_resolver_registry = registry async def _resolve( - self, profile: Profile, did: Union[str, DID] + self, + profile: Profile, + did: Union[str, DID], + service_accept: Optional[Sequence[Text]] = None, ) -> Tuple[BaseDIDResolver, dict]: """Retrieve doc and return with resolver.""" # TODO Cache results @@ -52,6 +55,7 @@ async def _resolve( document = await resolver.resolve( profile, did, + service_accept, ) return resolver, document except DIDNotFound: @@ -59,9 +63,14 @@ async def _resolve( raise DIDNotFound(f"DID {did} could not be resolved") - async def resolve(self, profile: Profile, did: Union[str, DID]) -> dict: + async def resolve( + self, + profile: Profile, + did: Union[str, DID], + service_accept: Optional[Sequence[Text]] = None, + ) -> dict: """Resolve a DID.""" - _, doc = await self._resolve(profile, did) + _, doc = await self._resolve(profile, did, service_accept) return doc async def resolve_with_metadata( diff --git a/aries_cloudagent/resolver/tests/test_base.py b/aries_cloudagent/resolver/tests/test_base.py index cdae76a91e..2108468458 100644 --- a/aries_cloudagent/resolver/tests/test_base.py +++ b/aries_cloudagent/resolver/tests/test_base.py @@ -22,7 +22,7 @@ async def setup(self, context): def supported_did_regex(self): return re.compile("^did:example:[a-zA-Z0-9_.-]+$") - async def _resolve(self, profile, did) -> DIDDocument: + async def _resolve(self, profile, did, accept) -> DIDDocument: return DIDDocument("did:example:123") @@ -74,7 +74,7 @@ async def setup(self, context): def supported_methods(self): return ["example"] - async def _resolve(self, profile, did) -> DIDDocument: + async def _resolve(self, profile, did, accept) -> DIDDocument: return DIDDocument("did:example:123") with pytest.deprecated_call(): diff --git a/aries_cloudagent/resolver/tests/test_did_resolver.py b/aries_cloudagent/resolver/tests/test_did_resolver.py index b08480e805..1ed26b582b 100644 --- a/aries_cloudagent/resolver/tests/test_did_resolver.py +++ b/aries_cloudagent/resolver/tests/test_did_resolver.py @@ -80,7 +80,7 @@ def supported_did_regex(self) -> Pattern: async def setup(self, context): pass - async def _resolve(self, profile, did): + async def _resolve(self, profile, did, accept): if isinstance(self.resolved, Exception): raise self.resolved return self.resolved.serialize() diff --git a/open-api/openapi.json b/open-api/openapi.json index 3494407fdf..a4eb84e6a4 100644 --- a/open-api/openapi.json +++ b/open-api/openapi.json @@ -7328,6 +7328,14 @@ "description" : "Handshake protocol to specify in invitation" } }, + "accept" : { + "type" : "array", + "items" : { + "type" : "string", + "example" : "didcomm/aip1", + "description" : "Mime type list in order of preference to be used in response" + } + }, "mediation_id" : { "type" : "string", "example" : "3fa85f64-5717-4562-b3fc-2c963f66afa6", @@ -7373,6 +7381,14 @@ "description" : "Handshake protocol" } }, + "accept" : { + "type" : "array", + "items" : { + "type" : "string", + "example" : "didcomm/aip1", + "description" : "Mime type list in order of preference to be used in response" + } + }, "label" : { "type" : "string", "example" : "Bob", diff --git a/requirements.dev.txt b/requirements.dev.txt index 22433dcadc..336d280aba 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,8 +1,10 @@ asynctest==0.13.0 +async-case~=10.1 pytest~=5.4.0 pytest-asyncio==0.14.0 pytest-cov==2.10.1 pytest-flake8==1.0.6 +mock~=4.0 flake8==3.9.0 # flake8-rst-docstrings==0.0.8 @@ -17,4 +19,4 @@ sphinx-rtd-theme>=0.4.3 ptvsd==4.3.2 pydevd==1.5.1 -pydevd-pycharm~=193.6015.39 \ No newline at end of file +pydevd-pycharm~=193.6015.39