Clean up cancellation token

This commit is contained in:
Rayhaan Jaufeerally
2024-07-16 18:58:27 +00:00
parent 75dbfc319a
commit 9dfa39b99d
4 changed files with 78 additions and 83 deletions

View File

@ -9,7 +9,6 @@ use netlink_packet_route::route::RouteMessage;
use netlink_packet_route::route::RouteProtocol; use netlink_packet_route::route::RouteProtocol;
use netlink_packet_route::route::RouteType; use netlink_packet_route::route::RouteType;
use netlink_packet_route::AddressFamily as NetlinkAddressFamily; use netlink_packet_route::AddressFamily as NetlinkAddressFamily;
use netlink_packet_utils::nla::Nla;
use rtnetlink::IpVersion; use rtnetlink::IpVersion;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::{convert::TryInto, io::ErrorKind}; use std::{convert::TryInto, io::ErrorKind};

View File

@ -28,10 +28,10 @@ use std::net::Ipv6Addr;
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::broadcast;
use tokio::sync::mpsc::unbounded_channel; use tokio::sync::mpsc::unbounded_channel;
use tokio::sync::mpsc::UnboundedSender; use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn}; use tracing::{info, warn};
use warp::Filter; use warp::Filter;
use warp::Reply; use warp::Reply;
@ -44,7 +44,7 @@ async fn socket_listener(
c: UnboundedSender<(TcpStream, SocketAddr)>, c: UnboundedSender<(TcpStream, SocketAddr)>,
listen_addr: String, listen_addr: String,
notifier: oneshot::Sender<Result<(), String>>, notifier: oneshot::Sender<Result<(), String>>,
mut shutdown: broadcast::Receiver<()>, shutdown: CancellationToken,
) { ) {
info!("Starting to listen on addr: {}", listen_addr); info!("Starting to listen on addr: {}", listen_addr);
let listener_result = TcpListener::bind(&listen_addr).await; let listener_result = TcpListener::bind(&listen_addr).await;
@ -52,7 +52,7 @@ async fn socket_listener(
warn!("Listener for {} failed: {}", listen_addr, e.to_string()); warn!("Listener for {} failed: {}", listen_addr, e.to_string());
match notifier.send(Err(e.to_string())) { match notifier.send(Err(e.to_string())) {
Ok(_) => {} Ok(_) => {}
Err(e) => warn!("Failed to send notification of channel error: {:?}", e), Err(e) => warn!(?e, "Failed to send notification of channel error"),
} }
return; return;
} }
@ -62,16 +62,19 @@ async fn socket_listener(
Ok(_) => {} Ok(_) => {}
Err(e) => warn!("Failed to send notification of channel ready: {:?}", e), Err(e) => warn!("Failed to send notification of channel ready: {:?}", e),
} }
info!("Sucessfully spawned listner for: {}", listen_addr);
info!(listen_addr, "Spawned listner");
loop { loop {
let conn = tokio::select! { let conn = tokio::select! {
res = listener.accept() => res, res = listener.accept() => res,
_ = shutdown.recv() => { _ = shutdown.cancelled() => {
info!("Shutting down listener"); info!("Shutting down listener");
return; return;
} }
}; };
info!("Got something: {:?}", conn);
info!(?conn, "New inbound connection");
match conn { match conn {
Ok((stream, addr)) => { Ok((stream, addr)) => {
info!("Accepted socket connection from {}", addr); info!("Accepted socket connection from {}", addr);
@ -98,7 +101,7 @@ async fn start_http_server(
manager6: UnboundedSender<RouteManagerCommands<Ipv6Addr>>, manager6: UnboundedSender<RouteManagerCommands<Ipv6Addr>>,
peers: HashMap<String, UnboundedSender<PeerCommands>>, peers: HashMap<String, UnboundedSender<PeerCommands>>,
listen_addr: SocketAddr, listen_addr: SocketAddr,
mut shutdown: broadcast::Receiver<()>, shutdown: CancellationToken,
) -> Result<tokio::task::JoinHandle<()>, String> { ) -> Result<tokio::task::JoinHandle<()>, String> {
async fn manager_get_routes_handler<T: serde::ser::Serialize>( async fn manager_get_routes_handler<T: serde::ser::Serialize>(
channel: UnboundedSender<RouteManagerCommands<T>>, channel: UnboundedSender<RouteManagerCommands<T>>,
@ -118,48 +121,48 @@ async fn start_http_server(
} }
} }
async fn rm_large_community( // async fn rm_large_community(
chan: UnboundedSender<PeerCommands>, // chan: UnboundedSender<PeerCommands>,
ld1: u32, // ld1: u32,
ld2: u32, // ld2: u32,
) -> Result<impl warp::Reply, warp::Rejection> { // ) -> Result<impl warp::Reply, warp::Rejection> {
let (tx, rx) = tokio::sync::oneshot::channel::<String>(); // let (tx, rx) = tokio::sync::oneshot::channel::<String>();
if let Err(e) = chan.send(PeerCommands::RemoveLargeCommunity((ld1, ld2), tx)) { // if let Err(e) = chan.send(PeerCommands::RemoveLargeCommunity((ld1, ld2), tx)) {
warn!("Failed to send RemoveLargeCommunity request: {}", e); // warn!("Failed to send RemoveLargeCommunity request: {}", e);
return Err(warp::reject()); // return Err(warp::reject());
} // }
match rx.await { // match rx.await {
Ok(result) => Ok(warp::reply::json(&result)), // Ok(result) => Ok(warp::reply::json(&result)),
Err(e) => { // Err(e) => {
warn!( // warn!(
"RemoveLargeCommunity response from peer state machine: {}", // "RemoveLargeCommunity response from peer state machine: {}",
e // e
); // );
Err(warp::reject()) // Err(warp::reject())
} // }
} // }
} // }
async fn add_large_community( // async fn add_large_community(
chan: UnboundedSender<PeerCommands>, // chan: UnboundedSender<PeerCommands>,
ld1: u32, // ld1: u32,
ld2: u32, // ld2: u32,
) -> Result<impl warp::Reply, warp::Rejection> { // ) -> Result<impl warp::Reply, warp::Rejection> {
let (tx, rx) = tokio::sync::oneshot::channel::<String>(); // let (tx, rx) = tokio::sync::oneshot::channel::<String>();
if let Err(e) = chan.send(PeerCommands::AddLargeCommunity((ld1, ld2), tx)) { // if let Err(e) = chan.send(PeerCommands::AddLargeCommunity((ld1, ld2), tx)) {
warn!("Failed to send AddLargeCommunity request: {}", e); // warn!("Failed to send AddLargeCommunity request: {}", e);
return Err(warp::reject()); // return Err(warp::reject());
} // }
match rx.await { // match rx.await {
Ok(result) => Ok(warp::reply::json(&result)), // Ok(result) => Ok(warp::reply::json(&result)),
Err(e) => { // Err(e) => {
warn!("AddLargeCommunity response from peer state machine: {}", e); // warn!("AddLargeCommunity response from peer state machine: {}", e);
Err(warp::reject()) // Err(warp::reject())
} // }
} // }
} // }
// reset_peer_connection causes the PSM to close the connection, flush state, and reconnect to the peer. // reset_peer_connection causes the PSM to close the connection, flush state, and reconnect to the peer.
async fn reset_peer_connection( async fn reset_peer_connection(
@ -292,7 +295,7 @@ async fn start_http_server(
.or(peers_restart_route); .or(peers_restart_route);
let (_, server) = warp::serve(routes) let (_, server) = warp::serve(routes)
.try_bind_with_graceful_shutdown(listen_addr, async move { .try_bind_with_graceful_shutdown(listen_addr, async move {
shutdown.recv().await.ok(); shutdown.cancelled().await;
}) })
.map_err(|e| e.to_string())?; .map_err(|e| e.to_string())?;
Ok(tokio::task::spawn(server)) Ok(tokio::task::spawn(server))
@ -303,7 +306,7 @@ pub struct Server {
config: ServerConfig, config: ServerConfig,
// shutdown is a channel that a // shutdown is a channel that a
shutdown: broadcast::Sender<()>, shutdown: CancellationToken,
// worker_handles contains the JoinHandle of tasks spawned by the server so that // worker_handles contains the JoinHandle of tasks spawned by the server so that
// we can wait on them for shutdown. // we can wait on them for shutdown.
@ -315,7 +318,7 @@ pub struct Server {
impl Server { impl Server {
pub fn new(config: ServerConfig) -> Server { pub fn new(config: ServerConfig) -> Server {
let (shutdown, _) = broadcast::channel(1); let shutdown = CancellationToken::new();
Server { Server {
config, config,
shutdown, shutdown,
@ -345,9 +348,12 @@ impl Server {
info!("Starting listener for {}", listen_addr.to_string()); info!("Starting listener for {}", listen_addr.to_string());
let sender = tcp_in_tx.clone(); let sender = tcp_in_tx.clone();
let (ready_tx, ready_rx) = oneshot::channel(); let (ready_tx, ready_rx) = oneshot::channel();
let shutdown_channel = self.shutdown.subscribe();
let listen_handle = tokio::spawn(async move { let listen_handle = tokio::spawn({
socket_listener(sender, listen_addr.to_string(), ready_tx, shutdown_channel).await; let shutdown = self.shutdown.clone();
async move {
socket_listener(sender, listen_addr.to_string(), ready_tx, shutdown).await;
}
}); });
self.worker_handles.push(listen_handle); self.worker_handles.push(listen_handle);
if wait_startup { if wait_startup {
@ -363,7 +369,7 @@ impl Server {
let (rp6_tx, rp6_rx) = unbounded_channel::<RouteManagerCommands<Ipv6Addr>>(); let (rp6_tx, rp6_rx) = unbounded_channel::<RouteManagerCommands<Ipv6Addr>>();
self.mgr_v6 = Some(rp6_tx.clone()); self.mgr_v6 = Some(rp6_tx.clone());
let mut rib_manager6: RibManager<Ipv6Addr> = let mut rib_manager6: RibManager<Ipv6Addr> =
RibManager::<Ipv6Addr>::new(rp6_rx, self.shutdown.subscribe()).unwrap(); RibManager::<Ipv6Addr>::new(rp6_rx, self.shutdown.clone()).unwrap();
tokio::spawn(async move { tokio::spawn(async move {
match rib_manager6.run().await { match rib_manager6.run().await {
Ok(_) => {} Ok(_) => {}
@ -376,7 +382,7 @@ impl Server {
let (rp4_tx, rp4_rx) = unbounded_channel::<RouteManagerCommands<Ipv4Addr>>(); let (rp4_tx, rp4_rx) = unbounded_channel::<RouteManagerCommands<Ipv4Addr>>();
self.mgr_v4 = Some(rp4_tx.clone()); self.mgr_v4 = Some(rp4_tx.clone());
let mut rib_manager4: RibManager<Ipv4Addr> = let mut rib_manager4: RibManager<Ipv4Addr> =
RibManager::<Ipv4Addr>::new(rp4_rx, self.shutdown.subscribe()).unwrap(); RibManager::<Ipv4Addr>::new(rp4_rx, self.shutdown.clone()).unwrap();
tokio::spawn(async move { tokio::spawn(async move {
match rib_manager4.run().await { match rib_manager4.run().await {
Ok(_) => {} Ok(_) => {}
@ -388,7 +394,6 @@ impl Server {
// Start a PeerStateMachine for every peer that is configured and store its channel so that // Start a PeerStateMachine for every peer that is configured and store its channel so that
// we can communicate with it. // we can communicate with it.
let mut peer_statemachines: HashMap<String, (PeerConfig, UnboundedSender<PeerCommands>)> = let mut peer_statemachines: HashMap<String, (PeerConfig, UnboundedSender<PeerCommands>)> =
HashMap::new(); HashMap::new();
@ -402,11 +407,10 @@ impl Server {
psm_rx, psm_rx,
psm_tx.clone(), psm_tx.clone(),
rp6_tx.clone(), rp6_tx.clone(),
self.shutdown.subscribe(), self.shutdown.clone(),
); );
self.worker_handles.push(tokio::spawn(async move { self.worker_handles.push(tokio::spawn(async move {
psm.run().await; psm.run().await;
warn!("Should not reach here");
})); }));
} }
AddressFamilyIdentifier::Ipv4 => { AddressFamilyIdentifier::Ipv4 => {
@ -416,11 +420,10 @@ impl Server {
psm_rx, psm_rx,
psm_tx.clone(), psm_tx.clone(),
rp4_tx.clone(), rp4_tx.clone(),
self.shutdown.subscribe(), self.shutdown.clone(),
); );
self.worker_handles.push(tokio::spawn(async move { self.worker_handles.push(tokio::spawn(async move {
psm.run().await; psm.run().await;
warn!("Should not reach here");
})); }));
} }
_ => panic!("Unsupported address family: {}", peer_config.afi), _ => panic!("Unsupported address family: {}", peer_config.afi),
@ -442,7 +445,7 @@ impl Server {
rp6_tx.clone(), rp6_tx.clone(),
peer_chan_map.clone(), peer_chan_map.clone(),
addr, addr,
self.shutdown.subscribe(), self.shutdown.clone(),
) )
.await .await
.unwrap(); .unwrap();
@ -471,12 +474,12 @@ impl Server {
} }
// Event loop for processing inbound connections. // Event loop for processing inbound connections.
let mut shutdown_recv = self.shutdown.subscribe(); let shutdown = self.shutdown.clone();
self.worker_handles.push(tokio::spawn(async move { self.worker_handles.push(tokio::spawn(async move {
loop { loop {
let next = tokio::select! { let next = tokio::select! {
cmd = tcp_in_rx.recv() => cmd, cmd = tcp_in_rx.recv() => cmd,
_ = shutdown_recv.recv() => { _ = shutdown.cancelled() => {
warn!("Peer connection dispatcher shutting down due to shutdown signal."); warn!("Peer connection dispatcher shutting down due to shutdown signal.");
return; return;
} }
@ -508,13 +511,7 @@ impl Server {
} }
pub async fn shutdown(&mut self) { pub async fn shutdown(&mut self) {
match self.shutdown.send(()) { self.shutdown.cancel();
Ok(_) => {}
Err(e) => {
warn!("Failed to send shutdown signal: {}", e);
return;
}
}
for handle in &mut self.worker_handles { for handle in &mut self.worker_handles {
match handle.await { match handle.await {
Ok(_) => {} Ok(_) => {}

View File

@ -59,7 +59,6 @@ use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::net::tcp; use tokio::net::tcp;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::broadcast;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio::sync::Mutex; use tokio::sync::Mutex;
@ -355,7 +354,7 @@ pub struct PeerStateMachine<A: Address> {
keepalive_timer: Option<(JoinHandle<()>, CancellationToken)>, keepalive_timer: Option<(JoinHandle<()>, CancellationToken)>,
read_cancel_token: Option<CancellationToken>, read_cancel_token: Option<CancellationToken>,
shutdown: broadcast::Receiver<()>, shutdown: CancellationToken,
} }
impl<A: Address> PeerStateMachine<A> impl<A: Address> PeerStateMachine<A>
@ -370,7 +369,7 @@ where
iface_rx: mpsc::UnboundedReceiver<PeerCommands>, iface_rx: mpsc::UnboundedReceiver<PeerCommands>,
iface_tx: mpsc::UnboundedSender<PeerCommands>, iface_tx: mpsc::UnboundedSender<PeerCommands>,
route_manager: mpsc::UnboundedSender<RouteManagerCommands<A>>, route_manager: mpsc::UnboundedSender<RouteManagerCommands<A>>,
shutdown: broadcast::Receiver<()>, shutdown: CancellationToken,
) -> PeerStateMachine<A> { ) -> PeerStateMachine<A> {
let afi = config.afi; let afi = config.afi;
PeerStateMachine { PeerStateMachine {
@ -423,7 +422,7 @@ where
loop { loop {
let next = tokio::select! { let next = tokio::select! {
cmd = self.iface_rx.recv() => cmd, cmd = self.iface_rx.recv() => cmd,
_ = self.shutdown.recv() => { _ = self.shutdown.cancelled() => {
warn!("PSM shutting down due to shutdown signal."); warn!("PSM shutting down due to shutdown signal.");
return; return;
}, },
@ -528,7 +527,7 @@ where
} }
PeerCommands::AddLargeCommunity(c, sender) => { PeerCommands::AddLargeCommunity(c, sender) => {
for mut a in self.config.announcements.iter_mut() { for a in self.config.announcements.iter_mut() {
if let Some(lcs) = a.large_communities.as_mut() { if let Some(lcs) = a.large_communities.as_mut() {
lcs.push(format!("{}:{}:{}", self.config.asn, c.0, c.1)); lcs.push(format!("{}:{}:{}", self.config.asn, c.0, c.1));
} else { } else {

View File

@ -21,6 +21,7 @@ use crate::server::config::PeerConfig;
use crate::server::data_structures::RouteUpdate; use crate::server::data_structures::RouteUpdate;
use crate::server::peer::PeerCommands; use crate::server::peer::PeerCommands;
use tokio_util::sync::CancellationToken;
use tracing::{info, trace, warn}; use tracing::{info, trace, warn};
use std::cmp::Eq; use std::cmp::Eq;
@ -116,7 +117,7 @@ pub struct RibManager<A: Address> {
// Handle for streaming updates to PathSets in the RIB. // Handle for streaming updates to PathSets in the RIB.
pathset_streaming_handle: broadcast::Sender<(u64, PathSet<A>)>, pathset_streaming_handle: broadcast::Sender<(u64, PathSet<A>)>,
shutdown: broadcast::Receiver<()>, shutdown: CancellationToken,
} }
impl<A: Address> RibManager<A> impl<A: Address> RibManager<A>
@ -127,7 +128,7 @@ where
{ {
pub fn new( pub fn new(
chan: mpsc::UnboundedReceiver<RouteManagerCommands<A>>, chan: mpsc::UnboundedReceiver<RouteManagerCommands<A>>,
shutdown: broadcast::Receiver<()>, shutdown: CancellationToken,
) -> Result<Self, std::io::Error> { ) -> Result<Self, std::io::Error> {
// TODO: Make this a flag that can be configured. // TODO: Make this a flag that can be configured.
let (pathset_tx, _) = broadcast::channel(10_000_000); let (pathset_tx, _) = broadcast::channel(10_000_000);
@ -145,8 +146,8 @@ where
loop { loop {
let next = tokio::select! { let next = tokio::select! {
cmd = self.mgr_rx.recv() => cmd, cmd = self.mgr_rx.recv() => cmd,
_ = self.shutdown.recv() => { _ = self.shutdown.cancelled() => {
warn!("RIB manager shutting down due to shutdown signal."); warn!("RIB manager shutting down.");
return Ok(()); return Ok(());
} }
}; };
@ -227,7 +228,7 @@ where
// reannouncement or fresh announcement. // reannouncement or fresh announcement.
match path_set.paths.get_mut(&update.peer) { match path_set.paths.get_mut(&update.peer) {
// Peer already announced this route before. // Peer already announced this route before.
Some(mut existing) => { Some(existing) => {
trace!( trace!(
"Updating existing path attributes for NLRI: {}/{}", "Updating existing path attributes for NLRI: {}/{}",
addr, addr,
@ -339,14 +340,13 @@ mod tests {
use std::net::Ipv6Addr; use std::net::Ipv6Addr;
use std::str::FromStr; use std::str::FromStr;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
#[test] #[test]
fn test_manager_process_single() { fn test_manager_process_single() {
let (_, rp_rx) = mpsc::unbounded_channel::<RouteManagerCommands<Ipv6Addr>>(); let (_, rp_rx) = mpsc::unbounded_channel::<RouteManagerCommands<Ipv6Addr>>();
// Nothing spaawned here so no need to send the shutdown signal.
let (_shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel(1);
let mut rib_manager: RibManager<Ipv6Addr> = let mut rib_manager: RibManager<Ipv6Addr> =
RibManager::<Ipv6Addr>::new(rp_rx, shutdown_rx).unwrap(); RibManager::<Ipv6Addr>::new(rp_rx, CancellationToken::new()).unwrap();
let nexthop = Ipv6Addr::new(0x20, 0x01, 0xd, 0xb8, 0, 0, 0, 0x1); let nexthop = Ipv6Addr::new(0x20, 0x01, 0xd, 0xb8, 0, 0, 0, 0x1);