udp.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. use core::fmt;
  2. use byteorder::{ByteOrder, NetworkEndian};
  3. use crate::{Error, Result};
  4. use crate::phy::ChecksumCapabilities;
  5. use crate::wire::{IpProtocol, IpAddress};
  6. use crate::wire::ip::checksum;
  7. /// A read/write wrapper around an User Datagram Protocol packet buffer.
  8. #[derive(Debug, PartialEq, Clone)]
  9. #[cfg_attr(feature = "defmt", derive(defmt::Format))]
  10. pub struct Packet<T: AsRef<[u8]>> {
  11. buffer: T
  12. }
  13. mod field {
  14. #![allow(non_snake_case)]
  15. use crate::wire::field::*;
  16. pub const SRC_PORT: Field = 0..2;
  17. pub const DST_PORT: Field = 2..4;
  18. pub const LENGTH: Field = 4..6;
  19. pub const CHECKSUM: Field = 6..8;
  20. pub fn PAYLOAD(length: u16) -> Field {
  21. CHECKSUM.end..(length as usize)
  22. }
  23. }
  24. #[allow(clippy::len_without_is_empty)]
  25. impl<T: AsRef<[u8]>> Packet<T> {
  26. /// Imbue a raw octet buffer with UDP packet structure.
  27. pub fn new_unchecked(buffer: T) -> Packet<T> {
  28. Packet { buffer }
  29. }
  30. /// Shorthand for a combination of [new_unchecked] and [check_len].
  31. ///
  32. /// [new_unchecked]: #method.new_unchecked
  33. /// [check_len]: #method.check_len
  34. pub fn new_checked(buffer: T) -> Result<Packet<T>> {
  35. let packet = Self::new_unchecked(buffer);
  36. packet.check_len()?;
  37. Ok(packet)
  38. }
  39. /// Ensure that no accessor method will panic if called.
  40. /// Returns `Err(Error::Truncated)` if the buffer is too short.
  41. /// Returns `Err(Error::Malformed)` if the length field has a value smaller
  42. /// than the header length.
  43. ///
  44. /// The result of this check is invalidated by calling [set_len].
  45. ///
  46. /// [set_len]: #method.set_len
  47. pub fn check_len(&self) -> Result<()> {
  48. let buffer_len = self.buffer.as_ref().len();
  49. if buffer_len < field::CHECKSUM.end {
  50. Err(Error::Truncated)
  51. } else {
  52. let field_len = self.len() as usize;
  53. if buffer_len < field_len {
  54. Err(Error::Truncated)
  55. } else if field_len < field::CHECKSUM.end {
  56. Err(Error::Malformed)
  57. } else {
  58. Ok(())
  59. }
  60. }
  61. }
  62. /// Consume the packet, returning the underlying buffer.
  63. pub fn into_inner(self) -> T {
  64. self.buffer
  65. }
  66. /// Return the source port field.
  67. #[inline]
  68. pub fn src_port(&self) -> u16 {
  69. let data = self.buffer.as_ref();
  70. NetworkEndian::read_u16(&data[field::SRC_PORT])
  71. }
  72. /// Return the destination port field.
  73. #[inline]
  74. pub fn dst_port(&self) -> u16 {
  75. let data = self.buffer.as_ref();
  76. NetworkEndian::read_u16(&data[field::DST_PORT])
  77. }
  78. /// Return the length field.
  79. #[inline]
  80. pub fn len(&self) -> u16 {
  81. let data = self.buffer.as_ref();
  82. NetworkEndian::read_u16(&data[field::LENGTH])
  83. }
  84. /// Return the checksum field.
  85. #[inline]
  86. pub fn checksum(&self) -> u16 {
  87. let data = self.buffer.as_ref();
  88. NetworkEndian::read_u16(&data[field::CHECKSUM])
  89. }
  90. /// Validate the packet checksum.
  91. ///
  92. /// # Panics
  93. /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
  94. /// and that family is IPv4 or IPv6.
  95. ///
  96. /// # Fuzzing
  97. /// This function always returns `true` when fuzzing.
  98. pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool {
  99. if cfg!(fuzzing) { return true }
  100. let data = self.buffer.as_ref();
  101. checksum::combine(&[
  102. checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp,
  103. self.len() as u32),
  104. checksum::data(&data[..self.len() as usize])
  105. ]) == !0
  106. }
  107. }
  108. impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> {
  109. /// Return a pointer to the payload.
  110. #[inline]
  111. pub fn payload(&self) -> &'a [u8] {
  112. let length = self.len();
  113. let data = self.buffer.as_ref();
  114. &data[field::PAYLOAD(length)]
  115. }
  116. }
  117. impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
  118. /// Set the source port field.
  119. #[inline]
  120. pub fn set_src_port(&mut self, value: u16) {
  121. let data = self.buffer.as_mut();
  122. NetworkEndian::write_u16(&mut data[field::SRC_PORT], value)
  123. }
  124. /// Set the destination port field.
  125. #[inline]
  126. pub fn set_dst_port(&mut self, value: u16) {
  127. let data = self.buffer.as_mut();
  128. NetworkEndian::write_u16(&mut data[field::DST_PORT], value)
  129. }
  130. /// Set the length field.
  131. #[inline]
  132. pub fn set_len(&mut self, value: u16) {
  133. let data = self.buffer.as_mut();
  134. NetworkEndian::write_u16(&mut data[field::LENGTH], value)
  135. }
  136. /// Set the checksum field.
  137. #[inline]
  138. pub fn set_checksum(&mut self, value: u16) {
  139. let data = self.buffer.as_mut();
  140. NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
  141. }
  142. /// Compute and fill in the header checksum.
  143. ///
  144. /// # Panics
  145. /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
  146. /// and that family is IPv4 or IPv6.
  147. pub fn fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress) {
  148. self.set_checksum(0);
  149. let checksum = {
  150. let data = self.buffer.as_ref();
  151. !checksum::combine(&[
  152. checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp,
  153. self.len() as u32),
  154. checksum::data(&data[..self.len() as usize])
  155. ])
  156. };
  157. // UDP checksum value of 0 means no checksum; if the checksum really is zero,
  158. // use all-ones, which indicates that the remote end must verify the checksum.
  159. // Arithmetically, RFC 1071 checksums of all-zeroes and all-ones behave identically,
  160. // so no action is necessary on the remote end.
  161. self.set_checksum(if checksum == 0 { 0xffff } else { checksum })
  162. }
  163. /// Return a mutable pointer to the payload.
  164. #[inline]
  165. pub fn payload_mut(&mut self) -> &mut [u8] {
  166. let length = self.len();
  167. let data = self.buffer.as_mut();
  168. &mut data[field::PAYLOAD(length)]
  169. }
  170. }
  171. impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
  172. fn as_ref(&self) -> &[u8] {
  173. self.buffer.as_ref()
  174. }
  175. }
  176. /// A high-level representation of an User Datagram Protocol packet.
  177. #[derive(Debug, PartialEq, Eq, Clone, Copy)]
  178. #[cfg_attr(feature = "defmt", derive(defmt::Format))]
  179. pub struct Repr<'a> {
  180. pub src_port: u16,
  181. pub dst_port: u16,
  182. pub payload: &'a [u8]
  183. }
  184. impl<'a> Repr<'a> {
  185. /// Parse an User Datagram Protocol packet and return a high-level representation.
  186. pub fn parse<T>(packet: &Packet<&'a T>, src_addr: &IpAddress, dst_addr: &IpAddress,
  187. checksum_caps: &ChecksumCapabilities) -> Result<Repr<'a>>
  188. where T: AsRef<[u8]> + ?Sized {
  189. // Destination port cannot be omitted (but source port can be).
  190. if packet.dst_port() == 0 { return Err(Error::Malformed) }
  191. // Valid checksum is expected...
  192. if checksum_caps.udp.rx() && !packet.verify_checksum(src_addr, dst_addr) {
  193. match (src_addr, dst_addr) {
  194. // ... except on UDP-over-IPv4, where it can be omitted.
  195. #[cfg(feature = "proto-ipv4")]
  196. (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_))
  197. if packet.checksum() == 0 => (),
  198. _ => {
  199. return Err(Error::Checksum)
  200. }
  201. }
  202. }
  203. Ok(Repr {
  204. src_port: packet.src_port(),
  205. dst_port: packet.dst_port(),
  206. payload: packet.payload()
  207. })
  208. }
  209. /// Return the length of a packet that will be emitted from this high-level representation.
  210. pub fn buffer_len(&self) -> usize {
  211. field::CHECKSUM.end + self.payload.len()
  212. }
  213. /// Emit a high-level representation into an User Datagram Protocol packet.
  214. pub fn emit<T: ?Sized>(&self, packet: &mut Packet<&mut T>,
  215. src_addr: &IpAddress,
  216. dst_addr: &IpAddress,
  217. checksum_caps: &ChecksumCapabilities)
  218. where T: AsRef<[u8]> + AsMut<[u8]> {
  219. packet.set_src_port(self.src_port);
  220. packet.set_dst_port(self.dst_port);
  221. packet.set_len((field::CHECKSUM.end + self.payload.len()) as u16);
  222. packet.payload_mut().copy_from_slice(self.payload);
  223. if checksum_caps.udp.tx() {
  224. packet.fill_checksum(src_addr, dst_addr)
  225. } else {
  226. // make sure we get a consistently zeroed checksum,
  227. // since implementations might rely on it
  228. packet.set_checksum(0);
  229. }
  230. }
  231. }
  232. impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
  233. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
  234. // Cannot use Repr::parse because we don't have the IP addresses.
  235. write!(f, "UDP src={} dst={} len={}",
  236. self.src_port(), self.dst_port(), self.payload().len())
  237. }
  238. }
  239. impl<'a> fmt::Display for Repr<'a> {
  240. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
  241. write!(f, "UDP src={} dst={} len={}",
  242. self.src_port, self.dst_port, self.payload.len())
  243. }
  244. }
  245. use crate::wire::pretty_print::{PrettyPrint, PrettyIndent};
  246. impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
  247. fn pretty_print(buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter,
  248. indent: &mut PrettyIndent) -> fmt::Result {
  249. match Packet::new_checked(buffer) {
  250. Err(err) => write!(f, "{}({})", indent, err),
  251. Ok(packet) => write!(f, "{}{}", indent, packet)
  252. }
  253. }
  254. }
  255. #[cfg(test)]
  256. mod test {
  257. #[cfg(feature = "proto-ipv4")]
  258. use crate::wire::Ipv4Address;
  259. use super::*;
  260. #[cfg(feature = "proto-ipv4")]
  261. const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
  262. #[cfg(feature = "proto-ipv4")]
  263. const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
  264. #[cfg(feature = "proto-ipv4")]
  265. static PACKET_BYTES: [u8; 12] =
  266. [0xbf, 0x00, 0x00, 0x35,
  267. 0x00, 0x0c, 0x12, 0x4d,
  268. 0xaa, 0x00, 0x00, 0xff];
  269. #[cfg(feature = "proto-ipv4")]
  270. static NO_CHECKSUM_PACKET: [u8; 12] =
  271. [0xbf, 0x00, 0x00, 0x35,
  272. 0x00, 0x0c, 0x00, 0x00,
  273. 0xaa, 0x00, 0x00, 0xff];
  274. #[cfg(feature = "proto-ipv4")]
  275. static PAYLOAD_BYTES: [u8; 4] =
  276. [0xaa, 0x00, 0x00, 0xff];
  277. #[test]
  278. #[cfg(feature = "proto-ipv4")]
  279. fn test_deconstruct() {
  280. let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
  281. assert_eq!(packet.src_port(), 48896);
  282. assert_eq!(packet.dst_port(), 53);
  283. assert_eq!(packet.len(), 12);
  284. assert_eq!(packet.checksum(), 0x124d);
  285. assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]);
  286. assert_eq!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()), true);
  287. }
  288. #[test]
  289. #[cfg(feature = "proto-ipv4")]
  290. fn test_construct() {
  291. let mut bytes = vec![0xa5; 12];
  292. let mut packet = Packet::new_unchecked(&mut bytes);
  293. packet.set_src_port(48896);
  294. packet.set_dst_port(53);
  295. packet.set_len(12);
  296. packet.set_checksum(0xffff);
  297. packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]);
  298. packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
  299. assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]);
  300. }
  301. #[test]
  302. fn test_impossible_len() {
  303. let mut bytes = vec![0; 12];
  304. let mut packet = Packet::new_unchecked(&mut bytes);
  305. packet.set_len(4);
  306. assert_eq!(packet.check_len(), Err(Error::Malformed));
  307. }
  308. #[test]
  309. #[cfg(feature = "proto-ipv4")]
  310. fn test_zero_checksum() {
  311. let mut bytes = vec![0; 8];
  312. let mut packet = Packet::new_unchecked(&mut bytes);
  313. packet.set_src_port(1);
  314. packet.set_dst_port(31881);
  315. packet.set_len(8);
  316. packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
  317. assert_eq!(packet.checksum(), 0xffff);
  318. }
  319. #[cfg(feature = "proto-ipv4")]
  320. fn packet_repr() -> Repr<'static> {
  321. Repr {
  322. src_port: 48896,
  323. dst_port: 53,
  324. payload: &PAYLOAD_BYTES
  325. }
  326. }
  327. #[test]
  328. #[cfg(feature = "proto-ipv4")]
  329. fn test_parse() {
  330. let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
  331. let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into(),
  332. &ChecksumCapabilities::default()).unwrap();
  333. assert_eq!(repr, packet_repr());
  334. }
  335. #[test]
  336. #[cfg(feature = "proto-ipv4")]
  337. fn test_emit() {
  338. let repr = packet_repr();
  339. let mut bytes = vec![0xa5; repr.buffer_len()];
  340. let mut packet = Packet::new_unchecked(&mut bytes);
  341. repr.emit(&mut packet, &SRC_ADDR.into(), &DST_ADDR.into(),
  342. &ChecksumCapabilities::default());
  343. assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]);
  344. }
  345. #[test]
  346. #[cfg(feature = "proto-ipv4")]
  347. fn test_checksum_omitted() {
  348. let packet = Packet::new_unchecked(&NO_CHECKSUM_PACKET[..]);
  349. let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into(),
  350. &ChecksumCapabilities::default()).unwrap();
  351. assert_eq!(repr, packet_repr());
  352. }
  353. }