diff --git a/common/request/Cargo.toml b/common/request/Cargo.toml index d960e91b..467d967a 100644 --- a/common/request/Cargo.toml +++ b/common/request/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-request" -version = "0.1.0" +version = "0.1.1" description = "A simple HTTP(S) request library" license = "MIT" repository = "https://github.com/serai-dex/serai/tree/develop/common/simple-request" @@ -21,6 +21,7 @@ tower-service = { version = "0.3", default-features = false } hyper = { version = "1", default-features = false, features = ["http1", "client"] } hyper-util = { version = "0.1", default-features = false, features = ["http1", "client-legacy", "tokio"] } http-body-util = { version = "0.1", default-features = false } +futures-util = { version = "0.3", default-features = false, features = ["std"] } tokio = { version = "1", default-features = false } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "ring", "rustls-native-certs", "native-tokio"], optional = true } diff --git a/common/request/src/lib.rs b/common/request/src/lib.rs index df9689e1..04b162aa 100644 --- a/common/request/src/lib.rs +++ b/common/request/src/lib.rs @@ -97,7 +97,7 @@ impl Client { pub async fn request>(&self, request: R) -> Result, Error> { let request: Request = request.into(); - let mut request = request.0; + let Request { mut request, response_size_limit } = request; if let Some(header_host) = request.headers().get(hyper::header::HOST) { match &self.connection { Connection::ConnectionPool(_) => {} @@ -153,11 +153,11 @@ impl Client { let mut err = connection.ready().await.err(); if err.is_none() { // Send the request - let res = connection.send_request(request).await; - if let Ok(res) = res { - return Ok(Response(res, self)); + let response = connection.send_request(request).await; + if let Ok(response) = response { + return Ok(Response { response, size_limit: response_size_limit, client: self }); } - err = res.err(); + err = response.err(); } // Since this connection has been put into an error state, drop it *connection_lock = None; @@ -165,6 +165,6 @@ impl Client { } }; - Ok(Response(response, self)) + Ok(Response { response, size_limit: response_size_limit, client: self }) } } diff --git a/common/request/src/request.rs b/common/request/src/request.rs index ced0d10b..64a10ea7 100644 --- a/common/request/src/request.rs +++ b/common/request/src/request.rs @@ -7,11 +7,15 @@ pub use http_body_util::Full; use crate::Error; #[derive(Debug)] -pub struct Request(pub(crate) hyper::Request>); +pub struct Request { + pub(crate) request: hyper::Request>, + pub(crate) response_size_limit: Option, +} + impl Request { #[cfg(feature = "basic-auth")] fn username_password_from_uri(&self) -> Result<(String, String), Error> { - if let Some(authority) = self.0.uri().authority() { + if let Some(authority) = self.request.uri().authority() { let authority = authority.as_str(); if authority.contains('@') { // Decode the username and password from the URI @@ -36,7 +40,7 @@ impl Request { let mut formatted = format!("{username}:{password}"); let mut encoded = Base64::encode_string(formatted.as_bytes()); formatted.zeroize(); - self.0.headers_mut().insert( + self.request.headers_mut().insert( hyper::header::AUTHORIZATION, HeaderValue::from_str(&format!("Basic {encoded}")).unwrap(), ); @@ -59,9 +63,17 @@ impl Request { pub fn with_basic_auth(&mut self) { let _ = self.basic_auth_from_uri(); } -} -impl From>> for Request { - fn from(request: hyper::Request>) -> Request { - Request(request) + + /// Set a size limit for the response. + /// + /// This may be exceeded by a single HTTP frame and accordingly isn't perfect. + pub fn set_response_size_limit(&mut self, response_size_limit: Option) { + self.response_size_limit = response_size_limit; + } +} + +impl From>> for Request { + fn from(request: hyper::Request>) -> Request { + Request { request, response_size_limit: None } } } diff --git a/common/request/src/response.rs b/common/request/src/response.rs index e4628f72..19e025cf 100644 --- a/common/request/src/response.rs +++ b/common/request/src/response.rs @@ -1,24 +1,47 @@ +use std::io; + use hyper::{ StatusCode, header::{HeaderValue, HeaderMap}, - body::{Buf, Incoming}, + body::Incoming, }; use http_body_util::BodyExt; +use futures_util::{Stream, StreamExt}; + use crate::{Client, Error}; // Borrows the client so its async task lives as long as this response exists. #[allow(dead_code)] #[derive(Debug)] -pub struct Response<'a>(pub(crate) hyper::Response, pub(crate) &'a Client); +pub struct Response<'a> { + pub(crate) response: hyper::Response, + pub(crate) size_limit: Option, + pub(crate) client: &'a Client, +} + impl Response<'_> { pub fn status(&self) -> StatusCode { - self.0.status() + self.response.status() } pub fn headers(&self) -> &HeaderMap { - self.0.headers() + self.response.headers() } pub async fn body(self) -> Result { - Ok(self.0.into_body().collect().await.map_err(Error::Hyper)?.aggregate().reader()) + let mut body = self.response.into_body().into_data_stream(); + let mut res: Vec = vec![]; + loop { + if let Some(size_limit) = self.size_limit { + let (lower, upper) = body.size_hint(); + if res.len().wrapping_add(upper.unwrap_or(lower)) > size_limit.min(usize::MAX - 1) { + Err(Error::ConnectionError("response exceeded size limit".into()))?; + } + } + + let Some(part) = body.next().await else { break }; + let part = part.map_err(Error::Hyper)?; + res.extend(part.as_ref()); + } + Ok(io::Cursor::new(res)) } }