Eric Huss пре 1 година
родитељ
комит
fa84cebec4
2 измењених фајлова са 219 додато и 44 уклоњено
  1. 152 39
      src/github.rs
  2. 67 5
      src/handlers/milestone_prs.rs

+ 152 - 39
src/github.rs

@@ -426,6 +426,10 @@ impl IssueRepository {
         )
     }
 
+    fn full_repo_name(&self) -> String {
+        format!("{}/{}", self.organization, self.repository)
+    }
+
     async fn has_label(&self, client: &GithubClient, label: &str) -> anyhow::Result<bool> {
         #[allow(clippy::redundant_pattern_matching)]
         let url = format!("{}/labels/{}", self.url(), label);
@@ -745,6 +749,10 @@ impl Issue {
         Ok(())
     }
 
+    /// Sets the milestone of the issue or PR.
+    ///
+    /// This will create the milestone if it does not exist. The new milestone
+    /// will start in the "open" state.
     pub async fn set_milestone(&self, client: &GithubClient, title: &str) -> anyhow::Result<()> {
         log::trace!(
             "Setting milestone for rust-lang/rust#{} to {}",
@@ -752,42 +760,14 @@ impl Issue {
             title
         );
 
-        let create_url = format!("{}/milestones", self.repository().url());
-        let resp = client
-            .send_req(
-                client
-                    .post(&create_url)
-                    .body(serde_json::to_vec(&MilestoneCreateBody { title }).unwrap()),
-            )
-            .await;
-        // Explicitly do *not* try to return Err(...) if this fails -- that's
-        // fine, it just means the milestone was already created.
-        log::trace!("Created milestone: {:?}", resp);
-
-        let list_url = format!("{}/milestones", self.repository().url());
-        let milestone_list: Vec<Milestone> = client.json(client.get(&list_url)).await?;
-        let milestone_no = if let Some(milestone) = milestone_list.iter().find(|v| v.title == title)
-        {
-            milestone.number
-        } else {
-            anyhow::bail!(
-                "Despite just creating milestone {} on {}, it does not exist?",
-                title,
-                self.repository()
-            )
-        };
+        let full_repo_name = self.repository().full_repo_name();
+        let milestone = client
+            .get_or_create_milestone(&full_repo_name, title, "open")
+            .await?;
 
-        #[derive(serde::Serialize)]
-        struct SetMilestone {
-            milestone: u64,
-        }
-        let url = format!("{}/issues/{}", self.repository().url(), self.number);
         client
-            .send_req(client.patch(&url).json(&SetMilestone {
-                milestone: milestone_no,
-            }))
-            .await
-            .context("failed to set milestone")?;
+            .set_milestone(&full_repo_name, &milestone, self.number)
+            .await?;
         Ok(())
     }
 
@@ -886,11 +866,6 @@ pub struct PullRequestFile {
     pub blob_url: String,
 }
 
-#[derive(serde::Serialize)]
-struct MilestoneCreateBody<'a> {
-    title: &'a str,
-}
-
 #[derive(Debug, serde::Deserialize)]
 pub struct Milestone {
     number: u64,
@@ -1246,6 +1221,33 @@ impl Repository {
         )
     }
 
+    /// Returns a list of commits between the SHA ranges of start (exclusive)
+    /// and end (inclusive).
+    pub async fn commits_in_range(
+        &self,
+        client: &GithubClient,
+        start: &str,
+        end: &str,
+    ) -> anyhow::Result<Vec<GithubCommit>> {
+        let mut commits = Vec::new();
+        let mut page = 1;
+        loop {
+            let url = format!("{}/commits?sha={end}&per_page=100&page={page}", self.url());
+            let mut this_page: Vec<GithubCommit> = client
+                .json(client.get(&url))
+                .await
+                .with_context(|| format!("failed to fetch commits for {url}"))?;
+            if let Some(idx) = this_page.iter().position(|commit| commit.sha == start) {
+                this_page.truncate(idx);
+                commits.extend(this_page);
+                return Ok(commits);
+            } else {
+                commits.extend(this_page);
+            }
+            page += 1;
+        }
+    }
+
     /// Retrieves a git commit for the given SHA.
     pub async fn git_commit(&self, client: &GithubClient, sha: &str) -> anyhow::Result<GitCommit> {
         let url = format!("{}/git/commits/{sha}", self.url());
@@ -1616,6 +1618,40 @@ impl Repository {
             })?;
         Ok(())
     }
