Przeglądaj źródła

Use FnOnce, not FnMut, in Socket::dispatch() functions.

There was never any reason to use FnMut and this significantly
simplifies the job of the borrow checker.
whitequark 7 lat temu
rodzic
commit
917f89e14b
5 zmienionych plików z 28 dodań i 35 usunięć
  1. 9 16
      src/iface/ethernet.rs
  2. 2 2
      src/socket/mod.rs
  3. 7 7
      src/socket/raw.rs
  4. 5 5
      src/socket/tcp.rs
  5. 5 5
      src/socket/udp.rs

+ 9 - 16
src/iface/ethernet.rs

@@ -403,30 +403,23 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
     }
 
     fn emit(&mut self, sockets: &mut SocketSet, timestamp: u64) -> Result<bool> {
-        // Borrow checker is being overly careful around closures, so we have
-        // to hack around that.
-        let src_hardware_addr = self.hardware_addr;
-        let src_protocol_addrs = self.protocol_addrs.as_ref();
-        let arp_cache = &mut self.arp_cache;
-        let device = &mut self.device;
-
-        let mut limits = device.limits();
+        let mut limits = self.device.limits();
         limits.max_transmission_unit -= EthernetFrame::<&[u8]>::header_len();
 
         let mut nothing_to_transmit = true;
         for socket in sockets.iter_mut() {
-            let result = socket.dispatch(timestamp, &limits, &mut |repr, payload| {
-                let repr = repr.lower(src_protocol_addrs)?;
+            let result = socket.dispatch(timestamp, &limits, |repr, payload| {
+                let repr = repr.lower(self.protocol_addrs.as_ref())?;
 
-                match arp_cache.lookup(&repr.dst_addr()) {
+                match self.arp_cache.lookup(&repr.dst_addr()) {
                     Some(dst_hardware_addr) => {
                         let tx_len = EthernetFrame::<&[u8]>::buffer_len(repr.buffer_len() +
                                                                         payload.buffer_len());
-                        let mut tx_buffer = device.transmit(timestamp, tx_len)?;
+                        let mut tx_buffer = self.device.transmit(timestamp, tx_len)?;
                         debug_assert!(tx_buffer.as_ref().len() == tx_len);
 
                         let mut frame = EthernetFrame::new(&mut tx_buffer);
-                        frame.set_src_addr(src_hardware_addr);
+                        frame.set_src_addr(self.hardware_addr);
                         frame.set_dst_addr(dst_hardware_addr);
                         frame.set_ethertype(EthernetProtocol::Ipv4);
 
@@ -447,18 +440,18 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
 
                         let payload = ArpRepr::EthernetIpv4 {
                             operation: ArpOperation::Request,
-                            source_hardware_addr: src_hardware_addr,
+                            source_hardware_addr: self.hardware_addr,
                             source_protocol_addr: src_addr,
                             target_hardware_addr: EthernetAddress::default(),
                             target_protocol_addr: dst_addr,
                         };
 
                         let tx_len = EthernetFrame::<&[u8]>::buffer_len(payload.buffer_len());
-                        let mut tx_buffer = device.transmit(timestamp, tx_len)?;
+                        let mut tx_buffer = self.device.transmit(timestamp, tx_len)?;
                         debug_assert!(tx_buffer.as_ref().len() == tx_len);
 
                         let mut frame = EthernetFrame::new(&mut tx_buffer);
-                        frame.set_src_addr(src_hardware_addr);
+                        frame.set_src_addr(self.hardware_addr);
                         frame.set_dst_addr(EthernetAddress([0xff; 6]));
                         frame.set_ethertype(EthernetProtocol::Arp);
 

+ 2 - 2
src/socket/mod.rs

@@ -82,8 +82,8 @@ impl<'a, 'b> Socket<'a, 'b> {
     }
 
     pub(crate) fn dispatch<F, R>(&mut self, timestamp: u64, limits: &DeviceLimits,
-                                 emit: &mut F) -> Result<R>
-            where F: FnMut(&IpRepr, &IpPayload) -> Result<R> {
+                                 emit: F) -> Result<R>
+            where F: FnOnce(&IpRepr, &IpPayload) -> Result<R> {
         dispatch_socket!(self, |socket [mut]| socket.dispatch(timestamp, limits, emit))
     }
 }

+ 7 - 7
src/socket/raw.rs

@@ -184,8 +184,8 @@ impl<'a, 'b> RawSocket<'a, 'b> {
     }
 
     pub(crate) fn dispatch<F, R>(&mut self, _timestamp: u64, _limits: &DeviceLimits,
-                                 emit: &mut F) -> Result<R>
-            where F: FnMut(&IpRepr, &IpPayload) -> Result<R> {
+                                 emit: F) -> Result<R>
+            where F: FnOnce(&IpRepr, &IpPayload) -> Result<R> {
         fn prepare(protocol: IpProtocol, buffer: &mut [u8]) -> Result<(IpRepr, RawRepr)> {
             match IpVersion::of_packet(buffer.as_ref())? {
                 IpVersion::Ipv4 => {
@@ -291,7 +291,7 @@ mod test {
         let mut socket = socket(buffer(0), buffer(1));
 
         assert!(socket.can_send());
-        assert_eq!(socket.dispatch(0, &limits, &mut |_ip_repr, _ip_payload| {
+        assert_eq!(socket.dispatch(0, &limits, |_ip_repr, _ip_payload| {
             unreachable!()
         }), Err(Error::Exhausted) as Result<()>);
 
@@ -307,14 +307,14 @@ mod test {
             }}
         }
 
-        assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
+        assert_eq!(socket.dispatch(0, &limits, |ip_repr, ip_payload| {
             assert_eq!(ip_repr, &HEADER_REPR);
             assert_payload_eq!(ip_repr, ip_payload, PACKET_BYTES);
             Err(Error::Unaddressable)
         }), Err(Error::Unaddressable) as Result<()>);
         /*assert!(!socket.can_send());*/
 
-        assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
+        assert_eq!(socket.dispatch(0, &limits, |ip_repr, ip_payload| {
             assert_eq!(ip_repr, &HEADER_REPR);
             assert_payload_eq!(ip_repr, ip_payload, PACKET_BYTES);
             Ok(())
@@ -332,7 +332,7 @@ mod test {
         Ipv4Packet::new(&mut wrong_version).set_version(5);
 
         assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
-        assert_eq!(socket.dispatch(0, &limits, &mut |_ip_repr, _ip_payload| {
+        assert_eq!(socket.dispatch(0, &limits, |_ip_repr, _ip_payload| {
             unreachable!()
         }), Err(Error::Rejected) as Result<()>);
 
@@ -340,7 +340,7 @@ mod test {
         Ipv4Packet::new(&mut wrong_protocol).set_protocol(IpProtocol::Tcp);
 
         assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
-        assert_eq!(socket.dispatch(0, &limits, &mut |_ip_repr, _ip_payload| {
+        assert_eq!(socket.dispatch(0, &limits, |_ip_repr, _ip_payload| {
             unreachable!()
         }), Err(Error::Rejected) as Result<()>);
     }

+ 5 - 5
src/socket/tcp.rs

@@ -1091,8 +1091,8 @@ impl<'a> TcpSocket<'a> {
     }
 
     pub(crate) fn dispatch<F, R>(&mut self, timestamp: u64, limits: &DeviceLimits,
-                                 emit: &mut F) -> Result<R>
-            where F: FnMut(&IpRepr, &IpPayload) -> Result<R> {
+                                 emit: F) -> Result<R>
+            where F: FnOnce(&IpRepr, &IpPayload) -> Result<R> {
         if !self.remote_endpoint.is_specified() { return Err(Error::Exhausted) }
 
         if let Some(retransmit_delta) = self.timer.should_retransmit(timestamp) {
@@ -1392,7 +1392,7 @@ mod test {
         let mut buffer = vec![];
         let mut limits = DeviceLimits::default();
         limits.max_transmission_unit = 1520;
-        let result = socket.dispatch(timestamp, &limits, &mut |ip_repr, payload| {
+        let result = socket.dispatch(timestamp, &limits, |ip_repr, payload| {
             let ip_repr = ip_repr.lower(&[LOCAL_END.addr.into()]).unwrap();
 
             assert_eq!(ip_repr.protocol(), IpProtocol::Tcp);
@@ -2869,7 +2869,7 @@ mod test {
 
         limits.max_burst_size = None;
         s.send_slice(b"abcdef").unwrap();
-        s.dispatch(0, &limits, &mut |ip_repr, payload| {
+        s.dispatch(0, &limits, |ip_repr, payload| {
             let mut buffer = vec![0; payload.buffer_len()];
             payload.emit(&ip_repr, &mut buffer[..]);
             let packet = TcpPacket::new(&buffer[..]);
@@ -2879,7 +2879,7 @@ mod test {
 
         limits.max_burst_size = Some(4);
         s.send_slice(b"abcdef").unwrap();
-        s.dispatch(0, &limits, &mut |ip_repr, payload| {
+        s.dispatch(0, &limits, |ip_repr, payload| {
             let mut buffer = vec![0; payload.buffer_len()];
             payload.emit(&ip_repr, &mut buffer[..]);
             let packet = TcpPacket::new(&buffer[..]);

+ 5 - 5
src/socket/udp.rs

@@ -203,8 +203,8 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     }
 
     pub(crate) fn dispatch<F, R>(&mut self, _timestamp: u64, _limits: &DeviceLimits,
-                                 emit: &mut F) -> Result<R>
-            where F: FnMut(&IpRepr, &IpPayload) -> Result<R> {
+                                 emit: F) -> Result<R>
+            where F: FnOnce(&IpRepr, &IpPayload) -> Result<R> {
         let packet_buf = self.tx_buffer.dequeue()?;
         net_trace!("[{}]{}:{}: sending {} octets",
                    self.debug_id, self.endpoint,
@@ -320,7 +320,7 @@ mod test {
         assert_eq!(socket.bind(LOCAL_END), Ok(()));
 
         assert!(socket.can_send());
-        assert_eq!(socket.dispatch(0, &limits, &mut |_ip_repr, _ip_payload| {
+        assert_eq!(socket.dispatch(0, &limits, |_ip_repr, _ip_payload| {
             unreachable!()
         }), Err(Error::Exhausted) as Result<()>);
 
@@ -338,14 +338,14 @@ mod test {
             }}
         }
 
-        assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
+        assert_eq!(socket.dispatch(0, &limits, |ip_repr, ip_payload| {
             assert_eq!(ip_repr, &LOCAL_IP_REPR);
             assert_payload_eq!(ip_repr, ip_payload, &LOCAL_UDP_REPR);
             Err(Error::Unaddressable)
         }), Err(Error::Unaddressable) as Result<()>);
         /*assert!(!socket.can_send());*/
 
-        assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
+        assert_eq!(socket.dispatch(0, &limits, |ip_repr, ip_payload| {
             assert_eq!(ip_repr, &LOCAL_IP_REPR);
             assert_payload_eq!(ip_repr, ip_payload, &LOCAL_UDP_REPR);
             Ok(())