Redirect /:id to canonical URL for post. (#617)
* Redirect /:id to canonical URL for post. This implements redirection of `/:id` (a short-form URL to a post) to the post's canonical URL. Libreddit issues a `HEAD /:id` to Reddit to get the canonical URL, and on success will send an HTTP 302 to a client with the canonical URL set in as the value of the `Location:` header. This also implements support for short IDs for non-ASCII posts, c/o spikecodes. Co-authored-by: spikecodes <19519553+spikecodes@users.noreply.github.com>
This commit is contained in:
parent
584cd4aac1
commit
c6487799ed
3 changed files with 152 additions and 66 deletions
182
src/client.rs
182
src/client.rs
|
@ -1,13 +1,59 @@
|
|||
use cached::proc_macro::cached;
|
||||
use futures_lite::{future::Boxed, FutureExt};
|
||||
use hyper::{body, body::Buf, client, header, Body, Request, Response, Uri};
|
||||
use hyper::{body, body::Buf, client, header, Body, Method, Request, Response, Uri};
|
||||
use libflate::gzip;
|
||||
use percent_encoding::{percent_encode, CONTROLS};
|
||||
use serde_json::Value;
|
||||
use std::{io, result::Result};
|
||||
|
||||
use crate::dbg_msg;
|
||||
use crate::server::RequestExt;
|
||||
|
||||
const REDDIT_URL_BASE: &str = "https://www.reddit.com";
|
||||
|
||||
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
||||
/// making a `HEAD` request to Reddit at the path given in `path`.
|
||||
///
|
||||
/// This function returns `Ok(Some(path))`, where `path`'s value is identical
|
||||
/// to that of the value of the argument `path`, if Reddit responds to our
|
||||
/// `HEAD` request with a 2xx-family HTTP code. It will also return an
|
||||
/// `Ok(Some(String))` if Reddit responds to our `HEAD` request with a
|
||||
/// `Location` header in the response, and the HTTP code is in the 3xx-family;
|
||||
/// the `String` will contain the path as reported in `Location`. The return
|
||||
/// value is `Ok(None)` if Reddit responded with a 3xx, but did not provide a
|
||||
/// `Location` header. An `Err(String)` is returned if Reddit responds with a
|
||||
/// 429, or if we were unable to decode the value in the `Location` header.
|
||||
#[cached(size = 1024, time = 600, result = true)]
|
||||
pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
|
||||
let res = reddit_head(path.clone(), true).await?;
|
||||
|
||||
if res.status() == 429 {
|
||||
return Err("Too many requests.".to_string());
|
||||
};
|
||||
|
||||
// If Reddit responds with a 2xx, then the path is already canonical.
|
||||
if res.status().to_string().starts_with('2') {
|
||||
return Ok(Some(path));
|
||||
}
|
||||
|
||||
// If Reddit responds with anything other than 3xx (except for the 2xx as
|
||||
// above), return a None.
|
||||
if !res.status().to_string().starts_with('3') {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(
|
||||
res
|
||||
.headers()
|
||||
.get(header::LOCATION)
|
||||
.map(|val| percent_encode(val.as_bytes(), CONTROLS)
|
||||
.to_string()
|
||||
.trim_start_matches(REDDIT_URL_BASE)
|
||||
.to_string()
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn proxy(req: Request<Body>, format: &str) -> Result<Response<Body>, String> {
|
||||
let mut url = format!("{}?{}", format, req.uri().query().unwrap_or_default());
|
||||
|
||||
|
@ -63,21 +109,39 @@ async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String
|
|||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
|
||||
/// 3xx codes Reddit returns and will automatically redirect.
|
||||
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
request(&Method::GET, path, true, quarantine)
|
||||
}
|
||||
|
||||
/// Makes a HEAD request to Reddit at `path`. This will not follow redirects.
|
||||
fn reddit_head(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
request(&Method::HEAD, path, false, quarantine)
|
||||
}
|
||||
|
||||
/// Makes a request to Reddit. If `redirect` is `true`, request_with_redirect
|
||||
/// will recurse on the URL that Reddit provides in the Location HTTP header
|
||||
/// in its response.
|
||||
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
// Build Reddit URL from path.
|
||||
let url = format!("{}{}", REDDIT_URL_BASE, path);
|
||||
|
||||
// Prepare the HTTPS connector.
|
||||
let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build();
|
||||
|
||||
// Construct the hyper client from the HTTPS connector.
|
||||
let client: client::Client<_, hyper::Body> = client::Client::builder().build(https);
|
||||
|
||||
// Build request
|
||||
// Build request to Reddit. When making a GET, request gzip compression.
|
||||
// (Reddit doesn't do brotli yet.)
|
||||
let builder = Request::builder()
|
||||
.method("GET")
|
||||
.method(method)
|
||||
.uri(&url)
|
||||
.header("User-Agent", format!("web:libreddit:{}", env!("CARGO_PKG_VERSION")))
|
||||
.header("Host", "www.reddit.com")
|
||||
.header("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8")
|
||||
.header("Accept-Encoding", "gzip") // Reddit doesn't do brotli yet.
|
||||
.header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" })
|
||||
.header("Accept-Language", "en-US,en;q=0.5")
|
||||
.header("Connection", "keep-alive")
|
||||
.header("Cookie", if quarantine { "_options=%7B%22pref_quarantine_optin%22%3A%20true%7D" } else { "" })
|
||||
|
@ -87,8 +151,15 @@ fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String
|
|||
match builder {
|
||||
Ok(req) => match client.request(req).await {
|
||||
Ok(mut response) => {
|
||||
// Reddit may respond with a 3xx. Decide whether or not to
|
||||
// redirect based on caller params.
|
||||
if response.status().to_string().starts_with('3') {
|
||||
request(
|
||||
if !redirect {
|
||||
return Ok(response);
|
||||
};
|
||||
|
||||
return request(
|
||||
method,
|
||||
response
|
||||
.headers()
|
||||
.get("Location")
|
||||
|
@ -98,56 +169,64 @@ fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String
|
|||
})
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
true,
|
||||
quarantine,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
match response.headers().get(header::CONTENT_ENCODING) {
|
||||
// Content not compressed.
|
||||
None => Ok(response),
|
||||
.await;
|
||||
};
|
||||
|
||||
// Content gzipped.
|
||||
Some(hdr) => {
|
||||
// Since we requested gzipped content, we expect
|
||||
// to get back gzipped content. If we get
|
||||
// back anything else, that's a problem.
|
||||
if hdr.ne("gzip") {
|
||||
return Err("Reddit response was encoded with an unsupported compressor".to_string());
|
||||
}
|
||||
match response.headers().get(header::CONTENT_ENCODING) {
|
||||
// Content not compressed.
|
||||
None => Ok(response),
|
||||
|
||||
// The body must be something that implements
|
||||
// std::io::Read, hence the conversion to
|
||||
// bytes::buf::Buf and then transformation into a
|
||||
// Reader.
|
||||
let mut decompressed: Vec<u8>;
|
||||
{
|
||||
let mut aggregated_body = match body::aggregate(response.body_mut()).await {
|
||||
Ok(b) => b.reader(),
|
||||
Err(e) => return Err(e.to_string()),
|
||||
};
|
||||
|
||||
let mut decoder = match gzip::Decoder::new(&mut aggregated_body) {
|
||||
Ok(decoder) => decoder,
|
||||
Err(e) => return Err(e.to_string()),
|
||||
};
|
||||
|
||||
decompressed = Vec::<u8>::new();
|
||||
match io::copy(&mut decoder, &mut decompressed) {
|
||||
Ok(_) => {}
|
||||
Err(e) => return Err(e.to_string()),
|
||||
};
|
||||
}
|
||||
|
||||
response.headers_mut().remove(header::CONTENT_ENCODING);
|
||||
response.headers_mut().insert(header::CONTENT_LENGTH, decompressed.len().into());
|
||||
*(response.body_mut()) = Body::from(decompressed);
|
||||
|
||||
Ok(response)
|
||||
// Content encoded (hopefully with gzip).
|
||||
Some(hdr) => {
|
||||
match hdr.to_str() {
|
||||
Ok(val) => match val {
|
||||
"gzip" => {}
|
||||
"identity" => return Ok(response),
|
||||
_ => return Err("Reddit response was encoded with an unsupported compressor".to_string()),
|
||||
},
|
||||
Err(_) => return Err("Reddit response was invalid".to_string()),
|
||||
}
|
||||
|
||||
// We get here if the body is gzip-compressed.
|
||||
|
||||
// The body must be something that implements
|
||||
// std::io::Read, hence the conversion to
|
||||
// bytes::buf::Buf and then transformation into a
|
||||
// Reader.
|
||||
let mut decompressed: Vec<u8>;
|
||||
{
|
||||
let mut aggregated_body = match body::aggregate(response.body_mut()).await {
|
||||
Ok(b) => b.reader(),
|
||||
Err(e) => return Err(e.to_string()),
|
||||
};
|
||||
|
||||
let mut decoder = match gzip::Decoder::new(&mut aggregated_body) {
|
||||
Ok(decoder) => decoder,
|
||||
Err(e) => return Err(e.to_string()),
|
||||
};
|
||||
|
||||
decompressed = Vec::<u8>::new();
|
||||
if let Err(e) = io::copy(&mut decoder, &mut decompressed) {
|
||||
return Err(e.to_string());
|
||||
};
|
||||
}
|
||||
|
||||
response.headers_mut().remove(header::CONTENT_ENCODING);
|
||||
response.headers_mut().insert(header::CONTENT_LENGTH, decompressed.len().into());
|
||||
*(response.body_mut()) = Body::from(decompressed);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => Err(e.to_string()),
|
||||
Err(e) => {
|
||||
dbg_msg!("{} {}: {}", method, path, e);
|
||||
|
||||
Err(e.to_string())
|
||||
}
|
||||
},
|
||||
Err(_) => Err("Post url contains non-ASCII characters".to_string()),
|
||||
}
|
||||
|
@ -158,9 +237,6 @@ fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String
|
|||
// Make a request to a Reddit API and parse the JSON response
|
||||
#[cached(size = 100, time = 30, result = true)]
|
||||
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
||||
// Build Reddit url from path
|
||||
let url = format!("https://www.reddit.com{}", path);
|
||||
|
||||
// Closure to quickly build errors
|
||||
let err = |msg: &str, e: String| -> Result<Value, String> {
|
||||
// eprintln!("{} - {}: {}", url, msg, e);
|
||||
|
@ -168,7 +244,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||
};
|
||||
|
||||
// Fetch the url...
|
||||
match request(url.clone(), quarantine).await {
|
||||
match reddit_get(path.clone(), quarantine).await {
|
||||
Ok(response) => {
|
||||
let status = response.status();
|
||||
|
||||
|
@ -186,7 +262,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||
.as_str()
|
||||
.unwrap_or_else(|| {
|
||||
json["message"].as_str().unwrap_or_else(|| {
|
||||
eprintln!("{} - Error parsing reddit error", url);
|
||||
eprintln!("{}{} - Error parsing reddit error", REDDIT_URL_BASE, path);
|
||||
"Error parsing reddit error"
|
||||
})
|
||||
})
|
||||
|
|
31
src/main.rs
31
src/main.rs
|
@ -17,7 +17,7 @@ use futures_lite::FutureExt;
|
|||
use hyper::{header::HeaderValue, Body, Request, Response};
|
||||
|
||||
mod client;
|
||||
use client::proxy;
|
||||
use client::{canonical_path, proxy};
|
||||
use server::RequestExt;
|
||||
use utils::{error, redirect, ThemeAssets};
|
||||
|
||||
|
@ -259,9 +259,6 @@ async fn main() {
|
|||
|
||||
app.at("/r/:sub/:sort").get(|r| subreddit::community(r).boxed());
|
||||
|
||||
// Comments handler
|
||||
app.at("/comments/:id").get(|r| post::item(r).boxed());
|
||||
|
||||
// Front page
|
||||
app.at("/").get(|r| subreddit::community(r).boxed());
|
||||
|
||||
|
@ -279,13 +276,25 @@ async fn main() {
|
|||
// Handle about pages
|
||||
app.at("/about").get(|req| error(req, "About pages aren't added yet".to_string()).boxed());
|
||||
|
||||
app.at("/:id").get(|req: Request<Body>| match req.param("id").as_deref() {
|
||||
// Sort front page
|
||||
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).boxed(),
|
||||
// Short link for post
|
||||
Some(id) if id.len() > 4 && id.len() < 7 => post::item(req).boxed(),
|
||||
// Error message for unknown pages
|
||||
_ => error(req, "Nothing here".to_string()).boxed(),
|
||||
app.at("/:id").get(|req: Request<Body>| {
|
||||
Box::pin(async move {
|
||||
match req.param("id").as_deref() {
|
||||
// Sort front page
|
||||
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).await,
|
||||
|
||||
// Short link for post
|
||||
Some(id) if (5..7).contains(&id.len()) => match canonical_path(format!("/{}", id)).await {
|
||||
Ok(path_opt) => match path_opt {
|
||||
Some(path) => Ok(redirect(path)),
|
||||
None => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,
|
||||
},
|
||||
Err(e) => error(req, e).await,
|
||||
},
|
||||
|
||||
// Error message for unknown pages
|
||||
_ => error(req, "Nothing here".to_string()).await,
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
// Default service in case no routes match
|
||||
|
|
|
@ -716,10 +716,11 @@ pub fn redirect(path: String) -> Response<Body> {
|
|||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub async fn error(req: Request<Body>, msg: String) -> Result<Response<Body>, String> {
|
||||
/// Renders a generic error landing page.
|
||||
pub async fn error(req: Request<Body>, msg: impl ToString) -> Result<Response<Body>, String> {
|
||||
let url = req.uri().to_string();
|
||||
let body = ErrorTemplate {
|
||||
msg,
|
||||
msg: msg.to_string(),
|
||||
prefs: Preferences::new(req),
|
||||
url,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue