From a973c41a96b285cd9bdd0c002e64c7c67ee00b42 Mon Sep 17 00:00:00 2001 From: DreamingCodes Date: Mon, 7 Aug 2023 22:20:32 +0200 Subject: [PATCH 1/2] feat(websocket): add custom headers support --- plugins/websocket/src/lib.rs | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/plugins/websocket/src/lib.rs b/plugins/websocket/src/lib.rs index 31d744c3..a7d63620 100644 --- a/plugins/websocket/src/lib.rs +++ b/plugins/websocket/src/lib.rs @@ -29,6 +29,9 @@ use tokio_tungstenite::{ }; use std::collections::HashMap; +use std::str::FromStr; +use tauri::http::header::{HeaderName, HeaderValue}; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; type Id = u32; type WebSocket = WebSocketStream>; @@ -41,6 +44,10 @@ enum Error { Websocket(#[from] tokio_tungstenite::tungstenite::Error), #[error("connection not found for the given id: {0}")] ConnectionNotFound(Id), + #[error(transparent)] + InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue), + #[error(transparent)] + InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName), } impl Serialize for Error { @@ -62,6 +69,7 @@ pub struct ConnectionConfig { pub max_message_size: Option, pub max_frame_size: Option, pub accept_unmasked_frames: bool, + pub headers: Vec<(String, String)>, } impl From for WebSocketConfig { @@ -99,7 +107,19 @@ async fn connect( config: Option, ) -> Result { let id = rand::random(); - let (ws_stream, _) = connect_async_with_config(url, config.map(Into::into), false).await?; + let mut request = url.into_client_request()?; + + if let Some(config) = config.as_ref() { + let config_headers = config.headers.iter().map(|(k, v)| { + let header_name = HeaderName::from_str(k.as_str())?; + let header_value = HeaderValue::from_str(v.as_str())?; + Ok((header_name, header_value)) + }); + + request.headers_mut().extend(config_headers.filter_map(Result::ok)); + } + + let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?; tauri::async_runtime::spawn(async move { let (write, read) = ws_stream.split(); From 8d8f8f3c658090ef44290526cad807b2d4424978 Mon Sep 17 00:00:00 2001 From: DreamingCodes Date: Mon, 7 Aug 2023 23:33:40 +0200 Subject: [PATCH 2/2] fix(websocket): fix config --- plugins/websocket/guest-js/index.ts | 4 ++-- plugins/websocket/src/api-iife.js | 2 +- plugins/websocket/src/lib.rs | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/plugins/websocket/guest-js/index.ts b/plugins/websocket/guest-js/index.ts index bd6ca752..30da3672 100644 --- a/plugins/websocket/guest-js/index.ts +++ b/plugins/websocket/guest-js/index.ts @@ -37,7 +37,7 @@ export default class WebSocket { this.listeners = listeners; } - static async connect(url: string, options?: unknown): Promise { + static async connect(url: string, config?: unknown): Promise { const listeners: Array<(arg: Message) => void> = []; const handler = (message: Message): void => { listeners.forEach((l) => l(message)); @@ -47,7 +47,7 @@ export default class WebSocket { .__TAURI_INVOKE__("plugin:websocket|connect", { url, callbackFunction: window.__TAURI__.transformCallback(handler), - options, + config, }) .then((id) => new WebSocket(id, listeners)); } diff --git a/plugins/websocket/src/api-iife.js b/plugins/websocket/src/api-iife.js index cea3c830..5305d078 100644 --- a/plugins/websocket/src/api-iife.js +++ b/plugins/websocket/src/api-iife.js @@ -1 +1 @@ -if("__TAURI__"in window){var __TAURI_WEBSOCKET__=function(){"use strict";class e{constructor(e,t){this.id=e,this.listeners=t}static async connect(t,n){const i=[];return await window.__TAURI_INVOKE__("plugin:websocket|connect",{url:t,callbackFunction:window.__TAURI__.transformCallback((e=>{i.forEach((t=>t(e)))})),options:n}).then((t=>new e(t,i)))}addListener(e){this.listeners.push(e)}async send(e){let t;if("string"==typeof e)t={type:"Text",data:e};else if("object"==typeof e&&"type"in e)t=e;else{if(!Array.isArray(e))throw new Error("invalid `message` type, expected a `{ type: string, data: any }` object, a string or a numeric array");t={type:"Binary",data:e}}return await window.__TAURI_INVOKE__("plugin:websocket|send",{id:this.id,message:t})}async disconnect(){return await this.send({type:"Close",data:{code:1e3,reason:"Disconnected by client"}})}}return e}();Object.defineProperty(window.__TAURI__,"websocket",{value:__TAURI_WEBSOCKET__})} +if("__TAURI__"in window){var __TAURI_WEBSOCKET__=function(){"use strict";class e{constructor(e,t){this.id=e,this.listeners=t}static async connect(t,n){const i=[];return await window.__TAURI_INVOKE__("plugin:websocket|connect",{url:t,callbackFunction:window.__TAURI__.transformCallback((e=>{i.forEach((t=>t(e)))})),config:n}).then((t=>new e(t,i)))}addListener(e){this.listeners.push(e)}async send(e){let t;if("string"==typeof e)t={type:"Text",data:e};else if("object"==typeof e&&"type"in e)t=e;else{if(!Array.isArray(e))throw new Error("invalid `message` type, expected a `{ type: string, data: any }` object, a string or a numeric array");t={type:"Binary",data:e}}return await window.__TAURI_INVOKE__("plugin:websocket|send",{id:this.id,message:t})}async disconnect(){return await this.send({type:"Close",data:{code:1e3,reason:"Disconnected by client"}})}}return e}();Object.defineProperty(window.__TAURI__,"websocket",{value:__TAURI_WEBSOCKET__})} diff --git a/plugins/websocket/src/lib.rs b/plugins/websocket/src/lib.rs index 31d744c3..3a753d07 100644 --- a/plugins/websocket/src/lib.rs +++ b/plugins/websocket/src/lib.rs @@ -55,13 +55,13 @@ impl Serialize for Error { #[derive(Default)] struct ConnectionManager(Mutex>); -#[derive(Default, Deserialize)] +#[derive(Deserialize)] #[serde(rename_all = "camelCase")] pub struct ConnectionConfig { pub max_send_queue: Option, pub max_message_size: Option, pub max_frame_size: Option, - pub accept_unmasked_frames: bool, + pub accept_unmasked_frames: Option, } impl From for WebSocketConfig { @@ -70,7 +70,7 @@ impl From for WebSocketConfig { max_send_queue: config.max_send_queue, max_message_size: config.max_message_size, max_frame_size: config.max_frame_size, - accept_unmasked_frames: config.accept_unmasked_frames, + accept_unmasked_frames: config.accept_unmasked_frames.unwrap_or_default(), } } }