|
@@ -70,8 +70,11 @@ impl<'a> SocketBuffer<'a> {
|
|
|
|
|
|
fn clamp_reader(&self, offset: usize, mut size: usize) -> (usize, usize) {
|
|
|
let read_at = (self.read_at + offset) % self.storage.len();
|
|
|
+ // We can't read past the end of the queued data.
|
|
|
+ if offset > self.length { return (read_at, 0) }
|
|
|
// We can't dequeue more than was queued.
|
|
|
- if size > self.length { size = self.length }
|
|
|
+ let clamped_length = self.length - offset;
|
|
|
+ if size > clamped_length { size = clamped_length }
|
|
|
// We can't contiguously dequeue past the end of the storage.
|
|
|
let until_end = self.storage.len() - read_at;
|
|
|
if size > until_end { size = until_end }
|
|
@@ -79,7 +82,6 @@ impl<'a> SocketBuffer<'a> {
|
|
|
(read_at, size)
|
|
|
}
|
|
|
|
|
|
- #[allow(dead_code)] // only used in tests
|
|
|
fn dequeue(&mut self, size: usize) -> &[u8] {
|
|
|
let (read_at, size) = self.clamp_reader(0, size);
|
|
|
self.read_at = (self.read_at + size) % self.storage.len();
|
|
@@ -88,13 +90,14 @@ impl<'a> SocketBuffer<'a> {
|
|
|
}
|
|
|
|
|
|
fn peek(&self, offset: usize, size: usize) -> &[u8] {
|
|
|
- if offset > self.length { panic!("peeking {} octets past free space", offset) }
|
|
|
let (read_at, size) = self.clamp_reader(offset, size);
|
|
|
&self.storage[read_at..read_at + size]
|
|
|
}
|
|
|
|
|
|
fn advance(&mut self, size: usize) {
|
|
|
- if size > self.length { panic!("advancing {} octets into free space", size) }
|
|
|
+ if size > self.length {
|
|
|
+ panic!("advancing {} octets into free space", size - self.length)
|
|
|
+ }
|
|
|
self.read_at = (self.read_at + size) % self.storage.len();
|
|
|
self.length -= size;
|
|
|
}
|
|
@@ -327,10 +330,12 @@ impl<'a> TcpSocket<'a> {
|
|
|
pub fn send(&mut self, size: usize) -> Result<&mut [u8], ()> {
|
|
|
if !self.can_send() { return Err(()) }
|
|
|
|
|
|
+ let old_length = self.tx_buffer.len();
|
|
|
let buffer = self.tx_buffer.enqueue(size);
|
|
|
if buffer.len() > 0 {
|
|
|
- net_trace!("tcp:{}:{}: tx buffer: enqueueing {} octets",
|
|
|
- self.local_endpoint, self.remote_endpoint, buffer.len());
|
|
|
+ net_trace!("tcp:{}:{}: tx buffer: enqueueing {} octets (now {})",
|
|
|
+ self.local_endpoint, self.remote_endpoint,
|
|
|
+ buffer.len(), old_length + buffer.len());
|
|
|
}
|
|
|
Ok(buffer)
|
|
|
}
|
|
@@ -358,11 +363,13 @@ impl<'a> TcpSocket<'a> {
|
|
|
// but until the connection is fully open we refuse to dequeue any data.
|
|
|
if !self.can_recv() { return Err(()) }
|
|
|
|
|
|
+ let old_length = self.rx_buffer.len();
|
|
|
let buffer = self.rx_buffer.dequeue(size);
|
|
|
self.remote_seq_no += buffer.len();
|
|
|
if buffer.len() > 0 {
|
|
|
- net_trace!("tcp:{}:{}: rx buffer: dequeueing {} octets",
|
|
|
- self.local_endpoint, self.remote_endpoint, buffer.len());
|
|
|
+ net_trace!("tcp:{}:{}: rx buffer: dequeueing {} octets (now {})",
|
|
|
+ self.local_endpoint, self.remote_endpoint,
|
|
|
+ buffer.len(), old_length - buffer.len());
|
|
|
}
|
|
|
Ok(buffer)
|
|
|
}
|
|
@@ -434,7 +441,7 @@ impl<'a> TcpSocket<'a> {
|
|
|
return Err(Error::Malformed)
|
|
|
}
|
|
|
(State::Listen, TcpRepr { ack_number: None, .. }) => (),
|
|
|
- // A reset received in response to initial SYN is acceptable if it acknowledges
|
|
|
+ // An RST received in response to initial SYN is acceptable if it acknowledges
|
|
|
// the initial SYN.
|
|
|
(State::SynSent, TcpRepr { control: TcpControl::Rst, ack_number: None, .. }) => {
|
|
|
net_trace!("tcp:{}:{}: unacceptable RST (expecting RST|ACK) \
|
|
@@ -572,18 +579,19 @@ impl<'a> TcpSocket<'a> {
|
|
|
if let Some(ack_number) = repr.ack_number {
|
|
|
let ack_length = ack_number - self.local_seq_no;
|
|
|
if ack_length > 0 {
|
|
|
- net_trace!("tcp:{}:{}: tx buffer: dequeueing {} octets",
|
|
|
+ net_trace!("tcp:{}:{}: tx buffer: dequeueing {} octets (now {})",
|
|
|
self.local_endpoint, self.remote_endpoint,
|
|
|
- ack_length);
|
|
|
+ ack_length, self.tx_buffer.len() - ack_length);
|
|
|
}
|
|
|
- self.tx_buffer.advance(ack_length as usize);
|
|
|
+ self.tx_buffer.advance(ack_length);
|
|
|
self.local_seq_no = ack_number;
|
|
|
}
|
|
|
|
|
|
// Enqueue payload octets, which is guaranteed to be in order, unless we already did.
|
|
|
if repr.payload.len() > 0 {
|
|
|
- net_trace!("tcp:{}:{}: rx buffer: enqueueing {} octets",
|
|
|
- self.local_endpoint, self.remote_endpoint, repr.payload.len());
|
|
|
+ net_trace!("tcp:{}:{}: rx buffer: enqueueing {} octets (now {})",
|
|
|
+ self.local_endpoint, self.remote_endpoint,
|
|
|
+ repr.payload.len(), self.rx_buffer.len() + repr.payload.len());
|
|
|
self.rx_buffer.enqueue_slice(repr.payload)
|
|
|
}
|
|
|
|
|
@@ -637,16 +645,19 @@ impl<'a> TcpSocket<'a> {
|
|
|
|
|
|
if self.tx_buffer.len() > 0 && self.remote_win_len > 0 && may_send {
|
|
|
// We can send something, so let's do that.
|
|
|
- let offset = self.remote_last_seq - self.local_seq_no;
|
|
|
- let mut size = self.remote_win_len;
|
|
|
+ let mut size = self.tx_buffer.len();
|
|
|
+ // Clamp to remote window length.
|
|
|
+ if size > self.remote_win_len { size = self.remote_win_len }
|
|
|
// Clamp to MSS. Currently we only support the default MSS value.
|
|
|
if size > 536 { size = 536 }
|
|
|
// Extract data from the buffer. This may return less than what we want,
|
|
|
// in case it's not possible to extract a contiguous slice.
|
|
|
- let data = self.tx_buffer.peek(offset as usize, size);
|
|
|
+ let offset = self.remote_last_seq - self.local_seq_no;
|
|
|
+ let data = self.tx_buffer.peek(offset, size);
|
|
|
+ assert!(data.len() > 0);
|
|
|
// Send the extracted data.
|
|
|
- net_trace!("tcp:{}:{}: tx buffer: peeking at {} octets",
|
|
|
- self.local_endpoint, self.remote_endpoint, data.len());
|
|
|
+ net_trace!("tcp:{}:{}: tx buffer: peeking at {} octets (from {})",
|
|
|
+ self.local_endpoint, self.remote_endpoint, data.len(), offset);
|
|
|
repr.payload = data;
|
|
|
// Speculatively shrink the remote window. This will get updated the next
|
|
|
// time we receive a packet.
|
|
@@ -715,6 +726,14 @@ mod test {
|
|
|
buffer.enqueue_slice(&b"bazhoge"[..]); // zhobarba
|
|
|
}
|
|
|
|
|
|
+ #[test]
|
|
|
+ fn test_buffer_peek() {
|
|
|
+ let mut buffer = SocketBuffer::new(vec![0; 8]); // ........
|
|
|
+ buffer.enqueue_slice(&b"foobar"[..]); // foobar..
|
|
|
+ assert_eq!(buffer.peek(0, 8), &b"foobar"[..]);
|
|
|
+ assert_eq!(buffer.peek(3, 8), &b"bar"[..]);
|
|
|
+ }
|
|
|
+
|
|
|
const LOCAL_IP: IpAddress = IpAddress::v4(10, 0, 0, 1);
|
|
|
const REMOTE_IP: IpAddress = IpAddress::v4(10, 0, 0, 2);
|
|
|
const LOCAL_PORT: u16 = 80;
|
|
@@ -1012,6 +1031,20 @@ mod test {
|
|
|
assert_eq!(s.tx_buffer.len(), 0);
|
|
|
}
|
|
|
|
|
|
+ #[test]
|
|
|
+ fn test_established_send_buf_gt_win() {
|
|
|
+ let mut s = socket_established();
|
|
|
+ s.remote_win_len = 16;
|
|
|
+ // First roundtrip after establishing.
|
|
|
+ s.tx_buffer.enqueue_slice(&[0; 32][..]);
|
|
|
+ recv!(s, [TcpRepr {
|
|
|
+ seq_number: LOCAL_SEQ + 1,
|
|
|
+ ack_number: Some(REMOTE_SEQ + 1),
|
|
|
+ payload: &[0; 16][..],
|
|
|
+ ..RECV_TEMPL
|
|
|
+ }]);
|
|
|
+ }
|
|
|
+
|
|
|
#[test]
|
|
|
fn test_established_no_ack() {
|
|
|
let mut s = socket_established();
|