瀏覽代碼

Move rocket into main.rs and out of primary library

Mark Rousskov 5 年之前
父節點
當前提交
2b84e897b0
共有 3 個文件被更改,包括 155 次插入141 次删除
  1. 80 0
      src/lib.rs
  2. 44 78
      src/main.rs
  3. 31 63
      src/payload.rs

+ 80 - 0
src/lib.rs

@@ -0,0 +1,80 @@
+#![allow(clippy::new_without_default)]
+
+use failure::{Error, ResultExt};
+
+use interactions::ErrorComment;
+
+pub mod config;
+pub mod github;
+pub mod handlers;
+pub mod interactions;
+pub mod payload;
+pub mod team;
+
+pub enum EventName {
+    IssueComment,
+    Issue,
+    Other,
+}
+
+impl std::str::FromStr for EventName {
+    type Err = std::convert::Infallible;
+    fn from_str(s: &str) -> Result<EventName, Self::Err> {
+        Ok(match s {
+            "issue_comment" => EventName::IssueComment,
+            "issues" => EventName::Issue,
+            _ => EventName::Other,
+        })
+    }
+}
+
+#[derive(Debug)]
+pub struct WebhookError(Error);
+
+impl From<Error> for WebhookError {
+    fn from(e: Error) -> WebhookError {
+        WebhookError(e)
+    }
+}
+
+pub fn deserialize_payload<T: serde::de::DeserializeOwned>(v: &str) -> Result<T, Error> {
+    Ok(serde_json::from_str(&v).with_context(|_| format!("input: {:?}", v))?)
+}
+
+pub fn webhook(
+    event: EventName,
+    payload: String,
+    ctx: &handlers::Context,
+) -> Result<(), WebhookError> {
+    match event {
+        EventName::IssueComment => {
+            let payload = deserialize_payload::<github::IssueCommentEvent>(&payload)
+                .context("IssueCommentEvent failed to deserialize")
+                .map_err(Error::from)?;
+
+            let event = github::Event::IssueComment(payload);
+            if let Err(err) = handlers::handle(&ctx, &event) {
+                if let Some(issue) = event.issue() {
+                    ErrorComment::new(issue, err.to_string()).post(&ctx.github)?;
+                }
+                return Err(err.into());
+            }
+        }
+        EventName::Issue => {
+            let payload = deserialize_payload::<github::IssuesEvent>(&payload)
+                .context("IssuesEvent failed to deserialize")
+                .map_err(Error::from)?;
+
+            let event = github::Event::Issue(payload);
+            if let Err(err) = handlers::handle(&ctx, &event) {
+                if let Some(issue) = event.issue() {
+                    ErrorComment::new(issue, err.to_string()).post(&ctx.github)?;
+                }
+                return Err(err.into());
+            }
+        }
+        // Other events need not be handled
+        EventName::Other => {}
+    }
+    Ok(())
+}

+ 44 - 78
src/main.rs

@@ -4,105 +4,71 @@
 #[macro_use]
 extern crate rocket;
 
-use failure::{Error, ResultExt};
+use failure::ResultExt;
 use reqwest::Client;
-use rocket::request;
-use rocket::State;
-use rocket::{http::Status, Outcome, Request};
-use std::env;
+use rocket::{
+    data::Data,
+    http::Status,
+    request::{self, FromRequest, Request},
+    Outcome, State,
+};
+use std::{env, io::Read};
+use triagebot::{github, handlers, payload, EventName, WebhookError};
 
-mod config;
-mod github;
-mod handlers;
-mod interactions;
-mod payload;
-mod team;
+struct XGitHubEvent<'r>(&'r str);
 
