cubic.rs 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. use crate::time::Instant;
  2. use super::Controller;
  3. // Constants for the Cubic congestion control algorithm.
  4. // See RFC 8312.
  5. const BETA_CUBIC: f64 = 0.7;
  6. const C: f64 = 0.4;
  7. #[derive(Debug)]
  8. #[cfg_attr(feature = "defmt", derive(defmt::Format))]
  9. pub struct Cubic {
  10. cwnd: usize, // Congestion window
  11. min_cwnd: usize, // The minimum size of congestion window
  12. w_max: usize, // Window size just before congestion
  13. recovery_start: Option<Instant>,
  14. rwnd: usize, // Remote window
  15. last_update: Instant,
  16. ssthresh: usize,
  17. }
  18. impl Cubic {
  19. pub fn new() -> Cubic {
  20. Cubic {
  21. cwnd: 1024 * 2,
  22. min_cwnd: 1024 * 2,
  23. w_max: 1024 * 2,
  24. recovery_start: None,
  25. rwnd: 64 * 1024,
  26. last_update: Instant::from_millis(0),
  27. ssthresh: usize::MAX,
  28. }
  29. }
  30. }
  31. impl Controller for Cubic {
  32. fn window(&self) -> usize {
  33. self.cwnd
  34. }
  35. fn on_retransmit(&mut self, now: Instant) {
  36. self.w_max = self.cwnd;
  37. self.ssthresh = self.cwnd >> 1;
  38. self.recovery_start = Some(now);
  39. }
  40. fn on_duplicate_ack(&mut self, now: Instant) {
  41. self.w_max = self.cwnd;
  42. self.ssthresh = self.cwnd >> 1;
  43. self.recovery_start = Some(now);
  44. }
  45. fn set_remote_window(&mut self, remote_window: usize) {
  46. if self.rwnd < remote_window {
  47. self.rwnd = remote_window;
  48. }
  49. }
  50. fn on_ack(&mut self, _now: Instant, len: usize, _rtt: &crate::socket::tcp::RttEstimator) {
  51. // Slow start.
  52. if self.cwnd < self.ssthresh {
  53. self.cwnd = self
  54. .cwnd
  55. .saturating_add(len)
  56. .min(self.rwnd)
  57. .max(self.min_cwnd);
  58. }
  59. }
  60. fn pre_transmit(&mut self, now: Instant) {
  61. let Some(recovery_start) = self.recovery_start else {
  62. self.recovery_start = Some(now);
  63. return;
  64. };
  65. let now_millis = now.total_millis();
  66. // If the last update was less than 100ms ago, don't update the congestion window.
  67. if self.last_update > recovery_start && now_millis - self.last_update.total_millis() < 100 {
  68. return;
  69. }
  70. // Elapsed time since the start of the recovery phase.
  71. let t = now_millis - recovery_start.total_millis();
  72. if t < 0 {
  73. return;
  74. }
  75. // K = (w_max * (1 - beta) / C)^(1/3)
  76. let k3 = ((self.w_max as f64) * (1.0 - BETA_CUBIC)) / C;
  77. let k = if let Some(k) = cube_root(k3) {
  78. k
  79. } else {
  80. return;
  81. };
  82. // cwnd = C(T - K)^3 + w_max
  83. let s = t as f64 / 1000.0 - k;
  84. let s = s * s * s;
  85. let cwnd = C * s + self.w_max as f64;
  86. self.last_update = now;
  87. self.cwnd = (cwnd as usize).max(self.min_cwnd).min(self.rwnd);
  88. }
  89. fn set_mss(&mut self, mss: usize) {
  90. self.min_cwnd = mss;
  91. }
  92. }
  93. #[inline]
  94. fn abs(a: f64) -> f64 {
  95. if a < 0.0 {
  96. -a
  97. } else {
  98. a
  99. }
  100. }
  101. /// Calculate cube root by using the Newton-Raphson method.
  102. fn cube_root(a: f64) -> Option<f64> {
  103. if a <= 0.0 {
  104. return None;
  105. }
  106. let (tolerance, init) = if a < 1_000.0 {
  107. (1.0, 8.879040017426005) // cube_root(700.0)
  108. } else if a < 1_000_000.0 {
  109. (5.0, 88.79040017426004) // cube_root(700_000.0)
  110. } else if a < 1_000_000_000.0 {
  111. (50.0, 887.9040017426004) // cube_root(700_000_000.0)
  112. } else if a < 1_000_000_000_000.0 {
  113. (500.0, 8879.040017426003) // cube_root(700_000_000_000.0)
  114. } else if a < 1_000_000_000_000_000.0 {
  115. (5000.0, 88790.40017426001) // cube_root(700_000_000_000.0)
  116. } else {
  117. (50000.0, 887904.0017426) // cube_root(700_000_000_000_000.0)
  118. };
  119. let mut x = init; // initial value
  120. let mut n = 20; // The maximum iteration
  121. loop {
  122. let next_x = (2.0 * x + a / (x * x)) / 3.0;
  123. if abs(next_x - x) < tolerance {
  124. return Some(next_x);
  125. }
  126. x = next_x;
  127. if n == 0 {
  128. return Some(next_x);
  129. }
  130. n -= 1;
  131. }
  132. }
  133. #[cfg(test)]
  134. mod test {
  135. use crate::{socket::tcp::RttEstimator, time::Instant};
  136. use super::*;
  137. #[test]
  138. fn test_cubic() {
  139. let remote_window = 64 * 1024 * 1024;
  140. let now = Instant::from_millis(0);
  141. for i in 0..10 {
  142. for j in 0..9 {
  143. let mut cubic = Cubic::new();
  144. // Set remote window.
  145. cubic.set_remote_window(remote_window);
  146. cubic.set_mss(1480);
  147. if i & 1 == 0 {
  148. cubic.on_retransmit(now);
  149. } else {
  150. cubic.on_duplicate_ack(now);
  151. }
  152. cubic.pre_transmit(now);
  153. let mut n = i;
  154. for _ in 0..j {
  155. n *= i;
  156. }
  157. let elapsed = Instant::from_millis(n);
  158. cubic.pre_transmit(elapsed);
  159. let cwnd = cubic.window();
  160. println!("Cubic: elapsed = {}, cwnd = {}", elapsed, cwnd);
  161. assert!(cwnd >= cubic.min_cwnd);
  162. assert!(cubic.window() <= remote_window);
  163. }
  164. }
  165. }
  166. #[test]
  167. fn cubic_time_inversion() {
  168. let mut cubic = Cubic::new();
  169. let t1 = Instant::from_micros(0);
  170. let t2 = Instant::from_micros(i64::MAX);
  171. cubic.on_retransmit(t2);
  172. cubic.pre_transmit(t1);
  173. let cwnd = cubic.window();
  174. println!("Cubic:time_inversion: cwnd: {}, cubic: {cubic:?}", cwnd);
  175. assert!(cwnd >= cubic.min_cwnd);
  176. assert!(cwnd <= cubic.rwnd);
  177. }
  178. #[test]
  179. fn cubic_long_elapsed_time() {
  180. let mut cubic = Cubic::new();
  181. let t1 = Instant::from_millis(0);
  182. let t2 = Instant::from_micros(i64::MAX);
  183. cubic.on_retransmit(t1);
  184. cubic.pre_transmit(t2);
  185. let cwnd = cubic.window();
  186. println!("Cubic:long_elapsed_time: cwnd: {}", cwnd);
  187. assert!(cwnd >= cubic.min_cwnd);
  188. assert!(cwnd <= cubic.rwnd);
  189. }
  190. #[test]
  191. fn cubic_last_update() {
  192. let mut cubic = Cubic::new();
  193. let t1 = Instant::from_millis(0);
  194. let t2 = Instant::from_millis(100);
  195. let t3 = Instant::from_millis(199);
  196. let t4 = Instant::from_millis(20000);
  197. cubic.on_retransmit(t1);
  198. cubic.pre_transmit(t2);
  199. let cwnd2 = cubic.window();
  200. cubic.pre_transmit(t3);
  201. let cwnd3 = cubic.window();
  202. cubic.pre_transmit(t4);
  203. let cwnd4 = cubic.window();
  204. println!(
  205. "Cubic:last_update: cwnd2: {}, cwnd3: {}, cwnd4: {}",
  206. cwnd2, cwnd3, cwnd4
  207. );
  208. assert_eq!(cwnd2, cwnd3);
  209. assert_ne!(cwnd2, cwnd4);
  210. }
  211. #[test]
  212. fn cubic_slow_start() {
  213. let mut cubic = Cubic::new();
  214. let t1 = Instant::from_micros(0);
  215. let cwnd = cubic.window();
  216. let ack_len = 1024;
  217. cubic.on_ack(t1, ack_len, &RttEstimator::default());
  218. assert!(cubic.window() > cwnd);
  219. for i in 1..1000 {
  220. let t2 = Instant::from_micros(i);
  221. cubic.on_ack(t2, ack_len * 100, &RttEstimator::default());
  222. assert!(cubic.window() <= cubic.rwnd);
  223. }
  224. let t3 = Instant::from_micros(2000);
  225. let cwnd = cubic.window();
  226. cubic.on_retransmit(t3);
  227. assert_eq!(cwnd >> 1, cubic.ssthresh);
  228. }
  229. #[test]
  230. fn cubic_pre_transmit() {
  231. let mut cubic = Cubic::new();
  232. cubic.pre_transmit(Instant::from_micros(2000));
  233. }
  234. #[test]
  235. fn test_cube_root() {
  236. for n in (1..1000000).step_by(99) {
  237. let a = n as f64;
  238. let a = a * a * a;
  239. let result = cube_root(a);
  240. println!("cube_root({a}) = {}", result.unwrap());
  241. }
  242. }
  243. #[test]
  244. #[should_panic]
  245. fn cube_root_zero() {
  246. cube_root(0.0).unwrap();
  247. }
  248. }