udp.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. use core::cmp::min;
  2. use managed::Managed;
  3. use {Error, Result};
  4. use wire::{IpProtocol, IpEndpoint, UdpRepr};
  5. use socket::{Socket, IpRepr};
  6. use storage::{Resettable, RingBuffer};
  7. /// A buffered UDP packet.
  8. #[derive(Debug)]
  9. pub struct PacketBuffer<'a> {
  10. endpoint: IpEndpoint,
  11. size: usize,
  12. payload: Managed<'a, [u8]>
  13. }
  14. impl<'a> PacketBuffer<'a> {
  15. /// Create a buffered packet.
  16. pub fn new<T>(payload: T) -> PacketBuffer<'a>
  17. where T: Into<Managed<'a, [u8]>> {
  18. PacketBuffer {
  19. endpoint: IpEndpoint::default(),
  20. size: 0,
  21. payload: payload.into()
  22. }
  23. }
  24. fn as_ref<'b>(&'b self) -> &'b [u8] {
  25. &self.payload[..self.size]
  26. }
  27. fn as_mut<'b>(&'b mut self) -> &'b mut [u8] {
  28. &mut self.payload[..self.size]
  29. }
  30. fn resize<'b>(&'b mut self, size: usize) -> Result<&'b mut Self> {
  31. if self.payload.len() >= size {
  32. self.size = size;
  33. Ok(self)
  34. } else {
  35. Err(Error::Truncated)
  36. }
  37. }
  38. }
  39. impl<'a> Resettable for PacketBuffer<'a> {
  40. fn reset(&mut self) {
  41. self.endpoint = Default::default();
  42. self.size = 0;
  43. }
  44. }
  45. /// An UDP packet ring buffer.
  46. pub type SocketBuffer<'a, 'b : 'a> = RingBuffer<'a, PacketBuffer<'b>>;
  47. /// An User Datagram Protocol socket.
  48. ///
  49. /// An UDP socket is bound to a specific endpoint, and owns transmit and receive
  50. /// packet buffers.
  51. #[derive(Debug)]
  52. pub struct UdpSocket<'a, 'b: 'a> {
  53. debug_id: usize,
  54. endpoint: IpEndpoint,
  55. rx_buffer: SocketBuffer<'a, 'b>,
  56. tx_buffer: SocketBuffer<'a, 'b>,
  57. }
  58. impl<'a, 'b> UdpSocket<'a, 'b> {
  59. /// Create an UDP socket with the given buffers.
  60. pub fn new(rx_buffer: SocketBuffer<'a, 'b>,
  61. tx_buffer: SocketBuffer<'a, 'b>) -> Socket<'a, 'b> {
  62. Socket::Udp(UdpSocket {
  63. debug_id: 0,
  64. endpoint: IpEndpoint::default(),
  65. rx_buffer: rx_buffer,
  66. tx_buffer: tx_buffer,
  67. })
  68. }
  69. /// Return the debug identifier.
  70. #[inline]
  71. pub fn debug_id(&self) -> usize {
  72. self.debug_id
  73. }
  74. /// Set the debug identifier.
  75. ///
  76. /// The debug identifier is a number printed in socket trace messages.
  77. /// It could as well be used by the user code.
  78. pub fn set_debug_id(&mut self, id: usize) {
  79. self.debug_id = id
  80. }
  81. /// Return the bound endpoint.
  82. #[inline]
  83. pub fn endpoint(&self) -> IpEndpoint {
  84. self.endpoint
  85. }
  86. /// Bind the socket to the given endpoint.
  87. ///
  88. /// This function returns `Err(Error::Illegal)` if the socket was open
  89. /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)`
  90. /// if the port in the given endpoint is zero.
  91. pub fn bind<T: Into<IpEndpoint>>(&mut self, endpoint: T) -> Result<()> {
  92. let endpoint = endpoint.into();
  93. if endpoint.port == 0 { return Err(Error::Unaddressable) }
  94. if self.is_open() { return Err(Error::Illegal) }
  95. self.endpoint = endpoint;
  96. Ok(())
  97. }
  98. /// Check whether the socket is open.
  99. #[inline]
  100. pub fn is_open(&self) -> bool {
  101. self.endpoint.port != 0
  102. }
  103. /// Check whether the transmit buffer is full.
  104. #[inline]
  105. pub fn can_send(&self) -> bool {
  106. !self.tx_buffer.full()
  107. }
  108. /// Check whether the receive buffer is not empty.
  109. #[inline]
  110. pub fn can_recv(&self) -> bool {
  111. !self.rx_buffer.empty()
  112. }
  113. /// Enqueue a packet to be sent to a given remote endpoint, and return a pointer
  114. /// to its payload.
  115. ///
  116. /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full,
  117. /// `Err(Error::Truncated)` if the requested size is larger than the packet buffer
  118. /// size, and `Err(Error::Unaddressable)` if local or remote port, or remote address,
  119. /// are unspecified.
  120. pub fn send(&mut self, size: usize, endpoint: IpEndpoint) -> Result<&mut [u8]> {
  121. if self.endpoint.port == 0 { return Err(Error::Unaddressable) }
  122. if !endpoint.is_specified() { return Err(Error::Unaddressable) }
  123. let packet_buf = self.tx_buffer.enqueue_one_with(|buf| buf.resize(size))?;
  124. packet_buf.endpoint = endpoint;
  125. net_trace!("[{}]{}:{}: buffer to send {} octets",
  126. self.debug_id, self.endpoint, packet_buf.endpoint, size);
  127. Ok(&mut packet_buf.as_mut()[..size])
  128. }
  129. /// Enqueue a packet to be sent to a given remote endpoint, and fill it from a slice.
  130. ///
  131. /// See also [send](#method.send).
  132. pub fn send_slice(&mut self, data: &[u8], endpoint: IpEndpoint) -> Result<()> {
  133. self.send(data.len(), endpoint)?.copy_from_slice(data);
  134. Ok(())
  135. }
  136. /// Dequeue a packet received from a remote endpoint, and return the endpoint as well
  137. /// as a pointer to the payload.
  138. ///
  139. /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty.
  140. pub fn recv(&mut self) -> Result<(&[u8], IpEndpoint)> {
  141. let packet_buf = self.rx_buffer.dequeue_one()?;
  142. net_trace!("[{}]{}:{}: receive {} buffered octets",
  143. self.debug_id, self.endpoint,
  144. packet_buf.endpoint, packet_buf.size);
  145. Ok((&packet_buf.as_ref(), packet_buf.endpoint))
  146. }
  147. /// Dequeue a packet received from a remote endpoint, and return the endpoint as well
  148. /// as copy the payload into the given slice.
  149. ///
  150. /// See also [recv](#method.recv).
  151. pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpEndpoint)> {
  152. let (buffer, endpoint) = self.recv()?;
  153. let length = min(data.len(), buffer.len());
  154. data[..length].copy_from_slice(&buffer[..length]);
  155. Ok((length, endpoint))
  156. }
  157. pub(crate) fn accepts(&self, ip_repr: &IpRepr, repr: &UdpRepr) -> bool {
  158. if self.endpoint.port != repr.dst_port { return false }
  159. if !self.endpoint.addr.is_unspecified() &&
  160. self.endpoint.addr != ip_repr.dst_addr() { return false }
  161. true
  162. }
  163. pub(crate) fn process(&mut self, ip_repr: &IpRepr, repr: &UdpRepr) -> Result<()> {
  164. debug_assert!(self.accepts(ip_repr, repr));
  165. let packet_buf = self.rx_buffer.enqueue_one_with(|buf| buf.resize(repr.payload.len()))?;
  166. packet_buf.as_mut().copy_from_slice(repr.payload);
  167. packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port };
  168. net_trace!("[{}]{}:{}: receiving {} octets",
  169. self.debug_id, self.endpoint,
  170. packet_buf.endpoint, packet_buf.size);
  171. Ok(())
  172. }
  173. pub(crate) fn dispatch<F>(&mut self, emit: F) -> Result<()>
  174. where F: FnOnce((IpRepr, UdpRepr)) -> Result<()> {
  175. let debug_id = self.debug_id;
  176. let endpoint = self.endpoint;
  177. self.tx_buffer.dequeue_one_with(|packet_buf| {
  178. net_trace!("[{}]{}:{}: sending {} octets",
  179. debug_id, endpoint,
  180. packet_buf.endpoint, packet_buf.size);
  181. let repr = UdpRepr {
  182. src_port: endpoint.port,
  183. dst_port: packet_buf.endpoint.port,
  184. payload: &packet_buf.as_ref()[..]
  185. };
  186. let ip_repr = IpRepr::Unspecified {
  187. src_addr: endpoint.addr,
  188. dst_addr: packet_buf.endpoint.addr,
  189. protocol: IpProtocol::Udp,
  190. payload_len: repr.buffer_len()
  191. };
  192. emit((ip_repr, repr))
  193. })
  194. }
  195. pub(crate) fn poll_at(&self) -> Option<u64> {
  196. if self.tx_buffer.empty() {
  197. None
  198. } else {
  199. Some(0)
  200. }
  201. }
  202. }
  203. #[cfg(test)]
  204. mod test {
  205. use wire::{IpAddress, Ipv4Address, IpRepr, Ipv4Repr, UdpRepr};
  206. use super::*;
  207. fn buffer(packets: usize) -> SocketBuffer<'static, 'static> {
  208. let mut storage = vec![];
  209. for _ in 0..packets {
  210. storage.push(PacketBuffer::new(vec![0; 16]))
  211. }
  212. SocketBuffer::new(storage)
  213. }
  214. fn socket(rx_buffer: SocketBuffer<'static, 'static>,
  215. tx_buffer: SocketBuffer<'static, 'static>)
  216. -> UdpSocket<'static, 'static> {
  217. match UdpSocket::new(rx_buffer, tx_buffer) {
  218. Socket::Udp(socket) => socket,
  219. _ => unreachable!()
  220. }
  221. }
  222. const LOCAL_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 1]));
  223. const REMOTE_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 2]));
  224. const LOCAL_PORT: u16 = 53;
  225. const REMOTE_PORT: u16 = 49500;
  226. const LOCAL_END: IpEndpoint = IpEndpoint { addr: LOCAL_IP, port: LOCAL_PORT };
  227. const REMOTE_END: IpEndpoint = IpEndpoint { addr: REMOTE_IP, port: REMOTE_PORT };
  228. #[test]
  229. fn test_bind_unaddressable() {
  230. let mut socket = socket(buffer(0), buffer(0));
  231. assert_eq!(socket.bind(0), Err(Error::Unaddressable));
  232. }
  233. #[test]
  234. fn test_bind_twice() {
  235. let mut socket = socket(buffer(0), buffer(0));
  236. assert_eq!(socket.bind(1), Ok(()));
  237. assert_eq!(socket.bind(2), Err(Error::Illegal));
  238. }
  239. const LOCAL_IP_REPR: IpRepr = IpRepr::Unspecified {
  240. src_addr: LOCAL_IP,
  241. dst_addr: REMOTE_IP,
  242. protocol: IpProtocol::Udp,
  243. payload_len: 8 + 6
  244. };
  245. const LOCAL_UDP_REPR: UdpRepr = UdpRepr {
  246. src_port: LOCAL_PORT,
  247. dst_port: REMOTE_PORT,
  248. payload: b"abcdef"
  249. };
  250. #[test]
  251. fn test_send_unaddressable() {
  252. let mut socket = socket(buffer(0), buffer(1));
  253. assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Err(Error::Unaddressable));
  254. assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
  255. assert_eq!(socket.send_slice(b"abcdef",
  256. IpEndpoint { addr: IpAddress::Unspecified, ..REMOTE_END }),
  257. Err(Error::Unaddressable));
  258. assert_eq!(socket.send_slice(b"abcdef",
  259. IpEndpoint { port: 0, ..REMOTE_END }),
  260. Err(Error::Unaddressable));
  261. assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(()));
  262. }
  263. #[test]
  264. fn test_send_truncated() {
  265. let mut socket = socket(buffer(0), buffer(1));
  266. assert_eq!(socket.bind(LOCAL_END), Ok(()));
  267. assert_eq!(socket.send_slice(&[0; 32][..], REMOTE_END), Err(Error::Truncated));
  268. }
  269. #[test]
  270. fn test_send_dispatch() {
  271. let mut socket = socket(buffer(0), buffer(1));
  272. assert_eq!(socket.bind(LOCAL_END), Ok(()));
  273. assert!(socket.can_send());
  274. assert_eq!(socket.dispatch(|_| unreachable!()),
  275. Err(Error::Exhausted));
  276. assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(()));
  277. assert_eq!(socket.send_slice(b"123456", REMOTE_END), Err(Error::Exhausted));
  278. assert!(!socket.can_send());
  279. assert_eq!(socket.dispatch(|(ip_repr, udp_repr)| {
  280. assert_eq!(ip_repr, LOCAL_IP_REPR);
  281. assert_eq!(udp_repr, LOCAL_UDP_REPR);
  282. Err(Error::Unaddressable)
  283. }), Err(Error::Unaddressable));
  284. assert!(!socket.can_send());
  285. assert_eq!(socket.dispatch(|(ip_repr, udp_repr)| {
  286. assert_eq!(ip_repr, LOCAL_IP_REPR);
  287. assert_eq!(udp_repr, LOCAL_UDP_REPR);
  288. Ok(())
  289. }), Ok(()));
  290. assert!(socket.can_send());
  291. }
  292. const REMOTE_IP_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr {
  293. src_addr: Ipv4Address([10, 0, 0, 2]),
  294. dst_addr: Ipv4Address([10, 0, 0, 1]),
  295. protocol: IpProtocol::Udp,
  296. payload_len: 8 + 6
  297. });
  298. const REMOTE_UDP_REPR: UdpRepr = UdpRepr {
  299. src_port: REMOTE_PORT,
  300. dst_port: LOCAL_PORT,
  301. payload: b"abcdef"
  302. };
  303. #[test]
  304. fn test_recv_process() {
  305. let mut socket = socket(buffer(1), buffer(0));
  306. assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
  307. assert!(!socket.can_recv());
  308. assert_eq!(socket.recv(), Err(Error::Exhausted));
  309. assert!(socket.accepts(&REMOTE_IP_REPR, &REMOTE_UDP_REPR));
  310. assert_eq!(socket.process(&REMOTE_IP_REPR, &REMOTE_UDP_REPR),
  311. Ok(()));
  312. assert!(socket.can_recv());
  313. assert!(socket.accepts(&REMOTE_IP_REPR, &REMOTE_UDP_REPR));
  314. assert_eq!(socket.process(&REMOTE_IP_REPR, &REMOTE_UDP_REPR),
  315. Err(Error::Exhausted));
  316. assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END)));
  317. assert!(!socket.can_recv());
  318. }
  319. #[test]
  320. fn test_recv_truncated_slice() {
  321. let mut socket = socket(buffer(1), buffer(0));
  322. assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
  323. assert!(socket.accepts(&REMOTE_IP_REPR, &REMOTE_UDP_REPR));
  324. assert_eq!(socket.process(&REMOTE_IP_REPR, &REMOTE_UDP_REPR),
  325. Ok(()));
  326. let mut slice = [0; 4];
  327. assert_eq!(socket.recv_slice(&mut slice[..]), Ok((4, REMOTE_END)));
  328. assert_eq!(&slice, b"abcd");
  329. }
  330. #[test]
  331. fn test_recv_truncated_packet() {
  332. let mut socket = socket(buffer(1), buffer(0));
  333. assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
  334. let udp_repr = UdpRepr { payload: &[0; 100][..], ..REMOTE_UDP_REPR };
  335. assert!(socket.accepts(&REMOTE_IP_REPR, &udp_repr));
  336. assert_eq!(socket.process(&REMOTE_IP_REPR, &udp_repr),
  337. Err(Error::Truncated));
  338. }
  339. #[test]
  340. fn test_doesnt_accept_wrong_port() {
  341. let mut socket = socket(buffer(1), buffer(0));
  342. assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
  343. let mut udp_repr = REMOTE_UDP_REPR;
  344. assert!(socket.accepts(&REMOTE_IP_REPR, &udp_repr));
  345. udp_repr.dst_port += 1;
  346. assert!(!socket.accepts(&REMOTE_IP_REPR, &udp_repr));
  347. }
  348. #[test]
  349. fn test_doesnt_accept_wrong_ip() {
  350. let ip_repr = IpRepr::Ipv4(Ipv4Repr {
  351. src_addr: Ipv4Address([10, 0, 0, 2]),
  352. dst_addr: Ipv4Address([10, 0, 0, 10]),
  353. protocol: IpProtocol::Udp,
  354. payload_len: 8 + 6
  355. });
  356. let mut port_bound_socket = socket(buffer(1), buffer(0));
  357. assert_eq!(port_bound_socket.bind(LOCAL_PORT), Ok(()));
  358. assert!(port_bound_socket.accepts(&ip_repr, &REMOTE_UDP_REPR));
  359. let mut ip_bound_socket = socket(buffer(1), buffer(0));
  360. assert_eq!(ip_bound_socket.bind(LOCAL_END), Ok(()));
  361. assert!(!ip_bound_socket.accepts(&ip_repr, &REMOTE_UDP_REPR));
  362. }
  363. }