4
0

utils.rs 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. //! Utilities to run tests
  2. use std::{
  3. ffi::CString,
  4. io, process,
  5. sync::atomic::{AtomicU64, Ordering},
  6. };
  7. use aya::netlink_set_link_up;
  8. use libc::if_nametoindex;
  9. use netns_rs::{get_from_current_thread, NetNs};
  10. pub struct NetNsGuard {
  11. name: String,
  12. old_ns: NetNs,
  13. ns: Option<NetNs>,
  14. }
  15. impl NetNsGuard {
  16. pub fn new() -> Self {
  17. let old_ns = get_from_current_thread().expect("Failed to get current netns");
  18. static COUNTER: AtomicU64 = AtomicU64::new(0);
  19. let pid = process::id();
  20. let name = format!("aya-test-{pid}-{}", COUNTER.fetch_add(1, Ordering::Relaxed));
  21. // Create and enter netns
  22. let ns = NetNs::new(&name).unwrap_or_else(|e| panic!("Failed to create netns {name}: {e}"));
  23. let netns = Self {
  24. old_ns,
  25. ns: Some(ns),
  26. name,
  27. };
  28. let ns = netns.ns.as_ref().unwrap();
  29. ns.enter()
  30. .unwrap_or_else(|e| panic!("Failed to enter network namespace {}: {e}", netns.name));
  31. println!("Entered network namespace {}", netns.name);
  32. // By default, the loopback in a new netns is down. Set it up.
  33. let lo = CString::new("lo").unwrap();
  34. unsafe {
  35. let idx = if_nametoindex(lo.as_ptr());
  36. if idx == 0 {
  37. panic!(
  38. "Interface `lo` not found in netns {}: {}",
  39. netns.name,
  40. io::Error::last_os_error()
  41. );
  42. }
  43. netlink_set_link_up(idx as i32)
  44. .unwrap_or_else(|e| panic!("Failed to set `lo` up in netns {}: {e}", netns.name));
  45. }
  46. netns
  47. }
  48. }
  49. impl Drop for NetNsGuard {
  50. fn drop(&mut self) {
  51. // Avoid panic in panic
  52. if let Err(e) = self.old_ns.enter() {
  53. eprintln!("Failed to return to original netns: {e}");
  54. }
  55. if let Some(ns) = self.ns.take() {
  56. if let Err(e) = ns.remove() {
  57. eprintln!("Failed to remove netns {}: {e}", self.name);
  58. }
  59. }
  60. println!("Exited network namespace {}", self.name);
  61. }
  62. }