db.rs 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. use anyhow::Context as _;
  2. use native_tls::{Certificate, TlsConnector};
  3. use postgres_native_tls::MakeTlsConnector;
  4. use std::sync::{Arc, Mutex};
  5. use tokio::sync::{OwnedSemaphorePermit, Semaphore};
  6. use tokio_postgres::Client as DbClient;
  7. use crate::db::jobs::*;
  8. use crate::handlers::jobs::handle_job;
  9. pub mod jobs;
  10. pub mod issue_data;
  11. pub mod notifications;
  12. pub mod rustc_commits;
  13. const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem";
  14. lazy_static::lazy_static! {
  15. static ref CERTIFICATE_PEM: Vec<u8> = {
  16. let client = reqwest::blocking::Client::new();
  17. let resp = client
  18. .get(CERT_URL)
  19. .send()
  20. .expect("failed to get RDS cert");
  21. resp.bytes().expect("failed to get RDS cert body").to_vec()
  22. };
  23. }
  24. pub struct ClientPool {
  25. connections: Arc<Mutex<Vec<tokio_postgres::Client>>>,
  26. permits: Arc<Semaphore>,
  27. }
  28. pub struct PooledClient {
  29. client: Option<tokio_postgres::Client>,
  30. #[allow(unused)] // only used for drop impl
  31. permit: OwnedSemaphorePermit,
  32. pool: Arc<Mutex<Vec<tokio_postgres::Client>>>,
  33. }
  34. impl Drop for PooledClient {
  35. fn drop(&mut self) {
  36. let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner());
  37. clients.push(self.client.take().unwrap());
  38. }
  39. }
  40. impl std::ops::Deref for PooledClient {
  41. type Target = tokio_postgres::Client;
  42. fn deref(&self) -> &Self::Target {
  43. self.client.as_ref().unwrap()
  44. }
  45. }
  46. impl std::ops::DerefMut for PooledClient {
  47. fn deref_mut(&mut self) -> &mut Self::Target {
  48. self.client.as_mut().unwrap()
  49. }
  50. }
  51. impl ClientPool {
  52. pub fn new() -> ClientPool {
  53. ClientPool {
  54. connections: Arc::new(Mutex::new(Vec::with_capacity(16))),
  55. permits: Arc::new(Semaphore::new(16)),
  56. }
  57. }
  58. pub async fn get(&self) -> PooledClient {
  59. let permit = self.permits.clone().acquire_owned().await.unwrap();
  60. {
  61. let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner());
  62. // Pop connections until we hit a non-closed connection (or there are no
  63. // "possibly open" connections left).
  64. while let Some(c) = slots.pop() {
  65. if !c.is_closed() {
  66. return PooledClient {
  67. client: Some(c),
  68. permit,
  69. pool: self.connections.clone(),
  70. };
  71. }
  72. }
  73. }
  74. PooledClient {
  75. client: Some(make_client().await.unwrap()),
  76. permit,
  77. pool: self.connections.clone(),
  78. }
  79. }
  80. }
  81. async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
  82. let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
  83. if db_url.contains("rds.amazonaws.com") {
  84. let cert = &CERTIFICATE_PEM[..];
  85. let cert = Certificate::from_pem(&cert).context("made certificate")?;
  86. let connector = TlsConnector::builder()
  87. .add_root_certificate(cert)
  88. .build()
  89. .context("built TlsConnector")?;
  90. let connector = MakeTlsConnector::new(connector);
  91. let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await {
  92. Ok(v) => v,
  93. Err(e) => {
  94. anyhow::bail!("failed to connect to DB: {}", e);
  95. }
  96. };
  97. tokio::task::spawn(async move {
  98. if let Err(e) = connection.await {
  99. eprintln!("database connection error: {}", e);
  100. }
  101. });
  102. Ok(db_client)
  103. } else {
  104. eprintln!("Warning: Non-TLS connection to non-RDS DB");
  105. let (db_client, connection) =
  106. match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await {
  107. Ok(v) => v,
  108. Err(e) => {
  109. anyhow::bail!("failed to connect to DB: {}", e);
  110. }
  111. };
  112. tokio::spawn(async move {
  113. if let Err(e) = connection.await {
  114. eprintln!("database connection error: {}", e);
  115. }
  116. });
  117. Ok(db_client)
  118. }
  119. }
  120. pub async fn run_migrations(client: &DbClient) -> anyhow::Result<()> {
  121. client
  122. .execute(
  123. "CREATE TABLE IF NOT EXISTS database_versions (
  124. zero INTEGER PRIMARY KEY,
  125. migration_counter INTEGER
  126. );",
  127. &[],
  128. )
  129. .await
  130. .context("creating database versioning table")?;
  131. client
  132. .execute(
  133. "INSERT INTO database_versions (zero, migration_counter)
  134. VALUES (0, 0)
  135. ON CONFLICT DO NOTHING",
  136. &[],
  137. )
  138. .await
  139. .context("inserting initial database_versions")?;
  140. let migration_idx: i32 = client
  141. .query_one("SELECT migration_counter FROM database_versions", &[])
  142. .await
  143. .context("getting migration counter")?
  144. .get(0);
  145. let migration_idx = migration_idx as usize;
  146. for (idx, migration) in MIGRATIONS.iter().enumerate() {
  147. if idx >= migration_idx {
  148. client
  149. .execute(*migration, &[])
  150. .await
  151. .with_context(|| format!("executing {}th migration", idx))?;
  152. client
  153. .execute(
  154. "UPDATE database_versions SET migration_counter = $1",
  155. &[&(idx as i32 + 1)],
  156. )
  157. .await
  158. .with_context(|| format!("updating migration counter to {}", idx))?;
  159. }
  160. }
  161. Ok(())
  162. }
  163. pub async fn run_scheduled_jobs(db: &DbClient) -> anyhow::Result<()> {
  164. let jobs = get_jobs_to_execute(&db).await.unwrap();
  165. println!("jobs to execute: {:#?}", jobs);
  166. tracing::trace!("jobs to execute: {:#?}", jobs);
  167. for job in jobs.iter() {
  168. update_job_executed_at(&db, &job.id).await?;
  169. match handle_job(&job.name, &job.metadata).await {
  170. Ok(_) => {
  171. println!("job succesfully executed (id={})", job.id);
  172. tracing::trace!("job succesfully executed (id={})", job.id);
  173. delete_job(&db, &job.id).await?;
  174. },
  175. Err(e) => {
  176. println!("job failed on execution (id={:?}, error={:?})", job.id, e);
  177. tracing::trace!("job failed on execution (id={:?}, error={:?})", job.id, e);
  178. update_job_error_message(&db, &job.id, &e.to_string()).await?;
  179. },
  180. }
  181. }
  182. Ok(())
  183. }
  184. static MIGRATIONS: &[&str] = &[
  185. "
  186. CREATE TABLE notifications (
  187. notification_id BIGSERIAL PRIMARY KEY,
  188. user_id BIGINT,
  189. origin_url TEXT NOT NULL,
  190. origin_html TEXT,
  191. time TIMESTAMP WITH TIME ZONE
  192. );
  193. ",
  194. "
  195. CREATE TABLE users (
  196. user_id BIGINT PRIMARY KEY,
  197. username TEXT NOT NULL
  198. );
  199. ",
  200. "ALTER TABLE notifications ADD COLUMN short_description TEXT;",
  201. "ALTER TABLE notifications ADD COLUMN team_name TEXT;",
  202. "ALTER TABLE notifications ADD COLUMN idx INTEGER;",
  203. "ALTER TABLE notifications ADD COLUMN metadata TEXT;",
  204. "
  205. CREATE TABLE rustc_commits (
  206. sha TEXT PRIMARY KEY,
  207. parent_sha TEXT NOT NULL,
  208. time TIMESTAMP WITH TIME ZONE
  209. );
  210. ",
  211. "ALTER TABLE rustc_commits ADD COLUMN pr INTEGER;",
  212. "
  213. CREATE TABLE issue_data (
  214. repo TEXT,
  215. issue_number INTEGER,
  216. key TEXT,
  217. data JSONB,
  218. PRIMARY KEY (repo, issue_number, key)
  219. );
  220. ",
  221. "
  222. CREATE TABLE jobs (
  223. id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
  224. name TEXT NOT NULL,
  225. expected_time TIMESTAMP WITH TIME ZONE NOT NULL,
  226. metadata JSONB,
  227. executed_at TIMESTAMP WITH TIME ZONE,
  228. error_message TEXT
  229. );
  230. ",
  231. "
  232. CREATE UNIQUE INDEX jobs_name_expected_time_unique_index
  233. ON jobs (
  234. name, expected_time
  235. );
  236. "
  237. ];