فهرست منبع

tcp: add netsim test.

This test simulates a network with a given latency and packet loss, and measures the
throughput between two virtual smoltcp instances.
Dario Nieuwenhuis 5 ماه پیش
والد
کامیت
512aecba51
5فایلهای تغییر یافته به همراه404 افزوده شده و 1 حذف شده
  1. 9 1
      .github/workflows/test.yml
  2. 8 0
      Cargo.toml
  3. 8 0
      ci.sh
  4. 364 0
      tests/netsim.rs
  5. 15 0
      tests/snapshots/netsim__netsim.snap

+ 9 - 1
.github/workflows/test.yml

@@ -7,7 +7,7 @@ name: Test
 jobs:
   tests:
     runs-on: ubuntu-22.04
-    needs: [check-msrv, test-msrv, test-stable, clippy]
+    needs: [check-msrv, test-msrv, test-stable, clippy, test-netsim]
     steps:
       - name: Done
         run: exit 0
@@ -48,6 +48,14 @@ jobs:
       - name: Run Tests nightly
         run: ./ci.sh test nightly
 
+  test-netsim:
+    runs-on: ubuntu-22.04
+    continue-on-error: true
+    steps:
+      - uses: actions/checkout@v4
+      - name: Run network-simulation tests
+        run: ./ci.sh netsim
+
   test-build-16bit:
     runs-on: ubuntu-22.04
     continue-on-error: true

+ 8 - 0
Cargo.toml

@@ -35,6 +35,8 @@ getopts = "0.2"
 rand = "0.8"
 url = "2.0"
 rstest = "0.17"
+insta = "1.41.1"
+rand_chacha = "0.3.1"
 
 [features]
 std = ["managed/std", "alloc"]
