diff --git a/plugins/websocket/guest-js/index.ts b/plugins/websocket/guest-js/index.ts index bd6ca752..30da3672 100644 --- a/plugins/websocket/guest-js/index.ts +++ b/plugins/websocket/guest-js/index.ts @@ -37,7 +37,7 @@ export default class WebSocket { this.listeners = listeners; } - static async connect(url: string, options?: unknown): Promise { + static async connect(url: string, config?: unknown): Promise { const listeners: Array<(arg: Message) => void> = []; const handler = (message: Message): void => { listeners.forEach((l) => l(message)); @@ -47,7 +47,7 @@ export default class WebSocket { .__TAURI_INVOKE__("plugin:websocket|connect", { url, callbackFunction: window.__TAURI__.transformCallback(handler), - options, + config, }) .then((id) => new WebSocket(id, listeners)); } diff --git a/plugins/websocket/src/api-iife.js b/plugins/websocket/src/api-iife.js index cea3c830..5305d078 100644 --- a/plugins/websocket/src/api-iife.js +++ b/plugins/websocket/src/api-iife.js @@ -1 +1 @@ -if("__TAURI__"in window){var __TAURI_WEBSOCKET__=function(){"use strict";class e{constructor(e,t){this.id=e,this.listeners=t}static async connect(t,n){const i=[];return await window.__TAURI_INVOKE__("plugin:websocket|connect",{url:t,callbackFunction:window.__TAURI__.transformCallback((e=>{i.forEach((t=>t(e)))})),options:n}).then((t=>new e(t,i)))}addListener(e){this.listeners.push(e)}async send(e){let t;if("string"==typeof e)t={type:"Text",data:e};else if("object"==typeof e&&"type"in e)t=e;else{if(!Array.isArray(e))throw new Error("invalid `message` type, expected a `{ type: string, data: any }` object, a string or a numeric array");t={type:"Binary",data:e}}return await window.__TAURI_INVOKE__("plugin:websocket|send",{id:this.id,message:t})}async disconnect(){return await this.send({type:"Close",data:{code:1e3,reason:"Disconnected by client"}})}}return e}();Object.defineProperty(window.__TAURI__,"websocket",{value:__TAURI_WEBSOCKET__})} +if("__TAURI__"in window){var __TAURI_WEBSOCKET__=function(){"use strict";class e{constructor(e,t){this.id=e,this.listeners=t}static async connect(t,n){const i=[];return await window.__TAURI_INVOKE__("plugin:websocket|connect",{url:t,callbackFunction:window.__TAURI__.transformCallback((e=>{i.forEach((t=>t(e)))})),config:n}).then((t=>new e(t,i)))}addListener(e){this.listeners.push(e)}async send(e){let t;if("string"==typeof e)t={type:"Text",data:e};else if("object"==typeof e&&"type"in e)t=e;else{if(!Array.isArray(e))throw new Error("invalid `message` type, expected a `{ type: string, data: any }` object, a string or a numeric array");t={type:"Binary",data:e}}return await window.__TAURI_INVOKE__("plugin:websocket|send",{id:this.id,message:t})}async disconnect(){return await this.send({type:"Close",data:{code:1e3,reason:"Disconnected by client"}})}}return e}();Object.defineProperty(window.__TAURI__,"websocket",{value:__TAURI_WEBSOCKET__})} diff --git a/plugins/websocket/src/lib.rs b/plugins/websocket/src/lib.rs index 31d744c3..105d2c90 100644 --- a/plugins/websocket/src/lib.rs +++ b/plugins/websocket/src/lib.rs @@ -29,6 +29,9 @@ use tokio_tungstenite::{ }; use std::collections::HashMap; +use std::str::FromStr; +use tauri::http::header::{HeaderName, HeaderValue}; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; type Id = u32; type WebSocket = WebSocketStream>; @@ -41,6 +44,10 @@ enum Error { Websocket(#[from] tokio_tungstenite::tungstenite::Error), #[error("connection not found for the given id: {0}")] ConnectionNotFound(Id), + #[error(transparent)] + InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue), + #[error(transparent)] + InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName), } impl Serialize for Error { @@ -55,13 +62,14 @@ impl Serialize for Error { #[derive(Default)] struct ConnectionManager(Mutex>); -#[derive(Default, Deserialize)] +#[derive(Deserialize)] #[serde(rename_all = "camelCase")] pub struct ConnectionConfig { pub max_send_queue: Option, pub max_message_size: Option, pub max_frame_size: Option, - pub accept_unmasked_frames: bool, + pub accept_unmasked_frames: Option, + pub headers: Option>, } impl From for WebSocketConfig { @@ -70,7 +78,7 @@ impl From for WebSocketConfig { max_send_queue: config.max_send_queue, max_message_size: config.max_message_size, max_frame_size: config.max_frame_size, - accept_unmasked_frames: config.accept_unmasked_frames, + accept_unmasked_frames: config.accept_unmasked_frames.unwrap_or_default(), } } } @@ -99,7 +107,21 @@ async fn connect( config: Option, ) -> Result { let id = rand::random(); - let (ws_stream, _) = connect_async_with_config(url, config.map(Into::into), false).await?; + let mut request = url.into_client_request()?; + + if let Some(ref config) = config { + if let Some(headers) = &config.headers { + let config_headers = headers.iter().map(|(k, v)| { + let header_name = HeaderName::from_str(k.as_str())?; + let header_value = HeaderValue::from_str(v.as_str())?; + Ok((header_name, header_value)) + }); + + request.headers_mut().extend(config_headers.filter_map(Result::ok)); + } + } + + let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?; tauri::async_runtime::spawn(async move { let (write, read) = ws_stream.split();