Clean up cancellation token

This commit is contained in:
Rayhaan Jaufeerally
2024-07-16 18:58:27 +00:00
parent 75dbfc319a
commit 9dfa39b99d
4 changed files with 78 additions and 83 deletions

View File

@ -9,7 +9,6 @@ use netlink_packet_route::route::RouteMessage;
use netlink_packet_route::route::RouteProtocol;
use netlink_packet_route::route::RouteType;
use netlink_packet_route::AddressFamily as NetlinkAddressFamily;
use netlink_packet_utils::nla::Nla;
use rtnetlink::IpVersion;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::{convert::TryInto, io::ErrorKind};

View File

@ -28,10 +28,10 @@ use std::net::Ipv6Addr;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::sync::broadcast;
use tokio::sync::mpsc::unbounded_channel;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use warp::Filter;
use warp::Reply;
@ -44,7 +44,7 @@ async fn socket_listener(
c: UnboundedSender<(TcpStream, SocketAddr)>,
listen_addr: String,
notifier: oneshot::Sender<Result<(), String>>,
mut shutdown: broadcast::Receiver<()>,
shutdown: CancellationToken,
) {
info!("Starting to listen on addr: {}", listen_addr);
let listener_result = TcpListener::bind(&listen_addr).await;
@ -52,7 +52,7 @@ async fn socket_listener(
warn!("Listener for {} failed: {}", listen_addr, e.to_string());
match notifier.send(Err(e.to_string())) {
Ok(_) => {}
Err(e) => warn!("Failed to send notification of channel error: {:?}", e),
Err(e) => warn!(?e, "Failed to send notification of channel error"),
}
return;
}
@ -62,16 +62,19 @@ async fn socket_listener(
Ok(_) => {}
Err(e) => warn!("Failed to send notification of channel ready: {:?}", e),
}
info!("Sucessfully spawned listner for: {}", listen_addr);
info!(listen_addr, "Spawned listner");
loop {
let conn = tokio::select! {
res = listener.accept() => res,
_ = shutdown.recv() => {
_ = shutdown.cancelled() => {
info!("Shutting down listener");
return;
}
};
info!("Got something: {:?}", conn);
info!(?conn, "New inbound connection");
match conn {
Ok((stream, addr)) => {
info!("Accepted socket connection from {}", addr);
@ -98,7 +101,7 @@ async fn start_http_server(
manager6: UnboundedSender<RouteManagerCommands<Ipv6Addr>>,
peers: HashMap<String, UnboundedSender<PeerCommands>>,
listen_addr: SocketAddr,
mut shutdown: broadcast::Receiver<()>,
shutdown: CancellationToken,
) -> Result<tokio::task::JoinHandle<()>, String> {
async fn manager_get_routes_handler<T: serde::ser::Serialize>(
channel: UnboundedSender<RouteManagerCommands<T>>,
@ -118,48 +121,48 @@ async fn start_http_server(
}
}
async fn rm_large_community(
chan: UnboundedSender<PeerCommands>,
ld1: u32,
ld2: u32,
) -> Result<impl warp::Reply, warp::Rejection> {
let (tx, rx) = tokio::sync::oneshot::channel::<String>();
if let Err(e) = chan.send(PeerCommands::RemoveLargeCommunity((ld1, ld2), tx)) {
warn!("Failed to send RemoveLargeCommunity request: {}", e);
return Err(warp::reject());
}
// async fn rm_large_community(
// chan: UnboundedSender<PeerCommands>,
// ld1: u32,
// ld2: u32,
// ) -> Result<impl warp::Reply, warp::Rejection> {
// let (tx, rx) = tokio::sync::oneshot::channel::<String>();
// if let Err(e) = chan.send(PeerCommands::RemoveLargeCommunity((ld1, ld2), tx)) {
// warn!("Failed to send RemoveLargeCommunity request: {}", e);
// return Err(warp::reject());
// }
match rx.await {
Ok(result) => Ok(warp::reply::json(&result)),
Err(e) => {
warn!(
"RemoveLargeCommunity response from peer state machine: {}",
e
);
Err(warp::reject())
}
}
}
// match rx.await {
// Ok(result) => Ok(warp::reply::json(&result)),
// Err(e) => {
// warn!(
// "RemoveLargeCommunity response from peer state machine: {}",
// e
// );
// Err(warp::reject())
// }
// }
// }
async fn add_large_community(
chan: UnboundedSender<PeerCommands>,
ld1: u32,
ld2: u32,
) -> Result<impl warp::Reply, warp::Rejection> {
let (tx, rx) = tokio::sync::oneshot::channel::<String>();
if let Err(e) = chan.send(PeerCommands::AddLargeCommunity((ld1, ld2), tx)) {
warn!("Failed to send AddLargeCommunity request: {}", e);
return Err(warp::reject());
}
// async fn add_large_community(
// chan: UnboundedSender<PeerCommands>,
// ld1: u32,
// ld2: u32,
// ) -> Result<impl warp::Reply, warp::Rejection> {
// let (tx, rx) = tokio::sync::oneshot::channel::<String>();
// if let Err(e) = chan.send(PeerCommands::AddLargeCommunity((ld1, ld2), tx)) {
// warn!("Failed to send AddLargeCommunity request: {}", e);
// return Err(warp::reject());
// }
match rx.await {
Ok(result) => Ok(warp::reply::json(&result)),
Err(e) => {
warn!("AddLargeCommunity response from peer state machine: {}", e);
Err(warp::reject())
}
}
}
// match rx.await {
// Ok(result) => Ok(warp::reply::json(&result)),
// Err(e) => {
// warn!("AddLargeCommunity response from peer state machine: {}", e);
// Err(warp::reject())
// }
// }
// }
// reset_peer_connection causes the PSM to close the connection, flush state, and reconnect to the peer.
async fn reset_peer_connection(
@ -292,7 +295,7 @@ async fn start_http_server(
.or(peers_restart_route);
let (_, server) = warp::serve(routes)
.try_bind_with_graceful_shutdown(listen_addr, async move {
shutdown.recv().await.ok();
shutdown.cancelled().await;
})
.map_err(|e| e.to_string())?;
Ok(tokio::task::spawn(server))
@ -303,7 +306,7 @@ pub struct Server {
config: ServerConfig,
// shutdown is a channel that a
shutdown: broadcast::Sender<()>,
shutdown: CancellationToken,
// worker_handles contains the JoinHandle of tasks spawned by the server so that
// we can wait on them for shutdown.
@ -315,7 +318,7 @@ pub struct Server {
impl Server {
pub fn new(config: ServerConfig) -> Server {
let (shutdown, _) = broadcast::channel(1);
let shutdown = CancellationToken::new();
Server {
config,
shutdown,
@ -345,9 +348,12 @@ impl Server {
info!("Starting listener for {}", listen_addr.to_string());
let sender = tcp_in_tx.clone();
let (ready_tx, ready_rx) = oneshot::channel();
let shutdown_channel = self.shutdown.subscribe();
let listen_handle = tokio::spawn(async move {
socket_listener(sender, listen_addr.to_string(), ready_tx, shutdown_channel).await;
let listen_handle = tokio::spawn({
let shutdown = self.shutdown.clone();
async move {
socket_listener(sender, listen_addr.to_string(), ready_tx, shutdown).await;
}
});
self.worker_handles.push(listen_handle);
if wait_startup {
@ -363,7 +369,7 @@ impl Server {
let (rp6_tx, rp6_rx) = unbounded_channel::<RouteManagerCommands<Ipv6Addr>>();
self.mgr_v6 = Some(rp6_tx.clone());
let mut rib_manager6: RibManager<Ipv6Addr> =
RibManager::<Ipv6Addr>::new(rp6_rx, self.shutdown.subscribe()).unwrap();
RibManager::<Ipv6Addr>::new(rp6_rx, self.shutdown.clone()).unwrap();
tokio::spawn(async move {
match rib_manager6.run().await {
Ok(_) => {}
@ -376,7 +382,7 @@ impl Server {
let (rp4_tx, rp4_rx) = unbounded_channel::<RouteManagerCommands<Ipv4Addr>>();
self.mgr_v4 = Some(rp4_tx.clone());
let mut rib_manager4: RibManager<Ipv4Addr> =
RibManager::<Ipv4Addr>::new(rp4_rx, self.shutdown.subscribe()).unwrap();
RibManager::<Ipv4Addr>::new(rp4_rx, self.shutdown.clone()).unwrap();
tokio::spawn(async move {
match rib_manager4.run().await {
Ok(_) => {}
@ -388,7 +394,6 @@ impl Server {
// Start a PeerStateMachine for every peer that is configured and store its channel so that
// we can communicate with it.
let mut peer_statemachines: HashMap<String, (PeerConfig, UnboundedSender<PeerCommands>)> =
HashMap::new();
@ -402,11 +407,10 @@ impl Server {
psm_rx,
psm_tx.clone(),
rp6_tx.clone(),
self.shutdown.subscribe(),
self.shutdown.clone(),
);
self.worker_handles.push(tokio::spawn(async move {
psm.run().await;
warn!("Should not reach here");
}));
}
AddressFamilyIdentifier::Ipv4 => {
@ -416,11 +420,10 @@ impl Server {
psm_rx,
psm_tx.clone(),
rp4_tx.clone(),
self.shutdown.subscribe(),
self.shutdown.clone(),
);
self.worker_handles.push(tokio::spawn(async move {
psm.run().await;
warn!("Should not reach here");
}));
}
_ => panic!("Unsupported address family: {}", peer_config.afi),
@ -442,7 +445,7 @@ impl Server {
rp6_tx.clone(),
peer_chan_map.clone(),
addr,
self.shutdown.subscribe(),
self.shutdown.clone(),
)
.await
.unwrap();
@ -471,12 +474,12 @@ impl Server {
}
// Event loop for processing inbound connections.
let mut shutdown_recv = self.shutdown.subscribe();
let shutdown = self.shutdown.clone();
self.worker_handles.push(tokio::spawn(async move {
loop {
let next = tokio::select! {
cmd = tcp_in_rx.recv() => cmd,
_ = shutdown_recv.recv() => {
_ = shutdown.cancelled() => {
warn!("Peer connection dispatcher shutting down due to shutdown signal.");
return;
}
@ -508,13 +511,7 @@ impl Server {
}
pub async fn shutdown(&mut self) {
match self.shutdown.send(()) {
Ok(_) => {}
Err(e) => {
warn!("Failed to send shutdown signal: {}", e);
return;
}
}
self.shutdown.cancel();
for handle in &mut self.worker_handles {
match handle.await {
Ok(_) => {}

View File

@ -59,7 +59,6 @@ use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::tcp;
use tokio::net::TcpStream;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::Mutex;
@ -355,7 +354,7 @@ pub struct PeerStateMachine<A: Address> {
keepalive_timer: Option<(JoinHandle<()>, CancellationToken)>,
read_cancel_token: Option<CancellationToken>,
shutdown: broadcast::Receiver<()>,
shutdown: CancellationToken,
}
impl<A: Address> PeerStateMachine<A>
@ -370,7 +369,7 @@ where
iface_rx: mpsc::UnboundedReceiver<PeerCommands>,
iface_tx: mpsc::UnboundedSender<PeerCommands>,
route_manager: mpsc::UnboundedSender<RouteManagerCommands<A>>,
shutdown: broadcast::Receiver<()>,
shutdown: CancellationToken,
) -> PeerStateMachine<A> {
let afi = config.afi;
PeerStateMachine {
@ -423,7 +422,7 @@ where
loop {
let next = tokio::select! {
cmd = self.iface_rx.recv() => cmd,
_ = self.shutdown.recv() => {
_ = self.shutdown.cancelled() => {
warn!("PSM shutting down due to shutdown signal.");
return;
},
@ -528,7 +527,7 @@ where
}
PeerCommands::AddLargeCommunity(c, sender) => {
for mut a in self.config.announcements.iter_mut() {
for a in self.config.announcements.iter_mut() {
if let Some(lcs) = a.large_communities.as_mut() {
lcs.push(format!("{}:{}:{}", self.config.asn, c.0, c.1));
} else {

View File

@ -21,6 +21,7 @@ use crate::server::config::PeerConfig;
use crate::server::data_structures::RouteUpdate;
use crate::server::peer::PeerCommands;
use tokio_util::sync::CancellationToken;
use tracing::{info, trace, warn};
use std::cmp::Eq;
@ -116,7 +117,7 @@ pub struct RibManager<A: Address> {
// Handle for streaming updates to PathSets in the RIB.
pathset_streaming_handle: broadcast::Sender<(u64, PathSet<A>)>,
shutdown: broadcast::Receiver<()>,
shutdown: CancellationToken,
}
impl<A: Address> RibManager<A>
@ -127,7 +128,7 @@ where
{
pub fn new(
chan: mpsc::UnboundedReceiver<RouteManagerCommands<A>>,
shutdown: broadcast::Receiver<()>,
shutdown: CancellationToken,
) -> Result<Self, std::io::Error> {
// TODO: Make this a flag that can be configured.
let (pathset_tx, _) = broadcast::channel(10_000_000);
@ -145,8 +146,8 @@ where
loop {
let next = tokio::select! {
cmd = self.mgr_rx.recv() => cmd,
_ = self.shutdown.recv() => {
warn!("RIB manager shutting down due to shutdown signal.");
_ = self.shutdown.cancelled() => {
warn!("RIB manager shutting down.");
return Ok(());
}
};
@ -227,7 +228,7 @@ where
// reannouncement or fresh announcement.
match path_set.paths.get_mut(&update.peer) {
// Peer already announced this route before.
Some(mut existing) => {
Some(existing) => {
trace!(
"Updating existing path attributes for NLRI: {}/{}",
addr,
@ -339,14 +340,13 @@ mod tests {
use std::net::Ipv6Addr;
use std::str::FromStr;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
#[test]
fn test_manager_process_single() {
let (_, rp_rx) = mpsc::unbounded_channel::<RouteManagerCommands<Ipv6Addr>>();
// Nothing spaawned here so no need to send the shutdown signal.
let (_shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel(1);
let mut rib_manager: RibManager<Ipv6Addr> =
RibManager::<Ipv6Addr>::new(rp_rx, shutdown_rx).unwrap();
RibManager::<Ipv6Addr>::new(rp_rx, CancellationToken::new()).unwrap();
let nexthop = Ipv6Addr::new(0x20, 0x01, 0xd, 0xb8, 0, 0, 0, 0x1);