Add the ability to bound the response's size limit to simple-request

This commit is contained in:
Luke Parker
2025-09-11 17:24:47 -04:00
parent d0f497dc68
commit befbbbfb84
5 changed files with 61 additions and 24 deletions

11
Cargo.lock generated
View File

@@ -2995,7 +2995,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.59.0", "windows-sys 0.60.2",
] ]
[[package]] [[package]]
@@ -6914,7 +6914,7 @@ dependencies = [
"once_cell", "once_cell",
"socket2 0.6.0", "socket2 0.6.0",
"tracing", "tracing",
"windows-sys 0.59.0", "windows-sys 0.60.2",
] ]
[[package]] [[package]]
@@ -7552,7 +7552,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
"windows-sys 0.59.0", "windows-sys 0.60.2",
] ]
[[package]] [[package]]
@@ -9995,9 +9995,10 @@ dependencies = [
[[package]] [[package]]
name = "simple-request" name = "simple-request"
version = "0.1.0" version = "0.1.1"
dependencies = [ dependencies = [
"base64ct", "base64ct",
"futures-util",
"http-body-util", "http-body-util",
"hyper", "hyper",
"hyper-rustls", "hyper-rustls",
@@ -10931,7 +10932,7 @@ dependencies = [
"getrandom 0.3.3", "getrandom 0.3.3",
"once_cell", "once_cell",
"rustix", "rustix",
"windows-sys 0.59.0", "windows-sys 0.60.2",
] ]
[[package]] [[package]]

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "simple-request" name = "simple-request"
version = "0.1.0" version = "0.1.1"
description = "A simple HTTP(S) request library" description = "A simple HTTP(S) request library"
license = "MIT" license = "MIT"
repository = "https://github.com/serai-dex/serai/tree/develop/common/simple-request" repository = "https://github.com/serai-dex/serai/tree/develop/common/simple-request"
@@ -21,6 +21,7 @@ tower-service = { version = "0.3", default-features = false }
hyper = { version = "1", default-features = false, features = ["http1", "client"] } hyper = { version = "1", default-features = false, features = ["http1", "client"] }
hyper-util = { version = "0.1", default-features = false, features = ["http1", "client-legacy", "tokio"] } hyper-util = { version = "0.1", default-features = false, features = ["http1", "client-legacy", "tokio"] }
http-body-util = { version = "0.1", default-features = false } http-body-util = { version = "0.1", default-features = false }
futures-util = { version = "0.3", default-features = false, features = ["std"] }
tokio = { version = "1", default-features = false } tokio = { version = "1", default-features = false }
hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "ring", "rustls-native-certs", "native-tokio"], optional = true } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "ring", "rustls-native-certs", "native-tokio"], optional = true }

View File

@@ -97,7 +97,7 @@ impl Client {
pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response<'_>, Error> { pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response<'_>, Error> {
let request: Request = request.into(); let request: Request = request.into();
let mut request = request.0; let Request { mut request, response_size_limit } = request;
if let Some(header_host) = request.headers().get(hyper::header::HOST) { if let Some(header_host) = request.headers().get(hyper::header::HOST) {
match &self.connection { match &self.connection {
Connection::ConnectionPool(_) => {} Connection::ConnectionPool(_) => {}
@@ -153,11 +153,11 @@ impl Client {
let mut err = connection.ready().await.err(); let mut err = connection.ready().await.err();
if err.is_none() { if err.is_none() {
// Send the request // Send the request
let res = connection.send_request(request).await; let response = connection.send_request(request).await;
if let Ok(res) = res { if let Ok(response) = response {
return Ok(Response(res, self)); return Ok(Response { response, size_limit: response_size_limit, client: self });
} }
err = res.err(); err = response.err();
} }
// Since this connection has been put into an error state, drop it // Since this connection has been put into an error state, drop it
*connection_lock = None; *connection_lock = None;
@@ -165,6 +165,6 @@ impl Client {
} }
}; };
Ok(Response(response, self)) Ok(Response { response, size_limit: response_size_limit, client: self })
} }
} }

View File

