Skip to content

Commit 690d242

Browse files
committed
feat: Add AOP proxy support to MethodToolCallbackProvider
This change ensures that @tool annotated methods can be properly discovered even when the tool objects are wrapped in Spring AOP proxies, which is common when using aspects or other proxy-based features. - Enhance MethodToolCallbackProvider to properly handle AOP proxied tool objects by detecting proxies and retrieving their target classes when scanning for @tool annotated methods. - Add test suite in MethodToolCallbackProviderAopTests.java to verify AOP proxy handling Resolves #2356 Signed-off-by: Christian Tzolov <[email protected]>
1 parent bc375ab commit 690d242

File tree

4 files changed

+254
-1
lines changed

4 files changed

+254
-1
lines changed

Diff for: spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.springframework.ai.tool.definition.ToolDefinition;
2525
import org.springframework.ai.tool.metadata.ToolMetadata;
2626
import org.springframework.ai.tool.util.ToolUtils;
27+
import org.springframework.aop.support.AopUtils;
2728
import org.springframework.util.Assert;
2829
import org.springframework.util.ClassUtils;
2930
import org.springframework.util.ReflectionUtils;
@@ -59,7 +60,9 @@ private MethodToolCallbackProvider(List<Object> toolObjects) {
5960
@Override
6061
public ToolCallback[] getToolCallbacks() {
6162
var toolCallbacks = toolObjects.stream()
62-
.map(toolObject -> Stream.of(ReflectionUtils.getDeclaredMethods(toolObject.getClass()))
63+
.map(toolObject -> Stream
64+
.of(ReflectionUtils.getDeclaredMethods(
65+
AopUtils.isAopProxy(toolObject) ? AopUtils.getTargetClass(toolObject) : toolObject.getClass()))
6366
.filter(toolMethod -> toolMethod.isAnnotationPresent(Tool.class))
6467
.filter(toolMethod -> !isFunctionalType(toolMethod))
6568
.map(toolMethod -> MethodToolCallback.builder()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.tool.method;
17+
18+
import org.junit.jupiter.api.Test;
19+
import org.junit.jupiter.api.extension.ExtendWith;
20+
import org.mockito.MockedStatic;
21+
import org.mockito.Mockito;
22+
import org.mockito.junit.jupiter.MockitoExtension;
23+
import org.springframework.ai.tool.ToolCallback;
24+
import org.springframework.ai.tool.annotation.Tool;
25+
import org.springframework.aop.framework.ProxyFactory;
26+
import org.springframework.aop.support.AopUtils;
27+
import org.springframework.aop.support.DefaultPointcutAdvisor;
28+
import org.springframework.aop.support.annotation.AnnotationMatchingPointcut;
29+
import org.aopalliance.intercept.MethodInterceptor;
30+
import org.aopalliance.intercept.MethodInvocation;
31+
import org.springframework.stereotype.Component;
32+
33+
import java.util.List;
34+
import java.util.stream.Stream;
35+
36+
import static org.assertj.core.api.Assertions.assertThat;
37+
import static org.mockito.ArgumentMatchers.any;
38+
import static org.mockito.Mockito.times;
39+
40+
/**
41+
* Tests for {@link MethodToolCallbackProvider} with AOP proxies.
42+
*
43+
* @author Christian Tzolov
44+
*/
45+
@ExtendWith(MockitoExtension.class)
46+
class MethodToolCallbackProviderAopTests {
47+
48+
/**
49+
* Test annotation to simulate a Spring AOP aspect
50+
*/
51+
@java.lang.annotation.Target({ java.lang.annotation.ElementType.METHOD })
52+
@java.lang.annotation.Retention(java.lang.annotation.RetentionPolicy.RUNTIME)
53+
@java.lang.annotation.Documented
54+
public @interface LogExecution {
55+
56+
}
57+
58+
/**
59+
* Sample bean with methods annotated with both @Tool and @LogExecution
60+
*/
61+
@Component
62+
static class ToolsWithAopAnnotations {
63+
64+
@Tool(description = "Method with AOP annotation")
65+
@LogExecution
66+
public String methodWithAopAnnotation(String input) {
67+
return "Processed: " + input;
68+
}
69+
70+
@Tool(description = "Another method with AOP annotation")
71+
@LogExecution
72+
public List<String> anotherMethodWithAopAnnotation(String input) {
73+
return List.of("Item: " + input);
74+
}
75+
76+
@Tool(description = "Method without AOP annotation")
77+
public String methodWithoutAopAnnotation(String input) {
78+
return "Regular: " + input;
79+
}
80+
81+
}
82+
83+
@Test
84+
void shouldHandleAopProxiedToolObject() {
85+
// Create the original tool object
86+
ToolsWithAopAnnotations originalToolObject = new ToolsWithAopAnnotations();
87+
88+
// Create a proxy for the tool object with an aspect for @LogExecution annotation
89+
ProxyFactory proxyFactory = new ProxyFactory(originalToolObject);
90+
AnnotationMatchingPointcut pointcut = new AnnotationMatchingPointcut(null, LogExecution.class);
91+
92+
// Create a method interceptor for logging
93+
MethodInterceptor loggingInterceptor = new MethodInterceptor() {
94+
@Override
95+
public Object invoke(MethodInvocation methodInvocation) throws Throwable {
96+
// Simple logging advice
97+
System.out.println("Before executing: " + methodInvocation.getMethod().getName());
98+
Object result = methodInvocation.proceed();
99+
System.out.println("After executing: " + methodInvocation.getMethod().getName());
100+
return result;
101+
}
102+
};
103+
104+
proxyFactory.addAdvisor(new DefaultPointcutAdvisor(pointcut, loggingInterceptor));
105+
106+
Object proxiedToolObject = proxyFactory.getProxy();
107+
108+
// Verify that the object is indeed a proxy
109+
assertThat(AopUtils.isAopProxy(proxiedToolObject)).isTrue();
110+
assertThat(AopUtils.getTargetClass(proxiedToolObject)).isEqualTo(ToolsWithAopAnnotations.class);
111+
112+
// Create the provider with the proxied object
113+
MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder()
114+
.toolObjects(proxiedToolObject)
115+
.build();
116+
117+
// Get the tool callbacks
118+
ToolCallback[] callbacks = provider.getToolCallbacks();
119+
120+
// Verify that all methods with @Tool annotation are found, including those with
121+
// @LogExecution
122+
assertThat(callbacks).hasSize(3);
123+
124+
// Verify that the tool names match the expected method names
125+
assertThat(Stream.of(callbacks).map(ToolCallback::getName)).containsExactlyInAnyOrder("methodWithAopAnnotation",
126+
"anotherMethodWithAopAnnotation", "methodWithoutAopAnnotation");
127+
}
128+
129+
/**
130+
* This test specifically validates the AOP proxy handling logic in
131+
* MethodToolCallbackProvider. It uses Mockito to verify that AopUtils.isAopProxy and
132+
* AopUtils.getTargetClass are called correctly when processing a proxied object.
133+
*/
134+
@Test
135+
void shouldUseAopUtilsToHandleProxiedObjects() {
136+
// Create the original tool object
137+
ToolsWithAopAnnotations originalToolObject = new ToolsWithAopAnnotations();
138+
139+
// Create a proxy for the tool object
140+
ProxyFactory proxyFactory = new ProxyFactory(originalToolObject);
141+
AnnotationMatchingPointcut pointcut = new AnnotationMatchingPointcut(null, LogExecution.class);
142+
143+
MethodInterceptor loggingInterceptor = new MethodInterceptor() {
144+
@Override
145+
public Object invoke(MethodInvocation methodInvocation) throws Throwable {
146+
return methodInvocation.proceed();
147+
}
148+
};
149+
150+
proxyFactory.addAdvisor(new DefaultPointcutAdvisor(pointcut, loggingInterceptor));
151+
Object proxiedToolObject = proxyFactory.getProxy();
152+
153+
// Use MockedStatic to verify AopUtils static methods are called
154+
try (MockedStatic<AopUtils> mockedAopUtils = Mockito.mockStatic(AopUtils.class)) {
155+
// Set up the mocked behavior
156+
mockedAopUtils.when(() -> AopUtils.isAopProxy(any())).thenReturn(true);
157+
mockedAopUtils.when(() -> AopUtils.getTargetClass(any())).thenReturn(ToolsWithAopAnnotations.class);
158+
159+
// Create the provider with the proxied object
160+
MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder()
161+
.toolObjects(proxiedToolObject)
162+
.build();
163+
164+
// Get the tool callbacks - this should trigger the AopUtils methods
165+
provider.getToolCallbacks();
166+
167+
// Verify that AopUtils.isAopProxy was called with the proxied object
168+
mockedAopUtils.verify(() -> AopUtils.isAopProxy(proxiedToolObject), times(1));
169+
170+
// Verify that AopUtils.getTargetClass was called with the proxied object
171+
mockedAopUtils.verify(() -> AopUtils.getTargetClass(proxiedToolObject), times(1));
172+
}
173+
}
174+
175+
@Test
176+
void shouldHandleMixOfProxiedAndNonProxiedToolObjects() {
177+
// Create the original tool objects
178+
ToolsWithAopAnnotations originalToolObject = new ToolsWithAopAnnotations();
179+
180+
// Create a proxy for one of the tool objects
181+
ProxyFactory proxyFactory = new ProxyFactory(originalToolObject);
182+
AnnotationMatchingPointcut pointcut = new AnnotationMatchingPointcut(null, LogExecution.class);
183+
184+
// Create a method interceptor for logging
185+
MethodInterceptor loggingInterceptor = new MethodInterceptor() {
186+
@Override
187+
public Object invoke(MethodInvocation methodInvocation) throws Throwable {
188+
// Simple logging advice
189+
System.out.println("Before executing: " + methodInvocation.getMethod().getName());
190+
Object result = methodInvocation.proceed();
191+
System.out.println("After executing: " + methodInvocation.getMethod().getName());
192+
return result;
193+
}
194+
};
195+
196+
proxyFactory.addAdvisor(new DefaultPointcutAdvisor(pointcut, loggingInterceptor));
197+
198+
Object proxiedToolObject = proxyFactory.getProxy();
199+
200+
// Create a non-proxied tool object
201+
MethodToolCallbackProviderTests.ToolsExtra nonProxiedToolObject = new MethodToolCallbackProviderTests.ToolsExtra();
202+
203+
// Create the provider with both proxied and non-proxied objects
204+
MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder()
205+
.toolObjects(proxiedToolObject, nonProxiedToolObject)
206+
.build();
207+
208+
// Get the tool callbacks
209+
ToolCallback[] callbacks = provider.getToolCallbacks();
210+
211+
// Verify that all methods with @Tool annotation are found from both objects
212+
assertThat(callbacks).hasSize(5); // 3 from proxied + 2 from non-proxied
213+
214+
// Verify that the tool names match the expected method names
215+
assertThat(Stream.of(callbacks).map(ToolCallback::getName)).containsExactlyInAnyOrder("methodWithAopAnnotation",
216+
"anotherMethodWithAopAnnotation", "methodWithoutAopAnnotation", "extraMethod1", "extraMethod2");
217+
}
218+
219+
}

Diff for: spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java

+15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
116
package org.springframework.ai.tool.method;
217

318
import org.junit.jupiter.api.Nested;

Diff for: spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackTests.java

+16
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
117
package org.springframework.ai.tool.method;
218

319
import com.fasterxml.jackson.core.type.TypeReference;

0 commit comments

Comments
 (0)