From bb0437a3bbc237e8c0ccad454f5dadd33167cd34 Mon Sep 17 00:00:00 2001 From: Luke Curley Date: Mon, 24 Apr 2023 10:18:55 -0700 Subject: [PATCH] More refactoring ofc. --- media/generate | 20 +- server/.vscode/settings.json | 3 + server/Cargo.lock | 136 ++++++++++++ server/Cargo.toml | 6 +- server/src/error.rs | 44 ---- server/src/lib.rs | 5 +- server/src/main.rs | 47 +++- server/src/media.rs | 108 ++++++++++ server/src/media/mod.rs | 1 + server/src/media/source.rs | 163 ++++++++++++++ server/src/message.rs | 40 ++++ server/src/server.rs | 332 ---------------------------- server/src/session.rs | 104 --------- server/src/transport/app.rs | 8 + server/src/transport/connection.rs | 15 ++ server/src/transport/mod.rs | 7 + server/src/transport/server.rs | 334 +++++++++++++++++++++++++++++ server/src/transport/session.rs | 252 ++++++++++++++++++++++ 18 files changed, 1118 insertions(+), 507 deletions(-) create mode 100644 server/.vscode/settings.json delete mode 100644 server/src/error.rs create mode 100644 server/src/media.rs create mode 100644 server/src/media/mod.rs create mode 100644 server/src/media/source.rs create mode 100644 server/src/message.rs delete mode 100644 server/src/server.rs delete mode 100644 server/src/session.rs create mode 100644 server/src/transport/app.rs create mode 100644 server/src/transport/connection.rs create mode 100644 server/src/transport/mod.rs create mode 100644 server/src/transport/server.rs create mode 100644 server/src/transport/session.rs diff --git a/media/generate b/media/generate index 110f401..b387ac9 100755 --- a/media/generate +++ b/media/generate @@ -1,18 +1,6 @@ #!/bin/bash ffmpeg -i source.mp4 \ - -f dash -ldash 1 \ - -c:v libx264 \ - -preset veryfast -tune zerolatency \ - -c:a aac \ - -b:a 128k -ac 2 -ar 44100 \ - -map v:0 -s:v:0 1280x720 -b:v:0 3M \ - -map v:0 -s:v:1 854x480 -b:v:1 1.1M \ - -map v:0 -s:v:2 640x360 -b:v:2 365k \ - -map 0:a \ - -force_key_frames "expr:gte(t,n_forced*2)" \ - -sc_threshold 0 \ - -streaming 1 \ - -use_timeline 0 \ - -seg_duration 2 -frag_duration 0.01 \ - -frag_type duration \ - playlist.mpd + -c:v copy \ + -an \ + -movflags frag_every_frame+empty_moov \ + fragmented.mp4 diff --git a/server/.vscode/settings.json b/server/.vscode/settings.json new file mode 100644 index 0000000..4d9636b --- /dev/null +++ b/server/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "rust-analyzer.showUnlinkedFileNotification": false +} \ No newline at end of file diff --git a/server/Cargo.lock b/server/Cargo.lock index 035e890..00d51d5 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -60,6 +60,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "anyhow" +version = "1.0.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7de8ce5e0f9f8d88245311066a578d72b7af3e7088f32783804676302df237e4" + [[package]] name = "atty" version = "0.2.14" @@ -89,6 +95,18 @@ version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + +[[package]] +name = "bytes" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" + [[package]] name = "cc" version = "1.0.79" @@ -242,6 +260,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "itoa" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" + [[package]] name = "js-sys" version = "0.3.61" @@ -302,6 +326,63 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "mp4" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "509348cba250e7b852a875100a2ddce7a36ee3abf881a681c756670c1774264d" +dependencies = [ + "byteorder", + "bytes", + "num-rational", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "num-bigint" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", + "serde", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", +] + [[package]] name = "octets" version = "0.2.0" @@ -394,11 +475,42 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + [[package]] name = "serde" version = "1.0.160" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.160" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.15", +] + +[[package]] +name = "serde_json" +version = "1.0.96" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +dependencies = [ + "itoa", + "ryu", + "serde", +] [[package]] name = "slab" @@ -461,6 +573,26 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "thiserror" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.15", +] + [[package]] name = "unicode-ident" version = "1.0.8" @@ -483,12 +615,16 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" name = "warp" version = "0.1.0" dependencies = [ + "anyhow", "clap", "env_logger", "log", "mio", + "mp4", "quiche", "ring", + "serde", + "serde_json", ] [[package]] diff --git a/server/Cargo.toml b/server/Cargo.toml index c5f4afd..b397086 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -11,4 +11,8 @@ clap = { version = "4.0", features = [ "derive" ] } log = { version = "0.4", features = ["std"] } mio = { version = "0.8", features = ["net", "os-poll"] } env_logger = "0.9.3" -ring = "0.16" \ No newline at end of file +ring = "0.16" +anyhow = "1.0.70" +mp4 = "0.13.0" +serde = "1.0.160" +serde_json = "1.0" \ No newline at end of file diff --git a/server/src/error.rs b/server/src/error.rs deleted file mode 100644 index ef10de2..0000000 --- a/server/src/error.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::io; -use quiche::h3::webtransport; - -#[derive(Debug)] -pub enum Error { - Io(io::Error), - Quiche(quiche::Error), - WebTransport(webtransport::Error), - Server(Server), -} - -impl From for Error { - fn from(err: io::Error) -> Error { - Error::Io(err) - } -} - -impl From for Error { - fn from(err: quiche::Error) -> Error { - Error::Quiche(err) - } -} - -impl From for Error { - fn from(err: webtransport::Error) -> Error { - Error::WebTransport(err) - } -} - -// Custom server error messages. -#[derive(Debug)] -pub enum Server { - InvalidToken, - InvalidConnectionID, - UnknownConnectionID, -} - -impl From for Error { - fn from(err: Server) -> Error { - Error::Server(err) - } -} - -pub type Result = std::result::Result; \ No newline at end of file diff --git a/server/src/lib.rs b/server/src/lib.rs index 241d22a..915e3b1 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -1,3 +1,2 @@ -pub mod error; -pub mod server; -pub mod session; \ No newline at end of file +pub mod transport; +//mod media; \ No newline at end of file diff --git a/server/src/main.rs b/server/src/main.rs index 233d2cc..9c44d0f 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,6 +1,10 @@ -use warp::server::Server; +use quiche::h3::webtransport; +use warp::transport; + +use std::time; use clap::Parser; +use env_logger; /// Search for a pattern in a file and display the lines that contain it. #[derive(Parser)] @@ -16,19 +20,48 @@ struct Cli { /// Use the private key at this path #[arg(short, long, default_value = "../cert/localhost.key")] key: String, + + /// Use the media file at this path + #[arg(short, long, default_value = "../media/fragmented.mp4")] + media: String, } -fn main() { +#[derive(Default)] +struct Connection { + webtransport: Option, +} + +impl transport::App for Connection { + fn poll(&mut self, conn: &mut quiche::Connection, session: &mut webtransport::ServerSession) -> anyhow::Result<()> { + if !conn.is_established() { + // Wait until the handshake finishes + return Ok(()) + } + + if self.webtransport.is_none() { + self.webtransport = Some(webtransport::ServerSession::with_transport(conn)?) + } + + let webtransport = self.webtransport.as_mut().unwrap(); + + Ok(()) + } + + fn timeout(&self) -> Option { + None + } +} +fn main() -> anyhow::Result<()> { + env_logger::init(); + let args = Cli::parse(); - let server_config = warp::server::Config{ + let server_config = transport::Config{ addr: args.addr, cert: args.cert, key: args.key, }; - let mut server = Server::new(server_config).unwrap(); - loop { - server.poll().unwrap() - } + let mut server = transport::Server::::new(server_config).unwrap(); + server.run() } \ No newline at end of file diff --git a/server/src/media.rs b/server/src/media.rs new file mode 100644 index 0000000..5e86adb --- /dev/null +++ b/server/src/media.rs @@ -0,0 +1,108 @@ +use std::{io,fs}; + +use mp4; +use anyhow; +use bytes; + +use mp4::ReadBox; + +pub struct Source { + pub segments: Vec, +} + +impl Source { + pub fn new(path: &str) -> anyhow::Result { + let f = fs::read(path)?; + let mut bytes = bytes::Bytes::from(f); + + let mut segments = Vec::new(); + let mut current = Segment::new(); + + while bytes.len() > 0 { + // NOTE: Cloning is cheap, since the underlying bytes are reference counted. + let mut reader = io::Cursor::new(bytes.clone()); + + let header = mp4::BoxHeader::read(&mut reader)?; + let size: usize = header.size as usize; + + assert!(size > 0, "empty box"); + + let frag = bytes.split_to(size); + let fragment = Fragment{ bytes: frag }; + + match header.name { + /* + mp4::BoxType::FtypBox => { + } + mp4::BoxType::MoovBox => { + moov = mp4::MoovBox::read_box(&mut reader, size)? + } + mp4::BoxType::EmsgBox => { + let emsg = mp4::EmsgBox::read_box(&mut reader, size)?; + emsgs.push(emsg); + } + mp4::BoxType::MdatBox => { + mp4::skip_box(&mut reader, size)?; + } + */ + mp4::BoxType::MoofBox => { + let moof = mp4::MoofBox::read_box(&mut reader, header.size)?; + if has_keyframe(moof) { + segments.push(current); + current = Segment::new(); + } + } + _ => (), + } + + current.fragments.push(fragment); + } + + segments.push(current); + + Ok(Self { segments }) + } +} + +fn has_keyframe(moof: mp4::MoofBox) -> bool { + for traf in moof.trafs { + // TODO trak default flags if this is None + let default_flags = traf.tfhd.default_sample_flags.unwrap_or_default(); + let trun = traf.trun.expect("missing trun box"); + + for i in 0..trun.sample_count { + let mut flags = match trun.sample_flags.get(i as usize) { + Some(f) => *f, + None => default_flags, + }; + + if i == 0 && trun.first_sample_flags.is_some() { + flags = trun.first_sample_flags.unwrap(); + } + + // https://chromium.googlesource.com/chromium/src/media/+/master/formats/mp4/track_run_iterator.cc#177 + let keyframe = (flags >> 24) & 0x3 == 0x2; // kSampleDependsOnNoOther + let non_sync = (flags >> 16) & 0x1 == 0x1; // kSampleIsNonSyncSample + + if keyframe && non_sync { + return true + } + } + } + + false +} + +pub struct Segment { + pub fragments: Vec, +} + +impl Segment { + fn new() -> Self { + Segment { fragments: Vec::new() } + } +} + +pub struct Fragment { + pub bytes: bytes::Bytes, +} diff --git a/server/src/media/mod.rs b/server/src/media/mod.rs new file mode 100644 index 0000000..b5cb700 --- /dev/null +++ b/server/src/media/mod.rs @@ -0,0 +1 @@ +pub mod source; \ No newline at end of file diff --git a/server/src/media/source.rs b/server/src/media/source.rs new file mode 100644 index 0000000..7d83b5b --- /dev/null +++ b/server/src/media/source.rs @@ -0,0 +1,163 @@ +use std::{io,fs,time}; +use io::Read; + +use mp4; +use anyhow; + +use mp4::ReadBox; + +pub struct Source { + reader: io::BufReader, + start: time::Instant, + + pending: Option, + sequence: u64, +} + +pub struct Fragment { + pub data: Vec, + pub segment_id: u64, + pub timestamp: u64, +} + +impl Source { + pub fn new(path: &str) -> io::Result { + let f = fs::File::open(path)?; + let reader = io::BufReader::new(f); + let start = time::Instant::now(); + + Ok(Self{ + reader, + start, + pending: None, + sequence: 0, + }) + } + + pub fn next(&mut self) -> anyhow::Result> { + let pending = match self.pending.take() { + Some(f) => f, + None => self.next_inner()?, + }; + + if pending.timestamp > 0 && pending.timestamp < self.start.elapsed().as_millis() as u64 { + self.pending = Some(pending); + return Ok(None) + } + + Ok(Some(pending)) + } + + fn next_inner(&mut self) -> anyhow::Result { + // Read the next full atom. + let atom = read_box(&mut self.reader)?; + let mut timestamp = 0; + + // Before we return it, let's do some simple parsing. + let mut reader = io::Cursor::new(&atom); + let header = mp4::BoxHeader::read(&mut reader)?; + + + match header.name { + mp4::BoxType::MoofBox => { + let moof = mp4::MoofBox::read_box(&mut reader, header.size)?; + + if has_keyframe(&moof) { + self.sequence += 1 + } + + timestamp = first_timestamp(&moof); + } + _ => (), + } + + Ok(Fragment { + data: atom, + segment_id: self.sequence, + timestamp: timestamp, + }) + } +} + +// Read a full MP4 atom into a vector. +fn read_box(reader: &mut R) -> anyhow::Result> { + // Read the 8 bytes for the size + type + let mut buf = [0u8; 8]; + reader.read_exact(&mut buf)?; + + // Convert the first 4 bytes into the size. + let size = u32::from_be_bytes(buf[0..4].try_into()?) as u64; + let mut out = buf.to_vec(); + + let mut limit = match size { + // Runs until the end of the file. + 0 => reader.take(u64::MAX), + + // The next 8 bytes are the extended size to be used instead. + 1 => { + reader.read_exact(&mut buf)?; + let size_large = u64::from_be_bytes(buf); + anyhow::ensure!(size_large >= 16, "impossible extended box size: {}", size_large); + + reader.take(size_large - 16) + }, + + 2..=7 => { + anyhow::bail!("impossible box size: {}", size) + } + + // Otherwise read based on the size. + size => reader.take(size - 8) + }; + + // Append to the vector and return it. + limit.read_to_end(&mut out)?; + + Ok(out) +} + +fn has_keyframe(moof: &mp4::MoofBox) -> bool { + for traf in &moof.trafs { + // TODO trak default flags if this is None + let default_flags = traf.tfhd.default_sample_flags.unwrap_or_default(); + let trun = match &traf.trun { + Some(t) => t, + None => return false, + }; + + for i in 0..trun.sample_count { + let mut flags = match trun.sample_flags.get(i as usize) { + Some(f) => *f, + None => default_flags, + }; + + if i == 0 && trun.first_sample_flags.is_some() { + flags = trun.first_sample_flags.unwrap(); + } + + // https://chromium.googlesource.com/chromium/src/media/+/master/formats/mp4/track_run_iterator.cc#177 + let keyframe = (flags >> 24) & 0x3 == 0x2; // kSampleDependsOnNoOther + let non_sync = (flags >> 16) & 0x1 == 0x1; // kSampleIsNonSyncSample + + if keyframe && non_sync { + return true + } + } + } + + false +} + +fn first_timestamp(moof: &mp4::MoofBox) -> u64 { + let traf = match moof.trafs.first() { + Some(t) => t, + None => return 0, + }; + + let tfdt = match &traf.tfdt { + Some(t) => t, + None => return 0, + }; + + tfdt.base_media_decode_time +} diff --git a/server/src/message.rs b/server/src/message.rs new file mode 100644 index 0000000..1421e91 --- /dev/null +++ b/server/src/message.rs @@ -0,0 +1,40 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub struct Message { + pub init: Option, + pub segment: Option, +} + +#[derive(Serialize, Deserialize)] +pub struct Init { + pub id: String, +} + +#[derive(Serialize, Deserialize)] +pub struct Segment { + pub init: String, + pub timestamp: u64, +} + +impl Message { + pub fn new() -> Self { + Message { + init: None, + segment: None, + } + } + + pub fn serialize(&self) -> anyhow::Result> { + let str = serde_json::to_string(self)?; + let bytes = str.as_bytes(); + let size = bytes.len() + 8; + + let mut out = Vec::with_capacity(size); + out.extend_from_slice(b"warp"); + out.extend_from_slice(&size.to_be_bytes()); + out.extend_from_slice(bytes); + + Ok(out) + } +} \ No newline at end of file diff --git a/server/src/server.rs b/server/src/server.rs deleted file mode 100644 index b3f2359..0000000 --- a/server/src/server.rs +++ /dev/null @@ -1,332 +0,0 @@ -use crate::session; -use crate::error; - -use session::Session; -use error::{Error, Result}; - -use std::{io, net}; -use log; - -use quiche::h3::webtransport; - -const MAX_DATAGRAM_SIZE: usize = 1350; - -pub struct Server { - // IO stuff - socket: mio::net::UdpSocket, - poll: mio::Poll, - events: mio::Events, - - // QUIC stuff - quic: quiche::Config, - sessions: session::Map, - seed: ring::hmac::Key, // connection ID seed -} - -pub struct Config { - pub addr: String, - pub cert: String, - pub key: String, -} - -impl Server { - pub fn new(config: Config) -> io::Result { - // Listen on the provided socket address - let addr = config.addr.parse().unwrap(); - let mut socket = mio::net::UdpSocket::bind(addr).unwrap(); - - // Setup the event loop. - let poll = mio::Poll::new().unwrap(); - let events = mio::Events::with_capacity(1024); - let sessions = session::Map::new(); - - poll.registry().register( - &mut socket, - mio::Token(0), - mio::Interest::READABLE, - ).unwrap(); - - // Generate random values for connection IDs. - let rng = ring::rand::SystemRandom::new(); - let seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap(); - - // Create the configuration for the QUIC connections. - let mut quic = quiche::Config::new(quiche::PROTOCOL_VERSION).unwrap(); - quic.load_cert_chain_from_pem_file(&config.cert).unwrap(); - quic.load_priv_key_from_pem_file(&config.key).unwrap(); - quic.set_application_protos(quiche::h3::APPLICATION_PROTOCOL).unwrap(); - quic.set_max_idle_timeout(5000); - quic.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); - quic.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE); - quic.set_initial_max_data(10_000_000); - quic.set_initial_max_stream_data_bidi_local(1_000_000); - quic.set_initial_max_stream_data_bidi_remote(1_000_000); - quic.set_initial_max_stream_data_uni(1_000_000); - quic.set_initial_max_streams_bidi(100); - quic.set_initial_max_streams_uni(100); - quic.set_disable_active_migration(true); - quic.enable_early_data(); - quic.enable_dgram(true, 65536, 65536); - - Ok(Server { - socket, - poll, - events, - - quic, - sessions, - seed - }) - } - - pub fn poll(&mut self) -> io::Result<()> { - self.receive().unwrap(); - self.send().unwrap(); - self.cleanup().unwrap(); - - Ok(()) - } - - fn receive(&mut self) -> io::Result<()> { - // Find the shorter timeout from all the active connections. - // - // TODO: use event loop that properly supports timers - let timeout = self.sessions.values().filter_map(|c| c.conn.timeout()).min(); - - self.poll.poll(&mut self.events, timeout).unwrap(); - - // If the event loop reported no events, it means that the timeout - // has expired, so handle it without attempting to read packets. We - // will then proceed with the send loop. - if self.events.is_empty() { - self.sessions.values_mut().for_each(|session| { - session.conn.on_timeout() - }); - - return Ok(()) - } - - // Read incoming UDP packets from the socket and feed them to quiche, - // until there are no more packets to read. - loop { - match self.receive_once() { - Err(Error::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => return Ok(()), - Err(e) => log::error!("{:?}", e), - Ok(_) => (), - } - } - } - - fn receive_once(&mut self) -> Result<()> { - let mut src= [0; MAX_DATAGRAM_SIZE]; - - let (len, from) = self.socket.recv_from(&mut src).unwrap(); - let src = &mut src[..len]; - - let info = quiche::RecvInfo { - to: self.socket.local_addr().unwrap(), - from, - }; - - // Lookup a connection based on the packet's connection ID. If there - // is no connection matching, create a new one. - let pair = match self.accept(src, from).unwrap() { - Some(v) => v, - None => return Ok(()), - }; - - let conn = &mut pair.conn; - - // Process potentially coalesced packets. - conn.recv(src, info).unwrap(); - - // Create a new HTTP/3 connection as soon as the QUIC connection - // is established. - if (conn.is_in_early_data() || conn.is_established()) && pair.session.is_none() { - let session = webtransport::ServerSession::with_transport(conn).unwrap(); - pair.session = Some(session); - } - - // The `poll` can pull out the events that occurred according to the data passed here. - for (_, session) in self.sessions.iter_mut() { - session.poll().unwrap(); - } - - Ok(()) - } - - fn accept(&mut self, src: &mut [u8], from: net::SocketAddr) -> error::Result> { - // Parse the QUIC packet's header. - let hdr = quiche::Header::from_slice(src, quiche::MAX_CONN_ID_LEN).unwrap(); - - let conn_id = ring::hmac::sign(&self.seed, &hdr.dcid); - let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN]; - let conn_id = conn_id.to_vec().into(); - - if self.sessions.contains_key(&hdr.dcid) { - let pair = self.sessions.get_mut(&hdr.dcid).unwrap(); - return Ok(Some(pair)) - } else if self.sessions.contains_key(&conn_id) { - let pair = self.sessions.get_mut(&conn_id).unwrap(); - return Ok(Some(pair)); - } - - if hdr.ty != quiche::Type::Initial { - return Err(error::Server::UnknownConnectionID.into()) - } - - let mut dst = [0; MAX_DATAGRAM_SIZE]; - - if !quiche::version_is_supported(hdr.version) { - let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut dst).unwrap(); - let dst= &dst[..len]; - - self.socket.send_to(dst, from).unwrap(); - return Ok(None) - } - - let mut scid = [0; quiche::MAX_CONN_ID_LEN]; - scid.copy_from_slice(&conn_id); - - let scid = quiche::ConnectionId::from_ref(&scid); - - // Token is always present in Initial packets. - let token = hdr.token.as_ref().unwrap(); - - // Do stateless retry if the client didn't send a token. - if token.is_empty() { - let new_token = mint_token(&hdr, &from); - - let len = quiche::retry( - &hdr.scid, - &hdr.dcid, - &scid, - &new_token, - hdr.version, - &mut dst, - ) - .unwrap(); - - let dst= &dst[..len]; - - self.socket.send_to(dst, from).unwrap(); - return Ok(None) - } - - let odcid = validate_token(&from, token); - - // The token was not valid, meaning the retry failed, so - // drop the packet. - if odcid.is_none() { - return Err(error::Server::InvalidToken.into()) - } - - if scid.len() != hdr.dcid.len() { - return Err(error::Server::InvalidConnectionID.into()) - } - - // Reuse the source connection ID we sent in the Retry packet, - // instead of changing it again. - let conn_id= hdr.dcid.clone(); - let local_addr = self.socket.local_addr().unwrap(); - - let conn = - quiche::accept(&conn_id, odcid.as_ref(), local_addr, from, &mut self.quic) - .unwrap(); - - self.sessions.insert( - conn_id.clone(), - Session { - conn, - session: None, - }, - ); - - let pair = self.sessions.get_mut(&conn_id).unwrap(); - Ok(Some(pair)) - } - - fn send(&mut self) -> io::Result<()> { - let mut pkt = [0; MAX_DATAGRAM_SIZE]; - - // Generate outgoing QUIC packets for all active connections and send - // them on the UDP socket, until quiche reports that there are no more - // packets to be sent. - for session in self.sessions.values_mut() { - loop { - let (size , info) = session.conn.send(&mut pkt).unwrap(); - let pkt = &pkt[..size]; - - match self.socket.send_to(&pkt, info.to) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => break, - Err(err) => return Err(err), - Ok(_) => (), - } - } - } - - Ok(()) - } - - fn cleanup(&mut self) -> io::Result<()> { - // Garbage collect closed connections. - self.sessions.retain(|_, session| !session.conn.is_closed() ); - Ok(()) - } -} - -/// Generate a stateless retry token. -/// -/// The token includes the static string `"quiche"` followed by the IP address -/// of the client and by the original destination connection ID generated by the -/// client. -/// -/// Note that this function is only an example and doesn't do any cryptographic -/// authenticate of the token. *It should not be used in production system*. -fn mint_token(hdr: &quiche::Header, src: &std::net::SocketAddr) -> Vec { - let mut token = Vec::new(); - - token.extend_from_slice(b"quiche"); - - let addr = match src.ip() { - std::net::IpAddr::V4(a) => a.octets().to_vec(), - std::net::IpAddr::V6(a) => a.octets().to_vec(), - }; - - token.extend_from_slice(&addr); - token.extend_from_slice(&hdr.dcid); - - token -} - -/// Validates a stateless retry token. -/// -/// This checks that the ticket includes the `"quiche"` static string, and that -/// the client IP address matches the address stored in the ticket. -/// -/// Note that this function is only an example and doesn't do any cryptographic -/// authenticate of the token. *It should not be used in production system*. -fn validate_token<'a>( - src: &std::net::SocketAddr, token: &'a [u8], -) -> Option> { - if token.len() < 6 { - return None; - } - - if &token[..6] != b"quiche" { - return None; - } - - let token = &token[6..]; - - let addr = match src.ip() { - std::net::IpAddr::V4(a) => a.octets().to_vec(), - std::net::IpAddr::V6(a) => a.octets().to_vec(), - }; - - if token.len() < addr.len() || &token[..addr.len()] != addr.as_slice() { - return None; - } - - Some(quiche::ConnectionId::from_ref(&token[addr.len()..])) -} \ No newline at end of file diff --git a/server/src/session.rs b/server/src/session.rs deleted file mode 100644 index e06bcfa..0000000 --- a/server/src/session.rs +++ /dev/null @@ -1,104 +0,0 @@ -use crate::error; -use error::Result; - -use std::collections::HashMap; -use quiche::h3::webtransport; - -pub struct Session { - pub conn: quiche::Connection, - pub session: Option, -} - -pub type Map = HashMap, Session>; - -impl Session { - // Process any updates to a session. - pub fn poll(&mut self) -> Result<()> { - let session = match &mut self.session { - Some(s) => s, - None => return Ok(()), - }; - - loop { - let event = match session.poll(&mut self.conn) { - Err(webtransport::Error::Done) => return Ok(()), - Err(e) => return Err(e.into()), - Ok(e) => e, - }; - - match event { - webtransport::ServerEvent::ConnectRequest(_req) => { - // you can handle request with - // req.authority() - // req.path() - // and you can validate this request with req.origin() - session.accept_connect_request(&mut self.conn, None).unwrap(); - }, - webtransport::ServerEvent::StreamData(stream_id) => { - let mut buf = vec![0; 10000]; - while let Ok(len) = - session.recv_stream_data(&mut self.conn, stream_id, &mut buf) - { - let stream_data = &buf[0..len]; - - // handle stream_data - if (stream_id & 0x2) == 0 { - // bidirectional stream - // you can send data through this stream. - session - .send_stream_data(&mut self.conn, stream_id, stream_data) - .unwrap(); - } else { - // you cannot send data through client-initiated-unidirectional-stream. - // so, open new server-initiated-unidirectional-stream, and send data - // through it. - let new_stream_id = - session.open_stream(&mut self.conn, false).unwrap(); - session - .send_stream_data(&mut self.conn, new_stream_id, stream_data) - .unwrap(); - } - } - } - - webtransport::ServerEvent::StreamFinished(_stream_id) => { - // A WebTrnasport stream finished, handle it. - } - - webtransport::ServerEvent::Datagram => { - let mut buf = vec![0; 1500]; - while let Ok((in_session, offset, total)) = - session.recv_dgram(&mut self.conn, &mut buf) - { - if in_session { - let dgram = &buf[offset..total]; - dbg!(std::string::String::from_utf8_lossy(dgram)); - // handle this dgram - - // for instance, you can write echo-server like following - session.send_dgram(&mut self.conn, dgram).unwrap(); - } else { - // this dgram is not related to current WebTransport session. ignore. - } - } - } - - webtransport::ServerEvent::SessionReset(_e) => { - // Peer reset session stream, handle it. - } - - webtransport::ServerEvent::SessionFinished => { - // Peer finish session stream, handle it. - } - - webtransport::ServerEvent::SessionGoAway => { - // Peer signalled it is going away, handle it. - } - - webtransport::ServerEvent::Other(_stream_id, _event) => { - // Original h3::Event which is not related to WebTransport. - } - } - } - } -} \ No newline at end of file diff --git a/server/src/transport/app.rs b/server/src/transport/app.rs new file mode 100644 index 0000000..b55ce53 --- /dev/null +++ b/server/src/transport/app.rs @@ -0,0 +1,8 @@ +use std::time; + +use quiche::h3::webtransport; + +pub trait App: Default { + fn poll(&mut self, conn: &mut quiche::Connection, session: &mut webtransport::ServerSession) -> anyhow::Result<()>; + fn timeout(&self) -> Option; +} diff --git a/server/src/transport/connection.rs b/server/src/transport/connection.rs new file mode 100644 index 0000000..2fb12d9 --- /dev/null +++ b/server/src/transport/connection.rs @@ -0,0 +1,15 @@ +use quiche; +use quiche::h3::webtransport; + +use std::collections::hash_map as hmap; + +pub type Id = quiche::ConnectionId<'static>; + +use super::app; + +pub type Map = hmap::HashMap>; +pub struct Connection { + pub quiche: quiche::Connection, + pub session: Option, + pub app: T, +} \ No newline at end of file diff --git a/server/src/transport/mod.rs b/server/src/transport/mod.rs new file mode 100644 index 0000000..c003fa3 --- /dev/null +++ b/server/src/transport/mod.rs @@ -0,0 +1,7 @@ +mod server; +mod session; +mod connection; +mod app; + +pub use app::App; +pub use server::{Config, Server}; \ No newline at end of file diff --git a/server/src/transport/server.rs b/server/src/transport/server.rs new file mode 100644 index 0000000..bbc2a94 --- /dev/null +++ b/server/src/transport/server.rs @@ -0,0 +1,334 @@ +use std::io; + +use quiche::h3::webtransport; + +use super::connection; +use super::app; + +const MAX_DATAGRAM_SIZE: usize = 1350; + +pub struct Server { + // IO stuff + socket: mio::net::UdpSocket, + poll: mio::Poll, + events: mio::Events, + + // QUIC stuff + quic: quiche::Config, + seed: ring::hmac::Key, // connection ID seed + + conns: connection::Map, +} + +pub struct Config { + pub addr: String, + pub cert: String, + pub key: String, +} + +impl Server { + pub fn new(config: Config) -> io::Result { + // Listen on the provided socket address + let addr = config.addr.parse().unwrap(); + let mut socket = mio::net::UdpSocket::bind(addr).unwrap(); + + // Setup the event loop. + let poll = mio::Poll::new().unwrap(); + let events = mio::Events::with_capacity(1024); + + poll.registry().register( + &mut socket, + mio::Token(0), + mio::Interest::READABLE, + ).unwrap(); + + // Generate random values for connection IDs. + let rng = ring::rand::SystemRandom::new(); + let seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap(); + + // Create the configuration for the QUIC conns. + let mut quic = quiche::Config::new(quiche::PROTOCOL_VERSION).unwrap(); + quic.load_cert_chain_from_pem_file(&config.cert).unwrap(); + quic.load_priv_key_from_pem_file(&config.key).unwrap(); + quic.set_application_protos(quiche::h3::APPLICATION_PROTOCOL).unwrap(); + quic.set_max_idle_timeout(5000); + quic.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); + quic.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE); + quic.set_initial_max_data(10_000_000); + quic.set_initial_max_stream_data_bidi_local(1_000_000); + quic.set_initial_max_stream_data_bidi_remote(1_000_000); + quic.set_initial_max_stream_data_uni(1_000_000); + quic.set_initial_max_streams_bidi(100); + quic.set_initial_max_streams_uni(100); + quic.set_disable_active_migration(true); + quic.enable_early_data(); + quic.enable_dgram(true, 65536, 65536); + + let conns = Default::default(); + + Ok(Server { + socket, + poll, + events, + + quic, + seed, + + conns, + }) + } + + pub fn run(&mut self) -> anyhow::Result<()> { + loop { + self.wait()?; + self.receive()?; + self.app()?; + self.send()?; + } + } + + pub fn wait(&mut self) -> anyhow::Result<()> { + // Find the shorter timeout from all the active connections. + // + // TODO: use event loop that properly supports timers + let timeout = self.conns.values().filter_map(|c| { + let timeout = c.quiche.timeout(); + let expires = c.app.timeout(); + + match (timeout, expires) { + (Some(a), Some(b)) => Some(a.min(b)), + (Some(a), None) => Some(a), + (None, Some(b)) => Some(b), + (None, None) => None, + } + }).min(); + + self.poll.poll(&mut self.events, timeout).unwrap(); + + // If the event loop reported no events, it means that the timeout + // has expired, so handle it without attempting to read packets. We + // will then proceed with the send loop. + if self.events.is_empty() { + for conn in self.conns.values_mut() { + conn.quiche.on_timeout(); + } + } + + Ok(()) + } + + // Reads packets from the socket, updating any internal connection state. + fn receive(&mut self) -> anyhow::Result<()> { + let mut src= [0; MAX_DATAGRAM_SIZE]; + + // Try reading any data currently available on the socket. + loop { + let (len, from) = match self.socket.recv_from(&mut src) { + Ok(v) => v, + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => return Ok(()), + Err(e) => return Err(e.into()), + }; + + let src = &mut src[..len]; + + let info = quiche::RecvInfo { + to: self.socket.local_addr().unwrap(), + from, + }; + + // Parse the QUIC packet's header. + let hdr = quiche::Header::from_slice(src, quiche::MAX_CONN_ID_LEN).unwrap(); + + let conn_id = ring::hmac::sign(&self.seed, &hdr.dcid); + let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN]; + let conn_id = conn_id.to_vec().into(); + + // Check if it's an existing connection. + if let Some(conn) = self.conns.get_mut(&hdr.dcid) { + // initial or handshake traffic. + conn.quiche.recv(src, info)?; + + if conn.session.is_none() && conn.quiche.is_established() { + conn.session = Some(webtransport::ServerSession::with_transport(&mut conn.quiche)?) + } + + continue + } else if let Some(conn) = self.conns.get_mut(&conn_id) { + // 1-RTT traffic. + conn.quiche.recv(src, info)?; + + // TODO is this needed here? + if conn.session.is_none() && conn.quiche.is_established() { + conn.session = Some(webtransport::ServerSession::with_transport(&mut conn.quiche)?) + } + + continue + } + + if hdr.ty != quiche::Type::Initial { + anyhow::bail!("unknown connection ID"); + } + + let mut dst = [0; MAX_DATAGRAM_SIZE]; + + if !quiche::version_is_supported(hdr.version) { + let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut dst).unwrap(); + let dst= &dst[..len]; + + self.socket.send_to(dst, from).unwrap(); + continue + } + + let mut scid = [0; quiche::MAX_CONN_ID_LEN]; + scid.copy_from_slice(&conn_id); + + let scid = quiche::ConnectionId::from_ref(&scid); + + // Token is always present in Initial packets. + let token = hdr.token.as_ref().unwrap(); + + // Do stateless retry if the client didn't send a token. + if token.is_empty() { + let new_token = mint_token(&hdr, &from); + + let len = quiche::retry( + &hdr.scid, + &hdr.dcid, + &scid, + &new_token, + hdr.version, + &mut dst, + ) + .unwrap(); + + let dst= &dst[..len]; + + self.socket.send_to(dst, from).unwrap(); + continue + } + + let odcid = validate_token(&from, token); + + // The token was not valid, meaning the retry failed, so + // drop the packet. + if odcid.is_none() { + anyhow::bail!("invalid token"); + } + + if scid.len() != hdr.dcid.len() { + anyhow::bail!("invalid connection ID"); + } + + // Reuse the source connection ID we sent in the Retry packet, + // instead of changing it again. + let conn_id= hdr.dcid.clone(); + let local_addr = self.socket.local_addr().unwrap(); + + let mut conn = quiche::accept(&conn_id, odcid.as_ref(), local_addr, from, &mut self.quic)?; + + // Process potentially coalesced packets. + conn.recv(src, info)?; + + let user = connection::Connection{ + quiche: conn, + session: None, + app: T::default(), + }; + + self.conns.insert(user.quiche.source_id().into_owned(), user); + } + } + + pub fn app(&mut self) -> anyhow::Result<()> { + for (_, conn) in &mut self.conns { + if let Some(session) = &mut conn.session { + conn.app.poll(&mut conn.quiche, session)?; + } + } + + Ok(()) + } + + // Generate outgoing QUIC packets for all active connections and send + // them on the UDP socket, until quiche reports that there are no more + // packets to be sent. + pub fn send(&mut self) -> anyhow::Result<()> { + let mut pkt = [0; MAX_DATAGRAM_SIZE]; + + for conn in self.conns.values_mut() { + loop { + let (size , info) = match conn.quiche.send(&mut pkt) { + Ok(v) => v, + Err(quiche::Error::Done) => return Ok(()), + Err(e) => return Err(e.into()), + }; + + let pkt = &pkt[..size]; + + match self.socket.send_to(&pkt, info.to) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => break, + Err(err) => return Err(err.into()), + Ok(_) => (), + } + } + } + + Ok(()) + } +} + +/// Generate a stateless retry token. +/// +/// The token includes the static string `"quiche"` followed by the IP address +/// of the client and by the original destination connection ID generated by the +/// client. +/// +/// Note that this function is only an example and doesn't do any cryptographic +/// authenticate of the token. *It should not be used in production system*. +fn mint_token(hdr: &quiche::Header, src: &std::net::SocketAddr) -> Vec { + let mut token = Vec::new(); + + token.extend_from_slice(b"quiche"); + + let addr = match src.ip() { + std::net::IpAddr::V4(a) => a.octets().to_vec(), + std::net::IpAddr::V6(a) => a.octets().to_vec(), + }; + + token.extend_from_slice(&addr); + token.extend_from_slice(&hdr.dcid); + + token +} + +/// Validates a stateless retry token. +/// +/// This checks that the ticket includes the `"quiche"` static string, and that +/// the client IP address matches the address stored in the ticket. +/// +/// Note that this function is only an example and doesn't do any cryptographic +/// authenticate of the token. *It should not be used in production system*. +fn validate_token<'a>( + src: &std::net::SocketAddr, token: &'a [u8], +) -> Option> { + if token.len() < 6 { + return None; + } + + if &token[..6] != b"quiche" { + return None; + } + + let token = &token[6..]; + + let addr = match src.ip() { + std::net::IpAddr::V4(a) => a.octets().to_vec(), + std::net::IpAddr::V6(a) => a.octets().to_vec(), + }; + + if token.len() < addr.len() || &token[..addr.len()] != addr.as_slice() { + return None; + } + + Some(quiche::ConnectionId::from_ref(&token[addr.len()..])) +} \ No newline at end of file diff --git a/server/src/transport/session.rs b/server/src/transport/session.rs new file mode 100644 index 0000000..1334704 --- /dev/null +++ b/server/src/transport/session.rs @@ -0,0 +1,252 @@ +use std::collections::hash_map as hmap; +use quiche::h3::webtransport; + +type Session = webtransport::ServerSession; +type Map = hmap::HashMap, Session>; + +/* +impl Session { + pub fn with_transport(conn: &mut quiche::Connection) -> anyhow::Result { + let session = webtransport::ServerSession::with_transport(conn)?; + + Ok(Self{ + session + }) + } + + // Process any updates to a session. + pub fn poll(&mut self) -> anyhow::Result<()> { + log::debug!("poll conn"); + while self.poll_once()? {} + + log::debug!("poll streams"); + self.poll_streams()?; + + Ok(()) + } + + // Process any updates to a session. + pub fn poll_once(&mut self) -> anyhow::Result { + let session = match &mut self.session { + Some(s) => s, + None => return Ok(false), + }; + + let event = match session.poll(&mut self.conn) { + Err(webtransport::Error::Done) => return Ok(false), + Err(e) => return Err(e.into()), + Ok(e) => e, + }; + + match event { + webtransport::ServerEvent::ConnectRequest(req) => { + log::debug!("new connect {:?}", req); + // you can handle request with + // req.authority() + // req.path() + // and you can validate this request with req.origin() + session.accept_connect_request(&mut self.conn, None).unwrap(); + }, + webtransport::ServerEvent::StreamData(stream_id) => { + log::debug!("on stream data {}", stream_id); + + let mut buf = vec![0; 10000]; + while let Ok(len) = + session.recv_stream_data(&mut self.conn, stream_id, &mut buf) + { + let stream_data = &buf[0..len]; + log::debug!("stream data {:?}", stream_data); + +/* + // handle stream_data + if (stream_id & 0x2) == 0 { + // bidirectional stream + // you can send data through this stream. + session + .send_stream_data(&mut self.conn, stream_id, stream_data) + .unwrap(); + } else { + // you cannot send data through client-initiated-unidirectional-stream. + // so, open new server-initiated-unidirectional-stream, and send data + // through it. + let new_stream_id = + session.open_stream(&mut self.conn, false).unwrap(); + session + .send_stream_data(&mut self.conn, new_stream_id, stream_data) + .unwrap(); + } + */ + } + } + + webtransport::ServerEvent::StreamFinished(stream_id) => { + // A WebTrnasport stream finished, handle it. + log::debug!("stream finished {}", stream_id); + } + + webtransport::ServerEvent::Datagram => { + log::debug!("datagram"); + } + + webtransport::ServerEvent::SessionReset(e) => { + log::debug!("session reset {}", e); + // Peer reset session stream, handle it. + } + + webtransport::ServerEvent::SessionFinished => { + log::debug!("session finished"); + // Peer finish session stream, handle it. + } + + webtransport::ServerEvent::SessionGoAway => { + log::debug!("session go away"); + // Peer signalled it is going away, handle it. + } + + webtransport::ServerEvent::Other(stream_id, event) => { + log::debug!("session other: {} {:?}", stream_id, event); + // Original h3::Event which is not related to WebTransport. + } + } + + Ok(true) + } + +/* + fn poll_source(&mut self) -> anyhow::Result<()> { + let media = match &mut self.media { + Some(m) => m, + None => return Ok(()), + }; + + let fragment = match media.next()? { + Some(f) => f, + None => return Ok(()), + }; + + // Get or create a new stream for each unique segment ID. + let stream_id = match self.segments.entry(fragment.segment_id) { + map::Entry::Occupied(e) => e.into_mut(), + map::Entry::Vacant(e) => { + let stream_id = self.start_stream(&fragment)?; + e.insert(stream_id) + }, + }; + + // Get or create a buffered object for each unique stream ID. + let buffered = match self.streams.entry(*stream_id) { + map::Entry::Occupied(e) => e.into_mut(), + map::Entry::Vacant(e) => e.insert(Buffered::new()), + }; + + let session = match &mut self.session { + Some(s) => s, + None => return Ok(()), + }; + + let data = fragment.data.as_slice(); + + match self.conn.stream_writable(*stream_id, data.len()) { + Ok(true) if buffered.len() == 0 => { + session.send_stream_data(&mut self.conn, *stream_id, data)?; + }, + Ok(_) => buffered.push_back(fragment.data), + Err(quiche::Error::Done) => {}, // stream closed? + Err(e) => anyhow::bail!(e), + }; + + Ok(()) + } + + fn start_stream(&mut self, fragment: &source::Fragment) -> anyhow::Result { + let conn = &mut self.conn; + let session = self.session.as_mut().unwrap(); + + let stream_id = session.open_stream(conn, false)?; + + // TODO: conn.stream_priority(stream_id, urgency, incremental) + + let mut message = message::Message::new(); + if fragment.segment_id == 0 { + message.init = Some(message::Init{ + id: "video".to_string(), + }); + } else { + message.segment = Some(message::Segment{ + init: "video".to_string(), + timestamp: fragment.timestamp, + }); + } + + let data= message.serialize()?; + match conn.stream_writable(stream_id, data.len()) { + Ok(true) => { + session.send_stream_data(conn, stream_id, data.as_slice())?; + }, + Ok(false) => { + let mut buffered = Buffered::new(); + buffered.push_back(data); + + self.streams.insert(stream_id, buffered); + }, + Err(quiche::Error::Done) => {}, + Err(e) => anyhow::bail!(e), + }; + + Ok(stream_id) + } +*/ + + fn poll_streams(&mut self) -> anyhow::Result<()> { + // TODO make sure this loops in priority order + for stream_id in self.conn.writable() { + self.poll_stream(stream_id)?; + } + + // Remove any entry buffered values. + self.streams.retain(|_, buffered| buffered.len() > 0 ); + + Ok(()) + } + + pub fn poll_stream(&mut self, stream_id: u64) -> anyhow::Result<()> { + let buffered = match self.streams.get_mut(&stream_id) { + Some(b) => b, + None => return Ok(()), + }; + + let conn = &mut self.conn; + + let session = match &mut self.session { + Some(s) => s, + None => return Ok(()), + }; + + while let Some(data) = buffered.pop_front() { + match conn.stream_writable(stream_id, data.len()) { + Ok(true) => { + session.send_stream_data(conn, stream_id, data.as_slice())?; + }, + Ok(false) => { + buffered.push_front(data); + return Ok(()); + }, + Err(quiche::Error::Done) => {}, + Err(e) => anyhow::bail!(e), + }; + } + + Ok(()) + } + + pub fn timeout(&self) -> Option { + self.conn.timeout() + } + + pub fn on_timeout(&mut self) { + self.conn.on_timeout() + + // custom stuff here + } +} +*/ \ No newline at end of file