main.rs 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. #![allow(clippy::new_without_default)]
  2. use anyhow::Context as _;
  3. use futures::{future::FutureExt, stream::StreamExt};
  4. use hyper::{header, Body, Request, Response, Server, StatusCode};
  5. use native_tls::{Certificate, TlsConnector};
  6. use postgres_native_tls::MakeTlsConnector;
  7. use reqwest::Client;
  8. use std::{env, net::SocketAddr, sync::Arc};
  9. use triagebot::{github, handlers::Context, payload, EventName};
  10. use uuid::Uuid;
  11. mod db;
  12. mod logger;
  13. async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Body>, hyper::Error> {
  14. log::info!("request = {:?}", req);
  15. let (req, body_stream) = req.into_parts();
  16. if req.uri.path() == "/" {
  17. return Ok(Response::builder()
  18. .status(StatusCode::OK)
  19. .body(Body::from("Triagebot is awaiting triage."))
  20. .unwrap());
  21. }
  22. if req.uri.path() != "/github-hook" {
  23. return Ok(Response::builder()
  24. .status(StatusCode::NOT_FOUND)
  25. .body(Body::empty())
  26. .unwrap());
  27. }
  28. if req.method != hyper::Method::POST {
  29. return Ok(Response::builder()
  30. .status(StatusCode::METHOD_NOT_ALLOWED)
  31. .header(header::ALLOW, "POST")
  32. .body(Body::empty())
  33. .unwrap());
  34. }
  35. let event = if let Some(ev) = req.headers.get("X-GitHub-Event") {
  36. let ev = match ev.to_str().ok() {
  37. Some(v) => v,
  38. None => {
  39. return Ok(Response::builder()
  40. .status(StatusCode::BAD_REQUEST)
  41. .body(Body::from("X-GitHub-Event header must be UTF-8 encoded"))
  42. .unwrap());
  43. }
  44. };
  45. match ev.parse::<EventName>() {
  46. Ok(v) => v,
  47. Err(_) => unreachable!(),
  48. }
  49. } else {
  50. return Ok(Response::builder()
  51. .status(StatusCode::BAD_REQUEST)
  52. .body(Body::from("X-GitHub-Event header must be set"))
  53. .unwrap());
  54. };
  55. log::debug!("event={}", event);
  56. let signature = if let Some(sig) = req.headers.get("X-Hub-Signature") {
  57. match sig.to_str().ok() {
  58. Some(v) => v,
  59. None => {
  60. return Ok(Response::builder()
  61. .status(StatusCode::BAD_REQUEST)
  62. .body(Body::from("X-Hub-Signature header must be UTF-8 encoded"))
  63. .unwrap());
  64. }
  65. }
  66. } else {
  67. return Ok(Response::builder()
  68. .status(StatusCode::BAD_REQUEST)
  69. .body(Body::from("X-Hub-Signature header must be set"))
  70. .unwrap());
  71. };
  72. log::debug!("signature={}", signature);
  73. let mut c = body_stream;
  74. let mut payload = Vec::new();
  75. while let Some(chunk) = c.next().await {
  76. let chunk = chunk?;
  77. payload.extend_from_slice(&chunk);
  78. }
  79. if let Err(_) = payload::assert_signed(signature, &payload) {
  80. return Ok(Response::builder()
  81. .status(StatusCode::FORBIDDEN)
  82. .body(Body::from("Wrong signature"))
  83. .unwrap());
  84. }
  85. let payload = match String::from_utf8(payload) {
  86. Ok(p) => p,
  87. Err(_) => {
  88. return Ok(Response::builder()
  89. .status(StatusCode::BAD_REQUEST)
  90. .body(Body::from("Payload must be UTF-8"))
  91. .unwrap());
  92. }
  93. };
  94. match triagebot::webhook(event, payload, &ctx).await {
  95. Ok(()) => {}
  96. Err(err) => {
  97. log::error!("request failed: {:?}", err);
  98. return Ok(Response::builder()
  99. .status(StatusCode::INTERNAL_SERVER_ERROR)
  100. .body(Body::from("request failed"))
  101. .unwrap());
  102. }
  103. }
  104. Ok(Response::new(Body::from("processed request")))
  105. }
  106. const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem";
  107. async fn connect_to_db(client: Client) -> anyhow::Result<tokio_postgres::Client> {
  108. let db_url = env::var("DATABASE_URL").expect("needs DATABASE_URL");
  109. if db_url.contains("rds.amazonaws.com") {
  110. let resp = client
  111. .get(CERT_URL)
  112. .send()
  113. .await
  114. .context("failed to get RDS cert")?;
  115. let cert = resp.bytes().await.context("faield to get RDS cert body")?;
  116. let cert = Certificate::from_pem(&cert).context("made certificate")?;
  117. let connector = TlsConnector::builder()
  118. .add_root_certificate(cert)
  119. .build()
  120. .context("built TlsConnector")?;
  121. let connector = MakeTlsConnector::new(connector);
  122. let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await {
  123. Ok(v) => v,
  124. Err(e) => {
  125. anyhow::bail!("failed to connect to DB: {}", e);
  126. }
  127. };
  128. tokio::spawn(async move {
  129. if let Err(e) = connection.await {
  130. eprintln!("database connection error: {}", e);
  131. }
  132. });
  133. Ok(db_client)
  134. } else {
  135. eprintln!("Warning: Non-TLS connection to non-RDS DB");
  136. let (db_client, connection) =
  137. match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await {
  138. Ok(v) => v,
  139. Err(e) => {
  140. anyhow::bail!("failed to connect to DB: {}", e);
  141. }
  142. };
  143. tokio::spawn(async move {
  144. if let Err(e) = connection.await {
  145. eprintln!("database connection error: {}", e);
  146. }
  147. });
  148. Ok(db_client)
  149. }
  150. }
  151. async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
  152. log::info!("Listening on http://{}", addr);
  153. let client = Client::new();
  154. let db_client = connect_to_db(client.clone())
  155. .await
  156. .context("open database connection")?;
  157. db::run_migrations(&db_client)
  158. .await
  159. .context("database migrations")?;
  160. let gh = github::GithubClient::new(
  161. client.clone(),
  162. env::var("GITHUB_API_TOKEN").expect("Missing GITHUB_API_TOKEN"),
  163. );
  164. let ctx = Arc::new(Context {
  165. username: github::User::current(&gh).await.unwrap().login,
  166. db: db_client,
  167. github: gh,
  168. });
  169. let svc = hyper::service::make_service_fn(move |_conn| {
  170. let ctx = ctx.clone();
  171. async move {
  172. let uuid = Uuid::new_v4();
  173. Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
  174. logger::LogFuture::new(
  175. uuid,
  176. serve_req(req, ctx.clone()).map(move |mut resp| {
  177. if let Ok(resp) = &mut resp {
  178. resp.headers_mut()
  179. .insert("X-Request-Id", uuid.to_string().parse().unwrap());
  180. }
  181. log::info!("response = {:?}", resp);
  182. resp
  183. }),
  184. )
  185. }))
  186. }
  187. });
  188. let serve_future = Server::bind(&addr).serve(svc);
  189. serve_future.await?;
  190. Ok(())
  191. }
  192. #[tokio::main]
  193. async fn main() {
  194. dotenv::dotenv().ok();
  195. logger::init();
  196. let port = env::var("PORT")
  197. .ok()
  198. .map(|p| p.parse::<u16>().expect("parsed PORT"))
  199. .unwrap_or(8000);
  200. let addr = ([0, 0, 0, 0], port).into();
  201. if let Err(e) = run_server(addr).await {
  202. eprintln!("Failed to run server: {}", e);
  203. }
  204. }