Varint as an explicit type (#35)

This commit is contained in:
kixelated 2023-06-16 19:52:52 -07:00 committed by GitHub
parent d7872ef77d
commit 4c04fbf2b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 177 additions and 130 deletions

View File

@ -20,11 +20,11 @@ impl Decode for Bytes {
#[async_trait]
impl Decode for Vec<u8> {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let size = u64::decode(r).await?;
let size = VarInt::decode(r).await?;
// NOTE: we don't use with_capacity since size is from an untrusted source
let mut buf = Vec::new();
r.take(size).read_to_end(&mut buf).await?;
r.take(size.into()).read_to_end(&mut buf).await?;
Ok(buf)
}
@ -38,17 +38,3 @@ impl Decode for String {
Ok(s)
}
}
#[async_trait]
impl Decode for u64 {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
VarInt::decode(r).await.map(Into::into)
}
}
#[async_trait]
impl Decode for usize {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
VarInt::decode(r).await.map(Into::into)
}
}

View File

@ -1,5 +1,6 @@
use super::{Decode, Encode};
use crate::coding::VarInt;
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
@ -9,7 +10,7 @@ use std::time::Duration;
impl Encode for Duration {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
let ms = self.as_millis();
let ms = u64::try_from(ms)?;
let ms = VarInt::try_from(ms)?;
ms.encode(w).await
}
}
@ -17,8 +18,7 @@ impl Encode for Duration {
#[async_trait]
impl Decode for Duration {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let ms = u64::decode(r).await?;
let ms = ms;
Ok(Self::from_millis(ms))
let ms = VarInt::decode(r).await?;
Ok(Self::from_millis(ms.into()))
}
}

View File

@ -12,25 +12,22 @@ pub trait Encode: Sized {
#[async_trait]
impl Encode for Bytes {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
self.len().encode(w).await?;
w.write_all(self).await?;
Ok(())
self.as_ref().encode(w).await
}
}
#[async_trait]
impl Encode for Vec<u8> {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
self.len().encode(w).await?;
w.write_all(self).await?;
Ok(())
self.as_slice().encode(w).await
}
}
#[async_trait]
impl Encode for &[u8] {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
self.len().encode(w).await?;
let size: VarInt = self.len().try_into()?;
size.encode(w).await?;
w.write_all(self).await?;
Ok(())
}
@ -42,17 +39,3 @@ impl Encode for String {
self.as_bytes().encode(w).await
}
}
#[async_trait]
impl Encode for u64 {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
VarInt::try_from(*self)?.encode(w).await
}
}
#[async_trait]
impl Encode for usize {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
VarInt::try_from(*self)?.encode(w).await
}
}

View File

