123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- use anyhow::Context as _;
- use native_tls::{Certificate, TlsConnector};
- use postgres_native_tls::MakeTlsConnector;
- pub use tokio_postgres::Client as DbClient;
- pub mod notifications;
- pub mod rustc_commits;
- const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem";
- lazy_static::lazy_static! {
- static ref CERTIFICATE_PEM: Vec<u8> = {
- let client = reqwest::blocking::Client::new();
- let resp = client
- .get(CERT_URL)
- .send()
- .expect("failed to get RDS cert");
- resp.bytes().expect("failed to get RDS cert body").to_vec()
- };
- }
- pub async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
- let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
- if db_url.contains("rds.amazonaws.com") {
- let cert = &CERTIFICATE_PEM[..];
- let cert = Certificate::from_pem(&cert).context("made certificate")?;
- let connector = TlsConnector::builder()
- .add_root_certificate(cert)
- .build()
- .context("built TlsConnector")?;
- let connector = MakeTlsConnector::new(connector);
- let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await {
- Ok(v) => v,
- Err(e) => {
- anyhow::bail!("failed to connect to DB: {}", e);
- }
- };
- tokio::spawn(async move {
- if let Err(e) = connection.await {
- eprintln!("database connection error: {}", e);
- }
- });
- Ok(db_client)
- } else {
- eprintln!("Warning: Non-TLS connection to non-RDS DB");
- let (db_client, connection) =
- match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await {
- Ok(v) => v,
- Err(e) => {
- anyhow::bail!("failed to connect to DB: {}", e);
- }
- };
- tokio::spawn(async move {
- if let Err(e) = connection.await {
- eprintln!("database connection error: {}", e);
- }
- });
- Ok(db_client)
- }
- }
- pub async fn run_migrations(client: &DbClient) -> anyhow::Result<()> {
- client
- .execute(
- "CREATE TABLE IF NOT EXISTS database_versions (
- zero INTEGER PRIMARY KEY,
- migration_counter INTEGER
- );",
- &[],
- )
- .await
- .context("creating database versioning table")?;
- client
- .execute(
- "INSERT INTO database_versions (zero, migration_counter)
- VALUES (0, 0)
- ON CONFLICT DO NOTHING",
- &[],
- )
- .await
- .context("inserting initial database_versions")?;
- let migration_idx: i32 = client
- .query_one("SELECT migration_counter FROM database_versions", &[])
- .await
- .context("getting migration counter")?
- .get(0);
- let migration_idx = migration_idx as usize;
- for (idx, migration) in MIGRATIONS.iter().enumerate() {
- if idx >= migration_idx {
- client
- .execute(*migration, &[])
- .await
- .with_context(|| format!("executing {}th migration", idx))?;
- client
- .execute(
- "UPDATE database_versions SET migration_counter = $1",
- &[&(idx as i32 + 1)],
- )
- .await
- .with_context(|| format!("updating migration counter to {}", idx))?;
- }
- }
- Ok(())
- }
- static MIGRATIONS: &[&str] = &[
- "
- CREATE TABLE notifications (
- notification_id BIGSERIAL PRIMARY KEY,
- user_id BIGINT,
- origin_url TEXT NOT NULL,
- origin_html TEXT,
- time TIMESTAMP WITH TIME ZONE
- );
- ",
- "
- CREATE TABLE users (
- user_id BIGINT PRIMARY KEY,
- username TEXT NOT NULL
- );
- ",
- "ALTER TABLE notifications ADD COLUMN short_description TEXT;",
- "ALTER TABLE notifications ADD COLUMN team_name TEXT;",
- "ALTER TABLE notifications ADD COLUMN idx INTEGER;",
- "ALTER TABLE notifications ADD COLUMN metadata TEXT;",
- "
- CREATE TABLE rustc_commits (
- sha TEXT PRIMARY KEY,
- parent_sha TEXT NOT NULL,
- time TIMESTAMP WITH TIME ZONE
- );
- ",
- ];
|