More refactoring ofc.

This commit is contained in:
Luke Curley 2023-04-24 10:18:55 -07:00
parent 5204dbc19c
commit bb0437a3bb
18 changed files with 1118 additions and 507 deletions

View File

@ -1,18 +1,6 @@
#!/bin/bash #!/bin/bash
ffmpeg -i source.mp4 \ ffmpeg -i source.mp4 \
-f dash -ldash 1 \ -c:v copy \
-c:v libx264 \ -an \
-preset veryfast -tune zerolatency \ -movflags frag_every_frame+empty_moov \
-c:a aac \ fragmented.mp4
-b:a 128k -ac 2 -ar 44100 \
-map v:0 -s:v:0 1280x720 -b:v:0 3M \
-map v:0 -s:v:1 854x480 -b:v:1 1.1M \
-map v:0 -s:v:2 640x360 -b:v:2 365k \
-map 0:a \
-force_key_frames "expr:gte(t,n_forced*2)" \
-sc_threshold 0 \
-streaming 1 \
-use_timeline 0 \
-seg_duration 2 -frag_duration 0.01 \
-frag_type duration \
playlist.mpd

3
server/.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,3 @@
{
"rust-analyzer.showUnlinkedFileNotification": false
}

136
server/Cargo.lock generated
View File

@ -60,6 +60,12 @@ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
[[package]]
name = "anyhow"
version = "1.0.70"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7de8ce5e0f9f8d88245311066a578d72b7af3e7088f32783804676302df237e4"
[[package]] [[package]]
name = "atty" name = "atty"
version = "0.2.14" version = "0.2.14"
@ -89,6 +95,18 @@ version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "bytes"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.79" version = "1.0.79"
@ -242,6 +260,12 @@ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
[[package]]
name = "itoa"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
[[package]] [[package]]
name = "js-sys" name = "js-sys"
version = "0.3.61" version = "0.3.61"
@ -302,6 +326,63 @@ dependencies = [
"windows-sys 0.45.0", "windows-sys 0.45.0",
] ]
[[package]]
name = "mp4"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "509348cba250e7b852a875100a2ddce7a36ee3abf881a681c756670c1774264d"
dependencies = [
"byteorder",
"bytes",
"num-rational",
"serde",
"serde_json",
"thiserror",
]
[[package]]
name = "num-bigint"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
dependencies = [
"autocfg",
"num-bigint",
"num-integer",
"num-traits",
"serde",
]
[[package]]
name = "num-traits"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "octets" name = "octets"
version = "0.2.0" version = "0.2.0"
@ -394,11 +475,42 @@ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
[[package]]
name = "ryu"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.160" version = "1.0.160"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.160"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.15",
]
[[package]]
name = "serde_json"
version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]] [[package]]
name = "slab" name = "slab"
@ -461,6 +573,26 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "thiserror"
version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.15",
]
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.8" version = "1.0.8"
@ -483,12 +615,16 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
name = "warp" name = "warp"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow",
"clap", "clap",
"env_logger", "env_logger",
"log", "log",
"mio", "mio",
"mp4",
"quiche", "quiche",
"ring", "ring",
"serde",
"serde_json",
] ]
[[package]] [[package]]

View File

@ -11,4 +11,8 @@ clap = { version = "4.0", features = [ "derive" ] }
log = { version = "0.4", features = ["std"] } log = { version = "0.4", features = ["std"] }
mio = { version = "0.8", features = ["net", "os-poll"] } mio = { version = "0.8", features = ["net", "os-poll"] }
env_logger = "0.9.3" env_logger = "0.9.3"
ring = "0.16" ring = "0.16"
anyhow = "1.0.70"
mp4 = "0.13.0"
serde = "1.0.160"
serde_json = "1.0"

View File

@ -1,44 +0,0 @@
use std::io;
use quiche::h3::webtransport;
#[derive(Debug)]
pub enum Error {
Io(io::Error),
Quiche(quiche::Error),
WebTransport(webtransport::Error),
Server(Server),
}
impl From<io::Error> for Error {
fn from(err: io::Error) -> Error {
Error::Io(err)
}
}
impl From<quiche::Error> for Error {
fn from(err: quiche::Error) -> Error {
Error::Quiche(err)
}
}
impl From<webtransport::Error> for Error {
fn from(err: webtransport::Error) -> Error {
Error::WebTransport(err)
}
}
// Custom server error messages.
#[derive(Debug)]
pub enum Server {
InvalidToken,
InvalidConnectionID,
UnknownConnectionID,
}
impl From<Server> for Error {
fn from(err: Server) -> Error {
Error::Server(err)
}
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@ -1,3 +1,2 @@
pub mod error; pub mod transport;
pub mod server; //mod media;
pub mod session;

View File

@ -1,6 +1,10 @@
use warp::server::Server; use quiche::h3::webtransport;
use warp::transport;
use std::time;
use clap::Parser; use clap::Parser;
use env_logger;
/// Search for a pattern in a file and display the lines that contain it. /// Search for a pattern in a file and display the lines that contain it.
#[derive(Parser)] #[derive(Parser)]
@ -16,19 +20,48 @@ struct Cli {
/// Use the private key at this path /// Use the private key at this path
#[arg(short, long, default_value = "../cert/localhost.key")] #[arg(short, long, default_value = "../cert/localhost.key")]
key: String, key: String,
/// Use the media file at this path
#[arg(short, long, default_value = "../media/fragmented.mp4")]
media: String,
} }
fn main() { #[derive(Default)]
struct Connection {
webtransport: Option<webtransport::ServerSession>,
}
impl transport::App for Connection {
fn poll(&mut self, conn: &mut quiche::Connection, session: &mut webtransport::ServerSession) -> anyhow::Result<()> {
if !conn.is_established() {
// Wait until the handshake finishes
return Ok(())
}
if self.webtransport.is_none() {
self.webtransport = Some(webtransport::ServerSession::with_transport(conn)?)
}
let webtransport = self.webtransport.as_mut().unwrap();
Ok(())
}
fn timeout(&self) -> Option<time::Duration> {
None
}
}
fn main() -> anyhow::Result<()> {
env_logger::init();
let args = Cli::parse(); let args = Cli::parse();
let server_config = warp::server::Config{ let server_config = transport::Config{
addr: args.addr, addr: args.addr,
cert: args.cert, cert: args.cert,
key: args.key, key: args.key,
}; };
let mut server = Server::new(server_config).unwrap(); let mut server = transport::Server::<Connection>::new(server_config).unwrap();
loop { server.run()
server.poll().unwrap()
}
} }

