Browse Source

Refactor ELF parser a bit and add more tests

Alessandro Decina 4 years ago
parent
commit
5b8def7b69
2 changed files with 399 additions and 95 deletions
  1. 4 1
      Cargo.toml
  2. 395 94
      src/obj/mod.rs

+ 4 - 1
Cargo.toml

@@ -9,4 +9,7 @@ libc = "0.2"
 thiserror = "1"
 object = "0.23"
 bytes = "1"
-lazy_static = "1"
+lazy_static = "1"
+
+[dev-dependencies]
+matches = "0.1.8"

+ 395 - 94
src/obj/mod.rs

@@ -3,14 +3,14 @@ mod relocation;
 
 use object::{
     pod,
-    read::{Object as ElfObject, ObjectSection, Section},
+    read::{Object as ElfObject, ObjectSection, Section as ObjSection},
     Endianness, ObjectSymbol, ObjectSymbolTable, SectionIndex, SymbolIndex,
 };
 use std::{
     collections::HashMap,
     convert::{TryFrom, TryInto},
     ffi::{CStr, CString},
-    mem,
+    mem, ptr,
     str::FromStr,
 };
 use thiserror::Error;
@@ -93,27 +93,17 @@ impl Object {
         let section = obj
             .section_by_name("license")
             .ok_or(ParseError::MissingLicense)?;
-        let license = parse_license(BPFSection::try_from(&section)?.data)?;
+        let license = parse_license(Section::try_from(&section)?.data)?;
 
         let section = obj
             .section_by_name("version")
             .ok_or(ParseError::MissingKernelVersion)?;
-        let kernel_version = parse_version(BPFSection::try_from(&section)?.data, endianness)?;
+        let kernel_version = parse_version(Section::try_from(&section)?.data, endianness)?;
 
-        let mut bpf_obj = Object {
-            endianness: endianness.into(),
-            license,
-            kernel_version,
-            btf: None,
-            btf_ext: None,
-            maps: HashMap::new(),
-            programs: HashMap::new(),
-            relocations: HashMap::new(),
-            symbol_table: HashMap::new(),
-        };
+        let mut bpf_obj = Object::new(endianness, license, kernel_version);
 
         for s in obj.sections() {
-            parse_section(&mut bpf_obj, BPFSection::try_from(&s)?)?;
+            bpf_obj.parse_section(Section::try_from(&s)?)?;
         }
 
         if let Some(symbol_table) = obj.symbol_table() {
@@ -131,6 +121,90 @@ impl Object {
 
         return Ok(bpf_obj);
     }
+
+    fn new(endianness: Endianness, license: CString, kernel_version: KernelVersion) -> Object {
+        Object {
+            endianness: endianness.into(),
+            license,
+            kernel_version,
+            btf: None,
+            btf_ext: None,
+            maps: HashMap::new(),
+            programs: HashMap::new(),
+            relocations: HashMap::new(),
+            symbol_table: HashMap::new(),
+        }
+    }
+
+    fn parse_program(&self, section: &Section, ty: &str) -> Result<Program, ParseError> {
+        let num_instructions = section.data.len() / mem::size_of::<bpf_insn>();
+        if section.data.len() % mem::size_of::<bpf_insn>() > 0 {
+            return Err(ParseError::InvalidProgramCode {
+                name: section.name.to_owned(),
+            });
+        }
+        let instructions = (0..num_instructions)
+            .map(|i| unsafe {
+                ptr::read_unaligned(
+                    (section.data.as_ptr() as usize + i * mem::size_of::<bpf_insn>())
+                        as *const bpf_insn,
+                )
+            })
+            .collect::<Vec<_>>();
+
+        Ok(Program {
+            section_index: section.index,
+            license: self.license.clone(),
+            kernel_version: self.kernel_version,
+            instructions,
+            kind: ProgramKind::from_str(ty)?,
+        })
+    }
+
+    fn parse_btf(&mut self, section: &Section) -> Result<(), BtfError> {
+        self.btf = Some(Btf::parse(section.data)?);
+
+        Ok(())
+    }
+
+    fn parse_btf_ext(&mut self, section: &Section) -> Result<(), BtfError> {
+        self.btf_ext = Some(BtfExt::parse(section.data)?);
+        Ok(())
+    }
+
+    fn parse_section(&mut self, section: Section) -> Result<(), ParseError> {
+        let parts = section.name.split("/").collect::<Vec<_>>();
+
+        match parts.as_slice() {
+            &[name]
+                if name == ".bss" || name.starts_with(".data") || name.starts_with(".rodata") =>
+            {
+                self.maps
+                    .insert(name.to_string(), parse_map(&section, name)?);
+            }
+            &[".BTF"] => self.parse_btf(&section)?,
+            &[".BTF.ext"] => self.parse_btf_ext(&section)?,
+            &["maps", name] => {
+                self.maps
+                    .insert(name.to_string(), parse_map(&section, name)?);
+            }
+            &[ty @ "kprobe", name]
+            | &[ty @ "uprobe", name]
+            | &[ty @ "socket_filter", name]
+            | &[ty @ "xdp", name]
+            | &[ty @ "trace_point", name] => {
+                self.programs
+                    .insert(name.to_string(), self.parse_program(&section, ty)?);
+                if !section.relocations.is_empty() {
+                    self.relocations.insert(section.index, section.relocations);
+                }
+            }
+
+            _ => {}
+        }
+
+        Ok(())
+    }
 }
 
 #[derive(Debug, Clone, Error)]
@@ -179,23 +253,23 @@ pub enum ParseError {
     InvalidMapDefinition { name: String },
 }
 
-struct BPFSection<'s> {
+struct Section<'a> {
     index: SectionIndex,
-    name: &'s str,
-    data: &'s [u8],
+    name: &'a str,
+    data: &'a [u8],
     relocations: Vec<Relocation>,
 }
 
