tls.rs 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. use async_trait::async_trait;
  2. use log::*;
  3. use native_tls::TlsConnector;
  4. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  5. use tokio::net::TcpStream;
  6. use dns::{Request, Response};
  7. use super::{Transport, Error};
  8. /// The **TLS transport**, which uses Tokio.
  9. ///
  10. /// # Examples
  11. ///
  12. /// ```no_run
  13. /// use dns_transport::{Transport, TlsTransport};
  14. /// use dns::{Request, Flags, Query, Labels, QClass, qtype, record::SRV};
  15. ///
  16. /// let query = Query {
  17. /// qname: Labels::encode("dns.lookup.dog").unwrap(),
  18. /// qclass: QClass::IN,
  19. /// qtype: qtype!(SRV),
  20. /// };
  21. ///
  22. /// let request = Request {
  23. /// transaction_id: 0xABCD,
  24. /// flags: Flags::query(),
  25. /// query: query,
  26. /// additional: None,
  27. /// };
  28. ///
  29. /// let transport = TlsTransport::new("dns.google");
  30. /// transport.send(&request);
  31. /// ```
  32. #[derive(Debug)]
  33. pub struct TlsTransport {
  34. addr: String,
  35. }
  36. impl TlsTransport {
  37. /// Creates a new TLS transport that connects to the given host.
  38. pub fn new(sa: impl Into<String>) -> Self {
  39. let addr = sa.into();
  40. Self { addr }
  41. }
  42. }
  43. #[async_trait]
  44. impl Transport for TlsTransport {
  45. async fn send(&self, request: &Request) -> Result<Response, Error> {
  46. let connector = TlsConnector::new()?;
  47. let connector = tokio_tls::TlsConnector::from(connector);
  48. info!("Opening TLS socket");
  49. let stream =
  50. if self.addr.contains(':') {
  51. TcpStream::connect(&*self.addr).await?
  52. }
  53. else {
  54. TcpStream::connect((&*self.addr, 853)).await?
  55. };
  56. info!("Connecting");
  57. let mut stream = connector.connect(self.sni_domain(), stream).await?;
  58. // As with TCP, we need to prepend the message with its length.
  59. let mut bytes = request.to_bytes().expect("failed to serialise request");
  60. let len_bytes = (bytes.len() as u16).to_be_bytes();
  61. bytes.insert(0, len_bytes[0]);
  62. bytes.insert(1, len_bytes[1]);
  63. info!("Sending {} bytes of data to {}", bytes.len(), self.addr);
  64. stream.write_all(&bytes).await?;
  65. debug!("Sent");
  66. info!("Waiting to receive...");
  67. let mut buf = [0; 4096];
  68. let len = stream.read(&mut buf).await?;
  69. // Remember to deal with the length again.
  70. info!("Received {} bytes of data", buf.len());
  71. let response = Response::from_bytes(&buf[2..len])?;
  72. Ok(response)
  73. }
  74. }
  75. impl TlsTransport {
  76. fn sni_domain(&self) -> &str {
  77. if let Some(colon_index) = self.addr.find(':') {
  78. &self.addr[.. colon_index]
  79. }
  80. else {
  81. &self.addr[..]
  82. }
  83. }
  84. }