#![allow(unused)] use core::fmt; use managed::{ManagedMap, ManagedSlice}; use crate::storage::Assembler; use crate::time::{Duration, Instant}; // TODO: make configurable. const BUFFER_SIZE: usize = 1500; #[cfg(feature = "alloc")] type Buffer = alloc::vec::Vec; #[cfg(not(feature = "alloc"))] type Buffer = [u8; BUFFER_SIZE]; const PACKET_ASSEMBLER_COUNT: usize = 4; /// Problem when assembling: something was out of bounds. #[derive(Copy, Clone, PartialEq, Eq, Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct AssemblerError; /// Packet assembler is full #[derive(Copy, Clone, PartialEq, Eq, Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct AssemblerFullError; /// Holds different fragments of one packet, used for assembling fragmented packets. /// /// The buffer used for the `PacketAssembler` should either be dynamically sized (ex: Vec) /// or should be statically allocated based upon the MTU of the type of packet being /// assembled (ex: 1280 for a IPv6 frame). #[derive(Debug)] pub struct PacketAssembler { key: Option, buffer: Buffer, assembler: Assembler, total_size: Option, expires_at: Instant, offset_correction: isize, } impl PacketAssembler { /// Create a new empty buffer for fragments. pub fn new() -> Self { Self { key: None, #[cfg(feature = "alloc")] buffer: Buffer::new(), #[cfg(not(feature = "alloc"))] buffer: [0u8; BUFFER_SIZE], assembler: Assembler::new(), total_size: None, expires_at: Instant::ZERO, offset_correction: 0, } } pub(crate) fn reset(&mut self) { self.key = None; self.assembler.clear(); self.total_size = None; self.expires_at = Instant::ZERO; self.offset_correction = 0; } pub(crate) fn set_offset_correction(&mut self, correction: isize) { self.offset_correction = correction; } /// Set the total size of the packet assembler. pub(crate) fn set_total_size(&mut self, size: usize) -> Result<(), AssemblerError> { if let Some(old_size) = self.total_size { if old_size != size { return Err(AssemblerError); } } #[cfg(not(feature = "alloc"))] if self.buffer.len() < size { return Err(AssemblerError); } #[cfg(feature = "alloc")] if self.buffer.len() < size { self.buffer.resize(size, 0); } self.total_size = Some(size); Ok(()) } /// Return the instant when the assembler expires. pub(crate) fn expires_at(&self) -> Instant { self.expires_at } pub(crate) fn add_with( &mut self, offset: usize, f: impl Fn(&mut [u8]) -> Result, ) -> Result<(), AssemblerError> { if self.buffer.len() < offset { return Err(AssemblerError); } let len = f(&mut self.buffer[offset..])?; assert!(offset + len <= self.buffer.len()); net_debug!( "frag assembler: receiving {} octets at offset {}", len, offset ); match self.assembler.add(offset, len) { Ok(()) => { net_debug!("assembler: {}", self.assembler); Ok(()) } Err(_) => { net_debug!("packet assembler: too many holes, dropping."); Err(AssemblerError) } } } /// Add a fragment into the packet that is being reassembled. /// /// # Errors /// /// - Returns [`Error::PacketAssemblerBufferTooSmall`] when trying to add data into the buffer at a non-existing /// place. pub(crate) fn add(&mut self, data: &[u8], offset: usize) -> Result<(), AssemblerError> { let offset = offset as isize + self.offset_correction; let offset = if offset <= 0 { 0 } else { offset as usize }; #[cfg(not(feature = "alloc"))] if self.buffer.len() < offset + data.len() { return Err(AssemblerError); } #[cfg(feature = "alloc")] if self.buffer.len() < offset + data.len() { self.buffer.resize(offset + data.len(), 0); } let len = data.len(); self.buffer[offset..][..len].copy_from_slice(data); net_debug!( "frag assembler: receiving {} octets at offset {}", len, offset ); match self.assembler.add(offset, data.len()) { Ok(()) => { net_debug!("assembler: {}", self.assembler); Ok(()) } Err(_) => { net_debug!("packet assembler: too many holes, dropping."); Err(AssemblerError) } } } /// Get an immutable slice of the underlying packet data, if reassembly complete. /// This will mark the assembler as empty, so that it can be reused. pub(crate) fn assemble(&mut self) -> Option<&'_ [u8]> { if !self.is_complete() { return None; } // NOTE: we can unwrap because `is_complete` already checks this. let total_size = self.total_size.unwrap(); self.reset(); Some(&self.buffer[..total_size]) } /// Returns `true` when all fragments have been received, otherwise `false`. pub(crate) fn is_complete(&self) -> bool { self.total_size == Some(self.assembler.peek_front()) } /// Returns `true` when the packet assembler is free to use. fn is_free(&self) -> bool { self.key.is_none() } } /// Set holding multiple [`PacketAssembler`]. #[derive(Debug)] pub struct PacketAssemblerSet { assemblers: [PacketAssembler; PACKET_ASSEMBLER_COUNT], } impl PacketAssemblerSet { /// Create a new set of packet assemblers. pub fn new() -> Self { Self { // TODO: support any PACKET_ASSEMBLER_COUNT assemblers: [ PacketAssembler::new(), PacketAssembler::new(), PacketAssembler::new(), PacketAssembler::new(), ], } } /// Get a [`PacketAssembler`] for a specific key. /// /// If it doesn't exist, it is created, with the `expires_at` timestamp. /// /// If the assembler set is full, in which case an error is returned. pub(crate) fn get( &mut self, key: &K, expires_at: Instant, ) -> Result<&mut PacketAssembler, AssemblerFullError> { let mut empty_slot = None; for slot in &mut self.assemblers { if slot.key.as_ref() == Some(key) { return Ok(slot); } if slot.is_free() { empty_slot = Some(slot) } } let slot = empty_slot.ok_or(AssemblerFullError)?; slot.key = Some(*key); slot.expires_at = expires_at; Ok(slot) } /// Remove all [`PacketAssembler`]s that are expired. pub fn remove_expired(&mut self, timestamp: Instant) { for frag in &mut self.assemblers { if !frag.is_free() && frag.expires_at < timestamp { frag.reset(); } } } } #[cfg(test)] mod tests { use super::*; #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] struct Key { id: usize, } #[test] fn packet_assembler_overlap() { let mut p_assembler = PacketAssembler::::new(); p_assembler.set_total_size(5).unwrap(); let data = b"Rust"; p_assembler.add(&data[..], 0); p_assembler.add(&data[..], 1); assert_eq!(p_assembler.assemble(), Some(&b"RRust"[..])) } #[test] fn packet_assembler_assemble() { let mut p_assembler = PacketAssembler::::new(); let data = b"Hello World!"; p_assembler.set_total_size(data.len()).unwrap(); p_assembler.add(b"Hello ", 0).unwrap(); assert_eq!(p_assembler.assemble(), None); p_assembler.add(b"World!", b"Hello ".len()).unwrap(); assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..])); } #[test] fn packet_assembler_out_of_order_assemble() { let mut p_assembler = PacketAssembler::::new(); let data = b"Hello World!"; p_assembler.set_total_size(data.len()).unwrap(); p_assembler.add(b"World!", b"Hello ".len()).unwrap(); assert_eq!(p_assembler.assemble(), None); p_assembler.add(b"Hello ", 0).unwrap(); assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..])); } #[test] fn packet_assembler_set() { let key = Key { id: 1 }; let mut set = PacketAssemblerSet::new(); assert!(set.get(&key, Instant::ZERO).is_ok()); } #[test] fn packet_assembler_set_full() { let mut set = PacketAssemblerSet::new(); set.get(&Key { id: 0 }, Instant::ZERO).unwrap(); set.get(&Key { id: 1 }, Instant::ZERO).unwrap(); set.get(&Key { id: 2 }, Instant::ZERO).unwrap(); set.get(&Key { id: 3 }, Instant::ZERO).unwrap(); assert!(set.get(&Key { id: 4 }, Instant::ZERO).is_err()); } #[test] fn packet_assembler_set_assembling_many() { let mut set = PacketAssemblerSet::new(); let key = Key { id: 0 }; let assr = set.get(&key, Instant::ZERO).unwrap(); assert_eq!(assr.assemble(), None); assr.set_total_size(0).unwrap(); assr.assemble().unwrap(); // Test that `.assemble()` effectively deletes it. let assr = set.get(&key, Instant::ZERO).unwrap(); assert_eq!(assr.assemble(), None); assr.set_total_size(0).unwrap(); assr.assemble().unwrap(); let key = Key { id: 1 }; let assr = set.get(&key, Instant::ZERO).unwrap(); assr.set_total_size(0).unwrap(); assr.assemble().unwrap(); let key = Key { id: 2 }; let assr = set.get(&key, Instant::ZERO).unwrap(); assr.set_total_size(0).unwrap(); assr.assemble().unwrap(); let key = Key { id: 2 }; let assr = set.get(&key, Instant::ZERO).unwrap(); assr.set_total_size(2).unwrap(); assr.add(&[0x00], 0).unwrap(); assert_eq!(assr.assemble(), None); let assr = set.get(&key, Instant::ZERO).unwrap(); assr.add(&[0x01], 1).unwrap(); assert_eq!(assr.assemble(), Some(&[0x00, 0x01][..])); } }