diff --git a/Cargo.lock b/Cargo.lock index 7bac3eb0..9d136107 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -376,9 +376,9 @@ checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] name = "base64" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64ct" @@ -1989,9 +1989,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "1.2.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" dependencies = [ "bytes", "futures-channel", @@ -2025,9 +2025,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.3" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" +checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" dependencies = [ "bytes", "futures-channel", @@ -3644,7 +3644,7 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e6cc1e89e689536eb5aeede61520e874df5a4707df811cd5da4aa5fbb2aae19" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "bytes", "encoding_rs", "futures-core", @@ -3797,13 +3797,26 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "rustls-pki-types", ] @@ -4367,7 +4380,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "936cac0ab331b14cb3921c62156d913e4c15b74fb6ec0f3146bd4ef6e4fb3c12" dependencies = [ "atoi", - "base64 0.22.0", + "base64 0.22.1", "bitflags 2.4.1", "byteorder", "bytes", @@ -4410,7 +4423,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9734dbce698c67ecf67c442f768a5e90a49b2a4d61a9f1d59f73874bd4cf0710" dependencies = [ "atoi", - "base64 0.22.0", + "base64 0.22.1", "bitflags 2.4.1", "byteorder", "crc", @@ -5021,8 +5034,11 @@ dependencies = [ name = "tauri-plugin-websocket" version = "0.0.0" dependencies = [ + "base64 0.22.1", "futures-util", "http 1.0.0", + "hyper", + "hyper-util", "log", "rand 0.8.5", "serde", @@ -5294,6 +5310,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls", + "rustls-pki-types", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.14" @@ -5313,10 +5340,13 @@ checksum = "489a59b6730eda1b0171fcfda8b121f4bee2b35cba8645ca35c5f7ba3eb736c1" dependencies = [ "futures-util", "log", - "native-tls", + "rustls", + "rustls-native-certs", + "rustls-pki-types", "tokio", - "tokio-native-tls", + "tokio-rustls", "tungstenite", + "webpki-roots", ] [[package]] @@ -5389,7 +5419,6 @@ dependencies = [ "tokio", "tower-layer", "tower-service", - "tracing", ] [[package]] @@ -5484,8 +5513,9 @@ dependencies = [ "http 1.0.0", "httparse", "log", - "native-tls", "rand 0.9.1", + "rustls", + "rustls-pki-types", "sha1", "thiserror 2.0.9", "utf-8", diff --git a/plugins/websocket/Cargo.toml b/plugins/websocket/Cargo.toml index 1391a032..7fa8c9ff 100644 --- a/plugins/websocket/Cargo.toml +++ b/plugins/websocket/Cargo.toml @@ -19,4 +19,10 @@ http = "1" rand = "0.8" futures-util = "0.3" tokio = { version = "1", features = ["net", "sync"] } -tokio-tungstenite = { version = "0.27", features = ["native-tls"] } +tokio-tungstenite = { version = "0.27", features = ["rustls-tls-webpki-roots"] } +hyper = { version = "1", features = ["client"] } +hyper-util = { version = "0.1", features = ["tokio", "http1"] } +base64 = "0.22" + +[features] +rustls-tls-native-roots = ["tokio-tungstenite/rustls-tls-native-roots"] \ No newline at end of file diff --git a/plugins/websocket/src/lib.rs b/plugins/websocket/src/lib.rs index cbf4e112..2beaf472 100644 --- a/plugins/websocket/src/lib.rs +++ b/plugins/websocket/src/lib.rs @@ -1,16 +1,23 @@ +use base64::prelude::{Engine, BASE64_STANDARD}; use futures_util::{stream::SplitSink, SinkExt, StreamExt}; -use http::header::{HeaderName, HeaderValue}; +use http::{ + header::{HeaderName, HeaderValue}, + Request, +}; +use hyper::client::conn; +use hyper_util::rt::TokioIo; use serde::{ser::Serializer, Deserialize, Serialize}; use tauri::{ api::ipc::{format_callback, CallbackFn}, plugin::{Builder as PluginBuilder, TauriPlugin}, - Manager, Runtime, State, Window, + AppHandle, Manager, Runtime, State, Window, }; use tokio::{net::TcpStream, sync::Mutex}; use tokio_tungstenite::{ - connect_async_tls_with_config, + client_async_tls_with_config, connect_async_tls_with_config, tungstenite::{ client::IntoClientRequest, + error::UrlError, protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig}, Message, }, @@ -19,10 +26,12 @@ use tokio_tungstenite::{ use std::collections::HashMap; use std::str::FromStr; +use std::sync::Mutex as StdMutex; type Id = u32; type WebSocket = WebSocketStream>; -type WebSocketWriter = SplitSink; +type WebSocketWriter = + SplitSink>, Message>; type Result = std::result::Result; #[derive(Debug, thiserror::Error)] @@ -35,6 +44,16 @@ enum Error { InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue), #[error(transparent)] InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName), + #[error(transparent)] + ProxyConnection(#[from] hyper::Error), + #[error("proxy returned status code: {0}")] + ProxyStatus(u16), + #[error(transparent)] + ProxyIo(#[from] std::io::Error), + #[error(transparent)] + ProxyHttp(#[from] http::Error), + #[error(transparent)] + ProxyJoinHandle(#[from] tokio::task::JoinError), } impl Serialize for Error { @@ -49,7 +68,27 @@ impl Serialize for Error { #[derive(Default)] struct ConnectionManager(Mutex>); -struct TlsConnector(Mutex>); +struct TlsConnector(StdMutex>); +struct ProxyConfigurationInternal(StdMutex>); + +#[derive(Clone)] +pub struct ProxyAuth { + pub username: String, + pub password: String, +} + +impl ProxyAuth { + pub fn encode(&self) -> String { + BASE64_STANDARD.encode(format!("{}:{}", self.username, self.password)) + } +} + +#[derive(Clone)] +pub struct ProxyConfiguration { + pub proxy_url: String, + pub proxy_port: u16, + pub auth: Option, +} #[derive(Deserialize)] #[serde(untagged, rename_all = "camelCase")] @@ -133,10 +172,6 @@ async fn connect( ) -> Result { let id = rand::random(); let mut request = url.into_client_request()?; - let tls_connector = match window.try_state::() { - Some(tls_connector) => tls_connector.0.lock().await.clone(), - None => None, - }; if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) { for (k, v) in headers { @@ -146,9 +181,21 @@ async fn connect( } } - let (ws_stream, _) = + let tls_connector = window + .try_state::() + .and_then(|c| c.0.lock().unwrap().clone()); + + let proxy_config = window + .try_state::() + .and_then(|c| c.0.lock().unwrap().clone()); + + let ws_stream = if let Some(proxy_config) = proxy_config { + connect_using_proxy(request, config, proxy_config, tls_connector).await? + } else { connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector) - .await?; + .await? + .0 + }; tauri::async_runtime::spawn(async move { let (write, read) = ws_stream.split(); @@ -182,7 +229,7 @@ async fn connect( }))) .unwrap() } - Ok(Message::Frame(_)) => serde_json::Value::Null, // This value can't be recieved. + Ok(Message::Frame(_)) => serde_json::Value::Null, // This value can't be received. Err(e) => serde_json::to_value(Error::from(e)).unwrap(), }; let js = format_callback(callback_function, &response) @@ -196,6 +243,62 @@ async fn connect( Ok(id) } +async fn connect_using_proxy( + request: Request<()>, + config: Option, + proxy_config: ProxyConfiguration, + tls_connector: Option, +) -> Result { + let domain = domain(&request)?; + let port = request + .uri() + .port_u16() + .or_else(|| match request.uri().scheme_str() { + Some("wss") => Some(443), + Some("ws") => Some(80), + _ => None, + }) + .ok_or(Error::Websocket( + tokio_tungstenite::tungstenite::Error::Url(UrlError::UnsupportedUrlScheme), + ))?; + + let tcp = TcpStream::connect(format!( + "{}:{}", + proxy_config.proxy_url, proxy_config.proxy_port + )) + .await?; + let io = TokioIo::new(tcp); + + let (mut request_sender, proxy_connection) = + conn::http1::handshake::, String>(io).await?; + let proxy_connection_task = tokio::spawn(proxy_connection.without_shutdown()); + + let addr = format!("{domain}:{port}"); + let mut req_builder = Request::connect(addr); + + if let Some(auth) = proxy_config.auth { + req_builder = req_builder.header("Proxy-Authorization", format!("Basic {}", auth.encode())); + } + + // TODO: This looks super fishy + let req = req_builder.body("".to_string())?; + let res = request_sender.send_request(req).await?; + if !res.status().is_success() { + return Err(Error::ProxyStatus(res.status().as_u16())); + } + + let proxied_tcp_socket = proxy_connection_task.await??.io.into_inner(); + let (ws_stream, _) = client_async_tls_with_config( + request, + proxied_tcp_socket, + config.map(Into::into), + tls_connector, + ) + .await?; + + Ok(ws_stream) +} + #[tauri::command] async fn send( manager: State<'_, ConnectionManager>, @@ -228,12 +331,14 @@ pub fn init() -> TauriPlugin { #[derive(Default)] pub struct Builder { tls_connector: Option, + proxy_configuration: Option, } impl Builder { pub fn new() -> Self { Self { tls_connector: None, + proxy_configuration: None, } } @@ -242,14 +347,52 @@ impl Builder { self } + pub fn proxy_configuration(mut self, proxy_configuration: ProxyConfiguration) -> Self { + self.proxy_configuration.replace(proxy_configuration); + self + } + pub fn build(self) -> TauriPlugin { PluginBuilder::new("websocket") .invoke_handler(tauri::generate_handler![connect, send]) .setup(|app| { app.manage(ConnectionManager::default()); - app.manage(TlsConnector(Mutex::new(self.tls_connector))); + app.manage(TlsConnector(StdMutex::new(self.tls_connector))); + app.manage(ProxyConfigurationInternal(StdMutex::new( + self.proxy_configuration, + ))); + Ok(()) }) .build() } } + +pub async fn reconfigure_proxy(app: &AppHandle, proxy_config: Option) { + if let Some(state) = app.try_state::() { + *state.0.lock().unwrap() = proxy_config; + } +} + +pub async fn reconfigure_tls_connector(app: &AppHandle, tls_connector: Option) { + if let Some(state) = app.try_state::() { + *state.0.lock().unwrap() = tls_connector; + } +} + +// Copied from tokio-tungstenite internal function (tokio-tungstenite/src/lib.rs) with the same name +// Get a domain from an URL. +#[allow(clippy::result_large_err)] +#[inline] +fn domain( + request: &tokio_tungstenite::tungstenite::handshake::client::Request, +) -> tokio_tungstenite::tungstenite::Result { + match request.uri().host() { + // rustls expects IPv6 addresses without the surrounding [] brackets + Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()), + Some(d) => Ok(d.to_string()), + None => Err(tokio_tungstenite::tungstenite::Error::Url( + tokio_tungstenite::tungstenite::error::UrlError::NoHostName, + )), + } +}