214 lines
5.8 KiB
Rust
214 lines
5.8 KiB
Rust
use std::marker::PhantomData;
|
|
|
|
use anyhow::Context;
|
|
|
|
use bytes::Buf;
|
|
use moq_generic_transport::{SendStream, SendStreamUnframed, BidiStream, Connection};
|
|
use tokio::io::AsyncWriteExt;
|
|
use tokio::task::JoinSet; // allows locking across await
|
|
|
|
use moq_transport::{Announce, AnnounceError, AnnounceOk, Object, Subscribe, SubscribeError, SubscribeOk, VarInt};
|
|
use moq_transport_trait::SendObjects;
|
|
|
|
use super::{broker, control};
|
|
use crate::model::{segment, track};
|
|
|
|
pub struct Session<S: SendStream + SendStreamUnframed + Send, C: Connection + Send> {
|
|
// Objects are sent to the client
|
|
objects: SendObjects<C>,
|
|
|
|
// Used to send and receive control messages.
|
|
control: control::Component<control::Distribute>,
|
|
|
|
// Globally announced namespaces, which can be subscribed to.
|
|
broker: broker::Broadcasts,
|
|
|
|
// A list of tasks that are currently running.
|
|
run_subscribes: JoinSet<SubscribeError>, // run subscriptions, sending the returned error if they fail
|
|
|
|
_marker: PhantomData<S>,
|
|
}
|
|
|
|
impl<S, C> Session<S, C> where
|
|
S: SendStream + SendStreamUnframed + Send,
|
|
C: Connection<SendStream = S> + Send + 'static {
|
|
pub fn new(
|
|
objects: SendObjects<C>,
|
|
control: control::Component<control::Distribute>,
|
|
broker: broker::Broadcasts,
|
|
) -> Self {
|
|
Self {
|
|
objects,
|
|
control,
|
|
broker,
|
|
run_subscribes: JoinSet::new(),
|
|
_marker: PhantomData,
|
|
}
|
|
}
|
|
|
|
pub async fn run(mut self) -> anyhow::Result<()> {
|
|
// Announce all available tracks and get a stream of updates.
|
|
let (available, mut updates) = self.broker.available();
|
|
for namespace in available {
|
|
self.on_available(broker::Update::Insert(namespace)).await?;
|
|
}
|
|
|
|
loop {
|
|
tokio::select! {
|
|
res = self.run_subscribes.join_next(), if !self.run_subscribes.is_empty() => {
|
|
let res = res.expect("no tasks").expect("task aborted");
|
|
self.control.send(res).await?;
|
|
},
|
|
delta = updates.next() => {
|
|
let delta = delta.expect("no more broadcasts");
|
|
self.on_available(delta).await?;
|
|
},
|
|
msg = self.control.recv() => {
|
|
let msg = msg.context("failed to receive control message")?;
|
|
self.receive_message(msg).await?;
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn receive_message(&mut self, msg: control::Distribute) -> anyhow::Result<()> {
|
|
match msg {
|
|
control::Distribute::AnnounceOk(msg) => self.receive_announce_ok(msg),
|
|
control::Distribute::AnnounceError(msg) => self.receive_announce_error(msg),
|
|
control::Distribute::Subscribe(msg) => self.receive_subscribe(msg).await,
|
|
}
|
|
}
|
|
|
|
fn receive_announce_ok(&mut self, _msg: AnnounceOk) -> anyhow::Result<()> {
|
|
// TODO make sure we sent this announce
|
|
Ok(())
|
|
}
|
|
|
|
fn receive_announce_error(&mut self, msg: AnnounceError) -> anyhow::Result<()> {
|
|
// TODO make sure we sent this announce
|
|
// TODO remove this from the list of subscribable broadcasts.
|
|
anyhow::bail!("received ANNOUNCE_ERROR({:?}): {}", msg.code, msg.reason)
|
|
}
|
|
|
|
async fn receive_subscribe(&mut self, msg: Subscribe) -> anyhow::Result<()> {
|
|
match self.receive_subscribe_inner(&msg).await {
|
|
Ok(()) => {
|
|
self.control
|
|
.send(SubscribeOk {
|
|
track_id: msg.track_id,
|
|
expires: None,
|
|
})
|
|
.await
|
|
}
|
|
Err(e) => {
|
|
self.control
|
|
.send(SubscribeError {
|
|
track_id: msg.track_id,
|
|
code: VarInt::from_u32(1),
|
|
reason: e.to_string(),
|
|
})
|
|
.await
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn receive_subscribe_inner(&mut self, msg: &Subscribe) -> anyhow::Result<()> {
|
|
let track = self
|
|
.broker
|
|
.subscribe(&msg.track_namespace, &msg.track_name)
|
|
.context("could not find broadcast")?;
|
|
|
|
// TODO can we just clone self?
|
|
let objects = self.objects.clone();
|
|
let track_id = msg.track_id;
|
|
|
|
self.run_subscribes
|
|
.spawn(async move { Self::run_subscribe(objects, track_id, track).await });
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn run_subscribe(objects: SendObjects<C>, track_id: VarInt, mut track: track::Subscriber) -> SubscribeError {
|
|
let mut tasks = JoinSet::new();
|
|
let mut result = None;
|
|
|
|
loop {
|
|
tokio::select! {
|
|
// Accept new segments added to the track.
|
|
segment = track.next_segment(), if result.is_none() => {
|
|
match segment {
|
|
Ok(segment) => {
|
|
let objects = objects.clone();
|
|
tasks.spawn(async move { Self::serve_group(objects, track_id, segment).await });
|
|
},
|
|
Err(e) => {
|
|
result = Some(SubscribeError {
|
|
track_id,
|
|
code: e.code,
|
|
reason: e.reason,
|
|
})
|
|
},
|
|
}
|
|
},
|
|
// Poll any pending segments until they exit.
|
|
res = tasks.join_next(), if !tasks.is_empty() => {
|
|
let res = res.expect("no tasks").expect("task aborted");
|
|
if let Err(err) = res {
|
|
log::error!("failed to serve segment: {:?}", err);
|
|
}
|
|
},
|
|
else => return result.unwrap()
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn serve_group(
|
|
mut objects: SendObjects<C>,
|
|
track_id: VarInt,
|
|
mut segment: segment::Subscriber,
|
|
) -> anyhow::Result<()> {
|
|
let object = Object {
|
|
track: track_id,
|
|
group: segment.sequence,
|
|
sequence: VarInt::from_u32(0), // Always zero since we send an entire group as an object
|
|
send_order: segment.send_order,
|
|
};
|
|
|
|
let mut stream = objects.send(object).await?;
|
|
|
|
// Write each fragment as they are available.
|
|
while let Some(fragment) = segment.fragments.next().await {
|
|
let mut buf = bytes::Bytes::copy_from_slice(fragment.as_slice());
|
|
while buf.has_remaining() {
|
|
moq_generic_transport::send(&mut stream, &mut buf).await?;
|
|
}
|
|
// stream.write_all(fragment.as_slice()).await?;
|
|
}
|
|
|
|
// NOTE: stream is automatically closed when dropped
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn on_available(&mut self, delta: broker::Update) -> anyhow::Result<()> {
|
|
match delta {
|
|
broker::Update::Insert(name) => {
|
|
self.control
|
|
.send(Announce {
|
|
track_namespace: name.clone(),
|
|
})
|
|
.await
|
|
}
|
|
broker::Update::Remove(name, error) => {
|
|
self.control
|
|
.send(AnnounceError {
|
|
track_namespace: name,
|
|
code: error.code,
|
|
reason: error.reason,
|
|
})
|
|
.await
|
|
}
|
|
}
|
|
}
|
|
}
|