You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
tauri-plugins-workspace/plugins/http/src/commands.rs

470 lines
15 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: MIT
use std::{future::Future, pin::Pin, str::FromStr, sync::Arc, time::Duration};
use http::{header, HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
use reqwest::{redirect::Policy, NoProxy};
use serde::{Deserialize, Serialize};
use tauri::{
async_runtime::Mutex,
command,
ipc::{Channel, CommandScope, GlobalScope},
Manager, ResourceId, ResourceTable, Runtime, State, Webview,
};
use tokio::sync::oneshot::{channel, Receiver, Sender};
use crate::{
scope::{Entry, Scope},
Error, Http, Result,
};
const HTTP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
struct ReqwestResponse(reqwest::Response);
impl tauri::Resource for ReqwestResponse {}
type CancelableResponseResult = Result<reqwest::Response>;
type CancelableResponseFuture =
Pin<Box<dyn Future<Output = CancelableResponseResult> + Send + Sync>>;
struct FetchRequest {
fut: Mutex<CancelableResponseFuture>,
abort_tx_rid: ResourceId,
abort_rx_rid: ResourceId,
}
impl tauri::Resource for FetchRequest {}
struct AbortSender(Sender<()>);
impl tauri::Resource for AbortRecveiver {}
impl AbortSender {
fn abort(self) {
let _ = self.0.send(());
}
}
struct AbortRecveiver(Receiver<()>);
impl tauri::Resource for AbortSender {}
trait AddRequest {
fn add_request(&mut self, fut: CancelableResponseFuture) -> ResourceId;
}
impl AddRequest for ResourceTable {
fn add_request(&mut self, fut: CancelableResponseFuture) -> ResourceId {
let (tx, rx) = channel::<()>();
let (tx, rx) = (AbortSender(tx), AbortRecveiver(rx));
let req = FetchRequest {
fut: Mutex::new(fut),
abort_tx_rid: self.add(tx),
abort_rx_rid: self.add(rx),
};
self.add(req)
}
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct FetchResponse {
status: u16,
status_text: String,
headers: Vec<(String, String)>,
url: String,
rid: ResourceId,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)] //feature flags shoudln't affect api
pub struct DangerousSettings {
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClientConfig {
method: String,
url: url::Url,
headers: Vec<(String, String)>,
data: Option<Vec<u8>>,
connect_timeout: Option<u64>,
max_redirections: Option<usize>,
proxy: Option<Proxy>,
danger: Option<DangerousSettings>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Proxy {
all: Option<UrlOrConfig>,
http: Option<UrlOrConfig>,
https: Option<UrlOrConfig>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[serde(untagged)]
pub enum UrlOrConfig {
Url(String),
Config(ProxyConfig),
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ProxyConfig {
url: String,
basic_auth: Option<BasicAuth>,
no_proxy: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct BasicAuth {
username: String,
password: String,
}
#[inline]
fn proxy_creator(
url_or_config: UrlOrConfig,
proxy_fn: fn(String) -> reqwest::Result<reqwest::Proxy>,
) -> reqwest::Result<reqwest::Proxy> {
match url_or_config {
UrlOrConfig::Url(url) => Ok(proxy_fn(url)?),
UrlOrConfig::Config(ProxyConfig {
url,
basic_auth,
no_proxy,
}) => {
let mut proxy = proxy_fn(url)?;
if let Some(basic_auth) = basic_auth {
proxy = proxy.basic_auth(&basic_auth.username, &basic_auth.password);
}
if let Some(no_proxy) = no_proxy {
proxy = proxy.no_proxy(NoProxy::from_string(&no_proxy));
}
Ok(proxy)
}
}
}
fn attach_proxy(
proxy: Proxy,
mut builder: reqwest::ClientBuilder,
) -> crate::Result<reqwest::ClientBuilder> {
let Proxy { all, http, https } = proxy;
if let Some(all) = all {
let proxy = proxy_creator(all, reqwest::Proxy::all)?;
builder = builder.proxy(proxy);
}
if let Some(http) = http {
let proxy = proxy_creator(http, reqwest::Proxy::http)?;
builder = builder.proxy(proxy);
}
if let Some(https) = https {
let proxy = proxy_creator(https, reqwest::Proxy::https)?;
builder = builder.proxy(proxy);
}
Ok(builder)
}
#[command]
pub async fn fetch<R: Runtime>(
webview: Webview<R>,
state: State<'_, Http>,
client_config: ClientConfig,
command_scope: CommandScope<Entry>,
global_scope: GlobalScope<Entry>,
) -> crate::Result<ResourceId> {
let ClientConfig {
method,
url,
headers: headers_raw,
data,
connect_timeout,
max_redirections,
proxy,
danger,
} = client_config;
let scheme = url.scheme();
let method = Method::from_bytes(method.as_bytes())?;
let mut headers = HeaderMap::new();
for (h, v) in headers_raw {
let name = HeaderName::from_str(&h)?;
#[cfg(not(feature = "unsafe-headers"))]
if is_unsafe_header(&name) {
#[cfg(debug_assertions)]
{
eprintln!("[\x1b[33mWARNING\x1b[0m] Skipping {name} header as it is a forbidden header per fetch spec https://fetch.spec.whatwg.org/#terminology-headers");
eprintln!("[\x1b[33mWARNING\x1b[0m] if keeping the header is a desired behavior, you can enable `unsafe-headers` feature flag in your Cargo.toml");
}
continue;
}
headers.append(name, HeaderValue::from_str(&v)?);
}
match scheme {
"http" | "https" => {
if Scope::new(
command_scope
.allows()
.iter()
.chain(global_scope.allows())
.collect(),
command_scope
.denies()
.iter()
.chain(global_scope.denies())
.collect(),
)
.is_allowed(&url)
{
let mut builder = reqwest::ClientBuilder::new();
if let Some(danger_config) = danger {
#[cfg(not(feature = "dangerous-settings"))]
{
#[cfg(debug_assertions)]
{
eprintln!("[\x1b[33mWARNING\x1b[0m] using dangerous settings requires `dangerous-settings` feature flag in your Cargo.toml");
}
let _ = danger_config;
return Err(Error::DangerousSettings);
}
#[cfg(feature = "dangerous-settings")]
{
builder = builder
.danger_accept_invalid_certs(danger_config.accept_invalid_certs)
.danger_accept_invalid_hostnames(danger_config.accept_invalid_hostnames)
}
}
if let Some(timeout) = connect_timeout {
builder = builder.connect_timeout(Duration::from_millis(timeout));
}
if let Some(max_redirections) = max_redirections {
builder = builder.redirect(if max_redirections == 0 {
Policy::none()
} else {
Policy::limited(max_redirections)
});
}
if let Some(proxy_config) = proxy {
builder = attach_proxy(proxy_config, builder)?;
}
#[cfg(feature = "cookies")]
{
builder = builder.cookie_provider(state.cookies_jar.clone());
}
let mut request = builder.build()?.request(method.clone(), url);
// POST and PUT requests should always have a 0 length content-length,
// if there is no body. https://fetch.spec.whatwg.org/#http-network-or-cache-fetch
if data.is_none() && matches!(method, Method::POST | Method::PUT) {
headers.append(header::CONTENT_LENGTH, HeaderValue::from_str("0")?);
}
if headers.contains_key(header::RANGE) {
// https://fetch.spec.whatwg.org/#http-network-or-cache-fetch step 18
// If httpRequests header list contains `Range`, then append (`Accept-Encoding`, `identity`)
headers.append(header::ACCEPT_ENCODING, HeaderValue::from_str("identity")?);
}
if !headers.contains_key(header::USER_AGENT) {
headers.append(header::USER_AGENT, HeaderValue::from_str(HTTP_USER_AGENT)?);
}
// ensure we have an Origin header set
if cfg!(not(feature = "unsafe-headers")) || !headers.contains_key(header::ORIGIN) {
if let Ok(url) = webview.url() {
headers.append(
header::ORIGIN,
HeaderValue::from_str(&url.origin().ascii_serialization())?,
);
}
}
// In case empty origin is passed, remove it. Some services do not like Origin header
// so this way we can remove it in explicit way. The default behaviour is still to set it
if cfg!(feature = "unsafe-headers")
&& headers.get(header::ORIGIN) == Some(&HeaderValue::from_static(""))
{
headers.remove(header::ORIGIN);
};
if let Some(data) = data {
request = request.body(data);
}
request = request.headers(headers);
#[cfg(feature = "tracing")]
tracing::trace!("{:?}", request);
let fut = async move { request.send().await.map_err(Into::into) };
let mut resources_table = webview.resources_table();
let rid = resources_table.add_request(Box::pin(fut));
Ok(rid)
} else {
Err(Error::UrlNotAllowed(url))
}
}
"data" => {
let data_url =
data_url::DataUrl::process(url.as_str()).map_err(|_| Error::DataUrlError)?;
let (body, _) = data_url
.decode_to_vec()
.map_err(|_| Error::DataUrlDecodeError)?;
let response = http::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, data_url.mime_type().to_string())
.body(reqwest::Body::from(body))?;
#[cfg(feature = "tracing")]
tracing::trace!("{:?}", response);
let fut = async move { Ok(reqwest::Response::from(response)) };
let mut resources_table = webview.resources_table();
let rid = resources_table.add_request(Box::pin(fut));
Ok(rid)
}
_ => Err(Error::SchemeNotSupport(scheme.to_string())),
}
}
#[command]
pub fn fetch_cancel<R: Runtime>(webview: Webview<R>, rid: ResourceId) -> crate::Result<()> {
let mut resources_table = webview.resources_table();
let req = resources_table.get::<FetchRequest>(rid)?;
let abort_tx = resources_table.take::<AbortSender>(req.abort_tx_rid)?;
if let Some(abort_tx) = Arc::into_inner(abort_tx) {
abort_tx.abort();
}
Ok(())
}
#[command]
pub async fn fetch_send<R: Runtime>(
webview: Webview<R>,
rid: ResourceId,
) -> crate::Result<FetchResponse> {
let (req, abort_rx) = {
let mut resources_table = webview.resources_table();
let req = resources_table.get::<FetchRequest>(rid)?;
let abort_rx = resources_table.take::<AbortRecveiver>(req.abort_rx_rid)?;
(req, abort_rx)
};
let Some(abort_rx) = Arc::into_inner(abort_rx) else {
return Err(Error::RequestCanceled);
};
let mut fut = req.fut.lock().await;
let res = tokio::select! {
res = fut.as_mut() => res?,
_ = abort_rx.0 => {
let mut resources_table = webview.resources_table();
resources_table.close(rid)?;
return Err(Error::RequestCanceled);
}
};
#[cfg(feature = "tracing")]
tracing::trace!("{:?}", res);
let status = res.status();
let url = res.url().to_string();
let mut headers = Vec::new();
for (key, val) in res.headers().iter() {
headers.push((
key.as_str().into(),
String::from_utf8(val.as_bytes().to_vec())?,
));
}
let mut resources_table = webview.resources_table();
let rid = resources_table.add(ReqwestResponse(res));
Ok(FetchResponse {
status: status.as_u16(),
status_text: status.canonical_reason().unwrap_or_default().to_string(),
headers,
url,
rid,
})
}
#[command]
pub async fn fetch_read_body<R: Runtime>(
webview: Webview<R>,
rid: ResourceId,
stream_channel: Channel<tauri::ipc::InvokeResponseBody>,
) -> crate::Result<()> {
let res = {
let mut resources_table = webview.resources_table();
resources_table.take::<ReqwestResponse>(rid)?
};
let mut res = Arc::into_inner(res).unwrap().0;
// send response through IPC channel
while let Some(chunk) = res.chunk().await? {
let mut chunk = chunk.to_vec();
// append 0 to indicate we are not done yet
chunk.push(0);
stream_channel.send(tauri::ipc::InvokeResponseBody::Raw(chunk))?;
}
// send 1 to indicate we are done
stream_channel.send(tauri::ipc::InvokeResponseBody::Raw(vec![1]))?;
Ok(())
}
// forbidden headers per fetch spec https://fetch.spec.whatwg.org/#terminology-headers
#[cfg(not(feature = "unsafe-headers"))]
fn is_unsafe_header(header: &HeaderName) -> bool {
matches!(
*header,
header::ACCEPT_CHARSET
| header::ACCEPT_ENCODING
| header::ACCESS_CONTROL_REQUEST_HEADERS
| header::ACCESS_CONTROL_REQUEST_METHOD
| header::CONNECTION
| header::CONTENT_LENGTH
| header::COOKIE
| header::DATE
| header::DNT
| header::EXPECT
| header::HOST
| header::ORIGIN
| header::REFERER
| header::SET_COOKIE
| header::TE
| header::TRAILER
| header::TRANSFER_ENCODING
| header::UPGRADE
| header::VIA
) || {
let lower = header.as_str().to_lowercase();
lower.starts_with("proxy-") || lower.starts_with("sec-")
}
}