108
server/src/media.rs Normal file
View File

@ -0,0 +1,108 @@
use std::{io,fs};
use mp4;
use anyhow;
use bytes;
use mp4::ReadBox;
pub struct Source {
pub segments: Vec<Segment>,
}
impl Source {
pub fn new(path: &str) -> anyhow::Result<Self> {
let f = fs::read(path)?;
let mut bytes = bytes::Bytes::from(f);
let mut segments = Vec::new();
let mut current = Segment::new();
while bytes.len() > 0 {
// NOTE: Cloning is cheap, since the underlying bytes are reference counted.
let mut reader = io::Cursor::new(bytes.clone());
let header = mp4::BoxHeader::read(&mut reader)?;
let size: usize = header.size as usize;
assert!(size > 0, "empty box");
let frag = bytes.split_to(size);
let fragment = Fragment{ bytes: frag };
match header.name {
/*
mp4::BoxType::FtypBox => {
}
mp4::BoxType::MoovBox => {
moov = mp4::MoovBox::read_box(&mut reader, size)?
}
mp4::BoxType::EmsgBox => {
let emsg = mp4::EmsgBox::read_box(&mut reader, size)?;
emsgs.push(emsg);
}
mp4::BoxType::MdatBox => {
mp4::skip_box(&mut reader, size)?;
}
*/
mp4::BoxType::MoofBox => {
let moof = mp4::MoofBox::read_box(&mut reader, header.size)?;
if has_keyframe(moof) {
segments.push(current);
current = Segment::new();
}
}
_ => (),
}
current.fragments.push(fragment);
}
segments.push(current);
Ok(Self { segments })
}
}
fn has_keyframe(moof: mp4::MoofBox) -> bool {
for traf in moof.trafs {
// TODO trak default flags if this is None
let default_flags = traf.tfhd.default_sample_flags.unwrap_or_default();
let trun = traf.trun.expect("missing trun box");
for i in 0..trun.sample_count {
let mut flags = match trun.sample_flags.get(i as usize) {
Some(f) => *f,
None => default_flags,
};
if i == 0 && trun.first_sample_flags.is_some() {
flags = trun.first_sample_flags.unwrap();
}
// https://chromium.googlesource.com/chromium/src/media/+/master/formats/mp4/track_run_iterator.cc#177
let keyframe = (flags >> 24) & 0x3 == 0x2; // kSampleDependsOnNoOther
let non_sync = (flags >> 16) & 0x1 == 0x1; // kSampleIsNonSyncSample
if keyframe && non_sync {
return true
}
}
}
false
}
pub struct Segment {
pub fragments: Vec<Fragment>,
}
impl Segment {
fn new() -> Self {
Segment { fragments: Vec::new() }
}
}
pub struct Fragment {
pub bytes: bytes::Bytes,
}

1
server/src/media/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod source;

163
server/src/media/source.rs Normal file
View File

