diff --git a/moq-transport/src/error.rs b/moq-transport/src/error.rs index e147e23..7804de9 100644 --- a/moq-transport/src/error.rs +++ b/moq-transport/src/error.rs @@ -1,3 +1,5 @@ +use std::io; + use thiserror::Error; use crate::VarInt; @@ -73,4 +75,19 @@ impl Error { Self::Write => "write error", } } + + /// Crudely tries to convert the Error into an io::Error. + pub fn as_io(&self) -> io::Error { + match self { + Self::Closed => io::ErrorKind::ConnectionAborted.into(), + Self::Reset(_) => io::ErrorKind::ConnectionReset.into(), + Self::Stop => io::ErrorKind::ConnectionAborted.into(), + Self::NotFound => io::ErrorKind::NotFound.into(), + Self::Duplicate => io::ErrorKind::AlreadyExists.into(), + Self::Role(_) => io::ErrorKind::PermissionDenied.into(), + Self::Unknown => io::ErrorKind::Other.into(), + Self::Read => io::ErrorKind::BrokenPipe.into(), + Self::Write => io::ErrorKind::BrokenPipe.into(), + } + } } diff --git a/moq-transport/src/model/segment.rs b/moq-transport/src/model/segment.rs index d2db43a..9f21448 100644 --- a/moq-transport/src/model/segment.rs +++ b/moq-transport/src/model/segment.rs @@ -9,10 +9,19 @@ //! //! The segment is closed with [Error::Closed] when all publishers or subscribers are dropped. use core::fmt; -use std::{ops::Deref, sync::Arc, time}; +use std::{ + future::{poll_fn, Future}, + io, + ops::Deref, + pin::Pin, + sync::Arc, + task::{ready, Context, Poll}, + time, +}; use crate::{Error, VarInt}; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; +use tokio::pin; use super::Watch; @@ -140,6 +149,9 @@ pub struct Subscriber { // NOTE: Cloned subscribers inherit this index, but then run in parallel. index: usize, + // A temporary buffer when using AsyncRead. + buffer: BytesMut, + // Dropped when all Subscribers are dropped. _dropped: Arc, } @@ -152,30 +164,83 @@ impl Subscriber { state, info, index: 0, + buffer: BytesMut::new(), _dropped, } } + /// Check if there is a chunk available. + pub fn poll_chunk(&mut self, cx: &mut Context<'_>) -> Poll, Error>> { + if !self.buffer.is_empty() { + let chunk = self.buffer.split().freeze(); + return Poll::Ready(Ok(Some(chunk))); + } + + let state = self.state.lock(); + + if self.index < state.data.len() { + let chunk = state.data[self.index].clone(); + self.index += 1; + return Poll::Ready(Ok(Some(chunk))); + } + + let notify = match state.closed { + Err(Error::Closed) => return Poll::Ready(Ok(None)), + Err(err) => return Poll::Ready(Err(err)), + Ok(()) => state.changed(), // Wake up when the state changes + }; + + // Register our context with the notify waker. + pin!(notify); + let _ = notify.poll(cx); + + Poll::Pending + } + /// Block until the next chunk of bytes is available. pub async fn read_chunk(&mut self) -> Result, Error> { - loop { - let notify = { - let state = self.state.lock(); - if self.index < state.data.len() { - let chunk = state.data[self.index].clone(); - self.index += 1; - return Ok(Some(chunk)); - } + poll_fn(|cx| self.poll_chunk(cx)).await + } +} - match state.closed { - Err(Error::Closed) => return Ok(None), - Err(err) => return Err(err), - Ok(()) => state.changed(), - } - }; +impl tokio::io::AsyncRead for Subscriber { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + if !self.buffer.is_empty() { + // Read from the existing buffer + let size = std::cmp::min(buf.remaining(), self.buffer.len()); + let data = self.buffer.split_to(size).freeze(); + buf.put_slice(&data); - notify.await; // Try again when the state changes + return Poll::Ready(Ok(())); } + + // Check if there's a new chunk available + let chunk = ready!(self.poll_chunk(cx)); + let chunk = match chunk { + // We'll read as much of it as we can, and buffer the rest. + Ok(Some(chunk)) => chunk, + + // No more data. + Ok(None) => return Poll::Ready(Ok(())), + + // TODO cast to io::Error + Err(err) => return Poll::Ready(Err(err.as_io())), + }; + + // Determine how much of the chunk we can return vs buffer. + let size = std::cmp::min(buf.remaining(), chunk.len()); + + // Return this much. + buf.put_slice(chunk[..size].as_ref()); + + // Buffer this much. + self.buffer.extend_from_slice(chunk[size..].as_ref()); + + Poll::Ready(Ok(())) } }