-use interactions::ErrorComment;
-use payload::SignedPayload;
-
-enum EventName {
-    IssueComment,
-    Issue,
-    Other,
-}
-
-impl<'a, 'r> request::FromRequest<'a, 'r> for EventName {
-    type Error = String;
+impl<'a, 'r> FromRequest<'a, 'r> for XGitHubEvent<'a> {
+    type Error = &'static str;
     fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
         let ev = if let Some(ev) = req.headers().get_one("X-GitHub-Event") {
             ev
         } else {
-            return Outcome::Failure((Status::BadRequest, "Needs a X-GitHub-Event".into()));
-        };
-        let ev = match ev {
-            "issue_comment" => EventName::IssueComment,
-            "issues" => EventName::Issue,
-            _ => EventName::Other,
+            return Outcome::Failure((Status::BadRequest, "Needs a X-GitHub-Event"));
         };
-        Outcome::Success(ev)
+        Outcome::Success(XGitHubEvent(ev))
     }
 }
 
-#[derive(Debug)]
-struct WebhookError(Error);
-
-impl<'r> rocket::response::Responder<'r> for WebhookError {
-    fn respond_to(self, _: &Request) -> rocket::response::Result<'r> {
-        let body = format!("{:?}", self.0);
-        rocket::Response::build()
-            .header(rocket::http::ContentType::Plain)
-            .status(rocket::http::Status::InternalServerError)
-            .sized_body(std::io::Cursor::new(body))
-            .ok()
-    }
-}
+struct XHubSignature<'r>(&'r str);
 
