Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing graceful shutdown #105

Merged
merged 21 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea
/target
*.deb
.vscode
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ sha2 = "0.10"
base64 = "0.13"
stringprep = "0.1"
tokio-rustls = "0.23"
rustls-pemfile = "1"
rustls-pemfile = "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a terminal new line

62 changes: 56 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use log::{debug, error, info, trace};
use std::collections::HashMap;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver;

use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::config::get_config;
Expand Down Expand Up @@ -73,12 +74,15 @@ pub struct Client<S, T> {
last_server_id: Option<i32>,

target_pool: ConnectionPool,

shutdown_event_receiver: Receiver<()>,
}

/// Client entrypoint.
pub async fn client_entrypoint(
mut stream: TcpStream,
client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
) -> Result<(), Error> {
// Figure out if the client wants TLS or not.
let addr = stream.peer_addr().unwrap();
Expand All @@ -97,7 +101,7 @@ pub async fn client_entrypoint(
write_all(&mut stream, yes).await?;

// Negotiate TLS.
match startup_tls(stream, client_server_map).await {
match startup_tls(stream, client_server_map, shutdown_event_receiver).await {
Ok(mut client) => {
info!("Client {:?} connected (TLS)", addr);

Expand All @@ -121,7 +125,16 @@ pub async fn client_entrypoint(
let (read, write) = split(stream);

// Continue with regular startup.
match Client::startup(read, write, addr, bytes, client_server_map).await {
match Client::startup(
read,
write,
addr,
bytes,
client_server_map,
shutdown_event_receiver,
)
.await
{
Ok(mut client) => {
info!("Client {:?} connected (plain)", addr);

Expand All @@ -142,7 +155,16 @@ pub async fn client_entrypoint(
let (read, write) = split(stream);

// Continue with regular startup.
match Client::startup(read, write, addr, bytes, client_server_map).await {
match Client::startup(
read,
write,
addr,
bytes,
client_server_map,
shutdown_event_receiver,
)
.await
{
Ok(mut client) => {
info!("Client {:?} connected (plain)", addr);

Expand All @@ -157,7 +179,16 @@ pub async fn client_entrypoint(
let (read, write) = split(stream);

// Continue with cancel query request.
match Client::cancel(read, write, addr, bytes, client_server_map).await {
match Client::cancel(
read,
write,
addr,
bytes,
client_server_map,
shutdown_event_receiver,
)
.await
{
Ok(mut client) => {
info!("Client {:?} issued a cancel query request", addr);

Expand Down Expand Up @@ -214,6 +245,7 @@ where
pub async fn startup_tls(
stream: TcpStream,
client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
// Negotiate TLS.
let tls = Tls::new()?;
Expand All @@ -237,7 +269,15 @@ pub async fn startup_tls(
Ok((ClientConnectionType::Startup, bytes)) => {
let (read, write) = split(stream);

Client::startup(read, write, addr, bytes, client_server_map).await
Client::startup(
read,
write,
addr,
bytes,
client_server_map,
shutdown_event_receiver,
)
.await
}

// Bad Postgres client.
Expand All @@ -258,6 +298,7 @@ where
addr: std::net::SocketAddr,
bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
) -> Result<Client<S, T>, Error> {
let config = get_config();
let stats = get_reporter();
Expand Down Expand Up @@ -384,6 +425,7 @@ where
last_address_id: None,
last_server_id: None,
target_pool: target_pool,
shutdown_event_receiver: shutdown_event_receiver,
});
}

Expand All @@ -394,6 +436,7 @@ where
addr: std::net::SocketAddr,
mut bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
) -> Result<Client<S, T>, Error> {
let process_id = bytes.get_i32();
let secret_key = bytes.get_i32();
Expand All @@ -413,6 +456,7 @@ where
last_address_id: None,
last_server_id: None,
target_pool: ConnectionPool::default(),
shutdown_event_receiver: shutdown_event_receiver,
});
}

Expand Down Expand Up @@ -467,7 +511,13 @@ where
// We can parse it here before grabbing a server from the pool,
// in case the client is sending some custom protocol messages, e.g.
// SET SHARDING KEY TO 'bigint';
let mut message = read_message(&mut self.read).await?;

let mut message = tokio::select! {
_ = self.shutdown_event_receiver.recv() => {
return Ok(())
},
message_result = read_message(&mut self.read) => message_result?
};

// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
Expand Down
69 changes: 54 additions & 15 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ extern crate tokio;
extern crate tokio_rustls;
extern crate toml;

use log::{debug, error, info};
use log::{debug, error, info, warn};
use parking_lot::Mutex;
use tokio::net::TcpListener;
use tokio::{
signal,
signal::unix::{signal as unix_signal, SignalKind},
sync::mpsc,
};

use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::broadcast;

mod admin;
mod client;
Expand Down Expand Up @@ -139,24 +139,51 @@ async fn main() {

info!("Waiting for clients");

let (shutdown_event_tx, mut shutdown_event_rx) = broadcast::channel::<()>(1);

let shutdown_event_tx_clone = shutdown_event_tx.clone();

// Client connection loop.
tokio::task::spawn(async move {
// Creates event subscriber for shutdown event, this is dropped when shutdown event is broadcast
let mut listener_rx = shutdown_event_tx_clone.subscribe();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename to shutdown_event_rx because it can be confused with TCP listener by the reader.

loop {
let client_server_map = client_server_map.clone();

let (socket, addr) = match listener.accept().await {
Ok((socket, addr)) => (socket, addr),
Err(err) => {
error!("{:?}", err);
continue;
// Listen for shutdown event and client connection at the same time
let (socket, addr) = tokio::select! {
_ = listener_rx.recv() => {
break;
}

listener_response = listener.accept() => {
match listener_response {
Ok((socket, addr)) => (socket, addr),
Err(err) => {
error!("{:?}", err);
continue;
}
}
}
};

// Used to signal shutdown
let client_shutdown_handler_rx = shutdown_event_tx_clone.subscribe();

// Used to signal that the task has completed
let dummy_tx = shutdown_event_tx_clone.clone();

// Handle client.
tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc();

match client::client_entrypoint(socket, client_server_map).await {
match client::client_entrypoint(
socket,
client_server_map,
client_shutdown_handler_rx,
)
.await
{
Ok(_) => {
let duration = chrono::offset::Utc::now().naive_utc() - start;

Expand All @@ -171,6 +198,8 @@ async fn main() {
debug!("Client disconnected with error {:?}", err);
}
};
// Drop this transmitter so receiver knows that the task is completed
drop(dummy_tx);
});
}
});
Expand Down Expand Up @@ -214,15 +243,25 @@ async fn main() {
});
}

// Exit on Ctrl-C (SIGINT) and SIGTERM.
let mut term_signal = unix_signal(SignalKind::terminate()).unwrap();
// Initiate graceful shutdown sequence on sig int
let mut stream = unix_signal(SignalKind::interrupt()).unwrap();

tokio::select! {
_ = signal::ctrl_c() => (),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why you removed the ctrl-c and the SIGTERM handler? SIGTERM is commonly used in Docker containers to stop a container.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pgbouncer interprets sigterm and sigint differently. since the sigint is graceful shutdown we want to catch and handle that. for sigterm, which is immediate shutdown we don't need to do anything special

SIGINT
Safe shutdown. Same as issuing PAUSE and SHUTDOWN on the console.
SIGTERM
Immediate shutdown. Same as issuing SHUTDOWN on the console.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about the reasoning here. In Kuberentes SIGTERM is meant to initiate a graceful shutdown. https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/

_ = term_signal.recv() => (),
};
stream.recv().await;
warn!("Got SIGINT, waiting for client connection drain now");

info!("Shutting down...");
// Broadcast that client tasks need to finish
shutdown_event_tx.send(()).unwrap();
// Closes transmitter
drop(shutdown_event_tx);

loop {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you looping?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be two events that the receiver gets, the 1st is the event sent by shutdown_event_tx.send(()).unwrap(); and then the second event would be when all senders are closed and there is an error on the receive channel. made it a loop so that the timeout could be applied to all the operations.

match shutdown_event_rx.recv().await {
// The first event the receiver gets is from the initial broadcast, so we ignore that
Ok(_) => {}
// Expect to receive a closed error when all transmitters are closed. Which means all clients have completed their work
Err(_) => break,
};
}
}

/// Format chrono::Duration to be more human-friendly.
Expand Down