feat(websocket): add custom headers feature (#542)

Co-authored-by: Lucas Nogueira <lucas@tauri.studio>
pull/551/head
Lorenzo Rizzotti 2 years ago committed by GitHub
parent ac495b9fb4
commit 7e58dc8502
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,6 +6,7 @@ export interface ConnectionConfig {
maxMessageSize?: number;
maxFrameSize?: number;
acceptUnmaskedFrames?: boolean;
headers?: HeadersInit;
}
export interface MessageKind<T, D> {
@ -43,6 +44,10 @@ export default class WebSocket {
listeners.forEach((l) => l(message));
};
if (config?.headers) {
config.headers = Array.from(new Headers(config.headers).entries());
}
return await invoke<number>("plugin:websocket|connect", {
url,
callbackFunction: transformCallback(handler),

@ -16,6 +16,9 @@ use tokio_tungstenite::{
};
use std::collections::HashMap;
use std::str::FromStr;
use tauri::http::header::{HeaderName, HeaderValue};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
type Id = u32;
type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
@ -28,6 +31,10 @@ enum Error {
Websocket(#[from] tokio_tungstenite::tungstenite::Error),
#[error("connection not found for the given id: {0}")]
ConnectionNotFound(Id),
#[error(transparent)]
InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue),
#[error(transparent)]
InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName),
}
impl Serialize for Error {
@ -51,6 +58,7 @@ pub struct ConnectionConfig {
pub max_frame_size: Option<usize>,
#[serde(default)]
pub accept_unmasked_frames: bool,
pub headers: Option<Vec<(String, String)>>,
}
impl From<ConnectionConfig> for WebSocketConfig {
@ -94,7 +102,17 @@ async fn connect<R: Runtime>(
config: Option<ConnectionConfig>,
) -> Result<Id> {
let id = rand::random();
let (ws_stream, _) = connect_async_with_config(url, config.map(Into::into), false).await?;
let mut request = url.into_client_request()?;
if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) {
for (k, v) in headers {
let header_name = HeaderName::from_str(k.as_str())?;
let header_value = HeaderValue::from_str(v.as_str())?;
request.headers_mut().insert(header_name, header_value);
}
}
let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?;
tauri::async_runtime::spawn(async move {
let (write, read) = ws_stream.split();

@ -2,7 +2,7 @@
"compilerOptions": {
"allowSyntheticDefaultImports": true,
"esModuleInterop": true,
"lib": ["ES2019", "ES2020.Promise", "ES2020.String", "DOM"],
"lib": ["ES2019", "ES2020.Promise", "ES2020.String", "DOM", "DOM.Iterable"],
"module": "ESNext",
"moduleResolution": "node",
"noEmit": true,

Loading…
Cancel
Save