Browse Source

aya: refactor program section parsing

This renames aya::obj::ProgramKind to aya::obj::ProgramSection and moves
all the program section parsing to ProgramSection::from_str.
Alessandro Decina 3 years ago
parent
commit
bb595c4e69
2 changed files with 160 additions and 103 deletions
  1. 20 18
      aya/src/bpf.rs
  2. 140 85
      aya/src/obj/mod.rs

+ 20 - 18
aya/src/bpf.rs

@@ -16,7 +16,7 @@ use crate::{
     maps::{Map, MapError, MapLock, MapRef, MapRefMut},
     obj::{
         btf::{Btf, BtfError},
-        Object, ParseError, ProgramKind,
+        Object, ParseError, ProgramSection,
     },
     programs::{
         CgroupSkb, CgroupSkbAttachType, KProbe, LircMode2, ProbeKind, Program, ProgramData,
@@ -148,55 +148,57 @@ impl Bpf {
             .programs
             .drain()
             .map(|(name, obj)| {
-                let kind = obj.kind;
+                let section = obj.section.clone();
                 let data = ProgramData {
                     obj,
                     name: name.clone(),
                     fd: None,
                     links: Vec::new(),
                 };
-                let program = match kind {
-                    ProgramKind::KProbe => Program::KProbe(KProbe {
+                let program = match section {
+                    ProgramSection::KProbe { .. } => Program::KProbe(KProbe {
                         data,
                         kind: ProbeKind::KProbe,
                     }),
-                    ProgramKind::KRetProbe => Program::KProbe(KProbe {
+                    ProgramSection::KRetProbe { .. } => Program::KProbe(KProbe {
                         data,
                         kind: ProbeKind::KRetProbe,
                     }),
-                    ProgramKind::UProbe => Program::UProbe(UProbe {
+                    ProgramSection::UProbe { .. } => Program::UProbe(UProbe {
                         data,
                         kind: ProbeKind::UProbe,
                     }),
-                    ProgramKind::URetProbe => Program::UProbe(UProbe {
+                    ProgramSection::URetProbe { .. } => Program::UProbe(UProbe {
                         data,
                         kind: ProbeKind::URetProbe,
                     }),
-                    ProgramKind::TracePoint => Program::TracePoint(TracePoint { data }),
-                    ProgramKind::SocketFilter => Program::SocketFilter(SocketFilter { data }),
-                    ProgramKind::Xdp => Program::Xdp(Xdp { data }),
-                    ProgramKind::SkMsg => Program::SkMsg(SkMsg { data }),
-                    ProgramKind::SkSkbStreamParser => Program::SkSkb(SkSkb {
+                    ProgramSection::TracePoint { .. } => Program::TracePoint(TracePoint { data }),
+                    ProgramSection::SocketFilter { .. } => {
+                        Program::SocketFilter(SocketFilter { data })
+                    }
+                    ProgramSection::Xdp { .. } => Program::Xdp(Xdp { data }),
+                    ProgramSection::SkMsg { .. } => Program::SkMsg(SkMsg { data }),
+                    ProgramSection::SkSkbStreamParser { .. } => Program::SkSkb(SkSkb {
                         data,
                         kind: SkSkbKind::StreamParser,
                     }),
-                    ProgramKind::SkSkbStreamVerdict => Program::SkSkb(SkSkb {
+                    ProgramSection::SkSkbStreamVerdict { .. } => Program::SkSkb(SkSkb {
                         data,
                         kind: SkSkbKind::StreamVerdict,
                     }),
-                    ProgramKind::SockOps => Program::SockOps(SockOps { data }),
-                    ProgramKind::SchedClassifier => {
+                    ProgramSection::SockOps { .. } => Program::SockOps(SockOps { data }),
+                    ProgramSection::SchedClassifier { .. } => {
                         Program::SchedClassifier(SchedClassifier { data })
                     }
-                    ProgramKind::CgroupSkbIngress => Program::CgroupSkb(CgroupSkb {
+                    ProgramSection::CgroupSkbIngress { .. } => Program::CgroupSkb(CgroupSkb {
                         data,
                         expected_attach_type: Some(CgroupSkbAttachType::Ingress),
                     }),
-                    ProgramKind::CgroupSkbEgress => Program::CgroupSkb(CgroupSkb {
+                    ProgramSection::CgroupSkbEgress { .. } => Program::CgroupSkb(CgroupSkb {
                         data,
                         expected_attach_type: Some(CgroupSkbAttachType::Egress),
                     }),
-                    ProgramKind::LircMode2 => Program::LircMode2(LircMode2 { data }),
+                    ProgramSection::LircMode2 { .. } => Program::LircMode2(LircMode2 { data }),
                 };
 
                 (name, program)

+ 140 - 85
aya/src/obj/mod.rs

@@ -51,7 +51,7 @@ pub struct Map {
 pub(crate) struct Program {
     pub(crate) license: CString,
     pub(crate) kernel_version: KernelVersion,
-    pub(crate) kind: ProgramKind,
+    pub(crate) section: ProgramSection,
     pub(crate) function: Function,
 }
 
@@ -64,49 +64,86 @@ pub(crate) struct Function {
     pub(crate) instructions: Vec<bpf_insn>,
 }
 
-#[derive(Debug, Copy, Clone)]
-pub enum ProgramKind {
-    KProbe,
-    KRetProbe,
-    UProbe,
-    URetProbe,
-    TracePoint,
-    SocketFilter,
-    Xdp,
-    SkMsg,
-    SkSkbStreamParser,
-    SkSkbStreamVerdict,
-    SockOps,
-    SchedClassifier,
-    CgroupSkbIngress,
-    CgroupSkbEgress,
-    LircMode2,
+#[derive(Debug, Clone)]
+pub enum ProgramSection {
+    KRetProbe { name: String },
+    KProbe { name: String },
+    UProbe { name: String },
+    URetProbe { name: String },
+    TracePoint { name: String },
+    SocketFilter { name: String },
+    Xdp { name: String },
+    SkMsg { name: String },
+    SkSkbStreamParser { name: String },
+    SkSkbStreamVerdict { name: String },
+    SockOps { name: String },
+    SchedClassifier { name: String },
+    CgroupSkbIngress { name: String },
+    CgroupSkbEgress { name: String },
+    LircMode2 { name: String },
 }
 
-impl FromStr for ProgramKind {
+impl ProgramSection {
+    fn name(&self) -> &str {
+        match self {
+            ProgramSection::KRetProbe { name } => name,
+            ProgramSection::KProbe { name } => name,
+            ProgramSection::UProbe { name } => name,
+            ProgramSection::URetProbe { name } => name,
+            ProgramSection::TracePoint { name } => name,
+            ProgramSection::SocketFilter { name } => name,
+            ProgramSection::Xdp { name } => name,
+            ProgramSection::SkMsg { name } => name,
+            ProgramSection::SkSkbStreamParser { name } => name,
+            ProgramSection::SkSkbStreamVerdict { name } => name,
+            ProgramSection::SockOps { name } => name,
+            ProgramSection::SchedClassifier { name } => name,
+            ProgramSection::CgroupSkbIngress { name } => name,
+            ProgramSection::CgroupSkbEgress { name } => name,
+            ProgramSection::LircMode2 { name } => name,
+        }
+    }
+}
+
+impl FromStr for ProgramSection {
     type Err = ParseError;
 
-    fn from_str(kind: &str) -> Result<ProgramKind, ParseError> {
-        use ProgramKind::*;
+    fn from_str(section: &str) -> Result<ProgramSection, ParseError> {
+        use ProgramSection::*;
+
+        // parse the common case, eg "xdp/program_name" or
+        // "sk_skb/stream_verdict/program_name"
+        let mut parts = section.rsplitn(2, "/").collect::<Vec<_>>();
+        if parts.len() == 1 {
+            parts.push(parts[0]);
+        }
+        let kind = parts[1];
+        let name = parts[0].to_owned();
+
         Ok(match kind {
-            "kprobe" => KProbe,
-            "kretprobe" => KRetProbe,
-            "uprobe" => UProbe,
-            "uretprobe" => URetProbe,
-            "xdp" => Xdp,
-            "tracepoint" => TracePoint,
-            "socket_filter" => SocketFilter,
-            "sk_msg" => SkMsg,
-            "sk_skb/stream_parser" => SkSkbStreamParser,
-            "sk_skb/stream_verdict" => SkSkbStreamVerdict,
-            "sockops" => SockOps,
-            "classifier" => SchedClassifier,
-            "cgroup_skb/ingress" => CgroupSkbIngress,
-            "cgroup_skb/egress" => CgroupSkbEgress,
-            "lirc_mode2" => LircMode2,
+            "kprobe" => KProbe { name },
+            "kretprobe" => KRetProbe { name },
+            "uprobe" => UProbe { name },
+            "uretprobe" => URetProbe { name },
+            "xdp" => Xdp { name },
+            _ if kind.starts_with("tracepoint") || kind.starts_with("tp") => {
+                // tracepoint sections are named `tracepoint/category/event_name`,
+                // and we want to parse the name as "category/event_name"
+                let name = section.splitn(2, "/").last().unwrap().to_owned();
+                TracePoint { name }
+            }
+            "socket_filter" => SocketFilter { name },
+            "sk_msg" => SkMsg { name },
+            "sk_skb/stream_parser" => SkSkbStreamParser { name },
+            "sk_skb/stream_verdict" => SkSkbStreamVerdict { name },
+            "sockops" => SockOps { name },
+            "classifier" => SchedClassifier { name },
+            "cgroup_skb/ingress" => CgroupSkbIngress { name },
+            "cgroup_skb/egress" => CgroupSkbEgress { name },
+            "lirc_mode2" => LircMode2 { name },
             _ => {
-                return Err(ParseError::InvalidProgramKind {
-                    kind: kind.to_string(),
+                return Err(ParseError::InvalidProgramSection {
+                    section: section.to_owned(),
                 })
             }
         })
@@ -181,18 +218,15 @@ impl Object {
         Ok(())
     }
 
-    fn parse_program(
-        &self,
-        section: &Section,
-        ty: &str,
-        name: &str,
-    ) -> Result<Program, ParseError> {
+    fn parse_program(&self, section: &Section) -> Result<Program, ParseError> {
+        let prog_sec = ProgramSection::from_str(section.name)?;
+        let name = prog_sec.name().to_owned();
         Ok(Program {
             license: self.license.clone(),
             kernel_version: self.kernel_version,
-            kind: ProgramKind::from_str(ty)?,
+            section: prog_sec,
             function: Function {
-                name: name.to_owned(),
+                name,
                 address: section.address,
                 section_index: section.index,
                 section_offset: 0,
@@ -276,38 +310,23 @@ impl Object {
             }
         }
 
-        match parts.as_slice() {
-            &[name]
-                if name == ".bss" || name.starts_with(".data") || name.starts_with(".rodata") =>
-            {
+        match section.name {
+            name if name == ".bss" || name.starts_with(".data") || name.starts_with(".rodata") => {
                 self.maps
                     .insert(name.to_string(), parse_map(&section, name)?);
             }
-            &[name] if name.starts_with(".text") => self.parse_text_section(section)?,
-            &[".BTF"] => self.parse_btf(&section)?,
-            &[".BTF.ext"] => self.parse_btf_ext(&section)?,
-            &["maps", name] => {
+            name if name.starts_with(".text") => self.parse_text_section(section)?,
+            ".BTF" => self.parse_btf(&section)?,
+            ".BTF.ext" => self.parse_btf_ext(&section)?,
+            map if map.starts_with("maps/") => {
+                let name = map.splitn(2, "/").last().unwrap();
                 self.maps
                     .insert(name.to_string(), parse_map(&section, name)?);
             }
-            &[ty @ "kprobe", name]
-            | &[ty @ "kretprobe", name]
-            | &[ty @ "uprobe", name]
-            | &[ty @ "uretprobe", name]
-            | &[ty @ "socket_filter", name]
-            | &[ty @ "xdp", name]
-            | &[ty @ "tracepoint", name]
-            | &[ty @ "sk_msg", name]
-            | &[ty @ "sk_skb/stream_parser", name]
-            | &[ty @ "sk_skb/stream_verdict", name]
-            | &[ty @ "sockops", name]
-            | &[ty @ "classifier", name]
-            | &[ty @ "cgroup_skb/ingress", name]
-            | &[ty @ "cgroup_skb/egress", name]
-            | &[ty @ "cgroup/skb", name]
-            | &[ty @ "lirc_mode2", name] => {
+            name if is_program_section(name) => {
+                let program = self.parse_program(&section)?;
                 self.programs
-                    .insert(name.to_string(), self.parse_program(&section, ty, name)?);
+                    .insert(program.section.name().to_owned(), program);
                 if !section.relocations.is_empty() {
                     self.relocations.insert(
                         section.index,
@@ -351,8 +370,8 @@ pub enum ParseError {
     #[error("unsupported relocation target")]
     UnsupportedRelocationTarget,
 
-    #[error("invalid program kind `{kind}`")]
-    InvalidProgramKind { kind: String },
+    #[error("invalid program section `{section}`")]
+    InvalidProgramSection { section: String },
 
     #[error("invalid program code")]
     InvalidProgramCode,
@@ -516,6 +535,34 @@ fn copy_instructions(data: &[u8]) -> Result<Vec<bpf_insn>, ParseError> {
     Ok(instructions)
 }
 
+fn is_program_section(name: &str) -> bool {
+    for prefix in &[
+        "classifier",
+        "cgroup/skb",
+        "cgroup_skb/egress",
+        "cgroup_skb/ingress",
+        "kprobe",
+        "kretprobe",
+        "lirc_mode2",
+        "sk_msg",
+        "sk_skb/stream_parser",
+        "sk_skb/stream_verdict",
+        "socket_filter",
+        "sockops",
+        "tp",
+        "tracepoint",
+        "uprobe",
+        "uretprobe",
+        "xdp",
+    ] {
+        if name.starts_with(prefix) {
+            return true;
+        }
+    }
+
+    false
+}
+
 #[cfg(test)]
 mod tests {
     use matches::assert_matches;
@@ -723,11 +770,7 @@ mod tests {
         let obj = fake_obj();
 
         assert_matches!(
-            obj.parse_program(
-                &fake_section("kprobe/foo", &42u32.to_ne_bytes(),),
-                "kprobe",
-                "foo"
-            ),
+            obj.parse_program(&fake_section("kprobe/foo", &42u32.to_ne_bytes(),),),
             Err(ParseError::InvalidProgramCode)
         );
     }
@@ -737,11 +780,11 @@ mod tests {
         let obj = fake_obj();
 
         assert_matches!(
-            obj.parse_program(&fake_section("kprobe/foo", bytes_of(&fake_ins())), "kprobe", "foo"),
+            obj.parse_program(&fake_section("kprobe/foo", bytes_of(&fake_ins()))),
             Ok(Program {
                 license,
                 kernel_version: KernelVersion::Any,
-                kind: ProgramKind::KProbe,
+                section: ProgramSection::KProbe { .. },
                 function: Function {
                     name,
                     address: 0,
@@ -820,7 +863,7 @@ mod tests {
         assert_matches!(
             obj.programs.get("foo"),
             Some(Program {
-                kind: ProgramKind::KProbe,
+                section: ProgramSection::KProbe { .. },
                 ..
             })
         );
@@ -837,7 +880,7 @@ mod tests {
         assert_matches!(
             obj.programs.get("foo"),
             Some(Program {
-                kind: ProgramKind::UProbe,
+                section: ProgramSection::UProbe { .. },
                 ..
             })
         );
@@ -854,7 +897,19 @@ mod tests {
         assert_matches!(
             obj.programs.get("foo"),
             Some(Program {
-                kind: ProgramKind::TracePoint,
+                section: ProgramSection::TracePoint { .. },
+                ..
+            })
+        );
+
+        assert_matches!(
+            obj.parse_section(fake_section("tp/foo/bar", bytes_of(&fake_ins()))),
+            Ok(())
+        );
+        assert_matches!(
+            obj.programs.get("foo/bar"),
+            Some(Program {
+                section: ProgramSection::TracePoint { .. },
                 ..
             })
         );
@@ -871,7 +926,7 @@ mod tests {
         assert_matches!(
             obj.programs.get("foo"),
             Some(Program {
-                kind: ProgramKind::SocketFilter,
+                section: ProgramSection::SocketFilter { .. },
                 ..
             })
         );
@@ -888,7 +943,7 @@ mod tests {
         assert_matches!(
             obj.programs.get("foo"),
             Some(Program {
-                kind: ProgramKind::Xdp,
+                section: ProgramSection::Xdp { .. },
                 ..
             })
         );