Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Added changes to DelayedStream.setStream() should cancel the provided stream if not using it #11969

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions core/src/main/java/io/grpc/internal/DelayedClientTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,6 @@ private PendingStream createPendingStream(PickSubchannelArgs args, ClientStreamT
if (args.getCallOptions().isWaitForReady() && pickResult != null && pickResult.hasResult()) {
pendingStream.lastPickStatus = pickResult.getStatus();
}
pendingStreams.add(pendingStream);
if (getPendingStreamsCount() == 1) {
syncContext.executeLater(reportTransportInUse);
}
for (ClientStreamTracer streamTracer : tracers) {
streamTracer.createPendingStream();
}
return pendingStream;
}

Expand Down Expand Up @@ -363,6 +356,20 @@ private PendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers) {
this.tracers = tracers;
}

@Override
public void start(ClientStreamListener listener) {
super.start(listener);
synchronized (lock) {
pendingStreams.add(this);
if (getPendingStreamsCount() == 1) {
syncContext.executeLater(reportTransportInUse);
}
for (ClientStreamTracer streamTracer : tracers) {
streamTracer.createPendingStream();
}
}
}

/** Runnable may be null. */
private Runnable createRealStream(ClientTransport transport, String authorityOverride) {
ClientStream realStream;
Expand Down
11 changes: 10 additions & 1 deletion core/src/main/java/io/grpc/internal/DelayedStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,20 @@ public void appendTimeoutInsight(InsightBuilder insight) {
@CheckReturnValue
final Runnable setStream(ClientStream stream) {
ClientStreamListener savedListener;
ClientStream oldStream = null;
boolean cancelOldStream = false;

synchronized (this) {
// If realStream != null, then either setStream() or cancel() has been called.
if (realStream != null) {
oldStream = realStream;
cancelOldStream = listener != null;
}
if (oldStream != null && !cancelOldStream) {
return null;
}
if (cancelOldStream) {
oldStream.cancel(Status.CANCELLED.withDescription("Replaced by a new Stream"));
}
setRealStream(checkNotNull(stream, "stream"));
savedListener = listener;
if (savedListener == null) {
Expand Down
102 changes: 76 additions & 26 deletions core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,20 @@ public void uncaughtException(Thread t, Throwable e) {

@Test public void newStreamThenAssignTransportThenShutdown() {
ClientStream stream = delayedTransport.newStream(method, headers, callOptions, tracers);
stream.start(streamListener);
assertEquals(1, delayedTransport.getPendingStreamsCount());
assertTrue(stream instanceof DelayedStream);
delayedTransport.reprocess(mockPicker);
assertEquals(0, delayedTransport.getPendingStreamsCount());
delayedTransport.shutdown(SHUTDOWN_STATUS);
verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS));
verify(transportListener).transportTerminated();
fakeExecutor.runDueTasks();
assertEquals(0, fakeExecutor.runDueTasks());
verify(mockRealTransport).newStream(
same(method), same(headers), same(callOptions),
ArgumentMatchers.<ClientStreamTracer[]>any());
stream.start(streamListener);
verify(mockRealStream).start(same(streamListener));
verify(mockRealStream).start(any(ClientStreamListener.class));
}

@Test public void transportTerminatedThenAssignTransport() {
Expand Down Expand Up @@ -225,8 +226,10 @@ public void uncaughtException(Thread t, Throwable e) {
ClientStream stream = delayedTransport.newStream(
method, new Metadata(), CallOptions.DEFAULT, tracers);
stream.start(streamListener);

assertEquals(1, delayedTransport.getPendingStreamsCount());
stream.cancel(Status.CANCELLED);

assertEquals(0, delayedTransport.getPendingStreamsCount());
verify(streamListener).closed(
same(Status.CANCELLED), same(RpcProgress.PROCESSED), any(Metadata.class));
Expand Down Expand Up @@ -271,14 +274,45 @@ public void uncaughtException(Thread t, Throwable e) {
verifyNoMoreInteractions(mockRealStream);
}

@Test
public void testNewStreamThenShutDownNow() {
ClientStream stream = delayedTransport.newStream(
method, new Metadata(), CallOptions.DEFAULT, tracers);
stream.start(streamListener);
assertEquals(1,delayedTransport.getPendingStreamsCount());
Copy link
Contributor

@vinodhabib vinodhabib Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an empty line above asserts statements in all applicable places.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

delayedTransport.shutdownNow(Status.UNAVAILABLE);
verify(transportListener).transportShutdown(any(Status.class));
verify(transportListener).transportTerminated();
verify(streamListener).closed(
statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class));

assertEquals(0,delayedTransport.getPendingStreamsCount());
assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode());
}

@Test
public void testDelayedClientTransportPendingStreamsOnShutDown() {
ClientStream clientStream = delayedTransport.newStream(method, headers, callOptions, tracers);
ClientStream clientStream1 = delayedTransport.newStream(method, headers, callOptions, tracers);

assertEquals(0, delayedTransport.getPendingStreamsCount());
clientStream.start(streamListener);
clientStream1.start(streamListener);

assertEquals(2, delayedTransport.getPendingStreamsCount());
delayedTransport.shutdownNow(Status.UNAVAILABLE);

assertEquals(0, delayedTransport.getPendingStreamsCount());
}

@Test public void newStreamThenShutdownTransportThenCancelStream() {
ClientStream stream = delayedTransport.newStream(
method, new Metadata(), CallOptions.DEFAULT, tracers);
method, new Metadata(), CallOptions.DEFAULT, tracers);
stream.start(streamListener);
delayedTransport.shutdown(SHUTDOWN_STATUS);
verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS));
verify(transportListener, times(0)).transportTerminated();
assertEquals(1, delayedTransport.getPendingStreamsCount());
stream.start(streamListener);
stream.cancel(Status.CANCELLED);
verify(transportListener).transportTerminated();
assertEquals(0, delayedTransport.getPendingStreamsCount());
Expand Down Expand Up @@ -322,7 +356,9 @@ public void uncaughtException(Thread t, Throwable e) {
assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode());
}

@Test public void reprocessSemantics() {
@Test
@SuppressWarnings("DirectInvocationOnMock")
public void reprocessSemantics() {
CallOptions failFastCallOptions = CallOptions.DEFAULT.withOption(SHARD_ID, 1);
CallOptions waitForReadyCallOptions = CallOptions.DEFAULT.withOption(SHARD_ID, 2)
.withWaitForReady();
Expand All @@ -348,41 +384,48 @@ public void uncaughtException(Thread t, Throwable e) {
ff1.start(mock(ClientStreamListener.class));
ff1.halfClose();
PickSubchannelArgsMatcher ff1args = new PickSubchannelArgsMatcher(method, headers,
failFastCallOptions);
failFastCallOptions);
transportListener.transportInUse(true);
verify(transportListener).transportInUse(true);
DelayedStream ff2 = (DelayedStream) delayedTransport.newStream(
method2, headers2, failFastCallOptions, tracers);
method2, headers2, failFastCallOptions, tracers);
ff2.start(mock(ClientStreamListener.class));
PickSubchannelArgsMatcher ff2args = new PickSubchannelArgsMatcher(method2, headers2,
failFastCallOptions);
DelayedStream ff3 = (DelayedStream) delayedTransport.newStream(
method, headers, failFastCallOptions, tracers);
method, headers, failFastCallOptions, tracers);
ff3.start(mock(ClientStreamListener.class));
PickSubchannelArgsMatcher ff3args = new PickSubchannelArgsMatcher(method, headers,
failFastCallOptions);
DelayedStream ff4 = (DelayedStream) delayedTransport.newStream(
method2, headers2, failFastCallOptions, tracers);
method2, headers2, failFastCallOptions, tracers);
ff4.start(mock(ClientStreamListener.class));
PickSubchannelArgsMatcher ff4args = new PickSubchannelArgsMatcher(method2, headers2,
failFastCallOptions);

