db.rs 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. use crate::{db::jobs::*, handlers::Context, jobs::jobs};
  2. use anyhow::Context as _;
  3. use chrono::Utc;
  4. use native_tls::{Certificate, TlsConnector};
  5. use postgres_native_tls::MakeTlsConnector;
  6. use std::sync::{Arc, Mutex};
  7. use tokio::sync::{OwnedSemaphorePermit, Semaphore};
  8. use tokio_postgres::Client as DbClient;
  9. pub mod issue_data;
  10. pub mod jobs;
  11. pub mod notifications;
  12. pub mod rustc_commits;
  13. const CERT_URL: &str = "https://truststore.pki.rds.amazonaws.com/global/global-bundle.pem";
  14. lazy_static::lazy_static! {
  15. static ref CERTIFICATE_PEMS: Vec<u8> = {
  16. let client = reqwest::blocking::Client::new();
  17. let resp = client
  18. .get(CERT_URL)
  19. .send()
  20. .expect("failed to get RDS cert");
  21. resp.bytes().expect("failed to get RDS cert body").to_vec()
  22. };
  23. }
  24. pub struct ClientPool {
  25. connections: Arc<Mutex<Vec<tokio_postgres::Client>>>,
  26. permits: Arc<Semaphore>,
  27. }
  28. pub struct PooledClient {
  29. client: Option<tokio_postgres::Client>,
  30. #[allow(unused)] // only used for drop impl
  31. permit: OwnedSemaphorePermit,
  32. pool: Arc<Mutex<Vec<tokio_postgres::Client>>>,
  33. }
  34. impl Drop for PooledClient {
  35. fn drop(&mut self) {
  36. let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner());
  37. clients.push(self.client.take().unwrap());
  38. }
  39. }
  40. impl std::ops::Deref for PooledClient {
  41. type Target = tokio_postgres::Client;
  42. fn deref(&self) -> &Self::Target {
  43. self.client.as_ref().unwrap()
  44. }
  45. }
  46. impl std::ops::DerefMut for PooledClient {
  47. fn deref_mut(&mut self) -> &mut Self::Target {
  48. self.client.as_mut().unwrap()
  49. }
  50. }
  51. impl ClientPool {
  52. pub fn new() -> ClientPool {
  53. ClientPool {
  54. connections: Arc::new(Mutex::new(Vec::with_capacity(16))),
  55. permits: Arc::new(Semaphore::new(16)),
  56. }
  57. }
  58. pub async fn get(&self) -> PooledClient {
  59. let permit = self.permits.clone().acquire_owned().await.unwrap();
  60. {
  61. let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner());
  62. // Pop connections until we hit a non-closed connection (or there are no
  63. // "possibly open" connections left).
  64. while let Some(c) = slots.pop() {
  65. if !c.is_closed() {
  66. return PooledClient {
  67. client: Some(c),
  68. permit,
  69. pool: self.connections.clone(),
  70. };
  71. }
  72. }
  73. }
  74. PooledClient {
  75. client: Some(make_client().await.unwrap()),
  76. permit,
  77. pool: self.connections.clone(),
  78. }
  79. }
  80. }
  81. async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
  82. let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
  83. if db_url.contains("rds.amazonaws.com") {
  84. let mut builder = TlsConnector::builder();
  85. for cert in make_certificates() {
  86. builder.add_root_certificate(cert);
  87. }
  88. let connector = builder.build().context("built TlsConnector")?;
  89. let connector = MakeTlsConnector::new(connector);
  90. let (db_client, connection) = match tokio_postgres::connect(&db_url, connector).await {
  91. Ok(v) => v,
  92. Err(e) => {
  93. anyhow::bail!("failed to connect to DB: {}", e);
  94. }
  95. };
  96. tokio::task::spawn(async move {
  97. if let Err(e) = connection.await {
  98. eprintln!("database connection error: {}", e);
  99. }
  100. });
  101. Ok(db_client)
  102. } else {
  103. eprintln!("Warning: Non-TLS connection to non-RDS DB");
  104. let (db_client, connection) =
  105. match tokio_postgres::connect(&db_url, tokio_postgres::NoTls).await {
  106. Ok(v) => v,
  107. Err(e) => {
  108. anyhow::bail!("failed to connect to DB: {}", e);
  109. }
  110. };
  111. tokio::spawn(async move {
  112. if let Err(e) = connection.await {
  113. eprintln!("database connection error: {}", e);
  114. }
  115. });
  116. Ok(db_client)
  117. }
  118. }
  119. fn make_certificates() -> Vec<Certificate> {
  120. use x509_cert::der::pem::LineEnding;
  121. use x509_cert::der::EncodePem;
  122. let certs = x509_cert::Certificate::load_pem_chain(&CERTIFICATE_PEMS[..]).unwrap();
  123. certs
  124. .into_iter()
  125. .map(|cert| Certificate::from_pem(cert.to_pem(LineEnding::LF).unwrap().as_bytes()).unwrap())
  126. .collect()
  127. }
  128. // Makes sure we successfully parse the RDS certificates and load them into native-tls compatible
  129. // format.
  130. #[test]
  131. fn cert() {
  132. make_certificates();
  133. }
  134. pub async fn run_migrations(client: &DbClient) -> anyhow::Result<()> {
  135. client
  136. .execute(
  137. "CREATE TABLE IF NOT EXISTS database_versions (
  138. zero INTEGER PRIMARY KEY,
  139. migration_counter INTEGER
  140. );",
  141. &[],
  142. )
  143. .await
  144. .context("creating database versioning table")?;
  145. client
  146. .execute(
  147. "INSERT INTO database_versions (zero, migration_counter)
  148. VALUES (0, 0)
  149. ON CONFLICT DO NOTHING",
  150. &[],
  151. )
  152. .await
  153. .context("inserting initial database_versions")?;
  154. let migration_idx: i32 = client
  155. .query_one("SELECT migration_counter FROM database_versions", &[])
  156. .await
  157. .context("getting migration counter")?
  158. .get(0);
  159. let migration_idx = migration_idx as usize;
  160. for (idx, migration) in MIGRATIONS.iter().enumerate() {
  161. if idx >= migration_idx {
  162. client
  163. .execute(*migration, &[])
  164. .await
  165. .with_context(|| format!("executing {}th migration", idx))?;
  166. client
  167. .execute(
  168. "UPDATE database_versions SET migration_counter = $1",
  169. &[&(idx as i32 + 1)],
  170. )
  171. .await
  172. .with_context(|| format!("updating migration counter to {}", idx))?;
  173. }
  174. }
  175. Ok(())
  176. }
  177. pub async fn schedule_jobs(db: &DbClient, jobs: Vec<JobSchedule>) -> anyhow::Result<()> {
  178. for job in jobs {
  179. let mut upcoming = job.schedule.upcoming(Utc).take(1);
  180. if let Some(scheduled_at) = upcoming.next() {
  181. schedule_job(db, job.name, job.metadata, scheduled_at).await?;
  182. }
  183. }
  184. Ok(())
  185. }
  186. pub async fn schedule_job(
  187. db: &DbClient,
  188. job_name: &str,
  189. job_metadata: serde_json::Value,
  190. when: chrono::DateTime<Utc>,
  191. ) -> anyhow::Result<()> {
  192. let all_jobs = jobs();
  193. if !all_jobs.iter().any(|j| j.name() == job_name) {
  194. anyhow::bail!("Job {} does not exist in the current job list.", job_name);
  195. }
  196. if let Err(_) = get_job_by_name_and_scheduled_at(&db, job_name, &when).await {
  197. // mean there's no job already in the db with that name and scheduled_at
  198. insert_job(&db, job_name, &when, &job_metadata).await?;
  199. }
  200. Ok(())
  201. }
  202. pub async fn run_scheduled_jobs(ctx: &Context, db: &DbClient) -> anyhow::Result<()> {
  203. let jobs = get_jobs_to_execute(&db).await.unwrap();
  204. tracing::trace!("jobs to execute: {:#?}", jobs);
  205. for job in jobs.iter() {
  206. update_job_executed_at(&db, &job.id).await?;
  207. match handle_job(&ctx, &job.name, &job.metadata).await {
  208. Ok(_) => {
  209. tracing::trace!("job successfully executed (id={})", job.id);
  210. delete_job(&db, &job.id).await?;
  211. }
  212. Err(e) => {
  213. tracing::error!("job failed on execution (id={:?}, error={:?})", job.id, e);
  214. update_job_error_message(&db, &job.id, &e.to_string()).await?;
  215. }
  216. }
  217. }
  218. Ok(())
  219. }
  220. // Try to handle a specific job
  221. async fn handle_job(
  222. ctx: &Context,
  223. name: &String,
  224. metadata: &serde_json::Value,
  225. ) -> anyhow::Result<()> {
  226. for job in jobs() {
  227. if &job.name() == &name {
  228. return job.run(ctx, metadata).await;
  229. }
  230. }
  231. tracing::trace!(
  232. "handle_job fell into default case: (name={:?}, metadata={:?})",
  233. name,
  234. metadata
  235. );
  236. Ok(())
  237. }
  238. static MIGRATIONS: &[&str] = &[
  239. "
  240. CREATE TABLE notifications (
  241. notification_id BIGSERIAL PRIMARY KEY,
  242. user_id BIGINT,
  243. origin_url TEXT NOT NULL,
  244. origin_html TEXT,
  245. time TIMESTAMP WITH TIME ZONE
  246. );
  247. ",
  248. "
  249. CREATE TABLE users (
  250. user_id BIGINT PRIMARY KEY,
  251. username TEXT NOT NULL
  252. );
  253. ",
  254. "ALTER TABLE notifications ADD COLUMN short_description TEXT;",
  255. "ALTER TABLE notifications ADD COLUMN team_name TEXT;",
  256. "ALTER TABLE notifications ADD COLUMN idx INTEGER;",
  257. "ALTER TABLE notifications ADD COLUMN metadata TEXT;",
  258. "
  259. CREATE TABLE rustc_commits (
  260. sha TEXT PRIMARY KEY,
  261. parent_sha TEXT NOT NULL,
  262. time TIMESTAMP WITH TIME ZONE
  263. );
  264. ",
  265. "ALTER TABLE rustc_commits ADD COLUMN pr INTEGER;",
  266. "
  267. CREATE TABLE issue_data (
  268. repo TEXT,
  269. issue_number INTEGER,
  270. key TEXT,
  271. data JSONB,
  272. PRIMARY KEY (repo, issue_number, key)
  273. );
  274. ",
  275. "
  276. CREATE TABLE jobs (
  277. id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
  278. name TEXT NOT NULL,
  279. scheduled_at TIMESTAMP WITH TIME ZONE NOT NULL,
  280. metadata JSONB,
  281. executed_at TIMESTAMP WITH TIME ZONE,
  282. error_message TEXT
  283. );
  284. ",
  285. "
  286. CREATE UNIQUE INDEX jobs_name_scheduled_at_unique_index
  287. ON jobs (
  288. name, scheduled_at
  289. );
  290. ",
  291. "
  292. CREATE table review_prefs (
  293. id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
  294. user_id BIGINT REFERENCES users(user_id),
  295. assigned_prs INT[] NOT NULL DEFAULT array[]::INT[]
  296. );",
  297. "
  298. CREATE EXTENSION intarray;
  299. CREATE UNIQUE INDEX review_prefs_user_id ON review_prefs(user_id);
  300. ",
  301. ];