Skip to content

Commit 1958e42

Browse files
authored
xds: add support for custom per-target credentials on the transport (#11951)
1 parent 94f8e93 commit 1958e42

8 files changed

+198
-51
lines changed

xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java

+38-16
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static com.google.common.base.Preconditions.checkNotNull;
2020

2121
import com.google.common.annotations.VisibleForTesting;
22+
import io.grpc.CallCredentials;
2223
import io.grpc.CallOptions;
2324
import io.grpc.ChannelCredentials;
2425
import io.grpc.ClientCall;
@@ -34,35 +35,50 @@
3435

3536
final class GrpcXdsTransportFactory implements XdsTransportFactory {
3637

37-
static final GrpcXdsTransportFactory DEFAULT_XDS_TRANSPORT_FACTORY =
38-
new GrpcXdsTransportFactory();
38+
private final CallCredentials callCredentials;
39+
40+
GrpcXdsTransportFactory(CallCredentials callCredentials) {
41+
this.callCredentials = callCredentials;
42+
}
3943

4044
@Override
4145
public XdsTransport create(Bootstrapper.ServerInfo serverInfo) {
42-
return new GrpcXdsTransport(serverInfo);
46+
return new GrpcXdsTransport(serverInfo, callCredentials);
4347
}
4448

4549
@VisibleForTesting
4650
public XdsTransport createForTest(ManagedChannel channel) {
47-
return new GrpcXdsTransport(channel);
51+
return new GrpcXdsTransport(channel, callCredentials);
4852
}
4953

5054
@VisibleForTesting
5155
static class GrpcXdsTransport implements XdsTransport {
5256

5357
private final ManagedChannel channel;
58+
private final CallCredentials callCredentials;
5459

5560
public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) {
61+
this(serverInfo, null);
62+
}
63+
64+
@VisibleForTesting
65+
public GrpcXdsTransport(ManagedChannel channel) {
66+
this(channel, null);
67+
}
68+
69+
public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) {
5670
String target = serverInfo.target();
5771
ChannelCredentials channelCredentials = (ChannelCredentials) serverInfo.implSpecificConfig();
5872
this.channel = Grpc.newChannelBuilder(target, channelCredentials)
5973
.keepAliveTime(5, TimeUnit.MINUTES)
6074
.build();
75+
this.callCredentials = callCredentials;
6176
}
6277

6378
@VisibleForTesting
64-
public GrpcXdsTransport(ManagedChannel channel) {
79+
public GrpcXdsTransport(ManagedChannel channel, CallCredentials callCredentials) {
6580
this.channel = checkNotNull(channel, "channel");
81+
this.callCredentials = callCredentials;
6682
}
6783

6884
@Override
@@ -72,7 +88,8 @@ public <ReqT, RespT> StreamingCall<ReqT, RespT> createStreamingCall(
7288
MethodDescriptor.Marshaller<RespT> respMarshaller) {
7389
Context prevContext = Context.ROOT.attach();
7490
try {
75-
return new XdsStreamingCall<>(fullMethodName, reqMarshaller, respMarshaller);
91+
return new XdsStreamingCall<>(
92+
fullMethodName, reqMarshaller, respMarshaller, callCredentials);
7693
} finally {
7794
Context.ROOT.detach(prevContext);
7895
}
@@ -89,16 +106,21 @@ private class XdsStreamingCall<ReqT, RespT> implements
89106

90107
private final ClientCall<ReqT, RespT> call;
91108

92-
public XdsStreamingCall(String methodName, MethodDescriptor.Marshaller<ReqT> reqMarshaller,
93-
MethodDescriptor.Marshaller<RespT> respMarshaller) {
94-
this.call = channel.newCall(
95-
MethodDescriptor.<ReqT, RespT>newBuilder()
96-
.setFullMethodName(methodName)
97-
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
98-
.setRequestMarshaller(reqMarshaller)
99-
.setResponseMarshaller(respMarshaller)
100-
.build(),
101-
CallOptions.DEFAULT); // TODO(zivy): support waitForReady
109+
public XdsStreamingCall(
110+
String methodName,
111+
MethodDescriptor.Marshaller<ReqT> reqMarshaller,
112+
MethodDescriptor.Marshaller<RespT> respMarshaller,
113+
CallCredentials callCredentials) {
114+
this.call =
115+
channel.newCall(
116+
MethodDescriptor.<ReqT, RespT>newBuilder()
117+
.setFullMethodName(methodName)
118+
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
119+
.setRequestMarshaller(reqMarshaller)
120+
.setResponseMarshaller(respMarshaller)
121+
.build(),
122+
CallOptions.DEFAULT.withCallCredentials(
123+
callCredentials)); // TODO(zivy): support waitForReady
102124
}
103125

104126
@Override

xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package io.grpc.xds;
1818

19+
import io.grpc.CallCredentials;
1920
import io.grpc.Internal;
2021
import io.grpc.MetricRecorder;
2122
import io.grpc.internal.ObjectPool;
@@ -42,6 +43,13 @@ public static ObjectPool<XdsClient> getOrCreate(String target)
4243

4344
public static ObjectPool<XdsClient> getOrCreate(String target, MetricRecorder metricRecorder)
4445
throws XdsInitializationException {
45-
return SharedXdsClientPoolProvider.getDefaultProvider().getOrCreate(target, metricRecorder);
46+
return getOrCreate(target, metricRecorder, null);
47+
}
48+
49+
public static ObjectPool<XdsClient> getOrCreate(
50+
String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials)
51+
throws XdsInitializationException {
52+
return SharedXdsClientPoolProvider.getDefaultProvider()
53+
.getOrCreate(target, metricRecorder, transportCallCredentials);
4654
}
4755
}

xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java

+36-14
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
package io.grpc.xds;
1818

1919
import static com.google.common.base.Preconditions.checkNotNull;
20-
import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY;
2120

2221
import com.google.common.annotations.VisibleForTesting;
2322
import com.google.common.collect.ImmutableList;
2423
import com.google.errorprone.annotations.concurrent.GuardedBy;
24+
import io.grpc.CallCredentials;
2525
import io.grpc.MetricRecorder;
2626
import io.grpc.internal.ExponentialBackoffPolicy;
2727
import io.grpc.internal.GrpcUtil;
@@ -87,6 +87,12 @@ public ObjectPool<XdsClient> get(String target) {
8787
@Override
8888
public ObjectPool<XdsClient> getOrCreate(String target, MetricRecorder metricRecorder)
8989
throws XdsInitializationException {
90+
return getOrCreate(target, metricRecorder, null);
91+
}
92+
93+
public ObjectPool<XdsClient> getOrCreate(
94+
String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials)
95+
throws XdsInitializationException {
9096
ObjectPool<XdsClient> ref = targetToXdsClientMap.get(target);
9197
if (ref == null) {
9298
synchronized (lock) {
@@ -102,7 +108,9 @@ public ObjectPool<XdsClient> getOrCreate(String target, MetricRecorder metricRec
102108
if (bootstrapInfo.servers().isEmpty()) {
103109
throw new XdsInitializationException("No xDS server provided");
104110
}
105-
ref = new RefCountedXdsClientObjectPool(bootstrapInfo, target, metricRecorder);
111+
ref =
112+
new RefCountedXdsClientObjectPool(
113+
bootstrapInfo, target, metricRecorder, transportCallCredentials);
106114
targetToXdsClientMap.put(target, ref);
107115
}
108116
}
@@ -126,6 +134,7 @@ class RefCountedXdsClientObjectPool implements ObjectPool<XdsClient> {
126134
private final BootstrapInfo bootstrapInfo;
127135
private final String target; // The target associated with the xDS client.
128136
private final MetricRecorder metricRecorder;
137+
private final CallCredentials transportCallCredentials;
129138
private final Object lock = new Object();
130139
@GuardedBy("lock")
131140
private ScheduledExecutorService scheduler;
@@ -137,11 +146,21 @@ class RefCountedXdsClientObjectPool implements ObjectPool<XdsClient> {
137146
private XdsClientMetricReporterImpl metricReporter;
138147

139148
@VisibleForTesting
140-
RefCountedXdsClientObjectPool(BootstrapInfo bootstrapInfo, String target,
141-
MetricRecorder metricRecorder) {
149+
RefCountedXdsClientObjectPool(
150+
BootstrapInfo bootstrapInfo, String target, MetricRecorder metricRecorder) {
151+
this(bootstrapInfo, target, metricRecorder, null);
152+
}
153+
154+
@VisibleForTesting
155+
RefCountedXdsClientObjectPool(
156+
BootstrapInfo bootstrapInfo,
157+
String target,
158+
MetricRecorder metricRecorder,
159+
CallCredentials transportCallCredentials) {
142160
this.bootstrapInfo = checkNotNull(bootstrapInfo);
143161
this.target = target;
144162
this.metricRecorder = metricRecorder;
163+
this.transportCallCredentials = transportCallCredentials;
145164
}
146165

147166
@Override
@@ -153,16 +172,19 @@ public XdsClient getObject() {
153172
}
154173
scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE);
155174
metricReporter = new XdsClientMetricReporterImpl(metricRecorder, target);
156-
xdsClient = new XdsClientImpl(
157-
DEFAULT_XDS_TRANSPORT_FACTORY,
158-
bootstrapInfo,
159-
scheduler,
160-
BACKOFF_POLICY_PROVIDER,
161-
GrpcUtil.STOPWATCH_SUPPLIER,
162-
TimeProvider.SYSTEM_TIME_PROVIDER,
163-
MessagePrinter.INSTANCE,
164-
new TlsContextManagerImpl(bootstrapInfo),
165-
metricReporter);
175+
GrpcXdsTransportFactory xdsTransportFactory =
176+
new GrpcXdsTransportFactory(transportCallCredentials);
177+
xdsClient =
178+
new XdsClientImpl(
179+
xdsTransportFactory,
180+
bootstrapInfo,
181+
scheduler,
182+
BACKOFF_POLICY_PROVIDER,
183+
GrpcUtil.STOPWATCH_SUPPLIER,
184+
TimeProvider.SYSTEM_TIME_PROVIDER,
185+
MessagePrinter.INSTANCE,
186+
new TlsContextManagerImpl(bootstrapInfo),
187+
metricReporter);
166188
metricReporter.setXdsClient(xdsClient);
167189
}
168190
refCount++;

xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import static com.google.common.truth.Truth.assertThat;
2020
import static com.google.common.truth.Truth.assertWithMessage;
21-
import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY;
2221
import static org.mockito.ArgumentMatchers.any;
2322
import static org.mockito.ArgumentMatchers.eq;
2423
import static org.mockito.ArgumentMatchers.isA;
@@ -4193,7 +4192,7 @@ public void serverFailureMetricReport_forRetryAndBackoff() {
41934192
private XdsClientImpl createXdsClient(String serverUri) {
41944193
BootstrapInfo bootstrapInfo = buildBootStrap(serverUri);
41954194
return new XdsClientImpl(
4196-
DEFAULT_XDS_TRANSPORT_FACTORY,
4195+
new GrpcXdsTransportFactory(null),
41974196
bootstrapInfo,
41984197
fakeClock.getScheduledExecutorService(),
41994198
backoffPolicyProvider,

xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ public void onCompleted() {
9292
@Test
9393
public void callApis() throws Exception {
9494
XdsTransportFactory.XdsTransport xdsTransport =
95-
GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.create(
96-
Bootstrapper.ServerInfo.create("localhost:" + server.getPort(),
97-
InsecureChannelCredentials.create()));
95+
new GrpcXdsTransportFactory(null)
96+
.create(
97+
Bootstrapper.ServerInfo.create(
98+
"localhost:" + server.getPort(), InsecureChannelCredentials.create()));
9899
MethodDescriptor<DiscoveryRequest, DiscoveryResponse> methodDescriptor =
99100
AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod();
100101
XdsTransportFactory.StreamingCall<DiscoveryRequest, DiscoveryResponse> streamingCall =

xds/src/test/java/io/grpc/xds/LoadReportClientTest.java

+9-5
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,15 @@ public void cancelled(Context context) {
178178
when(backoffPolicy2.nextBackoffNanos())
179179
.thenReturn(TimeUnit.SECONDS.toNanos(2L), TimeUnit.SECONDS.toNanos(20L));
180180
addFakeStatsData();
181-
lrsClient = new LoadReportClient(loadStatsManager,
182-
GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.createForTest(channel),
183-
NODE,
184-
syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider,
185-
fakeClock.getStopwatchSupplier());
181+
lrsClient =
182+
new LoadReportClient(
183+
loadStatsManager,
184+
new GrpcXdsTransportFactory(null).createForTest(channel),
185+
NODE,
186+
syncContext,
187+
fakeClock.getScheduledExecutorService(),
188+
backoffPolicyProvider,
189+
fakeClock.getStopwatchSupplier());
186190
syncContext.execute(new Runnable() {
187191
@Override
188192
public void run() {

xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java

+77
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,36 @@
1818

1919

2020
import static com.google.common.truth.Truth.assertThat;
21+
import static io.grpc.Metadata.ASCII_STRING_MARSHALLER;
2122
import static org.mockito.Mockito.verify;
2223
import static org.mockito.Mockito.verifyNoMoreInteractions;
2324
import static org.mockito.Mockito.when;
2425

26+
import com.google.auth.oauth2.AccessToken;
27+
import com.google.auth.oauth2.OAuth2Credentials;
28+
import com.google.common.util.concurrent.SettableFuture;
29+
import io.grpc.CallCredentials;
30+
import io.grpc.Grpc;
2531
import io.grpc.InsecureChannelCredentials;
32+
import io.grpc.InsecureServerCredentials;
33+
import io.grpc.Metadata;
2634
import io.grpc.MetricRecorder;
35+
import io.grpc.Server;
36+
import io.grpc.ServerCall;
37+
import io.grpc.ServerCallHandler;
38+
import io.grpc.ServerInterceptor;
39+
import io.grpc.auth.MoreCallCredentials;
2740
import io.grpc.internal.ObjectPool;
2841
import io.grpc.xds.SharedXdsClientPoolProvider.RefCountedXdsClientObjectPool;
42+
import io.grpc.xds.XdsListenerResource.LdsUpdate;
2943
import io.grpc.xds.client.Bootstrapper.BootstrapInfo;
3044
import io.grpc.xds.client.Bootstrapper.ServerInfo;
3145
import io.grpc.xds.client.EnvoyProtoData.Node;
3246
import io.grpc.xds.client.XdsClient;
47+
import io.grpc.xds.client.XdsClient.ResourceWatcher;
3348
import io.grpc.xds.client.XdsInitializationException;
3449
import java.util.Collections;
50+
import java.util.concurrent.TimeUnit;
3551
import org.junit.Rule;
3652
import org.junit.Test;
3753
import org.junit.rules.ExpectedException;
@@ -54,9 +70,12 @@ public class SharedXdsClientPoolProviderTest {
5470
private final Node node = Node.newBuilder().setId("SharedXdsClientPoolProviderTest").build();
5571
private final MetricRecorder metricRecorder = new MetricRecorder() {};
5672
private static final String DUMMY_TARGET = "dummy";
73+
static final Metadata.Key<String> AUTHORIZATION_METADATA_KEY =
74+
Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER);
5775

5876
@Mock
5977
private GrpcBootstrapperImpl bootstrapper;
78+
@Mock private ResourceWatcher<LdsUpdate> ldsResourceWatcher;
6079

6180
@Test
6281
public void noServer() throws XdsInitializationException {
@@ -138,4 +157,62 @@ public void refCountedXdsClientObjectPool_getObjectCreatesNewInstanceIfAlreadySh
138157
assertThat(xdsClient2).isNotSameInstanceAs(xdsClient1);
139158
xdsClientPool.returnObject(xdsClient2);
140159
}
160+
161+
private class CallCredsServerInterceptor implements ServerInterceptor {
162+
private SettableFuture<String> tokenFuture = SettableFuture.create();
163+
164+
@Override
165+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
166+
ServerCall<ReqT, RespT> serverCall,
167+
Metadata metadata,
168+
ServerCallHandler<ReqT, RespT> next) {
169+
tokenFuture.set(metadata.get(AUTHORIZATION_METADATA_KEY));
170+
return next.startCall(serverCall, metadata);
171+
}
172+
173+
public String getTokenWithTimeout(long timeout, TimeUnit unit) throws Exception {
174+
return tokenFuture.get(timeout, unit);
175+
}
176+
}
177+
178+
@Test
179+
public void xdsClient_usesCallCredentials() throws Exception {
180+
// Set up fake xDS server
181+
XdsTestControlPlaneService fakeXdsService = new XdsTestControlPlaneService();
182+
CallCredsServerInterceptor callCredentialsInterceptor = new CallCredsServerInterceptor();
183+
Server xdsServer =
184+
Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create())
185+
.addService(fakeXdsService)
186+
.intercept(callCredentialsInterceptor)
187+
.build()
188+
.start();
189+
String xdsServerUri = "localhost:" + xdsServer.getPort();
190+
191+
// Set up bootstrap & xDS client pool provider
192+
ServerInfo server = ServerInfo.create(xdsServerUri, InsecureChannelCredentials.create());
193+
BootstrapInfo bootstrapInfo =
194+
BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build();
195+
when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo);
196+
SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper);
197+
198+
// Create custom xDS transport CallCredentials
199+
CallCredentials sampleCreds =
200+
MoreCallCredentials.from(
201+
OAuth2Credentials.create(new AccessToken("token", /* expirationTime= */ null)));
202+
203+
// Create xDS client that uses the CallCredentials on the transport
204+
ObjectPool<XdsClient> xdsClientPool =
205+
provider.getOrCreate("target", metricRecorder, sampleCreds);
206+
XdsClient xdsClient = xdsClientPool.getObject();
207+
xdsClient.watchXdsResource(
208+
XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher);
209+
210+
// Wait for xDS server to get the request and verify that it received the CallCredentials
211+
assertThat(callCredentialsInterceptor.getTokenWithTimeout(5, TimeUnit.SECONDS))
212+
.isEqualTo("Bearer token");
213+
214+
// Clean up
215+
xdsClientPool.returnObject(xdsClient);
216+
xdsServer.shutdownNow();
217+
}
141218
}

0 commit comments

Comments
 (0)