Parcourir la source

Pool database connection and test validity

This should mitigate the periodic downtime we experience if the database
connection dies, which previously required an external restart of the service.
Now, if the connection is closed, the next request will automatically either use
a different pooled connection or open a new one.

The pooling implementation is partially extracted from rustc-perf, which has not
encountered the database errors in production that triagebot has.
Mark Rousskov il y a 3 ans
Parent
commit
50f2ff8c8b
7 fichiers modifiés avec 98 ajouts et 31 suppressions
  1. 70 2
      src/db.rs
  2. 1 2
      src/handlers.rs
  3. 3 2
      src/handlers/notification.rs
  4. 2 1
      src/handlers/rustc_commits.rs
  5. 5 7
      src/main.rs
  6. 1 2
      src/notification_listing.rs
  7. 16 15
      src/zulip.rs

+ 70 - 2
src/db.rs

@@ -1,7 +1,9 @@
 use anyhow::Context as _;
 use native_tls::{Certificate, TlsConnector};
 use postgres_native_tls::MakeTlsConnector;
-pub use tokio_postgres::Client as DbClient;
+use std::sync::{Arc, Mutex};
+use tokio::sync::{OwnedSemaphorePermit, Semaphore};
+use tokio_postgres::Client as DbClient;
 
 pub mod notifications;
 pub mod rustc_commits;
@@ -19,7 +21,73 @@ lazy_static::lazy_static! {
     };
 }
 
