Browse Source

netlink: use OwnedFd

Updates #612.
Tamir Duberstein 1 year ago
parent
commit
cee0265b52
2 changed files with 41 additions and 18 deletions
  1. 6 2
      aya/src/programs/xdp.rs
  2. 35 16
      aya/src/sys/netlink.rs

+ 6 - 2
aya/src/programs/xdp.rs

@@ -11,7 +11,7 @@ use std::{
     ffi::CString,
     hash::Hash,
     io,
-    os::fd::{AsFd as _, AsRawFd as _, RawFd},
+    os::fd::{AsFd as _, AsRawFd as _, BorrowedFd, RawFd},
 };
 use thiserror::Error;
 
@@ -201,6 +201,8 @@ impl Xdp {
             XdpLinkInner::NlLink(nl_link) => {
                 let if_index = nl_link.if_index;
                 let old_prog_fd = nl_link.prog_fd;
+                // SAFETY: TODO(https://github.com/aya-rs/aya/issues/612): make this safe by not holding `RawFd`s.
+                let old_prog_fd = unsafe { BorrowedFd::borrow_raw(old_prog_fd) };
                 let flags = nl_link.flags;
                 let replace_flags = flags | XdpFlags::REPLACE;
                 unsafe {
@@ -246,7 +248,9 @@ impl Link for NlLink {
         } else {
             self.flags.bits()
         };
-        let _ = unsafe { netlink_set_xdp_fd(self.if_index, None, Some(self.prog_fd), flags) };
+        // SAFETY: TODO(https://github.com/aya-rs/aya/issues/612): make this safe by not holding `RawFd`s.
+        let prog_fd = unsafe { BorrowedFd::borrow_raw(self.prog_fd) };
+        let _ = unsafe { netlink_set_xdp_fd(self.if_index, None, Some(prog_fd), flags) };
         Ok(())
     }
 }

+ 35 - 16
aya/src/sys/netlink.rs

@@ -2,13 +2,13 @@ use std::{
     collections::HashMap,
     ffi::CStr,
     io, mem,
-    os::fd::{AsRawFd as _, BorrowedFd, RawFd},
+    os::fd::{AsRawFd as _, BorrowedFd, FromRawFd as _, OwnedFd},
     ptr, slice,
 };
 use thiserror::Error;
 
 use libc::{
-    close, getsockname, nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, socket,
+    getsockname, nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, socket,
     AF_NETLINK, AF_UNSPEC, ETH_P_ALL, IFF_UP, IFLA_XDP, NETLINK_EXT_ACK, NETLINK_ROUTE,
     NLA_ALIGNTO, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_DONE, NLMSG_ERROR, NLM_F_ACK, NLM_F_CREATE,
     NLM_F_DUMP, NLM_F_ECHO, NLM_F_EXCL, NLM_F_MULTI, NLM_F_REQUEST, RTM_DELTFILTER, RTM_GETTFILTER,
@@ -32,7 +32,7 @@ const NLA_HDR_LEN: usize = align_to(mem::size_of::<nlattr>(), NLA_ALIGNTO as usi
 pub(crate) unsafe fn netlink_set_xdp_fd(
     if_index: i32,
     fd: Option<BorrowedFd<'_>>,
-    old_fd: Option<RawFd>,
+    old_fd: Option<BorrowedFd<'_>>,
     flags: u32,
 ) -> Result<(), io::Error> {
     let sock = NetlinkSocket::open()?;
@@ -64,7 +64,10 @@ pub(crate) unsafe fn netlink_set_xdp_fd(
     }
 
     if flags & XDP_FLAGS_REPLACE != 0 {
-        attrs.write_attr(IFLA_XDP_EXPECTED_FD as u16, old_fd.unwrap())?;
+        attrs.write_attr(
+            IFLA_XDP_EXPECTED_FD as u16,
+            old_fd.map(|fd| fd.as_raw_fd()).unwrap(),
+        )?;
     }
 
     let nla_len = attrs.finish()?;
@@ -290,7 +293,7 @@ struct TcRequest {
 }
 
 struct NetlinkSocket {
-    sock: RawFd,
+    sock: OwnedFd,
     _nl_pid: u32,
 }
 
@@ -301,12 +304,14 @@ impl NetlinkSocket {
         if sock < 0 {
             return Err(io::Error::last_os_error());
         }
+        // SAFETY: `socket` returns a file descriptor.
+        let sock = unsafe { OwnedFd::from_raw_fd(sock) };
 
         let enable = 1i32;
         // Safety: libc wrapper
         unsafe {
             setsockopt(
-                sock,
+                sock.as_raw_fd(),
                 SOL_NETLINK,
                 NETLINK_EXT_ACK,
                 &enable as *const _ as *const _,
@@ -319,7 +324,13 @@ impl NetlinkSocket {
         addr.nl_family = AF_NETLINK as u16;
         let mut addr_len = mem::size_of::<sockaddr_nl>() as u32;
         // Safety: libc wrapper
-        if unsafe { getsockname(sock, &mut addr as *mut _ as *mut _, &mut addr_len as *mut _) } < 0
+        if unsafe {
+            getsockname(
+                sock.as_raw_fd(),
+                &mut addr as *mut _ as *mut _,
+                &mut addr_len as *mut _,
+            )
+        } < 0
         {
             return Err(io::Error::last_os_error());
         }
@@ -331,7 +342,15 @@ impl NetlinkSocket {
     }
 
     fn send(&self, msg: &[u8]) -> Result<(), io::Error> {
-        if unsafe { send(self.sock, msg.as_ptr() as *const _, msg.len(), 0) } < 0 {
+        if unsafe {
+            send(
+                self.sock.as_raw_fd(),
+                msg.as_ptr() as *const _,
+                msg.len(),
+                0,
+            )
+        } < 0
+        {
             return Err(io::Error::last_os_error());
         }
         Ok(())
@@ -344,7 +363,14 @@ impl NetlinkSocket {
         'out: while multipart {
             multipart = false;
             // Safety: libc wrapper
-            let len = unsafe { recv(self.sock, buf.as_mut_ptr() as *mut _, buf.len(), 0) };
+            let len = unsafe {
+                recv(
+                    self.sock.as_raw_fd(),
+                    buf.as_mut_ptr() as *mut _,
+                    buf.len(),
+                    0,
+                )
+            };
             if len < 0 {
                 return Err(io::Error::last_os_error());
             }
@@ -430,13 +456,6 @@ impl NetlinkMessage {
     }
 }
 
-impl Drop for NetlinkSocket {
-    fn drop(&mut self) {
-        // Safety: libc wrapper
-        unsafe { close(self.sock) };
-    }
-}
-
 const fn align_to(v: usize, align: usize) -> usize {
     (v + (align - 1)) & !(align - 1)
 }