main.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. #![allow(clippy::new_without_default)]
  2. use anyhow::Context as _;
  3. use futures::future::FutureExt;
  4. use futures::StreamExt;
  5. use hyper::{header, Body, Request, Response, Server, StatusCode};
  6. use route_recognizer::Router;
  7. use std::{env, net::SocketAddr, sync::Arc};
  8. use tokio::{task, time};
  9. // use tower::{Service, ServiceExt};
  10. use tracing as log;
  11. use tracing::Instrument;
  12. use triagebot::handlers::pull_requests_assignment_update::PullRequestAssignmentUpdate;
  13. use triagebot::jobs::{
  14. default_jobs, Job, JOB_PROCESSING_CADENCE_IN_SECS, JOB_SCHEDULING_CADENCE_IN_SECS,
  15. };
  16. use triagebot::{db, github, handlers::Context, notification_listing, payload, EventName};
  17. #[allow(dead_code)]
  18. async fn handle_agenda_request(req: String) -> anyhow::Result<String> {
  19. if req == "/agenda/lang/triage" {
  20. return triagebot::agenda::lang().call().await;
  21. }
  22. if req == "/agenda/lang/planning" {
  23. return triagebot::agenda::lang_planning().call().await;
  24. }
  25. if req == "/agenda/types/planning" {
  26. return triagebot::agenda::types_planning().call().await;
  27. }
  28. anyhow::bail!("Unknown agenda; see /agenda for index.")
  29. }
  30. async fn serve_req(
  31. req: Request<Body>,
  32. ctx: Arc<Context>,
  33. // mut agenda: impl Service<String, Response = String, Error = tower::BoxError>,
  34. ) -> Result<Response<Body>, hyper::Error> {
  35. log::info!("request = {:?}", req);
  36. let mut router = Router::new();
  37. router.add("/triage", "index".to_string());
  38. router.add("/triage/:owner/:repo", "pulls".to_string());
  39. let (req, body_stream) = req.into_parts();
  40. if let Ok(matcher) = router.recognize(req.uri.path()) {
  41. if matcher.handler().as_str() == "pulls" {
  42. let params = matcher.params();
  43. let owner = params.find("owner");
  44. let repo = params.find("repo");
  45. return triagebot::triage::pulls(ctx, owner.unwrap(), repo.unwrap()).await;
  46. } else {
  47. return triagebot::triage::index();
  48. }
  49. }
  50. // if req.uri.path() == "/agenda" {
  51. // return Ok(Response::builder()
  52. // .status(StatusCode::OK)
  53. // .body(Body::from(triagebot::agenda::INDEX))
  54. // .unwrap());
  55. // }
  56. // if req.uri.path() == "/agenda/lang/triage"
  57. // || req.uri.path() == "/agenda/lang/planning"
  58. // || req.uri.path() == "/agenda/types/planning"
  59. // {
  60. // match agenda
  61. // .ready()
  62. // .await
  63. // .expect("agenda keeps running")
  64. // .call(req.uri.path().to_owned())
  65. // .await
  66. // {
  67. // Ok(agenda) => {
  68. // return Ok(Response::builder()
  69. // .status(StatusCode::OK)
  70. // .body(Body::from(agenda))
  71. // .unwrap())
  72. // }
  73. // Err(err) => {
  74. // return Ok(Response::builder()
  75. // .status(StatusCode::INTERNAL_SERVER_ERROR)
  76. // .body(Body::from(err.to_string()))
  77. // .unwrap())
  78. // }
  79. // }
  80. // }
  81. if req.uri.path() == "/" {
  82. return Ok(Response::builder()
  83. .status(StatusCode::OK)
  84. .body(Body::from("Triagebot is awaiting triage."))
  85. .unwrap());
  86. }
  87. if req.uri.path() == "/bors-commit-list" {
  88. let res = db::rustc_commits::get_commits_with_artifacts(&*ctx.db.get().await).await;
  89. let res = match res {
  90. Ok(r) => r,
  91. Err(e) => {
  92. return Ok(Response::builder()
  93. .status(StatusCode::INTERNAL_SERVER_ERROR)
  94. .body(Body::from(format!("{:?}", e)))
  95. .unwrap());
  96. }
  97. };
  98. return Ok(Response::builder()
  99. .status(StatusCode::OK)
  100. .header("Content-Type", "application/json")
  101. .body(Body::from(serde_json::to_string(&res).unwrap()))
  102. .unwrap());
  103. }
  104. if req.uri.path() == "/notifications" {
  105. if let Some(query) = req.uri.query() {
  106. let user = url::form_urlencoded::parse(query.as_bytes()).find(|(k, _)| k == "user");
  107. if let Some((_, name)) = user {
  108. return Ok(Response::builder()
  109. .status(StatusCode::OK)
  110. .body(Body::from(
  111. notification_listing::render(&ctx.db.get().await, &*name).await,
  112. ))
  113. .unwrap());
  114. }
  115. }
  116. return Ok(Response::builder()
  117. .status(StatusCode::OK)
  118. .body(Body::from(String::from(
  119. "Please provide `?user=<username>` query param on URL.",
  120. )))
  121. .unwrap());
  122. }
  123. if req.uri.path() == "/zulip-hook" {
  124. let mut c = body_stream;
  125. let mut payload = Vec::new();
  126. while let Some(chunk) = c.next().await {
  127. let chunk = chunk?;
  128. payload.extend_from_slice(&chunk);
  129. }
  130. let req = match serde_json::from_slice(&payload) {
  131. Ok(r) => r,
  132. Err(e) => {
  133. return Ok(Response::builder()
  134. .status(StatusCode::BAD_REQUEST)
  135. .body(Body::from(format!(
  136. "Did not send valid JSON request: {}",
  137. e
  138. )))
  139. .unwrap());
  140. }
  141. };
  142. return Ok(Response::builder()
  143. .status(StatusCode::OK)
  144. .body(Body::from(triagebot::zulip::respond(&ctx, req).await))
  145. .unwrap());
  146. }
  147. if req.uri.path() != "/github-hook" {
  148. return Ok(Response::builder()
  149. .status(StatusCode::NOT_FOUND)
  150. .body(Body::empty())
  151. .unwrap());
  152. }
  153. if req.method != hyper::Method::POST {
  154. return Ok(Response::builder()
  155. .status(StatusCode::METHOD_NOT_ALLOWED)
  156. .header(header::ALLOW, "POST")
  157. .body(Body::empty())
  158. .unwrap());
  159. }
  160. let event = if let Some(ev) = req.headers.get("X-GitHub-Event") {
  161. let ev = match ev.to_str().ok() {
  162. Some(v) => v,
  163. None => {
  164. return Ok(Response::builder()
  165. .status(StatusCode::BAD_REQUEST)
  166. .body(Body::from("X-GitHub-Event header must be UTF-8 encoded"))
  167. .unwrap());
  168. }
  169. };
  170. match ev.parse::<EventName>() {
  171. Ok(v) => v,
  172. Err(_) => unreachable!(),
  173. }
  174. } else {
  175. return Ok(Response::builder()
  176. .status(StatusCode::BAD_REQUEST)
  177. .body(Body::from("X-GitHub-Event header must be set"))
  178. .unwrap());
  179. };
  180. log::debug!("event={}", event);
  181. let signature = if let Some(sig) = req.headers.get("X-Hub-Signature") {
  182. match sig.to_str().ok() {
  183. Some(v) => v,
  184. None => {
  185. return Ok(Response::builder()
  186. .status(StatusCode::BAD_REQUEST)
  187. .body(Body::from("X-Hub-Signature header must be UTF-8 encoded"))
  188. .unwrap());
  189. }
  190. }
  191. } else {
  192. return Ok(Response::builder()
  193. .status(StatusCode::BAD_REQUEST)
  194. .body(Body::from("X-Hub-Signature header must be set"))
  195. .unwrap());
  196. };
  197. log::debug!("signature={}", signature);
  198. let mut c = body_stream;
  199. let mut payload = Vec::new();
  200. while let Some(chunk) = c.next().await {
  201. let chunk = chunk?;
  202. payload.extend_from_slice(&chunk);
  203. }
  204. if let Err(_) = payload::assert_signed(signature, &payload) {
  205. return Ok(Response::builder()
  206. .status(StatusCode::FORBIDDEN)
  207. .body(Body::from("Wrong signature"))
  208. .unwrap());
  209. }
  210. let payload = match String::from_utf8(payload) {
  211. Ok(p) => p,
  212. Err(_) => {
  213. return Ok(Response::builder()
  214. .status(StatusCode::BAD_REQUEST)
  215. .body(Body::from("Payload must be UTF-8"))
  216. .unwrap());
  217. }
  218. };
  219. match triagebot::webhook(event, payload, &ctx).await {
  220. Ok(true) => Ok(Response::new(Body::from("processed request"))),
  221. Ok(false) => Ok(Response::new(Body::from("ignored request"))),
  222. Err(err) => {
  223. log::error!("request failed: {:?}", err);
  224. Ok(Response::builder()
  225. .status(StatusCode::INTERNAL_SERVER_ERROR)
  226. .body(Body::from(format!("request failed: {:?}", err)))
  227. .unwrap())
  228. }
  229. }
  230. }
  231. async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
  232. let pool = db::ClientPool::new();
  233. db::run_migrations(&*pool.get().await)
  234. .await
  235. .context("database migrations")?;
  236. let gh = github::GithubClient::new_from_env();
  237. let oc = octocrab::OctocrabBuilder::new()
  238. .personal_token(github::default_token_from_env())
  239. .build()
  240. .expect("Failed to build octograb.");
  241. let ctx = Arc::new(Context {
  242. username: std::env::var("TRIAGEBOT_USERNAME").or_else(|err| match err {
  243. std::env::VarError::NotPresent => Ok("dragonosbot".to_owned()),
  244. err => Err(err),
  245. })?,
  246. db: pool,
  247. github: gh,
  248. octocrab: oc,
  249. });
  250. // Run all jobs that don't have a schedule (one-off jobs)
  251. // TODO: Ideally JobSchedule.schedule should become an `Option<Schedule>`
  252. // and here we run all those with schedule=None
  253. if !is_scheduled_jobs_disabled() {
  254. spawn_job_oneoffs(ctx.clone()).await;
  255. }
  256. // Run all jobs that have a schedule (recurring jobs)
  257. if !is_scheduled_jobs_disabled() {
  258. spawn_job_scheduler();
  259. spawn_job_runner(ctx.clone());
  260. }
  261. // let agenda = tower::ServiceBuilder::new()
  262. // .buffer(10)
  263. // .layer_fn(|input| {
  264. // tower::util::MapErr::new(
  265. // tower::load_shed::LoadShed::new(tower::limit::RateLimit::new(
  266. // input,
  267. // tower::limit::rate::Rate::new(2, std::time::Duration::from_secs(60)),
  268. // )),
  269. // |e| {
  270. // tracing::error!("agenda request failed: {:?}", e);
  271. // anyhow::anyhow!("Rate limit of 2 request / 60 seconds exceeded")
  272. // },
  273. // )
  274. // })
  275. // .service_fn(handle_agenda_request);
  276. let svc = hyper::service::make_service_fn(move |_conn| {
  277. let ctx = ctx.clone();
  278. // let agenda = agenda.clone();
  279. async move {
  280. Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
  281. let uuid = uuid::Uuid::new_v4();
  282. let span = tracing::span!(tracing::Level::INFO, "request", ?uuid);
  283. serve_req(req, ctx.clone())
  284. .map(move |mut resp| {
  285. if let Ok(resp) = &mut resp {
  286. resp.headers_mut()
  287. .insert("X-Request-Id", uuid.to_string().parse().unwrap());
  288. }
  289. log::info!("response = {:?}", resp);
  290. resp
  291. })
  292. .instrument(span)
  293. }))
  294. }
  295. });
  296. log::info!("Listening on http://{}", addr);
  297. let serve_future = Server::bind(&addr).serve(svc);
  298. serve_future.await?;
  299. Ok(())
  300. }
  301. /// Spawns a background tokio task which runs all jobs having no schedule
  302. /// i.e. manually executed at the end of the triagebot startup
  303. // - jobs are not guaranteed to start in sequence (care is to be taken to ensure thet are completely independent one from the other)
  304. // - the delay between jobs start is not guaranteed to be precise
  305. async fn spawn_job_oneoffs(ctx: Arc<Context>) {
  306. let jobs: Vec<Box<dyn Job + Send + Sync>> = vec![Box::new(PullRequestAssignmentUpdate)];
  307. for (idx, job) in jobs.into_iter().enumerate() {
  308. let ctx = ctx.clone();
  309. task::spawn(async move {
  310. // Allow some spacing between starting jobs
  311. let delay = idx as u64 * 2;
  312. time::sleep(time::Duration::from_secs(delay)).await;
  313. match job.run(&ctx, &serde_json::Value::Null).await {
  314. Ok(_) => {
  315. log::trace!("job successfully executed (name={})", &job.name());
  316. }
  317. Err(e) => {
  318. log::error!(
  319. "job failed on execution (name={:?}, error={:?})",
  320. job.name(),
  321. e
  322. );
  323. }
  324. }
  325. });
  326. }
  327. }
  328. /// Spawns a background tokio task which runs continuously to queue up jobs
  329. /// to be run by the job runner.
  330. ///
  331. /// The scheduler wakes up every `JOB_SCHEDULING_CADENCE_IN_SECS` seconds to
  332. /// check if there are any jobs ready to run. Jobs get inserted into the the
  333. /// database which acts as a queue.
  334. fn spawn_job_scheduler() {
  335. task::spawn(async move {
  336. loop {
  337. let res = task::spawn(async move {
  338. let pool = db::ClientPool::new();
  339. let mut interval =
  340. time::interval(time::Duration::from_secs(JOB_SCHEDULING_CADENCE_IN_SECS));
  341. loop {
  342. interval.tick().await;
  343. db::schedule_jobs(&*pool.get().await, default_jobs())
  344. .await
  345. .context("database schedule jobs")
  346. .unwrap();
  347. }
  348. });
  349. match res.await {
  350. Err(err) if err.is_panic() => {
  351. /* handle panic in above task, re-launching */
  352. tracing::error!("schedule_jobs task died (error={err})");
  353. tokio::time::sleep(std::time::Duration::new(5, 0)).await;
  354. }
  355. _ => unreachable!(),
  356. }
  357. }
  358. });
  359. }
  360. /// Spawns a background tokio task which runs continuously to run scheduled
  361. /// jobs.
  362. ///
  363. /// The runner wakes up every `JOB_PROCESSING_CADENCE_IN_SECS` seconds to
  364. /// check if any jobs have been put into the queue by the scheduler. They
  365. /// will get popped off the queue and run if any are found.
  366. fn spawn_job_runner(ctx: Arc<Context>) {
  367. task::spawn(async move {
  368. loop {
  369. let ctx = ctx.clone();
  370. let res = task::spawn(async move {
  371. let pool = db::ClientPool::new();
  372. let mut interval =
  373. time::interval(time::Duration::from_secs(JOB_PROCESSING_CADENCE_IN_SECS));
  374. loop {
  375. interval.tick().await;
  376. db::run_scheduled_jobs(&ctx, &*pool.get().await)
  377. .await
  378. .context("run database scheduled jobs")
  379. .unwrap();
  380. }
  381. });
  382. match res.await {
  383. Err(err) if err.is_panic() => {
  384. /* handle panic in above task, re-launching */
  385. tracing::error!("run_scheduled_jobs task died (error={err})");
  386. tokio::time::sleep(std::time::Duration::new(5, 0)).await;
  387. }
  388. _ => unreachable!(),
  389. }
  390. }
  391. });
  392. }
  393. /// Determines whether or not background scheduled jobs should be disabled for
  394. /// the purpose of testing.
  395. ///
  396. /// This helps avoid having random jobs run while testing other things.
  397. fn is_scheduled_jobs_disabled() -> bool {
  398. env::var_os("TRIAGEBOT_TEST_DISABLE_JOBS").is_some()
  399. }
  400. #[tokio::main(flavor = "current_thread")]
  401. async fn main() {
  402. dotenv::dotenv().ok();
  403. tracing_subscriber::fmt::Subscriber::builder()
  404. .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
  405. .with_ansi(std::env::var_os("DISABLE_COLOR").is_none())
  406. .try_init()
  407. .unwrap();
  408. let port = env::var("PORT")
  409. .ok()
  410. .map(|p| p.parse::<u16>().expect("parsed PORT"))
  411. .unwrap_or(8000);
  412. let addr = ([0, 0, 0, 0], port).into();
  413. if let Err(e) = run_server(addr).await {
  414. eprintln!("Failed to run server: {:?}", e);
  415. }
  416. }