main.rs 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. #![allow(clippy::new_without_default)]
  2. use anyhow::Context as _;
  3. use futures::future::FutureExt;
  4. use futures::StreamExt;
  5. use hyper::{header, Body, Request, Response, Server, StatusCode};
  6. use reqwest::Client;
  7. use route_recognizer::Router;
  8. use std::{env, net::SocketAddr, sync::Arc};
  9. use triagebot::{db, github, handlers::Context, logger, notification_listing, payload, EventName};
  10. use uuid::Uuid;
  11. async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Body>, hyper::Error> {
  12. log::info!("request = {:?}", req);
  13. let mut router = Router::new();
  14. router.add("/triage", "index".to_string());
  15. router.add("/triage/:owner/:repo", "pulls".to_string());
  16. let (req, body_stream) = req.into_parts();
  17. if let Ok(matcher) = router.recognize(req.uri.path()) {
  18. if matcher.handler().as_str() == "pulls" {
  19. let params = matcher.params();
  20. let owner = params.find("owner");
  21. let repo = params.find("repo");
  22. return triagebot::triage::pulls(ctx, owner.unwrap(), repo.unwrap()).await;
  23. } else {
  24. return triagebot::triage::index();
  25. }
  26. }
  27. if req.uri.path() == "/" {
  28. return Ok(Response::builder()
  29. .status(StatusCode::OK)
  30. .body(Body::from("Triagebot is awaiting triage."))
  31. .unwrap());
  32. }
  33. if req.uri.path() == "/bors-commit-list" {
  34. let res = db::rustc_commits::get_commits_with_artifacts(&ctx.db).await;
  35. let res = match res {
  36. Ok(r) => r,
  37. Err(e) => {
  38. return Ok(Response::builder()
  39. .status(StatusCode::INTERNAL_SERVER_ERROR)
  40. .body(Body::from(format!("{:?}", e)))
  41. .unwrap());
  42. }
  43. };
  44. return Ok(Response::builder()
  45. .status(StatusCode::OK)
  46. .header("Content-Type", "application/json")
  47. .body(Body::from(serde_json::to_string(&res).unwrap()))
  48. .unwrap());
  49. }
  50. if req.uri.path() == "/notifications" {
  51. if let Some(query) = req.uri.query() {
  52. let user = url::form_urlencoded::parse(query.as_bytes()).find(|(k, _)| k == "user");
  53. if let Some((_, name)) = user {
  54. return Ok(Response::builder()
  55. .status(StatusCode::OK)
  56. .body(Body::from(
  57. notification_listing::render(&ctx.db, &*name).await,
  58. ))
  59. .unwrap());
  60. }
  61. }
  62. return Ok(Response::builder()
  63. .status(StatusCode::OK)
  64. .body(Body::from(String::from(
  65. "Please provide `?user=<username>` query param on URL.",
  66. )))
  67. .unwrap());
  68. }
  69. if req.uri.path() == "/zulip-hook" {
  70. let mut c = body_stream;
  71. let mut payload = Vec::new();
  72. while let Some(chunk) = c.next().await {
  73. let chunk = chunk?;
  74. payload.extend_from_slice(&chunk);
  75. }
  76. let req = match serde_json::from_slice(&payload) {
  77. Ok(r) => r,
  78. Err(e) => {
  79. return Ok(Response::builder()
  80. .status(StatusCode::BAD_REQUEST)
  81. .body(Body::from(format!(
  82. "Did not send valid JSON request: {}",
  83. e
  84. )))
  85. .unwrap());
  86. }
  87. };
  88. return Ok(Response::builder()
  89. .status(StatusCode::OK)
  90. .body(Body::from(triagebot::zulip::respond(&ctx, req).await))
  91. .unwrap());
  92. }
  93. if req.uri.path() != "/github-hook" {
  94. return Ok(Response::builder()
  95. .status(StatusCode::NOT_FOUND)
  96. .body(Body::empty())
  97. .unwrap());
  98. }
  99. if req.method != hyper::Method::POST {
  100. return Ok(Response::builder()
  101. .status(StatusCode::METHOD_NOT_ALLOWED)
  102. .header(header::ALLOW, "POST")
  103. .body(Body::empty())
  104. .unwrap());
  105. }
  106. let event = if let Some(ev) = req.headers.get("X-GitHub-Event") {
  107. let ev = match ev.to_str().ok() {
  108. Some(v) => v,
  109. None => {
  110. return Ok(Response::builder()
  111. .status(StatusCode::BAD_REQUEST)
  112. .body(Body::from("X-GitHub-Event header must be UTF-8 encoded"))
  113. .unwrap());
  114. }
  115. };
  116. match ev.parse::<EventName>() {
  117. Ok(v) => v,
  118. Err(_) => unreachable!(),
  119. }
  120. } else {
  121. return Ok(Response::builder()
  122. .status(StatusCode::BAD_REQUEST)
  123. .body(Body::from("X-GitHub-Event header must be set"))
  124. .unwrap());
  125. };
  126. log::debug!("event={}", event);
  127. let signature = if let Some(sig) = req.headers.get("X-Hub-Signature") {
  128. match sig.to_str().ok() {
  129. Some(v) => v,
  130. None => {
  131. return Ok(Response::builder()
  132. .status(StatusCode::BAD_REQUEST)
  133. .body(Body::from("X-Hub-Signature header must be UTF-8 encoded"))
  134. .unwrap());
  135. }
  136. }
  137. } else {
  138. return Ok(Response::builder()
  139. .status(StatusCode::BAD_REQUEST)
  140. .body(Body::from("X-Hub-Signature header must be set"))
  141. .unwrap());
  142. };
  143. log::debug!("signature={}", signature);
  144. let mut c = body_stream;
  145. let mut payload = Vec::new();
  146. while let Some(chunk) = c.next().await {
  147. let chunk = chunk?;
  148. payload.extend_from_slice(&chunk);
  149. }
  150. if let Err(_) = payload::assert_signed(signature, &payload) {
  151. return Ok(Response::builder()
  152. .status(StatusCode::FORBIDDEN)
  153. .body(Body::from("Wrong signature"))
  154. .unwrap());
  155. }
  156. let payload = match String::from_utf8(payload) {
  157. Ok(p) => p,
  158. Err(_) => {
  159. return Ok(Response::builder()
  160. .status(StatusCode::BAD_REQUEST)
  161. .body(Body::from("Payload must be UTF-8"))
  162. .unwrap());
  163. }
  164. };
  165. match triagebot::webhook(event, payload, &ctx).await {
  166. Ok(true) => Ok(Response::new(Body::from("processed request"))),
  167. Ok(false) => Ok(Response::new(Body::from("ignored request"))),
  168. Err(err) => {
  169. log::error!("request failed: {:?}", err);
  170. Ok(Response::builder()
  171. .status(StatusCode::INTERNAL_SERVER_ERROR)
  172. .body(Body::from(format!("request failed: {:?}", err)))
  173. .unwrap())
  174. }
  175. }
  176. }
  177. async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
  178. log::info!("Listening on http://{}", addr);
  179. let db_client = db::make_client()
  180. .await
  181. .context("open database connection")?;
  182. db::run_migrations(&db_client)
  183. .await
  184. .context("database migrations")?;
  185. let client = Client::new();
  186. let gh = github::GithubClient::new_with_default_token(client.clone());
  187. let oc = octocrab::OctocrabBuilder::new()
  188. .personal_token(github::default_token_from_env())
  189. .build()
  190. .expect("Failed to build octograb.");
  191. let ctx = Arc::new(Context {
  192. username: String::from("rustbot"),
  193. db: db_client,
  194. github: gh,
  195. octocrab: oc,
  196. });
  197. let svc = hyper::service::make_service_fn(move |_conn| {
  198. let ctx = ctx.clone();
  199. async move {
  200. let uuid = Uuid::new_v4();
  201. Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
  202. logger::LogFuture::new(
  203. uuid,
  204. serve_req(req, ctx.clone()).map(move |mut resp| {
  205. if let Ok(resp) = &mut resp {
  206. resp.headers_mut()
  207. .insert("X-Request-Id", uuid.to_string().parse().unwrap());
  208. }
  209. log::info!("response = {:?}", resp);
  210. resp
  211. }),
  212. )
  213. }))
  214. }
  215. });
  216. let serve_future = Server::bind(&addr).serve(svc);
  217. serve_future.await?;
  218. Ok(())
  219. }
  220. #[tokio::main(flavor = "current_thread")]
  221. async fn main() {
  222. dotenv::dotenv().ok();
  223. logger::init();
  224. let port = env::var("PORT")
  225. .ok()
  226. .map(|p| p.parse::<u16>().expect("parsed PORT"))
  227. .unwrap_or(8000);
  228. let addr = ([0, 0, 0, 0], port).into();
  229. if let Err(e) = run_server(addr).await {
  230. eprintln!("Failed to run server: {:?}", e);
  231. }
  232. }