db.rs 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. use anyhow::Context as _;
  2. use native_tls::{Certificate, TlsConnector};
  3. use postgres_native_tls::MakeTlsConnector;
  4. pub use tokio_postgres::Client as DbClient;
  5. pub mod notifications;
  6. pub mod rustc_commits;
  7. const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem";
  8. lazy_static::lazy_static! {
  9. static ref CERTIFICATE_PEM: Vec<u8> = {
  10. let client = reqwest::blocking::Client::new();
  11. let resp = client
  12. .get(CERT_URL)
  13. .send()
  14. .expect("failed to get RDS cert");
  15. resp.bytes().expect("failed to get RDS cert body").to_vec()
  16. };
  17. }
  18. pub async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
  19. let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
  20. if db_url.contains("rds.amazonaws.com") {
  21. let cert = &CERTIFICATE_PEM[..];
  22. let cert = Certificate::from_pem(&cert).context("made certificate")?;
  23. let connector = TlsConnector::builder()
  24. .add_root_certificate(cert)
  25. .build()
  26. .context("built TlsConnector")?;
  27. let connector = MakeTlsConnector::new(connector);
  28. let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await {
  29. Ok(v) => v,
  30. Err(e) => {
  31. anyhow::bail!("failed to connect to DB: {}", e);
  32. }
  33. };
  34. tokio::spawn(async move {
  35. if let Err(e) = connection.await {
  36. eprintln!("database connection error: {}", e);
  37. }
  38. });
  39. Ok(db_client)
  40. } else {
  41. eprintln!("Warning: Non-TLS connection to non-RDS DB");
  42. let (db_client, connection) =
  43. match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await {
  44. Ok(v) => v,
  45. Err(e) => {
  46. anyhow::bail!("failed to connect to DB: {}", e);
  47. }
  48. };
  49. tokio::spawn(async move {
  50. if let Err(e) = connection.await {
  51. eprintln!("database connection error: {}", e);
  52. }
  53. });
  54. Ok(db_client)
  55. }
  56. }
  57. pub async fn run_migrations(client: &DbClient) -> anyhow::Result<()> {
  58. client
  59. .execute(
  60. "CREATE TABLE IF NOT EXISTS database_versions (
  61. zero INTEGER PRIMARY KEY,
  62. migration_counter INTEGER
  63. );",
  64. &[],
  65. )
  66. .await
  67. .context("creating database versioning table")?;
  68. client
  69. .execute(
  70. "INSERT INTO database_versions (zero, migration_counter)
  71. VALUES (0, 0)
  72. ON CONFLICT DO NOTHING",
  73. &[],
  74. )
  75. .await
  76. .context("inserting initial database_versions")?;
  77. let migration_idx: i32 = client
  78. .query_one("SELECT migration_counter FROM database_versions", &[])
  79. .await
  80. .context("getting migration counter")?
  81. .get(0);
  82. let migration_idx = migration_idx as usize;
  83. for (idx, migration) in MIGRATIONS.iter().enumerate() {
  84. if idx >= migration_idx {
  85. client
  86. .execute(*migration, &[])
  87. .await
  88. .with_context(|| format!("executing {}th migration", idx))?;
  89. client
  90. .execute(
  91. "UPDATE database_versions SET migration_counter = $1",
  92. &[&(idx as i32 + 1)],
  93. )
  94. .await
  95. .with_context(|| format!("updating migration counter to {}", idx))?;
  96. }
  97. }
  98. Ok(())
  99. }
  100. static MIGRATIONS: &[&str] = &[
  101. "
  102. CREATE TABLE notifications (
  103. notification_id BIGSERIAL PRIMARY KEY,
  104. user_id BIGINT,
  105. origin_url TEXT NOT NULL,
  106. origin_html TEXT,
  107. time TIMESTAMP WITH TIME ZONE
  108. );
  109. ",
  110. "
  111. CREATE TABLE users (
  112. user_id BIGINT PRIMARY KEY,
  113. username TEXT NOT NULL
  114. );
  115. ",
  116. "ALTER TABLE notifications ADD COLUMN short_description TEXT;",
  117. "ALTER TABLE notifications ADD COLUMN team_name TEXT;",
  118. "ALTER TABLE notifications ADD COLUMN idx INTEGER;",
  119. "ALTER TABLE notifications ADD COLUMN metadata TEXT;",
  120. "
  121. CREATE TABLE rustc_commits (
  122. sha TEXT PRIMARY KEY,
  123. parent_sha TEXT NOT NULL,
  124. time TIMESTAMP WITH TIME ZONE
  125. );
  126. ",
  127. ];