Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xds: add support for custom per-target credentials on the transport. #11951

Merged
merged 8 commits into from
Mar 21, 2025
Merged
54 changes: 38 additions & 16 deletions xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.annotations.VisibleForTesting;
import io.grpc.CallCredentials;
import io.grpc.CallOptions;
import io.grpc.ChannelCredentials;
import io.grpc.ClientCall;
Expand All @@ -34,35 +35,50 @@

final class GrpcXdsTransportFactory implements XdsTransportFactory {

static final GrpcXdsTransportFactory DEFAULT_XDS_TRANSPORT_FACTORY =
new GrpcXdsTransportFactory();
private final CallCredentials callCredentials;

GrpcXdsTransportFactory(CallCredentials callCredentials) {
this.callCredentials = callCredentials;
}

@Override
public XdsTransport create(Bootstrapper.ServerInfo serverInfo) {
return new GrpcXdsTransport(serverInfo);
return new GrpcXdsTransport(serverInfo, callCredentials);
}

@VisibleForTesting
public XdsTransport createForTest(ManagedChannel channel) {
return new GrpcXdsTransport(channel);
return new GrpcXdsTransport(channel, callCredentials);
}

@VisibleForTesting
static class GrpcXdsTransport implements XdsTransport {

private final ManagedChannel channel;
private final CallCredentials callCredentials;

public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) {
this(serverInfo, null);
}

Check warning on line 62 in xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java#L61-L62

Added lines #L61 - L62 were not covered by tests

@VisibleForTesting
public GrpcXdsTransport(ManagedChannel channel) {
this(channel, null);
}

public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) {
String target = serverInfo.target();
ChannelCredentials channelCredentials = (ChannelCredentials) serverInfo.implSpecificConfig();
this.channel = Grpc.newChannelBuilder(target, channelCredentials)
.keepAliveTime(5, TimeUnit.MINUTES)
.build();
this.callCredentials = callCredentials;
}

