tide_disco/
api.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the tide-disco library.
3
4// You should have received a copy of the MIT License
5// along with the tide-disco library. If not, see <https://mit-license.org/>.
6
7use crate::{
8    healthcheck::{HealthCheck, HealthStatus},
9    method::{Method, ReadState, WriteState},
10    metrics::Metrics,
11    middleware::{error_handler, ErrorHandler},
12    request::RequestParams,
13    route::{self, *},
14    socket, Html,
15};
16use async_std::sync::Arc;
17use async_trait::async_trait;
18use derivative::Derivative;
19use futures::{
20    future::{BoxFuture, FutureExt},
21    stream::BoxStream,
22};
23use maud::{html, PreEscaped};
24use semver::Version;
25use serde::{de::DeserializeOwned, Deserialize, Serialize};
26use serde_with::{serde_as, DisplayFromStr};
27use snafu::{OptionExt, ResultExt, Snafu};
28use std::{
29    borrow::Cow,
30    collections::hash_map::{Entry, HashMap, IntoValues, Values},
31    convert::Infallible,
32    fmt::Display,
33    fs,
34    marker::PhantomData,
35    ops::Index,
36    path::{Path, PathBuf},
37};
38use tide::http::content::Accept;
39use vbs::version::StaticVersionType;
40
41/// An error encountered when parsing or constructing an [Api].
42#[derive(Clone, Debug, Snafu, PartialEq, Eq)]
43pub enum ApiError {
44    Route { source: RouteParseError },
45    ApiMustBeTable,
46    MissingRoutesTable,
47    RoutesMustBeTable,
48    UndefinedRoute,
49    HandlerAlreadyRegistered,
50    IncorrectMethod { expected: Method, actual: Method },
51    InvalidMetaTable { source: toml::de::Error },
52    MissingFormatVersion,
53    InvalidFormatVersion,
54    AmbiguousRoutes { route1: String, route2: String },
55    CannotReadToml { reason: String },
56}
57
58/// Version information about an API.
59#[serde_as]
60#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
61pub struct ApiVersion {
62    /// The version of this API.
63    #[serde_as(as = "Option<DisplayFromStr>")]
64    pub api_version: Option<Version>,
65
66    /// The format version of the TOML specification used to load this API.
67    #[serde_as(as = "DisplayFromStr")]
68    pub spec_version: Version,
69}
70
71/// Metadata used for describing and documenting an API.
72///
73/// [ApiMetadata] contains version information about the API, as well as optional HTML fragments to
74/// customize the formatting of automatically generated API documentation. Each of the supported
75/// HTML fragments is optional and will be filled in with a reasonable default if not provided. Some
76/// of the HTML fragments may contain "placeholders", which are identifiers enclosed in `{{ }}`,
77/// like `{{SOME_PLACEHOLDER}}`. These will be replaced by contextual information when the
78/// documentation is generated. The placeholders supported by each HTML fragment are documented
79/// below.
80#[serde_as]
81#[derive(Clone, Debug, Deserialize, Serialize)]
82#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
83pub struct ApiMetadata {
84    /// The name of this API.
85    ///
86    /// Note that the name of the API may be overridden if the API is registered with an app using
87    /// a different name.
88    #[serde(default = "meta_defaults::name")]
89    pub name: String,
90
91    /// A description of this API.
92    #[serde(default = "meta_defaults::description")]
93    pub description: String,
94
95    /// The version of the Tide Disco API specification format.
96    ///
97    /// If not specified, the version of this crate will be used.
98    #[serde_as(as = "DisplayFromStr")]
99    #[serde(default = "meta_defaults::format_version")]
100    pub format_version: Version,
101
102    /// HTML to be prepended to automatically generated documentation.
103    ///
104    /// # Placeholders
105    ///
106    /// * `NAME`: the name of the API
107    /// * `DESCRIPTION`: the description provided in `Cargo.toml`
108    /// * `VERSION`: the version of the API
109    /// * `FORMAT_VERSION`: the `FORMAT_VERSION` of the API
110    /// * `PUBLIC`: the URL where the public directory for this API is being served
111    #[serde(default = "meta_defaults::html_top")]
112    pub html_top: String,
113
114    /// HTML to be appended to automatically generated documentation.
115    #[serde(default = "meta_defaults::html_bottom")]
116    pub html_bottom: String,
117
118    /// The heading for documentation of a route.
119    ///
120    /// # Placeholders
121    ///
122    /// * `METHOD`: the method of the route
123    /// * `NAME`: the name of the route
124    #[serde(default = "meta_defaults::heading_entry")]
125    pub heading_entry: String,
126
127    /// The heading preceding documentation of all routes in this API.
128    #[serde(default = "meta_defaults::heading_routes")]
129    pub heading_routes: String,
130
131    /// The heading preceding documentation of route parameters.
132    #[serde(default = "meta_defaults::heading_parameters")]
133    pub heading_parameters: String,
134
135    /// The heading preceding documentation of a route description.
136    #[serde(default = "meta_defaults::heading_description")]
137    pub heading_description: String,
138
139    /// HTML formatting the path of a route.
140    ///
141    /// # Placeholders
142    ///
143    /// * `PATH`: the path being formatted
144    #[serde(default = "meta_defaults::route_path")]
145    pub route_path: String,
146
147    /// HTML preceding the contents of a table documenting the parameters of a route.
148    #[serde(default = "meta_defaults::parameter_table_open")]
149    pub parameter_table_open: String,
150
151    /// HTML closing a table documenting the parameters of a route.
152    #[serde(default = "meta_defaults::parameter_table_close")]
153    pub parameter_table_close: String,
154
155    /// HTML formatting an entry in a table documenting the parameters of a route.
156    ///
157    /// # Placeholders
158    ///
159    /// * `NAME`: the parameter being documented
160    /// * `TYPE`: the type of the parameter being documented
161    #[serde(default = "meta_defaults::parameter_row")]
162    pub parameter_row: String,
163
164    /// Documentation to insert in the parameters section of a route with no parameters.
165    #[serde(default = "meta_defaults::parameter_none")]
166    pub parameter_none: String,
167}
168
169impl Default for ApiMetadata {
170    fn default() -> Self {
171        // Deserialize an empty table, using the `serde` defaults for every field.
172        toml::Value::Table(Default::default()).try_into().unwrap()
173    }
174}
175
176mod meta_defaults {
177    use super::Version;
178
179    pub fn name() -> String {
180        "default-tide-disco-api".to_string()
181    }
182
183    pub fn description() -> String {
184        "Default Tide Disco API".to_string()
185    }
186
187    pub fn format_version() -> Version {
188        "0.1.0".parse().unwrap()
189    }
190
191    pub fn html_top() -> String {
192        "
193        <!DOCTYPE html>
194        <html lang='en'>
195          <head>
196            <meta charset='utf-8'>
197            <title>{{NAME}} Reference</title>
198            <link rel='stylesheet' href='{{PUBLIC}}/css/style.css'>
199            <script src='{{PUBLIC}}/js/script.js'></script>
200            <link rel='icon' type='image/svg+xml'
201             href='/public/favicon.svg'>
202          </head>
203          <body>
204            <div><a href='/'><img src='{{PUBLIC}}/espressosys_logo.svg'
205                      alt='Espresso Systems Logo'
206                      /></a></div>
207            <h1>{{NAME}} API {{VERSION}} Reference</h1>
208            <p>{{SHORT_DESCRIPTION}}</p><br/>
209            {{LONG_DESCRIPTION}}
210        "
211        .to_string()
212    }
213
214    pub fn html_bottom() -> String {
215        "
216            <h1>&nbsp;</h1>
217            <p>Copyright © 2022 Espresso Systems. All rights reserved.</p>
218          </body>
219        </html>
220        "
221        .to_string()
222    }
223
224    pub fn heading_entry() -> String {
225        "<a name='{{NAME}}'><h3 class='entry'><span class='meth'>{{METHOD}}</span> {{NAME}}</h3></a>\n".to_string()
226    }
227
228    pub fn heading_routes() -> String {
229        "<h3>Routes</h3>\n".to_string()
230    }
231    pub fn heading_parameters() -> String {
232        "<h3>Parameters</h3>\n".to_string()
233    }
234    pub fn heading_description() -> String {
235        "<h3>Description</h3>\n".to_string()
236    }
237
238    pub fn route_path() -> String {
239        "<p class='path'>{{PATH}}</p>\n".to_string()
240    }
241
242    pub fn parameter_table_open() -> String {
243        "<table>\n".to_string()
244    }
245    pub fn parameter_table_close() -> String {
246        "</table>\n\n".to_string()
247    }
248    pub fn parameter_row() -> String {
249        "<tr><td class='parameter'>{{NAME}}</td><td class='type'>{{TYPE}}</td></tr>\n".to_string()
250    }
251    pub fn parameter_none() -> String {
252        "<div class='meta'>None</div>".to_string()
253    }
254}
255
256/// A description of an API.
257///
258/// An [Api] is a structured representation of an `api.toml` specification. It contains API-level
259/// metadata and descriptions of all of the routes in the specification. It can be parsed from a
260/// TOML file and registered as a module of an [App](crate::App).
261#[derive(Derivative)]
262#[derivative(Debug(bound = ""))]
263pub struct Api<State, Error, VER> {
264    inner: ApiInner<State, Error>,
265    _version: PhantomData<VER>,
266}
267
268/// A version-erased description of an API.
269///
270/// This type contains all the details of the API, with the version of the binary serialization
271/// format type-erased and encapsulated into the route handlers. This type is used internally by
272/// [`App`], to allow dynamic registration of different versions of APIs with different versions of
273/// the binary format.
274///
275/// It is exposed publicly and manipulated _only_ via [`Api`], which wraps this type with a static
276/// format version type parameter, which provides compile-time enforcement of format version
277/// consistency as the API is being constructed, until it is registered with an [`App`] and
278/// type-erased.
279#[derive(Derivative)]
280#[derivative(Debug(bound = ""))]
281pub(crate) struct ApiInner<State, Error> {
282    meta: Arc<ApiMetadata>,
283    name: String,
284    routes: HashMap<String, Route<State, Error>>,
285    routes_by_path: HashMap<String, Vec<String>>,
286    #[derivative(Debug = "ignore")]
287    health_check: HealthCheckHandler<State>,
288    api_version: Option<Version>,
289    /// Error handler encapsulating the serialization format version for errors.
290    ///
291    /// This field is optional so it can be bound late, potentially after a `map_err` changes the
292    /// error type. However, it will always be set after `Api::into_inner` is called.
293    #[derivative(Debug = "ignore")]
294    error_handler: Option<Arc<dyn ErrorHandler<Error>>>,
295    /// Response handler encapsulating the serialization format version for version requests
296    #[derivative(Debug = "ignore")]
297    version_handler: Arc<dyn VersionHandler>,
298    public: Option<PathBuf>,
299    short_description: String,
300    long_description: String,
301}
302
303pub(crate) trait VersionHandler:
304    Send + Sync + Fn(&Accept, ApiVersion) -> Result<tide::Response, RouteError<Infallible>>
305{
306}
307impl<F> VersionHandler for F where
308    F: Send + Sync + Fn(&Accept, ApiVersion) -> Result<tide::Response, RouteError<Infallible>>
309{
310}
311
312impl<'a, State, Error> IntoIterator for &'a ApiInner<State, Error> {
313    type Item = &'a Route<State, Error>;
314    type IntoIter = Values<'a, String, Route<State, Error>>;
315
316    fn into_iter(self) -> Self::IntoIter {
317        self.routes.values()
318    }
319}
320
321impl<State, Error> IntoIterator for ApiInner<State, Error> {
322    type Item = Route<State, Error>;
323    type IntoIter = IntoValues<String, Route<State, Error>>;
324
325    fn into_iter(self) -> Self::IntoIter {
326        self.routes.into_values()
327    }
328}
329
330impl<State, Error> Index<&str> for ApiInner<State, Error> {
331    type Output = Route<State, Error>;
332
333    fn index(&self, index: &str) -> &Route<State, Error> {
334        &self.routes[index]
335    }
336}
337
338/// Iterator for [routes_by_path](ApiInner::routes_by_path).
339///
340/// This type iterates over all of the routes that have a given path.
341/// [routes_by_path](ApiInner::routes_by_path), in turn, returns an iterator over paths whose items
342/// contain a [RoutesWithPath] iterator.
343pub(crate) struct RoutesWithPath<'a, State, Error> {
344    routes: std::slice::Iter<'a, String>,
345    api: &'a ApiInner<State, Error>,
346}
347
348impl<'a, State, Error> Iterator for RoutesWithPath<'a, State, Error> {
349    type Item = &'a Route<State, Error>;
350
351    fn next(&mut self) -> Option<Self::Item> {
352        Some(&self.api.routes[self.routes.next()?])
353    }
354}
355
356impl<State, Error> ApiInner<State, Error> {
357    /// Iterate over groups of routes with the same path.
358    pub(crate) fn routes_by_path(
359        &self,
360    ) -> impl Iterator<Item = (&str, RoutesWithPath<'_, State, Error>)> {
361        self.routes_by_path.iter().map(|(path, routes)| {
362            (
363                path.as_str(),
364                RoutesWithPath {
365                    routes: routes.iter(),
366                    api: self,
367                },
368            )
369        })
370    }
371
372    /// Check the health status of a server with the given state.
373    pub(crate) async fn health(&self, req: RequestParams, state: &State) -> tide::Response {
374        (self.health_check)(req, state).await
375    }
376
377    /// Get the version of this API.
378    pub(crate) fn version(&self) -> ApiVersion {
379        ApiVersion {
380            api_version: self.api_version.clone(),
381            spec_version: self.meta.format_version.clone(),
382        }
383    }
384
385    pub(crate) fn public(&self) -> Option<&PathBuf> {
386        self.public.as_ref()
387    }
388
389    pub(crate) fn set_name(&mut self, name: String) {
390        self.name = name;
391    }
392
393    /// Compose an HTML page documenting all the routes in this API.
394    pub(crate) fn documentation(&self) -> Html {
395        html! {
396            (PreEscaped(self.meta.html_top
397                .replace("{{NAME}}", &self.name)
398                .replace("{{SHORT_DESCRIPTION}}", &self.short_description)
399                .replace("{{LONG_DESCRIPTION}}", &self.long_description)
400                .replace("{{VERSION}}", &match &self.api_version {
401                    Some(version) => version.to_string(),
402                    None => "(no version)".to_string(),
403                })
404                .replace("{{FORMAT_VERSION}}", &self.meta.format_version.to_string())
405                .replace("{{PUBLIC}}", &format!("/public/{}", self.name))))
406            @for route in self.routes.values() {
407                (route.documentation())
408            }
409            (PreEscaped(&self.meta.html_bottom))
410        }
411    }
412
413    /// The short description of this API from the specification.
414    pub(crate) fn short_description(&self) -> &str {
415        &self.short_description
416    }
417
418    pub(crate) fn error_handler(&self) -> Arc<dyn ErrorHandler<Error>> {
419        self.error_handler.clone().unwrap()
420    }
421
422    pub(crate) fn version_handler(&self) -> Arc<dyn VersionHandler> {
423        self.version_handler.clone()
424    }
425}
426
427impl<State, Error, VER> Api<State, Error, VER>
428where
429    State: 'static,
430    Error: 'static,
431    VER: StaticVersionType + 'static,
432{
433    /// Parse an API from a TOML specification.
434    pub fn new(api: impl Into<toml::Value>) -> Result<Self, ApiError> {
435        let mut api = api.into();
436        let meta = match api
437            .as_table_mut()
438            .context(ApiMustBeTableSnafu)?
439            .remove("meta")
440        {
441            Some(meta) => toml::Value::try_into(meta)
442                .map_err(|source| ApiError::InvalidMetaTable { source })?,
443            None => ApiMetadata::default(),
444        };
445        let meta = Arc::new(meta);
446        let routes = match api.get("route") {
447            Some(routes) => routes.as_table().context(RoutesMustBeTableSnafu)?,
448            None => return Err(ApiError::MissingRoutesTable),
449        };
450        // Collect routes into a [HashMap] indexed by route name.
451        let routes = routes
452            .into_iter()
453            .map(|(name, spec)| {
454                let route = Route::new(name.clone(), spec, meta.clone()).context(RouteSnafu)?;
455                Ok((route.name(), route))
456            })
457            .collect::<Result<HashMap<_, _>, _>>()?;
458        // Collect routes into groups of route names indexed by route pattern.
459        let mut routes_by_path = HashMap::new();
460        for route in routes.values() {
461            for path in route.patterns() {
462                match routes_by_path.entry(path.clone()) {
463                    Entry::Vacant(e) => e.insert(Vec::new()).push(route.name().clone()),
464                    Entry::Occupied(mut e) => {
465                        // If there is already a route with this path and method, then dispatch is
466                        // ambiguous.
467                        if let Some(ambiguous_name) = e
468                            .get()
469                            .iter()
470                            .find(|name| routes[*name].method() == route.method())
471                        {
472                            return Err(ApiError::AmbiguousRoutes {
473                                route1: route.name(),
474                                route2: ambiguous_name.clone(),
475                            });
476                        }
477                        e.get_mut().push(route.name());
478                    }
479                }
480            }
481        }
482
483        // Parse description: the first line is a short description, to display when briefly
484        // describing this API in a list. The rest is the long description, to display on this API's
485        // own documentation page. Both are rendered to HTML via Markdown.
486        let blocks = markdown::tokenize(&meta.description);
487        let (short_description, long_description) = match blocks.split_first() {
488            Some((short, long)) => {
489                let render = |blocks| markdown::to_html(&markdown::generate_markdown(blocks));
490
491                let short = render(vec![short.clone()]);
492                let long = render(long.to_vec());
493
494                // The short description is only one block, and sometimes we would like to display
495                // it inline (as a `span`). Markdown automatically wraps blocks in `<p>`. We will
496                // strip this outer tag so that we can wrap it in either `<p>` or `<span>`,
497                // depending on the context.
498                let short = short.strip_prefix("<p>").unwrap_or(&short);
499                let short = short.strip_suffix("</p>").unwrap_or(short);
500                let short = short.to_string();
501
502                (short, long)
503            }
504            None => Default::default(),
505        };
506
507        Ok(Self {
508            inner: ApiInner {
509                name: meta.name.clone(),
510                meta,
511                routes,
512                routes_by_path,
513                health_check: Box::new(Self::default_health_check),
514                api_version: None,
515                error_handler: None,
516                version_handler: Arc::new(|accept, version| {
517                    respond_with(accept, version, VER::instance())
518                }),
519                public: None,
520                short_description,
521                long_description,
522            },
523            _version: Default::default(),
524        })
525    }
526
527    /// Create an [Api] by reading a TOML specification from a file.
528    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ApiError> {
529        let bytes = fs::read(path).map_err(|err| ApiError::CannotReadToml {
530            reason: err.to_string(),
531        })?;
532        let string = std::str::from_utf8(&bytes).map_err(|err| ApiError::CannotReadToml {
533            reason: err.to_string(),
534        })?;
535        Self::new(toml::from_str::<toml::Value>(string).map_err(|err| {
536            ApiError::CannotReadToml {
537                reason: err.to_string(),
538            }
539        })?)
540    }
541
542    /// Set the API version.
543    ///
544    /// The version information will automatically be included in responses to `GET /version`. This
545    /// version can also be used to serve multiple major versions of the same API simultaneously,
546    /// under a version prefix. For more information, see
547    /// [App::register_module](crate::App::register_module).
548    ///
549    /// This is the version of the application or sub-application which this instance of [Api]
550    /// represents. The versioning corresponds to the API specification passed to [new](Api::new),
551    /// and may be different from the version of the Rust crate implementing the route handlers for
552    /// the API.
553    pub fn with_version(&mut self, version: Version) -> &mut Self {
554        self.inner.api_version = Some(version);
555        self
556    }
557
558    /// Serve the contents of `dir` at the URL `/public/{{NAME}}`.
559    pub fn with_public(&mut self, dir: PathBuf) -> &mut Self {
560        self.inner.public = Some(dir);
561        self
562    }
563
564    /// Register a handler for a route.
565    ///
566    /// When the server receives a request whose URL matches the pattern of the route `name`,
567    /// `handler` will be invoked with the parameters of the request and a reference to the current
568    /// state, and the result will be serialized into a response.
569    ///
570    /// # Examples
571    ///
572    /// A simple getter route for a state object.
573    ///
574    /// `api.toml`
575    ///
576    /// ```toml
577    /// [route.getstate]
578    /// PATH = ["/getstate"]
579    /// DOC = "Gets the current state."
580    /// ```
581    ///
582    /// ```
583    /// use futures::FutureExt;
584    /// # use tide_disco::Api;
585    /// # use vbs::version::StaticVersion;
586    ///
587    /// type State = u64;
588    /// type StaticVer01 = StaticVersion<0, 1>;
589    ///
590    /// # fn ex(api: &mut Api<State, (), StaticVer01>) {
591    /// api.at("getstate", |req, state| async { Ok(*state) }.boxed());
592    /// # }
593    /// ```
594    ///
595    /// A counter endpoint which increments a mutable state. Notice how we use `METHOD = "POST"` to
596    /// ensure that the HTTP method for this route is compatible with mutable access.
597    ///
598    /// `api.toml`
599    ///
600    /// ```toml
601    /// [route.increment]
602    /// PATH = ["/increment"]
603    /// METHOD = "POST"
604    /// DOC = "Increment the current state and return the new value."
605    /// ```
606    ///
607    /// ```
608    /// use async_std::sync::Mutex;
609    /// use futures::FutureExt;
610    /// # use tide_disco::Api;
611    /// # use vbs::version::StaticVersion;
612    ///
613    /// type State = Mutex<u64>;
614    /// type StaticVer01 = StaticVersion<0, 1>;
615    ///
616    /// # fn ex(api: &mut Api<State, (), StaticVer01>) {
617    /// api.at("increment", |req, state| async {
618    ///     let mut guard = state.lock().await;
619    ///     *guard += 1;
620    ///     Ok(*guard)
621    /// }.boxed());
622    /// # }
623    /// ```
624    ///
625    /// # Warnings
626    /// The route will use the HTTP method specified in the TOML specification for the named route
627    /// (or GET if the method is not specified). Some HTTP methods imply constraints on mutability.
628    /// For example, GET routes must be "pure", and not mutate any server state. Violating this
629    /// constraint may lead to confusing and unpredictable behavior. If the `State` type has
630    /// interior mutability (for instance, [RwLock](async_std::sync::RwLock)) it is up to the
631    /// `handler` not to use the interior mutability if the HTTP method suggests it shouldn't.
632    ///
633    /// If you know the HTTP method when you are registering the route, we recommend you use the
634    /// safer versions of this function, which enforce the appropriate mutability constraints. For
635    /// example,
636    /// * [get](Self::get)
637    /// * [post](Self::post)
638    /// * [put](Self::put)
639    /// * [delete](Self::delete)
640    ///
641    /// # Errors
642    ///
643    /// If the route `name` does not exist in the API specification, or if the route already has a
644    /// handler registered, an error is returned. Note that all routes are initialized with a
645    /// default handler that echoes parameters and shows documentation, but this default handler can
646    /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
647    ///
648    /// If the route `name` exists, but it is not an HTTP route (for example, `METHOD = "SOCKET"`
649    /// was used when defining the route in the API specification), [ApiError::IncorrectMethod] is
650    /// returned.
651    ///
652    /// # Limitations
653    ///
654    /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
655    /// handler function is required to return a [BoxFuture].
656    pub fn at<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
657    where
658        F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxFuture<'_, Result<T, Error>>,
659        T: Serialize,
660        State: 'static + Send + Sync,
661        VER: 'static + Send + Sync,
662    {
663        let route = self
664            .inner
665            .routes
666            .get_mut(name)
667            .ok_or(ApiError::UndefinedRoute)?;
668        if route.has_handler() {
669            return Err(ApiError::HandlerAlreadyRegistered);
670        }
671
672        if !route.method().is_http() {
673            return Err(ApiError::IncorrectMethod {
674                // Just pick any HTTP method as the expected method.
675                expected: Method::get(),
676                actual: route.method(),
677            });
678        }
679
680        // `set_fn_handler` only fails if the route is not an HTTP route; since we have already
681        // checked that it is, this cannot fail.
682        route
683            .set_fn_handler(handler, VER::instance())
684            .unwrap_or_else(|_| panic!("unexpected failure in set_fn_handler"));
685
686        Ok(self)
687    }
688
689    fn method_immutable<F, T>(
690        &mut self,
691        method: Method,
692        name: &str,
693        handler: F,
694    ) -> Result<&mut Self, ApiError>
695    where
696        F: 'static
697            + Send
698            + Sync
699            + Fn(RequestParams, &<State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
700        T: Serialize,
701        State: 'static + Send + Sync + ReadState,
702        VER: 'static + Send + Sync + StaticVersionType,
703    {
704        assert!(method.is_http() && !method.is_mutable());
705        let route = self
706            .inner
707            .routes
708            .get_mut(name)
709            .ok_or(ApiError::UndefinedRoute)?;
710        if route.method() != method {
711            return Err(ApiError::IncorrectMethod {
712                expected: method,
713                actual: route.method(),
714            });
715        }
716        if route.has_handler() {
717            return Err(ApiError::HandlerAlreadyRegistered);
718        }
719        // `set_handler` only fails if the route is not an HTTP route; since we have already checked
720        // that it is, this cannot fail.
721        route
722            .set_handler(ReadHandler::<_, VER>::from(handler))
723            .unwrap_or_else(|_| panic!("unexpected failure in set_handler"));
724        Ok(self)
725    }
726
727    /// Register a handler for a GET route.
728    ///
729    /// When the server receives a GET request whose URL matches the pattern of the route `name`,
730    /// `handler` will be invoked with the parameters of the request and immutable access to the
731    /// current state, and the result will be serialized into a response.
732    ///
733    /// The [ReadState] trait is used to acquire immutable access to the state, so the state
734    /// reference passed to `handler` is actually [`<State as ReadState>::State`](ReadState::State).
735    /// For example, if `State` is `RwLock<T>`, the lock will automatically be acquired for reading,
736    /// and the handler will be passed a `&T`.
737    ///
738    /// # Examples
739    ///
740    /// A simple getter route for a state object.
741    ///
742    /// `api.toml`
743    ///
744    /// ```toml
745    /// [route.getstate]
746    /// PATH = ["/getstate"]
747    /// DOC = "Gets the current state."
748    /// ```
749    ///
750    /// ```
751    /// use async_std::sync::RwLock;
752    /// use futures::FutureExt;
753    /// # use tide_disco::Api;
754    /// # use vbs::{Serializer, version::StaticVersion};
755    ///
756    /// type State = RwLock<u64>;
757    /// type StaticVer01 = StaticVersion<0, 1>;
758    ///
759    /// # fn ex(api: &mut Api<State, (), StaticVer01>) {
760    /// api.get("getstate", |req, state| async { Ok(*state) }.boxed());
761    /// # }
762    /// ```
763    ///
764    /// # Errors
765    ///
766    /// If the route `name` does not exist in the API specification, or if the route already has a
767    /// handler registered, an error is returned. Note that all routes are initialized with a
768    /// default handler that echoes parameters and shows documentation, but this default handler can
769    /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
770    ///
771    /// If the route `name` exists, but the method is not GET (that is, `METHOD = "M"` was used in
772    /// the route definition in `api.toml`, with `M` other than `GET`) the error
773    /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
774    ///
775    /// # Limitations
776    ///
777    /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
778    /// handler function is required to return a [BoxFuture].
779    pub fn get<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
780    where
781        F: 'static
782            + Send
783            + Sync
784            + Fn(RequestParams, &<State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
785        T: Serialize,
786        State: 'static + Send + Sync + ReadState,
787        VER: 'static + Send + Sync,
788    {
789        self.method_immutable(Method::get(), name, handler)
790    }
791
792    fn method_mutable<F, T>(
793        &mut self,
794        method: Method,
795        name: &str,
796        handler: F,
797    ) -> Result<&mut Self, ApiError>
798    where
799        F: 'static
800            + Send
801            + Sync
802            + Fn(RequestParams, &mut <State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
803        T: Serialize,
804        State: 'static + Send + Sync + WriteState,
805        VER: 'static + Send + Sync,
806    {
807        assert!(method.is_http() && method.is_mutable());
808        let route = self
809            .inner
810            .routes
811            .get_mut(name)
812            .ok_or(ApiError::UndefinedRoute)?;
813        if route.method() != method {
814            return Err(ApiError::IncorrectMethod {
815                expected: method,
816                actual: route.method(),
817            });
818        }
819        if route.has_handler() {
820            return Err(ApiError::HandlerAlreadyRegistered);
821        }
822
823        // `set_handler` only fails if the route is not an HTTP route; since we have already checked
824        // that it is, this cannot fail.
825        route
826            .set_handler(WriteHandler::<_, VER>::from(handler))
827            .unwrap_or_else(|_| panic!("unexpected failure in set_handler"));
828        Ok(self)
829    }
830
831    /// Register a handler for a POST route.
832    ///
833    /// When the server receives a POST request whose URL matches the pattern of the route `name`,
834    /// `handler` will be invoked with the parameters of the request and exclusive, mutable access
835    /// to the current state, and the result will be serialized into a response.
836    ///
837    /// The [WriteState] trait is used to acquire mutable access to the state, so the state
838    /// reference passed to `handler` is actually [`<State as ReadState>::State`](ReadState::State).
839    /// For example, if `State` is `RwLock<T>`, the lock will automatically be acquired for writing,
840    /// and the handler will be passed a `&mut T`.
841    ///
842    /// # Examples
843    ///
844    /// A counter endpoint which increments the state and returns the new state.
845    ///
846    /// `api.toml`
847    ///
848    /// ```toml
849    /// [route.increment]
850    /// PATH = ["/increment"]
851    /// METHOD = "POST"
852    /// DOC = "Increment the current state and return the new value."
853    /// ```
854    ///
855    /// ```
856    /// use async_std::sync::RwLock;
857    /// use futures::FutureExt;
858    /// # use tide_disco::Api;
859    /// # use vbs::version::StaticVersion;
860    ///
861    /// type State = RwLock<u64>;
862    /// type StaticVer01 = StaticVersion<0, 1>;
863    ///
864    /// # fn ex(api: &mut Api<State, (), StaticVer01>) {
865    /// api.post("increment", |req, state| async {
866    ///     *state += 1;
867    ///     Ok(*state)
868    /// }.boxed());
869    /// # }
870    /// ```
871    ///
872    /// # Errors
873    ///
874    /// If the route `name` does not exist in the API specification, or if the route already has a
875    /// handler registered, an error is returned. Note that all routes are initialized with a
876    /// default handler that echoes parameters and shows documentation, but this default handler can
877    /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
878    ///
879    /// If the route `name` exists, but the method is not POST (that is, `METHOD = "M"` was used in
880    /// the route definition in `api.toml`, with `M` other than `POST`) the error
881    /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
882    ///
883    /// # Limitations
884    ///
885    /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
886    /// handler function is required to return a [BoxFuture].
887    pub fn post<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
888    where
889        F: 'static
890            + Send
891            + Sync
892            + Fn(RequestParams, &mut <State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
893        T: Serialize,
894        State: 'static + Send + Sync + WriteState,
895        VER: 'static + Send + Sync,
896    {
897        self.method_mutable(Method::post(), name, handler)
898    }
899
900    /// Register a handler for a PUT route.
901    ///
902    /// When the server receives a PUT request whose URL matches the pattern of the route `name`,
903    /// `handler` will be invoked with the parameters of the request and exclusive, mutable access
904    /// to the current state, and the result will be serialized into a response.
905    ///
906    /// The [WriteState] trait is used to acquire mutable access to the state, so the state
907    /// reference passed to `handler` is actually [`<State as ReadState>::State`](ReadState::State).
908    /// For example, if `State` is `RwLock<T>`, the lock will automatically be acquired for writing,
909    /// and the handler will be passed a `&mut T`.
910    ///
911    /// # Examples
912    ///
913    /// An endpoint which replaces the current state with a new value.
914    ///
915    /// `api.toml`
916    ///
917    /// ```toml
918    /// [route.replace]
919    /// PATH = ["/replace/:new_state"]
920    /// METHOD = "PUT"
921    /// ":new_state" = "Integer"
922    /// DOC = "Set the state to `:new_state`."
923    /// ```
924    ///
925    /// ```
926    /// use async_std::sync::RwLock;
927    /// use futures::FutureExt;
928    /// # use tide_disco::Api;
929    /// # use vbs::version::StaticVersion;
930    ///
931    /// type State = RwLock<u64>;
932    /// type StaticVer01 = StaticVersion<0, 1>;
933    ///
934    /// # fn ex(api: &mut Api<State, tide_disco::RequestError, StaticVer01>) {
935    /// api.post("replace", |req, state| async move {
936    ///     *state = req.integer_param("new_state")?;
937    ///     Ok(())
938    /// }.boxed());
939    /// # }
940    /// ```
941    ///
942    /// # Errors
943    ///
944    /// If the route `name` does not exist in the API specification, or if the route already has a
945    /// handler registered, an error is returned. Note that all routes are initialized with a
946    /// default handler that echoes parameters and shows documentation, but this default handler can
947    /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
948    ///
949    /// If the route `name` exists, but the method is not PUT (that is, `METHOD = "M"` was used in
950    /// the route definition in `api.toml`, with `M` other than `PUT`) the error
951    /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
952    ///
953    /// # Limitations
954    ///
955    /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
956    /// handler function is required to return a [BoxFuture].
957    pub fn put<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
958    where
959        F: 'static
960            + Send
961            + Sync
962            + Fn(RequestParams, &mut <State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
963        T: Serialize,
964        State: 'static + Send + Sync + WriteState,
965        VER: 'static + Send + Sync,
966    {
967        self.method_mutable(Method::put(), name, handler)
968    }
969
970    /// Register a handler for a DELETE route.
971    ///
972    /// When the server receives a DELETE request whose URL matches the pattern of the route `name`,
973    /// `handler` will be invoked with the parameters of the request and exclusive, mutable access
974    /// to the current state, and the result will be serialized into a response.
975    ///
976    /// The [WriteState] trait is used to acquire mutable access to the state, so the state
977    /// reference passed to `handler` is actually [`<State as ReadState>::State`](ReadState::State).
978    /// For example, if `State` is `RwLock<T>`, the lock will automatically be acquired for writing,
979    /// and the handler will be passed a `&mut T`.
980    ///
981    /// # Examples
982    ///
983    /// An endpoint which clears the current state.
984    ///
985    /// `api.toml`
986    ///
987    /// ```toml
988    /// [route.state]
989    /// PATH = ["/state"]
990    /// METHOD = "DELETE"
991    /// DOC = "Clear the state."
992    /// ```
993    ///
994    /// ```
995    /// use async_std::sync::RwLock;
996    /// use futures::FutureExt;
997    /// # use tide_disco::Api;
998    /// # use vbs::version::StaticVersion;
999    ///
1000    /// type State = RwLock<Option<u64>>;
1001    /// type StaticVer01 = StaticVersion<0, 1>;
1002    ///
1003    /// # fn ex(api: &mut Api<State, (), StaticVer01>) {
1004    /// api.delete("state", |req, state| async {
1005    ///     *state = None;
1006    ///     Ok(())
1007    /// }.boxed());
1008    /// # }
1009    /// ```
1010    ///
1011    /// # Errors
1012    ///
1013    /// If the route `name` does not exist in the API specification, or if the route already has a
1014    /// handler registered, an error is returned. Note that all routes are initialized with a
1015    /// default handler that echoes parameters and shows documentation, but this default handler can
1016    /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
1017    ///
1018    /// If the route `name` exists, but the method is not DELETE (that is, `METHOD = "M"` was used
1019    /// in the route definition in `api.toml`, with `M` other than `DELETE`) the error
1020    /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
1021    ///
1022    /// # Limitations
1023    ///
1024    /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
1025    /// handler function is required to return a [BoxFuture].
1026    pub fn delete<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
1027    where
1028        F: 'static
1029            + Send
1030            + Sync
1031            + Fn(RequestParams, &mut <State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
1032        T: Serialize,
1033        State: 'static + Send + Sync + WriteState,
1034        VER: 'static + Send + Sync,
1035    {
1036        self.method_mutable(Method::delete(), name, handler)
1037    }
1038
1039    /// Register a handler for a SOCKET route.
1040    ///
1041    /// When the server receives any request whose URL matches the pattern for this route and which
1042    /// includes the WebSockets upgrade headers, the server will negotiate a protocol upgrade with
1043    /// the client, establishing a WebSockets connection, and then invoke `handler`. `handler` will
1044    /// be given the parameters of the request which initiated the connection and a reference to the
1045    /// application state, as well as a [Connection](socket::Connection) object which it can then
1046    /// use for asynchronous, bi-directional communication with the client.
1047    ///
1048    /// The server side of the connection will remain open as long as the future returned by
1049    /// `handler` is remains unresolved. The handler can terminate the connection by returning. If
1050    /// it returns an error, the error message will be included in the
1051    /// [CloseFrame](tide_websockets::tungstenite::protocol::CloseFrame) sent to the client when
1052    /// tearing down the connection.
1053    ///
1054    /// # Examples
1055    ///
1056    /// A socket endpoint which receives amounts from the client and returns a running sum.
1057    ///
1058    /// `api.toml`
1059    ///
1060    /// ```toml
1061    /// [route.sum]
1062    /// PATH = ["/sum"]
1063    /// METHOD = "SOCKET"
1064    /// DOC = "Stream a running sum."
1065    /// ```
1066    ///
1067    /// ```
1068    /// use futures::{FutureExt, SinkExt, StreamExt};
1069    /// use tide_disco::{error::ServerError, socket::Connection, Api};
1070    /// # use vbs::version::StaticVersion;
1071    ///
1072    /// # fn ex(api: &mut Api<(), ServerError, StaticVersion<0, 1>>) {
1073    /// api.socket("sum", |_req, mut conn: Connection<i32, i32, ServerError, StaticVersion<0, 1>>, _state| async move {
1074    ///     let mut sum = 0;
1075    ///     while let Some(amount) = conn.next().await {
1076    ///         sum += amount?;
1077    ///         conn.send(&sum).await?;
1078    ///     }
1079    ///     Ok(())
1080    /// }.boxed());
1081    /// # }
1082    /// ```
1083    ///
1084    /// In some cases, it may be desirable to handle messages to and from the client in separate
1085    /// tasks. There are two ways of doing this:
1086    ///
1087    /// ## Split the connection into separate stream and sink
1088    ///
1089    /// ```
1090    /// use async_std::task::spawn;
1091    /// use futures::{future::{join, FutureExt}, sink::SinkExt, stream::StreamExt};
1092    /// use tide_disco::{error::ServerError, socket::Connection, Api};
1093    /// # use vbs::version::StaticVersion;
1094    ///
1095    /// # fn ex(api: &mut Api<(), ServerError, StaticVersion<0, 1>>) {
1096    /// api.socket("endpoint", |_req, mut conn: Connection<i32, i32, ServerError, StaticVersion<0, 1>>, _state| async move {
1097    ///     let (mut sink, mut stream) = conn.split();
1098    ///     let recv = spawn(async move {
1099    ///         while let Some(Ok(msg)) = stream.next().await {
1100    ///             // Handle message from client.
1101    ///         }
1102    ///     });
1103    ///     let send = spawn(async move {
1104    ///         loop {
1105    ///             let msg = // get message to send to client
1106    /// #               0;
1107    ///             sink.send(msg).await;
1108    ///         }
1109    ///     });
1110    ///
1111    ///     join(send, recv).await;
1112    ///     Ok(())
1113    /// }.boxed());
1114    /// # }
1115    /// ```
1116    ///
1117    /// This approach requires messages to be sent to the client by value, consuming the message.
1118    /// This is because, if we were to use the `Sync<&ToClient>` implementation for `Connection`,
1119    /// the lifetime for `&ToClient` would be fixed after `split` is called, since the lifetime
1120    /// appears in the return type, `SplitSink<Connection<...>, &ToClient>`. Thus, this lifetime
1121    /// outlives any scoped local variables created after the `split` call, such as `msg` in the
1122    /// `loop`.
1123    ///
1124    /// If we want to use the message after sending it to the client, we would have to clone it,
1125    /// which may be inefficient or impossible. Thus, there is another approach:
1126    ///
1127    /// ## Clone the connection
1128    ///
1129    /// ```
1130    /// use async_std::task::spawn;
1131    /// use futures::{future::{join, FutureExt}, sink::SinkExt, stream::StreamExt};
1132    /// use tide_disco::{error::ServerError, socket::Connection, Api};
1133    /// # use vbs::version::StaticVersion;
1134    ///
1135    /// # fn ex(api: &mut Api<(), ServerError, StaticVersion<0, 1>>) {
1136    /// api.socket("endpoint", |_req, mut conn: Connection<i32, i32, ServerError, StaticVersion<0, 1>>, _state| async move {
1137    ///     let recv = {
1138    ///         let mut conn = conn.clone();
1139    ///         spawn(async move {
1140    ///             while let Some(Ok(msg)) = conn.next().await {
1141    ///                 // Handle message from client.
1142    ///             }
1143    ///         })
1144    ///     };
1145    ///     let send = spawn(async move {
1146    ///         loop {
1147    ///             let msg = // get message to send to client
1148    /// #               0;
1149    ///             conn.send(&msg).await;
1150    ///             // msg is still live at this point.
1151    ///             drop(msg);
1152    ///         }
1153    ///     });
1154    ///
1155    ///     join(send, recv).await;
1156    ///     Ok(())
1157    /// }.boxed());
1158    /// # }
1159    /// ```
1160    ///
1161    /// Depending on the exact situation, this method may end up being more verbose than the
1162    /// previous example. But it allows us to retain the higher-ranked trait bound `conn: for<'a>
1163    /// Sink<&'a ToClient>` instead of fixing the lifetime, which can prevent an unnecessary clone
1164    /// in certain situations.
1165    ///
1166    /// # Errors
1167    ///
1168    /// If the route `name` does not exist in the API specification, or if the route already has a
1169    /// handler registered, an error is returned. Note that all routes are initialized with a
1170    /// default handler that echoes parameters and shows documentation, but this default handler can
1171    /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
1172    ///
1173    /// If the route `name` exists, but the method is not SOCKET (that is, `METHOD = "M"` was used
1174    /// in the route definition in `api.toml`, with `M` other than `SOCKET`) the error
1175    /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
1176    ///
1177    /// # Limitations
1178    ///
1179    /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
1180    /// handler function is required to return a [BoxFuture].
1181    pub fn socket<F, ToClient, FromClient>(
1182        &mut self,
1183        name: &str,
1184        handler: F,
1185    ) -> Result<&mut Self, ApiError>
1186    where
1187        F: 'static
1188            + Send
1189            + Sync
1190            + Fn(
1191                RequestParams,
1192                socket::Connection<ToClient, FromClient, Error, VER>,
1193                &State,
1194            ) -> BoxFuture<'_, Result<(), Error>>,
1195        ToClient: 'static + Serialize + ?Sized,
1196        FromClient: 'static + DeserializeOwned,
1197        State: 'static + Send + Sync,
1198        Error: 'static + Send + Display,
1199    {
1200        self.register_socket_handler(name, socket::handler(handler))
1201    }
1202
1203    /// Register a uni-directional handler for a SOCKET route.
1204    ///
1205    /// This function is very similar to [socket](Self::socket), but it permits the handler only to
1206    /// send messages to the client, not to receive messages back. As such, the handler does not
1207    /// take a [Connection](socket::Connection). Instead, it simply returns a stream of messages
1208    /// which are forwarded to the client as they are generated. If the stream ever yields an error,
1209    /// the error is propagated to the client and then the connection is closed.
1210    ///
1211    /// This function can be simpler to use than [socket](Self::socket) in case the handler does not
1212    /// need to receive messages from the client.
1213    pub fn stream<F, Msg>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
1214    where
1215        F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxStream<Result<Msg, Error>>,
1216        Msg: 'static + Serialize + Send + Sync,
1217        State: 'static + Send + Sync,
1218        Error: 'static + Send + Display,
1219        VER: 'static + Send + Sync,
1220    {
1221        self.register_socket_handler(name, socket::stream_handler::<_, _, _, _, VER>(handler))
1222    }
1223
1224    fn register_socket_handler(
1225        &mut self,
1226        name: &str,
1227        handler: socket::Handler<State, Error>,
1228    ) -> Result<&mut Self, ApiError> {
1229        let route = self
1230            .inner
1231            .routes
1232            .get_mut(name)
1233            .ok_or(ApiError::UndefinedRoute)?;
1234        if route.method() != Method::Socket {
1235            return Err(ApiError::IncorrectMethod {
1236                expected: Method::Socket,
1237                actual: route.method(),
1238            });
1239        }
1240        if route.has_handler() {
1241            return Err(ApiError::HandlerAlreadyRegistered);
1242        }
1243
1244        // `set_handler` only fails if the route is not a socket route; since we have already
1245        // checked that it is, this cannot fail.
1246        route
1247            .set_socket_handler(handler)
1248            .unwrap_or_else(|_| panic!("unexpected failure in set_socket_handler"));
1249        Ok(self)
1250    }
1251
1252    /// Register a handler for a METRICS route.
1253    ///
1254    /// When the server receives any request whose URL matches the pattern for this route and whose
1255    /// headers indicate it is a request for metrics, the server will invoke this `handler` instead
1256    /// of the regular HTTP handler for the endpoint. Instead of returning a typed object to
1257    /// serialize, `handler` will return a [Metrics] object which will be serialized to plaintext
1258    /// using the Prometheus format.
1259    ///
1260    /// A request is considered a request for metrics, for the purpose of dispatching to this
1261    /// handler, if the method is GET and the `Accept` header specifies `text/plain` as a better
1262    /// response type than `application/json` and `application/octet-stream` (other Tide Disco
1263    /// handlers respond to the content types `application/json` or `application/octet-stream`). As
1264    /// a special case, a request with no `Accept` header or `Accept: *` will return metrics when
1265    /// there is a metrics route matching the request URL, since metrics are given priority over
1266    /// other content types when multiple routes match the URL.
1267    ///
1268    /// # Examples
1269    ///
1270    /// A metrics endpoint which keeps track of how many times it has been called.
1271    ///
1272    /// `api.toml`
1273    ///
1274    /// ```toml
1275    /// [route.metrics]
1276    /// PATH = ["/metrics"]
1277    /// METHOD = "METRICS"
1278    /// DOC = "Export Prometheus metrics."
1279    /// ```
1280    ///
1281    /// ```
1282    /// # use async_std::sync::Mutex;
1283    /// # use futures::FutureExt;
1284    /// # use tide_disco::{api::{Api, ApiError}, error::ServerError};
1285    /// # use std::borrow::Cow;
1286    /// # use vbs::version::StaticVersion;
1287    /// use prometheus::{Counter, Registry};
1288    ///
1289    /// struct State {
1290    ///     counter: Counter,
1291    ///     metrics: Registry,
1292    /// }
1293    /// type StaticVer01 = StaticVersion<0, 1>;
1294    ///
1295    /// # fn ex(_api: Api<Mutex<State>, ServerError, StaticVer01>) -> Result<(), ApiError> {
1296    /// let mut api: Api<Mutex<State>, ServerError, StaticVer01>;
1297    /// # api = _api;
1298    /// api.metrics("metrics", |_req, state| async move {
1299    ///     state.counter.inc();
1300    ///     Ok(Cow::Borrowed(&state.metrics))
1301    /// }.boxed())?;
1302    /// # Ok(())
1303    /// # }
1304    /// ```
1305    //
1306    /// # Errors
1307    ///
1308    /// If the route `name` does not exist in the API specification, or if the route already has a
1309    /// handler registered, an error is returned. Note that all routes are initialized with a
1310    /// default handler that echoes parameters and shows documentation, but this default handler can
1311    /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
1312    ///
1313    /// If the route `name` exists, but the method is not METRICS (that is, `METHOD = "M"` was used
1314    /// in the route definition in `api.toml`, with `M` other than `METRICS`) the error
1315    /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
1316    ///
1317    /// # Limitations
1318    ///
1319    /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
1320    /// handler function is required to return a [BoxFuture].
1321    pub fn metrics<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
1322    where
1323        F: 'static
1324            + Send
1325            + Sync
1326            + Fn(RequestParams, &State::State) -> BoxFuture<Result<Cow<T>, Error>>,
1327        T: 'static + Clone + Metrics,
1328        State: 'static + Send + Sync + ReadState,
1329        Error: 'static,
1330        VER: 'static + Send + Sync,
1331    {
1332        let route = self
1333            .inner
1334            .routes
1335            .get_mut(name)
1336            .ok_or(ApiError::UndefinedRoute)?;
1337        if route.method() != Method::Metrics {
1338            return Err(ApiError::IncorrectMethod {
1339                expected: Method::Metrics,
1340                actual: route.method(),
1341            });
1342        }
1343        if route.has_handler() {
1344            return Err(ApiError::HandlerAlreadyRegistered);
1345        }
1346        // `set_metrics_handler` only fails if the route is not a metrics route; since we have
1347        // already checked that it is, this cannot fail.
1348        route
1349            .set_metrics_handler(handler)
1350            .unwrap_or_else(|_| panic!("unexpected failure in set_metrics_handler"));
1351        Ok(self)
1352    }
1353
1354    /// Set the health check handler for this API.
1355    ///
1356    /// This overrides the existing handler. If `health_check` has not yet been called, the default
1357    /// handler is one which simply returns `Health::default()`.
1358    pub fn with_health_check<H>(
1359        &mut self,
1360        handler: impl 'static + Send + Sync + Fn(&State) -> BoxFuture<H>,
1361    ) -> &mut Self
1362    where
1363        State: 'static + Send + Sync,
1364        H: 'static + HealthCheck,
1365        VER: 'static + Send + Sync,
1366    {
1367        self.inner.health_check = route::health_check_handler::<_, _, VER>(handler);
1368        self
1369    }
1370
1371    /// Create a new [Api] which is just like this one, except has a transformed `Error` type.
1372    pub(crate) fn map_err<Error2>(
1373        self,
1374        f: impl 'static + Clone + Send + Sync + Fn(Error) -> Error2,
1375    ) -> Api<State, Error2, VER>
1376    where
1377        Error: 'static + Send + Sync,
1378        Error2: 'static,
1379        State: 'static + Send + Sync,
1380    {
1381        Api {
1382            inner: ApiInner {
1383                meta: self.inner.meta,
1384                name: self.inner.name,
1385                routes: self
1386                    .inner
1387                    .routes
1388                    .into_iter()
1389                    .map(|(name, route)| (name, route.map_err(f.clone())))
1390                    .collect(),
1391                routes_by_path: self.inner.routes_by_path,
1392                health_check: self.inner.health_check,
1393                api_version: self.inner.api_version,
1394                error_handler: None,
1395                version_handler: self.inner.version_handler,
1396                public: self.inner.public,
1397                short_description: self.inner.short_description,
1398                long_description: self.inner.long_description,
1399            },
1400            _version: Default::default(),
1401        }
1402    }
1403
1404    pub(crate) fn into_inner(mut self) -> ApiInner<State, Error>
1405    where
1406        Error: crate::Error,
1407    {
1408        // This `into_inner` finalizes the error type for the API. At this point, ensure
1409        // `error_handler` is set.
1410        self.inner.error_handler = Some(error_handler::<Error, VER>());
1411        self.inner
1412    }
1413
1414    fn default_health_check(req: RequestParams, _state: &State) -> BoxFuture<'_, tide::Response> {
1415        async move {
1416            // If there is no healthcheck handler registered, just return [HealthStatus::Available]
1417            // by default; after all, if this handler is getting hit at all, the service must be up.
1418            route::health_check_response::<_, VER>(
1419                &req.accept().unwrap_or_else(|_| {
1420                    // The healthcheck endpoint is not allowed to fail, so just use the default
1421                    // content type if we can't parse the Accept header.
1422                    let mut accept = Accept::new();
1423                    accept.set_wildcard(true);
1424                    accept
1425                }),
1426                HealthStatus::Available,
1427            )
1428        }
1429        .boxed()
1430    }
1431}
1432
1433// `ReadHandler { handler }` essentially represents a handler function
1434// `move |req, state| async { state.read(|state| handler(req, state)).await.await }`. However, I
1435// cannot convince Rust that the future returned by this closure moves out of `req` while borrowing
1436// from `handler`, which is owned by the closure itself and thus outlives the closure body. This is
1437// partly due to the limitation where _all_ closure parameters must be captured either by value or
1438// by reference, and probably partly due to my lack of creativity. In any case, writing out the
1439// closure object and [Handler] implementation by hand seems to convince Rust that this code is
1440// memory safe.
1441struct ReadHandler<F, VER> {
1442    handler: F,
1443    _version: PhantomData<VER>,
1444}
1445
1446impl<F, VER> From<F> for ReadHandler<F, VER> {
1447    fn from(f: F) -> Self {
1448        Self {
1449            handler: f,
1450            _version: Default::default(),
1451        }
1452    }
1453}
1454
1455#[async_trait]
1456impl<State, Error, F, R, VER> Handler<State, Error> for ReadHandler<F, VER>
1457where
1458    F: 'static
1459        + Send
1460        + Sync
1461        + Fn(RequestParams, &<State as ReadState>::State) -> BoxFuture<'_, Result<R, Error>>,
1462    R: Serialize,
1463    State: 'static + Send + Sync + ReadState,
1464    VER: 'static + Send + Sync + StaticVersionType,
1465{
1466    async fn handle(
1467        &self,
1468        req: RequestParams,
1469        state: &State,
1470    ) -> Result<tide::Response, RouteError<Error>> {
1471        let accept = req.accept()?;
1472        response_from_result(
1473            &accept,
1474            state.read(|state| (self.handler)(req, state)).await,
1475            VER::instance(),
1476        )
1477    }
1478}
1479
1480// A manual closure that serves a similar purpose as [ReadHandler].
1481struct WriteHandler<F, VER> {
1482    handler: F,
1483    _version: PhantomData<VER>,
1484}
1485
1486impl<F, VER> From<F> for WriteHandler<F, VER> {
1487    fn from(f: F) -> Self {
1488        Self {
1489            handler: f,
1490            _version: Default::default(),
1491        }
1492    }
1493}
1494
1495#[async_trait]
1496impl<State, Error, F, R, VER> Handler<State, Error> for WriteHandler<F, VER>
1497where
1498    F: 'static
1499        + Send
1500        + Sync
1501        + Fn(RequestParams, &mut <State as ReadState>::State) -> BoxFuture<'_, Result<R, Error>>,
1502    R: Serialize,
1503    State: 'static + Send + Sync + WriteState,
1504    VER: 'static + Send + Sync + StaticVersionType,
1505{
1506    async fn handle(
1507        &self,
1508        req: RequestParams,
1509        state: &State,
1510    ) -> Result<tide::Response, RouteError<Error>> {
1511        let accept = req.accept()?;
1512        response_from_result(
1513            &accept,
1514            state.write(|state| (self.handler)(req, state)).await,
1515            VER::instance(),
1516        )
1517    }
1518}
1519
1520#[cfg(test)]
1521mod test {
1522    use crate::{
1523        error::{Error, ServerError},
1524        healthcheck::HealthStatus,
1525        socket::Connection,
1526        testing::{setup_test, test_ws_client, test_ws_client_with_headers, Client},
1527        App, StatusCode, Url,
1528    };
1529    use async_std::{sync::RwLock, task::spawn};
1530    use async_tungstenite::{
1531        tungstenite::{http::header::*, protocol::frame::coding::CloseCode, protocol::Message},
1532        WebSocketStream,
1533    };
1534    use futures::{
1535        stream::{iter, once, repeat},
1536        AsyncRead, AsyncWrite, FutureExt, SinkExt, StreamExt,
1537    };
1538    use portpicker::pick_unused_port;
1539    use prometheus::{Counter, Registry};
1540    use std::borrow::Cow;
1541    use toml::toml;
1542    use vbs::{
1543        version::{StaticVersion, StaticVersionType},
1544        BinarySerializer, Serializer,
1545    };
1546
1547    #[cfg(windows)]
1548    use async_tungstenite::tungstenite::Error as WsError;
1549    #[cfg(windows)]
1550    use std::io::ErrorKind;
1551
1552    type StaticVer01 = StaticVersion<0, 1>;
1553    type SerializerV01 = Serializer<StaticVersion<0, 1>>;
1554
1555    async fn check_stream_closed<S>(mut conn: WebSocketStream<S>)
1556    where
1557        S: AsyncRead + AsyncWrite + Unpin,
1558    {
1559        let msg = conn.next().await;
1560
1561        #[cfg(not(windows))]
1562        assert!(msg.is_none(), "{:?}", msg);
1563
1564        // Windows doesn't handle shutdown very gracefully.
1565        #[cfg(windows)]
1566        match msg {
1567            None => {}
1568            Some(Err(WsError::Io(err))) if err.kind() == ErrorKind::ConnectionAborted => {}
1569            msg => panic!(
1570                "expected end of stream or ConnectionAborted error, got {:?}",
1571                msg
1572            ),
1573        }
1574    }
1575
1576    #[async_std::test]
1577    async fn test_socket_endpoint() {
1578        setup_test();
1579
1580        let mut app = App::<_, ServerError>::with_state(RwLock::new(()));
1581        let api_toml = toml! {
1582            [meta]
1583            FORMAT_VERSION = "0.1.0"
1584
1585            [route.echo]
1586            PATH = ["/echo"]
1587            METHOD = "SOCKET"
1588
1589            [route.once]
1590            PATH = ["/once"]
1591            METHOD = "SOCKET"
1592
1593            [route.error]
1594            PATH = ["/error"]
1595            METHOD = "SOCKET"
1596        };
1597        {
1598            let mut api = app
1599                .module::<ServerError, StaticVer01>("mod", api_toml)
1600                .unwrap();
1601            api.socket(
1602                "echo",
1603                |_req, mut conn: Connection<String, String, _, StaticVer01>, _state| {
1604                    async move {
1605                        while let Some(msg) = conn.next().await {
1606                            conn.send(&msg?).await?;
1607                        }
1608                        Ok(())
1609                    }
1610                    .boxed()
1611                },
1612            )
1613            .unwrap()
1614            .socket(
1615                "once",
1616                |_req, mut conn: Connection<str, (), _, StaticVer01>, _state| {
1617                    async move {
1618                        conn.send("msg").boxed().await?;
1619                        Ok(())
1620                    }
1621                    .boxed()
1622                },
1623            )
1624            .unwrap()
1625            .socket(
1626                "error",
1627                |_req, _conn: Connection<(), (), _, StaticVer01>, _state| {
1628                    async move {
1629                        Err(ServerError::catch_all(
1630                            StatusCode::INTERNAL_SERVER_ERROR,
1631                            "an error message".to_string(),
1632                        ))
1633                    }
1634                    .boxed()
1635                },
1636            )
1637            .unwrap();
1638        }
1639        let port = pick_unused_port().unwrap();
1640        let url: Url = format!("http://localhost:{}", port).parse().unwrap();
1641        spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
1642
1643        // Create a client that accepts JSON messages.
1644        let mut conn = test_ws_client_with_headers(
1645            url.join("mod/echo").unwrap(),
1646            &[(ACCEPT, "application/json")],
1647        )
1648        .await;
1649
1650        // Send a JSON message.
1651        conn.send(Message::Text(serde_json::to_string("hello").unwrap()))
1652            .await
1653            .unwrap();
1654        assert_eq!(
1655            conn.next().await.unwrap().unwrap(),
1656            Message::Text(serde_json::to_string("hello").unwrap())
1657        );
1658
1659        // Send a binary message.
1660        conn.send(Message::Binary(
1661            SerializerV01::serialize("goodbye").unwrap(),
1662        ))
1663        .await
1664        .unwrap();
1665        assert_eq!(
1666            conn.next().await.unwrap().unwrap(),
1667            Message::Text(serde_json::to_string("goodbye").unwrap())
1668        );
1669
1670        // Create a client that accepts binary messages.
1671        let mut conn = test_ws_client_with_headers(
1672            url.join("mod/echo").unwrap(),
1673            &[(ACCEPT, "application/octet-stream")],
1674        )
1675        .await;
1676
1677        // Send a JSON message.
1678        conn.send(Message::Text(serde_json::to_string("hello").unwrap()))
1679            .await
1680            .unwrap();
1681        assert_eq!(
1682            conn.next().await.unwrap().unwrap(),
1683            Message::Binary(SerializerV01::serialize("hello").unwrap())
1684        );
1685
1686        // Send a binary message.
1687        conn.send(Message::Binary(
1688            SerializerV01::serialize("goodbye").unwrap(),
1689        ))
1690        .await
1691        .unwrap();
1692        assert_eq!(
1693            conn.next().await.unwrap().unwrap(),
1694            Message::Binary(SerializerV01::serialize("goodbye").unwrap())
1695        );
1696
1697        // Test a stream that exits normally.
1698        let mut conn = test_ws_client(url.join("mod/once").unwrap()).await;
1699        assert_eq!(
1700            conn.next().await.unwrap().unwrap(),
1701            Message::Text(serde_json::to_string("msg").unwrap())
1702        );
1703        match conn.next().await.unwrap().unwrap() {
1704            Message::Close(None) => {}
1705            msg => panic!("expected normal close frame, got {:?}", msg),
1706        };
1707        check_stream_closed(conn).await;
1708
1709        // Test a stream that errors.
1710        let mut conn = test_ws_client(url.join("mod/error").unwrap()).await;
1711        match conn.next().await.unwrap().unwrap() {
1712            Message::Close(Some(frame)) => {
1713                assert_eq!(frame.code, CloseCode::Error);
1714                assert_eq!(frame.reason, "Error 500: an error message");
1715            }
1716            msg => panic!("expected error close frame, got {:?}", msg),
1717        }
1718        check_stream_closed(conn).await;
1719    }
1720
1721    #[async_std::test]
1722    async fn test_stream_endpoint() {
1723        setup_test();
1724
1725        let mut app = App::<_, ServerError>::with_state(RwLock::new(()));
1726        let api_toml = toml! {
1727            [meta]
1728            FORMAT_VERSION = "0.1.0"
1729
1730            [route.nat]
1731            PATH = ["/nat"]
1732            METHOD = "SOCKET"
1733
1734            [route.once]
1735            PATH = ["/once"]
1736            METHOD = "SOCKET"
1737
1738            [route.error]
1739            PATH = ["/error"]
1740            METHOD = "SOCKET"
1741        };
1742        {
1743            let mut api = app
1744                .module::<ServerError, StaticVer01>("mod", api_toml)
1745                .unwrap();
1746            api.stream("nat", |_req, _state| iter(0..).map(Ok).boxed())
1747                .unwrap()
1748                .stream("once", |_req, _state| once(async { Ok(0) }).boxed())
1749                .unwrap()
1750                .stream::<_, ()>("error", |_req, _state| {
1751                    // We intentionally return a stream that never terminates, to check that simply
1752                    // yielding an error causes the connection to terminate.
1753                    repeat(Err(ServerError::catch_all(
1754                        StatusCode::INTERNAL_SERVER_ERROR,
1755                        "an error message".to_string(),
1756                    )))
1757                    .boxed()
1758                })
1759                .unwrap();
1760        }
1761        let port = pick_unused_port().unwrap();
1762        let url: Url = format!("http://localhost:{}", port).parse().unwrap();
1763        spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
1764
1765        // Consume the `nat` stream.
1766        let mut conn = test_ws_client(url.join("mod/nat").unwrap()).await;
1767        for i in 0..100 {
1768            assert_eq!(
1769                conn.next().await.unwrap().unwrap(),
1770                Message::Text(serde_json::to_string(&i).unwrap())
1771            );
1772        }
1773
1774        // Test a finite stream.
1775        let mut conn = test_ws_client(url.join("mod/once").unwrap()).await;
1776        assert_eq!(
1777            conn.next().await.unwrap().unwrap(),
1778            Message::Text(serde_json::to_string(&0).unwrap())
1779        );
1780        match conn.next().await.unwrap().unwrap() {
1781            Message::Close(None) => {}
1782            msg => panic!("expected normal close frame, got {:?}", msg),
1783        }
1784        check_stream_closed(conn).await;
1785
1786        // Test a stream that errors.
1787        let mut conn = test_ws_client(url.join("mod/error").unwrap()).await;
1788        match conn.next().await.unwrap().unwrap() {
1789            Message::Close(Some(frame)) => {
1790                assert_eq!(frame.code, CloseCode::Error);
1791                assert_eq!(frame.reason, "Error 500: an error message");
1792            }
1793            msg => panic!("expected error close frame, got {:?}", msg),
1794        }
1795        check_stream_closed(conn).await;
1796    }
1797
1798    #[async_std::test]
1799    async fn test_custom_healthcheck() {
1800        setup_test();
1801
1802        let mut app = App::<_, ServerError>::with_state(HealthStatus::Available);
1803        let api_toml = toml! {
1804            [meta]
1805            FORMAT_VERSION = "0.1.0"
1806
1807            [route.dummy]
1808            PATH = ["/dummy"]
1809        };
1810        {
1811            let mut api = app
1812                .module::<ServerError, StaticVer01>("mod", api_toml)
1813                .unwrap();
1814            api.with_health_check(|state| async move { *state }.boxed());
1815        }
1816        let port = pick_unused_port().unwrap();
1817        let url: Url = format!("http://localhost:{}", port).parse().unwrap();
1818        spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
1819        let client = Client::new(url).await;
1820
1821        let res = client.get("/mod/healthcheck").send().await.unwrap();
1822        assert_eq!(res.status(), StatusCode::OK);
1823        assert_eq!(
1824            res.json::<HealthStatus>().await.unwrap(),
1825            HealthStatus::Available
1826        );
1827    }
1828
1829    #[async_std::test]
1830    async fn test_metrics_endpoint() {
1831        setup_test();
1832
1833        struct State {
1834            metrics: Registry,
1835            counter: Counter,
1836        }
1837
1838        let counter = Counter::new(
1839            "counter",
1840            "count of how many times metrics have been exported",
1841        )
1842        .unwrap();
1843        let metrics = Registry::new();
1844        metrics.register(Box::new(counter.clone())).unwrap();
1845        let state = State { metrics, counter };
1846
1847        let mut app = App::<_, ServerError>::with_state(RwLock::new(state));
1848        let api_toml = toml! {
1849            [meta]
1850            FORMAT_VERSION = "0.1.0"
1851
1852            [route.metrics]
1853            PATH = ["/metrics"]
1854            METHOD = "METRICS"
1855        };
1856        {
1857            let mut api = app
1858                .module::<ServerError, StaticVer01>("mod", api_toml)
1859                .unwrap();
1860            api.metrics("metrics", |_req, state| {
1861                async move {
1862                    state.counter.inc();
1863                    Ok(Cow::Borrowed(&state.metrics))
1864                }
1865                .boxed()
1866            })
1867            .unwrap();
1868        }
1869        let port = pick_unused_port().unwrap();
1870        let url: Url = format!("http://localhost:{port}").parse().unwrap();
1871        spawn(app.serve(format!("0.0.0.0:{port}"), StaticVer01::instance()));
1872        let client = Client::new(url).await;
1873
1874        for i in 1..5 {
1875            tracing::info!("making metrics request {i}");
1876            let expected = format!("# HELP counter count of how many times metrics have been exported\n# TYPE counter counter\ncounter {i}\n");
1877            let res = client.get("mod/metrics").send().await.unwrap();
1878            assert_eq!(res.status(), StatusCode::OK);
1879            assert_eq!(res.text().await.unwrap(), expected);
1880        }
1881    }
1882}