1#![deny(warnings)]
6
7use serde::{Deserialize, Serialize};
8use serde_json as json;
9use thiserror::Error;
10use tracing::debug;
11use tracing_subscriber::EnvFilter;
12
13use std::collections::HashSet;
14use std::env;
15use std::ffi::OsStr;
16use std::io::{self, BufRead};
17use std::process;
18
19use regex::Regex;
20
21use clap::{CommandFactory, Parser};
22
23const DEFAULT_PATTERN: &str = r".*\.rs";
27
28#[derive(Error, Debug)]
29enum FormatDiffError {
30 #[error("{0}")]
31 IncorrectOptions(#[from] getopts::Fail),
32 #[error("{0}")]
33 IncorrectFilter(#[from] regex::Error),
34 #[error("{0}")]
35 IoError(#[from] io::Error),
36}
37
38#[derive(Parser, Debug)]
39#[command(
40 name = "rustfmt-format-diff",
41 disable_version_flag = true,
42 next_line_help = true
43)]
44pub struct Opts {
45 #[arg(
47 short = 'p',
48 long = "skip-prefix",
49 value_name = "NUMBER",
50 default_value = "0"
51 )]
52 skip_prefix: u32,
53
54 #[arg(
56 short = 'f',
57 long = "filter",
58 value_name = "PATTERN",
59 default_value = DEFAULT_PATTERN
60 )]
61 filter: String,
62}
63
64fn main() {
65 tracing_subscriber::fmt()
66 .with_env_filter(EnvFilter::from_env("RUSTFMT_LOG"))
67 .init();
68 let opts = Opts::parse();
69 if let Err(e) = run(opts) {
70 println!("{e}");
71 Opts::command()
72 .print_help()
73 .expect("cannot write to stdout");
74 process::exit(1);
75 }
76}
77
78#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
79struct Range {
80 file: String,
81 range: [u32; 2],
82}
83
84fn run(opts: Opts) -> Result<(), FormatDiffError> {
85 let (files, ranges) = scan_diff(io::stdin(), opts.skip_prefix, &opts.filter)?;
86 run_rustfmt(&files, &ranges)
87}
88
89fn run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError> {
90 if files.is_empty() || ranges.is_empty() {
91 debug!("No files to format found");
92 return Ok(());
93 }
94
95 let ranges_as_json = json::to_string(ranges).unwrap();
96
97 debug!("Files: {:?}", files);
98 debug!("Ranges: {:?}", ranges);
99
100 let rustfmt_var = env::var_os("RUSTFMT");
101 let rustfmt = match &rustfmt_var {
102 Some(rustfmt) => rustfmt,
103 None => OsStr::new("rustfmt"),
104 };
105 let exit_status = process::Command::new(rustfmt)
106 .args(files)
107 .arg("--file-lines")
108 .arg(ranges_as_json)
109 .status()?;
110
111 if !exit_status.success() {
112 return Err(FormatDiffError::IoError(io::Error::new(
113 io::ErrorKind::Other,
114 format!("rustfmt failed with {exit_status}"),
115 )));
116 }
117 Ok(())
118}
119
120fn scan_diff<R>(
123 from: R,
124 skip_prefix: u32,
125 file_filter: &str,
126) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError>
127where
128 R: io::Read,
129{
130 let diff_pattern = format!(r"^\+\+\+\s(?:.*?/){{{skip_prefix}}}(\S*)");
131 let diff_pattern = Regex::new(&diff_pattern).unwrap();
132
133 let lines_pattern = Regex::new(r"^@@.*\+(\d+)(,(\d+))?").unwrap();
134
135 let file_filter = Regex::new(&format!("^{file_filter}$"))?;
136
137 let mut current_file = None;
138
139 let mut files = HashSet::new();
140 let mut ranges = vec![];
141 for line in io::BufReader::new(from).lines() {
142 let line = line.unwrap();
143
144 if let Some(captures) = diff_pattern.captures(&line) {
145 current_file = Some(captures.get(1).unwrap().as_str().to_owned());
146 }
147
148 let file = match current_file {
149 Some(ref f) => &**f,
150 None => continue,
151 };
152
153 if !file_filter.is_match(file) {
156 continue;
157 }
158
159 let lines_captures = match lines_pattern.captures(&line) {
160 Some(captures) => captures,
161 None => continue,
162 };
163
164 let start_line = lines_captures
165 .get(1)
166 .unwrap()
167 .as_str()
168 .parse::<u32>()
169 .unwrap();
170 let line_count = match lines_captures.get(3) {
171 Some(line_count) => line_count.as_str().parse::<u32>().unwrap(),
172 None => 1,
173 };
174
175 if line_count == 0 {
176 continue;
177 }
178
179 let end_line = start_line + line_count - 1;
180 files.insert(file.to_owned());
181 ranges.push(Range {
182 file: file.to_owned(),
183 range: [start_line, end_line],
184 });
185 }
186
187 Ok((files, ranges))
188}
189
190#[test]
191fn scan_simple_git_diff() {
192 const DIFF: &str = include_str!("test/bindgen.diff");
193 let (files, ranges) = scan_diff(DIFF.as_bytes(), 1, r".*\.rs").expect("scan_diff failed?");
194
195 assert!(
196 files.contains("src/ir/traversal.rs"),
197 "Should've matched the filter"
198 );
199
200 assert!(
201 !files.contains("tests/headers/anon_enum.hpp"),
202 "Shouldn't have matched the filter"
203 );
204
205 assert_eq!(
206 &ranges,
207 &[
208 Range {
209 file: "src/ir/item.rs".to_owned(),
210 range: [148, 158],
211 },
212 Range {
213 file: "src/ir/item.rs".to_owned(),
214 range: [160, 170],
215 },
216 Range {
217 file: "src/ir/traversal.rs".to_owned(),
218 range: [9, 16],
219 },
220 Range {
221 file: "src/ir/traversal.rs".to_owned(),
222 range: [35, 43],
223 },
224 ]
225 );
226}
227
228#[cfg(test)]
229mod cmd_line_tests {
230 use super::*;
231
232 #[test]
233 fn default_options() {
234 let empty: Vec<String> = vec![];
235 let o = Opts::parse_from(empty);
236 assert_eq!(DEFAULT_PATTERN, o.filter);
237 assert_eq!(0, o.skip_prefix);
238 }
239
240 #[test]
241 fn good_options() {
242 let o = Opts::parse_from(["test", "-p", "10", "-f", r".*\.hs"]);
243 assert_eq!(r".*\.hs", o.filter);
244 assert_eq!(10, o.skip_prefix);
245 }
246
247 #[test]
248 fn unexpected_option() {
249 assert!(
250 Opts::command()
251 .try_get_matches_from(["test", "unexpected"])
252 .is_err()
253 );
254 }
255
256 #[test]
257 fn unexpected_flag() {
258 assert!(
259 Opts::command()
260 .try_get_matches_from(["test", "--flag"])
261 .is_err()
262 );
263 }
264
265 #[test]
266 fn overridden_option() {
267 assert!(
268 Opts::command()
269 .try_get_matches_from(["test", "-p", "10", "-p", "20"])
270 .is_err()
271 );
272 }
273
274 #[test]
275 fn negative_filter() {
276 assert!(
277 Opts::command()
278 .try_get_matches_from(["test", "-p", "-1"])
279 .is_err()
280 );
281 }
282}