@VisibleForTesting
public GrpcXdsTransport(ManagedChannel channel) {
public GrpcXdsTransport(ManagedChannel channel, CallCredentials callCredentials) {
this.channel = checkNotNull(channel, "channel");
this.callCredentials = callCredentials;
}

@Override
Expand All @@ -72,7 +88,8 @@
MethodDescriptor.Marshaller<RespT> respMarshaller) {
Context prevContext = Context.ROOT.attach();
try {
return new XdsStreamingCall<>(fullMethodName, reqMarshaller, respMarshaller);
return new XdsStreamingCall<>(
fullMethodName, reqMarshaller, respMarshaller, callCredentials);
} finally {
Context.ROOT.detach(prevContext);
}
Expand All @@ -89,16 +106,21 @@

private final ClientCall<ReqT, RespT> call;

public XdsStreamingCall(String methodName, MethodDescriptor.Marshaller<ReqT> reqMarshaller,
MethodDescriptor.Marshaller<RespT> respMarshaller) {
this.call = channel.newCall(
MethodDescriptor.<ReqT, RespT>newBuilder()
.setFullMethodName(methodName)
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
.setRequestMarshaller(reqMarshaller)
.setResponseMarshaller(respMarshaller)
.build(),
CallOptions.DEFAULT); // TODO(zivy): support waitForReady
public XdsStreamingCall(
String methodName,
MethodDescriptor.Marshaller<ReqT> reqMarshaller,
MethodDescriptor.Marshaller<RespT> respMarshaller,
CallCredentials callCredentials) {
this.call =
channel.newCall(
MethodDescriptor.<ReqT, RespT>newBuilder()
.setFullMethodName(methodName)
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
.setRequestMarshaller(reqMarshaller)
.setResponseMarshaller(respMarshaller)
.build(),
CallOptions.DEFAULT.withCallCredentials(
callCredentials)); // TODO(zivy): support waitForReady
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.grpc.xds;

import io.grpc.CallCredentials;
import io.grpc.Internal;
import io.grpc.MetricRecorder;
import io.grpc.internal.ObjectPool;
Expand All @@ -42,6 +43,18 @@

public static ObjectPool<XdsClient> getOrCreate(String target, MetricRecorder metricRecorder)
throws XdsInitializationException {
return SharedXdsClientPoolProvider.getDefaultProvider().getOrCreate(target, metricRecorder);
return getOrCreate(target, metricRecorder, null);

Check warning on line 46 in xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java#L46

Added line #L46 was not covered by tests
}

public static ObjectPool<XdsClient> getOrCreate(
String target, CallCredentials transportCallCredentials) throws XdsInitializationException {
return getOrCreate(target, new MetricRecorder() {}, transportCallCredentials);

Check warning on line 51 in xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java#L51

Added line #L51 was not covered by tests
}

public static ObjectPool<XdsClient> getOrCreate(
String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials)
throws XdsInitializationException {
return SharedXdsClientPoolProvider.getDefaultProvider()
.getOrCreate(target, metricRecorder, transportCallCredentials);

Check warning on line 58 in xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java#L57-L58

Added lines #L57 - L58 were not covered by tests
}
}
50 changes: 36 additions & 14 deletions xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
package io.grpc.xds;

import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.grpc.CallCredentials;
import io.grpc.MetricRecorder;
import io.grpc.internal.ExponentialBackoffPolicy;
import io.grpc.internal.GrpcUtil;
Expand Down Expand Up @@ -87,6 +87,12 @@ public ObjectPool<XdsClient> get(String target) {
@Override
public ObjectPool<XdsClient> getOrCreate(String target, MetricRecorder metricRecorder)
throws XdsInitializationException {
return getOrCreate(target, metricRecorder, null);
}

public ObjectPool<XdsClient> getOrCreate(
String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials)
throws XdsInitializationException {
ObjectPool<XdsClient> ref = targetToXdsClientMap.get(target);
if (ref == null) {
synchronized (lock) {
Expand All @@ -102,7 +108,9 @@ public ObjectPool<XdsClient> getOrCreate(String target, MetricRecorder metricRec
if (bootstrapInfo.servers().isEmpty()) {
throw new XdsInitializationException("No xDS server provided");
}
ref = new RefCountedXdsClientObjectPool(bootstrapInfo, target, metricRecorder);
ref =
new RefCountedXdsClientObjectPool(
bootstrapInfo, target, metricRecorder, transportCallCredentials);
targetToXdsClientMap.put(target, ref);
}
}
Expand All @@ -126,6 +134,7 @@ class RefCountedXdsClientObjectPool implements ObjectPool<XdsClient> {
private final BootstrapInfo bootstrapInfo;
private final String target; // The target associated with the xDS client.
private final MetricRecorder metricRecorder;
private final CallCredentials transportCallCredentials;
private final Object lock = new Object();
@GuardedBy("lock")
private ScheduledExecutorService scheduler;
Expand All @@ -137,11 +146,21 @@ class RefCountedXdsClientObjectPool implements ObjectPool<XdsClient> {
private XdsClientMetricReporterImpl metricReporter;

@VisibleForTesting
RefCountedXdsClientObjectPool(BootstrapInfo bootstrapInfo, String target,
MetricRecorder metricRecorder) {
RefCountedXdsClientObjectPool(
BootstrapInfo bootstrapInfo, String target, MetricRecorder metricRecorder) {
this(bootstrapInfo, target, metricRecorder, null);
}

@VisibleForTesting
RefCountedXdsClientObjectPool(
BootstrapInfo bootstrapInfo,
String target,
MetricRecorder metricRecorder,
CallCredentials transportCallCredentials) {
this.bootstrapInfo = checkNotNull(bootstrapInfo);
this.target = target;
this.metricRecorder = metricRecorder;
this.transportCallCredentials = transportCallCredentials;
}

@Override
Expand All @@ -153,16 +172,19 @@ public XdsClient getObject() {
}
scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE);
metricReporter = new XdsClientMetricReporterImpl(metricRecorder, target);
xdsClient = new XdsClientImpl(
DEFAULT_XDS_TRANSPORT_FACTORY,
bootstrapInfo,
scheduler,
BACKOFF_POLICY_PROVIDER,
GrpcUtil.STOPWATCH_SUPPLIER,
TimeProvider.SYSTEM_TIME_PROVIDER,
MessagePrinter.INSTANCE,
new TlsContextManagerImpl(bootstrapInfo),
metricReporter);
GrpcXdsTransportFactory xdsTransportFactory =
new GrpcXdsTransportFactory(transportCallCredentials);
xdsClient =
new XdsClientImpl(
xdsTransportFactory,
bootstrapInfo,
scheduler,
BACKOFF_POLICY_PROVIDER,
GrpcUtil.STOPWATCH_SUPPLIER,
TimeProvider.SYSTEM_TIME_PROVIDER,
MessagePrinter.INSTANCE,
new TlsContextManagerImpl(bootstrapInfo),
metricReporter);
metricReporter.setXdsClient(xdsClient);
}
refCount++;
Expand Down
3 changes: 1 addition & 2 deletions xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
Expand Down Expand Up @@ -4193,7 +4192,7 @@ public void serverFailureMetricReport_forRetryAndBackoff() {
private XdsClientImpl createXdsClient(String serverUri) {
BootstrapInfo bootstrapInfo = buildBootStrap(serverUri);
return new XdsClientImpl(
DEFAULT_XDS_TRANSPORT_FACTORY,
new GrpcXdsTransportFactory(null),
bootstrapInfo,
fakeClock.getScheduledExecutorService(),
backoffPolicyProvider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ public void onCompleted() {
@Test
public void callApis() throws Exception {
XdsTransportFactory.XdsTransport xdsTransport =
GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.create(
Bootstrapper.ServerInfo.create("localhost:" + server.getPort(),
InsecureChannelCredentials.create()));
new GrpcXdsTransportFactory(null)
.create(
Bootstrapper.ServerInfo.create(
"localhost:" + server.getPort(), InsecureChannelCredentials.create()));
MethodDescriptor<DiscoveryRequest, DiscoveryResponse> methodDescriptor =
AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod();
XdsTransportFactory.StreamingCall<DiscoveryRequest, DiscoveryResponse> streamingCall =
Expand Down
14 changes: 9 additions & 5 deletions xds/src/test/java/io/grpc/xds/LoadReportClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,15 @@ public void cancelled(Context context) {
when(backoffPolicy2.nextBackoffNanos())
.thenReturn(TimeUnit.SECONDS.toNanos(2L), TimeUnit.SECONDS.toNanos(20L));
addFakeStatsData();
lrsClient = new LoadReportClient(loadStatsManager,
GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.createForTest(channel),
NODE,
syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider,
fakeClock.getStopwatchSupplier());
lrsClient =
new LoadReportClient(
loadStatsManager,
new GrpcXdsTransportFactory(null).createForTest(channel),
NODE,
syncContext,
fakeClock.getScheduledExecutorService(),
backoffPolicyProvider,
fakeClock.getStopwatchSupplier());
syncContext.execute(new Runnable() {
@Override
public void run() {
Expand Down
Loading
Loading