Skip to content

Commit da703eb

Browse files
Fishrock123jbr
andcommitted
Server: require State to be Clone
Alternative to http-rs#642 This approach is more flexible but requires the user ensure that their state implements/derives `Clone`, or is wrapped in an `Arc`. Co-authored-by: Jacob Rothstein <[email protected]>
1 parent 3778706 commit da703eb

File tree

10 files changed

+57
-51
lines changed

10 files changed

+57
-51
lines changed

Diff for: Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ serde = "1.0.102"
4242
serde_json = "1.0.41"
4343
route-recognizer = "0.2.0"
4444
logtest = "2.0.0"
45+
pin-project-lite = "0.1.7"
4546

4647
[dev-dependencies]
4748
async-std = { version = "1.6.0", features = ["unstable", "attributes"] }

Diff for: examples/graphql.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use async_std::task;
24
use juniper::{http::graphiql, http::GraphQLRequest, RootNode};
35
use std::sync::RwLock;
@@ -72,7 +74,7 @@ fn create_schema() -> Schema {
7274
Schema::new(QueryRoot {}, MutationRoot {})
7375
}
7476

75-
async fn handle_graphql(mut request: Request<State>) -> tide::Result {
77+
async fn handle_graphql(mut request: Request<Arc<State>>) -> tide::Result {
7678
let query: GraphQLRequest = request.body_json().await?;
7779
let schema = create_schema(); // probably worth making the schema a singleton using lazy_static library
7880
let response = query.execute(&schema, request.state());
@@ -87,17 +89,17 @@ async fn handle_graphql(mut request: Request<State>) -> tide::Result {
8789
.build())
8890
}
8991

90-
async fn handle_graphiql(_: Request<State>) -> tide::Result<impl Into<Response>> {
92+
async fn handle_graphiql(_: Request<Arc<State>>) -> tide::Result<impl Into<Response>> {
9193
Ok(Response::builder(200)
9294
.body(graphiql::graphiql_source("/graphql"))
9395
.content_type(mime::HTML))
9496
}
9597

9698
fn main() -> std::io::Result<()> {
9799
task::block_on(async {
98-
let mut app = Server::with_state(State {
100+
let mut app = Server::with_state(Arc::new(State {
99101
users: RwLock::new(Vec::new()),
100-
});
102+
}));
101103
app.at("/").get(Redirect::permanent("/graphiql"));
102104
app.at("/graphql").post(handle_graphql);
103105
app.at("/graphiql").get(handle_graphiql);

Diff for: examples/middleware.rs

+9-7
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ impl UserDatabase {
2525
// application state. Because it depends on a specific request state,
2626
// it would likely be closely tied to a specific application
2727
fn user_loader<'a>(
28-
mut request: Request<UserDatabase>,
29-
next: Next<'a, UserDatabase>,
28+
mut request: Request<Arc<UserDatabase>>,
29+
next: Next<'a, Arc<UserDatabase>>,
3030
) -> Pin<Box<dyn Future<Output = Result> + Send + 'a>> {
3131
Box::pin(async {
3232
if let Some(user) = request.state().find_user().await {
@@ -98,7 +98,7 @@ const INTERNAL_SERVER_ERROR_HTML_PAGE: &str = "<html><body>
9898
#[async_std::main]
9999
async fn main() -> Result<()> {
100100
tide::log::start();
101-
let mut app = tide::with_state(UserDatabase::default());
101+
let mut app = tide::with_state(Arc::new(UserDatabase::default()));
102102

103103
app.middleware(After(|response: Response| async move {
104104
let response = match response.status() {
@@ -120,10 +120,12 @@ async fn main() -> Result<()> {
120120

121121
app.middleware(user_loader);
122122
app.middleware(RequestCounterMiddleware::new(0));
123-
app.middleware(Before(|mut request: Request<UserDatabase>| async move {
124-
request.set_ext(std::time::Instant::now());
125-
request
126-
}));
123+
app.middleware(Before(
124+
|mut request: Request<Arc<UserDatabase>>| async move {
125+
request.set_ext(std::time::Instant::now());
126+
request
127+
},
128+
));
127129

128130
app.at("/").get(|req: Request<_>| async move {
129131
let count: &RequestCount = req.ext().unwrap();

Diff for: examples/upload.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use async_std::{fs::OpenOptions, io};
24
use tempfile::TempDir;
35
use tide::prelude::*;
@@ -6,15 +8,15 @@ use tide::{Body, Request, Response, StatusCode};
68
#[async_std::main]
79
async fn main() -> Result<(), std::io::Error> {
810
tide::log::start();
9-
let mut app = tide::with_state(tempfile::tempdir()?);
11+
let mut app = tide::with_state(Arc::new(tempfile::tempdir()?));
1012

1113
// To test this example:
1214
// $ cargo run --example upload
1315
// $ curl -T ./README.md locahost:8080 # this writes the file to a temp directory
1416
// $ curl localhost:8080/README.md # this reads the file from the same temp directory
1517

1618
app.at(":file")
17-
.put(|req: Request<TempDir>| async move {
19+
.put(|req: Request<Arc<TempDir>>| async move {
1820
let path: String = req.param("file")?;
1921
let fs_path = req.state().path().join(path);
2022

@@ -33,7 +35,7 @@ async fn main() -> Result<(), std::io::Error> {
3335

3436
Ok(json!({ "bytes": bytes_written }))
3537
})
36-
.get(|req: Request<TempDir>| async move {
38+
.get(|req: Request<Arc<TempDir>>| async move {
3739
let path: String = req.param("file")?;
3840
let fs_path = req.state().path().join(path);
3941

Diff for: src/fs/serve_dir.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ where
6262
mod test {
6363
use super::*;
6464

65-
use async_std::sync::Arc;
66-
6765
use std::fs::{self, File};
6866
use std::io::Write;
6967

@@ -85,7 +83,7 @@ mod test {
8583
let request = crate::http::Request::get(
8684
crate::http::Url::parse(&format!("http://localhost/{}", path)).unwrap(),
8785
);
88-
crate::Request::new(Arc::new(()), request, vec![])
86+
crate::Request::new((), request, vec![])
8987
}
9088

9189
#[async_std::test]

Diff for: src/lib.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ pub fn new() -> server::Server<()> {
255255
/// # use async_std::task::block_on;
256256
/// # fn main() -> Result<(), std::io::Error> { block_on(async {
257257
/// #
258+
/// use std::sync::Arc;
258259
/// use tide::Request;
259260
///
260261
/// /// The shared application state.
@@ -268,8 +269,8 @@ pub fn new() -> server::Server<()> {
268269
/// };
269270
///
270271
/// // Initialize the application with state.
271-
/// let mut app = tide::with_state(state);
272-
/// app.at("/").get(|req: Request<State>| async move {
272+
/// let mut app = tide::with_state(Arc::new(state));
273+
/// app.at("/").get(|req: Request<Arc<State>>| async move {
273274
/// Ok(format!("Hello, {}!", &req.state().name))
274275
/// });
275276
/// app.listen("127.0.0.1:8080").await?;
@@ -278,7 +279,7 @@ pub fn new() -> server::Server<()> {
278279
/// ```
279280
pub fn with_state<State>(state: State) -> server::Server<State>
280281
where
281-
State: Send + Sync + 'static,
282+
State: Clone + Send + Sync + 'static,
282283
{
283284
Server::with_state(state)
284285
}

Diff for: src/request.rs

+19-20
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,29 @@ use route_recognizer::Params;
44

55
use std::ops::Index;
66
use std::pin::Pin;
7-
use std::{fmt, str::FromStr, sync::Arc};
7+
use std::{fmt, str::FromStr};
88

99
use crate::cookies::CookieData;
1010
use crate::http::cookies::Cookie;
1111
use crate::http::headers::{self, HeaderName, HeaderValues, ToHeaderValues};
1212
use crate::http::{self, Body, Method, Mime, StatusCode, Url, Version};
1313
use crate::Response;
1414

15-
/// An HTTP request.
16-
///
17-
/// The `Request` gives endpoints access to basic information about the incoming
18-
/// request, route parameters, and various ways of accessing the request's body.
19-
///
20-
/// Requests also provide *extensions*, a type map primarily used for low-level
21-
/// communication between middleware and endpoints.
22-
#[derive(Debug)]
23-
pub struct Request<State> {
24-
pub(crate) state: Arc<State>,
25-
pub(crate) req: http::Request,
26-
pub(crate) route_params: Vec<Params>,
15+
pin_project_lite::pin_project! {
16+
/// An HTTP request.
17+
///
18+
/// The `Request` gives endpoints access to basic information about the incoming
19+
/// request, route parameters, and various ways of accessing the request's body.
20+
///
21+
/// Requests also provide *extensions*, a type map primarily used for low-level
22+
/// communication between middleware and endpoints.
23+
#[derive(Debug)]
24+
pub struct Request<State> {
25+
pub(crate) state: State,
26+
#[pin]
27+
pub(crate) req: http::Request,
28+
pub(crate) route_params: Vec<Params>,
29+
}
2730
}
2831

2932
#[derive(Debug)]
@@ -45,11 +48,7 @@ impl<T: fmt::Debug + fmt::Display> std::error::Error for ParamError<T> {}
4548

4649
impl<State> Request<State> {
4750
/// Create a new `Request`.
48-
pub(crate) fn new(
49-
state: Arc<State>,
50-
req: http_types::Request,
51-
route_params: Vec<Params>,
52-
) -> Self {
51+
pub(crate) fn new(state: State, req: http_types::Request, route_params: Vec<Params>) -> Self {
5352
Self {
5453
state,
5554
req,
@@ -550,11 +549,11 @@ impl<State> AsMut<http::Headers> for Request<State> {
550549

551550
impl<State> Read for Request<State> {
552551
fn poll_read(
553-
mut self: Pin<&mut Self>,
552+
self: Pin<&mut Self>,
554553
cx: &mut Context<'_>,
555554
buf: &mut [u8],
556555
) -> Poll<io::Result<usize>> {
557-
Pin::new(&mut self.req).poll_read(cx, buf)
556+
self.project().req.poll_read(cx, buf)
558557
}
559558
}
560559

Diff for: src/route.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ impl<'a, State: Send + Sync + 'static> Route<'a, State> {
102102
/// [`Server`]: struct.Server.html
103103
pub fn nest<InnerState>(&mut self, service: crate::Server<InnerState>) -> &mut Self
104104
where
105-
State: Send + Sync + 'static,
106-
InnerState: Send + Sync + 'static,
105+
State: Clone + Send + Sync + 'static,
106+
InnerState: Clone + Send + Sync + 'static,
107107
{
108108
self.prefix = true;
109109
self.all(service);

Diff for: src/server.rs

+9-8
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ use crate::{Endpoint, Request, Route};
129129
#[allow(missing_debug_implementations)]
130130
pub struct Server<State> {
131131
router: Arc<Router<State>>,
132-
state: Arc<State>,
132+
state: State,
133133
middleware: Arc<Vec<Arc<dyn Middleware<State>>>>,
134134
}
135135

@@ -166,7 +166,7 @@ impl Default for Server<()> {
166166
}
167167
}
168168

169-
impl<State: Send + Sync + 'static> Server<State> {
169+
impl<State: Clone + Send + Sync + 'static> Server<State> {
170170
/// Create a new Tide server with shared application scoped state.
171171
///
172172
/// Application scoped state is useful for storing items
@@ -177,6 +177,7 @@ impl<State: Send + Sync + 'static> Server<State> {
177177
/// # use async_std::task::block_on;
178178
/// # fn main() -> Result<(), std::io::Error> { block_on(async {
179179
/// #
180+
/// use std::sync::Arc;
180181
/// use tide::Request;
181182
///
182183
/// /// The shared application state.
@@ -190,8 +191,8 @@ impl<State: Send + Sync + 'static> Server<State> {
190191
/// };
191192
///
192193
/// // Initialize the application with state.
193-
/// let mut app = tide::with_state(state);
194-
/// app.at("/").get(|req: Request<State>| async move {
194+
/// let mut app = tide::with_state(Arc::new(state));
195+
/// app.at("/").get(|req: Request<Arc<State>>| async move {
195196
/// Ok(format!("Hello, {}!", &req.state().name))
196197
/// });
197198
/// app.listen("127.0.0.1:8080").await?;
@@ -202,7 +203,7 @@ impl<State: Send + Sync + 'static> Server<State> {
202203
let mut server = Self {
203204
router: Arc::new(Router::new()),
204205
middleware: Arc::new(vec![]),
205-
state: Arc::new(state),
206+
state,
206207
};
207208
server.middleware(cookies::CookiesMiddleware::new());
208209
server.middleware(log::LogMiddleware::new());
@@ -429,7 +430,7 @@ impl<State: Send + Sync + 'static> Server<State> {
429430
}
430431
}
431432

432-
impl<State> Clone for Server<State> {
433+
impl<State: Clone> Clone for Server<State> {
433434
fn clone(&self) -> Self {
434435
Self {
435436
router: self.router.clone(),
@@ -439,8 +440,8 @@ impl<State> Clone for Server<State> {
439440
}
440441
}
441442

442-
impl<State: Sync + Send + 'static, InnerState: Sync + Send + 'static> Endpoint<State>
443-
for Server<InnerState>
443+
impl<State: Clone + Sync + Send + 'static, InnerState: Clone + Sync + Send + 'static>
444+
Endpoint<State> for Server<InnerState>
444445
{
445446
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, crate::Result> {
446447
let Request {

Diff for: tests/test_utils.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub trait ServerTestingExt {
3535

3636
impl<State> ServerTestingExt for Server<State>
3737
where
38-
State: Send + Sync + 'static,
38+
State: Clone + Send + Sync + 'static,
3939
{
4040
fn request<'a>(
4141
&'a self,

0 commit comments

Comments
 (0)