#![allow(clippy::new_without_default)]
use anyhow::Context as _;
use futures::future::FutureExt;
use futures::StreamExt;
use hyper::{header, Body, Request, Response, Server, StatusCode};
use reqwest::Client;
use route_recognizer::Router;
use std::{env, net::SocketAddr, sync::Arc};
use triagebot::{db, github, handlers::Context, logger, notification_listing, payload, EventName};
use uuid::Uuid;
async fn serve_req(req: Request
, ctx: Arc) -> Result, hyper::Error> {
log::info!("request = {:?}", req);
let mut router = Router::new();
router.add("/triage", "index".to_string());
router.add("/triage/:owner/:repo", "pulls".to_string());
let (req, body_stream) = req.into_parts();
if let Ok(matcher) = router.recognize(req.uri.path()) {
if matcher.handler().as_str() == "pulls" {
let params = matcher.params();
let owner = params.find("owner");
let repo = params.find("repo");
return triagebot::triage::pulls(ctx, owner.unwrap(), repo.unwrap()).await;
} else {
return triagebot::triage::index();
}
}
if req.uri.path() == "/" {
return Ok(Response::builder()
.status(StatusCode::OK)
.body(Body::from("Triagebot is awaiting triage."))
.unwrap());
}
if req.uri.path() == "/bors-commit-list" {
let res = db::rustc_commits::get_commits_with_artifacts(&ctx.db).await;
let res = match res {
Ok(r) => r,
Err(e) => {
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(format!("{:?}", e)))
.unwrap());
}
};
return Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Body::from(serde_json::to_string(&res).unwrap()))
.unwrap());
}
if req.uri.path() == "/notifications" {
if let Some(query) = req.uri.query() {
let user = url::form_urlencoded::parse(query.as_bytes()).find(|(k, _)| k == "user");
if let Some((_, name)) = user {
return Ok(Response::builder()
.status(StatusCode::OK)
.body(Body::from(
notification_listing::render(&ctx.db, &*name).await,
))
.unwrap());
}
}
return Ok(Response::builder()
.status(StatusCode::OK)
.body(Body::from(String::from(
"Please provide `?user=` query param on URL.",
)))
.unwrap());
}
if req.uri.path() == "/zulip-hook" {
let mut c = body_stream;
let mut payload = Vec::new();
while let Some(chunk) = c.next().await {
let chunk = chunk?;
payload.extend_from_slice(&chunk);
}
let req = match serde_json::from_slice(&payload) {
Ok(r) => r,
Err(e) => {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!(
"Did not send valid JSON request: {}",
e
)))
.unwrap());
}
};
return Ok(Response::builder()
.status(StatusCode::OK)
.body(Body::from(triagebot::zulip::respond(&ctx, req).await))
.unwrap());
}
if req.uri.path() != "/github-hook" {
return Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap());
}
if req.method != hyper::Method::POST {
return Ok(Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.header(header::ALLOW, "POST")
.body(Body::empty())
.unwrap());
}
let event = if let Some(ev) = req.headers.get("X-GitHub-Event") {
let ev = match ev.to_str().ok() {
Some(v) => v,
None => {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("X-GitHub-Event header must be UTF-8 encoded"))
.unwrap());
}
};
match ev.parse::() {
Ok(v) => v,
Err(_) => unreachable!(),
}
} else {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("X-GitHub-Event header must be set"))
.unwrap());
};
log::debug!("event={}", event);
let signature = if let Some(sig) = req.headers.get("X-Hub-Signature") {
match sig.to_str().ok() {
Some(v) => v,
None => {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("X-Hub-Signature header must be UTF-8 encoded"))
.unwrap());
}
}
} else {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("X-Hub-Signature header must be set"))
.unwrap());
};
log::debug!("signature={}", signature);
let mut c = body_stream;
let mut payload = Vec::new();
while let Some(chunk) = c.next().await {
let chunk = chunk?;
payload.extend_from_slice(&chunk);
}
if let Err(_) = payload::assert_signed(signature, &payload) {
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::from("Wrong signature"))
.unwrap());
}
let payload = match String::from_utf8(payload) {
Ok(p) => p,
Err(_) => {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Payload must be UTF-8"))
.unwrap());
}
};
match triagebot::webhook(event, payload, &ctx).await {
Ok(true) => Ok(Response::new(Body::from("processed request"))),
Ok(false) => Ok(Response::new(Body::from("ignored request"))),
Err(err) => {
log::error!("request failed: {:?}", err);
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(format!("request failed: {:?}", err)))
.unwrap())
}
}
}
async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
log::info!("Listening on http://{}", addr);
let db_client = db::make_client()
.await
.context("open database connection")?;
db::run_migrations(&db_client)
.await
.context("database migrations")?;
let client = Client::new();
let gh = github::GithubClient::new_with_default_token(client.clone());
let oc = octocrab::OctocrabBuilder::new()
.personal_token(github::default_token_from_env())
.build()
.expect("Failed to build octograb.");
let ctx = Arc::new(Context {
username: String::from("rustbot"),
db: db_client,
github: gh,
octocrab: oc,
});
let svc = hyper::service::make_service_fn(move |_conn| {
let ctx = ctx.clone();
async move {
let uuid = Uuid::new_v4();
Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
logger::LogFuture::new(
uuid,
serve_req(req, ctx.clone()).map(move |mut resp| {
if let Ok(resp) = &mut resp {
resp.headers_mut()
.insert("X-Request-Id", uuid.to_string().parse().unwrap());
}
log::info!("response = {:?}", resp);
resp
}),
)
}))
}
});
let serve_future = Server::bind(&addr).serve(svc);
serve_future.await?;
Ok(())
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
dotenv::dotenv().ok();
logger::init();
let port = env::var("PORT")
.ok()
.map(|p| p.parse::().expect("parsed PORT"))
.unwrap_or(8000);
let addr = ([0, 0, 0, 0], port).into();
if let Err(e) = run_server(addr).await {
eprintln!("Failed to run server: {:?}", e);
}
}