|
@@ -25,6 +25,11 @@ impl<'a> SocketBuffer<'a> {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ fn clear(&mut self) {
|
|
|
+ self.read_at = 0;
|
|
|
+ self.length = 0;
|
|
|
+ }
|
|
|
+
|
|
|
fn capacity(&self) -> usize {
|
|
|
self.storage.len()
|
|
|
}
|
|
@@ -253,6 +258,8 @@ pub struct TcpSocket<'a> {
|
|
|
debug_id: usize
|
|
|
}
|
|
|
|
|
|
+const DEFAULT_MSS: usize = 536;
|
|
|
+
|
|
|
impl<'a> TcpSocket<'a> {
|
|
|
/// Create a socket using the given buffers.
|
|
|
pub fn new<T>(rx_buffer: T, tx_buffer: T) -> Socket<'a, 'static>
|
|
@@ -273,7 +280,7 @@ impl<'a> TcpSocket<'a> {
|
|
|
remote_last_seq: TcpSeqNumber(0),
|
|
|
remote_last_ack: TcpSeqNumber(0),
|
|
|
remote_win_len: 0,
|
|
|
- remote_mss: 536,
|
|
|
+ remote_mss: DEFAULT_MSS,
|
|
|
retransmit: Retransmit::new(),
|
|
|
tx_buffer: tx_buffer.into(),
|
|
|
rx_buffer: rx_buffer.into(),
|
|
@@ -311,20 +318,77 @@ impl<'a> TcpSocket<'a> {
|
|
|
self.state
|
|
|
}
|
|
|
|
|
|
+ fn reset(&mut self) {
|
|
|
+ self.listen_address = IpAddress::default();
|
|
|
+ self.local_endpoint = IpEndpoint::default();
|
|
|
+ self.remote_endpoint = IpEndpoint::default();
|
|
|
+ self.local_seq_no = TcpSeqNumber(0);
|
|
|
+ self.remote_seq_no = TcpSeqNumber(0);
|
|
|
+ self.remote_last_seq = TcpSeqNumber(0);
|
|
|
+ self.remote_last_ack = TcpSeqNumber(0);
|
|
|
+ self.remote_win_len = 0;
|
|
|
+ self.remote_win_len = 0;
|
|
|
+ self.remote_mss = DEFAULT_MSS;
|
|
|
+ self.retransmit.reset();
|
|
|
+ self.tx_buffer.clear();
|
|
|
+ self.rx_buffer.clear();
|
|
|
+ }
|
|
|
+
|
|
|
/// Start listening on the given endpoint.
|
|
|
///
|
|
|
/// This function returns an error if the socket was open; see [is_open](#method.is_open).
|
|
|
- pub fn listen<T: Into<IpEndpoint>>(&mut self, endpoint: T) -> Result<(), ()> {
|
|
|
+ /// It also returns an error if the specified port is zero.
|
|
|
+ pub fn listen<T>(&mut self, local_endpoint: T) -> Result<(), ()>
|
|
|
+ where T: Into<IpEndpoint> {
|
|
|
+ let local_endpoint = local_endpoint.into();
|
|
|
+
|
|
|
if self.is_open() { return Err(()) }
|
|
|
+ if local_endpoint.port == 0 { return Err(()) }
|
|
|
|
|
|
- let endpoint = endpoint.into();
|
|
|
- self.listen_address = endpoint.addr;
|
|
|
- self.local_endpoint = endpoint;
|
|
|
+ self.reset();
|
|
|
+ self.listen_address = local_endpoint.addr;
|
|
|
+ self.local_endpoint = local_endpoint;
|
|
|
self.remote_endpoint = IpEndpoint::default();
|
|
|
self.set_state(State::Listen);
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
+ /// Connect to a given endpoint.
|
|
|
+ ///
|
|
|
+ /// The local port must be provided explicitly. Assuming `fn get_ephemeral_port() -> u16`
|
|
|
+ /// allocates a port from the 49152 to 65535 range, a connection may be established as follows:
|
|
|
+ ///
|
|
|
+ /// ```rust,ignore
|
|
|
+ /// socket.connect((IpAddress::v4(10, 0, 0, 1), 80), get_ephemeral_port())
|
|
|
+ /// ```
|
|
|
+ ///
|
|
|
+ /// The local address may optionally be provided.
|
|
|
+ ///
|
|
|
+ /// This function returns an error if the socket was open; see [is_open](#method.is_open).
|
|
|
+ /// It also returns an error if the local or remote port is zero, or if
|
|
|
+ /// the local or remote address is unspecified.
|
|
|
+ pub fn connect<T, U>(&mut self, remote_endpoint: T, local_endpoint: U) -> Result<(), ()>
|
|
|
+ where T: Into<IpEndpoint>, U: Into<IpEndpoint> {
|
|
|
+ let remote_endpoint = remote_endpoint.into();
|
|
|
+ let local_endpoint = local_endpoint.into();
|
|
|
+
|
|
|
+ if self.is_open() { return Err(()) }
|
|
|
+ if remote_endpoint.port == 0 { return Err(()) }
|
|
|
+ if remote_endpoint.addr.is_unspecified() { return Err(()) }
|
|
|
+ if local_endpoint.port == 0 { return Err(()) }
|
|
|
+ if local_endpoint.addr.is_unspecified() { return Err(()) }
|
|
|
+
|
|
|
+ // Carry over the local sequence number.
|
|
|
+ let local_seq_no = self.local_seq_no;
|
|
|
+
|
|
|
+ self.reset();
|
|
|
+ self.local_endpoint = local_endpoint;
|
|
|
+ self.remote_endpoint = remote_endpoint;
|
|
|
+ self.local_seq_no = local_seq_no;
|
|
|
+ self.set_state(State::SynSent);
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
/// Close the transmit half of the full-duplex connection.
|
|
|
///
|
|
|
/// Note that there is no corresponding function for the receive half of the full-duplex
|
|
@@ -715,6 +779,23 @@ impl<'a> TcpSocket<'a> {
|
|
|
self.retransmit.reset();
|
|
|
}
|
|
|
|
|
|
+ // SYN|ACK packets in the SYN-SENT state change it to ESTABLISHED.
|
|
|
+ (State::SynSent, TcpRepr {
|
|
|
+ control: TcpControl::Syn, seq_number, ack_number: Some(_),
|
|
|
+ max_seg_size, ..
|
|
|
+ }) => {
|
|
|
+ net_trace!("[{}]{}:{}: received SYN|ACK",
|
|
|
+ self.debug_id, self.local_endpoint, self.remote_endpoint);
|
|
|
+ self.remote_last_seq = self.local_seq_no + 1;
|
|
|
+ self.remote_seq_no = seq_number + 1;
|
|
|
+ self.remote_last_ack = seq_number;
|
|
|
+ if let Some(max_seg_size) = max_seg_size {
|
|
|
+ self.remote_mss = max_seg_size as usize;
|
|
|
+ }
|
|
|
+ self.set_state(State::Established);
|
|
|
+ self.retransmit.reset();
|
|
|
+ }
|
|
|
+
|
|
|
// ACK packets in ESTABLISHED state reset the retransmit timer.
|
|
|
(State::Established, TcpRepr { control: TcpControl::None, .. }) => {
|
|
|
self.retransmit.reset()
|
|
@@ -962,8 +1043,10 @@ impl<'a> TcpSocket<'a> {
|
|
|
self.retransmit.delay);
|
|
|
}
|
|
|
|
|
|
- repr.ack_number = Some(ack_number);
|
|
|
- self.remote_last_ack = ack_number;
|
|
|
+ if self.state != State::SynSent {
|
|
|
+ repr.ack_number = Some(ack_number);
|
|
|
+ self.remote_last_ack = ack_number;
|
|
|
+ }
|
|
|
|
|
|
// Remember the header length before enabling the MSS option, since that option
|
|
|
// only affects SYN packets.
|
|
@@ -1249,6 +1332,12 @@ mod test {
|
|
|
sanity!(s, socket_listen());
|
|
|
}
|
|
|
|
|
|
+ #[test]
|
|
|
+ fn test_listen_validation() {
|
|
|
+ let mut s = socket();
|
|
|
+ assert_eq!(s.listen(0), Err(()));
|
|
|
+ }
|
|
|
+
|
|
|
#[test]
|
|
|
fn test_listen_syn() {
|
|
|
let mut s = socket_listen();
|
|
@@ -1358,6 +1447,54 @@ mod test {
|
|
|
s
|
|
|
}
|
|
|
|
|
|
+ #[test]
|
|
|
+ fn test_connect_validation() {
|
|
|
+ let mut s = socket();
|
|
|
+ assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), Err(()));
|
|
|
+ assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(0, 0, 0, 0), 80)), Err(()));
|
|
|
+ assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(10, 0, 0, 0), 0)), Err(()));
|
|
|
+ assert_eq!(s.connect((IpAddress::Unspecified, 80), LOCAL_END), Err(()));
|
|
|
+ assert_eq!(s.connect((IpAddress::v4(0, 0, 0, 0), 80), LOCAL_END), Err(()));
|
|
|
+ assert_eq!(s.connect((IpAddress::v4(10, 0, 0, 0), 0), LOCAL_END), Err(()));
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn test_syn_sent_sanity() {
|
|
|
+ let mut s = socket();
|
|
|
+ s.local_seq_no = LOCAL_SEQ;
|
|
|
+ s.connect(REMOTE_END, LOCAL_END).unwrap();
|
|
|
+ sanity!(s, socket_syn_sent());
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn test_syn_sent_syn_ack() {
|
|
|
+ let mut s = socket_syn_sent();
|
|
|
+ recv!(s, [TcpRepr {
|
|
|
+ control: TcpControl::Syn,
|
|
|
+ seq_number: LOCAL_SEQ,
|
|
|
+ ack_number: None,
|
|
|
+ max_seg_size: Some(1480),
|
|
|
+ ..RECV_TEMPL
|
|
|
+ }]);
|
|
|
+ send!(s, TcpRepr {
|
|
|
+ control: TcpControl::Syn,
|
|
|
+ seq_number: REMOTE_SEQ,
|
|
|
+ ack_number: Some(LOCAL_SEQ + 1),
|
|
|
+ max_seg_size: Some(1400),
|
|
|
+ ..SEND_TEMPL
|
|
|
+ });
|
|
|
+ recv!(s, [TcpRepr {
|
|
|
+ seq_number: LOCAL_SEQ + 1,
|
|
|
+ ack_number: Some(REMOTE_SEQ + 1),
|
|
|
+ ..RECV_TEMPL
|
|
|
+ }]);
|
|
|
+ assert_eq!(s.state, State::Established);
|
|
|
+ sanity!(s, TcpSocket {
|
|
|
+ retransmit: Retransmit { resend_at: 100, delay: 100 },
|
|
|
+ ..socket_established()
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
#[test]
|
|
|
fn test_syn_sent_rst() {
|
|
|
let mut s = socket_syn_sent();
|