diff --git a/plugins/upload/guest-js/index.ts b/plugins/upload/guest-js/index.ts index 1a605633..f13a1627 100644 --- a/plugins/upload/guest-js/index.ts +++ b/plugins/upload/guest-js/index.ts @@ -1,30 +1,38 @@ -import { invoke } from "@tauri-apps/api/tauri"; -import { appWindow } from "tauri-plugin-window-api"; +import { invoke, transformCallback } from "@tauri-apps/api/tauri"; interface ProgressPayload { - id: number; progress: number; total: number; } -type ProgressHandler = (progress: number, total: number) => void; -const handlers: Map = new Map(); -let listening = false; +type ProgressHandler = (progress: ProgressPayload) => void; -async function listenToEventIfNeeded(event: string): Promise { - if (listening) { - return await Promise.resolve(); - } - return await appWindow - .listen(event, ({ payload }) => { - const handler = handlers.get(payload.id); - if (handler != null) { - handler(payload.progress, payload.total); - } - }) - .then(() => { - listening = true; +// TODO: use channel from @tauri-apps/api on v2 +class Channel { + id: number; + // @ts-expect-error field used by the IPC serializer + private readonly __TAURI_CHANNEL_MARKER__ = true; + #onmessage: (response: T) => void = () => { + // no-op + }; + + constructor() { + this.id = transformCallback((response: T) => { + this.#onmessage(response); }); + } + + set onmessage(handler: (response: T) => void) { + this.#onmessage = handler; + } + + get onmessage(): (response: T) => void { + return this.#onmessage; + } + + toJSON(): string { + return `__CHANNEL__:${this.id}`; + } } async function upload( @@ -37,17 +45,17 @@ async function upload( window.crypto.getRandomValues(ids); const id = ids[0]; + const onProgress = new Channel(); if (progressHandler != null) { - handlers.set(id, progressHandler); + onProgress.onmessage = progressHandler; } - await listenToEventIfNeeded("upload://progress"); - await invoke("plugin:upload|upload", { id, url, filePath, headers: headers ?? {}, + onProgress, }); } @@ -65,17 +73,17 @@ async function download( window.crypto.getRandomValues(ids); const id = ids[0]; + const onProgress = new Channel(); if (progressHandler != null) { - handlers.set(id, progressHandler); + onProgress.onmessage = progressHandler; } - await listenToEventIfNeeded("download://progress"); - await invoke("plugin:upload|download", { id, url, filePath, headers: headers ?? {}, + onProgress, }); } diff --git a/plugins/upload/src/lib.rs b/plugins/upload/src/lib.rs index c4a0d8c7..714fd1ca 100644 --- a/plugins/upload/src/lib.rs +++ b/plugins/upload/src/lib.rs @@ -5,16 +5,17 @@ use futures_util::TryStreamExt; use serde::{ser::Serializer, Serialize}; use tauri::{ + api::ipc::Channel, command, plugin::{Builder as PluginBuilder, TauriPlugin}, - Runtime, Window, + Runtime, }; use tokio::{fs::File, io::AsyncWriteExt}; use tokio_util::codec::{BytesCodec, FramedRead}; use read_progress_stream::ReadProgressStream; -use std::{collections::HashMap, sync::Mutex}; +use std::collections::HashMap; type Result = std::result::Result; @@ -39,19 +40,17 @@ impl Serialize for Error { #[derive(Clone, Serialize)] struct ProgressPayload { - id: u32, progress: u64, total: u64, } #[command] async fn download( - window: Window, - id: u32, url: &str, file_path: &str, headers: HashMap, -) -> Result { + on_progress: Channel, +) -> Result<()> { let client = reqwest::Client::new(); let mut request = client.get(url); @@ -69,33 +68,28 @@ async fn download( while let Some(chunk) = stream.try_next().await? { file.write_all(&chunk).await?; - let _ = window.emit( - "download://progress", - ProgressPayload { - id, - progress: chunk.len() as u64, - total, - }, - ); + let _ = on_progress.send(&ProgressPayload { + progress: chunk.len() as u64, + total, + }); } - Ok(id) + Ok(()) } #[command] async fn upload( - window: Window, - id: u32, url: &str, file_path: &str, headers: HashMap, + on_progress: Channel, ) -> Result { // Read the file let file = File::open(file_path).await?; // Create the request and attach the file to the body let client = reqwest::Client::new(); - let mut request = client.post(url).body(file_to_body(id, window, file)); + let mut request = client.post(url).body(file_to_body(on_progress, file)); // Loop trought the headers keys and values // and add them to the request object. @@ -108,20 +102,13 @@ async fn upload( response.json().await.map_err(Into::into) } -fn file_to_body(id: u32, window: Window, file: File) -> reqwest::Body { +fn file_to_body(channel: Channel, file: File) -> reqwest::Body { let stream = FramedRead::new(file, BytesCodec::new()).map_ok(|r| r.freeze()); - let window = Mutex::new(window); + reqwest::Body::wrap_stream(ReadProgressStream::new( stream, Box::new(move |progress, total| { - let _ = window.lock().unwrap().emit( - "upload://progress", - ProgressPayload { - id, - progress, - total, - }, - ); + let _ = channel.send(&ProgressPayload { progress, total }); }), )) }