Browse Source

Merge pull request #66 from larsch/https-fix

Fix HTTPS protocol buffering
Benjamin Sago 3 years ago
parent
commit
848697a9b8
1 changed files with 32 additions and 5 deletions
  1. 32 5
      dns-transport/src/https.rs

+ 32 - 5
dns-transport/src/https.rs

@@ -5,10 +5,9 @@ use std::net::TcpStream;
 
 use log::*;
 
-use dns::{Request, Response};
+use dns::{Request, Response, WireError};
 use super::{Transport, Error};
 
-
 /// The **HTTPS transport**, which sends DNS wire data inside HTTP packets
 /// encrypted with TLS, using TCP.
 pub struct HttpsTransport {
@@ -23,6 +22,15 @@ impl HttpsTransport {
     }
 }
 
+fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
+    haystack.windows(needle.len()).position(|window| window == needle)
+}
+
+fn contains_header(buf: &[u8]) -> bool {
+    let header_end: [u8; 4] = [ 13, 10, 13, 10 ];
+    find_subsequence(buf, &header_end).is_some()
+}
+
 impl Transport for HttpsTransport {
 
     #[cfg(feature = "with_https")]
@@ -53,13 +61,19 @@ impl Transport for HttpsTransport {
 
         info!("Waiting to receive...");
         let mut buf = [0; 4096];
-        let read_len = stream.read(&mut buf)?;
+        let mut read_len = stream.read(&mut buf)?;
+        while !contains_header(&buf[0..read_len]) {
+            if read_len == buf.len() {
+                return Err(Error::WireError(WireError::IO));
+            }
+            read_len += stream.read(&mut buf[read_len..])?;
+        }
+        let mut expected_len = read_len;
         info!("Received {} bytes of data", read_len);
 
         let mut headers = [httparse::EMPTY_HEADER; 16];
         let mut response = httparse::Response::new(&mut headers);
         let index: usize = response.parse(&buf)?.unwrap();
-        let body = &buf[index .. read_len];
 
         if response.code != Some(200) {
             let reason = response.reason.map(str::to_owned);
@@ -67,9 +81,22 @@ impl Transport for HttpsTransport {
         }
 
         for header in response.headers {
-            debug!("Header {:?} -> {:?}", header.name, String::from_utf8_lossy(header.value));
+            let str_value = String::from_utf8_lossy(header.value);
+            debug!("Header {:?} -> {:?}", header.name, str_value);
+            if header.name == "Content-Length" {
+                let content_length: usize = str_value.parse().unwrap();
+                expected_len = index + content_length;
+            }
+        }
+
+        while read_len < expected_len {
+            if read_len == buf.len() {
+                return Err(Error::WireError(WireError::IO));
+            }
+            read_len += stream.read(&mut buf[read_len..])?;
         }
 
+        let body = &buf[index .. read_len];
         debug!("HTTP body has {} bytes", body.len());
         let response = Response::from_bytes(&body)?;
         Ok(response)