@@ -109,6 +111,8 @@ default = [
 
 "_proto-fragmentation" = []
 
+"_netsim" = []
+
 # BEGIN AUTOGENERATED CONFIG FEATURES
 # Generated by gen_config.py. DO NOT EDIT.
 iface-max-addr-count-1 = []
@@ -267,6 +271,10 @@ rpl-parents-buffer-count-32 = []
 
 # END AUTOGENERATED CONFIG FEATURES
 
+[[test]]
+name = "netsim"
+required-features = ["_netsim"]
+
 [[example]]
 name = "packet2pcap"
 path = "utils/packet2pcap.rs"

+ 8 - 0
ci.sh

@@ -60,6 +60,10 @@ test() {
     fi
 }
 
+netsim() {
+    cargo test --release --features _netsim netsim
+}
+
 check() {
     local version=$1
     rustup toolchain install $version
@@ -139,3 +143,7 @@ fi
 if [[ $1 == "coverage" || $1 == "all" ]]; then
     coverage
 fi
+
+if [[ $1 == "netsim" || $1 == "all" ]]; then
+    netsim
+fi

+ 364 - 0
tests/netsim.rs

@@ -0,0 +1,364 @@
+use std::cell::RefCell;
+use std::collections::BinaryHeap;
+use std::fmt::Write as _;
+use std::io::Write as _;
+use std::sync::Mutex;
+
+use rand::{Rng, SeedableRng};
+use rand_chacha::ChaCha20Rng;
+use smoltcp::iface::{Config, Interface, SocketSet};
+use smoltcp::phy::Tracer;
+use smoltcp::phy::{self, ChecksumCapabilities, Device, DeviceCapabilities, Medium};
+use smoltcp::socket::tcp;
+use smoltcp::time::{Duration, Instant};
+use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress, IpCidr};
+
+const MAC_A: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([2, 0, 0, 0, 0, 1]));
+const MAC_B: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([2, 0, 0, 0, 0, 2]));
+const IP_A: IpAddress = IpAddress::v4(10, 0, 0, 1);
+const IP_B: IpAddress = IpAddress::v4(10, 0, 0, 2);
+
+const BYTES: usize = 10 * 1024 * 1024;
+
+static CLOCK: Mutex<(Instant, char)> = Mutex::new((Instant::ZERO, ' '));
+
+#[test]
+fn netsim() {
+    setup_logging();
+
+    let buffers = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768];
+    let losses = [0.0, 0.001, 0.01, 0.02, 0.05, 0.10, 0.20, 0.30];
+
+    let mut s = String::new();
+
+    write!(&mut s, "buf\\loss").unwrap();
+    for loss in losses {
+        write!(&mut s, "{loss:9.3} ").unwrap();
+    }
+    writeln!(&mut s).unwrap();
+
+    for buffer in buffers {
+        write!(&mut s, "{buffer:7}").unwrap();
+        for loss in losses {
+            let r = run_test(TestCase {
+                rtt: Duration::from_millis(100),
+                buffer,
+                loss,
+            });
+            write!(&mut s, " {r:9.2}").unwrap();
+        }
+        writeln!(&mut s).unwrap();
+    }
+
+    insta::assert_snapshot!(s);
+}
+
+struct TestCase {
+    rtt: Duration,
+    loss: f64,
+    buffer: usize,
+}
+
+fn run_test(case: TestCase) -> f64 {
+    let mut time = Instant::ZERO;
+
+    let params = QueueParams {
+        latency: case.rtt / 2,
+        loss: case.loss,
+    };
+    let queue_a_to_b = RefCell::new(PacketQueue::new(params.clone(), 0));
+    let queue_b_to_a = RefCell::new(PacketQueue::new(params.clone(), 1));
+    let device_a = QueueDevice::new(&queue_a_to_b, &queue_b_to_a, Medium::Ethernet);
+    let device_b = QueueDevice::new(&queue_b_to_a, &queue_a_to_b, Medium::Ethernet);
+
+    let mut device_a = Tracer::new(device_a, |_timestamp, _printer| log::trace!("{}", _printer));
+    let mut device_b = Tracer::new(device_b, |_timestamp, _printer| log::trace!("{}", _printer));
+
+    let mut iface_a = Interface::new(Config::new(MAC_A), &mut device_a, time);
+    iface_a.update_ip_addrs(|a| a.push(IpCidr::new(IP_A, 8)).unwrap());
+    let mut iface_b = Interface::new(Config::new(MAC_B), &mut device_b, time);
+    iface_b.update_ip_addrs(|a| a.push(IpCidr::new(IP_B, 8)).unwrap());
+
+    // Create sockets
+    let socket_a = {
+        let tcp_rx_buffer = tcp::SocketBuffer::new(vec![0; case.buffer]);
+        let tcp_tx_buffer = tcp::SocketBuffer::new(vec![0; case.buffer]);
+        tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer)
+    };
+
+    let socket_b = {
+        let tcp_rx_buffer = tcp::SocketBuffer::new(vec![0; case.buffer]);
+        let tcp_tx_buffer = tcp::SocketBuffer::new(vec![0; case.buffer]);
+        tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer)
+    };
+
+    let mut sockets_a: [_; 2] = Default::default();
+    let mut sockets_a = SocketSet::new(&mut sockets_a[..]);
+    let socket_a_handle = sockets_a.add(socket_a);
+
+    let mut sockets_b: [_; 2] = Default::default();
+    let mut sockets_b = SocketSet::new(&mut sockets_b[..]);
+    let socket_b_handle = sockets_b.add(socket_b);
+
+    let mut did_listen = false;
+    let mut did_connect = false;
+    let mut processed = 0;
+    while processed < BYTES {
+        *CLOCK.lock().unwrap() = (time, ' ');
+        log::info!("loop");
+        //println!("t = {}", time);
+
+        *CLOCK.lock().unwrap() = (time, 'A');
+
+        iface_a.poll(time, &mut device_a, &mut sockets_a);
+
+        let socket = sockets_a.get_mut::<tcp::Socket>(socket_a_handle);
+        if !socket.is_active() && !socket.is_listening() && !did_listen {
+            //println!("listening");
+            socket.listen(1234).unwrap();
+            did_listen = true;
+        }
+
+        while socket.can_recv() {
+            let received = socket.recv(|buffer| (buffer.len(), buffer.len())).unwrap();
+            //println!("got {:?}", received,);
+            processed += received;
+        }
+
+        *CLOCK.lock().unwrap() = (time, 'B');
+        iface_b.poll(time, &mut device_b, &mut sockets_b);
+        let socket = sockets_b.get_mut::<tcp::Socket>(socket_b_handle);
+        let cx = iface_b.context();
+        if !socket.is_open() && !did_connect {
+            //println!("connecting");
+            socket.connect(cx, (IP_A, 1234), 65000).unwrap();
+            did_connect = true;
+        }
+
+        while socket.can_send() {
+            //println!("sending");
+            socket.send(|buffer| (buffer.len(), ())).unwrap();
+        }
+
+        *CLOCK.lock().unwrap() = (time, ' ');
+
+        let mut next_time = queue_a_to_b.borrow_mut().next_expiration();
+        next_time = next_time.min(queue_b_to_a.borrow_mut().next_expiration());
+        if let Some(t) = iface_a.poll_at(time, &sockets_a) {
+            next_time = next_time.min(t);
+        }
+        if let Some(t) = iface_b.poll_at(time, &sockets_b) {
+            next_time = next_time.min(t);
+        }
+        assert!(next_time.total_micros() != i64::MAX);
+        time = time.max(next_time);
+    }
+
+    let duration = time - Instant::ZERO;
+    processed as f64 / duration.total_micros() as f64 * 1e6
+}
+
+struct Packet {
+    timestamp: Instant,
+    id: u64,
+    data: Vec<u8>,
+}
+
+impl PartialEq for Packet {
+    fn eq(&self, other: &Self) -> bool {
+        (other.timestamp, other.id) == (self.timestamp, self.id)
+    }
+}
+
+impl Eq for Packet {}
+
+impl PartialOrd for Packet {
+    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+        Some(self.cmp(other))
+    }
+}
+
+impl Ord for Packet {
+    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
+        (other.timestamp, other.id).cmp(&(self.timestamp, self.id))
+    }
+}
+
+#[derive(Clone)]
+struct QueueParams {
+    latency: Duration,
+    loss: f64,
+}
+
+struct PacketQueue {
+    queue: BinaryHeap<Packet>,
+    next_id: u64,
+    params: QueueParams,
+    rng: ChaCha20Rng,
+}
+
+impl PacketQueue {
+    pub fn new(params: QueueParams, seed: u64) -> Self {
+        Self {
+            queue: BinaryHeap::new(),
+            next_id: 0,
+            params,
+            rng: ChaCha20Rng::seed_from_u64(seed),
+        }
+    }
+
+    pub fn next_expiration(&self) -> Instant {
+        self.queue
+            .peek()
+            .map(|p| p.timestamp)
+            .unwrap_or(Instant::from_micros(i64::MAX))
+    }
+
+    pub fn push(&mut self, data: Vec<u8>, timestamp: Instant) {
+        if self.rng.gen::<f64>() < self.params.loss {
+            log::info!("PACKET LOST!");
+            return;
+        }
+
+        self.queue.push(Packet {
+            data,
+            id: self.next_id,
+            timestamp: timestamp + self.params.latency,
+        });
+        self.next_id += 1;
+    }
+
+    pub fn pop(&mut self, timestamp: Instant) -> Option<Vec<u8>> {
+        let p = self.queue.peek()?;
+        if p.timestamp > timestamp {
+            return None;
+        }
+        Some(self.queue.pop().unwrap().data)
+    }
+}
+
+pub struct QueueDevice<'a> {
+    tx_queue: &'a RefCell<PacketQueue>,
+    rx_queue: &'a RefCell<PacketQueue>,
+    medium: Medium,
+}
+
+impl<'a> QueueDevice<'a> {
+    fn new(
+        tx_queue: &'a RefCell<PacketQueue>,
+        rx_queue: &'a RefCell<PacketQueue>,
+        medium: Medium,
+    ) -> Self {
+        Self {
+            tx_queue,
+            rx_queue,
+            medium,
+        }
+    }
+}
+
+impl Device for QueueDevice<'_> {
+    type RxToken<'a>
+        = RxToken
+    where
+        Self: 'a;
+    type TxToken<'a>
+        = TxToken<'a>
+    where
+        Self: 'a;
+
+    fn capabilities(&self) -> DeviceCapabilities {
+        let mut caps = DeviceCapabilities::default();
+        caps.max_transmission_unit = 1514;
+        caps.medium = self.medium;
+        caps.checksum = ChecksumCapabilities::ignored();
+        caps
+    }
+
+    fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
+        self.rx_queue
+            .borrow_mut()
+            .pop(timestamp)
+            .map(move |buffer| {
+                let rx = RxToken { buffer };
+                let tx = TxToken {
+                    queue: self.tx_queue,
+                    timestamp,
+                };
+                (rx, tx)
+            })
+    }
+
+    fn transmit(&mut self, timestamp: Instant) -> Option<Self::TxToken<'_>> {
+        Some(TxToken {
+            queue: self.tx_queue,
+            timestamp,
+        })
+    }
+}
+
+pub struct RxToken {
+    buffer: Vec<u8>,
+}
+
+impl phy::RxToken for RxToken {
+    fn consume<R, F>(self, f: F) -> R
+    where
+        F: FnOnce(&[u8]) -> R,
+    {
+        f(&self.buffer)
+    }
+}
+
+pub struct TxToken<'a> {
+    queue: &'a RefCell<PacketQueue>,
+    timestamp: Instant,
+}
+
+impl<'a> phy::TxToken for TxToken<'a> {
+    fn consume<R, F>(self, len: usize, f: F) -> R
+    where
+        F: FnOnce(&mut [u8]) -> R,
+    {
+        let mut buffer = vec![0; len];
+        let result = f(&mut buffer);
+        self.queue.borrow_mut().push(buffer, self.timestamp);
+        result
+    }
+}
+
+pub fn setup_logging() {
+    env_logger::Builder::new()
+        .format(move |buf, record| {
+            let (elapsed, side) = *CLOCK.lock().unwrap();
+
+            let timestamp = format!("[{elapsed} {side}]");
+            if record.target().starts_with("smoltcp::") {
+                writeln!(
+                    buf,
+                    "{} ({}): {}",
+                    timestamp,
+                    record.target().replace("smoltcp::", ""),
+                    record.args()
+                )
+            } else if record.level() == log::Level::Trace {
+                let message = format!("{}", record.args());
+                writeln!(
+                    buf,
+                    "{} {}",
+                    timestamp,
+                    message.replace('\n', "\n             ")
+                )
+            } else {
+                writeln!(
+                    buf,
+                    "{} ({}): {}",
+                    timestamp,
+                    record.target(),
+                    record.args()
+                )
+            }
+        })
+        .parse_env("RUST_LOG")
+        .init();
+}

+ 15 - 0
tests/snapshots/netsim__netsim.snap

@@ -0,0 +1,15 @@
+---
+source: tests/netsim.rs
+expression: s
+snapshot_kind: text
+---
+buf\loss    0.000     0.001     0.010     0.020     0.050     0.100     0.200     0.300 
+    128   1279.98   1255.76   1054.15    886.36    538.66    227.84     33.99      7.18
+    256   2559.91   2507.27   2100.03   1770.30   1070.71    468.24     66.71     14.35
+    512   5119.63   5011.95   4172.36   3531.57   2098.73    942.38    144.73     29.45
+   1024  10238.50  10023.19   8340.90   7084.25   4003.34   1869.94    290.74     60.92
+   2048  17535.11  17171.82  14093.50  12063.90   7205.27   3379.12    824.76    131.54
+   4096  35062.41  33852.31  27011.08  22073.09  13680.70   7631.11   1617.81    302.65
+   8192  77374.28  72409.99  58428.68  48310.75  29123.30  14314.36   2880.39    551.60
+  16384 161842.28 159448.56 141467.31 127073.06  78239.08  38637.20   7565.64   1112.31
+  32768 322944.88 314313.90 266384.37 245985.29 138762.29  83162.99  10739.10   1951.95