|
@@ -166,9 +166,15 @@ impl<'a, 'b> RawSocket<'a, 'b> {
|
|
|
Ok(length)
|
|
|
}
|
|
|
|
|
|
+ pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool {
|
|
|
+ if ip_repr.version() != self.ip_version { return false }
|
|
|
+ if ip_repr.protocol() != self.ip_protocol { return false }
|
|
|
+
|
|
|
+ true
|
|
|
+ }
|
|
|
+
|
|
|
pub(crate) fn process(&mut self, ip_repr: &IpRepr, payload: &[u8]) -> Result<()> {
|
|
|
- if ip_repr.version() != self.ip_version { return Err(Error::Rejected) }
|
|
|
- if ip_repr.protocol() != self.ip_protocol { return Err(Error::Rejected) }
|
|
|
+ debug_assert!(self.accepts(ip_repr));
|
|
|
|
|
|
let header_len = ip_repr.buffer_len();
|
|
|
let total_len = header_len + payload.len();
|
|
@@ -246,17 +252,18 @@ mod test {
|
|
|
fn socket(rx_buffer: SocketBuffer<'static, 'static>,
|
|
|
tx_buffer: SocketBuffer<'static, 'static>)
|
|
|
-> RawSocket<'static, 'static> {
|
|
|
- match RawSocket::new(IpVersion::Ipv4, IpProtocol::Unknown(63),
|
|
|
+ match RawSocket::new(IpVersion::Ipv4, IpProtocol::Unknown(IP_PROTO),
|
|
|
rx_buffer, tx_buffer) {
|
|
|
Socket::Raw(socket) => socket,
|
|
|
_ => unreachable!()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ const IP_PROTO: u8 = 63;
|
|
|
const HEADER_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr {
|
|
|
src_addr: Ipv4Address([10, 0, 0, 1]),
|
|
|
dst_addr: Ipv4Address([10, 0, 0, 2]),
|
|
|
- protocol: IpProtocol::Unknown(63),
|
|
|
+ protocol: IpProtocol::Unknown(IP_PROTO),
|
|
|
payload_len: 4
|
|
|
});
|
|
|
const PACKET_BYTES: [u8; 24] = [
|
|
@@ -332,10 +339,12 @@ mod test {
|
|
|
Ipv4Packet::new(&mut cksumd_packet).fill_checksum();
|
|
|
|
|
|
assert_eq!(socket.recv(), Err(Error::Exhausted));
|
|
|
+ assert!(socket.accepts(&HEADER_REPR));
|
|
|
assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
|
|
|
Ok(()));
|
|
|
assert!(socket.can_recv());
|
|
|
|
|
|
+ assert!(socket.accepts(&HEADER_REPR));
|
|
|
assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
|
|
|
Err(Error::Exhausted));
|
|
|
assert_eq!(socket.recv(), Ok(&cksumd_packet[..]));
|
|
@@ -346,6 +355,7 @@ mod test {
|
|
|
fn test_recv_truncated_slice() {
|
|
|
let mut socket = socket(buffer(1), buffer(0));
|
|
|
|
|
|
+ assert!(socket.accepts(&HEADER_REPR));
|
|
|
assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
|
|
|
Ok(()));
|
|
|
|
|
@@ -361,7 +371,19 @@ mod test {
|
|
|
let mut buffer = vec![0; 128];
|
|
|
buffer[..PACKET_BYTES.len()].copy_from_slice(&PACKET_BYTES[..]);
|
|
|
|
|
|
+ assert!(socket.accepts(&HEADER_REPR));
|
|
|
assert_eq!(socket.process(&HEADER_REPR, &buffer),
|
|
|
Err(Error::Truncated));
|
|
|
}
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn test_doesnt_accept_wrong_proto() {
|
|
|
+ let socket = match RawSocket::new(IpVersion::Ipv4,
|
|
|
+ IpProtocol::Unknown(IP_PROTO+1),
|
|
|
+ buffer(1), buffer(1)) {
|
|
|
+ Socket::Raw(socket) => socket,
|
|
|
+ _ => unreachable!()
|
|
|
+ };
|
|
|
+ assert!(!socket.accepts(&HEADER_REPR));
|
|
|
+ }
|
|
|
}
|