Sfoglia il codice sorgente

aya: tc: clean up netlink code a bit

Alessandro Decina 3 anni fa
parent
commit
9185f32f6f
1 ha cambiato i file con 42 aggiunte e 79 eliminazioni
  1. 42 79
      aya/src/sys/netlink.rs

+ 42 - 79
aya/src/sys/netlink.rs

@@ -31,30 +31,22 @@ pub(crate) unsafe fn netlink_set_xdp_fd(
 ) -> Result<(), io::Error> {
     let sock = NetlinkSocket::open()?;
 
-    let seq = 1;
     // Safety: Request is POD so this is safe
     let mut req = mem::zeroed::<Request>();
 
+    let nlmsg_len = mem::size_of::<nlmsghdr>() + mem::size_of::<ifinfomsg>();
     req.header = nlmsghdr {
-        nlmsg_len: (mem::size_of::<nlmsghdr>() + mem::size_of::<ifinfomsg>()) as u32,
+        nlmsg_len: nlmsg_len as u32,
         nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK) as u16,
         nlmsg_type: RTM_SETLINK,
         nlmsg_pid: 0,
-        nlmsg_seq: seq,
+        nlmsg_seq: 1,
     };
     req.if_info.ifi_family = AF_UNSPEC as u8;
     req.if_info.ifi_index = if_index;
 
-    let attrs_buf = {
-        let attrs_addr = align_to(
-            &req as *const _ as usize + req.header.nlmsg_len as usize,
-            NLMSG_ALIGNTO as usize,
-        );
-        let attrs_end = &req as *const _ as usize + mem::size_of::<Request>();
-        slice::from_raw_parts_mut(attrs_addr as *mut u8, attrs_end - attrs_addr)
-    };
-
     // write the attrs
+    let attrs_buf = request_attributes(&mut req, nlmsg_len);
     let mut attrs = NestedAttrs::new(attrs_buf, IFLA_XDP);
     attrs.write_attr(IFLA_XDP_FD as u16, fd)?;
 
