1use super::{BorrowedBuf, BufReader, BufWriter, DEFAULT_BUF_SIZE, Read, Result, Write};
2use crate::alloc::Allocator;
3use crate::cmp;
4use crate::collections::VecDeque;
5use crate::io::IoSlice;
6use crate::mem::MaybeUninit;
7
8#[cfg(test)]
9mod tests;
10
11#[stable(feature = "rust1", since = "1.0.0")]
61pub fn copy<R: ?Sized, W: ?Sized>(reader: &mut R, writer: &mut W) -> Result<u64>
62where
63    R: Read,
64    W: Write,
65{
66    cfg_select! {
67        any(target_os = "linux", target_os = "android") => {
68            crate::sys::kernel_copy::copy_spec(reader, writer)
69        }
70        _ => {
71            generic_copy(reader, writer)
72        }
73    }
74}
75
76pub(crate) fn generic_copy<R: ?Sized, W: ?Sized>(reader: &mut R, writer: &mut W) -> Result<u64>
79where
80    R: Read,
81    W: Write,
82{
83    let read_buf = BufferedReaderSpec::buffer_size(reader);
84    let write_buf = BufferedWriterSpec::buffer_size(writer);
85
86    if read_buf >= DEFAULT_BUF_SIZE && read_buf >= write_buf {
87        return BufferedReaderSpec::copy_to(reader, writer);
88    }
89
90    BufferedWriterSpec::copy_from(writer, reader)
91}
92
93trait BufferedReaderSpec {
97    fn buffer_size(&self) -> usize;
98
99    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64>;
100}
101
102impl<T> BufferedReaderSpec for T
103where
104    Self: Read,
105    T: ?Sized,
106{
107    #[inline]
108    default fn buffer_size(&self) -> usize {
109        0
110    }
111
112    default fn copy_to(&mut self, _to: &mut (impl Write + ?Sized)) -> Result<u64> {
113        unreachable!("only called from specializations")
114    }
115}
116
117impl BufferedReaderSpec for &[u8] {
118    fn buffer_size(&self) -> usize {
119        usize::MAX
122    }
123
124    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64> {
125        let len = self.len();
126        to.write_all(self)?;
127        *self = &self[len..];
128        Ok(len as u64)
129    }
130}
131
132impl<A: Allocator> BufferedReaderSpec for VecDeque<u8, A> {
133    fn buffer_size(&self) -> usize {
134        usize::MAX
137    }
138
139    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64> {
140        let len = self.len();
141        let (front, back) = self.as_slices();
142        let bufs = &mut [IoSlice::new(front), IoSlice::new(back)];
143        to.write_all_vectored(bufs)?;
144        self.clear();
145        Ok(len as u64)
146    }
147}
148
149impl<I> BufferedReaderSpec for BufReader<I>
150where
151    Self: Read,
152    I: ?Sized,
153{
154    fn buffer_size(&self) -> usize {
155        self.capacity()
156    }
157
158    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64> {
159        let mut len = 0;
160
161        loop {
162            match self.read(&mut []) {
167                Ok(_) => {}
168                Err(e) if e.is_interrupted() => continue,
169                Err(e) => return Err(e),
170            }
171            let buf = self.buffer();
172            if self.buffer().len() == 0 {
173                return Ok(len);
174            }
175
176            to.write_all(buf)?;
182            len += buf.len() as u64;
183            self.discard_buffer();
184        }
185    }
186}
187
188trait BufferedWriterSpec: Write {
191    fn buffer_size(&self) -> usize;
192
193    fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64>;
194}
195
196impl<W: Write + ?Sized> BufferedWriterSpec for W {
197    #[inline]
198    default fn buffer_size(&self) -> usize {
199        0
200    }
201
202    default fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64> {
203        stack_buffer_copy(reader, self)
204    }
205}
206
207impl<I: Write + ?Sized> BufferedWriterSpec for BufWriter<I> {
208    fn buffer_size(&self) -> usize {
209        self.capacity()
210    }
211
212    fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64> {
213        if self.capacity() < DEFAULT_BUF_SIZE {
214            return stack_buffer_copy(reader, self);
215        }
216
217        let mut len = 0;
218        let mut init = 0;
219
220        loop {
221            let buf = self.buffer_mut();
222            let mut read_buf: BorrowedBuf<'_> = buf.spare_capacity_mut().into();
223
224            unsafe {
225                read_buf.set_init(init);
227            }
228
229            if read_buf.capacity() >= DEFAULT_BUF_SIZE {
230                let mut cursor = read_buf.unfilled();
231                match reader.read_buf(cursor.reborrow()) {
232                    Ok(()) => {
233                        let bytes_read = cursor.written();
234
235                        if bytes_read == 0 {
236                            return Ok(len);
237                        }
238
239                        init = read_buf.init_len() - bytes_read;
240                        len += bytes_read as u64;
241
242                        unsafe { buf.set_len(buf.len() + bytes_read) };
244
245                        }
248                    Err(ref e) if e.is_interrupted() => {}
249                    Err(e) => return Err(e),
250                }
251            } else {
252                init += buf.len();
255
256                self.flush_buf()?;
257            }
258        }
259    }
260}
261
262impl BufferedWriterSpec for Vec<u8> {
263    fn buffer_size(&self) -> usize {
264        cmp::max(DEFAULT_BUF_SIZE, self.capacity() - self.len())
265    }
266
267    fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64> {
268        reader.read_to_end(self).map(|bytes| u64::try_from(bytes).expect("usize overflowed u64"))
269    }
270}
271
272pub fn stack_buffer_copy<R: Read + ?Sized, W: Write + ?Sized>(
273    reader: &mut R,
274    writer: &mut W,
275) -> Result<u64> {
276    let buf: &mut [_] = &mut [MaybeUninit::uninit(); DEFAULT_BUF_SIZE];
277    let mut buf: BorrowedBuf<'_> = buf.into();
278
279    let mut len = 0;
280
281    loop {
282        match reader.read_buf(buf.unfilled()) {
283            Ok(()) => {}
284            Err(e) if e.is_interrupted() => continue,
285            Err(e) => return Err(e),
286        };
287
288        if buf.filled().is_empty() {
289            break;
290        }
291
292        len += buf.filled().len() as u64;
293        writer.write_all(buf.filled())?;
294        buf.clear();
295    }
296
297    Ok(len)
298}