main.rs 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. #![allow(clippy::new_without_default)]
  2. use anyhow::Context as _;
  3. use futures::future::FutureExt;
  4. use futures::StreamExt;
  5. use hyper::{header, Body, Request, Response, Server, StatusCode};
  6. use reqwest::Client;
  7. use route_recognizer::Router;
  8. use std::{env, net::SocketAddr, sync::Arc};
  9. use tracing as log;
  10. use tracing::Instrument;
  11. use triagebot::{db, github, handlers::Context, notification_listing, payload, EventName};
  12. async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Body>, hyper::Error> {
  13. log::info!("request = {:?}", req);
  14. let mut router = Router::new();
  15. router.add("/triage", "index".to_string());
  16. router.add("/triage/:owner/:repo", "pulls".to_string());
  17. let (req, body_stream) = req.into_parts();
  18. if let Ok(matcher) = router.recognize(req.uri.path()) {
  19. if matcher.handler().as_str() == "pulls" {
  20. let params = matcher.params();
  21. let owner = params.find("owner");
  22. let repo = params.find("repo");
  23. return triagebot::triage::pulls(ctx, owner.unwrap(), repo.unwrap()).await;
  24. } else {
  25. return triagebot::triage::index();
  26. }
  27. }
  28. if req.uri.path() == "/" {
  29. return Ok(Response::builder()
  30. .status(StatusCode::OK)
  31. .body(Body::from("Triagebot is awaiting triage."))
  32. .unwrap());
  33. }
  34. if req.uri.path() == "/bors-commit-list" {
  35. let res = db::rustc_commits::get_commits_with_artifacts(&*ctx.db.get().await).await;
  36. let res = match res {
  37. Ok(r) => r,
  38. Err(e) => {
  39. return Ok(Response::builder()
  40. .status(StatusCode::INTERNAL_SERVER_ERROR)
  41. .body(Body::from(format!("{:?}", e)))
  42. .unwrap());
  43. }
  44. };
  45. return Ok(Response::builder()
  46. .status(StatusCode::OK)
  47. .header("Content-Type", "application/json")
  48. .body(Body::from(serde_json::to_string(&res).unwrap()))
  49. .unwrap());
  50. }
  51. if req.uri.path() == "/notifications" {
  52. if let Some(query) = req.uri.query() {
  53. let user = url::form_urlencoded::parse(query.as_bytes()).find(|(k, _)| k == "user");
  54. if let Some((_, name)) = user {
  55. return Ok(Response::builder()
  56. .status(StatusCode::OK)
  57. .body(Body::from(
  58. notification_listing::render(&ctx.db.get().await, &*name).await,
  59. ))
  60. .unwrap());
  61. }
  62. }
  63. return Ok(Response::builder()
  64. .status(StatusCode::OK)
  65. .body(Body::from(String::from(
  66. "Please provide `?user=<username>` query param on URL.",
  67. )))
  68. .unwrap());
  69. }
  70. if req.uri.path() == "/zulip-hook" {
  71. let mut c = body_stream;
  72. let mut payload = Vec::new();
  73. while let Some(chunk) = c.next().await {
  74. let chunk = chunk?;
  75. payload.extend_from_slice(&chunk);
  76. }
  77. let req = match serde_json::from_slice(&payload) {
  78. Ok(r) => r,
  79. Err(e) => {
  80. return Ok(Response::builder()
  81. .status(StatusCode::BAD_REQUEST)
  82. .body(Body::from(format!(
  83. "Did not send valid JSON request: {}",
  84. e
  85. )))
  86. .unwrap());
  87. }
  88. };
  89. return Ok(Response::builder()
  90. .status(StatusCode::OK)
  91. .body(Body::from(triagebot::zulip::respond(&ctx, req).await))
  92. .unwrap());
  93. }
  94. if req.uri.path() != "/github-hook" {
  95. return Ok(Response::builder()
  96. .status(StatusCode::NOT_FOUND)
  97. .body(Body::empty())
  98. .unwrap());
  99. }
  100. if req.method != hyper::Method::POST {
  101. return Ok(Response::builder()
  102. .status(StatusCode::METHOD_NOT_ALLOWED)
  103. .header(header::ALLOW, "POST")
  104. .body(Body::empty())
  105. .unwrap());
  106. }
  107. let event = if let Some(ev) = req.headers.get("X-GitHub-Event") {
  108. let ev = match ev.to_str().ok() {
  109. Some(v) => v,
  110. None => {
  111. return Ok(Response::builder()
  112. .status(StatusCode::BAD_REQUEST)
  113. .body(Body::from("X-GitHub-Event header must be UTF-8 encoded"))
  114. .unwrap());
  115. }
  116. };
  117. match ev.parse::<EventName>() {
  118. Ok(v) => v,
  119. Err(_) => unreachable!(),
  120. }
  121. } else {
  122. return Ok(Response::builder()
  123. .status(StatusCode::BAD_REQUEST)
  124. .body(Body::from("X-GitHub-Event header must be set"))
  125. .unwrap());
  126. };
  127. log::debug!("event={}", event);
  128. let signature = if let Some(sig) = req.headers.get("X-Hub-Signature") {
  129. match sig.to_str().ok() {
  130. Some(v) => v,
  131. None => {
  132. return Ok(Response::builder()
  133. .status(StatusCode::BAD_REQUEST)
  134. .body(Body::from("X-Hub-Signature header must be UTF-8 encoded"))
  135. .unwrap());
  136. }
  137. }
  138. } else {
  139. return Ok(Response::builder()
  140. .status(StatusCode::BAD_REQUEST)
  141. .body(Body::from("X-Hub-Signature header must be set"))
  142. .unwrap());
  143. };
  144. log::debug!("signature={}", signature);
  145. let mut c = body_stream;
  146. let mut payload = Vec::new();
  147. while let Some(chunk) = c.next().await {
  148. let chunk = chunk?;
  149. payload.extend_from_slice(&chunk);
  150. }
  151. if let Err(_) = payload::assert_signed(signature, &payload) {
  152. return Ok(Response::builder()
  153. .status(StatusCode::FORBIDDEN)
  154. .body(Body::from("Wrong signature"))
  155. .unwrap());
  156. }
  157. let payload = match String::from_utf8(payload) {
  158. Ok(p) => p,
  159. Err(_) => {
  160. return Ok(Response::builder()
  161. .status(StatusCode::BAD_REQUEST)
  162. .body(Body::from("Payload must be UTF-8"))
  163. .unwrap());
  164. }
  165. };
  166. match triagebot::webhook(event, payload, &ctx).await {
  167. Ok(true) => Ok(Response::new(Body::from("processed request"))),
  168. Ok(false) => Ok(Response::new(Body::from("ignored request"))),
  169. Err(err) => {
  170. log::error!("request failed: {:?}", err);
  171. Ok(Response::builder()
  172. .status(StatusCode::INTERNAL_SERVER_ERROR)
  173. .body(Body::from(format!("request failed: {:?}", err)))
  174. .unwrap())
  175. }
  176. }
  177. }
  178. async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
  179. log::info!("Listening on http://{}", addr);
  180. let pool = db::ClientPool::new();
  181. db::run_migrations(&*pool.get().await)
  182. .await
  183. .context("database migrations")?;
  184. let client = Client::new();
  185. let gh = github::GithubClient::new_with_default_token(client.clone());
  186. let oc = octocrab::OctocrabBuilder::new()
  187. .personal_token(github::default_token_from_env())
  188. .build()
  189. .expect("Failed to build octograb.");
  190. let ctx = Arc::new(Context {
  191. username: String::from("rustbot"),
  192. db: pool,
  193. github: gh,
  194. octocrab: oc,
  195. });
  196. let svc = hyper::service::make_service_fn(move |_conn| {
  197. let ctx = ctx.clone();
  198. async move {
  199. Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
  200. let uuid = uuid::Uuid::new_v4();
  201. let span = tracing::span!(tracing::Level::INFO, "request", ?uuid);
  202. serve_req(req, ctx.clone())
  203. .map(move |mut resp| {
  204. if let Ok(resp) = &mut resp {
  205. resp.headers_mut()
  206. .insert("X-Request-Id", uuid.to_string().parse().unwrap());
  207. }
  208. log::info!("response = {:?}", resp);
  209. resp
  210. })
  211. .instrument(span)
  212. }))
  213. }
  214. });
  215. let serve_future = Server::bind(&addr).serve(svc);
  216. serve_future.await?;
  217. Ok(())
  218. }
  219. #[tokio::main(flavor = "current_thread")]
  220. async fn main() {
  221. dotenv::dotenv().ok();
  222. tracing_subscriber::fmt::Subscriber::builder()
  223. .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
  224. .with_ansi(std::env::var_os("DISABLE_COLOR").is_none())
  225. .try_init()
  226. .unwrap();
  227. let port = env::var("PORT")
  228. .ok()
  229. .map(|p| p.parse::<u16>().expect("parsed PORT"))
  230. .unwrap_or(8000);
  231. let addr = ([0, 0, 0, 0], port).into();
  232. if let Err(e) = run_server(addr).await {
  233. eprintln!("Failed to run server: {:?}", e);
  234. }
  235. }