From 8ce517412e8c6d93f2fda51559858573b7a93620 Mon Sep 17 00:00:00 2001 From: DreamingCodes Date: Tue, 8 Aug 2023 16:26:52 +0200 Subject: [PATCH] feat(websocket): Add custom headers feature --- plugins/websocket/guest-js/index.ts | 1 + plugins/websocket/src/lib.rs | 23 ++++++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/plugins/websocket/guest-js/index.ts b/plugins/websocket/guest-js/index.ts index 0f748092..ef34e6f3 100644 --- a/plugins/websocket/guest-js/index.ts +++ b/plugins/websocket/guest-js/index.ts @@ -6,6 +6,7 @@ export interface ConnectionConfig { maxMessageSize?: number; maxFrameSize?: number; acceptUnmaskedFrames?: boolean; + headers?: Record; } export interface MessageKind { diff --git a/plugins/websocket/src/lib.rs b/plugins/websocket/src/lib.rs index 3d3878b5..03552aac 100644 --- a/plugins/websocket/src/lib.rs +++ b/plugins/websocket/src/lib.rs @@ -16,6 +16,9 @@ use tokio_tungstenite::{ }; use std::collections::HashMap; +use std::str::FromStr; +use tauri::http::header::{HeaderName, HeaderValue}; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; type Id = u32; type WebSocket = WebSocketStream>; @@ -28,6 +31,10 @@ enum Error { Websocket(#[from] tokio_tungstenite::tungstenite::Error), #[error("connection not found for the given id: {0}")] ConnectionNotFound(Id), + #[error(transparent)] + InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue), + #[error(transparent)] + InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName), } impl Serialize for Error { @@ -51,6 +58,8 @@ pub struct ConnectionConfig { pub max_frame_size: Option, #[serde(default)] pub accept_unmasked_frames: bool, + #[serde(default)] + pub headers: HashMap, } impl From for WebSocketConfig { @@ -94,7 +103,19 @@ async fn connect( config: Option, ) -> Result { let id = rand::random(); - let (ws_stream, _) = connect_async_with_config(url, config.map(Into::into), false).await?; + let mut request = url.into_client_request()?; + + if let Some(ref config) = config { + let config_headers = config.headers.iter().map(|(k, v)| { + let header_name = HeaderName::from_str(k.as_str())?; + let header_value = HeaderValue::from_str(v.as_str())?; + Ok((header_name, header_value)) + }); + + request.headers_mut().extend(config_headers.filter_map(Result::ok)); + } + + let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?; tauri::async_runtime::spawn(async move { let (write, read) = ws_stream.split();