+
+    /// Get or create a [`Milestone`].
+    ///
+    /// This will not change the state if it already exists.
+    pub async fn get_or_create_milestone(
+        &self,
+        client: &GithubClient,
+        title: &str,
+        state: &str,
+    ) -> anyhow::Result<Milestone> {
+        client
+            .get_or_create_milestone(&self.full_name, title, state)
+            .await
+    }
+
+    /// Set the milestone of an issue or PR.
+    pub async fn set_milestone(
+        &self,
+        client: &GithubClient,
+        milestone: &Milestone,
+        issue_num: u64,
+    ) -> anyhow::Result<()> {
+        client
+            .set_milestone(&self.full_name, milestone, issue_num)
+            .await
+    }
+
+    pub async fn get_issue(&self, client: &GithubClient, issue_num: u64) -> anyhow::Result<Issue> {
+        let url = format!("{}/pulls/{issue_num}", self.url());
+        client
+            .json(client.get(&url))
+            .await
+            .with_context(|| format!("{} failed to get issue {issue_num}", self.full_name))
+    }
 }
 
 pub struct Query<'a> {
@@ -2126,6 +2162,83 @@ impl GithubClient {
             .await
             .with_context(|| format!("{} failed to get repo", full_name))
     }
+
+    /// Get or create a [`Milestone`].
+    ///
+    /// This will not change the state if it already exists.
+    async fn get_or_create_milestone(
+        &self,
+        full_repo_name: &str,
+        title: &str,
+        state: &str,
+    ) -> anyhow::Result<Milestone> {
+        let url = format!(
+            "{}/repos/{full_repo_name}/milestones",
+            Repository::GITHUB_API_URL
+        );
+        let resp = self
+            .send_req(self.post(&url).json(&serde_json::json!({
+                "title": title,
+                "state": state,
+            })))
+            .await;
+        match resp {
+            Ok((body, _dbg)) => {
+                let milestone = serde_json::from_slice(&body)?;
+                log::trace!("Created milestone: {milestone:?}");
+                return Ok(milestone);
+            }
+            Err(e) => {
+                if e.downcast_ref::<reqwest::Error>().map_or(false, |e| {
+                    matches!(e.status(), Some(StatusCode::UNPROCESSABLE_ENTITY))
+                }) {
+                    // fall-through, it already exists
+                } else {
+                    return Err(e.context(format!(
+                        "failed to create milestone {url} with title {title}"
+                    )));
+                }
+            }
+        }
+        // In the case where it already exists, we need to search for its number.
+        let mut page = 1;
+        loop {
+            let url = format!(
+                "{}/repos/{full_repo_name}/milestones?page={page}&state=all",
+                Repository::GITHUB_API_URL
+            );
+            let milestones: Vec<Milestone> = self
+                .json(self.get(&url))
+                .await
+                .with_context(|| format!("failed to get milestones {url} searching for {title}"))?;
+            if milestones.is_empty() {
+                anyhow::bail!("expected to find milestone with title {title}");
+            }
+            if let Some(milestone) = milestones.into_iter().find(|m| m.title == title) {
+                return Ok(milestone);
+            }
+            page += 1;
+        }
+    }
+
+    /// Set the milestone of an issue or PR.
+    async fn set_milestone(
+        &self,
+        full_repo_name: &str,
+        milestone: &Milestone,
+        issue_num: u64,
+    ) -> anyhow::Result<()> {
+        let url = format!(
+            "{}/repos/{full_repo_name}/issues/{issue_num}",
+            Repository::GITHUB_API_URL
+        );
+        self.send_req(self.patch(&url).json(&serde_json::json!({
+            "milestone": milestone.number
+        })))
+        .await
+        .with_context(|| format!("failed to set milestone for {url} to milestone {milestone:?}"))?;
+        Ok(())
+    }
 }
 
 #[derive(Debug, serde::Deserialize)]

