Skip to content

Commit 3719c22

Browse files
authored
Implementing graceful shutdown (#105)
* Initial commit for graceful shutdown * fmt * Add .vscode to gitignore * Updates shutdown logic to use channels * fmt * fmt * Adds shutdown timeout * Fmt and updates tomls * Updates readme * fmt and updates log levels * Update python tests to test shutdown * merge changes * Rename listener rx and update bash to be in line with master * Update python test bash script ordering * Adds error response message before shutdown * Add details on shutdown event loop * Fixes response length for error * Adds handler for sigterm * Uses ready for query function and fixes number of bytes * fmt
1 parent 106ebee commit 3719c22

File tree

12 files changed

+308
-47
lines changed

12 files changed

+308
-47
lines changed

.circleci/pgcat.toml

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ connect_timeout = 100
1717
# How much time to give the health check query to return with a result (ms).
1818
healthcheck_timeout = 100
1919

20+
# How much time to give clients during shutdown before forcibly killing client connections (ms).
21+
shutdown_timeout = 5000
22+
2023
# For how long to ban a server if it fails a health check (seconds).
2124
ban_time = 60 # Seconds
2225

.circleci/run_tests.sh

+4-4
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ cd ../..
7474

7575
#
7676
# Python tests
77+
# These tests will start and stop the pgcat server so it will need to be restarted after the tests
7778
#
78-
cd tests/python
79-
pip3 install -r requirements.txt
80-
python3 tests.py
81-
cd ../..
79+
pip3 install -r tests/python/requirements.txt
80+
python3 tests/python/tests.py
8281

82+
start_pgcat "info"
8383

8484
# Admin tests
8585
export PGPASSWORD=admin_pass

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.idea
22
/target
33
*.deb
4+
.vscode

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ psql -h 127.0.0.1 -p 6432 -c 'SELECT 1'
4747
| `pool_mode` | The pool mode to use, i.e. `session` or `transaction`. | `transaction` |
4848
| `connect_timeout` | Maximum time to establish a connection to a server (milliseconds). If reached, the server is banned and the next target is attempted. | `5000` |
4949
| `healthcheck_timeout` | Maximum time to pass a health check (`SELECT 1`, milliseconds). If reached, the server is banned and the next target is attempted. | `1000` |
50+
| `shutdown_timeout` | Maximum time to give clients during shutdown before forcibly killing client connections (ms). | `60000` |
5051
| `ban_time` | Ban time for a server (seconds). It won't be allowed to serve transactions until the ban expires; failover targets will be used instead. | `60` |
5152
| | | |
5253
| **`user`** | | |
@@ -250,6 +251,7 @@ The config can be reloaded by sending a `kill -s SIGHUP` to the process or by qu
250251
| `pool_mode` | no |
251252
| `connect_timeout` | yes |
252253
| `healthcheck_timeout` | no |
254+
| `shutdown_timeout` | no |
253255
| `ban_time` | no |
254256
| `user` | yes |
255257
| `shards` | yes |

examples/docker/pgcat.toml

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ connect_timeout = 5000
1717
# How much time to give `SELECT 1` health check query to return with a result (ms).
1818
healthcheck_timeout = 1000
1919

20+
# How much time to give clients during shutdown before forcibly killing client connections (ms).
21+
shutdown_timeout = 60000
22+
2023
# For how long to ban a server if it fails a health check (seconds).
2124
ban_time = 60 # seconds
2225

pgcat.toml

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ connect_timeout = 5000
1717
# How much time to give the health check query to return with a result (ms).
1818
healthcheck_timeout = 1000
1919

20+
# How much time to give clients during shutdown before forcibly killing client connections (ms).
21+
shutdown_timeout = 60000
22+
2023
# For how long to ban a server if it fails a health check (seconds).
2124
ban_time = 60 # seconds
2225

src/client.rs

+57-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use log::{debug, error, info, trace};
44
use std::collections::HashMap;
55
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
66
use tokio::net::TcpStream;
7+
use tokio::sync::broadcast::Receiver;
78

89
use crate::admin::{generate_server_info_for_admin, handle_admin};
910
use crate::config::get_config;
@@ -73,12 +74,15 @@ pub struct Client<S, T> {
7374
last_server_id: Option<i32>,
7475

7576
target_pool: ConnectionPool,
77+
78+
shutdown_event_receiver: Receiver<()>,
7679
}
7780

7881
/// Client entrypoint.
7982
pub async fn client_entrypoint(
8083
mut stream: TcpStream,
8184
client_server_map: ClientServerMap,
85+
shutdown_event_receiver: Receiver<()>,
8286
) -> Result<(), Error> {
8387
// Figure out if the client wants TLS or not.
8488
let addr = stream.peer_addr().unwrap();
@@ -97,7 +101,7 @@ pub async fn client_entrypoint(
97101
write_all(&mut stream, yes).await?;
98102

99103
// Negotiate TLS.
100-
match startup_tls(stream, client_server_map).await {
104+
match startup_tls(stream, client_server_map, shutdown_event_receiver).await {
101105
Ok(mut client) => {
102106
info!("Client {:?} connected (TLS)", addr);
103107

@@ -121,7 +125,16 @@ pub async fn client_entrypoint(
121125
let (read, write) = split(stream);
122126

123127
// Continue with regular startup.
124-
match Client::startup(read, write, addr, bytes, client_server_map).await {
128+
match Client::startup(
129+
read,
130+
write,
131+
addr,
132+
bytes,
133+
client_server_map,
134+
shutdown_event_receiver,
135+
)
136+
.await
137+
{
125138
Ok(mut client) => {
126139
info!("Client {:?} connected (plain)", addr);
127140

@@ -142,7 +155,16 @@ pub async fn client_entrypoint(
142155
let (read, write) = split(stream);
143156

144157
// Continue with regular startup.
145-
match Client::startup(read, write, addr, bytes, client_server_map).await {
158+
match Client::startup(
159+
read,
160+
write,
161+
addr,
162+
bytes,
163+
client_server_map,
164+
shutdown_event_receiver,
165+
)
166+
.await
167+
{
146168
Ok(mut client) => {
147169
info!("Client {:?} connected (plain)", addr);
148170

@@ -157,7 +179,16 @@ pub async fn client_entrypoint(
157179
let (read, write) = split(stream);
158180

159181
// Continue with cancel query request.
160-
match Client::cancel(read, write, addr, bytes, client_server_map).await {
182+
match Client::cancel(
183+
read,
184+
write,
185+
addr,
186+
bytes,
187+
client_server_map,
188+
shutdown_event_receiver,
189+
)
190+
.await
191+
{
161192
Ok(mut client) => {
162193
info!("Client {:?} issued a cancel query request", addr);
163194

@@ -214,6 +245,7 @@ where
214245
pub async fn startup_tls(
215246
stream: TcpStream,
216247
client_server_map: ClientServerMap,
248+
shutdown_event_receiver: Receiver<()>,
217249
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
218250
// Negotiate TLS.
219251
let tls = Tls::new()?;
@@ -237,7 +269,15 @@ pub async fn startup_tls(
237269
Ok((ClientConnectionType::Startup, bytes)) => {
238270
let (read, write) = split(stream);
239271

240-
Client::startup(read, write, addr, bytes, client_server_map).await
272+
Client::startup(
273+
read,
274+
write,
275+
addr,
276+
bytes,
277+
client_server_map,
278+
shutdown_event_receiver,
279+
)
280+
.await
241281
}
242282

243283
// Bad Postgres client.
@@ -258,6 +298,7 @@ where
258298
addr: std::net::SocketAddr,
259299
bytes: BytesMut, // The rest of the startup message.
260300
client_server_map: ClientServerMap,
301+
shutdown_event_receiver: Receiver<()>,
261302
) -> Result<Client<S, T>, Error> {
262303
let config = get_config();
263304
let stats = get_reporter();
@@ -384,6 +425,7 @@ where
384425
last_address_id: None,
385426
last_server_id: None,
386427
target_pool: target_pool,
428+
shutdown_event_receiver: shutdown_event_receiver,
387429
});
388430
}
389431

@@ -394,6 +436,7 @@ where
394436
addr: std::net::SocketAddr,
395437
mut bytes: BytesMut, // The rest of the startup message.
396438
client_server_map: ClientServerMap,
439+
shutdown_event_receiver: Receiver<()>,
397440
) -> Result<Client<S, T>, Error> {
398441
let process_id = bytes.get_i32();
399442
let secret_key = bytes.get_i32();
@@ -413,6 +456,7 @@ where
413456
last_address_id: None,
414457
last_server_id: None,
415458
target_pool: ConnectionPool::default(),
459+
shutdown_event_receiver: shutdown_event_receiver,
416460
});
417461
}
418462

@@ -467,7 +511,14 @@ where
467511
// We can parse it here before grabbing a server from the pool,
468512
// in case the client is sending some custom protocol messages, e.g.
469513
// SET SHARDING KEY TO 'bigint';
470-
let mut message = read_message(&mut self.read).await?;
514+
515+
let mut message = tokio::select! {
516+
_ = self.shutdown_event_receiver.recv() => {
517+
error_response_terminal(&mut self.write, &format!("terminating connection due to administrator command")).await?;
518+
return Ok(())
519+
},
520+
message_result = read_message(&mut self.read) => message_result?
521+
};
471522

472523
// Get a pool instance referenced by the most up-to-date
473524
// pointer. This ensures we always read the latest config

src/config.rs

+7
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ pub struct General {
119119
pub port: i16,
120120
pub connect_timeout: u64,
121121
pub healthcheck_timeout: u64,
122+
pub shutdown_timeout: u64,
122123
pub ban_time: i64,
123124
pub autoreload: bool,
124125
pub tls_certificate: Option<String>,
@@ -134,6 +135,7 @@ impl Default for General {
134135
port: 5432,
135136
connect_timeout: 5000,
136137
healthcheck_timeout: 1000,
138+
shutdown_timeout: 60000,
137139
ban_time: 60,
138140
autoreload: false,
139141
tls_certificate: None,
@@ -273,6 +275,10 @@ impl From<&Config> for std::collections::HashMap<String, String> {
273275
"healthcheck_timeout".to_string(),
274276
config.general.healthcheck_timeout.to_string(),
275277
),
278+
(
279+
"shutdown_timeout".to_string(),
280+
config.general.shutdown_timeout.to_string(),
281+
),
276282
("ban_time".to_string(), config.general.ban_time.to_string()),
277283
];
278284

@@ -290,6 +296,7 @@ impl Config {
290296
self.general.healthcheck_timeout
291297
);
292298
info!("Connection timeout: {}ms", self.general.connect_timeout);
299+
info!("Shutdown timeout: {}ms", self.general.shutdown_timeout);
293300
match self.general.tls_certificate.clone() {
294301
Some(tls_certificate) => {
295302
info!("TLS certificate: {}", tls_certificate);

src/main.rs

+68-10
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ use log::{debug, error, info};
4040
use parking_lot::Mutex;
4141
use tokio::net::TcpListener;
4242
use tokio::{
43-
signal,
4443
signal::unix::{signal as unix_signal, SignalKind},
4544
sync::mpsc,
4645
};
4746

4847
use std::collections::HashMap;
4948
use std::sync::Arc;
49+
use tokio::sync::broadcast;
5050

5151
mod admin;
5252
mod client;
@@ -139,24 +139,52 @@ async fn main() {
139139

140140
info!("Waiting for clients");
141141

142+
let (shutdown_event_tx, mut shutdown_event_rx) = broadcast::channel::<()>(1);
143+
144+
let shutdown_event_tx_clone = shutdown_event_tx.clone();
145+
142146
// Client connection loop.
143147
tokio::task::spawn(async move {
148+
// Creates event subscriber for shutdown event, this is dropped when shutdown event is broadcast
149+
let mut listener_shutdown_event_rx = shutdown_event_tx_clone.subscribe();
144150
loop {
145151
let client_server_map = client_server_map.clone();
146152

147-
let (socket, addr) = match listener.accept().await {
148-
Ok((socket, addr)) => (socket, addr),
149-
Err(err) => {
150-
error!("{:?}", err);
151-
continue;
153+
// Listen for shutdown event and client connection at the same time
154+
let (socket, addr) = tokio::select! {
155+
_ = listener_shutdown_event_rx.recv() => {
156+
// Exits client connection loop which drops listener, listener_shutdown_event_rx and shutdown_event_tx_clone
157+
break;
158+
}
159+
160+
listener_response = listener.accept() => {
161+
match listener_response {
162+
Ok((socket, addr)) => (socket, addr),
163+
Err(err) => {
164+
error!("{:?}", err);
165+
continue;
166+
}
167+
}
152168
}
153169
};
154170

171+
// Used to signal shutdown
172+
let client_shutdown_handler_rx = shutdown_event_tx_clone.subscribe();
173+
174+
// Used to signal that the task has completed
175+
let dummy_tx = shutdown_event_tx_clone.clone();
176+
155177
// Handle client.
156178
tokio::task::spawn(async move {
157179
let start = chrono::offset::Utc::now().naive_utc();
158180

159-
match client::client_entrypoint(socket, client_server_map).await {
181+
match client::client_entrypoint(
182+
socket,
183+
client_server_map,
184+
client_shutdown_handler_rx,
185+
)
186+
.await
187+
{
160188
Ok(_) => {
161189
let duration = chrono::offset::Utc::now().naive_utc() - start;
162190

@@ -171,6 +199,8 @@ async fn main() {
171199
debug!("Client disconnected with error {:?}", err);
172200
}
173201
};
202+
// Drop this transmitter so receiver knows that the task is completed
203+
drop(dummy_tx);
174204
});
175205
}
176206
});
@@ -214,13 +244,41 @@ async fn main() {
214244
});
215245
}
216246

217-
// Exit on Ctrl-C (SIGINT) and SIGTERM.
218247
let mut term_signal = unix_signal(SignalKind::terminate()).unwrap();
248+
let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap();
219249

220250
tokio::select! {
221-
_ = signal::ctrl_c() => (),
251+
// Initiate graceful shutdown sequence on sig int
252+
_ = interrupt_signal.recv() => {
253+
info!("Got SIGINT, waiting for client connection drain now");
254+
255+
// Broadcast that client tasks need to finish
256+
shutdown_event_tx.send(()).unwrap();
257+
// Closes transmitter
258+
drop(shutdown_event_tx);
259+
260+
// This is in a loop because the first event that the receiver receives will be the shutdown event
261+
// This is not what we are waiting for instead, we want the receiver to send an error once all senders are closed which is reached after the shutdown event is received
262+
loop {
263+
match tokio::time::timeout(
264+
tokio::time::Duration::from_millis(config.general.shutdown_timeout),
265+
shutdown_event_rx.recv(),
266+
)
267+
.await
268+
{
269+
Ok(res) => match res {
270+
Ok(_) => {}
271+
Err(_) => break,
272+
},
273+
Err(_) => {
274+
info!("Timed out while waiting for clients to shutdown");
275+
break;
276+
}
277+
}
278+
}
279+
},
222280
_ = term_signal.recv() => (),
223-
};
281+
}
224282

225283
info!("Shutting down...");
226284
}

0 commit comments

Comments
 (0)