Implement AsyncRead for segment::Subscriber

Untested.
This commit is contained in:
Luke Curley 2023-09-17 10:53:20 -07:00
parent 639f916b6a
commit ddf22012e0
2 changed files with 99 additions and 17 deletions

View File

@ -1,3 +1,5 @@
use std::io;
use thiserror::Error;
use crate::VarInt;
@ -73,4 +75,19 @@ impl Error {
Self::Write => "write error",
}
}
/// Crudely tries to convert the Error into an io::Error.
pub fn as_io(&self) -> io::Error {
match self {
Self::Closed => io::ErrorKind::ConnectionAborted.into(),
Self::Reset(_) => io::ErrorKind::ConnectionReset.into(),
Self::Stop => io::ErrorKind::ConnectionAborted.into(),
Self::NotFound => io::ErrorKind::NotFound.into(),
Self::Duplicate => io::ErrorKind::AlreadyExists.into(),
Self::Role(_) => io::ErrorKind::PermissionDenied.into(),
Self::Unknown => io::ErrorKind::Other.into(),
Self::Read => io::ErrorKind::BrokenPipe.into(),
Self::Write => io::ErrorKind::BrokenPipe.into(),
}
}
}

View File

@ -9,10 +9,19 @@
//!
//! The segment is closed with [Error::Closed] when all publishers or subscribers are dropped.
use core::fmt;
use std::{ops::Deref, sync::Arc, time};
use std::{
future::{poll_fn, Future},
io,
ops::Deref,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
time,
};
use crate::{Error, VarInt};
use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use tokio::pin;
use super::Watch;
@ -140,6 +149,9 @@ pub struct Subscriber {
// NOTE: Cloned subscribers inherit this index, but then run in parallel.
index: usize,
// A temporary buffer when using AsyncRead.
buffer: BytesMut,
// Dropped when all Subscribers are dropped.
_dropped: Arc<Dropped>,
}
@ -152,30 +164,83 @@ impl Subscriber {
state,
info,
index: 0,
buffer: BytesMut::new(),
_dropped,
}
}
/// Check if there is a chunk available.
pub fn poll_chunk(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<Bytes>, Error>> {
if !self.buffer.is_empty() {
let chunk = self.buffer.split().freeze();
return Poll::Ready(Ok(Some(chunk)));
}
let state = self.state.lock();
if self.index < state.data.len() {
let chunk = state.data[self.index].clone();
self.index += 1;
return Poll::Ready(Ok(Some(chunk)));
}
let notify = match state.closed {
Err(Error::Closed) => return Poll::Ready(Ok(None)),
Err(err) => return Poll::Ready(Err(err)),
Ok(()) => state.changed(), // Wake up when the state changes
};
// Register our context with the notify waker.
pin!(notify);
let _ = notify.poll(cx);
Poll::Pending
}
/// Block until the next chunk of bytes is available.
pub async fn read_chunk(&mut self) -> Result<Option<Bytes>, Error> {
loop {
let notify = {
let state = self.state.lock();
if self.index < state.data.len() {
let chunk = state.data[self.index].clone();
self.index += 1;
return Ok(Some(chunk));
}
poll_fn(|cx| self.poll_chunk(cx)).await
}
}
match state.closed {
Err(Error::Closed) => return Ok(None),
Err(err) => return Err(err),
Ok(()) => state.changed(),
}
};
impl tokio::io::AsyncRead for Subscriber {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if !self.buffer.is_empty() {
// Read from the existing buffer
let size = std::cmp::min(buf.remaining(), self.buffer.len());
let data = self.buffer.split_to(size).freeze();
buf.put_slice(&data);
notify.await; // Try again when the state changes
return Poll::Ready(Ok(()));
}
// Check if there's a new chunk available
let chunk = ready!(self.poll_chunk(cx));
let chunk = match chunk {
// We'll read as much of it as we can, and buffer the rest.
Ok(Some(chunk)) => chunk,
// No more data.
Ok(None) => return Poll::Ready(Ok(())),
// TODO cast to io::Error
Err(err) => return Poll::Ready(Err(err.as_io())),
};
// Determine how much of the chunk we can return vs buffer.
let size = std::cmp::min(buf.remaining(), chunk.len());
// Return this much.
buf.put_slice(chunk[..size].as_ref());
// Buffer this much.
self.buffer.extend_from_slice(chunk[size..].as_ref());
Poll::Ready(Ok(()))
}
}