Remove subscribers/publisher on close (#103)
This commit is contained in:
parent
a30f313439
commit
53817f41e7
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -1068,6 +1068,8 @@ dependencies = [
|
||||
"rustls-native-certs",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"webtransport-quinn",
|
||||
]
|
||||
|
6
dev/api
6
dev/api
@ -4,14 +4,14 @@ set -euo pipefail
|
||||
# Change directory to the root of the project
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
# Use debug logging by default
|
||||
export RUST_LOG="${RUST_LOG:-debug}"
|
||||
|
||||
# Run the API server on port 4442 by default
|
||||
HOST="${HOST:-[::]}"
|
||||
PORT="${PORT:-4442}"
|
||||
LISTEN="${LISTEN:-$HOST:$PORT}"
|
||||
|
||||
# Default to info log level
|
||||
export RUST_LOG="${RUST_LOG:-info}"
|
||||
|
||||
# Check for Podman/Docker and set runtime accordingly
|
||||
if command -v podman &> /dev/null; then
|
||||
RUNTIME=podman
|
||||
|
3
dev/pub
3
dev/pub
@ -4,7 +4,8 @@ set -euo pipefail
|
||||
# Change directory to the root of the project
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
export RUST_LOG="${RUST_LOG:-info}"
|
||||
# Use debug logging by default
|
||||
export RUST_LOG="${RUST_LOG:-debug}"
|
||||
|
||||
# Connect to localhost by default.
|
||||
HOST="${HOST:-localhost}"
|
||||
|
@ -4,8 +4,8 @@ set -euo pipefail
|
||||
# Change directory to the root of the project
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
# Use info logging by default
|
||||
export RUST_LOG="${RUST_LOG:-info}"
|
||||
# Use debug logging by default
|
||||
export RUST_LOG="${RUST_LOG:-debug}"
|
||||
|
||||
# Default to a self-signed certificate
|
||||
# TODO automatically generate if it doesn't exist.
|
||||
|
@ -27,10 +27,10 @@ impl Client {
|
||||
Ok(Some(origin))
|
||||
}
|
||||
|
||||
pub async fn set_origin(&mut self, id: &str, origin: Origin) -> Result<(), ApiError> {
|
||||
pub async fn set_origin(&mut self, id: &str, origin: &Origin) -> Result<(), ApiError> {
|
||||
let url = self.url.join("origin/")?.join(id)?;
|
||||
|
||||
let resp = self.client.post(url).json(&origin).send().await?;
|
||||
let resp = self.client.post(url).json(origin).send().await?;
|
||||
resp.error_for_status()?;
|
||||
|
||||
Ok(())
|
||||
@ -44,4 +44,13 @@ impl Client {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn patch_origin(&mut self, id: &str, origin: &Origin) -> Result<(), ApiError> {
|
||||
let url = self.url.join("origin/")?.join(id)?;
|
||||
|
||||
let resp = self.client.patch(url).json(origin).send().await?;
|
||||
resp.error_for_status()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use url::Url;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct Origin {
|
||||
pub url: Url,
|
||||
}
|
||||
|
@ -46,7 +46,13 @@ impl Server {
|
||||
.await?;
|
||||
|
||||
let app = Router::new()
|
||||
.route("/origin/:id", get(get_origin).post(set_origin).delete(delete_origin))
|
||||
.route(
|
||||
"/origin/:id",
|
||||
get(get_origin)
|
||||
.post(set_origin)
|
||||
.delete(delete_origin)
|
||||
.patch(patch_origin),
|
||||
)
|
||||
.with_state(redis);
|
||||
|
||||
log::info!("serving requests: bind={}", self.config.listen);
|
||||
@ -67,11 +73,8 @@ async fn get_origin(
|
||||
|
||||
log::debug!("get_origin: id={}", id);
|
||||
|
||||
let payload: String = match redis.get(&key).await? {
|
||||
Some(payload) => payload,
|
||||
None => return Err(AppError::NotFound),
|
||||
};
|
||||
|
||||
let payload: Option<String> = redis.get(&key).await?;
|
||||
let payload = payload.ok_or(AppError::NotFound)?;
|
||||
let origin: Origin = serde_json::from_str(&payload)?;
|
||||
|
||||
Ok(Json(origin))
|
||||
@ -94,7 +97,7 @@ async fn set_origin(
|
||||
.arg(payload)
|
||||
.arg("NX")
|
||||
.arg("EX")
|
||||
.arg(60 * 60 * 24 * 2) // Set the key to expire in 2 days; just in case we forget to remove it.
|
||||
.arg(600) // Set the key to expire in 10 minutes; the origin needs to keep refreshing it.
|
||||
.query_async(&mut redis)
|
||||
.await?;
|
||||
|
||||
@ -113,6 +116,31 @@ async fn delete_origin(Path(id): Path<String>, State(mut redis): State<Connectio
|
||||
}
|
||||
}
|
||||
|
||||
// Update the expiration deadline.
|
||||
async fn patch_origin(
|
||||
Path(id): Path<String>,
|
||||
State(mut redis): State<ConnectionManager>,
|
||||
Json(origin): Json<Origin>,
|
||||
) -> Result<(), AppError> {
|
||||
let key = origin_key(&id);
|
||||
|
||||
// Make sure the contents haven't changed
|
||||
// TODO make a LUA script to do this all in one operation.
|
||||
let payload: Option<String> = redis.get(&key).await?;
|
||||
let payload = payload.ok_or(AppError::NotFound)?;
|
||||
let expected: Origin = serde_json::from_str(&payload)?;
|
||||
|
||||
if expected != origin {
|
||||
return Err(AppError::Duplicate);
|
||||
}
|
||||
|
||||
// Reset the timeout to 10 minutes.
|
||||
match redis.expire(key, 600).await? {
|
||||
0 => Err(AppError::NotFound),
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn origin_key(id: &str) -> String {
|
||||
format!("origin.{}", id)
|
||||
}
|
||||
|
@ -37,6 +37,8 @@ mp4 = "0.13"
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
serde_json = "1"
|
||||
rfc6381-codec = "0.1"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
|
||||
[build-dependencies]
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
|
@ -15,9 +15,15 @@ use moq_transport::cache::broadcast;
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
env_logger::init();
|
||||
|
||||
// Disable tracing so we don't get a bunch of Quinn spam.
|
||||
let tracer = tracing_subscriber::FmtSubscriber::builder()
|
||||
.with_max_level(tracing::Level::WARN)
|
||||
.finish();
|
||||
tracing::subscriber::set_global_default(tracer).unwrap();
|
||||
|
||||
let config = Config::parse();
|
||||
|
||||
let (publisher, subscriber) = broadcast::new();
|
||||
let (publisher, subscriber) = broadcast::new("");
|
||||
let mut media = Media::new(&config, publisher).await?;
|
||||
|
||||
// Ugh, just let me use my native root certs already
|
||||
|
@ -1,25 +1,31 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::{
|
||||
collections::{hash_map, HashMap},
|
||||
sync::{Arc, Mutex},
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex, Weak},
|
||||
};
|
||||
|
||||
use moq_api::ApiError;
|
||||
use moq_transport::cache::{broadcast, CacheError};
|
||||
use url::Url;
|
||||
|
||||
use tokio::time;
|
||||
|
||||
use crate::RelayError;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Origin {
|
||||
// An API client used to get/set broadcasts.
|
||||
// If None then we never use a remote origin.
|
||||
// TODO: Stub this out instead.
|
||||
api: Option<moq_api::Client>,
|
||||
|
||||
// The internal address of our node.
|
||||
// If None then we can never advertise ourselves as an origin.
|
||||
// TODO: Stub this out instead.
|
||||
node: Option<Url>,
|
||||
|
||||
// A map of active broadcasts.
|
||||
lookup: Arc<Mutex<HashMap<String, broadcast::Subscriber>>>,
|
||||
// A map of active broadcasts by ID.
|
||||
cache: Arc<Mutex<HashMap<String, Weak<Subscriber>>>>,
|
||||
|
||||
// A QUIC endpoint we'll use to fetch from other origins.
|
||||
quic: quinn::Endpoint,
|
||||
@ -30,48 +36,80 @@ impl Origin {
|
||||
Self {
|
||||
api,
|
||||
node,
|
||||
lookup: Default::default(),
|
||||
cache: Default::default(),
|
||||
quic,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_broadcast(&mut self, id: &str) -> Result<broadcast::Publisher, RelayError> {
|
||||
let (publisher, subscriber) = broadcast::new();
|
||||
/// Create a new broadcast with the given ID.
|
||||
///
|
||||
/// Publisher::run needs to be called to periodically refresh the origin cache.
|
||||
pub async fn publish(&mut self, id: &str) -> Result<Publisher, RelayError> {
|
||||
let (publisher, subscriber) = broadcast::new(id);
|
||||
|
||||
// Check if a broadcast already exists by that id.
|
||||
match self.lookup.lock().unwrap().entry(id.to_string()) {
|
||||
hash_map::Entry::Occupied(_) => return Err(CacheError::Duplicate.into()),
|
||||
hash_map::Entry::Vacant(v) => v.insert(subscriber),
|
||||
let subscriber = {
|
||||
let mut cache = self.cache.lock().unwrap();
|
||||
|
||||
// Check if the broadcast already exists.
|
||||
// TODO This is racey, because a new publisher could be created while existing subscribers are still active.
|
||||
if cache.contains_key(id) {
|
||||
return Err(CacheError::Duplicate.into());
|
||||
}
|
||||
|
||||
// Create subscriber that will remove from the cache when dropped.
|
||||
let subscriber = Arc::new(Subscriber {
|
||||
broadcast: subscriber,
|
||||
origin: self.clone(),
|
||||
});
|
||||
|
||||
cache.insert(id.to_string(), Arc::downgrade(&subscriber));
|
||||
|
||||
subscriber
|
||||
};
|
||||
|
||||
if let Some(ref mut api) = self.api {
|
||||
// Create a publisher that constantly updates itself as the origin in moq-api.
|
||||
// It holds a reference to the subscriber to prevent dropping early.
|
||||
let mut publisher = Publisher {
|
||||
broadcast: publisher,
|
||||
subscriber,
|
||||
api: None,
|
||||
};
|
||||
|
||||
// Insert the publisher into the database.
|
||||
if let Some(api) = self.api.as_mut() {
|
||||
// Make a URL for the broadcast.
|
||||
let url = self.node.as_ref().ok_or(RelayError::MissingNode)?.clone().join(id)?;
|
||||
let origin = moq_api::Origin { url };
|
||||
api.set_origin(id, &origin).await?;
|
||||
|
||||
log::info!("announcing origin: id={} url={}", id, url);
|
||||
|
||||
let entry = moq_api::Origin { url };
|
||||
|
||||
if let Err(err) = api.set_origin(id, entry).await {
|
||||
self.lookup.lock().unwrap().remove(id);
|
||||
return Err(err.into());
|
||||
}
|
||||
// Refresh every 5 minutes
|
||||
publisher.api = Some((api.clone(), origin));
|
||||
}
|
||||
|
||||
Ok(publisher)
|
||||
}
|
||||
|
||||
pub fn get_broadcast(&self, id: &str) -> broadcast::Subscriber {
|
||||
let mut lookup = self.lookup.lock().unwrap();
|
||||
pub fn subscribe(&self, id: &str) -> Arc<Subscriber> {
|
||||
let mut cache = self.cache.lock().unwrap();
|
||||
|
||||
if let Some(broadcast) = lookup.get(id) {
|
||||
if broadcast.closed().is_none() {
|
||||
return broadcast.clone();
|
||||
if let Some(broadcast) = cache.get(id) {
|
||||
if let Some(broadcast) = broadcast.upgrade() {
|
||||
log::debug!("returned broadcast from cache: id={}", id);
|
||||
return broadcast;
|
||||
} else {
|
||||
log::debug!("stale broadcast in cache somehow: id={}", id);
|
||||
}
|
||||
}
|
||||
|
||||
let (publisher, subscriber) = broadcast::new();
|
||||
lookup.insert(id.to_string(), subscriber.clone());
|
||||
let (publisher, subscriber) = broadcast::new(id);
|
||||
let subscriber = Arc::new(Subscriber {
|
||||
broadcast: subscriber,
|
||||
origin: self.clone(),
|
||||
});
|
||||
|
||||
cache.insert(id.to_string(), Arc::downgrade(&subscriber));
|
||||
|
||||
log::debug!("fetching into cache: id={}", id);
|
||||
|
||||
let mut this = self.clone();
|
||||
let id = id.to_string();
|
||||
@ -82,63 +120,104 @@ impl Origin {
|
||||
// However, the downside is that we don't return an error immediately.
|
||||
// If that's important, it can be done but it gets a bit racey.
|
||||
tokio::spawn(async move {
|
||||
match this.fetch_broadcast(&id).await {
|
||||
Ok(session) => {
|
||||
if let Err(err) = this.run_broadcast(session, publisher).await {
|
||||
log::warn!("failed to run broadcast: id={} err={:#?}", id, err);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
log::warn!("failed to fetch broadcast: id={} err={:#?}", id, err);
|
||||
publisher.close(CacheError::NotFound).ok();
|
||||
}
|
||||
if let Err(err) = this.serve(&id, publisher).await {
|
||||
log::warn!("failed to serve remote broadcast: id={} err={}", id, err);
|
||||
}
|
||||
});
|
||||
|
||||
subscriber
|
||||
}
|
||||
|
||||
async fn fetch_broadcast(&mut self, id: &str) -> Result<webtransport_quinn::Session, RelayError> {
|
||||
async fn serve(&mut self, id: &str, publisher: broadcast::Publisher) -> Result<(), RelayError> {
|
||||
log::debug!("finding origin: id={}", id);
|
||||
|
||||
// Fetch the origin from the API.
|
||||
let api = match self.api {
|
||||
Some(ref mut api) => api,
|
||||
let origin = self
|
||||
.api
|
||||
.as_mut()
|
||||
.ok_or(CacheError::NotFound)?
|
||||
.get_origin(id)
|
||||
.await?
|
||||
.ok_or(CacheError::NotFound)?;
|
||||
|
||||
// We return NotFound here instead of earlier just to simulate an API fetch.
|
||||
None => return Err(CacheError::NotFound.into()),
|
||||
};
|
||||
|
||||
log::info!("fetching origin: id={}", id);
|
||||
|
||||
let origin = api.get_origin(id).await?.ok_or(CacheError::NotFound)?;
|
||||
|
||||
log::info!("connecting to origin: url={}", origin.url);
|
||||
log::debug!("fetching from origin: id={} url={}", id, origin.url);
|
||||
|
||||
// Establish the webtransport session.
|
||||
let session = webtransport_quinn::connect(&self.quic, &origin.url).await?;
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
async fn run_broadcast(
|
||||
&mut self,
|
||||
session: webtransport_quinn::Session,
|
||||
broadcast: broadcast::Publisher,
|
||||
) -> Result<(), RelayError> {
|
||||
let session = moq_transport::session::Client::subscriber(session, broadcast).await?;
|
||||
let session = moq_transport::session::Client::subscriber(session, publisher).await?;
|
||||
|
||||
session.run().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_broadcast(&mut self, id: &str) -> Result<(), RelayError> {
|
||||
self.lookup.lock().unwrap().remove(id).ok_or(CacheError::NotFound)?;
|
||||
pub struct Subscriber {
|
||||
pub broadcast: broadcast::Subscriber,
|
||||
|
||||
if let Some(ref mut api) = self.api {
|
||||
log::info!("deleting origin: id={}", id);
|
||||
api.delete_origin(id).await?;
|
||||
origin: Origin,
|
||||
}
|
||||
|
||||
impl Drop for Subscriber {
|
||||
fn drop(&mut self) {
|
||||
log::debug!("subscriber: removing from cache: id={}", self.id);
|
||||
self.origin.cache.lock().unwrap().remove(&self.broadcast.id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Subscriber {
|
||||
type Target = broadcast::Subscriber;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.broadcast
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Publisher {
|
||||
pub broadcast: broadcast::Publisher,
|
||||
|
||||
api: Option<(moq_api::Client, moq_api::Origin)>,
|
||||
|
||||
#[allow(dead_code)]
|
||||
subscriber: Arc<Subscriber>,
|
||||
}
|
||||
|
||||
impl Publisher {
|
||||
pub async fn run(&mut self) -> Result<(), ApiError> {
|
||||
// Every 5m tell the API we're still alive.
|
||||
// TODO don't hard-code these values
|
||||
let mut interval = time::interval(time::Duration::from_secs(60 * 5));
|
||||
|
||||
loop {
|
||||
if let Some((api, origin)) = self.api.as_mut() {
|
||||
log::debug!("refreshing origin: id={}", self.broadcast.id);
|
||||
api.patch_origin(&self.broadcast.id, origin).await?;
|
||||
}
|
||||
|
||||
// TODO move to start of loop; this is just for testing
|
||||
interval.tick().await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close(&mut self) -> Result<(), ApiError> {
|
||||
if let Some((api, _)) = self.api.as_mut() {
|
||||
api.delete_origin(&self.broadcast.id).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Publisher {
|
||||
type Target = broadcast::Publisher;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.broadcast
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Publisher {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.broadcast
|
||||
}
|
||||
}
|
||||
|
@ -30,6 +30,7 @@ impl Quic {
|
||||
transport_config.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
|
||||
transport_config.keep_alive_interval(Some(time::Duration::from_secs(4))); // TODO make this smarter
|
||||
transport_config.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
|
||||
transport_config.mtu_discovery_config(None); // Disable MTU discovery
|
||||
let transport_config = Arc::new(transport_config);
|
||||
|
||||
let mut client_config = quinn::ClientConfig::new(Arc::new(client_config));
|
||||
|
@ -1,6 +1,6 @@
|
||||
use anyhow::Context;
|
||||
|
||||
use moq_transport::{cache::broadcast, session::Request, setup::Role, MoqError};
|
||||
use moq_transport::{session::Request, setup::Role, MoqError};
|
||||
|
||||
use crate::Origin;
|
||||
|
||||
@ -53,8 +53,16 @@ impl Session {
|
||||
let role = request.role();
|
||||
|
||||
match role {
|
||||
Role::Publisher => self.serve_publisher(id, request, &path).await,
|
||||
Role::Subscriber => self.serve_subscriber(id, request, &path).await,
|
||||
Role::Publisher => {
|
||||
if let Err(err) = self.serve_publisher(id, request, &path).await {
|
||||
log::warn!("error serving publisher: id={} path={} err={:#?}", id, path, err);
|
||||
}
|
||||
}
|
||||
Role::Subscriber => {
|
||||
if let Err(err) = self.serve_subscriber(id, request, &path).await {
|
||||
log::warn!("error serving subscriber: id={} path={} err={:#?}", id, path, err);
|
||||
}
|
||||
}
|
||||
Role::Both => {
|
||||
log::warn!("role both not supported: id={}", id);
|
||||
request.reject(300);
|
||||
@ -66,44 +74,38 @@ impl Session {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn serve_publisher(&mut self, id: usize, request: Request, path: &str) {
|
||||
async fn serve_publisher(&mut self, id: usize, request: Request, path: &str) -> anyhow::Result<()> {
|
||||
log::info!("serving publisher: id={}, path={}", id, path);
|
||||
|
||||
let broadcast = match self.origin.create_broadcast(path).await {
|
||||
Ok(broadcast) => broadcast,
|
||||
let mut origin = match self.origin.publish(path).await {
|
||||
Ok(origin) => origin,
|
||||
Err(err) => {
|
||||
log::warn!("error accepting publisher: id={} path={} err={:#?}", id, path, err);
|
||||
return request.reject(err.code());
|
||||
request.reject(err.code());
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = self.run_publisher(request, broadcast).await {
|
||||
log::warn!("error serving publisher: id={} path={} err={:#?}", id, path, err);
|
||||
}
|
||||
let session = request.subscriber(origin.broadcast.clone()).await?;
|
||||
|
||||
// TODO can we do this on drop? Otherwise we might miss it.
|
||||
self.origin.remove_broadcast(path).await.ok();
|
||||
}
|
||||
tokio::select! {
|
||||
_ = session.run() => origin.close().await?,
|
||||
_ = origin.run() => (), // TODO send error to session
|
||||
};
|
||||
|
||||
async fn run_publisher(&mut self, request: Request, publisher: broadcast::Publisher) -> anyhow::Result<()> {
|
||||
let session = request.subscriber(publisher).await?;
|
||||
session.run().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn serve_subscriber(&mut self, id: usize, request: Request, path: &str) {
|
||||
async fn serve_subscriber(&mut self, id: usize, request: Request, path: &str) -> anyhow::Result<()> {
|
||||
log::info!("serving subscriber: id={} path={}", id, path);
|
||||
|
||||
let broadcast = self.origin.get_broadcast(path);
|
||||
let subscriber = self.origin.subscribe(path);
|
||||
|
||||
if let Err(err) = self.run_subscriber(request, broadcast).await {
|
||||
log::warn!("error serving subscriber: id={} path={} err={:#?}", id, path, err);
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_subscriber(&mut self, request: Request, broadcast: broadcast::Subscriber) -> anyhow::Result<()> {
|
||||
let session = request.publisher(broadcast).await?;
|
||||
let session = request.publisher(subscriber.broadcast.clone()).await?;
|
||||
session.run().await?;
|
||||
|
||||
// Make sure this doesn't get dropped too early
|
||||
drop(subscriber);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
70
moq-transport/src/cache/broadcast.rs
vendored
70
moq-transport/src/cache/broadcast.rs
vendored
@ -13,21 +13,29 @@
|
||||
use std::{
|
||||
collections::{hash_map, HashMap, VecDeque},
|
||||
fmt,
|
||||
ops::Deref,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use super::{track, CacheError, Watch};
|
||||
|
||||
/// Create a new broadcast.
|
||||
pub fn new() -> (Publisher, Subscriber) {
|
||||
pub fn new(id: &str) -> (Publisher, Subscriber) {
|
||||
let state = Watch::new(State::default());
|
||||
let info = Arc::new(Info { id: id.to_string() });
|
||||
|
||||
let publisher = Publisher::new(state.clone());
|
||||
let subscriber = Subscriber::new(state);
|
||||
let publisher = Publisher::new(state.clone(), info.clone());
|
||||
let subscriber = Subscriber::new(state, info);
|
||||
|
||||
(publisher, subscriber)
|
||||
}
|
||||
|
||||
/// Static information about a broadcast.
|
||||
#[derive(Debug)]
|
||||
pub struct Info {
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
/// Dynamic information about the broadcast.
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
@ -105,13 +113,14 @@ impl Default for State {
|
||||
#[derive(Clone)]
|
||||
pub struct Publisher {
|
||||
state: Watch<State>,
|
||||
info: Arc<Info>,
|
||||
_dropped: Arc<Dropped>,
|
||||
}
|
||||
|
||||
impl Publisher {
|
||||
fn new(state: Watch<State>) -> Self {
|
||||
fn new(state: Watch<State>, info: Arc<Info>) -> Self {
|
||||
let _dropped = Arc::new(Dropped::new(state.clone()));
|
||||
Self { state, _dropped }
|
||||
Self { state, info, _dropped }
|
||||
}
|
||||
|
||||
/// Create a new track with the given name, inserting it into the broadcast.
|
||||
@ -148,9 +157,20 @@ impl Publisher {
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Publisher {
|
||||
type Target = Info;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.info
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Publisher {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Publisher").field("state", &self.state).finish()
|
||||
f.debug_struct("Publisher")
|
||||
.field("state", &self.state)
|
||||
.field("info", &self.info)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
@ -160,13 +180,14 @@ impl fmt::Debug for Publisher {
|
||||
#[derive(Clone)]
|
||||
pub struct Subscriber {
|
||||
state: Watch<State>,
|
||||
info: Arc<Info>,
|
||||
_dropped: Arc<Dropped>,
|
||||
}
|
||||
|
||||
impl Subscriber {
|
||||
fn new(state: Watch<State>) -> Self {
|
||||
fn new(state: Watch<State>, info: Arc<Info>) -> Self {
|
||||
let _dropped = Arc::new(Dropped::new(state.clone()));
|
||||
Self { state, _dropped }
|
||||
Self { state, info, _dropped }
|
||||
}
|
||||
|
||||
/// Get a track from the broadcast by name.
|
||||
@ -182,15 +203,42 @@ impl Subscriber {
|
||||
state.into_mut().request(name)
|
||||
}
|
||||
|
||||
/// Return if the broadcast is closed, either because the publisher was dropped or called [Publisher::close].
|
||||
pub fn closed(&self) -> Option<CacheError> {
|
||||
/// Check if the broadcast is closed, either because the publisher was dropped or called [Publisher::close].
|
||||
pub fn is_closed(&self) -> Option<CacheError> {
|
||||
self.state.lock().closed.as_ref().err().cloned()
|
||||
}
|
||||
|
||||
/// Wait until if the broadcast is closed, either because the publisher was dropped or called [Publisher::close].
|
||||
pub async fn closed(&self) -> CacheError {
|
||||
loop {
|
||||
let notify = {
|
||||
let state = self.state.lock();
|
||||
if let Some(err) = state.closed.as_ref().err() {
|
||||
return err.clone();
|
||||
}
|
||||
|
||||
state.changed()
|
||||
};
|
||||
|
||||
notify.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Subscriber {
|
||||
type Target = Info;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.info
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Subscriber {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Subscriber").field("state", &self.state).finish()
|
||||
f.debug_struct("Subscriber")
|
||||
.field("state", &self.state)
|
||||
.field("info", &self.info)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,7 +64,7 @@ impl Publisher {
|
||||
stream?;
|
||||
return Err(SessionError::RoleViolation(VarInt::ZERO));
|
||||
}
|
||||
// NOTE: this is not cancel safe, but it's fine since the other branch is a fatal error.
|
||||
// NOTE: this is not cancel safe, but it's fine since the other branchs are fatal.
|
||||
msg = self.control.recv() => {
|
||||
let msg = msg?;
|
||||
|
||||
@ -72,7 +72,12 @@ impl Publisher {
|
||||
if let Err(err) = self.recv_message(&msg).await {
|
||||
log::warn!("message error: {:?} {:?}", err, msg);
|
||||
}
|
||||
}
|
||||
},
|
||||
// No more broadcasts are available.
|
||||
err = self.source.closed() => {
|
||||
self.webtransport.close(err.code(), err.reason().as_bytes());
|
||||
return Ok(());
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -178,7 +183,7 @@ impl Publisher {
|
||||
expires: segment.expires,
|
||||
};
|
||||
|
||||
log::debug!("serving object: {:?}", object);
|
||||
log::trace!("serving object: {:?}", object);
|
||||
|
||||
let mut stream = self.webtransport.open_uni().await?;
|
||||
stream.set_priority(object.priority).ok();
|
||||
|
@ -135,6 +135,7 @@ impl Subscriber {
|
||||
}
|
||||
|
||||
async fn run_source(mut self) -> Result<(), SessionError> {
|
||||
// NOTE: This returns Closed when the source is closed.
|
||||
while let Some(track) = self.source.next_track().await? {
|
||||
let name = track.name.clone();
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user