feat(websocket): Add proxy configuration

* Allow usage of http(s) proxy for websocket traffic
pull/1536/head
Jonas Osburg 1 year ago
parent a32008965b
commit d25323b7be
No known key found for this signature in database
GPG Key ID: E0E4E79398D3A172

22
Cargo.lock generated

@ -394,9 +394,9 @@ checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
[[package]]
name = "base64"
version = "0.22.0"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "base64ct"
@ -1990,9 +1990,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
[[package]]
name = "hyper"
version = "1.2.0"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a"
checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05"
dependencies = [
"bytes",
"futures-channel",
@ -2026,9 +2026,9 @@ dependencies = [
[[package]]
name = "hyper-util"
version = "0.1.3"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa"
checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956"
dependencies = [
"bytes",
"futures-channel",
@ -3618,7 +3618,7 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e6cc1e89e689536eb5aeede61520e874df5a4707df811cd5da4aa5fbb2aae19"
dependencies = [
"base64 0.22.0",
"base64 0.22.1",
"bytes",
"encoding_rs",
"futures-core",
@ -3789,7 +3789,7 @@ version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d"
dependencies = [
"base64 0.22.0",
"base64 0.22.1",
"rustls-pki-types",
]
@ -4845,7 +4845,7 @@ name = "tauri-plugin-authenticator"
version = "0.0.0"
dependencies = [
"authenticator",
"base64 0.22.0",
"base64 0.22.1",
"byteorder",
"bytes",
"chrono",
@ -5029,8 +5029,11 @@ dependencies = [
name = "tauri-plugin-websocket"
version = "0.0.0"
dependencies = [
"base64 0.22.1",
"futures-util",
"http 1.0.0",
"hyper",
"hyper-util",
"log",
"rand 0.8.5",
"serde",
@ -5373,7 +5376,6 @@ dependencies = [
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]

@ -20,3 +20,6 @@ rand = "0.8"
futures-util = "0.3"
tokio = { version = "1", features = ["net", "sync"] }
tokio-tungstenite = { version = "0.23", features = ["native-tls"] }
hyper = { version = "1.4.1", features = ["client"] }
hyper-util = { version = "0.1.6", features = ["tokio", "http1"] }
base64 = "0.22.1"

@ -1,16 +1,23 @@
use base64::prelude::{Engine, BASE64_STANDARD};
use futures_util::{stream::SplitSink, SinkExt, StreamExt};
use http::header::{HeaderName, HeaderValue};
use http::{
header::{HeaderName, HeaderValue},
Request,
};
use hyper::client::conn;
use hyper_util::rt::TokioIo;
use serde::{ser::Serializer, Deserialize, Serialize};
use tauri::{
api::ipc::{format_callback, CallbackFn},
plugin::{Builder as PluginBuilder, TauriPlugin},
Manager, Runtime, State, Window,
AppHandle, Manager, Runtime, State, Window,
};
use tokio::{net::TcpStream, sync::Mutex};
use tokio_tungstenite::{
connect_async_tls_with_config,
client_async_tls_with_config, connect_async_with_config,
tungstenite::{
client::IntoClientRequest,
error::UrlError,
protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig},
Message,
},
@ -22,7 +29,8 @@ use std::str::FromStr;
type Id = u32;
type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
type WebSocketWriter = SplitSink<WebSocket, Message>;
type WebSocketWriter =
SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>, Message>;
type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, thiserror::Error)]
@ -35,6 +43,14 @@ enum Error {
InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue),
#[error(transparent)]
InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName),
#[error(transparent)]
ProxyConnectionError(#[from] hyper::Error),
#[error("proxy returned status code: {0}")]
ProxyStatusError(u16),
#[error(transparent)]
ProxyIoError(std::io::Error),
#[error(transparent)]
ProxyHttpError(http::Error),
}
impl Serialize for Error {
@ -50,6 +66,26 @@ impl Serialize for Error {
struct ConnectionManager(Mutex<HashMap<Id, WebSocketWriter>>);
struct TlsConnector(Mutex<Option<Connector>>);
struct ProxyConfigurationInternal(Mutex<Option<ProxyConfiguration>>);
#[derive(Clone)]
pub struct ProxyAuth {
pub username: String,
pub password: String,
}
impl ProxyAuth {
pub fn encode(&self) -> String {
BASE64_STANDARD.encode(format!("{}:{}", self.username, self.password))
}
}
#[derive(Clone)]
pub struct ProxyConfiguration {
pub proxy_url: String,
pub proxy_port: u16,
pub auth: Option<ProxyAuth>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
@ -105,10 +141,6 @@ async fn connect<R: Runtime>(
) -> Result<Id> {
let id = rand::random();
let mut request = url.into_client_request()?;
let tls_connector = match window.try_state::<TlsConnector>() {
Some(tls_connector) => tls_connector.0.lock().await.clone(),
None => None,
};
if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) {
for (k, v) in headers {
@ -118,9 +150,32 @@ async fn connect<R: Runtime>(
}
}
let (ws_stream, _) =
connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector)
.await?;
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
let tls_connector = match window.try_state::<TlsConnector>() {
Some(tls_connector) => tls_connector.0.lock().await.clone(),
None => None,
};
#[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))]
let tls_connector = None;
let proxy_config = match window.try_state::<ProxyConfigurationInternal>() {
Some(proxy_config) => proxy_config.0.lock().await.clone(),
None => None,
};
let ws_stream = if let Some(proxy_config) = proxy_config {
connect_using_proxy(request, config, proxy_config, tls_connector).await?
} else {
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
let (ws_stream, _) =
connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector)
.await?;
#[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))]
let (ws_stream, _) =
connect_async_with_config(request, config.map(Into::into), false).await?;
ws_stream
};
tauri::async_runtime::spawn(async move {
let (write, read) = ws_stream.split();
@ -168,6 +223,70 @@ async fn connect<R: Runtime>(
Ok(id)
}
async fn connect_using_proxy(
request: Request<()>,
config: Option<ConnectionConfig>,
proxy_config: ProxyConfiguration,
tls_connector: Option<Connector>,
) -> Result<WebSocket> {
let domain = domain(&request)?;
let port = request
.uri()
.port_u16()
.or_else(|| match request.uri().scheme_str() {
Some("wss") => Some(443),
Some("ws") => Some(80),
_ => None,
})
.ok_or(Error::Websocket(
tokio_tungstenite::tungstenite::Error::Url(UrlError::UnsupportedUrlScheme),
))?;
let tcp = TcpStream::connect(format!(
"{}:{}",
proxy_config.proxy_url, proxy_config.proxy_port
))
.await
.map_err(|original| Error::ProxyIoError(original))?;
let io = TokioIo::new(tcp);
let (mut request_sender, proxy_connection) =
conn::http1::handshake::<TokioIo<tokio::net::TcpStream>, String>(io).await?;
let proxy_connection_task = tokio::spawn(proxy_connection.without_shutdown());
let addr = format!("{domain}:{port}");
let mut req_builder = Request::connect(addr);
if let Some(auth) = proxy_config.auth {
req_builder = req_builder.header("Proxy-Authorization", format!("Basic {}", auth.encode()));
}
let req = req_builder
.body("".to_string())
.map_err(|orig| Error::ProxyHttpError(orig))?;
let res = request_sender.send_request(req).await?;
if res.status().as_u16() < 200 || res.status().as_u16() >= 300 {
return Err(Error::ProxyStatusError(res.status().as_u16()));
}
// expect is fine since it would only rely panics from within the tokio task (or a cancellation which does not happen)
let proxy_connection = proxy_connection_task
.await
.expect("Panic in tokio task during websocket proxy initialization")?;
let proxy_tcp_wrapper = proxy_connection.io;
let proxied_tcp_socket = proxy_tcp_wrapper.into_inner();
let (ws_stream, _) = client_async_tls_with_config(
request,
proxied_tcp_socket,
config.map(Into::into),
tls_connector,
)
.await?;
Ok(ws_stream)
}
#[tauri::command]
async fn send(
manager: State<'_, ConnectionManager>,
@ -200,12 +319,14 @@ pub fn init<R: Runtime>() -> TauriPlugin<R> {
#[derive(Default)]
pub struct Builder {
tls_connector: Option<Connector>,
proxy_configuration: Option<ProxyConfiguration>,
}
impl Builder {
pub fn new() -> Self {
Self {
tls_connector: None,
proxy_configuration: None,
}
}
@ -214,14 +335,60 @@ impl Builder {
self
}
pub fn proxy_configuration(mut self, proxy_configuration: ProxyConfiguration) -> Self {
self.proxy_configuration.replace(proxy_configuration);
self
}
pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
PluginBuilder::new("websocket")
.invoke_handler(tauri::generate_handler![connect, send])
.setup(|app| {
app.manage(ConnectionManager::default());
app.manage(TlsConnector(Mutex::new(self.tls_connector)));
app.manage(ProxyConfigurationInternal(Mutex::new(
self.proxy_configuration,
)));
Ok(())
})
.build()
}
}
pub async fn reconfigure_proxy(app: &AppHandle, proxy_config: Option<ProxyConfiguration>) {
if let Some(state) = app.try_state::<ProxyConfigurationInternal>() {
if let Some(proxy_config) = proxy_config {
state.0.lock().await.replace(proxy_config);
} else {
state.0.lock().await.take();
}
}
}
pub async fn reconfigure_tls_connector(app: &AppHandle, tls_connector: Option<Connector>) {
if let Some(state) = app.try_state::<TlsConnector>() {
if let Some(tls_connector) = tls_connector {
state.0.lock().await.replace(tls_connector);
} else {
state.0.lock().await.take();
}
}
}
// Copied from tokio-tungstenite internal function (tokio-tungstenite/src/lib.rs) with the same name
// Get a domain from an URL.
#[inline]
fn domain(
request: &tokio_tungstenite::tungstenite::handshake::client::Request,
) -> tokio_tungstenite::tungstenite::Result<String, tokio_tungstenite::tungstenite::Error> {
match request.uri().host() {
// rustls expects IPv6 addresses without the surrounding [] brackets
#[cfg(feature = "__rustls-tls")]
Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
Some(d) => Ok(d.to_string()),
None => Err(tokio_tungstenite::tungstenite::Error::Url(
tokio_tungstenite::tungstenite::error::UrlError::NoHostName,
)),
}
}

Loading…
Cancel
Save