ソースを参照

maps: fail new() for high level wrappers if the underlying map hasn't been created

Alessandro Decina 4 年 前
コミット
ba992a2414
3 ファイル変更39 行追加64 行削除
  1. 27 53
      aya/src/maps/hash_map.rs
  2. 1 0
      aya/src/maps/perf_map/perf_map.rs
  3. 11 11
      aya/src/maps/program_array.rs

+ 27 - 53
aya/src/maps/hash_map.rs

@@ -23,25 +23,29 @@ pub struct HashMap<T: Deref<Target = Map>, K, V> {
 
 impl<T: Deref<Target = Map>, K: Pod, V: Pod> HashMap<T, K, V> {
     pub fn new(map: T) -> Result<HashMap<T, K, V>, MapError> {
-        let inner = map.deref();
-        let map_type = inner.obj.def.map_type;
+        let map_type = map.obj.def.map_type;
+
+        // validate the map definition
         if map_type != BPF_MAP_TYPE_HASH {
             return Err(MapError::InvalidMapType {
                 map_type: map_type as u32,
             })?;
         }
         let size = mem::size_of::<K>();
-        let expected = inner.obj.def.key_size as usize;
+        let expected = map.obj.def.key_size as usize;
         if size != expected {
             return Err(MapError::InvalidKeySize { size, expected });
         }
 
         let size = mem::size_of::<V>();
-        let expected = inner.obj.def.value_size as usize;
+        let expected = map.obj.def.value_size as usize;
         if size != expected {
             return Err(MapError::InvalidValueSize { size, expected });
         }
 
+        // make sure the map has been created
+        let _fd = map.fd_or_err()?;
+
         Ok(HashMap {
             inner: map,
             _k: PhantomData,
@@ -220,26 +224,35 @@ mod tests {
     }
 
     #[test]
-    fn test_try_from_ok() {
-        let map = Map {
+    fn test_new_not_created() {
+        let mut map = Map {
             obj: new_obj_map("TEST"),
             fd: None,
         };
-        assert!(HashMap::<_, u32, u32>::try_from(&map).is_ok())
+
+        assert!(matches!(
+            HashMap::<_, u32, u32>::new(&mut map),
+            Err(MapError::NotCreated { .. })
+        ));
     }
 
     #[test]
-    fn test_insert_not_created() {
+    fn test_new_ok() {
         let mut map = Map {
             obj: new_obj_map("TEST"),
-            fd: None,
+            fd: Some(42),
         };
-        let mut hm = HashMap::<_, u32, u32>::new(&mut map).unwrap();
 
-        assert!(matches!(
-            hm.insert(1, 42, 0),
-            Err(MapError::NotCreated { .. })
-        ));
+        assert!(HashMap::<_, u32, u32>::new(&mut map).is_ok());
+    }
+
+    #[test]
+    fn test_try_from_ok() {
+        let map = Map {
+            obj: new_obj_map("TEST"),
+            fd: Some(42),
+        };
+        assert!(HashMap::<_, u32, u32>::try_from(&map).is_ok())
     }
 
     #[test]
@@ -277,17 +290,6 @@ mod tests {
         assert!(hm.insert(1, 42, 0).is_ok());
     }
 
-    #[test]
-    fn test_remove_not_created() {
-        let mut map = Map {
-            obj: new_obj_map("TEST"),
-            fd: None,
-        };
-        let mut hm = HashMap::<_, u32, u32>::new(&mut map).unwrap();
-
-        assert!(matches!(hm.remove(&1), Err(MapError::NotCreated { .. })));
-    }
-
     #[test]
     fn test_remove_syscall_error() {
         override_syscall(|_| sys_error(EFAULT));
@@ -323,20 +325,6 @@ mod tests {
         assert!(hm.remove(&1).is_ok());
     }
 
-    #[test]
-    fn test_get_not_created() {
-        let map = Map {
-            obj: new_obj_map("TEST"),
-            fd: None,
-        };
-        let hm = HashMap::<_, u32, u32>::new(&map).unwrap();
-
-        assert!(matches!(
-            unsafe { hm.get(&1, 0) },
-            Err(MapError::NotCreated { .. })
-        ));
-    }
-
     #[test]
     fn test_get_syscall_error() {
         override_syscall(|_| sys_error(EFAULT));
@@ -370,20 +358,6 @@ mod tests {
         assert!(matches!(unsafe { hm.get(&1, 0) }, Ok(None)));
     }
 
-    #[test]
-    fn test_pop_not_created() {
-        let mut map = Map {
-            obj: new_obj_map("TEST"),
-            fd: None,
-        };
-        let mut hm = HashMap::<_, u32, u32>::new(&mut map).unwrap();
-
-        assert!(matches!(
-            unsafe { hm.pop(&1) },
-            Err(MapError::NotCreated { .. })
-        ));
-    }
-
     #[test]
     fn test_pop_syscall_error() {
         override_syscall(|_| sys_error(EFAULT));

+ 1 - 0
aya/src/maps/perf_map/perf_map.rs

@@ -73,6 +73,7 @@ impl<T: DerefMut<Target = Map>> PerfMap<T> {
                 map_type: map_type as u32,
             })?;
         }
+        let _fd = map.fd_or_err()?;
 
         Ok(PerfMap {
             map: Arc::new(map),

+ 11 - 11
aya/src/maps/program_array.rs

@@ -21,30 +21,30 @@ pub struct ProgramArray<T: Deref<Target = Map>> {
 
 impl<T: Deref<Target = Map>> ProgramArray<T> {
     pub fn new(map: T) -> Result<ProgramArray<T>, MapError> {
-        let inner = map.deref();
-        let map_type = inner.obj.def.map_type;
+        let map_type = map.obj.def.map_type;
         if map_type != BPF_MAP_TYPE_PROG_ARRAY {
             return Err(MapError::InvalidMapType {
                 map_type: map_type as u32,
             })?;
         }
         let expected = mem::size_of::<RawFd>();
-        let size = inner.obj.def.key_size as usize;
+        let size = map.obj.def.key_size as usize;
         if size != expected {
             return Err(MapError::InvalidKeySize { size, expected });
         }
 
         let expected = mem::size_of::<RawFd>();
-        let size = inner.obj.def.value_size as usize;
+        let size = map.obj.def.value_size as usize;
         if size != expected {
             return Err(MapError::InvalidValueSize { size, expected });
         }
+        let _fd = map.fd_or_err()?;
 
         Ok(ProgramArray { inner: map })
     }
 
     pub unsafe fn get(&self, key: &u32, flags: u64) -> Result<Option<RawFd>, MapError> {
-        let fd = self.inner.deref().fd_or_err()?;
+        let fd = self.inner.fd_or_err()?;
         let fd = bpf_map_lookup_elem(fd, key, flags)
             .map_err(|(code, io_error)| MapError::LookupElementError { code, io_error })?;
         Ok(fd)
@@ -59,8 +59,8 @@ impl<T: Deref<Target = Map>> ProgramArray<T> {
     }
 
     fn check_bounds(&self, index: u32) -> Result<(), MapError> {
-        let max_entries = self.inner.deref().obj.def.max_entries;
-        if index >= self.inner.deref().obj.def.max_entries {
+        let max_entries = self.inner.obj.def.max_entries;
+        if index >= self.inner.obj.def.max_entries {
             Err(MapError::OutOfBounds { index, max_entries })
         } else {
             Ok(())
@@ -75,7 +75,7 @@ impl<T: Deref<Target = Map> + DerefMut<Target = Map>> ProgramArray<T> {
         program: &dyn ProgramFd,
         flags: u64,
     ) -> Result<(), MapError> {
-        let fd = self.inner.deref().fd_or_err()?;
+        let fd = self.inner.fd_or_err()?;
         self.check_bounds(index)?;
         let prog_fd = program.fd().ok_or(MapError::ProgramNotLoaded)?;
 
@@ -85,14 +85,14 @@ impl<T: Deref<Target = Map> + DerefMut<Target = Map>> ProgramArray<T> {
     }
 
     pub unsafe fn pop(&mut self, index: &u32) -> Result<Option<RawFd>, MapError> {
-        let fd = self.inner.deref().fd_or_err()?;
+        let fd = self.inner.fd_or_err()?;
         self.check_bounds(*index)?;
         bpf_map_lookup_and_delete_elem(fd, index)
             .map_err(|(code, io_error)| MapError::LookupAndDeleteElementError { code, io_error })
     }
 
     pub fn remove(&mut self, index: &u32) -> Result<(), MapError> {
-        let fd = self.inner.deref().fd_or_err()?;
+        let fd = self.inner.fd_or_err()?;
         self.check_bounds(*index)?;
         bpf_map_delete_elem(fd, index)
             .map(|_| ())
@@ -102,7 +102,7 @@ impl<T: Deref<Target = Map> + DerefMut<Target = Map>> ProgramArray<T> {
 
 impl<T: Deref<Target = Map>> IterableMap<u32, RawFd> for ProgramArray<T> {
     fn fd(&self) -> Result<RawFd, MapError> {
-        self.inner.deref().fd_or_err()
+        self.inner.fd_or_err()
     }
 
     unsafe fn get(&self, index: &u32) -> Result<Option<RawFd>, MapError> {