udp.rs 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. use core::fmt;
  2. use byteorder::{ByteOrder, NetworkEndian};
  3. use Error;
  4. use super::{InternetProtocolType, InternetAddress};
  5. use super::ip::checksum;
  6. /// A read/write wrapper around an User Datagram Protocol packet buffer.
  7. #[derive(Debug)]
  8. pub struct Packet<T: AsRef<[u8]>> {
  9. buffer: T
  10. }
  11. mod field {
  12. #![allow(non_snake_case)]
  13. use wire::field::*;
  14. pub const SRC_PORT: Field = 0..2;
  15. pub const DST_PORT: Field = 2..4;
  16. pub const LENGTH: Field = 4..6;
  17. pub const CHECKSUM: Field = 6..8;
  18. pub fn PAYLOAD(length: u16) -> Field {
  19. CHECKSUM.end..(length as usize)
  20. }
  21. }
  22. impl<T: AsRef<[u8]>> Packet<T> {
  23. /// Wrap a buffer with an UDP packet. Returns an error if the buffer
  24. /// is too small to contain one.
  25. pub fn new(buffer: T) -> Result<Packet<T>, Error> {
  26. let len = buffer.as_ref().len();
  27. if len < field::CHECKSUM.end {
  28. Err(Error::Truncated)
  29. } else {
  30. let packet = Packet { buffer: buffer };
  31. if len < packet.len() as usize {
  32. Err(Error::Truncated)
  33. } else {
  34. Ok(packet)
  35. }
  36. }
  37. }
  38. /// Consumes the packet, returning the underlying buffer.
  39. pub fn into_inner(self) -> T {
  40. self.buffer
  41. }
  42. /// Return the source port field.
  43. #[inline(always)]
  44. pub fn src_port(&self) -> u16 {
  45. let data = self.buffer.as_ref();
  46. NetworkEndian::read_u16(&data[field::SRC_PORT])
  47. }
  48. /// Return the destination port field.
  49. #[inline(always)]
  50. pub fn dst_port(&self) -> u16 {
  51. let data = self.buffer.as_ref();
  52. NetworkEndian::read_u16(&data[field::DST_PORT])
  53. }
  54. /// Return the length field.
  55. #[inline(always)]
  56. pub fn len(&self) -> u16 {
  57. let data = self.buffer.as_ref();
  58. NetworkEndian::read_u16(&data[field::LENGTH])
  59. }
  60. /// Return the checksum field.
  61. #[inline(always)]
  62. pub fn checksum(&self) -> u16 {
  63. let data = self.buffer.as_ref();
  64. NetworkEndian::read_u16(&data[field::CHECKSUM])
  65. }
  66. /// Validate the packet checksum.
  67. ///
  68. /// # Panics
  69. /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
  70. /// and that family is IPv4 or IPv6.
  71. pub fn verify_checksum(&self, src_addr: &InternetAddress, dst_addr: &InternetAddress) -> bool {
  72. let data = self.buffer.as_ref();
  73. checksum::combine(&[
  74. checksum::pseudo_header(src_addr, dst_addr, InternetProtocolType::Udp,
  75. self.len() as u32),
  76. checksum::data(&data[..self.len() as usize])
  77. ]) == !0
  78. }
  79. }
  80. impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> {
  81. /// Return a pointer to the payload.
  82. #[inline(always)]
  83. pub fn payload(&self) -> &'a [u8] {
  84. let length = self.len();
  85. let data = self.buffer.as_ref();
  86. &data[field::PAYLOAD(length)]
  87. }
  88. }
  89. impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
  90. /// Set the source port field.
  91. #[inline(always)]
  92. pub fn set_src_port(&mut self, value: u16) {
  93. let mut data = self.buffer.as_mut();
  94. NetworkEndian::write_u16(&mut data[field::SRC_PORT], value)
  95. }
  96. /// Set the destination port field.
  97. #[inline(always)]
  98. pub fn set_dst_port(&mut self, value: u16) {
  99. let mut data = self.buffer.as_mut();
  100. NetworkEndian::write_u16(&mut data[field::DST_PORT], value)
  101. }
  102. /// Set the length field.
  103. #[inline(always)]
  104. pub fn set_len(&mut self, value: u16) {
  105. let mut data = self.buffer.as_mut();
  106. NetworkEndian::write_u16(&mut data[field::LENGTH], value)
  107. }
  108. /// Set the checksum field.
  109. #[inline(always)]
  110. pub fn set_checksum(&mut self, value: u16) {
  111. let mut data = self.buffer.as_mut();
  112. NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
  113. }
  114. /// Compute and fill in the header checksum.
  115. ///
  116. /// # Panics
  117. /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
  118. /// and that family is IPv4 or IPv6.
  119. pub fn fill_checksum(&mut self, src_addr: &InternetAddress, dst_addr: &InternetAddress) {
  120. self.set_checksum(0);
  121. let checksum = {
  122. let data = self.buffer.as_ref();
  123. !checksum::combine(&[
  124. checksum::pseudo_header(src_addr, dst_addr, InternetProtocolType::Udp,
  125. self.len() as u32),
  126. checksum::data(&data[..self.len() as usize])
  127. ])
  128. };
  129. self.set_checksum(checksum)
  130. }
  131. }
  132. impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'a mut T> {
  133. /// Return a mutable pointer to the payload.
  134. #[inline(always)]
  135. pub fn payload_mut(&mut self) -> &mut [u8] {
  136. let length = self.len();
  137. let mut data = self.buffer.as_mut();
  138. &mut data[field::PAYLOAD(length)]
  139. }
  140. }
  141. /// A high-level representation of an User Datagram Protocol packet.
  142. #[derive(Debug, PartialEq, Eq, Clone, Copy)]
  143. pub struct Repr<'a> {
  144. pub src_port: u16,
  145. pub dst_port: u16,
  146. pub payload: &'a [u8]
  147. }
  148. impl<'a> Repr<'a> {
  149. /// Parse an User Datagram Protocol packet and return a high-level representation.
  150. pub fn parse<T: ?Sized>(packet: &Packet<&'a T>,
  151. src_addr: &InternetAddress,
  152. dst_addr: &InternetAddress) -> Result<Repr<'a>, Error>
  153. where T: AsRef<[u8]> {
  154. // Destination port cannot be omitted (but source port can be).
  155. if packet.dst_port() == 0 { return Err(Error::Malformed) }
  156. // Valid checksum is expected...
  157. if !packet.verify_checksum(src_addr, dst_addr) {
  158. match (src_addr, dst_addr) {
  159. (&InternetAddress::Ipv4(_), &InternetAddress::Ipv4(_))
  160. if packet.checksum() != 0 => {
  161. // ... except on UDP-over-IPv4, where it can be omitted.
  162. return Err(Error::Checksum)
  163. },
  164. _ => {
  165. return Err(Error::Checksum)
  166. }
  167. }
  168. }
  169. Ok(Repr {
  170. src_port: packet.src_port(),
  171. dst_port: packet.dst_port(),
  172. payload: packet.payload()
  173. })
  174. }
  175. /// Return the length of a packet that will be emitted from this high-level representation.
  176. pub fn len(&self) -> usize {
  177. field::CHECKSUM.end + self.payload.len()
  178. }
  179. /// Emit a high-level representation into an User Datagram Protocol packet.
  180. pub fn emit<T: ?Sized>(&self, packet: &mut Packet<&mut T>,
  181. src_addr: &InternetAddress,
  182. dst_addr: &InternetAddress)
  183. where T: AsRef<[u8]> + AsMut<[u8]> {
  184. packet.set_src_port(self.src_port);
  185. packet.set_dst_port(self.dst_port);
  186. packet.set_len((field::CHECKSUM.end + self.payload.len()) as u16);
  187. packet.payload_mut().copy_from_slice(self.payload);
  188. packet.fill_checksum(src_addr, dst_addr)
  189. }
  190. }
  191. impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
  192. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
  193. // Cannot use Repr::parse because we don't have the IP addresses.
  194. write!(f, "UDP src={} dst={} len={}",
  195. self.src_port(), self.dst_port(), self.payload().len())
  196. }
  197. }
  198. impl<'a> fmt::Display for Repr<'a> {
  199. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
  200. write!(f, "UDP src={} dst={} len={}",
  201. self.src_port, self.dst_port, self.payload.len())
  202. }
  203. }
  204. use super::pretty_print::{PrettyPrint, PrettyIndent};
  205. impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
  206. fn pretty_print(buffer: &AsRef<[u8]>, f: &mut fmt::Formatter,
  207. indent: &mut PrettyIndent) -> fmt::Result {
  208. match Packet::new(buffer) {
  209. Err(err) => write!(f, "{}({})\n", indent, err),
  210. Ok(packet) => write!(f, "{}{}\n", indent, packet)
  211. }
  212. }
  213. }
  214. #[cfg(test)]
  215. mod test {
  216. use wire::Ipv4Address;
  217. use super::*;
  218. const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
  219. const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
  220. static PACKET_BYTES: [u8; 12] =
  221. [0xbf, 0x00, 0x00, 0x35,
  222. 0x00, 0x0c, 0x12, 0x4d,
  223. 0xaa, 0x00, 0x00, 0xff];
  224. static PAYLOAD_BYTES: [u8; 4] =
  225. [0xaa, 0x00, 0x00, 0xff];
  226. #[test]
  227. fn test_deconstruct() {
  228. let packet = Packet::new(&PACKET_BYTES[..]).unwrap();
  229. assert_eq!(packet.src_port(), 48896);
  230. assert_eq!(packet.dst_port(), 53);
  231. assert_eq!(packet.len(), 12);
  232. assert_eq!(packet.checksum(), 0x124d);
  233. assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]);
  234. assert_eq!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()), true);
  235. }
  236. #[test]
  237. fn test_construct() {
  238. let mut bytes = vec![0; 12];
  239. let mut packet = Packet::new(&mut bytes).unwrap();
  240. packet.set_src_port(48896);
  241. packet.set_dst_port(53);
  242. packet.set_len(12);
  243. packet.set_checksum(0xffff);
  244. packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]);
  245. packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
  246. assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]);
  247. }
  248. fn packet_repr() -> Repr<'static> {
  249. Repr {
  250. src_port: 48896,
  251. dst_port: 53,
  252. payload: &PAYLOAD_BYTES
  253. }
  254. }
  255. #[test]
  256. fn test_parse() {
  257. let packet = Packet::new(&PACKET_BYTES[..]).unwrap();
  258. let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into()).unwrap();
  259. assert_eq!(repr, packet_repr());
  260. }
  261. #[test]
  262. fn test_emit() {
  263. let mut bytes = vec![0; 12];
  264. let mut packet = Packet::new(&mut bytes).unwrap();
  265. packet_repr().emit(&mut packet, &SRC_ADDR.into(), &DST_ADDR.into());
  266. assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]);
  267. }
  268. }