@ -0,0 +1,163 @@
use std::{io,fs,time};
use io::Read;
use mp4;
use anyhow;
use mp4::ReadBox;
pub struct Source {
reader: io::BufReader<fs::File>,
start: time::Instant,
pending: Option<Fragment>,
sequence: u64,
}
pub struct Fragment {
pub data: Vec<u8>,
pub segment_id: u64,
pub timestamp: u64,
}
impl Source {
pub fn new(path: &str) -> io::Result<Self> {
let f = fs::File::open(path)?;
let reader = io::BufReader::new(f);
let start = time::Instant::now();
Ok(Self{
reader,
start,
pending: None,
sequence: 0,
})
}
pub fn next(&mut self) -> anyhow::Result<Option<Fragment>> {
let pending = match self.pending.take() {
Some(f) => f,
None => self.next_inner()?,
};
if pending.timestamp > 0 && pending.timestamp < self.start.elapsed().as_millis() as u64 {
self.pending = Some(pending);
return Ok(None)
}
Ok(Some(pending))
}
fn next_inner(&mut self) -> anyhow::Result<Fragment> {
// Read the next full atom.
let atom = read_box(&mut self.reader)?;
let mut timestamp = 0;
// Before we return it, let's do some simple parsing.
let mut reader = io::Cursor::new(&atom);
let header = mp4::BoxHeader::read(&mut reader)?;
match header.name {
mp4::BoxType::MoofBox => {
let moof = mp4::MoofBox::read_box(&mut reader, header.size)?;
if has_keyframe(&moof) {
self.sequence += 1
}
timestamp = first_timestamp(&moof);
}
_ => (),
}
Ok(Fragment {
data: atom,
segment_id: self.sequence,
timestamp: timestamp,
})
}
}
// Read a full MP4 atom into a vector.
fn read_box<R: io::Read>(reader: &mut R) -> anyhow::Result<Vec<u8>> {
// Read the 8 bytes for the size + type
let mut buf = [0u8; 8];
reader.read_exact(&mut buf)?;
// Convert the first 4 bytes into the size.
let size = u32::from_be_bytes(buf[0..4].try_into()?) as u64;
let mut out = buf.to_vec();
let mut limit = match size {
// Runs until the end of the file.
0 => reader.take(u64::MAX),
// The next 8 bytes are the extended size to be used instead.
1 => {
reader.read_exact(&mut buf)?;
let size_large = u64::from_be_bytes(buf);
anyhow::ensure!(size_large >= 16, "impossible extended box size: {}", size_large);
reader.take(size_large - 16)
},
2..=7 => {
anyhow::bail!("impossible box size: {}", size)
}
// Otherwise read based on the size.
size => reader.take(size - 8)
};
// Append to the vector and return it.
limit.read_to_end(&mut out)?;
Ok(out)
}
fn has_keyframe(moof: &mp4::MoofBox) -> bool {
for traf in &moof.trafs {
// TODO trak default flags if this is None
let default_flags = traf.tfhd.default_sample_flags.unwrap_or_default();
let trun = match &traf.trun {
Some(t) => t,
None => return false,
};
for i in 0..trun.sample_count {
let mut flags = match trun.sample_flags.get(i as usize) {
Some(f) => *f,
None => default_flags,
};
if i == 0 && trun.first_sample_flags.is_some() {
flags = trun.first_sample_flags.unwrap();
}
// https://chromium.googlesource.com/chromium/src/media/+/master/formats/mp4/track_run_iterator.cc#177
let keyframe = (flags >> 24) & 0x3 == 0x2; // kSampleDependsOnNoOther
let non_sync = (flags >> 16) & 0x1 == 0x1; // kSampleIsNonSyncSample
if keyframe && non_sync {
return true
}
}
}
false
}
fn first_timestamp(moof: &mp4::MoofBox) -> u64 {
let traf = match moof.trafs.first() {
Some(t) => t,
None => return 0,
};
let tfdt = match &traf.tfdt {
Some(t) => t,
None => return 0,
};
tfdt.base_media_decode_time
}

40
server/src/message.rs Normal file
View File

@ -0,0 +1,40 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct Message {
pub init: Option<Init>,
pub segment: Option<Segment>,
}
#[derive(Serialize, Deserialize)]
pub struct Init {
pub id: String,
}
#[derive(Serialize, Deserialize)]
pub struct Segment {
pub init: String,
pub timestamp: u64,
}
impl Message {
pub fn new() -> Self {
Message {
init: None,
segment: None,
}
}
pub fn serialize(&self) -> anyhow::Result<Vec<u8>> {
let str = serde_json::to_string(self)?;
let bytes = str.as_bytes();
let size = bytes.len() + 8;
let mut out = Vec::with_capacity(size);
out.extend_from_slice(b"warp");
out.extend_from_slice(&size.to_be_bytes());
out.extend_from_slice(bytes);
Ok(out)
}
}

View File

