1use std::io::{self, Read, Take};
2
3#[derive(Debug)]
4pub struct LimitErrorReader<R> {
5 inner: Take<R>,
6}
7
8impl<R: Read> LimitErrorReader<R> {
9 pub fn new(r: R, limit: u64) -> LimitErrorReader<R> {
10 LimitErrorReader {
11 inner: r.take(limit),
12 }
13 }
14}
15
16impl<R: Read> Read for LimitErrorReader<R> {
17 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
18 match self.inner.read(buf) {
19 Ok(0) if self.inner.limit() == 0 => Err(io::Error::new(
20 io::ErrorKind::Other,
21 "maximum limit reached when reading",
22 )),
23 e => e,
24 }
25 }
26}
27
28#[cfg(test)]
29mod tests {
30 use super::LimitErrorReader;
31
32 use std::io::Read;
33
34 #[test]
35 fn under_the_limit() {
36 let buf = &[1; 7][..];
37 let mut r = LimitErrorReader::new(buf, 8);
38 let mut out = Vec::new();
39 assert!(matches!(r.read_to_end(&mut out), Ok(7)));
40 assert_eq!(buf, out.as_slice());
41 }
42
43 #[test]
44 #[should_panic = "maximum limit reached when reading"]
45 fn over_the_limit() {
46 let buf = &[1; 8][..];
47 let mut r = LimitErrorReader::new(buf, 8);
48 let mut out = Vec::new();
49 r.read_to_end(&mut out).unwrap();
50 }
51}