Procházet zdrojové kódy

Implement moving notifications by index

Mark Rousskov před 5 roky
rodič
revize
d6d611e432
4 změnil soubory, kde provedl 209 přidání a 68 odebrání
  1. 78 1
      src/db/notifications.rs
  2. 57 0
      src/handlers.rs
  3. 1 54
      src/main.rs
  4. 73 13
      src/zulip.rs

+ 78 - 1
src/db/notifications.rs

@@ -49,6 +49,83 @@ pub struct NotificationData {
     pub time: DateTime<FixedOffset>,
 }
 
+pub async fn move_indices(
+    db: &mut DbClient,
+    user_id: i64,
+    from: usize,
+    to: usize,
+) -> anyhow::Result<()> {
+    loop {
+        let t = db
+            .build_transaction()
+            .isolation_level(tokio_postgres::IsolationLevel::Serializable)
+            .start()
+            .await
+            .context("begin transaction")?;
+
+        let notifications = t
+            .query(
+                "select notification_id, idx, user_id
+        from notifications
+        where user_id = $1
+        order by notifications.idx desc, notifications.time desc;",
+                &[&user_id],
+            )
+            .await
+            .context("failed to get initial ordering")?;
+
+        let mut notifications = notifications
+            .into_iter()
+            .map(|n| n.get(0))
+            .collect::<Vec<i64>>();
+
+        if notifications.get(from).is_none() {
+            anyhow::bail!(
+                "`from` index not present, must be less than {}",
+                notifications.len()
+            );
+        }
+
+        if notifications.get(to).is_none() {
+            anyhow::bail!(
+                "`to` index not present, must be less than {}",
+                notifications.len()
+            );
+        }
+
+        if from < to {
+            notifications[from..=to].rotate_left(1);
+        } else if to < from {
+            notifications[to..=from].rotate_right(1);
+        }
+
+        for (idx, id) in notifications.into_iter().enumerate() {
+            t.execute(
+                "update notifications SET notifications.idx = $2
+                 where notifications.notification_id = $1",
+                &[&id, &(idx as i64)],
+            )
+            .await
+            .context("update notification id")?;
+        }
+
+        if let Err(e) = t.commit().await {
+            if e.code().map_or(false, |c| {
+                *c == tokio_postgres::error::SqlState::T_R_SERIALIZATION_FAILURE
+            }) {
+                log::trace!("serialization failure, restarting index movement");
+                continue;
+            } else {
+                return Err(e).context("transaction commit failure");
+            }
+        } else {
+            break;
+        }
+    }
+
+    Ok(())
+}
+
 pub async fn get_notifications(
     db: &DbClient,
     username: &str,
@@ -60,7 +137,7 @@ pub async fn get_notifications(
         from notifications
         join users on notifications.user_id = users.user_id
         where username = $1
-        order by notifications.idx desc, notifications.time desc;",
+        order by notifications.idx desc nulls last, notifications.time desc;",
             &[&username],
         )
         .await

+ 57 - 0
src/handlers.rs

@@ -1,6 +1,9 @@
 use crate::config::{self, ConfigurationError};
 use crate::github::{Event, GithubClient};
+use anyhow::Context as _;
 use futures::future::BoxFuture;
+use native_tls::{Certificate, TlsConnector};
+use postgres_native_tls::MakeTlsConnector;
 use std::fmt;
 use tokio_postgres::Client as DbClient;
 
@@ -77,6 +80,60 @@ pub struct Context {
     pub username: String,
 }
 
+const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem";
+
+impl Context {
+    pub async fn make_db_client(
+        client: &reqwest::Client,
+    ) -> anyhow::Result<tokio_postgres::Client> {
+        let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
+        if db_url.contains("rds.amazonaws.com") {
+            let resp = client
+                .get(CERT_URL)
+                .send()
+                .await
+                .context("failed to get RDS cert")?;
+            let cert = resp.bytes().await.context("faield to get RDS cert body")?;
+            let cert = Certificate::from_pem(&cert).context("made certificate")?;
+            let connector = TlsConnector::builder()
+                .add_root_certificate(cert)
+                .build()
+                .context("built TlsConnector")?;
+            let connector = MakeTlsConnector::new(connector);
+
+            let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await {
+                Ok(v) => v,
+                Err(e) => {
+                    anyhow::bail!("failed to connect to DB: {}", e);
+                }
+            };
+            tokio::spawn(async move {
+                if let Err(e) = connection.await {
+                    eprintln!("database connection error: {}", e);
+                }
+            });
+
+            Ok(db_client)
+        } else {
+            eprintln!("Warning: Non-TLS connection to non-RDS DB");
+            let (db_client, connection) =
+                match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await {
+                    Ok(v) => v,
+                    Err(e) => {
+                        anyhow::bail!("failed to connect to DB: {}", e);
+                    }
+                };
+            tokio::spawn(async move {
+                if let Err(e) = connection.await {
+                    eprintln!("database connection error: {}", e);
+                }
+            });
+
+            Ok(db_client)
+        }
+    }
+}
+
 pub trait Handler: Sync + Send {
     type Input;
     type Config;

+ 1 - 54
src/main.rs

@@ -3,8 +3,6 @@
 use anyhow::Context as _;
 use futures::{future::FutureExt, stream::StreamExt};
 use hyper::{header, Body, Request, Response, Server, StatusCode};
-use native_tls::{Certificate, TlsConnector};
-use postgres_native_tls::MakeTlsConnector;
 use reqwest::Client;
 use std::{env, net::SocketAddr, sync::Arc};
 use triagebot::{db, github, handlers::Context, notification_listing, payload, EventName};
@@ -156,64 +154,13 @@ async fn serve_req(req: Request<Body>, ctx: Arc<Context>) -> Result<Response<Bod
     Ok(Response::new(Body::from("processed request")))
 }
 
-const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem";
-
-async fn connect_to_db(client: Client) -> anyhow::Result<tokio_postgres::Client> {
-    let db_url = env::var("DATABASE_URL").expect("needs DATABASE_URL");
-    if db_url.contains("rds.amazonaws.com") {
-        let resp = client
-            .get(CERT_URL)
-            .send()
-            .await
-            .context("failed to get RDS cert")?;
-        let cert = resp.bytes().await.context("faield to get RDS cert body")?;
-        let cert = Certificate::from_pem(&cert).context("made certificate")?;
-        let connector = TlsConnector::builder()
-            .add_root_certificate(cert)
-            .build()
-            .context("built TlsConnector")?;
-        let connector = MakeTlsConnector::new(connector);
-
-        let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await {
-            Ok(v) => v,
-            Err(e) => {
-                anyhow::bail!("failed to connect to DB: {}", e);
-            }
-        };
-        tokio::spawn(async move {
-            if let Err(e) = connection.await {
-                eprintln!("database connection error: {}", e);
-            }
-        });
-
-        Ok(db_client)
-    } else {
-        eprintln!("Warning: Non-TLS connection to non-RDS DB");
-        let (db_client, connection) =
-            match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await {
-                Ok(v) => v,
-                Err(e) => {
-                    anyhow::bail!("failed to connect to DB: {}", e);
-                }
-            };
-        tokio::spawn(async move {
-            if let Err(e) = connection.await {
-                eprintln!("database connection error: {}", e);
-            }
-        });
-
-        Ok(db_client)
-    }
-}
-
 async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
     log::info!("Listening on http://{}", addr);
 
     let client = Client::new();
-    let db_client = connect_to_db(client.clone())
+    let db_client = Context::make_db_client(&client)
         .await
         .context("open database connection")?;
-
     db::run_migrations(&db_client)
         .await
         .context("database migrations")?;

+ 73 - 13
src/zulip.rs

@@ -1,4 +1,4 @@
-use crate::db::notifications::delete_ping;
+use crate::db::notifications::{delete_ping, move_indices};
 use crate::github::GithubClient;
 use crate::handlers::Context;
 use anyhow::Context as _;
@@ -64,14 +64,25 @@ pub async fn respond(ctx: &Context, req: Request) -> String {
         }
     };
 