@ -1,332 +0,0 @@
use crate::session;
use crate::error;
use session::Session;
use error::{Error, Result};
use std::{io, net};
use log;
use quiche::h3::webtransport;
const MAX_DATAGRAM_SIZE: usize = 1350;
pub struct Server {
// IO stuff
socket: mio::net::UdpSocket,
poll: mio::Poll,
events: mio::Events,
// QUIC stuff
quic: quiche::Config,
sessions: session::Map,
seed: ring::hmac::Key, // connection ID seed
}
pub struct Config {
pub addr: String,
pub cert: String,
pub key: String,
}
impl Server {
pub fn new(config: Config) -> io::Result<Server> {
// Listen on the provided socket address
let addr = config.addr.parse().unwrap();
let mut socket = mio::net::UdpSocket::bind(addr).unwrap();
// Setup the event loop.
let poll = mio::Poll::new().unwrap();
let events = mio::Events::with_capacity(1024);
let sessions = session::Map::new();
poll.registry().register(
&mut socket,
mio::Token(0),
mio::Interest::READABLE,
).unwrap();
// Generate random values for connection IDs.
let rng = ring::rand::SystemRandom::new();
let seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap();
// Create the configuration for the QUIC connections.
let mut quic = quiche::Config::new(quiche::PROTOCOL_VERSION).unwrap();
quic.load_cert_chain_from_pem_file(&config.cert).unwrap();
quic.load_priv_key_from_pem_file(&config.key).unwrap();
quic.set_application_protos(quiche::h3::APPLICATION_PROTOCOL).unwrap();
quic.set_max_idle_timeout(5000);
quic.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE);
quic.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE);
quic.set_initial_max_data(10_000_000);
quic.set_initial_max_stream_data_bidi_local(1_000_000);
quic.set_initial_max_stream_data_bidi_remote(1_000_000);
quic.set_initial_max_stream_data_uni(1_000_000);
quic.set_initial_max_streams_bidi(100);
quic.set_initial_max_streams_uni(100);
quic.set_disable_active_migration(true);
quic.enable_early_data();
quic.enable_dgram(true, 65536, 65536);
Ok(Server {
socket,
poll,
events,
quic,
sessions,
seed
})
}
pub fn poll(&mut self) -> io::Result<()> {
self.receive().unwrap();
self.send().unwrap();
self.cleanup().unwrap();
Ok(())
}
fn receive(&mut self) -> io::Result<()> {
// Find the shorter timeout from all the active connections.
//
// TODO: use event loop that properly supports timers
let timeout = self.sessions.values().filter_map(|c| c.conn.timeout()).min();
self.poll.poll(&mut self.events, timeout).unwrap();
// If the event loop reported no events, it means that the timeout
// has expired, so handle it without attempting to read packets. We
// will then proceed with the send loop.
if self.events.is_empty() {
self.sessions.values_mut().for_each(|session| {
session.conn.on_timeout()
});
return Ok(())
}
// Read incoming UDP packets from the socket and feed them to quiche,
// until there are no more packets to read.
loop {
match self.receive_once() {
Err(Error::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => return Ok(()),
Err(e) => log::error!("{:?}", e),
Ok(_) => (),
}
}
}
fn receive_once(&mut self) -> Result<()> {
let mut src= [0; MAX_DATAGRAM_SIZE];
let (len, from) = self.socket.recv_from(&mut src).unwrap();
let src = &mut src[..len];
let info = quiche::RecvInfo {
to: self.socket.local_addr().unwrap(),
from,
};
// Lookup a connection based on the packet's connection ID. If there
// is no connection matching, create a new one.
let pair = match self.accept(src, from).unwrap() {
Some(v) => v,
None => return Ok(()),
};
let conn = &mut pair.conn;
// Process potentially coalesced packets.
conn.recv(src, info).unwrap();
// Create a new HTTP/3 connection as soon as the QUIC connection
// is established.
if (conn.is_in_early_data() || conn.is_established()) && pair.session.is_none() {
let session = webtransport::ServerSession::with_transport(conn).unwrap();
pair.session = Some(session);
}
// The `poll` can pull out the events that occurred according to the data passed here.
for (_, session) in self.sessions.iter_mut() {
session.poll().unwrap();
}
Ok(())
}
fn accept(&mut self, src: &mut [u8], from: net::SocketAddr) -> error::Result<Option<&mut Session>> {
// Parse the QUIC packet's header.
let hdr = quiche::Header::from_slice(src, quiche::MAX_CONN_ID_LEN).unwrap();
let conn_id = ring::hmac::sign(&self.seed, &hdr.dcid);
let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN];
let conn_id = conn_id.to_vec().into();
if self.sessions.contains_key(&hdr.dcid) {
let pair = self.sessions.get_mut(&hdr.dcid).unwrap();
return Ok(Some(pair))
} else if self.sessions.contains_key(&conn_id) {
let pair = self.sessions.get_mut(&conn_id).unwrap();
return Ok(Some(pair));
}
if hdr.ty != quiche::Type::Initial {
return Err(error::Server::UnknownConnectionID.into())
}
let mut dst = [0; MAX_DATAGRAM_SIZE];
if !quiche::version_is_supported(hdr.version) {
let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut dst).unwrap();
let dst= &dst[..len];
self.socket.send_to(dst, from).unwrap();
return Ok(None)
}
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
scid.copy_from_slice(&conn_id);
let scid = quiche::ConnectionId::from_ref(&scid);
// Token is always present in Initial packets.
let token = hdr.token.as_ref().unwrap();
// Do stateless retry if the client didn't send a token.
if token.is_empty() {
let new_token = mint_token(&hdr, &from);
let len = quiche::retry(
&hdr.scid,
&hdr.dcid,
&scid,
&new_token,
hdr.version,
&mut dst,
)
.unwrap();
let dst= &dst[..len];
self.socket.send_to(dst, from).unwrap();
return Ok(None)
}
let odcid = validate_token(&from, token);
// The token was not valid, meaning the retry failed, so
// drop the packet.
if odcid.is_none() {
return Err(error::Server::InvalidToken.into())
}
if scid.len() != hdr.dcid.len() {
return Err(error::Server::InvalidConnectionID.into())
}
// Reuse the source connection ID we sent in the Retry packet,
// instead of changing it again.
let conn_id= hdr.dcid.clone();
let local_addr = self.socket.local_addr().unwrap();
let conn =
quiche::accept(&conn_id, odcid.as_ref(), local_addr, from, &mut self.quic)
.unwrap();
self.sessions.insert(
conn_id.clone(),
Session {
conn,
session: None,
},
);
let pair = self.sessions.get_mut(&conn_id).unwrap();
Ok(Some(pair))
}
fn send(&mut self) -> io::Result<()> {
let mut pkt = [0; MAX_DATAGRAM_SIZE];
// Generate outgoing QUIC packets for all active connections and send
// them on the UDP socket, until quiche reports that there are no more
// packets to be sent.
for session in self.sessions.values_mut() {
loop {
let (size , info) = session.conn.send(&mut pkt).unwrap();
let pkt = &pkt[..size];
match self.socket.send_to(&pkt, info.to) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => break,
Err(err) => return Err(err),
Ok(_) => (),
}
}
}
Ok(())
}
fn cleanup(&mut self) -> io::Result<()> {
// Garbage collect closed connections.
self.sessions.retain(|_, session| !session.conn.is_closed() );
Ok(())
}
}
/// Generate a stateless retry token.
///
/// The token includes the static string `"quiche"` followed by the IP address
/// of the client and by the original destination connection ID generated by the
/// client.
///
/// Note that this function is only an example and doesn't do any cryptographic
/// authenticate of the token. *It should not be used in production system*.
fn mint_token(hdr: &quiche::Header, src: &std::net::SocketAddr) -> Vec<u8> {
let mut token = Vec::new();
token.extend_from_slice(b"quiche");
let addr = match src.ip() {
std::net::IpAddr::V4(a) => a.octets().to_vec(),
std::net::IpAddr::V6(a) => a.octets().to_vec(),
};
token.extend_from_slice(&addr);
token.extend_from_slice(&hdr.dcid);
token
}
/// Validates a stateless retry token.
///
/// This checks that the ticket includes the `"quiche"` static string, and that
/// the client IP address matches the address stored in the ticket.
///
/// Note that this function is only an example and doesn't do any cryptographic
/// authenticate of the token. *It should not be used in production system*.
fn validate_token<'a>(
src: &std::net::SocketAddr, token: &'a [u8],
) -> Option<quiche::ConnectionId<'a>> {
if token.len() < 6 {
return None;
}
if &token[..6] != b"quiche" {
return None;
}
let token = &token[6..];
let addr = match src.ip() {
std::net::IpAddr::V4(a) => a.octets().to_vec(),
std::net::IpAddr::V6(a) => a.octets().to_vec(),
};
if token.len() < addr.len() || &token[..addr.len()] != addr.as_slice() {
return None;
}
Some(quiche::ConnectionId::from_ref(&token[addr.len()..]))
}

