rustc_data_structures/sync/
parallel.rs1use std::any::Any;
5use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
6
7use parking_lot::Mutex;
8
9use crate::FatalErrorMarker;
10use crate::sync::{DynSend, DynSync, FromDyn, IntoDynSyncSend, mode};
11
12pub struct ParallelGuard {
18 panic: Mutex<Option<IntoDynSyncSend<Box<dyn Any + Send + 'static>>>>,
19}
20
21impl ParallelGuard {
22 pub fn run<R>(&self, f: impl FnOnce() -> R) -> Option<R> {
23 catch_unwind(AssertUnwindSafe(f))
24 .map_err(|err| {
25 let mut panic = self.panic.lock();
26 if panic.is_none() || !(*err).is::<FatalErrorMarker>() {
27 *panic = Some(IntoDynSyncSend(err));
28 }
29 })
30 .ok()
31 }
32}
33
34#[inline]
37pub fn parallel_guard<R>(f: impl FnOnce(&ParallelGuard) -> R) -> R {
38 let guard = ParallelGuard { panic: Mutex::new(None) };
39 let ret = f(&guard);
40 if let Some(IntoDynSyncSend(panic)) = guard.panic.into_inner() {
41 resume_unwind(panic);
42 }
43 ret
44}
45
46fn serial_join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
47where
48 A: FnOnce() -> RA,
49 B: FnOnce() -> RB,
50{
51 let (a, b) = parallel_guard(|guard| {
52 let a = guard.run(oper_a);
53 let b = guard.run(oper_b);
54 (a, b)
55 });
56 (a.unwrap(), b.unwrap())
57}
58
59pub fn spawn(func: impl FnOnce() + DynSend + 'static) {
60 if let Some(proof) = mode::check_dyn_thread_safe() {
61 let func = proof.derive(func);
62 rustc_thread_pool::spawn(|| {
63 (func.into_inner())();
64 });
65 } else {
66 func()
67 }
68}
69
70pub fn par_fns(funcs: &mut [&mut (dyn FnMut() + DynSend)]) {
75 parallel_guard(|guard: &ParallelGuard| {
76 if let Some(proof) = mode::check_dyn_thread_safe() {
77 let funcs = proof.derive(funcs);
78 rustc_thread_pool::scope(|s| {
79 let Some((first, rest)) = funcs.into_inner().split_at_mut_checked(1) else {
80 return;
81 };
82
83 for f in rest.iter_mut().rev() {
87 let f = proof.derive(f);
88 s.spawn(|_| {
89 guard.run(|| (f.into_inner())());
90 });
91 }
92
93 guard.run(|| first[0]());
96 });
97 } else {
98 for f in funcs {
99 guard.run(|| f());
100 }
101 }
102 });
103}
104
105#[inline]
106pub fn par_join<A, B, RA: DynSend, RB: DynSend>(oper_a: A, oper_b: B) -> (RA, RB)
107where
108 A: FnOnce() -> RA + DynSend,
109 B: FnOnce() -> RB + DynSend,
110{
111 if let Some(proof) = mode::check_dyn_thread_safe() {
112 let oper_a = proof.derive(oper_a);
113 let oper_b = proof.derive(oper_b);
114 let (a, b) = parallel_guard(|guard| {
115 rustc_thread_pool::join(
116 move || guard.run(move || proof.derive(oper_a.into_inner()())),
117 move || guard.run(move || proof.derive(oper_b.into_inner()())),
118 )
119 });
120 (a.unwrap().into_inner(), b.unwrap().into_inner())
121 } else {
122 serial_join(oper_a, oper_b)
123 }
124}
125
126fn par_slice<I: DynSend>(
127 items: &mut [I],
128 guard: &ParallelGuard,
129 for_each: impl Fn(&mut I) + DynSync + DynSend,
130 proof: FromDyn<()>,
131) {
132 match items {
133 [] => return,
134 [item] => {
135 guard.run(|| for_each(item));
136 return;
137 }
138 _ => (),
139 }
140
141 let for_each = proof.derive(for_each);
142 let mut items = for_each.derive(items);
143 rustc_thread_pool::scope(|s| {
144 let proof = items.derive(());
145
146 const MAX_GROUP_COUNT: usize = 128;
147 let group_size = items.len().div_ceil(MAX_GROUP_COUNT);
148 let mut groups = items.chunks_mut(group_size);
149
150 let Some(first_group) = groups.next() else { return };
151
152 for group in groups.rev() {
156 let group = proof.derive(group);
157 s.spawn(|_| {
158 let mut group = group;
159 for i in group.iter_mut() {
160 guard.run(|| for_each(i));
161 }
162 });
163 }
164
165 for i in first_group.iter_mut() {
167 guard.run(|| for_each(i));
168 }
169 });
170}
171
172pub fn par_for_each_in<I: DynSend, T: IntoIterator<Item = I>>(
173 t: T,
174 for_each: impl Fn(&I) + DynSync + DynSend,
175) {
176 parallel_guard(|guard| {
177 if let Some(proof) = mode::check_dyn_thread_safe() {
178 let mut items: Vec<_> = t.into_iter().collect();
179 par_slice(&mut items, guard, |i| for_each(&*i), proof)
180 } else {
181 t.into_iter().for_each(|i| {
182 guard.run(|| for_each(&i));
183 });
184 }
185 });
186}
187
188pub fn try_par_for_each_in<T: IntoIterator, E: DynSend>(
193 t: T,
194 for_each: impl Fn(&<T as IntoIterator>::Item) -> Result<(), E> + DynSync + DynSend,
195) -> Result<(), E>
196where
197 <T as IntoIterator>::Item: DynSend,
198{
199 parallel_guard(|guard| {
200 if let Some(proof) = mode::check_dyn_thread_safe() {
201 let mut items: Vec<_> = t.into_iter().collect();
202
203 let error = Mutex::new(None);
204
205 par_slice(
206 &mut items,
207 guard,
208 |i| {
209 if let Err(err) = for_each(&*i) {
210 *error.lock() = Some(err);
211 }
212 },
213 proof,
214 );
215
216 if let Some(err) = error.into_inner() { Err(err) } else { Ok(()) }
217 } else {
218 t.into_iter().filter_map(|i| guard.run(|| for_each(&i))).fold(Ok(()), Result::and)
219 }
220 })
221}
222
223pub fn par_map<I: DynSend, T: IntoIterator<Item = I>, R: DynSend, C: FromIterator<R>>(
224 t: T,
225 map: impl Fn(I) -> R + DynSync + DynSend,
226) -> C {
227 parallel_guard(|guard| {
228 if let Some(proof) = mode::check_dyn_thread_safe() {
229 let map = proof.derive(map);
230
231 let mut items: Vec<(Option<I>, Option<R>)> =
232 t.into_iter().map(|i| (Some(i), None)).collect();
233
234 par_slice(
235 &mut items,
236 guard,
237 |i| {
238 i.1 = Some(map(i.0.take().unwrap()));
239 },
240 proof,
241 );
242
243 items.into_iter().filter_map(|i| i.1).collect()
244 } else {
245 t.into_iter().filter_map(|i| guard.run(|| map(i))).collect()
246 }
247 })
248}
249
250pub fn broadcast<R: DynSend>(op: impl Fn(usize) -> R + DynSync) -> Vec<R> {
251 if let Some(proof) = mode::check_dyn_thread_safe() {
252 let op = proof.derive(op);
253 let results = rustc_thread_pool::broadcast(|context| op.derive(op(context.index())));
254 results.into_iter().map(|r| r.into_inner()).collect()
255 } else {
256 ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[op(0)]))vec![op(0)]
257 }
258}