Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
version: 2
updates:
- package-ecosystem: "cargo"
directory: "/"
- package-ecosystem: cargo
directory: '/'
commit-message:
prefix: "deps: "
prefix: 'deps: '
schedule:
day: "saturday"
interval: "weekly"
time: "07:15"
day: saturday
interval: weekly
time: '07:15'
133 changes: 95 additions & 38 deletions src/webhooks/actix_web.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,119 @@
use crate::Incoming;
use actix_web::{
dev::Payload,
error::{Error, ErrorBadRequest, ErrorUnauthorized},
web::Json,
FromRequest, HttpRequest,
};
use serde::de::DeserializeOwned;
use super::IncomingPayload;
use std::{
fmt::{self, Display, Formatter},
future::Future,
pin::Pin,
task::{ready, Context, Poll},
task::{Context, Poll, ready},
time::{Duration, Instant},
};

use actix_web::{
FromRequest, HttpRequest, HttpResponse, ResponseError, body::BoxBody, dev::Payload,
http::StatusCode,
};
use chrono::{DateTime, Utc};
use futures_core::stream::Stream;

#[doc(hidden)]
pub struct IncomingFut<T: DeserializeOwned> {
#[derive(Debug)]
pub enum IncomingPayloadError {
BadRequest,
Timeout,
}

impl Display for IncomingPayloadError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::BadRequest => "Bad Request.",

Self::Timeout => "Request timed out.",
})
}
}

impl ResponseError for IncomingPayloadError {
fn error_response(&self) -> HttpResponse<BoxBody> {
match self {
Self::BadRequest => HttpResponse::BadRequest().body("Bad Request"),

Self::Timeout => HttpResponse::RequestTimeout().body("Request timed out"),
}
}

fn status_code(&self) -> StatusCode {
match self {
Self::BadRequest => StatusCode::BAD_REQUEST,

Self::Timeout => StatusCode::REQUEST_TIMEOUT,
}
}
}

#[doc(hidden)]
pub struct IncomingPayloadFut {
req: HttpRequest,
json_fut: <Json<T> as FromRequest>::Future,
payload: Payload,
body: Vec<u8>,
start: Instant,
now: DateTime<Utc>,
}

impl IncomingPayloadFut {
fn timed_out(&self) -> bool {
self.start.elapsed() > Duration::from_secs(5)
}
}

impl<T> Future for IncomingFut<T>
where
T: DeserializeOwned,
{
type Output = Result<Incoming<T>, Error>;
impl Future for IncomingPayloadFut {
type Output = Result<IncomingPayload, IncomingPayloadError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Ok(json) = ready!(Pin::new(&mut self.json_fut).poll(cx)) {
let headers = self.req.headers();

if let Some(authorization) = headers.get("Authorization") {
if let Ok(authorization) = authorization.to_str() {
return Poll::Ready(Ok(Incoming {
authorization: authorization.to_owned(),
data: json.into_inner(),
}));
}
if self.timed_out() {
return Poll::Ready(Err(IncomingPayloadError::Timeout));
}

while let Some(body) = ready!(Pin::new(&mut self.payload).poll_next(cx)) {
if self.timed_out() {
return Poll::Ready(Err(IncomingPayloadError::Timeout));
}

return Poll::Ready(Err(ErrorUnauthorized("401")));
if let Ok(body) = body {
self.body.extend_from_slice(&body);
} else {
return Poll::Ready(Err(IncomingPayloadError::BadRequest));
}
}

Poll::Ready(Err(ErrorBadRequest("400")))
let headers = self.req.headers();

if let (Some(signature), Some(trace)) = (
headers.get("x-topgg-signature"),
headers.get("x-topgg-trace"),
) && let (Ok(signature), Ok(trace), Ok(body)) = (
signature.to_str(),
trace.to_str(),
str::from_utf8(&self.body),
) && let Some(incoming) = IncomingPayload::new(self.now, body.into(), signature, trace)
{
return Poll::Ready(Ok(incoming));
}

Poll::Ready(Err(IncomingPayloadError::BadRequest))
}
}

#[cfg_attr(docsrs, doc(cfg(feature = "actix-web")))]
impl<T> FromRequest for Incoming<T>
where
T: DeserializeOwned,
{
type Error = Error;
type Future = IncomingFut<T>;

#[inline(always)]
impl FromRequest for IncomingPayload {
type Error = IncomingPayloadError;
type Future = IncomingPayloadFut;

fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
IncomingFut {
IncomingPayloadFut {
req: req.clone(),
json_fut: Json::from_request(req, payload),
payload: payload.take(),
body: vec![],
start: Instant::now(),
now: Utc::now(),
}
}
}
131 changes: 95 additions & 36 deletions src/webhooks/axum.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,77 @@
use super::Webhook;
use super::{Payload, PayloadResult};
use std::{sync::Arc, time::Duration};

use axum::{
extract::State,
BoxError, Router,
error_handling::HandleErrorLayer,
extract::{DefaultBodyLimit, State},
http::{HeaderMap, StatusCode},
response::IntoResponse,
response::{IntoResponse, Response},
routing::post,
Router,
};
use serde::de::DeserializeOwned;
use std::sync::Arc;
use chrono::Utc;
use tower::{ServiceBuilder, timeout::error::Elapsed};

/// An axum webhook listener for listening to payloads.
///
/// # Example
///
/// ```rust,no_run
/// struct MyTopggListener {}
///
/// #[async_trait::async_trait]
/// impl topgg::axum::Listener for MyTopggListener {
/// async fn callback(self: Arc<Self>, payload: Payload, _trace: &str) -> Response {
/// println!("{payload:?}");
///
/// (StatusCode::NO_CONTENT, ()).into_response()
/// }
/// }
/// ```
#[async_trait::async_trait]
#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
pub trait Listener: Send + Sync + 'static {
async fn callback(self: Arc<Self>, payload: Payload, trace: &str) -> Response;
}

struct WebhookState<T> {
state: Arc<T>,
password: Arc<String>,
secret: Arc<String>,
}

impl<T> Clone for WebhookState<T> {
#[inline(always)]
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
password: Arc::clone(&self.password),
state: self.state.clone(),
secret: self.secret.clone(),
}
}
}

