Add the ability to bound the response's size limit to simple-request

This commit is contained in:
Luke Parker
2025-09-11 17:24:47 -04:00
parent 02a5f15535
commit 4db78b1787
4 changed files with 55 additions and 19 deletions

View File

@@ -1,6 +1,6 @@
[package]
name = "simple-request"
version = "0.1.0"
version = "0.1.1"
description = "A simple HTTP(S) request library"
license = "MIT"
repository = "https://github.com/serai-dex/serai/tree/develop/common/simple-request"
@@ -21,6 +21,7 @@ tower-service = { version = "0.3", default-features = false }
hyper = { version = "1", default-features = false, features = ["http1", "client"] }
hyper-util = { version = "0.1", default-features = false, features = ["http1", "client-legacy", "tokio"] }
http-body-util = { version = "0.1", default-features = false }
futures-util = { version = "0.3", default-features = false, features = ["std"] }
tokio = { version = "1", default-features = false }
hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "ring", "rustls-native-certs", "native-tokio"], optional = true }

View File

@@ -97,7 +97,7 @@ impl Client {
pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response<'_>, Error> {
let request: Request = request.into();
let mut request = request.0;
let Request { mut request, response_size_limit } = request;
if let Some(header_host) = request.headers().get(hyper::header::HOST) {
match &self.connection {
Connection::ConnectionPool(_) => {}
@@ -153,11 +153,11 @@ impl Client {
let mut err = connection.ready().await.err();
if err.is_none() {
// Send the request
let res = connection.send_request(request).await;
if let Ok(res) = res {
return Ok(Response(res, self));
let response = connection.send_request(request).await;
if let Ok(response) = response {
return Ok(Response { response, size_limit: response_size_limit, client: self });
}
err = res.err();
err = response.err();
}
// Since this connection has been put into an error state, drop it
*connection_lock = None;
@@ -165,6 +165,6 @@ impl Client {
}
};
Ok(Response(response, self))
Ok(Response { response, size_limit: response_size_limit, client: self })
}
}

View File

@@ -7,11 +7,15 @@ pub use http_body_util::Full;
use crate::Error;
#[derive(Debug)]
pub struct Request(pub(crate) hyper::Request<Full<Bytes>>);
pub struct Request {
pub(crate) request: hyper::Request<Full<Bytes>>,
pub(crate) response_size_limit: Option<usize>,
}
impl Request {
#[cfg(feature = "basic-auth")]
fn username_password_from_uri(&self) -> Result<(String, String), Error> {
if let Some(authority) = self.0.uri().authority() {
if let Some(authority) = self.request.uri().authority() {
let authority = authority.as_str();
if authority.contains('@') {
// Decode the username and password from the URI
@@ -36,7 +40,7 @@ impl Request {
let mut formatted = format!("{username}:{password}");
let mut encoded = Base64::encode_string(formatted.as_bytes());
formatted.zeroize();
self.0.headers_mut().insert(
self.request.headers_mut().insert(
hyper::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Basic {encoded}")).unwrap(),
);
@@ -59,9 +63,17 @@ impl Request {
pub fn with_basic_auth(&mut self) {
let _ = self.basic_auth_from_uri();
}
/// Set a size limit for the response.
///
/// This may be exceeded by a single HTTP frame and accordingly isn't perfect.
pub fn set_response_size_limit(&mut self, response_size_limit: Option<usize>) {
self.response_size_limit = response_size_limit;
}
}
impl From<hyper::Request<Full<Bytes>>> for Request {
fn from(request: hyper::Request<Full<Bytes>>) -> Request {
Request(request)
Request { request, response_size_limit: None }
}
}

View File

@@ -1,24 +1,47 @@
use std::io;
use hyper::{
StatusCode,
header::{HeaderValue, HeaderMap},
body::{Buf, Incoming},
body::Incoming,
};
use http_body_util::BodyExt;
use futures_util::{Stream, StreamExt};
use crate::{Client, Error};
// Borrows the client so its async task lives as long as this response exists.
#[allow(dead_code)]
#[derive(Debug)]
pub struct Response<'a>(pub(crate) hyper::Response<Incoming>, pub(crate) &'a Client);
pub struct Response<'a> {
pub(crate) response: hyper::Response<Incoming>,
pub(crate) size_limit: Option<usize>,
pub(crate) client: &'a Client,
}
impl Response<'_> {
pub fn status(&self) -> StatusCode {
self.0.status()
self.response.status()
}
pub fn headers(&self) -> &HeaderMap<HeaderValue> {
self.0.headers()
self.response.headers()
}
pub async fn body(self) -> Result<impl std::io::Read, Error> {
Ok(self.0.into_body().collect().await.map_err(Error::Hyper)?.aggregate().reader())
let mut body = self.response.into_body().into_data_stream();
let mut res: Vec<u8> = vec![];
loop {
if let Some(size_limit) = self.size_limit {
let (lower, upper) = body.size_hint();
if res.len().wrapping_add(upper.unwrap_or(lower)) > size_limit.min(usize::MAX - 1) {
Err(Error::ConnectionError("response exceeded size limit".into()))?;
}
}
let Some(part) = body.next().await else { break };
let part = part.map_err(Error::Hyper)?;
res.extend(part.as_ref());
}
Ok(io::Cursor::new(res))
}
}