1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use std::{
    fmt::{self, Debug, Formatter},
    io::Read,
};

/// The error reported when trying to read more bytes than were available.
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
#[error("Expected at least {expected} bytes, but only {actual} bytes were found")]
pub(crate) struct InvalidSize {
    pub(crate) expected: usize,
    pub(crate) actual: usize,
}

/// A cheaply copyable, incremental byte reader.
#[derive(Clone, PartialEq, Eq)]
pub(crate) struct Scanner<'buf> {
    rest: &'buf [u8],
    current_position: usize,
}

impl<'buf> Scanner<'buf> {
    pub(crate) fn new(bytes: &'buf [u8]) -> Self {
        Scanner {
            rest: bytes,
            current_position: 0,
        }
    }

    pub(crate) fn with_current_position(self, current_position: usize) -> Self {
        Scanner {
            current_position,
            ..self
        }
    }

    pub(crate) fn current_position(&self) -> usize {
        self.current_position
    }

    /// The un-scanned bytes.
    pub(crate) fn rest(&self) -> &'buf [u8] {
        self.rest
    }

    pub(crate) fn is_empty(&self) -> bool {
        self.rest().is_empty()
    }

    /// Take a certain number of bytes from the start of the buffer, advancing
    /// the [`Scanner`] if the read was successful.
    pub(crate) fn take(&mut self, len: usize) -> Result<&'buf [u8], InvalidSize> {
        if self.rest.len() < len {
            Err(InvalidSize {
                expected: self.current_position + len,
                actual: self.current_position + self.rest.len(),
            })
        } else {
            let (bytes, rest) = self.rest.split_at(len);
            self.rest = rest;
            self.current_position += len;

            Ok(bytes)
        }
    }

    /// Split off the next `len` bytes into their own [`Scanner`], advancing
    /// the current [`Scanner`] past the bytes.
    pub(crate) fn split_off(&mut self, len: usize) -> Result<Self, InvalidSize> {
        let current_position = self.current_position();

        if len > self.rest().len() {
            return Err(InvalidSize {
                expected: current_position + len,
                actual: current_position + self.rest().len(),
            });
        }

        let (head, tail) = self.rest().split_at(len);

        *self = Scanner {
            rest: tail,
            current_position: current_position + len,
        };
        Ok(Scanner {
            rest: head,
            current_position,
        })
    }

    /// Get a copy of this [`Scanner`] which can only read up to `len` bytes.
    pub(crate) fn truncated(&self, len: usize) -> Result<Self, InvalidSize> {
        self.clone().split_off(len)
    }

    /// Read an array from the buffer by value.
    pub(crate) fn read<const LEN: usize>(&mut self) -> Result<[u8; LEN], InvalidSize>
    where
        [u8; LEN]: Copy,
    {
        self.read_ref().copied()
    }

    pub(crate) fn read_usize(&mut self) -> Result<usize, InvalidSize> {
        let bytes = self.read()?;
        Ok(u64::from_le_bytes(bytes).try_into().unwrap())
    }

    /// Read an array from the buffer by reference.
    pub(crate) fn read_ref<const LEN: usize>(&mut self) -> Result<&'buf [u8; LEN], InvalidSize> {
        self.take(LEN)
            .map(|bytes| bytes.try_into().expect("Already checked"))
    }
}

impl Read for Scanner<'_> {
    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
        let rest = self.rest();

        let bytes_read = std::cmp::min(rest.len(), buf.len());
        let buffer = self.take(bytes_read).expect("unreachable");

        buf.copy_from_slice(buffer);
        Ok(bytes_read)
    }
}

impl Debug for Scanner<'_> {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        let Scanner {
            rest,
            current_position,
        } = self;

        f.debug_struct("Scanner")
            .field(
                "rest",
                &TruncatedBuffer {
                    buffer: rest,
                    length: 32,
                },
            )
            .field("current_position", current_position)
            .finish()
    }
}

struct TruncatedBuffer<'a> {
    buffer: &'a [u8],
    length: usize,
}

impl Debug for TruncatedBuffer<'_> {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        let TruncatedBuffer { buffer, length } = *self;

        match buffer.get(..length) {
            Some(truncated) => write!(f, "{truncated:?}..."),
            None => write!(f, "{buffer:?}"),
        }
    }
}