@@ -69,15 +61,7 @@ pub(crate) unsafe fn netlink_set_xdp_fd(
     let nla_len = attrs.finish()?;
     req.header.nlmsg_len += align_to(nla_len, NLA_ALIGNTO as usize) as u32;
 
-    if send(
-        sock.sock,
-        &req as *const _ as *const _,
-        req.header.nlmsg_len as usize,
-        0,
-    ) < 0
-    {
-        return Err(io::Error::last_os_error());
-    }
+    sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
 
     sock.recv()?;
 
@@ -87,16 +71,15 @@ pub(crate) unsafe fn netlink_set_xdp_fd(
 pub(crate) unsafe fn netlink_qdisc_add_clsact(if_index: i32) -> Result<(), io::Error> {
     let sock = NetlinkSocket::open()?;
 
-    let seq = 1;
-    let mut req = mem::zeroed::<QdiscRequest>();
+    let mut req = mem::zeroed::<TcRequest>();
 
-    // prepare the TC rquest
+    let nlmsg_len = mem::size_of::<nlmsghdr>() + mem::size_of::<tcmsg>();
     req.header = nlmsghdr {
-        nlmsg_len: (mem::size_of::<nlmsghdr>() + mem::size_of::<tcmsg>()) as u32,
+        nlmsg_len: nlmsg_len as u32,
         nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE) as u16,
         nlmsg_type: RTM_NEWQDISC,
         nlmsg_pid: 0,
-        nlmsg_seq: seq,
+        nlmsg_seq: 1,
     };
     req.tc_info.tcm_family = AF_UNSPEC as u8;
     req.tc_info.tcm_ifindex = if_index;
@@ -105,26 +88,11 @@ pub(crate) unsafe fn netlink_qdisc_add_clsact(if_index: i32) -> Result<(), io::E
     req.tc_info.tcm_info = 0;
 
     // add the TCA_KIND attribute
-    let attrs_buf = {
-        let attrs_addr = align_to(
-            &req as *const _ as usize + req.header.nlmsg_len as usize,
-            NLMSG_ALIGNTO as usize,
-        );
-        let attrs_end = &req as *const _ as usize + mem::size_of::<QdiscRequest>();
-        slice::from_raw_parts_mut(attrs_addr as *mut u8, attrs_end - attrs_addr)
-    };
+    let attrs_buf = request_attributes(&mut req, nlmsg_len);
     let attr_len = write_attr_bytes(attrs_buf, 0, TCA_KIND as u16, b"clsact\0")?;
     req.header.nlmsg_len += align_to(attr_len as usize, NLA_ALIGNTO as usize) as u32;
 
-    if send(
-        sock.sock,
-        &req as *const _ as *const _,
-        req.header.nlmsg_len as usize,
-        0,
-    ) < 0
-    {
-        return Err(io::Error::last_os_error());
-    }
+    sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
     sock.recv()?;
 
     Ok(())
@@ -137,32 +105,24 @@ pub(crate) unsafe fn netlink_qdisc_attach(
     prog_name: &CStr,
 ) -> Result<u32, io::Error> {
     let sock = NetlinkSocket::open()?;
-    let seq = 1;
     let priority = 0;
-    let mut req = mem::zeroed::<QdiscRequest>();
+    let mut req = mem::zeroed::<TcRequest>();
 
+    let nlmsg_len = mem::size_of::<nlmsghdr>() + mem::size_of::<tcmsg>();
     req.header = nlmsghdr {
-        nlmsg_len: (mem::size_of::<nlmsghdr>() + mem::size_of::<tcmsg>()) as u32,
+        nlmsg_len: nlmsg_len as u32,
         nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE | NLM_F_ECHO) as u16,
         nlmsg_type: RTM_NEWTFILTER,
         nlmsg_pid: 0,
-        nlmsg_seq: seq,
+        nlmsg_seq: 1,
     };
     req.tc_info.tcm_family = AF_UNSPEC as u8;
     req.tc_info.tcm_handle = 0; // auto-assigned, if not provided
     req.tc_info.tcm_ifindex = if_index;
     req.tc_info.tcm_parent = attach_type.parent();
-
     req.tc_info.tcm_info = tc_handler_make(priority << 16, htons(ETH_P_ALL as u16) as u32);
 
-    let attrs_buf = {
-        let attrs_addr = align_to(
-            &req as *const _ as usize + req.header.nlmsg_len as usize,
-            NLMSG_ALIGNTO as usize,
-        );
-        let attrs_end = &req as *const _ as usize + mem::size_of::<QdiscRequest>();
-        slice::from_raw_parts_mut(attrs_addr as *mut u8, attrs_end - attrs_addr)
-    };
+    let attrs_buf = request_attributes(&mut req, nlmsg_len);
 
     // add TCA_KIND
     let kind_len = write_attr_bytes(attrs_buf, 0, TCA_KIND as u16, b"bpf\0")?;
@@ -176,16 +136,7 @@ pub(crate) unsafe fn netlink_qdisc_attach(
     let options_len = options.finish()?;
 
     req.header.nlmsg_len += align_to(kind_len + options_len as usize, NLA_ALIGNTO as usize) as u32;
-
-    if send(
-        sock.sock,
-        &req as *const _ as *const _,
-        req.header.nlmsg_len as usize,
-        0,
-    ) < 0
-    {
-        return Err(io::Error::last_os_error());
-    }
+    sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
 
     // find the RTM_NEWTFILTER reply and read the tcm_info field which we'll
     // need to detach
@@ -218,15 +169,14 @@ pub(crate) unsafe fn netlink_qdisc_detach(
     priority: u32,
 ) -> Result<(), io::Error> {
     let sock = NetlinkSocket::open()?;
-    let seq = 1;
-    let mut req = mem::zeroed::<QdiscRequest>();
+    let mut req = mem::zeroed::<TcRequest>();
 
     req.header = nlmsghdr {
         nlmsg_len: (mem::size_of::<nlmsghdr>() + mem::size_of::<tcmsg>()) as u32,
         nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK) as u16,
         nlmsg_type: RTM_DELTFILTER,
         nlmsg_pid: 0,
-        nlmsg_seq: seq,
+        nlmsg_seq: 1,
     };
 
     req.tc_info.tcm_family = AF_UNSPEC as u8;
@@ -235,15 +185,7 @@ pub(crate) unsafe fn netlink_qdisc_detach(
     req.tc_info.tcm_parent = attach_type.parent();
     req.tc_info.tcm_ifindex = if_index;
 
-    if send(
-        sock.sock,
-        &req as *const _ as *const _,
-        req.header.nlmsg_len as usize,
-        0,
-    ) < 0
-    {
-        return Err(io::Error::last_os_error());
-    }
+    sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?;
 
     sock.recv()?;
 
@@ -258,7 +200,7 @@ struct Request {
 }
 
 #[repr(C)]
-struct QdiscRequest {
+struct TcRequest {
     header: nlmsghdr,
     tc_info: tcmsg,
     attrs: [u8; 64],
@@ -305,6 +247,13 @@ impl NetlinkSocket {
         })
     }
 
+    fn send(&self, msg: &[u8]) -> Result<(), io::Error> {
+        if unsafe { send(self.sock, msg.as_ptr() as *const _, msg.len(), 0) } < 0 {
+            return Err(io::Error::last_os_error());
+        }
+        Ok(())
+    }
+
     fn recv(&self) -> Result<Vec<NetlinkMessage>, io::Error> {
         let mut buf = [0u8; 4096];
         let mut messages = Vec::new();
@@ -487,6 +436,20 @@ fn write_bytes(buf: &mut [u8], offset: usize, value: &[u8]) -> Result<usize, io:
     Ok(value.len())
 }
 
+
+unsafe fn request_attributes<T>(req: &mut T, msg_len: usize) -> &mut [u8] {
+    let attrs_addr = align_to(
+        req as *const _ as usize + msg_len as usize,
+        NLMSG_ALIGNTO as usize,
+    );
+    let attrs_end = req as *const _ as usize + mem::size_of::<T>();
+    slice::from_raw_parts_mut(attrs_addr as *mut u8, attrs_end - attrs_addr)
+}
+
+fn bytes_of<T>(val: &T) -> &[u8] {
+    let size = mem::size_of::<T>();
+    unsafe { slice::from_raw_parts(slice::from_ref(val).as_ptr().cast(), size) }
+}
 #[cfg(test)]
 mod tests {
     use super::*;