Remove subscribers/publisher on close (#103)

This commit is contained in:
kixelated 2023-10-20 12:04:55 +09:00 committed by GitHub
parent a30f313439
commit 53817f41e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 305 additions and 121 deletions

2
Cargo.lock generated
View File

@ -1068,6 +1068,8 @@ dependencies = [
"rustls-native-certs",
"serde_json",
"tokio",
"tracing",
"tracing-subscriber",
"url",
"webtransport-quinn",
]

View File

@ -4,14 +4,14 @@ set -euo pipefail
# Change directory to the root of the project
cd "$(dirname "$0")/.."
# Use debug logging by default
export RUST_LOG="${RUST_LOG:-debug}"
# Run the API server on port 4442 by default
HOST="${HOST:-[::]}"
PORT="${PORT:-4442}"
LISTEN="${LISTEN:-$HOST:$PORT}"
# Default to info log level
export RUST_LOG="${RUST_LOG:-info}"
# Check for Podman/Docker and set runtime accordingly
if command -v podman &> /dev/null; then
RUNTIME=podman

View File

@ -4,7 +4,8 @@ set -euo pipefail
# Change directory to the root of the project
cd "$(dirname "$0")/.."
export RUST_LOG="${RUST_LOG:-info}"
# Use debug logging by default
export RUST_LOG="${RUST_LOG:-debug}"
# Connect to localhost by default.
HOST="${HOST:-localhost}"

View File

@ -4,8 +4,8 @@ set -euo pipefail
# Change directory to the root of the project
cd "$(dirname "$0")/.."
# Use info logging by default
export RUST_LOG="${RUST_LOG:-info}"
# Use debug logging by default
export RUST_LOG="${RUST_LOG:-debug}"
# Default to a self-signed certificate
# TODO automatically generate if it doesn't exist.

View File

@ -27,10 +27,10 @@ impl Client {
Ok(Some(origin))
}
pub async fn set_origin(&mut self, id: &str, origin: Origin) -> Result<(), ApiError> {
pub async fn set_origin(&mut self, id: &str, origin: &Origin) -> Result<(), ApiError> {
let url = self.url.join("origin/")?.join(id)?;
let resp = self.client.post(url).json(&origin).send().await?;
let resp = self.client.post(url).json(origin).send().await?;
resp.error_for_status()?;
Ok(())
@ -44,4 +44,13 @@ impl Client {
Ok(())
}
pub async fn patch_origin(&mut self, id: &str, origin: &Origin) -> Result<(), ApiError> {
let url = self.url.join("origin/")?.join(id)?;
let resp = self.client.patch(url).json(origin).send().await?;
resp.error_for_status()?;
Ok(())
}
}

View File

@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};
use url::Url;
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, PartialEq, Eq)]
pub struct Origin {
pub url: Url,
}

View File