@ -21,7 +21,22 @@ pub struct BoundsExceeded;
// It would be neat if we could express to Rust that the top two bits are available for use as enum
// discriminants
#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub(crate) struct VarInt(u64);
pub struct VarInt(u64);
impl VarInt {
pub const MAX: Self = Self((1 << 62) - 1);
/// Construct a `VarInt` infallibly using the largest available type.
/// Larger values need to use `try_from` instead.
pub const fn from_u32(x: u32) -> Self {
Self(x as u64)
}
/// Extract the integer value
pub const fn into_inner(self) -> u64 {
self.0
}
}
impl From<VarInt> for u64 {
fn from(x: VarInt) -> Self {
@ -35,6 +50,12 @@ impl From<VarInt> for usize {
}
}
impl From<VarInt> for u128 {
fn from(x: VarInt) -> Self {
x.0 as u128
}
}
impl From<u8> for VarInt {
fn from(x: u8) -> Self {
Self(x.into())
@ -58,7 +79,7 @@ impl TryFrom<u64> for VarInt {
/// Succeeds iff `x` < 2^62
fn try_from(x: u64) -> Result<Self, BoundsExceeded> {
if x < 2u64.pow(62) {
if x <= Self::MAX.into_inner() {
Ok(Self(x))
} else {
Err(BoundsExceeded)
@ -66,6 +87,19 @@ impl TryFrom<u64> for VarInt {
}
}
impl TryFrom<u128> for VarInt {
type Error = BoundsExceeded;
/// Succeeds iff `x` < 2^62
fn try_from(x: u128) -> Result<Self, BoundsExceeded> {
if x <= Self::MAX.into() {
Ok(Self(x as u64))
} else {
Err(BoundsExceeded)
}
}
}
impl TryFrom<usize> for VarInt {
type Error = BoundsExceeded;
/// Succeeds iff `x` < 2^62

View File

@ -1,4 +1,4 @@
use crate::coding::{Decode, Encode};
use crate::coding::{Decode, Encode, VarInt};
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
@ -10,7 +10,7 @@ pub struct AnnounceError {
pub track_namespace: String,
// An error code.
pub code: u64,
pub code: VarInt,
// An optional, human-readable reason.
pub reason: String,
@ -20,7 +20,7 @@ pub struct AnnounceError {
impl Decode for AnnounceError {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let track_namespace = String::decode(r).await?;
let code = u64::decode(r).await?;
let code = VarInt::decode(r).await?;
let reason = String::decode(r).await?;
Ok(Self {

View File

@ -16,7 +16,7 @@ pub use subscribe::*;
pub use subscribe_error::*;
pub use subscribe_ok::*;
use crate::coding::{Decode, Encode};
use crate::coding::{Decode, Encode, VarInt};
use async_trait::async_trait;
use std::fmt;
@ -35,9 +35,9 @@ macro_rules! message_types {
#[async_trait]
impl Decode for Message {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let t = u64::decode(r).await.context("failed to decode type")?;
let t = VarInt::decode(r).await.context("failed to decode type")?;
Ok(match u64::from(t) {
Ok(match t.into_inner() {
$($val => {
let msg = $name::decode(r).await.context(concat!("failed to decode ", stringify!($name)))?;
Self::$name(msg)
@ -52,8 +52,7 @@ macro_rules! message_types {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
match self {
$(Self::$name(ref m) => {
let id: u64 = $val; // tell the compiler this is a u64
id.encode(w).await.context("failed to encode type")?;
VarInt::from_u32($val).encode(w).await.context("failed to encode type")?;
m.encode(w).await.context("failed to encode message")
},)*
}

View File

@ -1,4 +1,4 @@
use crate::coding::{Decode, Encode};
use crate::coding::{Decode, Encode, VarInt};
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
@ -7,7 +7,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
pub struct Subscribe {
// An ID we choose so we can map to the track_name.
// Proposal: https://github.com/moq-wg/moq-transport/issues/209
pub track_id: u64,
pub track_id: VarInt,
// The track namespace.
pub track_namespace: String,
@ -19,7 +19,7 @@ pub struct Subscribe {
#[async_trait]
impl Decode for Subscribe {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let track_id = u64::decode(r).await?;
let track_id = VarInt::decode(r).await?;
let track_namespace = String::decode(r).await?;
let track_name = String::decode(r).await?;

View File

@ -1,4 +1,4 @@
use crate::coding::{Decode, Encode};
use crate::coding::{Decode, Encode, VarInt};
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
@ -8,10 +8,10 @@ pub struct SubscribeError {
// NOTE: No full track name because of this proposal: https://github.com/moq-wg/moq-transport/issues/209
// The ID for this track.
pub track_id: u64,
pub track_id: VarInt,
// An error code.
pub code: u64,
pub code: VarInt,
// An optional, human-readable reason.
pub reason: String,
@ -20,8 +20,8 @@ pub struct SubscribeError {
#[async_trait]
impl Decode for SubscribeError {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let track_id = u64::decode(r).await?;
let code = u64::decode(r).await?;
let track_id = VarInt::decode(r).await?;
let code = VarInt::decode(r).await?;
let reason = String::decode(r).await?;
Ok(Self { track_id, code, reason })

View File

@ -1,4 +1,4 @@
use crate::coding::{Decode, Encode};
use crate::coding::{Decode, Encode, VarInt};
use std::time::Duration;
@ -10,7 +10,7 @@ pub struct SubscribeOk {
// NOTE: No full track name because of this proposal: https://github.com/moq-wg/moq-transport/issues/209
// The ID for this track.
pub track_id: u64,
pub track_id: VarInt,
// The subscription will end after this duration has elapsed.
// A value of zero is invalid.
@ -20,7 +20,7 @@ pub struct SubscribeOk {
#[async_trait]
impl Decode for SubscribeOk {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let track_id = u64::decode(r).await?;
let track_id = VarInt::decode(r).await?;
let expires = Duration::decode(r).await?;
let expires = if expires == Duration::ZERO { None } else { Some(expires) };

View File

@ -1,4 +1,4 @@
use crate::coding::{Decode, Encode};
use crate::coding::{Decode, Encode, VarInt};
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
@ -8,30 +8,30 @@ use tokio::io::{AsyncRead, AsyncWrite};
pub struct Header {
// An ID for this track.
// Proposal: https://github.com/moq-wg/moq-transport/issues/209
pub track_id: u64,
pub track_id: VarInt,
// The group sequence number.
pub group_sequence: u64,
pub group_sequence: VarInt,
// The object sequence number.
pub object_sequence: u64,
pub object_sequence: VarInt,
// The priority/send order.
pub send_order: u64,
pub send_order: VarInt,
}
#[async_trait]
impl Decode for Header {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let typ = u64::decode(r).await?;
anyhow::ensure!(typ == 0, "OBJECT type must be 0");
let typ = VarInt::decode(r).await?;
anyhow::ensure!(u64::from(typ) == 0, "OBJECT type must be 0");
// NOTE: size has been omitted
let track_id = u64::decode(r).await?;
let group_sequence = u64::decode(r).await?;
let object_sequence = u64::decode(r).await?;
let send_order = u64::decode(r).await?;
let track_id = VarInt::decode(r).await?;
let group_sequence = VarInt::decode(r).await?;
let object_sequence = VarInt::decode(r).await?;
let send_order = VarInt::decode(r).await?;
Ok(Self {
track_id,
@ -45,7 +45,7 @@ impl Decode for Header {
#[async_trait]
impl Encode for Header {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
0u64.encode(w).await?;
VarInt::from_u32(0).encode(w).await?;
self.track_id.encode(w).await?;
self.group_sequence.encode(w).await?;
self.object_sequence.encode(w).await?;

View File

@ -1,5 +1,5 @@
use super::{Role, Versions};
use crate::coding::{Decode, Encode};
use crate::coding::{Decode, Encode, VarInt};
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
@ -26,8 +26,8 @@ pub struct Client {
#[async_trait]
impl Decode for Client {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let typ = u64::decode(r).await.context("failed to read type")?;
anyhow::ensure!(typ == 1, "client SETUP must be type 1");
let typ = VarInt::decode(r).await.context("failed to read type")?;
anyhow::ensure!(typ.into_inner() == 1, "client SETUP must be type 1");
let versions = Versions::decode(r).await.context("failed to read supported versions")?;
anyhow::ensure!(!versions.is_empty(), "client must support at least one version");
@ -42,7 +42,7 @@ impl Decode for Client {
#[async_trait]
impl Encode for Client {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
1u64.encode(w).await?;
VarInt::from_u32(1).encode(w).await?;
anyhow::ensure!(!self.versions.is_empty(), "client must support at least one version");
self.versions.encode(w).await?;

View File

@ -1,7 +1,7 @@
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::coding::{Decode, Encode};
use crate::coding::{Decode, Encode, VarInt};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
@ -26,21 +26,21 @@ impl Role {
}
}
impl From<Role> for u64 {
impl From<Role> for VarInt {
fn from(r: Role) -> Self {
match r {
VarInt::from_u32(match r {
Role::Publisher => 0x0,
Role::Subscriber => 0x1,
Role::Both => 0x2,
}
})
}
}
impl TryFrom<u64> for Role {
impl TryFrom<VarInt> for Role {
type Error = anyhow::Error;
fn try_from(v: u64) -> Result<Self, Self::Error> {
Ok(match v {
fn try_from(v: VarInt) -> Result<Self, Self::Error> {
Ok(match v.into_inner() {
0x0 => Self::Publisher,
0x1 => Self::Subscriber,
0x2 => Self::Both,
@ -52,7 +52,7 @@ impl TryFrom<u64> for Role {
#[async_trait]
impl Decode for Role {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let v = u64::decode(r).await?;
let v = VarInt::decode(r).await?;
v.try_into()
}
}
@ -60,6 +60,6 @@ impl Decode for Role {
#[async_trait]
impl Encode for Role {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
u64::from(*self).encode(w).await
VarInt::from(*self).encode(w).await
}
}

View File

@ -1,5 +1,5 @@
use super::{Role, Version};
use crate::coding::{Decode, Encode};
use crate::coding::{Decode, Encode, VarInt};
use anyhow::Context;
use async_trait::async_trait;
@ -21,8 +21,8 @@ pub struct Server {
#[async_trait]
impl Decode for Server {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let typ = u64::decode(r).await.context("failed to read type")?;
anyhow::ensure!(typ == 2, "server SETUP must be type 2");
let typ = VarInt::decode(r).await.context("failed to read type")?;
anyhow::ensure!(typ.into_inner() == 2, "server SETUP must be type 2");
let version = Version::decode(r).await.context("failed to read version")?;
let role = Role::decode(r).await.context("failed to read role")?;
@ -34,7 +34,7 @@ impl Decode for Server {
#[async_trait]
impl Encode for Server {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
2u64.encode(w).await?; // setup type
VarInt::from_u32(2).encode(w).await?; // setup type
self.version.encode(w).await?;
self.role.encode(w).await?;

View File

@ -1,4 +1,4 @@
use crate::coding::{Decode, Encode};
use crate::coding::{Decode, Encode, VarInt};
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
@ -6,19 +6,19 @@ use tokio::io::{AsyncRead, AsyncWrite};
use std::ops::Deref;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Version(pub u64);
pub struct Version(pub VarInt);
impl Version {
pub const DRAFT_00: Version = Version(0xff00);
pub const DRAFT_00: Version = Version(VarInt::from_u32(0xff00));
}
impl From<u64> for Version {
fn from(v: u64) -> Self {
impl From<VarInt> for Version {
fn from(v: VarInt) -> Self {
Self(v)
}
}
impl From<Version> for u64 {
impl From<Version> for VarInt {
fn from(v: Version) -> Self {
v.0
}
@ -27,7 +27,7 @@ impl From<Version> for u64 {
#[async_trait]
impl Decode for Version {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let v = u64::decode(r).await?;
let v = VarInt::decode(r).await?;
Ok(Self(v))
}
}
@ -45,7 +45,7 @@ pub struct Versions(pub Vec<Version>);
#[async_trait]
impl Decode for Versions {
async fn decode<R: AsyncRead + Unpin + Send>(r: &mut R) -> anyhow::Result<Self> {
let count = u64::decode(r).await?;
let count = VarInt::decode(r).await?.into_inner();
let mut vs = Vec::new();
for _ in 0..count {
@ -60,7 +60,8 @@ impl Decode for Versions {
#[async_trait]
impl Encode for Versions {
async fn encode<W: AsyncWrite + Unpin + Send>(&self, w: &mut W) -> anyhow::Result<()> {
self.0.len().encode(w).await?;
let size: VarInt = self.0.len().try_into()?;
size.encode(w).await?;
for v in &self.0 {
v.encode(w).await?;
}

View File

@ -7,6 +7,7 @@ use tokio::task::JoinSet;
use std::sync::Arc;
use moq_transport::coding::VarInt;
use moq_transport::{control, data, server, setup};
pub struct Session {
@ -82,7 +83,7 @@ impl Session {
res = self.tasks.join_next(), if !self.tasks.is_empty() => {
let res = res.expect("no tasks").expect("task aborted");
if let Err(err) = res {
log::warn!("failed to serve subscription: {:?}", err);
log::error!("failed to serve subscription: {:?}", err);
}
}
}
@ -124,7 +125,7 @@ impl Session {
Err(e) => {
self.send_message(control::SubscribeError {
track_id: sub.track_id,
code: 1,
code: VarInt::from_u32(1),
reason: e.to_string(),
})
.await
@ -144,9 +145,11 @@ impl Session {
.context("unknown track name")?
.clone();
let track_id = sub.track_id;
let sub = Subscription {
track,
track_id: sub.track_id,
track_id,
transport: self.transport.clone(),
};
@ -158,7 +161,7 @@ impl Session {
pub struct Subscription {
transport: Arc<data::Transport>,
track_id: u64,
track_id: VarInt,
track: media::Track,
}
@ -197,18 +200,17 @@ impl Subscription {
struct Group {
transport: Arc<data::Transport>,
track_id: u64,
track_id: VarInt,
segment: media::Segment,
}
impl Group {
pub async fn serve(mut self) -> anyhow::Result<()> {
// TODO proper values
let header = moq_transport::data::Header {
track_id: self.track_id,
group_sequence: 0, // TODO
object_sequence: 0, // Always zero since we send an entire group as an object
send_order: 0, // TODO
group_sequence: self.segment.sequence,
object_sequence: VarInt::from_u32(0), // Always zero since we send an entire group as an object
send_order: self.segment.send_order,
};
let mut stream = self.transport.send(header).await?;

View File

@ -1,7 +1,9 @@
use super::Subscriber;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use moq_transport::coding::VarInt;
// Map from track namespace to broadcast.
// TODO support updates
@ -21,8 +23,14 @@ pub struct Track {
#[derive(Clone)]
pub struct Segment {
// The timestamp of the segment.
pub timestamp: Duration,
// The sequence number of the segment within the track.
pub sequence: VarInt,
// The priority of the segment within the BROADCAST.
pub send_order: VarInt,
// The time at which the segment expires for cache purposes.
pub expires: Option<Instant>,
// A list of fragments that make up the segment.
pub fragments: Subscriber<Fragment>,

View File

@ -11,6 +11,8 @@ use anyhow::Context;
use std::collections::HashMap;
use std::sync::Arc;
use moq_transport::coding::VarInt;
use super::{Broadcast, Fragment, Producer, Segment, Track};
pub struct Source {
@ -98,8 +100,10 @@ impl Source {
fragments.push(raw.into());
segments.push(Segment {
sequence: VarInt::from_u32(0), // first and only segment
send_order: VarInt::from_u32(0), // highest priority
expires: None, // never delete from the cache
fragments: fragments.subscribe(),
timestamp: time::Duration::ZERO,
});
Track {
@ -170,45 +174,75 @@ struct SourceTrack {
// The number of units per second.
timescale: u64,
// The number of segments produced.
sequence: u64,
}
impl SourceTrack {
fn new(segments: Producer<Segment>, timescale: u64) -> Self {
Self {
segments,
sequence: 0,
fragments: None,
timescale,
}
}
pub fn header(&mut self, raw: Vec<u8>, fragment: SourceFragment) -> anyhow::Result<()> {
// Close the current segment if we have a new keyframe.
if fragment.keyframe {
self.fragments.take();
if let Some(fragments) = self.fragments.as_mut() {
if !fragment.keyframe {
// Use the existing segment
fragments.push(raw.into());
return Ok(());
}
}
// Get or create the current segment.
let fragments = self.fragments.get_or_insert_with(|| {
// Compute the timestamp in seconds.
let timestamp = fragment.timestamp(self.timescale);
// Otherwise make a new segment
let now = time::Instant::now();
// Compute the timestamp in milliseconds.
// Overflows after 583 million years, so we're fine.
let timestamp = fragment
.timestamp(self.timescale)
.as_millis()
.try_into()
.context("timestamp too large")?;
// The send order is simple; newer timestamps are higher priority.
// TODO give audio a boost?
let send_order = VarInt::MAX
.into_inner()
.checked_sub(timestamp)
.context("timestamp too large")?
.try_into()
.unwrap();
// Delete segments after 10s.
let expires = Some(now + time::Duration::from_secs(10));
let sequence = self.sequence.try_into().context("sequence too large")?;
self.sequence += 1;
// Create a new segment, and save the fragments producer so we can push to it.
let fragments = Producer::<Fragment>::new();
let mut fragments = Producer::<Fragment>::new();
self.segments.push(Segment {
timestamp,
sequence,
expires,
send_order,
fragments: fragments.subscribe(),
});
// Remove any segments older than 10s.
let expires = timestamp.saturating_sub(time::Duration::from_secs(10));
self.segments.drain(|segment| segment.timestamp < expires);
fragments
});
// TODO This can only drain from the FRONT of the queue, so don't get clever with expirations.
self.segments.drain(|segment| segment.expires.unwrap() < now);
// Insert the raw atom into the segment.
fragments.push(raw.into());
// Save for the next iteration
self.fragments = Some(fragments);
Ok(())
}