// Copyright 2019-2023 Tauri Programme within The Commons Conservancy // SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: MIT use std::{future::Future, pin::Pin, str::FromStr, sync::Arc, time::Duration}; use http::{header, HeaderMap, HeaderName, HeaderValue, Method, StatusCode}; use reqwest::{redirect::Policy, NoProxy}; use serde::{Deserialize, Serialize}; use tauri::{ async_runtime::Mutex, command, ipc::{Channel, CommandScope, GlobalScope}, Manager, ResourceId, ResourceTable, Runtime, State, Webview, }; use tokio::sync::oneshot::{channel, Receiver, Sender}; use crate::{ scope::{Entry, Scope}, Error, Http, Result, }; const HTTP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); struct ReqwestResponse(reqwest::Response); impl tauri::Resource for ReqwestResponse {} type CancelableResponseResult = Result; type CancelableResponseFuture = Pin + Send + Sync>>; struct FetchRequest { fut: Mutex, abort_tx_rid: ResourceId, abort_rx_rid: ResourceId, } impl tauri::Resource for FetchRequest {} struct AbortSender(Sender<()>); impl tauri::Resource for AbortRecveiver {} impl AbortSender { fn abort(self) { let _ = self.0.send(()); } } struct AbortRecveiver(Receiver<()>); impl tauri::Resource for AbortSender {} trait AddRequest { fn add_request(&mut self, fut: CancelableResponseFuture) -> ResourceId; } impl AddRequest for ResourceTable { fn add_request(&mut self, fut: CancelableResponseFuture) -> ResourceId { let (tx, rx) = channel::<()>(); let (tx, rx) = (AbortSender(tx), AbortRecveiver(rx)); let req = FetchRequest { fut: Mutex::new(fut), abort_tx_rid: self.add(tx), abort_rx_rid: self.add(rx), }; self.add(req) } } #[derive(Serialize)] #[serde(rename_all = "camelCase")] pub struct FetchResponse { status: u16, status_text: String, headers: Vec<(String, String)>, url: String, rid: ResourceId, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] #[allow(dead_code)] //feature flags shoudln't affect api pub struct DangerousSettings { accept_invalid_certs: bool, accept_invalid_hostnames: bool, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ClientConfig { method: String, url: url::Url, headers: Vec<(String, String)>, data: Option>, connect_timeout: Option, max_redirections: Option, proxy: Option, danger: Option, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Proxy { all: Option, http: Option, https: Option, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] #[serde(untagged)] pub enum UrlOrConfig { Url(String), Config(ProxyConfig), } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ProxyConfig { url: String, basic_auth: Option, no_proxy: Option, } #[derive(Debug, Deserialize)] pub struct BasicAuth { username: String, password: String, } #[inline] fn proxy_creator( url_or_config: UrlOrConfig, proxy_fn: fn(String) -> reqwest::Result, ) -> reqwest::Result { match url_or_config { UrlOrConfig::Url(url) => Ok(proxy_fn(url)?), UrlOrConfig::Config(ProxyConfig { url, basic_auth, no_proxy, }) => { let mut proxy = proxy_fn(url)?; if let Some(basic_auth) = basic_auth { proxy = proxy.basic_auth(&basic_auth.username, &basic_auth.password); } if let Some(no_proxy) = no_proxy { proxy = proxy.no_proxy(NoProxy::from_string(&no_proxy)); } Ok(proxy) } } } fn attach_proxy( proxy: Proxy, mut builder: reqwest::ClientBuilder, ) -> crate::Result { let Proxy { all, http, https } = proxy; if let Some(all) = all { let proxy = proxy_creator(all, reqwest::Proxy::all)?; builder = builder.proxy(proxy); } if let Some(http) = http { let proxy = proxy_creator(http, reqwest::Proxy::http)?; builder = builder.proxy(proxy); } if let Some(https) = https { let proxy = proxy_creator(https, reqwest::Proxy::https)?; builder = builder.proxy(proxy); } Ok(builder) } #[command] pub async fn fetch( webview: Webview, state: State<'_, Http>, client_config: ClientConfig, command_scope: CommandScope, global_scope: GlobalScope, ) -> crate::Result { let ClientConfig { method, url, headers: headers_raw, data, connect_timeout, max_redirections, proxy, danger, } = client_config; let scheme = url.scheme(); let method = Method::from_bytes(method.as_bytes())?; let mut headers = HeaderMap::new(); for (h, v) in headers_raw { let name = HeaderName::from_str(&h)?; #[cfg(not(feature = "unsafe-headers"))] if is_unsafe_header(&name) { #[cfg(debug_assertions)] { eprintln!("[\x1b[33mWARNING\x1b[0m] Skipping {name} header as it is a forbidden header per fetch spec https://fetch.spec.whatwg.org/#terminology-headers"); eprintln!("[\x1b[33mWARNING\x1b[0m] if keeping the header is a desired behavior, you can enable `unsafe-headers` feature flag in your Cargo.toml"); } continue; } headers.append(name, HeaderValue::from_str(&v)?); } match scheme { "http" | "https" => { if Scope::new( command_scope .allows() .iter() .chain(global_scope.allows()) .collect(), command_scope .denies() .iter() .chain(global_scope.denies()) .collect(), ) .is_allowed(&url) { let mut builder = reqwest::ClientBuilder::new(); if let Some(danger_config) = danger { #[cfg(not(feature = "dangerous-settings"))] { #[cfg(debug_assertions)] { eprintln!("[\x1b[33mWARNING\x1b[0m] using dangerous settings requires `dangerous-settings` feature flag in your Cargo.toml"); } let _ = danger_config; return Err(Error::DangerousSettings); } #[cfg(feature = "dangerous-settings")] { builder = builder .danger_accept_invalid_certs(danger_config.accept_invalid_certs) .danger_accept_invalid_hostnames(danger_config.accept_invalid_hostnames) } } if let Some(timeout) = connect_timeout { builder = builder.connect_timeout(Duration::from_millis(timeout)); } if let Some(max_redirections) = max_redirections { builder = builder.redirect(if max_redirections == 0 { Policy::none() } else { Policy::limited(max_redirections) }); } if let Some(proxy_config) = proxy { builder = attach_proxy(proxy_config, builder)?; } #[cfg(feature = "cookies")] { builder = builder.cookie_provider(state.cookies_jar.clone()); } let mut request = builder.build()?.request(method.clone(), url); // POST and PUT requests should always have a 0 length content-length, // if there is no body. https://fetch.spec.whatwg.org/#http-network-or-cache-fetch if data.is_none() && matches!(method, Method::POST | Method::PUT) { headers.append(header::CONTENT_LENGTH, HeaderValue::from_str("0")?); } if headers.contains_key(header::RANGE) { // https://fetch.spec.whatwg.org/#http-network-or-cache-fetch step 18 // If httpRequest’s header list contains `Range`, then append (`Accept-Encoding`, `identity`) headers.append(header::ACCEPT_ENCODING, HeaderValue::from_str("identity")?); } if !headers.contains_key(header::USER_AGENT) { headers.append(header::USER_AGENT, HeaderValue::from_str(HTTP_USER_AGENT)?); } // ensure we have an Origin header set if cfg!(not(feature = "unsafe-headers")) || !headers.contains_key(header::ORIGIN) { if let Ok(url) = webview.url() { headers.append( header::ORIGIN, HeaderValue::from_str(&url.origin().ascii_serialization())?, ); } } // In case empty origin is passed, remove it. Some services do not like Origin header // so this way we can remove it in explicit way. The default behaviour is still to set it if cfg!(feature = "unsafe-headers") && headers.get(header::ORIGIN) == Some(&HeaderValue::from_static("")) { headers.remove(header::ORIGIN); }; if let Some(data) = data { request = request.body(data); } request = request.headers(headers); #[cfg(feature = "tracing")] tracing::trace!("{:?}", request); let fut = async move { request.send().await.map_err(Into::into) }; let mut resources_table = webview.resources_table(); let rid = resources_table.add_request(Box::pin(fut)); Ok(rid) } else { Err(Error::UrlNotAllowed(url)) } } "data" => { let data_url = data_url::DataUrl::process(url.as_str()).map_err(|_| Error::DataUrlError)?; let (body, _) = data_url .decode_to_vec() .map_err(|_| Error::DataUrlDecodeError)?; let response = http::Response::builder() .status(StatusCode::OK) .header(header::CONTENT_TYPE, data_url.mime_type().to_string()) .body(reqwest::Body::from(body))?; #[cfg(feature = "tracing")] tracing::trace!("{:?}", response); let fut = async move { Ok(reqwest::Response::from(response)) }; let mut resources_table = webview.resources_table(); let rid = resources_table.add_request(Box::pin(fut)); Ok(rid) } _ => Err(Error::SchemeNotSupport(scheme.to_string())), } } #[command] pub fn fetch_cancel(webview: Webview, rid: ResourceId) -> crate::Result<()> { let mut resources_table = webview.resources_table(); let req = resources_table.get::(rid)?; let abort_tx = resources_table.take::(req.abort_tx_rid)?; if let Some(abort_tx) = Arc::into_inner(abort_tx) { abort_tx.abort(); } Ok(()) } #[command] pub async fn fetch_send( webview: Webview, rid: ResourceId, ) -> crate::Result { let (req, abort_rx) = { let mut resources_table = webview.resources_table(); let req = resources_table.get::(rid)?; let abort_rx = resources_table.take::(req.abort_rx_rid)?; (req, abort_rx) }; let Some(abort_rx) = Arc::into_inner(abort_rx) else { return Err(Error::RequestCanceled); }; let mut fut = req.fut.lock().await; let res = tokio::select! { res = fut.as_mut() => res?, _ = abort_rx.0 => { let mut resources_table = webview.resources_table(); resources_table.close(rid)?; return Err(Error::RequestCanceled); } }; #[cfg(feature = "tracing")] tracing::trace!("{:?}", res); let status = res.status(); let url = res.url().to_string(); let mut headers = Vec::new(); for (key, val) in res.headers().iter() { headers.push(( key.as_str().into(), String::from_utf8(val.as_bytes().to_vec())?, )); } let mut resources_table = webview.resources_table(); let rid = resources_table.add(ReqwestResponse(res)); Ok(FetchResponse { status: status.as_u16(), status_text: status.canonical_reason().unwrap_or_default().to_string(), headers, url, rid, }) } #[command] pub async fn fetch_read_body( webview: Webview, rid: ResourceId, stream_channel: Channel, ) -> crate::Result<()> { let res = { let mut resources_table = webview.resources_table(); resources_table.take::(rid)? }; let mut res = Arc::into_inner(res).unwrap().0; // send response through IPC channel while let Some(chunk) = res.chunk().await? { let mut chunk = chunk.to_vec(); // append 0 to indicate we are not done yet chunk.push(0); stream_channel.send(tauri::ipc::InvokeResponseBody::Raw(chunk))?; } // send 1 to indicate we are done stream_channel.send(tauri::ipc::InvokeResponseBody::Raw(vec![1]))?; Ok(()) } // forbidden headers per fetch spec https://fetch.spec.whatwg.org/#terminology-headers #[cfg(not(feature = "unsafe-headers"))] fn is_unsafe_header(header: &HeaderName) -> bool { matches!( *header, header::ACCEPT_CHARSET | header::ACCEPT_ENCODING | header::ACCESS_CONTROL_REQUEST_HEADERS | header::ACCESS_CONTROL_REQUEST_METHOD | header::CONNECTION | header::CONTENT_LENGTH | header::COOKIE | header::DATE | header::DNT | header::EXPECT | header::HOST | header::ORIGIN | header::REFERER | header::SET_COOKIE | header::TE | header::TRAILER | header::TRANSFER_ENCODING | header::UPGRADE | header::VIA ) || { let lower = header.as_str().to_lowercase(); lower.starts_with("proxy-") || lower.starts_with("sec-") } }