1use crate::StatusCode;
8use async_lock::Semaphore;
9use async_std::{
10 net::TcpListener,
11 sync::Arc,
12 task::{sleep, spawn},
13};
14use async_trait::async_trait;
15use derivative::Derivative;
16use futures::stream::StreamExt;
17use std::{
18 fmt::{self, Display, Formatter},
19 io::{self, ErrorKind},
20 net::SocketAddr,
21 time::Duration,
22};
23use tide::{
24 http,
25 listener::{ListenInfo, Listener, ToListener},
26 Server,
27};
28
29#[derive(Derivative)]
36#[derivative(Debug(bound = "State: Send + Sync + 'static"))]
37pub struct RateLimitListener<State> {
38 addr: SocketAddr,
39 listener: Option<TcpListener>,
40 server: Option<Server<State>>,
41 info: Option<ListenInfo>,
42 permit: Arc<Semaphore>,
43}
44
45impl<State> RateLimitListener<State> {
46 pub fn new(addr: SocketAddr, limit: usize) -> Self {
48 Self {
49 addr,
50 listener: None,
51 server: None,
52 info: None,
53 permit: Arc::new(Semaphore::new(limit)),
54 }
55 }
56
57 pub fn with_port(port: u16, limit: usize) -> Self {
59 Self::new(([0, 0, 0, 0], port).into(), limit)
60 }
61}
62
63#[async_trait]
64impl<State> Listener<State> for RateLimitListener<State>
65where
66 State: Clone + Send + Sync + 'static,
67{
68 async fn bind(&mut self, app: Server<State>) -> io::Result<()> {
69 if self.server.is_some() {
70 return Err(io::Error::new(
71 ErrorKind::AlreadyExists,
72 "`bind` should only be called once",
73 ));
74 }
75 self.server = Some(app);
76 self.listener = Some(TcpListener::bind(&[self.addr][..]).await?);
77
78 let conn_string = format!("{}", self);
80 let transport = "tcp".to_owned();
81 let tls = false;
82 self.info = Some(ListenInfo::new(conn_string, transport, tls));
83
84 Ok(())
85 }
86
87 async fn accept(&mut self) -> io::Result<()> {
88 let server = self.server.take().ok_or_else(|| {
89 io::Error::other("`Listener::bind` must be called before `Listener::accept`")
90 })?;
91 let listener = self.listener.take().ok_or_else(|| {
92 io::Error::other("`Listener::bind` must be called before `Listener::accept`")
93 })?;
94
95 let mut incoming = listener.incoming();
96 while let Some(stream) = incoming.next().await {
97 match stream {
98 Err(err) if is_transient_error(&err) => continue,
99 Err(err) => {
100 tracing::warn!(%err, "TCP error");
101 sleep(Duration::from_millis(500)).await;
102 continue;
103 }
104 Ok(stream) => {
105 let app = server.clone();
106 let permit = self.permit.clone();
107 spawn(async move {
108 let local_addr = stream.local_addr().ok();
109 let peer_addr = stream.peer_addr().ok();
110
111 let fut = async_h1::accept(stream, |mut req| async {
112 if let Some(_guard) = permit.try_acquire() {
114 req.set_local_addr(local_addr);
115 req.set_peer_addr(peer_addr);
116 app.respond(req).await
117 } else {
118 Ok(http::Response::new(StatusCode::TOO_MANY_REQUESTS))
121 }
122 });
123
124 if let Err(error) = fut.await {
125 tracing::error!(%error, "HTTP error");
126 }
127 });
128 }
129 };
130 }
131 Ok(())
132 }
133
134 fn info(&self) -> Vec<ListenInfo> {
135 match &self.info {
136 Some(info) => vec![info.clone()],
137 None => vec![],
138 }
139 }
140}
141
142impl<State> ToListener<State> for RateLimitListener<State>
143where
144 State: Clone + Send + Sync + 'static,
145{
146 type Listener = Self;
147
148 fn to_listener(self) -> io::Result<Self::Listener> {
149 Ok(self)
150 }
151}
152
153impl<State> Display for RateLimitListener<State> {
154 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
155 match &self.listener {
156 Some(listener) => {
157 let addr = listener.local_addr().expect("Could not get local addr");
158 write!(f, "http://{}", addr)
159 }
160 None => write!(f, "http://{}", self.addr),
161 }
162 }
163}
164
165fn is_transient_error(e: &io::Error) -> bool {
166 matches!(
167 e.kind(),
168 ErrorKind::ConnectionRefused | ErrorKind::ConnectionAborted | ErrorKind::ConnectionReset
169 )
170}
171
172#[cfg(test)]
173mod test {
174 use super::*;
175 use crate::{
176 error::ServerError,
177 testing::{setup_test, Client},
178 App,
179 };
180 use futures::future::{try_join_all, FutureExt};
181 use portpicker::pick_unused_port;
182 use toml::toml;
183 use vbs::version::{StaticVersion, StaticVersionType};
184
185 type StaticVer01 = StaticVersion<0, 1>;
186
187 #[async_std::test]
188 async fn test_rate_limiting() {
189 setup_test();
190
191 let mut app = App::<_, ServerError>::with_state(());
192 let api_toml = toml! {
193 [route.test]
194 PATH = ["/test"]
195 METHOD = "GET"
196 };
197 {
198 let mut api = app
199 .module::<ServerError, StaticVer01>("mod", api_toml)
200 .unwrap();
201 api.get("test", |_req, _state| {
202 async move {
203 sleep(Duration::from_secs(30)).await;
205 Ok(())
206 }
207 .boxed()
208 })
209 .unwrap();
210 }
211
212 let limit = 10;
213 let port = pick_unused_port().unwrap();
214 spawn(app.serve(
215 RateLimitListener::with_port(port, limit),
216 StaticVer01::instance(),
217 ));
218 let client = Client::new(format!("http://localhost:{port}").parse().unwrap()).await;
219
220 let reqs = (0..limit)
222 .map(|_| spawn(client.get("mod/test").send()))
223 .collect::<Vec<_>>();
224
225 sleep(Duration::from_secs(5)).await;
227
228 let res = client.get("mod/test").send().await.unwrap();
230 assert_eq!(StatusCode::TOO_MANY_REQUESTS, res.status());
231
232 for res in try_join_all(reqs).await.unwrap() {
234 assert_eq!(StatusCode::OK, res.status());
235 }
236 }
237}