mirror of
https://github.com/serai-dex/serai.git
synced 2025-12-08 12:19:24 +00:00
Don't default to basic-auth if it's enabled, yet require it to be specified
This commit is contained in:
@@ -2,27 +2,19 @@
|
||||
#![doc = include_str!("../README.md")]
|
||||
|
||||
use hyper_rustls::{HttpsConnectorBuilder, HttpsConnector};
|
||||
use hyper::{
|
||||
StatusCode,
|
||||
header::{HeaderValue, HeaderMap},
|
||||
body::{Buf, Body},
|
||||
Response as HyperResponse,
|
||||
client::HttpConnector,
|
||||
};
|
||||
pub use hyper::{self, Request};
|
||||
use hyper::{header::HeaderValue, client::HttpConnector};
|
||||
pub use hyper;
|
||||
|
||||
mod request;
|
||||
pub use request::*;
|
||||
|
||||
mod response;
|
||||
pub use response::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Response(HyperResponse<Body>);
|
||||
impl Response {
|
||||
pub fn status(&self) -> StatusCode {
|
||||
self.0.status()
|
||||
}
|
||||
pub fn headers(&self) -> &HeaderMap<HeaderValue> {
|
||||
self.0.headers()
|
||||
}
|
||||
pub async fn body(self) -> Result<impl std::io::Read, hyper::Error> {
|
||||
Ok(hyper::body::aggregate(self.0.into_body()).await?.reader())
|
||||
}
|
||||
pub enum Error {
|
||||
InvalidUri,
|
||||
Hyper(hyper::Error),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -35,12 +27,6 @@ pub struct Client {
|
||||
connection: Connection,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
InvalidHost,
|
||||
Hyper(hyper::Error),
|
||||
}
|
||||
|
||||
impl Client {
|
||||
fn https_builder() -> HttpsConnector<HttpConnector> {
|
||||
HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build()
|
||||
@@ -56,38 +42,14 @@ impl Client {
|
||||
fn without_connection_pool() -> Client {}
|
||||
*/
|
||||
|
||||
pub async fn request(&self, mut request: Request<Body>) -> Result<Response, Error> {
|
||||
pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response, Error> {
|
||||
let request: Request = request.into();
|
||||
let mut request = request.0;
|
||||
if request.headers().get(hyper::header::HOST).is_none() {
|
||||
let host = request.uri().host().ok_or(Error::InvalidHost)?.to_string();
|
||||
let host = request.uri().host().ok_or(Error::InvalidUri)?.to_string();
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidHost)?);
|
||||
}
|
||||
|
||||
#[cfg(feature = "basic-auth")]
|
||||
if request.headers().get(hyper::header::AUTHORIZATION).is_none() {
|
||||
if let Some(authority) = request.uri().authority() {
|
||||
let authority = authority.as_str();
|
||||
if authority.contains('@') {
|
||||
// Decode the username and password from the URI
|
||||
let mut userpass = authority.split('@').next().unwrap().to_string();
|
||||
// If the password is "", the URI may omit :, yet the authentication will still expect it
|
||||
if !userpass.contains(':') {
|
||||
userpass.push(':');
|
||||
}
|
||||
|
||||
use zeroize::Zeroize;
|
||||
use base64ct::{Encoding, Base64};
|
||||
|
||||
let mut encoded = Base64::encode_string(userpass.as_bytes());
|
||||
userpass.zeroize();
|
||||
request.headers_mut().insert(
|
||||
hyper::header::AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("Basic {encoded}")).unwrap(),
|
||||
);
|
||||
encoded.zeroize();
|
||||
}
|
||||
}
|
||||
.insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?);
|
||||
}
|
||||
|
||||
Ok(Response(match &self.connection {
|
||||
|
||||
66
common/request/src/request.rs
Normal file
66
common/request/src/request.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use hyper::body::Body;
|
||||
#[cfg(feature = "basic-auth")]
|
||||
use hyper::header::HeaderValue;
|
||||
|
||||
#[cfg(feature = "basic-auth")]
|
||||
use crate::Error;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Request(pub(crate) hyper::Request<Body>);
|
||||
impl Request {
|
||||
#[cfg(feature = "basic-auth")]
|
||||
fn username_password_from_uri(&self) -> Result<(String, String), Error> {
|
||||
if let Some(authority) = self.0.uri().authority() {
|
||||
let authority = authority.as_str();
|
||||
if authority.contains('@') {
|
||||
// Decode the username and password from the URI
|
||||
let mut userpass = authority.split('@').next().unwrap().to_string();
|
||||
|
||||
let mut userpass_iter = userpass.split(':');
|
||||
let username = userpass_iter.next().unwrap().to_string();
|
||||
let password = userpass_iter.next().map(str::to_string).unwrap_or_else(String::new);
|
||||
zeroize::Zeroize::zeroize(&mut userpass);
|
||||
|
||||
return Ok((username, password));
|
||||
}
|
||||
}
|
||||
Err(Error::InvalidUri)
|
||||
}
|
||||
|
||||
#[cfg(feature = "basic-auth")]
|
||||
pub fn basic_auth(&mut self, username: &str, password: &str) {
|
||||
use zeroize::Zeroize;
|
||||
use base64ct::{Encoding, Base64};
|
||||
|
||||
let mut formatted = format!("{username}:{password}");
|
||||
let mut encoded = Base64::encode_string(formatted.as_bytes());
|
||||
formatted.zeroize();
|
||||
self.0.headers_mut().insert(
|
||||
hyper::header::AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("Basic {encoded}")).unwrap(),
|
||||
);
|
||||
encoded.zeroize();
|
||||
}
|
||||
|
||||
#[cfg(feature = "basic-auth")]
|
||||
pub fn basic_auth_from_uri(&mut self) -> Result<(), Error> {
|
||||
let (mut username, mut password) = self.username_password_from_uri()?;
|
||||
self.basic_auth(&username, &password);
|
||||
|
||||
use zeroize::Zeroize;
|
||||
username.zeroize();
|
||||
password.zeroize();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "basic-auth")]
|
||||
pub fn with_basic_auth(&mut self) {
|
||||
let _ = self.basic_auth_from_uri();
|
||||
}
|
||||
}
|
||||
impl From<hyper::Request<Body>> for Request {
|
||||
fn from(request: hyper::Request<Body>) -> Request {
|
||||
Request(request)
|
||||
}
|
||||
}
|
||||
21
common/request/src/response.rs
Normal file
21
common/request/src/response.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
use hyper::{
|
||||
StatusCode,
|
||||
header::{HeaderValue, HeaderMap},
|
||||
body::{Buf, Body},
|
||||
};
|
||||
|
||||
use crate::Error;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Response(pub(crate) hyper::Response<Body>);
|
||||
impl Response {
|
||||
pub fn status(&self) -> StatusCode {
|
||||
self.0.status()
|
||||
}
|
||||
pub fn headers(&self) -> &HeaderMap<HeaderValue> {
|
||||
self.0.headers()
|
||||
}
|
||||
pub async fn body(self) -> Result<impl std::io::Read, Error> {
|
||||
hyper::body::aggregate(self.0.into_body()).await.map(Buf::reader).map_err(Error::Hyper)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user