Skip to content

Commit 09b9c77

Browse files
committed
update to reflect new changes to vacation api, add enum for whether to process packets async
1 parent 0df5320 commit 09b9c77

File tree

10 files changed

+86
-66
lines changed

10 files changed

+86
-66
lines changed

.github/workflows/CI.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ jobs:
6262
run: |
6363
cargo test --locked --all
6464
cargo test --locked -p tokio-rustls --features early-data --test early-data
65-
# we run all test suites against this feature since it shifts the default behavior globally
66-
cargo test --locked -p tokio-rustls --features compute-heavy-future-executor
65+
# we run all test suites against this feature
66+
# to capture any regressions that come from changes to the handshake future state machine
67+
cargo test --locked -p tokio-rustls --features vacation
6768
6869
lints:
6970
name: Lints

Cargo.lock

+10-10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+3-7
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ rust-version = "1.70"
1313
exclude = ["/.github", "/examples", "/scripts"]
1414

1515
[dependencies]
16-
# implicitly enables the tokio feature for compute-heavy-future-executor
17-
# (defaulting to strategy of spawn_blocking w/ concurrency conctorl)
18-
compute-heavy-future-executor = { version = "0.1", optional = true}
16+
vacation = { version = "0.1", optional = true, default-features = false }
1917
pin-project-lite = { version = "0.2.15", optional = true }
2018
rustls = { version = "0.23.15", default-features = false, features = ["std"] }
2119
tokio = "1.0"
@@ -24,7 +22,7 @@ tokio = "1.0"
2422
default = ["logging", "tls12", "aws_lc_rs"]
2523
aws_lc_rs = ["rustls/aws_lc_rs"]
2624
aws-lc-rs = ["aws_lc_rs"] # Alias because Cargo features commonly use `-`
27-
compute-heavy-future-executor = ["dep:compute-heavy-future-executor", "pin-project-lite"]
25+
vacation = ["dep:vacation", "pin-project-lite"]
2826
early-data = []
2927
fips = ["rustls/fips"]
3028
logging = ["rustls/logging"]
@@ -37,7 +35,5 @@ futures-util = "0.3.1"
3735
lazy_static = "1.1"
3836
rcgen = { version = "0.13", features = ["pem"] }
3937
tokio = { version = "1.0", features = ["full"] }
38+
vacation = { version = "0.1", features = ["tokio"] }
4039
webpki-roots = "0.26"
41-
42-
[patch.crates-io]
43-
compute-heavy-future-executor = { path = "../compute-heavy-future-executor" }

src/client.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ fn poll_handle_early_data<IO>(
288288
where
289289
IO: AsyncRead + AsyncWrite + Unpin,
290290
{
291+
use crate::common::PacketProcessingMode;
292+
291293
if let TlsState::EarlyData(pos, data) = state {
292294
use std::io::Write;
293295

@@ -321,7 +323,8 @@ where
321323

322324
// complete handshake
323325
while stream.session.is_handshaking() {
324-
ready!(stream.handshake(cx, false))?;
326+
// TODO: also model as using `vacation` executor
327+
ready!(stream.handshake(cx, PacketProcessingMode::Sync))?;
325328
}
326329

327330
// write early data (fallback)

src/common/async_session.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use super::{Stream, TlsState};
1717
/// Full result of sync closure
1818
type SessionResult<S> = Result<S, (Option<S>, io::Error)>;
1919
/// Executor result wrapping sync closure result
20-
type SyncExecutorResult<S> = Result<SessionResult<S>, compute_heavy_future_executor::Error>;
20+
type SyncExecutorResult<S> = Result<SessionResult<S>, vacation::Error>;
2121
/// Future wrapping waiting on executor
2222
type SessionFuture<S> = Box<dyn Future<Output = SyncExecutorResult<S>> + Unpin + Send>;
2323

@@ -53,7 +53,9 @@ where
5353
)),
5454
};
5555

