Bläddra i källkod

Move database client creation to db module

Also, this moves the download for the AWS TLS certificate to a one-off per run
of triagebot, placing it in a lazy_static. This slightly speeds up subsequent
`make_client()` calls and should work out fine, as in practice updates to that
certificate happen very rarely (and if it does, we'd just restart the app and
get a new one).
Mark Rousskov 5 år sedan
förälder
incheckning
95c8efbb2d
6 ändrade filer med 80 tillägg och 87 borttagningar
  1. 11 0
      Cargo.lock
  2. 1 1
      Cargo.toml
  3. 58 0
      src/db.rs
  4. 0 57
      src/handlers.rs
  5. 2 2
      src/main.rs
  6. 8 27
      src/zulip.rs

+ 11 - 0
Cargo.lock

@@ -693,6 +693,15 @@ dependencies = [
  "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
 ]
 
+[[package]]
+name = "num_cpus"
+version = "1.12.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "hermit-abi 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
+ "libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
 [[package]]
 name = "once_cell"
 version = "1.3.1"
@@ -1226,6 +1235,7 @@ dependencies = [
  "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
  "mio 0.6.21 (registry+https://github.com/rust-lang/crates.io-index)",
  "mio-uds 0.6.7 (registry+https://github.com/rust-lang/crates.io-index)",
+ "num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)",
  "pin-project-lite 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
  "slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)",
  "tokio-macros 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -1674,6 +1684,7 @@ dependencies = [
 "checksum nom 4.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2ad2a91a8e869eeb30b9cb3119ae87773a8f4ae617f41b1eb9c154b2905f7bd6"
 "checksum num-integer 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "3f6ea62e9d81a77cd3ee9a2a5b9b609447857f3d358704331e4ef39eb247fcba"
 "checksum num-traits 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "c62be47e61d1842b9170f0fdeec8eba98e60e90e5446449a0545e5152acd7096"
+"checksum num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6"
 "checksum once_cell 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b1c601810575c99596d4afc46f78a678c80105117c379eb3650cf99b8a21ce5b"
 "checksum opaque-debug 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2839e79665f131bdb5782e51f2c6c9599c133c6098982a54c794358bf432529c"
 "checksum openssl 0.10.29 (registry+https://github.com/rust-lang/crates.io-index)" = "cee6d85f4cb4c4f59a6a85d5b68a233d280c82e29e822913b9c8b129fbf20bdd"

+ 1 - 1
Cargo.toml

@@ -10,7 +10,7 @@ edition = "2018"
 serde_json = "1"
 openssl = "0.10"
 dotenv = "0.14"
-reqwest = { version = "0.10", features = ["json"] }
+reqwest = { version = "0.10", features = ["json", "blocking"] }
 regex = "1"
 lazy_static = "1"
 log = "0.4"

+ 58 - 0
src/db.rs

@@ -1,8 +1,66 @@
 use anyhow::Context as _;
+use native_tls::{Certificate, TlsConnector};
+use postgres_native_tls::MakeTlsConnector;
 pub use tokio_postgres::Client as DbClient;
 
 pub mod notifications;
 
+const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem";
+
+lazy_static::lazy_static! {
+    static ref CERTIFICATE_PEM: Vec<u8> = {
+        let client = reqwest::blocking::Client::new();
+        let resp = client
+            .get(CERT_URL)
+            .send()
+            .expect("failed to get RDS cert");
+         resp.bytes().expect("failed to get RDS cert body").to_vec()
+    };
+}
+
+pub async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
+    let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
+    if db_url.contains("rds.amazonaws.com") {
+        let cert = &CERTIFICATE_PEM[..];
+        let cert = Certificate::from_pem(&cert).context("made certificate")?;
+        let connector = TlsConnector::builder()
+            .add_root_certificate(cert)
+            .build()
+            .context("built TlsConnector")?;
+        let connector = MakeTlsConnector::new(connector);
+
+        let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await {
+            Ok(v) => v,
+            Err(e) => {
+                anyhow::bail!("failed to connect to DB: {}", e);
+            }
+        };
+        tokio::spawn(async move {
+            if let Err(e) = connection.await {
+                eprintln!("database connection error: {}", e);
+            }
+        });
+
+        Ok(db_client)
+    } else {
+        eprintln!("Warning: Non-TLS connection to non-RDS DB");
+        let (db_client, connection) =
+            match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await {
+                Ok(v) => v,
+                Err(e) => {
+                    anyhow::bail!("failed to connect to DB: {}", e);
+                }
+            };
+        tokio::spawn(async move {
+            if let Err(e) = connection.await {
+                eprintln!("database connection error: {}", e);
+            }
+        });
+
+        Ok(db_client)
+    }
+}
+
 pub async fn run_migrations(client: &DbClient) -> anyhow::Result<()> {
     client
         .execute(

+ 0 - 57
src/handlers.rs

@@ -1,9 +1,6 @@
 use crate::config::{self, ConfigurationError};
 use crate::github::{Event, GithubClient};
-use anyhow::Context as _;
 use futures::future::BoxFuture;
-use native_tls::{Certificate, TlsConnector};
-use postgres_native_tls::MakeTlsConnector;
 use std::fmt;
 use tokio_postgres::Client as DbClient;
 
@@ -81,60 +78,6 @@ pub struct Context {
     pub username: String,
 }
 
-const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem";
-
-impl Context {
-    pub async fn make_db_client(
-        client: &reqwest::Client,
-    ) -> anyhow::Result<tokio_postgres::Client> {
-        let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
-        if db_url.contains("rds.amazonaws.com") {
-            let resp = client
-                .get(CERT_URL)
-                .send()
-                .await
-                .context("failed to get RDS cert")?;
-            let cert = resp.bytes().await.context("faield to get RDS cert body")?;
-            let cert = Certificate::from_pem(&cert).context("made certificate")?;
-            let connector = TlsConnector::builder()
-                .add_root_certificate(cert)
-                .build()
-                .context("built TlsConnector")?;
-            let connector = MakeTlsConnector::new(connector);
-
-            let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await {
-                Ok(v) => v,
-                Err(e) => {
-                    anyhow::bail!("failed to connect to DB: {}", e);
-                }
-            };
-            tokio::spawn(async move {
-                if let Err(e) = connection.await {
-                    eprintln!("database connection error: {}", e);
-                }
-            });
-
-            Ok(db_client)
-        } else {
-            eprintln!("Warning: Non-TLS connection to non-RDS DB");
-            let (db_client, connection) =
-                match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await {
-                    Ok(v) => v,
-                    Err(e) => {
-                        anyhow::bail!("failed to connect to DB: {}", e);
-                    }
-                };
-            tokio::spawn(async move {
-                if let Err(e) = connection.await {
-                    eprintln!("database connection error: {}", e);
-                }
-            });
-
-            Ok(db_client)
-        }
-    }
-}
-
 pub trait Handler: Sync + Send {
     type Input;
     type Config;

+ 2 - 2
src/main.rs

@@ -156,14 +156,14 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
 async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
     log::info!("Listening on http://{}", addr);
 
-    let client = Client::new();
-    let db_client = Context::make_db_client(&client)
+    let db_client = db::make_client()
         .await
         .context("open database connection")?;
     db::run_migrations(&db_client)
         .await
         .context("database migrations")?;
 
+    let client = Client::new();
     let gh = github::GithubClient::new(
         client.clone(),
         env::var("GITHUB_API_TOKEN").expect("Missing GITHUB_API_TOKEN"),

+ 8 - 27
src/zulip.rs

@@ -100,7 +100,7 @@ fn handle_command<'a>(
                     ),
                 })
                 .unwrap(),
-            }
+            };
         }
         let gh_id = match gh_id {
             Ok(id) => id,
@@ -108,7 +108,7 @@ fn handle_command<'a>(
         };
 
         match next {
-            Some("acknowledge") | Some("ack") => match acknowledge(&ctx, gh_id, words).await {
+            Some("acknowledge") | Some("ack") => match acknowledge(gh_id, words).await {
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
@@ -128,7 +128,7 @@ fn handle_command<'a>(
                 })
                 .unwrap(),
             },
