main.rs 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. #![allow(clippy::new_without_default)]
  2. use anyhow::Context as _;
  3. use futures::{future::FutureExt, stream::StreamExt};
  4. use hyper::{header, Body, Request, Response, Server, StatusCode};
  5. use reqwest::Client;
  6. use route_recognizer::Router;
  7. use std::{env, net::SocketAddr, sync::Arc};
  8. use triagebot::{db, github, handlers::Context, logger, notification_listing, payload, EventName};
  9. use uuid::Uuid;
  10. async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Body>, hyper::Error> {
  11. log::info!("request = {:?}", req);
  12. let mut router = Router::new();
  13. router.add("/triage", "index".to_string());
  14. router.add("/triage/:owner/:repo", "pulls".to_string());
  15. let (req, body_stream) = req.into_parts();
  16. if let Ok(matcher) = router.recognize(req.uri.path()) {
  17. if matcher.handler().as_str() == "pulls" {
  18. let params = matcher.params();
  19. let owner = params.find("owner");
  20. let repo = params.find("repo");
  21. return triagebot::triage::pulls(ctx, owner.unwrap(), repo.unwrap()).await;
  22. } else {
  23. return triagebot::triage::index();
  24. }
  25. }
  26. if req.uri.path() == "/" {
  27. return Ok(Response::builder()
  28. .status(StatusCode::OK)
  29. .body(Body::from("Triagebot is awaiting triage."))
  30. .unwrap());
  31. }
  32. if req.uri.path() == "/bors-commit-list" {
  33. let res = db::rustc_commits::get_commits_with_artifacts(&ctx.db).await;
  34. let res = match res {
  35. Ok(r) => r,
  36. Err(e) => {
  37. return Ok(Response::builder()
  38. .status(StatusCode::INTERNAL_SERVER_ERROR)
  39. .body(Body::from(format!("{:?}", e)))
  40. .unwrap());
  41. }
  42. };
  43. return Ok(Response::builder()
  44. .status(StatusCode::OK)
  45. .header("Content-Type", "application/json")
  46. .body(Body::from(serde_json::to_string(&res).unwrap()))
  47. .unwrap());
  48. }
  49. if req.uri.path() == "/notifications" {
  50. if let Some(query) = req.uri.query() {
  51. let user = url::form_urlencoded::parse(query.as_bytes()).find(|(k, _)| k == "user");
  52. if let Some((_, name)) = user {
  53. return Ok(Response::builder()
  54. .status(StatusCode::OK)
  55. .body(Body::from(
  56. notification_listing::render(&ctx.db, &*name).await,
  57. ))
  58. .unwrap());
  59. }
  60. }
  61. return Ok(Response::builder()
  62. .status(StatusCode::OK)
  63. .body(Body::from(String::from(
  64. "Please provide `?user=<username>` query param on URL.",
  65. )))
  66. .unwrap());
  67. }
  68. if req.uri.path() == "/zulip-hook" {
  69. let mut c = body_stream;
  70. let mut payload = Vec::new();
  71. while let Some(chunk) = c.next().await {
  72. let chunk = chunk?;
  73. payload.extend_from_slice(&chunk);
  74. }
  75. let req = match serde_json::from_slice(&payload) {
  76. Ok(r) => r,
  77. Err(e) => {
  78. return Ok(Response::builder()
  79. .status(StatusCode::BAD_REQUEST)
  80. .body(Body::from(format!(
  81. "Did not send valid JSON request: {}",
  82. e
  83. )))
  84. .unwrap());
  85. }
  86. };
  87. return Ok(Response::builder()
  88. .status(StatusCode::OK)
  89. .body(Body::from(triagebot::zulip::respond(&ctx, req).await))
  90. .unwrap());
  91. }
  92. if req.uri.path() != "/github-hook" {
  93. return Ok(Response::builder()
  94. .status(StatusCode::NOT_FOUND)
  95. .body(Body::empty())
  96. .unwrap());
  97. }
  98. if req.method != hyper::Method::POST {
  99. return Ok(Response::builder()
  100. .status(StatusCode::METHOD_NOT_ALLOWED)
  101. .header(header::ALLOW, "POST")
  102. .body(Body::empty())
  103. .unwrap());
  104. }
  105. let event = if let Some(ev) = req.headers.get("X-GitHub-Event") {
  106. let ev = match ev.to_str().ok() {
  107. Some(v) => v,
  108. None => {
  109. return Ok(Response::builder()
  110. .status(StatusCode::BAD_REQUEST)
  111. .body(Body::from("X-GitHub-Event header must be UTF-8 encoded"))
  112. .unwrap());
  113. }
  114. };
  115. match ev.parse::<EventName>() {
  116. Ok(v) => v,
  117. Err(_) => unreachable!(),
  118. }
  119. } else {
  120. return Ok(Response::builder()
  121. .status(StatusCode::BAD_REQUEST)
  122. .body(Body::from("X-GitHub-Event header must be set"))
  123. .unwrap());
  124. };
  125. log::debug!("event={}", event);
  126. let signature = if let Some(sig) = req.headers.get("X-Hub-Signature") {
  127. match sig.to_str().ok() {
  128. Some(v) => v,
  129. None => {
  130. return Ok(Response::builder()
  131. .status(StatusCode::BAD_REQUEST)
  132. .body(Body::from("X-Hub-Signature header must be UTF-8 encoded"))
  133. .unwrap());
  134. }
  135. }
  136. } else {
  137. return Ok(Response::builder()
  138. .status(StatusCode::BAD_REQUEST)
  139. .body(Body::from("X-Hub-Signature header must be set"))
  140. .unwrap());
  141. };
  142. log::debug!("signature={}", signature);
  143. let mut c = body_stream;
  144. let mut payload = Vec::new();
  145. while let Some(chunk) = c.next().await {
  146. let chunk = chunk?;
  147. payload.extend_from_slice(&chunk);
  148. }
  149. if let Err(_) = payload::assert_signed(signature, &payload) {
  150. return Ok(Response::builder()
  151. .status(StatusCode::FORBIDDEN)
  152. .body(Body::from("Wrong signature"))
  153. .unwrap());
  154. }
  155. let payload = match String::from_utf8(payload) {
  156. Ok(p) => p,
  157. Err(_) => {
  158. return Ok(Response::builder()
  159. .status(StatusCode::BAD_REQUEST)
  160. .body(Body::from("Payload must be UTF-8"))
  161. .unwrap());
  162. }
  163. };
  164. match triagebot::webhook(event, payload, &ctx).await {
  165. Ok(true) => Ok(Response::new(Body::from("processed request"))),
  166. Ok(false) => Ok(Response::new(Body::from("ignored request"))),
  167. Err(err) => {
  168. log::error!("request failed: {:?}", err);
  169. Ok(Response::builder()
  170. .status(StatusCode::INTERNAL_SERVER_ERROR)
  171. .body(Body::from(format!("request failed: {:?}", err)))
  172. .unwrap())
  173. }
  174. }
  175. }
  176. async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
  177. log::info!("Listening on http://{}", addr);
  178. let db_client = db::make_client()
  179. .await
  180. .context("open database connection")?;
  181. db::run_migrations(&db_client)
  182. .await
  183. .context("database migrations")?;
  184. let client = Client::new();
  185. let gh = github::GithubClient::new_with_default_token(client.clone());
  186. let oc = octocrab::OctocrabBuilder::new()
  187. .personal_token(github::default_token_from_env())
  188. .build()
  189. .expect("Failed to build octograb.");
  190. let ctx = Arc::new(Context {
  191. username: String::from("rustbot"),
  192. db: db_client,
  193. github: gh,
  194. octocrab: oc,
  195. });
  196. let svc = hyper::service::make_service_fn(move |_conn| {
  197. let ctx = ctx.clone();
  198. async move {
  199. let uuid = Uuid::new_v4();
  200. Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
  201. logger::LogFuture::new(
  202. uuid,
  203. serve_req(req, ctx.clone()).map(move |mut resp| {
  204. if let Ok(resp) = &mut resp {
  205. resp.headers_mut()
  206. .insert("X-Request-Id", uuid.to_string().parse().unwrap());
  207. }
  208. log::info!("response = {:?}", resp);
  209. resp
  210. }),
  211. )
  212. }))
  213. }
  214. });
  215. let serve_future = Server::bind(&addr).serve(svc);
  216. serve_future.await?;
  217. Ok(())
  218. }
  219. #[tokio::main]
  220. async fn main() {
  221. dotenv::dotenv().ok();
  222. logger::init();
  223. let port = env::var("PORT")
  224. .ok()
  225. .map(|p| p.parse::<u16>().expect("parsed PORT"))
  226. .unwrap_or(8000);
  227. let addr = ([0, 0, 0, 0], port).into();
  228. if let Err(e) = run_server(addr).await {
  229. eprintln!("Failed to run server: {:?}", e);
  230. }
  231. }