use anyhow::Context as _; use native_tls::{Certificate, TlsConnector}; use postgres_native_tls::MakeTlsConnector; pub use tokio_postgres::Client as DbClient; pub mod notifications; pub mod rustc_commits; const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem"; lazy_static::lazy_static! { static ref CERTIFICATE_PEM: Vec = { let client = reqwest::blocking::Client::new(); let resp = client .get(CERT_URL) .send() .expect("failed to get RDS cert"); resp.bytes().expect("failed to get RDS cert body").to_vec() }; } pub async fn make_client() -> anyhow::Result { let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL"); if db_url.contains("rds.amazonaws.com") { let cert = &CERTIFICATE_PEM[..]; let cert = Certificate::from_pem(&cert).context("made certificate")?; let connector = TlsConnector::builder() .add_root_certificate(cert) .build() .context("built TlsConnector")?; let connector = MakeTlsConnector::new(connector); let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await { Ok(v) => v, Err(e) => { anyhow::bail!("failed to connect to DB: {}", e); } }; tokio::spawn(async move { if let Err(e) = connection.await { eprintln!("database connection error: {}", e); } }); Ok(db_client) } else { eprintln!("Warning: Non-TLS connection to non-RDS DB"); let (db_client, connection) = match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await { Ok(v) => v, Err(e) => { anyhow::bail!("failed to connect to DB: {}", e); } }; tokio::spawn(async move { if let Err(e) = connection.await { eprintln!("database connection error: {}", e); } }); Ok(db_client) } } pub async fn run_migrations(client: &DbClient) -> anyhow::Result<()> { client .execute( "CREATE TABLE IF NOT EXISTS database_versions ( zero INTEGER PRIMARY KEY, migration_counter INTEGER );", &[], ) .await .context("creating database versioning table")?; client .execute( "INSERT INTO database_versions (zero, migration_counter) VALUES (0, 0) ON CONFLICT DO NOTHING", &[], ) .await .context("inserting initial database_versions")?; let migration_idx: i32 = client .query_one("SELECT migration_counter FROM database_versions", &[]) .await .context("getting migration counter")? .get(0); let migration_idx = migration_idx as usize; for (idx, migration) in MIGRATIONS.iter().enumerate() { if idx >= migration_idx { client .execute(*migration, &[]) .await .with_context(|| format!("executing {}th migration", idx))?; client .execute( "UPDATE database_versions SET migration_counter = $1", &[&(idx as i32 + 1)], ) .await .with_context(|| format!("updating migration counter to {}", idx))?; } } Ok(()) } static MIGRATIONS: &[&str] = &[ " CREATE TABLE notifications ( notification_id BIGSERIAL PRIMARY KEY, user_id BIGINT, origin_url TEXT NOT NULL, origin_html TEXT, time TIMESTAMP WITH TIME ZONE ); ", " CREATE TABLE users ( user_id BIGINT PRIMARY KEY, username TEXT NOT NULL ); ", "ALTER TABLE notifications ADD COLUMN short_description TEXT;", "ALTER TABLE notifications ADD COLUMN team_name TEXT;", "ALTER TABLE notifications ADD COLUMN idx INTEGER;", "ALTER TABLE notifications ADD COLUMN metadata TEXT;", " CREATE TABLE rustc_commits ( sha TEXT PRIMARY KEY, parent_sha TEXT NOT NULL, time TIMESTAMP WITH TIME ZONE ); ", ];