Luke Curley 975d6b2580 Fix the buffering used for parsing.
fill_buf didn't work like I expected. This code is much better anyway.
2023-08-02 11:25:41 -07:00

151 lines
4.2 KiB
Rust

use std::{collections::BinaryHeap, io::Cursor, sync::Arc};
use anyhow::Context;
use bytes::{Buf, BytesMut};
use moq_transport::{Decode, DecodeError, Encode, Object};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use webtransport_quinn::Session;
use crate::{RecvStream, SendStream, SendStreamOrder};
// Allow this to be cloned so we can have multiple senders.
#[derive(Clone)]
pub struct SendObjects {
// This is a tokio mutex since we need to lock across await boundaries.
inner: Arc<Mutex<SendObjectsInner>>,
}
impl SendObjects {
pub fn new(session: Session) -> Self {
let inner = SendObjectsInner::new(session);
Self {
inner: Arc::new(Mutex::new(inner)),
}
}
pub async fn open(&mut self, header: Object) -> anyhow::Result<SendStream> {
let mut inner = self.inner.lock().await;
inner.open(header).await
}
}
struct SendObjectsInner {
session: Session,
// Quinn supports a i32 for priority, but the wire format is a u64.
// Our work around is to keep a list of streams in priority order and use the index as the priority.
// This involves more work, so TODO either increase the Quinn size or reduce the wire size.
ordered: BinaryHeap<SendStreamOrder>,
ordered_swap: BinaryHeap<SendStreamOrder>, // reuse memory to avoid allocations
// A reusable buffer for encoding headers.
// TODO figure out how to use BufMut on the stack and remove this.
buf: BytesMut,
}
impl SendObjectsInner {
fn new(session: Session) -> Self {
Self {
session,
ordered: BinaryHeap::new(),
ordered_swap: BinaryHeap::new(),
buf: BytesMut::new(),
}
}
pub async fn open(&mut self, header: Object) -> anyhow::Result<SendStream> {
let stream = self.session.open_uni().await.context("failed to open uni stream")?;
let (mut stream, priority) = SendStream::with_order(stream, header.send_order.into_inner());
// Add the priority to our existing list.
self.ordered.push(priority);
// Loop through the list and update the priorities of any still active streams.
let mut index = 0;
while let Some(stream) = self.ordered.pop() {
if stream.update(index).is_ok() {
// Add the stream to the new list so it'll be in sorted order.
self.ordered_swap.push(stream);
index += 1;
}
}
// Swap the lists so we can reuse the memory.
std::mem::swap(&mut self.ordered, &mut self.ordered_swap);
// Encode and write the stream header.
// TODO do this in SendStream so we don't hold the lock.
// Otherwise,
self.buf.clear();
header.encode(&mut self.buf).unwrap();
stream.write_all(&self.buf).await.context("failed to write header")?;
// log::info!("created stream: {:?}", header);
Ok(stream)
}
}
// Not clone, so we don't accidentally have two listners.
pub struct RecvObjects {
session: Session,
// Streams that we've accepted but haven't read the header from yet.
streams: JoinSet<anyhow::Result<(Object, RecvStream)>>,
}
impl RecvObjects {
pub fn new(session: Session) -> Self {
Self {
session,
streams: JoinSet::new(),
}
}
pub async fn recv(&mut self) -> anyhow::Result<(Object, RecvStream)> {
loop {
tokio::select! {
res = self.session.accept_uni() => {
let stream = res.context("failed to accept stream")?;
self.streams.spawn(async move { Self::read(stream).await });
},
res = self.streams.join_next(), if !self.streams.is_empty() => {
return res.unwrap().context("failed to run join set")?;
}
}
}
}
async fn read(mut stream: webtransport_quinn::RecvStream) -> anyhow::Result<(Object, RecvStream)> {
let mut buf = BytesMut::new();
loop {
// Read more data into the buffer.
stream.read_buf(&mut buf).await?;
// Use a cursor to read the buffer and remember how much we read.
let mut read = Cursor::new(&mut buf);
let header = match Object::decode(&mut read) {
Ok(header) => header,
Err(DecodeError::UnexpectedEnd) => continue,
Err(err) => return Err(err.into()),
};
// We parsed a full header, advance the buffer.
let size = read.position() as usize;
buf.advance(size);
let buf = buf.freeze();
// log::info!("received stream: {:?}", header);
let stream = RecvStream::new(buf, stream);
return Ok((header, stream));
}
}
}