-impl From<Error> for WebhookError {
-    fn from(e: Error) -> WebhookError {
-        WebhookError(e)
+impl<'a, 'r> FromRequest<'a, 'r> for XHubSignature<'a> {
+    type Error = &'static str;
+    fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
+        let ev = if let Some(ev) = req.headers().get_one("X-Hub-Signature") {
+            ev
+        } else {
+            return Outcome::Failure((Status::BadRequest, "Needs a X-Hub-Signature"));
+        };
+        Outcome::Success(XHubSignature(ev))
     }
 }
 
 #[post("/github-hook", data = "<payload>")]
 fn webhook(
-    event: EventName,
-    payload: SignedPayload,
+    signature: XHubSignature,
+    event_header: XGitHubEvent,
+    payload: Data,
     ctx: State<handlers::Context>,
 ) -> Result<(), WebhookError> {
-    match event {
-        EventName::IssueComment => {
-            let payload = payload
-                .deserialize::<github::IssueCommentEvent>()
-                .context("IssueCommentEvent failed to deserialize")
-                .map_err(Error::from)?;
-
-            let event = github::Event::IssueComment(payload);
-            if let Err(err) = handlers::handle(&ctx, &event) {
-                if let Some(issue) = event.issue() {
-                    ErrorComment::new(issue, err.to_string()).post(&ctx.github)?;
-                }
-                return Err(err.into());
-            }
-        }
-        EventName::Issue => {
-            let payload = payload
-                .deserialize::<github::IssuesEvent>()
-                .context("IssueCommentEvent failed to deserialize")
-                .map_err(Error::from)?;
+    let event = match event_header.0.parse::<EventName>() {
+        Ok(v) => v,
+        Err(_) => unreachable!(),
+    };
 
-            let event = github::Event::Issue(payload);
-            if let Err(err) = handlers::handle(&ctx, &event) {
-                if let Some(issue) = event.issue() {
-                    ErrorComment::new(issue, err.to_string()).post(&ctx.github)?;
-                }
-                return Err(err.into());
-            }
-        }
-        // Other events need not be handled
-        EventName::Other => {}
+    let mut stream = payload.open().take(1024 * 1024 * 5); // 5 Megabytes
+    let mut buf = Vec::new();
+    if let Err(err) = stream.read_to_end(&mut buf) {
+        log::trace!("failed to read request body: {:?}", err);
+        return Err(WebhookError::from(failure::err_msg(
+            "failed to read request body",
+        )));
     }
-    Ok(())
+
+    payload::assert_signed(signature.0, &buf).map_err(failure::Error::from)?;
+    let payload = String::from_utf8(buf)
+        .context("utf-8 payload required")
+        .map_err(failure::Error::from)?;
+    triagebot::webhook(event, payload, &ctx)
 }
 
 #[catch(404)]

+ 31 - 63
src/payload.rs

@@ -1,71 +1,39 @@
-//! This module implements the payload verification for GitHub webhook events.
-
 use openssl::{hash::MessageDigest, memcmp, pkey::PKey, sign::Signer};
-use rocket::{
-    data::{self, Data, FromDataSimple},
-    http::Status,
-    request::Request,
-    Outcome,
-};
-use std::{env, io::Read};
-
-pub struct SignedPayload(Vec<u8>);
-
-impl FromDataSimple for SignedPayload {
-    type Error = String;
-    fn from_data(req: &Request, data: Data) -> data::Outcome<Self, Self::Error> {
-        let signature = match req.headers().get_one("X-Hub-Signature") {
-            Some(s) => s,
-            None => {
-                return Outcome::Failure((
-                    Status::Unauthorized,
-                    "Unauthorized, no signature".into(),
-                ));
-            }
-        };
-        let signature = &signature["sha1=".len()..];
-        let signature = match hex::decode(&signature) {
-            Ok(e) => e,
-            Err(e) => {
-                return Outcome::Failure((
-                    Status::BadRequest,
-                    format!(
-                        "failed to convert signature {:?} from hex: {:?}",
-                        signature, e
-                    ),
-                ));
-            }
-        };
-
-        let mut stream = data.open().take(1024 * 1024 * 5); // 5 Megabytes
-        let mut buf = Vec::new();
-        if let Err(err) = stream.read_to_end(&mut buf) {
-            return Outcome::Failure((
-                Status::InternalServerError,
-                format!("failed to read request body to string: {:?}", err),
-            ));
-        }
+use std::fmt;
 
-        let key = PKey::hmac(
-            env::var("GITHUB_WEBHOOK_SECRET")
-                .expect("Missing GITHUB_WEBHOOK_SECRET")
-                .as_bytes(),
-        )
-        .unwrap();
-        let mut signer = Signer::new(MessageDigest::sha1(), &key).unwrap();
-        signer.update(&buf).unwrap();
-        let hmac = signer.sign_to_vec().unwrap();
+#[derive(Debug)]
+pub struct SignedPayloadError;
 
-        if !memcmp::eq(&hmac, &signature) {
-            return Outcome::Failure((Status::Unauthorized, "HMAC not correct".into()));
-        }
-
-        Outcome::Success(SignedPayload(buf))
+impl fmt::Display for SignedPayloadError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "failed to validate payload")
     }
 }
 
-impl SignedPayload {
-    pub fn deserialize<T: serde::de::DeserializeOwned>(self) -> Result<T, serde_json::Error> {
-        serde_json::from_slice(&self.0)
+impl std::error::Error for SignedPayloadError {}
+
+pub fn assert_signed(signature: &str, payload: &[u8]) -> Result<(), SignedPayloadError> {
+    let signature = signature.get("sha1=".len()..).ok_or(SignedPayloadError)?;
+    let signature = match hex::decode(&signature) {
+        Ok(e) => e,
+        Err(e) => {
+            log::trace!("hex decode failed for {:?}: {:?}", signature, e);
+            return Err(SignedPayloadError);
+        }
+    };
+
+    let key = PKey::hmac(
+        std::env::var("GITHUB_WEBHOOK_SECRET")
+            .expect("Missing GITHUB_WEBHOOK_SECRET")
+            .as_bytes(),
+    )
+    .unwrap();
+    let mut signer = Signer::new(MessageDigest::sha1(), &key).unwrap();
+    signer.update(&payload).unwrap();
+    let hmac = signer.sign_to_vec().unwrap();
+
+    if !memcmp::eq(&hmac, &signature) {
+        return Err(SignedPayloadError);
     }
+    Ok(())
 }