diff --git a/Cargo.lock b/Cargo.lock index 995c672..aa74566 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -782,8 +782,14 @@ name = "moq-warp" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", + "log", "moq-transport", "mp4", + "quinn", + "ring", + "rustls 0.21.2", + "rustls-pemfile", "tokio", ] diff --git a/moq-demo/src/main.rs b/moq-demo/src/main.rs index 7bac619..7f9c7bd 100644 --- a/moq-demo/src/main.rs +++ b/moq-demo/src/main.rs @@ -1,19 +1,11 @@ -mod server; -mod session; - -use server::*; - use std::{fs, io, net, path, sync}; -use std::collections::HashMap; -use std::sync::Arc; - use anyhow::Context; use clap::Parser; use ring::digest::{digest, SHA256}; use warp::Filter; -use moq_warp::Source; +use moq_warp::{relay, source}; /// Search for a pattern in a file and display the lines that contain it. #[derive(Parser, Clone)] @@ -45,20 +37,22 @@ async fn main() -> anyhow::Result<()> { let serve = serve_http(args.clone()); // Create a fake media source from disk. - let mut media = Source::new(args.media).context("failed to open fragmented.mp4")?; + let media = source::File::new(args.media).context("failed to open file source")?; - let mut broadcasts = HashMap::new(); - broadcasts.insert("demo".to_string(), media.broadcast()); + let broker = relay::broker::Broadcasts::new(); + broker + .announce("demo", media.source()) + .context("failed to announce file source")?; // Create a server to actually serve the media - let config = ServerConfig { + let config = relay::ServerConfig { addr: args.addr, cert: args.cert, key: args.key, - broadcasts: Arc::new(broadcasts), + broker, }; - let mut server = Server::new(config).context("failed to create server")?; + let server = relay::Server::new(config).context("failed to create server")?; // Run all of the above tokio::select! { diff --git a/moq-demo/src/session.rs b/moq-demo/src/session.rs deleted file mode 100644 index e726cbb..0000000 --- a/moq-demo/src/session.rs +++ /dev/null @@ -1,226 +0,0 @@ -use anyhow::Context; - -use tokio::io::AsyncWriteExt; -use tokio::task::JoinSet; - -use std::sync::Arc; - -use moq_transport::coding::VarInt; -use moq_transport::{control, data, server, setup}; -use moq_warp::{Broadcasts, Segment, Track}; - -pub struct Session { - // Used to send/receive data streams. - transport: Arc, - - // Used to send/receive control messages. - control: control::Stream, - - // The list of available broadcasts for the session. - media: Broadcasts, - - // The list of active subscriptions. - tasks: JoinSet>, -} - -impl Session { - pub async fn accept(session: server::Accept, media: Broadcasts) -> anyhow::Result { - // Accep the WebTransport session. - // OPTIONAL validate the conn.uri() otherwise call conn.reject() - let session = session - .accept() - .await - .context("failed to accept WebTransport session")?; - - session - .setup() - .versions - .iter() - .find(|v| **v == setup::Version::DRAFT_00) - .context("failed to find supported version")?; - - match session.setup().role { - setup::Role::Subscriber => {} - _ => anyhow::bail!("TODO publishing not yet supported"), - } - - let setup = setup::Server { - version: setup::Version::DRAFT_00, - role: setup::Role::Publisher, - }; - - let (transport, control) = session.accept(setup).await?; - - let session = Self { - transport: Arc::new(transport), - control, - media, - tasks: JoinSet::new(), - }; - - Ok(session) - } - - pub async fn serve(mut self) -> anyhow::Result<()> { - // TODO fix lazy: make a copy of the strings to avoid the borrow checker on self. - let broadcasts: Vec = self.media.keys().cloned().collect(); - - // Announce each available broadcast immediately. - for name in broadcasts { - self.send_message(control::Announce { - track_namespace: name.clone(), - }) - .await?; - } - - loop { - tokio::select! { - msg = self.control.recv() => { - let msg = msg.context("failed to receive control message")?; - self.handle_message(msg).await?; - }, - res = self.tasks.join_next(), if !self.tasks.is_empty() => { - let res = res.expect("no tasks").expect("task aborted"); - if let Err(err) = res { - log::error!("failed to serve subscription: {:?}", err); - } - } - } - } - } - - async fn handle_message(&mut self, msg: control::Message) -> anyhow::Result<()> { - log::info!("received message: {:?}", msg); - - // TODO implement publish and subscribe halves of the protocol. - match msg { - control::Message::Announce(_) => anyhow::bail!("ANNOUNCE not supported"), - control::Message::AnnounceOk(ref _ok) => Ok(()), // noop - control::Message::AnnounceError(ref err) => { - anyhow::bail!("received ANNOUNCE_ERROR({:?}): {}", err.code, err.reason) - } - control::Message::Subscribe(ref sub) => self.receive_subscribe(sub).await, - control::Message::SubscribeOk(_) => anyhow::bail!("SUBSCRIBE OK not supported"), - control::Message::SubscribeError(_) => anyhow::bail!("SUBSCRIBE ERROR not supported"), - control::Message::GoAway(_) => anyhow::bail!("goaway not supported"), - } - } - - async fn send_message>(&mut self, msg: T) -> anyhow::Result<()> { - let msg = msg.into(); - log::info!("sending message: {:?}", msg); - self.control.send(msg).await - } - - async fn receive_subscribe(&mut self, sub: &control::Subscribe) -> anyhow::Result<()> { - match self.subscribe(sub) { - Ok(()) => { - self.send_message(control::SubscribeOk { - track_id: sub.track_id, - expires: None, - }) - .await - } - Err(e) => { - self.send_message(control::SubscribeError { - track_id: sub.track_id, - code: VarInt::from_u32(1), - reason: e.to_string(), - }) - .await - } - } - } - - fn subscribe(&mut self, sub: &control::Subscribe) -> anyhow::Result<()> { - let broadcast = self - .media - .get(&sub.track_namespace) - .context("unknown track namespace")?; - - let track = broadcast - .tracks - .get(&sub.track_name) - .context("unknown track name")? - .clone(); - - let track_id = sub.track_id; - - let sub = Subscription { - track, - track_id, - transport: self.transport.clone(), - }; - - self.tasks.spawn(async move { sub.serve().await }); - - Ok(()) - } -} - -pub struct Subscription { - transport: Arc, - track_id: VarInt, - track: Track, -} - -impl Subscription { - pub async fn serve(mut self) -> anyhow::Result<()> { - let mut tasks = JoinSet::new(); - let mut done = false; - - loop { - tokio::select! { - // Accept new tracks added to the broadcast. - segment = self.track.segments.next(), if !done => { - match segment { - Some(segment) => { - let group = Group { - segment, - transport: self.transport.clone(), - track_id: self.track_id, - }; - - tasks.spawn(async move { group.serve().await }); - }, - None => done = true, // no more segments in the track - } - }, - // Poll any pending segments until they exit. - res = tasks.join_next(), if !tasks.is_empty() => { - let res = res.expect("no tasks").expect("task aborted"); - res.context("failed serve segment")? - }, - else => return Ok(()), // all segments received and finished serving - } - } - } -} - -struct Group { - transport: Arc, - track_id: VarInt, - segment: Segment, -} - -impl Group { - pub async fn serve(mut self) -> anyhow::Result<()> { - let header = data::Header { - track_id: self.track_id, - group_sequence: self.segment.sequence, - object_sequence: VarInt::from_u32(0), // Always zero since we send an entire group as an object - send_order: self.segment.send_order, - }; - - let mut stream = self.transport.send(header).await?; - - // Write each fragment as they are available. - while let Some(fragment) = self.segment.fragments.next().await { - stream.write_all(fragment.as_slice()).await?; - } - - // NOTE: stream is automatically closed when dropped - - Ok(()) - } -} diff --git a/moq-transport/src/control/stream.rs b/moq-transport/src/control/stream.rs index 6d5eb57..170a667 100644 --- a/moq-transport/src/control/stream.rs +++ b/moq-transport/src/control/stream.rs @@ -4,6 +4,8 @@ use crate::control::Message; use bytes::Bytes; use h3::quic::BidiStream; +use std::sync::Arc; +use tokio::sync::Mutex; pub struct Stream { sender: SendStream, @@ -13,8 +15,8 @@ pub struct Stream { impl Stream { pub(crate) fn new(stream: h3_webtransport::stream::BidiStream, Bytes>) -> Self { let (sender, recver) = stream.split(); - let sender = SendStream::new(sender); - let recver = RecvStream::new(recver); + let sender = SendStream { stream: sender }; + let recver = RecvStream { stream: recver }; Self { sender, recver } } @@ -37,12 +39,30 @@ pub struct SendStream { } impl SendStream { - pub(crate) fn new(stream: h3_webtransport::stream::SendStream, Bytes>) -> Self { - Self { stream } + pub async fn send>(&mut self, msg: T) -> anyhow::Result<()> { + let msg = msg.into(); + log::info!("sending message: {:?}", msg); + msg.encode(&mut self.stream).await } - pub async fn send(&mut self, msg: Message) -> anyhow::Result<()> { - msg.encode(&mut self.stream).await + // Helper that lets multiple threads send control messages. + pub fn share(self) -> SendShared { + SendShared { + stream: Arc::new(Mutex::new(self)), + } + } +} + +// Helper that allows multiple threads to send control messages. +#[derive(Clone)] +pub struct SendShared { + stream: Arc>, +} + +impl SendShared { + pub async fn send>(&mut self, msg: T) -> anyhow::Result<()> { + let mut stream = self.stream.lock().await; + stream.send(msg).await } } @@ -51,11 +71,9 @@ pub struct RecvStream { } impl RecvStream { - pub(crate) fn new(stream: h3_webtransport::stream::RecvStream) -> Self { - Self { stream } - } - pub async fn recv(&mut self) -> anyhow::Result { - Message::decode(&mut self.stream).await + let msg = Message::decode(&mut self.stream).await?; + log::info!("received message: {:?}", msg); + Ok(msg) } } diff --git a/moq-transport/src/lib.rs b/moq-transport/src/lib.rs index 5519519..eb66a07 100644 --- a/moq-transport/src/lib.rs +++ b/moq-transport/src/lib.rs @@ -1,6 +1,6 @@ pub mod coding; pub mod control; -pub mod data; +pub mod object; pub mod server; pub mod setup; diff --git a/moq-transport/src/data/header.rs b/moq-transport/src/object/header.rs similarity index 100% rename from moq-transport/src/data/header.rs rename to moq-transport/src/object/header.rs diff --git a/moq-transport/src/data/mod.rs b/moq-transport/src/object/mod.rs similarity index 100% rename from moq-transport/src/data/mod.rs rename to moq-transport/src/object/mod.rs diff --git a/moq-transport/src/object/recv.rs b/moq-transport/src/object/recv.rs new file mode 100644 index 0000000..f9526e8 --- /dev/null +++ b/moq-transport/src/object/recv.rs @@ -0,0 +1,26 @@ +use super::Header; + +use std::sync::Arc; + +// Reduce some typing for implementors. +pub type RecvStream = h3_webtransport::stream::RecvStream; + +// Not clone, so we don't accidentally have two listners. +pub struct Receiver { + transport: Arc, +} + +impl Receiver { + pub async fn recv(&mut self) -> anyhow::Result<(Header, RecvStream)> { + let (_session_id, mut stream) = self + .transport + .accept_uni() + .await + .context("failed to accept uni stream")? + .context("no uni stream")?; + + let header = Header::decode(&mut stream).await?; + + Ok((header, stream)) + } +} diff --git a/moq-transport/src/object/send.rs b/moq-transport/src/object/send.rs new file mode 100644 index 0000000..9f1bc87 --- /dev/null +++ b/moq-transport/src/object/send.rs @@ -0,0 +1,24 @@ +use super::{Header, SendStream, WebTransportSession}; + +pub type SendStream = h3_webtransport::stream::SendStream, Bytes>; + +#[derive(Clone)] +pub struct Sender { + transport: Arc, +} + +impl Sender { + pub async fn send(&self, header: Header) -> anyhow::Result { + let mut stream = self + .transport + .open_uni(self.transport.session_id()) + .await + .context("failed to open uni stream")?; + + // TODO set send_order based on header + + header.encode(&mut stream).await?; + + Ok(stream) + } +} diff --git a/moq-transport/src/object/session.rs b/moq-transport/src/object/session.rs new file mode 100644 index 0000000..61a7dff --- /dev/null +++ b/moq-transport/src/object/session.rs @@ -0,0 +1,35 @@ +use super::{Header, Receiver, RecvStream, SendStream, Sender}; + +use anyhow::Context; +use bytes::Bytes; + +use crate::coding::{Decode, Encode}; + +use std::sync::Arc; + +// TODO support clients +type WebTransportSession = h3_webtransport::server::WebTransportSession; + +pub struct Session { + pub send: Sender, + pub recv: Receiver, +} + +impl Session { + pub fn new(transport: WebTransportSession) -> Self { + let shared = Arc::new(transport); + + Self { + send: Sender::new(shared.clone()), + recv: Sender::new(shared), + } + } + + pub async fn recv(&mut self) -> anyhow::Result<(Header, RecvStream)> { + self.recv.recv().await + } + + pub async fn send(&self, header: Header) -> anyhow::Result { + self.send.send(header).await + } +} diff --git a/moq-transport/src/data/transport.rs b/moq-transport/src/object/transport.rs similarity index 88% rename from moq-transport/src/data/transport.rs rename to moq-transport/src/object/transport.rs index 832cd82..f59c2cf 100644 --- a/moq-transport/src/data/transport.rs +++ b/moq-transport/src/object/transport.rs @@ -20,6 +20,7 @@ impl Transport { Self { transport } } + // TODO This should be &mut self to prevent multiple threads trying to read objects pub async fn recv(&self) -> anyhow::Result<(Header, RecvStream)> { let (_session_id, mut stream) = self .transport @@ -33,6 +34,7 @@ impl Transport { Ok((header, stream)) } + // This can be &self since threads can create streams in parallel. pub async fn send(&self, header: Header) -> anyhow::Result { let mut stream = self .transport diff --git a/moq-transport/src/server/handshake.rs b/moq-transport/src/server/handshake.rs index da5816f..f4416c7 100644 --- a/moq-transport/src/server/handshake.rs +++ b/moq-transport/src/server/handshake.rs @@ -1,5 +1,5 @@ use super::setup::{RecvSetup, SendSetup}; -use crate::{control, data, setup}; +use crate::{control, object, setup}; use anyhow::Context; use bytes::Bytes; @@ -71,7 +71,7 @@ impl Accept { .context("failed to accept bidi stream")? .unwrap(); - let transport = data::Transport::new(transport); + let transport = object::Transport::new(transport); let stream = match stream { h3_webtransport::server::AcceptedBi::BidiStream(_session_id, stream) => stream, @@ -92,7 +92,7 @@ impl Accept { pub struct Setup { setup: SendSetup, - transport: data::Transport, + transport: object::Transport, } impl Setup { @@ -102,7 +102,7 @@ impl Setup { } // Accept the session with our own setup message. - pub async fn accept(self, setup: setup::Server) -> anyhow::Result<(data::Transport, control::Stream)> { + pub async fn accept(self, setup: setup::Server) -> anyhow::Result<(object::Transport, control::Stream)> { let control = self.setup.send(setup).await?; Ok((self.transport, control)) } diff --git a/moq-warp/Cargo.toml b/moq-warp/Cargo.toml index 906218a..48f3ab9 100644 --- a/moq-warp/Cargo.toml +++ b/moq-warp/Cargo.toml @@ -20,3 +20,12 @@ moq-transport = { path = "../moq-transport" } tokio = "1.27" mp4 = "0.13.0" anyhow = "1.0.70" +log = "0.4" # TODO remove + +# QUIC stuff +quinn = "0.10" +ring = "0.16.20" +rustls = "0.21.2" +rustls-pemfile = "1.0.2" + +async-trait = "0.1" diff --git a/moq-warp/src/lib.rs b/moq-warp/src/lib.rs index 923cde4..c025c8e 100644 --- a/moq-warp/src/lib.rs +++ b/moq-warp/src/lib.rs @@ -1,8 +1,3 @@ -mod source; -pub use source::Source; - -mod model; -pub use model::*; - -mod watch; -use watch::{Producer, Subscriber}; +pub mod model; +pub mod relay; +pub mod source; diff --git a/moq-warp/src/model.rs b/moq-warp/src/model.rs deleted file mode 100644 index 7734403..0000000 --- a/moq-warp/src/model.rs +++ /dev/null @@ -1,40 +0,0 @@ -use super::Subscriber; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Instant; - -use moq_transport::VarInt; - -// Map from track namespace to broadcast. -// TODO support updates -pub type Broadcasts = Arc>; - -#[derive(Clone)] -pub struct Broadcast { - // TODO support updates. - pub tracks: Arc>, -} - -#[derive(Clone)] -pub struct Track { - // A list of segments, which are independently decodable. - pub segments: Subscriber, -} - -#[derive(Clone)] -pub struct Segment { - // The sequence number of the segment within the track. - pub sequence: VarInt, - - // The priority of the segment within the BROADCAST. - pub send_order: VarInt, - - // The time at which the segment expires for cache purposes. - pub expires: Option, - - // A list of fragments that make up the segment. - pub fragments: Subscriber, -} - -// Use Arc to avoid cloning the entire MP4 data for each subscriber. -pub type Fragment = Arc>; diff --git a/moq-warp/src/model/broadcast.rs b/moq-warp/src/model/broadcast.rs new file mode 100644 index 0000000..3bd0dee --- /dev/null +++ b/moq-warp/src/model/broadcast.rs @@ -0,0 +1,64 @@ +use std::{error, fmt}; + +use moq_transport::VarInt; + +// TODO generialize broker::Broadcasts and source::Source into this module. + +/* +pub struct Publisher { + pub namespace: String, + + pub tracks: watch::Publisher, +} + +impl Publisher { + pub fn new(namespace: &str) -> Self { + Self { + namespace: namespace.to_string(), + tracks: watch::Publisher::new(), + } + } + + pub fn subscribe(&self) -> Subscriber { + Subscriber { + namespace: self.namespace.clone(), + tracks: self.tracks.subscribe(), + } + } +} + +#[derive(Clone)] +pub struct Subscriber { + pub namespace: String, + + pub tracks: watch::Subscriber, +} +*/ + +#[derive(Clone)] +pub struct Error { + pub code: VarInt, + pub reason: String, +} + +impl error::Error for Error {} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if !self.reason.is_empty() { + write!(f, "broadcast error ({}): {}", self.code, self.reason) + } else { + write!(f, "broadcast error ({})", self.code) + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if !self.reason.is_empty() { + write!(f, "broadcast error ({}): {}", self.code, self.reason) + } else { + write!(f, "broadcast error ({})", self.code) + } + } +} diff --git a/moq-warp/src/model/fragment.rs b/moq-warp/src/model/fragment.rs new file mode 100644 index 0000000..9fdcf74 --- /dev/null +++ b/moq-warp/src/model/fragment.rs @@ -0,0 +1,10 @@ +use super::watch; +use std::sync::Arc; + +// Use Arc to avoid cloning the entire MP4 data for each subscriber. +pub type Shared = Arc>; + +// TODO combine fragments into the same buffer, instead of separate buffers. + +pub type Publisher = watch::Publisher; +pub type Subscriber = watch::Subscriber; diff --git a/moq-warp/src/model/mod.rs b/moq-warp/src/model/mod.rs new file mode 100644 index 0000000..610d7b9 --- /dev/null +++ b/moq-warp/src/model/mod.rs @@ -0,0 +1,5 @@ +pub mod broadcast; +pub mod fragment; +pub mod segment; +pub mod track; +pub(crate) mod watch; diff --git a/moq-warp/src/model/segment.rs b/moq-warp/src/model/segment.rs new file mode 100644 index 0000000..7fb6533 --- /dev/null +++ b/moq-warp/src/model/segment.rs @@ -0,0 +1,65 @@ +use super::{fragment, watch}; + +use moq_transport::VarInt; +use std::ops::Deref; +use std::sync::Arc; +use std::time; + +#[derive(Clone, Debug)] +pub struct Info { + // The sequence number of the segment within the track. + pub sequence: VarInt, + + // The priority of the segment within the BROADCAST. + pub send_order: VarInt, + + // The time at which the segment expires for cache purposes. + pub expires: Option, +} + +pub struct Publisher { + pub info: Arc, + + // A list of fragments that make up the segment. + pub fragments: watch::Publisher, +} + +impl Publisher { + pub fn new(info: Info) -> Self { + Self { + info: Arc::new(info), + fragments: watch::Publisher::new(), + } + } + + pub fn subscribe(&self) -> Subscriber { + Subscriber { + info: self.info.clone(), + fragments: self.fragments.subscribe(), + } + } +} + +impl Deref for Publisher { + type Target = Info; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +#[derive(Clone)] +pub struct Subscriber { + pub info: Arc, + + // A list of fragments that make up the segment. + pub fragments: watch::Subscriber, +} + +impl Deref for Subscriber { + type Target = Info; + + fn deref(&self) -> &Self::Target { + &self.info + } +} diff --git a/moq-warp/src/model/track.rs b/moq-warp/src/model/track.rs new file mode 100644 index 0000000..4cc38ce --- /dev/null +++ b/moq-warp/src/model/track.rs @@ -0,0 +1,107 @@ +use super::{segment, watch}; +use std::{error, fmt, time}; + +use moq_transport::VarInt; + +pub struct Publisher { + pub name: String, + + segments: watch::Publisher>, +} + +impl Publisher { + pub fn new(name: &str) -> Publisher { + Self { + name: name.to_string(), + segments: watch::Publisher::new(), + } + } + + pub fn push_segment(&mut self, segment: segment::Subscriber) { + self.segments.push(Ok(segment)) + } + + pub fn drain_segments(&mut self, before: time::Instant) { + self.segments.drain(|segment| { + if let Ok(segment) = segment { + if let Some(expires) = segment.expires { + return expires < before; + } + } + + false + }) + } + + pub fn close(mut self, err: Error) { + self.segments.push(Err(err)) + } + + pub fn subscribe(&self) -> Subscriber { + Subscriber { + name: self.name.clone(), + segments: self.segments.subscribe(), + } + } +} + +impl fmt::Debug for Publisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "track publisher: {:?}", self.name) + } +} + +#[derive(Clone)] +pub struct Subscriber { + pub name: String, + + // A list of segments, which are independently decodable. + segments: watch::Subscriber>, +} + +impl Subscriber { + pub async fn next_segment(&mut self) -> Result { + let res = self.segments.next().await; + match res { + None => Err(Error { + code: VarInt::from_u32(0), + reason: String::from("closed"), + }), + Some(res) => res, + } + } +} + +impl fmt::Debug for Subscriber { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "track subscriber: {:?}", self.name) + } +} + +#[derive(Clone)] +pub struct Error { + pub code: VarInt, + pub reason: String, +} + +impl error::Error for Error {} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if !self.reason.is_empty() { + write!(f, "track error ({}): {}", self.code, self.reason) + } else { + write!(f, "track error ({})", self.code) + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if !self.reason.is_empty() { + write!(f, "track error ({}): {}", self.code, self.reason) + } else { + write!(f, "track error ({})", self.code) + } + } +} diff --git a/moq-warp/src/watch.rs b/moq-warp/src/model/watch.rs similarity index 76% rename from moq-warp/src/watch.rs rename to moq-warp/src/model/watch.rs index a68ef7a..8c64576 100644 --- a/moq-warp/src/watch.rs +++ b/moq-warp/src/model/watch.rs @@ -5,7 +5,6 @@ use tokio::sync::watch; struct State { queue: VecDeque, drained: usize, - closed: bool, } impl State { @@ -13,7 +12,6 @@ impl State { Self { queue: VecDeque::new(), drained: 0, - closed: false, } } @@ -42,11 +40,11 @@ impl State { } } -pub struct Producer { +pub struct Publisher { sender: watch::Sender>, } -impl Producer { +impl Publisher { pub fn new() -> Self { let state = State::new(); let (sender, _) = watch::channel(state); @@ -70,23 +68,23 @@ impl Producer { }); } + // Subscribe for all NEW updates. pub fn subscribe(&self) -> Subscriber { - Subscriber::new(self.sender.subscribe()) + let index = self.sender.borrow().queue.len(); + + Subscriber { + state: self.sender.subscribe(), + index, + } } } -impl Default for Producer { +impl Default for Publisher { fn default() -> Self { Self::new() } } -impl Drop for Producer { - fn drop(&mut self) { - self.sender.send_modify(|state| state.closed = true); - } -} - #[derive(Clone)] pub struct Subscriber { state: watch::Receiver>, @@ -94,17 +92,17 @@ pub struct Subscriber { } impl Subscriber { - fn new(state: watch::Receiver>) -> Self { - Self { state, index: 0 } - } - pub async fn next(&mut self) -> Option { // Wait until the queue has a new element or if it's closed. let state = self .state - .wait_for(|state| state.closed || self.index < state.drained + state.queue.len()) - .await - .expect("publisher dropped without close"); + .wait_for(|state| self.index < state.drained + state.queue.len()) + .await; + + let state = match state { + Ok(state) => state, + Err(_) => return None, // publisher was dropped + }; // If our index is smaller than drained, skip past those elements we missed. let index = self.index.saturating_sub(state.drained); @@ -117,9 +115,6 @@ impl Subscriber { self.index = index + state.drained + 1; Some(element) - } else if state.closed { - // Return None if we've consumed all entries and the queue is closed. - None } else { unreachable!("impossible subscriber state") } diff --git a/moq-warp/src/relay/broker.rs b/moq-warp/src/relay/broker.rs new file mode 100644 index 0000000..104c795 --- /dev/null +++ b/moq-warp/src/relay/broker.rs @@ -0,0 +1,77 @@ +use crate::model::{broadcast, track, watch}; +use crate::source::Source; + +use std::collections::hash_map::HashMap; +use std::sync::{Arc, Mutex}; + +use anyhow::Context; + +#[derive(Clone, Default)] +pub struct Broadcasts { + // Operate on the inner struct so we can share/clone the outer struct. + inner: Arc>, +} + +#[derive(Default)] +struct BroadcastsInner { + // TODO Automatically reclaim dropped sources. + lookup: HashMap>, + updates: watch::Publisher, +} + +#[derive(Clone)] +pub enum Update { + // Broadcast was announced + Insert(String), // TODO include source? + + // Broadcast was unannounced + Remove(String, broadcast::Error), +} + +impl Broadcasts { + pub fn new() -> Self { + Default::default() + } + + // Return the list of available broadcasts, and a subscriber that will return updates (add/remove). + pub fn available(&self) -> (Vec, watch::Subscriber) { + // Grab the lock. + let this = self.inner.lock().unwrap(); + + // Get the list of all available tracks. + let keys = this.lookup.keys().cloned().collect(); + + // Get a subscriber that will return future updates. + let updates = this.updates.subscribe(); + + (keys, updates) + } + + pub fn announce(&self, namespace: &str, source: Arc) -> anyhow::Result<()> { + let mut this = self.inner.lock().unwrap(); + + if let Some(_existing) = this.lookup.get(namespace) { + anyhow::bail!("namespace already registered"); + } + + this.lookup.insert(namespace.to_string(), source); + this.updates.push(Update::Insert(namespace.to_string())); + + Ok(()) + } + + pub fn unannounce(&self, namespace: &str, error: broadcast::Error) -> anyhow::Result<()> { + let mut this = self.inner.lock().unwrap(); + + this.lookup.remove(namespace).context("namespace was not published")?; + this.updates.push(Update::Remove(namespace.to_string(), error)); + + Ok(()) + } + + pub fn subscribe(&self, namespace: &str, name: &str) -> Option { + let this = self.inner.lock().unwrap(); + + this.lookup.get(namespace).and_then(|v| v.subscribe(name)) + } +} diff --git a/moq-warp/src/relay/contribute.rs b/moq-warp/src/relay/contribute.rs new file mode 100644 index 0000000..d6dfd79 --- /dev/null +++ b/moq-warp/src/relay/contribute.rs @@ -0,0 +1,298 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time; + +use tokio::io::AsyncReadExt; +use tokio::sync::mpsc; +use tokio::task::JoinSet; // lock across await boundaries + +use moq_transport::coding::VarInt; +use moq_transport::object; + +use anyhow::Context; + +use super::{broker, control}; +use crate::model::{broadcast, segment, track}; +use crate::source::Source; + +// TODO experiment with making this Clone, so every task can have its own copy. +pub struct Session { + // Used to receive objects. + // TODO split into send/receive halves. + transport: Arc, + + // Used to send and receive control messages. + control: control::Component, + + // Globally announced namespaces, which we can add ourselves to. + broker: broker::Broadcasts, + + // The names of active broadcasts being produced. + broadcasts: HashMap>, + + // Active tracks being produced by this session. + publishers: Publishers, + + // Tasks we are currently serving. + run_segments: JoinSet>, // receiving objects +} + +impl Session { + pub fn new( + transport: Arc, + control: control::Component, + broker: broker::Broadcasts, + ) -> Self { + Self { + transport, + control, + broker, + broadcasts: HashMap::new(), + publishers: Publishers::new(), + run_segments: JoinSet::new(), + } + } + + pub async fn run(mut self) -> anyhow::Result<()> { + loop { + tokio::select! { + res = self.run_segments.join_next(), if !self.run_segments.is_empty() => { + let res = res.expect("no tasks").expect("task aborted"); + if let Err(err) = res { + log::error!("failed to produce segment: {:?}", err); + } + }, + object = self.transport.recv() => { + let (header, stream )= object.context("failed to receive object")?; + let res = self.receive_object(header, stream).await; + if let Err(err) = res { + log::error!("failed to receive object: {:?}", err); + } + }, + subscribe = self.publishers.incoming() => { + let msg = subscribe.context("failed to receive subscription")?; + self.control.send(msg).await?; + }, + msg = self.control.recv() => { + let msg = msg.context("failed to receive control message")?; + self.receive_message(msg).await?; + }, + } + } + } + + async fn receive_message(&mut self, msg: control::Contribute) -> anyhow::Result<()> { + match msg { + control::Contribute::Announce(msg) => self.receive_announce(msg).await, + control::Contribute::SubscribeOk(msg) => self.receive_subscribe_ok(msg), + control::Contribute::SubscribeError(msg) => self.receive_subscribe_error(msg), + } + } + + async fn receive_object(&mut self, header: object::Header, stream: object::RecvStream) -> anyhow::Result<()> { + let id = header.track_id; + + let segment = segment::Info { + sequence: header.object_sequence, + send_order: header.send_order, + expires: Some(time::Instant::now() + time::Duration::from_secs(10)), + }; + + let segment = segment::Publisher::new(segment); + + self.publishers + .push_segment(id, segment.subscribe()) + .context("failed to publish segment")?; + + // TODO implement a timeout + + self.run_segments + .spawn(async move { Self::run_segment(segment, stream).await }); + + Ok(()) + } + + async fn run_segment(mut segment: segment::Publisher, mut stream: object::RecvStream) -> anyhow::Result<()> { + let mut buf = [0u8; 32 * 1024]; + loop { + let size = stream.read(&mut buf).await.context("failed to read from stream")?; + if size == 0 { + return Ok(()); + } + + let chunk = buf[..size].to_vec(); + segment.fragments.push(chunk.into()) + } + } + + async fn receive_announce(&mut self, msg: control::Announce) -> anyhow::Result<()> { + match self.receive_announce_inner(&msg).await { + Ok(()) => { + let msg = control::AnnounceOk { + track_namespace: msg.track_namespace, + }; + self.control.send(msg).await + } + Err(e) => { + let msg = control::AnnounceError { + track_namespace: msg.track_namespace, + code: VarInt::from_u32(1), + reason: e.to_string(), + }; + self.control.send(msg).await + } + } + } + + async fn receive_announce_inner(&mut self, msg: &control::Announce) -> anyhow::Result<()> { + // Create a broadcast and announce it. + // We don't actually start producing the broadcast until we receive a subscription. + let broadcast = Arc::new(Broadcast::new(&msg.track_namespace, &self.publishers)); + + self.broker.announce(&msg.track_namespace, broadcast.clone())?; + self.broadcasts.insert(msg.track_namespace.clone(), broadcast); + + Ok(()) + } + + fn receive_subscribe_ok(&mut self, _msg: control::SubscribeOk) -> anyhow::Result<()> { + // TODO make sure this is for a track we are subscribed to + Ok(()) + } + + fn receive_subscribe_error(&mut self, msg: control::SubscribeError) -> anyhow::Result<()> { + let error = track::Error { + code: msg.code, + reason: format!("upstream error: {}", msg.reason), + }; + + // Stop producing the track. + self.publishers + .close(msg.track_id, error) + .context("failed to close track")?; + + Ok(()) + } +} + +impl Drop for Session { + fn drop(&mut self) { + // Unannounce all broadcasts we have announced. + // TODO make this automatic so we can't screw up? + // TOOD Implement UNANNOUNCE so we can return good errors. + for broadcast in self.broadcasts.keys() { + let error = broadcast::Error { + code: VarInt::from_u32(1), + reason: "connection closed".to_string(), + }; + + self.broker.unannounce(broadcast, error).unwrap(); + } + } +} + +// A list of subscriptions for a broadcast. +#[derive(Clone)] +pub struct Broadcast { + // Our namespace + namespace: String, + + // A lookup from name to a subscription (duplicate subscribers) + subscriptions: Arc>>, + + // Issue a SUBSCRIBE message for a new subscription (new subscriber) + queue: mpsc::UnboundedSender<(String, track::Publisher)>, +} + +impl Broadcast { + pub fn new(namespace: &str, publishers: &Publishers) -> Self { + Self { + namespace: namespace.to_string(), + subscriptions: Default::default(), + queue: publishers.sender.clone(), + } + } +} + +impl Source for Broadcast { + fn subscribe(&self, name: &str) -> Option { + let mut subscriptions = self.subscriptions.lock().unwrap(); + + // Check if there's an existing subscription. + if let Some(subscriber) = subscriptions.get(name).cloned() { + return Some(subscriber); + } + + // Otherwise, make a new track and tell the publisher to fufill it. + let track = track::Publisher::new(name); + let subscriber = track.subscribe(); + + // Save the subscriber for duplication. + subscriptions.insert(name.to_string(), subscriber.clone()); + + // Send the publisher to another thread to actually subscribe. + self.queue.send((self.namespace.clone(), track)).unwrap(); + + // Return the subscriber we created. + Some(subscriber) + } +} + +pub struct Publishers { + // A lookup from subscription ID to a track being produced, or none if it's been closed. + tracks: HashMap>, + + // The next subscription ID + next: u64, + + // A queue of subscriptions that we need to fulfill + receiver: mpsc::UnboundedReceiver<(String, track::Publisher)>, + + // A clonable queue, so other threads can issue subscriptions. + sender: mpsc::UnboundedSender<(String, track::Publisher)>, +} + +impl Publishers { + pub fn new() -> Self { + let (sender, receiver) = mpsc::unbounded_channel(); + + Self { + tracks: Default::default(), + next: 0, + sender, + receiver, + } + } + + pub fn push_segment(&mut self, id: VarInt, segment: segment::Subscriber) -> anyhow::Result<()> { + let track = self.tracks.get_mut(&id).context("no track with that ID")?; + let track = track.as_mut().context("track closed")?; // TODO don't make fatal + + track.push_segment(segment); + + Ok(()) + } + + pub fn close(&mut self, id: VarInt, err: track::Error) -> anyhow::Result<()> { + let track = self.tracks.get_mut(&id).context("no track with that ID")?; + let track = track.take().context("track closed")?; + track.close(err); + + Ok(()) + } + + // Returns the next subscribe message we need to issue. + pub async fn incoming(&mut self) -> anyhow::Result { + let (namespace, track) = self.receiver.recv().await.context("no more subscriptions")?; + + let msg = control::Subscribe { + track_id: VarInt::try_from(self.next)?, + track_namespace: namespace, + track_name: track.name, + }; + + self.next += 1; + + Ok(msg) + } +} diff --git a/moq-warp/src/relay/control.rs b/moq-warp/src/relay/control.rs new file mode 100644 index 0000000..9339e6a --- /dev/null +++ b/moq-warp/src/relay/control.rs @@ -0,0 +1,119 @@ +use moq_transport::control; +use tokio::sync::mpsc; + +pub use control::*; + +pub struct Main { + control: control::Stream, + outgoing: mpsc::Receiver, + + contribute: mpsc::Sender, + distribute: mpsc::Sender, +} + +impl Main { + pub async fn run(mut self) -> anyhow::Result<()> { + loop { + tokio::select! { + Some(msg) = self.outgoing.recv() => self.control.send(msg).await?, + Ok(msg) = self.control.recv() => self.handle(msg).await?, + } + } + } + + pub async fn handle(&mut self, msg: control::Message) -> anyhow::Result<()> { + match msg.try_into() { + Ok(msg) => self.contribute.send(msg).await?, + Err(msg) => match msg.try_into() { + Ok(msg) => self.distribute.send(msg).await?, + Err(msg) => anyhow::bail!("unsupported control message: {:?}", msg), + }, + } + + Ok(()) + } +} + +pub struct Component { + incoming: mpsc::Receiver, + outgoing: mpsc::Sender, +} + +impl Component { + pub async fn send>(&mut self, msg: M) -> anyhow::Result<()> { + self.outgoing.send(msg.into()).await?; + Ok(()) + } + + pub async fn recv(&mut self) -> Option { + self.incoming.recv().await + } +} + +// Splits a control stream into two components, based on if it's a message for contribution or distribution. +pub fn split(control: control::Stream) -> (Main, Component, Component) { + let (outgoing_tx, outgoing_rx) = mpsc::channel(1); + let (contribute_tx, contribute_rx) = mpsc::channel(1); + let (distribute_tx, distribute_rx) = mpsc::channel(1); + + let control = Main { + control, + outgoing: outgoing_rx, + contribute: contribute_tx, + distribute: distribute_tx, + }; + + let contribute = Component { + incoming: contribute_rx, + outgoing: outgoing_tx.clone(), + }; + + let distribute = Component { + incoming: distribute_rx, + outgoing: outgoing_tx, + }; + + (control, contribute, distribute) +} + +// Messages we expect to receive from the client for contribution. +#[derive(Debug)] +pub enum Contribute { + Announce(control::Announce), + SubscribeOk(control::SubscribeOk), + SubscribeError(control::SubscribeError), +} + +impl TryFrom for Contribute { + type Error = control::Message; + + fn try_from(msg: control::Message) -> Result { + match msg { + control::Message::Announce(msg) => Ok(Self::Announce(msg)), + control::Message::SubscribeOk(msg) => Ok(Self::SubscribeOk(msg)), + control::Message::SubscribeError(msg) => Ok(Self::SubscribeError(msg)), + _ => Err(msg), + } + } +} + +// Messages we expect to receive from the client for distribution. +#[derive(Debug)] +pub enum Distribute { + AnnounceOk(control::AnnounceOk), + AnnounceError(control::AnnounceError), + Subscribe(control::Subscribe), +} + +impl TryFrom for Distribute { + type Error = control::Message; + + fn try_from(value: control::Message) -> Result { + match value { + control::Message::AnnounceOk(msg) => Ok(Self::AnnounceOk(msg)), + control::Message::AnnounceError(msg) => Ok(Self::AnnounceError(msg)), + control::Message::Subscribe(msg) => Ok(Self::Subscribe(msg)), + _ => Err(value), + } + } +} diff --git a/moq-warp/src/relay/distribute.rs b/moq-warp/src/relay/distribute.rs new file mode 100644 index 0000000..152ca90 --- /dev/null +++ b/moq-warp/src/relay/distribute.rs @@ -0,0 +1,206 @@ +use anyhow::Context; + +use tokio::io::AsyncWriteExt; +use tokio::task::JoinSet; // allows locking across await + +use std::sync::Arc; + +use moq_transport::coding::VarInt; +use moq_transport::object; + +use super::{broker, control}; +use crate::model::{segment, track}; + +pub struct Session { + // Objects are sent to the client using this transport. + transport: Arc, + + // Used to send and receive control messages. + control: control::Component, + + // Globally announced namespaces, which can be subscribed to. + broker: broker::Broadcasts, + + // A list of tasks that are currently running. + run_subscribes: JoinSet, // run subscriptions, sending the returned error if they fail +} + +impl Session { + pub fn new( + transport: Arc, + control: control::Component, + broker: broker::Broadcasts, + ) -> Self { + Self { + transport, + control, + broker, + run_subscribes: JoinSet::new(), + } + } + + pub async fn run(mut self) -> anyhow::Result<()> { + // Announce all available tracks and get a stream of updates. + let (available, mut updates) = self.broker.available(); + for namespace in available { + self.on_available(broker::Update::Insert(namespace)).await?; + } + + loop { + tokio::select! { + res = self.run_subscribes.join_next(), if !self.run_subscribes.is_empty() => { + let res = res.expect("no tasks").expect("task aborted"); + self.control.send(res).await?; + }, + delta = updates.next() => { + let delta = delta.expect("no more broadcasts"); + self.on_available(delta).await?; + }, + msg = self.control.recv() => { + let msg = msg.context("failed to receive control message")?; + self.receive_message(msg).await?; + }, + } + } + } + + async fn receive_message(&mut self, msg: control::Distribute) -> anyhow::Result<()> { + match msg { + control::Distribute::AnnounceOk(msg) => self.receive_announce_ok(msg), + control::Distribute::AnnounceError(msg) => self.receive_announce_error(msg), + control::Distribute::Subscribe(msg) => self.receive_subscribe(msg).await, + } + } + + fn receive_announce_ok(&mut self, _msg: control::AnnounceOk) -> anyhow::Result<()> { + // TODO make sure we sent this announce + Ok(()) + } + + fn receive_announce_error(&mut self, msg: control::AnnounceError) -> anyhow::Result<()> { + // TODO make sure we sent this announce + // TODO remove this from the list of subscribable broadcasts. + anyhow::bail!("received ANNOUNCE_ERROR({:?}): {}", msg.code, msg.reason) + } + + async fn receive_subscribe(&mut self, msg: control::Subscribe) -> anyhow::Result<()> { + match self.receive_subscribe_inner(&msg).await { + Ok(()) => { + self.control + .send(control::SubscribeOk { + track_id: msg.track_id, + expires: None, + }) + .await + } + Err(e) => { + self.control + .send(control::SubscribeError { + track_id: msg.track_id, + code: VarInt::from_u32(1), + reason: e.to_string(), + }) + .await + } + } + } + + async fn receive_subscribe_inner(&mut self, msg: &control::Subscribe) -> anyhow::Result<()> { + let track = self + .broker + .subscribe(&msg.track_namespace, &msg.track_name) + .context("could not find broadcast")?; + + // TODO can we just clone self? + let transport = self.transport.clone(); + let track_id = msg.track_id; + + self.run_subscribes + .spawn(async move { Self::run_subscribe(transport, track_id, track).await }); + + Ok(()) + } + + async fn run_subscribe( + transport: Arc, + track_id: VarInt, + mut track: track::Subscriber, + ) -> control::SubscribeError { + let mut tasks = JoinSet::new(); + let mut result = None; + + loop { + tokio::select! { + // Accept new segments added to the track. + segment = track.next_segment(), if result.is_none() => { + match segment { + Ok(segment) => { + let transport = transport.clone(); + tasks.spawn(async move { Self::serve_group(transport, track_id, segment).await }); + }, + Err(e) => { + result = Some(control::SubscribeError { + track_id, + code: e.code, + reason: e.reason, + }) + }, + } + }, + // Poll any pending segments until they exit. + res = tasks.join_next(), if !tasks.is_empty() => { + let res = res.expect("no tasks").expect("task aborted"); + if let Err(err) = res { + log::error!("failed to serve segment: {:?}", err); + } + }, + else => return result.unwrap() + } + } + } + + async fn serve_group( + transport: Arc, + track_id: VarInt, + mut segment: segment::Subscriber, + ) -> anyhow::Result<()> { + let header = object::Header { + track_id, + group_sequence: segment.sequence, + object_sequence: VarInt::from_u32(0), // Always zero since we send an entire group as an object + send_order: segment.send_order, + }; + + let mut stream = transport.send(header).await?; + + // Write each fragment as they are available. + while let Some(fragment) = segment.fragments.next().await { + stream.write_all(fragment.as_slice()).await?; + } + + // NOTE: stream is automatically closed when dropped + + Ok(()) + } + + async fn on_available(&mut self, delta: broker::Update) -> anyhow::Result<()> { + match delta { + broker::Update::Insert(name) => { + self.control + .send(control::Announce { + track_namespace: name.clone(), + }) + .await + } + broker::Update::Remove(name, error) => { + self.control + .send(control::AnnounceError { + track_namespace: name, + code: error.code, + reason: error.reason, + }) + .await + } + } + } +} diff --git a/moq-warp/src/relay/mod.rs b/moq-warp/src/relay/mod.rs new file mode 100644 index 0000000..485adfa --- /dev/null +++ b/moq-warp/src/relay/mod.rs @@ -0,0 +1,10 @@ +pub mod broker; + +mod contribute; +mod control; +mod distribute; +mod server; +mod session; + +pub use server::*; +pub use session::*; diff --git a/moq-demo/src/server.rs b/moq-warp/src/relay/server.rs similarity index 79% rename from moq-demo/src/server.rs rename to moq-warp/src/relay/server.rs index 1dcd740..e39c66b 100644 --- a/moq-demo/src/server.rs +++ b/moq-warp/src/relay/server.rs @@ -1,7 +1,6 @@ -use super::session::Session; +use super::{broker, Session}; use moq_transport::server::Endpoint; -use moq_warp::Broadcasts; use std::{fs, io, net, path, sync, time}; @@ -13,11 +12,11 @@ pub struct Server { // The MoQ transport server. server: Endpoint, - // The media source. - broadcasts: Broadcasts, + // The media sources. + broker: broker::Broadcasts, // Sessions actively being run. - sessions: JoinSet>, + tasks: JoinSet>, } pub struct ServerConfig { @@ -25,7 +24,7 @@ pub struct ServerConfig { pub cert: path::PathBuf, pub key: path::PathBuf, - pub broadcasts: Broadcasts, + pub broker: broker::Broadcasts, } impl Server { @@ -75,31 +74,27 @@ impl Server { server_config.transport = sync::Arc::new(transport_config); let server = quinn::Endpoint::server(server_config, config.addr)?; - let broadcasts = config.broadcasts; + let broker = config.broker; let server = Endpoint::new(server); - let sessions = JoinSet::new(); + let tasks = JoinSet::new(); - Ok(Self { - server, - broadcasts, - sessions, - }) + Ok(Self { server, broker, tasks }) } - pub async fn run(&mut self) -> anyhow::Result<()> { + pub async fn run(mut self) -> anyhow::Result<()> { loop { tokio::select! { res = self.server.accept() => { let session = res.context("failed to accept connection")?; - let broadcasts = self.broadcasts.clone(); + let broker = self.broker.clone(); - self.sessions.spawn(async move { - let session: Session = Session::accept(session, broadcasts).await?; - session.serve().await + self.tasks.spawn(async move { + let session: Session = Session::accept(session, broker).await?; + session.run().await }); }, - res = self.sessions.join_next(), if !self.sessions.is_empty() => { + res = self.tasks.join_next(), if !self.tasks.is_empty() => { let res = res.expect("no tasks").expect("task aborted"); if let Err(err) = res { diff --git a/moq-warp/src/relay/session.rs b/moq-warp/src/relay/session.rs new file mode 100644 index 0000000..0c1e9e4 --- /dev/null +++ b/moq-warp/src/relay/session.rs @@ -0,0 +1,70 @@ +use anyhow::Context; + +use std::sync::Arc; + +use moq_transport::{server, setup}; + +use super::{broker, contribute, control, distribute}; + +pub struct Session { + // Split logic into contribution/distribution to reduce the problem space. + contribute: contribute::Session, + distribute: distribute::Session, + + // Used to receive control messages and forward to contribute/distribute. + control: control::Main, +} + +impl Session { + pub async fn accept(session: server::Accept, broker: broker::Broadcasts) -> anyhow::Result { + // Accep the WebTransport session. + // OPTIONAL validate the conn.uri() otherwise call conn.reject() + let session = session + .accept() + .await + .context(": server::Setupfailed to accept WebTransport session")?; + + session + .setup() + .versions + .iter() + .find(|v| **v == setup::Version::DRAFT_00) + .context("failed to find supported version")?; + + match session.setup().role { + setup::Role::Subscriber => {} + _ => anyhow::bail!("TODO publishing not yet supported"), + } + + let setup = setup::Server { + version: setup::Version::DRAFT_00, + role: setup::Role::Publisher, + }; + + let (transport, control) = session.accept(setup).await?; + let transport = Arc::new(transport); + + let (control, contribute, distribute) = control::split(control); + + let contribute = contribute::Session::new(transport.clone(), contribute, broker.clone()); + let distribute = distribute::Session::new(transport, distribute, broker); + + let session = Self { + control, + contribute, + distribute, + }; + + Ok(session) + } + + pub async fn run(self) -> anyhow::Result<()> { + let control = self.control.run(); + let contribute = self.contribute.run(); + let distribute = self.distribute.run(); + + tokio::try_join!(control, contribute, distribute)?; + + Ok(()) + } +} diff --git a/moq-warp/src/source.rs b/moq-warp/src/source/file.rs similarity index 72% rename from moq-warp/src/source.rs rename to moq-warp/src/source/file.rs index 6dc87db..9993eaf 100644 --- a/moq-warp/src/source.rs +++ b/moq-warp/src/source/file.rs @@ -10,20 +10,24 @@ use std::sync::Arc; use moq_transport::VarInt; -use super::{Broadcast, Fragment, Producer, Segment, Track}; +use super::MapSource; +use crate::model::{segment, track}; -pub struct Source { +pub struct File { // We read the file once, in order, and don't seek backwards. reader: io::BufReader, - // The subscribable broadcast. - broadcast: Broadcast, + // The catalog for the broadcast, held just so it's closed only when the broadcast is over. + _catalog: track::Publisher, // The tracks we're producing. - tracks: HashMap, + tracks: HashMap, + + // A subscribable source. + source: Arc, } -impl Source { +impl File { pub fn new(path: path::PathBuf) -> anyhow::Result { let f = fs::File::open(path)?; let mut reader = io::BufReader::new(f); @@ -45,75 +49,65 @@ impl Source { // Parse the moov box so we can detect the timescales for each track. let moov = mp4::MoovBox::read_box(&mut moov_reader, moov_header.size)?; + // Create a source that can be subscribed to. + let mut source = HashMap::default(); + + // Create the catalog track + let (_catalog, subscriber) = Self::create_catalog(init); + source.insert("catalog".to_string(), subscriber); + let mut tracks = HashMap::new(); - // Create the init track - let init_track = Self::create_init_track(init); - tracks.insert("catalog".to_string(), init_track); - - // Create a map with the current segment for each track. - // NOTE: We don't add the init track to this, since it's not part of the MP4. - let mut sources = HashMap::new(); - for trak in &moov.traks { - let track_id = trak.tkhd.track_id; - anyhow::ensure!(track_id != 0xff, "track ID 0xff is reserved"); + let id = trak.tkhd.track_id; + let name = id.to_string(); - let timescale = track_timescale(&moov, track_id); - - let segments = Producer::::new(); - - // Insert the subscribable track for consumerts. - // The track_name is just the integer track ID. - let track_name = track_id.to_string(); - tracks.insert( - track_name, - Track { - segments: segments.subscribe(), - }, - ); + let timescale = track_timescale(&moov, id); // Store the track publisher in a map so we can update it later. - let source = SourceTrack::new(segments, timescale); - sources.insert(track_id, source); + let track = Track::new(&name, timescale); + source.insert(name.to_string(), track.subscribe()); + + tracks.insert(name, track); } - let broadcast = Broadcast { - tracks: Arc::new(tracks), - }; + let source = Arc::new(MapSource(source)); Ok(Self { reader, - broadcast, - tracks: sources, + _catalog, + tracks, + source, }) } - // Create an init track - fn create_init_track(raw: Vec) -> Track { - let mut fragments = Producer::new(); - let mut segments = Producer::new(); + fn create_catalog(raw: Vec) -> (track::Publisher, track::Subscriber) { + // Create a track with a single segment containing the init data. + let mut catalog = track::Publisher::new("catalog"); - fragments.push(raw.into()); + // Subscribe to the catalog before we push the segment. + let subscriber = catalog.subscribe(); - segments.push(Segment { + let mut segment = segment::Publisher::new(segment::Info { sequence: VarInt::from_u32(0), // first and only segment send_order: VarInt::from_u32(0), // highest priority expires: None, // never delete from the cache - fragments: fragments.subscribe(), }); - Track { - segments: segments.subscribe(), - } + // Add the segment and add the fragment. + catalog.push_segment(segment.subscribe()); + segment.fragments.push(raw.into()); + + // Return the catalog + (catalog, subscriber) } - pub async fn run(&mut self) -> anyhow::Result<()> { + pub async fn run(mut self) -> anyhow::Result<()> { // The timestamp when the broadcast "started", so we can sleep to simulate a live stream. let start = tokio::time::Instant::now(); - // The ID of the last moof header. - let mut track_id = None; + // The current track name + let mut track_name = None; loop { let atom = read_atom(&mut self.reader)?; @@ -126,26 +120,27 @@ impl Source { let moof = mp4::MoofBox::read_box(&mut reader, header.size).context("failed to read MP4")?; // Process the moof. - let fragment = SourceFragment::new(moof)?; + let fragment = Fragment::new(moof)?; + let name = fragment.track.to_string(); // Get the track for this moof. - let track = self.tracks.get_mut(&fragment.track).context("failed to find track")?; + let track = self.tracks.get_mut(&name).context("failed to find track")?; // Sleep until we should publish this sample. let timestamp = time::Duration::from_millis(1000 * fragment.timestamp / track.timescale); tokio::time::sleep_until(start + timestamp).await; // Save the track ID for the next iteration, which must be a mdat. - anyhow::ensure!(track_id.is_none(), "multiple moof atoms"); - track_id.replace(fragment.track); + anyhow::ensure!(track_name.is_none(), "multiple moof atoms"); + track_name.replace(name); // Publish the moof header, creating a new segment if it's a keyframe. track.header(atom, fragment).context("failed to publish moof")?; } mp4::BoxType::MdatBox => { // Get the track ID from the previous moof. - let track_id = track_id.take().context("missing moof")?; - let track = self.tracks.get_mut(&track_id).context("failed to find track")?; + let name = track_name.take().context("missing moof")?; + let track = self.tracks.get_mut(&name).context("failed to find track")?; // Publish the mdat atom. track.data(atom).context("failed to publish mdat")?; @@ -157,17 +152,17 @@ impl Source { } } - pub fn broadcast(&self) -> Broadcast { - self.broadcast.clone() + pub fn source(&self) -> Arc { + self.source.clone() } } -struct SourceTrack { +struct Track { // The track we're producing - segments: Producer, + track: track::Publisher, - // The current segment's fragments - fragments: Option>, + // The current segment + segment: Option, // The number of units per second. timescale: u64, @@ -176,21 +171,23 @@ struct SourceTrack { sequence: u64, } -impl SourceTrack { - fn new(segments: Producer, timescale: u64) -> Self { +impl Track { + fn new(name: &str, timescale: u64) -> Self { + let track = track::Publisher::new(name); + Self { - segments, + track, sequence: 0, - fragments: None, + segment: None, timescale, } } - pub fn header(&mut self, raw: Vec, fragment: SourceFragment) -> anyhow::Result<()> { - if let Some(fragments) = self.fragments.as_mut() { + pub fn header(&mut self, raw: Vec, fragment: Fragment) -> anyhow::Result<()> { + if let Some(segment) = self.segment.as_mut() { if !fragment.keyframe { // Use the existing segment - fragments.push(raw.into()); + segment.fragments.push(raw.into()); return Ok(()); } } @@ -221,37 +218,42 @@ impl SourceTrack { self.sequence += 1; - // Create a new segment, and save the fragments producer so we can push to it. - let mut fragments = Producer::::new(); - self.segments.push(Segment { + // Create a new segment. + let segment = segment::Info { sequence, expires, send_order, - fragments: fragments.subscribe(), - }); + }; + + let mut segment = segment::Publisher::new(segment); + self.track.push_segment(segment.subscribe()); + + // Insert the raw atom into the segment. + segment.fragments.push(raw.into()); + + // Save for the next iteration + self.segment = Some(segment); // Remove any segments older than 10s. // TODO This can only drain from the FRONT of the queue, so don't get clever with expirations. - self.segments.drain(|segment| segment.expires.unwrap() < now); - - // Insert the raw atom into the segment. - fragments.push(raw.into()); - - // Save for the next iteration - self.fragments = Some(fragments); + self.track.drain_segments(now); Ok(()) } pub fn data(&mut self, raw: Vec) -> anyhow::Result<()> { - let fragments = self.fragments.as_mut().context("missing keyframe")?; - fragments.push(raw.into()); + let segment = self.segment.as_mut().context("missing segment")?; + segment.fragments.push(raw.into()); Ok(()) } + + pub fn subscribe(&self) -> track::Subscriber { + self.track.subscribe() + } } -struct SourceFragment { +struct Fragment { // The track for this fragment. track: u32, @@ -262,7 +264,7 @@ struct SourceFragment { keyframe: bool, } -impl SourceFragment { +impl Fragment { fn new(moof: mp4::MoofBox) -> anyhow::Result { // We can't split the mdat atom, so this is impossible to support anyhow::ensure!(moof.trafs.len() == 1, "multiple tracks per moof atom"); diff --git a/moq-warp/src/source/mod.rs b/moq-warp/src/source/mod.rs new file mode 100644 index 0000000..c69c0d2 --- /dev/null +++ b/moq-warp/src/source/mod.rs @@ -0,0 +1,20 @@ +mod file; +pub use file::*; + +use crate::model::track; + +use std::collections::HashMap; + +// TODO move to model::Broadcast? +pub trait Source { + fn subscribe(&self, name: &str) -> Option; +} + +#[derive(Clone, Default)] +pub struct MapSource(pub HashMap); + +impl Source for MapSource { + fn subscribe(&self, name: &str) -> Option { + self.0.get(name).cloned() + } +}