rustfmt_format_diff/
main.rs

1// Inspired by Clang's clang-format-diff:
2//
3// https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/clang-format-diff.py
4
5#![deny(warnings)]
6
7use serde::{Deserialize, Serialize};
8use serde_json as json;
9use thiserror::Error;
10use tracing::debug;
11use tracing_subscriber::EnvFilter;
12
13use std::collections::HashSet;
14use std::env;
15use std::ffi::OsStr;
16use std::io::{self, BufRead};
17use std::process;
18
19use regex::Regex;
20
21use clap::{CommandFactory, Parser};
22
23/// The default pattern of files to format.
24///
25/// We only want to format rust files by default.
26const DEFAULT_PATTERN: &str = r".*\.rs";
27
28#[derive(Error, Debug)]
29enum FormatDiffError {
30    #[error("{0}")]
31    IncorrectOptions(#[from] getopts::Fail),
32    #[error("{0}")]
33    IncorrectFilter(#[from] regex::Error),
34    #[error("{0}")]
35    IoError(#[from] io::Error),
36}
37
38#[derive(Parser, Debug)]
39#[command(
40    name = "rustfmt-format-diff",
41    disable_version_flag = true,
42    next_line_help = true
43)]
44pub struct Opts {
45    /// Skip the smallest prefix containing NUMBER slashes
46    #[arg(
47        short = 'p',
48        long = "skip-prefix",
49        value_name = "NUMBER",
50        default_value = "0"
51    )]
52    skip_prefix: u32,
53
54    /// Custom pattern selecting file paths to reformat
55    #[arg(
56        short = 'f',
57        long = "filter",
58        value_name = "PATTERN",
59        default_value = DEFAULT_PATTERN
60    )]
61    filter: String,
62}
63
64fn main() {
65    tracing_subscriber::fmt()
66        .with_env_filter(EnvFilter::from_env("RUSTFMT_LOG"))
67        .init();
68    let opts = Opts::parse();
69    if let Err(e) = run(opts) {
70        println!("{e}");
71        Opts::command()
72            .print_help()
73            .expect("cannot write to stdout");
74        process::exit(1);
75    }
76}
77
78#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
79struct Range {
80    file: String,
81    range: [u32; 2],
82}
83
84fn run(opts: Opts) -> Result<(), FormatDiffError> {
85    let (files, ranges) = scan_diff(io::stdin(), opts.skip_prefix, &opts.filter)?;
86    run_rustfmt(&files, &ranges)
87}
88
89fn run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError> {
90    if files.is_empty() || ranges.is_empty() {
91        debug!("No files to format found");
92        return Ok(());
93    }
94
95    let ranges_as_json = json::to_string(ranges).unwrap();
96
97    debug!("Files: {:?}", files);
98    debug!("Ranges: {:?}", ranges);
99
100    let rustfmt_var = env::var_os("RUSTFMT");
101    let rustfmt = match &rustfmt_var {
102        Some(rustfmt) => rustfmt,
103        None => OsStr::new("rustfmt"),
104    };
105    let exit_status = process::Command::new(rustfmt)
106        .args(files)
107        .arg("--file-lines")
108        .arg(ranges_as_json)
109        .status()?;
110
111    if !exit_status.success() {
112        return Err(FormatDiffError::IoError(io::Error::new(
113            io::ErrorKind::Other,
114            format!("rustfmt failed with {exit_status}"),
115        )));
116    }
117    Ok(())
118}
119
120/// Scans a diff from `from`, and returns the set of files found, and the ranges
121/// in those files.
122fn scan_diff<R>(
123    from: R,
124    skip_prefix: u32,
125    file_filter: &str,
126) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError>
127where
128    R: io::Read,
129{
130    let diff_pattern = format!(r"^\+\+\+\s(?:.*?/){{{skip_prefix}}}(\S*)");
131    let diff_pattern = Regex::new(&diff_pattern).unwrap();
132
133    let lines_pattern = Regex::new(r"^@@.*\+(\d+)(,(\d+))?").unwrap();
134
135    let file_filter = Regex::new(&format!("^{file_filter}$"))?;
136
137    let mut current_file = None;
138
139    let mut files = HashSet::new();
140    let mut ranges = vec![];
141    for line in io::BufReader::new(from).lines() {
142        let line = line.unwrap();
143
144        if let Some(captures) = diff_pattern.captures(&line) {
145            current_file = Some(captures.get(1).unwrap().as_str().to_owned());
146        }
147
148        let file = match current_file {
149            Some(ref f) => &**f,
150            None => continue,
151        };
152
153        // FIXME(emilio): We could avoid this most of the time if needed, but
154        // it's not clear it's worth it.
155        if !file_filter.is_match(file) {
156            continue;
157        }
158
159        let lines_captures = match lines_pattern.captures(&line) {
160            Some(captures) => captures,
161            None => continue,
162        };
163
164        let start_line = lines_captures
165            .get(1)
166            .unwrap()
167            .as_str()
168            .parse::<u32>()
169            .unwrap();
170        let line_count = match lines_captures.get(3) {
171            Some(line_count) => line_count.as_str().parse::<u32>().unwrap(),
172            None => 1,
173        };
174
175        if line_count == 0 {
176            continue;
177        }
178
179        let end_line = start_line + line_count - 1;
180        files.insert(file.to_owned());
181        ranges.push(Range {
182            file: file.to_owned(),
183            range: [start_line, end_line],
184        });
185    }
186
187    Ok((files, ranges))
188}
189
190#[test]
191fn scan_simple_git_diff() {
192    const DIFF: &str = include_str!("test/bindgen.diff");
193    let (files, ranges) = scan_diff(DIFF.as_bytes(), 1, r".*\.rs").expect("scan_diff failed?");
194
195    assert!(
196        files.contains("src/ir/traversal.rs"),
197        "Should've matched the filter"
198    );
199
200    assert!(
201        !files.contains("tests/headers/anon_enum.hpp"),
202        "Shouldn't have matched the filter"
203    );
204
205    assert_eq!(
206        &ranges,
207        &[
208            Range {
209                file: "src/ir/item.rs".to_owned(),
210                range: [148, 158],
211            },
212            Range {
213                file: "src/ir/item.rs".to_owned(),
214                range: [160, 170],
215            },
216            Range {
217                file: "src/ir/traversal.rs".to_owned(),
218                range: [9, 16],
219            },
220            Range {
221                file: "src/ir/traversal.rs".to_owned(),
222                range: [35, 43],
223            },
224        ]
225    );
226}
227
228#[cfg(test)]
229mod cmd_line_tests {
230    use super::*;
231
232    #[test]
233    fn default_options() {
234        let empty: Vec<String> = vec![];
235        let o = Opts::parse_from(empty);
236        assert_eq!(DEFAULT_PATTERN, o.filter);
237        assert_eq!(0, o.skip_prefix);
238    }
239
240    #[test]
241    fn good_options() {
242        let o = Opts::parse_from(["test", "-p", "10", "-f", r".*\.hs"]);
243        assert_eq!(r".*\.hs", o.filter);
244        assert_eq!(10, o.skip_prefix);
245    }
246
247    #[test]
248    fn unexpected_option() {
249        assert!(
250            Opts::command()
251                .try_get_matches_from(["test", "unexpected"])
252                .is_err()
253        );
254    }
255
256    #[test]
257    fn unexpected_flag() {
258        assert!(
259            Opts::command()
260                .try_get_matches_from(["test", "--flag"])
261                .is_err()
262        );
263    }
264
265    #[test]
266    fn overridden_option() {
267        assert!(
268            Opts::command()
269                .try_get_matches_from(["test", "-p", "10", "-p", "20"])
270                .is_err()
271        );
272    }
273
274    #[test]
275    fn negative_filter() {
276        assert!(
277            Opts::command()
278                .try_get_matches_from(["test", "-p", "-1"])
279                .is_err()
280        );
281    }
282}