Line data Source code
1 : use std::future::Future;
2 : use std::io;
3 : use std::net::{IpAddr, SocketAddr};
4 : use std::time::Duration;
5 :
6 : use tokio::net::{self, TcpStream};
7 : use tokio::time;
8 :
9 : use crate::Error;
10 : use crate::config::Host;
11 :
12 0 : pub(crate) async fn connect_socket(
13 0 : host_addr: Option<IpAddr>,
14 0 : host: &Host,
15 0 : port: u16,
16 0 : connect_timeout: Option<Duration>,
17 0 : ) -> Result<TcpStream, Error> {
18 0 : match host {
19 0 : Host::Tcp(host) => {
20 0 : let addrs = match host_addr {
21 0 : Some(addr) => vec![SocketAddr::new(addr, port)],
22 0 : None => net::lookup_host((&**host, port))
23 0 : .await
24 0 : .map_err(Error::connect)?
25 0 : .collect(),
26 : };
27 :
28 0 : let mut last_err = None;
29 :
30 0 : for addr in addrs {
31 0 : let stream =
32 0 : match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await {
33 0 : Ok(stream) => stream,
34 0 : Err(e) => {
35 0 : last_err = Some(e);
36 0 : continue;
37 : }
38 : };
39 :
40 0 : stream.set_nodelay(true).map_err(Error::connect)?;
41 :
42 0 : return Ok(stream);
43 : }
44 :
45 0 : Err(last_err.unwrap_or_else(|| {
46 0 : Error::connect(io::Error::new(
47 0 : io::ErrorKind::InvalidInput,
48 0 : "could not resolve any addresses",
49 0 : ))
50 0 : }))
51 : }
52 : }
53 0 : }
54 :
55 0 : async fn connect_with_timeout<F, T>(connect: F, timeout: Option<Duration>) -> Result<T, Error>
56 0 : where
57 0 : F: Future<Output = io::Result<T>>,
58 0 : {
59 0 : match timeout {
60 0 : Some(timeout) => match time::timeout(timeout, connect).await {
61 0 : Ok(Ok(socket)) => Ok(socket),
62 0 : Ok(Err(e)) => Err(Error::connect(e)),
63 0 : Err(_) => Err(Error::connect(io::Error::new(
64 0 : io::ErrorKind::TimedOut,
65 0 : "connection timed out",
66 0 : ))),
67 : },
68 0 : None => match connect.await {
69 0 : Ok(socket) => Ok(socket),
70 0 : Err(e) => Err(Error::connect(e)),
71 : },
72 : }
73 0 : }
|