浏览代码

Pool database connection and test validity

This should mitigate the periodic downtime we experience if the database
connection dies, which previously required an external restart of the service.
Now, if the connection is closed, the next request will automatically either use
a different pooled connection or open a new one.

The pooling implementation is partially extracted from rustc-perf, which has not
encountered the database errors in production that triagebot has.
Mark Rousskov 3 年之前
父节点
当前提交
50f2ff8c8b
共有 7 个文件被更改,包括 98 次插入31 次删除
  1. 70 2
      src/db.rs
  2. 1 2
      src/handlers.rs
  3. 3 2
      src/handlers/notification.rs
  4. 2 1
      src/handlers/rustc_commits.rs
  5. 5 7
      src/main.rs
  6. 1 2
      src/notification_listing.rs
  7. 16 15
      src/zulip.rs

+ 70 - 2
src/db.rs

@@ -1,7 +1,9 @@
 use anyhow::Context as _;
 use anyhow::Context as _;
 use native_tls::{Certificate, TlsConnector};
 use native_tls::{Certificate, TlsConnector};
 use postgres_native_tls::MakeTlsConnector;
 use postgres_native_tls::MakeTlsConnector;
-pub use tokio_postgres::Client as DbClient;
+use std::sync::{Arc, Mutex};
+use tokio::sync::{OwnedSemaphorePermit, Semaphore};
+use tokio_postgres::Client as DbClient;
 
 
 pub mod notifications;
 pub mod notifications;
 pub mod rustc_commits;
 pub mod rustc_commits;
@@ -19,7 +21,73 @@ lazy_static::lazy_static! {
     };
     };
 }
 }
 
 
