main.rs 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #![allow(clippy::new_without_default)]
  2. use anyhow::Context as _;
  3. use futures::{future::FutureExt, stream::StreamExt};
  4. use hyper::{header, Body, Request, Response, Server, StatusCode};
  5. use reqwest::Client;
  6. use std::{env, net::SocketAddr, sync::Arc};
  7. use triagebot::{db, github, handlers::Context, notification_listing, payload, EventName};
  8. use uuid::Uuid;
  9. mod logger;
  10. async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Body>, hyper::Error> {
  11. log::info!("request = {:?}", req);
  12. let (req, body_stream) = req.into_parts();
  13. if req.uri.path() == "/" {
  14. return Ok(Response::builder()
  15. .status(StatusCode::OK)
  16. .body(Body::from("Triagebot is awaiting triage."))
  17. .unwrap());
  18. }
  19. if req.uri.path() == "/bors-commit-list" {
  20. let res = db::rustc_commits::get_commits_with_artifacts(&ctx.db).await;
  21. let res = match res {
  22. Ok(r) => r,
  23. Err(e) => {
  24. return Ok(Response::builder()
  25. .status(StatusCode::INTERNAL_SERVER_ERROR)
  26. .body(Body::from(format!("{:?}", e)))
  27. .unwrap());
  28. }
  29. };
  30. return Ok(Response::builder()
  31. .status(StatusCode::OK)
  32. .header("Content-Type", "application/json")
  33. .body(Body::from(serde_json::to_string(&res).unwrap()))
  34. .unwrap());
  35. }
  36. if req.uri.path() == "/notifications" {
  37. if let Some(query) = req.uri.query() {
  38. let user = url::form_urlencoded::parse(query.as_bytes()).find(|(k, _)| k == "user");
  39. if let Some((_, name)) = user {
  40. return Ok(Response::builder()
  41. .status(StatusCode::OK)
  42. .body(Body::from(
  43. notification_listing::render(&ctx.db, &*name).await,
  44. ))
  45. .unwrap());
  46. }
  47. }
  48. return Ok(Response::builder()
  49. .status(StatusCode::OK)
  50. .body(Body::from(String::from(
  51. "Please provide `?user=<username>` query param on URL.",
  52. )))
  53. .unwrap());
  54. }
  55. if req.uri.path() == "/zulip-hook" {
  56. let mut c = body_stream;
  57. let mut payload = Vec::new();
  58. while let Some(chunk) = c.next().await {
  59. let chunk = chunk?;
  60. payload.extend_from_slice(&chunk);
  61. }
  62. let req = match serde_json::from_slice(&payload) {
  63. Ok(r) => r,
  64. Err(e) => {
  65. return Ok(Response::builder()
  66. .status(StatusCode::BAD_REQUEST)
  67. .body(Body::from(format!(
  68. "Did not send valid JSON request: {}",
  69. e
  70. )))
  71. .unwrap());
  72. }
  73. };
  74. return Ok(Response::builder()
  75. .status(StatusCode::OK)
  76. .body(Body::from(triagebot::zulip::respond(&ctx, req).await))
  77. .unwrap());
  78. }
  79. if req.uri.path() != "/github-hook" {
  80. return Ok(Response::builder()
  81. .status(StatusCode::NOT_FOUND)
  82. .body(Body::empty())
  83. .unwrap());
  84. }
  85. if req.method != hyper::Method::POST {
  86. return Ok(Response::builder()
  87. .status(StatusCode::METHOD_NOT_ALLOWED)
  88. .header(header::ALLOW, "POST")
  89. .body(Body::empty())
  90. .unwrap());
  91. }
  92. let event = if let Some(ev) = req.headers.get("X-GitHub-Event") {
  93. let ev = match ev.to_str().ok() {
  94. Some(v) => v,
  95. None => {
  96. return Ok(Response::builder()
  97. .status(StatusCode::BAD_REQUEST)
  98. .body(Body::from("X-GitHub-Event header must be UTF-8 encoded"))
  99. .unwrap());
  100. }
  101. };
  102. match ev.parse::<EventName>() {
  103. Ok(v) => v,
  104. Err(_) => unreachable!(),
  105. }
  106. } else {
  107. return Ok(Response::builder()
  108. .status(StatusCode::BAD_REQUEST)
  109. .body(Body::from("X-GitHub-Event header must be set"))
  110. .unwrap());
  111. };
  112. log::debug!("event={}", event);
  113. let signature = if let Some(sig) = req.headers.get("X-Hub-Signature") {
  114. match sig.to_str().ok() {
  115. Some(v) => v,
  116. None => {
  117. return Ok(Response::builder()
  118. .status(StatusCode::BAD_REQUEST)
  119. .body(Body::from("X-Hub-Signature header must be UTF-8 encoded"))
  120. .unwrap());
  121. }
  122. }
  123. } else {
  124. return Ok(Response::builder()
  125. .status(StatusCode::BAD_REQUEST)
  126. .body(Body::from("X-Hub-Signature header must be set"))
  127. .unwrap());
  128. };
  129. log::debug!("signature={}", signature);
  130. let mut c = body_stream;
  131. let mut payload = Vec::new();
  132. while let Some(chunk) = c.next().await {
  133. let chunk = chunk?;
  134. payload.extend_from_slice(&chunk);
  135. }
  136. if let Err(_) = payload::assert_signed(signature, &payload) {
  137. return Ok(Response::builder()
  138. .status(StatusCode::FORBIDDEN)
  139. .body(Body::from("Wrong signature"))
  140. .unwrap());
  141. }
  142. let payload = match String::from_utf8(payload) {
  143. Ok(p) => p,
  144. Err(_) => {
  145. return Ok(Response::builder()
  146. .status(StatusCode::BAD_REQUEST)
  147. .body(Body::from("Payload must be UTF-8"))
  148. .unwrap());
  149. }
  150. };
  151. match triagebot::webhook(event, payload, &ctx).await {
  152. Ok(true) => Ok(Response::new(Body::from("processed request"))),
  153. Ok(false) => Ok(Response::new(Body::from("ignored request"))),
  154. Err(err) => {
  155. log::error!("request failed: {:?}", err);
  156. Ok(Response::builder()
  157. .status(StatusCode::INTERNAL_SERVER_ERROR)
  158. .body(Body::from(format!("request failed: {:?}", err)))
  159. .unwrap())
  160. }
  161. }
  162. }
  163. async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
  164. log::info!("Listening on http://{}", addr);
  165. let db_client = db::make_client()
  166. .await
  167. .context("open database connection")?;
  168. db::run_migrations(&db_client)
  169. .await
  170. .context("database migrations")?;
  171. let client = Client::new();
  172. let gh = github::GithubClient::new(
  173. client.clone(),
  174. env::var("GITHUB_API_TOKEN").expect("Missing GITHUB_API_TOKEN"),
  175. );
  176. let oc = octocrab::OctocrabBuilder::new()
  177. .personal_token(env::var("GITHUB_API_TOKEN").expect("Missing GITHUB_API_TOKEN"))
  178. .build()
  179. .expect("Failed to build octograb.");
  180. let ctx = Arc::new(Context {
  181. username: github::User::current(&gh).await.unwrap().login,
  182. db: db_client,
  183. github: gh,
  184. octocrab: oc,
  185. });
  186. let svc = hyper::service::make_service_fn(move |_conn| {
  187. let ctx = ctx.clone();
  188. async move {
  189. let uuid = Uuid::new_v4();
  190. Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
  191. logger::LogFuture::new(
  192. uuid,
  193. serve_req(req, ctx.clone()).map(move |mut resp| {
  194. if let Ok(resp) = &mut resp {
  195. resp.headers_mut()
  196. .insert("X-Request-Id", uuid.to_string().parse().unwrap());
  197. }
  198. log::info!("response = {:?}", resp);
  199. resp
  200. }),
  201. )
  202. }))
  203. }
  204. });
  205. let serve_future = Server::bind(&addr).serve(svc);
  206. serve_future.await?;
  207. Ok(())
  208. }
  209. #[tokio::main]
  210. async fn main() {
  211. dotenv::dotenv().ok();
  212. logger::init();
  213. let port = env::var("PORT")
  214. .ok()
  215. .map(|p| p.parse::<u16>().expect("parsed PORT"))
  216. .unwrap_or(8000);
  217. let addr = ([0, 0, 0, 0], port).into();
  218. if let Err(e) = run_server(addr).await {
  219. eprintln!("Failed to run server: {:?}", e);
  220. }
  221. }