瀏覽代碼

Add support for IPv4 default gateway.

Egor Karavaev 7 年之前
父節點
當前提交
331dc10780
共有 8 個文件被更改,包括 95 次插入49 次删除
  1. 7 0
      README.md
  2. 6 5
      examples/client.rs
  3. 3 3
      examples/loopback.rs
  4. 10 8
      examples/ping.rs
  5. 3 3
      examples/server.rs
  6. 54 18
      src/iface/ethernet.rs
  7. 2 2
      src/socket/tcp.rs
  8. 10 10
      src/wire/ip.rs

+ 7 - 0
README.md

@@ -150,6 +150,13 @@ sudo ip link set tap0 up
 sudo ip addr add 192.168.69.100/24 dev tap0
 ```
 
+It's possible to let _smoltcp_ access Internet by enabling routing for the tap interface:
+
+```sh
+sudo iptables -t nat -A POSTROUTING -s 192.168.69.0/24 -j MASQUERADE
+sudo sysctl net.ipv4.ip_forward=1
+```
+
 ### Fault injection
 
 In order to demonstrate the response of _smoltcp_ to adverse network conditions, all examples

+ 6 - 5
examples/client.rs

@@ -10,7 +10,7 @@ use std::str::{self, FromStr};
 use std::time::Instant;
 use std::os::unix::io::AsRawFd;
 use smoltcp::phy::wait as phy_wait;
-use smoltcp::wire::{EthernetAddress, IpAddress};
+use smoltcp::wire::{EthernetAddress, Ipv4Address, IpAddress, IpCidr};
 use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
 use smoltcp::socket::{AsSocket, SocketSet};
 use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
@@ -40,17 +40,18 @@ fn main() {
     let tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
 
     let hardware_addr  = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]);
-    let protocol_addr  = IpAddress::v4(192, 168, 69, 2);
-    let mut iface      = EthernetInterface::new(
+    let protocol_addrs = [IpCidr::new(IpAddress::v4(192, 168, 69, 2), 24)];
+    let default_v4_gw  = Ipv4Address::new(192, 168, 69, 100);
+    let mut iface = EthernetInterface::new(
         Box::new(device), Box::new(arp_cache) as Box<ArpCache>,
-        hardware_addr, [protocol_addr]);
+        hardware_addr, protocol_addrs, Some(default_v4_gw));
 
     let mut sockets = SocketSet::new(vec![]);
     let tcp_handle = sockets.add(tcp_socket);
 
     {
         let socket: &mut TcpSocket = sockets.get_mut(tcp_handle).as_socket();
-        socket.connect((address, port), (protocol_addr, 49500)).unwrap();
+        socket.connect((address, port), 49500).unwrap();
     }
 
     let mut tcp_active = false;

+ 3 - 3
examples/loopback.rs

@@ -17,7 +17,7 @@ mod utils;
 
 use core::str;
 use smoltcp::phy::Loopback;
-use smoltcp::wire::{EthernetAddress, IpAddress};
+use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr};
 use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
 use smoltcp::socket::{AsSocket, SocketSet};
 use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
@@ -90,10 +90,10 @@ fn main() {
     let mut arp_cache = SliceArpCache::new(&mut arp_cache_entries[..]);
 
     let     hardware_addr  = EthernetAddress::default();
-    let mut protocol_addrs = [IpAddress::v4(127, 0, 0, 1)];
+    let mut protocol_addrs = [IpCidr::new(IpAddress::v4(127, 0, 0, 1), 24)];
     let mut iface = EthernetInterface::new(
         &mut device, &mut arp_cache as &mut ArpCache,
-        hardware_addr, &mut protocol_addrs[..]);
+        hardware_addr, &mut protocol_addrs[..], None);
 
     let server_socket = {
         // It is not strictly necessary to use a `static mut` and unsafe code here, but

+ 10 - 8
examples/ping.rs

@@ -12,7 +12,7 @@ use std::time::Instant;
 use std::os::unix::io::AsRawFd;
 use smoltcp::phy::Device;
 use smoltcp::phy::wait as phy_wait;
-use smoltcp::wire::{EthernetAddress, IpVersion, IpProtocol, IpAddress,
+use smoltcp::wire::{EthernetAddress, IpVersion, IpProtocol, IpAddress, IpCidr,
                     Ipv4Address, Ipv4Packet, Ipv4Repr,
                     Icmpv4Repr, Icmpv4Packet};
 use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
@@ -39,6 +39,7 @@ fn main() {
     let device = utils::parse_tap_options(&mut matches);
     let fd = device.as_raw_fd();
     let device = utils::parse_middleware_options(&mut matches, device, /*loopback=*/false);
+    let device_caps = device.capabilities();
     let address  = Ipv4Address::from_str(&matches.free[0]).expect("invalid address format");
     let count    = matches.opt_str("count").map(|s| usize::from_str(&s).unwrap()).unwrap_or(4);
     let interval = matches.opt_str("interval").map(|s| u64::from_str(&s).unwrap()).unwrap_or(1);
@@ -56,11 +57,12 @@ fn main() {
     let raw_socket = RawSocket::new(IpVersion::Ipv4, IpProtocol::Icmp,
                                     raw_rx_buffer, raw_tx_buffer);
 
-    let hardware_addr  = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]);
-    let caps = device.capabilities();
+    let hardware_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]);
+    let protocol_addr = IpCidr::new(IpAddress::from(local_addr), 24);
+    let default_v4_gw = Ipv4Address::new(192, 168, 69, 100);
     let mut iface = EthernetInterface::new(
         Box::new(device), Box::new(arp_cache) as Box<ArpCache>,
-        hardware_addr, [IpAddress::from(local_addr)]);
+        hardware_addr, [protocol_addr], Some(default_v4_gw));
 
     let mut sockets = SocketSet::new(vec![]);
     let raw_handle = sockets.add(raw_socket);
@@ -100,9 +102,9 @@ fn main() {
                     .unwrap();
 
                 let mut ipv4_packet = Ipv4Packet::new(raw_payload);
-                ipv4_repr.emit(&mut ipv4_packet, &caps.checksum);
+                ipv4_repr.emit(&mut ipv4_packet, &device_caps.checksum);
                 let mut icmp_packet = Icmpv4Packet::new(ipv4_packet.payload_mut());
-                icmp_repr.emit(&mut icmp_packet, &caps.checksum);
+                icmp_repr.emit(&mut icmp_packet, &device_caps.checksum);
 
                 waiting_queue.insert(seq_no, timestamp);
                 seq_no += 1;
@@ -112,11 +114,11 @@ fn main() {
             if socket.can_recv() {
                 let payload = socket.recv().unwrap();
                 let ipv4_packet = Ipv4Packet::new(payload);
-                let ipv4_repr = Ipv4Repr::parse(&ipv4_packet, &caps.checksum).unwrap();
+                let ipv4_repr = Ipv4Repr::parse(&ipv4_packet, &device_caps.checksum).unwrap();
 
                 if ipv4_repr.src_addr == remote_addr && ipv4_repr.dst_addr == local_addr {
                     let icmp_packet = Icmpv4Packet::new(ipv4_packet.payload());
-                    let icmp_repr = Icmpv4Repr::parse(&icmp_packet, &caps.checksum);
+                    let icmp_repr = Icmpv4Repr::parse(&icmp_packet, &device_caps.checksum);
 
                     if let Ok(Icmpv4Repr::EchoReply { seq_no, data, .. }) = icmp_repr {
                         if let Some(_) = waiting_queue.get(&seq_no) {

+ 3 - 3
examples/server.rs

@@ -11,7 +11,7 @@ use std::fmt::Write;
 use std::time::Instant;
 use std::os::unix::io::AsRawFd;
 use smoltcp::phy::wait as phy_wait;
-use smoltcp::wire::{EthernetAddress, IpAddress};
+use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr};
 use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
 use smoltcp::socket::{AsSocket, SocketSet};
 use smoltcp::socket::{UdpSocket, UdpSocketBuffer, UdpPacketBuffer};
@@ -54,10 +54,10 @@ fn main() {
     let tcp4_socket = TcpSocket::new(tcp4_rx_buffer, tcp4_tx_buffer);
 
     let hardware_addr  = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
-    let protocol_addrs = [IpAddress::v4(192, 168, 69, 1)];
+    let protocol_addrs = [IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)];
     let mut iface      = EthernetInterface::new(
         Box::new(device), Box::new(arp_cache) as Box<ArpCache>,
-        hardware_addr, protocol_addrs);
+        hardware_addr, protocol_addrs, None);
 
     let mut sockets = SocketSet::new(vec![]);
     let udp_handle  = sockets.add(udp_socket);

+ 54 - 18
src/iface/ethernet.rs

@@ -6,10 +6,11 @@ use managed::{Managed, ManagedSlice};
 use {Error, Result};
 use phy::Device;
 use wire::{EthernetAddress, EthernetProtocol, EthernetFrame};
+use wire::{Ipv4Address};
+use wire::{IpAddress, IpProtocol, IpRepr, IpCidr};
 use wire::{ArpPacket, ArpRepr, ArpOperation};
 use wire::{Ipv4Packet, Ipv4Repr};
 use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv4DstUnreachable};
-use wire::{IpAddress, IpProtocol, IpRepr};
 #[cfg(feature = "socket-udp")] use wire::{UdpPacket, UdpRepr};
 #[cfg(feature = "socket-tcp")] use wire::{TcpPacket, TcpRepr, TcpControl};
 use socket::{Socket, SocketSet, AsSocket};
@@ -27,7 +28,8 @@ pub struct Interface<'a, 'b, 'c, DeviceT: Device + 'a> {
     device:         Managed<'a, DeviceT>,
     arp_cache:      Managed<'b, ArpCache>,
     hardware_addr:  EthernetAddress,
-    protocol_addrs: ManagedSlice<'c, IpAddress>,
+    protocol_addrs: ManagedSlice<'c, IpCidr>,
+    ipv4_gateway:   Option<Ipv4Address>,
 }
 
 enum Packet<'a> {
@@ -48,24 +50,29 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
     /// # Panics
     /// See the restrictions on [set_hardware_addr](#method.set_hardware_addr)
     /// and [set_protocol_addrs](#method.set_protocol_addrs) functions.
-    pub fn new<DeviceMT, ArpCacheMT, ProtocolAddrsMT>
+    pub fn new<DeviceMT, ArpCacheMT, ProtocolAddrsMT, Ipv4GatewayAddrT>
               (device: DeviceMT, arp_cache: ArpCacheMT,
-               hardware_addr: EthernetAddress, protocol_addrs: ProtocolAddrsMT) ->
+               hardware_addr: EthernetAddress,
+               protocol_addrs: ProtocolAddrsMT,
+               ipv4_gateway: Ipv4GatewayAddrT) ->
               Interface<'a, 'b, 'c, DeviceT>
             where DeviceMT: Into<Managed<'a, DeviceT>>,
                   ArpCacheMT: Into<Managed<'b, ArpCache>>,
-                  ProtocolAddrsMT: Into<ManagedSlice<'c, IpAddress>>, {
+                  ProtocolAddrsMT: Into<ManagedSlice<'c, IpCidr>>,
+                  Ipv4GatewayAddrT: Into<Option<Ipv4Address>>, {
         let device = device.into();
         let arp_cache = arp_cache.into();
         let protocol_addrs = protocol_addrs.into();
+        let ipv4_gateway = ipv4_gateway.into();
 
         Self::check_hardware_addr(&hardware_addr);
         Self::check_protocol_addrs(&protocol_addrs);
         Interface {
-            device:         device,
-            arp_cache:      arp_cache,
-            hardware_addr:  hardware_addr,
-            protocol_addrs: protocol_addrs,
+            device,
+            arp_cache,
+            hardware_addr,
+            protocol_addrs,
+            ipv4_gateway,
         }
     }
 
@@ -89,16 +96,16 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
         Self::check_hardware_addr(&self.hardware_addr);
     }
 
-    fn check_protocol_addrs(addrs: &[IpAddress]) {
-        for addr in addrs {
-            if !addr.is_unicast() {
-                panic!("protocol address {} is not unicast", addr)
+    fn check_protocol_addrs(addrs: &[IpCidr]) {
+        for cidr in addrs {
+            if !cidr.address().is_unicast() {
+                panic!("protocol address {} is not unicast", cidr.address())
             }
         }
     }
 
     /// Get the protocol addresses of the interface.
-    pub fn protocol_addrs(&self) -> &[IpAddress] {
+    pub fn protocol_addrs(&self) -> &[IpCidr] {
         self.protocol_addrs.as_ref()
     }
 
@@ -106,7 +113,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
     ///
     /// # Panics
     /// This function panics if any of the addresses is not unicast.
-    pub fn update_protocol_addrs<F: FnOnce(&mut ManagedSlice<'c, IpAddress>)>(&mut self, f: F) {
+    pub fn update_protocol_addrs<F: FnOnce(&mut ManagedSlice<'c, IpCidr>)>(&mut self, f: F) {
         f(&mut self.protocol_addrs);
         Self::check_protocol_addrs(&self.protocol_addrs)
     }
@@ -114,7 +121,18 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
     /// Check whether the interface has the given protocol address assigned.
     pub fn has_protocol_addr<T: Into<IpAddress>>(&self, addr: T) -> bool {
         let addr = addr.into();
-        self.protocol_addrs.iter().any(|&probe| probe == addr)
+        self.protocol_addrs.iter().any(|probe| probe.address() == addr)
+    }
+
+    /// Get the IPv4 gateway of the interface.
+    pub fn ipv4_gateway(&self) -> Option<Ipv4Address> {
+        self.ipv4_gateway
+    }
+
+    /// Set the IPv4 gateway of the interface.
+    pub fn set_ipv4_gateway<GatewayAddrT>(&mut self, gateway: GatewayAddrT)
+            where GatewayAddrT: Into<Option<Ipv4Address>> {
+        self.ipv4_gateway = gateway.into();
     }
 
     /// Transmit packets queued in the given sockets, and receive packets queued
@@ -541,10 +559,28 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
         Ok(())
     }
 
+    fn route_address(&self, addr: &IpAddress) -> Result<IpAddress> {
+        self.protocol_addrs
+            .iter()
+            .find(|cidr| cidr.contains_addr(&addr))
+            .map(|_cidr| Ok(addr.clone())) // route directly
+            .unwrap_or_else(|| {
+                match (addr, self.ipv4_gateway) {
+                    // route via a gateway
+                    (&IpAddress::Ipv4(_), Some(gateway)) =>
+                        Ok(gateway.into()),
+                    // unroutable
+                    _ => Err(Error::Unaddressable)
+                }
+            })
+    }
+
     fn lookup_hardware_addr(&mut self, timestamp: u64,
                             src_addr: &IpAddress, dst_addr: &IpAddress) ->
                            Result<EthernetAddress> {
-        if let Some(hardware_addr) = self.arp_cache.lookup(dst_addr) {
+        let dst_addr = self.route_address(dst_addr)?;
+
+        if let Some(hardware_addr) = self.arp_cache.lookup(&dst_addr) {
             return Ok(hardware_addr)
         }
 
@@ -553,7 +589,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
         }
 
         match (src_addr, dst_addr) {
-            (&IpAddress::Ipv4(src_addr), &IpAddress::Ipv4(dst_addr)) => {
+            (&IpAddress::Ipv4(src_addr), IpAddress::Ipv4(dst_addr)) => {
                 net_debug!("address {} not in ARP cache, sending request",
                            dst_addr);
 

+ 2 - 2
src/socket/tcp.rs

@@ -1450,7 +1450,7 @@ impl<'a> fmt::Write for TcpSocket<'a> {
 
 #[cfg(test)]
 mod test {
-    use wire::{IpAddress, Ipv4Address};
+    use wire::{IpAddress, Ipv4Address, IpCidr};
     use super::*;
 
     #[test]
@@ -1529,7 +1529,7 @@ mod test {
         let mut caps = DeviceCapabilities::default();
         caps.max_transmission_unit = 1520;
         let result = socket.dispatch(timestamp, &caps, |(ip_repr, tcp_repr)| {
-            let ip_repr = ip_repr.lower(&[LOCAL_END.addr.into()]).unwrap();
+            let ip_repr = ip_repr.lower(&[IpCidr::new(LOCAL_END.addr, 24)]).unwrap();
 
             assert_eq!(ip_repr.protocol(), IpProtocol::Tcp);
             assert_eq!(ip_repr.src_addr(), LOCAL_IP);

+ 10 - 10
src/wire/ip.rs

@@ -342,7 +342,7 @@ impl IpRepr {
     /// # Panics
     /// This function panics if source and destination addresses belong to different families,
     /// or the destination address is unspecified, since this indicates a logic error.
-    pub fn lower(&self, fallback_src_addrs: &[Address]) -> Result<IpRepr> {
+    pub fn lower(&self, fallback_src_addrs: &[Cidr]) -> Result<IpRepr> {
         match self {
             &IpRepr::Unspecified {
                 src_addr: Address::Ipv4(src_addr),
@@ -363,9 +363,9 @@ impl IpRepr {
                 protocol, payload_len
             } => {
                 let mut src_addr = None;
-                for addr in fallback_src_addrs {
-                    match addr {
-                        &Address::Ipv4(addr) => {
+                for cidr in fallback_src_addrs {
+                    match cidr.address() {
+                        Address::Ipv4(addr) => {
                             src_addr = Some(addr);
                             break
                         }
@@ -388,9 +388,9 @@ impl IpRepr {
 
             &IpRepr::Ipv4(mut repr) => {
                 if repr.src_addr.is_unspecified() {
-                    for addr in fallback_src_addrs {
-                        match addr {
-                            &Address::Ipv4(addr) => {
+                    for cidr in fallback_src_addrs {
+                        match cidr.address() {
+                            Address::Ipv4(addr) => {
                                 repr.src_addr = addr;
                                 return Ok(IpRepr::Ipv4(repr));
                             }
@@ -521,7 +521,7 @@ pub mod checksum {
 #[cfg(test)]
 mod test {
     use super::*;
-    use wire::{Ipv4Address, IpProtocol, IpAddress, Ipv4Repr};
+    use wire::{Ipv4Address, IpProtocol, IpAddress, Ipv4Repr, IpCidr};
     #[test]
     fn ip_repr_lower() {
         let ip_addr_a = Ipv4Address::new(1, 2, 3, 4);
@@ -560,7 +560,7 @@ mod test {
                 dst_addr: IpAddress::Ipv4(ip_addr_b),
                 protocol: proto,
                 payload_len
-            }.lower(&[IpAddress::Ipv4(ip_addr_a)]),
+            }.lower(&[IpCidr::new(IpAddress::Ipv4(ip_addr_a), 24)]),
             Ok(IpRepr::Ipv4(Ipv4Repr{
                 src_addr: ip_addr_a,
                 dst_addr: ip_addr_b,
@@ -600,7 +600,7 @@ mod test {
                 dst_addr: ip_addr_b,
                 protocol: proto,
                 payload_len
-            }).lower(&[IpAddress::Ipv4(ip_addr_a)]),
+            }).lower(&[IpCidr::new(IpAddress::Ipv4(ip_addr_a), 24)]),
             Ok(IpRepr::Ipv4(Ipv4Repr{
                 src_addr: ip_addr_a,
                 dst_addr: ip_addr_b,