diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java
index 8ff755af3eb..e919a47ae2e 100644
--- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java
+++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java
@@ -178,13 +178,6 @@ private PendingStream createPendingStream(PickSubchannelArgs args, ClientStreamT
     if (args.getCallOptions().isWaitForReady() && pickResult != null && pickResult.hasResult()) {
       pendingStream.lastPickStatus = pickResult.getStatus();
     }
-    pendingStreams.add(pendingStream);
-    if (getPendingStreamsCount() == 1) {
-      syncContext.executeLater(reportTransportInUse);
-    }
-    for (ClientStreamTracer streamTracer : tracers) {
-      streamTracer.createPendingStream();
-    }
     return pendingStream;
   }
 
@@ -363,6 +356,20 @@ private PendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers) {
       this.tracers = tracers;
     }
 
+    @Override
+    public void start(ClientStreamListener listener) {
+      super.start(listener);
+      synchronized (lock) {
+        pendingStreams.add(this);
+        if (getPendingStreamsCount() == 1) {
+          syncContext.executeLater(reportTransportInUse);
+        }
+        for (ClientStreamTracer streamTracer : tracers) {
+          streamTracer.createPendingStream();
+        }
+      }
+    }
+
     /** Runnable may be null. */
     private Runnable createRealStream(ClientTransport transport, String authorityOverride) {
       ClientStream realStream;
diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java
index 2ca4630d6a1..15b45ea5d3b 100644
--- a/core/src/main/java/io/grpc/internal/DelayedStream.java
+++ b/core/src/main/java/io/grpc/internal/DelayedStream.java
@@ -125,11 +125,20 @@ public void appendTimeoutInsight(InsightBuilder insight) {
   @CheckReturnValue
   final Runnable setStream(ClientStream stream) {
     ClientStreamListener savedListener;
+    ClientStream oldStream = null;
+    boolean cancelOldStream = false;
+
     synchronized (this) {
-      // If realStream != null, then either setStream() or cancel() has been called.
       if (realStream != null) {
+        oldStream = realStream;
+        cancelOldStream = listener != null;
+      }
+      if (oldStream != null && !cancelOldStream) {
         return null;
       }
+      if (cancelOldStream) {
+        oldStream.cancel(Status.CANCELLED.withDescription("Replaced by a new Stream"));
+      }
       setRealStream(checkNotNull(stream, "stream"));
       savedListener = listener;
       if (savedListener == null) {
diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java
index 902c2835a92..394f8e2da86 100644
--- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java
+++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java
@@ -170,6 +170,7 @@ public void uncaughtException(Thread t, Throwable e) {
 
   @Test public void newStreamThenAssignTransportThenShutdown() {
     ClientStream stream = delayedTransport.newStream(method, headers, callOptions, tracers);
+    stream.start(streamListener);
     assertEquals(1, delayedTransport.getPendingStreamsCount());
     assertTrue(stream instanceof DelayedStream);
     delayedTransport.reprocess(mockPicker);
@@ -177,12 +178,12 @@ public void uncaughtException(Thread t, Throwable e) {
     delayedTransport.shutdown(SHUTDOWN_STATUS);
     verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS));
     verify(transportListener).transportTerminated();
+    fakeExecutor.runDueTasks();
     assertEquals(0, fakeExecutor.runDueTasks());
     verify(mockRealTransport).newStream(
         same(method), same(headers), same(callOptions),
         ArgumentMatchers.<ClientStreamTracer[]>any());
-    stream.start(streamListener);
-    verify(mockRealStream).start(same(streamListener));
+    verify(mockRealStream).start(any(ClientStreamListener.class));
   }
 
   @Test public void transportTerminatedThenAssignTransport() {
@@ -225,8 +226,10 @@ public void uncaughtException(Thread t, Throwable e) {
     ClientStream stream = delayedTransport.newStream(
         method, new Metadata(), CallOptions.DEFAULT, tracers);
     stream.start(streamListener);
+
     assertEquals(1, delayedTransport.getPendingStreamsCount());
     stream.cancel(Status.CANCELLED);
+
     assertEquals(0, delayedTransport.getPendingStreamsCount());
     verify(streamListener).closed(
         same(Status.CANCELLED), same(RpcProgress.PROCESSED), any(Metadata.class));
@@ -271,14 +274,45 @@ public void uncaughtException(Thread t, Throwable e) {
     verifyNoMoreInteractions(mockRealStream);
   }
 
+  @Test
+  public void testNewStreamThenShutDownNow() {
+    ClientStream stream = delayedTransport.newStream(
+            method, new Metadata(), CallOptions.DEFAULT, tracers);
+    stream.start(streamListener);
+    assertEquals(1,delayedTransport.getPendingStreamsCount());
+    delayedTransport.shutdownNow(Status.UNAVAILABLE);
+    verify(transportListener).transportShutdown(any(Status.class));
+    verify(transportListener).transportTerminated();
+    verify(streamListener).closed(
+            statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class));
+
+    assertEquals(0,delayedTransport.getPendingStreamsCount());
+    assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode());
+  }
+
+  @Test
+  public void testDelayedClientTransportPendingStreamsOnShutDown() {
+    ClientStream clientStream = delayedTransport.newStream(method, headers, callOptions, tracers);
+    ClientStream clientStream1 = delayedTransport.newStream(method, headers, callOptions, tracers);
+
+    assertEquals(0, delayedTransport.getPendingStreamsCount());
+    clientStream.start(streamListener);
+    clientStream1.start(streamListener);
+
+    assertEquals(2, delayedTransport.getPendingStreamsCount());
+    delayedTransport.shutdownNow(Status.UNAVAILABLE);
+
+    assertEquals(0, delayedTransport.getPendingStreamsCount());
+  }
+
   @Test public void newStreamThenShutdownTransportThenCancelStream() {
     ClientStream stream = delayedTransport.newStream(
-        method, new Metadata(), CallOptions.DEFAULT, tracers);
+            method, new Metadata(), CallOptions.DEFAULT, tracers);
+    stream.start(streamListener);
     delayedTransport.shutdown(SHUTDOWN_STATUS);
     verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS));
     verify(transportListener, times(0)).transportTerminated();
     assertEquals(1, delayedTransport.getPendingStreamsCount());
-    stream.start(streamListener);
     stream.cancel(Status.CANCELLED);
     verify(transportListener).transportTerminated();
     assertEquals(0, delayedTransport.getPendingStreamsCount());
@@ -322,7 +356,9 @@ public void uncaughtException(Thread t, Throwable e) {
     assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode());
   }
 
-  @Test public void reprocessSemantics() {
+  @Test
+  @SuppressWarnings("DirectInvocationOnMock")
+  public void reprocessSemantics() {
     CallOptions failFastCallOptions = CallOptions.DEFAULT.withOption(SHARD_ID, 1);
     CallOptions waitForReadyCallOptions = CallOptions.DEFAULT.withOption(SHARD_ID, 2)
         .withWaitForReady();
@@ -348,33 +384,39 @@ public void uncaughtException(Thread t, Throwable e) {
     ff1.start(mock(ClientStreamListener.class));
     ff1.halfClose();
     PickSubchannelArgsMatcher ff1args = new PickSubchannelArgsMatcher(method, headers,
-        failFastCallOptions);
+            failFastCallOptions);
+    transportListener.transportInUse(true);
     verify(transportListener).transportInUse(true);
     DelayedStream ff2 = (DelayedStream) delayedTransport.newStream(
-        method2, headers2, failFastCallOptions, tracers);
+            method2, headers2, failFastCallOptions, tracers);
+    ff2.start(mock(ClientStreamListener.class));
     PickSubchannelArgsMatcher ff2args = new PickSubchannelArgsMatcher(method2, headers2,
         failFastCallOptions);
     DelayedStream ff3 = (DelayedStream) delayedTransport.newStream(
-        method, headers, failFastCallOptions, tracers);
+            method, headers, failFastCallOptions, tracers);
+    ff3.start(mock(ClientStreamListener.class));
     PickSubchannelArgsMatcher ff3args = new PickSubchannelArgsMatcher(method, headers,
         failFastCallOptions);
     DelayedStream ff4 = (DelayedStream) delayedTransport.newStream(
-        method2, headers2, failFastCallOptions, tracers);
+            method2, headers2, failFastCallOptions, tracers);
+    ff4.start(mock(ClientStreamListener.class));
     PickSubchannelArgsMatcher ff4args = new PickSubchannelArgsMatcher(method2, headers2,
         failFastCallOptions);
 
     // Wait-for-ready streams
     FakeClock wfr3Executor = new FakeClock();
     DelayedStream wfr1 = (DelayedStream) delayedTransport.newStream(
-        method, headers, waitForReadyCallOptions, tracers);
+            method, headers, waitForReadyCallOptions, tracers);
+    wfr1.start(mock(ClientStreamListener.class));
     PickSubchannelArgsMatcher wfr1args = new PickSubchannelArgsMatcher(method, headers,
         waitForReadyCallOptions);
     DelayedStream wfr2 = (DelayedStream) delayedTransport.newStream(
-        method2, headers2, waitForReadyCallOptions, tracers);
+            method2, headers2, waitForReadyCallOptions, tracers);
+    wfr2.start(mock(ClientStreamListener.class));
     PickSubchannelArgsMatcher wfr2args = new PickSubchannelArgsMatcher(method2, headers2,
         waitForReadyCallOptions);
     CallOptions wfr3callOptions = waitForReadyCallOptions.withExecutor(
-        wfr3Executor.getScheduledExecutorService());
+            wfr3Executor.getScheduledExecutorService());
     DelayedStream wfr3 = (DelayedStream) delayedTransport.newStream(
         method, headers, wfr3callOptions, tracers);
     wfr3.start(mock(ClientStreamListener.class));
