|
@@ -666,25 +666,31 @@ impl<'a> TcpSocket<'a> {
|
|
|
(ip_reply_repr, reply_repr)
|
|
|
}
|
|
|
|
|
|
- pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr, repr: &TcpRepr) ->
|
|
|
- Result<Option<(IpRepr, TcpRepr<'static>)>> {
|
|
|
- if self.state == State::Closed { return Err(Error::Rejected) }
|
|
|
+ pub(crate) fn accepts(&self, ip_repr: &IpRepr, repr: &TcpRepr) -> bool {
|
|
|
+ if self.state == State::Closed { return false }
|
|
|
|
|
|
// If we're still listening for SYNs and the packet has an ACK, it cannot
|
|
|
// be destined to this socket, but another one may well listen on the same
|
|
|
// local endpoint.
|
|
|
- if self.state == State::Listen && repr.ack_number.is_some() { return Err(Error::Rejected) }
|
|
|
+ if self.state == State::Listen && repr.ack_number.is_some() { return false }
|
|
|
|
|
|
// Reject packets with a wrong destination.
|
|
|
- if self.local_endpoint.port != repr.dst_port { return Err(Error::Rejected) }
|
|
|
+ if self.local_endpoint.port != repr.dst_port { return false }
|
|
|
if !self.local_endpoint.addr.is_unspecified() &&
|
|
|
- self.local_endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) }
|
|
|
+ self.local_endpoint.addr != ip_repr.dst_addr() { return false }
|
|
|
|
|
|
// Reject packets from a source to which we aren't connected.
|
|
|
if self.remote_endpoint.port != 0 &&
|
|
|
- self.remote_endpoint.port != repr.src_port { return Err(Error::Rejected) }
|
|
|
+ self.remote_endpoint.port != repr.src_port { return false }
|
|
|
if !self.remote_endpoint.addr.is_unspecified() &&
|
|
|
- self.remote_endpoint.addr != ip_repr.src_addr() { return Err(Error::Rejected) }
|
|
|
+ self.remote_endpoint.addr != ip_repr.src_addr() { return false }
|
|
|
+
|
|
|
+ true
|
|
|
+ }
|
|
|
+
|
|
|
+ pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr, repr: &TcpRepr) ->
|
|
|
+ Result<Option<(IpRepr, TcpRepr<'static>)>> {
|
|
|
+ debug_assert!(self.accepts(ip_repr, repr));
|
|
|
|
|
|
// Consider how much the sequence number space differs from the transmit buffer space.
|
|
|
let (sent_syn, sent_fin) = match self.state {
|
|
@@ -1241,6 +1247,7 @@ mod test {
|
|
|
|
|
|
const LOCAL_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 1]));
|
|
|
const REMOTE_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 2]));
|
|
|
+ const THIRD_PARTY_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 3]));
|
|
|
const LOCAL_PORT: u16 = 80;
|
|
|
const REMOTE_PORT: u16 = 49500;
|
|
|
const LOCAL_END: IpEndpoint = IpEndpoint { addr: LOCAL_IP, port: LOCAL_PORT };
|
|
@@ -1272,6 +1279,11 @@ mod test {
|
|
|
protocol: IpProtocol::Tcp,
|
|
|
payload_len: repr.buffer_len()
|
|
|
};
|
|
|
+
|
|
|
+ if !socket.accepts(&ip_repr, repr) {
|
|
|
+ return Err(Error::Rejected);
|
|
|
+ }
|
|
|
+
|
|
|
match socket.process(timestamp, &ip_repr, repr) {
|
|
|
Ok(Some((_ip_repr, repr))) => {
|
|
|
trace!("recv: {}", repr);
|
|
@@ -2930,4 +2942,66 @@ mod test {
|
|
|
..RECV_TEMPL
|
|
|
}]);
|
|
|
}
|
|
|
+
|
|
|
+ // =========================================================================================//
|
|
|
+ // Tests for packet filtering
|
|
|
+ // =========================================================================================//
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn test_doesnt_accept_wrong_port() {
|
|
|
+ let mut s = socket_established();
|
|
|
+ s.rx_buffer = SocketBuffer::new(vec![0; 6]);
|
|
|
+
|
|
|
+ send!(s, TcpRepr {
|
|
|
+ seq_number: REMOTE_SEQ + 1,
|
|
|
+ ack_number: Some(LOCAL_SEQ + 1),
|
|
|
+ payload: &b"abcdef"[..],
|
|
|
+ dst_port: LOCAL_PORT + 1,
|
|
|
+ ..SEND_TEMPL
|
|
|
+ }, Err(Error::Rejected));
|
|
|
+
|
|
|
+ send!(s, TcpRepr {
|
|
|
+ seq_number: REMOTE_SEQ + 1,
|
|
|
+ ack_number: Some(LOCAL_SEQ + 1),
|
|
|
+ payload: &b"abcdef"[..],
|
|
|
+ src_port: REMOTE_PORT + 1,
|
|
|
+ ..SEND_TEMPL
|
|
|
+ }, Err(Error::Rejected));
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn test_doesnt_accept_wrong_ip() {
|
|
|
+ let s = socket_established();
|
|
|
+
|
|
|
+ let tcp_repr = TcpRepr {
|
|
|
+ seq_number: REMOTE_SEQ + 1,
|
|
|
+ ack_number: Some(LOCAL_SEQ + 1),
|
|
|
+ payload: &b"abcdef"[..],
|
|
|
+ ..SEND_TEMPL
|
|
|
+ };
|
|
|
+
|
|
|
+ let ip_repr = IpRepr::Unspecified {
|
|
|
+ src_addr: REMOTE_IP,
|
|
|
+ dst_addr: LOCAL_IP,
|
|
|
+ protocol: IpProtocol::Tcp,
|
|
|
+ payload_len: tcp_repr.buffer_len()
|
|
|
+ };
|
|
|
+ assert!(s.accepts(&ip_repr, &tcp_repr));
|
|
|
+
|
|
|
+ let ip_repr_wrong_src = IpRepr::Unspecified {
|
|
|
+ src_addr: THIRD_PARTY_IP,
|
|
|
+ dst_addr: LOCAL_IP,
|
|
|
+ protocol: IpProtocol::Tcp,
|
|
|
+ payload_len: tcp_repr.buffer_len()
|
|
|
+ };
|
|
|
+ assert!(!s.accepts(&ip_repr_wrong_src, &tcp_repr));
|
|
|
+
|
|
|
+ let ip_repr_wrong_dst = IpRepr::Unspecified {
|
|
|
+ src_addr: REMOTE_IP,
|
|
|
+ dst_addr: THIRD_PARTY_IP,
|
|
|
+ protocol: IpProtocol::Tcp,
|
|
|
+ payload_len: tcp_repr.buffer_len()
|
|
|
+ };
|
|
|
+ assert!(!s.accepts(&ip_repr_wrong_dst, &tcp_repr));
|
|
|
+ }
|
|
|
}
|