diff --git a/Cargo.lock b/Cargo.lock index 37393a7..5838672 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1068,6 +1068,8 @@ dependencies = [ "rustls-native-certs", "serde_json", "tokio", + "tracing", + "tracing-subscriber", "url", "webtransport-quinn", ] diff --git a/dev/api b/dev/api index f59b60d..8a74d7d 100755 --- a/dev/api +++ b/dev/api @@ -4,14 +4,14 @@ set -euo pipefail # Change directory to the root of the project cd "$(dirname "$0")/.." +# Use debug logging by default +export RUST_LOG="${RUST_LOG:-debug}" + # Run the API server on port 4442 by default HOST="${HOST:-[::]}" PORT="${PORT:-4442}" LISTEN="${LISTEN:-$HOST:$PORT}" -# Default to info log level -export RUST_LOG="${RUST_LOG:-info}" - # Check for Podman/Docker and set runtime accordingly if command -v podman &> /dev/null; then RUNTIME=podman diff --git a/dev/pub b/dev/pub index 0da6ac6..c3bf8e0 100755 --- a/dev/pub +++ b/dev/pub @@ -4,7 +4,8 @@ set -euo pipefail # Change directory to the root of the project cd "$(dirname "$0")/.." -export RUST_LOG="${RUST_LOG:-info}" +# Use debug logging by default +export RUST_LOG="${RUST_LOG:-debug}" # Connect to localhost by default. HOST="${HOST:-localhost}" diff --git a/dev/relay b/dev/relay index 29ebad4..51aa7cb 100755 --- a/dev/relay +++ b/dev/relay @@ -4,8 +4,8 @@ set -euo pipefail # Change directory to the root of the project cd "$(dirname "$0")/.." -# Use info logging by default -export RUST_LOG="${RUST_LOG:-info}" +# Use debug logging by default +export RUST_LOG="${RUST_LOG:-debug}" # Default to a self-signed certificate # TODO automatically generate if it doesn't exist. diff --git a/moq-api/src/client.rs b/moq-api/src/client.rs index d60a417..5f07d11 100644 --- a/moq-api/src/client.rs +++ b/moq-api/src/client.rs @@ -27,10 +27,10 @@ impl Client { Ok(Some(origin)) } - pub async fn set_origin(&mut self, id: &str, origin: Origin) -> Result<(), ApiError> { + pub async fn set_origin(&mut self, id: &str, origin: &Origin) -> Result<(), ApiError> { let url = self.url.join("origin/")?.join(id)?; - let resp = self.client.post(url).json(&origin).send().await?; + let resp = self.client.post(url).json(origin).send().await?; resp.error_for_status()?; Ok(()) @@ -44,4 +44,13 @@ impl Client { Ok(()) } + + pub async fn patch_origin(&mut self, id: &str, origin: &Origin) -> Result<(), ApiError> { + let url = self.url.join("origin/")?.join(id)?; + + let resp = self.client.patch(url).json(origin).send().await?; + resp.error_for_status()?; + + Ok(()) + } } diff --git a/moq-api/src/model.rs b/moq-api/src/model.rs index 18ab5ed..073e923 100644 --- a/moq-api/src/model.rs +++ b/moq-api/src/model.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use url::Url; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, PartialEq, Eq)] pub struct Origin { pub url: Url, } diff --git a/moq-api/src/server.rs b/moq-api/src/server.rs index 671d83a..d3cf398 100644 --- a/moq-api/src/server.rs +++ b/moq-api/src/server.rs @@ -46,7 +46,13 @@ impl Server { .await?; let app = Router::new() - .route("/origin/:id", get(get_origin).post(set_origin).delete(delete_origin)) + .route( + "/origin/:id", + get(get_origin) + .post(set_origin) + .delete(delete_origin) + .patch(patch_origin), + ) .with_state(redis); log::info!("serving requests: bind={}", self.config.listen); @@ -67,11 +73,8 @@ async fn get_origin( log::debug!("get_origin: id={}", id); - let payload: String = match redis.get(&key).await? { - Some(payload) => payload, - None => return Err(AppError::NotFound), - }; - + let payload: Option = redis.get(&key).await?; + let payload = payload.ok_or(AppError::NotFound)?; let origin: Origin = serde_json::from_str(&payload)?; Ok(Json(origin)) @@ -94,7 +97,7 @@ async fn set_origin( .arg(payload) .arg("NX") .arg("EX") - .arg(60 * 60 * 24 * 2) // Set the key to expire in 2 days; just in case we forget to remove it. + .arg(600) // Set the key to expire in 10 minutes; the origin needs to keep refreshing it. .query_async(&mut redis) .await?; @@ -113,6 +116,31 @@ async fn delete_origin(Path(id): Path, State(mut redis): State, + State(mut redis): State, + Json(origin): Json, +) -> Result<(), AppError> { + let key = origin_key(&id); + + // Make sure the contents haven't changed + // TODO make a LUA script to do this all in one operation. + let payload: Option = redis.get(&key).await?; + let payload = payload.ok_or(AppError::NotFound)?; + let expected: Origin = serde_json::from_str(&payload)?; + + if expected != origin { + return Err(AppError::Duplicate); + } + + // Reset the timeout to 10 minutes. + match redis.expire(key, 600).await? { + 0 => Err(AppError::NotFound), + _ => Ok(()), + } +} + fn origin_key(id: &str) -> String { format!("origin.{}", id) } diff --git a/moq-pub/Cargo.toml b/moq-pub/Cargo.toml index b12c645..2fc3d1d 100644 --- a/moq-pub/Cargo.toml +++ b/moq-pub/Cargo.toml @@ -37,6 +37,8 @@ mp4 = "0.13" anyhow = { version = "1", features = ["backtrace"] } serde_json = "1" rfc6381-codec = "0.1" +tracing = "0.1" +tracing-subscriber = "0.3" [build-dependencies] clap = { version = "4", features = ["derive"] } diff --git a/moq-pub/src/main.rs b/moq-pub/src/main.rs index 2f0d2a8..c753dc5 100644 --- a/moq-pub/src/main.rs +++ b/moq-pub/src/main.rs @@ -15,9 +15,15 @@ use moq_transport::cache::broadcast; async fn main() -> anyhow::Result<()> { env_logger::init(); + // Disable tracing so we don't get a bunch of Quinn spam. + let tracer = tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::WARN) + .finish(); + tracing::subscriber::set_global_default(tracer).unwrap(); + let config = Config::parse(); - let (publisher, subscriber) = broadcast::new(); + let (publisher, subscriber) = broadcast::new(""); let mut media = Media::new(&config, publisher).await?; // Ugh, just let me use my native root certs already diff --git a/moq-relay/src/origin.rs b/moq-relay/src/origin.rs index b9f96fb..9c7535c 100644 --- a/moq-relay/src/origin.rs +++ b/moq-relay/src/origin.rs @@ -1,25 +1,31 @@ +use std::ops::{Deref, DerefMut}; use std::{ - collections::{hash_map, HashMap}, - sync::{Arc, Mutex}, + collections::HashMap, + sync::{Arc, Mutex, Weak}, }; +use moq_api::ApiError; use moq_transport::cache::{broadcast, CacheError}; use url::Url; +use tokio::time; + use crate::RelayError; #[derive(Clone)] pub struct Origin { // An API client used to get/set broadcasts. // If None then we never use a remote origin. + // TODO: Stub this out instead. api: Option, // The internal address of our node. // If None then we can never advertise ourselves as an origin. + // TODO: Stub this out instead. node: Option, - // A map of active broadcasts. - lookup: Arc>>, + // A map of active broadcasts by ID. + cache: Arc>>>, // A QUIC endpoint we'll use to fetch from other origins. quic: quinn::Endpoint, @@ -30,48 +36,80 @@ impl Origin { Self { api, node, - lookup: Default::default(), + cache: Default::default(), quic, } } - pub async fn create_broadcast(&mut self, id: &str) -> Result { - let (publisher, subscriber) = broadcast::new(); + /// Create a new broadcast with the given ID. + /// + /// Publisher::run needs to be called to periodically refresh the origin cache. + pub async fn publish(&mut self, id: &str) -> Result { + let (publisher, subscriber) = broadcast::new(id); - // Check if a broadcast already exists by that id. - match self.lookup.lock().unwrap().entry(id.to_string()) { - hash_map::Entry::Occupied(_) => return Err(CacheError::Duplicate.into()), - hash_map::Entry::Vacant(v) => v.insert(subscriber), + let subscriber = { + let mut cache = self.cache.lock().unwrap(); + + // Check if the broadcast already exists. + // TODO This is racey, because a new publisher could be created while existing subscribers are still active. + if cache.contains_key(id) { + return Err(CacheError::Duplicate.into()); + } + + // Create subscriber that will remove from the cache when dropped. + let subscriber = Arc::new(Subscriber { + broadcast: subscriber, + origin: self.clone(), + }); + + cache.insert(id.to_string(), Arc::downgrade(&subscriber)); + + subscriber }; - if let Some(ref mut api) = self.api { + // Create a publisher that constantly updates itself as the origin in moq-api. + // It holds a reference to the subscriber to prevent dropping early. + let mut publisher = Publisher { + broadcast: publisher, + subscriber, + api: None, + }; + + // Insert the publisher into the database. + if let Some(api) = self.api.as_mut() { // Make a URL for the broadcast. let url = self.node.as_ref().ok_or(RelayError::MissingNode)?.clone().join(id)?; + let origin = moq_api::Origin { url }; + api.set_origin(id, &origin).await?; - log::info!("announcing origin: id={} url={}", id, url); - - let entry = moq_api::Origin { url }; - - if let Err(err) = api.set_origin(id, entry).await { - self.lookup.lock().unwrap().remove(id); - return Err(err.into()); - } + // Refresh every 5 minutes + publisher.api = Some((api.clone(), origin)); } Ok(publisher) } - pub fn get_broadcast(&self, id: &str) -> broadcast::Subscriber { - let mut lookup = self.lookup.lock().unwrap(); + pub fn subscribe(&self, id: &str) -> Arc { + let mut cache = self.cache.lock().unwrap(); - if let Some(broadcast) = lookup.get(id) { - if broadcast.closed().is_none() { - return broadcast.clone(); + if let Some(broadcast) = cache.get(id) { + if let Some(broadcast) = broadcast.upgrade() { + log::debug!("returned broadcast from cache: id={}", id); + return broadcast; + } else { + log::debug!("stale broadcast in cache somehow: id={}", id); } } - let (publisher, subscriber) = broadcast::new(); - lookup.insert(id.to_string(), subscriber.clone()); + let (publisher, subscriber) = broadcast::new(id); + let subscriber = Arc::new(Subscriber { + broadcast: subscriber, + origin: self.clone(), + }); + + cache.insert(id.to_string(), Arc::downgrade(&subscriber)); + + log::debug!("fetching into cache: id={}", id); let mut this = self.clone(); let id = id.to_string(); @@ -82,63 +120,104 @@ impl Origin { // However, the downside is that we don't return an error immediately. // If that's important, it can be done but it gets a bit racey. tokio::spawn(async move { - match this.fetch_broadcast(&id).await { - Ok(session) => { - if let Err(err) = this.run_broadcast(session, publisher).await { - log::warn!("failed to run broadcast: id={} err={:#?}", id, err); - } - } - Err(err) => { - log::warn!("failed to fetch broadcast: id={} err={:#?}", id, err); - publisher.close(CacheError::NotFound).ok(); - } + if let Err(err) = this.serve(&id, publisher).await { + log::warn!("failed to serve remote broadcast: id={} err={}", id, err); } }); subscriber } - async fn fetch_broadcast(&mut self, id: &str) -> Result { + async fn serve(&mut self, id: &str, publisher: broadcast::Publisher) -> Result<(), RelayError> { + log::debug!("finding origin: id={}", id); + // Fetch the origin from the API. - let api = match self.api { - Some(ref mut api) => api, + let origin = self + .api + .as_mut() + .ok_or(CacheError::NotFound)? + .get_origin(id) + .await? + .ok_or(CacheError::NotFound)?; - // We return NotFound here instead of earlier just to simulate an API fetch. - None => return Err(CacheError::NotFound.into()), - }; - - log::info!("fetching origin: id={}", id); - - let origin = api.get_origin(id).await?.ok_or(CacheError::NotFound)?; - - log::info!("connecting to origin: url={}", origin.url); + log::debug!("fetching from origin: id={} url={}", id, origin.url); // Establish the webtransport session. let session = webtransport_quinn::connect(&self.quic, &origin.url).await?; - - Ok(session) - } - - async fn run_broadcast( - &mut self, - session: webtransport_quinn::Session, - broadcast: broadcast::Publisher, - ) -> Result<(), RelayError> { - let session = moq_transport::session::Client::subscriber(session, broadcast).await?; + let session = moq_transport::session::Client::subscriber(session, publisher).await?; session.run().await?; Ok(()) } +} - pub async fn remove_broadcast(&mut self, id: &str) -> Result<(), RelayError> { - self.lookup.lock().unwrap().remove(id).ok_or(CacheError::NotFound)?; +pub struct Subscriber { + pub broadcast: broadcast::Subscriber, - if let Some(ref mut api) = self.api { - log::info!("deleting origin: id={}", id); - api.delete_origin(id).await?; + origin: Origin, +} + +impl Drop for Subscriber { + fn drop(&mut self) { + log::debug!("subscriber: removing from cache: id={}", self.id); + self.origin.cache.lock().unwrap().remove(&self.broadcast.id); + } +} + +impl Deref for Subscriber { + type Target = broadcast::Subscriber; + + fn deref(&self) -> &Self::Target { + &self.broadcast + } +} + +pub struct Publisher { + pub broadcast: broadcast::Publisher, + + api: Option<(moq_api::Client, moq_api::Origin)>, + + #[allow(dead_code)] + subscriber: Arc, +} + +impl Publisher { + pub async fn run(&mut self) -> Result<(), ApiError> { + // Every 5m tell the API we're still alive. + // TODO don't hard-code these values + let mut interval = time::interval(time::Duration::from_secs(60 * 5)); + + loop { + if let Some((api, origin)) = self.api.as_mut() { + log::debug!("refreshing origin: id={}", self.broadcast.id); + api.patch_origin(&self.broadcast.id, origin).await?; + } + + // TODO move to start of loop; this is just for testing + interval.tick().await; + } + } + + pub async fn close(&mut self) -> Result<(), ApiError> { + if let Some((api, _)) = self.api.as_mut() { + api.delete_origin(&self.broadcast.id).await?; } Ok(()) } } + +impl Deref for Publisher { + type Target = broadcast::Publisher; + + fn deref(&self) -> &Self::Target { + &self.broadcast + } +} + +impl DerefMut for Publisher { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.broadcast + } +} diff --git a/moq-relay/src/quic.rs b/moq-relay/src/quic.rs index 2fae446..c94a81b 100644 --- a/moq-relay/src/quic.rs +++ b/moq-relay/src/quic.rs @@ -30,6 +30,7 @@ impl Quic { transport_config.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap())); transport_config.keep_alive_interval(Some(time::Duration::from_secs(4))); // TODO make this smarter transport_config.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default())); + transport_config.mtu_discovery_config(None); // Disable MTU discovery let transport_config = Arc::new(transport_config); let mut client_config = quinn::ClientConfig::new(Arc::new(client_config)); diff --git a/moq-relay/src/session.rs b/moq-relay/src/session.rs index 59403cf..a7570f7 100644 --- a/moq-relay/src/session.rs +++ b/moq-relay/src/session.rs @@ -1,6 +1,6 @@ use anyhow::Context; -use moq_transport::{cache::broadcast, session::Request, setup::Role, MoqError}; +use moq_transport::{session::Request, setup::Role, MoqError}; use crate::Origin; @@ -53,8 +53,16 @@ impl Session { let role = request.role(); match role { - Role::Publisher => self.serve_publisher(id, request, &path).await, - Role::Subscriber => self.serve_subscriber(id, request, &path).await, + Role::Publisher => { + if let Err(err) = self.serve_publisher(id, request, &path).await { + log::warn!("error serving publisher: id={} path={} err={:#?}", id, path, err); + } + } + Role::Subscriber => { + if let Err(err) = self.serve_subscriber(id, request, &path).await { + log::warn!("error serving subscriber: id={} path={} err={:#?}", id, path, err); + } + } Role::Both => { log::warn!("role both not supported: id={}", id); request.reject(300); @@ -66,44 +74,38 @@ impl Session { Ok(()) } - async fn serve_publisher(&mut self, id: usize, request: Request, path: &str) { + async fn serve_publisher(&mut self, id: usize, request: Request, path: &str) -> anyhow::Result<()> { log::info!("serving publisher: id={}, path={}", id, path); - let broadcast = match self.origin.create_broadcast(path).await { - Ok(broadcast) => broadcast, + let mut origin = match self.origin.publish(path).await { + Ok(origin) => origin, Err(err) => { - log::warn!("error accepting publisher: id={} path={} err={:#?}", id, path, err); - return request.reject(err.code()); + request.reject(err.code()); + return Err(err.into()); } }; - if let Err(err) = self.run_publisher(request, broadcast).await { - log::warn!("error serving publisher: id={} path={} err={:#?}", id, path, err); - } + let session = request.subscriber(origin.broadcast.clone()).await?; - // TODO can we do this on drop? Otherwise we might miss it. - self.origin.remove_broadcast(path).await.ok(); - } + tokio::select! { + _ = session.run() => origin.close().await?, + _ = origin.run() => (), // TODO send error to session + }; - async fn run_publisher(&mut self, request: Request, publisher: broadcast::Publisher) -> anyhow::Result<()> { - let session = request.subscriber(publisher).await?; - session.run().await?; Ok(()) } - async fn serve_subscriber(&mut self, id: usize, request: Request, path: &str) { + async fn serve_subscriber(&mut self, id: usize, request: Request, path: &str) -> anyhow::Result<()> { log::info!("serving subscriber: id={} path={}", id, path); - let broadcast = self.origin.get_broadcast(path); + let subscriber = self.origin.subscribe(path); - if let Err(err) = self.run_subscriber(request, broadcast).await { - log::warn!("error serving subscriber: id={} path={} err={:#?}", id, path, err); - } - } - - async fn run_subscriber(&mut self, request: Request, broadcast: broadcast::Subscriber) -> anyhow::Result<()> { - let session = request.publisher(broadcast).await?; + let session = request.publisher(subscriber.broadcast.clone()).await?; session.run().await?; + + // Make sure this doesn't get dropped too early + drop(subscriber); + Ok(()) } } diff --git a/moq-transport/src/cache/broadcast.rs b/moq-transport/src/cache/broadcast.rs index 58c28fa..d485c30 100644 --- a/moq-transport/src/cache/broadcast.rs +++ b/moq-transport/src/cache/broadcast.rs @@ -13,21 +13,29 @@ use std::{ collections::{hash_map, HashMap, VecDeque}, fmt, + ops::Deref, sync::Arc, }; use super::{track, CacheError, Watch}; /// Create a new broadcast. -pub fn new() -> (Publisher, Subscriber) { +pub fn new(id: &str) -> (Publisher, Subscriber) { let state = Watch::new(State::default()); + let info = Arc::new(Info { id: id.to_string() }); - let publisher = Publisher::new(state.clone()); - let subscriber = Subscriber::new(state); + let publisher = Publisher::new(state.clone(), info.clone()); + let subscriber = Subscriber::new(state, info); (publisher, subscriber) } +/// Static information about a broadcast. +#[derive(Debug)] +pub struct Info { + pub id: String, +} + /// Dynamic information about the broadcast. #[derive(Debug)] struct State { @@ -105,13 +113,14 @@ impl Default for State { #[derive(Clone)] pub struct Publisher { state: Watch, + info: Arc, _dropped: Arc, } impl Publisher { - fn new(state: Watch) -> Self { + fn new(state: Watch, info: Arc) -> Self { let _dropped = Arc::new(Dropped::new(state.clone())); - Self { state, _dropped } + Self { state, info, _dropped } } /// Create a new track with the given name, inserting it into the broadcast. @@ -148,9 +157,20 @@ impl Publisher { } } +impl Deref for Publisher { + type Target = Info; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + impl fmt::Debug for Publisher { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Publisher").field("state", &self.state).finish() + f.debug_struct("Publisher") + .field("state", &self.state) + .field("info", &self.info) + .finish() } } @@ -160,13 +180,14 @@ impl fmt::Debug for Publisher { #[derive(Clone)] pub struct Subscriber { state: Watch, + info: Arc, _dropped: Arc, } impl Subscriber { - fn new(state: Watch) -> Self { + fn new(state: Watch, info: Arc) -> Self { let _dropped = Arc::new(Dropped::new(state.clone())); - Self { state, _dropped } + Self { state, info, _dropped } } /// Get a track from the broadcast by name. @@ -182,15 +203,42 @@ impl Subscriber { state.into_mut().request(name) } - /// Return if the broadcast is closed, either because the publisher was dropped or called [Publisher::close]. - pub fn closed(&self) -> Option { + /// Check if the broadcast is closed, either because the publisher was dropped or called [Publisher::close]. + pub fn is_closed(&self) -> Option { self.state.lock().closed.as_ref().err().cloned() } + + /// Wait until if the broadcast is closed, either because the publisher was dropped or called [Publisher::close]. + pub async fn closed(&self) -> CacheError { + loop { + let notify = { + let state = self.state.lock(); + if let Some(err) = state.closed.as_ref().err() { + return err.clone(); + } + + state.changed() + }; + + notify.await; + } + } +} + +impl Deref for Subscriber { + type Target = Info; + + fn deref(&self) -> &Self::Target { + &self.info + } } impl fmt::Debug for Subscriber { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Subscriber").field("state", &self.state).finish() + f.debug_struct("Subscriber") + .field("state", &self.state) + .field("info", &self.info) + .finish() } } diff --git a/moq-transport/src/session/publisher.rs b/moq-transport/src/session/publisher.rs index 4ae3c60..5c8e761 100644 --- a/moq-transport/src/session/publisher.rs +++ b/moq-transport/src/session/publisher.rs @@ -64,7 +64,7 @@ impl Publisher { stream?; return Err(SessionError::RoleViolation(VarInt::ZERO)); } - // NOTE: this is not cancel safe, but it's fine since the other branch is a fatal error. + // NOTE: this is not cancel safe, but it's fine since the other branchs are fatal. msg = self.control.recv() => { let msg = msg?; @@ -72,7 +72,12 @@ impl Publisher { if let Err(err) = self.recv_message(&msg).await { log::warn!("message error: {:?} {:?}", err, msg); } - } + }, + // No more broadcasts are available. + err = self.source.closed() => { + self.webtransport.close(err.code(), err.reason().as_bytes()); + return Ok(()); + }, } } } @@ -178,7 +183,7 @@ impl Publisher { expires: segment.expires, }; - log::debug!("serving object: {:?}", object); + log::trace!("serving object: {:?}", object); let mut stream = self.webtransport.open_uni().await?; stream.set_priority(object.priority).ok(); diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index 3148c2b..700cfa8 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -135,6 +135,7 @@ impl Subscriber { } async fn run_source(mut self) -> Result<(), SessionError> { + // NOTE: This returns Closed when the source is closed. while let Some(track) = self.source.next_track().await? { let name = track.name.clone();