Эх сурвалжийг харах

Merge pull request #1749 from ehuss/diff-only-once

Only fetch PR diff once.
Eric Huss 1 жил өмнө
parent
commit
0a68321ca1

+ 73 - 28
src/github.rs

@@ -232,7 +232,30 @@ pub struct Label {
 /// needed at this time (merged_at, diff_url, html_url, patch_url, url).
 #[derive(Debug, serde::Deserialize)]
 pub struct PullRequestDetails {
-    // none for now
+    /// This is a slot to hold the diff for a PR.
+    ///
+    /// This will be filled in only once as an optimization since multiple
+    /// handlers want to see PR changes, and getting the diff can be
+    /// expensive.
+    #[serde(skip)]
+    files_changed: tokio::sync::OnceCell<Vec<FileDiff>>,
+}
+
+/// Representation of a diff to a single file.
+#[derive(Debug)]
+pub struct FileDiff {
+    /// The full path of the file.
+    pub path: String,
+    /// The diff for the file.
+    pub diff: String,
+}
+
+impl PullRequestDetails {
+    pub fn new() -> PullRequestDetails {
+        PullRequestDetails {
+            files_changed: tokio::sync::OnceCell::new(),
+        }
+    }
 }
 
 /// An issue or pull request.
@@ -786,23 +809,33 @@ impl Issue {
     }
 
     /// Returns the diff in this event, for Open and Synchronize events for now.
-    pub async fn diff(&self, client: &GithubClient) -> anyhow::Result<Option<String>> {
+    ///
+    /// Returns `None` if the issue is not a PR.
+    pub async fn diff(&self, client: &GithubClient) -> anyhow::Result<Option<&[FileDiff]>> {
+        let Some(pr) = &self.pull_request else {
+            return Ok(None);
+        };
         let (before, after) = if let (Some(base), Some(head)) = (&self.base, &self.head) {
-            (base.sha.clone(), head.sha.clone())
+            (&base.sha, &head.sha)
         } else {
             return Ok(None);
         };
 
-        let mut req = client.get(&format!(
-            "{}/compare/{}...{}",
-            self.repository().url(),
-            before,
-            after
-        ));
-        req = req.header("Accept", "application/vnd.github.v3.diff");
-        let (diff, _) = client.send_req(req).await?;
-        let body = String::from_utf8_lossy(&diff).to_string();
-        Ok(Some(body))
+        let diff = pr
+            .files_changed
+            .get_or_try_init::<anyhow::Error, _, _>(|| async move {
+                let url = format!("{}/compare/{before}...{after}", self.repository().url());
+                let mut req = client.get(&url);
+                req = req.header("Accept", "application/vnd.github.v3.diff");
+                let (diff, _) = client
+                    .send_req(req)
+                    .await
+                    .with_context(|| format!("failed to fetch diff comparison for {url}"))?;
+                let body = String::from_utf8_lossy(&diff);
+                Ok(parse_diff(&body))
+            })
+            .await?;
+        Ok(Some(diff))
     }
 
     /// Returns the commits from this pull request (no commits are returned if this `Issue` is not
@@ -982,19 +1015,29 @@ pub struct CommitBase {
     pub repo: Repository,
 }
 
-pub fn files_changed(diff: &str) -> Vec<&str> {
-    let mut files = Vec::new();
-    for line in diff.lines() {
-        // mostly copied from highfive
-        if line.starts_with("diff --git ") {
-            files.push(
-                line[line.find(" b/").unwrap()..]
-                    .strip_prefix(" b/")
-                    .unwrap(),
-            );
-        }
-    }
+pub fn parse_diff(diff: &str) -> Vec<FileDiff> {
+    // This does not properly handle filenames with spaces.
+    let re = regex::Regex::new("(?m)^diff --git .* b/(.*)").unwrap();
+    let mut files: Vec<_> = re
+        .captures_iter(diff)
+        .map(|cap| {
+            let start = cap.get(0).unwrap().start();
+            let path = cap.get(1).unwrap().as_str().to_string();
+            (start, path)
+        })
+        .collect();
+    // Break the list up into (start, end) pairs starting from the "diff --git" line.
+    files.push((diff.len(), String::new()));
     files
+        .windows(2)
+        .map(|w| {
+            let (start, end) = (&w[0], &w[1]);
+            FileDiff {
+                path: start.1.clone(),
+                diff: diff[start.0..end.0].to_string(),
+            }
+        })
+        .collect()
 }
 
 #[derive(Debug, serde::Deserialize)]
@@ -1503,7 +1546,7 @@ impl Repository {
                     self.full_name
                 )
             })?;
-        issue.pull_request = Some(PullRequestDetails {});
+        issue.pull_request = Some(PullRequestDetails::new());
         Ok(issue)
     }
 
