udp.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. use core::fmt;
  2. use byteorder::{ByteOrder, NetworkEndian};
  3. use crate::{Error, Result};
  4. use crate::phy::ChecksumCapabilities;
  5. use crate::wire::{IpProtocol, IpAddress};
  6. use crate::wire::ip::checksum;
  7. /// A read/write wrapper around an User Datagram Protocol packet buffer.
  8. #[derive(Debug, PartialEq, Clone)]
  9. #[cfg_attr(feature = "defmt", derive(defmt::Format))]
  10. pub struct Packet<T: AsRef<[u8]>> {
  11. buffer: T
  12. }
  13. mod field {
  14. #![allow(non_snake_case)]
  15. use crate::wire::field::*;
  16. pub const SRC_PORT: Field = 0..2;
  17. pub const DST_PORT: Field = 2..4;
  18. pub const LENGTH: Field = 4..6;
  19. pub const CHECKSUM: Field = 6..8;
  20. pub fn PAYLOAD(length: u16) -> Field {
  21. CHECKSUM.end..(length as usize)
  22. }
  23. }
  24. #[allow(clippy::len_without_is_empty)]
  25. impl<T: AsRef<[u8]>> Packet<T> {
  26. /// Imbue a raw octet buffer with UDP packet structure.
  27. pub fn new_unchecked(buffer: T) -> Packet<T> {
  28. Packet { buffer }
  29. }
  30. /// Shorthand for a combination of [new_unchecked] and [check_len].
  31. ///
  32. /// [new_unchecked]: #method.new_unchecked
  33. /// [check_len]: #method.check_len
  34. pub fn new_checked(buffer: T) -> Result<Packet<T>> {
  35. let packet = Self::new_unchecked(buffer);
  36. packet.check_len()?;
  37. Ok(packet)
  38. }
  39. /// Ensure that no accessor method will panic if called.
  40. /// Returns `Err(Error::Truncated)` if the buffer is too short.
  41. /// Returns `Err(Error::Malformed)` if the length field has a value smaller
  42. /// than the header length.
  43. ///
  44. /// The result of this check is invalidated by calling [set_len].
  45. ///
  46. /// [set_len]: #method.set_len
  47. pub fn check_len(&self) -> Result<()> {
  48. let buffer_len = self.buffer.as_ref().len();
  49. if buffer_len < field::CHECKSUM.end {
  50. Err(Error::Truncated)
  51. } else {
  52. let field_len = self.len() as usize;
  53. if buffer_len < field_len {
  54. Err(Error::Truncated)
  55. } else if field_len < field::CHECKSUM.end {
  56. Err(Error::Malformed)
  57. } else {
  58. Ok(())
  59. }
  60. }
  61. }
  62. /// Consume the packet, returning the underlying buffer.
  63. pub fn into_inner(self) -> T {
  64. self.buffer
  65. }
  66. /// Return the source port field.
  67. #[inline]
  68. pub fn src_port(&self) -> u16 {
  69. let data = self.buffer.as_ref();
  70. NetworkEndian::read_u16(&data[field::SRC_PORT])
  71. }
  72. /// Return the destination port field.
  73. #[inline]
  74. pub fn dst_port(&self) -> u16 {
  75. let data = self.buffer.as_ref();
  76. NetworkEndian::read_u16(&data[field::DST_PORT])
  77. }
  78. /// Return the length field.
  79. #[inline]
  80. pub fn len(&self) -> u16 {
  81. let data = self.buffer.as_ref();
  82. NetworkEndian::read_u16(&data[field::LENGTH])
  83. }
  84. /// Return the checksum field.
  85. #[inline]
  86. pub fn checksum(&self) -> u16 {
  87. let data = self.buffer.as_ref();
  88. NetworkEndian::read_u16(&data[field::CHECKSUM])
  89. }
  90. /// Validate the packet checksum.
  91. ///
  92. /// # Panics
  93. /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
  94. /// and that family is IPv4 or IPv6.
  95. ///
  96. /// # Fuzzing
  97. /// This function always returns `true` when fuzzing.
  98. pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool {
  99. if cfg!(fuzzing) { return true }
  100. let data = self.buffer.as_ref();
  101. checksum::combine(&[
  102. checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp,
  103. self.len() as u32),
  104. checksum::data(&data[..self.len() as usize])
  105. ]) == !0
  106. }
  107. }
  108. impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> {
  109. /// Return a pointer to the payload.
  110. #[inline]
  111. pub fn payload(&self) -> &'a [u8] {
  112. let length = self.len();
  113. let data = self.buffer.as_ref();
  114. &data[field::PAYLOAD(length)]
  115. }
  116. }
  117. impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
  118. /// Set the source port field.
  119. #[inline]
  120. pub fn set_src_port(&mut self, value: u16) {
  121. let data = self.buffer.as_mut();
  122. NetworkEndian::write_u16(&mut data[field::SRC_PORT], value)
  123. }
  124. /// Set the destination port field.
  125. #[inline]
  126. pub fn set_dst_port(&mut self, value: u16) {
  127. let data = self.buffer.as_mut();
  128. NetworkEndian::write_u16(&mut data[field::DST_PORT], value)
  129. }
  130. /// Set the length field.
  131. #[inline]
  132. pub fn set_len(&mut self, value: u16) {
  133. let data = self.buffer.as_mut();
  134. NetworkEndian::write_u16(&mut data[field::LENGTH], value)
  135. }
  136. /// Set the checksum field.
  137. #[inline]
  138. pub fn set_checksum(&mut self, value: u16) {
  139. let data = self.buffer.as_mut();
  140. NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
  141. }
  142. /// Compute and fill in the header checksum.
  143. ///
  144. /// # Panics
  145. /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
  146. /// and that family is IPv4 or IPv6.
  147. pub fn fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress) {
  148. self.set_checksum(0);
  149. let checksum = {
  150. let data = self.buffer.as_ref();
  151. !checksum::combine(&[
  152. checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp,
  153. self.len() as u32),
  154. checksum::data(&data[..self.len() as usize])
  155. ])
  156. };
  157. // UDP checksum value of 0 means no checksum; if the checksum really is zero,
  158. // use all-ones, which indicates that the remote end must verify the checksum.
  159. // Arithmetically, RFC 1071 checksums of all-zeroes and all-ones behave identically,
  160. // so no action is necessary on the remote end.
  161. self.set_checksum(if checksum == 0 { 0xffff } else { checksum })
  162. }
  163. /// Return a mutable pointer to the payload.
  164. #[inline]
  165. pub fn payload_mut(&mut self) -> &mut [u8] {
  166. let length = self.len();
  167. let data = self.buffer.as_mut();
  168. &mut data[field::PAYLOAD(length)]
  169. }
  170. }
  171. impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
  172. fn as_ref(&self) -> &[u8] {
  173. self.buffer.as_ref()
  174. }
  175. }
  176. /// A high-level representation of an User Datagram Protocol packet.
  177. #[derive(Debug, PartialEq, Eq, Clone, Copy)]
  178. #[cfg_attr(feature = "defmt", derive(defmt::Format))]
  179. pub struct Repr {
  180. pub src_port: u16,
  181. pub dst_port: u16,
  182. }
  183. impl Repr {
  184. /// Parse an User Datagram Protocol packet and return a high-level representation.
  185. pub fn parse<T>(packet: &Packet<&T>, src_addr: &IpAddress, dst_addr: &IpAddress,
  186. checksum_caps: &ChecksumCapabilities) -> Result<Repr>
  187. where T: AsRef<[u8]> + ?Sized {
  188. // Destination port cannot be omitted (but source port can be).
  189. if packet.dst_port() == 0 { return Err(Error::Malformed) }
  190. // Valid checksum is expected...
  191. if checksum_caps.udp.rx() && !packet.verify_checksum(src_addr, dst_addr) {
  192. match (src_addr, dst_addr) {
  193. // ... except on UDP-over-IPv4, where it can be omitted.
  194. #[cfg(feature = "proto-ipv4")]
  195. (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_))
  196. if packet.checksum() == 0 => (),
  197. _ => {
  198. return Err(Error::Checksum)
  199. }
  200. }
  201. }
  202. Ok(Repr {
  203. src_port: packet.src_port(),
  204. dst_port: packet.dst_port(),
  205. })
  206. }
  207. /// Return the length of a packet that will be emitted from this high-level representation.
  208. pub fn header_len(&self) -> usize {
  209. field::CHECKSUM.end
  210. }
  211. /// Emit a high-level representation into an User Datagram Protocol packet.
  212. pub fn emit<T: ?Sized>(&self, packet: &mut Packet<&mut T>,
  213. src_addr: &IpAddress,
  214. dst_addr: &IpAddress,
  215. payload_len: usize,
  216. emit_payload: impl FnOnce(&mut [u8]),
  217. checksum_caps: &ChecksumCapabilities)
  218. where T: AsRef<[u8]> + AsMut<[u8]> {
  219. packet.set_src_port(self.src_port);
  220. packet.set_dst_port(self.dst_port);
  221. packet.set_len((field::CHECKSUM.end + payload_len) as u16);
  222. emit_payload(packet.payload_mut());
  223. if checksum_caps.udp.tx() {
  224. packet.fill_checksum(src_addr, dst_addr)
  225. } else {
  226. // make sure we get a consistently zeroed checksum,
  227. // since implementations might rely on it
  228. packet.set_checksum(0);
  229. }
  230. }
  231. }
  232. impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
  233. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
  234. // Cannot use Repr::parse because we don't have the IP addresses.
  235. write!(f, "UDP src={} dst={} len={}",
  236. self.src_port(), self.dst_port(), self.payload().len())
  237. }
  238. }
  239. impl fmt::Display for Repr {
  240. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
  241. write!(f, "UDP src={} dst={}", self.src_port, self.dst_port)
  242. }
  243. }
  244. use crate::wire::pretty_print::{PrettyPrint, PrettyIndent};
  245. impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
  246. fn pretty_print(buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter,
  247. indent: &mut PrettyIndent) -> fmt::Result {
  248. match Packet::new_checked(buffer) {
  249. Err(err) => write!(f, "{}({})", indent, err),
  250. Ok(packet) => write!(f, "{}{}", indent, packet)
  251. }
  252. }
  253. }
  254. #[cfg(test)]
  255. mod test {
  256. #[cfg(feature = "proto-ipv4")]
  257. use crate::wire::Ipv4Address;
  258. use super::*;
  259. #[cfg(feature = "proto-ipv4")]
  260. const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
  261. #[cfg(feature = "proto-ipv4")]
  262. const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
  263. #[cfg(feature = "proto-ipv4")]
  264. static PACKET_BYTES: [u8; 12] =
  265. [0xbf, 0x00, 0x00, 0x35,
  266. 0x00, 0x0c, 0x12, 0x4d,
  267. 0xaa, 0x00, 0x00, 0xff];
  268. #[cfg(feature = "proto-ipv4")]
  269. static NO_CHECKSUM_PACKET: [u8; 12] =
  270. [0xbf, 0x00, 0x00, 0x35,
  271. 0x00, 0x0c, 0x00, 0x00,
  272. 0xaa, 0x00, 0x00, 0xff];
  273. #[cfg(feature = "proto-ipv4")]
  274. static PAYLOAD_BYTES: [u8; 4] =
  275. [0xaa, 0x00, 0x00, 0xff];
  276. #[test]
  277. #[cfg(feature = "proto-ipv4")]
  278. fn test_deconstruct() {
  279. let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
  280. assert_eq!(packet.src_port(), 48896);
  281. assert_eq!(packet.dst_port(), 53);
  282. assert_eq!(packet.len(), 12);
  283. assert_eq!(packet.checksum(), 0x124d);
  284. assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]);
  285. assert_eq!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()), true);
  286. }
  287. #[test]
  288. #[cfg(feature = "proto-ipv4")]
  289. fn test_construct() {
  290. let mut bytes = vec![0xa5; 12];
  291. let mut packet = Packet::new_unchecked(&mut bytes);
  292. packet.set_src_port(48896);
  293. packet.set_dst_port(53);
  294. packet.set_len(12);
  295. packet.set_checksum(0xffff);
  296. packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]);
  297. packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
  298. assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]);
  299. }
  300. #[test]
  301. fn test_impossible_len() {
  302. let mut bytes = vec![0; 12];
  303. let mut packet = Packet::new_unchecked(&mut bytes);
  304. packet.set_len(4);
  305. assert_eq!(packet.check_len(), Err(Error::Malformed));
  306. }
  307. #[test]
  308. #[cfg(feature = "proto-ipv4")]
  309. fn test_zero_checksum() {
  310. let mut bytes = vec![0; 8];
  311. let mut packet = Packet::new_unchecked(&mut bytes);
  312. packet.set_src_port(1);
  313. packet.set_dst_port(31881);
  314. packet.set_len(8);
  315. packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
  316. assert_eq!(packet.checksum(), 0xffff);
  317. }
  318. #[cfg(feature = "proto-ipv4")]
  319. fn packet_repr() -> Repr {
  320. Repr {
  321. src_port: 48896,
  322. dst_port: 53,
  323. }
  324. }
  325. #[test]
  326. #[cfg(feature = "proto-ipv4")]
  327. fn test_parse() {
  328. let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
  329. let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into(),
  330. &ChecksumCapabilities::default()).unwrap();
  331. assert_eq!(repr, packet_repr());
  332. }
  333. #[test]
  334. #[cfg(feature = "proto-ipv4")]
  335. fn test_emit() {
  336. let repr = packet_repr();
  337. let mut bytes = vec![0xa5; repr.header_len() + PAYLOAD_BYTES.len()];
  338. let mut packet = Packet::new_unchecked(&mut bytes);
  339. repr.emit(&mut packet, &SRC_ADDR.into(), &DST_ADDR.into(),
  340. PAYLOAD_BYTES.len(),
  341. |payload| payload.copy_from_slice(&PAYLOAD_BYTES),
  342. &ChecksumCapabilities::default());
  343. assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]);
  344. }
  345. #[test]
  346. #[cfg(feature = "proto-ipv4")]
  347. fn test_checksum_omitted() {
  348. let packet = Packet::new_unchecked(&NO_CHECKSUM_PACKET[..]);
  349. let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into(),
  350. &ChecksumCapabilities::default()).unwrap();
  351. assert_eq!(repr, packet_repr());
  352. }
  353. }