main.rs 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. #![allow(clippy::new_without_default)]
  2. use anyhow::Context as _;
  3. use futures::{future::FutureExt, stream::StreamExt};
  4. use hyper::{header, Body, Request, Response, Server, StatusCode};
  5. use reqwest::Client;
  6. use std::{env, net::SocketAddr, sync::Arc};
  7. use triagebot::{db, github, handlers::Context, logger, notification_listing, payload, EventName};
  8. use uuid::Uuid;
  9. async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Body>, hyper::Error> {
  10. log::info!("request = {:?}", req);
  11. let (req, body_stream) = req.into_parts();
  12. if req.uri.path() == "/" {
  13. return Ok(Response::builder()
  14. .status(StatusCode::OK)
  15. .body(Body::from("Triagebot is awaiting triage."))
  16. .unwrap());
  17. }
  18. if req.uri.path() == "/bors-commit-list" {
  19. let res = db::rustc_commits::get_commits_with_artifacts(&ctx.db).await;
  20. let res = match res {
  21. Ok(r) => r,
  22. Err(e) => {
  23. return Ok(Response::builder()
  24. .status(StatusCode::INTERNAL_SERVER_ERROR)
  25. .body(Body::from(format!("{:?}", e)))
  26. .unwrap());
  27. }
  28. };
  29. return Ok(Response::builder()
  30. .status(StatusCode::OK)
  31. .header("Content-Type", "application/json")
  32. .body(Body::from(serde_json::to_string(&res).unwrap()))
  33. .unwrap());
  34. }
  35. if req.uri.path() == "/notifications" {
  36. if let Some(query) = req.uri.query() {
  37. let user = url::form_urlencoded::parse(query.as_bytes()).find(|(k, _)| k == "user");
  38. if let Some((_, name)) = user {
  39. return Ok(Response::builder()
  40. .status(StatusCode::OK)
  41. .body(Body::from(
  42. notification_listing::render(&ctx.db, &*name).await,
  43. ))
  44. .unwrap());
  45. }
  46. }
  47. return Ok(Response::builder()
  48. .status(StatusCode::OK)
  49. .body(Body::from(String::from(
  50. "Please provide `?user=<username>` query param on URL.",
  51. )))
  52. .unwrap());
  53. }
  54. if req.uri.path() == "/zulip-hook" {
  55. let mut c = body_stream;
  56. let mut payload = Vec::new();
  57. while let Some(chunk) = c.next().await {
  58. let chunk = chunk?;
  59. payload.extend_from_slice(&chunk);
  60. }
  61. let req = match serde_json::from_slice(&payload) {
  62. Ok(r) => r,
  63. Err(e) => {
  64. return Ok(Response::builder()
  65. .status(StatusCode::BAD_REQUEST)
  66. .body(Body::from(format!(
  67. "Did not send valid JSON request: {}",
  68. e
  69. )))
  70. .unwrap());
  71. }
  72. };
  73. return Ok(Response::builder()
  74. .status(StatusCode::OK)
  75. .body(Body::from(triagebot::zulip::respond(&ctx, req).await))
  76. .unwrap());
  77. }
  78. if req.uri.path() != "/github-hook" {
  79. return Ok(Response::builder()
  80. .status(StatusCode::NOT_FOUND)
  81. .body(Body::empty())
  82. .unwrap());
  83. }
  84. if req.method != hyper::Method::POST {
  85. return Ok(Response::builder()
  86. .status(StatusCode::METHOD_NOT_ALLOWED)
  87. .header(header::ALLOW, "POST")
  88. .body(Body::empty())
  89. .unwrap());
  90. }
  91. let event = if let Some(ev) = req.headers.get("X-GitHub-Event") {
  92. let ev = match ev.to_str().ok() {
  93. Some(v) => v,
  94. None => {
  95. return Ok(Response::builder()
  96. .status(StatusCode::BAD_REQUEST)
  97. .body(Body::from("X-GitHub-Event header must be UTF-8 encoded"))
  98. .unwrap());
  99. }
  100. };
  101. match ev.parse::<EventName>() {
  102. Ok(v) => v,
  103. Err(_) => unreachable!(),
  104. }
  105. } else {
  106. return Ok(Response::builder()
  107. .status(StatusCode::BAD_REQUEST)
  108. .body(Body::from("X-GitHub-Event header must be set"))
  109. .unwrap());
  110. };
  111. log::debug!("event={}", event);
  112. let signature = if let Some(sig) = req.headers.get("X-Hub-Signature") {
  113. match sig.to_str().ok() {
  114. Some(v) => v,
  115. None => {
  116. return Ok(Response::builder()
  117. .status(StatusCode::BAD_REQUEST)
  118. .body(Body::from("X-Hub-Signature header must be UTF-8 encoded"))
  119. .unwrap());
  120. }
  121. }
  122. } else {
  123. return Ok(Response::builder()
  124. .status(StatusCode::BAD_REQUEST)
  125. .body(Body::from("X-Hub-Signature header must be set"))
  126. .unwrap());
  127. };
  128. log::debug!("signature={}", signature);
  129. let mut c = body_stream;
  130. let mut payload = Vec::new();
  131. while let Some(chunk) = c.next().await {
  132. let chunk = chunk?;
  133. payload.extend_from_slice(&chunk);
  134. }
  135. if let Err(_) = payload::assert_signed(signature, &payload) {
  136. return Ok(Response::builder()
  137. .status(StatusCode::FORBIDDEN)
  138. .body(Body::from("Wrong signature"))
  139. .unwrap());
  140. }
  141. let payload = match String::from_utf8(payload) {
  142. Ok(p) => p,
  143. Err(_) => {
  144. return Ok(Response::builder()
  145. .status(StatusCode::BAD_REQUEST)
  146. .body(Body::from("Payload must be UTF-8"))
  147. .unwrap());
  148. }
  149. };
  150. match triagebot::webhook(event, payload, &ctx).await {
  151. Ok(true) => Ok(Response::new(Body::from("processed request"))),
  152. Ok(false) => Ok(Response::new(Body::from("ignored request"))),
  153. Err(err) => {
  154. log::error!("request failed: {:?}", err);
  155. Ok(Response::builder()
  156. .status(StatusCode::INTERNAL_SERVER_ERROR)
  157. .body(Body::from(format!("request failed: {:?}", err)))
  158. .unwrap())
  159. }
  160. }
  161. }
  162. async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
  163. log::info!("Listening on http://{}", addr);
  164. let db_client = db::make_client()
  165. .await
  166. .context("open database connection")?;
  167. db::run_migrations(&db_client)
  168. .await
  169. .context("database migrations")?;
  170. let client = Client::new();
  171. let gh = github::GithubClient::new_with_default_token(client.clone());
  172. let oc = octocrab::OctocrabBuilder::new()
  173. .personal_token(github::default_token_from_env())
  174. .build()
  175. .expect("Failed to build octograb.");
  176. let ctx = Arc::new(Context {
  177. username: String::from("rustbot"),
  178. db: db_client,
  179. github: gh,
  180. octocrab: oc,
  181. });
  182. let svc = hyper::service::make_service_fn(move |_conn| {
  183. let ctx = ctx.clone();
  184. async move {
  185. let uuid = Uuid::new_v4();
  186. Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
  187. logger::LogFuture::new(
  188. uuid,
  189. serve_req(req, ctx.clone()).map(move |mut resp| {
  190. if let Ok(resp) = &mut resp {
  191. resp.headers_mut()
  192. .insert("X-Request-Id", uuid.to_string().parse().unwrap());
  193. }
  194. log::info!("response = {:?}", resp);
  195. resp
  196. }),
  197. )
  198. }))
  199. }
  200. });
  201. let serve_future = Server::bind(&addr).serve(svc);
  202. serve_future.await?;
  203. Ok(())
  204. }
  205. #[tokio::main]
  206. async fn main() {
  207. dotenv::dotenv().ok();
  208. logger::init();
  209. let port = env::var("PORT")
  210. .ok()
  211. .map(|p| p.parse::<u16>().expect("parsed PORT"))
  212. .unwrap_or(8000);
  213. let addr = ([0, 0, 0, 0], port).into();
  214. if let Err(e) = run_server(addr).await {
  215. eprintln!("Failed to run server: {:?}", e);
  216. }
  217. }