Эх сурвалжийг харах

Distinguish sockets by debug identifiers (socket set indexes).

whitequark 8 жил өмнө
parent
commit
16826628fe
4 өөрчлөгдсөн 125 нэмэгдсэн , 74 устгасан
  1. 25 14
      src/socket/mod.rs
  2. 5 2
      src/socket/set.rs
  3. 66 48
      src/socket/tcp.rs
  4. 29 10
      src/socket/udp.rs

+ 25 - 14
src/socket/mod.rs

@@ -49,7 +49,30 @@ pub enum Socket<'a, 'b: 'a> {
     __Nonexhaustive
 }
 
+macro_rules! dispatch_socket {
+    ($self_:expr, |$socket:ident [$( $mut_:tt )*]| $code:expr) => ({
+        match $self_ {
+            &$( $mut_ )* Socket::Udp(ref $( $mut_ )* $socket) => $code,
+            &$( $mut_ )* Socket::Tcp(ref $( $mut_ )* $socket) => $code,
+            &$( $mut_ )* Socket::__Nonexhaustive => unreachable!()
+        }
+    })
+}
+
 impl<'a, 'b> Socket<'a, 'b> {
+    /// Return the debug identifier.
+    pub fn debug_id(&self) -> usize {
+        dispatch_socket!(self, |socket []| socket.debug_id())
+    }
+
+    /// Set the debug identifier.
+    ///
+    /// The debug identifier is a number printed in socket trace messages.
+    /// It could as well be used by the user code.
+    pub fn set_debug_id(&mut self, id: usize) {
+        dispatch_socket!(self, |socket [mut]| socket.set_debug_id(id))
+    }
+
     /// Process a packet received from a network interface.
     ///
     /// This function checks if the packet contained in the payload matches the socket endpoint,
@@ -59,13 +82,7 @@ impl<'a, 'b> Socket<'a, 'b> {
     /// This function is used internally by the networking stack.
     pub fn process(&mut self, timestamp: u64, ip_repr: &IpRepr,
                    payload: &[u8]) -> Result<(), Error> {
-        match self {
-            &mut Socket::Udp(ref mut socket) =>
-                socket.process(timestamp, ip_repr, payload),
-            &mut Socket::Tcp(ref mut socket) =>
-                socket.process(timestamp, ip_repr, payload),
-            &mut Socket::__Nonexhaustive => unreachable!()
-        }
+        dispatch_socket!(self, |socket [mut]| socket.process(timestamp, ip_repr, payload))
     }
 
     /// Prepare a packet to be transmitted to a network interface.
@@ -77,13 +94,7 @@ impl<'a, 'b> Socket<'a, 'b> {
     /// This function is used internally by the networking stack.
     pub fn dispatch<F, R>(&mut self, timestamp: u64, emit: &mut F) -> Result<R, Error>
             where F: FnMut(&IpRepr, &IpPayload) -> Result<R, Error> {
-        match self {
-            &mut Socket::Udp(ref mut socket) =>
-                socket.dispatch(timestamp, emit),
-            &mut Socket::Tcp(ref mut socket) =>
-                socket.dispatch(timestamp, emit),
-            &mut Socket::__Nonexhaustive => unreachable!()
-        }
+        dispatch_socket!(self, |socket [mut]| socket.dispatch(timestamp, emit))
     }
 }
 

+ 5 - 2
src/socket/set.rs

@@ -29,9 +29,10 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> {
     ///
     /// # Panics
     /// This function panics if the storage is fixed-size (not a `Vec`) and is full.
-    pub fn add(&mut self, socket: Socket<'b, 'c>) -> Handle {
+    pub fn add(&mut self, mut socket: Socket<'b, 'c>) -> Handle {
         for (index, slot) in self.sockets.iter_mut().enumerate() {
             if slot.is_none() {
+                socket.set_debug_id(index);
                 *slot = Some(socket);
                 return Handle { index: index }
             }
@@ -42,8 +43,10 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> {
                 panic!("adding a socket to a full SocketSet")
             }
             ManagedSlice::Owned(ref mut sockets) => {
+                let index = sockets.len();
+                socket.set_debug_id(index);
                 sockets.push(Some(socket));
-                Handle { index: sockets.len() - 1 }
+                Handle { index: index }
             }
         }
     }

+ 66 - 48
src/socket/tcp.rs

@@ -247,7 +247,8 @@ pub struct TcpSocket<'a> {
     remote_win_len:  usize,
     retransmit:      Retransmit,
     rx_buffer:       SocketBuffer<'a>,
-    tx_buffer:       SocketBuffer<'a>
+    tx_buffer:       SocketBuffer<'a>,
+    debug_id:        usize
 }
 
 impl<'a> TcpSocket<'a> {
@@ -272,10 +273,24 @@ impl<'a> TcpSocket<'a> {
             remote_win_len:  0,
             retransmit:      Retransmit::new(),
             tx_buffer:       tx_buffer.into(),
-            rx_buffer:       rx_buffer.into()
+            rx_buffer:       rx_buffer.into(),
+            debug_id:        0
         })
     }
 
+    /// Return the debug identifier.
+    pub fn debug_id(&self) -> usize {
+        self.debug_id
+    }
+
+    /// Set the debug identifier.
+    ///
+    /// The debug identifier is a number printed in socket trace messages.
+    /// It could as well be used by the user code.
+    pub fn set_debug_id(&mut self, id: usize) {
+        self.debug_id = id
+    }
+
     /// Return the local endpoint.
     #[inline]
     pub fn local_endpoint(&self) -> IpEndpoint {
@@ -436,8 +451,8 @@ impl<'a> TcpSocket<'a> {
         let old_length = self.tx_buffer.len();
         let buffer = self.tx_buffer.enqueue(size);
         if buffer.len() > 0 {
-            net_trace!("tcp:{}:{}: tx buffer: enqueueing {} octets (now {})",
-                       self.local_endpoint, self.remote_endpoint,
+            net_trace!("[{}]{}:{}: tx buffer: enqueueing {} octets (now {})",
+                       self.debug_id, self.local_endpoint, self.remote_endpoint,
                        buffer.len(), old_length + buffer.len());
             self.retransmit.reset();
         }
@@ -471,8 +486,8 @@ impl<'a> TcpSocket<'a> {
         let buffer = self.rx_buffer.dequeue(size);
         self.remote_seq_no += buffer.len();
         if buffer.len() > 0 {
-            net_trace!("tcp:{}:{}: rx buffer: dequeueing {} octets (now {})",
-                       self.local_endpoint, self.remote_endpoint,
+            net_trace!("[{}]{}:{}: rx buffer: dequeueing {} octets (now {})",
+                       self.debug_id, self.local_endpoint, self.remote_endpoint,
                        buffer.len(), old_length - buffer.len());
         }
         Ok(buffer)
@@ -501,11 +516,13 @@ impl<'a> TcpSocket<'a> {
     fn set_state(&mut self, state: State) {
         if self.state != state {
             if self.remote_endpoint.addr.is_unspecified() {
-                net_trace!("tcp:{}: state={}=>{}",
-                           self.local_endpoint, self.state, state);
+                net_trace!("[{}]{}: state={}=>{}",
+                           self.debug_id, self.local_endpoint,
+                           self.state, state);
             } else {
-                net_trace!("tcp:{}:{}: state={}=>{}",
-                           self.local_endpoint, self.remote_endpoint, self.state, state);
+                net_trace!("[{}]{}:{}: state={}=>{}",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint,
+                           self.state, state);
             }
         }
         self.state = state
@@ -534,25 +551,25 @@ impl<'a> TcpSocket<'a> {
         match (self.state, repr) {
             // The initial SYN (or whatever) cannot contain an acknowledgement.
             (State::Listen, TcpRepr { ack_number: Some(_), .. }) => {
-                net_trace!("tcp:{}:{}: ACK received by a socket in LISTEN state",
-                           self.local_endpoint, self.remote_endpoint);
+                net_trace!("[{}]{}:{}: ACK received by a socket in LISTEN state",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint);
                 return Err(Error::Malformed)
             }
             (State::Listen, TcpRepr { ack_number: None, .. }) => (),
             // An RST received in response to initial SYN is acceptable if it acknowledges
             // the initial SYN.
             (State::SynSent, TcpRepr { control: TcpControl::Rst, ack_number: None, .. }) => {
-                net_trace!("tcp:{}:{}: unacceptable RST (expecting RST|ACK) \
+                net_trace!("[{}]{}:{}: unacceptable RST (expecting RST|ACK) \
                             in response to initial SYN",
-                           self.local_endpoint, self.remote_endpoint);
+                           self.debug_id, self.local_endpoint, self.remote_endpoint);
                 return Err(Error::Malformed)
             }
             (State::SynSent, TcpRepr {
                 control: TcpControl::Rst, ack_number: Some(ack_number), ..
             }) => {
                 if ack_number != self.local_seq_no {
-                    net_trace!("tcp:{}:{}: unacceptable RST|ACK in response to initial SYN",
-                               self.local_endpoint, self.remote_endpoint);
+                    net_trace!("[{}]{}:{}: unacceptable RST|ACK in response to initial SYN",
+                               self.debug_id, self.local_endpoint, self.remote_endpoint);
                     return Err(Error::Malformed)
                 }
             }
@@ -560,8 +577,8 @@ impl<'a> TcpSocket<'a> {
             (_, TcpRepr { control: TcpControl::Rst, .. }) => (),
             // Every packet after the initial SYN must be an acknowledgement.
             (_, TcpRepr { ack_number: None, .. }) => {
-                net_trace!("tcp:{}:{}: expecting an ACK",
-                           self.local_endpoint, self.remote_endpoint);
+                net_trace!("[{}]{}:{}: expecting an ACK",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint);
                 return Err(Error::Malformed)
             }
             // Every acknowledgement must be for transmitted but unacknowledged data.
@@ -578,8 +595,8 @@ impl<'a> TcpSocket<'a> {
                 let unacknowledged = self.tx_buffer.len() + control_len;
                 if !(ack_number >= self.local_seq_no &&
                      ack_number <= (self.local_seq_no + unacknowledged)) {
-                    net_trace!("tcp:{}:{}: unacceptable ACK ({} not in {}...{})",
-                               self.local_endpoint, self.remote_endpoint,
+                    net_trace!("[{}]{}:{}: unacceptable ACK ({} not in {}...{})",
+                               self.debug_id, self.local_endpoint, self.remote_endpoint,
                                ack_number, self.local_seq_no, self.local_seq_no + unacknowledged);
                     return Err(Error::Dropped)
                 }
@@ -595,13 +612,13 @@ impl<'a> TcpSocket<'a> {
             (_, TcpRepr { seq_number, .. }) => {
                 let next_remote_seq = self.remote_seq_no + self.rx_buffer.len();
                 if seq_number > next_remote_seq {
-                    net_trace!("tcp:{}:{}: unacceptable SEQ ({} not in {}..)",
-                               self.local_endpoint, self.remote_endpoint,
+                    net_trace!("[{}]{}:{}: unacceptable SEQ ({} not in {}..)",
+                               self.debug_id, self.local_endpoint, self.remote_endpoint,
                                seq_number, next_remote_seq);
                     return Err(Error::Dropped)
                 } else if seq_number != next_remote_seq {
-                    net_trace!("tcp:{}:{}: duplicate SEQ ({} in ..{})",
-                               self.local_endpoint, self.remote_endpoint,
+                    net_trace!("[{}]{}:{}: duplicate SEQ ({} in ..{})",
+                               self.debug_id, self.local_endpoint, self.remote_endpoint,
                                seq_number, next_remote_seq);
                     // If we've seen this sequence number already but the remote end is not aware
                     // of that, make sure we send the acknowledgement again.
@@ -620,8 +637,8 @@ impl<'a> TcpSocket<'a> {
 
             // RSTs in SYN-RECEIVED flip the socket back to the LISTEN state.
             (State::SynReceived, TcpRepr { control: TcpControl::Rst, .. }) => {
-                net_trace!("tcp:{}:{}: received RST",
-                           self.local_endpoint, self.remote_endpoint);
+                net_trace!("[{}]{}:{}: received RST",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint);
                 self.local_endpoint.addr = self.listen_address;
                 self.remote_endpoint     = IpEndpoint::default();
                 self.set_state(State::Listen);
@@ -630,8 +647,8 @@ impl<'a> TcpSocket<'a> {
 
             // RSTs in any other state close the socket.
             (_, TcpRepr { control: TcpControl::Rst, .. }) => {
-                net_trace!("tcp:{}:{}: received RST",
-                           self.local_endpoint, self.remote_endpoint);
+                net_trace!("[{}]{}:{}: received RST",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint);
                 self.set_state(State::Closed);
                 self.local_endpoint  = IpEndpoint::default();
                 self.remote_endpoint = IpEndpoint::default();
@@ -642,8 +659,8 @@ impl<'a> TcpSocket<'a> {
             (State::Listen, TcpRepr {
                 src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None, ..
             }) => {
-                net_trace!("tcp:{}: received SYN",
-                           self.local_endpoint);
+                net_trace!("[{}]{}: received SYN",
+                           self.debug_id, self.local_endpoint);
                 self.local_endpoint  = IpEndpoint::new(ip_repr.dst_addr(), dst_port);
                 self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), src_port);
                 // FIXME: use something more secure here
@@ -710,8 +727,8 @@ impl<'a> TcpSocket<'a> {
             }
 
             _ => {
-                net_trace!("tcp:{}:{}: unexpected packet {}",
-                           self.local_endpoint, self.remote_endpoint, repr);
+                net_trace!("[{}]{}:{}: unexpected packet {}",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint, repr);
                 return Err(Error::Malformed)
             }
         }
@@ -720,8 +737,8 @@ impl<'a> TcpSocket<'a> {
         if let Some(ack_number) = repr.ack_number {
             let ack_length = ack_number - self.local_seq_no;
             if ack_length > 0 {
-                net_trace!("tcp:{}:{}: tx buffer: dequeueing {} octets (now {})",
-                           self.local_endpoint, self.remote_endpoint,
+                net_trace!("[{}]{}:{}: tx buffer: dequeueing {} octets (now {})",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint,
                            ack_length, self.tx_buffer.len() - ack_length);
             }
             self.tx_buffer.advance(ack_length);
@@ -730,8 +747,8 @@ impl<'a> TcpSocket<'a> {
 
         // Enqueue payload octets, which is guaranteed to be in order, unless we already did.
         if repr.payload.len() > 0 {
-            net_trace!("tcp:{}:{}: rx buffer: enqueueing {} octets (now {})",
-                       self.local_endpoint, self.remote_endpoint,
+            net_trace!("[{}]{}:{}: rx buffer: enqueueing {} octets (now {})",
+                       self.debug_id, self.local_endpoint, self.remote_endpoint,
                        repr.payload.len(), self.rx_buffer.len() + repr.payload.len());
             self.rx_buffer.enqueue_slice(repr.payload)
         }
@@ -782,8 +799,8 @@ impl<'a> TcpSocket<'a> {
             // We transmit a SYN|ACK in the SYN-RECEIVED state.
             State::SynReceived => {
                 repr.control = TcpControl::Syn;
-                net_trace!("tcp:{}:{}: sending SYN|ACK",
-                           self.local_endpoint, self.remote_endpoint);
+                net_trace!("[{}]{}:{}: sending SYN|ACK",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint);
                 should_send = true;
             }
 
@@ -791,8 +808,8 @@ impl<'a> TcpSocket<'a> {
             State::SynSent => {
                 repr.control = TcpControl::Syn;
                 repr.ack_number = None;
-                net_trace!("tcp:{}:{}: sending SYN",
-                           self.local_endpoint, self.remote_endpoint);
+                net_trace!("[{}]{}:{}: sending SYN",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint);
                 should_send = true;
             }
 
@@ -816,8 +833,8 @@ impl<'a> TcpSocket<'a> {
                 let data = self.tx_buffer.peek(offset, size);
                 if data.len() > 0 {
                     // Send the extracted data.
-                    net_trace!("tcp:{}:{}: tx buffer: peeking at {} octets (from {})",
-                               self.local_endpoint, self.remote_endpoint,
+                    net_trace!("[{}]{}:{}: tx buffer: peeking at {} octets (from {})",
+                               self.debug_id, self.local_endpoint, self.remote_endpoint,
                                data.len(), offset);
                     repr.seq_number += offset;
                     repr.payload = data;
@@ -832,8 +849,8 @@ impl<'a> TcpSocket<'a> {
                     State::FinWait1 | State::LastAck => {
                         // We should notify the other side that we've closed the transmit half
                         // of the connection.
-                        net_trace!("tcp:{}:{}: sending FIN|ACK",
-                                   self.local_endpoint, self.remote_endpoint);
+                        net_trace!("[{}]{}:{}: sending FIN|ACK",
+                                   self.debug_id, self.local_endpoint, self.remote_endpoint);
                         repr.control = TcpControl::Fin;
                         should_send = true;
                     }
@@ -850,15 +867,16 @@ impl<'a> TcpSocket<'a> {
         let ack_number = self.remote_seq_no + self.rx_buffer.len();
         if !should_send && self.remote_last_ack != ack_number {
             // Acknowledge all data we have received, since it is all in order.
-            net_trace!("tcp:{}:{}: sending ACK",
-                       self.local_endpoint, self.remote_endpoint);
+            net_trace!("[{}]{}:{}: sending ACK",
+                       self.debug_id, self.local_endpoint, self.remote_endpoint);
             should_send = true;
         }
 
         if should_send {
             if self.retransmit.commit(timestamp) {
-                net_trace!("tcp:{}:{}: retransmit after {}ms",
-                           self.local_endpoint, self.remote_endpoint, self.retransmit.delay);
+                net_trace!("[{}]{}:{}: retransmit after {}ms",
+                           self.debug_id, self.local_endpoint, self.remote_endpoint,
+                           self.retransmit.delay);
             }
 
             repr.ack_number = Some(ack_number);

+ 29 - 10
src/socket/udp.rs

@@ -111,7 +111,8 @@ impl<'a, 'b> SocketBuffer<'a, 'b> {
 pub struct UdpSocket<'a, 'b: 'a> {
     endpoint:  IpEndpoint,
     rx_buffer: SocketBuffer<'a, 'b>,
-    tx_buffer: SocketBuffer<'a, 'b>
+    tx_buffer: SocketBuffer<'a, 'b>,
+    debug_id:  usize
 }
 
 impl<'a, 'b> UdpSocket<'a, 'b> {
@@ -121,10 +122,24 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
         Socket::Udp(UdpSocket {
             endpoint:  IpEndpoint::default(),
             rx_buffer: rx_buffer,
-            tx_buffer: tx_buffer
+            tx_buffer: tx_buffer,
+            debug_id:  0
         })
     }
 
+    /// Return the debug identifier.
+    pub fn debug_id(&self) -> usize {
+        self.debug_id
+    }
+
+    /// Set the debug identifier.
+    ///
+    /// The debug identifier is a number printed in socket trace messages.
+    /// It could as well be used by the user code.
+    pub fn set_debug_id(&mut self, id: usize) {
+        self.debug_id = id
+    }
+
     /// Return the bound endpoint.
     #[inline]
     pub fn endpoint(&self) -> IpEndpoint {
@@ -155,8 +170,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
         let packet_buf = try!(self.tx_buffer.enqueue());
         packet_buf.endpoint = endpoint;
         packet_buf.size = size;
-        net_trace!("udp:{}:{}: buffer to send {} octets",
-                   self.endpoint, packet_buf.endpoint, packet_buf.size);
+        net_trace!("[{}]{}:{}: buffer to send {} octets",
+                   self.debug_id, self.endpoint,
+                   packet_buf.endpoint, packet_buf.size);
         Ok(&mut packet_buf.as_mut()[..size])
     }
 
@@ -176,8 +192,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     /// This function returns `Err(())` if the receive buffer is empty.
     pub fn recv(&mut self) -> Result<(&[u8], IpEndpoint), ()> {
         let packet_buf = try!(self.rx_buffer.dequeue());
-        net_trace!("udp:{}:{}: receive {} buffered octets",
-                   self.endpoint, packet_buf.endpoint, packet_buf.size);
+        net_trace!("[{}]{}:{}: receive {} buffered octets",
+                   self.debug_id, self.endpoint,
+                   packet_buf.endpoint, packet_buf.size);
         Ok((&packet_buf.as_ref()[..packet_buf.size], packet_buf.endpoint))
     }
 
@@ -208,8 +225,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
         packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port };
         packet_buf.size = repr.payload.len();
         packet_buf.as_mut()[..repr.payload.len()].copy_from_slice(repr.payload);
-        net_trace!("udp:{}:{}: receiving {} octets",
-                   self.endpoint, packet_buf.endpoint, packet_buf.size);
+        net_trace!("[{}]{}:{}: receiving {} octets",
+                   self.debug_id, self.endpoint,
+                   packet_buf.endpoint, packet_buf.size);
         Ok(())
     }
 
@@ -217,8 +235,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
     pub fn dispatch<F, R>(&mut self, _timestamp: u64, emit: &mut F) -> Result<R, Error>
             where F: FnMut(&IpRepr, &IpPayload) -> Result<R, Error> {
         let packet_buf = try!(self.tx_buffer.dequeue().map_err(|()| Error::Exhausted));
-        net_trace!("udp:{}:{}: sending {} octets",
-                   self.endpoint, packet_buf.endpoint, packet_buf.size);
+        net_trace!("[{}]{}:{}: sending {} octets",
+                   self.debug_id, self.endpoint,
+                   packet_buf.endpoint, packet_buf.size);
         let repr = UdpRepr {
             src_port: self.endpoint.port,
             dst_port: packet_buf.endpoint.port,