diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 5505ae2..5ef95a7 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,10 +1,10 @@ version: 2 updates: - - package-ecosystem: "cargo" - directory: "/" + - package-ecosystem: cargo + directory: '/' commit-message: - prefix: "deps: " + prefix: 'deps: ' schedule: - day: "saturday" - interval: "weekly" - time: "07:15" \ No newline at end of file + day: saturday + interval: weekly + time: '07:15' \ No newline at end of file diff --git a/src/webhooks/actix_web.rs b/src/webhooks/actix_web.rs index 9393140..ce3209f 100644 --- a/src/webhooks/actix_web.rs +++ b/src/webhooks/actix_web.rs @@ -1,62 +1,119 @@ -use crate::Incoming; -use actix_web::{ - dev::Payload, - error::{Error, ErrorBadRequest, ErrorUnauthorized}, - web::Json, - FromRequest, HttpRequest, -}; -use serde::de::DeserializeOwned; +use super::IncomingPayload; use std::{ + fmt::{self, Display, Formatter}, future::Future, pin::Pin, - task::{ready, Context, Poll}, + task::{Context, Poll, ready}, + time::{Duration, Instant}, +}; + +use actix_web::{ + FromRequest, HttpRequest, HttpResponse, ResponseError, body::BoxBody, dev::Payload, + http::StatusCode, }; +use chrono::{DateTime, Utc}; +use futures_core::stream::Stream; #[doc(hidden)] -pub struct IncomingFut { +#[derive(Debug)] +pub enum IncomingPayloadError { + BadRequest, + Timeout, +} + +impl Display for IncomingPayloadError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Self::BadRequest => "Bad Request.", + + Self::Timeout => "Request timed out.", + }) + } +} + +impl ResponseError for IncomingPayloadError { + fn error_response(&self) -> HttpResponse { + match self { + Self::BadRequest => HttpResponse::BadRequest().body("Bad Request"), + + Self::Timeout => HttpResponse::RequestTimeout().body("Request timed out"), + } + } + + fn status_code(&self) -> StatusCode { + match self { + Self::BadRequest => StatusCode::BAD_REQUEST, + + Self::Timeout => StatusCode::REQUEST_TIMEOUT, + } + } +} + +#[doc(hidden)] +pub struct IncomingPayloadFut { req: HttpRequest, - json_fut: as FromRequest>::Future, + payload: Payload, + body: Vec, + start: Instant, + now: DateTime, +} + +impl IncomingPayloadFut { + fn timed_out(&self) -> bool { + self.start.elapsed() > Duration::from_secs(5) + } } -impl Future for IncomingFut -where - T: DeserializeOwned, -{ - type Output = Result, Error>; +impl Future for IncomingPayloadFut { + type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if let Ok(json) = ready!(Pin::new(&mut self.json_fut).poll(cx)) { - let headers = self.req.headers(); - - if let Some(authorization) = headers.get("Authorization") { - if let Ok(authorization) = authorization.to_str() { - return Poll::Ready(Ok(Incoming { - authorization: authorization.to_owned(), - data: json.into_inner(), - })); - } + if self.timed_out() { + return Poll::Ready(Err(IncomingPayloadError::Timeout)); + } + + while let Some(body) = ready!(Pin::new(&mut self.payload).poll_next(cx)) { + if self.timed_out() { + return Poll::Ready(Err(IncomingPayloadError::Timeout)); } - return Poll::Ready(Err(ErrorUnauthorized("401"))); + if let Ok(body) = body { + self.body.extend_from_slice(&body); + } else { + return Poll::Ready(Err(IncomingPayloadError::BadRequest)); + } } - Poll::Ready(Err(ErrorBadRequest("400"))) + let headers = self.req.headers(); + + if let (Some(signature), Some(trace)) = ( + headers.get("x-topgg-signature"), + headers.get("x-topgg-trace"), + ) && let (Ok(signature), Ok(trace), Ok(body)) = ( + signature.to_str(), + trace.to_str(), + str::from_utf8(&self.body), + ) && let Some(incoming) = IncomingPayload::new(self.now, body.into(), signature, trace) + { + return Poll::Ready(Ok(incoming)); + } + + Poll::Ready(Err(IncomingPayloadError::BadRequest)) } } #[cfg_attr(docsrs, doc(cfg(feature = "actix-web")))] -impl FromRequest for Incoming -where - T: DeserializeOwned, -{ - type Error = Error; - type Future = IncomingFut; - - #[inline(always)] +impl FromRequest for IncomingPayload { + type Error = IncomingPayloadError; + type Future = IncomingPayloadFut; + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - IncomingFut { + IncomingPayloadFut { req: req.clone(), - json_fut: Json::from_request(req, payload), + payload: payload.take(), + body: vec![], + start: Instant::now(), + now: Utc::now(), } } } diff --git a/src/webhooks/axum.rs b/src/webhooks/axum.rs index 4175b1b..5ea8a13 100644 --- a/src/webhooks/axum.rs +++ b/src/webhooks/axum.rs @@ -1,45 +1,77 @@ -use super::Webhook; +use super::{Payload, PayloadResult}; +use std::{sync::Arc, time::Duration}; + use axum::{ - extract::State, + BoxError, Router, + error_handling::HandleErrorLayer, + extract::{DefaultBodyLimit, State}, http::{HeaderMap, StatusCode}, - response::IntoResponse, + response::{IntoResponse, Response}, routing::post, - Router, }; -use serde::de::DeserializeOwned; -use std::sync::Arc; +use chrono::Utc; +use tower::{ServiceBuilder, timeout::error::Elapsed}; + +/// An axum webhook listener for listening to payloads. +/// +/// # Example +/// +/// ```rust,no_run +/// struct MyTopggListener {} +/// +/// #[async_trait::async_trait] +/// impl topgg::axum::Listener for MyTopggListener { +/// async fn callback(self: Arc, payload: Payload, _trace: &str) -> Response { +/// println!("{payload:?}"); +/// +/// (StatusCode::NO_CONTENT, ()).into_response() +/// } +/// } +/// ``` +#[async_trait::async_trait] +#[cfg_attr(docsrs, doc(cfg(feature = "axum")))] +pub trait Listener: Send + Sync + 'static { + async fn callback(self: Arc, payload: Payload, trace: &str) -> Response; +} struct WebhookState { state: Arc, - password: Arc, + secret: Arc, } impl Clone for WebhookState { - #[inline(always)] fn clone(&self) -> Self { Self { - state: Arc::clone(&self.state), - password: Arc::clone(&self.password), + state: self.state.clone(), + secret: self.secret.clone(), } } } -/// Creates a new axum [`Router`] for receiving vote events. +/// Creates a new axum [`Router`] for receiving webhook payloads. /// /// # Example /// /// ```rust,no_run -/// use axum::{routing::get, Router}; -/// use topgg::{VoteEvent, Webhook}; -/// use tokio::net::TcpListener; +/// use topgg::Payload; /// use std::sync::Arc; /// -/// struct MyVoteListener {} +/// use axum::{ +/// Router, +/// http::status::StatusCode, +/// response::{IntoResponse, Response}, +/// routing::get, +/// }; +/// use tokio::net::TcpListener; +/// +/// struct MyTopggListener {} /// /// #[async_trait::async_trait] -/// impl Webhook for MyVoteListener { -/// async fn callback(&self, vote: VoteEvent) { -/// println!("A user with the ID of {} has voted us on Top.gg!", vote.voter_id); +/// impl topgg::axum::Listener for MyTopggListener { +/// async fn callback(self: Arc, payload: Payload, _trace: &str) -> Response { +/// println!("{payload:?}"); +/// +/// (StatusCode::NO_CONTENT, ()).into_response() /// } /// } /// @@ -49,11 +81,12 @@ impl Clone for WebhookState { /// /// #[tokio::main] /// async fn main() { -/// let state = Arc::new(MyVoteListener {}); +/// let state = Arc::new(MyTopggListener {}); /// +/// // POST /webhook /// let router = Router::new().route("/", get(index)).nest( -/// "/votes", -/// topgg::axum::webhook(env!("MY_TOPGG_WEBHOOK_SECRET").to_string(), Arc::clone(&state)), +/// "/webhook", +/// topgg::axum::webhook(Arc::clone(&state), env!("TOPGG_WEBHOOK_SECRET").into()), /// ); /// /// let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap(); @@ -61,36 +94,62 @@ impl Clone for WebhookState { /// axum::serve(listener, router).await.unwrap(); /// } /// ``` -#[inline(always)] #[cfg_attr(docsrs, doc(cfg(feature = "axum")))] -pub fn webhook(password: String, state: Arc) -> Router +pub fn webhook(state: Arc, secret: String) -> Router where - D: DeserializeOwned + Send, - T: Webhook, + S: Listener, { + let timeout_layer = ServiceBuilder::new() + .layer(HandleErrorLayer::new(|err: BoxError| async move { + if err.is::() { + (StatusCode::REQUEST_TIMEOUT, "Request timed out") + } else { + (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error") + } + })) + .timeout(Duration::from_secs(5)); + Router::new() .route( "/", post( - async |headers: HeaderMap, State(webhook): State>, body: String| { - if let Some(authorization) = headers.get("Authorization") { - if let Ok(authorization) = authorization.to_str() { - if authorization == *(webhook.password) { - if let Ok(data) = serde_json::from_str(&body) { - webhook.state.callback(data).await; + async |headers: HeaderMap, State(wrapped_state): State>, body: String| { + let now = Utc::now(); - return (StatusCode::NO_CONTENT, ()).into_response(); - } + if let Some(signature) = headers.get("x-topgg-signature") + && let Ok(signature) = signature.to_str() + && let Some(trace) = headers.get("x-topgg-trace") + && let Ok(trace) = trace.to_str() + { + match Payload::new(now, body, signature, &wrapped_state.secret) { + PayloadResult::Accepted(payload) => { + wrapped_state.state.callback(payload, trace).await + } + + PayloadResult::Forbidden => (StatusCode::FORBIDDEN, "Forbidden").into_response(), + + PayloadResult::BadRequest => (StatusCode::BAD_REQUEST, "Bad Request").into_response(), + + PayloadResult::Unauthorized => { + (StatusCode::UNAUTHORIZED, "Unauthorized").into_response() + } + + PayloadResult::DeserializationFailure => (StatusCode::NO_CONTENT, "").into_response(), + + PayloadResult::InternalServerError => { + (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response() } } + } else { + (StatusCode::BAD_REQUEST, "Bad Request").into_response() } - - (StatusCode::UNAUTHORIZED, ()).into_response() }, ), ) + .layer(timeout_layer.into_inner()) + .layer(DefaultBodyLimit::max(2 * 1024 * 1024)) .with_state(WebhookState { state, - password: Arc::new(password), + secret: Arc::new(secret), }) } diff --git a/src/webhooks/mod.rs b/src/webhooks/mod.rs index a5ec591..d358965 100644 --- a/src/webhooks/mod.rs +++ b/src/webhooks/mod.rs @@ -1,6 +1,6 @@ -mod vote; -#[cfg_attr(docsrs, doc(cfg(feature = "webhooks")))] -pub use vote::*; +mod payload; + +pub use payload::{IncomingPayload, Payload, PayloadResult}; #[cfg(feature = "actix-web")] mod actix_web; @@ -23,52 +23,3 @@ cfg_if::cfg_if! { pub mod warp; } } - -cfg_if::cfg_if! { - if #[cfg(any(feature = "actix-web", feature = "rocket"))] { - /// An unauthenticated incoming Top.gg webhook request. - #[must_use] - #[cfg_attr(docsrs, doc(cfg(any(feature = "actix-web", feature = "rocket"))))] - pub struct Incoming { - pub(crate) authorization: String, - pub(crate) data: T, - } - - impl Incoming { - /// Authenticates a valid password with this request. - #[must_use] - #[inline(always)] - pub fn authenticate(self, password: &str) -> Option { - if self.authorization == password { - Some(self.data) - } else { - None - } - } - } - - impl Clone for Incoming - where - T: Clone, - { - #[inline(always)] - fn clone(&self) -> Self { - Self { - authorization: self.authorization.clone(), - data: self.data.clone(), - } - } - } - } -} - -cfg_if::cfg_if! { - if #[cfg(any(feature = "axum", feature = "warp"))] { - /// Webhook event handler. - #[cfg_attr(docsrs, doc(cfg(any(feature = "axum", feature = "warp"))))] - #[async_trait::async_trait] - pub trait Webhook: Send + Sync + 'static { - async fn callback(&self, data: T); - } - } -} diff --git a/src/webhooks/payload.rs b/src/webhooks/payload.rs new file mode 100644 index 0000000..05dcfec --- /dev/null +++ b/src/webhooks/payload.rs @@ -0,0 +1,276 @@ +use super::super::{PartialProject, User, snowflake}; +use std::collections::HashMap; + +use chrono::{DateTime, Utc}; +use hmac::{Hmac, Mac}; +use log::warn; +use serde::Deserialize; +use sha2::Sha256; + +/// A webhook payload. +#[non_exhaustive] +#[derive(Clone, Debug, Deserialize)] +#[serde(tag = "type", content = "data")] +#[cfg_attr(docsrs, doc(cfg(feature = "webhooks")))] +pub enum Payload { + /// An `integration.create` webhook payload. Fires when a user has connected to your webhook integration. + #[serde(rename = "integration.create")] + IntegrationCreate { + /// The unique identifier for this connection. + #[serde(deserialize_with = "snowflake::deserialize")] + connection_id: u64, + + /// The secret used to verify future webhook deliveries. + #[serde(rename = "webhook_secret")] + secret: String, + + /// The project that the integration refers to. + project: PartialProject, + + /// The user who triggered this event. + user: User, + }, + + /// An `integration.delete` webhook payload. Fires when a user has disconnected from your webhook integration. + #[serde(rename = "integration.delete")] + IntegrationDelete { + /// The unique identifier for this connection. + #[serde(deserialize_with = "snowflake::deserialize")] + connection_id: u64, + }, + + /// A `webhook.test` webhook payload. Fires upon sent test from the project dashboard. + #[serde(rename = "webhook.test")] + Test { + /// The project that the test refers to. + project: PartialProject, + + /// The user who triggered this test. + user: User, + }, + + /// A `vote.create` webhook payload. Fires when a user votes for your project. + #[serde(rename = "vote.create")] + VoteCreate { + /// The vote's ID. + #[serde(deserialize_with = "snowflake::deserialize")] + id: u64, + + /// The number of votes this vote counted for. This is a rounded integer value which determines how many points this individual vote was worth. + weight: u64, + + /// When the vote was cast. + #[serde(rename = "created_at")] + voted_at: DateTime, + + /// When the vote expires and the user is required to vote again. + expires_at: DateTime, + + /// The project that received this vote. + project: PartialProject, + + /// The user who voted for this project. + user: User, + }, +} + +impl Payload { + #[allow(clippy::new_ret_no_self)] + #[cfg(any(feature = "axum", feature = "warp"))] + pub(super) fn new( + now: DateTime, + body: String, + signature: &str, + secret: &str, + ) -> PayloadResult { + IncomingPayload::new(now, body, signature, "").map_or(PayloadResult::BadRequest, |incoming| { + incoming.authenticate(secret) + }) + } +} + +/// A processed [`Payload`]. +#[cfg_attr(docsrs, doc(cfg(feature = "webhooks")))] +pub enum PayloadResult { + /// The payload has been successfully authenticated. + Accepted(Payload), + + /// The timestamp is outside of the accepted time window, possibly being a part of a replay attack. + Forbidden, + + /// The request's headers are missing or invalid. + BadRequest, + + /// The request's signature cannot be authenticated with the correct webhook secret. + Unauthorized, + + /// Unable deserialize payload. This could possibly be a bug with the SDK. + /// + /// It's recommended to return a 200 and 204 status code and report this to the SDK's maintainers when this happens. + DeserializationFailure, + + /// Unable to create a SHA-256 HMAC instance from the specified webhook secret. + InternalServerError, +} + +/// An incoming [`Payload`] that is yet to be [authenticated with a secret][IncomingPayload::authenticate]. +/// +/// # Examples +/// +/// With actix-web: +/// +/// ```rust,no_run +/// use topgg::{IncomingPayload, PayloadResult}; +/// use std::io; +/// +/// use actix_web::{ +/// App, HttpServer, +/// error::{Error, ErrorBadRequest, ErrorForbidden, ErrorInternalServerError, ErrorUnauthorized}, +/// get, post, +/// }; +/// +/// #[get("/")] +/// async fn index() -> &'static str { +/// "Hello, World!" +/// } +/// +/// // POST /webhook +/// #[post("/webhook")] +/// async fn webhook(payload: IncomingPayload) -> Result<&'static str, Error> { +/// match payload.authenticate(env!("TOPGG_WEBHOOK_SECRET")) { +/// PayloadResult::Accepted(payload) => { +/// println!("{payload:?}"); +/// +/// Ok("ok") +/// } +/// +/// PayloadResult::Forbidden => Err(ErrorForbidden("Forbidden")), +/// +/// PayloadResult::BadRequest => Err(ErrorBadRequest("Bad Request")), +/// +/// PayloadResult::Unauthorized => Err(ErrorUnauthorized("Unauthorized")), +/// +/// PayloadResult::DeserializationFailure => Ok(""), +/// +/// PayloadResult::InternalServerError => Err(ErrorInternalServerError("Internal Server Error")), +/// } +/// } +/// +/// #[actix_web::main] +/// async fn main() -> io::Result<()> { +/// HttpServer::new(|| App::new().service(index).service(webhook)) +/// .bind("127.0.0.1:8080")? +/// .run() +/// .await +/// } +/// ``` +/// +/// With rocket: +/// +/// ```rust,no_run +/// use topgg::{IncomingPayload, PayloadResult}; +/// +/// use rocket::{Build, Rocket, get, http::Status, launch, post, routes}; +/// +/// #[get("/")] +/// fn index() -> &'static str { +/// "Hello, World!" +/// } +/// +/// // POST /webhook +/// #[post("/webhook", data = "")] +/// fn webhook(payload: IncomingPayload) -> Status { +/// match payload.authenticate(env!("TOPGG_WEBHOOK_SECRET")) { +/// PayloadResult::Accepted(payload) => { +/// println!("{payload:?}"); +/// +/// Status::NoContent +/// } +/// +/// PayloadResult::Forbidden => Status::Forbidden, +/// +/// PayloadResult::BadRequest => Status::BadRequest, +/// +/// PayloadResult::Unauthorized => Status::Unauthorized, +/// +/// PayloadResult::DeserializationFailure => Status::NoContent, +/// +/// PayloadResult::InternalServerError => Status::InternalServerError, +/// } +/// } +/// +/// #[launch] +/// fn rocket() -> Rocket { +/// rocket::build().mount("/", routes![index, webhook]) +/// } +/// ``` +#[cfg_attr(docsrs, doc(cfg(feature = "webhooks")))] +pub struct IncomingPayload { + timestamp: i64, + now: i64, + signature: Vec, + body: String, + trace: String, +} + +impl IncomingPayload { + /// Tries to create a new incoming payload from the current timestamp, a request body, an `x-topgg-signature` header, and an `x-topgg-trace` header. Returns [`None`] if the header values cannot be parsed. + #[must_use] + pub fn new(now: DateTime, body: String, signature: &str, trace: &str) -> Option { + let signature = signature + .split(',') + .filter_map(|p| p.split_once('=')) + .collect::>(); + + if let (Some(timestamp), Some(signature)) = (signature.get("t"), signature.get("v1")) + && let (Ok(timestamp), Ok(signature)) = (timestamp.parse(), hex::decode(signature)) + { + Some(Self { + timestamp, + now: now.timestamp_millis(), + signature, + body, + trace: trace.into(), + }) + } else { + None + } + } + + /// Tries to authenticate a valid secret with this request. + #[must_use] + pub fn authenticate(&self, secret: &str) -> PayloadResult { + if (self.now - (self.timestamp * 1000)).abs() > 30000 { + return PayloadResult::Forbidden; + } + + let Ok(mut hmac) = Hmac::::new_from_slice(secret.as_bytes()) else { + warn!( + "Unable to create a SHA-256 HMAC instance from the specified webhook secret. Dismissing payload request." + ); + + return PayloadResult::InternalServerError; + }; + + hmac.update(format!("{}.{}", self.timestamp, self.body).as_bytes()); + + if hmac.verify_slice(&self.signature).is_ok() { + serde_json::from_str(&self.body).map_or_else(|_| { + warn!( + "Unable to parse Top.gg webhook payload. Please report this bug to the SDK maintainers.\n--- BEGIN BODY DUMP ---\n{}\n--- END BODY DUMP ---", + self.body + ); + + PayloadResult::DeserializationFailure + }, PayloadResult::Accepted) + } else { + PayloadResult::Unauthorized + } + } + + /// Retrieves the payload's `x-topgg-trace` header for debugging and correlating requests with Top.gg support. + #[must_use] + pub fn get_trace(&self) -> &str { + &self.trace + } +} diff --git a/src/webhooks/rocket.rs b/src/webhooks/rocket.rs index 0bfb988..38ecc6a 100644 --- a/src/webhooks/rocket.rs +++ b/src/webhooks/rocket.rs @@ -1,33 +1,47 @@ -use crate::Incoming; +use super::IncomingPayload; +use std::time::Duration; + +use chrono::Utc; use rocket::{ - data::{Data, FromData, Outcome}, + data::{Data, FromData, Outcome, ToByteUnit}, http::Status, request::Request, - serde::json::Json, }; -use serde::de::DeserializeOwned; +use tokio::time::timeout; #[cfg_attr(docsrs, doc(cfg(feature = "rocket")))] #[rocket::async_trait] -impl<'r, T> FromData<'r> for Incoming -where - T: DeserializeOwned, -{ +impl<'r> FromData<'r> for IncomingPayload { type Error = (); async fn from_data(request: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> { + let now = Utc::now(); let headers = request.headers(); - if let Some(authorization) = headers.get_one("Authorization") { - return match as FromData>::from_data(request, data).await { - Outcome::Success(data) => Outcome::Success(Self { - authorization: authorization.to_owned(), - data: data.into_inner(), - }), - _ => Outcome::Error((Status::BadRequest, ())), - }; + if let (Some(signature), Some(trace)) = ( + headers.get_one("x-topgg-signature"), + headers.get_one("x-topgg-trace"), + ) { + match timeout( + Duration::from_secs(5), + data.open(2.mebibytes()).into_bytes(), + ) + .await + { + Ok(Ok(body)) => { + if let Ok(body) = String::from_utf8(body.into_inner()) + && let Some(payload) = Self::new(now, body, signature, trace) + { + return Outcome::Success(payload); + } + } + + Err(_) => return Outcome::Error((Status::RequestTimeout, ())), + + _ => {} + } } - Outcome::Error((Status::Unauthorized, ())) + Outcome::Error((Status::BadRequest, ())) } } diff --git a/src/webhooks/vote.rs b/src/webhooks/vote.rs deleted file mode 100644 index 0bbbb54..0000000 --- a/src/webhooks/vote.rs +++ /dev/null @@ -1,67 +0,0 @@ -use crate::snowflake; -use serde::{Deserialize, Deserializer}; -use std::collections::HashMap; - -#[inline(always)] -fn deserialize_is_test<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - String::deserialize(deserializer).map(|s| s == "test") -} - -fn deserialize_query_string<'de, D>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - Ok( - String::deserialize(deserializer) - .map(|s| { - let mut output = HashMap::new(); - - for mut it in s - .trim_start_matches('?') - .split('&') - .map(|pair| pair.split('=')) - { - if let (Some(k), Some(v)) = (it.next(), it.next()) { - if let Ok(v) = urlencoding::decode(v) { - output.insert(k.to_owned(), v.into_owned()); - } - } - } - - output - }) - .unwrap_or_default(), - ) -} - -/// A dispatched Top.gg vote event. -#[must_use] -#[derive(Clone, Debug, Deserialize)] -pub struct VoteEvent { - /// The ID of the project that received a vote. - #[serde( - deserialize_with = "snowflake::deserialize", - alias = "bot", - alias = "guild" - )] - pub receiver_id: u64, - - /// The ID of the Top.gg user who voted. - #[serde(deserialize_with = "snowflake::deserialize", rename = "user")] - pub voter_id: u64, - - /// Whether this vote is just a test done from the page settings. - #[serde(deserialize_with = "deserialize_is_test", rename = "type")] - pub is_test: bool, - - /// Whether the weekend multiplier is active, where a single vote counts as two. - #[serde(default, rename = "isWeekend")] - pub is_weekend: bool, - - /// Query strings found on the vote page. - #[serde(default, deserialize_with = "deserialize_query_string")] - pub query: HashMap, -} diff --git a/src/webhooks/warp.rs b/src/webhooks/warp.rs index 51c108b..6c1a763 100644 --- a/src/webhooks/warp.rs +++ b/src/webhooks/warp.rs @@ -1,36 +1,44 @@ -use super::Webhook; -use serde::de::DeserializeOwned; -use std::sync::Arc; -use warp::{body, header, http::StatusCode, path, Filter, Rejection, Reply}; +use super::{Payload, PayloadResult}; + +use bytes::Bytes; +use chrono::Utc; +use warp::{Filter, Rejection, body, header, path}; /// Creates a new warp [`Filter`] for receiving webhook events. /// /// # Example /// /// ```rust,no_run -/// use std::{net::SocketAddr, sync::Arc}; -/// use topgg::{VoteEvent, Webhook}; -/// use warp::Filter; -/// -/// struct MyVoteListener {} +/// use topgg::PayloadResult; +/// use std::net::SocketAddr; /// -/// #[async_trait::async_trait] -/// impl Webhook for MyVoteListener { -/// async fn callback(&self, vote: VoteEvent) { -/// println!("A user with the ID of {} has voted us on Top.gg!", vote.voter_id); -/// } -/// } +/// use warp::{Filter, http::StatusCode, reply}; /// /// #[tokio::main] /// async fn main() { -/// let state = Arc::new(MyVoteListener {}); +/// // POST /webhook +/// let webhook = +/// topgg::warp::webhook("webhook", env!("TOPGG_WEBHOOK_SECRET").into()).then(|payload, _trace| async move { +/// match payload { +/// PayloadResult::Accepted(payload) => { +/// println!("{payload:?}"); +/// +/// reply::with_status("", StatusCode::NO_CONTENT) +/// } +/// +/// PayloadResult::Forbidden => reply::with_status("Forbidden", StatusCode::FORBIDDEN), /// -/// // POST /votes -/// let webhook = topgg::warp::webhook( -/// "votes", -/// env!("MY_TOPGG_WEBHOOK_SECRET").to_string(), -/// Arc::clone(&state), -/// ); +/// PayloadResult::BadRequest => reply::with_status("Bad Request", StatusCode::BAD_REQUEST), +/// +/// PayloadResult::Unauthorized => reply::with_status("Unauthorized", StatusCode::UNAUTHORIZED), +/// +/// PayloadResult::DeserializationFailure => reply::with_status("", StatusCode::NO_CONTENT), +/// +/// PayloadResult::InternalServerError => { +/// reply::with_status("Internal Server Error", StatusCode::INTERNAL_SERVER_ERROR) +/// } +/// } +/// }); /// /// let routes = warp::get().map(|| "Hello, World!").or(webhook); /// @@ -39,34 +47,23 @@ use warp::{body, header, http::StatusCode, path, Filter, Rejection, Reply}; /// warp::serve(routes).run(addr).await /// } /// ``` +#[must_use] #[cfg_attr(docsrs, doc(cfg(feature = "warp")))] -pub fn webhook( +pub fn webhook( endpoint: &'static str, - password: String, - state: Arc, -) -> impl Filter + Clone -where - D: DeserializeOwned + Send, - T: Webhook, -{ - let password = Arc::new(password); - + secret: String, +) -> impl Filter + Clone { warp::post() .and(path(endpoint)) - .and(header("Authorization")) - .and(body::json()) - .then(move |auth: String, data: D| { - let current_state = Arc::clone(&state); - let current_password = Arc::clone(&password); - - async move { - if auth == *current_password { - current_state.callback(data).await; + .and(header("x-topgg-signature")) + .and(body::content_length_limit(2 * 1024 * 1024)) + .and(body::bytes()) + .map(move |signature: String, body: Bytes| { + let now = Utc::now(); - StatusCode::NO_CONTENT - } else { - StatusCode::UNAUTHORIZED - } - } + String::from_utf8(body.to_vec()).map_or(PayloadResult::BadRequest, |body| { + Payload::new(now, body, &signature, &secret) + }) }) + .and(header("x-topgg-trace")) }