udp.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. use core::cmp::min;
  2. use managed::Managed;
  3. use {Error, Result};
  4. use phy::DeviceLimits;
  5. use wire::{IpProtocol, IpEndpoint};
  6. use wire::{UdpPacket, UdpRepr};
  7. use socket::{Socket, IpRepr, IpPayload};
  8. use storage::{Resettable, RingBuffer};
  9. /// A buffered UDP packet.
  10. #[derive(Debug)]
  11. pub struct PacketBuffer<'a> {
  12. endpoint: IpEndpoint,
  13. size: usize,
  14. payload: Managed<'a, [u8]>
  15. }
  16. impl<'a> PacketBuffer<'a> {
  17. /// Create a buffered packet.
  18. pub fn new<T>(payload: T) -> PacketBuffer<'a>
  19. where T: Into<Managed<'a, [u8]>> {
  20. PacketBuffer {
  21. endpoint: IpEndpoint::default(),
  22. size: 0,
  23. payload: payload.into()
  24. }
  25. }
  26. fn as_ref<'b>(&'b self) -> &'b [u8] {
  27. &self.payload[..self.size]
  28. }
  29. fn as_mut<'b>(&'b mut self) -> &'b mut [u8] {
  30. &mut self.payload[..self.size]
  31. }
  32. fn resize<'b>(&'b mut self, size: usize) -> Result<&'b mut Self> {
  33. if self.payload.len() >= size {
  34. self.size = size;
  35. Ok(self)
  36. } else {
  37. Err(Error::Truncated)
  38. }
  39. }
  40. }
  41. impl<'a> Resettable for PacketBuffer<'a> {
  42. fn reset(&mut self) {
  43. self.endpoint = Default::default();
  44. self.size = 0;
  45. }
  46. }
  47. /// An UDP packet ring buffer.
  48. pub type SocketBuffer<'a, 'b : 'a> = RingBuffer<'a, PacketBuffer<'b>>;
  49. /// An User Datagram Protocol socket.
  50. ///
  51. /// An UDP socket is bound to a specific endpoint, and owns transmit and receive
  52. /// packet buffers.
  53. #[derive(Debug)]
  54. pub struct UdpSocket<'a, 'b: 'a> {
  55. debug_id: usize,
  56. endpoint: IpEndpoint,
  57. rx_buffer: SocketBuffer<'a, 'b>,
  58. tx_buffer: SocketBuffer<'a, 'b>,
  59. }
  60. impl<'a, 'b> UdpSocket<'a, 'b> {
  61. /// Create an UDP socket with the given buffers.
  62. pub fn new(rx_buffer: SocketBuffer<'a, 'b>,
  63. tx_buffer: SocketBuffer<'a, 'b>) -> Socket<'a, 'b> {
  64. Socket::Udp(UdpSocket {
  65. debug_id: 0,
  66. endpoint: IpEndpoint::default(),
  67. rx_buffer: rx_buffer,
  68. tx_buffer: tx_buffer,
  69. })
  70. }
  71. /// Return the debug identifier.
  72. #[inline]
  73. pub fn debug_id(&self) -> usize {
  74. self.debug_id
  75. }
  76. /// Set the debug identifier.
  77. ///
  78. /// The debug identifier is a number printed in socket trace messages.
  79. /// It could as well be used by the user code.
  80. pub fn set_debug_id(&mut self, id: usize) {
  81. self.debug_id = id
  82. }
  83. /// Return the bound endpoint.
  84. #[inline]
  85. pub fn endpoint(&self) -> IpEndpoint {
  86. self.endpoint
  87. }
  88. /// Bind the socket to the given endpoint.
  89. ///
  90. /// This function returns `Err(Error::Illegal)` if the socket was open
  91. /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)`
  92. /// if the port in the given endpoint is zero.
  93. pub fn bind<T: Into<IpEndpoint>>(&mut self, endpoint: T) -> Result<()> {
  94. let endpoint = endpoint.into();
  95. if endpoint.port == 0 { return Err(Error::Unaddressable) }
  96. if self.is_open() { return Err(Error::Illegal) }
  97. self.endpoint = endpoint;
  98. Ok(())
  99. }
  100. /// Check whether the socket is open.
  101. #[inline]
  102. pub fn is_open(&self) -> bool {
  103. self.endpoint.port != 0
  104. }
  105. /// Check whether the transmit buffer is full.
  106. #[inline]
  107. pub fn can_send(&self) -> bool {
  108. !self.tx_buffer.full()
  109. }
  110. /// Check whether the receive buffer is not empty.
  111. #[inline]
  112. pub fn can_recv(&self) -> bool {
  113. !self.rx_buffer.empty()
  114. }
  115. /// Enqueue a packet to be sent to a given remote endpoint, and return a pointer
  116. /// to its payload.
  117. ///
  118. /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full,
  119. /// `Err(Error::Truncated)` if the requested size is larger than the packet buffer
  120. /// size, and `Err(Error::Unaddressable)` if local or remote port, or remote address,
  121. /// are unspecified.
  122. pub fn send(&mut self, size: usize, endpoint: IpEndpoint) -> Result<&mut [u8]> {
  123. if self.endpoint.port == 0 { return Err(Error::Unaddressable) }
  124. if !endpoint.is_specified() { return Err(Error::Unaddressable) }
  125. let packet_buf = self.tx_buffer.try_enqueue(|buf| buf.resize(size))?;
  126. packet_buf.endpoint = endpoint;
  127. net_trace!("[{}]{}:{}: buffer to send {} octets",
  128. self.debug_id, self.endpoint, packet_buf.endpoint, size);
  129. Ok(&mut packet_buf.as_mut()[..size])
  130. }
  131. /// Enqueue a packet to be sent to a given remote endpoint, and fill it from a slice.
  132. ///
  133. /// See also [send](#method.send).
  134. pub fn send_slice(&mut self, data: &[u8], endpoint: IpEndpoint) -> Result<()> {
  135. self.send(data.len(), endpoint)?.copy_from_slice(data);
  136. Ok(())
  137. }
  138. /// Dequeue a packet received from a remote endpoint, and return the endpoint as well
  139. /// as a pointer to the payload.
  140. ///
  141. /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty.
  142. pub fn recv(&mut self) -> Result<(&[u8], IpEndpoint)> {
  143. let packet_buf = self.rx_buffer.dequeue()?;
  144. net_trace!("[{}]{}:{}: receive {} buffered octets",
  145. self.debug_id, self.endpoint,
  146. packet_buf.endpoint, packet_buf.size);
  147. Ok((&packet_buf.as_ref(), packet_buf.endpoint))
  148. }
  149. /// Dequeue a packet received from a remote endpoint, and return the endpoint as well
  150. /// as copy the payload into the given slice.
  151. ///
  152. /// See also [recv](#method.recv).
  153. pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpEndpoint)> {
  154. let (buffer, endpoint) = self.recv()?;
  155. let length = min(data.len(), buffer.len());
  156. data[..length].copy_from_slice(&buffer[..length]);
  157. Ok((length, endpoint))
  158. }
  159. pub(crate) fn process(&mut self, _timestamp: u64, ip_repr: &IpRepr,
  160. payload: &[u8]) -> Result<()> {
  161. debug_assert!(ip_repr.protocol() == IpProtocol::Udp);
  162. let packet = UdpPacket::new_checked(&payload[..ip_repr.payload_len()])?;
  163. let repr = UdpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr())?;
  164. // Reject packets with a wrong destination.
  165. if self.endpoint.port != repr.dst_port { return Err(Error::Rejected) }
  166. if !self.endpoint.addr.is_unspecified() &&
  167. self.endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) }
  168. let packet_buf = self.rx_buffer.try_enqueue(|buf| buf.resize(repr.payload.len()))?;
  169. packet_buf.as_mut().copy_from_slice(repr.payload);
  170. packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port };
  171. net_trace!("[{}]{}:{}: receiving {} octets",
  172. self.debug_id, self.endpoint,
  173. packet_buf.endpoint, packet_buf.size);
  174. Ok(())
  175. }
  176. pub(crate) fn dispatch<F, R>(&mut self, _timestamp: u64, _limits: &DeviceLimits,
  177. emit: &mut F) -> Result<R>
  178. where F: FnMut(&IpRepr, &IpPayload) -> Result<R> {
  179. let packet_buf = self.tx_buffer.dequeue()?;
  180. net_trace!("[{}]{}:{}: sending {} octets",
  181. self.debug_id, self.endpoint,
  182. packet_buf.endpoint, packet_buf.size);
  183. let repr = UdpRepr {
  184. src_port: self.endpoint.port,
  185. dst_port: packet_buf.endpoint.port,
  186. payload: &packet_buf.as_ref()[..]
  187. };
  188. let ip_repr = IpRepr::Unspecified {
  189. src_addr: self.endpoint.addr,
  190. dst_addr: packet_buf.endpoint.addr,
  191. protocol: IpProtocol::Udp,
  192. payload_len: repr.buffer_len()
  193. };
  194. emit(&ip_repr, &repr)
  195. }
  196. }
  197. impl<'a> IpPayload for UdpRepr<'a> {
  198. fn buffer_len(&self) -> usize {
  199. self.buffer_len()
  200. }
  201. fn emit(&self, repr: &IpRepr, payload: &mut [u8]) {
  202. let mut packet = UdpPacket::new(payload);
  203. self.emit(&mut packet, &repr.src_addr(), &repr.dst_addr())
  204. }
  205. }
  206. #[cfg(test)]
  207. mod test {
  208. use wire::{IpAddress, Ipv4Address, IpRepr, Ipv4Repr, UdpRepr};
  209. use super::*;
  210. fn buffer(packets: usize) -> SocketBuffer<'static, 'static> {
  211. let mut storage = vec![];
  212. for _ in 0..packets {
  213. storage.push(PacketBuffer::new(vec![0; 16]))
  214. }
  215. SocketBuffer::new(storage)
  216. }
  217. fn socket(rx_buffer: SocketBuffer<'static, 'static>,
  218. tx_buffer: SocketBuffer<'static, 'static>)
  219. -> UdpSocket<'static, 'static> {
  220. match UdpSocket::new(rx_buffer, tx_buffer) {
  221. Socket::Udp(socket) => socket,
  222. _ => unreachable!()
  223. }
  224. }
  225. const LOCAL_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 1]));
  226. const REMOTE_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 2]));
  227. const LOCAL_PORT: u16 = 53;
  228. const REMOTE_PORT: u16 = 49500;
  229. const LOCAL_END: IpEndpoint = IpEndpoint { addr: LOCAL_IP, port: LOCAL_PORT };
  230. const REMOTE_END: IpEndpoint = IpEndpoint { addr: REMOTE_IP, port: REMOTE_PORT };
  231. #[test]
  232. fn test_bind_unaddressable() {
  233. let mut socket = socket(buffer(0), buffer(0));
  234. assert_eq!(socket.bind(0), Err(Error::Unaddressable));
  235. }
  236. #[test]
  237. fn test_bind_twice() {
  238. let mut socket = socket(buffer(0), buffer(0));
  239. assert_eq!(socket.bind(1), Ok(()));
  240. assert_eq!(socket.bind(2), Err(Error::Illegal));
  241. }
  242. const LOCAL_IP_REPR: IpRepr = IpRepr::Unspecified {
  243. src_addr: LOCAL_IP,
  244. dst_addr: REMOTE_IP,
  245. protocol: IpProtocol::Udp,
  246. payload_len: 8 + 6
  247. };
  248. const LOCAL_UDP_REPR: UdpRepr = UdpRepr {
  249. src_port: LOCAL_PORT,
  250. dst_port: REMOTE_PORT,
  251. payload: b"abcdef"
  252. };
  253. #[test]
  254. fn test_send_unaddressable() {
  255. let mut socket = socket(buffer(0), buffer(1));
  256. assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Err(Error::Unaddressable));
  257. assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
  258. assert_eq!(socket.send_slice(b"abcdef",
  259. IpEndpoint { addr: IpAddress::Unspecified, ..REMOTE_END }),
  260. Err(Error::Unaddressable));
  261. assert_eq!(socket.send_slice(b"abcdef",
  262. IpEndpoint { port: 0, ..REMOTE_END }),
  263. Err(Error::Unaddressable));
  264. assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(()));
  265. }
  266. #[test]
  267. fn test_send_truncated() {
  268. let mut socket = socket(buffer(0), buffer(1));
  269. assert_eq!(socket.bind(LOCAL_END), Ok(()));
  270. assert_eq!(socket.send_slice(&[0; 32][..], REMOTE_END), Err(Error::Truncated));
  271. }
  272. #[test]
  273. fn test_send_dispatch() {
  274. let limits = DeviceLimits::default();
  275. let mut socket = socket(buffer(0), buffer(1));
  276. assert_eq!(socket.bind(LOCAL_END), Ok(()));
  277. assert!(socket.can_send());
  278. assert_eq!(socket.dispatch(0, &limits, &mut |_ip_repr, _ip_payload| {
  279. unreachable!()
  280. }), Err(Error::Exhausted) as Result<()>);
  281. assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(()));
  282. assert_eq!(socket.send_slice(b"123456", REMOTE_END), Err(Error::Exhausted));
  283. assert!(!socket.can_send());
  284. macro_rules! assert_payload_eq {
  285. ($ip_repr:expr, $ip_payload:expr, $expected:expr) => {{
  286. let mut buffer = vec![0; $ip_payload.buffer_len()];
  287. $ip_payload.emit($ip_repr, &mut buffer);
  288. let udp_packet = UdpPacket::new_checked(&buffer).unwrap();
  289. let udp_repr = UdpRepr::parse(&udp_packet, &LOCAL_IP, &REMOTE_IP).unwrap();
  290. assert_eq!(&udp_repr, $expected)
  291. }}
  292. }
  293. assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
  294. assert_eq!(ip_repr, &LOCAL_IP_REPR);
  295. assert_payload_eq!(ip_repr, ip_payload, &LOCAL_UDP_REPR);
  296. Err(Error::Unaddressable)
  297. }), Err(Error::Unaddressable) as Result<()>);
  298. /*assert!(!socket.can_send());*/
  299. assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
  300. assert_eq!(ip_repr, &LOCAL_IP_REPR);
  301. assert_payload_eq!(ip_repr, ip_payload, &LOCAL_UDP_REPR);
  302. Ok(())
  303. }), /*Ok(())*/ Err(Error::Exhausted));
  304. assert!(socket.can_send());
  305. }
  306. const REMOTE_IP_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr {
  307. src_addr: Ipv4Address([10, 0, 0, 2]),
  308. dst_addr: Ipv4Address([10, 0, 0, 1]),
  309. protocol: IpProtocol::Udp,
  310. payload_len: 8 + 6
  311. });
  312. const REMOTE_UDP_REPR: UdpRepr = UdpRepr {
  313. src_port: REMOTE_PORT,
  314. dst_port: LOCAL_PORT,
  315. payload: b"abcdef"
  316. };
  317. #[test]
  318. fn test_recv_process() {
  319. let mut socket = socket(buffer(1), buffer(0));
  320. assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
  321. let mut buffer = vec![0; REMOTE_UDP_REPR.buffer_len()];
  322. REMOTE_UDP_REPR.emit(&mut UdpPacket::new(&mut buffer), &LOCAL_IP, &REMOTE_IP);
  323. assert!(!socket.can_recv());
  324. assert_eq!(socket.recv(), Err(Error::Exhausted));
  325. assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer),
  326. Ok(()));
  327. assert!(socket.can_recv());
  328. assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer),
  329. Err(Error::Exhausted));
  330. assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END)));
  331. assert!(!socket.can_recv());
  332. }
  333. #[test]
  334. fn test_recv_truncated_slice() {
  335. let mut socket = socket(buffer(1), buffer(0));
  336. assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
  337. let mut buffer = vec![0; REMOTE_UDP_REPR.buffer_len()];
  338. REMOTE_UDP_REPR.emit(&mut UdpPacket::new(&mut buffer), &LOCAL_IP, &REMOTE_IP);
  339. assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer), Ok(()));
  340. let mut slice = [0; 4];
  341. assert_eq!(socket.recv_slice(&mut slice[..]), Ok((4, REMOTE_END)));
  342. assert_eq!(&slice, b"abcd");
  343. }
  344. #[test]
  345. fn test_recv_truncated_packet() {
  346. let mut socket = socket(buffer(1), buffer(0));
  347. assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
  348. let udp_repr = UdpRepr { payload: &[0; 100][..], ..REMOTE_UDP_REPR };
  349. let mut buffer = vec![0; udp_repr.buffer_len()];
  350. udp_repr.emit(&mut UdpPacket::new(&mut buffer), &LOCAL_IP, &REMOTE_IP);
  351. assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer),
  352. Err(Error::Truncated));
  353. }
  354. }