db.rs 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. use crate::handlers::jobs::handle_job;
  2. use crate::{db::jobs::*, handlers::Context};
  3. use anyhow::Context as _;
  4. use chrono::Utc;
  5. use native_tls::{Certificate, TlsConnector};
  6. use postgres_native_tls::MakeTlsConnector;
  7. use std::sync::{Arc, Mutex};
  8. use tokio::sync::{OwnedSemaphorePermit, Semaphore};
  9. use tokio_postgres::Client as DbClient;
  10. pub mod issue_data;
  11. pub mod jobs;
  12. pub mod notifications;
  13. pub mod rustc_commits;
  14. const CERT_URL: &str = "https://s3.amazonaws.com/rds-downloads/rds-ca-2019-root.pem";
  15. lazy_static::lazy_static! {
  16. static ref CERTIFICATE_PEM: Vec<u8> = {
  17. let client = reqwest::blocking::Client::new();
  18. let resp = client
  19. .get(CERT_URL)
  20. .send()
  21. .expect("failed to get RDS cert");
  22. resp.bytes().expect("failed to get RDS cert body").to_vec()
  23. };
  24. }
  25. pub struct ClientPool {
  26. connections: Arc<Mutex<Vec<tokio_postgres::Client>>>,
  27. permits: Arc<Semaphore>,
  28. }
  29. pub struct PooledClient {
  30. client: Option<tokio_postgres::Client>,
  31. #[allow(unused)] // only used for drop impl
  32. permit: OwnedSemaphorePermit,
  33. pool: Arc<Mutex<Vec<tokio_postgres::Client>>>,
  34. }
  35. impl Drop for PooledClient {
  36. fn drop(&mut self) {
  37. let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner());
  38. clients.push(self.client.take().unwrap());
  39. }
  40. }
  41. impl std::ops::Deref for PooledClient {
  42. type Target = tokio_postgres::Client;
  43. fn deref(&self) -> &Self::Target {
  44. self.client.as_ref().unwrap()
  45. }
  46. }
  47. impl std::ops::DerefMut for PooledClient {
  48. fn deref_mut(&mut self) -> &mut Self::Target {
  49. self.client.as_mut().unwrap()
  50. }
  51. }
  52. impl ClientPool {
  53. pub fn new() -> ClientPool {
  54. ClientPool {
  55. connections: Arc::new(Mutex::new(Vec::with_capacity(16))),
  56. permits: Arc::new(Semaphore::new(16)),
  57. }
  58. }
  59. pub async fn get(&self) -> PooledClient {
  60. let permit = self.permits.clone().acquire_owned().await.unwrap();
  61. {
  62. let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner());
  63. // Pop connections until we hit a non-closed connection (or there are no
  64. // "possibly open" connections left).
  65. while let Some(c) = slots.pop() {
  66. if !c.is_closed() {
  67. return PooledClient {
  68. client: Some(c),
  69. permit,
  70. pool: self.connections.clone(),
  71. };
  72. }
  73. }
  74. }
  75. PooledClient {
  76. client: Some(make_client().await.unwrap()),
  77. permit,
  78. pool: self.connections.clone(),
  79. }
  80. }
  81. }
  82. async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
  83. let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
  84. if db_url.contains("rds.amazonaws.com") {
  85. let cert = &CERTIFICATE_PEM[..];
  86. let cert = Certificate::from_pem(&cert).context("made certificate")?;
  87. let connector = TlsConnector::builder()
  88. .add_root_certificate(cert)
  89. .build()
  90. .context("built TlsConnector")?;
  91. let connector = MakeTlsConnector::new(connector);
  92. let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await {
  93. Ok(v) => v,
  94. Err(e) => {
  95. anyhow::bail!("failed to connect to DB: {}", e);
  96. }
  97. };
  98. tokio::task::spawn(async move {
  99. if let Err(e) = connection.await {
  100. eprintln!("database connection error: {}", e);
  101. }
  102. });
  103. Ok(db_client)
  104. } else {
  105. eprintln!("Warning: Non-TLS connection to non-RDS DB");
  106. let (db_client, connection) =
  107. match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await {
  108. Ok(v) => v,
  109. Err(e) => {
  110. anyhow::bail!("failed to connect to DB: {}", e);
  111. }
  112. };
  113. tokio::spawn(async move {
  114. if let Err(e) = connection.await {
  115. eprintln!("database connection error: {}", e);
  116. }
  117. });
  118. Ok(db_client)
  119. }
  120. }
  121. pub async fn run_migrations(client: &DbClient) -> anyhow::Result<()> {
  122. client
  123. .execute(
  124. "CREATE TABLE IF NOT EXISTS database_versions (
  125. zero INTEGER PRIMARY KEY,
  126. migration_counter INTEGER
  127. );",
  128. &[],
  129. )
  130. .await
  131. .context("creating database versioning table")?;
  132. client
  133. .execute(
  134. "INSERT INTO database_versions (zero, migration_counter)
  135. VALUES (0, 0)
  136. ON CONFLICT DO NOTHING",
  137. &[],
  138. )
  139. .await
  140. .context("inserting initial database_versions")?;
  141. let migration_idx: i32 = client
  142. .query_one("SELECT migration_counter FROM database_versions", &[])
  143. .await
  144. .context("getting migration counter")?
  145. .get(0);
  146. let migration_idx = migration_idx as usize;
  147. for (idx, migration) in MIGRATIONS.iter().enumerate() {
  148. if idx >= migration_idx {
  149. client
  150. .execute(*migration, &[])
  151. .await
  152. .with_context(|| format!("executing {}th migration", idx))?;
  153. client
  154. .execute(
  155. "UPDATE database_versions SET migration_counter = $1",
  156. &[&(idx as i32 + 1)],
  157. )
  158. .await
  159. .with_context(|| format!("updating migration counter to {}", idx))?;
  160. }
  161. }
  162. Ok(())
  163. }
  164. pub async fn schedule_jobs(db: &DbClient, jobs: Vec<JobSchedule>) -> anyhow::Result<()> {
  165. for job in jobs {
  166. let mut upcoming = job.schedule.upcoming(Utc).take(1);
  167. if let Some(scheduled_at) = upcoming.next() {
  168. if let Err(_) = get_job_by_name_and_scheduled_at(&db, &job.name, &scheduled_at).await {
  169. // mean there's no job already in the db with that name and scheduled_at
  170. insert_job(&db, &job.name, &scheduled_at, &job.metadata).await?;
  171. }
  172. }
  173. }
  174. Ok(())
  175. }
  176. pub async fn run_scheduled_jobs(ctx: &Context, db: &DbClient) -> anyhow::Result<()> {
  177. let jobs = get_jobs_to_execute(&db).await.unwrap();
  178. tracing::trace!("jobs to execute: {:#?}", jobs);
  179. for job in jobs.iter() {
  180. update_job_executed_at(&db, &job.id).await?;
  181. match handle_job(&ctx, &job.name, &job.metadata).await {
  182. Ok(_) => {
  183. tracing::trace!("job successfully executed (id={})", job.id);
  184. delete_job(&db, &job.id).await?;
  185. }
  186. Err(e) => {
  187. tracing::error!("job failed on execution (id={:?}, error={:?})", job.id, e);
  188. update_job_error_message(&db, &job.id, &e.to_string()).await?;
  189. }
  190. }
  191. }
  192. Ok(())
  193. }
  194. static MIGRATIONS: &[&str] = &[
  195. "
  196. CREATE TABLE notifications (
  197. notification_id BIGSERIAL PRIMARY KEY,
  198. user_id BIGINT,
  199. origin_url TEXT NOT NULL,
  200. origin_html TEXT,
  201. time TIMESTAMP WITH TIME ZONE
  202. );
  203. ",
  204. "
  205. CREATE TABLE users (
  206. user_id BIGINT PRIMARY KEY,
  207. username TEXT NOT NULL
  208. );
  209. ",
  210. "ALTER TABLE notifications ADD COLUMN short_description TEXT;",
  211. "ALTER TABLE notifications ADD COLUMN team_name TEXT;",
  212. "ALTER TABLE notifications ADD COLUMN idx INTEGER;",
  213. "ALTER TABLE notifications ADD COLUMN metadata TEXT;",
  214. "
  215. CREATE TABLE rustc_commits (
  216. sha TEXT PRIMARY KEY,
  217. parent_sha TEXT NOT NULL,
  218. time TIMESTAMP WITH TIME ZONE
  219. );
  220. ",
  221. "ALTER TABLE rustc_commits ADD COLUMN pr INTEGER;",
  222. "
  223. CREATE TABLE issue_data (
  224. repo TEXT,
  225. issue_number INTEGER,
  226. key TEXT,
  227. data JSONB,
  228. PRIMARY KEY (repo, issue_number, key)
  229. );
  230. ",
  231. "
  232. CREATE TABLE jobs (
  233. id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
  234. name TEXT NOT NULL,
  235. scheduled_at TIMESTAMP WITH TIME ZONE NOT NULL,
  236. metadata JSONB,
  237. executed_at TIMESTAMP WITH TIME ZONE,
  238. error_message TEXT
  239. );
  240. ",
  241. "
  242. CREATE UNIQUE INDEX jobs_name_scheduled_at_unique_index
  243. ON jobs (
  244. name, scheduled_at
  245. );
  246. ",
  247. ];