@ -46,7 +46,13 @@ impl Server {
.await?;
let app = Router::new()
.route("/origin/:id", get(get_origin).post(set_origin).delete(delete_origin))
.route(
"/origin/:id",
get(get_origin)
.post(set_origin)
.delete(delete_origin)
.patch(patch_origin),
)
.with_state(redis);
log::info!("serving requests: bind={}", self.config.listen);
@ -67,11 +73,8 @@ async fn get_origin(
log::debug!("get_origin: id={}", id);
let payload: String = match redis.get(&key).await? {
Some(payload) => payload,
None => return Err(AppError::NotFound),
};
let payload: Option<String> = redis.get(&key).await?;
let payload = payload.ok_or(AppError::NotFound)?;
let origin: Origin = serde_json::from_str(&payload)?;
Ok(Json(origin))
@ -94,7 +97,7 @@ async fn set_origin(
.arg(payload)
.arg("NX")
.arg("EX")
.arg(60 * 60 * 24 * 2) // Set the key to expire in 2 days; just in case we forget to remove it.
.arg(600) // Set the key to expire in 10 minutes; the origin needs to keep refreshing it.
.query_async(&mut redis)
.await?;
@ -113,6 +116,31 @@ async fn delete_origin(Path(id): Path<String>, State(mut redis): State<Connectio
}
}
// Update the expiration deadline.
async fn patch_origin(
Path(id): Path<String>,
State(mut redis): State<ConnectionManager>,
Json(origin): Json<Origin>,
) -> Result<(), AppError> {
let key = origin_key(&id);
// Make sure the contents haven't changed
// TODO make a LUA script to do this all in one operation.
let payload: Option<String> = redis.get(&key).await?;
let payload = payload.ok_or(AppError::NotFound)?;
let expected: Origin = serde_json::from_str(&payload)?;
if expected != origin {
return Err(AppError::Duplicate);
}
// Reset the timeout to 10 minutes.
match redis.expire(key, 600).await? {
0 => Err(AppError::NotFound),
_ => Ok(()),
}
}
fn origin_key(id: &str) -> String {
format!("origin.{}", id)
}

View File

@ -37,6 +37,8 @@ mp4 = "0.13"
anyhow = { version = "1", features = ["backtrace"] }
serde_json = "1"
rfc6381-codec = "0.1"
tracing = "0.1"
tracing-subscriber = "0.3"
[build-dependencies]
clap = { version = "4", features = ["derive"] }

View File

@ -15,9 +15,15 @@ use moq_transport::cache::broadcast;
async fn main() -> anyhow::Result<()> {
env_logger::init();
// Disable tracing so we don't get a bunch of Quinn spam.
let tracer = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(tracing::Level::WARN)
.finish();
tracing::subscriber::set_global_default(tracer).unwrap();
let config = Config::parse();
let (publisher, subscriber) = broadcast::new();
let (publisher, subscriber) = broadcast::new("");
let mut media = Media::new(&config, publisher).await?;
// Ugh, just let me use my native root certs already

View File

@ -1,25 +1,31 @@
use std::ops::{Deref, DerefMut};
use std::{
collections::{hash_map, HashMap},
sync::{Arc, Mutex},
collections::HashMap,
sync::{Arc, Mutex, Weak},
};
use moq_api::ApiError;
use moq_transport::cache::{broadcast, CacheError};
use url::Url;
use tokio::time;
use crate::RelayError;
#[derive(Clone)]
pub struct Origin {
// An API client used to get/set broadcasts.
// If None then we never use a remote origin.
// TODO: Stub this out instead.
api: Option<moq_api::Client>,
// The internal address of our node.
// If None then we can never advertise ourselves as an origin.
// TODO: Stub this out instead.
node: Option<Url>,
// A map of active broadcasts.
lookup: Arc<Mutex<HashMap<String, broadcast::Subscriber>>>,
// A map of active broadcasts by ID.
cache: Arc<Mutex<HashMap<String, Weak<Subscriber>>>>,
// A QUIC endpoint we'll use to fetch from other origins.
quic: quinn::Endpoint,
@ -30,48 +36,80 @@ impl Origin {
Self {
api,
node,
lookup: Default::default(),
cache: Default::default(),
quic,
}
}
pub async fn create_broadcast(&mut self, id: &str) -> Result<broadcast::Publisher, RelayError> {
let (publisher, subscriber) = broadcast::new();
/// Create a new broadcast with the given ID.
///
/// Publisher::run needs to be called to periodically refresh the origin cache.
pub async fn publish(&mut self, id: &str) -> Result<Publisher, RelayError> {
let (publisher, subscriber) = broadcast::new(id);
// Check if a broadcast already exists by that id.
match self.lookup.lock().unwrap().entry(id.to_string()) {
hash_map::Entry::Occupied(_) => return Err(CacheError::Duplicate.into()),
hash_map::Entry::Vacant(v) => v.insert(subscriber),
let subscriber = {
let mut cache = self.cache.lock().unwrap();
// Check if the broadcast already exists.
// TODO This is racey, because a new publisher could be created while existing subscribers are still active.
if cache.contains_key(id) {
return Err(CacheError::Duplicate.into());
}
// Create subscriber that will remove from the cache when dropped.
let subscriber = Arc::new(Subscriber {
broadcast: subscriber,
origin: self.clone(),
});
cache.insert(id.to_string(), Arc::downgrade(&subscriber));
subscriber
};
if let Some(ref mut api) = self.api {
// Create a publisher that constantly updates itself as the origin in moq-api.
// It holds a reference to the subscriber to prevent dropping early.
let mut publisher = Publisher {
broadcast: publisher,
subscriber,
api: None,
};
// Insert the publisher into the database.
if let Some(api) = self.api.as_mut() {
// Make a URL for the broadcast.
let url = self.node.as_ref().ok_or(RelayError::MissingNode)?.clone().join(id)?;
let origin = moq_api::Origin { url };
api.set_origin(id, &origin).await?;
log::info!("announcing origin: id={} url={}", id, url);
let entry = moq_api::Origin { url };
if let Err(err) = api.set_origin(id, entry).await {
self.lookup.lock().unwrap().remove(id);
return Err(err.into());
}
// Refresh every 5 minutes
publisher.api = Some((api.clone(), origin));
}
Ok(publisher)
}
pub fn get_broadcast(&self, id: &str) -> broadcast::Subscriber {
let mut lookup = self.lookup.lock().unwrap();
pub fn subscribe(&self, id: &str) -> Arc<Subscriber> {
let mut cache = self.cache.lock().unwrap();
if let Some(broadcast) = lookup.get(id) {
if broadcast.closed().is_none() {
return broadcast.clone();
if let Some(broadcast) = cache.get(id) {
if let Some(broadcast) = broadcast.upgrade() {
log::debug!("returned broadcast from cache: id={}", id);
return broadcast;
} else {
log::debug!("stale broadcast in cache somehow: id={}", id);
}
}
let (publisher, subscriber) = broadcast::new();
lookup.insert(id.to_string(), subscriber.clone());
let (publisher, subscriber) = broadcast::new(id);
let subscriber = Arc::new(Subscriber {
broadcast: subscriber,
origin: self.clone(),
});
cache.insert(id.to_string(), Arc::downgrade(&subscriber));
log::debug!("fetching into cache: id={}", id);
let mut this = self.clone();
let id = id.to_string();
@ -82,63 +120,104 @@ impl Origin {
// However, the downside is that we don't return an error immediately.
// If that's important, it can be done but it gets a bit racey.
tokio::spawn(async move {
match this.fetch_broadcast(&id).await {
Ok(session) => {
if let Err(err) = this.run_broadcast(session, publisher).await {
log::warn!("failed to run broadcast: id={} err={:#?}", id, err);
}
}
Err(err) => {
log::warn!("failed to fetch broadcast: id={} err={:#?}", id, err);
publisher.close(CacheError::NotFound).ok();
}
if let Err(err) = this.serve(&id, publisher).await {
log::warn!("failed to serve remote broadcast: id={} err={}", id, err);
}
});
subscriber
}
async fn fetch_broadcast(&mut self, id: &str) -> Result<webtransport_quinn::Session, RelayError> {
async fn serve(&mut self, id: &str, publisher: broadcast::Publisher) -> Result<(), RelayError> {
log::debug!("finding origin: id={}", id);
// Fetch the origin from the API.
let api = match self.api {
Some(ref mut api) => api,
let origin = self
.api
.as_mut()
.ok_or(CacheError::NotFound)?
.get_origin(id)
.await?
.ok_or(CacheError::NotFound)?;
// We return NotFound here instead of earlier just to simulate an API fetch.
None => return Err(CacheError::NotFound.into()),
};
log::info!("fetching origin: id={}", id);
let origin = api.get_origin(id).await?.ok_or(CacheError::NotFound)?;
log::info!("connecting to origin: url={}", origin.url);
log::debug!("fetching from origin: id={} url={}", id, origin.url);
// Establish the webtransport session.
let session = webtransport_quinn::connect(&self.quic, &origin.url).await?;
Ok(session)
}
async fn run_broadcast(
&mut self,
session: webtransport_quinn::Session,
broadcast: broadcast::Publisher,
) -> Result<(), RelayError> {
let session = moq_transport::session::Client::subscriber(session, broadcast).await?;
let session = moq_transport::session::Client::subscriber(session, publisher).await?;
session.run().await?;
Ok(())
}
}
pub async fn remove_broadcast(&mut self, id: &str) -> Result<(), RelayError> {
self.lookup.lock().unwrap().remove(id).ok_or(CacheError::NotFound)?;
pub struct Subscriber {
pub broadcast: broadcast::Subscriber,
if let Some(ref mut api) = self.api {
log::info!("deleting origin: id={}", id);
api.delete_origin(id).await?;
origin: Origin,
}
impl Drop for Subscriber {
fn drop(&mut self) {
log::debug!("subscriber: removing from cache: id={}", self.id);
self.origin.cache.lock().unwrap().remove(&self.broadcast.id);
}
}
impl Deref for Subscriber {
type Target = broadcast::Subscriber;
fn deref(&self) -> &Self::Target {
&self.broadcast
}
}
pub struct Publisher {
pub broadcast: broadcast::Publisher,
api: Option<(moq_api::Client, moq_api::Origin)>,
#[allow(dead_code)]
subscriber: Arc<Subscriber>,
}
impl Publisher {
pub async fn run(&mut self) -> Result<(), ApiError> {
// Every 5m tell the API we're still alive.
// TODO don't hard-code these values
let mut interval = time::interval(time::Duration::from_secs(60 * 5));
loop {
if let Some((api, origin)) = self.api.as_mut() {
log::debug!("refreshing origin: id={}", self.broadcast.id);
api.patch_origin(&self.broadcast.id, origin).await?;
}
// TODO move to start of loop; this is just for testing
interval.tick().await;
}
}
pub async fn close(&mut self) -> Result<(), ApiError> {
if let Some((api, _)) = self.api.as_mut() {
api.delete_origin(&self.broadcast.id).await?;
}
Ok(())
}
}
impl Deref for Publisher {
type Target = broadcast::Publisher;
fn deref(&self) -> &Self::Target {
&self.broadcast
}
}
impl DerefMut for Publisher {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.broadcast
}
}

View File

@ -30,6 +30,7 @@ impl Quic {
transport_config.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
transport_config.keep_alive_interval(Some(time::Duration::from_secs(4))); // TODO make this smarter
transport_config.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
transport_config.mtu_discovery_config(None); // Disable MTU discovery
let transport_config = Arc::new(transport_config);
let mut client_config = quinn::ClientConfig::new(Arc::new(client_config));

View File

@ -1,6 +1,6 @@
use anyhow::Context;
use moq_transport::{cache::broadcast, session::Request, setup::Role, MoqError};
use moq_transport::{session::Request, setup::Role, MoqError};
use crate::Origin;
@ -53,8 +53,16 @@ impl Session {
let role = request.role();
match role {
Role::Publisher => self.serve_publisher(id, request, &path).await,
Role::Subscriber => self.serve_subscriber(id, request, &path).await,
Role::Publisher => {
if let Err(err) = self.serve_publisher(id, request, &path).await {
log::warn!("error serving publisher: id={} path={} err={:#?}", id, path, err);
}
}
Role::Subscriber => {
if let Err(err) = self.serve_subscriber(id, request, &path).await {
log::warn!("error serving subscriber: id={} path={} err={:#?}", id, path, err);
}
}
Role::Both => {
log::warn!("role both not supported: id={}", id);
request.reject(300);
@ -66,44 +74,38 @@ impl Session {
Ok(())
}
async fn serve_publisher(&mut self, id: usize, request: Request, path: &str) {
async fn serve_publisher(&mut self, id: usize, request: Request, path: &str) -> anyhow::Result<()> {
log::info!("serving publisher: id={}, path={}", id, path);
let broadcast = match self.origin.create_broadcast(path).await {
Ok(broadcast) => broadcast,
let mut origin = match self.origin.publish(path).await {
Ok(origin) => origin,
Err(err) => {
log::warn!("error accepting publisher: id={} path={} err={:#?}", id, path, err);
return request.reject(err.code());
request.reject(err.code());
return Err(err.into());
}
};
if let Err(err) = self.run_publisher(request, broadcast).await {
log::warn!("error serving publisher: id={} path={} err={:#?}", id, path, err);
}
let session = request.subscriber(origin.broadcast.clone()).await?;
// TODO can we do this on drop? Otherwise we might miss it.
self.origin.remove_broadcast(path).await.ok();
}
tokio::select! {
_ = session.run() => origin.close().await?,
_ = origin.run() => (), // TODO send error to session
};
async fn run_publisher(&mut self, request: Request, publisher: broadcast::Publisher) -> anyhow::Result<()> {
let session = request.subscriber(publisher).await?;
session.run().await?;
Ok(())
}
async fn serve_subscriber(&mut self, id: usize, request: Request, path: &str) {
async fn serve_subscriber(&mut self, id: usize, request: Request, path: &str) -> anyhow::Result<()> {
log::info!("serving subscriber: id={} path={}", id, path);
let broadcast = self.origin.get_broadcast(path);
let subscriber = self.origin.subscribe(path);
if let Err(err) = self.run_subscriber(request, broadcast).await {
log::warn!("error serving subscriber: id={} path={} err={:#?}", id, path, err);
}
}
async fn run_subscriber(&mut self, request: Request, broadcast: broadcast::Subscriber) -> anyhow::Result<()> {
let session = request.publisher(broadcast).await?;
let session = request.publisher(subscriber.broadcast.clone()).await?;
session.run().await?;
// Make sure this doesn't get dropped too early
drop(subscriber);
Ok(())
}
}

View File

@ -13,21 +13,29 @@
use std::{
collections::{hash_map, HashMap, VecDeque},
fmt,
ops::Deref,
sync::Arc,
};
use super::{track, CacheError, Watch};
/// Create a new broadcast.
pub fn new() -> (Publisher, Subscriber) {
pub fn new(id: &str) -> (Publisher, Subscriber) {
let state = Watch::new(State::default());
let info = Arc::new(Info { id: id.to_string() });
let publisher = Publisher::new(state.clone());
let subscriber = Subscriber::new(state);
let publisher = Publisher::new(state.clone(), info.clone());
let subscriber = Subscriber::new(state, info);
(publisher, subscriber)
}
/// Static information about a broadcast.
#[derive(Debug)]
pub struct Info {
pub id: String,
}
/// Dynamic information about the broadcast.
#[derive(Debug)]
struct State {
@ -105,13 +113,14 @@ impl Default for State {
#[derive(Clone)]
pub struct Publisher {
state: Watch<State>,
info: Arc<Info>,
_dropped: Arc<Dropped>,
}
impl Publisher {
fn new(state: Watch<State>) -> Self {
fn new(state: Watch<State>, info: Arc<Info>) -> Self {
let _dropped = Arc::new(Dropped::new(state.clone()));
Self { state, _dropped }
Self { state, info, _dropped }
}
/// Create a new track with the given name, inserting it into the broadcast.
@ -148,9 +157,20 @@ impl Publisher {
}
}
impl Deref for Publisher {
type Target = Info;
fn deref(&self) -> &Self::Target {
&self.info
}
}
impl fmt::Debug for Publisher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Publisher").field("state", &self.state).finish()
f.debug_struct("Publisher")
.field("state", &self.state)
.field("info", &self.info)
.finish()
}
}
@ -160,13 +180,14 @@ impl fmt::Debug for Publisher {
#[derive(Clone)]
pub struct Subscriber {
state: Watch<State>,
info: Arc<Info>,
_dropped: Arc<Dropped>,
}
impl Subscriber {
fn new(state: Watch<State>) -> Self {
fn new(state: Watch<State>, info: Arc<Info>) -> Self {
let _dropped = Arc::new(Dropped::new(state.clone()));
Self { state, _dropped }
Self { state, info, _dropped }
}
/// Get a track from the broadcast by name.
@ -182,15 +203,42 @@ impl Subscriber {
state.into_mut().request(name)
}
/// Return if the broadcast is closed, either because the publisher was dropped or called [Publisher::close].
pub fn closed(&self) -> Option<CacheError> {
/// Check if the broadcast is closed, either because the publisher was dropped or called [Publisher::close].
pub fn is_closed(&self) -> Option<CacheError> {
self.state.lock().closed.as_ref().err().cloned()
}
/// Wait until if the broadcast is closed, either because the publisher was dropped or called [Publisher::close].
pub async fn closed(&self) -> CacheError {
loop {
let notify = {
let state = self.state.lock();
if let Some(err) = state.closed.as_ref().err() {
return err.clone();
}
state.changed()
};
notify.await;
}
}
}
impl Deref for Subscriber {
type Target = Info;
fn deref(&self) -> &Self::Target {
&self.info
}
}
impl fmt::Debug for Subscriber {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Subscriber").field("state", &self.state).finish()
f.debug_struct("Subscriber")
.field("state", &self.state)
.field("info", &self.info)
.finish()
}
}

View File

@ -64,7 +64,7 @@ impl Publisher {
stream?;
return Err(SessionError::RoleViolation(VarInt::ZERO));
}
// NOTE: this is not cancel safe, but it's fine since the other branch is a fatal error.
// NOTE: this is not cancel safe, but it's fine since the other branchs are fatal.
msg = self.control.recv() => {
let msg = msg?;
@ -72,7 +72,12 @@ impl Publisher {
if let Err(err) = self.recv_message(&msg).await {
log::warn!("message error: {:?} {:?}", err, msg);
}
}
},
// No more broadcasts are available.
err = self.source.closed() => {
self.webtransport.close(err.code(), err.reason().as_bytes());
return Ok(());
},
}
}
}
@ -178,7 +183,7 @@ impl Publisher {
expires: segment.expires,
};
log::debug!("serving object: {:?}", object);
log::trace!("serving object: {:?}", object);
let mut stream = self.webtransport.open_uni().await?;
stream.set_priority(object.priority).ok();

View File

@ -135,6 +135,7 @@ impl Subscriber {
}
async fn run_source(mut self) -> Result<(), SessionError> {
// NOTE: This returns Closed when the source is closed.
while let Some(track) = self.source.next_track().await? {
let name = track.name.clone();