1use crate::{
10 http::{content::Accept, mime},
11 method::Method,
12 request::{best_response_type, RequestError, RequestParams},
13 StatusCode,
14};
15use async_std::sync::Arc;
16use futures::{
17 future::BoxFuture,
18 sink,
19 stream::BoxStream,
20 task::{Context, Poll},
21 FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt,
22};
23use pin_project::pin_project;
24use serde::{de::DeserializeOwned, Serialize};
25use std::borrow::Cow;
26use std::fmt::{self, Display, Formatter};
27use std::marker::PhantomData;
28use std::pin::Pin;
29use tide_websockets::{
30 tungstenite::protocol::frame::{coding::CloseCode, CloseFrame},
31 Message, WebSocketConnection,
32};
33use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
34
35#[derive(Debug)]
41pub enum SocketError<E> {
42 AppSpecific(E),
43 Request(RequestError),
44 Binary(anyhow::Error),
45 Json(serde_json::Error),
46 WebSockets(tide_websockets::Error),
47 UnsupportedMessageType,
48 Closed,
49 IncorrectMethod { expected: Method, actual: Method },
50}
51
52impl<E> SocketError<E> {
53 pub fn status(&self) -> StatusCode {
54 match self {
55 Self::Request(_) | Self::UnsupportedMessageType | Self::IncorrectMethod { .. } => {
56 StatusCode::BAD_REQUEST
57 }
58 _ => StatusCode::INTERNAL_SERVER_ERROR,
59 }
60 }
61
62 pub fn code(&self) -> CloseCode {
63 CloseCode::Error
64 }
65
66 pub fn map_app_specific<E2>(self, f: &impl Fn(E) -> E2) -> SocketError<E2> {
67 match self {
68 Self::AppSpecific(e) => SocketError::AppSpecific(f(e)),
69 Self::Request(e) => SocketError::Request(e),
70 Self::Binary(e) => SocketError::Binary(e),
71 Self::Json(e) => SocketError::Json(e),
72 Self::WebSockets(e) => SocketError::WebSockets(e),
73 Self::UnsupportedMessageType => SocketError::UnsupportedMessageType,
74 Self::Closed => SocketError::Closed,
75 Self::IncorrectMethod { expected, actual } => {
76 SocketError::IncorrectMethod { expected, actual }
77 }
78 }
79 }
80}
81
82impl<E: Display> Display for SocketError<E> {
83 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
84 match self {
85 Self::AppSpecific(e) => write!(f, "{}", e),
86 Self::Request(e) => write!(f, "{}", e),
87 Self::Binary(e) => write!(f, "error creating byte stream: {}", e),
88 Self::Json(e) => write!(f, "error creating JSON message: {}", e),
89 Self::WebSockets(e) => write!(f, "WebSockets protocol error: {}", e),
90 Self::UnsupportedMessageType => {
91 write!(f, "unsupported content type for WebSockets message")
92 }
93 Self::Closed => write!(f, "connection closed"),
94 Self::IncorrectMethod { expected, actual } => write!(
95 f,
96 "endpoint must be called as {}, but was called as {}",
97 expected, actual
98 ),
99 }
100 }
101}
102
103impl<E> From<RequestError> for SocketError<E> {
104 fn from(err: RequestError) -> Self {
105 Self::Request(err)
106 }
107}
108
109impl<E> From<anyhow::Error> for SocketError<E> {
110 fn from(err: anyhow::Error) -> Self {
111 Self::Binary(err)
112 }
113}
114
115impl<E> From<serde_json::Error> for SocketError<E> {
116 fn from(err: serde_json::Error) -> Self {
117 Self::Json(err)
118 }
119}
120
121impl<E> From<tide_websockets::Error> for SocketError<E> {
122 fn from(err: tide_websockets::Error) -> Self {
123 Self::WebSockets(err)
124 }
125}
126
127#[derive(Clone, Copy, Debug)]
128enum MessageType {
129 Binary,
130 Json,
131}
132
133#[pin_project]
138pub struct Connection<ToClient: ?Sized, FromClient, Error, VER: StaticVersionType> {
139 #[pin]
140 conn: WebSocketConnection,
141 sink: Pin<Box<dyn Send + Sink<Message, Error = SocketError<Error>>>>,
143 accept: MessageType,
144 #[allow(clippy::type_complexity)]
145 _phantom: PhantomData<fn(&ToClient, &FromClient, &Error, &VER) -> ()>,
146}
147
148impl<ToClient: ?Sized, FromClient: DeserializeOwned, E, VER: StaticVersionType> Stream
149 for Connection<ToClient, FromClient, E, VER>
150{
151 type Item = Result<FromClient, SocketError<E>>;
152
153 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
154 match self.project().conn.poll_next(cx) {
157 Poll::Ready(None) => Poll::Ready(None),
158 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))),
159 Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(match msg {
160 Message::Binary(bytes) => {
161 Serializer::<VER>::deserialize(&bytes).map_err(SocketError::from)
162 }
163 Message::Text(s) => serde_json::from_str(&s).map_err(SocketError::from),
164 _ => Err(SocketError::UnsupportedMessageType),
165 })),
166 Poll::Pending => Poll::Pending,
167 }
168 }
169}
170
171impl<ToClient: Serialize + ?Sized, FromClient, E, VER: StaticVersionType> Sink<&ToClient>
172 for Connection<ToClient, FromClient, E, VER>
173{
174 type Error = SocketError<E>;
175
176 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177 self.sink.as_mut().poll_ready(cx).map_err(SocketError::from)
178 }
179
180 fn start_send(mut self: Pin<&mut Self>, item: &ToClient) -> Result<(), Self::Error> {
181 let msg = match self.accept {
182 MessageType::Binary => Message::Binary(Serializer::<VER>::serialize(item)?),
183 MessageType::Json => Message::Text(serde_json::to_string(item)?),
184 };
185 self.sink.as_mut().start_send(msg)
186 }
187
188 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189 self.sink.as_mut().poll_flush(cx).map_err(SocketError::from)
190 }
191
192 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
193 self.sink.as_mut().poll_close(cx).map_err(SocketError::from)
194 }
195}
196
197impl<ToClient: Serialize, FromClient, E, VER: StaticVersionType> Sink<ToClient>
198 for Connection<ToClient, FromClient, E, VER>
199{
200 type Error = SocketError<E>;
201
202 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
203 Sink::<&ToClient>::poll_ready(self, cx)
204 }
205
206 fn start_send(self: Pin<&mut Self>, item: ToClient) -> Result<(), Self::Error> {
207 self.start_send(&item)
208 }
209
210 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
211 Sink::<&ToClient>::poll_flush(self, cx)
212 }
213
214 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215 Sink::<&ToClient>::poll_close(self, cx)
216 }
217}
218
219impl<ToClient: ?Sized, FromClient, E, VER: StaticVersionType>
220 Connection<ToClient, FromClient, E, VER>
221{
222 #[allow(clippy::result_large_err)]
223 fn new(accept: &Accept, conn: WebSocketConnection) -> Result<Self, SocketError<E>> {
224 let ty = best_response_type(accept, &[mime::JSON, mime::BYTE_STREAM])?;
225 let ty = if ty == mime::JSON {
226 MessageType::Json
227 } else if ty == mime::BYTE_STREAM {
228 MessageType::Binary
229 } else {
230 unreachable!()
231 };
232 Ok(Self {
233 sink: Self::sink(conn.clone()),
234 conn,
235 accept: ty,
236 _phantom: Default::default(),
237 })
238 }
239
240 fn sink(
242 conn: WebSocketConnection,
243 ) -> Pin<Box<dyn Send + Sink<Message, Error = SocketError<E>>>> {
244 Box::pin(sink::unfold(conn, |conn, msg| async move {
245 conn.send(msg).await?;
246 Ok(conn)
247 }))
248 }
249}
250
251impl<ToClient: ?Sized, FromClient, E, VER: StaticVersionType> Clone
252 for Connection<ToClient, FromClient, E, VER>
253{
254 fn clone(&self) -> Self {
255 Self {
256 sink: Self::sink(self.conn.clone()),
257 conn: self.conn.clone(),
258 accept: self.accept,
259 _phantom: Default::default(),
260 }
261 }
262}
263
264pub(crate) type Handler<State, Error> = Box<
265 dyn 'static
266 + Send
267 + Sync
268 + Fn(RequestParams, WebSocketConnection, &State) -> BoxFuture<Result<(), SocketError<Error>>>,
269>;
270
271pub(crate) fn handler<State, Error, ToClient, FromClient, F, VER: StaticVersionType>(
272 f: F,
273) -> Handler<State, Error>
274where
275 F: 'static
276 + Send
277 + Sync
278 + Fn(
279 RequestParams,
280 Connection<ToClient, FromClient, Error, VER>,
281 &State,
282 ) -> BoxFuture<Result<(), Error>>,
283 State: 'static + Send + Sync,
284 ToClient: 'static + Serialize + ?Sized,
285 FromClient: 'static + DeserializeOwned,
286 Error: 'static + Send + Display,
287{
288 raw_handler(move |req, conn, state| {
289 f(req, conn, state)
290 .map_err(SocketError::AppSpecific)
291 .boxed()
292 })
293}
294
295struct StreamHandler<F, VER: StaticVersionType>(F, PhantomData<VER>);
296
297impl<F, VER: StaticVersionType> StreamHandler<F, VER> {
298 fn handle<'a, State, Error, Msg>(
299 &self,
300 req: RequestParams,
301 mut conn: Connection<Msg, (), Error, VER>,
302 state: &'a State,
303 ) -> BoxFuture<'a, Result<(), SocketError<Error>>>
304 where
305 F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxStream<Result<Msg, Error>>,
306 State: 'static + Send + Sync,
307 Msg: 'static + Serialize + Send + Sync,
308 Error: 'static + Send,
309 VER: 'static + Send + Sync,
310 {
311 let mut stream = (self.0)(req, state);
312 async move {
313 while let Some(msg) = stream.next().await {
314 conn.send(&msg.map_err(SocketError::AppSpecific)?).await?;
315 }
316 Ok(())
317 }
318 .boxed()
319 }
320}
321
322pub(crate) fn stream_handler<State, Error, Msg, F, VER>(f: F) -> Handler<State, Error>
323where
324 F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxStream<Result<Msg, Error>>,
325 State: 'static + Send + Sync,
326 Msg: 'static + Serialize + Send + Sync,
327 Error: 'static + Send + Display,
328 VER: 'static + Send + Sync + StaticVersionType,
329{
330 let handler: StreamHandler<F, VER> = StreamHandler(f, Default::default());
331 raw_handler(move |req, conn, state| handler.handle(req, conn, state))
332}
333
334fn raw_handler<State, Error, ToClient, FromClient, F, VER>(f: F) -> Handler<State, Error>
335where
336 F: 'static
337 + Send
338 + Sync
339 + Fn(
340 RequestParams,
341 Connection<ToClient, FromClient, Error, VER>,
342 &State,
343 ) -> BoxFuture<Result<(), SocketError<Error>>>,
344 State: 'static + Send + Sync,
345 ToClient: 'static + Serialize + ?Sized,
346 FromClient: 'static + DeserializeOwned,
347 Error: 'static + Send + Display,
348 VER: StaticVersionType,
349{
350 let close = |conn: WebSocketConnection, res: Result<(), SocketError<Error>>| async move {
351 let msg = res.as_ref().err().map(|err| CloseFrame {
354 code: err.code(),
355 reason: Cow::Owned(err.to_string()),
356 });
357 conn.send(Message::Close(msg)).await?;
358 res
359 };
360 Box::new(move |req, raw_conn, state| {
361 let accept = match req.accept() {
362 Ok(accept) => accept,
363 Err(err) => return close(raw_conn, Err(err.into())).boxed(),
364 };
365 let conn = match Connection::new(&accept, raw_conn.clone()) {
366 Ok(conn) => conn,
367 Err(err) => return close(raw_conn, Err(err)).boxed(),
368 };
369 f(req, conn, state)
370 .then(move |res| close(raw_conn, res))
371 .boxed()
372 })
373}
374
375struct MapErr<State, Error, F> {
376 handler: Handler<State, Error>,
377 map: Arc<F>,
378}
379
380impl<State, Error, F> MapErr<State, Error, F> {
381 fn handle<'a, Error2>(
382 &self,
383 req: RequestParams,
384 conn: WebSocketConnection,
385 state: &'a State,
386 ) -> BoxFuture<'a, Result<(), SocketError<Error2>>>
387 where
388 F: 'static + Send + Sync + Fn(Error) -> Error2,
389 State: 'static + Send + Sync,
390 Error: 'static,
391 {
392 let map = self.map.clone();
393 let fut = (self.handler)(req, conn, state);
394 async move { fut.await.map_err(|err| err.map_app_specific(&*map)) }.boxed()
395 }
396}
397
398pub(crate) fn map_err<State, Error, Error2>(
399 h: Handler<State, Error>,
400 f: impl 'static + Send + Sync + Fn(Error) -> Error2,
401) -> Handler<State, Error2>
402where
403 State: 'static + Send + Sync,
404 Error: 'static,
405{
406 let handler = MapErr {
407 handler: h,
408 map: Arc::new(f),
409 };
410 Box::new(move |req, conn, state| handler.handle(req, conn, state))
411}