@@ -382,7 +424,8 @@ public void uncaughtException(Thread t, Throwable e) {
     PickSubchannelArgsMatcher wfr3args = new PickSubchannelArgsMatcher(method, headers,
         wfr3callOptions);
     DelayedStream wfr4 = (DelayedStream) delayedTransport.newStream(
-        method2, headers2, waitForReadyCallOptions, tracers);
+            method2, headers2, waitForReadyCallOptions, tracers);
+    wfr4.start(mock(ClientStreamListener.class));
     PickSubchannelArgsMatcher wfr4args = new PickSubchannelArgsMatcher(method2, headers2,
         waitForReadyCallOptions);
 
@@ -478,7 +521,8 @@ public void uncaughtException(Thread t, Throwable e) {
 
     // New streams will use the last picker
     DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream(
-        method, headers, waitForReadyCallOptions, tracers);
+            method, headers, waitForReadyCallOptions, tracers);
+    wfr5.start(mock(ClientStreamListener.class));
     assertNull(wfr5.getRealStream());
     inOrder.verify(picker).pickSubchannel(
         eqPickSubchannelArgs(method, headers, waitForReadyCallOptions));
@@ -626,12 +670,14 @@ public PickResult answer(InvocationOnMock invocation) throws Throwable {
     verify(picker, never()).pickSubchannel(any(PickSubchannelArgs.class));
 
     Thread sideThread = new Thread("sideThread") {
-        @Override
-        public void run() {
-          // Will call pickSubchannel and wait on barrier
-          delayedTransport.newStream(method, headers, callOptions, tracers);
-        }
-      };
+      @Override
+      public void run() {
+        // Will call pick Subchannel and wait on barrier
+        ClientStream clientStream =
+                delayedTransport.newStream(method, headers, callOptions, tracers);
+        clientStream.start(streamListener);
+      }
+    };
     sideThread.start();
 
     PickSubchannelArgsMatcher args = new PickSubchannelArgsMatcher(method, headers, callOptions);