View File

@ -1,104 +0,0 @@
use crate::error;
use error::Result;
use std::collections::HashMap;
use quiche::h3::webtransport;
pub struct Session {
pub conn: quiche::Connection,
pub session: Option<webtransport::ServerSession>,
}
pub type Map = HashMap<quiche::ConnectionId<'static>, Session>;
impl Session {
// Process any updates to a session.
pub fn poll(&mut self) -> Result<()> {
let session = match &mut self.session {
Some(s) => s,
None => return Ok(()),
};
loop {
let event = match session.poll(&mut self.conn) {
Err(webtransport::Error::Done) => return Ok(()),
Err(e) => return Err(e.into()),
Ok(e) => e,
};
match event {
webtransport::ServerEvent::ConnectRequest(_req) => {
// you can handle request with
// req.authority()
// req.path()
// and you can validate this request with req.origin()
session.accept_connect_request(&mut self.conn, None).unwrap();
},
webtransport::ServerEvent::StreamData(stream_id) => {
let mut buf = vec![0; 10000];
while let Ok(len) =
session.recv_stream_data(&mut self.conn, stream_id, &mut buf)
{
let stream_data = &buf[0..len];
// handle stream_data
if (stream_id & 0x2) == 0 {
// bidirectional stream
// you can send data through this stream.
session
.send_stream_data(&mut self.conn, stream_id, stream_data)
.unwrap();
} else {
// you cannot send data through client-initiated-unidirectional-stream.
// so, open new server-initiated-unidirectional-stream, and send data
// through it.
let new_stream_id =
session.open_stream(&mut self.conn, false).unwrap();
session
.send_stream_data(&mut self.conn, new_stream_id, stream_data)
.unwrap();
}
}
}
webtransport::ServerEvent::StreamFinished(_stream_id) => {
// A WebTrnasport stream finished, handle it.
}
webtransport::ServerEvent::Datagram => {
let mut buf = vec![0; 1500];
while let Ok((in_session, offset, total)) =
session.recv_dgram(&mut self.conn, &mut buf)
{
if in_session {
let dgram = &buf[offset..total];
dbg!(std::string::String::from_utf8_lossy(dgram));
// handle this dgram
// for instance, you can write echo-server like following
session.send_dgram(&mut self.conn, dgram).unwrap();
} else {
// this dgram is not related to current WebTransport session. ignore.
}
}
}
webtransport::ServerEvent::SessionReset(_e) => {
// Peer reset session stream, handle it.
}
webtransport::ServerEvent::SessionFinished => {
// Peer finish session stream, handle it.
}
webtransport::ServerEvent::SessionGoAway => {
// Peer signalled it is going away, handle it.
}
webtransport::ServerEvent::Other(_stream_id, _event) => {
// Original h3::Event which is not related to WebTransport.
}
}
}
}
}

View File

@ -0,0 +1,8 @@
use std::time;
use quiche::h3::webtransport;
pub trait App: Default {
fn poll(&mut self, conn: &mut quiche::Connection, session: &mut webtransport::ServerSession) -> anyhow::Result<()>;
fn timeout(&self) -> Option<time::Duration>;
}