/// Creates a new axum [`Router`] for receiving vote events.
/// Creates a new axum [`Router`] for receiving webhook payloads.
///
/// # Example
///
/// ```rust,no_run
/// use axum::{routing::get, Router};
/// use topgg::{VoteEvent, Webhook};
/// use tokio::net::TcpListener;
/// use topgg::Payload;
/// use std::sync::Arc;
///
/// struct MyVoteListener {}
/// use axum::{
/// Router,
/// http::status::StatusCode,
/// response::{IntoResponse, Response},
/// routing::get,
/// };
/// use tokio::net::TcpListener;
///
/// struct MyTopggListener {}
///
/// #[async_trait::async_trait]
/// impl Webhook<VoteEvent> for MyVoteListener {
/// async fn callback(&self, vote: VoteEvent) {
/// println!("A user with the ID of {} has voted us on Top.gg!", vote.voter_id);
/// impl topgg::axum::Listener for MyTopggListener {
/// async fn callback(self: Arc<Self>, payload: Payload, _trace: &str) -> Response {
/// println!("{payload:?}");
///
/// (StatusCode::NO_CONTENT, ()).into_response()
/// }
/// }
///
Expand All @@ -49,48 +81,75 @@ impl<T> Clone for WebhookState<T> {
///
/// #[tokio::main]
/// async fn main() {
/// let state = Arc::new(MyVoteListener {});
/// let state = Arc::new(MyTopggListener {});
///
/// // POST /webhook
/// let router = Router::new().route("/", get(index)).nest(
/// "/votes",
/// topgg::axum::webhook(env!("MY_TOPGG_WEBHOOK_SECRET").to_string(), Arc::clone(&state)),
/// "/webhook",
/// topgg::axum::webhook(Arc::clone(&state), env!("TOPGG_WEBHOOK_SECRET").into()),
/// );
///
/// let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
///
/// axum::serve(listener, router).await.unwrap();
/// }
/// ```
#[inline(always)]
#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
pub fn webhook<D, T>(password: String, state: Arc<T>) -> Router
pub fn webhook<S>(state: Arc<S>, secret: String) -> Router
where
D: DeserializeOwned + Send,
T: Webhook<D>,
S: Listener,
{
let timeout_layer = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|err: BoxError| async move {
if err.is::<Elapsed>() {
(StatusCode::REQUEST_TIMEOUT, "Request timed out")
} else {
(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")
}
}))
.timeout(Duration::from_secs(5));

Router::new()
.route(
"/",
post(
async |headers: HeaderMap, State(webhook): State<WebhookState<T>>, body: String| {
if let Some(authorization) = headers.get("Authorization") {
if let Ok(authorization) = authorization.to_str() {
if authorization == *(webhook.password) {
if let Ok(data) = serde_json::from_str(&body) {
webhook.state.callback(data).await;
async |headers: HeaderMap, State(wrapped_state): State<WebhookState<S>>, body: String| {
let now = Utc::now();

return (StatusCode::NO_CONTENT, ()).into_response();
}
if let Some(signature) = headers.get("x-topgg-signature")
&& let Ok(signature) = signature.to_str()
&& let Some(trace) = headers.get("x-topgg-trace")
&& let Ok(trace) = trace.to_str()
{
match Payload::new(now, body, signature, &wrapped_state.secret) {
PayloadResult::Accepted(payload) => {
wrapped_state.state.callback(payload, trace).await
}

PayloadResult::Forbidden => (StatusCode::FORBIDDEN, "Forbidden").into_response(),

PayloadResult::BadRequest => (StatusCode::BAD_REQUEST, "Bad Request").into_response(),

PayloadResult::Unauthorized => {
(StatusCode::UNAUTHORIZED, "Unauthorized").into_response()
}

PayloadResult::DeserializationFailure => (StatusCode::NO_CONTENT, "").into_response(),

PayloadResult::InternalServerError => {
(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response()
}
}
} else {
(StatusCode::BAD_REQUEST, "Bad Request").into_response()
}

(StatusCode::UNAUTHORIZED, ()).into_response()
},
),
)
.layer(timeout_layer.into_inner())
.layer(DefaultBodyLimit::max(2 * 1024 * 1024))
.with_state(WebhookState {
state,
password: Arc::new(password),
secret: Arc::new(secret),
})
}
Loading