22// SPDX-License-Identifier: AGPL-3.0-only
33
44use std:: sync:: Arc ;
5+ use std:: time:: Duration ;
56
67use async_trait:: async_trait;
78use axum:: body:: Bytes ;
@@ -19,7 +20,7 @@ use mx20022_channels::{
1920 ChannelError , ChannelHealth , DeliveryReceipt , InboundChannel , InboundMessage , OutboundChannel ,
2021 OutboundMessage ,
2122} ;
22- use tokio:: sync:: { mpsc, RwLock } ;
23+ use tokio:: sync:: { mpsc, watch , RwLock } ;
2324use tower_http:: cors:: CorsLayer ;
2425
2526#[ derive( Debug , Clone ) ]
@@ -44,13 +45,18 @@ struct InboundState {
4445pub struct HttpInboundChannel {
4546 config : HttpInboundConfig ,
4647 paused : Arc < RwLock < bool > > ,
48+ shutdown_tx : Arc < watch:: Sender < bool > > ,
49+ shutdown_rx : watch:: Receiver < bool > ,
4750}
4851
4952impl HttpInboundChannel {
5053 pub fn new ( config : HttpInboundConfig ) -> Self {
54+ let ( shutdown_tx, shutdown_rx) = watch:: channel ( false ) ;
5155 Self {
5256 config,
5357 paused : Arc :: new ( RwLock :: new ( false ) ) ,
58+ shutdown_tx : Arc :: new ( shutdown_tx) ,
59+ shutdown_rx,
5460 }
5561 }
5662}
@@ -92,7 +98,16 @@ impl InboundChannel for HttpInboundChannel {
9298 ) )
9399 } ) ?;
94100 tracing:: info!( channel = %self . config. name, bind = %self . config. bind, "http inbound channel starting with TLS" ) ;
101+ let handle = axum_server:: Handle :: new ( ) ;
102+ let shutdown_handle = handle. clone ( ) ;
103+ let mut shutdown_rx = self . shutdown_rx . clone ( ) ;
104+ tokio:: spawn ( async move {
105+ let _ = shutdown_rx. changed ( ) . await ;
106+ tracing:: info!( "TLS graceful shutdown triggered" ) ;
107+ shutdown_handle. graceful_shutdown ( Some ( Duration :: from_secs ( 30 ) ) ) ;
108+ } ) ;
95109 axum_server:: bind_rustls ( socket, tls)
110+ . handle ( handle)
96111 . serve ( app. into_make_service ( ) )
97112 . await
98113 . map_err ( |e| ChannelError :: new ( format ! ( "inbound channel serve failed: {e}" ) ) )
@@ -104,7 +119,13 @@ impl InboundChannel for HttpInboundChannel {
104119 ChannelError :: new ( format ! ( "failed to bind inbound channel: {e}" ) )
105120 } ) ?;
106121 tracing:: warn!( channel = %self . config. name, bind = %self . config. bind, "http inbound channel starting without TLS" ) ;
122+ let mut shutdown_rx = self . shutdown_rx . clone ( ) ;
123+ let shutdown_signal = async move {
124+ let _ = shutdown_rx. changed ( ) . await ;
125+ tracing:: info!( "graceful shutdown triggered" ) ;
126+ } ;
107127 axum:: serve ( listener, app)
128+ . with_graceful_shutdown ( shutdown_signal)
108129 . await
109130 . map_err ( |e| ChannelError :: new ( format ! ( "inbound channel serve failed: {e}" ) ) )
110131 }
@@ -115,6 +136,8 @@ impl InboundChannel for HttpInboundChannel {
115136 }
116137
117138 async fn shutdown ( & self ) -> Result < ( ) , ChannelError > {
139+ tracing:: info!( channel = %self . config. name, "http inbound channel shutting down" ) ;
140+ let _ = self . shutdown_tx . send ( true ) ;
118141 Ok ( ( ) )
119142 }
120143
@@ -307,16 +330,20 @@ fn now_millis() -> u128 {
307330#[ cfg( test) ]
308331mod tests {
309332 use std:: sync:: Arc ;
333+ use std:: time:: Duration ;
310334
311335 use axum:: http:: HeaderMap ;
312336 use axum:: http:: StatusCode ;
313337 use axum:: routing:: post;
314338 use axum:: { extract:: State , Router } ;
315339 use mx20022_channels:: auth:: InboundAuthConfig ;
316- use mx20022_channels:: { OutboundChannel , OutboundMessage } ;
340+ use mx20022_channels:: { InboundChannel , OutboundChannel , OutboundMessage } ;
317341 use tokio:: sync:: { mpsc, RwLock } ;
318342
319- use super :: { handle_post, HttpOutboundChannel , HttpOutboundConfig , InboundState } ;
343+ use super :: {
344+ handle_post, HttpInboundChannel , HttpInboundConfig , HttpOutboundChannel ,
345+ HttpOutboundConfig , InboundState ,
346+ } ;
320347
321348 #[ tokio:: test]
322349 async fn inbound_handler_enqueues_message ( ) {
@@ -336,6 +363,72 @@ mod tests {
336363 assert_eq ! ( queued. content_type, "application/xml" ) ;
337364 }
338365
366+ #[ tokio:: test]
367+ async fn shutdown_drain ( ) {
368+ let ( tx, mut rx) = mpsc:: channel ( 10 ) ;
369+
370+ // Bind to a free port by creating a temporary listener.
371+ let temp_listener = tokio:: net:: TcpListener :: bind ( "127.0.0.1:0" )
372+ . await
373+ . expect ( "bind temp listener" ) ;
374+ let addr = temp_listener. local_addr ( ) . expect ( "resolve addr" ) ;
375+ drop ( temp_listener) ;
376+
377+ let channel = Arc :: new ( HttpInboundChannel :: new ( HttpInboundConfig {
378+ name : "test-shutdown" . to_string ( ) ,
379+ bind : addr. to_string ( ) ,
380+ content_type : "application/xml" . to_string ( ) ,
381+ auth : InboundAuthConfig :: default ( ) ,
382+ cors_allowed_origins : vec ! [ ] ,
383+ tls_cert_path : None ,
384+ tls_key_path : None ,
385+ } ) ) ;
386+
387+ let run_channel = Arc :: clone ( & channel) ;
388+ let handle = tokio:: spawn ( async move { run_channel. run ( tx) . await } ) ;
389+
390+ // Wait for the server to be ready.
391+ tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) . await ;
392+
393+ // Send a message before shutdown — should succeed.
394+ let client = reqwest:: Client :: new ( ) ;
395+ let resp = client
396+ . post ( format ! ( "http://{addr}/" ) )
397+ . header ( "content-type" , "application/xml" )
398+ . body ( "<Document/>" )
399+ . send ( )
400+ . await
401+ . expect ( "pre-shutdown request" ) ;
402+ assert_eq ! ( resp. status( ) , StatusCode :: ACCEPTED ) ;
403+
404+ // Verify the message was queued.
405+ let msg = rx
406+ . recv ( )
407+ . await
408+ . expect ( "should receive pre-shutdown message" ) ;
409+ assert_eq ! ( msg. raw, "<Document/>" ) ;
410+
411+ // Trigger graceful shutdown.
412+ channel. shutdown ( ) . await . expect ( "shutdown" ) ;
413+
414+ // Wait for the server to drain.
415+ let result = tokio:: time:: timeout ( Duration :: from_secs ( 5 ) , handle) . await ;
416+ assert ! ( result. is_ok( ) , "server should shut down within timeout" ) ;
417+ assert ! ( result. unwrap( ) . is_ok( ) , "server task should not error" ) ;
418+
419+ // After shutdown, new requests should be rejected.
420+ let err = client
421+ . post ( format ! ( "http://{addr}/" ) )
422+ . header ( "content-type" , "application/xml" )
423+ . body ( "<AfterShutdown/>" )
424+ . send ( )
425+ . await ;
426+ assert ! (
427+ err. is_err( ) ,
428+ "post-shutdown request should fail (connection refused)"
429+ ) ;
430+ }
431+
339432 #[ tokio:: test]
340433 async fn outbound_channel_posts_payload ( ) {
341434 let listener = tokio:: net:: TcpListener :: bind ( "127.0.0.1:0" )
0 commit comments