// Wait-for-ready streams
FakeClock wfr3Executor = new FakeClock();
DelayedStream wfr1 = (DelayedStream) delayedTransport.newStream(
method, headers, waitForReadyCallOptions, tracers);
method, headers, waitForReadyCallOptions, tracers);
wfr1.start(mock(ClientStreamListener.class));
PickSubchannelArgsMatcher wfr1args = new PickSubchannelArgsMatcher(method, headers,
waitForReadyCallOptions);
DelayedStream wfr2 = (DelayedStream) delayedTransport.newStream(
method2, headers2, waitForReadyCallOptions, tracers);
method2, headers2, waitForReadyCallOptions, tracers);
wfr2.start(mock(ClientStreamListener.class));
PickSubchannelArgsMatcher wfr2args = new PickSubchannelArgsMatcher(method2, headers2,
waitForReadyCallOptions);
CallOptions wfr3callOptions = waitForReadyCallOptions.withExecutor(
wfr3Executor.getScheduledExecutorService());
wfr3Executor.getScheduledExecutorService());
DelayedStream wfr3 = (DelayedStream) delayedTransport.newStream(
method, headers, wfr3callOptions, tracers);
wfr3.start(mock(ClientStreamListener.class));
wfr3.halfClose();
PickSubchannelArgsMatcher wfr3args = new PickSubchannelArgsMatcher(method, headers,
wfr3callOptions);
DelayedStream wfr4 = (DelayedStream) delayedTransport.newStream(
method2, headers2, waitForReadyCallOptions, tracers);
method2, headers2, waitForReadyCallOptions, tracers);
wfr4.start(mock(ClientStreamListener.class));
PickSubchannelArgsMatcher wfr4args = new PickSubchannelArgsMatcher(method2, headers2,
waitForReadyCallOptions);

