From 54a00732752044fdecdd33ec04ec72e0d56b9db5 Mon Sep 17 00:00:00 2001 From: iximeow Date: Thu, 29 Dec 2022 13:31:22 -0800 Subject: use fixed AsyncWrite impl --- src/io.rs | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) (limited to 'src/io.rs') diff --git a/src/io.rs b/src/io.rs index 245c1ef..219edbf 100644 --- a/src/io.rs +++ b/src/io.rs @@ -3,6 +3,8 @@ use futures_util::StreamExt; use tokio::fs::File; use std::io::Write; use tokio::fs::OpenOptions; +use std::task::{Poll, Context}; +use std::pin::Pin; pub struct ArtifactStream { sender: hyper::body::Sender, @@ -14,6 +16,39 @@ impl ArtifactStream { } } +impl tokio::io::AsyncWrite for ArtifactStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8] + ) -> Poll> { + match self.get_mut().sender.try_send_data(buf.to_vec().into()) { + Ok(()) => { + Poll::Ready(Ok(buf.len())) + }, + _ => { + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + + pub struct ArtifactDescriptor { job_id: u64, artifact_id: u64, @@ -63,6 +98,22 @@ impl ArtifactDescriptor { } } +pub async fn forward_data(source: &mut (impl AsyncRead + Unpin), dest: &mut (impl AsyncWrite + Unpin)) -> Result<(), String> { + let mut buf = vec![0; 1024 * 1024]; + loop { + let n_read = source.read(&mut buf).await + .map_err(|e| format!("failed to read: {:?}", e))?; + + if n_read == 0 { + eprintln!("done reading!"); + return Ok(()); + } + + dest.write_all(&buf[..n_read]).await + .map_err(|e| format!("failed to write: {:?}", e))?; + } +} +/* pub async fn forward_data(source: &mut (impl AsyncRead + Unpin), dest: &mut ArtifactStream) -> Result<(), String> { let mut buf = vec![0; 1024 * 1024]; loop { @@ -78,3 +129,4 @@ pub async fn forward_data(source: &mut (impl AsyncRead + Unpin), dest: &mut Arti .map_err(|e| format!("failed to write: {:?}", e))?; } } +*/ -- cgit v1.1