-pub async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
+pub struct ClientPool {
+    connections: Arc<Mutex<Vec<tokio_postgres::Client>>>,
+    permits: Arc<Semaphore>,
+}
+
+pub struct PooledClient {
+    client: Option<tokio_postgres::Client>,
+    #[allow(unused)] // only used for drop impl
+    permit: OwnedSemaphorePermit,
+    pool: Arc<Mutex<Vec<tokio_postgres::Client>>>,
+}
+
+impl Drop for PooledClient {
+    fn drop(&mut self) {
+        let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner());
+        clients.push(self.client.take().unwrap());
+    }
+}
+
+impl std::ops::Deref for PooledClient {
+    type Target = tokio_postgres::Client;
+
+    fn deref(&self) -> &Self::Target {
+        self.client.as_ref().unwrap()
+    }
+}
+
+impl std::ops::DerefMut for PooledClient {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        self.client.as_mut().unwrap()
+    }
+}
+
+impl ClientPool {
+    pub fn new() -> ClientPool {
+        ClientPool {
+            connections: Arc::new(Mutex::new(Vec::with_capacity(16))),
+            permits: Arc::new(Semaphore::new(16)),
+        }
+    }
+
+    pub async fn get(&self) -> PooledClient {
+        let permit = self.permits.clone().acquire_owned().await.unwrap();
+        {
+            let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner());
+            // Pop connections until we hit a non-closed connection (or there are no
+            // "possibly open" connections left).
+            while let Some(c) = slots.pop() {
+                if !c.is_closed() {
+                    return PooledClient {
+                        client: Some(c),
+                        permit,
+                        pool: self.connections.clone(),
+                    };
+                }
+            }
+        }
+
+        PooledClient {
+            client: Some(make_client().await.unwrap()),
+            permit,
+            pool: self.connections.clone(),
+        }
+    }
+}
+
+async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
     let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
     if db_url.contains("rds.amazonaws.com") {
         let cert = &CERTIFICATE_PEM[..];

+ 1 - 2
src/handlers.rs

@@ -4,7 +4,6 @@ use octocrab::Octocrab;
 use parser::command::{Command, Input};
 use std::fmt;
 use std::sync::Arc;
-use tokio_postgres::Client as DbClient;
 
 #[derive(Debug)]
 pub enum HandlerError {
@@ -247,7 +246,7 @@ command_handlers! {
 
 pub struct Context {
     pub github: GithubClient,
-    pub db: DbClient,
+    pub db: crate::db::ClientPool,
     pub username: String,
     pub octocrab: Octocrab,
 }

+ 3 - 2
src/handlers/notification.rs

@@ -92,13 +92,14 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
             }
         };
 
+        let client = ctx.db.get().await;
         for user in users {
             if !users_notified.insert(user.id.unwrap()) {
                 // Skip users already associated with this event.
                 continue;
             }
 
-            if let Err(err) = notifications::record_username(&ctx.db, user.id.unwrap(), user.login)
+            if let Err(err) = notifications::record_username(&client, user.id.unwrap(), user.login)
                 .await
                 .context("failed to record username")
             {
@@ -106,7 +107,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
             }
 
             if let Err(err) = notifications::record_ping(
-                &ctx.db,
+                &client,
                 &notifications::Notification {
                     user_id: user.id.unwrap(),
                     origin_url: event.html_url().unwrap().to_owned(),

+ 2 - 1
src/handlers/rustc_commits.rs

@@ -72,6 +72,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
     let mut sha = bors.merge_sha;
     let mut pr = Some(event.issue.number.try_into().unwrap());
 
+    let db = ctx.db.get().await;
     loop {
         // FIXME: ideally we would pull in all the commits here, but unfortunately
         // in rust-lang/rust's case there's bors-authored commits that aren't
@@ -101,7 +102,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
         };
 
         let res = rustc_commits::record_commit(
-            &ctx.db,
+            &db,
             rustc_commits::Commit {
                 sha: gc.sha,
                 parent_sha: parent_sha.clone(),

+ 5 - 7
src/main.rs

@@ -34,7 +34,7 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
             .unwrap());
     }
     if req.uri.path() == "/bors-commit-list" {
-        let res = db::rustc_commits::get_commits_with_artifacts(&ctx.db).await;
+        let res = db::rustc_commits::get_commits_with_artifacts(&*ctx.db.get().await).await;
         let res = match res {
             Ok(r) => r,
             Err(e) => {
@@ -57,7 +57,7 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
                 return Ok(Response::builder()
                     .status(StatusCode::OK)
                     .body(Body::from(
-                        notification_listing::render(&ctx.db, &*name).await,
+                        notification_listing::render(&ctx.db.get().await, &*name).await,
                     ))
                     .unwrap());
             }
@@ -187,10 +187,8 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
 async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
     log::info!("Listening on http://{}", addr);
 
-    let db_client = db::make_client()
-        .await
-        .context("open database connection")?;
-    db::run_migrations(&db_client)
+    let pool = db::ClientPool::new();
+    db::run_migrations(&*pool.get().await)
         .await
         .context("database migrations")?;
 
@@ -202,7 +200,7 @@ async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
         .expect("Failed to build octograb.");
     let ctx = Arc::new(Context {
         username: String::from("rustbot"),
-        db: db_client,
+        db: pool,
         github: gh,
         octocrab: oc,
     });

+ 1 - 2
src/notification_listing.rs

@@ -1,7 +1,6 @@
 use crate::db::notifications::get_notifications;
-use crate::db::DbClient;
 
-pub async fn render(db: &DbClient, user: &str) -> String {
+pub async fn render(db: &crate::db::PooledClient, user: &str) -> String {
     let notifications = match get_notifications(db, user).await {
         Ok(n) => n,
         Err(e) => {

+ 16 - 15
src/zulip.rs

@@ -119,7 +119,7 @@ fn handle_command<'a>(
         };
 
         match next {
-            Some("acknowledge") | Some("ack") => match acknowledge(gh_id, words).await {
+            Some("acknowledge") | Some("ack") => match acknowledge(&ctx, gh_id, words).await {
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
@@ -139,7 +139,7 @@ fn handle_command<'a>(
                 })
                 .unwrap(),
             },
-            Some("move") => match move_notification(gh_id, words).await {
+            Some("move") => match move_notification(&ctx, gh_id, words).await {
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
@@ -149,7 +149,7 @@ fn handle_command<'a>(
                 })
                 .unwrap(),
             },
-            Some("meta") => match add_meta_notification(gh_id, words).await {
+            Some("meta") => match add_meta_notification(&ctx, gh_id, words).await {
                 Ok(r) => r,
                 Err(e) => serde_json::to_string(&Response {
                     content: &format!(
@@ -513,7 +513,11 @@ impl<'a> UpdateMessageApiRequest<'a> {
     }
 }
 
-async fn acknowledge(gh_id: i64, mut words: impl Iterator<Item = &str>) -> anyhow::Result<String> {
+async fn acknowledge(
+    ctx: &Context,
+    gh_id: i64,
+    mut words: impl Iterator<Item = &str>,
+) -> anyhow::Result<String> {
     let filter = match words.next() {
         Some(filter) => {
             if words.next().is_some() {
@@ -533,7 +537,8 @@ async fn acknowledge(gh_id: i64, mut words: impl Iterator<Item = &str>) -> anyho
     } else {
         Identifier::Url(filter)
     };
-    match delete_ping(&mut crate::db::make_client().await?, gh_id, ident).await {
+    let mut db = ctx.db.get().await;
+    match delete_ping(&mut *db, gh_id, ident).await {
         Ok(deleted) => {
             let resp = if deleted.is_empty() {
                 format!(
@@ -588,7 +593,7 @@ async fn add_notification(
         Some(description)
     };
     match record_ping(
-        &ctx.db,
+        &*ctx.db.get().await,
         &notifications::Notification {
             user_id: gh_id,
             origin_url: url.to_owned(),
@@ -612,6 +617,7 @@ async fn add_notification(
 }
 
 async fn add_meta_notification(
+    ctx: &Context,
     gh_id: i64,
     mut words: impl Iterator<Item = &str>,
 ) -> anyhow::Result<String> {
@@ -635,14 +641,8 @@ async fn add_meta_notification(
         assert_eq!(description.pop(), Some(' ')); // pop trailing space
         Some(description)
     };
-    match add_metadata(
-        &mut crate::db::make_client().await?,
-        gh_id,
-        idx,
-        description.as_deref(),
-    )
-    .await
-    {
+    let mut db = ctx.db.get().await;
+    match add_metadata(&mut db, gh_id, idx, description.as_deref()).await {
         Ok(()) => Ok(serde_json::to_string(&Response {
             content: "Added metadata!",
         })
@@ -655,6 +655,7 @@ async fn add_meta_notification(
 }
 
 async fn move_notification(
+    ctx: &Context,
     gh_id: i64,
     mut words: impl Iterator<Item = &str>,
 ) -> anyhow::Result<String> {
@@ -676,7 +677,7 @@ async fn move_notification(
         .context("to index")?
         .checked_sub(1)
         .ok_or_else(|| anyhow::anyhow!("1-based indexes"))?;
-    match move_indices(&mut crate::db::make_client().await?, gh_id, from, to).await {
+    match move_indices(&mut *ctx.db.get().await, gh_id, from, to).await {
         Ok(()) => Ok(serde_json::to_string(&Response {
             // to 1-base indices
             content: &format!("Moved {} to {}.", from + 1, to + 1),