diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 8ff755af3eb..e919a47ae2e 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -178,13 +178,6 @@ private PendingStream createPendingStream(PickSubchannelArgs args, ClientStreamT if (args.getCallOptions().isWaitForReady() && pickResult != null && pickResult.hasResult()) { pendingStream.lastPickStatus = pickResult.getStatus(); } - pendingStreams.add(pendingStream); - if (getPendingStreamsCount() == 1) { - syncContext.executeLater(reportTransportInUse); - } - for (ClientStreamTracer streamTracer : tracers) { - streamTracer.createPendingStream(); - } return pendingStream; } @@ -363,6 +356,20 @@ private PendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers) { this.tracers = tracers; } + @Override + public void start(ClientStreamListener listener) { + super.start(listener); + synchronized (lock) { + pendingStreams.add(this); + if (getPendingStreamsCount() == 1) { + syncContext.executeLater(reportTransportInUse); + } + for (ClientStreamTracer streamTracer : tracers) { + streamTracer.createPendingStream(); + } + } + } + /** Runnable may be null. */ private Runnable createRealStream(ClientTransport transport, String authorityOverride) { ClientStream realStream; diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 2ca4630d6a1..15b45ea5d3b 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -125,11 +125,20 @@ public void appendTimeoutInsight(InsightBuilder insight) { @CheckReturnValue final Runnable setStream(ClientStream stream) { ClientStreamListener savedListener; + ClientStream oldStream = null; + boolean cancelOldStream = false; + synchronized (this) { - // If realStream != null, then either setStream() or cancel() has been called. if (realStream != null) { + oldStream = realStream; + cancelOldStream = listener != null; + } + if (oldStream != null && !cancelOldStream) { return null; } + if (cancelOldStream) { + oldStream.cancel(Status.CANCELLED.withDescription("Replaced by a new Stream")); + } setRealStream(checkNotNull(stream, "stream")); savedListener = listener; if (savedListener == null) { diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 902c2835a92..394f8e2da86 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -170,6 +170,7 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void newStreamThenAssignTransportThenShutdown() { ClientStream stream = delayedTransport.newStream(method, headers, callOptions, tracers); + stream.start(streamListener); assertEquals(1, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof DelayedStream); delayedTransport.reprocess(mockPicker); @@ -177,12 +178,12 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); + fakeExecutor.runDueTasks(); assertEquals(0, fakeExecutor.runDueTasks()); verify(mockRealTransport).newStream( same(method), same(headers), same(callOptions), ArgumentMatchers.<ClientStreamTracer[]>any()); - stream.start(streamListener); - verify(mockRealStream).start(same(streamListener)); + verify(mockRealStream).start(any(ClientStreamListener.class)); } @Test public void transportTerminatedThenAssignTransport() { @@ -225,8 +226,10 @@ public void uncaughtException(Thread t, Throwable e) { ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); + assertEquals(1, delayedTransport.getPendingStreamsCount()); stream.cancel(Status.CANCELLED); + assertEquals(0, delayedTransport.getPendingStreamsCount()); verify(streamListener).closed( same(Status.CANCELLED), same(RpcProgress.PROCESSED), any(Metadata.class)); @@ -271,14 +274,45 @@ public void uncaughtException(Thread t, Throwable e) { verifyNoMoreInteractions(mockRealStream); } + @Test + public void testNewStreamThenShutDownNow() { + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); + stream.start(streamListener); + assertEquals(1,delayedTransport.getPendingStreamsCount()); + delayedTransport.shutdownNow(Status.UNAVAILABLE); + verify(transportListener).transportShutdown(any(Status.class)); + verify(transportListener).transportTerminated(); + verify(streamListener).closed( + statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + + assertEquals(0,delayedTransport.getPendingStreamsCount()); + assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); + } + + @Test + public void testDelayedClientTransportPendingStreamsOnShutDown() { + ClientStream clientStream = delayedTransport.newStream(method, headers, callOptions, tracers); + ClientStream clientStream1 = delayedTransport.newStream(method, headers, callOptions, tracers); + + assertEquals(0, delayedTransport.getPendingStreamsCount()); + clientStream.start(streamListener); + clientStream1.start(streamListener); + + assertEquals(2, delayedTransport.getPendingStreamsCount()); + delayedTransport.shutdownNow(Status.UNAVAILABLE); + + assertEquals(0, delayedTransport.getPendingStreamsCount()); + } + @Test public void newStreamThenShutdownTransportThenCancelStream() { ClientStream stream = delayedTransport.newStream( - method, new Metadata(), CallOptions.DEFAULT, tracers); + method, new Metadata(), CallOptions.DEFAULT, tracers); + stream.start(streamListener); delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener, times(0)).transportTerminated(); assertEquals(1, delayedTransport.getPendingStreamsCount()); - stream.start(streamListener); stream.cancel(Status.CANCELLED); verify(transportListener).transportTerminated(); assertEquals(0, delayedTransport.getPendingStreamsCount()); @@ -322,7 +356,9 @@ public void uncaughtException(Thread t, Throwable e) { assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); } - @Test public void reprocessSemantics() { + @Test + @SuppressWarnings("DirectInvocationOnMock") + public void reprocessSemantics() { CallOptions failFastCallOptions = CallOptions.DEFAULT.withOption(SHARD_ID, 1); CallOptions waitForReadyCallOptions = CallOptions.DEFAULT.withOption(SHARD_ID, 2) .withWaitForReady(); @@ -348,33 +384,39 @@ public void uncaughtException(Thread t, Throwable e) { ff1.start(mock(ClientStreamListener.class)); ff1.halfClose(); PickSubchannelArgsMatcher ff1args = new PickSubchannelArgsMatcher(method, headers, - failFastCallOptions); + failFastCallOptions); + transportListener.transportInUse(true); verify(transportListener).transportInUse(true); DelayedStream ff2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions, tracers); + method2, headers2, failFastCallOptions, tracers); + ff2.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher ff2args = new PickSubchannelArgsMatcher(method2, headers2, failFastCallOptions); DelayedStream ff3 = (DelayedStream) delayedTransport.newStream( - method, headers, failFastCallOptions, tracers); + method, headers, failFastCallOptions, tracers); + ff3.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher ff3args = new PickSubchannelArgsMatcher(method, headers, failFastCallOptions); DelayedStream ff4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions, tracers); + method2, headers2, failFastCallOptions, tracers); + ff4.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher ff4args = new PickSubchannelArgsMatcher(method2, headers2, failFastCallOptions); // Wait-for-ready streams FakeClock wfr3Executor = new FakeClock(); DelayedStream wfr1 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions, tracers); + method, headers, waitForReadyCallOptions, tracers); + wfr1.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher wfr1args = new PickSubchannelArgsMatcher(method, headers, waitForReadyCallOptions); DelayedStream wfr2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions, tracers); + method2, headers2, waitForReadyCallOptions, tracers); + wfr2.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher wfr2args = new PickSubchannelArgsMatcher(method2, headers2, waitForReadyCallOptions); CallOptions wfr3callOptions = waitForReadyCallOptions.withExecutor( - wfr3Executor.getScheduledExecutorService()); + wfr3Executor.getScheduledExecutorService()); DelayedStream wfr3 = (DelayedStream) delayedTransport.newStream( method, headers, wfr3callOptions, tracers); wfr3.start(mock(ClientStreamListener.class)); @@ -382,7 +424,8 @@ public void uncaughtException(Thread t, Throwable e) { PickSubchannelArgsMatcher wfr3args = new PickSubchannelArgsMatcher(method, headers, wfr3callOptions); DelayedStream wfr4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions, tracers); + method2, headers2, waitForReadyCallOptions, tracers); + wfr4.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher wfr4args = new PickSubchannelArgsMatcher(method2, headers2, waitForReadyCallOptions); @@ -478,7 +521,8 @@ public void uncaughtException(Thread t, Throwable e) { // New streams will use the last picker DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions, tracers); + method, headers, waitForReadyCallOptions, tracers); + wfr5.start(mock(ClientStreamListener.class)); assertNull(wfr5.getRealStream()); inOrder.verify(picker).pickSubchannel( eqPickSubchannelArgs(method, headers, waitForReadyCallOptions)); @@ -626,12 +670,14 @@ public PickResult answer(InvocationOnMock invocation) throws Throwable { verify(picker, never()).pickSubchannel(any(PickSubchannelArgs.class)); Thread sideThread = new Thread("sideThread") { - @Override - public void run() { - // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers, callOptions, tracers); - } - }; + @Override + public void run() { + // Will call pick Subchannel and wait on barrier + ClientStream clientStream = + delayedTransport.newStream(method, headers, callOptions, tracers); + clientStream.start(streamListener); + } + }; sideThread.start(); PickSubchannelArgsMatcher args = new PickSubchannelArgsMatcher(method, headers, callOptions); @@ -659,12 +705,14 @@ public void run() { ////////// Phase 2: reprocess() with a different picker // Create the second stream Thread sideThread2 = new Thread("sideThread2") { - @Override - public void run() { - // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers2, callOptions, tracers); - } - }; + @Override + public void run() { + // Will call pickSubchannel and wait on barrier + ClientStream clientStream = delayedTransport + .newStream(method, headers2, callOptions, tracers); + clientStream.start(streamListener); + } + }; sideThread2.start(); // The second stream will see the first picker verify(picker, timeout(5000)).pickSubchannel(argThat(args2)); @@ -714,6 +762,7 @@ public void reprocess_addOptionalLabelCallsTracer() throws Exception { } @Test + @SuppressWarnings("DirectInvocationOnMock") public void newStream_racesWithReprocessIdleMode() throws Exception { SubchannelPicker picker = new SubchannelPicker() { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { @@ -730,6 +779,7 @@ public void newStream_racesWithReprocessIdleMode() throws Exception { ClientStream stream = delayedTransport.newStream( method, headers, callOptions, tracers); stream.start(streamListener); + transportListener.transportInUse(true); assertTrue(delayedTransport.hasPendingStreams()); verify(transportListener).transportInUse(true); } diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index a47bea9f4ab..bcc0b7f8675 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -46,6 +47,7 @@ import java.io.ByteArrayInputStream; import java.io.InputStream; import java.util.concurrent.TimeUnit; +import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -84,6 +86,36 @@ public void setStream_setAuthority() { inOrder.verify(realStream).start(any(ClientStreamListener.class)); } + @Test + public void testSetStreamReplaceOldStreamProperly() { + ClientStream oldStream = mock(ClientStream.class); + ClientStream newStream = mock(ClientStream.class); + + // First stream set, but never started + callMeMaybe(stream.setStream(oldStream)); + callMeMaybe(stream.setStream(newStream)); + // Verify old stream was canceled + verify(oldStream,never()).cancel(any(Status.class)); + // Ensure new stream is properly set + verifyNoMoreInteractions(newStream); + } + + @Test + public void testSetStreamStartCancelsOldStreamProperly() { + ClientStream oldStream = mock(ClientStream.class); + ClientStream newStream = mock(ClientStream.class); + + // First stream set, but never started + callMeMaybe(stream.setStream(oldStream)); + stream.start(listener); + assertThrows(IllegalStateException.class, + () -> callMeMaybe(stream.setStream(mock(ClientStream.class)))); + // Verify old stream was canceled + verify(oldStream).cancel(any(Status.class)); + // Ensure new stream is properly set + verifyNoMoreInteractions(newStream); + } + @Test(expected = IllegalStateException.class) public void start_afterStart() { stream.start(listener); @@ -329,21 +361,31 @@ public void setStreamThenStartThenCancelled() { } @Test - public void setStreamTwice() { + public void testSetStreamTwice() { stream.start(listener); callMeMaybe(stream.setStream(realStream)); verify(realStream).start(any(ClientStreamListener.class)); - callMeMaybe(stream.setStream(mock(ClientStream.class))); + IllegalStateException e = assertThrows(IllegalStateException.class, () -> + callMeMaybe(stream.setStream(mock(ClientStream.class))) + ); + assertEquals("realStream already set to realStream", e.getMessage()); stream.flush(); verify(realStream).flush(); } @Test public void cancelThenSetStream() { - stream.start(listener); - stream.cancel(Status.CANCELLED); + try { + stream.cancel(Status.CANCELLED); + Assert.fail("Should have thrown"); + } catch (IllegalStateException e) { + assertEquals("May only be called after start", e.getMessage()); + } callMeMaybe(stream.setStream(realStream)); + stream.start(listener); stream.isReady(); + verify(realStream).start(same(listener)); + verify(realStream).isReady(); verifyNoMoreInteractions(realStream); } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 21ccf1095df..d1bf205205a 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -2920,8 +2920,13 @@ public void idleMode_resetsDelayedTransportPicker() { // Move channel to idle timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis)); + executor.runDueTasks(); assertEquals(IDLE, channel.getState(false)); + //Force transport re-creation explicitly + channel.getState(true); + executor.runDueTasks(); + // This call should be buffered, but will move the channel out of idle ClientCall<String, Integer> call2 = channel.newCall(method, CallOptions.DEFAULT); call2.start(mockCallListener2, new Metadata()); @@ -2947,15 +2952,15 @@ public void idleMode_resetsDelayedTransportPicker() { transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) - .thenReturn(PickResult.withSubchannel(subchannel)); + .thenReturn(PickResult.withSubchannel(subchannel)); updateBalancingStateSafely(helper2, READY, mockPicker); assertEquals(READY, channel.getState(false)); executor.runDueTasks(); // Verify the buffered call was drained verify(mockTransport).newStream( - same(method), any(Metadata.class), any(CallOptions.class), - ArgumentMatchers.<ClientStreamTracer[]>any()); + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.<ClientStreamTracer[]>any()); verify(mockStream).start(any(ClientStreamListener.class)); }