-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathrestart_coordination_socket.rs
127 lines (110 loc) · 4.32 KB
/
restart_coordination_socket.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
//! Communication with a running process over a unix domain socket.
use crate::RestartResult;
use anyhow::{anyhow, Context};
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use serde::{Deserialize, Serialize};
use tokio::net::UnixStream;
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
use tokio_util::codec::{Decoder, Framed};
/// Represents the restart coordination socket, used for communicating with a running oxy process.
/// This is used to trigger a restart and receive notification of its completion or failure.
pub struct RestartCoordinationSocket {
codec: Framed<UnixStream, LengthDelimitedCodec>,
}
impl RestartCoordinationSocket {
/// Create a new RestartCoordinationSocket wrapping a unix socket.
pub fn new(socket: UnixStream) -> Self {
RestartCoordinationSocket {
codec: LengthDelimitedCodec::new().framed(socket),
}
}
/// Sends a restart command through the socket. Returns Ok(child_pid) on success or an error
/// if the restart failed for any reason.
pub async fn send_restart_command(&mut self) -> RestartResult<u32> {
self.send_message(RestartMessage::Request(RestartRequest::TryRestart))
.await?;
match self.receive_message().await? {
RestartMessage::Response(RestartResponse::RestartComplete(pid)) => Ok(pid),
RestartMessage::Response(RestartResponse::RestartFailed(reason)) => {
Err(anyhow!(reason))
}
_ => Err(anyhow!("unexpected message received")),
}
}
/// Send a message over the socket
pub async fn send_message(&mut self, msg: RestartMessage) -> RestartResult<()> {
self.codec
.send(serde_json::to_string(&msg).unwrap().into())
.await?;
Ok(())
}
/// Receive a message from the socket.
pub async fn receive_message(&mut self) -> RestartResult<RestartMessage> {
let message = self
.codec
.next()
.await
.context("connection closed while awaiting a message")??;
Ok(serde_json::from_slice(&message)?)
}
}
/// Represents any message that may be sent over the socket.
#[derive(Debug, Serialize, Deserialize)]
pub enum RestartMessage {
Request(RestartRequest),
Response(RestartResponse),
}
/// A request message that expects a response.
#[derive(Debug, Serialize, Deserialize)]
pub enum RestartRequest {
TryRestart,
}
/// A response to a request message.
#[derive(Debug, Serialize, Deserialize)]
pub enum RestartResponse {
// Restart completed. The child PID is provided.
RestartComplete(u32),
// Restart failed. The error message is attached.
RestartFailed(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_restart_complete() {
let (client, server) = UnixStream::pair().unwrap();
let mut client = RestartCoordinationSocket::new(client);
let mut server = RestartCoordinationSocket::new(server);
let child_pid = 42;
tokio::spawn(async move {
let message = server.receive_message().await.unwrap();
assert!(matches!(
message,
RestartMessage::Request(RestartRequest::TryRestart)
));
let response = RestartMessage::Response(RestartResponse::RestartComplete(child_pid));
server.send_message(response).await.unwrap();
});
assert_eq!(client.send_restart_command().await.unwrap(), child_pid);
}
#[tokio::test]
async fn test_restart_failed() {
let (client, server) = UnixStream::pair().unwrap();
let mut client = RestartCoordinationSocket::new(client);
let mut server = RestartCoordinationSocket::new(server);
let error_message = "huge success";
tokio::spawn(async move {
let message = server.receive_message().await.unwrap();
assert!(matches!(
message,
RestartMessage::Request(RestartRequest::TryRestart)
));
let response =
RestartMessage::Response(RestartResponse::RestartFailed(error_message.into()));
server.send_message(response).await.unwrap();
});
let r = client.send_restart_command().await;
assert_eq!(r.err().map(|e| e.to_string()), Some(error_message.into()));
}
}