@@ -2484,7 +2527,8 @@ index fb9cee43b2d..b484c25ea51 100644
     zulip_stream = 245100 # #t-compiler/wg-prioritization/alerts
     topic = "#{number} {title}"
          "##;
-        assert_eq!(files_changed(input), vec!["triagebot.toml".to_string()]);
+        let files: Vec<_> = parse_diff(input).into_iter().map(|d| d.path).collect();
+        assert_eq!(files, vec!["triagebot.toml".to_string()]);
     }
 
     #[test]
@@ -2516,8 +2560,9 @@ index c58310947d2..3b0854d4a9b 100644
 }
 +
 "##;
+        let files: Vec<_> = parse_diff(input).into_iter().map(|d| d.path).collect();
         assert_eq!(
-            files_changed(input),
+            files,
             vec![
                 "library/stdarch".to_string(),
                 "src/librustdoc/clean/types.rs".to_string(),

+ 68 - 70
src/handlers/assign.rs

@@ -19,7 +19,7 @@
 
 use crate::{
     config::AssignConfig,
-    github::{self, Event, Issue, IssuesAction, Selection},
+    github::{self, Event, FileDiff, Issue, IssuesAction, Selection},
     handlers::{Context, GithubClient, IssuesEvent},
     interactions::EditIssueBody,
 };
@@ -80,13 +80,11 @@ struct AssignData {
 }
 
 /// Input for auto-assignment when a PR is created.
-pub(super) struct AssignInput {
-    git_diff: String,
-}
+pub(super) struct AssignInput {}
 
 /// Prepares the input when a new PR is opened.
 pub(super) async fn parse_input(
-    ctx: &Context,
+    _ctx: &Context,
     event: &IssuesEvent,
     config: Option<&AssignConfig>,
 ) -> Result<Option<AssignInput>, String> {
@@ -100,15 +98,7 @@ pub(super) async fn parse_input(
     {
         return Ok(None);
     }
-    let git_diff = match event.issue.diff(&ctx.github).await {
-        Ok(None) => return Ok(None),
-        Err(e) => {
-            log::error!("failed to fetch diff: {:?}", e);
-            return Ok(None);
-        }
-        Ok(Some(diff)) => diff,
-    };
-    Ok(Some(AssignInput { git_diff }))
+    Ok(Some(AssignInput {}))
 }
 
 /// Handles the work of setting an assignment for a new PR and posting a
@@ -117,11 +107,18 @@ pub(super) async fn handle_input(
     ctx: &Context,
     config: &AssignConfig,
     event: &IssuesEvent,
-    input: AssignInput,
+    _input: AssignInput,
 ) -> anyhow::Result<()> {
+    let Some(diff) = event.issue.diff(&ctx.github).await? else {
+        bail!(
+            "expected issue {} to be a PR, but the diff could not be determined",
+            event.issue.number
+        )
+    };
+
     // Don't auto-assign or welcome if the user manually set the assignee when opening.
     if event.issue.assignees.is_empty() {
-        let (assignee, from_comment) = determine_assignee(ctx, event, config, &input).await?;
+        let (assignee, from_comment) = determine_assignee(ctx, event, config, &diff).await?;
         if assignee.as_deref() == Some("ghost") {
             // "ghost" is GitHub's placeholder account for deleted accounts.
             // It is used here as a convenient way to prevent assignment. This
@@ -180,7 +177,7 @@ pub(super) async fn handle_input(
     if config.warn_non_default_branch {
         warnings.extend(non_default_branch(event));
     }
-    warnings.extend(modifies_submodule(&input.git_diff));
+    warnings.extend(modifies_submodule(diff));
     if !warnings.is_empty() {
         let warnings: Vec<_> = warnings
             .iter()
@@ -222,9 +219,9 @@ fn non_default_branch(event: &IssuesEvent) -> Option<String> {
 }
 
 /// Returns a message if the PR modifies a git submodule.
-fn modifies_submodule(diff: &str) -> Option<String> {
+fn modifies_submodule(diff: &[FileDiff]) -> Option<String> {
     let re = regex::Regex::new(r"\+Subproject\scommit\s").unwrap();
-    if re.is_match(diff) {
+    if diff.iter().any(|fd| re.is_match(&fd.diff)) {
         Some(SUBMODULE_WARNING_MSG.to_string())
     } else {
         None
@@ -278,7 +275,7 @@ async fn determine_assignee(
     ctx: &Context,
     event: &IssuesEvent,
     config: &AssignConfig,
-    input: &AssignInput,
+    diff: &[FileDiff],
 ) -> anyhow::Result<(Option<String>, bool)> {
     let teams = crate::team_data::teams(&ctx.github).await?;
     if let Some(name) = find_assign_command(ctx, event) {
@@ -298,7 +295,7 @@ async fn determine_assignee(
         }
     }
     // Errors fall-through to try fallback group.
-    match find_reviewers_from_diff(config, &input.git_diff) {
+    match find_reviewers_from_diff(config, diff) {
         Ok(candidates) if !candidates.is_empty() => {
             match find_reviewer_from_names(&teams, config, &event.issue, &candidates) {
                 Ok(assignee) => return Ok((Some(assignee), false)),
@@ -346,60 +343,61 @@ async fn determine_assignee(
 /// May return an error if the owners map is misconfigured.
 ///
 /// Beware this may return an empty list if nothing matches.
-fn find_reviewers_from_diff(config: &AssignConfig, diff: &str) -> anyhow::Result<Vec<String>> {
+fn find_reviewers_from_diff(
+    config: &AssignConfig,
+    diff: &[FileDiff],
+) -> anyhow::Result<Vec<String>> {
     // Map of `owners` path to the number of changes found in that path.
     // This weights the reviewer choice towards places where the most edits are done.
     let mut counts: HashMap<&str, u32> = HashMap::new();
-    // List of the longest `owners` patterns that match the current path. This
-    // prefers choosing reviewers from deeply nested paths over those defined
-    // for top-level paths, under the assumption that they are more
-    // specialized.
-    //
-    // This is a list to handle the situation if multiple paths of the same
-    // length match.
-    let mut longest_owner_patterns = Vec::new();
-    // Iterate over the diff, finding the start of each file. After each file
-    // is found, it counts the number of modified lines in that file, and
-    // tracks those in the `counts` map.
-    for line in diff.split('\n') {
-        if line.starts_with("diff --git ") {
-            // Start of a new file.
-            longest_owner_patterns.clear();
-            let path = line[line.find(" b/").unwrap()..]
-                .strip_prefix(" b/")
-                .unwrap();
-            // Find the longest `owners` entries that match this path.
-            let mut longest = HashMap::new();
-            for owner_pattern in config.owners.keys() {
-                let ignore = ignore::gitignore::GitignoreBuilder::new("/")
-                    .add_line(None, owner_pattern)
-                    .with_context(|| format!("owner file pattern `{owner_pattern}` is not valid"))?
-                    .build()?;
-                if ignore.matched_path_or_any_parents(path, false).is_ignore() {
-                    let owner_len = owner_pattern.split('/').count();
-                    longest.insert(owner_pattern, owner_len);
-                }
-            }
-            let max_count = longest.values().copied().max().unwrap_or(0);
-            longest_owner_patterns.extend(
-                longest
-                    .iter()
-                    .filter(|(_, count)| **count == max_count)
-                    .map(|x| *x.0),
-            );
-            // Give some weight to these patterns to start. This helps with
-            // files modified without any lines changed.
-            for owner_pattern in &longest_owner_patterns {
-                *counts.entry(owner_pattern).or_default() += 1;
+    // Iterate over the diff, counting the number of modified lines in each
+    // file, and tracks those in the `counts` map.
+    for file_diff in diff {
+        // List of the longest `owners` patterns that match the current path. This
+        // prefers choosing reviewers from deeply nested paths over those defined
+        // for top-level paths, under the assumption that they are more
+        // specialized.
+        //
+        // This is a list to handle the situation if multiple paths of the same
+        // length match.
+        let mut longest_owner_patterns = Vec::new();
+
+        // Find the longest `owners` entries that match this path.
+        let mut longest = HashMap::new();
+        for owner_pattern in config.owners.keys() {
+            let ignore = ignore::gitignore::GitignoreBuilder::new("/")
+                .add_line(None, owner_pattern)
+                .with_context(|| format!("owner file pattern `{owner_pattern}` is not valid"))?
+                .build()?;
+            if ignore
+                .matched_path_or_any_parents(&file_diff.path, false)
+                .is_ignore()
+            {
+                let owner_len = owner_pattern.split('/').count();
+                longest.insert(owner_pattern, owner_len);
             }
-            continue;
         }
-        // Check for a modified line.
-        if (!line.starts_with("+++") && line.starts_with('+'))
-            || (!line.starts_with("---") && line.starts_with('-'))
-        {
-            for owner_path in &longest_owner_patterns {
-                *counts.entry(owner_path).or_default() += 1;
+        let max_count = longest.values().copied().max().unwrap_or(0);
+        longest_owner_patterns.extend(
+            longest
+                .iter()
+                .filter(|(_, count)| **count == max_count)
+                .map(|x| *x.0),
+        );
+        // Give some weight to these patterns to start. This helps with
+        // files modified without any lines changed.
+        for owner_pattern in &longest_owner_patterns {
+            *counts.entry(owner_pattern).or_default() += 1;
+        }
+
+        // Count the modified lines.
+        for line in file_diff.diff.lines() {
+            if (!line.starts_with("+++") && line.starts_with('+'))
+                || (!line.starts_with("---") && line.starts_with('-'))
+            {
+                for owner_path in &longest_owner_patterns {
+                    *counts.entry(owner_path).or_default() += 1;
+                }
             }
         }
     }

+ 3 - 1
src/handlers/assign/tests/tests_from_diff.rs

@@ -2,12 +2,14 @@
 
 use super::super::*;
 use crate::config::AssignConfig;
+use crate::github::parse_diff;
 use std::fmt::Write;
 
 fn test_from_diff(diff: &str, config: toml::Value, expected: &[&str]) {
+    let files = parse_diff(diff);
     let aconfig: AssignConfig = config.try_into().unwrap();
     assert_eq!(
-        find_reviewers_from_diff(&aconfig, diff).unwrap(),
+        find_reviewers_from_diff(&aconfig, &files).unwrap(),
         expected.iter().map(|x| x.to_string()).collect::<Vec<_>>()
     );
 }

+ 3 - 4
src/handlers/autolabel.rs

@@ -1,6 +1,6 @@
 use crate::{
     config::AutolabelConfig,
-    github::{files_changed, IssuesAction, IssuesEvent, Label},
+    github::{IssuesAction, IssuesEvent, Label},
     handlers::Context,
 };
 use anyhow::Context as _;
@@ -27,7 +27,7 @@ pub(super) async fn parse_input(
     // remove. Not much can be done about that currently; the before/after on
     // synchronize may be straddling a rebase, which will break diff generation.
     if event.action == IssuesAction::Opened || event.action == IssuesAction::Synchronize {
-        let diff = event
+        let files = event
             .issue
             .diff(&ctx.github)
             .await
@@ -35,7 +35,6 @@ pub(super) async fn parse_input(
                 log::error!("failed to fetch diff: {:?}", e);
             })
             .unwrap_or_default();
-        let files = diff.as_deref().map(files_changed);
         let mut autolabels = Vec::new();
 
         'outer: for (label, cfg) in config.labels.iter() {
@@ -64,7 +63,7 @@ pub(super) async fn parse_input(
                 if cfg
                     .trigger_files
                     .iter()
-                    .any(|f| files.iter().any(|diff_file| diff_file.starts_with(f)))
+                    .any(|f| files.iter().any(|file_diff| file_diff.path.starts_with(f)))
                 {
                     autolabels.push(Label {
                         name: label.to_owned(),

+ 3 - 4
src/handlers/mentions.rs

@@ -5,7 +5,7 @@
 use crate::{
     config::{MentionsConfig, MentionsPathConfig},
     db::issue_data::IssueData,
-    github::{files_changed, IssuesAction, IssuesEvent},
+    github::{IssuesAction, IssuesEvent},
     handlers::Context,
 };
 use anyhow::Context as _;
@@ -50,7 +50,7 @@ pub(super) async fn parse_input(
         return Ok(None);
     }
 
-    if let Some(diff) = event
+    if let Some(files) = event
         .issue
         .diff(&ctx.github)
         .await
@@ -59,8 +59,7 @@ pub(super) async fn parse_input(
         })
         .unwrap_or_default()
     {
-        let files = files_changed(&diff);
-        let file_paths: Vec<_> = files.iter().map(|p| Path::new(p)).collect();
+        let file_paths: Vec<_> = files.iter().map(|fd| Path::new(&fd.path)).collect();
         let to_mention: Vec<_> = config
             .paths
             .iter()

+ 3 - 3
src/lib.rs

@@ -146,7 +146,7 @@ pub async fn webhook(
                 .map_err(anyhow::Error::from)?;
 
             log::info!("handling pull request review comment {:?}", payload);
-            payload.pull_request.pull_request = Some(PullRequestDetails {});
+            payload.pull_request.pull_request = Some(PullRequestDetails::new());
 
             // Treat pull request review comments exactly like pull request
             // review comments.
@@ -171,7 +171,7 @@ pub async fn webhook(
                 .context("PullRequestReview(Comment) failed to deserialize")
                 .map_err(anyhow::Error::from)?;
 
-            payload.issue.pull_request = Some(PullRequestDetails {});
+            payload.issue.pull_request = Some(PullRequestDetails::new());
 
             log::info!("handling pull request review comment {:?}", payload);
 
@@ -200,7 +200,7 @@ pub async fn webhook(
                 .map_err(anyhow::Error::from)?;
 
             if matches!(event, EventName::PullRequest) {
-                payload.issue.pull_request = Some(PullRequestDetails {});
+                payload.issue.pull_request = Some(PullRequestDetails::new());
             }
 
             log::info!("handling issue event {:?}", payload);