View File

@ -0,0 +1,15 @@
use quiche;
use quiche::h3::webtransport;
use std::collections::hash_map as hmap;
pub type Id = quiche::ConnectionId<'static>;
use super::app;
pub type Map<T> = hmap::HashMap<Id, Connection<T>>;
pub struct Connection<T: app::App> {
pub quiche: quiche::Connection,
pub session: Option<webtransport::ServerSession>,
pub app: T,
}

View File

@ -0,0 +1,7 @@
mod server;
mod session;
mod connection;
mod app;
pub use app::App;
pub use server::{Config, Server};

View File

@ -0,0 +1,334 @@
use std::io;
use quiche::h3::webtransport;
use super::connection;
use super::app;
const MAX_DATAGRAM_SIZE: usize = 1350;
pub struct Server<T: app::App> {
// IO stuff
socket: mio::net::UdpSocket,
poll: mio::Poll,
events: mio::Events,
// QUIC stuff
quic: quiche::Config,
seed: ring::hmac::Key, // connection ID seed
conns: connection::Map<T>,
}
pub struct Config {
pub addr: String,
pub cert: String,
pub key: String,
}
impl<T: app::App> Server<T> {
pub fn new(config: Config) -> io::Result<Self> {
// Listen on the provided socket address
let addr = config.addr.parse().unwrap();
let mut socket = mio::net::UdpSocket::bind(addr).unwrap();
// Setup the event loop.
let poll = mio::Poll::new().unwrap();
let events = mio::Events::with_capacity(1024);
poll.registry().register(
&mut socket,
mio::Token(0),
mio::Interest::READABLE,
).unwrap();
// Generate random values for connection IDs.
let rng = ring::rand::SystemRandom::new();
let seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap();
// Create the configuration for the QUIC conns.
let mut quic = quiche::Config::new(quiche::PROTOCOL_VERSION).unwrap();
quic.load_cert_chain_from_pem_file(&config.cert).unwrap();
quic.load_priv_key_from_pem_file(&config.key).unwrap();
quic.set_application_protos(quiche::h3::APPLICATION_PROTOCOL).unwrap();
quic.set_max_idle_timeout(5000);
quic.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE);
quic.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE);
quic.set_initial_max_data(10_000_000);
quic.set_initial_max_stream_data_bidi_local(1_000_000);
quic.set_initial_max_stream_data_bidi_remote(1_000_000);
quic.set_initial_max_stream_data_uni(1_000_000);
quic.set_initial_max_streams_bidi(100);
quic.set_initial_max_streams_uni(100);
quic.set_disable_active_migration(true);
quic.enable_early_data();
quic.enable_dgram(true, 65536, 65536);
let conns = Default::default();
Ok(Server {
socket,
poll,
events,
quic,
seed,
conns,
})
}
pub fn run(&mut self) -> anyhow::Result<()> {
loop {
self.wait()?;
self.receive()?;
self.app()?;
self.send()?;
}
}
pub fn wait(&mut self) -> anyhow::Result<()> {
// Find the shorter timeout from all the active connections.
//
// TODO: use event loop that properly supports timers
let timeout = self.conns.values().filter_map(|c| {
let timeout = c.quiche.timeout();
let expires = c.app.timeout();
match (timeout, expires) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
}
}).min();
self.poll.poll(&mut self.events, timeout).unwrap();
// If the event loop reported no events, it means that the timeout
// has expired, so handle it without attempting to read packets. We
// will then proceed with the send loop.
if self.events.is_empty() {
for conn in self.conns.values_mut() {
conn.quiche.on_timeout();
}
}
Ok(())
}
// Reads packets from the socket, updating any internal connection state.
fn receive(&mut self) -> anyhow::Result<()> {
let mut src= [0; MAX_DATAGRAM_SIZE];
// Try reading any data currently available on the socket.
loop {
let (len, from) = match self.socket.recv_from(&mut src) {
Ok(v) => v,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e.into()),
};
let src = &mut src[..len];
let info = quiche::RecvInfo {
to: self.socket.local_addr().unwrap(),
from,
};
// Parse the QUIC packet's header.
let hdr = quiche::Header::from_slice(src, quiche::MAX_CONN_ID_LEN).unwrap();
let conn_id = ring::hmac::sign(&self.seed, &hdr.dcid);
let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN];
let conn_id = conn_id.to_vec().into();
// Check if it's an existing connection.
if let Some(conn) = self.conns.get_mut(&hdr.dcid) {
// initial or handshake traffic.
conn.quiche.recv(src, info)?;
if conn.session.is_none() && conn.quiche.is_established() {
conn.session = Some(webtransport::ServerSession::with_transport(&mut conn.quiche)?)
}
continue
} else if let Some(conn) = self.conns.get_mut(&conn_id) {
// 1-RTT traffic.
conn.quiche.recv(src, info)?;
// TODO is this needed here?
if conn.session.is_none() && conn.quiche.is_established() {
conn.session = Some(webtransport::ServerSession::with_transport(&mut conn.quiche)?)
}
continue
}
if hdr.ty != quiche::Type::Initial {
anyhow::bail!("unknown connection ID");
}
let mut dst = [0; MAX_DATAGRAM_SIZE];
if !quiche::version_is_supported(hdr.version) {
let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut dst).unwrap();
let dst= &dst[..len];
self.socket.send_to(dst, from).unwrap();
continue
}
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
scid.copy_from_slice(&conn_id);
let scid = quiche::ConnectionId::from_ref(&scid);
// Token is always present in Initial packets.
let token = hdr.token.as_ref().unwrap();
// Do stateless retry if the client didn't send a token.
if token.is_empty() {
let new_token = mint_token(&hdr, &from);
let len = quiche::retry(
&hdr.scid,
&hdr.dcid,
&scid,
&new_token,
hdr.version,
&mut dst,
)
.unwrap();
let dst= &dst[..len];
self.socket.send_to(dst, from).unwrap();
continue
}
let odcid = validate_token(&from, token);
// The token was not valid, meaning the retry failed, so
// drop the packet.
if odcid.is_none() {
anyhow::bail!("invalid token");
}
if scid.len() != hdr.dcid.len() {
anyhow::bail!("invalid connection ID");
}
// Reuse the source connection ID we sent in the Retry packet,
// instead of changing it again.
let conn_id= hdr.dcid.clone();
let local_addr = self.socket.local_addr().unwrap();
let mut conn = quiche::accept(&conn_id, odcid.as_ref(), local_addr, from, &mut self.quic)?;
// Process potentially coalesced packets.
conn.recv(src, info)?;
let user = connection::Connection{
quiche: conn,
session: None,
app: T::default(),
};
self.conns.insert(user.quiche.source_id().into_owned(), user);
}
}
pub fn app(&mut self) -> anyhow::Result<()> {
for (_, conn) in &mut self.conns {
if let Some(session) = &mut conn.session {
conn.app.poll(&mut conn.quiche, session)?;
}
}
Ok(())
}
// Generate outgoing QUIC packets for all active connections and send
// them on the UDP socket, until quiche reports that there are no more
// packets to be sent.
pub fn send(&mut self) -> anyhow::Result<()> {
let mut pkt = [0; MAX_DATAGRAM_SIZE];
for conn in self.conns.values_mut() {
loop {
let (size , info) = match conn.quiche.send(&mut pkt) {
Ok(v) => v,
Err(quiche::Error::Done) => return Ok(()),
Err(e) => return Err(e.into()),
};
let pkt = &pkt[..size];
match self.socket.send_to(&pkt, info.to) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => break,
Err(err) => return Err(err.into()),
Ok(_) => (),
}
}
}
Ok(())
}
}
/// Generate a stateless retry token.
///
/// The token includes the static string `"quiche"` followed by the IP address
/// of the client and by the original destination connection ID generated by the
/// client.
///
/// Note that this function is only an example and doesn't do any cryptographic
/// authenticate of the token. *It should not be used in production system*.
fn mint_token(hdr: &quiche::Header, src: &std::net::SocketAddr) -> Vec<u8> {
let mut token = Vec::new();
token.extend_from_slice(b"quiche");
let addr = match src.ip() {
std::net::IpAddr::V4(a) => a.octets().to_vec(),
std::net::IpAddr::V6(a) => a.octets().to_vec(),
};
token.extend_from_slice(&addr);
token.extend_from_slice(&hdr.dcid);
token
}
/// Validates a stateless retry token.
///
/// This checks that the ticket includes the `"quiche"` static string, and that
/// the client IP address matches the address stored in the ticket.
///
/// Note that this function is only an example and doesn't do any cryptographic
/// authenticate of the token. *It should not be used in production system*.
fn validate_token<'a>(
src: &std::net::SocketAddr, token: &'a [u8],
) -> Option<quiche::ConnectionId<'a>> {
if token.len() < 6 {
return None;
}
if &token[..6] != b"quiche" {
return None;
}
let token = &token[6..];
let addr = match src.ip() {
std::net::IpAddr::V4(a) => a.octets().to_vec(),
std::net::IpAddr::V6(a) => a.octets().to_vec(),
};
if token.len() < addr.len() || &token[..addr.len()] != addr.as_slice() {
return None;
}
Some(quiche::ConnectionId::from_ref(&token[addr.len()..]))
}

