From 7e58dc8502f654b99d51c087421f84ccc0e03119 Mon Sep 17 00:00:00 2001 From: Lorenzo Rizzotti Date: Tue, 8 Aug 2023 18:18:20 +0200 Subject: [PATCH] feat(websocket): add custom headers feature (#542) Co-authored-by: Lucas Nogueira --- plugins/websocket/guest-js/index.ts | 5 +++++ plugins/websocket/src/lib.rs | 20 +++++++++++++++++++- tsconfig.base.json | 2 +- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/plugins/websocket/guest-js/index.ts b/plugins/websocket/guest-js/index.ts index 0f748092..4967b837 100644 --- a/plugins/websocket/guest-js/index.ts +++ b/plugins/websocket/guest-js/index.ts @@ -6,6 +6,7 @@ export interface ConnectionConfig { maxMessageSize?: number; maxFrameSize?: number; acceptUnmaskedFrames?: boolean; + headers?: HeadersInit; } export interface MessageKind { @@ -43,6 +44,10 @@ export default class WebSocket { listeners.forEach((l) => l(message)); }; + if (config?.headers) { + config.headers = Array.from(new Headers(config.headers).entries()); + } + return await invoke("plugin:websocket|connect", { url, callbackFunction: transformCallback(handler), diff --git a/plugins/websocket/src/lib.rs b/plugins/websocket/src/lib.rs index 3d3878b5..ad692d60 100644 --- a/plugins/websocket/src/lib.rs +++ b/plugins/websocket/src/lib.rs @@ -16,6 +16,9 @@ use tokio_tungstenite::{ }; use std::collections::HashMap; +use std::str::FromStr; +use tauri::http::header::{HeaderName, HeaderValue}; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; type Id = u32; type WebSocket = WebSocketStream>; @@ -28,6 +31,10 @@ enum Error { Websocket(#[from] tokio_tungstenite::tungstenite::Error), #[error("connection not found for the given id: {0}")] ConnectionNotFound(Id), + #[error(transparent)] + InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue), + #[error(transparent)] + InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName), } impl Serialize for Error { @@ -51,6 +58,7 @@ pub struct ConnectionConfig { pub max_frame_size: Option, #[serde(default)] pub accept_unmasked_frames: bool, + pub headers: Option>, } impl From for WebSocketConfig { @@ -94,7 +102,17 @@ async fn connect( config: Option, ) -> Result { let id = rand::random(); - let (ws_stream, _) = connect_async_with_config(url, config.map(Into::into), false).await?; + let mut request = url.into_client_request()?; + + if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) { + for (k, v) in headers { + let header_name = HeaderName::from_str(k.as_str())?; + let header_value = HeaderValue::from_str(v.as_str())?; + request.headers_mut().insert(header_name, header_value); + } + } + + let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?; tauri::async_runtime::spawn(async move { let (write, read) = ws_stream.split(); diff --git a/tsconfig.base.json b/tsconfig.base.json index 1eebbeb6..629a7c96 100644 --- a/tsconfig.base.json +++ b/tsconfig.base.json @@ -2,7 +2,7 @@ "compilerOptions": { "allowSyntheticDefaultImports": true, "esModuleInterop": true, - "lib": ["ES2019", "ES2020.Promise", "ES2020.String", "DOM"], + "lib": ["ES2019", "ES2020.Promise", "ES2020.String", "DOM", "DOM.Iterable"], "module": "ESNext", "moduleResolution": "node", "noEmit": true,