+ 67 - 5
src/handlers/milestone_prs.rs

@@ -1,8 +1,9 @@
 use crate::{
-    github::{Event, IssuesAction},
+    github::{Event, GithubClient, IssuesAction},
     handlers::Context,
 };
 use anyhow::Context as _;
+use regex::Regex;
 use reqwest::StatusCode;
 use tracing as log;
 
@@ -42,7 +43,7 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
     };
 
     // Fetch the version from the upstream repository.
-    let version = if let Some(version) = get_version_standalone(ctx, merge_sha).await? {
+    let version = if let Some(version) = get_version_standalone(&ctx.github, merge_sha).await? {
         version
     } else {
         log::error!("could not find the version of {:?}", merge_sha);
@@ -62,12 +63,21 @@ pub async fn handle(ctx: &Context, event: &Event) -> anyhow::Result<()> {
     // eventually automate it separately.
     e.issue.set_milestone(&ctx.github, &version).await?;
 
+    let files = e.issue.diff(&ctx.github).await?;
+    if let Some(files) = files {
+        if let Some(cargo) = files.iter().find(|fd| fd.path == "src/tools/cargo") {
+            milestone_cargo(&ctx.github, &version, &cargo.diff).await?;
+        }
+    }
+
     Ok(())
 }
 
-async fn get_version_standalone(ctx: &Context, merge_sha: &str) -> anyhow::Result<Option<String>> {
-    let resp = ctx
-        .github
+async fn get_version_standalone(
+    gh: &GithubClient,
+    merge_sha: &str,
+) -> anyhow::Result<Option<String>> {
+    let resp = gh
         .raw()
         .get(&format!(
             "https://raw.githubusercontent.com/rust-lang/rust/{}/src/version",
@@ -96,3 +106,55 @@ async fn get_version_standalone(ctx: &Context, merge_sha: &str) -> anyhow::Resul
             .to_string(),
     ))
 }
+
+/// Milestones all PRs in the cargo repo when the submodule is synced in
+/// rust-lang/rust.
+async fn milestone_cargo(
+    gh: &GithubClient,
+    release_version: &str,
+    submodule_diff: &str,
+) -> anyhow::Result<()> {
+    // Determine the start/end range of commits in this submodule update by
+    // looking at the diff content which indicates the old and new hash.
+    let subproject_re = Regex::new("Subproject commit ([0-9a-f]+)").unwrap();
+    let mut caps = subproject_re.captures_iter(submodule_diff);
+    let cargo_start_hash = &caps.next().unwrap()[1];
+    let cargo_end_hash = &caps.next().unwrap()[1];
+    assert!(caps.next().is_none());
+
+    // Get all of the git commits in the cargo repo.
+    let cargo_repo = gh.repository("rust-lang/cargo").await?;
+    let commits = cargo_repo
+        .commits_in_range(gh, cargo_start_hash, cargo_end_hash)
+        .await?;
+
+    // For each commit, look for a message from bors that indicates which
+    // PR was merged.
+    //
+    // GitHub has a specific API for this at
+    // /repos/{owner}/{repo}/commits/{commit_sha}/pulls
+    // <https://docs.github.com/en/rest/commits/commits?apiVersion=2022-11-28#list-pull-requests-associated-with-a-commit>,
+    // but it is a little awkward to use, only works on the default branch,
+    // and this is a bit simpler/faster. However, it is sensitive to the
+    // specific messages generated by bors, and won't catch things merged
+    // without bors.
+    let merge_re = Regex::new("(?:Auto merge of|Merge pull request) #([0-9]+)").unwrap();
+
+    let pr_nums = commits.iter().filter_map(|commit| {
+        merge_re.captures(&commit.commit.message).map(|cap| {
+            cap.get(1)
+                .unwrap()
+                .as_str()
+                .parse::<u64>()
+                .expect("digits only")
+        })
+    });
+    let milestone = cargo_repo
+        .get_or_create_milestone(gh, release_version, "closed")
+        .await?;
+    for pr_num in pr_nums {
+        cargo_repo.set_milestone(gh, &milestone, pr_num).await?;
+    }
+
+    Ok(())
+}