Quellcode durchsuchen

Merge pull request #1481 from Mark-Simulacrum/pool

Pool database connection and test validity
Mark Rousskov vor 3 Jahren
Ursprung
Commit
b5c6c17c42
7 geänderte Dateien mit 98 neuen und 31 gelöschten Zeilen
  1. 70 2
      src/db.rs
  2. 1 2
      src/handlers.rs
  3. 3 2
      src/handlers/notification.rs
  4. 2 1
      src/handlers/rustc_commits.rs
  5. 5 7
      src/main.rs
  6. 1 2
      src/notification_listing.rs
  7. 16 15
      src/zulip.rs

+ 70 - 2
src/db.rs

@@ -1,7 +1,9 @@
 use anyhow::Context as _;
 use native_tls::{Certificate, TlsConnector};
 use postgres_native_tls::MakeTlsConnector;
-pub use tokio_postgres::Client as DbClient;
+use std::sync::{Arc, Mutex};
+use tokio::sync::{OwnedSemaphorePermit, Semaphore};
+use tokio_postgres::Client as DbClient;
 
 pub mod notifications;
 pub mod rustc_commits;
@@ -19,7 +21,73 @@ lazy_static::lazy_static! {
     };
 }
 
-pub async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
+pub struct ClientPool {
+    connections: Arc<Mutex<Vec<tokio_postgres::Client>>>,
+    permits: Arc<Semaphore>,
+}
+
+pub struct PooledClient {
+    client: Option<tokio_postgres::Client>,
+    #[allow(unused)] // only used for drop impl
+    permit: OwnedSemaphorePermit,
+    pool: Arc<Mutex<Vec<tokio_postgres::Client>>>,
+}
+
+impl Drop for PooledClient {
+    fn drop(&mut self) {
+        let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner());
+        clients.push(self.client.take().unwrap());
+    }
+}
+
+impl std::ops::Deref for PooledClient {
+    type Target = tokio_postgres::Client;
+
+    fn deref(&self) -> &Self::Target {
+        self.client.as_ref().unwrap()
+    }
+}
+
+impl std::ops::DerefMut for PooledClient {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        self.client.as_mut().unwrap()
+    }
+}
+
+impl ClientPool {
+    pub fn new() -> ClientPool {
+        ClientPool {
+            connections: Arc::new(Mutex::new(Vec::with_capacity(16))),
+            permits: Arc::new(Semaphore::new(16)),
+        }
+    }
+
+    pub async fn get(&self) -> PooledClient {
+        let permit = self.permits.clone().acquire_owned().await.unwrap();
+        {
+            let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner());
+            // Pop connections until we hit a non-closed connection (or there are no
+            // "possibly open" connections left).
+            while let Some(c) = slots.pop() {
+                if !c.is_closed() {
+                    return PooledClient {
+                        client: Some(c),
+                        permit,
+                        pool: self.connections.clone(),
+                    };
+                }
+            }
+        }
+
+        PooledClient {
+            client: Some(make_client().await.unwrap()),
+            permit,
+            pool: self.connections.clone(),
+        }
+    }
+}
+
+async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
     let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
     if db_url.contains("rds.amazonaws.com") {
         let cert = &CERTIFICATE_PEM[..];

+ 1 - 2
src/handlers.rs

@@ -4,7 +4,6 @@ use octocrab::Octocrab;
 use parser::command::{Command, Input};
 use std::fmt;
 use std::sync::Arc;
-use tokio_postgres::Client as DbClient;
 
 #[derive(Debug)]
 pub enum HandlerError {
@@ -247,7 +246,7 @@ command_handlers! {
 
 pub struct Context {
     pub github: GithubClient,
-    pub db: DbClient,
+    pub db: crate::db::ClientPool,
     pub username: String,
     pub octocrab: Octocrab,
 }

+ 3 - 2
src/handlers/notification.rs

@@ -92,13 +92,14 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
             }
         };
 
+        let client = ctx.db.get().await;
         for user in users {
             if !users_notified.insert(user.id.unwrap()) {
                 // Skip users already associated with this event.
                 continue;
             }
 
-            if let Err(err) = notifications::record_username(&ctx.db, user.id.unwrap(), user.login)
+            if let Err(err) = notifications::record_username(&client, user.id.unwrap(), user.login)
                 .await
                 .context("failed to record username")
             {
@@ -106,7 +107,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
             }
 
             if let Err(err) = notifications::record_ping(
-                &ctx.db,
+                &client,
                 &notifications::Notification {
                     user_id: user.id.unwrap(),
                     origin_url: event.html_url().unwrap().to_owned(),

+ 2 - 1
src/handlers/rustc_commits.rs

@@ -72,6 +72,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
     let mut sha = bors.merge_sha;
     let mut pr = Some(event.issue.number.try_into().unwrap());
 
+    let db = ctx.db.get().await;
     loop {
         // FIXME: ideally we would pull in all the commits here, but unfortunately
         // in rust-lang/rust's case there's bors-authored commits that aren't
@@ -101,7 +102,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
         };
 
         let res = rustc_commits::record_commit(
-            &ctx.db,
+            &db,
             rustc_commits::Commit {
                 sha: gc.sha,
                 parent_sha: parent_sha.clone(),

+ 5 - 7
src/main.rs

@@ -34,7 +34,7 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
             .unwrap());
     }
     if req.uri.path() == "/bors-commit-list" {
-        let res = db::rustc_commits::get_commits_with_artifacts(&ctx.db).await;
+        let res = db::rustc_commits::get_commits_with_artifacts(&*ctx.db.get().await).await;
         let res = match res {
             Ok(r) => r,
             Err(e) => {
@@ -57,7 +57,7 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
                 return Ok(Response::builder()
                     .status(StatusCode::OK)
                     .body(Body::from(
-                        notification_listing::render(&ctx.db, &*name).await,
+                        notification_listing::render(&ctx.db.get().await, &*name).await,
                     ))
                     .unwrap());
             }
@@ -187,10 +187,8 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
 async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
     log::info!("Listening on http://{}", addr);
 
-    let db_client = db::make_client()
-        .await
-        .context("open database connection")?;
-    db::run_migrations(&db_client)
+    let pool = db::ClientPool::new();
+    db::run_migrations(&*pool.get().await)
         .await
         .context("database migrations")?;
 
@@ -202,7 +200,7 @@ async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
         .expect("Failed to build octograb.");
     let ctx = Arc::new(Context {
         username: String::from("rustbot"),
-        db: db_client,
+        db: pool,
         github: gh,
         octocrab: oc,
     });

+ 1 - 2
src/notification_listing.rs

@@ -1,7 +1,6 @@
 use crate::db::notifications::get_notifications;
-use crate::db::DbClient;
 
-pub async fn render(db: &DbClient, user: &str) -> String {
+pub async fn render(db: &crate::db::PooledClient, user: &str) -> String {
     let notifications = match get_notifications(db, user).await {
         Ok(n) => n,
         Err(e) => {

+ 16 - 15
src/zulip.rs

@@ -119,7 +119,7 @@ fn handle_command<'a>(
         };
 
         match next {
-            Some("acknowledge") | Some("ack") => match acknowledge(gh_id, words).await {
+            Some("acknowledge") | Some("ack") => match acknowledge(&ctx, gh_id, words).await {
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
@@ -139,7 +139,7 @@ fn handle_command<'a>(
                 })
                 .unwrap(),
             },
-            Some("move") => match move_notification(gh_id, words).await {
+            Some("move") => match move_notification(&ctx, gh_id, words).await {
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
@@ -149,7 +149,7 @@ fn handle_command<'a>(
                 })
                 .unwrap(),
             },
-            Some("meta") => match add_meta_notification(gh_id, words).await {
+            Some("meta") => match add_meta_notification(&ctx, gh_id, words).await {
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
@@ -513,7 +513,11 @@ impl<'a> UpdateMessageApiRequest<'a> {
     }
 }
 
-async fn acknowledge(gh_id: i64, mut words: impl Iterator<Item = &str>) -> anyhow::Result<String> {
+async fn acknowledge(
+    ctx: &Context,
+    gh_id: i64,
+    mut words: impl Iterator<Item = &str>,
+) -> anyhow::Result<String> {
     let filter = match words.next() {
         Some(filter) => {
             if words.next().is_some() {
@@ -533,7 +537,8 @@ async fn acknowledge(gh_id: i64, mut words: impl Iterator<Item = &str>) -> anyho
     } else {
         Identifier::Url(filter)
     };
-    match delete_ping(&mut crate::db::make_client().await?, gh_id, ident).await {
+    let mut db = ctx.db.get().await;
+    match delete_ping(&mut *db, gh_id, ident).await {
         Ok(deleted) => {
             let resp = if deleted.is_empty() {
                 format!(
@@ -588,7 +593,7 @@ async fn add_notification(
         Some(description)
     };
     match record_ping(
-        &ctx.db,
+        &*ctx.db.get().await,
         &notifications::Notification {
             user_id: gh_id,
             origin_url: url.to_owned(),
@@ -612,6 +617,7 @@ async fn add_notification(
 }
 
 async fn add_meta_notification(
+    ctx: &Context,
     gh_id: i64,
     mut words: impl Iterator<Item = &str>,
 ) -> anyhow::Result<String> {
@@ -635,14 +641,8 @@ async fn add_meta_notification(
         assert_eq!(description.pop(), Some(' ')); // pop trailing space
         Some(description)
     };
-    match add_metadata(
-        &mut crate::db::make_client().await?,
-        gh_id,
-        idx,
-        description.as_deref(),
-    )
-    .await
-    {
+    let mut db = ctx.db.get().await;
+    match add_metadata(&mut db, gh_id, idx, description.as_deref()).await {
         Ok(()) => Ok(serde_json::to_string(&Response {
             content: "Added metadata!",
         })
@@ -655,6 +655,7 @@ async fn add_meta_notification(
 }
 
 async fn move_notification(
+    ctx: &Context,
     gh_id: i64,
     mut words: impl Iterator<Item = &str>,
 ) -> anyhow::Result<String> {
@@ -676,7 +677,7 @@ async fn move_notification(
         .context("to index")?
         .checked_sub(1)
         .ok_or_else(|| anyhow::anyhow!("1-based indexes"))?;
-    match move_indices(&mut crate::db::make_client().await?, gh_id, from, to).await {
+    match move_indices(&mut *ctx.db.get().await, gh_id, from, to).await {
         Ok(()) => Ok(serde_json::to_string(&Response {
             // to 1-base indices
             content: &format!("Moved {} to {}.", from + 1, to + 1),