tide_disco/
listener.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the tide-disco library.
3
4// You should have received a copy of the MIT License
5// along with the tide-disco library. If not, see <https://mit-license.org/>.
6
7use crate::StatusCode;
8use async_lock::Semaphore;
9use async_std::{
10    net::TcpListener,
11    sync::Arc,
12    task::{sleep, spawn},
13};
14use async_trait::async_trait;
15use derivative::Derivative;
16use futures::stream::StreamExt;
17use std::{
18    fmt::{self, Display, Formatter},
19    io::{self, ErrorKind},
20    net::SocketAddr,
21    time::Duration,
22};
23use tide::{
24    http,
25    listener::{ListenInfo, Listener, ToListener},
26    Server,
27};
28
29/// TCP listener which accepts only a limited number of connections at a time.
30///
31/// This listener is based on `tide::listener::TcpListener` and should match the semantics of that
32/// listener in every way, accept that when there are more simultaneous outstanding requests than
33/// the configured limit, excess requests will fail immediately with error code 429 (Too Many
34/// Requests).
35#[derive(Derivative)]
36#[derivative(Debug(bound = "State: Send + Sync + 'static"))]
37pub struct RateLimitListener<State> {
38    addr: SocketAddr,
39    listener: Option<TcpListener>,
40    server: Option<Server<State>>,
41    info: Option<ListenInfo>,
42    permit: Arc<Semaphore>,
43}
44
45impl<State> RateLimitListener<State> {
46    /// Listen at the given address.
47    pub fn new(addr: SocketAddr, limit: usize) -> Self {
48        Self {
49            addr,
50            listener: None,
51            server: None,
52            info: None,
53            permit: Arc::new(Semaphore::new(limit)),
54        }
55    }
56
57    /// Listen at the given port on all interfaces.
58    pub fn with_port(port: u16, limit: usize) -> Self {
59        Self::new(([0, 0, 0, 0], port).into(), limit)
60    }
61}
62
63#[async_trait]
64impl<State> Listener<State> for RateLimitListener<State>
65where
66    State: Clone + Send + Sync + 'static,
67{
68    async fn bind(&mut self, app: Server<State>) -> io::Result<()> {
69        if self.server.is_some() {
70            return Err(io::Error::new(
71                ErrorKind::AlreadyExists,
72                "`bind` should only be called once",
73            ));
74        }
75        self.server = Some(app);
76        self.listener = Some(TcpListener::bind(&[self.addr][..]).await?);
77
78        // Format the listen information.
79        let conn_string = format!("{}", self);
80        let transport = "tcp".to_owned();
81        let tls = false;
82        self.info = Some(ListenInfo::new(conn_string, transport, tls));
83
84        Ok(())
85    }
86
87    async fn accept(&mut self) -> io::Result<()> {
88        let server = self.server.take().ok_or_else(|| {
89            io::Error::other("`Listener::bind` must be called before `Listener::accept`")
90        })?;
91        let listener = self.listener.take().ok_or_else(|| {
92            io::Error::other("`Listener::bind` must be called before `Listener::accept`")
93        })?;
94
95        let mut incoming = listener.incoming();
96        while let Some(stream) = incoming.next().await {
97            match stream {
98                Err(err) if is_transient_error(&err) => continue,
99                Err(err) => {
100                    tracing::warn!(%err, "TCP error");
101                    sleep(Duration::from_millis(500)).await;
102                    continue;
103                }
104                Ok(stream) => {
105                    let app = server.clone();
106                    let permit = self.permit.clone();
107                    spawn(async move {
108                        let local_addr = stream.local_addr().ok();
109                        let peer_addr = stream.peer_addr().ok();
110
111                        let fut = async_h1::accept(stream, |mut req| async {
112                            // Handle the request if we can get a permit.
113                            if let Some(_guard) = permit.try_acquire() {
114                                req.set_local_addr(local_addr);
115                                req.set_peer_addr(peer_addr);
116                                app.respond(req).await
117                            } else {
118                                // Otherwise, we are rate limited. Respond immediately with an
119                                // error.
120                                Ok(http::Response::new(StatusCode::TOO_MANY_REQUESTS))
121                            }
122                        });
123
124                        if let Err(error) = fut.await {
125                            tracing::error!(%error, "HTTP error");
126                        }
127                    });
128                }
129            };
130        }
131        Ok(())
132    }
133
134    fn info(&self) -> Vec<ListenInfo> {
135        match &self.info {
136            Some(info) => vec![info.clone()],
137            None => vec![],
138        }
139    }
140}
141
142impl<State> ToListener<State> for RateLimitListener<State>
143where
144    State: Clone + Send + Sync + 'static,
145{
146    type Listener = Self;
147
148    fn to_listener(self) -> io::Result<Self::Listener> {
149        Ok(self)
150    }
151}
152
153impl<State> Display for RateLimitListener<State> {
154    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
155        match &self.listener {
156            Some(listener) => {
157                let addr = listener.local_addr().expect("Could not get local addr");
158                write!(f, "http://{}", addr)
159            }
160            None => write!(f, "http://{}", self.addr),
161        }
162    }
163}
164
165fn is_transient_error(e: &io::Error) -> bool {
166    matches!(
167        e.kind(),
168        ErrorKind::ConnectionRefused | ErrorKind::ConnectionAborted | ErrorKind::ConnectionReset
169    )
170}
171
172#[cfg(test)]
173mod test {
174    use super::*;
175    use crate::{
176        error::ServerError,
177        testing::{setup_test, Client},
178        App,
179    };
180    use futures::future::{try_join_all, FutureExt};
181    use portpicker::pick_unused_port;
182    use toml::toml;
183    use vbs::version::{StaticVersion, StaticVersionType};
184
185    type StaticVer01 = StaticVersion<0, 1>;
186
187    #[async_std::test]
188    async fn test_rate_limiting() {
189        setup_test();
190
191        let mut app = App::<_, ServerError>::with_state(());
192        let api_toml = toml! {
193            [route.test]
194            PATH = ["/test"]
195            METHOD = "GET"
196        };
197        {
198            let mut api = app
199                .module::<ServerError, StaticVer01>("mod", api_toml)
200                .unwrap();
201            api.get("test", |_req, _state| {
202                async move {
203                    // Make a really slow endpoint so we can have many simultaneous requests.
204                    sleep(Duration::from_secs(30)).await;
205                    Ok(())
206                }
207                .boxed()
208            })
209            .unwrap();
210        }
211
212        let limit = 10;
213        let port = pick_unused_port().unwrap();
214        spawn(app.serve(
215            RateLimitListener::with_port(port, limit),
216            StaticVer01::instance(),
217        ));
218        let client = Client::new(format!("http://localhost:{port}").parse().unwrap()).await;
219
220        // Start the maximum number of simultaneous requests.
221        let reqs = (0..limit)
222            .map(|_| spawn(client.get("mod/test").send()))
223            .collect::<Vec<_>>();
224
225        // Wait a bit for those requests to get accepted.
226        sleep(Duration::from_secs(5)).await;
227
228        // The next request gets rate limited.
229        let res = client.get("mod/test").send().await.unwrap();
230        assert_eq!(StatusCode::TOO_MANY_REQUESTS, res.status());
231
232        // The other requests eventually complete successfully.
233        for res in try_join_all(reqs).await.unwrap() {
234            assert_eq!(StatusCode::OK, res.status());
235        }
236    }
237}