View File

@ -0,0 +1,252 @@
use std::collections::hash_map as hmap;
use quiche::h3::webtransport;
type Session = webtransport::ServerSession;
type Map = hmap::HashMap<quiche::ConnectionId<'static>, Session>;
/*
impl Session {
pub fn with_transport(conn: &mut quiche::Connection) -> anyhow::Result<Self> {
let session = webtransport::ServerSession::with_transport(conn)?;
Ok(Self{
session
})
}
// Process any updates to a session.
pub fn poll(&mut self) -> anyhow::Result<()> {
log::debug!("poll conn");
while self.poll_once()? {}
log::debug!("poll streams");
self.poll_streams()?;
Ok(())
}
// Process any updates to a session.
pub fn poll_once(&mut self) -> anyhow::Result<bool> {
let session = match &mut self.session {
Some(s) => s,
None => return Ok(false),
};
let event = match session.poll(&mut self.conn) {
Err(webtransport::Error::Done) => return Ok(false),
Err(e) => return Err(e.into()),
Ok(e) => e,
};
match event {
webtransport::ServerEvent::ConnectRequest(req) => {
log::debug!("new connect {:?}", req);
// you can handle request with
// req.authority()
// req.path()
// and you can validate this request with req.origin()
session.accept_connect_request(&mut self.conn, None).unwrap();
},
webtransport::ServerEvent::StreamData(stream_id) => {
log::debug!("on stream data {}", stream_id);
let mut buf = vec![0; 10000];
while let Ok(len) =
session.recv_stream_data(&mut self.conn, stream_id, &mut buf)
{
let stream_data = &buf[0..len];
log::debug!("stream data {:?}", stream_data);
/*
// handle stream_data
if (stream_id & 0x2) == 0 {
// bidirectional stream
// you can send data through this stream.
session
.send_stream_data(&mut self.conn, stream_id, stream_data)
.unwrap();
} else {
// you cannot send data through client-initiated-unidirectional-stream.
// so, open new server-initiated-unidirectional-stream, and send data
// through it.
let new_stream_id =
session.open_stream(&mut self.conn, false).unwrap();
session
.send_stream_data(&mut self.conn, new_stream_id, stream_data)
.unwrap();
}
*/
}
}
webtransport::ServerEvent::StreamFinished(stream_id) => {
// A WebTrnasport stream finished, handle it.
log::debug!("stream finished {}", stream_id);
}
webtransport::ServerEvent::Datagram => {
log::debug!("datagram");
}
webtransport::ServerEvent::SessionReset(e) => {
log::debug!("session reset {}", e);
// Peer reset session stream, handle it.
}
webtransport::ServerEvent::SessionFinished => {
log::debug!("session finished");
// Peer finish session stream, handle it.
}
webtransport::ServerEvent::SessionGoAway => {
log::debug!("session go away");
// Peer signalled it is going away, handle it.
}
webtransport::ServerEvent::Other(stream_id, event) => {
log::debug!("session other: {} {:?}", stream_id, event);
// Original h3::Event which is not related to WebTransport.
}
}
Ok(true)
}
/*
fn poll_source(&mut self) -> anyhow::Result<()> {
let media = match &mut self.media {
Some(m) => m,
None => return Ok(()),
};
let fragment = match media.next()? {
Some(f) => f,
None => return Ok(()),
};
// Get or create a new stream for each unique segment ID.
let stream_id = match self.segments.entry(fragment.segment_id) {
map::Entry::Occupied(e) => e.into_mut(),
map::Entry::Vacant(e) => {
let stream_id = self.start_stream(&fragment)?;
e.insert(stream_id)
},
};
// Get or create a buffered object for each unique stream ID.
let buffered = match self.streams.entry(*stream_id) {
map::Entry::Occupied(e) => e.into_mut(),
map::Entry::Vacant(e) => e.insert(Buffered::new()),
};
let session = match &mut self.session {
Some(s) => s,
None => return Ok(()),
};
let data = fragment.data.as_slice();
match self.conn.stream_writable(*stream_id, data.len()) {
Ok(true) if buffered.len() == 0 => {
session.send_stream_data(&mut self.conn, *stream_id, data)?;
},
Ok(_) => buffered.push_back(fragment.data),
Err(quiche::Error::Done) => {}, // stream closed?
Err(e) => anyhow::bail!(e),
};
Ok(())
}
fn start_stream(&mut self, fragment: &source::Fragment) -> anyhow::Result<u64> {
let conn = &mut self.conn;
let session = self.session.as_mut().unwrap();
let stream_id = session.open_stream(conn, false)?;
// TODO: conn.stream_priority(stream_id, urgency, incremental)
let mut message = message::Message::new();
if fragment.segment_id == 0 {
message.init = Some(message::Init{
id: "video".to_string(),
});
} else {
message.segment = Some(message::Segment{
init: "video".to_string(),
timestamp: fragment.timestamp,
});
}
let data= message.serialize()?;
match conn.stream_writable(stream_id, data.len()) {
Ok(true) => {
session.send_stream_data(conn, stream_id, data.as_slice())?;
},
Ok(false) => {
let mut buffered = Buffered::new();
buffered.push_back(data);
self.streams.insert(stream_id, buffered);
},
Err(quiche::Error::Done) => {},
Err(e) => anyhow::bail!(e),
};
Ok(stream_id)
}
*/
fn poll_streams(&mut self) -> anyhow::Result<()> {
// TODO make sure this loops in priority order
for stream_id in self.conn.writable() {
self.poll_stream(stream_id)?;
}
// Remove any entry buffered values.
self.streams.retain(|_, buffered| buffered.len() > 0 );
Ok(())
}
pub fn poll_stream(&mut self, stream_id: u64) -> anyhow::Result<()> {
let buffered = match self.streams.get_mut(&stream_id) {
Some(b) => b,
None => return Ok(()),
};
let conn = &mut self.conn;
let session = match &mut self.session {
Some(s) => s,
None => return Ok(()),
};
while let Some(data) = buffered.pop_front() {
match conn.stream_writable(stream_id, data.len()) {
Ok(true) => {
session.send_stream_data(conn, stream_id, data.as_slice())?;
},
Ok(false) => {
buffered.push_front(data);
return Ok(());
},
Err(quiche::Error::Done) => {},
Err(e) => anyhow::bail!(e),
};
}
Ok(())
}
pub fn timeout(&self) -> Option<time::Duration> {
self.conn.timeout()
}
pub fn on_timeout(&mut self) {
self.conn.on_timeout()
// custom stuff here
}
}
*/