Merge branch 'cache-determine-compressor'

This commit is contained in:
Daniel Valentine 2022-12-10 18:35:38 -07:00
commit 08a20b89a6
No known key found for this signature in database
GPG key ID: C82492E4FF813823

View file

@ -243,7 +243,7 @@ impl Server {
match func.await { match func.await {
Ok(mut res) => { Ok(mut res) => {
res.headers_mut().extend(def_headers); res.headers_mut().extend(def_headers);
let _ = compress_response(req_headers, &mut res).await; let _ = compress_response(&req_headers, &mut res).await;
Ok(res) Ok(res)
} }
@ -282,7 +282,7 @@ async fn new_boilerplate(
) -> Result<Response<Body>, String> { ) -> Result<Response<Body>, String> {
match Response::builder().status(status).body(body) { match Response::builder().status(status).body(body) {
Ok(mut res) => { Ok(mut res) => {
let _ = compress_response(req_headers, &mut res).await; let _ = compress_response(&req_headers, &mut res).await;
res.headers_mut().extend(default_headers.clone()); res.headers_mut().extend(default_headers.clone());
Ok(res) Ok(res)
@ -306,7 +306,8 @@ async fn new_boilerplate(
/// Accept-Encoding: gzip, compress, br /// Accept-Encoding: gzip, compress, br
/// Accept-Encoding: br;q=1.0, gzip;q=0.8, *;q=0.1 /// Accept-Encoding: br;q=1.0, gzip;q=0.8, *;q=0.1
/// ``` /// ```
fn determine_compressor(accept_encoding: &str) -> Option<CompressionType> { #[cached]
fn determine_compressor(accept_encoding: String) -> Option<CompressionType> {
if accept_encoding.is_empty() { if accept_encoding.is_empty() {
return None; return None;
}; };
@ -473,7 +474,7 @@ fn determine_compressor(accept_encoding: &str) -> Option<CompressionType> {
/// ///
/// This function logs errors to stderr, but only in debug mode. No information /// This function logs errors to stderr, but only in debug mode. No information
/// is logged in release builds. /// is logged in release builds.
async fn compress_response(req_headers: HeaderMap<header::HeaderValue>, res: &mut Response<Body>) -> Result<(), String> { async fn compress_response(req_headers: &HeaderMap<header::HeaderValue>, res: &mut Response<Body>) -> Result<(), String> {
// Check if the data is eligible for compression. // Check if the data is eligible for compression.
if let Some(hdr) = res.headers().get(header::CONTENT_TYPE) { if let Some(hdr) = res.headers().get(header::CONTENT_TYPE) {
match from_utf8(hdr.as_bytes()) { match from_utf8(hdr.as_bytes()) {
@ -503,30 +504,22 @@ async fn compress_response(req_headers: HeaderMap<header::HeaderValue>, res: &mu
return Ok(()); return Ok(());
}; };
// Quick and dirty closure for extracting a header from the request and // Check to see which compressor is requested, and if we can use it.
// returning it as a &str. let accept_encoding: String = match req_headers.get(header::ACCEPT_ENCODING) {
let get_req_header = |k: header::HeaderName| -> Option<&str> { None => return Ok(()), // Client requested no compression.
match req_headers.get(k) {
Some(hdr) => match from_utf8(hdr.as_bytes()) { Some(hdr) => match String::from_utf8(hdr.as_bytes().into()) {
Ok(val) => Some(val), Ok(val) => val,
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
Err(e) => { Err(e) => {
dbg_msg!(e); dbg_msg!(e);
None return Ok(());
} }
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]
Err(_) => None, Err(_) => return Ok(()),
}, },
None => None,
}
};
// Check to see which compressor is requested, and if we can use it.
let accept_encoding: &str = match get_req_header(header::ACCEPT_ENCODING) {
Some(val) => val,
None => return Ok(()), // Client requested no compression.
}; };
let compressor: CompressionType = match determine_compressor(accept_encoding) { let compressor: CompressionType = match determine_compressor(accept_encoding) {
@ -636,18 +629,18 @@ mod tests {
#[test] #[test]
fn test_determine_compressor() { fn test_determine_compressor() {
// Single compressor given. // Single compressor given.
assert_eq!(determine_compressor("unsupported"), None); assert_eq!(determine_compressor("unsupported".to_string()), None);
assert_eq!(determine_compressor("gzip"), Some(CompressionType::Gzip)); assert_eq!(determine_compressor("gzip".to_string()), Some(CompressionType::Gzip));
assert_eq!(determine_compressor("*"), Some(DEFAULT_COMPRESSOR)); assert_eq!(determine_compressor("*".to_string()), Some(DEFAULT_COMPRESSOR));
// Multiple compressors. // Multiple compressors.
assert_eq!(determine_compressor("gzip, br"), Some(CompressionType::Brotli)); assert_eq!(determine_compressor("gzip, br".to_string()), Some(CompressionType::Brotli));
assert_eq!(determine_compressor("gzip;q=0.8, br;q=0.3"), Some(CompressionType::Gzip)); assert_eq!(determine_compressor("gzip;q=0.8, br;q=0.3".to_string()), Some(CompressionType::Gzip));
assert_eq!(determine_compressor("br, gzip"), Some(CompressionType::Brotli)); assert_eq!(determine_compressor("br, gzip".to_string()), Some(CompressionType::Brotli));
assert_eq!(determine_compressor("br;q=0.3, gzip;q=0.4"), Some(CompressionType::Gzip)); assert_eq!(determine_compressor("br;q=0.3, gzip;q=0.4".to_string()), Some(CompressionType::Gzip));
// Invalid q-values. // Invalid q-values.
assert_eq!(determine_compressor("gzip;q=NAN"), None); assert_eq!(determine_compressor("gzip;q=NAN".to_string()), None);
} }
#[test] #[test]
@ -672,9 +665,9 @@ mod tests {
] { ] {
// Determine what the expected encoding should be based on both the // Determine what the expected encoding should be based on both the
// specific encodings we accept. // specific encodings we accept.
let expected_encoding: CompressionType = match determine_compressor(accept_encoding) { let expected_encoding: CompressionType = match determine_compressor(accept_encoding.to_string()) {
Some(s) => s, Some(s) => s,
None => panic!("determine_compressor(accept_encoding) => None"), None => panic!("determine_compressor(accept_encoding.to_string()) => None"),
}; };
// Build headers with our Accept-Encoding. // Build headers with our Accept-Encoding.
@ -691,8 +684,8 @@ mod tests {
.unwrap(); .unwrap();
// Perform the compression. // Perform the compression.
if let Err(e) = block_on(compress_response(req_headers, &mut res)) { if let Err(e) = block_on(compress_response(&req_headers, &mut res)) {
panic!("compress_response(req_headers, &mut res) => Err(\"{}\")", e); panic!("compress_response(&req_headers, &mut res) => Err(\"{}\")", e);
}; };
// If the content was compressed, we expect the Content-Encoding // If the content was compressed, we expect the Content-Encoding
@ -739,9 +732,8 @@ mod tests {
}; };
let mut decompressed = Vec::<u8>::new(); let mut decompressed = Vec::<u8>::new();
match io::copy(&mut decoder, &mut decompressed) { if let Err(e) = io::copy(&mut decoder, &mut decompressed) {
Ok(_) => {} panic!("{}", e);
Err(e) => panic!("{}", e),
}; };
assert!(decompressed.eq(&expected_lorem_ipsum)); assert!(decompressed.eq(&expected_lorem_ipsum));