@@ -7,11 +7,15 @@ pub use http_body_util::Full;
use crate::Error; use crate::Error;
#[derive(Debug)] #[derive(Debug)]
pub struct Request(pub(crate) hyper::Request<Full<Bytes>>); pub struct Request {
pub(crate) request: hyper::Request<Full<Bytes>>,
pub(crate) response_size_limit: Option<usize>,
}
impl Request { impl Request {
#[cfg(feature = "basic-auth")] #[cfg(feature = "basic-auth")]
fn username_password_from_uri(&self) -> Result<(String, String), Error> { fn username_password_from_uri(&self) -> Result<(String, String), Error> {
if let Some(authority) = self.0.uri().authority() { if let Some(authority) = self.request.uri().authority() {
let authority = authority.as_str(); let authority = authority.as_str();
if authority.contains('@') { if authority.contains('@') {
// Decode the username and password from the URI // Decode the username and password from the URI
@@ -36,7 +40,7 @@ impl Request {
let mut formatted = format!("{username}:{password}"); let mut formatted = format!("{username}:{password}");
let mut encoded = Base64::encode_string(formatted.as_bytes()); let mut encoded = Base64::encode_string(formatted.as_bytes());
formatted.zeroize(); formatted.zeroize();
self.0.headers_mut().insert( self.request.headers_mut().insert(
hyper::header::AUTHORIZATION, hyper::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Basic {encoded}")).unwrap(), HeaderValue::from_str(&format!("Basic {encoded}")).unwrap(),
); );
@@ -59,9 +63,17 @@ impl Request {
pub fn with_basic_auth(&mut self) { pub fn with_basic_auth(&mut self) {
let _ = self.basic_auth_from_uri(); let _ = self.basic_auth_from_uri();
} }
}
impl From<hyper::Request<Full<Bytes>>> for Request { /// Set a size limit for the response.
fn from(request: hyper::Request<Full<Bytes>>) -> Request { ///
Request(request) /// This may be exceeded by a single HTTP frame and accordingly isn't perfect.
pub fn set_response_size_limit(&mut self, response_size_limit: Option<usize>) {
self.response_size_limit = response_size_limit;
}
}
impl From<hyper::Request<Full<Bytes>>> for Request {
fn from(request: hyper::Request<Full<Bytes>>) -> Request {
Request { request, response_size_limit: None }
} }
} }

View File

@@ -1,24 +1,47 @@
use std::io;
use hyper::{ use hyper::{
StatusCode, StatusCode,
header::{HeaderValue, HeaderMap}, header::{HeaderValue, HeaderMap},
body::{Buf, Incoming}, body::Incoming,
}; };
use http_body_util::BodyExt; use http_body_util::BodyExt;
use futures_util::{Stream, StreamExt};
use crate::{Client, Error}; use crate::{Client, Error};
// Borrows the client so its async task lives as long as this response exists. // Borrows the client so its async task lives as long as this response exists.
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug)] #[derive(Debug)]
pub struct Response<'a>(pub(crate) hyper::Response<Incoming>, pub(crate) &'a Client); pub struct Response<'a> {
pub(crate) response: hyper::Response<Incoming>,
pub(crate) size_limit: Option<usize>,
pub(crate) client: &'a Client,
}
impl Response<'_> { impl Response<'_> {
pub fn status(&self) -> StatusCode { pub fn status(&self) -> StatusCode {
self.0.status() self.response.status()
} }
pub fn headers(&self) -> &HeaderMap<HeaderValue> { pub fn headers(&self) -> &HeaderMap<HeaderValue> {
self.0.headers() self.response.headers()
} }
pub async fn body(self) -> Result<impl std::io::Read, Error> { pub async fn body(self) -> Result<impl std::io::Read, Error> {
Ok(self.0.into_body().collect().await.map_err(Error::Hyper)?.aggregate().reader()) let mut body = self.response.into_body().into_data_stream();
let mut res: Vec<u8> = vec![];
loop {
if let Some(size_limit) = self.size_limit {
let (lower, upper) = body.size_hint();
if res.len().wrapping_add(upper.unwrap_or(lower)) > size_limit.min(usize::MAX - 1) {
Err(Error::ConnectionError("response exceeded size limit".into()))?;
}
}
let Some(part) = body.next().await else { break };
let part = part.map_err(Error::Hyper)?;
res.extend(part.as_ref());
}
Ok(io::Cursor::new(res))
} }
} }