Line data Source code
1 : use std::{error::Error, sync::Arc};
2 :
3 : use futures::StreamExt;
4 : use futures::stream::FuturesUnordered;
5 : use hyper0::Body;
6 : use hyper0::server::conn::Http;
7 : use routerify::{RequestService, RequestServiceBuilder};
8 : use tokio::io::{AsyncRead, AsyncWrite};
9 : use tokio_rustls::TlsAcceptor;
10 : use tokio_util::sync::CancellationToken;
11 : use tracing::{error, info};
12 :
13 : use crate::error::ApiError;
14 :
15 : /// A simple HTTP server over hyper library.
16 : /// You may want to use it instead of [`hyper0::server::Server`] because:
17 : /// 1. hyper0's Server was removed from hyper v1.
18 : /// It's recommended to replace hyepr0's Server with a manual loop, which is done here.
19 : /// 2. hyper0's Server doesn't support TLS out of the box, and there is no way
20 : /// to support it efficiently with the Accept trait that hyper0's Server uses.
21 : /// That's one of the reasons why it was removed from v1.
22 : /// <https://github.com/hyperium/hyper/blob/115339d3df50f20c8717680aa35f48858e9a6205/docs/ROADMAP.md#higher-level-client-and-server-problems>
23 : pub struct Server {
24 : request_service: Arc<RequestServiceBuilder<Body, ApiError>>,
25 : listener: tokio::net::TcpListener,
26 : tls_acceptor: Option<TlsAcceptor>,
27 : }
28 :
29 : impl Server {
30 0 : pub fn new(
31 0 : request_service: Arc<RequestServiceBuilder<Body, ApiError>>,
32 0 : listener: std::net::TcpListener,
33 0 : tls_acceptor: Option<TlsAcceptor>,
34 0 : ) -> anyhow::Result<Self> {
35 0 : // Note: caller of from_std is responsible for setting nonblocking mode.
36 0 : listener.set_nonblocking(true)?;
37 0 : let listener = tokio::net::TcpListener::from_std(listener)?;
38 :
39 0 : Ok(Self {
40 0 : request_service,
41 0 : listener,
42 0 : tls_acceptor,
43 0 : })
44 0 : }
45 :
46 0 : pub async fn serve(self, cancel: CancellationToken) -> anyhow::Result<()> {
47 0 : fn suppress_io_error(err: &std::io::Error) -> bool {
48 : use std::io::ErrorKind::*;
49 0 : matches!(err.kind(), ConnectionReset | ConnectionAborted | BrokenPipe)
50 0 : }
51 0 : fn suppress_hyper_error(err: &hyper0::Error) -> bool {
52 0 : if err.is_incomplete_message() || err.is_closed() || err.is_timeout() {
53 0 : return true;
54 0 : }
55 0 : if let Some(inner) = err.source() {
56 0 : if let Some(io) = inner.downcast_ref::<std::io::Error>() {
57 0 : return suppress_io_error(io);
58 0 : }
59 0 : }
60 0 : false
61 0 : }
62 :
63 0 : let mut connections = FuturesUnordered::new();
64 : loop {
65 0 : tokio::select! {
66 0 : stream = self.listener.accept() => {
67 0 : let (tcp_stream, remote_addr) = match stream {
68 0 : Ok(stream) => stream,
69 0 : Err(err) => {
70 0 : if !suppress_io_error(&err) {
71 0 : info!("Failed to accept TCP connection: {err:#}");
72 0 : }
73 0 : continue;
74 : }
75 : };
76 :
77 0 : let service = self.request_service.build(remote_addr);
78 0 : let tls_acceptor = self.tls_acceptor.clone();
79 0 : let cancel = cancel.clone();
80 0 :
81 0 : connections.push(tokio::spawn(
82 0 : async move {
83 0 : match tls_acceptor {
84 0 : Some(tls_acceptor) => {
85 : // Handle HTTPS connection.
86 0 : let tls_stream = tokio::select! {
87 0 : tls_stream = tls_acceptor.accept(tcp_stream) => tls_stream,
88 0 : _ = cancel.cancelled() => return,
89 : };
90 0 : let tls_stream = match tls_stream {
91 0 : Ok(tls_stream) => tls_stream,
92 0 : Err(err) => {
93 0 : if !suppress_io_error(&err) {
94 0 : info!("Failed to accept TLS connection: {err:#}");
95 0 : }
96 0 : return;
97 : }
98 : };
99 0 : if let Err(err) = Self::serve_connection(tls_stream, service, cancel).await {
100 0 : if !suppress_hyper_error(&err) {
101 0 : info!("Failed to serve HTTPS connection: {err:#}");
102 0 : }
103 0 : }
104 : }
105 : None => {
106 : // Handle HTTP connection.
107 0 : if let Err(err) = Self::serve_connection(tcp_stream, service, cancel).await {
108 0 : if !suppress_hyper_error(&err) {
109 0 : info!("Failed to serve HTTP connection: {err:#}");
110 0 : }
111 0 : }
112 : }
113 : };
114 0 : }));
115 0 : }
116 0 : Some(conn) = connections.next() => {
117 0 : if let Err(err) = conn {
118 0 : error!("Connection panicked: {err:#}");
119 0 : }
120 : }
121 0 : _ = cancel.cancelled() => {
122 : // Wait for graceful shutdown of all connections.
123 0 : while let Some(conn) = connections.next().await {
124 0 : if let Err(err) = conn {
125 0 : error!("Connection panicked: {err:#}");
126 0 : }
127 : }
128 0 : break;
129 0 : }
130 0 : }
131 0 : }
132 0 : Ok(())
133 0 : }
134 :
135 : /// Serves HTTP connection with graceful shutdown.
136 0 : async fn serve_connection<I>(
137 0 : io: I,
138 0 : service: RequestService<Body, ApiError>,
139 0 : cancel: CancellationToken,
140 0 : ) -> Result<(), hyper0::Error>
141 0 : where
142 0 : I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
143 0 : {
144 0 : let mut conn = Http::new().serve_connection(io, service).with_upgrades();
145 0 :
146 0 : tokio::select! {
147 0 : res = &mut conn => res,
148 0 : _ = cancel.cancelled() => {
149 0 : Pin::new(&mut conn).graceful_shutdown();
150 0 : // Note: connection should still be awaited for graceful shutdown to complete.
151 0 : conn.await
152 : }
153 : }
154 0 : }
155 : }
|