github.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. use anyhow::Context;
  2. use chrono::{FixedOffset, Utc};
  3. use futures::stream::{FuturesUnordered, StreamExt};
  4. use once_cell::sync::OnceCell;
  5. use reqwest::header::{AUTHORIZATION, USER_AGENT};
  6. use reqwest::{Client, RequestBuilder, Response, StatusCode};
  7. use std::fmt;
  8. #[derive(Debug, PartialEq, Eq, serde::Deserialize)]
  9. pub struct User {
  10. pub login: String,
  11. pub id: Option<i64>,
  12. }
  13. impl GithubClient {
  14. async fn _send_req(&self, req: RequestBuilder) -> Result<(Response, String), reqwest::Error> {
  15. log::debug!("_send_req with {:?}", req);
  16. let req = req.build()?;
  17. let req_dbg = format!("{:?}", req);
  18. let resp = self.client.execute(req).await?;
  19. resp.error_for_status_ref()?;
  20. Ok((resp, req_dbg))
  21. }
  22. async fn send_req(&self, req: RequestBuilder) -> anyhow::Result<Vec<u8>> {
  23. let (mut resp, req_dbg) = self._send_req(req).await?;
  24. let mut body = Vec::new();
  25. while let Some(chunk) = resp.chunk().await.transpose() {
  26. let chunk = chunk
  27. .context("reading stream failed")
  28. .map_err(anyhow::Error::from)
  29. .context(req_dbg.clone())?;
  30. body.extend_from_slice(&chunk);
  31. }
  32. Ok(body)
  33. }
  34. pub async fn json<T>(&self, req: RequestBuilder) -> anyhow::Result<T>
  35. where
  36. T: serde::de::DeserializeOwned,
  37. {
  38. let (resp, req_dbg) = self._send_req(req).await?;
  39. Ok(resp.json().await.context(req_dbg)?)
  40. }
  41. }
  42. impl User {
  43. pub async fn current(client: &GithubClient) -> anyhow::Result<Self> {
  44. client.json(client.get("https://api.github.com/user")).await
  45. }
  46. pub async fn is_team_member<'a>(&'a self, client: &'a GithubClient) -> anyhow::Result<bool> {
  47. let url = format!("{}/teams.json", rust_team_data::v1::BASE_URL);
  48. let permission: rust_team_data::v1::Teams = client
  49. .json(client.raw().get(&url))
  50. .await
  51. .context("could not get team data")?;
  52. let map = permission.teams;
  53. let is_triager = map
  54. .get("wg-triage")
  55. .map_or(false, |w| w.members.iter().any(|g| g.github == self.login));
  56. Ok(map["all"].members.iter().any(|g| g.github == self.login) || is_triager)
  57. }
  58. // Returns the ID of the given user, if the user is in the `all` team.
  59. pub async fn get_id<'a>(&'a self, client: &'a GithubClient) -> anyhow::Result<Option<usize>> {
  60. let url = format!("{}/teams.json", rust_team_data::v1::BASE_URL);
  61. let permission: rust_team_data::v1::Teams = client
  62. .json(client.raw().get(&url))
  63. .await
  64. .context("could not get team data")?;
  65. let map = permission.teams;
  66. Ok(map["all"]
  67. .members
  68. .iter()
  69. .find(|g| g.github == self.login)
  70. .map(|u| u.github_id))
  71. }
  72. }
  73. pub async fn get_team(
  74. client: &GithubClient,
  75. team: &str,
  76. ) -> anyhow::Result<Option<rust_team_data::v1::Team>> {
  77. let url = format!("{}/teams.json", rust_team_data::v1::BASE_URL);
  78. let permission: rust_team_data::v1::Teams = client
  79. .json(client.raw().get(&url))
  80. .await
  81. .context("could not get team data")?;
  82. let mut map = permission.teams;
  83. Ok(map.swap_remove(team))
  84. }
  85. #[derive(PartialEq, Eq, Debug, Clone, serde::Deserialize)]
  86. pub struct Label {
  87. pub name: String,
  88. }
  89. impl Label {
  90. async fn exists<'a>(&'a self, repo_api_prefix: &'a str, client: &'a GithubClient) -> bool {
  91. #[allow(clippy::redundant_pattern_matching)]
  92. let url = format!("{}/labels/{}", repo_api_prefix, self.name);
  93. match client.send_req(client.get(&url)).await {
  94. Ok(_) => true,
  95. // XXX: Error handling if the request failed for reasons beyond 'label didn't exist'
  96. Err(_) => false,
  97. }
  98. }
  99. }
  100. #[derive(Debug, serde::Deserialize)]
  101. pub struct PullRequestDetails {
  102. // none for now
  103. }
  104. #[derive(Debug, serde::Deserialize)]
  105. pub struct Issue {
  106. pub number: u64,
  107. pub body: String,
  108. created_at: chrono::DateTime<Utc>,
  109. pub title: String,
  110. html_url: String,
  111. pub user: User,
  112. labels: Vec<Label>,
  113. assignees: Vec<User>,
  114. pull_request: Option<PullRequestDetails>,
  115. // API URL
  116. repository_url: String,
  117. comments_url: String,
  118. #[serde(skip)]
  119. repository: OnceCell<IssueRepository>,
  120. }
  121. #[derive(Debug, serde::Deserialize)]
  122. pub struct Comment {
  123. pub body: String,
  124. pub html_url: String,
  125. pub user: User,
  126. #[serde(alias = "submitted_at")] // for pull request reviews
  127. pub updated_at: chrono::DateTime<Utc>,
  128. }
  129. #[derive(Debug)]
  130. pub enum AssignmentError {
  131. InvalidAssignee,
  132. Http(reqwest::Error),
  133. }
  134. #[derive(Debug)]
  135. pub enum Selection<'a, T> {
  136. All,
  137. One(&'a T),
  138. }
  139. impl fmt::Display for AssignmentError {
  140. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
  141. match self {
  142. AssignmentError::InvalidAssignee => write!(f, "invalid assignee"),
  143. AssignmentError::Http(e) => write!(f, "cannot assign: {}", e),
  144. }
  145. }
  146. }
  147. impl std::error::Error for AssignmentError {}
  148. #[derive(Debug)]
  149. pub struct IssueRepository {
  150. pub organization: String,
  151. pub repository: String,
  152. }
  153. impl fmt::Display for IssueRepository {
  154. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
  155. write!(f, "{}/{}", self.organization, self.repository)
  156. }
  157. }
  158. impl Issue {
  159. pub fn repository(&self) -> &IssueRepository {
  160. self.repository.get_or_init(|| {
  161. log::trace!("get repository for {}", self.repository_url);
  162. let url = url::Url::parse(&self.repository_url).unwrap();
  163. let mut segments = url.path_segments().unwrap();
  164. let repository = segments.nth_back(0).unwrap();
  165. let organization = segments.nth_back(1).unwrap();
  166. IssueRepository {
  167. organization: organization.into(),
  168. repository: repository.into(),
  169. }
  170. })
  171. }
  172. pub fn global_id(&self) -> String {
  173. format!("{}#{}", self.repository(), self.number)
  174. }
  175. pub fn is_pr(&self) -> bool {
  176. self.pull_request.is_some()
  177. }
  178. pub async fn get_comment(&self, client: &GithubClient, id: usize) -> anyhow::Result<Comment> {
  179. let comment_url = format!("{}/issues/comments/{}", self.repository_url, id);
  180. let comment = client.json(client.get(&comment_url)).await?;
  181. Ok(comment)
  182. }
  183. pub async fn edit_body(&self, client: &GithubClient, body: &str) -> anyhow::Result<()> {
  184. let edit_url = format!("{}/issues/{}", self.repository_url, self.number);
  185. #[derive(serde::Serialize)]
  186. struct ChangedIssue<'a> {
  187. body: &'a str,
  188. }
  189. client
  190. ._send_req(client.patch(&edit_url).json(&ChangedIssue { body }))
  191. .await
  192. .context("failed to edit issue body")?;
  193. Ok(())
  194. }
  195. pub async fn edit_comment(
  196. &self,
  197. client: &GithubClient,
  198. id: usize,
  199. new_body: &str,
  200. ) -> anyhow::Result<()> {
  201. let comment_url = format!("{}/issues/comments/{}", self.repository_url, id);
  202. #[derive(serde::Serialize)]
  203. struct NewComment<'a> {
  204. body: &'a str,
  205. }
  206. client
  207. ._send_req(
  208. client
  209. .patch(&comment_url)
  210. .json(&NewComment { body: new_body }),
  211. )
  212. .await
  213. .context("failed to edit comment")?;
  214. Ok(())
  215. }
  216. pub async fn post_comment(&self, client: &GithubClient, body: &str) -> anyhow::Result<()> {
  217. #[derive(serde::Serialize)]
  218. struct PostComment<'a> {
  219. body: &'a str,
  220. }
  221. client
  222. ._send_req(client.post(&self.comments_url).json(&PostComment { body }))
  223. .await
  224. .context("failed to post comment")?;
  225. Ok(())
  226. }
  227. pub async fn set_labels(
  228. &self,
  229. client: &GithubClient,
  230. labels: Vec<Label>,
  231. ) -> anyhow::Result<()> {
  232. log::info!("set_labels {} to {:?}", self.global_id(), labels);
  233. // PUT /repos/:owner/:repo/issues/:number/labels
  234. // repo_url = https://api.github.com/repos/Codertocat/Hello-World
  235. let url = format!(
  236. "{repo_url}/issues/{number}/labels",
  237. repo_url = self.repository_url,
  238. number = self.number
  239. );
  240. let mut stream = labels
  241. .into_iter()
  242. .map(|label| async { (label.exists(&self.repository_url, &client).await, label) })
  243. .collect::<FuturesUnordered<_>>();
  244. let mut labels = Vec::new();
  245. while let Some((true, label)) = stream.next().await {
  246. labels.push(label);
  247. }
  248. #[derive(serde::Serialize)]
  249. struct LabelsReq {
  250. labels: Vec<String>,
  251. }
  252. client
  253. ._send_req(client.put(&url).json(&LabelsReq {
  254. labels: labels.iter().map(|l| l.name.clone()).collect(),
  255. }))
  256. .await
  257. .context("failed to set labels")?;
  258. Ok(())
  259. }
  260. pub fn labels(&self) -> &[Label] {
  261. &self.labels
  262. }
  263. pub fn contain_assignee(&self, user: &User) -> bool {
  264. self.assignees.contains(user)
  265. }
  266. pub async fn remove_assignees(
  267. &self,
  268. client: &GithubClient,
  269. selection: Selection<'_, User>,
  270. ) -> Result<(), AssignmentError> {
  271. log::info!("remove {:?} assignees for {}", selection, self.global_id());
  272. let url = format!(
  273. "{repo_url}/issues/{number}/assignees",
  274. repo_url = self.repository_url,
  275. number = self.number
  276. );
  277. let assignees = match selection {
  278. Selection::All => self
  279. .assignees
  280. .iter()
  281. .map(|u| u.login.as_str())
  282. .collect::<Vec<_>>(),
  283. Selection::One(user) => vec![user.login.as_str()],
  284. };
  285. #[derive(serde::Serialize)]
  286. struct AssigneeReq<'a> {
  287. assignees: &'a [&'a str],
  288. }
  289. client
  290. ._send_req(client.delete(&url).json(&AssigneeReq {
  291. assignees: &assignees[..],
  292. }))
  293. .await
  294. .map_err(AssignmentError::Http)?;
  295. Ok(())
  296. }
  297. pub async fn set_assignee(
  298. &self,
  299. client: &GithubClient,
  300. user: &str,
  301. ) -> Result<(), AssignmentError> {
  302. log::info!("set_assignee for {} to {}", self.global_id(), user);
  303. let url = format!(
  304. "{repo_url}/issues/{number}/assignees",
  305. repo_url = self.repository_url,
  306. number = self.number
  307. );
  308. let check_url = format!(
  309. "{repo_url}/assignees/{name}",
  310. repo_url = self.repository_url,
  311. name = user,
  312. );
  313. match client._send_req(client.get(&check_url)).await {
  314. Ok((resp, _)) => {
  315. if resp.status() == reqwest::StatusCode::NO_CONTENT {
  316. // all okay
  317. log::debug!("set_assignee: assignee is valid");
  318. } else {
  319. log::error!(
  320. "unknown status for assignee check, assuming all okay: {:?}",
  321. resp
  322. );
  323. }
  324. }
  325. Err(e) => {
  326. if e.status() == Some(reqwest::StatusCode::NOT_FOUND) {
  327. log::debug!("set_assignee: assignee is invalid, returning");
  328. return Err(AssignmentError::InvalidAssignee);
  329. }
  330. log::debug!("set_assignee: get {} failed, {:?}", check_url, e);
  331. return Err(AssignmentError::Http(e));
  332. }
  333. }
  334. self.remove_assignees(client, Selection::All).await?;
  335. #[derive(serde::Serialize)]
  336. struct AssigneeReq<'a> {
  337. assignees: &'a [&'a str],
  338. }
  339. client
  340. ._send_req(client.post(&url).json(&AssigneeReq { assignees: &[user] }))
  341. .await
  342. .map_err(AssignmentError::Http)?;
  343. Ok(())
  344. }
  345. }
  346. #[derive(PartialEq, Eq, Debug, serde::Deserialize)]
  347. #[serde(rename_all = "lowercase")]
  348. pub enum PullRequestReviewAction {
  349. Submitted,
  350. Edited,
  351. Dismissed,
  352. }
  353. #[derive(Debug, serde::Deserialize)]
  354. pub struct PullRequestReviewEvent {
  355. pub action: PullRequestReviewAction,
  356. pub pull_request: Issue,
  357. pub review: Comment,
  358. pub repository: Repository,
  359. }
  360. #[derive(Debug, serde::Deserialize)]
  361. pub struct PullRequestReviewComment {
  362. pub action: IssueCommentAction,
  363. #[serde(rename = "pull_request")]
  364. pub issue: Issue,
  365. pub comment: Comment,
  366. pub repository: Repository,
  367. }
  368. #[derive(PartialEq, Eq, Debug, serde::Deserialize)]
  369. #[serde(rename_all = "lowercase")]
  370. pub enum IssueCommentAction {
  371. Created,
  372. Edited,
  373. Deleted,
  374. }
  375. #[derive(Debug, serde::Deserialize)]
  376. pub struct IssueCommentEvent {
  377. pub action: IssueCommentAction,
  378. pub issue: Issue,
  379. pub comment: Comment,
  380. pub repository: Repository,
  381. }
  382. #[derive(PartialEq, Eq, Debug, serde::Deserialize)]
  383. #[serde(rename_all = "lowercase")]
  384. pub enum IssuesAction {
  385. Opened,
  386. Edited,
  387. Deleted,
  388. Transferred,
  389. Pinned,
  390. Unpinned,
  391. Closed,
  392. Reopened,
  393. Assigned,
  394. Unassigned,
  395. Labeled,
  396. Unlabeled,
  397. Locked,
  398. Unlocked,
  399. Milestoned,
  400. Demilestoned,
  401. }
  402. #[derive(Debug, serde::Deserialize)]
  403. pub struct IssuesEvent {
  404. pub action: IssuesAction,
  405. pub issue: Issue,
  406. pub repository: Repository,
  407. }
  408. #[derive(Debug, serde::Deserialize)]
  409. pub struct Repository {
  410. pub full_name: String,
  411. }
  412. #[derive(Debug)]
  413. pub enum Event {
  414. IssueComment(IssueCommentEvent),
  415. Issue(IssuesEvent),
  416. }
  417. impl Event {
  418. pub fn repo_name(&self) -> &str {
  419. match self {
  420. Event::IssueComment(event) => &event.repository.full_name,
  421. Event::Issue(event) => &event.repository.full_name,
  422. }
  423. }
  424. pub fn issue(&self) -> Option<&Issue> {
  425. match self {
  426. Event::IssueComment(event) => Some(&event.issue),
  427. Event::Issue(event) => Some(&event.issue),
  428. }
  429. }
  430. /// This will both extract from IssueComment events but also Issue events
  431. pub fn comment_body(&self) -> Option<&str> {
  432. match self {
  433. Event::Issue(e) => Some(&e.issue.body),
  434. Event::IssueComment(e) => Some(&e.comment.body),
  435. }
  436. }
  437. pub fn html_url(&self) -> Option<&str> {
  438. match self {
  439. Event::Issue(e) => Some(&e.issue.html_url),
  440. Event::IssueComment(e) => Some(&e.comment.html_url),
  441. }
  442. }
  443. pub fn user(&self) -> &User {
  444. match self {
  445. Event::Issue(e) => &e.issue.user,
  446. Event::IssueComment(e) => &e.comment.user,
  447. }
  448. }
  449. pub fn time(&self) -> chrono::DateTime<FixedOffset> {
  450. match self {
  451. Event::Issue(e) => e.issue.created_at.into(),
  452. Event::IssueComment(e) => e.comment.updated_at.into(),
  453. }
  454. }
  455. }
  456. trait RequestSend: Sized {
  457. fn configure(self, g: &GithubClient) -> Self;
  458. }
  459. impl RequestSend for RequestBuilder {
  460. fn configure(self, g: &GithubClient) -> RequestBuilder {
  461. self.header(USER_AGENT, "rust-lang-triagebot")
  462. .header(AUTHORIZATION, format!("token {}", g.token))
  463. }
  464. }
  465. #[derive(Clone)]
  466. pub struct GithubClient {
  467. token: String,
  468. client: Client,
  469. }
  470. impl GithubClient {
  471. pub fn new(client: Client, token: String) -> Self {
  472. GithubClient { client, token }
  473. }
  474. pub fn raw(&self) -> &Client {
  475. &self.client
  476. }
  477. pub async fn raw_file(
  478. &self,
  479. repo: &str,
  480. branch: &str,
  481. path: &str,
  482. ) -> anyhow::Result<Option<Vec<u8>>> {
  483. let url = format!(
  484. "https://raw.githubusercontent.com/{}/{}/{}",
  485. repo, branch, path
  486. );
  487. let req = self.get(&url);
  488. let req_dbg = format!("{:?}", req);
  489. let req = req
  490. .build()
  491. .with_context(|| format!("failed to build request {:?}", req_dbg))?;
  492. let mut resp = self.client.execute(req).await.context(req_dbg.clone())?;
  493. let status = resp.status();
  494. match status {
  495. StatusCode::OK => {
  496. let mut buf = Vec::with_capacity(resp.content_length().unwrap_or(4) as usize);
  497. while let Some(chunk) = resp.chunk().await.transpose() {
  498. let chunk = chunk
  499. .context("reading stream failed")
  500. .map_err(anyhow::Error::from)
  501. .context(req_dbg.clone())?;
  502. buf.extend_from_slice(&chunk);
  503. }
  504. Ok(Some(buf))
  505. }
  506. StatusCode::NOT_FOUND => Ok(None),
  507. status => anyhow::bail!("failed to GET {}: {}", url, status),
  508. }
  509. }
  510. fn get(&self, url: &str) -> RequestBuilder {
  511. log::trace!("get {:?}", url);
  512. self.client.get(url).configure(self)
  513. }
  514. fn patch(&self, url: &str) -> RequestBuilder {
  515. log::trace!("patch {:?}", url);
  516. self.client.patch(url).configure(self)
  517. }
  518. fn delete(&self, url: &str) -> RequestBuilder {
  519. log::trace!("delete {:?}", url);
  520. self.client.delete(url).configure(self)
  521. }
  522. fn post(&self, url: &str) -> RequestBuilder {
  523. log::trace!("post {:?}", url);
  524. self.client.post(url).configure(self)
  525. }
  526. fn put(&self, url: &str) -> RequestBuilder {
  527. log::trace!("put {:?}", url);
  528. self.client.put(url).configure(self)
  529. }
  530. }