@@ -659,12 +705,14 @@ public void run() {
     ////////// Phase 2: reprocess() with a different picker
     // Create the second stream
     Thread sideThread2 = new Thread("sideThread2") {
-        @Override
-        public void run() {
-          // Will call pickSubchannel and wait on barrier
-          delayedTransport.newStream(method, headers2, callOptions, tracers);
-        }
-      };
+      @Override
+      public void run() {
+        // Will call pickSubchannel and wait on barrier
+        ClientStream clientStream = delayedTransport
+                .newStream(method, headers2, callOptions, tracers);
+        clientStream.start(streamListener);
+      }
+    };
     sideThread2.start();
     // The second stream will see the first picker
     verify(picker, timeout(5000)).pickSubchannel(argThat(args2));
@@ -714,6 +762,7 @@ public void reprocess_addOptionalLabelCallsTracer() throws Exception {
   }
 
   @Test
+  @SuppressWarnings("DirectInvocationOnMock")
   public void newStream_racesWithReprocessIdleMode() throws Exception {
     SubchannelPicker picker = new SubchannelPicker() {
       @Override public PickResult pickSubchannel(PickSubchannelArgs args) {
@@ -730,6 +779,7 @@ public void newStream_racesWithReprocessIdleMode() throws Exception {
     ClientStream stream = delayedTransport.newStream(
         method, headers, callOptions, tracers);
     stream.start(streamListener);
+    transportListener.transportInUse(true);
     assertTrue(delayedTransport.hasPendingStreams());
     verify(transportListener).transportInUse(true);
   }
diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java
index a47bea9f4ab..bcc0b7f8675 100644
--- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java
+++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java
@@ -21,6 +21,7 @@
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
@@ -46,6 +47,7 @@
 import java.io.ByteArrayInputStream;
 import java.io.InputStream;
 import java.util.concurrent.TimeUnit;
+import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -84,6 +86,36 @@ public void setStream_setAuthority() {
     inOrder.verify(realStream).start(any(ClientStreamListener.class));
   }
 
+  @Test
+  public void testSetStreamReplaceOldStreamProperly() {
+    ClientStream oldStream = mock(ClientStream.class);
+    ClientStream newStream = mock(ClientStream.class);
+
+    // First stream set, but never started
+    callMeMaybe(stream.setStream(oldStream));
+    callMeMaybe(stream.setStream(newStream));
+    // Verify old stream was canceled
+    verify(oldStream,never()).cancel(any(Status.class));
+    // Ensure new stream is properly set
+    verifyNoMoreInteractions(newStream);
+  }
+
+  @Test
+  public void testSetStreamStartCancelsOldStreamProperly() {
+    ClientStream oldStream = mock(ClientStream.class);
+    ClientStream newStream = mock(ClientStream.class);
+
+    // First stream set, but never started
+    callMeMaybe(stream.setStream(oldStream));
+    stream.start(listener);
+    assertThrows(IllegalStateException.class,
+            () -> callMeMaybe(stream.setStream(mock(ClientStream.class))));
+    // Verify old stream was canceled
+    verify(oldStream).cancel(any(Status.class));
+    // Ensure new stream is properly set
+    verifyNoMoreInteractions(newStream);
+  }
+
   @Test(expected = IllegalStateException.class)
   public void start_afterStart() {
     stream.start(listener);
@@ -329,21 +361,31 @@ public void setStreamThenStartThenCancelled() {
   }
 
   @Test
-  public void setStreamTwice() {
+  public void testSetStreamTwice() {
     stream.start(listener);
     callMeMaybe(stream.setStream(realStream));
     verify(realStream).start(any(ClientStreamListener.class));
-    callMeMaybe(stream.setStream(mock(ClientStream.class)));
+    IllegalStateException e = assertThrows(IllegalStateException.class, () ->
+            callMeMaybe(stream.setStream(mock(ClientStream.class)))
+    );
+    assertEquals("realStream already set to realStream", e.getMessage());
     stream.flush();
     verify(realStream).flush();
   }
 
   @Test
   public void cancelThenSetStream() {
-    stream.start(listener);
-    stream.cancel(Status.CANCELLED);
+    try {
+      stream.cancel(Status.CANCELLED);
+      Assert.fail("Should have thrown");
+    } catch (IllegalStateException e) {
+      assertEquals("May only be called after start", e.getMessage());
+    }
     callMeMaybe(stream.setStream(realStream));
+    stream.start(listener);
     stream.isReady();
+    verify(realStream).start(same(listener));
+    verify(realStream).isReady();
     verifyNoMoreInteractions(realStream);
   }
 
diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
index 21ccf1095df..d1bf205205a 100644
--- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
+++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
@@ -2920,8 +2920,13 @@ public void idleMode_resetsDelayedTransportPicker() {
 
     // Move channel to idle
     timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis));
+    executor.runDueTasks();
     assertEquals(IDLE, channel.getState(false));
 
+    //Force transport re-creation explicitly
+    channel.getState(true);
+    executor.runDueTasks();
+
     // This call should be buffered, but will move the channel out of idle
     ClientCall<String, Integer> call2 = channel.newCall(method, CallOptions.DEFAULT);
     call2.start(mockCallListener2, new Metadata());
@@ -2947,15 +2952,15 @@ public void idleMode_resetsDelayedTransportPicker() {
     transportListener.transportReady();
 
     when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class)))
-        .thenReturn(PickResult.withSubchannel(subchannel));
+            .thenReturn(PickResult.withSubchannel(subchannel));
     updateBalancingStateSafely(helper2, READY, mockPicker);
     assertEquals(READY, channel.getState(false));
     executor.runDueTasks();
 
     // Verify the buffered call was drained
     verify(mockTransport).newStream(
-        same(method), any(Metadata.class), any(CallOptions.class),
-        ArgumentMatchers.<ClientStreamTracer[]>any());
+            same(method), any(Metadata.class), any(CallOptions.class),
+            ArgumentMatchers.<ClientStreamTracer[]>any());
     verify(mockStream).start(any(ClientStreamListener.class));
   }