Expand Down Expand Up @@ -478,7 +521,8 @@ public void uncaughtException(Thread t, Throwable e) {

// New streams will use the last picker
DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream(
method, headers, waitForReadyCallOptions, tracers);
method, headers, waitForReadyCallOptions, tracers);
wfr5.start(mock(ClientStreamListener.class));
assertNull(wfr5.getRealStream());
inOrder.verify(picker).pickSubchannel(
eqPickSubchannelArgs(method, headers, waitForReadyCallOptions));
Expand Down Expand Up @@ -626,12 +670,14 @@ public PickResult answer(InvocationOnMock invocation) throws Throwable {
verify(picker, never()).pickSubchannel(any(PickSubchannelArgs.class));

Thread sideThread = new Thread("sideThread") {
@Override
public void run() {
// Will call pickSubchannel and wait on barrier
delayedTransport.newStream(method, headers, callOptions, tracers);
}
};
@Override
public void run() {
// Will call pick Subchannel and wait on barrier
ClientStream clientStream =
delayedTransport.newStream(method, headers, callOptions, tracers);
clientStream.start(streamListener);
}
};
sideThread.start();

PickSubchannelArgsMatcher args = new PickSubchannelArgsMatcher(method, headers, callOptions);
Expand Down Expand Up @@ -659,12 +705,14 @@ public void run() {
////////// Phase 2: reprocess() with a different picker
// Create the second stream
Thread sideThread2 = new Thread("sideThread2") {
@Override
public void run() {
// Will call pickSubchannel and wait on barrier
delayedTransport.newStream(method, headers2, callOptions, tracers);
}
};
@Override
public void run() {
// Will call pickSubchannel and wait on barrier
ClientStream clientStream = delayedTransport
.newStream(method, headers2, callOptions, tracers);
clientStream.start(streamListener);
}
};
sideThread2.start();
// The second stream will see the first picker
verify(picker, timeout(5000)).pickSubchannel(argThat(args2));
Expand Down Expand Up @@ -714,6 +762,7 @@ public void reprocess_addOptionalLabelCallsTracer() throws Exception {
}

