|
@@ -17,37 +17,11 @@ use crate::Result;
|
|
|
#[derive(Debug)]
|
|
|
pub struct PacketAssembler<'a> {
|
|
|
buffer: ManagedSlice<'a, u8>,
|
|
|
- assembler: AssemblerState,
|
|
|
-}
|
|
|
-
|
|
|
-/// Holds the state of the assembling of one packet.
|
|
|
-#[cfg_attr(feature = "defmt", derive(defmt::Format))]
|
|
|
-#[derive(Debug, PartialEq)]
|
|
|
-enum AssemblerState {
|
|
|
- NotInit,
|
|
|
- Assembling {
|
|
|
- assembler: Assembler,
|
|
|
- total_size: Option<usize>,
|
|
|
- expires_at: Instant,
|
|
|
- offset_correction: isize,
|
|
|
- },
|
|
|
-}
|
|
|
|
|
|
-impl fmt::Display for AssemblerState {
|
|
|
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
|
- match self {
|
|
|
- AssemblerState::NotInit => write!(f, "Not init")?,
|
|
|
- AssemblerState::Assembling {
|
|
|
- assembler,
|
|
|
- total_size,
|
|
|
- expires_at,
|
|
|
- offset_correction,
|
|
|
- } => {
|
|
|
- write!(f, "{assembler} expires at {expires_at}")?;
|
|
|
- }
|
|
|
- }
|
|
|
- Ok(())
|
|
|
- }
|
|
|
+ assembler: Assembler,
|
|
|
+ total_size: Option<usize>,
|
|
|
+ expires_at: Instant,
|
|
|
+ offset_correction: isize,
|
|
|
}
|
|
|
|
|
|
impl<'a> PacketAssembler<'a> {
|
|
@@ -59,10 +33,21 @@ impl<'a> PacketAssembler<'a> {
|
|
|
let s = storage.into();
|
|
|
PacketAssembler {
|
|
|
buffer: s,
|
|
|
- assembler: AssemblerState::NotInit,
|
|
|
+
|
|
|
+ assembler: Assembler::new(),
|
|
|
+ total_size: None,
|
|
|
+ expires_at: Instant::ZERO,
|
|
|
+ offset_correction: 0,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ pub(crate) fn reset(&mut self) {
|
|
|
+ self.assembler = Assembler::new();
|
|
|
+ self.total_size = None;
|
|
|
+ self.expires_at = Instant::ZERO;
|
|
|
+ self.offset_correction = 0;
|
|
|
+ }
|
|
|
+
|
|
|
/// Start with saving fragments.
|
|
|
/// We initialize the assembler with the total size of the final packet.
|
|
|
///
|
|
@@ -76,143 +61,102 @@ impl<'a> PacketAssembler<'a> {
|
|
|
expires_at: Instant,
|
|
|
offset_correction: isize,
|
|
|
) -> Result<()> {
|
|
|
- match &mut self.buffer {
|
|
|
- ManagedSlice::Borrowed(b) => {
|
|
|
- if let Some(total_size) = total_size {
|
|
|
- if b.len() < total_size {
|
|
|
- return Err(Error::PacketAssemblerBufferTooSmall);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- #[cfg(feature = "alloc")]
|
|
|
- ManagedSlice::Owned(b) => {
|
|
|
- if let Some(total_size) = total_size {
|
|
|
- b.resize(total_size, 0);
|
|
|
- }
|
|
|
- }
|
|
|
+ self.reset();
|
|
|
+ if let Some(total_size) = total_size {
|
|
|
+ self.set_total_size(total_size)?;
|
|
|
}
|
|
|
-
|
|
|
- self.assembler = AssemblerState::Assembling {
|
|
|
- assembler: Assembler::new(),
|
|
|
- total_size,
|
|
|
- expires_at,
|
|
|
- offset_correction,
|
|
|
- };
|
|
|
-
|
|
|
+ self.expires_at = expires_at;
|
|
|
+ self.offset_correction = offset_correction;
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
/// Set the total size of the packet assembler.
|
|
|
- ///
|
|
|
- /// # Errors
|
|
|
- ///
|
|
|
- /// - Returns [`Error::PacketAssemblerNotInit`] when the assembler was not initialized (try initializing the
|
|
|
- /// assembler with [Self::start]).
|
|
|
pub(crate) fn set_total_size(&mut self, size: usize) -> Result<()> {
|
|
|
- match self.assembler {
|
|
|
- AssemblerState::NotInit => Err(Error::PacketAssemblerNotInit),
|
|
|
- AssemblerState::Assembling {
|
|
|
- ref mut total_size, ..
|
|
|
- } => {
|
|
|
- *total_size = Some(size);
|
|
|
- Ok(())
|
|
|
+ if let Some(old_size) = self.total_size {
|
|
|
+ if old_size != size {
|
|
|
+ return Err(Error::Malformed);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ match &mut self.buffer {
|
|
|
+ ManagedSlice::Borrowed(b) => {
|
|
|
+ if b.len() < size {
|
|
|
+ return Err(Error::PacketAssemblerBufferTooSmall);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ #[cfg(feature = "alloc")]
|
|
|
+ ManagedSlice::Owned(b) => b.resize(size, 0),
|
|
|
+ }
|
|
|
+
|
|
|
+ self.total_size = Some(size);
|
|
|
+ Ok(())
|
|
|
}
|
|
|
|
|
|
/// Return the instant when the assembler expires.
|
|
|
- ///
|
|
|
- /// # Errors
|
|
|
- ///
|
|
|
- /// - Returns [`Error::PacketAssemblerNotInit`] when the assembler was not initialized (try initializing the
|
|
|
- /// assembler with [Self::start]).
|
|
|
pub(crate) fn expires_at(&self) -> Result<Instant> {
|
|
|
- match self.assembler {
|
|
|
- AssemblerState::NotInit => Err(Error::PacketAssemblerNotInit),
|
|
|
- AssemblerState::Assembling { expires_at, .. } => Ok(expires_at),
|
|
|
- }
|
|
|
+ Ok(self.expires_at)
|
|
|
}
|
|
|
|
|
|
/// Add a fragment into the packet that is being reassembled.
|
|
|
///
|
|
|
/// # Errors
|
|
|
///
|
|
|
- /// - Returns [`Error::PacketAssemblerNotInit`] when the assembler was not initialized (try initializing the
|
|
|
- /// assembler with [Self::start]).
|
|
|
/// - Returns [`Error::PacketAssemblerBufferTooSmall`] when trying to add data into the buffer at a non-existing
|
|
|
/// place.
|
|
|
/// - Returns [`Error::PacketAssemblerOverlap`] when there was an overlap when adding data.
|
|
|
pub(crate) fn add(&mut self, data: &[u8], offset: usize) -> Result<bool> {
|
|
|
- match self.assembler {
|
|
|
- AssemblerState::NotInit => Err(Error::PacketAssemblerNotInit),
|
|
|
- AssemblerState::Assembling {
|
|
|
- ref mut assembler,
|
|
|
- total_size,
|
|
|
- offset_correction,
|
|
|
- ..
|
|
|
- } => {
|
|
|
- let offset = offset as isize + offset_correction;
|
|
|
- let offset = if offset <= 0 { 0 } else { offset as usize };
|
|
|
-
|
|
|
- match &mut self.buffer {
|
|
|
- ManagedSlice::Borrowed(b) => {
|
|
|
- if offset + data.len() > b.len() {
|
|
|
- return Err(Error::PacketAssemblerBufferTooSmall);
|
|
|
- }
|
|
|
- }
|
|
|
- #[cfg(feature = "alloc")]
|
|
|
- ManagedSlice::Owned(b) => {
|
|
|
- if offset + data.len() > b.len() {
|
|
|
- b.resize(offset + data.len(), 0);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
+ let offset = offset as isize + self.offset_correction;
|
|
|
+ let offset = if offset <= 0 { 0 } else { offset as usize };
|
|
|
|
|
|
- let len = data.len();
|
|
|
- self.buffer[offset..][..len].copy_from_slice(data);
|
|
|
-
|
|
|
- net_debug!(
|
|
|
- "frag assembler: receiving {} octests at offset {}",
|
|
|
- len,
|
|
|
- offset
|
|
|
- );
|
|
|
-
|
|
|
- match assembler.add(offset, data.len()) {
|
|
|
- Ok(()) => {
|
|
|
- net_debug!("assembler: {}", self.assembler);
|
|
|
- self.is_complete()
|
|
|
- }
|
|
|
- // NOTE(thvdveld): hopefully we wont get too many holes errors I guess?
|
|
|
- Err(_) => Err(Error::PacketAssemblerTooManyHoles),
|
|
|
+ match &mut self.buffer {
|
|
|
+ ManagedSlice::Borrowed(b) => {
|
|
|
+ if offset + data.len() > b.len() {
|
|
|
+ return Err(Error::PacketAssemblerBufferTooSmall);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ #[cfg(feature = "alloc")]
|
|
|
+ ManagedSlice::Owned(b) => {
|
|
|
+ if offset + data.len() > b.len() {
|
|
|
+ b.resize(offset + data.len(), 0);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ let len = data.len();
|
|
|
+ self.buffer[offset..][..len].copy_from_slice(data);
|
|
|
+
|
|
|
+ net_debug!(
|
|
|
+ "frag assembler: receiving {} octests at offset {}",
|
|
|
+ len,
|
|
|
+ offset
|
|
|
+ );
|
|
|
+
|
|
|
+ match self.assembler.add(offset, data.len()) {
|
|
|
+ Ok(()) => {
|
|
|
+ net_debug!("assembler: {}", self.assembler);
|
|
|
+ self.is_complete()
|
|
|
+ }
|
|
|
+ // NOTE(thvdveld): hopefully we wont get too many holes errors I guess?
|
|
|
+ Err(_) => Err(Error::PacketAssemblerTooManyHoles),
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/// Get an immutable slice of the underlying packet data.
|
|
|
- /// This will mark the assembler state as [`AssemblerState::NotInit`] such that it can be reused.
|
|
|
+ /// This will mark the assembler as empty, so that it can be reused.
|
|
|
///
|
|
|
/// # Errors
|
|
|
///
|
|
|
- /// - Returns [`Error::PacketAssemblerNotInit`] when the assembler was not initialized (try initializing the
|
|
|
- /// assembler with [`Self::start`]).
|
|
|
/// - Returns [`Error::PacketAssemblerIncomplete`] when not all the fragments have been collected.
|
|
|
pub(crate) fn assemble(&mut self) -> Result<&'_ [u8]> {
|
|
|
- let b = match self.assembler {
|
|
|
- AssemblerState::NotInit => return Err(Error::PacketAssemblerNotInit),
|
|
|
- AssemblerState::Assembling { total_size, .. } => {
|
|
|
- if self.is_complete()? {
|
|
|
- // NOTE: we can unwrap because `is_complete` already checks this.
|
|
|
- let total_size = total_size.unwrap();
|
|
|
- let a = &self.buffer[..total_size];
|
|
|
- self.assembler = AssemblerState::NotInit;
|
|
|
- a
|
|
|
- } else {
|
|
|
- return Err(Error::PacketAssemblerIncomplete);
|
|
|
- }
|
|
|
- }
|
|
|
- };
|
|
|
- Ok(b)
|
|
|
+ if !self.is_complete()? {
|
|
|
+ return Err(Error::PacketAssemblerIncomplete);
|
|
|
+ }
|
|
|
+
|
|
|
+ // NOTE: we can unwrap because `is_complete` already checks this.
|
|
|
+ let total_size = self.total_size.unwrap();
|
|
|
+ let a = &self.buffer[..total_size];
|
|
|
+
|
|
|
+ Ok(a)
|
|
|
}
|
|
|
|
|
|
/// Returns `true` when all fragments have been received, otherwise `false`.
|
|
@@ -222,33 +166,15 @@ impl<'a> PacketAssembler<'a> {
|
|
|
/// - Returns [`Error::PacketAssemblerNotInit`] when the assembler was not initialized (try initializing the
|
|
|
/// assembler with [`Self::start`]).
|
|
|
pub(crate) fn is_complete(&self) -> Result<bool> {
|
|
|
- match &self.assembler {
|
|
|
- AssemblerState::NotInit => Err(Error::PacketAssemblerNotInit),
|
|
|
- AssemblerState::Assembling {
|
|
|
- assembler,
|
|
|
- total_size,
|
|
|
- ..
|
|
|
- } => match (total_size, assembler.peek_front()) {
|
|
|
- (Some(total_size), front) => Ok(front == *total_size),
|
|
|
- _ => Ok(false),
|
|
|
- },
|
|
|
+ match (self.total_size, self.assembler.peek_front()) {
|
|
|
+ (Some(total_size), front) => Ok(front == total_size),
|
|
|
+ _ => Ok(false),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- /// Returns `true` when the packet assembler is free to use.
|
|
|
- fn is_free(&self) -> bool {
|
|
|
- self.assembler == AssemblerState::NotInit
|
|
|
- }
|
|
|
-
|
|
|
- /// Mark this assembler as [`AssemblerState::NotInit`].
|
|
|
- /// This is then cleaned up by the [`PacketAssemblerSet`].
|
|
|
- pub fn mark_discarded(&mut self) {
|
|
|
- self.assembler = AssemblerState::NotInit;
|
|
|
- }
|
|
|
-
|
|
|
- /// Returns `true` when the [`AssemblerState`] is discarded.
|
|
|
- pub fn is_discarded(&self) -> bool {
|
|
|
- matches!(self.assembler, AssemblerState::NotInit)
|
|
|
+ /// Returns `true` when the packet assembler is empty (free to use).
|
|
|
+ fn is_empty(&self) -> bool {
|
|
|
+ self.assembler.is_empty()
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -349,7 +275,7 @@ impl<'a, K: Eq + Ord + Clone + Copy> PacketAssemblerSet<'a, K> {
|
|
|
self.packet_buffer
|
|
|
.iter()
|
|
|
.enumerate()
|
|
|
- .find(|(_, b)| b.is_free())
|
|
|
+ .find(|(_, b)| b.is_empty())
|
|
|
.map(|(i, _)| i)
|
|
|
}
|
|
|
|
|
@@ -388,7 +314,7 @@ impl<'a, K: Eq + Ord + Clone + Copy> PacketAssemblerSet<'a, K> {
|
|
|
loop {
|
|
|
let mut key = None;
|
|
|
for (k, i) in self.index_buffer.iter() {
|
|
|
- if matches!(self.packet_buffer[*i].assembler, AssemblerState::NotInit) {
|
|
|
+ if self.packet_buffer[*i].is_empty() {
|
|
|
key = Some(*k);
|
|
|
break;
|
|
|
}
|
|
@@ -411,7 +337,7 @@ impl<'a, K: Eq + Ord + Clone + Copy> PacketAssemblerSet<'a, K> {
|
|
|
for (_, i) in &mut self.index_buffer.iter() {
|
|
|
let frag = &mut self.packet_buffer[*i];
|
|
|
if f(frag)? {
|
|
|
- frag.mark_discarded();
|
|
|
+ frag.reset();
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -439,22 +365,6 @@ mod tests {
|
|
|
id: usize,
|
|
|
}
|
|
|
|
|
|
- #[test]
|
|
|
- fn packet_assembler_not_init() {
|
|
|
- let mut p_assembler = PacketAssembler::new(vec![]);
|
|
|
- let data = b"Hello World!";
|
|
|
- assert_eq!(
|
|
|
- p_assembler.add(&data[..], data.len()),
|
|
|
- Err(Error::PacketAssemblerNotInit)
|
|
|
- );
|
|
|
-
|
|
|
- assert_eq!(
|
|
|
- p_assembler.is_complete(),
|
|
|
- Err(Error::PacketAssemblerNotInit)
|
|
|
- );
|
|
|
- assert_eq!(p_assembler.assemble(), Err(Error::PacketAssemblerNotInit));
|
|
|
- }
|
|
|
-
|
|
|
#[test]
|
|
|
fn packet_assembler_buffer_too_small() {
|
|
|
let mut storage = [0u8; 1];
|