cargo fmt

This commit is contained in:
Luke Curley 2023-05-02 11:05:21 -07:00
parent e578b757e5
commit b5b7ffedfa
12 changed files with 126 additions and 87 deletions

View File

@ -1,3 +1,3 @@
pub mod transport;
pub mod session;
pub mod media; pub mod media;
pub mod session;
pub mod transport;

View File

@ -1,4 +1,4 @@
use warp::{session,transport}; use warp::{session, transport};
use clap::Parser; use clap::Parser;
use env_logger; use env_logger;
@ -28,7 +28,7 @@ fn main() -> anyhow::Result<()> {
let args = Cli::parse(); let args = Cli::parse();
let server_config = transport::Config{ let server_config = transport::Config {
addr: args.addr, addr: args.addr,
cert: args.cert, cert: args.cert,
key: args.key, key: args.key,

View File

@ -1,3 +1,3 @@
mod source; mod source;
pub use source::{Fragment,Source}; pub use source::{Fragment, Source};

View File

@ -1,13 +1,13 @@
use std::{io,fs,time};
use io::Read; use io::Read;
use std::collections::{VecDeque}; use std::collections::VecDeque;
use std::{fs, io, time};
use std::io::Write; use std::io::Write;
use mp4;
use anyhow; use anyhow;
use mp4;
use mp4::{ReadBox,WriteBox}; use mp4::{ReadBox, WriteBox};
pub struct Source { pub struct Source {
// We read the file once, in order, and don't seek backwards. // We read the file once, in order, and don't seek backwards.
@ -40,7 +40,7 @@ pub struct Fragment {
pub keyframe: bool, pub keyframe: bool,
// The timestamp of the fragment, in milliseconds, to simulate a live stream. // The timestamp of the fragment, in milliseconds, to simulate a live stream.
pub timestamp: Option<u64> pub timestamp: Option<u64>,
} }
impl Source { impl Source {
@ -49,7 +49,7 @@ impl Source {
let reader = io::BufReader::new(f); let reader = io::BufReader::new(f);
let start = time::Instant::now(); let start = time::Instant::now();
Ok(Self{ Ok(Self {
reader, reader,
start, start,
fragments: VecDeque::new(), fragments: VecDeque::new(),
@ -64,7 +64,7 @@ impl Source {
}; };
if self.timeout().is_some() { if self.timeout().is_some() {
return Ok(None) return Ok(None);
} }
Ok(self.fragments.pop_front()) Ok(self.fragments.pop_front())
@ -84,7 +84,7 @@ impl Source {
// Don't return anything until we know the total number of tracks. // Don't return anything until we know the total number of tracks.
// To be honest, I didn't expect the borrow checker to allow this, but it does! // To be honest, I didn't expect the borrow checker to allow this, but it does!
self.ftyp = atom; self.ftyp = atom;
}, }
mp4::BoxType::MoovBox => { mp4::BoxType::MoovBox => {
// We need to split the moov based on the tracks. // We need to split the moov based on the tracks.
let moov = mp4::MoovBox::read_box(&mut reader, header.size)?; let moov = mp4::MoovBox::read_box(&mut reader, header.size)?;
@ -105,7 +105,11 @@ impl Source {
// We remove every box for other track IDs. // We remove every box for other track IDs.
let mut toov = moov.clone(); let mut toov = moov.clone();
toov.traks.retain(|t| t.tkhd.track_id == track_id); toov.traks.retain(|t| t.tkhd.track_id == track_id);
toov.mvex.as_mut().expect("missing mvex").trexs.retain(|f| f.track_id == track_id); toov.mvex
.as_mut()
.expect("missing mvex")
.trexs
.retain(|f| f.track_id == track_id);
// Marshal the box. // Marshal the box.
let mut toov_data = Vec::new(); let mut toov_data = Vec::new();
@ -124,7 +128,7 @@ impl Source {
} }
self.moov = Some(moov); self.moov = Some(moov);
}, }
mp4::BoxType::MoofBox => { mp4::BoxType::MoofBox => {
let moof = mp4::MoofBox::read_box(&mut reader, header.size)?; let moof = mp4::MoofBox::read_box(&mut reader, header.size)?;
@ -133,19 +137,19 @@ impl Source {
anyhow::bail!("multiple tracks per moof atom") anyhow::bail!("multiple tracks per moof atom")
} }
self.fragments.push_back(Fragment{ self.fragments.push_back(Fragment {
track: moof.trafs[0].tfhd.track_id, track: moof.trafs[0].tfhd.track_id,
typ: mp4::BoxType::MoofBox, typ: mp4::BoxType::MoofBox,
data: atom, data: atom,
keyframe: has_keyframe(&moof), keyframe: has_keyframe(&moof),
timestamp: first_timestamp(&moof), timestamp: first_timestamp(&moof),
}) })
}, }
mp4::BoxType::MdatBox => { mp4::BoxType::MdatBox => {
let moof = self.fragments.back().expect("no atom before mdat"); let moof = self.fragments.back().expect("no atom before mdat");
assert!(moof.typ == mp4::BoxType::MoofBox, "no moof before mdat"); assert!(moof.typ == mp4::BoxType::MoofBox, "no moof before mdat");
self.fragments.push_back(Fragment{ self.fragments.push_back(Fragment {
track: moof.track, track: moof.track,
typ: mp4::BoxType::MoofBox, typ: mp4::BoxType::MoofBox,
data: atom, data: atom,
@ -154,8 +158,8 @@ impl Source {
}); });
// We have some media data, return so we can start sending it. // We have some media data, return so we can start sending it.
return Ok(()) return Ok(());
}, }
_ => anyhow::bail!("unknown top-level atom: {:?}", header.name), _ => anyhow::bail!("unknown top-level atom: {:?}", header.name),
} }
} }
@ -167,7 +171,12 @@ impl Source {
let timestamp = next.timestamp?; let timestamp = next.timestamp?;
// Find the timescale for the track. // Find the timescale for the track.
let track = self.moov.as_ref()?.traks.iter().find(|t| t.tkhd.track_id == next.track)?; let track = self
.moov
.as_ref()?
.traks
.iter()
.find(|t| t.tkhd.track_id == next.track)?;
let timescale = track.mdia.mdhd.timescale as u64; let timescale = track.mdia.mdhd.timescale as u64;
let delay = time::Duration::from_millis(1000 * timestamp / timescale); let delay = time::Duration::from_millis(1000 * timestamp / timescale);
@ -195,17 +204,21 @@ fn read_box<R: io::Read>(reader: &mut R) -> anyhow::Result<Vec<u8>> {
1 => { 1 => {
reader.read_exact(&mut buf)?; reader.read_exact(&mut buf)?;
let size_large = u64::from_be_bytes(buf); let size_large = u64::from_be_bytes(buf);
anyhow::ensure!(size_large >= 16, "impossible extended box size: {}", size_large); anyhow::ensure!(
size_large >= 16,
"impossible extended box size: {}",
size_large
);
reader.take(size_large - 16) reader.take(size_large - 16)
}, }
2..=7 => { 2..=7 => {
anyhow::bail!("impossible box size: {}", size) anyhow::bail!("impossible box size: {}", size)
} }
// Otherwise read based on the size. // Otherwise read based on the size.
size => reader.take(size - 8) size => reader.take(size - 8),
}; };
// Append to the vector and return it. // Append to the vector and return it.
@ -238,7 +251,7 @@ fn has_keyframe(moof: &mp4::MoofBox) -> bool {
let non_sync = (flags >> 16) & 0x1 == 0x1; // kSampleIsNonSyncSample let non_sync = (flags >> 16) & 0x1 == 0x1; // kSampleIsNonSyncSample
if keyframe && !non_sync { if keyframe && !non_sync {
return true return true;
} }
} }
} }

View File

@ -1,4 +1,4 @@
mod session;
mod message; mod message;
mod session;
pub use session::Session; pub use session::Session;

View File

@ -3,8 +3,8 @@ use std::time;
use quiche; use quiche;
use quiche::h3::webtransport; use quiche::h3::webtransport;
use crate::{media,transport};
use super::message; use super::message;
use crate::{media, transport};
#[derive(Default)] #[derive(Default)]
pub struct Session { pub struct Session {
@ -16,7 +16,11 @@ pub struct Session {
impl transport::App for Session { impl transport::App for Session {
// Process any updates to a session. // Process any updates to a session.
fn poll(&mut self, conn: &mut quiche::Connection, session: &mut webtransport::ServerSession) -> anyhow::Result<()> { fn poll(
&mut self,
conn: &mut quiche::Connection,
session: &mut webtransport::ServerSession,
) -> anyhow::Result<()> {
loop { loop {
let event = match session.poll(conn) { let event = match session.poll(conn) {
Err(webtransport::Error::Done) => break, Err(webtransport::Error::Done) => break,
@ -39,18 +43,16 @@ impl transport::App for Session {
self.media = Some(media); self.media = Some(media);
session.accept_connect_request(conn, None).unwrap(); session.accept_connect_request(conn, None).unwrap();
}, }
webtransport::ServerEvent::StreamData(stream_id) => { webtransport::ServerEvent::StreamData(stream_id) => {
let mut buf = vec![0; 10000]; let mut buf = vec![0; 10000];
while let Ok(len) = while let Ok(len) = session.recv_stream_data(conn, stream_id, &mut buf) {
session.recv_stream_data(conn, stream_id, &mut buf)
{
let stream_data = &buf[0..len]; let stream_data = &buf[0..len];
log::debug!("stream data {:?}", stream_data); log::debug!("stream data {:?}", stream_data);
} }
} }
_ => {}, _ => {}
} }
} }
@ -69,7 +71,11 @@ impl transport::App for Session {
} }
impl Session { impl Session {
fn poll_source(&mut self, conn: &mut quiche::Connection, session: &mut webtransport::ServerSession) -> anyhow::Result<()> { fn poll_source(
&mut self,
conn: &mut quiche::Connection,
session: &mut webtransport::ServerSession,
) -> anyhow::Result<()> {
// Get the media source once the connection is established. // Get the media source once the connection is established.
let media = match &mut self.media { let media = match &mut self.media {
Some(m) => m, Some(m) => m,
@ -92,7 +98,7 @@ impl Session {
// Encode a JSON header indicating this is the video track. // Encode a JSON header indicating this is the video track.
let mut message = message::Message::new(); let mut message = message::Message::new();
message.segment = Some(message::Segment{ message.segment = Some(message::Segment {
init: "video".to_string(), init: "video".to_string(),
}); });
@ -105,13 +111,13 @@ impl Session {
self.streams.send(conn, stream_id, &data, false)?; self.streams.send(conn, stream_id, &data, false)?;
stream_id stream_id
}, }
None => { None => {
// This is the start of an init segment. // This is the start of an init segment.
// Create a JSON header. // Create a JSON header.
let mut message = message::Message::new(); let mut message = message::Message::new();
message.init = Some(message::Init{ message.init = Some(message::Init {
id: "video".to_string(), id: "video".to_string(),
}); });

View File

@ -3,6 +3,10 @@ use std::time;
use quiche::h3::webtransport; use quiche::h3::webtransport;
pub trait App: Default { pub trait App: Default {
fn poll(&mut self, conn: &mut quiche::Connection, session: &mut webtransport::ServerSession) -> anyhow::Result<()>; fn poll(
&mut self,
conn: &mut quiche::Connection,
session: &mut webtransport::ServerSession,
) -> anyhow::Result<()>;
fn timeout(&self) -> Option<time::Duration>; fn timeout(&self) -> Option<time::Duration>;
} }

View File

@ -1,6 +1,6 @@
mod server;
mod connection;
mod app; mod app;
mod connection;
mod server;
mod streams; mod streams;
pub use app::App; pub use app::App;

View File

@ -2,8 +2,8 @@ use std::io;
use quiche::h3::webtransport; use quiche::h3::webtransport;
use super::connection;
use super::app; use super::app;
use super::connection;
const MAX_DATAGRAM_SIZE: usize = 1350; const MAX_DATAGRAM_SIZE: usize = 1350;
@ -36,11 +36,9 @@ impl<T: app::App> Server<T> {
let poll = mio::Poll::new().unwrap(); let poll = mio::Poll::new().unwrap();
let events = mio::Events::with_capacity(1024); let events = mio::Events::with_capacity(1024);
poll.registry().register( poll.registry()
&mut socket, .register(&mut socket, mio::Token(0), mio::Interest::READABLE)
mio::Token(0), .unwrap();
mio::Interest::READABLE,
).unwrap();
// Generate random values for connection IDs. // Generate random values for connection IDs.
let rng = ring::rand::SystemRandom::new(); let rng = ring::rand::SystemRandom::new();
@ -50,7 +48,8 @@ impl<T: app::App> Server<T> {
let mut quic = quiche::Config::new(quiche::PROTOCOL_VERSION).unwrap(); let mut quic = quiche::Config::new(quiche::PROTOCOL_VERSION).unwrap();
quic.load_cert_chain_from_pem_file(&config.cert).unwrap(); quic.load_cert_chain_from_pem_file(&config.cert).unwrap();
quic.load_priv_key_from_pem_file(&config.key).unwrap(); quic.load_priv_key_from_pem_file(&config.key).unwrap();
quic.set_application_protos(quiche::h3::APPLICATION_PROTOCOL).unwrap(); quic.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)
.unwrap();
quic.set_max_idle_timeout(5000); quic.set_max_idle_timeout(5000);
quic.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); quic.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE);
quic.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE); quic.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE);
@ -92,7 +91,10 @@ impl<T: app::App> Server<T> {
// Find the shorter timeout from all the active connections. // Find the shorter timeout from all the active connections.
// //
// TODO: use event loop that properly supports timers // TODO: use event loop that properly supports timers
let timeout = self.conns.values().filter_map(|c| { let timeout = self
.conns
.values()
.filter_map(|c| {
let timeout = c.quiche.timeout(); let timeout = c.quiche.timeout();
let expires = c.app.timeout(); let expires = c.app.timeout();
@ -102,7 +104,8 @@ impl<T: app::App> Server<T> {
(None, Some(b)) => Some(b), (None, Some(b)) => Some(b),
(None, None) => None, (None, None) => None,
} }
}).min(); })
.min();
self.poll.poll(&mut self.events, timeout).unwrap(); self.poll.poll(&mut self.events, timeout).unwrap();
@ -120,7 +123,7 @@ impl<T: app::App> Server<T> {
// Reads packets from the socket, updating any internal connection state. // Reads packets from the socket, updating any internal connection state.
fn receive(&mut self) -> anyhow::Result<()> { fn receive(&mut self) -> anyhow::Result<()> {
let mut src= [0; MAX_DATAGRAM_SIZE]; let mut src = [0; MAX_DATAGRAM_SIZE];
// Try reading any data currently available on the socket. // Try reading any data currently available on the socket.
loop { loop {
@ -150,20 +153,24 @@ impl<T: app::App> Server<T> {
conn.quiche.recv(src, info)?; conn.quiche.recv(src, info)?;
if conn.session.is_none() && conn.quiche.is_established() { if conn.session.is_none() && conn.quiche.is_established() {
conn.session = Some(webtransport::ServerSession::with_transport(&mut conn.quiche)?) conn.session = Some(webtransport::ServerSession::with_transport(
&mut conn.quiche,
)?)
} }
continue continue;
} else if let Some(conn) = self.conns.get_mut(&conn_id) { } else if let Some(conn) = self.conns.get_mut(&conn_id) {
// 1-RTT traffic. // 1-RTT traffic.
conn.quiche.recv(src, info)?; conn.quiche.recv(src, info)?;
// TODO is this needed here? // TODO is this needed here?
if conn.session.is_none() && conn.quiche.is_established() { if conn.session.is_none() && conn.quiche.is_established() {
conn.session = Some(webtransport::ServerSession::with_transport(&mut conn.quiche)?) conn.session = Some(webtransport::ServerSession::with_transport(
&mut conn.quiche,
)?)
} }
continue continue;
} }
if hdr.ty != quiche::Type::Initial { if hdr.ty != quiche::Type::Initial {
@ -174,10 +181,10 @@ impl<T: app::App> Server<T> {
if !quiche::version_is_supported(hdr.version) { if !quiche::version_is_supported(hdr.version) {
let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut dst).unwrap(); let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut dst).unwrap();
let dst= &dst[..len]; let dst = &dst[..len];
self.socket.send_to(dst, from).unwrap(); self.socket.send_to(dst, from).unwrap();
continue continue;
} }
let mut scid = [0; quiche::MAX_CONN_ID_LEN]; let mut scid = [0; quiche::MAX_CONN_ID_LEN];
@ -202,10 +209,10 @@ impl<T: app::App> Server<T> {
) )
.unwrap(); .unwrap();
let dst= &dst[..len]; let dst = &dst[..len];
self.socket.send_to(dst, from).unwrap(); self.socket.send_to(dst, from).unwrap();
continue continue;
} }
let odcid = validate_token(&from, token); let odcid = validate_token(&from, token);
@ -222,21 +229,23 @@ impl<T: app::App> Server<T> {
// Reuse the source connection ID we sent in the Retry packet, // Reuse the source connection ID we sent in the Retry packet,
// instead of changing it again. // instead of changing it again.
let conn_id= hdr.dcid.clone(); let conn_id = hdr.dcid.clone();
let local_addr = self.socket.local_addr().unwrap(); let local_addr = self.socket.local_addr().unwrap();
let mut conn = quiche::accept(&conn_id, odcid.as_ref(), local_addr, from, &mut self.quic)?; let mut conn =
quiche::accept(&conn_id, odcid.as_ref(), local_addr, from, &mut self.quic)?;
// Process potentially coalesced packets. // Process potentially coalesced packets.
conn.recv(src, info)?; conn.recv(src, info)?;
let user = connection::Connection{ let user = connection::Connection {
quiche: conn, quiche: conn,
session: None, session: None,
app: T::default(), app: T::default(),
}; };
self.conns.insert(user.quiche.source_id().into_owned(), user); self.conns
.insert(user.quiche.source_id().into_owned(), user);
} }
} }
@ -262,7 +271,7 @@ impl<T: app::App> Server<T> {
for conn in self.conns.values_mut() { for conn in self.conns.values_mut() {
loop { loop {
let (size , info) = match conn.quiche.send(&mut pkt) { let (size, info) = match conn.quiche.send(&mut pkt) {
Ok(v) => v, Ok(v) => v,
Err(quiche::Error::Done) => return Ok(()), Err(quiche::Error::Done) => return Ok(()),
Err(e) => return Err(e.into()), Err(e) => return Err(e.into()),
@ -283,7 +292,7 @@ impl<T: app::App> Server<T> {
pub fn cleanup(&mut self) { pub fn cleanup(&mut self) {
// Garbage collect closed connections. // Garbage collect closed connections.
self.conns.retain(|_, ref mut c| !c.quiche.is_closed() ); self.conns.retain(|_, ref mut c| !c.quiche.is_closed());
} }
} }
@ -319,7 +328,8 @@ fn mint_token(hdr: &quiche::Header, src: &std::net::SocketAddr) -> Vec<u8> {
/// Note that this function is only an example and doesn't do any cryptographic /// Note that this function is only an example and doesn't do any cryptographic
/// authenticate of the token. *It should not be used in production system*. /// authenticate of the token. *It should not be used in production system*.
fn validate_token<'a>( fn validate_token<'a>(
src: &std::net::SocketAddr, token: &'a [u8], src: &std::net::SocketAddr,
token: &'a [u8],
) -> Option<quiche::ConnectionId<'a>> { ) -> Option<quiche::ConnectionId<'a>> {
if token.len() < 6 { if token.len() < 6 {
return None; return None;

View File

@ -1,8 +1,8 @@
use std::collections::hash_map as hmap; use std::collections::hash_map as hmap;
use std::collections::VecDeque; use std::collections::VecDeque;
use quiche;
use anyhow; use anyhow;
use quiche;
#[derive(Default)] #[derive(Default)]
pub struct Streams { pub struct Streams {
@ -16,14 +16,20 @@ struct State {
} }
impl Streams { impl Streams {
pub fn send(&mut self, conn: &mut quiche::Connection, id: u64, buf: &[u8], fin: bool) -> anyhow::Result<()> { pub fn send(
&mut self,
conn: &mut quiche::Connection,
id: u64,
buf: &[u8],
fin: bool,
) -> anyhow::Result<()> {
match self.lookup.entry(id) { match self.lookup.entry(id) {
hmap::Entry::Occupied(mut entry) => { hmap::Entry::Occupied(mut entry) => {
// Add to the existing buffer. // Add to the existing buffer.
let state = entry.get_mut(); let state = entry.get_mut();
state.buffer.extend(buf); state.buffer.extend(buf);
state.fin |= fin; state.fin |= fin;
}, }
hmap::Entry::Vacant(entry) => { hmap::Entry::Vacant(entry) => {
let size = conn.stream_send(id, buf, fin)?; let size = conn.stream_send(id, buf, fin)?;
@ -32,9 +38,9 @@ impl Streams {
let mut buffer = VecDeque::with_capacity(buf.len()); let mut buffer = VecDeque::with_capacity(buf.len());
buffer.extend(&buf[size..]); buffer.extend(&buf[size..]);
entry.insert(State{buffer, fin}); entry.insert(State { buffer, fin });
}
} }
},
}; };
Ok(()) Ok(())
@ -58,7 +64,7 @@ impl Streams {
let size = conn.stream_send(id, parts.0, false)?; let size = conn.stream_send(id, parts.0, false)?;
if size == 0 { if size == 0 {
// No more space available for this stream. // No more space available for this stream.
continue 'outer continue 'outer;
} }
// Remove the bytes that were written. // Remove the bytes that were written.