tide_disco/
app.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the tide-disco library.
3
4// You should have received a copy of the MIT License
5// along with the tide-disco library. If not, see <https://mit-license.org/>.
6
7use crate::{
8    api::{Api, ApiError, ApiInner, ApiVersion},
9    dispatch::{self, DispatchError, Trie},
10    healthcheck::{HealthCheck, HealthStatus},
11    http,
12    method::Method,
13    middleware::{request_params, AddErrorBody, MetricsMiddleware},
14    request::RequestParams,
15    route::{health_check_response, respond_with, Handler, Route, RouteError},
16    socket::SocketError,
17    Html, StatusCode,
18};
19use async_std::sync::Arc;
20use derive_more::From;
21use futures::future::{BoxFuture, FutureExt};
22use include_dir::{include_dir, Dir};
23use lazy_static::lazy_static;
24use maud::{html, PreEscaped};
25use rand::Rng;
26use semver::Version;
27use serde::{Deserialize, Serialize};
28use serde_with::{serde_as, DisplayFromStr};
29use snafu::{ResultExt, Snafu};
30use std::{
31    collections::btree_map::BTreeMap,
32    convert::Infallible,
33    env, fs, io,
34    ops::{Deref, DerefMut},
35    path::PathBuf,
36};
37use tide::{
38    http::{headers::HeaderValue, mime::HTML},
39    security::{CorsMiddleware, Origin},
40};
41use tide_websockets::WebSocket;
42use vbs::version::StaticVersionType;
43
44pub use tide::listener::{Listener, ToListener};
45
46/// A tide-disco server application.
47///
48/// An [App] is a collection of API modules, plus a global `State`. Modules can be registered by
49/// constructing an [Api] for each module and calling [App::register_module]. Once all of the
50/// desired modules are registered, the app can be converted into an asynchronous server task using
51/// [App::serve].
52///
53/// Note that the [`App`] is bound to a binary serialization version `VER`. This format only applies
54/// to application-level endpoints like `/version` and `/healthcheck`. The binary format version in
55/// use by any given API module may differ, depending on the supported version of the API.
56#[derive(Debug)]
57pub struct App<State, Error> {
58    pub(crate) modules: Trie<ApiInner<State, Error>>,
59    pub(crate) state: Arc<State>,
60    app_version: Option<Version>,
61}
62
63/// An error encountered while building an [App].
64#[derive(Clone, Debug, From, Snafu, PartialEq, Eq)]
65pub enum AppError {
66    Api { source: ApiError },
67    Dispatch { source: DispatchError },
68}
69
70impl<State: Send + Sync + 'static, Error: 'static> App<State, Error> {
71    /// Create a new [App] with a given state.
72    pub fn with_state(state: State) -> Self {
73        Self {
74            modules: Default::default(),
75            state: Arc::new(state),
76            app_version: None,
77        }
78    }
79
80    /// Create and register an API module.
81    ///
82    /// Creates a new [`Api`] with the given `api` specification and returns an RAII guard for this
83    /// API. The guard can be used to access the API module, configure it, and populate its
84    /// handlers. When [`Module::register`] is called on the guard (or the guard is dropped), the
85    /// module will be registered in this [`App`] as if by calling
86    /// [`register_module`](Self::register_module).
87    pub fn module<'a, ModuleError, ModuleVersion>(
88        &'a mut self,
89        base_url: &'a str,
90        api: impl Into<toml::Value>,
91    ) -> Result<Module<'a, State, Error, ModuleError, ModuleVersion>, AppError>
92    where
93        Error: crate::Error + From<ModuleError>,
94        ModuleError: Send + Sync + 'static,
95        ModuleVersion: StaticVersionType + 'static,
96    {
97        Ok(Module {
98            app: self,
99            base_url,
100            api: Some(Api::new(api).context(ApiSnafu)?),
101        })
102    }
103
104    /// Register an API module.
105    ///
106    /// The module `api` will be registered as an implementation of the module hosted under the URL
107    /// prefix `base_url`.
108    ///
109    /// # Versioning
110    ///
111    /// Multiple versions of the same [`Api`] may be registered by calling this function several
112    /// times with the same `base_url`, and passing in different APIs which must have different
113    /// _major_ versions. The API version can be set using [`Api::with_version`].
114    ///
115    /// When multiple versions of the same API are registered, requests for endpoints directly under
116    /// the base URL, like `GET /base_url/endpoint`, will always be dispatched to the latest
117    /// available version of the API. There will in addition be an extension of `base_url` for each
118    /// major version registered, so `GET /base_url/v1/endpoint` will always dispatch to the
119    /// `endpoint` handler in the module with major version 1, if it exists, regardless of what the
120    /// latest version is.
121    ///
122    /// It is an error to register multiple versions of the same module with the same major version.
123    /// It is _not_ an error to register non-sequential versions of a module. For example, you could
124    /// have `/base_url/v2` and `/base_url/v4`, but not `v1` or `v3`. Requests for `v1` or `v3` will
125    /// simply fail.
126    ///
127    /// The intention of this functionality is to allow for non-disruptive breaking updates. Rather
128    /// than deploying a new major version of the API with breaking changes _in place of_ the old
129    /// version, breaking all your clients, you can continue to serve the old version for some
130    /// period of time under a version prefix. Clients can point at this version prefix until they
131    /// update their software to use the new version, on their own time.
132    ///
133    /// Note that non-breaking changes (e.g. new endpoints) can be deployed in place of an existing
134    /// API without even incrementing the major version. The need for serving two versions of an API
135    /// simultaneously only arises when you have breaking changes.
136    pub fn register_module<ModuleError, ModuleVersion>(
137        &mut self,
138        base_url: &str,
139        api: Api<State, ModuleError, ModuleVersion>,
140    ) -> Result<&mut Self, AppError>
141    where
142        Error: crate::Error + From<ModuleError>,
143        ModuleError: Send + Sync + 'static,
144        ModuleVersion: StaticVersionType + 'static,
145    {
146        let mut api = api.map_err(Error::from).into_inner();
147        api.set_name(base_url.to_string());
148
149        let major_version = match api.version().api_version {
150            Some(version) => version.major,
151            None => {
152                // If no version is explicitly specified, default to 0.
153                0
154            }
155        };
156
157        self.modules
158            .insert(dispatch::split(base_url), major_version, api)?;
159        Ok(self)
160    }
161
162    /// Set the application version.
163    ///
164    /// The version information will automatically be included in responses to `GET /version`.
165    ///
166    /// This is the version of the overall application, which may encompass several APIs, each with
167    /// their own version. Changes to the version of any of the APIs which make up this application
168    /// should imply a change to the application version, but the application version may also
169    /// change without changing any of the API versions.
170    ///
171    /// This version is optional, as the `/version` endpoint will automatically include the version
172    /// of each registered API, which is usually enough to uniquely identify the application. Set
173    /// this explicitly if you want to track the version of additional behavior or interfaces which
174    /// are not encompassed by the sub-modules of this application.
175    ///
176    /// If you set an application version, it is a good idea to use the version of the application
177    /// crate found in Cargo.toml. This can be automatically found at build time using the
178    /// environment variable `CARGO_PKG_VERSION` and the [env!] macro. As long as the following code
179    /// is contained in the application crate, it should result in a reasonable version:
180    ///
181    /// ```
182    /// # use vbs::version::StaticVersion;
183    /// # type StaticVer01 = StaticVersion<0, 1>;
184    /// # fn ex(app: &mut tide_disco::App<(), ()>) {
185    /// app.with_version(env!("CARGO_PKG_VERSION").parse().unwrap());
186    /// # }
187    /// ```
188    pub fn with_version(&mut self, version: Version) -> &mut Self {
189        self.app_version = Some(version);
190        self
191    }
192
193    /// Get the version of this application.
194    pub fn version(&self) -> AppVersion {
195        AppVersion {
196            app_version: self.app_version.clone(),
197            disco_version: env!("CARGO_PKG_VERSION").parse().unwrap(),
198            modules: self
199                .modules
200                .iter()
201                .map(|module| {
202                    (
203                        module.path(),
204                        module
205                            .versions
206                            .values()
207                            .rev()
208                            .map(|api| api.version())
209                            .collect(),
210                    )
211                })
212                .collect(),
213        }
214    }
215
216    /// Check the health of each registered module in response to a request.
217    ///
218    /// The response includes a status code for each module, which will be [StatusCode::OK] if the
219    /// module is healthy. Detailed health status from each module is not included in the response
220    /// (due to type erasure) but can be queried using [module_health](Self::module_health) or by
221    /// hitting the endpoint `GET /:module/healthcheck`.
222    pub async fn health(&self, req: RequestParams, state: &State) -> AppHealth {
223        let mut modules_health = BTreeMap::<String, BTreeMap<_, _>>::new();
224        let mut status = HealthStatus::Available;
225        for module in &self.modules {
226            let versions_health = modules_health.entry(module.path()).or_default();
227            for (version, api) in &module.versions {
228                let health = StatusCode::from(api.health(req.clone(), state).await.status());
229                if health != StatusCode::OK {
230                    status = HealthStatus::Unhealthy;
231                }
232                versions_health.insert(*version, health);
233            }
234        }
235        AppHealth {
236            status,
237            modules: modules_health,
238        }
239    }
240
241    /// Check the health of the named module.
242    ///
243    /// The resulting [Response](tide::Response) has a status code which is [StatusCode::OK] if the
244    /// module is healthy. The response body is constructed from the results of the module's
245    /// registered healthcheck handler. If the module does not have an explicit healthcheck
246    /// handler, the response will be a [HealthStatus].
247    ///
248    /// `major_version` can be used to query the health status of a specific version of the desired
249    /// module. If it is not provided, the most recent supported version will be queried.
250    ///
251    /// If there is no module with the given name or version, returns [None].
252    pub async fn module_health(
253        &self,
254        req: RequestParams,
255        state: &State,
256        module: &str,
257        major_version: Option<u64>,
258    ) -> Option<tide::Response> {
259        let module = self.modules.get(dispatch::split(module))?;
260        let api = match major_version {
261            Some(v) => module.versions.get(&v)?,
262            None => module.versions.last_key_value()?.1,
263        };
264        Some(api.health(req, state).await)
265    }
266}
267
268static DEFAULT_PUBLIC_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/public/media");
269lazy_static! {
270    static ref DEFAULT_PUBLIC_PATH: PathBuf = {
271        // Generate a random number to index into `/tmp` with
272        let mut rng = rand::thread_rng();
273        let index: u64 = rng.gen();
274
275        // The contents of the default public directory are included in the binary. The first time
276        // the default directory is used, if ever, we extract them to a directory on the host file
277        // system and return the path to that directory.
278        let path = PathBuf::from(format!("/tmp/tide-disco/{}/public/media", index));
279        // If the path already exists, move it aside so we can update it.
280        let _ = fs::rename(&path, path.with_extension("old"));
281        DEFAULT_PUBLIC_DIR.extract(&path).unwrap();
282        path
283    };
284}
285
286impl<State, Error> App<State, Error>
287where
288    State: Send + Sync + 'static,
289    Error: 'static + crate::Error,
290{
291    /// Serve the [App] asynchronously.
292    ///
293    /// `VER` controls the binary format version used for responses to top-level endpoints like
294    /// `/version` and `/healthcheck`. All endpoints for specific API modules will use the format
295    /// version of that module (`ModuleVersion` when the module was
296    /// [registered](Self::register_module)).
297    pub async fn serve<L, VER>(self, listener: L, bind_version: VER) -> io::Result<()>
298    where
299        L: ToListener<Arc<Self>>,
300        VER: StaticVersionType + 'static,
301    {
302        let state = Arc::new(self);
303        let mut server = tide::Server::with_state(state.clone());
304        server.with(Self::version_middleware);
305        server.with(AddErrorBody::<Error>::with_version::<VER>());
306        server.with(
307            CorsMiddleware::new()
308                .allow_methods("GET, POST".parse::<HeaderValue>().unwrap())
309                .allow_headers("*".parse::<HeaderValue>().unwrap())
310                .allow_origin(Origin::from("*"))
311                .allow_credentials(true),
312        );
313
314        for module in &state.modules {
315            Self::register_api(&mut server, module.prefix.clone(), &module.versions)?;
316        }
317
318        // Register app-level routes summarizing the status and documentation of all the registered
319        // modules. We skip this step if this is a singleton app with only one module registered at
320        // the root URL, as these app-level endpoints would conflict with the (probably more
321        // specific) API-level status endpoints.
322        if !state.modules.is_singleton() {
323            // Register app-level automatic routes: `healthcheck` and `version`.
324            server
325                .at("healthcheck")
326                .get(move |req: tide::Request<Arc<Self>>| async move {
327                    let state = req.state().clone();
328                    let app_state = &*state.state;
329                    let req = request_params(req, &[]).await?;
330                    let accept = req.accept()?;
331                    let res = state.health(req, app_state).await;
332                    Ok(health_check_response::<_, VER>(&accept, res))
333                });
334            server
335                .at("version")
336                .get(move |req: tide::Request<Arc<Self>>| async move {
337                    let accept = RequestParams::accept_from_headers(&req)?;
338                    respond_with(&accept, req.state().version(), bind_version)
339                        .map_err(|err| Error::from_route_error::<Infallible>(err).into_tide_error())
340                });
341
342            // Serve documentation at the root URL for discoverability
343            server
344                .at("/")
345                .all(move |req: tide::Request<Arc<Self>>| async move {
346                    Ok(tide::Response::from(Self::top_level_docs(req)))
347                });
348        }
349
350        server.listen(listener).await
351    }
352
353    fn list_apis(&self) -> Html {
354        html! {
355            ul {
356                @for module in &self.modules {
357                    li {
358                        // Link to the alias for the latest version as the primary link.
359                        a href=(format!("/{}", module.path())) {(module.path())}
360                        // Add a superscript link (link a footnote) for each specific supported
361                        // version, linking to documentation for that specific version.
362                        @for version in module.versions.keys().rev() {
363                            sup {
364                                a href=(format!("/v{version}/{}", module.path())) {
365                                    (format!("[v{version}]"))
366                                }
367                            }
368                        }
369                        " "
370                        // Take the description of the latest supported version.
371                        (PreEscaped(module.versions.last_key_value().unwrap().1.short_description()))
372                    }
373                }
374            }
375        }
376    }
377
378    fn register_api(
379        server: &mut tide::Server<Arc<Self>>,
380        prefix: Vec<String>,
381        versions: &BTreeMap<u64, ApiInner<State, Error>>,
382    ) -> io::Result<()> {
383        for (version, api) in versions {
384            Self::register_api_version(server, &prefix, *version, api)?;
385        }
386        Ok(())
387    }
388
389    fn register_api_version(
390        server: &mut tide::Server<Arc<Self>>,
391        prefix: &[String],
392        version: u64,
393        api: &ApiInner<State, Error>,
394    ) -> io::Result<()> {
395        // Clippy complains if the only non-trivial operation in an `unwrap_or_else` closure is
396        // a deref, but for `lazy_static` types, deref is an effectful operation that (in this
397        // case) causes a directory to be renamed and another extracted. We only want to execute
398        // this if we need to (if `api.public()` is `None`) so we disable the lint.
399        #[allow(clippy::unnecessary_lazy_evaluations)]
400        server
401            .at("/public")
402            .at(&format!("v{version}"))
403            .at(&prefix.join("/"))
404            .serve_dir(api.public().unwrap_or_else(|| &DEFAULT_PUBLIC_PATH))?;
405
406        // Register routes for this API.
407        let mut version_endpoint = server.at(&format!("/v{version}"));
408        let mut api_endpoint = if prefix.is_empty() {
409            version_endpoint
410        } else {
411            version_endpoint.at(&prefix.join("/"))
412        };
413        api_endpoint.with(AddErrorBody::new(api.error_handler()));
414        for (path, routes) in api.routes_by_path() {
415            let mut endpoint = api_endpoint.at(path);
416            let routes = routes.collect::<Vec<_>>();
417
418            // Register socket and metrics middlewares. These must be registered before any
419            // regular HTTP routes, because Tide only applies middlewares to routes which were
420            // already registered before the route handler.
421            if let Some(socket_route) = routes.iter().find(|route| route.method() == Method::Socket)
422            {
423                // If there is a socket route with this pattern, add the socket middleware to
424                // all endpoints registered under this pattern, so that any request with any
425                // method that has the socket upgrade headers will trigger a WebSockets upgrade.
426                Self::register_socket(prefix.to_vec(), version, &mut endpoint, socket_route);
427            }
428            if let Some(metrics_route) = routes
429                .iter()
430                .find(|route| route.method() == Method::Metrics)
431            {
432                // If there is a metrics route with this pattern, add the metrics middleware to
433                // all endpoints registered under this pattern, so that a request to this path
434                // with the right headers will return metrics instead of going through the
435                // normal method-based dispatching.
436                Self::register_metrics(prefix.to_vec(), version, &mut endpoint, metrics_route);
437            }
438
439            // Register the HTTP routes.
440            for route in routes {
441                if let Method::Http(method) = route.method() {
442                    Self::register_route(prefix.to_vec(), version, &mut endpoint, route, method);
443                }
444            }
445        }
446
447        // Register automatic routes for this API: documentation, `healthcheck` and `version`. Serve
448        // documentation at the root of the API (with or without a trailing slash).
449        for path in ["", "/"] {
450            let prefix = prefix.to_vec();
451            api_endpoint
452                .at(path)
453                .all(move |req: tide::Request<Arc<Self>>| {
454                    let prefix = prefix.clone();
455                    async move {
456                        let api = &req.state().clone().modules[&prefix].versions[&version];
457                        Ok(api.documentation())
458                    }
459                });
460        }
461        {
462            let prefix = prefix.to_vec();
463            api_endpoint
464                .at("*path")
465                .all(move |req: tide::Request<Arc<Self>>| {
466                    let prefix = prefix.clone();
467                    async move {
468                        // The request did not match any route. Serve documentation for the API.
469                        let api = &req.state().clone().modules[&prefix].versions[&version];
470                        let docs = html! {
471                            "No route matches /" (req.param("path")?)
472                            br{}
473                            (api.documentation())
474                        };
475                        Ok(tide::Response::builder(StatusCode::NOT_FOUND)
476                            .body(docs.into_string())
477                            .build())
478                    }
479                });
480        }
481        {
482            let prefix = prefix.to_vec();
483            api_endpoint
484                .at("healthcheck")
485                .get(move |req: tide::Request<Arc<Self>>| {
486                    let prefix = prefix.clone();
487                    async move {
488                        let api = &req.state().clone().modules[&prefix].versions[&version];
489                        let state = req.state().clone();
490                        Ok(api
491                            .health(request_params(req, &[]).await?, &state.state)
492                            .await)
493                    }
494                });
495        }
496        {
497            let prefix = prefix.to_vec();
498            api_endpoint
499                .at("version")
500                .get(move |req: tide::Request<Arc<Self>>| {
501                    let prefix = prefix.clone();
502                    async move {
503                        let api = &req.state().modules[&prefix].versions[&version];
504                        let accept = RequestParams::accept_from_headers(&req)?;
505                        api.version_handler()(&accept, api.version())
506                            .map_err(|err| Error::from_route_error(err).into_tide_error())
507                    }
508                });
509        }
510
511        Ok(())
512    }
513
514    fn register_route(
515        api: Vec<String>,
516        version: u64,
517        endpoint: &mut tide::Route<Arc<Self>>,
518        route: &Route<State, Error>,
519        method: http::Method,
520    ) {
521        let name = route.name();
522        endpoint.method(method, move |req: tide::Request<Arc<Self>>| {
523            let name = name.clone();
524            let api = api.clone();
525            async move {
526                let route = &req.state().clone().modules[&api].versions[&version][&name];
527                let state = &*req.state().clone().state;
528                let req = request_params(req, route.params()).await?;
529                route
530                    .handle(req, state)
531                    .await
532                    .map_err(|err| match err {
533                        RouteError::AppSpecific(err) => err,
534                        _ => Error::from_route_error(err),
535                    })
536                    .map_err(|err| err.into_tide_error())
537            }
538        });
539    }
540
541    fn register_metrics(
542        api: Vec<String>,
543        version: u64,
544        endpoint: &mut tide::Route<Arc<Self>>,
545        route: &Route<State, Error>,
546    ) {
547        let name = route.name();
548        if route.has_handler() {
549            // If there is a metrics handler, add middleware to the endpoint to intercept the
550            // request and respond with metrics, rather than the usual HTTP dispatching, if the
551            // appropriate headers are set.
552            endpoint.with(MetricsMiddleware::new(name.clone(), api.clone(), version));
553        }
554
555        // Register a catch-all HTTP handler for the route, which serves the route documentation as
556        // HTML. This ensures that there is at least one endpoint registered with the Tide
557        // dispatcher, so that the middleware actually fires on requests to this path. In addition,
558        // this handler will trigger for requests that are not otherwise valid, aiding in
559        // discoverability.
560        //
561        // We register the default handler using `all`, which makes it act as a fallback handler.
562        // This means if there are other, non-metrics routes with this same path, we will still
563        // dispatch to them if the path is hit with the appropriate method.
564        Self::register_fallback(api, version, endpoint, route);
565    }
566
567    fn register_socket(
568        api: Vec<String>,
569        version: u64,
570        endpoint: &mut tide::Route<Arc<Self>>,
571        route: &Route<State, Error>,
572    ) {
573        let name = route.name();
574        if route.has_handler() {
575            // If there is a socket handler, add the [WebSocket] middleware to the endpoint, so that
576            // upgrade requests will automatically upgrade to a WebSockets connection.
577            let name = name.clone();
578            let api = api.clone();
579            endpoint.with(WebSocket::new(
580                move |req: tide::Request<Arc<Self>>, conn| {
581                    let name = name.clone();
582                    let api = api.clone();
583                    async move {
584                        let route = &req.state().clone().modules[&api].versions[&version][&name];
585                        let state = &*req.state().clone().state;
586                        let req = request_params(req, route.params()).await?;
587                        route
588                            .handle_socket(req, conn, state)
589                            .await
590                            .map_err(|err| match err {
591                                SocketError::AppSpecific(err) => err,
592                                _ => Error::from_socket_error(err),
593                            })
594                            .map_err(|err| err.into_tide_error())
595                    }
596                },
597            ));
598        }
599
600        // Register a catch-all HTTP handler for the route, which serves the route documentation as
601        // HTML. This ensures that there is at least one endpoint registered with the Tide
602        // dispatcher, so that the middleware actually fires on requests to this path. In addition,
603        // this handler will trigger for requests that are not valid WebSockets handshakes. The
604        // documentation should make clear that this is a WebSockets endpoint, aiding in
605        // discoverability. This will also trigger if there is no socket handler for this route,
606        // which will signal to the developer that they need to implement a socket handler for this
607        // route to work.
608        //
609        // We register the default handler using `all`, which makes it act as a fallback handler.
610        // This means if there are other, non-socket routes with this same path, we will still
611        // dispatch to them if the path is hit with the appropriate method.
612        Self::register_fallback(api, version, endpoint, route);
613    }
614
615    fn register_fallback(
616        api: Vec<String>,
617        version: u64,
618        endpoint: &mut tide::Route<Arc<Self>>,
619        route: &Route<State, Error>,
620    ) {
621        let name = route.name();
622        endpoint.all(move |req: tide::Request<Arc<Self>>| {
623            let name = name.clone();
624            let api = api.clone();
625            async move {
626                let route = &req.state().clone().modules[&api].versions[&version][&name];
627                route
628                    .default_handler()
629                    .map_err(|err| match err {
630                        RouteError::AppSpecific(err) => err,
631                        _ => Error::from_route_error(err),
632                    })
633                    .map_err(|err| err.into_tide_error())
634            }
635        });
636    }
637
638    /// Server middleware which returns redirect responses for requests lacking an explicit version
639    /// prefix.
640    fn version_middleware(
641        req: tide::Request<Arc<Self>>,
642        next: tide::Next<Arc<Self>>,
643    ) -> BoxFuture<tide::Result> {
644        async move {
645            let Some(path) = req.url().path_segments() else {
646                // If we can't parse the path, we can't run this middleware. Do our best by
647                // continuing the request processing lifecycle.
648                return Ok(next.run(req).await);
649            };
650            let path = path.collect::<Vec<_>>();
651            let Some(seg1) = path.first() else {
652                // This is the root URL, with no path segments. Nothing for this middleware to do.
653                return Ok(next.run(req).await);
654            };
655            if seg1.is_empty() {
656                // This is the root URL, with no path segments. Nothing for this middleware to do.
657                return Ok(next.run(req).await);
658            }
659
660            // The first segment is either a version identifier or (part of) an API identifier
661            // (implicitly requesting the latest version of the API). We handle these cases
662            // differently.
663            if let Some(version) = seg1.strip_prefix('v').and_then(|n| n.parse().ok()) {
664                // If the version identifier is present, we probably don't need a redirect. However,
665                // we still check if this is a valid version for the request API. If not, we will
666                // serve documentation listing the available versions.
667                let Some(module) = req.state().modules.search(&path[1..]) else {
668                    let message = html! {
669                        ("No API matches ")
670                        span style = "font-family: monospace" {
671                            (format!("/{}", path[1..].join("/")))
672                        }
673                    };
674                    return Ok(Self::top_level_error(req, StatusCode::NOT_FOUND, message));
675                };
676                if !module.versions.contains_key(&version) {
677                    // This version is not supported, list suported versions.
678                    return Ok(html! {
679                        "Unsupported version v" (version) ". Supported versions are:"
680                        ul {
681                            @for v in module.versions.keys().rev() {
682                                li {
683                                    a href=(format!("/v{v}/{}", module.path())) { "v" (v) }
684                                }
685                            }
686                        }
687                    }
688                    .into());
689                }
690
691                // This is a valid request with a specific version. It should be handled
692                // successfully by the route handlers for this API.
693                Ok(next.run(req).await)
694            } else {
695                // If the first path segment is not a version prefix, then the path is either the
696                // name of an API (implicitly requesting the latest version) or one of the magic
697                // top-level endpoints (version, healthcheck). Validate the API and then redirect.
698                if !req.state().modules.is_singleton() && ["version", "healthcheck"].contains(seg1)
699                {
700                    return Ok(next.run(req).await);
701                }
702                let Some(module) = req.state().modules.search(&path) else {
703                    let message = html! {
704                        ("No API matches ")
705                        span style = "font-family: monospace" {
706                            (format!("/{}", path.join("/")))
707                        }
708                    };
709                    return Ok(Self::top_level_error(req, StatusCode::NOT_FOUND, message));
710                };
711
712                let latest_version = *module.versions.last_key_value().unwrap().0;
713                let path = path.join("/");
714                Ok(tide::Redirect::temporary(format!("/v{latest_version}/{path}")).into())
715            }
716        }
717        .boxed()
718    }
719
720    /// Top-level documentation about the app.
721    fn top_level_docs(req: tide::Request<Arc<Self>>) -> PreEscaped<String> {
722        html! {
723            "This is a Tide Disco app composed of the following modules:"
724            (req.state().list_apis())
725        }
726    }
727
728    /// Documentation served when there is a routing error at the app level.
729    fn top_level_error(
730        req: tide::Request<Arc<Self>>,
731        status: StatusCode,
732        message: PreEscaped<String>,
733    ) -> tide::Response {
734        let docs = html! {
735            p style = "color:red" {
736                (message)
737            }
738            (Self::top_level_docs(req))
739        };
740        tide::Response::builder(status)
741            .body(docs.into_string())
742            .content_type(HTML)
743            .build()
744    }
745}
746
747/// The health status of an application.
748#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
749pub struct AppHealth {
750    /// The status of the overall application.
751    ///
752    /// [HealthStatus::Available] if all of the application's modules are healthy, otherwise a
753    /// [HealthStatus] variant with [status](HealthCheck::status) other than 200.
754    pub status: HealthStatus,
755    /// The status of each registered module, indexed by version.
756    pub modules: BTreeMap<String, BTreeMap<u64, StatusCode>>,
757}
758
759impl HealthCheck for AppHealth {
760    fn status(&self) -> StatusCode {
761        self.status.status()
762    }
763}
764
765/// Version information about an application.
766#[serde_as]
767#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
768pub struct AppVersion {
769    /// The supported versions of each module registered with this application.
770    ///
771    /// Versions for each module are ordered from newest to oldest.
772    pub modules: BTreeMap<String, Vec<ApiVersion>>,
773
774    /// The version of this application.
775    #[serde_as(as = "Option<DisplayFromStr>")]
776    pub app_version: Option<Version>,
777
778    /// The version of the Tide Disco server framework.
779    #[serde_as(as = "DisplayFromStr")]
780    pub disco_version: Version,
781}
782
783/// RAII guard to ensure a module is registered after it is configured.
784///
785/// This type allows the owner to configure an [`Api`] module via the [`Deref`] and [`DerefMut`]
786/// traits. Once the API is configured, this object can be dropped, which will automatically
787/// register the module with the [`App`].
788///
789/// # Panics
790///
791/// Note that if anything goes wrong during module registration (for example, there is already an
792/// incompatible module registered with the same name), the drop implementation may panic. To handle
793/// errors without panicking, call [`register`](Self::register) explicitly.
794#[derive(Debug)]
795pub struct Module<'a, State, Error, ModuleError, ModuleVersion>
796where
797    State: Send + Sync + 'static,
798    Error: crate::Error + From<ModuleError> + 'static,
799    ModuleError: Send + Sync + 'static,
800    ModuleVersion: StaticVersionType + 'static,
801{
802    app: &'a mut App<State, Error>,
803    base_url: &'a str,
804    // This is only an [Option] so we can [take] out of it during [drop].
805    api: Option<Api<State, ModuleError, ModuleVersion>>,
806}
807
808impl<State, Error, ModuleError, ModuleVersion> Deref
809    for Module<'_, State, Error, ModuleError, ModuleVersion>
810where
811    State: Send + Sync + 'static,
812    Error: crate::Error + From<ModuleError> + 'static,
813    ModuleError: Send + Sync + 'static,
814    ModuleVersion: StaticVersionType + 'static,
815{
816    type Target = Api<State, ModuleError, ModuleVersion>;
817
818    fn deref(&self) -> &Self::Target {
819        self.api.as_ref().unwrap()
820    }
821}
822
823impl<State, Error, ModuleError, ModuleVersion> DerefMut
824    for Module<'_, State, Error, ModuleError, ModuleVersion>
825where
826    State: Send + Sync + 'static,
827    Error: crate::Error + From<ModuleError> + 'static,
828    ModuleError: Send + Sync + 'static,
829    ModuleVersion: StaticVersionType + 'static,
830{
831    fn deref_mut(&mut self) -> &mut Self::Target {
832        self.api.as_mut().unwrap()
833    }
834}
835
836impl<State, Error, ModuleError, ModuleVersion> Drop
837    for Module<'_, State, Error, ModuleError, ModuleVersion>
838where
839    State: Send + Sync + 'static,
840    Error: crate::Error + From<ModuleError> + 'static,
841    ModuleError: Send + Sync + 'static,
842    ModuleVersion: StaticVersionType + 'static,
843{
844    fn drop(&mut self) {
845        self.register_impl().unwrap();
846    }
847}
848
849impl<State, Error, ModuleError, ModuleVersion> Module<'_, State, Error, ModuleError, ModuleVersion>
850where
851    State: Send + Sync + 'static,
852    Error: crate::Error + From<ModuleError> + 'static,
853    ModuleError: Send + Sync + 'static,
854    ModuleVersion: StaticVersionType + 'static,
855{
856    /// Register this module with the linked app.
857    pub fn register(mut self) -> Result<(), AppError> {
858        self.register_impl()
859    }
860
861    /// Perform the logic of [`Self::register`] without consuming `self`, so this can be called from
862    /// `drop`.
863    fn register_impl(&mut self) -> Result<(), AppError> {
864        if let Some(api) = self.api.take() {
865            self.app.register_module(self.base_url, api)?;
866            Ok(())
867        } else {
868            // Already registered.
869            Ok(())
870        }
871    }
872}
873
874#[cfg(test)]
875mod test {
876    use super::*;
877    use crate::{
878        error::{Error, ServerError},
879        metrics::Metrics,
880        socket::Connection,
881        testing::{setup_test, test_ws_client, Client},
882        Url,
883    };
884    use async_std::{sync::RwLock, task::spawn};
885    use async_tungstenite::tungstenite::Message;
886    use futures::{FutureExt, SinkExt, StreamExt};
887    use portpicker::pick_unused_port;
888    use regex::Regex;
889    use serde::de::DeserializeOwned;
890    use std::{borrow::Cow, fmt::Debug};
891    use toml::toml;
892    use vbs::{version::StaticVersion, BinarySerializer, Serializer};
893
894    type StaticVer01 = StaticVersion<0, 1>;
895    type SerializerV01 = Serializer<StaticVer01>;
896
897    type StaticVer02 = StaticVersion<0, 2>;
898    type SerializerV02 = Serializer<StaticVer02>;
899
900    type StaticVer03 = StaticVersion<0, 3>;
901    type SerializerV03 = Serializer<StaticVer03>;
902
903    #[derive(Clone, Copy, Debug)]
904    struct FakeMetrics;
905
906    impl Metrics for FakeMetrics {
907        type Error = ServerError;
908
909        fn export(&self) -> Result<String, Self::Error> {
910            Ok("METRICS".into())
911        }
912    }
913
914    /// Test route dispatching for routes with the same path and different methods.
915    #[async_std::test]
916    async fn test_method_dispatch() {
917        setup_test();
918
919        use crate::http::Method::*;
920
921        let mut app = App::<_, ServerError>::with_state(RwLock::new(FakeMetrics));
922        let api_toml = toml! {
923            [meta]
924            FORMAT_VERSION = "0.1.0"
925
926            [route.get_test]
927            PATH = ["/test"]
928            METHOD = "GET"
929
930            [route.post_test]
931            PATH = ["/test"]
932            METHOD = "POST"
933
934            [route.put_test]
935            PATH = ["/test"]
936            METHOD = "PUT"
937
938            [route.delete_test]
939            PATH = ["/test"]
940            METHOD = "DELETE"
941
942            [route.socket_test]
943            PATH = ["/test"]
944            METHOD = "SOCKET"
945
946            [route.metrics_test]
947            PATH = ["/test"]
948            METHOD = "METRICS"
949        };
950        {
951            let mut api = app
952                .module::<ServerError, StaticVer01>("mod", api_toml)
953                .unwrap();
954            api.get("get_test", |_req, _state| {
955                async move { Ok(Get.to_string()) }.boxed()
956            })
957            .unwrap()
958            .post("post_test", |_req, _state| {
959                async move { Ok(Post.to_string()) }.boxed()
960            })
961            .unwrap()
962            .put("put_test", |_req, _state| {
963                async move { Ok(Put.to_string()) }.boxed()
964            })
965            .unwrap()
966            .delete("delete_test", |_req, _state| {
967                async move { Ok(Delete.to_string()) }.boxed()
968            })
969            .unwrap()
970            .socket(
971                "socket_test",
972                |_req, mut conn: Connection<str, (), _, StaticVer01>, _state| {
973                    async move {
974                        conn.send("SOCKET").await.unwrap();
975                        Ok(())
976                    }
977                    .boxed()
978                },
979            )
980            .unwrap()
981            .metrics("metrics_test", |_req, state| {
982                async move { Ok(Cow::Borrowed(state)) }.boxed()
983            })
984            .unwrap();
985        }
986        let port = pick_unused_port().unwrap();
987        let url: Url = format!("http://localhost:{}", port).parse().unwrap();
988        spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
989        let client = Client::new(url.clone()).await;
990
991        // Regular HTTP methods.
992        for method in [Get, Post, Put, Delete] {
993            let res = client
994                .request(method, "mod/test")
995                .header("Accept", "application/json")
996                .send()
997                .await
998                .unwrap();
999            assert_eq!(res.status(), StatusCode::OK);
1000            assert_eq!(res.json::<String>().await.unwrap(), method.to_string());
1001        }
1002
1003        // Metrics with Accept header.
1004        let res = client
1005            .get("mod/test")
1006            .header("Accept", "text/plain")
1007            .send()
1008            .await
1009            .unwrap();
1010        assert_eq!(res.status(), StatusCode::OK);
1011        assert_eq!(res.text().await.unwrap(), "METRICS");
1012
1013        // Metrics without Accept header.
1014        let res = client.get("mod/test").send().await.unwrap();
1015        assert_eq!(res.status(), StatusCode::OK);
1016        assert_eq!(res.text().await.unwrap(), "METRICS");
1017
1018        // Socket.
1019        let mut conn = test_ws_client(url.join("mod/test").unwrap()).await;
1020        let msg = conn.next().await.unwrap().unwrap();
1021        let body: String = match msg {
1022            Message::Text(m) => serde_json::from_str(&m).unwrap(),
1023            Message::Binary(m) => SerializerV01::deserialize(&m).unwrap(),
1024            m => panic!("expected Text or Binary message, but got {}", m),
1025        };
1026        assert_eq!(body, "SOCKET");
1027    }
1028
1029    /// Test route dispatching for routes with patterns containing different parmaeters
1030    #[async_std::test]
1031    async fn test_param_dispatch() {
1032        setup_test();
1033
1034        let mut app = App::<_, ServerError>::with_state(RwLock::new(()));
1035        let api_toml = toml! {
1036            [meta]
1037            FORMAT_VERSION = "0.1.0"
1038
1039            [route.test]
1040            PATH = ["/test/a/:a", "/test/b/:b"]
1041            ":a" = "Integer"
1042            ":b" = "Boolean"
1043        };
1044        {
1045            let mut api = app
1046                .module::<ServerError, StaticVer01>("mod", api_toml)
1047                .unwrap();
1048            api.get("test", |req, _state| {
1049                async move {
1050                    if let Some(a) = req.opt_integer_param::<_, i32>("a")? {
1051                        Ok(("a", a.to_string()))
1052                    } else {
1053                        Ok(("b", req.boolean_param("b")?.to_string()))
1054                    }
1055                }
1056                .boxed()
1057            })
1058            .unwrap();
1059        }
1060        let port = pick_unused_port().unwrap();
1061        let url: Url = format!("http://localhost:{}", port).parse().unwrap();
1062        spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
1063        let client = Client::new(url.clone()).await;
1064
1065        let res = client.get("mod/test/a/42").send().await.unwrap();
1066        assert_eq!(res.status(), StatusCode::OK);
1067        assert_eq!(
1068            res.json::<(String, String)>().await.unwrap(),
1069            ("a".to_string(), "42".to_string())
1070        );
1071
1072        let res = client.get("mod/test/b/true").send().await.unwrap();
1073        assert_eq!(res.status(), StatusCode::OK);
1074        assert_eq!(
1075            res.json::<(String, String)>().await.unwrap(),
1076            ("b".to_string(), "true".to_string())
1077        );
1078    }
1079
1080    #[async_std::test]
1081    async fn test_versions() {
1082        setup_test();
1083
1084        let mut app = App::<_, ServerError>::with_state(RwLock::new(()));
1085
1086        // Create two different, non-consecutive major versions of an API. One method will be
1087        // deleted in version 1, one will be added in version 3, and one will be present in both
1088        // versions (with a different implementation).
1089        let v1_toml = toml! {
1090            [meta]
1091            FORMAT_VERSION = "0.1.0"
1092
1093            [route.deleted]
1094            PATH = ["/deleted"]
1095
1096            [route.unchanged]
1097            PATH = ["/unchanged"]
1098        };
1099        let v3_toml = toml! {
1100            [meta]
1101            FORMAT_VERSION = "0.1.0"
1102
1103            [route.added]
1104            PATH = ["/added"]
1105
1106            [route.unchanged]
1107            PATH = ["/unchanged"]
1108        };
1109
1110        {
1111            let mut v1 = app
1112                .module::<ServerError, StaticVer01>("mod", v1_toml.clone())
1113                .unwrap();
1114            v1.with_version("1.0.0".parse().unwrap())
1115                .get("deleted", |_req, _state| {
1116                    async move { Ok("deleted v1") }.boxed()
1117                })
1118                .unwrap()
1119                .get("unchanged", |_req, _state| {
1120                    async move { Ok("unchanged v1") }.boxed()
1121                })
1122                .unwrap()
1123                // Add a custom healthcheck for the old version so we can check healthcheck routing.
1124                .with_health_check(|_state| {
1125                    async move { HealthStatus::TemporarilyUnavailable }.boxed()
1126                });
1127        }
1128        {
1129            // Registering the same major version twice is an error.
1130            let mut api = app
1131                .module::<ServerError, StaticVer01>("mod", v1_toml)
1132                .unwrap();
1133            api.with_version("1.1.1".parse().unwrap());
1134            assert_eq!(
1135                api.register().unwrap_err(),
1136                DispatchError::ModuleAlreadyExists {
1137                    prefix: "mod".into(),
1138                    version: 1,
1139                }
1140                .into()
1141            );
1142        }
1143        {
1144            let mut v3 = app
1145                .module::<ServerError, StaticVer01>("mod", v3_toml.clone())
1146                .unwrap();
1147            v3.with_version("3.0.0".parse().unwrap())
1148                .get("added", |_req, _state| {
1149                    async move { Ok("added v3") }.boxed()
1150                })
1151                .unwrap()
1152                .get("unchanged", |_req, _state| {
1153                    async move { Ok("unchanged v3") }.boxed()
1154                })
1155                .unwrap();
1156        }
1157
1158        let port = pick_unused_port().unwrap();
1159        let url: Url = format!("http://localhost:{}", port).parse().unwrap();
1160        spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
1161        let client = Client::new(url.clone()).await;
1162
1163        // First check that we can call all the expected methods.
1164        assert_eq!(
1165            "deleted v1",
1166            client
1167                .get("v1/mod/deleted")
1168                .send()
1169                .await
1170                .unwrap()
1171                .json::<String>()
1172                .await
1173                .unwrap()
1174        );
1175        assert_eq!(
1176            "unchanged v1",
1177            client
1178                .get("v1/mod/unchanged")
1179                .send()
1180                .await
1181                .unwrap()
1182                .json::<String>()
1183                .await
1184                .unwrap()
1185        );
1186        // For the v3 methods, we can query with or without a version prefix.
1187        for prefix in ["", "/v3"] {
1188            let span = tracing::info_span!("version", prefix);
1189            let _enter = span.enter();
1190
1191            assert_eq!(
1192                "added v3",
1193                client
1194                    .get(&format!("{prefix}/mod/added"))
1195                    .send()
1196                    .await
1197                    .unwrap()
1198                    .json::<String>()
1199                    .await
1200                    .unwrap()
1201            );
1202            assert_eq!(
1203                "unchanged v3",
1204                client
1205                    .get(&format!("{prefix}/mod/unchanged"))
1206                    .send()
1207                    .await
1208                    .unwrap()
1209                    .json::<String>()
1210                    .await
1211                    .unwrap()
1212            );
1213        }
1214
1215        // Test documentation for invalid routes.
1216        let check_docs = |version, route: &'static str| {
1217            let client = &client;
1218            async move {
1219                let span = tracing::info_span!("check_docs", ?version, route);
1220                let _enter = span.enter();
1221                tracing::info!("test invalid route docs");
1222
1223                let prefix = match version {
1224                    Some(v) => format!("/v{v}"),
1225                    None => "".into(),
1226                };
1227
1228                // Invalid route or no route with no version prefix redirects to documentation for
1229                // the latest supported version.
1230                let version = version.unwrap_or(3);
1231
1232                let res = client
1233                    .get(&format!("{prefix}/mod/{route}"))
1234                    .send()
1235                    .await
1236                    .unwrap();
1237                let docs = res.text().await.unwrap();
1238                if !route.is_empty() {
1239                    assert!(
1240                        docs.contains(&format!("No route matches /{route}")),
1241                        "{docs}"
1242                    );
1243                }
1244                assert!(
1245                    docs.contains(&format!("mod API {version}.0.0 Reference")),
1246                    "{docs}"
1247                );
1248            }
1249        };
1250
1251        for route in ["", "deleted"] {
1252            check_docs(None, route).await;
1253        }
1254        for route in ["", "deleted"] {
1255            check_docs(Some(3), route).await;
1256        }
1257        for route in ["", "added"] {
1258            check_docs(Some(1), route).await;
1259        }
1260
1261        // Request with an unsupported version lists the supported versions.
1262        let expected_html = html! {
1263            "Unsupported version v2. Supported versions are:"
1264            ul {
1265                li {
1266                    a href="/v3/mod" {"v3"}
1267                }
1268                li {
1269                    a href="/v1/mod" {"v1"}
1270                }
1271            }
1272        }
1273        .into_string();
1274        for route in ["", "/unchanged"] {
1275            let span = tracing::info_span!("unsupported_version_docs", route);
1276            let _enter = span.enter();
1277            tracing::info!("test unsupported version docs");
1278
1279            let res = client.get(&format!("/v2/mod{route}")).send().await.unwrap();
1280            let docs = res.text().await.unwrap();
1281            assert_eq!(docs, expected_html);
1282        }
1283
1284        // Test version endpoints.
1285        for version in [None, Some(1), Some(3)] {
1286            let span = tracing::info_span!("version_endpoints", version);
1287            let _enter = span.enter();
1288            tracing::info!("test version endpoints");
1289
1290            let prefix = match version {
1291                Some(v) => format!("/v{v}"),
1292                None => "".into(),
1293            };
1294            let res = client
1295                .get(&format!("{prefix}/mod/version"))
1296                .send()
1297                .await
1298                .unwrap();
1299            assert_eq!(
1300                res.json::<ApiVersion>()
1301                    .await
1302                    .unwrap()
1303                    .api_version
1304                    .unwrap()
1305                    .major,
1306                version.unwrap_or(3)
1307            );
1308        }
1309
1310        // Test the application version.
1311        let res = client.get("version").send().await.unwrap();
1312        assert_eq!(
1313            res.json::<AppVersion>().await.unwrap().modules["mod"],
1314            [
1315                ApiVersion {
1316                    api_version: Some("3.0.0".parse().unwrap()),
1317                    spec_version: "0.1.0".parse().unwrap(),
1318                },
1319                ApiVersion {
1320                    api_version: Some("1.0.0".parse().unwrap()),
1321                    spec_version: "0.1.0".parse().unwrap(),
1322                }
1323            ]
1324        );
1325
1326        // Test healthcheck endpoints.
1327        for version in [None, Some(1), Some(3)] {
1328            let span = tracing::info_span!("healthcheck_endpoints", version);
1329            let _enter = span.enter();
1330            tracing::info!("test healthcheck endpoints");
1331
1332            let prefix = match version {
1333                Some(v) => format!("/v{v}"),
1334                None => "".into(),
1335            };
1336            let res = client
1337                .get(&format!("{prefix}/mod/healthcheck"))
1338                .send()
1339                .await
1340                .unwrap();
1341            let status = res.status();
1342            let health: HealthStatus = res.json().await.unwrap();
1343            assert_eq!(health.status(), status);
1344            assert_eq!(
1345                health,
1346                if version == Some(1) {
1347                    HealthStatus::TemporarilyUnavailable
1348                } else {
1349                    HealthStatus::Available
1350                }
1351            );
1352        }
1353
1354        // Test the application health.
1355        let res = client.get("healthcheck").send().await.unwrap();
1356        assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
1357        let health: AppHealth = res.json().await.unwrap();
1358        assert_eq!(health.status, HealthStatus::Unhealthy);
1359        assert_eq!(
1360            health.modules["mod"],
1361            [(3, StatusCode::OK), (1, StatusCode::SERVICE_UNAVAILABLE)].into()
1362        );
1363    }
1364
1365    #[async_std::test]
1366    async fn test_api_disco() {
1367        setup_test();
1368
1369        // Test discoverability documentation when a request is for an unknown API.
1370        let mut app = App::<_, ServerError>::with_state(());
1371        app.module::<ServerError, StaticVer01>(
1372            "the-correct-module",
1373            toml! {
1374                route = {}
1375            },
1376        )
1377        .unwrap()
1378        .with_version("1.0.0".parse().unwrap());
1379
1380        let port = pick_unused_port().unwrap();
1381        let url: Url = format!("http://localhost:{}", port).parse().unwrap();
1382        spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
1383        let client = Client::new(url.clone()).await;
1384
1385        let expected_list_item = html! {
1386            a href="/the-correct-module" {"the-correct-module"}
1387            sup {
1388                a href="/v1/the-correct-module" {"[v1]"}
1389            }
1390        }
1391        .into_string();
1392
1393        let expected_err = Regex::new("No API matches .*/test").unwrap();
1394        for version_prefix in ["", "/v1"] {
1395            let docs = client
1396                .get(&format!("{version_prefix}/test"))
1397                .send()
1398                .await
1399                .unwrap()
1400                .text()
1401                .await
1402                .unwrap();
1403            expected_err
1404                .find(&docs)
1405                .unwrap_or_else(|| panic!("Docs contains error message:\n{docs}"));
1406            assert!(docs.contains(&expected_list_item), "{docs}");
1407        }
1408
1409        // Top level documentation.
1410        let docs = client.get("").send().await.unwrap().text().await.unwrap();
1411        assert!(!docs.contains("No API matches"), "{docs}");
1412        assert!(docs.contains(&expected_list_item), "{docs}");
1413
1414        let docs = client
1415            .get("/v1")
1416            .send()
1417            .await
1418            .unwrap()
1419            .text()
1420            .await
1421            .unwrap();
1422        Regex::new("No API matches .*/")
1423            .unwrap()
1424            .find(&docs)
1425            .unwrap_or_else(|| panic!("Docs contains error message:\n{docs}"));
1426        assert!(docs.contains(&expected_list_item), "{docs}");
1427    }
1428
1429    #[async_std::test]
1430    async fn test_post_redirect_idempotency() {
1431        setup_test();
1432
1433        let mut app = App::<_, ServerError>::with_state(RwLock::new(0));
1434
1435        let api_toml = toml! {
1436            [meta]
1437            FORMAT_VERSION = "0.1.0"
1438
1439            [route.test]
1440            METHOD = "POST"
1441            PATH = ["/test"]
1442        };
1443        {
1444            let mut api = app
1445                .module::<ServerError, StaticVer01>("mod", api_toml.clone())
1446                .unwrap();
1447            api.post("test", |_req, state| {
1448                async move {
1449                    *state += 1;
1450                    Ok(*state)
1451                }
1452                .boxed()
1453            })
1454            .unwrap();
1455        }
1456
1457        let port = pick_unused_port().unwrap();
1458        let url: Url = format!("http://localhost:{}", port).parse().unwrap();
1459        spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
1460        let client = Client::new(url.clone()).await;
1461
1462        for i in 1..3 {
1463            // Request gets redirected to latest version of API and resent, but endpoint handler
1464            // only executes once.
1465            assert_eq!(
1466                client
1467                    .post("mod/test")
1468                    .send()
1469                    .await
1470                    .unwrap()
1471                    .json::<u64>()
1472                    .await
1473                    .unwrap(),
1474                i
1475            );
1476        }
1477    }
1478
1479    #[async_std::test]
1480    async fn test_format_versions() {
1481        setup_test();
1482
1483        // Register two modules with different binary format versions, each in turn different from
1484        // the app-level version. Each module has two endpoints, one which always succeeds and one
1485        // which always fails, so we can test error serialization.
1486        let mut app = App::<_, ServerError>::with_state(());
1487        let api_toml = toml! {
1488            [meta]
1489            FORMAT_VERSION = "0.1.0"
1490
1491            [route.ok]
1492            METHOD = "GET"
1493            PATH = ["/ok"]
1494
1495            [route.err]
1496            METHOD = "GET"
1497            PATH = ["/err"]
1498        };
1499
1500        fn init_api<VER: StaticVersionType + 'static>(api: &mut Api<(), ServerError, VER>) {
1501            api.get("ok", |_req, _state| async move { Ok("ok") }.boxed())
1502                .unwrap()
1503                .get("err", |_req, _state| {
1504                    async move {
1505                        Err::<String, _>(ServerError::catch_all(
1506                            StatusCode::INTERNAL_SERVER_ERROR,
1507                            "err".into(),
1508                        ))
1509                    }
1510                    .boxed()
1511                })
1512                .unwrap();
1513        }
1514
1515        {
1516            let mut api = app
1517                .module::<ServerError, StaticVer02>("mod02", api_toml.clone())
1518                .unwrap();
1519            init_api(&mut api);
1520        }
1521        {
1522            let mut api = app
1523                .module::<ServerError, StaticVer03>("mod03", api_toml.clone())
1524                .unwrap();
1525            init_api(&mut api);
1526        }
1527
1528        let port = pick_unused_port().unwrap();
1529        let url: Url = format!("http://localhost:{}", port).parse().unwrap();
1530        spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
1531        let client = Client::new(url.clone()).await;
1532
1533        async fn get<S: BinarySerializer, T: DeserializeOwned>(
1534            client: &Client,
1535            endpoint: &str,
1536            expected_status: StatusCode,
1537        ) -> anyhow::Result<T> {
1538            tracing::info!("GET {endpoint} ->");
1539            let res = client
1540                .get(endpoint)
1541                .header("Accept", "application/octet-stream")
1542                .send()
1543                .await
1544                .unwrap();
1545            tracing::info!(?res, "<-");
1546            assert_eq!(res.status(), expected_status);
1547            let bytes = res.bytes().await.unwrap();
1548            anyhow::Context::context(
1549                S::deserialize(&bytes),
1550                format!("failed to deserialize bytes {bytes:?}"),
1551            )
1552        }
1553
1554        #[tracing::instrument(skip(client))]
1555        async fn check_ok<S: BinarySerializer>(
1556            client: &Client,
1557            endpoint: &str,
1558            expected: impl Debug + DeserializeOwned + Eq,
1559        ) {
1560            tracing::info!("checking successful deserialization");
1561            assert_eq!(
1562                expected,
1563                get::<S, _>(client, endpoint, StatusCode::OK).await.unwrap()
1564            );
1565        }
1566
1567        let api_version = ApiVersion {
1568            spec_version: "0.1.0".parse().unwrap(),
1569            api_version: None,
1570        };
1571
1572        check_ok::<SerializerV01>(
1573            &client,
1574            "healthcheck",
1575            AppHealth {
1576                status: HealthStatus::Available,
1577                modules: [
1578                    ("mod02".into(), [(0, StatusCode::OK)].into()),
1579                    ("mod03".into(), [(0, StatusCode::OK)].into()),
1580                ]
1581                .into(),
1582            },
1583        )
1584        .await;
1585        check_ok::<SerializerV01>(
1586            &client,
1587            "version",
1588            AppVersion {
1589                app_version: None,
1590                disco_version: env!("CARGO_PKG_VERSION").parse().unwrap(),
1591                modules: [
1592                    ("mod02".into(), vec![api_version.clone()]),
1593                    ("mod03".into(), vec![api_version.clone()]),
1594                ]
1595                .into(),
1596            },
1597        )
1598        .await;
1599        check_ok::<SerializerV02>(&client, "mod02/ok", "ok".to_string()).await;
1600        check_ok::<SerializerV02>(&client, "mod02/healthcheck", HealthStatus::Available).await;
1601        check_ok::<SerializerV02>(&client, "mod02/version", api_version.clone()).await;
1602        check_ok::<SerializerV03>(&client, "mod03/ok", "ok".to_string()).await;
1603        check_ok::<SerializerV03>(&client, "mod03/healthcheck", HealthStatus::Available).await;
1604        check_ok::<SerializerV03>(&client, "mod03/version", api_version.clone()).await;
1605
1606        #[tracing::instrument(skip(client))]
1607        async fn check_wrong_version<S: BinarySerializer, T: Debug + DeserializeOwned>(
1608            client: &Client,
1609            endpoint: &str,
1610        ) {
1611            tracing::info!("checking deserialization fails with wrong version");
1612            get::<S, T>(client, endpoint, StatusCode::OK)
1613                .await
1614                .unwrap_err();
1615        }
1616
1617        check_wrong_version::<SerializerV02, AppHealth>(&client, "healthcheck").await;
1618        check_wrong_version::<SerializerV02, AppVersion>(&client, "version").await;
1619        check_wrong_version::<SerializerV03, String>(&client, "mod02/ok").await;
1620        check_wrong_version::<SerializerV03, HealthStatus>(&client, "mod02/healthcheck").await;
1621        check_wrong_version::<SerializerV03, ApiVersion>(&client, "mod02/version").await;
1622        check_wrong_version::<SerializerV01, String>(&client, "mod03/ok").await;
1623        check_wrong_version::<SerializerV01, HealthStatus>(&client, "mod03/healthcheck").await;
1624        check_wrong_version::<SerializerV01, ApiVersion>(&client, "mod03/version").await;
1625
1626        #[tracing::instrument(skip(client))]
1627        async fn check_err<S: BinarySerializer>(client: &Client, endpoint: &str) {
1628            tracing::info!("checking error deserialization");
1629            tracing::info!("checking successful deserialization");
1630            assert_eq!(
1631                get::<S, ServerError>(client, endpoint, StatusCode::INTERNAL_SERVER_ERROR)
1632                    .await
1633                    .unwrap(),
1634                ServerError::catch_all(StatusCode::INTERNAL_SERVER_ERROR, "err".into())
1635            );
1636        }
1637
1638        check_err::<SerializerV02>(&client, "mod02/err").await;
1639        check_err::<SerializerV03>(&client, "mod03/err").await;
1640    }
1641
1642    #[async_std::test]
1643    async fn test_api_prefix() {
1644        setup_test();
1645
1646        // It is illegal to register two API modules where one is a prefix (in terms of route
1647        // segments) of another.
1648        for (api1, api2) in [
1649            ("", "api"),
1650            ("api", ""),
1651            ("path", "path/sub"),
1652            ("path/sub", "path"),
1653        ] {
1654            tracing::info!(api1, api2, "test case");
1655            let (prefix, conflict) = if api1.len() < api2.len() {
1656                (api1.to_string(), api2.to_string())
1657            } else {
1658                (api2.to_string(), api1.to_string())
1659            };
1660
1661            let mut app = App::<_, ServerError>::with_state(());
1662            let toml = toml! {
1663                route = {}
1664            };
1665            app.module::<ServerError, StaticVer01>(api1, toml.clone())
1666                .unwrap()
1667                .register()
1668                .unwrap();
1669            assert_eq!(
1670                app.module::<ServerError, StaticVer01>(api2, toml)
1671                    .unwrap()
1672                    .register()
1673                    .unwrap_err(),
1674                DispatchError::ConflictingModules { prefix, conflict }.into()
1675            );
1676        }
1677    }
1678
1679    #[async_std::test]
1680    async fn test_singleton_api() {
1681        setup_test();
1682
1683        // If there is only one API, it should be possible to register it with an empty prefix.
1684        let toml = toml! {
1685            [route.test]
1686            PATH = ["/test"]
1687        };
1688        let mut app = App::<_, ServerError>::with_state(());
1689        let mut api = app.module::<ServerError, StaticVer01>("", toml).unwrap();
1690        api.with_version("0.1.0".parse().unwrap())
1691            .get("test", |_, _| async move { Ok("response") }.boxed())
1692            .unwrap();
1693        api.register().unwrap();
1694
1695        let port = pick_unused_port().unwrap();
1696        spawn(app.serve(format!("0.0.0.0:{port}"), StaticVer01::instance()));
1697        let client = Client::new(format!("http://localhost:{port}").parse().unwrap()).await;
1698
1699        // Test an endpoint.
1700        let res = client.get("/test").send().await.unwrap();
1701        assert_eq!(
1702            res.status(),
1703            StatusCode::OK,
1704            "{}",
1705            res.text().await.unwrap()
1706        );
1707        assert_eq!(res.json::<String>().await.unwrap(), "response");
1708
1709        // Test healthcheck and version endpoints. Since these would ordinarily conflict with the
1710        // app-level healthcheck and version endpoints for an API with no prefix, we only get the
1711        // API-level endpoints, so that a singleton API behaves like a normal API, while app-level
1712        // stuff is reserved for non-trivial applications with more than one API.
1713        let res = client.get("/healthcheck").send().await.unwrap();
1714        assert_eq!(res.status(), StatusCode::OK);
1715        assert_eq!(
1716            res.json::<HealthStatus>().await.unwrap(),
1717            HealthStatus::Available
1718        );
1719
1720        let res = client.get("/version").send().await.unwrap();
1721        assert_eq!(res.status(), StatusCode::OK);
1722        assert_eq!(
1723            res.json::<ApiVersion>().await.unwrap(),
1724            ApiVersion {
1725                api_version: Some("0.1.0".parse().unwrap()),
1726                spec_version: "0.1.0".parse().unwrap(),
1727            },
1728        );
1729    }
1730
1731    #[async_std::test]
1732    async fn test_multi_segment() {
1733        setup_test();
1734
1735        let toml = toml! {
1736            [route.test]
1737            PATH = ["/test"]
1738        };
1739        let mut app = App::<_, ServerError>::with_state(());
1740
1741        for name in ["a", "b"] {
1742            let path = format!("api/{name}");
1743            let mut api = app
1744                .module::<ServerError, StaticVer01>(&path, toml.clone())
1745                .unwrap();
1746            api.with_version("0.1.0".parse().unwrap())
1747                .get("test", move |_, _| async move { Ok(name) }.boxed())
1748                .unwrap();
1749            api.register().unwrap();
1750        }
1751
1752        let port = pick_unused_port().unwrap();
1753        spawn(app.serve(format!("0.0.0.0:{port}"), StaticVer01::instance()));
1754        let client = Client::new(format!("http://localhost:{port}").parse().unwrap()).await;
1755
1756        for api in ["a", "b"] {
1757            tracing::info!(api, "testing api");
1758
1759            // Test an endpoint.
1760            let res = client.get(&format!("api/{api}/test")).send().await.unwrap();
1761            assert_eq!(res.status(), StatusCode::OK);
1762            assert_eq!(res.json::<String>().await.unwrap(), api);
1763
1764            // Test healthcheck.
1765            let res = client
1766                .get(&format!("api/{api}/healthcheck"))
1767                .send()
1768                .await
1769                .unwrap();
1770            assert_eq!(res.status(), StatusCode::OK);
1771            assert_eq!(
1772                res.json::<HealthStatus>().await.unwrap(),
1773                HealthStatus::Available
1774            );
1775
1776            // Test version.
1777            let res = client
1778                .get(&format!("api/{api}/version"))
1779                .send()
1780                .await
1781                .unwrap();
1782            assert_eq!(res.status(), StatusCode::OK);
1783            assert_eq!(
1784                res.json::<ApiVersion>().await.unwrap().api_version.unwrap(),
1785                "0.1.0".parse().unwrap()
1786            );
1787        }
1788
1789        // Test app-level healthcheck.
1790        let res = client.get("healthcheck").send().await.unwrap();
1791        assert_eq!(res.status(), StatusCode::OK);
1792        assert_eq!(
1793            res.json::<AppHealth>().await.unwrap(),
1794            AppHealth {
1795                status: HealthStatus::Available,
1796                modules: [
1797                    ("api/a".into(), [(0, StatusCode::OK)].into()),
1798                    ("api/b".into(), [(0, StatusCode::OK)].into()),
1799                ]
1800                .into()
1801            }
1802        );
1803
1804        // Test app-level version.
1805        let res = client.get("version").send().await.unwrap();
1806        assert_eq!(res.status(), StatusCode::OK);
1807        assert_eq!(
1808            res.json::<AppVersion>().await.unwrap().modules,
1809            [
1810                (
1811                    "api/a".into(),
1812                    vec![ApiVersion {
1813                        api_version: Some("0.1.0".parse().unwrap()),
1814                        spec_version: "0.1.0".parse().unwrap(),
1815                    }]
1816                ),
1817                (
1818                    "api/b".into(),
1819                    vec![ApiVersion {
1820                        api_version: Some("0.1.0".parse().unwrap()),
1821                        spec_version: "0.1.0".parse().unwrap(),
1822                    }]
1823                ),
1824            ]
1825            .into()
1826        );
1827    }
1828}