-            Some("move") => match move_notification(&ctx, gh_id, words).await {
+            Some("move") => match move_notification(gh_id, words).await {
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
@@ -138,7 +138,7 @@ fn handle_command<'a>(
                 })
                 .unwrap(),
             },
-            Some("meta") => match add_meta_notification(&ctx, gh_id, words).await {
+            Some("meta") => match add_meta_notification(gh_id, words).await {
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
@@ -316,11 +316,7 @@ struct MessageApiRequest<'a> {
     content: &'a str,
 }
 
-async fn acknowledge(
-    ctx: &Context,
-    gh_id: i64,
-    mut words: impl Iterator<Item = &str>,
-) -> anyhow::Result<String> {
+async fn acknowledge(gh_id: i64, mut words: impl Iterator<Item = &str>) -> anyhow::Result<String> {
     let url = match words.next() {
         Some(url) => {
             if words.next().is_some() {
@@ -338,13 +334,7 @@ async fn acknowledge(
     } else {
         Identifier::Url(url)
     };
-    match delete_ping(
-        &mut Context::make_db_client(&ctx.github.raw()).await?,
-        gh_id,
-        ident,
-    )
-    .await
-    {
+    match delete_ping(&mut crate::db::make_client().await?, gh_id, ident).await {
         Ok(deleted) => {
             let mut resp = format!("Acknowledged:\n");
             for deleted in deleted {
@@ -414,7 +404,6 @@ async fn add_notification(
 }
 
 async fn add_meta_notification(
-    ctx: &Context,
     gh_id: i64,
     mut words: impl Iterator<Item = &str>,
 ) -> anyhow::Result<String> {
@@ -439,7 +428,7 @@ async fn add_meta_notification(
         Some(description)
     };
     match add_metadata(
-        &mut Context::make_db_client(&ctx.github.raw()).await?,
+        &mut crate::db::make_client().await?,
         gh_id,
         idx,
         description.as_deref(),
@@ -458,7 +447,6 @@ async fn add_meta_notification(
 }
 
 async fn move_notification(
-    ctx: &Context,
     gh_id: i64,
     mut words: impl Iterator<Item = &str>,
 ) -> anyhow::Result<String> {
@@ -480,14 +468,7 @@ async fn move_notification(
         .context("to index")?
         .checked_sub(1)
         .ok_or_else(|| anyhow::anyhow!("1-based indexes"))?;
-    match move_indices(
-        &mut Context::make_db_client(&ctx.github.raw()).await?,
-        gh_id,
-        from,
-        to,
-    )
-    .await
-    {
+    match move_indices(&mut crate::db::make_client().await?, gh_id, from, to).await {
         Ok(()) => Ok(serde_json::to_string(&Response {
             // to 1-base indices
             content: &format!("Moved {} to {}.", from + 1, to + 1),