-impl<'data, 'file, 's> TryFrom<&'s Section<'data, 'file>> for BPFSection<'s> {
+impl<'data, 'file, 'a> TryFrom<&'a ObjSection<'data, 'file>> for Section<'a> {
     type Error = ParseError;
 
-    fn try_from(section: &'s Section) -> Result<BPFSection<'s>, ParseError> {
+    fn try_from(section: &'a ObjSection) -> Result<Section<'a>, ParseError> {
         let index = section.index();
         let map_err = |source| ParseError::SectionError {
             index: index.0,
             source,
         };
-        Ok(BPFSection {
+        Ok(Section {
             index,
             name: section.name().map_err(map_err)?,
             data: section.data().map_err(map_err)?,
@@ -269,7 +343,7 @@ impl From<KernelVersion> for u32 {
     }
 }
 
-fn parse_map(section: &BPFSection, name: &str) -> Result<Map, ParseError> {
+fn parse_map(section: &Section, name: &str) -> Result<Map, ParseError> {
     let (def, data) = if name == ".bss" || name.starts_with(".data") || name.starts_with(".rodata")
     {
         let def = bpf_map_def {
@@ -293,89 +367,45 @@ fn parse_map(section: &BPFSection, name: &str) -> Result<Map, ParseError> {
 }
 
 fn parse_map_def(name: &str, data: &[u8]) -> Result<bpf_map_def, ParseError> {
-    let (def, rest) =
-        pod::from_bytes::<bpf_map_def>(data).map_err(|_| ParseError::InvalidMapDefinition {
-            name: name.to_string(),
-        })?;
-    if !rest.is_empty() {
+    if mem::size_of::<bpf_map_def>() > data.len() {
         return Err(ParseError::InvalidMapDefinition {
-            name: name.to_string(),
-        });
-    }
-
-    Ok(*def)
-}
-
-fn parse_program(bpf: &Object, section: &BPFSection, ty: &str) -> Result<Program, ParseError> {
-    let (code, rest) = pod::slice_from_bytes::<bpf_insn>(
-        section.data,
-        section.data.len() / mem::size_of::<bpf_insn>(),
-    )
-    .map_err(|_| ParseError::InvalidProgramCode {
-        name: section.name.to_string(),
-    })?;
-
-    if !rest.is_empty() {
-        return Err(ParseError::InvalidProgramCode {
-            name: section.name.to_string(),
+            name: name.to_owned(),
         });
     }
 
-    Ok(Program {
-        section_index: section.index,
-        license: bpf.license.clone(),
-        kernel_version: bpf.kernel_version,
-        instructions: code.to_vec(),
-        kind: ProgramKind::from_str(ty)?,
-    })
+    Ok(unsafe { ptr::read_unaligned(data.as_ptr() as *const bpf_map_def) })
 }
 
-fn parse_btf(obj: &mut Object, section: &BPFSection) -> Result<(), BtfError> {
-    obj.btf = Some(Btf::parse(section.data)?);
-
-    Ok(())
-}
-
-fn parse_btf_ext(obj: &mut Object, section: &BPFSection) -> Result<(), BtfError> {
-    obj.btf_ext = Some(BtfExt::parse(section.data)?);
-    Ok(())
-}
+#[cfg(test)]
+mod tests {
+    use matches::assert_matches;
+    use object::Endianness;
+    use std::slice;
 
-fn parse_section(bpf: &mut Object, section: BPFSection) -> Result<(), ParseError> {
-    let parts = section.name.split("/").collect::<Vec<_>>();
+    use super::*;
 
-    match parts.as_slice() {
-        &[name] if name == ".bss" || name.starts_with(".data") || name.starts_with(".rodata") => {
-            bpf.maps
-                .insert(name.to_string(), parse_map(&section, name)?);
-        }
-        &[".BTF"] => parse_btf(bpf, &section)?,
-        &[".BTF.ext"] => parse_btf_ext(bpf, &section)?,
-        &["maps", name] => {
-            bpf.maps
-                .insert(name.to_string(), parse_map(&section, name)?);
-        }
-        &[ty @ "kprobe", name]
-        | &[ty @ "uprobe", name]
-        | &[ty @ "xdp", name]
-        | &[ty @ "trace_point", name] => {
-            bpf.programs
-                .insert(name.to_string(), parse_program(bpf, &section, ty)?);
-            if !section.relocations.is_empty() {
-                bpf.relocations.insert(section.index, section.relocations);
-            }
+    fn fake_section<'a>(name: &'a str, data: &'a [u8]) -> Section<'a> {
+        Section {
+            index: SectionIndex(0),
+            name,
+            data,
+            relocations: Vec::new(),
         }
-
-        _ => {}
     }
 
-    Ok(())
-}
+    fn fake_ins() -> bpf_insn {
+        bpf_insn {
+            code: 0,
+            _bitfield_1: bpf_insn::new_bitfield_1(0, 0),
+            off: 0,
+            imm: 0,
+        }
+    }
 
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use object::Endianness;
+    fn bytes_of<T>(val: &T) -> &[u8] {
+        let size = mem::size_of::<T>();
+        unsafe { slice::from_raw_parts(slice::from_ref(val).as_ptr().cast(), size) }
+    }
 
     #[test]
     fn test_parse_generic_error() {
@@ -435,4 +465,275 @@ mod tests {
             KernelVersion::Version(1234)
         );
     }
+
+    #[test]
+    fn test_parse_map_def() {
+        assert!(matches!(
+            parse_map_def("foo", &[]),
+            Err(ParseError::InvalidMapDefinition { .. })
+        ));
+        assert!(matches!(
+            parse_map_def(
+                "foo",
+                bytes_of(&bpf_map_def {
+                    map_type: 1,
+                    key_size: 2,
+                    value_size: 3,
+                    max_entries: 4,
+                    map_flags: 5
+                })
+            ),
+            Ok(bpf_map_def {
+                map_type: 1,
+                key_size: 2,
+                value_size: 3,
+                max_entries: 4,
+                map_flags: 5
+            })
+        ));
+    }
+
+    #[test]
+    fn test_parse_map_error() {
+        assert!(matches!(
+            parse_map(&fake_section("maps/foo", &[]), "foo"),
+            Err(ParseError::InvalidMapDefinition { .. })
+        ))
+    }
+
+    #[test]
+    fn test_parse_map() {
+        assert!(matches!(
+            parse_map(
+                &fake_section(
+                    "maps/foo",
+                    bytes_of(&bpf_map_def {
+                        map_type: 1,
+                        key_size: 2,
+                        value_size: 3,
+                        max_entries: 4,
+                        map_flags: 5
+                    })
+                ),
+                "foo"
+            ),
+            Ok(Map {
+                section_index: 0,
+                name,
+                def: bpf_map_def {
+                    map_type: 1,
+                    key_size: 2,
+                    value_size: 3,
+                    max_entries: 4,
+                    map_flags: 5,
+                },
+                data
+            }) if name == "foo" && data.is_empty()
+        ))
+    }
+
+    #[test]
+    fn test_parse_map_data() {
+        let map_data = b"map data";
+        assert!(matches!(
+            parse_map(
+                &fake_section(
+                    ".bss",
+                    map_data,
+                ),
+                ".bss"
+            ),
+            Ok(Map {
+                section_index: 0,
+                name,
+                def: bpf_map_def {
+                    map_type: BPF_MAP_TYPE_ARRAY,
+                    key_size: 4,
+                    value_size,
+                    max_entries: 1,
+                    map_flags: 0,
+                },
+                data
+            }) if name == ".bss" && data == map_data && value_size == map_data.len() as u32
+        ))
+    }
+
+    fn fake_obj() -> Object {
+        Object::new(
+            Endianness::Little,
+            CString::new("GPL").unwrap(),
+            KernelVersion::Any,
+        )
+    }
+
+    #[test]
+    fn test_parse_program_error() {
+        let obj = fake_obj();
+
+        assert_matches!(
+            obj.parse_program(
+                &fake_section(
+                    "kprobe/foo",
+                    &42u32.to_ne_bytes(),
+                ),
+                "kprobe"
+            ),
+            Err(ParseError::InvalidProgramCode { name }) if name == "kprobe/foo"
+        );
+    }
+
+    #[test]
+    fn test_parse_program() {
+        let obj = fake_obj();
+
+        assert_matches!(
+            obj.parse_program(&fake_section("kprobe/foo", bytes_of(&fake_ins())), "kprobe"),
+            Ok(Program {
+                license,
+                kernel_version,
+                kind: ProgramKind::KProbe,
+                section_index: SectionIndex(0),
+                instructions
+            }) if license.to_string_lossy() == "GPL" && kernel_version == KernelVersion::Any && instructions.len() == 1
+        );
+    }
+
+    #[test]
+    fn test_parse_section_map() {
+        let mut obj = fake_obj();
+
+        assert_matches!(
+            obj.parse_section(fake_section(
+                "maps/foo",
+                bytes_of(&bpf_map_def {
+                    map_type: 1,
+                    key_size: 2,
+                    value_size: 3,
+                    max_entries: 4,
+                    map_flags: 5
+                })
+            ),),
+            Ok(())
+        );
+        assert!(obj.maps.get("foo").is_some());
+    }
+
+    #[test]
+    fn test_parse_section_data() {
+        let mut obj = fake_obj();
+        assert_matches!(
+            obj.parse_section(fake_section(".bss", b"map data"),),
+            Ok(())
+        );
+        assert!(obj.maps.get(".bss").is_some());
+
+        assert_matches!(
+            obj.parse_section(fake_section(".rodata", b"map data"),),
+            Ok(())
+        );
+        assert!(obj.maps.get(".rodata").is_some());
+
+        assert_matches!(
+            obj.parse_section(fake_section(".rodata.boo", b"map data"),),
+            Ok(())
+        );
+        assert!(obj.maps.get(".rodata.boo").is_some());
+
+        assert_matches!(
+            obj.parse_section(fake_section(".data", b"map data"),),
+            Ok(())
+        );
+        assert!(obj.maps.get(".data").is_some());
+
+        assert_matches!(
+            obj.parse_section(fake_section(".data.boo", b"map data"),),
+            Ok(())
+        );
+        assert!(obj.maps.get(".data.boo").is_some());
+    }
+
+    #[test]
+    fn test_parse_section_kprobe() {
+        let mut obj = fake_obj();
+
+        assert_matches!(
+            obj.parse_section(fake_section("kprobe/foo", bytes_of(&fake_ins()))),
+            Ok(())
+        );
+        assert_matches!(
+            obj.programs.get("foo"),
+            Some(Program {
+                kind: ProgramKind::KProbe,
+                ..
+            })
+        );
+    }
+
+    #[test]
+    fn test_parse_section_uprobe() {
+        let mut obj = fake_obj();
+
+        assert_matches!(
+            obj.parse_section(fake_section("uprobe/foo", bytes_of(&fake_ins()))),
+            Ok(())
+        );
+        assert_matches!(
+            obj.programs.get("foo"),
+            Some(Program {
+                kind: ProgramKind::UProbe,
+                ..
+            })
+        );
+    }
+
+    #[test]
+    fn test_parse_section_trace_point() {
+        let mut obj = fake_obj();
+
+        assert_matches!(
+            obj.parse_section(fake_section("trace_point/foo", bytes_of(&fake_ins()))),
+            Ok(())
+        );
+        assert_matches!(
+            obj.programs.get("foo"),
+            Some(Program {
+                kind: ProgramKind::TracePoint,
+                ..
+            })
+        );
+    }
+
+    #[test]
+    fn test_parse_section_socket_filter() {
+        let mut obj = fake_obj();
+
+        assert_matches!(
+            obj.parse_section(fake_section("socket_filter/foo", bytes_of(&fake_ins()))),
+            Ok(())
+        );
+        assert_matches!(
+            obj.programs.get("foo"),
+            Some(Program {
+                kind: ProgramKind::SocketFilter,
+                ..
+            })
+        );
+    }
+
+    #[test]
+    fn test_parse_section_xdp() {
+        let mut obj = fake_obj();
+
+        assert_matches!(
+            obj.parse_section(fake_section("xdp/foo", bytes_of(&fake_ins()))),
+            Ok(())
+        );
+        assert_matches!(
+            obj.programs.get("foo"),
+            Some(Program {
+                kind: ProgramKind::Xdp,
+                ..
+            })
+        );
+    }
 }