-pub async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
+pub struct ClientPool {
+    connections: Arc<Mutex<Vec<tokio_postgres::Client>>>,
+    permits: Arc<Semaphore>,
+}
+
+pub struct PooledClient {
+    client: Option<tokio_postgres::Client>,
+    #[allow(unused)] // only used for drop impl
+    permit: OwnedSemaphorePermit,
+    pool: Arc<Mutex<Vec<tokio_postgres::Client>>>,
+}
+
+impl Drop for PooledClient {
+    fn drop(&mut self) {
+        let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner());
+        clients.push(self.client.take().unwrap());
+    }
+}
+
+impl std::ops::Deref for PooledClient {
+    type Target = tokio_postgres::Client;
+
+    fn deref(&self) -> &Self::Target {
+        self.client.as_ref().unwrap()
+    }
+}
+
+impl std::ops::DerefMut for PooledClient {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        self.client.as_mut().unwrap()
+    }
+}
+
+impl ClientPool {
+    pub fn new() -> ClientPool {
+        ClientPool {
+            connections: Arc::new(Mutex::new(Vec::with_capacity(16))),
+            permits: Arc::new(Semaphore::new(16)),
+        }
+    }
+
+    pub async fn get(&self) -> PooledClient {
+        let permit = self.permits.clone().acquire_owned().await.unwrap();
+        {
+            let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner());
+            // Pop connections until we hit a non-closed connection (or there are no
+            // "possibly open" connections left).
+            while let Some(c) = slots.pop() {
+                if !c.is_closed() {
+                    return PooledClient {
+                        client: Some(c),
+                        permit,
+                        pool: self.connections.clone(),
+                    };
+                }
+            }
+        }
+
+        PooledClient {
+            client: Some(make_client().await.unwrap()),
+            permit,
+            pool: self.connections.clone(),
+        }
+    }
+}
+
+async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
     let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
     let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
     if db_url.contains("rds.amazonaws.com") {
     if db_url.contains("rds.amazonaws.com") {
         let cert = &CERTIFICATE_PEM[..];
         let cert = &CERTIFICATE_PEM[..];

+ 1 - 2
src/handlers.rs

@@ -4,7 +4,6 @@ use octocrab::Octocrab;
 use parser::command::{Command, Input};
 use parser::command::{Command, Input};
 use std::fmt;
 use std::fmt;
 use std::sync::Arc;
 use std::sync::Arc;
-use tokio_postgres::Client as DbClient;
 
 
 #[derive(Debug)]
 #[derive(Debug)]
 pub enum HandlerError {
 pub enum HandlerError {
@@ -247,7 +246,7 @@ command_handlers! {
 
 
 pub struct Context {
 pub struct Context {
     pub github: GithubClient,
     pub github: GithubClient,
-    pub db: DbClient,
+    pub db: crate::db::ClientPool,
     pub username: String,
     pub username: String,
     pub octocrab: Octocrab,
     pub octocrab: Octocrab,
 }
 }

+ 3 - 2
src/handlers/notification.rs

@@ -92,13 +92,14 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
             }
             }
         };
         };
 
 
+        let client = ctx.db.get().await;
         for user in users {
         for user in users {
             if !users_notified.insert(user.id.unwrap()) {
             if !users_notified.insert(user.id.unwrap()) {
                 // Skip users already associated with this event.
                 // Skip users already associated with this event.
                 continue;
                 continue;
             }
             }
 
 
-            if let Err(err) = notifications::record_username(&ctx.db, user.id.unwrap(), user.login)
+            if let Err(err) = notifications::record_username(&client, user.id.unwrap(), user.login)
                 .await
                 .await
                 .context("failed to record username")
                 .context("failed to record username")
             {
             {
@@ -106,7 +107,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
             }
             }
 
 
             if let Err(err) = notifications::record_ping(
             if let Err(err) = notifications::record_ping(
-                &ctx.db,
+                &client,
                 &notifications::Notification {
                 &notifications::Notification {
                     user_id: user.id.unwrap(),
                     user_id: user.id.unwrap(),
                     origin_url: event.html_url().unwrap().to_owned(),
                     origin_url: event.html_url().unwrap().to_owned(),

+ 2 - 1
src/handlers/rustc_commits.rs

@@ -72,6 +72,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
     let mut sha = bors.merge_sha;
     let mut sha = bors.merge_sha;
     let mut pr = Some(event.issue.number.try_into().unwrap());
     let mut pr = Some(event.issue.number.try_into().unwrap());
 
 
+    let db = ctx.db.get().await;
     loop {
     loop {
         // FIXME: ideally we would pull in all the commits here, but unfortunately
         // FIXME: ideally we would pull in all the commits here, but unfortunately
         // in rust-lang/rust's case there's bors-authored commits that aren't
         // in rust-lang/rust's case there's bors-authored commits that aren't
@@ -101,7 +102,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
         };
         };
 
 
         let res = rustc_commits::record_commit(
         let res = rustc_commits::record_commit(
-            &ctx.db,
+            &db,
             rustc_commits::Commit {
             rustc_commits::Commit {
                 sha: gc.sha,
                 sha: gc.sha,
                 parent_sha: parent_sha.clone(),
                 parent_sha: parent_sha.clone(),

+ 5 - 7
src/main.rs

@@ -34,7 +34,7 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
             .unwrap());
             .unwrap());
     }
     }
     if req.uri.path() == "/bors-commit-list" {
     if req.uri.path() == "/bors-commit-list" {
-        let res = db::rustc_commits::get_commits_with_artifacts(&ctx.db).await;
+        let res = db::rustc_commits::get_commits_with_artifacts(&*ctx.db.get().await).await;
         let res = match res {
         let res = match res {
             Ok(r) => r,
             Ok(r) => r,
             Err(e) => {
             Err(e) => {
@@ -57,7 +57,7 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
                 return Ok(Response::builder()
                 return Ok(Response::builder()
                     .status(StatusCode::OK)
                     .status(StatusCode::OK)
                     .body(Body::from(
                     .body(Body::from(
-                        notification_listing::render(&ctx.db, &*name).await,
+                        notification_listing::render(&ctx.db.get().await, &*name).await,
                     ))
                     ))
                     .unwrap());
                     .unwrap());
             }
             }
@@ -187,10 +187,8 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
 async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
 async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
     log::info!("Listening on http://{}", addr);
     log::info!("Listening on http://{}", addr);
 
 
-    let db_client = db::make_client()
-        .await
-        .context("open database connection")?;
-    db::run_migrations(&db_client)
+    let pool = db::ClientPool::new();
+    db::run_migrations(&*pool.get().await)
         .await
         .await
         .context("database migrations")?;
         .context("database migrations")?;
 
 
@@ -202,7 +200,7 @@ async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
         .expect("Failed to build octograb.");
         .expect("Failed to build octograb.");
     let ctx = Arc::new(Context {
     let ctx = Arc::new(Context {
         username: String::from("rustbot"),
         username: String::from("rustbot"),
-        db: db_client,
+        db: pool,
         github: gh,
         github: gh,
         octocrab: oc,
         octocrab: oc,
     });
     });

+ 1 - 2
src/notification_listing.rs

@@ -1,7 +1,6 @@
 use crate::db::notifications::get_notifications;
 use crate::db::notifications::get_notifications;
-use crate::db::DbClient;
 
 
-pub async fn render(db: &DbClient, user: &str) -> String {
+pub async fn render(db: &crate::db::PooledClient, user: &str) -> String {
     let notifications = match get_notifications(db, user).await {
     let notifications = match get_notifications(db, user).await {
         Ok(n) => n,
         Ok(n) => n,
         Err(e) => {
         Err(e) => {

+ 16 - 15
src/zulip.rs

@@ -119,7 +119,7 @@ fn handle_command<'a>(
         };
         };
 
 
         match next {
         match next {
-            Some("acknowledge") | Some("ack") => match acknowledge(gh_id, words).await {
+            Some("acknowledge") | Some("ack") => match acknowledge(&ctx, gh_id, words).await {
                 Ok(r) => r,
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
                     content: &format!(
@@ -139,7 +139,7 @@ fn handle_command<'a>(
                 })
                 })
                 .unwrap(),
                 .unwrap(),
             },
             },
-            Some("move") => match move_notification(gh_id, words).await {
+            Some("move") => match move_notification(&ctx, gh_id, words).await {
                 Ok(r) => r,
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
                     content: &format!(
@@ -149,7 +149,7 @@ fn handle_command<'a>(
                 })
                 })
                 .unwrap(),
                 .unwrap(),
             },
             },
-            Some("meta") => match add_meta_notification(gh_id, words).await {
+            Some("meta") => match add_meta_notification(&ctx, gh_id, words).await {
                 Ok(r) => r,
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
                     content: &format!(
@@ -513,7 +513,11 @@ impl<'a> UpdateMessageApiRequest<'a> {
     }
     }
 }
 }
 
 
-async fn acknowledge(gh_id: i64, mut words: impl Iterator<Item = &str>) -> anyhow::Result<String> {
+async fn acknowledge(
+    ctx: &Context,
+    gh_id: i64,
+    mut words: impl Iterator<Item = &str>,
+) -> anyhow::Result<String> {
     let filter = match words.next() {
     let filter = match words.next() {
         Some(filter) => {
         Some(filter) => {
             if words.next().is_some() {
             if words.next().is_some() {
@@ -533,7 +537,8 @@ async fn acknowledge(gh_id: i64, mut words: impl Iterator<Item = &str>) -> anyho
     } else {
     } else {
         Identifier::Url(filter)
         Identifier::Url(filter)
     };
     };
-    match delete_ping(&mut crate::db::make_client().await?, gh_id, ident).await {
+    let mut db = ctx.db.get().await;
+    match delete_ping(&mut *db, gh_id, ident).await {
         Ok(deleted) => {
         Ok(deleted) => {
             let resp = if deleted.is_empty() {
             let resp = if deleted.is_empty() {
                 format!(
                 format!(
@@ -588,7 +593,7 @@ async fn add_notification(
         Some(description)
         Some(description)
     };
     };
     match record_ping(
     match record_ping(
-        &ctx.db,
+        &*ctx.db.get().await,
         &notifications::Notification {
         &notifications::Notification {
             user_id: gh_id,
             user_id: gh_id,
             origin_url: url.to_owned(),
             origin_url: url.to_owned(),
@@ -612,6 +617,7 @@ async fn add_notification(
 }
 }
 
 
 async fn add_meta_notification(
 async fn add_meta_notification(
+    ctx: &Context,
     gh_id: i64,
     gh_id: i64,
     mut words: impl Iterator<Item = &str>,
     mut words: impl Iterator<Item = &str>,
 ) -> anyhow::Result<String> {
 ) -> anyhow::Result<String> {
@@ -635,14 +641,8 @@ async fn add_meta_notification(
         assert_eq!(description.pop(), Some(' ')); // pop trailing space
         assert_eq!(description.pop(), Some(' ')); // pop trailing space
         Some(description)
         Some(description)
     };
     };
-    match add_metadata(
-        &mut crate::db::make_client().await?,
-        gh_id,
-        idx,
-        description.as_deref(),
-    )
-    .await
-    {
+    let mut db = ctx.db.get().await;
+    match add_metadata(&mut db, gh_id, idx, description.as_deref()).await {
         Ok(()) => Ok(serde_json::to_string(&Response {
         Ok(()) => Ok(serde_json::to_string(&Response {
             content: "Added metadata!",
             content: "Added metadata!",
         })
         })
@@ -655,6 +655,7 @@ async fn add_meta_notification(
 }
 }
 
 
 async fn move_notification(
 async fn move_notification(
+    ctx: &Context,
     gh_id: i64,
     gh_id: i64,
     mut words: impl Iterator<Item = &str>,
     mut words: impl Iterator<Item = &str>,
 ) -> anyhow::Result<String> {
 ) -> anyhow::Result<String> {
@@ -676,7 +677,7 @@ async fn move_notification(
         .context("to index")?
         .context("to index")?
         .checked_sub(1)
         .checked_sub(1)
         .ok_or_else(|| anyhow::anyhow!("1-based indexes"))?;
         .ok_or_else(|| anyhow::anyhow!("1-based indexes"))?;
-    match move_indices(&mut crate::db::make_client().await?, gh_id, from, to).await {
+    match move_indices(&mut *ctx.db.get().await, gh_id, from, to).await {
         Ok(()) => Ok(serde_json::to_string(&Response {
         Ok(()) => Ok(serde_json::to_string(&Response {
             // to 1-base indices
             // to 1-base indices
             content: &format!("Moved {} to {}.", from + 1, to + 1),
             content: &format!("Moved {} to {}.", from + 1, to + 1),