use crate::method::Method;
use serde::{Deserialize, Serialize};
use snafu::{OptionExt, Snafu};
use std::any::type_name;
use std::collections::HashMap;
use std::fmt::Display;
use strum_macros::EnumString;
use tagged_base64::TaggedBase64;
use tide::http::{self, content::Accept, mime::Mime, Headers};
use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
#[derive(Clone, Debug, Snafu, Deserialize, Serialize)]
pub enum RequestError {
#[snafu(display("missing required parameter: {}", name))]
MissingParam { name: String },
#[snafu(display(
"incorrect parameter type: {} cannot be converted to {}",
actual,
expected
))]
IncorrectParamType {
actual: RequestParamType,
expected: RequestParamType,
},
#[snafu(display("value {} is too large for type {}", value, expected))]
IntegerOverflow { value: u128, expected: String },
#[snafu(display("Unable to deserialize from JSON"))]
Json,
#[snafu(display("Unable to deserialize from binary"))]
Binary,
#[snafu(display("Unable to deserialise from tagged base 64: {}", reason))]
TaggedBase64 { reason: String },
#[snafu(display("Content type not specified or type not supported"))]
UnsupportedContentType,
#[snafu(display("HTTP protocol error: {}", reason))]
Http { reason: String },
#[snafu(display("error parsing {} parameter: {}", param_type, reason))]
InvalidParam { param_type: String, reason: String },
#[snafu(display("unexpected tag in TaggedBase64: {} (expected {})", actual, expected))]
TagMismatch { actual: String, expected: String },
}
#[derive(Clone, Debug)]
pub struct RequestParams {
req: http::Request,
post_data: Vec<u8>,
params: HashMap<String, RequestParamValue>,
}
impl RequestParams {
pub(crate) async fn new<S>(
mut req: tide::Request<S>,
formal_params: &[RequestParam],
) -> Result<Self, RequestError> {
Ok(Self {
post_data: req.body_bytes().await.unwrap(),
params: formal_params
.iter()
.filter_map(|param| match RequestParamValue::new(&req, param) {
Ok(None) => None,
Ok(Some(value)) => Some(Ok((param.name.clone(), value))),
Err(err) => Some(Err(err)),
})
.collect::<Result<_, _>>()?,
req: req.into(),
})
}
pub fn method(&self) -> Method {
self.req.method().into()
}
pub fn headers(&self) -> &Headers {
self.req.as_ref()
}
pub fn accept(&self) -> Result<Accept, RequestError> {
Self::accept_from_headers(self.headers())
}
pub(crate) fn accept_from_headers(
headers: impl AsRef<Headers>,
) -> Result<Accept, RequestError> {
match Accept::from_headers(headers).map_err(|err| RequestError::Http {
reason: err.to_string(),
})? {
Some(mut accept) => {
accept.sort();
Ok(accept)
}
None => {
let mut accept = Accept::new();
accept.set_wildcard(true);
Ok(accept)
}
}
}
pub fn remote(&self) -> Option<&str> {
self.req.remote()
}
pub fn param<Name>(&self, name: &Name) -> Result<&RequestParamValue, RequestError>
where
Name: ?Sized + Display,
{
self.opt_param(name).context(MissingParamSnafu {
name: name.to_string(),
})
}
pub fn opt_param<Name>(&self, name: &Name) -> Option<&RequestParamValue>
where
Name: ?Sized + Display,
{
self.params.get(&name.to_string())
}
pub fn integer_param<Name, T>(&self, name: &Name) -> Result<T, RequestError>
where
Name: ?Sized + Display,
T: TryFrom<u128>,
{
self.opt_integer_param(name)?.context(MissingParamSnafu {
name: name.to_string(),
})
}
pub fn opt_integer_param<Name, T>(&self, name: &Name) -> Result<Option<T>, RequestError>
where
Name: ?Sized + Display,
T: TryFrom<u128>,
{
self.opt_param(name).map(|val| val.as_integer()).transpose()
}
pub fn boolean_param<Name>(&self, name: &Name) -> Result<bool, RequestError>
where
Name: ?Sized + Display,
{
self.opt_boolean_param(name)?.context(MissingParamSnafu {
name: name.to_string(),
})
}
pub fn opt_boolean_param<Name>(&self, name: &Name) -> Result<Option<bool>, RequestError>
where
Name: ?Sized + Display,
{
self.opt_param(name).map(|val| val.as_boolean()).transpose()
}
pub fn string_param<Name>(&self, name: &Name) -> Result<&str, RequestError>
where
Name: ?Sized + Display,
{
self.opt_string_param(name)?.context(MissingParamSnafu {
name: name.to_string(),
})
}
pub fn opt_string_param<Name>(&self, name: &Name) -> Result<Option<&str>, RequestError>
where
Name: ?Sized + Display,
{
self.opt_param(name).map(|val| val.as_string()).transpose()
}
pub fn tagged_base64_param<Name>(&self, name: &Name) -> Result<&TaggedBase64, RequestError>
where
Name: ?Sized + Display,
{
self.opt_tagged_base64_param(name)?
.context(MissingParamSnafu {
name: name.to_string(),
})
}
pub fn opt_tagged_base64_param<Name>(
&self,
name: &Name,
) -> Result<Option<&TaggedBase64>, RequestError>
where
Name: ?Sized + Display,
{
self.opt_param(name)
.map(|val| val.as_tagged_base64())
.transpose()
}
pub fn blob_param<'a, Name, T>(&'a self, name: &Name) -> Result<T, RequestError>
where
Name: ?Sized + Display,
T: TryFrom<&'a TaggedBase64>,
<T as TryFrom<&'a TaggedBase64>>::Error: Display,
{
self.opt_blob_param(name)?.context(MissingParamSnafu {
name: name.to_string(),
})
}
pub fn opt_blob_param<'a, Name, T>(&'a self, name: &Name) -> Result<Option<T>, RequestError>
where
Name: ?Sized + Display,
T: TryFrom<&'a TaggedBase64>,
<T as TryFrom<&'a TaggedBase64>>::Error: Display,
{
self.opt_param(name).map(|val| val.as_blob()).transpose()
}
pub fn body_bytes(&self) -> Vec<u8> {
self.post_data.clone()
}
pub fn body_json<T>(&self) -> Result<T, RequestError>
where
T: serde::de::DeserializeOwned,
{
serde_json::from_slice(&self.post_data.clone()).map_err(|_| RequestError::Json {})
}
pub fn body_auto<T, VER: StaticVersionType>(&self, _: VER) -> Result<T, RequestError>
where
T: serde::de::DeserializeOwned,
{
if let Some(content_type) = self.headers().get("Content-Type") {
match content_type.as_str() {
"application/json" => self.body_json(),
"application/octet-stream" => {
let bytes = self.body_bytes();
Serializer::<VER>::deserialize(&bytes).map_err(|_err| RequestError::Binary {})
}
_content_type => Err(RequestError::UnsupportedContentType {}),
}
} else {
Err(RequestError::UnsupportedContentType {})
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum RequestParamValue {
Boolean(bool),
Hexadecimal(u128),
Integer(u128),
TaggedBase64(TaggedBase64),
Literal(String),
}
impl RequestParamValue {
pub fn new<S>(
req: &tide::Request<S>,
formal: &RequestParam,
) -> Result<Option<Self>, RequestError> {
if let Ok(param) = req.param(&formal.name) {
Self::parse(param, formal).map(Some)
} else {
Ok(None)
}
}
pub fn parse(s: &str, formal: &RequestParam) -> Result<Self, RequestError> {
match formal.param_type {
RequestParamType::Literal => Ok(RequestParamValue::Literal(s.to_string())),
RequestParamType::Boolean => Ok(RequestParamValue::Boolean(s.parse().map_err(
|err: std::str::ParseBoolError| RequestError::InvalidParam {
param_type: "Boolean".to_string(),
reason: err.to_string(),
},
)?)),
RequestParamType::Integer => Ok(RequestParamValue::Integer(s.parse().map_err(
|err: std::num::ParseIntError| RequestError::InvalidParam {
param_type: "Integer".to_string(),
reason: err.to_string(),
},
)?)),
RequestParamType::Hexadecimal => Ok(RequestParamValue::Hexadecimal(
s.parse()
.map_err(|err: std::num::ParseIntError| RequestError::InvalidParam {
param_type: "Hexadecimal".to_string(),
reason: err.to_string(),
})?,
)),
RequestParamType::TaggedBase64 => Ok(RequestParamValue::TaggedBase64(
TaggedBase64::parse(s).map_err(|err| RequestError::InvalidParam {
param_type: "TaggedBase64".to_string(),
reason: err.to_string(),
})?,
)),
}
}
pub fn param_type(&self) -> RequestParamType {
match self {
Self::Boolean(_) => RequestParamType::Boolean,
Self::Hexadecimal(_) => RequestParamType::Hexadecimal,
Self::Integer(_) => RequestParamType::Integer,
Self::TaggedBase64(_) => RequestParamType::TaggedBase64,
Self::Literal(_) => RequestParamType::Literal,
}
}
pub fn as_string(&self) -> Result<&str, RequestError> {
match self {
Self::Literal(s) => Ok(s),
_ => Err(RequestError::IncorrectParamType {
expected: RequestParamType::Literal,
actual: self.param_type(),
}),
}
}
pub fn as_integer<T: TryFrom<u128>>(&self) -> Result<T, RequestError> {
match self {
Self::Integer(x) | Self::Hexadecimal(x) => {
T::try_from(*x).map_err(|_| RequestError::IntegerOverflow {
value: *x,
expected: type_name::<T>().to_string(),
})
}
_ => Err(RequestError::IncorrectParamType {
expected: RequestParamType::Integer,
actual: self.param_type(),
}),
}
}
pub fn as_boolean(&self) -> Result<bool, RequestError> {
match self {
Self::Boolean(x) => Ok(*x),
_ => Err(RequestError::IncorrectParamType {
expected: RequestParamType::Boolean,
actual: self.param_type(),
}),
}
}
pub fn as_tagged_base64(&self) -> Result<&TaggedBase64, RequestError> {
match self {
Self::TaggedBase64(x) => Ok(x),
_ => Err(RequestError::IncorrectParamType {
expected: RequestParamType::TaggedBase64,
actual: self.param_type(),
}),
}
}
pub fn as_blob<'a, T>(&'a self) -> Result<T, RequestError>
where
T: TryFrom<&'a TaggedBase64>,
<T as TryFrom<&'a TaggedBase64>>::Error: Display,
{
let tb64 = self.as_tagged_base64()?;
tb64.try_into()
.map_err(
|err: <T as TryFrom<&'a TaggedBase64>>::Error| RequestError::TaggedBase64 {
reason: err.to_string(),
},
)
}
}
#[derive(
Clone, Copy, Debug, EnumString, strum_macros::Display, Deserialize, Serialize, PartialEq, Eq,
)]
pub enum RequestParamType {
Boolean,
Hexadecimal,
Integer,
TaggedBase64,
Literal,
}
#[derive(Clone, Debug)]
pub struct RequestParam {
pub name: String,
pub param_type: RequestParamType,
}
pub(crate) fn best_response_type(
accept: &Accept,
available: &[Mime],
) -> Result<Mime, RequestError> {
for proposed in accept.iter() {
if proposed.basetype() == "*" {
return Ok(available[0].clone());
} else if proposed.subtype() == "*" {
if let Some(mime) = available
.iter()
.find(|mime| mime.basetype() == proposed.basetype())
{
return Ok(mime.clone());
}
} else {
if let Some(mime) = available.iter().find(|mime| {
mime.basetype() == proposed.basetype() && mime.subtype() == proposed.subtype()
}) {
return Ok(mime.clone());
}
}
}
if accept.wildcard() {
Ok(available[0].clone())
} else {
Err(RequestError::UnsupportedContentType)
}
}
#[cfg(test)]
mod test {
use super::*;
use ark_serialize::*;
use tagged_base64::tagged;
fn default_req() -> http::Request {
http::Request::new(http::Method::Get, "http://localhost:12345")
}
fn param(ty: RequestParamType, name: &str, val: &str) -> RequestParamValue {
RequestParamValue::parse(
val,
&RequestParam {
name: name.to_string(),
param_type: ty,
},
)
.unwrap()
}
fn request_from_params(
params: impl IntoIterator<Item = (String, RequestParamValue)>,
) -> RequestParams {
RequestParams {
req: default_req(),
post_data: Default::default(),
params: params.into_iter().collect(),
}
}
#[tagged("BLOB")]
#[derive(Clone, Debug, PartialEq, Eq, CanonicalSerialize, CanonicalDeserialize)]
struct Blob {
data: String,
}
#[test]
fn test_params() {
let tb64 = TaggedBase64::new("TAG", &[0; 20]).unwrap();
let blob = Blob {
data: "blob".to_string(),
};
let string_param = param(RequestParamType::Literal, "string", "hello");
let integer_param = param(RequestParamType::Integer, "integer", "42");
let boolean_param = param(RequestParamType::Boolean, "boolean", "true");
let tagged_base64_param = param(
RequestParamType::TaggedBase64,
"tagged_base64",
&tb64.to_string(),
);
let blob_param = param(RequestParamType::TaggedBase64, "blob", &blob.to_string());
let params = vec![
("string".to_string(), string_param.clone()),
("integer".to_string(), integer_param.clone()),
("boolean".to_string(), boolean_param.clone()),
("tagged_base64".to_string(), tagged_base64_param.clone()),
("blob".to_string(), blob_param.clone()),
];
let req = request_from_params(params);
assert_eq!(*req.param("string").unwrap(), string_param);
assert_eq!(*req.param("integer").unwrap(), integer_param);
assert_eq!(*req.param("boolean").unwrap(), boolean_param);
assert_eq!(*req.param("tagged_base64").unwrap(), tagged_base64_param);
assert_eq!(*req.param("blob").unwrap(), blob_param);
match req.param("nosuchparam").unwrap_err() {
RequestError::MissingParam { name } if name == "nosuchparam" => {}
err => panic!("expecting MissingParam {{ nosuchparam }}, got {:?}", err),
}
assert_eq!(*req.opt_param("string").unwrap(), string_param);
assert_eq!(*req.opt_param("integer").unwrap(), integer_param);
assert_eq!(*req.opt_param("boolean").unwrap(), boolean_param);
assert_eq!(
*req.opt_param("tagged_base64").unwrap(),
tagged_base64_param
);
assert_eq!(*req.opt_param("blob").unwrap(), blob_param);
assert_eq!(req.opt_param("nosuchparam"), None);
assert_eq!(req.string_param("string").unwrap(), "hello");
match req.string_param("integer").unwrap_err() {
RequestError::IncorrectParamType { actual, expected }
if actual == RequestParamType::Integer && expected == RequestParamType::Literal => {
}
err => panic!(
"expecting IncorrectParamType {{ Integer, String }}, got {:?}",
err
),
}
match req.string_param("nosuchparam").unwrap_err() {
RequestError::MissingParam { name } if name == "nosuchparam" => {}
err => panic!("expecting MissingParam {{ nosuchparam }}, got {:?}", err),
};
assert_eq!(req.integer_param::<_, usize>("integer").unwrap(), 42);
match req.integer_param::<_, usize>("string").unwrap_err() {
RequestError::IncorrectParamType { actual, expected }
if actual == RequestParamType::Literal && expected == RequestParamType::Integer => {
}
err => panic!(
"expecting IncorrectParamType {{ Literal, Integer }}, got {:?}",
err
),
}
match req.integer_param::<_, usize>("nosuchparam").unwrap_err() {
RequestError::MissingParam { name } if name == "nosuchparam" => {}
err => panic!("expecting MissingParam {{ nosuchparam }}, got {:?}", err),
};
assert!(req.boolean_param("boolean").unwrap());
match req.boolean_param("integer").unwrap_err() {
RequestError::IncorrectParamType { actual, expected }
if actual == RequestParamType::Integer && expected == RequestParamType::Boolean => {
}
err => panic!(
"expecting IncorrectParamType {{ Integer, Boolean }}, got {:?}",
err
),
}
match req.boolean_param("nosuchparam").unwrap_err() {
RequestError::MissingParam { name } if name == "nosuchparam" => {}
err => panic!("expecting MissingParam {{ nosuchparam }}, got {:?}", err),
};
assert_eq!(*req.tagged_base64_param("tagged_base64").unwrap(), tb64);
match req.tagged_base64_param("integer").unwrap_err() {
RequestError::IncorrectParamType { actual, expected }
if actual == RequestParamType::Integer
&& expected == RequestParamType::TaggedBase64 => {}
err => panic!(
"expecting IncorrectParamType {{ Integer, TaggedBase64 }}, got {:?}",
err
),
}
match req.tagged_base64_param("nosuchparam").unwrap_err() {
RequestError::MissingParam { name } if name == "nosuchparam" => {}
err => panic!("expecting MissingParam {{ nosuchparam }}, got {:?}", err),
};
assert_eq!(req.blob_param::<_, Blob>("blob").unwrap(), blob);
match req.tagged_base64_param("integer").unwrap_err() {
RequestError::IncorrectParamType { actual, expected }
if actual == RequestParamType::Integer
&& expected == RequestParamType::TaggedBase64 => {}
err => panic!(
"expecting IncorrectParamType {{ Integer, TaggedBase64 }}, got {:?}",
err
),
}
match req.tagged_base64_param("nosuchparam").unwrap_err() {
RequestError::MissingParam { name } if name == "nosuchparam" => {}
err => panic!("expecting MissingParam {{ nosuchparam }}, got {:?}", err),
};
assert_eq!(req.opt_string_param("string").unwrap().unwrap(), "hello");
match req.opt_string_param("integer").unwrap_err() {
RequestError::IncorrectParamType { actual, expected }
if actual == RequestParamType::Integer && expected == RequestParamType::Literal => {
}
err => panic!(
"expecting IncorrectParamType {{ Integer, String }}, got {:?}",
err
),
}
assert_eq!(req.opt_string_param("nosuchparam").unwrap(), None);
assert_eq!(
req.opt_integer_param::<_, usize>("integer")
.unwrap()
.unwrap(),
42
);
match req.opt_integer_param::<_, usize>("string").unwrap_err() {
RequestError::IncorrectParamType { actual, expected }
if actual == RequestParamType::Literal && expected == RequestParamType::Integer => {
}
err => panic!(
"expecting IncorrectParamType {{ Literal, Integer }}, got {:?}",
err
),
}
assert_eq!(
req.opt_integer_param::<_, usize>("nosuchparam").unwrap(),
None
);
assert!(req.opt_boolean_param("boolean").unwrap().unwrap());
match req.opt_boolean_param("integer").unwrap_err() {
RequestError::IncorrectParamType { actual, expected }
if actual == RequestParamType::Integer && expected == RequestParamType::Boolean => {
}
err => panic!(
"expecting IncorrectParamType {{ Integer, Boolean }}, got {:?}",
err
),
}
assert_eq!(req.opt_boolean_param("nosuchparam").unwrap(), None);
assert_eq!(
*req.opt_tagged_base64_param("tagged_base64")
.unwrap()
.unwrap(),
tb64
);
match req.opt_tagged_base64_param("integer").unwrap_err() {
RequestError::IncorrectParamType { actual, expected }
if actual == RequestParamType::Integer
&& expected == RequestParamType::TaggedBase64 => {}
err => panic!(
"expecting IncorrectParamType {{ Integer, TaggedBase64 }}, got {:?}",
err
),
}
assert_eq!(req.opt_tagged_base64_param("nosuchparam").unwrap(), None);
assert_eq!(
req.opt_blob_param::<_, Blob>("blob").unwrap().unwrap(),
blob
);
match req.opt_blob_param::<_, Blob>("integer").unwrap_err() {
RequestError::IncorrectParamType { actual, expected }
if actual == RequestParamType::Integer
&& expected == RequestParamType::TaggedBase64 => {}
err => panic!(
"expecting IncorrectParamType {{ Integer, TaggedBase64 }}, got {:?}",
err
),
}
assert_eq!(req.opt_blob_param::<_, Blob>("nosuchparam").unwrap(), None);
}
}