56-
let future = compute_heavy_future_executor::execute_sync(closure);
56+
// TODO: if we ever start also delegating non-handshake byte processing, make this chance of blocking
57+
// variable and set by caller
58+
let future = vacation::execute_sync(closure, vacation::ChanceOfBlocking::High);
5759

5860
Self {
5961
future: Box::new(Box::pin(future)),

src/common/handshake.rs

+9-10
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ use rustls::server::AcceptedAlert;
88
use rustls::{ConnectionCommon, SideData};
99
use tokio::io::{AsyncRead, AsyncWrite};
1010

11-
use crate::common::{Stream, SyncWriteAdapter, TlsState};
11+
use crate::common::{PacketProcessingMode, Stream, SyncWriteAdapter, TlsState};
1212

13-
#[cfg(feature = "compute-heavy-future-executor")]
13+
#[cfg(feature = "vacation")]
1414
use super::async_session::AsyncSession;
1515

1616
pub(crate) trait IoSession {
@@ -34,7 +34,7 @@ pub(crate) trait IoSession {
3434

3535
pub(crate) enum MidHandshake<IS: IoSession> {
3636
Handshaking(IS),
37-
#[cfg(feature = "compute-heavy-future-executor")]
37+
#[cfg(feature = "vacation")]
3838
AsyncSession(AsyncSession<IS>),
3939
End,
4040
SendAlert {
@@ -61,7 +61,7 @@ where
6161

6262
let mut stream = match mem::replace(this, MidHandshake::End) {
6363
MidHandshake::Handshaking(stream) => stream,
64-
#[cfg(feature = "compute-heavy-future-executor")]
64+
#[cfg(feature = "vacation")]
6565
MidHandshake::AsyncSession(mut async_session) => {
6666
let pinned = Pin::new(&mut async_session);
6767
let session_result = ready!(pinned.poll(cx));
@@ -94,7 +94,7 @@ where
9494
( $e:expr ) => {
9595
match $e {
9696
Poll::Ready(Ok(_)) => (),
97-
#[cfg(feature = "compute-heavy-future-executor")]
97+
#[cfg(feature = "vacation")]
9898
Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::WouldBlock => {
9999
// TODO: downcast to decide on closure, for now we only do this for
100100
// process_packets
@@ -132,12 +132,11 @@ where
132132
};
133133
}
134134

135-
136135
while tls_stream.session.is_handshaking() {
137-
#[cfg(feature = "compute-heavy-future-executor")]
138-
try_poll!(tls_stream.handshake(cx, true));
139-
#[cfg(not(feature = "compute-heavy-future-executor"))]
140-
try_poll!(tls_stream.handshake(cx, false));
136+
#[cfg(feature = "vacation")]
137+
try_poll!(tls_stream.handshake(cx, PacketProcessingMode::Async));
138+
#[cfg(not(feature = "vacation"))]
139+
try_poll!(tls_stream.handshake(cx, PacketProcessingMode::Sync));
141140
}
142141

143142
try_poll!(Pin::new(&mut tls_stream).poll_flush(cx));

src/common/mod.rs

+23-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::task::{Context, Poll};
66
use rustls::{ConnectionCommon, SideData};
77
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
88

9-
#[cfg(feature = "compute-heavy-future-executor")]
9+
#[cfg(feature = "vacation")]
1010
mod async_session;
1111
mod handshake;
1212
pub(crate) use handshake::{IoSession, MidHandshake};
@@ -21,6 +21,14 @@ pub enum TlsState {
2121
FullyShutdown,
2222
}
2323

24+
/// Whether to delegate the call to the `vacation` executor,
25+
/// only kicks in if `vacation` feature is enabled
26+
#[derive(Debug, PartialEq, Clone, Copy)]
27+
pub enum PacketProcessingMode {
28+
Async,
29+
Sync,
30+
}
31+
2432
impl TlsState {
2533
#[inline]
2634
pub fn shutdown_read(&mut self) {
@@ -92,7 +100,11 @@ where
92100
}
93101

94102
#[allow(unused_variables)]
95-
pub fn read_io(&mut self, cx: &mut Context, process_packets_async: bool) -> Poll<io::Result<usize>> {
103+
pub fn read_io(
104+
&mut self,
105+
cx: &mut Context,
106+
packet_processing_mode: PacketProcessingMode,
107+
) -> Poll<io::Result<usize>> {
96108
let mut reader = SyncReadAdapter { io: self.io, cx };
97109

98110
let n: usize = match self.session.read_tls(&mut reader) {
@@ -101,8 +113,8 @@ where
101113
Err(err) => return Poll::Ready(Err(err)),
102114
};
103115

104-
#[cfg(feature = "compute-heavy-future-executor")]
105-
if process_packets_async {
116+
#[cfg(feature = "vacation")]
117+
if packet_processing_mode == PacketProcessingMode::Async {
106118
// TODO: stop modeling errors as IO, use enum on types of async session processing
107119
return Poll::Ready(Err(io::Error::new(
108120
io::ErrorKind::WouldBlock,
@@ -131,7 +143,11 @@ where
131143
}
132144
}
133145

134-
pub fn handshake(&mut self, cx: &mut Context, process_packets_async: bool) -> Poll<io::Result<(usize, usize)>> {
146+
pub fn handshake(
147+
&mut self,
148+
cx: &mut Context,
149+
packet_processing_mode: PacketProcessingMode,
150+
) -> Poll<io::Result<(usize, usize)>> {
135151
let mut wrlen = 0;
136152
let mut rdlen = 0;
137153

@@ -164,7 +180,7 @@ where
164180
}
165181

166182
while !self.eof && self.session.wants_read() {
167-
match self.read_io(cx, process_packets_async) {
183+
match self.read_io(cx, packet_processing_mode) {
168184
Poll::Ready(Ok(0)) => self.eof = true,
169185
Poll::Ready(Ok(n)) => rdlen += n,
170186
Poll::Pending => {
@@ -208,7 +224,7 @@ where
208224

209225
// read a packet
210226
while !self.eof && self.session.wants_read() {
211-
match self.read_io(cx, false) {
227+
match self.read_io(cx, PacketProcessingMode::Sync) {
212228
Poll::Ready(Ok(0)) => {
213229
break;
214230
}

src/common/test_stream.rs

+23-20
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use rustls::pki_types::ServerName;
99
use rustls::{ClientConnection, Connection, ServerConnection};
1010
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
1111

12-
use super::Stream;
12+
use super::{PacketProcessingMode, Stream};
1313

1414
struct Good<'a>(&'a mut Connection);
1515

@@ -229,12 +229,13 @@ async fn stream_handshake() -> io::Result<()> {
229229
{
230230
let mut good = Good(&mut server);
231231
let mut stream = Stream::new(&mut good, &mut client);
232-
let (r, w) = poll_fn(|cx| stream.handshake(cx, false)).await?;
232+
let (r, w) = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?;
233233

234234
assert!(r > 0);
235235
assert!(w > 0);
236236

237-
poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, false)).await?; // finish server handshake
237+
poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, PacketProcessingMode::Sync)).await?;
238+
// finish server handshake
238239
}
239240

240241
assert!(!server.is_handshaking());
@@ -253,12 +254,12 @@ async fn stream_buffered_handshake() -> io::Result<()> {
253254
{
254255
let mut good = BufWriter::new(Good(&mut server));
255256
let mut stream = Stream::new(&mut good, &mut client);
256-
let (r, w) = poll_fn(|cx| stream.handshake(cx, false)).await?;
257+
let (r, w) = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?;
257258

258259
assert!(r > 0);
259260
assert!(w > 0);
260261

261-
poll_fn(|cx| stream.handshake(cx, false)).await?; // finish server handshake
262+
poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?; // finish server handshake
262263
}
263264

264265
assert!(!server.is_handshaking());
@@ -275,7 +276,7 @@ async fn stream_handshake_eof() -> io::Result<()> {
275276
let mut stream = Stream::new(&mut bad, &mut client);
276277

277278
let mut cx = Context::from_waker(noop_waker_ref());
278-
let r = stream.handshake(&mut cx, false);
279+
let r = stream.handshake(&mut cx, PacketProcessingMode::Sync);
279280
assert_eq!(
280281
r.map_err(|err| err.kind()),
281282
Poll::Ready(Err(io::ErrorKind::UnexpectedEof))
@@ -292,7 +293,7 @@ async fn stream_handshake_write_eof() -> io::Result<()> {
292293
let mut stream = Stream::new(&mut io, &mut client);
293294

294295
let mut cx = Context::from_waker(noop_waker_ref());
295-
let r = stream.handshake(&mut cx, false);
296+
let r = stream.handshake(&mut cx, PacketProcessingMode::Sync);
296297
assert_eq!(
297298
r.map_err(|err| err.kind()),
298299
Poll::Ready(Err(io::ErrorKind::WriteZero))
@@ -310,7 +311,7 @@ async fn stream_handshake_regression_issues_77() -> io::Result<()> {
310311
let mut stream = Stream::new(&mut bad, &mut client);
311312

312313
let mut cx = Context::from_waker(noop_waker_ref());
313-
let r = stream.handshake(&mut cx, false);
314+
let r = stream.handshake(&mut cx, PacketProcessingMode::Sync);
314315
assert_eq!(
315316
r.map_err(|err| err.kind()),
316317
Poll::Ready(Err(io::ErrorKind::InvalidData))
@@ -366,31 +367,33 @@ async fn async_process_packets() -> io::Result<()> {
366367
let mut stream = Stream::new(&mut good, &mut client);
367368

368369
// if feature is enabled, we expect a blocking response on process packets throughout the handshake,
369-
#[cfg(feature = "compute-heavy-future-executor")]
370-
{ let result = poll_fn(|cx| stream.handshake(cx, true)).await;
370+
#[cfg(feature = "vacation")]
371+
{
372+
let result = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Async)).await;
371373

372374
assert_eq!(
373375
result.err().map(|e| e.kind()),
374376
Some(io::ErrorKind::WouldBlock)
375377
);
376378

377379
// finish the handshake without delegating to async session
378-
poll_fn(|cx| stream.handshake(cx, false)).await?; // client handshake
379-
poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, true)).await?; // server handshake
380+
poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?; // client handshake
381+
poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, PacketProcessingMode::Sync)).await?;
382+
// server handshake
380383
}
381384

382-
// if feature is disabled, we expect normal handling
383-
#[cfg(not(feature = "compute-heavy-future-executor"))]
385+
// if feature is disabled, we expect normal handling even if async is passed in
386+
#[cfg(not(feature = "vacation"))]
384387
{
385388
{
386-
let (r, w) = poll_fn(|cx| stream.handshake(cx, true)).await?; // client handshake
387-
389+
let (r, w) = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Async)).await?; // client handshake
390+
388391
assert!(r > 0);
389392
assert!(w > 0);
390-
391-
poll_fn(|cx| stream.handshake(cx, true)).await?; // server handshake
393+
394+
poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Async)).await?;
395+
// server handshake
392396
}
393-
394397
}
395398

396399
// once handshake is done, there is no longer blocking sending data over the stream
@@ -426,7 +429,7 @@ fn do_handshake(
426429
let mut stream = Stream::new(&mut good, client);
427430

428431
while stream.session.is_handshaking() {
429-
ready!(stream.handshake(cx, false))?;
432+
ready!(stream.handshake(cx, PacketProcessingMode::Sync))?;
430433
}
431434

432435
while stream.session.wants_write() {

0 commit comments

Comments
 (0)