@Test
@SuppressWarnings("DirectInvocationOnMock")
public void newStream_racesWithReprocessIdleMode() throws Exception {
SubchannelPicker picker = new SubchannelPicker() {
@Override public PickResult pickSubchannel(PickSubchannelArgs args) {
Expand All @@ -730,6 +779,7 @@ public void newStream_racesWithReprocessIdleMode() throws Exception {
ClientStream stream = delayedTransport.newStream(
method, headers, callOptions, tracers);
stream.start(streamListener);
transportListener.transportInUse(true);
assertTrue(delayedTransport.hasPendingStreams());
verify(transportListener).transportInUse(true);
}
Expand Down
50 changes: 46 additions & 4 deletions core/src/test/java/io/grpc/internal/DelayedStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand All @@ -46,6 +47,7 @@
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.concurrent.TimeUnit;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -84,6 +86,36 @@ public void setStream_setAuthority() {
inOrder.verify(realStream).start(any(ClientStreamListener.class));
}

@Test
public void testSetStreamReplaceOldStreamProperly() {
ClientStream oldStream = mock(ClientStream.class);
ClientStream newStream = mock(ClientStream.class);

// First stream set, but never started
callMeMaybe(stream.setStream(oldStream));
callMeMaybe(stream.setStream(newStream));
// Verify old stream was canceled
verify(oldStream,never()).cancel(any(Status.class));
// Ensure new stream is properly set
verifyNoMoreInteractions(newStream);
}

@Test
public void testSetStreamStartCancelsOldStreamProperly() {
ClientStream oldStream = mock(ClientStream.class);
ClientStream newStream = mock(ClientStream.class);

// First stream set, but never started
callMeMaybe(stream.setStream(oldStream));
stream.start(listener);
assertThrows(IllegalStateException.class,
() -> callMeMaybe(stream.setStream(mock(ClientStream.class))));
// Verify old stream was canceled
verify(oldStream).cancel(any(Status.class));
// Ensure new stream is properly set
verifyNoMoreInteractions(newStream);
}

@Test(expected = IllegalStateException.class)
public void start_afterStart() {
stream.start(listener);
Expand Down Expand Up @@ -329,21 +361,31 @@ public void setStreamThenStartThenCancelled() {
}

@Test
public void setStreamTwice() {
public void testSetStreamTwice() {
stream.start(listener);
callMeMaybe(stream.setStream(realStream));
verify(realStream).start(any(ClientStreamListener.class));
callMeMaybe(stream.setStream(mock(ClientStream.class)));
IllegalStateException e = assertThrows(IllegalStateException.class, () ->
callMeMaybe(stream.setStream(mock(ClientStream.class)))
);
assertEquals("realStream already set to realStream", e.getMessage());
stream.flush();
verify(realStream).flush();
}

@Test
public void cancelThenSetStream() {
stream.start(listener);
stream.cancel(Status.CANCELLED);
try {
stream.cancel(Status.CANCELLED);
Assert.fail("Should have thrown");
} catch (IllegalStateException e) {
assertEquals("May only be called after start", e.getMessage());
}
callMeMaybe(stream.setStream(realStream));
stream.start(listener);
stream.isReady();
verify(realStream).start(same(listener));
verify(realStream).isReady();
verifyNoMoreInteractions(realStream);
}

Expand Down
11 changes: 8 additions & 3 deletions core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2920,8 +2920,13 @@ public void idleMode_resetsDelayedTransportPicker() {

// Move channel to idle
timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis));
executor.runDueTasks();
assertEquals(IDLE, channel.getState(false));

//Force transport re-creation explicitly
channel.getState(true);
executor.runDueTasks();

// This call should be buffered, but will move the channel out of idle
ClientCall<String, Integer> call2 = channel.newCall(method, CallOptions.DEFAULT);
call2.start(mockCallListener2, new Metadata());
Expand All @@ -2947,15 +2952,15 @@ public void idleMode_resetsDelayedTransportPicker() {
transportListener.transportReady();

when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(PickResult.withSubchannel(subchannel));
.thenReturn(PickResult.withSubchannel(subchannel));
updateBalancingStateSafely(helper2, READY, mockPicker);
assertEquals(READY, channel.getState(false));
executor.runDueTasks();

// Verify the buffered call was drained
verify(mockTransport).newStream(
same(method), any(Metadata.class), any(CallOptions.class),
ArgumentMatchers.<ClientStreamTracer[]>any());
same(method), any(Metadata.class), any(CallOptions.class),
ArgumentMatchers.<ClientStreamTracer[]>any());
verify(mockStream).start(any(ClientStreamListener.class));
}

Expand Down
Loading