main.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. #![allow(clippy::new_without_default)]
  2. use anyhow::Context as _;
  3. use futures::future::FutureExt;
  4. use futures::StreamExt;
  5. use hyper::{header, Body, Request, Response, Server, StatusCode};
  6. use reqwest::Client;
  7. use route_recognizer::Router;
  8. use std::{env, net::SocketAddr, sync::Arc, time::Duration};
  9. use tokio::{task, time::sleep};
  10. use tower::{Service, ServiceExt};
  11. use tracing as log;
  12. use tracing::Instrument;
  13. use triagebot::{db, github, handlers::Context, notification_listing, payload, EventName};
  14. const JOB_PROCESSING_CADENCE_IN_SECS: u64 = 60;
  15. async fn handle_agenda_request(req: String) -> anyhow::Result<String> {
  16. if req == "/agenda/lang/triage" {
  17. return triagebot::agenda::lang().call().await;
  18. }
  19. if req == "/agenda/lang/planning" {
  20. return triagebot::agenda::lang_planning().call().await;
  21. }
  22. anyhow::bail!("Unknown agenda; see /agenda for index.")
  23. }
  24. async fn serve_req(
  25. req: Request<Body>,
  26. ctx: Arc<Context>,
  27. mut agenda: impl Service<String, Response = String, Error = tower::BoxError>,
  28. ) -> Result<Response<Body>, hyper::Error> {
  29. log::info!("request = {:?}", req);
  30. let mut router = Router::new();
  31. router.add("/triage", "index".to_string());
  32. router.add("/triage/:owner/:repo", "pulls".to_string());
  33. let (req, body_stream) = req.into_parts();
  34. if let Ok(matcher) = router.recognize(req.uri.path()) {
  35. if matcher.handler().as_str() == "pulls" {
  36. let params = matcher.params();
  37. let owner = params.find("owner");
  38. let repo = params.find("repo");
  39. return triagebot::triage::pulls(ctx, owner.unwrap(), repo.unwrap()).await;
  40. } else {
  41. return triagebot::triage::index();
  42. }
  43. }
  44. if req.uri.path() == "/agenda" {
  45. return Ok(Response::builder()
  46. .status(StatusCode::OK)
  47. .body(Body::from(triagebot::agenda::INDEX))
  48. .unwrap());
  49. }
  50. if req.uri.path() == "/agenda/lang/triage" || req.uri.path() == "/agenda/lang/planning" {
  51. match agenda
  52. .ready()
  53. .await
  54. .expect("agenda keeps running")
  55. .call(req.uri.path().to_owned())
  56. .await
  57. {
  58. Ok(agenda) => {
  59. return Ok(Response::builder()
  60. .status(StatusCode::OK)
  61. .body(Body::from(agenda))
  62. .unwrap())
  63. }
  64. Err(err) => {
  65. return Ok(Response::builder()
  66. .status(StatusCode::INTERNAL_SERVER_ERROR)
  67. .body(Body::from(err.to_string()))
  68. .unwrap())
  69. }
  70. }
  71. }
  72. if req.uri.path() == "/" {
  73. return Ok(Response::builder()
  74. .status(StatusCode::OK)
  75. .body(Body::from("Triagebot is awaiting triage."))
  76. .unwrap());
  77. }
  78. if req.uri.path() == "/bors-commit-list" {
  79. let res = db::rustc_commits::get_commits_with_artifacts(&*ctx.db.get().await).await;
  80. let res = match res {
  81. Ok(r) => r,
  82. Err(e) => {
  83. return Ok(Response::builder()
  84. .status(StatusCode::INTERNAL_SERVER_ERROR)
  85. .body(Body::from(format!("{:?}", e)))
  86. .unwrap());
  87. }
  88. };
  89. return Ok(Response::builder()
  90. .status(StatusCode::OK)
  91. .header("Content-Type", "application/json")
  92. .body(Body::from(serde_json::to_string(&res).unwrap()))
  93. .unwrap());
  94. }
  95. if req.uri.path() == "/notifications" {
  96. if let Some(query) = req.uri.query() {
  97. let user = url::form_urlencoded::parse(query.as_bytes()).find(|(k, _)| k == "user");
  98. if let Some((_, name)) = user {
  99. return Ok(Response::builder()
  100. .status(StatusCode::OK)
  101. .body(Body::from(
  102. notification_listing::render(&ctx.db.get().await, &*name).await,
  103. ))
  104. .unwrap());
  105. }
  106. }
  107. return Ok(Response::builder()
  108. .status(StatusCode::OK)
  109. .body(Body::from(String::from(
  110. "Please provide `?user=<username>` query param on URL.",
  111. )))
  112. .unwrap());
  113. }
  114. if req.uri.path() == "/zulip-hook" {
  115. let mut c = body_stream;
  116. let mut payload = Vec::new();
  117. while let Some(chunk) = c.next().await {
  118. let chunk = chunk?;
  119. payload.extend_from_slice(&chunk);
  120. }
  121. let req = match serde_json::from_slice(&payload) {
  122. Ok(r) => r,
  123. Err(e) => {
  124. return Ok(Response::builder()
  125. .status(StatusCode::BAD_REQUEST)
  126. .body(Body::from(format!(
  127. "Did not send valid JSON request: {}",
  128. e
  129. )))
  130. .unwrap());
  131. }
  132. };
  133. return Ok(Response::builder()
  134. .status(StatusCode::OK)
  135. .body(Body::from(triagebot::zulip::respond(&ctx, req).await))
  136. .unwrap());
  137. }
  138. if req.uri.path() != "/github-hook" {
  139. return Ok(Response::builder()
  140. .status(StatusCode::NOT_FOUND)
  141. .body(Body::empty())
  142. .unwrap());
  143. }
  144. if req.method != hyper::Method::POST {
  145. return Ok(Response::builder()
  146. .status(StatusCode::METHOD_NOT_ALLOWED)
  147. .header(header::ALLOW, "POST")
  148. .body(Body::empty())
  149. .unwrap());
  150. }
  151. let event = if let Some(ev) = req.headers.get("X-GitHub-Event") {
  152. let ev = match ev.to_str().ok() {
  153. Some(v) => v,
  154. None => {
  155. return Ok(Response::builder()
  156. .status(StatusCode::BAD_REQUEST)
  157. .body(Body::from("X-GitHub-Event header must be UTF-8 encoded"))
  158. .unwrap());
  159. }
  160. };
  161. match ev.parse::<EventName>() {
  162. Ok(v) => v,
  163. Err(_) => unreachable!(),
  164. }
  165. } else {
  166. return Ok(Response::builder()
  167. .status(StatusCode::BAD_REQUEST)
  168. .body(Body::from("X-GitHub-Event header must be set"))
  169. .unwrap());
  170. };
  171. log::debug!("event={}", event);
  172. let signature = if let Some(sig) = req.headers.get("X-Hub-Signature") {
  173. match sig.to_str().ok() {
  174. Some(v) => v,
  175. None => {
  176. return Ok(Response::builder()
  177. .status(StatusCode::BAD_REQUEST)
  178. .body(Body::from("X-Hub-Signature header must be UTF-8 encoded"))
  179. .unwrap());
  180. }
  181. }
  182. } else {
  183. return Ok(Response::builder()
  184. .status(StatusCode::BAD_REQUEST)
  185. .body(Body::from("X-Hub-Signature header must be set"))
  186. .unwrap());
  187. };
  188. log::debug!("signature={}", signature);
  189. let mut c = body_stream;
  190. let mut payload = Vec::new();
  191. while let Some(chunk) = c.next().await {
  192. let chunk = chunk?;
  193. payload.extend_from_slice(&chunk);
  194. }
  195. if let Err(_) = payload::assert_signed(signature, &payload) {
  196. return Ok(Response::builder()
  197. .status(StatusCode::FORBIDDEN)
  198. .body(Body::from("Wrong signature"))
  199. .unwrap());
  200. }
  201. let payload = match String::from_utf8(payload) {
  202. Ok(p) => p,
  203. Err(_) => {
  204. return Ok(Response::builder()
  205. .status(StatusCode::BAD_REQUEST)
  206. .body(Body::from("Payload must be UTF-8"))
  207. .unwrap());
  208. }
  209. };
  210. match triagebot::webhook(event, payload, &ctx).await {
  211. Ok(true) => Ok(Response::new(Body::from("processed request"))),
  212. Ok(false) => Ok(Response::new(Body::from("ignored request"))),
  213. Err(err) => {
  214. log::error!("request failed: {:?}", err);
  215. Ok(Response::builder()
  216. .status(StatusCode::INTERNAL_SERVER_ERROR)
  217. .body(Body::from(format!("request failed: {:?}", err)))
  218. .unwrap())
  219. }
  220. }
  221. }
  222. async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
  223. let pool = db::ClientPool::new();
  224. db::run_migrations(&*pool.get().await)
  225. .await
  226. .context("database migrations")?;
  227. // spawning a background task that will run the scheduled jobs
  228. // every JOB_PROCESSING_CADENCE_IN_SECS
  229. task::spawn(async move {
  230. loop {
  231. let res = task::spawn(async move {
  232. let pool = db::ClientPool::new();
  233. loop {
  234. db::run_scheduled_jobs(&*pool.get().await)
  235. .await
  236. .context("run database scheduled jobs")
  237. .unwrap();
  238. sleep(Duration::from_secs(JOB_PROCESSING_CADENCE_IN_SECS)).await;
  239. }
  240. });
  241. match res.await {
  242. Err(err) if err.is_panic() => {
  243. /* handle panic in above task, re-launching */
  244. tracing::trace!("run_scheduled_jobs task died (error={})", err);
  245. }
  246. _ => unreachable!()
  247. }
  248. }
  249. });
  250. let client = Client::new();
  251. let gh = github::GithubClient::new_with_default_token(client.clone());
  252. let oc = octocrab::OctocrabBuilder::new()
  253. .personal_token(github::default_token_from_env())
  254. .build()
  255. .expect("Failed to build octograb.");
  256. let ctx = Arc::new(Context {
  257. username: String::from("rustbot"),
  258. db: pool,
  259. github: gh,
  260. octocrab: oc,
  261. });
  262. let agenda = tower::ServiceBuilder::new()
  263. .buffer(10)
  264. .layer_fn(|input| {
  265. tower::util::MapErr::new(
  266. tower::load_shed::LoadShed::new(tower::limit::RateLimit::new(
  267. input,
  268. tower::limit::rate::Rate::new(2, std::time::Duration::from_secs(60)),
  269. )),
  270. |_| anyhow::anyhow!("Rate limit of 2 request / 60 seconds exceeded"),
  271. )
  272. })
  273. .service_fn(handle_agenda_request);
  274. let svc = hyper::service::make_service_fn(move |_conn| {
  275. let ctx = ctx.clone();
  276. let agenda = agenda.clone();
  277. async move {
  278. Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
  279. let uuid = uuid::Uuid::new_v4();
  280. let span = tracing::span!(tracing::Level::INFO, "request", ?uuid);
  281. serve_req(req, ctx.clone(), agenda.clone())
  282. .map(move |mut resp| {
  283. if let Ok(resp) = &mut resp {
  284. resp.headers_mut()
  285. .insert("X-Request-Id", uuid.to_string().parse().unwrap());
  286. }
  287. log::info!("response = {:?}", resp);
  288. resp
  289. })
  290. .instrument(span)
  291. }))
  292. }
  293. });
  294. log::info!("Listening on http://{}", addr);
  295. let serve_future = Server::bind(&addr).serve(svc);
  296. serve_future.await?;
  297. Ok(())
  298. }
  299. #[tokio::main(flavor = "current_thread")]
  300. async fn main() {
  301. dotenv::dotenv().ok();
  302. tracing_subscriber::fmt::Subscriber::builder()
  303. .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
  304. .with_ansi(std::env::var_os("DISABLE_COLOR").is_none())
  305. .try_init()
  306. .unwrap();
  307. let port = env::var("PORT")
  308. .ok()
  309. .map(|p| p.parse::<u16>().expect("parsed PORT"))
  310. .unwrap_or(8000);
  311. let addr = ([0, 0, 0, 0], port).into();
  312. if let Err(e) = run_server(addr).await {
  313. eprintln!("Failed to run server: {:?}", e);
  314. }
  315. }