tide_disco/
testing.rs

1#![cfg(any(test, feature = "testing"))]
2
3use crate::{http::Method, wait_for_server, Url, SERVER_STARTUP_RETRIES, SERVER_STARTUP_SLEEP_MS};
4use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
5use async_tungstenite::{
6    async_std::{connect_async, ConnectStream},
7    tungstenite::{client::IntoClientRequest, http::header::*, Error as WsError},
8    WebSocketStream,
9};
10use reqwest::RequestBuilder;
11use std::time::Duration;
12
13pub struct Client {
14    inner: reqwest::Client,
15    base_url: Url,
16}
17
18impl Client {
19    pub async fn new(base_url: Url) -> Self {
20        wait_for_server(&base_url, SERVER_STARTUP_RETRIES, SERVER_STARTUP_SLEEP_MS).await;
21        Self {
22            inner: reqwest::Client::builder()
23                .timeout(Duration::from_secs(60))
24                .build()
25                .unwrap(),
26            base_url,
27        }
28    }
29
30    pub fn request(&self, method: Method, path: &str) -> RequestBuilder {
31        let req_method: reqwest::Method = method.to_string().parse().unwrap();
32        self.inner
33            .request(req_method, self.base_url.join(path).unwrap())
34    }
35
36    pub fn get(&self, path: &str) -> RequestBuilder {
37        self.request(Method::Get, path)
38    }
39
40    pub fn post(&self, path: &str) -> RequestBuilder {
41        self.request(Method::Post, path)
42    }
43}
44
45pub fn setup_test() {
46    setup_logging();
47    setup_backtrace();
48}
49
50pub async fn test_ws_client(url: Url) -> WebSocketStream<ConnectStream> {
51    test_ws_client_with_headers(url, &[]).await
52}
53
54pub async fn test_ws_client_with_headers(
55    mut url: Url,
56    headers: &[(HeaderName, &str)],
57) -> WebSocketStream<ConnectStream> {
58    wait_for_server(&url, SERVER_STARTUP_RETRIES, SERVER_STARTUP_SLEEP_MS).await;
59    url.set_scheme("ws").unwrap();
60
61    // Follow redirects.
62    loop {
63        let mut req = url.clone().into_client_request().unwrap();
64        for (name, value) in headers {
65            req.headers_mut().insert(name, value.parse().unwrap());
66        }
67
68        match connect_async(req).await {
69            Ok((conn, _)) => return conn,
70            Err(WsError::Http(res)) if (301..=308).contains(&u16::from(res.status())) => {
71                let location = res.headers()["location"].to_str().unwrap();
72                tracing::info!(from = %url, to = %location, "WS handshake following redirect");
73                url.set_path(location);
74            }
75            Err(err) => panic!("socket connection failed: {err}"),
76        }
77    }
78}