1use crate::{
10 http::{content::Accept, mime},
11 method::Method,
12 request::{best_response_type, RequestError, RequestParams},
13 StatusCode,
14};
15use async_std::sync::Arc;
16use futures::{
17 future::BoxFuture,
18 select, sink,
19 stream::BoxStream,
20 task::{Context, Poll},
21 FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt,
22};
23use pin_project::pin_project;
24use serde::{de::DeserializeOwned, Serialize};
25use std::borrow::Cow;
26use std::fmt::{self, Display, Formatter};
27use std::marker::PhantomData;
28use std::pin::Pin;
29use tide_websockets::{
30 tungstenite::protocol::frame::{coding::CloseCode, CloseFrame},
31 Message, WebSocketConnection,
32};
33use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
34
35#[derive(Debug)]
41pub enum SocketError<E> {
42 AppSpecific(E),
43 Request(RequestError),
44 Binary(anyhow::Error),
45 Json(serde_json::Error),
46 WebSockets(tide_websockets::Error),
47 UnsupportedMessageType,
48 Closed,
49 IncorrectMethod { expected: Method, actual: Method },
50}
51
52impl<E> SocketError<E> {
53 pub fn status(&self) -> StatusCode {
54 match self {
55 Self::Request(_) | Self::UnsupportedMessageType | Self::IncorrectMethod { .. } => {
56 StatusCode::BAD_REQUEST
57 }
58 _ => StatusCode::INTERNAL_SERVER_ERROR,
59 }
60 }
61
62 pub fn code(&self) -> CloseCode {
63 CloseCode::Error
64 }
65
66 pub fn map_app_specific<E2>(self, f: &impl Fn(E) -> E2) -> SocketError<E2> {
67 match self {
68 Self::AppSpecific(e) => SocketError::AppSpecific(f(e)),
69 Self::Request(e) => SocketError::Request(e),
70 Self::Binary(e) => SocketError::Binary(e),
71 Self::Json(e) => SocketError::Json(e),
72 Self::WebSockets(e) => SocketError::WebSockets(e),
73 Self::UnsupportedMessageType => SocketError::UnsupportedMessageType,
74 Self::Closed => SocketError::Closed,
75 Self::IncorrectMethod { expected, actual } => {
76 SocketError::IncorrectMethod { expected, actual }
77 }
78 }
79 }
80}
81
82impl<E: Display> Display for SocketError<E> {
83 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
84 match self {
85 Self::AppSpecific(e) => write!(f, "{}", e),
86 Self::Request(e) => write!(f, "{}", e),
87 Self::Binary(e) => write!(f, "error creating byte stream: {}", e),
88 Self::Json(e) => write!(f, "error creating JSON message: {}", e),
89 Self::WebSockets(e) => write!(f, "WebSockets protocol error: {}", e),
90 Self::UnsupportedMessageType => {
91 write!(f, "unsupported content type for WebSockets message")
92 }
93 Self::Closed => write!(f, "connection closed"),
94 Self::IncorrectMethod { expected, actual } => write!(
95 f,
96 "endpoint must be called as {}, but was called as {}",
97 expected, actual
98 ),
99 }
100 }
101}
102
103impl<E> From<RequestError> for SocketError<E> {
104 fn from(err: RequestError) -> Self {
105 Self::Request(err)
106 }
107}
108
109impl<E> From<anyhow::Error> for SocketError<E> {
110 fn from(err: anyhow::Error) -> Self {
111 Self::Binary(err)
112 }
113}
114
115impl<E> From<serde_json::Error> for SocketError<E> {
116 fn from(err: serde_json::Error) -> Self {
117 Self::Json(err)
118 }
119}
120
121impl<E> From<tide_websockets::Error> for SocketError<E> {
122 fn from(err: tide_websockets::Error) -> Self {
123 Self::WebSockets(err)
124 }
125}
126
127#[derive(Clone, Copy, Debug)]
128enum MessageType {
129 Binary,
130 Json,
131}
132
133#[pin_project]
138pub struct Connection<ToClient: ?Sized, FromClient, Error, VER: StaticVersionType> {
139 #[pin]
140 conn: WebSocketConnection,
141 sink: Pin<Box<dyn Send + Sink<Message, Error = SocketError<Error>>>>,
143 accept: MessageType,
144 #[allow(clippy::type_complexity)]
145 _phantom: PhantomData<fn(&ToClient, &FromClient, &Error, &VER) -> ()>,
146}
147
148impl<ToClient: ?Sized, FromClient: DeserializeOwned, E, VER: StaticVersionType> Stream
149 for Connection<ToClient, FromClient, E, VER>
150{
151 type Item = Result<FromClient, SocketError<E>>;
152
153 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
154 match self.project().conn.poll_next(cx) {
157 Poll::Ready(None) => Poll::Ready(None),
158 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))),
159 Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(match msg {
160 Message::Binary(bytes) => {
161 Serializer::<VER>::deserialize(&bytes).map_err(SocketError::from)
162 }
163 Message::Text(s) => serde_json::from_str(&s).map_err(SocketError::from),
164 _ => Err(SocketError::UnsupportedMessageType),
165 })),
166 Poll::Pending => Poll::Pending,
167 }
168 }
169}
170
171impl<ToClient: Serialize + ?Sized, FromClient, E, VER: StaticVersionType> Sink<&ToClient>
172 for Connection<ToClient, FromClient, E, VER>
173{
174 type Error = SocketError<E>;
175
176 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177 self.sink.as_mut().poll_ready(cx).map_err(SocketError::from)
178 }
179
180 fn start_send(mut self: Pin<&mut Self>, item: &ToClient) -> Result<(), Self::Error> {
181 let msg = match self.accept {
182 MessageType::Binary => Message::Binary(Serializer::<VER>::serialize(item)?),
183 MessageType::Json => Message::Text(serde_json::to_string(item)?),
184 };
185 self.sink.as_mut().start_send(msg)
186 }
187
188 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189 self.sink.as_mut().poll_flush(cx).map_err(SocketError::from)
190 }
191
192 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
193 self.sink.as_mut().poll_close(cx).map_err(SocketError::from)
194 }
195}
196
197impl<ToClient: Serialize, FromClient, E, VER: StaticVersionType> Sink<ToClient>
198 for Connection<ToClient, FromClient, E, VER>
199{
200 type Error = SocketError<E>;
201
202 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
203 Sink::<&ToClient>::poll_ready(self, cx)
204 }
205
206 fn start_send(self: Pin<&mut Self>, item: ToClient) -> Result<(), Self::Error> {
207 self.start_send(&item)
208 }
209
210 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
211 Sink::<&ToClient>::poll_flush(self, cx)
212 }
213
214 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215 Sink::<&ToClient>::poll_close(self, cx)
216 }
217}
218
219impl<ToClient: ?Sized, FromClient, E, VER: StaticVersionType>
220 Connection<ToClient, FromClient, E, VER>
221{
222 #[allow(clippy::result_large_err)]
223 fn new(accept: &Accept, conn: WebSocketConnection) -> Result<Self, SocketError<E>> {
224 let ty = best_response_type(accept, &[mime::JSON, mime::BYTE_STREAM])?;
225 let ty = if ty == mime::JSON {
226 MessageType::Json
227 } else if ty == mime::BYTE_STREAM {
228 MessageType::Binary
229 } else {
230 unreachable!()
231 };
232 Ok(Self {
233 sink: Self::sink(conn.clone()),
234 conn,
235 accept: ty,
236 _phantom: Default::default(),
237 })
238 }
239
240 fn sink(
242 conn: WebSocketConnection,
243 ) -> Pin<Box<dyn Send + Sink<Message, Error = SocketError<E>>>> {
244 Box::pin(sink::unfold(conn, |conn, msg| async move {
245 conn.send(msg).await?;
246 Ok(conn)
247 }))
248 }
249}
250
251impl<ToClient: ?Sized, FromClient, E, VER: StaticVersionType> Clone
252 for Connection<ToClient, FromClient, E, VER>
253{
254 fn clone(&self) -> Self {
255 Self {
256 sink: Self::sink(self.conn.clone()),
257 conn: self.conn.clone(),
258 accept: self.accept,
259 _phantom: Default::default(),
260 }
261 }
262}
263
264pub(crate) type Handler<State, Error> = Box<
265 dyn 'static
266 + Send
267 + Sync
268 + Fn(RequestParams, WebSocketConnection, &State) -> BoxFuture<Result<(), SocketError<Error>>>,
269>;
270
271pub(crate) fn handler<State, Error, ToClient, FromClient, F, VER: StaticVersionType>(
272 f: F,
273) -> Handler<State, Error>
274where
275 F: 'static
276 + Send
277 + Sync
278 + Fn(
279 RequestParams,
280 Connection<ToClient, FromClient, Error, VER>,
281 &State,
282 ) -> BoxFuture<Result<(), Error>>,
283 State: 'static + Send + Sync,
284 ToClient: 'static + Serialize + ?Sized,
285 FromClient: 'static + DeserializeOwned,
286 Error: 'static + Send + Display,
287{
288 raw_handler(move |req, conn, state| {
289 f(req, conn, state)
290 .map_err(SocketError::AppSpecific)
291 .boxed()
292 })
293}
294
295struct StreamHandler<F, VER: StaticVersionType>(F, PhantomData<VER>);
296
297impl<F, VER: StaticVersionType> StreamHandler<F, VER> {
298 fn handle<'a, State, Error, Msg>(
299 &self,
300 req: RequestParams,
301 conn: Connection<Msg, (), Error, VER>,
302 state: &'a State,
303 ) -> BoxFuture<'a, Result<(), SocketError<Error>>>
304 where
305 F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxStream<Result<Msg, Error>>,
306 State: 'static + Send + Sync,
307 Msg: 'static + Serialize + Send + Sync,
308 Error: 'static + Send,
309 VER: 'static + Send + Sync,
310 {
311 let mut stream = (self.0)(req, state).fuse();
312 async move {
313 let (mut send, mut recv) = (conn.clone(), conn);
315
316 let mut item_fut = stream.next();
318 let mut client_fut = recv.next().fuse();
319
320 loop {
321 select! {
322 item = item_fut => {
323 match item {
324 Some(msg) => {
325 send.send(&msg.map_err(SocketError::AppSpecific)?).await?;
326 item_fut = stream.next();
327 }
328 None => {
329 break;
330 }
331 }
332 }
333 client_msg = client_fut => {
336 client_fut = recv.next().fuse();
337 match client_msg {
338 None => return Ok(()),
339 Some(Err(e)) => return Err(e),
340 _ => {}
341 }
342 }
343 };
344 }
345 Ok(())
346 }
347 .boxed()
348 }
349}
350
351pub(crate) fn stream_handler<State, Error, Msg, F, VER>(f: F) -> Handler<State, Error>
352where
353 F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxStream<Result<Msg, Error>>,
354 State: 'static + Send + Sync,
355 Msg: 'static + Serialize + Send + Sync,
356 Error: 'static + Send + Display,
357 VER: 'static + Send + Sync + StaticVersionType,
358{
359 let handler: StreamHandler<F, VER> = StreamHandler(f, Default::default());
360 raw_handler(move |req, conn, state| handler.handle(req, conn, state))
361}
362
363fn raw_handler<State, Error, ToClient, FromClient, F, VER>(f: F) -> Handler<State, Error>
364where
365 F: 'static
366 + Send
367 + Sync
368 + Fn(
369 RequestParams,
370 Connection<ToClient, FromClient, Error, VER>,
371 &State,
372 ) -> BoxFuture<Result<(), SocketError<Error>>>,
373 State: 'static + Send + Sync,
374 ToClient: 'static + Serialize + ?Sized,
375 FromClient: 'static + DeserializeOwned,
376 Error: 'static + Send + Display,
377 VER: StaticVersionType,
378{
379 let close = |conn: WebSocketConnection, res: Result<(), SocketError<Error>>| async move {
380 let msg = res.as_ref().err().map(|err| CloseFrame {
383 code: err.code(),
384 reason: Cow::Owned(err.to_string()),
385 });
386 conn.send(Message::Close(msg)).await?;
387 res
388 };
389 Box::new(move |req, raw_conn, state| {
390 let accept = match req.accept() {
391 Ok(accept) => accept,
392 Err(err) => return close(raw_conn, Err(err.into())).boxed(),
393 };
394 let conn = match Connection::new(&accept, raw_conn.clone()) {
395 Ok(conn) => conn,
396 Err(err) => return close(raw_conn, Err(err)).boxed(),
397 };
398 f(req, conn, state)
399 .then(move |res| close(raw_conn, res))
400 .boxed()
401 })
402}
403
404struct MapErr<State, Error, F> {
405 handler: Handler<State, Error>,
406 map: Arc<F>,
407}
408
409impl<State, Error, F> MapErr<State, Error, F> {
410 fn handle<'a, Error2>(
411 &self,
412 req: RequestParams,
413 conn: WebSocketConnection,
414 state: &'a State,
415 ) -> BoxFuture<'a, Result<(), SocketError<Error2>>>
416 where
417 F: 'static + Send + Sync + Fn(Error) -> Error2,
418 State: 'static + Send + Sync,
419 Error: 'static,
420 {
421 let map = self.map.clone();
422 let fut = (self.handler)(req, conn, state);
423 async move { fut.await.map_err(|err| err.map_app_specific(&*map)) }.boxed()
424 }
425}
426
427pub(crate) fn map_err<State, Error, Error2>(
428 h: Handler<State, Error>,
429 f: impl 'static + Send + Sync + Fn(Error) -> Error2,
430) -> Handler<State, Error2>
431where
432 State: 'static + Send + Sync,
433 Error: 'static,
434{
435 let handler = MapErr {
436 handler: h,
437 map: Arc::new(f),
438 };
439 Box::new(move |req, conn, state| handler.handle(req, conn, state))
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use crate::{error::ServerError, testing::test_ws_client, Api, App, Url};
446 use async_std::task::{sleep, spawn};
447 use async_tungstenite::tungstenite::Message as TungsteniteMessage;
448 use futures::{stream, StreamExt};
449 use pin_project::pinned_drop;
450 use portpicker::pick_unused_port;
451 use std::{
452 sync::{
453 atomic::{AtomicBool, Ordering},
454 Arc,
455 },
456 time::Duration,
457 };
458 use vbs::version::StaticVersion;
459
460 type StaticVer01 = StaticVersion<0, 1>;
461
462 #[pin_project(PinnedDrop)]
463 struct DropStream<S: Stream> {
464 #[pin]
465 stream: S,
466 dropped: Arc<AtomicBool>,
467 }
468
469 impl<S: Stream> Stream for DropStream<S> {
470 type Item = S::Item;
471
472 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
473 let stream = self.project().stream;
474 stream.poll_next(cx)
475 }
476 }
477
478 #[pinned_drop]
479 impl<S: Stream> PinnedDrop for DropStream<S> {
480 fn drop(self: Pin<&mut Self>) {
481 self.dropped.store(true, Ordering::SeqCst);
482 }
483 }
484
485 #[async_std::test]
486 async fn test_stream_handler_client_closure() {
487 let port = pick_unused_port().expect("No ports available");
489
490 let mut app = App::<(), ServerError>::with_state(());
491 let toml_content = r#"
492 [meta]
493 FORMAT_VERSION = "0.1.0"
494
495 [route.stream_test]
496 PATH = ["/stream"]
497 METHOD = "SOCKET"
498 "#;
499
500 let mut api =
501 Api::<(), ServerError, StaticVer01>::new(toml_content.parse::<toml::Value>().unwrap())
502 .unwrap();
503
504 let dropped = Arc::new(AtomicBool::new(false));
507 let _dropped = dropped.clone();
508 api.stream("stream_test", move |_req, _state| {
509 Box::pin(DropStream {
510 stream: stream::iter(0..).map(Result::Ok),
511 dropped: _dropped.clone(),
512 })
513 })
514 .unwrap();
515
516 app.register_module("test", api).unwrap();
517
518 spawn(async move {
520 app.serve(format!("127.0.0.1:{}", port), StaticVer01::instance())
521 .await
522 .unwrap();
523 });
524
525 sleep(Duration::from_millis(500)).await;
527
528 let url = Url::parse(&format!("http://127.0.0.1:{}/test/stream", port)).unwrap();
530 let mut ws_stream = test_ws_client(url).await;
531
532 let mut received_count = 0;
534 for _ in 0..5 {
535 if let Some(Ok(TungsteniteMessage::Text(msg))) = ws_stream.next().await {
536 let parsed: usize = serde_json::from_str(&msg).unwrap();
537 assert_eq!(parsed, received_count);
538 received_count += 1;
539 }
540 }
541
542 ws_stream
544 .close(None)
545 .await
546 .expect("Failed to close connection");
547
548 sleep(Duration::from_millis(300)).await;
550
551 assert!(dropped.load(Ordering::SeqCst));
553 }
554}