tide_disco/api.rs
1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the tide-disco library.
3
4// You should have received a copy of the MIT License
5// along with the tide-disco library. If not, see <https://mit-license.org/>.
6
7use crate::{
8 healthcheck::{HealthCheck, HealthStatus},
9 method::{Method, ReadState, WriteState},
10 metrics::Metrics,
11 middleware::{error_handler, ErrorHandler},
12 request::RequestParams,
13 route::{self, *},
14 socket, Html,
15};
16use async_std::sync::Arc;
17use async_trait::async_trait;
18use derivative::Derivative;
19use futures::{
20 future::{BoxFuture, FutureExt},
21 stream::BoxStream,
22};
23use maud::{html, PreEscaped};
24use semver::Version;
25use serde::{de::DeserializeOwned, Deserialize, Serialize};
26use serde_with::{serde_as, DisplayFromStr};
27use snafu::{OptionExt, ResultExt, Snafu};
28use std::{
29 borrow::Cow,
30 collections::hash_map::{Entry, HashMap, IntoValues, Values},
31 convert::Infallible,
32 fmt::Display,
33 fs,
34 marker::PhantomData,
35 ops::Index,
36 path::{Path, PathBuf},
37};
38use tide::http::content::Accept;
39use vbs::version::StaticVersionType;
40
41/// An error encountered when parsing or constructing an [Api].
42#[derive(Clone, Debug, Snafu, PartialEq, Eq)]
43pub enum ApiError {
44 Route { source: RouteParseError },
45 ApiMustBeTable,
46 MissingRoutesTable,
47 RoutesMustBeTable,
48 UndefinedRoute,
49 HandlerAlreadyRegistered,
50 IncorrectMethod { expected: Method, actual: Method },
51 InvalidMetaTable { source: toml::de::Error },
52 MissingFormatVersion,
53 InvalidFormatVersion,
54 AmbiguousRoutes { route1: String, route2: String },
55 CannotReadToml { reason: String },
56}
57
58/// Version information about an API.
59#[serde_as]
60#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
61pub struct ApiVersion {
62 /// The version of this API.
63 #[serde_as(as = "Option<DisplayFromStr>")]
64 pub api_version: Option<Version>,
65
66 /// The format version of the TOML specification used to load this API.
67 #[serde_as(as = "DisplayFromStr")]
68 pub spec_version: Version,
69}
70
71/// Metadata used for describing and documenting an API.
72///
73/// [ApiMetadata] contains version information about the API, as well as optional HTML fragments to
74/// customize the formatting of automatically generated API documentation. Each of the supported
75/// HTML fragments is optional and will be filled in with a reasonable default if not provided. Some
76/// of the HTML fragments may contain "placeholders", which are identifiers enclosed in `{{ }}`,
77/// like `{{SOME_PLACEHOLDER}}`. These will be replaced by contextual information when the
78/// documentation is generated. The placeholders supported by each HTML fragment are documented
79/// below.
80#[serde_as]
81#[derive(Clone, Debug, Deserialize, Serialize)]
82#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
83pub struct ApiMetadata {
84 /// The name of this API.
85 ///
86 /// Note that the name of the API may be overridden if the API is registered with an app using
87 /// a different name.
88 #[serde(default = "meta_defaults::name")]
89 pub name: String,
90
91 /// A description of this API.
92 #[serde(default = "meta_defaults::description")]
93 pub description: String,
94
95 /// The version of the Tide Disco API specification format.
96 ///
97 /// If not specified, the version of this crate will be used.
98 #[serde_as(as = "DisplayFromStr")]
99 #[serde(default = "meta_defaults::format_version")]
100 pub format_version: Version,
101
102 /// HTML to be prepended to automatically generated documentation.
103 ///
104 /// # Placeholders
105 ///
106 /// * `NAME`: the name of the API
107 /// * `DESCRIPTION`: the description provided in `Cargo.toml`
108 /// * `VERSION`: the version of the API
109 /// * `FORMAT_VERSION`: the `FORMAT_VERSION` of the API
110 /// * `PUBLIC`: the URL where the public directory for this API is being served
111 #[serde(default = "meta_defaults::html_top")]
112 pub html_top: String,
113
114 /// HTML to be appended to automatically generated documentation.
115 #[serde(default = "meta_defaults::html_bottom")]
116 pub html_bottom: String,
117
118 /// The heading for documentation of a route.
119 ///
120 /// # Placeholders
121 ///
122 /// * `METHOD`: the method of the route
123 /// * `NAME`: the name of the route
124 #[serde(default = "meta_defaults::heading_entry")]
125 pub heading_entry: String,
126
127 /// The heading preceding documentation of all routes in this API.
128 #[serde(default = "meta_defaults::heading_routes")]
129 pub heading_routes: String,
130
131 /// The heading preceding documentation of route parameters.
132 #[serde(default = "meta_defaults::heading_parameters")]
133 pub heading_parameters: String,
134
135 /// The heading preceding documentation of a route description.
136 #[serde(default = "meta_defaults::heading_description")]
137 pub heading_description: String,
138
139 /// HTML formatting the path of a route.
140 ///
141 /// # Placeholders
142 ///
143 /// * `PATH`: the path being formatted
144 #[serde(default = "meta_defaults::route_path")]
145 pub route_path: String,
146
147 /// HTML preceding the contents of a table documenting the parameters of a route.
148 #[serde(default = "meta_defaults::parameter_table_open")]
149 pub parameter_table_open: String,
150
151 /// HTML closing a table documenting the parameters of a route.
152 #[serde(default = "meta_defaults::parameter_table_close")]
153 pub parameter_table_close: String,
154
155 /// HTML formatting an entry in a table documenting the parameters of a route.
156 ///
157 /// # Placeholders
158 ///
159 /// * `NAME`: the parameter being documented
160 /// * `TYPE`: the type of the parameter being documented
161 #[serde(default = "meta_defaults::parameter_row")]
162 pub parameter_row: String,
163
164 /// Documentation to insert in the parameters section of a route with no parameters.
165 #[serde(default = "meta_defaults::parameter_none")]
166 pub parameter_none: String,
167}
168
169impl Default for ApiMetadata {
170 fn default() -> Self {
171 // Deserialize an empty table, using the `serde` defaults for every field.
172 toml::Value::Table(Default::default()).try_into().unwrap()
173 }
174}
175
176mod meta_defaults {
177 use super::Version;
178
179 pub fn name() -> String {
180 "default-tide-disco-api".to_string()
181 }
182
183 pub fn description() -> String {
184 "Default Tide Disco API".to_string()
185 }
186
187 pub fn format_version() -> Version {
188 "0.1.0".parse().unwrap()
189 }
190
191 pub fn html_top() -> String {
192 "
193 <!DOCTYPE html>
194 <html lang='en'>
195 <head>
196 <meta charset='utf-8'>
197 <title>{{NAME}} Reference</title>
198 <link rel='stylesheet' href='{{PUBLIC}}/css/style.css'>
199 <script src='{{PUBLIC}}/js/script.js'></script>
200 <link rel='icon' type='image/svg+xml'
201 href='/public/favicon.svg'>
202 </head>
203 <body>
204 <div><a href='/'><img src='{{PUBLIC}}/espressosys_logo.svg'
205 alt='Espresso Systems Logo'
206 /></a></div>
207 <h1>{{NAME}} API {{VERSION}} Reference</h1>
208 <p>{{SHORT_DESCRIPTION}}</p><br/>
209 {{LONG_DESCRIPTION}}
210 "
211 .to_string()
212 }
213
214 pub fn html_bottom() -> String {
215 "
216 <h1> </h1>
217 <p>Copyright © 2022 Espresso Systems. All rights reserved.</p>
218 </body>
219 </html>
220 "
221 .to_string()
222 }
223
224 pub fn heading_entry() -> String {
225 "<a name='{{NAME}}'><h3 class='entry'><span class='meth'>{{METHOD}}</span> {{NAME}}</h3></a>\n".to_string()
226 }
227
228 pub fn heading_routes() -> String {
229 "<h3>Routes</h3>\n".to_string()
230 }
231 pub fn heading_parameters() -> String {
232 "<h3>Parameters</h3>\n".to_string()
233 }
234 pub fn heading_description() -> String {
235 "<h3>Description</h3>\n".to_string()
236 }
237
238 pub fn route_path() -> String {
239 "<p class='path'>{{PATH}}</p>\n".to_string()
240 }
241
242 pub fn parameter_table_open() -> String {
243 "<table>\n".to_string()
244 }
245 pub fn parameter_table_close() -> String {
246 "</table>\n\n".to_string()
247 }
248 pub fn parameter_row() -> String {
249 "<tr><td class='parameter'>{{NAME}}</td><td class='type'>{{TYPE}}</td></tr>\n".to_string()
250 }
251 pub fn parameter_none() -> String {
252 "<div class='meta'>None</div>".to_string()
253 }
254}
255
256/// A description of an API.
257///
258/// An [Api] is a structured representation of an `api.toml` specification. It contains API-level
259/// metadata and descriptions of all of the routes in the specification. It can be parsed from a
260/// TOML file and registered as a module of an [App](crate::App).
261#[derive(Derivative)]
262#[derivative(Debug(bound = ""))]
263pub struct Api<State, Error, VER> {
264 inner: ApiInner<State, Error>,
265 _version: PhantomData<VER>,
266}
267
268/// A version-erased description of an API.
269///
270/// This type contains all the details of the API, with the version of the binary serialization
271/// format type-erased and encapsulated into the route handlers. This type is used internally by
272/// [`App`], to allow dynamic registration of different versions of APIs with different versions of
273/// the binary format.
274///
275/// It is exposed publicly and manipulated _only_ via [`Api`], which wraps this type with a static
276/// format version type parameter, which provides compile-time enforcement of format version
277/// consistency as the API is being constructed, until it is registered with an [`App`] and
278/// type-erased.
279#[derive(Derivative)]
280#[derivative(Debug(bound = ""))]
281pub(crate) struct ApiInner<State, Error> {
282 meta: Arc<ApiMetadata>,
283 name: String,
284 routes: HashMap<String, Route<State, Error>>,
285 routes_by_path: HashMap<String, Vec<String>>,
286 #[derivative(Debug = "ignore")]
287 health_check: HealthCheckHandler<State>,
288 api_version: Option<Version>,
289 /// Error handler encapsulating the serialization format version for errors.
290 ///
291 /// This field is optional so it can be bound late, potentially after a `map_err` changes the
292 /// error type. However, it will always be set after `Api::into_inner` is called.
293 #[derivative(Debug = "ignore")]
294 error_handler: Option<Arc<dyn ErrorHandler<Error>>>,
295 /// Response handler encapsulating the serialization format version for version requests
296 #[derivative(Debug = "ignore")]
297 version_handler: Arc<dyn VersionHandler>,
298 public: Option<PathBuf>,
299 short_description: String,
300 long_description: String,
301}
302
303pub(crate) trait VersionHandler:
304 Send + Sync + Fn(&Accept, ApiVersion) -> Result<tide::Response, RouteError<Infallible>>
305{
306}
307impl<F> VersionHandler for F where
308 F: Send + Sync + Fn(&Accept, ApiVersion) -> Result<tide::Response, RouteError<Infallible>>
309{
310}
311
312impl<'a, State, Error> IntoIterator for &'a ApiInner<State, Error> {
313 type Item = &'a Route<State, Error>;
314 type IntoIter = Values<'a, String, Route<State, Error>>;
315
316 fn into_iter(self) -> Self::IntoIter {
317 self.routes.values()
318 }
319}
320
321impl<State, Error> IntoIterator for ApiInner<State, Error> {
322 type Item = Route<State, Error>;
323 type IntoIter = IntoValues<String, Route<State, Error>>;
324
325 fn into_iter(self) -> Self::IntoIter {
326 self.routes.into_values()
327 }
328}
329
330impl<State, Error> Index<&str> for ApiInner<State, Error> {
331 type Output = Route<State, Error>;
332
333 fn index(&self, index: &str) -> &Route<State, Error> {
334 &self.routes[index]
335 }
336}
337
338/// Iterator for [routes_by_path](ApiInner::routes_by_path).
339///
340/// This type iterates over all of the routes that have a given path.
341/// [routes_by_path](ApiInner::routes_by_path), in turn, returns an iterator over paths whose items
342/// contain a [RoutesWithPath] iterator.
343pub(crate) struct RoutesWithPath<'a, State, Error> {
344 routes: std::slice::Iter<'a, String>,
345 api: &'a ApiInner<State, Error>,
346}
347
348impl<'a, State, Error> Iterator for RoutesWithPath<'a, State, Error> {
349 type Item = &'a Route<State, Error>;
350
351 fn next(&mut self) -> Option<Self::Item> {
352 Some(&self.api.routes[self.routes.next()?])
353 }
354}
355
356impl<State, Error> ApiInner<State, Error> {
357 /// Iterate over groups of routes with the same path.
358 pub(crate) fn routes_by_path(
359 &self,
360 ) -> impl Iterator<Item = (&str, RoutesWithPath<'_, State, Error>)> {
361 self.routes_by_path.iter().map(|(path, routes)| {
362 (
363 path.as_str(),
364 RoutesWithPath {
365 routes: routes.iter(),
366 api: self,
367 },
368 )
369 })
370 }
371
372 /// Check the health status of a server with the given state.
373 pub(crate) async fn health(&self, req: RequestParams, state: &State) -> tide::Response {
374 (self.health_check)(req, state).await
375 }
376
377 /// Get the version of this API.
378 pub(crate) fn version(&self) -> ApiVersion {
379 ApiVersion {
380 api_version: self.api_version.clone(),
381 spec_version: self.meta.format_version.clone(),
382 }
383 }
384
385 pub(crate) fn public(&self) -> Option<&PathBuf> {
386 self.public.as_ref()
387 }
388
389 pub(crate) fn set_name(&mut self, name: String) {
390 self.name = name;
391 }
392
393 /// Compose an HTML page documenting all the routes in this API.
394 pub(crate) fn documentation(&self) -> Html {
395 html! {
396 (PreEscaped(self.meta.html_top
397 .replace("{{NAME}}", &self.name)
398 .replace("{{SHORT_DESCRIPTION}}", &self.short_description)
399 .replace("{{LONG_DESCRIPTION}}", &self.long_description)
400 .replace("{{VERSION}}", &match &self.api_version {
401 Some(version) => version.to_string(),
402 None => "(no version)".to_string(),
403 })
404 .replace("{{FORMAT_VERSION}}", &self.meta.format_version.to_string())
405 .replace("{{PUBLIC}}", &format!("/public/{}", self.name))))
406 @for route in self.routes.values() {
407 (route.documentation())
408 }
409 (PreEscaped(&self.meta.html_bottom))
410 }
411 }
412
413 /// The short description of this API from the specification.
414 pub(crate) fn short_description(&self) -> &str {
415 &self.short_description
416 }
417
418 pub(crate) fn error_handler(&self) -> Arc<dyn ErrorHandler<Error>> {
419 self.error_handler.clone().unwrap()
420 }
421
422 pub(crate) fn version_handler(&self) -> Arc<dyn VersionHandler> {
423 self.version_handler.clone()
424 }
425}
426
427impl<State, Error, VER> Api<State, Error, VER>
428where
429 State: 'static,
430 Error: 'static,
431 VER: StaticVersionType + 'static,
432{
433 /// Parse an API from a TOML specification.
434 pub fn new(api: impl Into<toml::Value>) -> Result<Self, ApiError> {
435 let mut api = api.into();
436 let meta = match api
437 .as_table_mut()
438 .context(ApiMustBeTableSnafu)?
439 .remove("meta")
440 {
441 Some(meta) => toml::Value::try_into(meta)
442 .map_err(|source| ApiError::InvalidMetaTable { source })?,
443 None => ApiMetadata::default(),
444 };
445 let meta = Arc::new(meta);
446 let routes = match api.get("route") {
447 Some(routes) => routes.as_table().context(RoutesMustBeTableSnafu)?,
448 None => return Err(ApiError::MissingRoutesTable),
449 };
450 // Collect routes into a [HashMap] indexed by route name.
451 let routes = routes
452 .into_iter()
453 .map(|(name, spec)| {
454 let route = Route::new(name.clone(), spec, meta.clone()).context(RouteSnafu)?;
455 Ok((route.name(), route))
456 })
457 .collect::<Result<HashMap<_, _>, _>>()?;
458 // Collect routes into groups of route names indexed by route pattern.
459 let mut routes_by_path = HashMap::new();
460 for route in routes.values() {
461 for path in route.patterns() {
462 match routes_by_path.entry(path.clone()) {
463 Entry::Vacant(e) => e.insert(Vec::new()).push(route.name().clone()),
464 Entry::Occupied(mut e) => {
465 // If there is already a route with this path and method, then dispatch is
466 // ambiguous.
467 if let Some(ambiguous_name) = e
468 .get()
469 .iter()
470 .find(|name| routes[*name].method() == route.method())
471 {
472 return Err(ApiError::AmbiguousRoutes {
473 route1: route.name(),
474 route2: ambiguous_name.clone(),
475 });
476 }
477 e.get_mut().push(route.name());
478 }
479 }
480 }
481 }
482
483 // Parse description: the first line is a short description, to display when briefly
484 // describing this API in a list. The rest is the long description, to display on this API's
485 // own documentation page. Both are rendered to HTML via Markdown.
486 let blocks = markdown::tokenize(&meta.description);
487 let (short_description, long_description) = match blocks.split_first() {
488 Some((short, long)) => {
489 let render = |blocks| markdown::to_html(&markdown::generate_markdown(blocks));
490
491 let short = render(vec![short.clone()]);
492 let long = render(long.to_vec());
493
494 // The short description is only one block, and sometimes we would like to display
495 // it inline (as a `span`). Markdown automatically wraps blocks in `<p>`. We will
496 // strip this outer tag so that we can wrap it in either `<p>` or `<span>`,
497 // depending on the context.
498 let short = short.strip_prefix("<p>").unwrap_or(&short);
499 let short = short.strip_suffix("</p>").unwrap_or(short);
500 let short = short.to_string();
501
502 (short, long)
503 }
504 None => Default::default(),
505 };
506
507 Ok(Self {
508 inner: ApiInner {
509 name: meta.name.clone(),
510 meta,
511 routes,
512 routes_by_path,
513 health_check: Box::new(Self::default_health_check),
514 api_version: None,
515 error_handler: None,
516 version_handler: Arc::new(|accept, version| {
517 respond_with(accept, version, VER::instance())
518 }),
519 public: None,
520 short_description,
521 long_description,
522 },
523 _version: Default::default(),
524 })
525 }
526
527 /// Create an [Api] by reading a TOML specification from a file.
528 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ApiError> {
529 let bytes = fs::read(path).map_err(|err| ApiError::CannotReadToml {
530 reason: err.to_string(),
531 })?;
532 let string = std::str::from_utf8(&bytes).map_err(|err| ApiError::CannotReadToml {
533 reason: err.to_string(),
534 })?;
535 Self::new(toml::from_str::<toml::Value>(string).map_err(|err| {
536 ApiError::CannotReadToml {
537 reason: err.to_string(),
538 }
539 })?)
540 }
541
542 /// Set the API version.
543 ///
544 /// The version information will automatically be included in responses to `GET /version`. This
545 /// version can also be used to serve multiple major versions of the same API simultaneously,
546 /// under a version prefix. For more information, see
547 /// [App::register_module](crate::App::register_module).
548 ///
549 /// This is the version of the application or sub-application which this instance of [Api]
550 /// represents. The versioning corresponds to the API specification passed to [new](Api::new),
551 /// and may be different from the version of the Rust crate implementing the route handlers for
552 /// the API.
553 pub fn with_version(&mut self, version: Version) -> &mut Self {
554 self.inner.api_version = Some(version);
555 self
556 }
557
558 /// Serve the contents of `dir` at the URL `/public/{{NAME}}`.
559 pub fn with_public(&mut self, dir: PathBuf) -> &mut Self {
560 self.inner.public = Some(dir);
561 self
562 }
563
564 /// Register a handler for a route.
565 ///
566 /// When the server receives a request whose URL matches the pattern of the route `name`,
567 /// `handler` will be invoked with the parameters of the request and a reference to the current
568 /// state, and the result will be serialized into a response.
569 ///
570 /// # Examples
571 ///
572 /// A simple getter route for a state object.
573 ///
574 /// `api.toml`
575 ///
576 /// ```toml
577 /// [route.getstate]
578 /// PATH = ["/getstate"]
579 /// DOC = "Gets the current state."
580 /// ```
581 ///
582 /// ```
583 /// use futures::FutureExt;
584 /// # use tide_disco::Api;
585 /// # use vbs::version::StaticVersion;
586 ///
587 /// type State = u64;
588 /// type StaticVer01 = StaticVersion<0, 1>;
589 ///
590 /// # fn ex(api: &mut Api<State, (), StaticVer01>) {
591 /// api.at("getstate", |req, state| async { Ok(*state) }.boxed());
592 /// # }
593 /// ```
594 ///
595 /// A counter endpoint which increments a mutable state. Notice how we use `METHOD = "POST"` to
596 /// ensure that the HTTP method for this route is compatible with mutable access.
597 ///
598 /// `api.toml`
599 ///
600 /// ```toml
601 /// [route.increment]
602 /// PATH = ["/increment"]
603 /// METHOD = "POST"
604 /// DOC = "Increment the current state and return the new value."
605 /// ```
606 ///
607 /// ```
608 /// use async_std::sync::Mutex;
609 /// use futures::FutureExt;
610 /// # use tide_disco::Api;
611 /// # use vbs::version::StaticVersion;
612 ///
613 /// type State = Mutex<u64>;
614 /// type StaticVer01 = StaticVersion<0, 1>;
615 ///
616 /// # fn ex(api: &mut Api<State, (), StaticVer01>) {
617 /// api.at("increment", |req, state| async {
618 /// let mut guard = state.lock().await;
619 /// *guard += 1;
620 /// Ok(*guard)
621 /// }.boxed());
622 /// # }
623 /// ```
624 ///
625 /// # Warnings
626 /// The route will use the HTTP method specified in the TOML specification for the named route
627 /// (or GET if the method is not specified). Some HTTP methods imply constraints on mutability.
628 /// For example, GET routes must be "pure", and not mutate any server state. Violating this
629 /// constraint may lead to confusing and unpredictable behavior. If the `State` type has
630 /// interior mutability (for instance, [RwLock](async_std::sync::RwLock)) it is up to the
631 /// `handler` not to use the interior mutability if the HTTP method suggests it shouldn't.
632 ///
633 /// If you know the HTTP method when you are registering the route, we recommend you use the
634 /// safer versions of this function, which enforce the appropriate mutability constraints. For
635 /// example,
636 /// * [get](Self::get)
637 /// * [post](Self::post)
638 /// * [put](Self::put)
639 /// * [delete](Self::delete)
640 ///
641 /// # Errors
642 ///
643 /// If the route `name` does not exist in the API specification, or if the route already has a
644 /// handler registered, an error is returned. Note that all routes are initialized with a
645 /// default handler that echoes parameters and shows documentation, but this default handler can
646 /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
647 ///
648 /// If the route `name` exists, but it is not an HTTP route (for example, `METHOD = "SOCKET"`
649 /// was used when defining the route in the API specification), [ApiError::IncorrectMethod] is
650 /// returned.
651 ///
652 /// # Limitations
653 ///
654 /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
655 /// handler function is required to return a [BoxFuture].
656 pub fn at<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
657 where
658 F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxFuture<'_, Result<T, Error>>,
659 T: Serialize,
660 State: 'static + Send + Sync,
661 VER: 'static + Send + Sync,
662 {
663 let route = self
664 .inner
665 .routes
666 .get_mut(name)
667 .ok_or(ApiError::UndefinedRoute)?;
668 if route.has_handler() {
669 return Err(ApiError::HandlerAlreadyRegistered);
670 }
671
672 if !route.method().is_http() {
673 return Err(ApiError::IncorrectMethod {
674 // Just pick any HTTP method as the expected method.
675 expected: Method::get(),
676 actual: route.method(),
677 });
678 }
679
680 // `set_fn_handler` only fails if the route is not an HTTP route; since we have already
681 // checked that it is, this cannot fail.
682 route
683 .set_fn_handler(handler, VER::instance())
684 .unwrap_or_else(|_| panic!("unexpected failure in set_fn_handler"));
685
686 Ok(self)
687 }
688
689 fn method_immutable<F, T>(
690 &mut self,
691 method: Method,
692 name: &str,
693 handler: F,
694 ) -> Result<&mut Self, ApiError>
695 where
696 F: 'static
697 + Send
698 + Sync
699 + Fn(RequestParams, &<State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
700 T: Serialize,
701 State: 'static + Send + Sync + ReadState,
702 VER: 'static + Send + Sync + StaticVersionType,
703 {
704 assert!(method.is_http() && !method.is_mutable());
705 let route = self
706 .inner
707 .routes
708 .get_mut(name)
709 .ok_or(ApiError::UndefinedRoute)?;
710 if route.method() != method {
711 return Err(ApiError::IncorrectMethod {
712 expected: method,
713 actual: route.method(),
714 });
715 }
716 if route.has_handler() {
717 return Err(ApiError::HandlerAlreadyRegistered);
718 }
719 // `set_handler` only fails if the route is not an HTTP route; since we have already checked
720 // that it is, this cannot fail.
721 route
722 .set_handler(ReadHandler::<_, VER>::from(handler))
723 .unwrap_or_else(|_| panic!("unexpected failure in set_handler"));
724 Ok(self)
725 }
726
727 /// Register a handler for a GET route.
728 ///
729 /// When the server receives a GET request whose URL matches the pattern of the route `name`,
730 /// `handler` will be invoked with the parameters of the request and immutable access to the
731 /// current state, and the result will be serialized into a response.
732 ///
733 /// The [ReadState] trait is used to acquire immutable access to the state, so the state
734 /// reference passed to `handler` is actually [`<State as ReadState>::State`](ReadState::State).
735 /// For example, if `State` is `RwLock<T>`, the lock will automatically be acquired for reading,
736 /// and the handler will be passed a `&T`.
737 ///
738 /// # Examples
739 ///
740 /// A simple getter route for a state object.
741 ///
742 /// `api.toml`
743 ///
744 /// ```toml
745 /// [route.getstate]
746 /// PATH = ["/getstate"]
747 /// DOC = "Gets the current state."
748 /// ```
749 ///
750 /// ```
751 /// use async_std::sync::RwLock;
752 /// use futures::FutureExt;
753 /// # use tide_disco::Api;
754 /// # use vbs::{Serializer, version::StaticVersion};
755 ///
756 /// type State = RwLock<u64>;
757 /// type StaticVer01 = StaticVersion<0, 1>;
758 ///
759 /// # fn ex(api: &mut Api<State, (), StaticVer01>) {
760 /// api.get("getstate", |req, state| async { Ok(*state) }.boxed());
761 /// # }
762 /// ```
763 ///
764 /// # Errors
765 ///
766 /// If the route `name` does not exist in the API specification, or if the route already has a
767 /// handler registered, an error is returned. Note that all routes are initialized with a
768 /// default handler that echoes parameters and shows documentation, but this default handler can
769 /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
770 ///
771 /// If the route `name` exists, but the method is not GET (that is, `METHOD = "M"` was used in
772 /// the route definition in `api.toml`, with `M` other than `GET`) the error
773 /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
774 ///
775 /// # Limitations
776 ///
777 /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
778 /// handler function is required to return a [BoxFuture].
779 pub fn get<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
780 where
781 F: 'static
782 + Send
783 + Sync
784 + Fn(RequestParams, &<State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
785 T: Serialize,
786 State: 'static + Send + Sync + ReadState,
787 VER: 'static + Send + Sync,
788 {
789 self.method_immutable(Method::get(), name, handler)
790 }
791
792 fn method_mutable<F, T>(
793 &mut self,
794 method: Method,
795 name: &str,
796 handler: F,
797 ) -> Result<&mut Self, ApiError>
798 where
799 F: 'static
800 + Send
801 + Sync
802 + Fn(RequestParams, &mut <State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
803 T: Serialize,
804 State: 'static + Send + Sync + WriteState,
805 VER: 'static + Send + Sync,
806 {
807 assert!(method.is_http() && method.is_mutable());
808 let route = self
809 .inner
810 .routes
811 .get_mut(name)
812 .ok_or(ApiError::UndefinedRoute)?;
813 if route.method() != method {
814 return Err(ApiError::IncorrectMethod {
815 expected: method,
816 actual: route.method(),
817 });
818 }
819 if route.has_handler() {
820 return Err(ApiError::HandlerAlreadyRegistered);
821 }
822
823 // `set_handler` only fails if the route is not an HTTP route; since we have already checked
824 // that it is, this cannot fail.
825 route
826 .set_handler(WriteHandler::<_, VER>::from(handler))
827 .unwrap_or_else(|_| panic!("unexpected failure in set_handler"));
828 Ok(self)
829 }
830
831 /// Register a handler for a POST route.
832 ///
833 /// When the server receives a POST request whose URL matches the pattern of the route `name`,
834 /// `handler` will be invoked with the parameters of the request and exclusive, mutable access
835 /// to the current state, and the result will be serialized into a response.
836 ///
837 /// The [WriteState] trait is used to acquire mutable access to the state, so the state
838 /// reference passed to `handler` is actually [`<State as ReadState>::State`](ReadState::State).
839 /// For example, if `State` is `RwLock<T>`, the lock will automatically be acquired for writing,
840 /// and the handler will be passed a `&mut T`.
841 ///
842 /// # Examples
843 ///
844 /// A counter endpoint which increments the state and returns the new state.
845 ///
846 /// `api.toml`
847 ///
848 /// ```toml
849 /// [route.increment]
850 /// PATH = ["/increment"]
851 /// METHOD = "POST"
852 /// DOC = "Increment the current state and return the new value."
853 /// ```
854 ///
855 /// ```
856 /// use async_std::sync::RwLock;
857 /// use futures::FutureExt;
858 /// # use tide_disco::Api;
859 /// # use vbs::version::StaticVersion;
860 ///
861 /// type State = RwLock<u64>;
862 /// type StaticVer01 = StaticVersion<0, 1>;
863 ///
864 /// # fn ex(api: &mut Api<State, (), StaticVer01>) {
865 /// api.post("increment", |req, state| async {
866 /// *state += 1;
867 /// Ok(*state)
868 /// }.boxed());
869 /// # }
870 /// ```
871 ///
872 /// # Errors
873 ///
874 /// If the route `name` does not exist in the API specification, or if the route already has a
875 /// handler registered, an error is returned. Note that all routes are initialized with a
876 /// default handler that echoes parameters and shows documentation, but this default handler can
877 /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
878 ///
879 /// If the route `name` exists, but the method is not POST (that is, `METHOD = "M"` was used in
880 /// the route definition in `api.toml`, with `M` other than `POST`) the error
881 /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
882 ///
883 /// # Limitations
884 ///
885 /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
886 /// handler function is required to return a [BoxFuture].
887 pub fn post<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
888 where
889 F: 'static
890 + Send
891 + Sync
892 + Fn(RequestParams, &mut <State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
893 T: Serialize,
894 State: 'static + Send + Sync + WriteState,
895 VER: 'static + Send + Sync,
896 {
897 self.method_mutable(Method::post(), name, handler)
898 }
899
900 /// Register a handler for a PUT route.
901 ///
902 /// When the server receives a PUT request whose URL matches the pattern of the route `name`,
903 /// `handler` will be invoked with the parameters of the request and exclusive, mutable access
904 /// to the current state, and the result will be serialized into a response.
905 ///
906 /// The [WriteState] trait is used to acquire mutable access to the state, so the state
907 /// reference passed to `handler` is actually [`<State as ReadState>::State`](ReadState::State).
908 /// For example, if `State` is `RwLock<T>`, the lock will automatically be acquired for writing,
909 /// and the handler will be passed a `&mut T`.
910 ///
911 /// # Examples
912 ///
913 /// An endpoint which replaces the current state with a new value.
914 ///
915 /// `api.toml`
916 ///
917 /// ```toml
918 /// [route.replace]
919 /// PATH = ["/replace/:new_state"]
920 /// METHOD = "PUT"
921 /// ":new_state" = "Integer"
922 /// DOC = "Set the state to `:new_state`."
923 /// ```
924 ///
925 /// ```
926 /// use async_std::sync::RwLock;
927 /// use futures::FutureExt;
928 /// # use tide_disco::Api;
929 /// # use vbs::version::StaticVersion;
930 ///
931 /// type State = RwLock<u64>;
932 /// type StaticVer01 = StaticVersion<0, 1>;
933 ///
934 /// # fn ex(api: &mut Api<State, tide_disco::RequestError, StaticVer01>) {
935 /// api.post("replace", |req, state| async move {
936 /// *state = req.integer_param("new_state")?;
937 /// Ok(())
938 /// }.boxed());
939 /// # }
940 /// ```
941 ///
942 /// # Errors
943 ///
944 /// If the route `name` does not exist in the API specification, or if the route already has a
945 /// handler registered, an error is returned. Note that all routes are initialized with a
946 /// default handler that echoes parameters and shows documentation, but this default handler can
947 /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
948 ///
949 /// If the route `name` exists, but the method is not PUT (that is, `METHOD = "M"` was used in
950 /// the route definition in `api.toml`, with `M` other than `PUT`) the error
951 /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
952 ///
953 /// # Limitations
954 ///
955 /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
956 /// handler function is required to return a [BoxFuture].
957 pub fn put<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
958 where
959 F: 'static
960 + Send
961 + Sync
962 + Fn(RequestParams, &mut <State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
963 T: Serialize,
964 State: 'static + Send + Sync + WriteState,
965 VER: 'static + Send + Sync,
966 {
967 self.method_mutable(Method::put(), name, handler)
968 }
969
970 /// Register a handler for a DELETE route.
971 ///
972 /// When the server receives a DELETE request whose URL matches the pattern of the route `name`,
973 /// `handler` will be invoked with the parameters of the request and exclusive, mutable access
974 /// to the current state, and the result will be serialized into a response.
975 ///
976 /// The [WriteState] trait is used to acquire mutable access to the state, so the state
977 /// reference passed to `handler` is actually [`<State as ReadState>::State`](ReadState::State).
978 /// For example, if `State` is `RwLock<T>`, the lock will automatically be acquired for writing,
979 /// and the handler will be passed a `&mut T`.
980 ///
981 /// # Examples
982 ///
983 /// An endpoint which clears the current state.
984 ///
985 /// `api.toml`
986 ///
987 /// ```toml
988 /// [route.state]
989 /// PATH = ["/state"]
990 /// METHOD = "DELETE"
991 /// DOC = "Clear the state."
992 /// ```
993 ///
994 /// ```
995 /// use async_std::sync::RwLock;
996 /// use futures::FutureExt;
997 /// # use tide_disco::Api;
998 /// # use vbs::version::StaticVersion;
999 ///
1000 /// type State = RwLock<Option<u64>>;
1001 /// type StaticVer01 = StaticVersion<0, 1>;
1002 ///
1003 /// # fn ex(api: &mut Api<State, (), StaticVer01>) {
1004 /// api.delete("state", |req, state| async {
1005 /// *state = None;
1006 /// Ok(())
1007 /// }.boxed());
1008 /// # }
1009 /// ```
1010 ///
1011 /// # Errors
1012 ///
1013 /// If the route `name` does not exist in the API specification, or if the route already has a
1014 /// handler registered, an error is returned. Note that all routes are initialized with a
1015 /// default handler that echoes parameters and shows documentation, but this default handler can
1016 /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
1017 ///
1018 /// If the route `name` exists, but the method is not DELETE (that is, `METHOD = "M"` was used
1019 /// in the route definition in `api.toml`, with `M` other than `DELETE`) the error
1020 /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
1021 ///
1022 /// # Limitations
1023 ///
1024 /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
1025 /// handler function is required to return a [BoxFuture].
1026 pub fn delete<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
1027 where
1028 F: 'static
1029 + Send
1030 + Sync
1031 + Fn(RequestParams, &mut <State as ReadState>::State) -> BoxFuture<'_, Result<T, Error>>,
1032 T: Serialize,
1033 State: 'static + Send + Sync + WriteState,
1034 VER: 'static + Send + Sync,
1035 {
1036 self.method_mutable(Method::delete(), name, handler)
1037 }
1038
1039 /// Register a handler for a SOCKET route.
1040 ///
1041 /// When the server receives any request whose URL matches the pattern for this route and which
1042 /// includes the WebSockets upgrade headers, the server will negotiate a protocol upgrade with
1043 /// the client, establishing a WebSockets connection, and then invoke `handler`. `handler` will
1044 /// be given the parameters of the request which initiated the connection and a reference to the
1045 /// application state, as well as a [Connection](socket::Connection) object which it can then
1046 /// use for asynchronous, bi-directional communication with the client.
1047 ///
1048 /// The server side of the connection will remain open as long as the future returned by
1049 /// `handler` is remains unresolved. The handler can terminate the connection by returning. If
1050 /// it returns an error, the error message will be included in the
1051 /// [CloseFrame](tide_websockets::tungstenite::protocol::CloseFrame) sent to the client when
1052 /// tearing down the connection.
1053 ///
1054 /// # Examples
1055 ///
1056 /// A socket endpoint which receives amounts from the client and returns a running sum.
1057 ///
1058 /// `api.toml`
1059 ///
1060 /// ```toml
1061 /// [route.sum]
1062 /// PATH = ["/sum"]
1063 /// METHOD = "SOCKET"
1064 /// DOC = "Stream a running sum."
1065 /// ```
1066 ///
1067 /// ```
1068 /// use futures::{FutureExt, SinkExt, StreamExt};
1069 /// use tide_disco::{error::ServerError, socket::Connection, Api};
1070 /// # use vbs::version::StaticVersion;
1071 ///
1072 /// # fn ex(api: &mut Api<(), ServerError, StaticVersion<0, 1>>) {
1073 /// api.socket("sum", |_req, mut conn: Connection<i32, i32, ServerError, StaticVersion<0, 1>>, _state| async move {
1074 /// let mut sum = 0;
1075 /// while let Some(amount) = conn.next().await {
1076 /// sum += amount?;
1077 /// conn.send(&sum).await?;
1078 /// }
1079 /// Ok(())
1080 /// }.boxed());
1081 /// # }
1082 /// ```
1083 ///
1084 /// In some cases, it may be desirable to handle messages to and from the client in separate
1085 /// tasks. There are two ways of doing this:
1086 ///
1087 /// ## Split the connection into separate stream and sink
1088 ///
1089 /// ```
1090 /// use async_std::task::spawn;
1091 /// use futures::{future::{join, FutureExt}, sink::SinkExt, stream::StreamExt};
1092 /// use tide_disco::{error::ServerError, socket::Connection, Api};
1093 /// # use vbs::version::StaticVersion;
1094 ///
1095 /// # fn ex(api: &mut Api<(), ServerError, StaticVersion<0, 1>>) {
1096 /// api.socket("endpoint", |_req, mut conn: Connection<i32, i32, ServerError, StaticVersion<0, 1>>, _state| async move {
1097 /// let (mut sink, mut stream) = conn.split();
1098 /// let recv = spawn(async move {
1099 /// while let Some(Ok(msg)) = stream.next().await {
1100 /// // Handle message from client.
1101 /// }
1102 /// });
1103 /// let send = spawn(async move {
1104 /// loop {
1105 /// let msg = // get message to send to client
1106 /// # 0;
1107 /// sink.send(msg).await;
1108 /// }
1109 /// });
1110 ///
1111 /// join(send, recv).await;
1112 /// Ok(())
1113 /// }.boxed());
1114 /// # }
1115 /// ```
1116 ///
1117 /// This approach requires messages to be sent to the client by value, consuming the message.
1118 /// This is because, if we were to use the `Sync<&ToClient>` implementation for `Connection`,
1119 /// the lifetime for `&ToClient` would be fixed after `split` is called, since the lifetime
1120 /// appears in the return type, `SplitSink<Connection<...>, &ToClient>`. Thus, this lifetime
1121 /// outlives any scoped local variables created after the `split` call, such as `msg` in the
1122 /// `loop`.
1123 ///
1124 /// If we want to use the message after sending it to the client, we would have to clone it,
1125 /// which may be inefficient or impossible. Thus, there is another approach:
1126 ///
1127 /// ## Clone the connection
1128 ///
1129 /// ```
1130 /// use async_std::task::spawn;
1131 /// use futures::{future::{join, FutureExt}, sink::SinkExt, stream::StreamExt};
1132 /// use tide_disco::{error::ServerError, socket::Connection, Api};
1133 /// # use vbs::version::StaticVersion;
1134 ///
1135 /// # fn ex(api: &mut Api<(), ServerError, StaticVersion<0, 1>>) {
1136 /// api.socket("endpoint", |_req, mut conn: Connection<i32, i32, ServerError, StaticVersion<0, 1>>, _state| async move {
1137 /// let recv = {
1138 /// let mut conn = conn.clone();
1139 /// spawn(async move {
1140 /// while let Some(Ok(msg)) = conn.next().await {
1141 /// // Handle message from client.
1142 /// }
1143 /// })
1144 /// };
1145 /// let send = spawn(async move {
1146 /// loop {
1147 /// let msg = // get message to send to client
1148 /// # 0;
1149 /// conn.send(&msg).await;
1150 /// // msg is still live at this point.
1151 /// drop(msg);
1152 /// }
1153 /// });
1154 ///
1155 /// join(send, recv).await;
1156 /// Ok(())
1157 /// }.boxed());
1158 /// # }
1159 /// ```
1160 ///
1161 /// Depending on the exact situation, this method may end up being more verbose than the
1162 /// previous example. But it allows us to retain the higher-ranked trait bound `conn: for<'a>
1163 /// Sink<&'a ToClient>` instead of fixing the lifetime, which can prevent an unnecessary clone
1164 /// in certain situations.
1165 ///
1166 /// # Errors
1167 ///
1168 /// If the route `name` does not exist in the API specification, or if the route already has a
1169 /// handler registered, an error is returned. Note that all routes are initialized with a
1170 /// default handler that echoes parameters and shows documentation, but this default handler can
1171 /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
1172 ///
1173 /// If the route `name` exists, but the method is not SOCKET (that is, `METHOD = "M"` was used
1174 /// in the route definition in `api.toml`, with `M` other than `SOCKET`) the error
1175 /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
1176 ///
1177 /// # Limitations
1178 ///
1179 /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
1180 /// handler function is required to return a [BoxFuture].
1181 pub fn socket<F, ToClient, FromClient>(
1182 &mut self,
1183 name: &str,
1184 handler: F,
1185 ) -> Result<&mut Self, ApiError>
1186 where
1187 F: 'static
1188 + Send
1189 + Sync
1190 + Fn(
1191 RequestParams,
1192 socket::Connection<ToClient, FromClient, Error, VER>,
1193 &State,
1194 ) -> BoxFuture<'_, Result<(), Error>>,
1195 ToClient: 'static + Serialize + ?Sized,
1196 FromClient: 'static + DeserializeOwned,
1197 State: 'static + Send + Sync,
1198 Error: 'static + Send + Display,
1199 {
1200 self.register_socket_handler(name, socket::handler(handler))
1201 }
1202
1203 /// Register a uni-directional handler for a SOCKET route.
1204 ///
1205 /// This function is very similar to [socket](Self::socket), but it permits the handler only to
1206 /// send messages to the client, not to receive messages back. As such, the handler does not
1207 /// take a [Connection](socket::Connection). Instead, it simply returns a stream of messages
1208 /// which are forwarded to the client as they are generated. If the stream ever yields an error,
1209 /// the error is propagated to the client and then the connection is closed.
1210 ///
1211 /// This function can be simpler to use than [socket](Self::socket) in case the handler does not
1212 /// need to receive messages from the client.
1213 pub fn stream<F, Msg>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
1214 where
1215 F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxStream<Result<Msg, Error>>,
1216 Msg: 'static + Serialize + Send + Sync,
1217 State: 'static + Send + Sync,
1218 Error: 'static + Send + Display,
1219 VER: 'static + Send + Sync,
1220 {
1221 self.register_socket_handler(name, socket::stream_handler::<_, _, _, _, VER>(handler))
1222 }
1223
1224 fn register_socket_handler(
1225 &mut self,
1226 name: &str,
1227 handler: socket::Handler<State, Error>,
1228 ) -> Result<&mut Self, ApiError> {
1229 let route = self
1230 .inner
1231 .routes
1232 .get_mut(name)
1233 .ok_or(ApiError::UndefinedRoute)?;
1234 if route.method() != Method::Socket {
1235 return Err(ApiError::IncorrectMethod {
1236 expected: Method::Socket,
1237 actual: route.method(),
1238 });
1239 }
1240 if route.has_handler() {
1241 return Err(ApiError::HandlerAlreadyRegistered);
1242 }
1243
1244 // `set_handler` only fails if the route is not a socket route; since we have already
1245 // checked that it is, this cannot fail.
1246 route
1247 .set_socket_handler(handler)
1248 .unwrap_or_else(|_| panic!("unexpected failure in set_socket_handler"));
1249 Ok(self)
1250 }
1251
1252 /// Register a handler for a METRICS route.
1253 ///
1254 /// When the server receives any request whose URL matches the pattern for this route and whose
1255 /// headers indicate it is a request for metrics, the server will invoke this `handler` instead
1256 /// of the regular HTTP handler for the endpoint. Instead of returning a typed object to
1257 /// serialize, `handler` will return a [Metrics] object which will be serialized to plaintext
1258 /// using the Prometheus format.
1259 ///
1260 /// A request is considered a request for metrics, for the purpose of dispatching to this
1261 /// handler, if the method is GET and the `Accept` header specifies `text/plain` as a better
1262 /// response type than `application/json` and `application/octet-stream` (other Tide Disco
1263 /// handlers respond to the content types `application/json` or `application/octet-stream`). As
1264 /// a special case, a request with no `Accept` header or `Accept: *` will return metrics when
1265 /// there is a metrics route matching the request URL, since metrics are given priority over
1266 /// other content types when multiple routes match the URL.
1267 ///
1268 /// # Examples
1269 ///
1270 /// A metrics endpoint which keeps track of how many times it has been called.
1271 ///
1272 /// `api.toml`
1273 ///
1274 /// ```toml
1275 /// [route.metrics]
1276 /// PATH = ["/metrics"]
1277 /// METHOD = "METRICS"
1278 /// DOC = "Export Prometheus metrics."
1279 /// ```
1280 ///
1281 /// ```
1282 /// # use async_std::sync::Mutex;
1283 /// # use futures::FutureExt;
1284 /// # use tide_disco::{api::{Api, ApiError}, error::ServerError};
1285 /// # use std::borrow::Cow;
1286 /// # use vbs::version::StaticVersion;
1287 /// use prometheus::{Counter, Registry};
1288 ///
1289 /// struct State {
1290 /// counter: Counter,
1291 /// metrics: Registry,
1292 /// }
1293 /// type StaticVer01 = StaticVersion<0, 1>;
1294 ///
1295 /// # fn ex(_api: Api<Mutex<State>, ServerError, StaticVer01>) -> Result<(), ApiError> {
1296 /// let mut api: Api<Mutex<State>, ServerError, StaticVer01>;
1297 /// # api = _api;
1298 /// api.metrics("metrics", |_req, state| async move {
1299 /// state.counter.inc();
1300 /// Ok(Cow::Borrowed(&state.metrics))
1301 /// }.boxed())?;
1302 /// # Ok(())
1303 /// # }
1304 /// ```
1305 //
1306 /// # Errors
1307 ///
1308 /// If the route `name` does not exist in the API specification, or if the route already has a
1309 /// handler registered, an error is returned. Note that all routes are initialized with a
1310 /// default handler that echoes parameters and shows documentation, but this default handler can
1311 /// replaced by this function without raising [ApiError::HandlerAlreadyRegistered].
1312 ///
1313 /// If the route `name` exists, but the method is not METRICS (that is, `METHOD = "M"` was used
1314 /// in the route definition in `api.toml`, with `M` other than `METRICS`) the error
1315 /// [IncorrectMethod](ApiError::IncorrectMethod) is returned.
1316 ///
1317 /// # Limitations
1318 ///
1319 /// [Like many function parameters](crate#boxed-futures) in [tide_disco](crate), the
1320 /// handler function is required to return a [BoxFuture].
1321 pub fn metrics<F, T>(&mut self, name: &str, handler: F) -> Result<&mut Self, ApiError>
1322 where
1323 F: 'static
1324 + Send
1325 + Sync
1326 + Fn(RequestParams, &State::State) -> BoxFuture<Result<Cow<T>, Error>>,
1327 T: 'static + Clone + Metrics,
1328 State: 'static + Send + Sync + ReadState,
1329 Error: 'static,
1330 VER: 'static + Send + Sync,
1331 {
1332 let route = self
1333 .inner
1334 .routes
1335 .get_mut(name)
1336 .ok_or(ApiError::UndefinedRoute)?;
1337 if route.method() != Method::Metrics {
1338 return Err(ApiError::IncorrectMethod {
1339 expected: Method::Metrics,
1340 actual: route.method(),
1341 });
1342 }
1343 if route.has_handler() {
1344 return Err(ApiError::HandlerAlreadyRegistered);
1345 }
1346 // `set_metrics_handler` only fails if the route is not a metrics route; since we have
1347 // already checked that it is, this cannot fail.
1348 route
1349 .set_metrics_handler(handler)
1350 .unwrap_or_else(|_| panic!("unexpected failure in set_metrics_handler"));
1351 Ok(self)
1352 }
1353
1354 /// Set the health check handler for this API.
1355 ///
1356 /// This overrides the existing handler. If `health_check` has not yet been called, the default
1357 /// handler is one which simply returns `Health::default()`.
1358 pub fn with_health_check<H>(
1359 &mut self,
1360 handler: impl 'static + Send + Sync + Fn(&State) -> BoxFuture<H>,
1361 ) -> &mut Self
1362 where
1363 State: 'static + Send + Sync,
1364 H: 'static + HealthCheck,
1365 VER: 'static + Send + Sync,
1366 {
1367 self.inner.health_check = route::health_check_handler::<_, _, VER>(handler);
1368 self
1369 }
1370
1371 /// Create a new [Api] which is just like this one, except has a transformed `Error` type.
1372 pub(crate) fn map_err<Error2>(
1373 self,
1374 f: impl 'static + Clone + Send + Sync + Fn(Error) -> Error2,
1375 ) -> Api<State, Error2, VER>
1376 where
1377 Error: 'static + Send + Sync,
1378 Error2: 'static,
1379 State: 'static + Send + Sync,
1380 {
1381 Api {
1382 inner: ApiInner {
1383 meta: self.inner.meta,
1384 name: self.inner.name,
1385 routes: self
1386 .inner
1387 .routes
1388 .into_iter()
1389 .map(|(name, route)| (name, route.map_err(f.clone())))
1390 .collect(),
1391 routes_by_path: self.inner.routes_by_path,
1392 health_check: self.inner.health_check,
1393 api_version: self.inner.api_version,
1394 error_handler: None,
1395 version_handler: self.inner.version_handler,
1396 public: self.inner.public,
1397 short_description: self.inner.short_description,
1398 long_description: self.inner.long_description,
1399 },
1400 _version: Default::default(),
1401 }
1402 }
1403
1404 pub(crate) fn into_inner(mut self) -> ApiInner<State, Error>
1405 where
1406 Error: crate::Error,
1407 {
1408 // This `into_inner` finalizes the error type for the API. At this point, ensure
1409 // `error_handler` is set.
1410 self.inner.error_handler = Some(error_handler::<Error, VER>());
1411 self.inner
1412 }
1413
1414 fn default_health_check(req: RequestParams, _state: &State) -> BoxFuture<'_, tide::Response> {
1415 async move {
1416 // If there is no healthcheck handler registered, just return [HealthStatus::Available]
1417 // by default; after all, if this handler is getting hit at all, the service must be up.
1418 route::health_check_response::<_, VER>(
1419 &req.accept().unwrap_or_else(|_| {
1420 // The healthcheck endpoint is not allowed to fail, so just use the default
1421 // content type if we can't parse the Accept header.
1422 let mut accept = Accept::new();
1423 accept.set_wildcard(true);
1424 accept
1425 }),
1426 HealthStatus::Available,
1427 )
1428 }
1429 .boxed()
1430 }
1431}
1432
1433// `ReadHandler { handler }` essentially represents a handler function
1434// `move |req, state| async { state.read(|state| handler(req, state)).await.await }`. However, I
1435// cannot convince Rust that the future returned by this closure moves out of `req` while borrowing
1436// from `handler`, which is owned by the closure itself and thus outlives the closure body. This is
1437// partly due to the limitation where _all_ closure parameters must be captured either by value or
1438// by reference, and probably partly due to my lack of creativity. In any case, writing out the
1439// closure object and [Handler] implementation by hand seems to convince Rust that this code is
1440// memory safe.
1441struct ReadHandler<F, VER> {
1442 handler: F,
1443 _version: PhantomData<VER>,
1444}
1445
1446impl<F, VER> From<F> for ReadHandler<F, VER> {
1447 fn from(f: F) -> Self {
1448 Self {
1449 handler: f,
1450 _version: Default::default(),
1451 }
1452 }
1453}
1454
1455#[async_trait]
1456impl<State, Error, F, R, VER> Handler<State, Error> for ReadHandler<F, VER>
1457where
1458 F: 'static
1459 + Send
1460 + Sync
1461 + Fn(RequestParams, &<State as ReadState>::State) -> BoxFuture<'_, Result<R, Error>>,
1462 R: Serialize,
1463 State: 'static + Send + Sync + ReadState,
1464 VER: 'static + Send + Sync + StaticVersionType,
1465{
1466 async fn handle(
1467 &self,
1468 req: RequestParams,
1469 state: &State,
1470 ) -> Result<tide::Response, RouteError<Error>> {
1471 let accept = req.accept()?;
1472 response_from_result(
1473 &accept,
1474 state.read(|state| (self.handler)(req, state)).await,
1475 VER::instance(),
1476 )
1477 }
1478}
1479
1480// A manual closure that serves a similar purpose as [ReadHandler].
1481struct WriteHandler<F, VER> {
1482 handler: F,
1483 _version: PhantomData<VER>,
1484}
1485
1486impl<F, VER> From<F> for WriteHandler<F, VER> {
1487 fn from(f: F) -> Self {
1488 Self {
1489 handler: f,
1490 _version: Default::default(),
1491 }
1492 }
1493}
1494
1495#[async_trait]
1496impl<State, Error, F, R, VER> Handler<State, Error> for WriteHandler<F, VER>
1497where
1498 F: 'static
1499 + Send
1500 + Sync
1501 + Fn(RequestParams, &mut <State as ReadState>::State) -> BoxFuture<'_, Result<R, Error>>,
1502 R: Serialize,
1503 State: 'static + Send + Sync + WriteState,
1504 VER: 'static + Send + Sync + StaticVersionType,
1505{
1506 async fn handle(
1507 &self,
1508 req: RequestParams,
1509 state: &State,
1510 ) -> Result<tide::Response, RouteError<Error>> {
1511 let accept = req.accept()?;
1512 response_from_result(
1513 &accept,
1514 state.write(|state| (self.handler)(req, state)).await,
1515 VER::instance(),
1516 )
1517 }
1518}
1519
1520#[cfg(test)]
1521mod test {
1522 use crate::{
1523 error::{Error, ServerError},
1524 healthcheck::HealthStatus,
1525 socket::Connection,
1526 testing::{setup_test, test_ws_client, test_ws_client_with_headers, Client},
1527 App, StatusCode, Url,
1528 };
1529 use async_std::{sync::RwLock, task::spawn};
1530 use async_tungstenite::{
1531 tungstenite::{http::header::*, protocol::frame::coding::CloseCode, protocol::Message},
1532 WebSocketStream,
1533 };
1534 use futures::{
1535 stream::{iter, once, repeat},
1536 AsyncRead, AsyncWrite, FutureExt, SinkExt, StreamExt,
1537 };
1538 use portpicker::pick_unused_port;
1539 use prometheus::{Counter, Registry};
1540 use std::borrow::Cow;
1541 use toml::toml;
1542 use vbs::{
1543 version::{StaticVersion, StaticVersionType},
1544 BinarySerializer, Serializer,
1545 };
1546
1547 #[cfg(windows)]
1548 use async_tungstenite::tungstenite::Error as WsError;
1549 #[cfg(windows)]
1550 use std::io::ErrorKind;
1551
1552 type StaticVer01 = StaticVersion<0, 1>;
1553 type SerializerV01 = Serializer<StaticVersion<0, 1>>;
1554
1555 async fn check_stream_closed<S>(mut conn: WebSocketStream<S>)
1556 where
1557 S: AsyncRead + AsyncWrite + Unpin,
1558 {
1559 let msg = conn.next().await;
1560
1561 #[cfg(not(windows))]
1562 assert!(msg.is_none(), "{:?}", msg);
1563
1564 // Windows doesn't handle shutdown very gracefully.
1565 #[cfg(windows)]
1566 match msg {
1567 None => {}
1568 Some(Err(WsError::Io(err))) if err.kind() == ErrorKind::ConnectionAborted => {}
1569 msg => panic!(
1570 "expected end of stream or ConnectionAborted error, got {:?}",
1571 msg
1572 ),
1573 }
1574 }
1575
1576 #[async_std::test]
1577 async fn test_socket_endpoint() {
1578 setup_test();
1579
1580 let mut app = App::<_, ServerError>::with_state(RwLock::new(()));
1581 let api_toml = toml! {
1582 [meta]
1583 FORMAT_VERSION = "0.1.0"
1584
1585 [route.echo]
1586 PATH = ["/echo"]
1587 METHOD = "SOCKET"
1588
1589 [route.once]
1590 PATH = ["/once"]
1591 METHOD = "SOCKET"
1592
1593 [route.error]
1594 PATH = ["/error"]
1595 METHOD = "SOCKET"
1596 };
1597 {
1598 let mut api = app
1599 .module::<ServerError, StaticVer01>("mod", api_toml)
1600 .unwrap();
1601 api.socket(
1602 "echo",
1603 |_req, mut conn: Connection<String, String, _, StaticVer01>, _state| {
1604 async move {
1605 while let Some(msg) = conn.next().await {
1606 conn.send(&msg?).await?;
1607 }
1608 Ok(())
1609 }
1610 .boxed()
1611 },
1612 )
1613 .unwrap()
1614 .socket(
1615 "once",
1616 |_req, mut conn: Connection<str, (), _, StaticVer01>, _state| {
1617 async move {
1618 conn.send("msg").boxed().await?;
1619 Ok(())
1620 }
1621 .boxed()
1622 },
1623 )
1624 .unwrap()
1625 .socket(
1626 "error",
1627 |_req, _conn: Connection<(), (), _, StaticVer01>, _state| {
1628 async move {
1629 Err(ServerError::catch_all(
1630 StatusCode::INTERNAL_SERVER_ERROR,
1631 "an error message".to_string(),
1632 ))
1633 }
1634 .boxed()
1635 },
1636 )
1637 .unwrap();
1638 }
1639 let port = pick_unused_port().unwrap();
1640 let url: Url = format!("http://localhost:{}", port).parse().unwrap();
1641 spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
1642
1643 // Create a client that accepts JSON messages.
1644 let mut conn = test_ws_client_with_headers(
1645 url.join("mod/echo").unwrap(),
1646 &[(ACCEPT, "application/json")],
1647 )
1648 .await;
1649
1650 // Send a JSON message.
1651 conn.send(Message::Text(serde_json::to_string("hello").unwrap()))
1652 .await
1653 .unwrap();
1654 assert_eq!(
1655 conn.next().await.unwrap().unwrap(),
1656 Message::Text(serde_json::to_string("hello").unwrap())
1657 );
1658
1659 // Send a binary message.
1660 conn.send(Message::Binary(
1661 SerializerV01::serialize("goodbye").unwrap(),
1662 ))
1663 .await
1664 .unwrap();
1665 assert_eq!(
1666 conn.next().await.unwrap().unwrap(),
1667 Message::Text(serde_json::to_string("goodbye").unwrap())
1668 );
1669
1670 // Create a client that accepts binary messages.
1671 let mut conn = test_ws_client_with_headers(
1672 url.join("mod/echo").unwrap(),
1673 &[(ACCEPT, "application/octet-stream")],
1674 )
1675 .await;
1676
1677 // Send a JSON message.
1678 conn.send(Message::Text(serde_json::to_string("hello").unwrap()))
1679 .await
1680 .unwrap();
1681 assert_eq!(
1682 conn.next().await.unwrap().unwrap(),
1683 Message::Binary(SerializerV01::serialize("hello").unwrap())
1684 );
1685
1686 // Send a binary message.
1687 conn.send(Message::Binary(
1688 SerializerV01::serialize("goodbye").unwrap(),
1689 ))
1690 .await
1691 .unwrap();
1692 assert_eq!(
1693 conn.next().await.unwrap().unwrap(),
1694 Message::Binary(SerializerV01::serialize("goodbye").unwrap())
1695 );
1696
1697 // Test a stream that exits normally.
1698 let mut conn = test_ws_client(url.join("mod/once").unwrap()).await;
1699 assert_eq!(
1700 conn.next().await.unwrap().unwrap(),
1701 Message::Text(serde_json::to_string("msg").unwrap())
1702 );
1703 match conn.next().await.unwrap().unwrap() {
1704 Message::Close(None) => {}
1705 msg => panic!("expected normal close frame, got {:?}", msg),
1706 };
1707 check_stream_closed(conn).await;
1708
1709 // Test a stream that errors.
1710 let mut conn = test_ws_client(url.join("mod/error").unwrap()).await;
1711 match conn.next().await.unwrap().unwrap() {
1712 Message::Close(Some(frame)) => {
1713 assert_eq!(frame.code, CloseCode::Error);
1714 assert_eq!(frame.reason, "Error 500: an error message");
1715 }
1716 msg => panic!("expected error close frame, got {:?}", msg),
1717 }
1718 check_stream_closed(conn).await;
1719 }
1720
1721 #[async_std::test]
1722 async fn test_stream_endpoint() {
1723 setup_test();
1724
1725 let mut app = App::<_, ServerError>::with_state(RwLock::new(()));
1726 let api_toml = toml! {
1727 [meta]
1728 FORMAT_VERSION = "0.1.0"
1729
1730 [route.nat]
1731 PATH = ["/nat"]
1732 METHOD = "SOCKET"
1733
1734 [route.once]
1735 PATH = ["/once"]
1736 METHOD = "SOCKET"
1737
1738 [route.error]
1739 PATH = ["/error"]
1740 METHOD = "SOCKET"
1741 };
1742 {
1743 let mut api = app
1744 .module::<ServerError, StaticVer01>("mod", api_toml)
1745 .unwrap();
1746 api.stream("nat", |_req, _state| iter(0..).map(Ok).boxed())
1747 .unwrap()
1748 .stream("once", |_req, _state| once(async { Ok(0) }).boxed())
1749 .unwrap()
1750 .stream::<_, ()>("error", |_req, _state| {
1751 // We intentionally return a stream that never terminates, to check that simply
1752 // yielding an error causes the connection to terminate.
1753 repeat(Err(ServerError::catch_all(
1754 StatusCode::INTERNAL_SERVER_ERROR,
1755 "an error message".to_string(),
1756 )))
1757 .boxed()
1758 })
1759 .unwrap();
1760 }
1761 let port = pick_unused_port().unwrap();
1762 let url: Url = format!("http://localhost:{}", port).parse().unwrap();
1763 spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
1764
1765 // Consume the `nat` stream.
1766 let mut conn = test_ws_client(url.join("mod/nat").unwrap()).await;
1767 for i in 0..100 {
1768 assert_eq!(
1769 conn.next().await.unwrap().unwrap(),
1770 Message::Text(serde_json::to_string(&i).unwrap())
1771 );
1772 }
1773
1774 // Test a finite stream.
1775 let mut conn = test_ws_client(url.join("mod/once").unwrap()).await;
1776 assert_eq!(
1777 conn.next().await.unwrap().unwrap(),
1778 Message::Text(serde_json::to_string(&0).unwrap())
1779 );
1780 match conn.next().await.unwrap().unwrap() {
1781 Message::Close(None) => {}
1782 msg => panic!("expected normal close frame, got {:?}", msg),
1783 }
1784 check_stream_closed(conn).await;
1785
1786 // Test a stream that errors.
1787 let mut conn = test_ws_client(url.join("mod/error").unwrap()).await;
1788 match conn.next().await.unwrap().unwrap() {
1789 Message::Close(Some(frame)) => {
1790 assert_eq!(frame.code, CloseCode::Error);
1791 assert_eq!(frame.reason, "Error 500: an error message");
1792 }
1793 msg => panic!("expected error close frame, got {:?}", msg),
1794 }
1795 check_stream_closed(conn).await;
1796 }
1797
1798 #[async_std::test]
1799 async fn test_custom_healthcheck() {
1800 setup_test();
1801
1802 let mut app = App::<_, ServerError>::with_state(HealthStatus::Available);
1803 let api_toml = toml! {
1804 [meta]
1805 FORMAT_VERSION = "0.1.0"
1806
1807 [route.dummy]
1808 PATH = ["/dummy"]
1809 };
1810 {
1811 let mut api = app
1812 .module::<ServerError, StaticVer01>("mod", api_toml)
1813 .unwrap();
1814 api.with_health_check(|state| async move { *state }.boxed());
1815 }
1816 let port = pick_unused_port().unwrap();
1817 let url: Url = format!("http://localhost:{}", port).parse().unwrap();
1818 spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
1819 let client = Client::new(url).await;
1820
1821 let res = client.get("/mod/healthcheck").send().await.unwrap();
1822 assert_eq!(res.status(), StatusCode::OK);
1823 assert_eq!(
1824 res.json::<HealthStatus>().await.unwrap(),
1825 HealthStatus::Available
1826 );
1827 }
1828
1829 #[async_std::test]
1830 async fn test_metrics_endpoint() {
1831 setup_test();
1832
1833 struct State {
1834 metrics: Registry,
1835 counter: Counter,
1836 }
1837
1838 let counter = Counter::new(
1839 "counter",
1840 "count of how many times metrics have been exported",
1841 )
1842 .unwrap();
1843 let metrics = Registry::new();
1844 metrics.register(Box::new(counter.clone())).unwrap();
1845 let state = State { metrics, counter };
1846
1847 let mut app = App::<_, ServerError>::with_state(RwLock::new(state));
1848 let api_toml = toml! {
1849 [meta]
1850 FORMAT_VERSION = "0.1.0"
1851
1852 [route.metrics]
1853 PATH = ["/metrics"]
1854 METHOD = "METRICS"
1855 };
1856 {
1857 let mut api = app
1858 .module::<ServerError, StaticVer01>("mod", api_toml)
1859 .unwrap();
1860 api.metrics("metrics", |_req, state| {
1861 async move {
1862 state.counter.inc();
1863 Ok(Cow::Borrowed(&state.metrics))
1864 }
1865 .boxed()
1866 })
1867 .unwrap();
1868 }
1869 let port = pick_unused_port().unwrap();
1870 let url: Url = format!("http://localhost:{port}").parse().unwrap();
1871 spawn(app.serve(format!("0.0.0.0:{port}"), StaticVer01::instance()));
1872 let client = Client::new(url).await;
1873
1874 for i in 1..5 {
1875 tracing::info!("making metrics request {i}");
1876 let expected = format!("# HELP counter count of how many times metrics have been exported\n# TYPE counter counter\ncounter {i}\n");
1877 let res = client.get("mod/metrics").send().await.unwrap();
1878 assert_eq!(res.status(), StatusCode::OK);
1879 assert_eq!(res.text().await.unwrap(), expected);
1880 }
1881 }
1882}