-    match two_words(&req.data) {
-        Some(["acknowledge", url]) => match delete_ping(&ctx.db, gh_id, url).await {
-            Ok(()) => serde_json::to_string(&Response {
-                content: &format!("Acknowledged {}.", url),
+    let mut words = req.data.split_whitespace();
+    match words.next() {
+        Some("acknowledge") => match acknowledge(&ctx, gh_id, words).await {
+            Ok(r) => r,
+            Err(e) => serde_json::to_string(&Response {
+                content: &format!(
+                    "Failed to parse acknowledgement, expected `acknowledge <url>`: {:?}.",
+                    e
+                ),
             })
             .unwrap(),
+        },
+        Some("move") => match move_notification(&ctx, gh_id, words).await {
+            Ok(r) => r,
             Err(e) => serde_json::to_string(&Response {
-                content: &format!("Failed to acknowledge {}: {:?}.", url, e),
+                content: &format!(
+                    "Failed to parse movement, expected `move <from> <to>`: {:?}.",
+                    e
+                ),
             })
             .unwrap(),
         },
@@ -82,13 +93,62 @@ pub async fn respond(ctx: &Context, req: Request) -> String {
     }
 }
 
-fn two_words(s: &str) -> Option<[&str; 2]> {
-    let mut iter = s.split_whitespace();
-    let first = iter.next()?;
-    let second = iter.next()?;
-    if iter.next().is_some() {
-        return None;
+async fn acknowledge(
+    ctx: &Context,
+    gh_id: i64,
+    mut words: impl Iterator<Item = &str>,
+) -> anyhow::Result<String> {
+    let url = match words.next() {
+        Some(url) => {
+            if words.next().is_some() {
+                anyhow::bail!("too many words");
+            }
+            url
+        }
+        None => anyhow::bail!("not enough words"),
+    };
+    match delete_ping(&ctx.db, gh_id, url).await {
+        Ok(()) => Ok(serde_json::to_string(&Response {
+            content: &format!("Acknowledged {}.", url),
+        })
+        .unwrap()),
+        Err(e) => Ok(serde_json::to_string(&Response {
+            content: &format!("Failed to acknowledge {}: {:?}.", url, e),
+        })
+        .unwrap()),
     }
+}
 
-    return Some([first, second]);
+async fn move_notification(
+    ctx: &Context,
+    gh_id: i64,
+    mut words: impl Iterator<Item = &str>,
+) -> anyhow::Result<String> {
+    let from = match words.next() {
+        Some(idx) => idx,
+        None => anyhow::bail!("from idx not present"),
+    };
+    let to = match words.next() {
+        Some(idx) => idx,
+        None => anyhow::bail!("from idx not present"),
+    };
+    let from = from.parse::<usize>().context("from index")?;
+    let to = to.parse::<usize>().context("to index")?;
+    match move_indices(
+        &mut Context::make_db_client(&ctx.github.raw()).await?,
+        gh_id,
+        from,
+        to,
+    )
+    .await
+    {
+        Ok(()) => Ok(serde_json::to_string(&Response {
+            content: &format!("Moved {} to {}.", from, to),
+        })
+        .unwrap()),
+        Err(e) => Ok(serde_json::to_string(&Response {
+            content: &format!("Failed to move: {:?}.", e),
+        })
+        .unwrap()),
+    }
 }