From c88f0b045a0fcab25aa563ba02daa08ac404a105 Mon Sep 17 00:00:00 2001 From: kixelated Date: Thu, 8 Jun 2023 00:01:34 -0700 Subject: [PATCH] Migrate to quinn and async Rust (#21) I miss quiche, but it was a pain to do anything asynchronous. MoQ is a pub/sub protocol so it's very important to support subscribers joining/leaving/stalling. The API is also just significantly better since quinn doesn't restrict itself to C bindings, which I'm sure will come back to haunt me when we want OBS support. --- Cargo.lock | 392 +++++++++++++++----------------- Cargo.toml | 38 +++- src/.DS_Store | Bin 6148 -> 0 bytes src/{session => app}/message.rs | 7 +- src/app/mod.rs | 8 + src/app/server.rs | 121 ++++++++++ src/app/session.rs | 115 ++++++++++ src/lib.rs | 3 +- src/main.rs | 55 ++--- src/media/mod.rs | 7 +- src/media/model.rs | 28 +++ src/media/source.rs | 265 ++++++++++++++------- src/media/watch.rs | 127 +++++++++++ src/session/mod.rs | 168 -------------- src/transport/app.rs | 8 - src/transport/connection.rs | 15 -- src/transport/mod.rs | 8 - src/transport/server.rs | 378 ------------------------------ src/transport/streams.rs | 136 ----------- 19 files changed, 826 insertions(+), 1053 deletions(-) delete mode 100644 src/.DS_Store rename src/{session => app}/message.rs (88%) create mode 100644 src/app/mod.rs create mode 100644 src/app/server.rs create mode 100644 src/app/session.rs create mode 100644 src/media/model.rs create mode 100644 src/media/watch.rs delete mode 100644 src/session/mod.rs delete mode 100644 src/transport/app.rs delete mode 100644 src/transport/connection.rs delete mode 100644 src/transport/mod.rs delete mode 100644 src/transport/server.rs delete mode 100644 src/transport/streams.rs diff --git a/Cargo.lock b/Cargo.lock index b5a168c..9803bf8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,15 +11,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - [[package]] name = "anstream" version = "0.3.2" @@ -149,19 +140,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "chrono" -version = "0.4.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e3c5919066adf22df73762e50cffcde3a758f2a848b113b586d1f86728b673b" -dependencies = [ - "iana-time-zone", - "num-integer", - "num-traits", - "serde", - "winapi", -] - [[package]] name = "clap" version = "4.3.0" @@ -204,27 +182,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" -[[package]] -name = "cmake" -version = "0.1.50" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" -dependencies = [ - "cc", -] - [[package]] name = "colorchoice" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" -[[package]] -name = "core-foundation-sys" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" - [[package]] name = "cpufeatures" version = "0.2.7" @@ -244,41 +207,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "darling" -version = "0.20.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0558d22a7b463ed0241e993f76f09f30b126687447751a8638587b864e4b3944" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.20.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab8bfa2e259f8ee1ce5e97824a3c55ec4404a0d772ca7fa96bf19f0752a046eb" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn", -] - -[[package]] -name = "darling_macro" -version = "0.20.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29a358ff9f12ec09c3e61fef9b5a9902623a695a46a917b07f269bff1445611a" -dependencies = [ - "darling_core", - "quote", - "syn", -] - [[package]] name = "digest" version = "0.10.7" @@ -332,6 +260,15 @@ dependencies = [ "libc", ] +[[package]] +name = "fastrand" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" +dependencies = [ + "instant", +] + [[package]] name = "fnv" version = "1.0.7" @@ -347,6 +284,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.28" @@ -363,6 +315,34 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +[[package]] +name = "futures-executor" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" + +[[package]] +name = "futures-macro" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.28" @@ -381,9 +361,13 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -429,6 +413,48 @@ dependencies = [ "tracing", ] +[[package]] +name = "h3" +version = "0.0.2" +source = "git+https://github.com/security-union/h3?branch=add-webtransport#db5c723f653911a476bfd8ffcfebf0f8f2eb980d" +dependencies = [ + "bytes", + "fastrand", + "futures-util", + "http", + "pin-project-lite", + "tokio", + "tracing", +] + +[[package]] +name = "h3-quinn" +version = "0.0.2" +source = "git+https://github.com/security-union/h3?branch=add-webtransport#db5c723f653911a476bfd8ffcfebf0f8f2eb980d" +dependencies = [ + "bytes", + "futures", + "h3", + "quinn", + "quinn-proto", + "tokio", + "tokio-util", +] + +[[package]] +name = "h3-webtransport" +version = "0.1.0" +source = "git+https://github.com/security-union/h3?branch=add-webtransport#db5c723f653911a476bfd8ffcfebf0f8f2eb980d" +dependencies = [ + "bytes", + "futures-util", + "h3", + "http", + "pin-project-lite", + "tokio", + "tracing", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -553,42 +579,13 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.4.9", "tokio", "tower-service", "tracing", "want", ] -[[package]] -name = "iana-time-zone" -version = "0.1.56" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0722cd7114b7de04316e7ea5456a0bbb20e4adb46fd27a3697adb812cff0f37c" -dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "wasm-bindgen", - "windows", -] - -[[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" -dependencies = [ - "cc", -] - -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "idna" version = "0.3.0" @@ -607,7 +604,15 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown", - "serde", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", ] [[package]] @@ -648,24 +653,12 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "lazy_static" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" - [[package]] name = "libc" version = "0.2.144" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b00cc1c228a6782d0f076e7b232802e0c5689d41bb5df366f2a6b6621cfdfe1" -[[package]] -name = "libm" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" - [[package]] name = "linux-raw-sys" version = "0.3.8" @@ -730,14 +723,18 @@ name = "moq" version = "0.1.0" dependencies = [ "anyhow", + "bytes", "clap", "env_logger", + "futures", + "h3", + "h3-quinn", + "h3-webtransport", "hex", "http", "log", - "mio", "mp4", - "quiche", + "quinn", "ring", "rustls 0.21.1", "rustls-pemfile", @@ -832,11 +829,6 @@ dependencies = [ "libc", ] -[[package]] -name = "octets" -version = "0.2.0" -source = "git+https://github.com/kixelated/quiche-webtransport.git?branch=master#007a25b35b9509d673466fed8ddc73fd8d9b4184" - [[package]] name = "once_cell" version = "1.17.1" @@ -920,33 +912,51 @@ dependencies = [ ] [[package]] -name = "qlog" -version = "0.9.0" -source = "git+https://github.com/kixelated/quiche-webtransport.git?branch=master#007a25b35b9509d673466fed8ddc73fd8d9b4184" +name = "quinn" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21252f1c0fc131f1b69182db8f34837e8a69737b8251dff75636a9be0518c324" dependencies = [ - "serde", - "serde_derive", - "serde_json", - "serde_with", - "smallvec", + "bytes", + "futures-io", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls 0.21.1", + "thiserror", + "tokio", + "tracing", ] [[package]] -name = "quiche" -version = "0.17.1" -source = "git+https://github.com/kixelated/quiche-webtransport.git?branch=master#007a25b35b9509d673466fed8ddc73fd8d9b4184" +name = "quinn-proto" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85af4ed6ee5a89f26a26086e9089a6643650544c025158449a3626ebf72884b3" dependencies = [ - "cmake", - "lazy_static", - "libc", - "libm", - "log", - "octets", - "qlog", + "bytes", + "rand", "ring", + "rustc-hash", + "rustls 0.21.1", "slab", - "smallvec", - "winapi", + "thiserror", + "tinyvec", + "tracing", +] + +[[package]] +name = "quinn-udp" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6df19e284d93757a9fb91d63672f7741b129246a669db09d1c0063071debc0c0" +dependencies = [ + "bytes", + "libc", + "socket2 0.5.3", + "tracing", + "windows-sys 0.48.0", ] [[package]] @@ -1029,6 +1039,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustix" version = "0.37.19" @@ -1140,7 +1156,6 @@ version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" dependencies = [ - "indexmap", "itoa", "ryu", "serde", @@ -1158,34 +1173,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_with" -version = "2.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07ff71d2c147a7b57362cead5e22f772cd52f6ab31cfcd9edcd7f6aeb2a0afbe" -dependencies = [ - "base64 0.13.1", - "chrono", - "hex", - "indexmap", - "serde", - "serde_json", - "serde_with_macros", - "time", -] - -[[package]] -name = "serde_with_macros" -version = "2.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "881b6f881b17d13214e5d494c939ebab463d01264ce1811e9d4ac3a882e7695f" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "sha1" version = "0.10.5" @@ -1220,9 +1207,6 @@ name = "smallvec" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" -dependencies = [ - "serde", -] [[package]] name = "socket2" @@ -1234,6 +1218,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "socket2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "spin" version = "0.5.2" @@ -1292,33 +1286,6 @@ dependencies = [ "syn", ] -[[package]] -name = "time" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3403384eaacbca9923fa06940178ac13e4edb725486d70e8e15881d0c836cc" -dependencies = [ - "itoa", - "serde", - "time-core", - "time-macros", -] - -[[package]] -name = "time-core" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" - -[[package]] -name = "time-macros" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" -dependencies = [ - "time-core", -] - [[package]] name = "tinyvec" version = "1.6.0" @@ -1348,7 +1315,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.4.9", "tokio-macros", "windows-sys 0.48.0", ] @@ -1427,9 +1394,21 @@ dependencies = [ "cfg-if", "log", "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing-core" version = "0.1.31" @@ -1688,15 +1667,6 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" -[[package]] -name = "windows" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" -dependencies = [ - "windows-targets 0.48.0", -] - [[package]] name = "windows-sys" version = "0.45.0" diff --git a/Cargo.toml b/Cargo.toml index bb847a7..dfd4eb4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,22 +6,36 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -# Fork of quiche until they add WebTransport support. -quiche = { git = "https://github.com/kixelated/quiche-webtransport.git", branch = "master", features = [ "qlog" ] } -clap = { version = "4.0", features = [ "derive" ] } -log = { version = "0.4", features = ["std"] } -mio = { version = "0.8", features = ["net", "os-poll"] } -env_logger = "0.9.3" +# Fork of h3 with WebTransport support +h3 = { git = "https://github.com/security-union/h3", branch = "add-webtransport" } +h3-quinn = { git = "https://github.com/security-union/h3", branch = "add-webtransport" } +h3-webtransport = { git = "https://github.com/security-union/h3", branch = "add-webtransport" } +quinn = { version = "0.10", default-features = false, features = ["runtime-tokio", "tls-rustls", "ring"] } + +# Crypto dependencies ring = "0.16" -anyhow = "1.0.70" +rustls = { version = "0.21", features = ["dangerous_configuration"] } +rustls-pemfile = "1.0.2" + +# Async stuff +tokio = { version = "1.27", features = ["full"] } +futures = "0.3" + +# Media mp4 = "0.13.0" + +# Encoding +bytes = "1" serde = "1.0.160" serde_json = "1.0" -# Required to serve the fingerprint over HTTPS -tokio = { version = "1", features = ["full"] } +# Web server to serve the fingerprint +http = "0.2" warp = { version = "0.3.3", features = ["tls"] } -rustls = "0.21" -rustls-pemfile = "1.0.2" hex = "0.4.3" -http = "0.2.9" + +# Logging and utility +clap = { version = "4.0", features = [ "derive" ] } +log = { version = "0.4", features = ["std"] } +env_logger = "0.9.3" +anyhow = "1.0.70" diff --git a/src/.DS_Store b/src/.DS_Store deleted file mode 100644 index 5111b4f214ed3f24caee1b3757e393cfab637a8f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKJ5Iwu5SFE0$hU|&~XbcLdTok zk>g*WKtc%3M7z(Po!K3~U3)S_WQOH^+E>I~YyV zBx~5J;%_p*d$&so-B3(3s_owiZ+mGpNuy{oPbbg@j}Py!SBG!sb&L2Li|np^LQdgM zd$hv(EBt1Z(xQ5XjxL)P8q)$1Wr%1=>uOAADPn4| zD#!KmT*TFCxryy-%W8c3V)21dQ3jNOA2WcQ%@XW;)Lt1-29$w<0scN%G{#Uc_vk(y z=v)Z^>_Y7X*IY}mj};6BbB`E-81D-7u0~1>Y@|XlZgtFW zojIA%qxQ;xGEik;$KUoj|6gpc|Eo@Vrwk|q|B3W#CsC_yn}2a!~*P diff --git a/src/session/message.rs b/src/app/message.rs similarity index 88% rename from src/session/message.rs rename to src/app/message.rs index 63faf73..d8ea6b8 100644 --- a/src/session/message.rs +++ b/src/app/message.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Default)] pub struct Message { pub init: Option, pub segment: Option, @@ -16,10 +16,7 @@ pub struct Segment { impl Message { pub fn new() -> Self { - Message { - init: None, - segment: None, - } + Default::default() } pub fn serialize(&self) -> anyhow::Result> { diff --git a/src/app/mod.rs b/src/app/mod.rs new file mode 100644 index 0000000..07f563a --- /dev/null +++ b/src/app/mod.rs @@ -0,0 +1,8 @@ +mod message; +mod server; +mod session; + +pub use server::{Server, ServerConfig}; + +// Reduce the amount of typing +type WebTransportSession = h3_webtransport::server::WebTransportSession; diff --git a/src/app/server.rs b/src/app/server.rs new file mode 100644 index 0000000..6ee228d --- /dev/null +++ b/src/app/server.rs @@ -0,0 +1,121 @@ +use super::session::Session; +use crate::media; + +use std::{fs, io, net, path, sync, time}; + +use super::WebTransportSession; + +use anyhow::Context; + +pub struct Server { + // The QUIC server, yielding new connections and sessions. + server: quinn::Endpoint, + + // The media source + broadcast: media::Broadcast, +} + +pub struct ServerConfig { + pub addr: net::SocketAddr, + pub cert: path::PathBuf, + pub key: path::PathBuf, + + pub broadcast: media::Broadcast, +} + +impl Server { + // Create a new server + pub fn new(config: ServerConfig) -> anyhow::Result { + // Read the PEM certificate chain + let certs = fs::File::open(config.cert).context("failed to open cert file")?; + let mut certs = io::BufReader::new(certs); + let certs = rustls_pemfile::certs(&mut certs)? + .into_iter() + .map(rustls::Certificate) + .collect(); + + // Read the PEM private key + let keys = fs::File::open(config.key).context("failed to open key file")?; + let mut keys = io::BufReader::new(keys); + let mut keys = rustls_pemfile::pkcs8_private_keys(&mut keys)?; + + anyhow::ensure!(keys.len() == 1, "expected a single key"); + let key = rustls::PrivateKey(keys.remove(0)); + + let mut tls_config = rustls::ServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(certs, key)?; + + tls_config.max_early_data_size = u32::MAX; + let alpn: Vec> = vec![ + b"h3".to_vec(), + b"h3-32".to_vec(), + b"h3-31".to_vec(), + b"h3-30".to_vec(), + b"h3-29".to_vec(), + ]; + tls_config.alpn_protocols = alpn; + + let mut server_config = quinn::ServerConfig::with_crypto(sync::Arc::new(tls_config)); + let mut transport_config = quinn::TransportConfig::default(); + transport_config.keep_alive_interval(Some(time::Duration::from_secs(2))); + server_config.transport = sync::Arc::new(transport_config); + + let server = quinn::Endpoint::server(server_config, config.addr)?; + let broadcast = config.broadcast; + + Ok(Self { server, broadcast }) + } + + pub async fn run(&mut self) -> anyhow::Result<()> { + loop { + let conn = self.server.accept().await.context("failed to accept connection")?; + let broadcast = self.broadcast.clone(); + + tokio::spawn(async move { + let session = Self::accept_session(conn).await.context("failed to accept session")?; + + // Use a wrapper run the session. + let session = Session::new(session); + session.serve_broadcast(broadcast).await + }); + } + } + + async fn accept_session(conn: quinn::Connecting) -> anyhow::Result { + let conn = conn.await.context("failed to accept h3 connection")?; + + let mut conn = h3::server::builder() + .enable_webtransport(true) + .enable_connect(true) + .enable_datagram(true) + .max_webtransport_sessions(1) + .send_grease(true) + .build(h3_quinn::Connection::new(conn)) + .await + .context("failed to create h3 server")?; + + let (req, stream) = conn + .accept() + .await + .context("failed to accept h3 session")? + .context("failed to accept h3 request")?; + + let ext = req.extensions(); + anyhow::ensure!(req.method() == http::Method::CONNECT, "expected CONNECT request"); + anyhow::ensure!( + ext.get::() == Some(&h3::ext::Protocol::WEB_TRANSPORT), + "expected WebTransport CONNECT" + ); + + let session = WebTransportSession::accept(req, stream, conn) + .await + .context("failed to accept WebTransport session")?; + + Ok(session) + } +} diff --git a/src/app/session.rs b/src/app/session.rs new file mode 100644 index 0000000..08ecdb7 --- /dev/null +++ b/src/app/session.rs @@ -0,0 +1,115 @@ +use crate::media; + +use anyhow::Context; + +use std::sync::Arc; +use tokio::io::AsyncWriteExt; +use tokio::task::JoinSet; + +use super::WebTransportSession; + +use super::message; + +#[derive(Clone)] +pub struct Session { + // The underlying transport session + transport: Arc, +} + +impl Session { + pub fn new(transport: WebTransportSession) -> Self { + let transport = Arc::new(transport); + Self { transport } + } + + pub async fn serve_broadcast(&self, mut broadcast: media::Broadcast) -> anyhow::Result<()> { + let mut tasks = JoinSet::new(); + let mut done = false; + + loop { + tokio::select! { + // Accept new tracks added to the broadcast. + track = broadcast.tracks.next(), if !done => { + match track { + Some(track) => { + let session = self.clone(); + + tasks.spawn(async move { + session.serve_track(track).await + }); + }, + None => done = true, + } + }, + // Poll any pending tracks until they exit. + res = tasks.join_next(), if !tasks.is_empty() => { + let res = res.context("no tracks running")?; + let res = res.context("failed to run track")?; + res.context("failed to serve track")?; + }, + else => return Ok(()), + } + } + } + + pub async fn serve_track(&self, mut track: media::Track) -> anyhow::Result<()> { + let mut tasks = JoinSet::new(); + let mut done = false; + + loop { + tokio::select! { + // Accept new tracks added to the broadcast. + segment = track.segments.next(), if !done => { + match segment { + Some(segment) => { + let track = track.clone(); + let session = self.clone(); + + tasks.spawn(async move { + session.serve_segment(track, segment).await + }); + }, + None => done = true, + } + }, + // Poll any pending segments until they exit. + res = tasks.join_next(), if !tasks.is_empty() => { + let res = res.context("no tasks running")?; + let res = res.context("failed to run segment")?; + res.context("failed serve segment")? + }, + else => return Ok(()), + } + } + } + + pub async fn serve_segment(&self, track: media::Track, mut segment: media::Segment) -> anyhow::Result<()> { + let mut stream = self.transport.open_uni(self.transport.session_id()).await?; + + // TODO support prioirty + // stream.set_priority(0); + + // Encode a JSON header indicating this is a new segment. + let mut message: message::Message = message::Message::new(); + + // TODO combine init and segment messages into one. + if track.id == 0xff { + message.init = Some(message::Init {}); + } else { + message.segment = Some(message::Segment { track_id: track.id }); + } + + // Write the JSON header. + let data = message.serialize()?; + stream.write_all(data.as_slice()).await?; + + // Write each fragment as they are available. + while let Some(fragment) = segment.fragments.next().await { + stream.write_all(fragment.as_slice()).await?; + } + + // NOTE: stream is automatically closed when dropped + + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 6715979..3dccb66 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,2 @@ +pub mod app; pub mod media; -pub mod session; -pub mod transport; diff --git a/src/main.rs b/src/main.rs index d427dfd..7aaa38e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,7 @@ -use std::io::BufReader; -use std::net::SocketAddr; -use std::{fs::File, sync::Arc}; - -use moq::{session, transport}; +use moq::{app, media}; +use std::{fs, io, net, path, sync}; +use anyhow::Context; use clap::Parser; use ring::digest::{digest, SHA256}; use warp::Filter; @@ -13,54 +11,57 @@ use warp::Filter; struct Cli { /// Listen on this address #[arg(short, long, default_value = "[::]:4443")] - addr: String, + addr: net::SocketAddr, /// Use the certificate file at this path #[arg(short, long, default_value = "cert/localhost.crt")] - cert: String, + cert: path::PathBuf, /// Use the private key at this path #[arg(short, long, default_value = "cert/localhost.key")] - key: String, + key: path::PathBuf, /// Use the media file at this path #[arg(short, long, default_value = "media/fragmented.mp4")] - media: String, + media: path::PathBuf, } #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::init(); - let moq_args = Cli::parse(); - let http_args = moq_args.clone(); + let args = Cli::parse(); - // TODO return result instead of panicing - tokio::task::spawn(async move { run_transport(moq_args).unwrap() }); + // Create a web server to serve the fingerprint + let serve = serve_http(args.clone()); - run_http(http_args).await -} + // Create a fake media source from disk. + let mut media = media::Source::new(args.media).context("failed to open fragmented.mp4")?; -// Run the WebTransport server using quiche. -fn run_transport(args: Cli) -> anyhow::Result<()> { - let server_config = transport::Config { + // Create a server to actually serve the media + let config = app::ServerConfig { addr: args.addr, cert: args.cert, key: args.key, + broadcast: media.broadcast(), }; - let mut server = transport::Server::::new(server_config).unwrap(); - server.run() + let mut server = app::Server::new(config).context("failed to create server")?; + + // Run all of the above + tokio::select! { + res = server.run() => res.context("failed to run server"), + res = media.run() => res.context("failed to run media source"), + res = serve => res.context("failed to run HTTP server"), + } } // Run a HTTP server using Warp // TODO remove this when Chrome adds support for self-signed certificates using WebTransport -async fn run_http(args: Cli) -> anyhow::Result<()> { - let addr: SocketAddr = args.addr.parse()?; - +async fn serve_http(args: Cli) -> anyhow::Result<()> { // Read the PEM certificate file - let crt = File::open(&args.cert)?; - let mut crt = BufReader::new(crt); + let crt = fs::File::open(&args.cert)?; + let mut crt = io::BufReader::new(crt); // Parse the DER certificate let certs = rustls_pemfile::certs(&mut crt)?; @@ -69,7 +70,7 @@ async fn run_http(args: Cli) -> anyhow::Result<()> { // Compute the SHA-256 digest let fingerprint = digest(&SHA256, cert.as_ref()); let fingerprint = hex::encode(fingerprint.as_ref()); - let fingerprint = Arc::new(fingerprint); + let fingerprint = sync::Arc::new(fingerprint); let cors = warp::cors().allow_any_origin(); @@ -83,7 +84,7 @@ async fn run_http(args: Cli) -> anyhow::Result<()> { .tls() .cert_path(args.cert) .key_path(args.key) - .run(addr) + .run(args.addr) .await; Ok(()) diff --git a/src/media/mod.rs b/src/media/mod.rs index e7ec5eb..923cde4 100644 --- a/src/media/mod.rs +++ b/src/media/mod.rs @@ -1,3 +1,8 @@ mod source; +pub use source::Source; -pub use source::{Fragment, Source}; +mod model; +pub use model::*; + +mod watch; +use watch::{Producer, Subscriber}; diff --git a/src/media/model.rs b/src/media/model.rs new file mode 100644 index 0000000..bab22bd --- /dev/null +++ b/src/media/model.rs @@ -0,0 +1,28 @@ +use super::Subscriber; +use std::{sync, time}; + +#[derive(Clone)] +pub struct Broadcast { + pub tracks: Subscriber, +} + +#[derive(Clone)] +pub struct Track { + // The track ID as stored in the MP4 + pub id: u32, + + // A list of segments, which are independently decodable. + pub segments: Subscriber, +} + +#[derive(Clone)] +pub struct Segment { + // The timestamp of the segment. + pub timestamp: time::Duration, + + // A list of fragments that make up the segment. + pub fragments: Subscriber, +} + +// Use Arc to avoid cloning the entire MP4 data for each subscriber. +pub type Fragment = sync::Arc>; diff --git a/src/media/source.rs b/src/media/source.rs index 9a87c0a..e0f8e29 100644 --- a/src/media/source.rs +++ b/src/media/source.rs @@ -1,51 +1,32 @@ -use std::collections::VecDeque; use std::io::Read; -use std::{fs, io, time}; + +use std::{fs, io, path, time}; use anyhow; use mp4; use mp4::ReadBox; +use anyhow::Context; +use std::collections::HashMap; + +use super::{Broadcast, Fragment, Producer, Segment, Track}; + pub struct Source { // We read the file once, in order, and don't seek backwards. reader: io::BufReader, - // The timestamp when the broadcast "started", so we can sleep to simulate a live stream. - start: time::Instant, + // The tracks we're producing + broadcast: Broadcast, - // The initialization payload; ftyp + moov boxes. - pub init: Vec, - - // The parsed moov box. - moov: mp4::MoovBox, - - // Any fragments parsed and ready to be returned by next(). - fragments: VecDeque, -} - -pub struct Fragment { - // The track ID for the fragment. - pub track_id: u32, - - // The data of the fragment. - pub data: Vec, - - // Whether this fragment is a keyframe. - pub keyframe: bool, - - // The number of samples that make up a second (ex. ms = 1000) - pub timescale: u64, - - // The timestamp of the fragment, in timescale units, to simulate a live stream. - pub timestamp: u64, + // The tracks we're producing. + tracks: HashMap, } impl Source { - pub fn new(path: &str) -> anyhow::Result { + pub fn new(path: path::PathBuf) -> anyhow::Result { let f = fs::File::open(path)?; let mut reader = io::BufReader::new(f); - let start = time::Instant::now(); let ftyp = read_atom(&mut reader)?; anyhow::ensure!(&ftyp[4..8] == b"ftyp", "expected ftyp atom"); @@ -64,28 +45,72 @@ impl Source { // Parse the moov box so we can detect the timescales for each track. let moov = mp4::MoovBox::read_box(&mut moov_reader, moov_header.size)?; + // Create a producer to populate the tracks. + let mut tracks = Producer::::new(); + + let broadcast = Broadcast { + tracks: tracks.subscribe(), + }; + + // Create the init track + let init_track = Self::create_init_track(init); + tracks.push(init_track); + + // Create a map with the current segment for each track. + // NOTE: We don't add the init track to this, since it's not part of the MP4. + let mut lookup = HashMap::new(); + + for trak in &moov.traks { + let track_id = trak.tkhd.track_id; + anyhow::ensure!(track_id != 0xff, "track ID 0xff is reserved"); + + let timescale = track_timescale(&moov, track_id); + + let segments = Producer::::new(); + + tracks.push(Track { + id: track_id, + segments: segments.subscribe(), + }); + + // Store the track publisher in a map so we can update it later. + let track = SourceTrack::new(segments, timescale); + lookup.insert(track_id, track); + } + Ok(Self { reader, - start, - init, - moov, - fragments: VecDeque::new(), + broadcast, + tracks: lookup, }) } - pub fn fragment(&mut self) -> anyhow::Result> { - if self.fragments.is_empty() { - self.parse()?; - }; + // Create an init track + fn create_init_track(raw: Vec) -> Track { + // TODO support static producers + let mut fragments = Producer::::new(); + let mut segments = Producer::::new(); - if self.timeout().is_some() { - return Ok(None); + fragments.push(raw.into()); + + segments.push(Segment { + fragments: fragments.subscribe(), + timestamp: time::Duration::ZERO, + }); + + Track { + id: 0xff, + segments: segments.subscribe(), } - - Ok(self.fragments.pop_front()) } - fn parse(&mut self) -> anyhow::Result<()> { + pub async fn run(&mut self) -> anyhow::Result<()> { + // The timestamp when the broadcast "started", so we can sleep to simulate a live stream. + let start = tokio::time::Instant::now(); + + // The ID of the last moof header. + let mut track_id = None; + loop { let atom = read_atom(&mut self.reader)?; @@ -93,46 +118,33 @@ impl Source { let header = mp4::BoxHeader::read(&mut reader)?; match header.name { - mp4::BoxType::FtypBox | mp4::BoxType::MoovBox => { - anyhow::bail!("must call init first") - } mp4::BoxType::MoofBox => { - let moof = mp4::MoofBox::read_box(&mut reader, header.size)?; + let moof = mp4::MoofBox::read_box(&mut reader, header.size).context("failed to read MP4")?; - if moof.trafs.len() != 1 { - // We can't split the mdat atom, so this is impossible to support - anyhow::bail!("multiple tracks per moof atom") - } + // Process the moof. + let fragment = SourceFragment::new(moof)?; - let track_id = moof.trafs[0].tfhd.track_id; - let timestamp = sample_timestamp(&moof).expect("couldn't find timestamp"); + // Get the track for this moof. + let track = self.tracks.get_mut(&fragment.track).context("failed to find track")?; - // Detect if this is a keyframe. - let keyframe = sample_keyframe(&moof); + // Sleep until we should publish this sample. + let timestamp = time::Duration::from_millis(1000 * fragment.timestamp / track.timescale); + tokio::time::sleep_until(start + timestamp).await; - let timescale = track_timescale(&self.moov, track_id); + // Save the track ID for the next iteration, which must be a mdat. + anyhow::ensure!(track_id.is_none(), "multiple moof atoms"); + track_id.replace(fragment.track); - self.fragments.push_back(Fragment { - track_id, - data: atom, - keyframe, - timescale, - timestamp, - }) + // Publish the moof header, creating a new segment if it's a keyframe. + track.header(atom, fragment).context("failed to publish moof")?; } mp4::BoxType::MdatBox => { - let moof = self.fragments.back().expect("no atom before mdat"); + // Get the track ID from the previous moof. + let track_id = track_id.take().context("missing moof")?; + let track = self.tracks.get_mut(&track_id).context("failed to find track")?; - self.fragments.push_back(Fragment { - track_id: moof.track_id, - data: atom, - keyframe: false, - timescale: moof.timescale, - timestamp: moof.timestamp, - }); - - // We have some media data, return so we can start sending it. - return Ok(()); + // Publish the mdat atom. + track.data(atom).context("failed to publish mdat")?; } _ => { // Skip unknown atoms @@ -141,19 +153,108 @@ impl Source { } } - // Simulate a live stream by sleeping until the next timestamp in the media. - pub fn timeout(&self) -> Option { - let next = self.fragments.front()?; + pub fn broadcast(&self) -> Broadcast { + self.broadcast.clone() + } +} - let delay = time::Duration::from_millis(1000 * next.timestamp / next.timescale); - let elapsed = self.start.elapsed(); +struct SourceTrack { + // The track we're producing + segments: Producer, - delay.checked_sub(elapsed) + // The current segment's fragments + fragments: Option>, + + // The number of units per second. + timescale: u64, +} + +impl SourceTrack { + fn new(segments: Producer, timescale: u64) -> Self { + Self { + segments, + fragments: None, + timescale, + } + } + + pub fn header(&mut self, raw: Vec, fragment: SourceFragment) -> anyhow::Result<()> { + // Close the current segment if we have a new keyframe. + if fragment.keyframe { + self.fragments.take(); + } + + // Get or create the current segment. + let fragments = self.fragments.get_or_insert_with(|| { + // Compute the timestamp in seconds. + let timestamp = fragment.timestamp(self.timescale); + + // Create a new segment, and save the fragments producer so we can push to it. + let fragments = Producer::::new(); + self.segments.push(Segment { + timestamp, + fragments: fragments.subscribe(), + }); + + // Remove any segments older than 10s. + let expires = timestamp.saturating_sub(time::Duration::from_secs(10)); + self.segments.drain(|segment| segment.timestamp < expires); + + fragments + }); + + // Insert the raw atom into the segment. + fragments.push(raw.into()); + + Ok(()) + } + + pub fn data(&mut self, raw: Vec) -> anyhow::Result<()> { + let fragments = self.fragments.as_mut().context("missing keyframe")?; + fragments.push(raw.into()); + + Ok(()) + } +} + +struct SourceFragment { + // The track for this fragment. + track: u32, + + // The timestamp of the first sample in this fragment, in timescale units. + timestamp: u64, + + // True if this fragment is a keyframe. + keyframe: bool, +} + +impl SourceFragment { + fn new(moof: mp4::MoofBox) -> anyhow::Result { + // We can't split the mdat atom, so this is impossible to support + anyhow::ensure!(moof.trafs.len() == 1, "multiple tracks per moof atom"); + let track = moof.trafs[0].tfhd.track_id; + + // Parse the moof to get some timing information to sleep. + let timestamp = sample_timestamp(&moof).expect("couldn't find timestamp"); + + // Detect if we should start a new segment. + let keyframe = sample_keyframe(&moof); + + Ok(Self { + track, + timestamp, + keyframe, + }) + } + + // Convert from timescale units to a duration. + fn timestamp(&self, timescale: u64) -> time::Duration { + time::Duration::from_millis(1000 * self.timestamp / timescale) } } // Read a full MP4 atom into a vector. -pub fn read_atom(reader: &mut R) -> anyhow::Result> { +fn read_atom(reader: &mut R) -> anyhow::Result> { // Read the 8 bytes for the size + type let mut buf = [0u8; 8]; reader.read_exact(&mut buf)?; diff --git a/src/media/watch.rs b/src/media/watch.rs new file mode 100644 index 0000000..6b98f45 --- /dev/null +++ b/src/media/watch.rs @@ -0,0 +1,127 @@ +use std::collections::VecDeque; +use tokio::sync::watch; + +#[derive(Default)] +struct State { + queue: VecDeque, + drained: usize, + closed: bool, +} + +impl State { + pub fn new() -> Self { + Self { + queue: VecDeque::new(), + drained: 0, + closed: false, + } + } + + // Add a new element to the end of the queue. + fn push(&mut self, t: T) { + self.queue.push_back(t) + } + + // Remove elements from the head of the queue if they match the conditional. + fn drain(&mut self, f: F) -> usize + where + F: Fn(&T) -> bool, + { + let prior = self.drained; + + while let Some(first) = self.queue.front() { + if !f(first) { + break; + } + + self.queue.pop_front(); + self.drained += 1; + } + + self.drained - prior + } +} + +pub struct Producer { + sender: watch::Sender>, +} + +impl Producer { + pub fn new() -> Self { + let state = State::new(); + let (sender, _) = watch::channel(state); + Self { sender } + } + + // Push a new element to the end of the queue. + pub fn push(&mut self, value: T) { + self.sender.send_modify(|state| state.push(value)); + } + + // Remove any elements from the front of the queue that match the condition. + pub fn drain(&mut self, f: F) + where + F: Fn(&T) -> bool, + { + // Use send_if_modified to never notify with the updated state. + self.sender.send_if_modified(|state| { + state.drain(f); + false + }); + } + + pub fn subscribe(&self) -> Subscriber { + Subscriber::new(self.sender.subscribe()) + } +} + +impl Default for Producer { + fn default() -> Self { + Self::new() + } +} + +impl Drop for Producer { + fn drop(&mut self) { + self.sender.send_modify(|state| state.closed = true); + } +} + +#[derive(Clone)] +pub struct Subscriber { + state: watch::Receiver>, + index: usize, +} + +impl Subscriber { + fn new(state: watch::Receiver>) -> Self { + Self { state, index: 0 } + } + + pub async fn next(&mut self) -> Option { + // Wait until the queue has a new element or if it's closed. + let state = self + .state + .wait_for(|state| state.closed || self.index < state.drained + state.queue.len()) + .await + .expect("publisher dropped without close"); + + // If our index is smaller than drained, skip past those elements we missed. + let index = self.index.saturating_sub(state.drained); + + if index < state.queue.len() { + // Clone the next element in the queue. + let element = state.queue[index].clone(); + + // Increment our index, relative to drained so we can skip ahead if needed. + self.index = index + state.drained + 1; + + Some(element) + } else if state.closed { + // Return None if we've consumed all entries and the queue is closed. + None + } else { + panic!("impossible subscriber state") + } + } +} diff --git a/src/session/mod.rs b/src/session/mod.rs deleted file mode 100644 index eca556b..0000000 --- a/src/session/mod.rs +++ /dev/null @@ -1,168 +0,0 @@ -mod message; - -use std::collections::hash_map as hmap; -use std::time; - -use quiche; -use quiche::h3::webtransport; - -use crate::{media, transport}; - -#[derive(Default)] -pub struct Session { - // The media source, configured on CONNECT. - media: Option, - - // A helper for automatically buffering stream data. - streams: transport::Streams, - - // Map from track_id to the the Track state. - tracks: hmap::HashMap, -} - -pub struct Track { - // Current stream_id - stream_id: Option, - - // The timescale used for this track. - timescale: u64, - - // The timestamp of the last keyframe. - keyframe: u64, -} - -impl transport::App for Session { - // Process any updates to a session. - fn poll(&mut self, conn: &mut quiche::Connection, session: &mut webtransport::ServerSession) -> anyhow::Result<()> { - loop { - let event = match session.poll(conn) { - Err(webtransport::Error::Done) => break, - Err(e) => return Err(e.into()), - Ok(e) => e, - }; - - log::debug!("webtransport event {:?}", event); - - match event { - webtransport::ServerEvent::ConnectRequest(_req) => { - // you can handle request with - // req.authority() - // req.path() - // and you can validate this request with req.origin() - session.accept_connect_request(conn, None)?; - - // TODO - let media = media::Source::new("media/fragmented.mp4").expect("failed to open fragmented.mp4"); - let init = &media.init; - - // Create a JSON header. - let mut message = message::Message::new(); - message.init = Some(message::Init {}); - let data = message.serialize()?; - - // Create a new stream and write the header. - let stream_id = session.open_stream(conn, false)?; - self.streams.send(conn, stream_id, data.as_slice(), false)?; - self.streams.send(conn, stream_id, init.as_slice(), true)?; - - self.media = Some(media); - } - webtransport::ServerEvent::StreamData(stream_id) => { - let mut buf = vec![0; 10000]; - while let Ok(len) = session.recv_stream_data(conn, stream_id, &mut buf) { - let _stream_data = &buf[0..len]; - } - } - - _ => {} - } - } - - // Send any pending stream data. - // NOTE: This doesn't return an error because it's async, and would be confusing. - self.streams.poll(conn); - - // Fetch the next media fragment, possibly queuing up stream data. - self.poll_source(conn, session)?; - - Ok(()) - } - - fn timeout(&self) -> Option { - self.media.as_ref().and_then(|m| m.timeout()) - } -} - -impl Session { - fn poll_source( - &mut self, - conn: &mut quiche::Connection, - session: &mut webtransport::ServerSession, - ) -> anyhow::Result<()> { - // Get the media source once the connection is established. - let media = match &mut self.media { - Some(m) => m, - None => return Ok(()), - }; - - // Get the next media fragment. - let fragment = match media.fragment()? { - Some(f) => f, - None => return Ok(()), - }; - - // Get the track state or insert a new entry. - let track = self.tracks.entry(fragment.track_id).or_insert_with(|| Track { - stream_id: None, - timescale: fragment.timescale, - keyframe: 0, - }); - - if let Some(stream_id) = track.stream_id { - // Existing stream, check if we should close it. - if fragment.keyframe && fragment.timestamp >= track.keyframe + track.timescale { - // Close the existing stream - self.streams.send(conn, stream_id, &[], true)?; - - // Unset the stream id so we create a new one. - track.stream_id = None; - track.keyframe = fragment.timestamp; - } - } - - let stream_id = match track.stream_id { - Some(stream_id) => stream_id, - None => { - // Create a new unidirectional stream. - let stream_id = session.open_stream(conn, false)?; - - // Set the stream priority to be equal to the timestamp. - // We subtract from u64::MAX so newer media is sent important. - // TODO prioritize audio - let order = u64::MAX - fragment.timestamp; - self.streams.send_order(conn, stream_id, order); - - // Encode a JSON header indicating this is a new track. - let mut message: message::Message = message::Message::new(); - message.segment = Some(message::Segment { - track_id: fragment.track_id, - }); - - // Write the header. - let data = message.serialize()?; - self.streams.send(conn, stream_id, &data, false)?; - - stream_id - } - }; - - // Write the current fragment. - let data = fragment.data.as_slice(); - self.streams.send(conn, stream_id, data, false)?; - - // Save the stream_id for the next fragment. - track.stream_id = Some(stream_id); - - Ok(()) - } -} diff --git a/src/transport/app.rs b/src/transport/app.rs deleted file mode 100644 index 2c5ae3f..0000000 --- a/src/transport/app.rs +++ /dev/null @@ -1,8 +0,0 @@ -use std::time; - -use quiche::h3::webtransport; - -pub trait App: Default { - fn poll(&mut self, conn: &mut quiche::Connection, session: &mut webtransport::ServerSession) -> anyhow::Result<()>; - fn timeout(&self) -> Option; -} diff --git a/src/transport/connection.rs b/src/transport/connection.rs deleted file mode 100644 index 5d806f5..0000000 --- a/src/transport/connection.rs +++ /dev/null @@ -1,15 +0,0 @@ -use quiche; -use quiche::h3::webtransport; - -use std::collections::hash_map as hmap; - -pub type Id = quiche::ConnectionId<'static>; - -use super::app; - -pub type Map = hmap::HashMap>; -pub struct Connection { - pub quiche: quiche::Connection, - pub session: Option, - pub app: T, -} diff --git a/src/transport/mod.rs b/src/transport/mod.rs deleted file mode 100644 index 60ef79d..0000000 --- a/src/transport/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod app; -mod connection; -mod server; -mod streams; - -pub use app::App; -pub use server::{Config, Server}; -pub use streams::Streams; diff --git a/src/transport/server.rs b/src/transport/server.rs deleted file mode 100644 index cb6854c..0000000 --- a/src/transport/server.rs +++ /dev/null @@ -1,378 +0,0 @@ -use std::io; - -use quiche::h3::webtransport; - -use super::app; -use super::connection; - -const MAX_DATAGRAM_SIZE: usize = 1350; - -pub struct Server { - // IO stuff - socket: mio::net::UdpSocket, - poll: mio::Poll, - events: mio::Events, - - // QUIC stuff - quic: quiche::Config, - seed: ring::hmac::Key, // connection ID seed - - conns: connection::Map, -} - -pub struct Config { - pub addr: String, - pub cert: String, - pub key: String, -} - -impl Server { - pub fn new(config: Config) -> io::Result { - // Listen on the provided socket address - let addr = config.addr.parse().unwrap(); - let mut socket = mio::net::UdpSocket::bind(addr).unwrap(); - - // Setup the event loop. - let poll = mio::Poll::new().unwrap(); - let events = mio::Events::with_capacity(1024); - - poll.registry() - .register(&mut socket, mio::Token(0), mio::Interest::READABLE) - .unwrap(); - - // Generate random values for connection IDs. - let rng = ring::rand::SystemRandom::new(); - let seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap(); - - // Create the configuration for the QUIC conns. - let mut quic = quiche::Config::new(quiche::PROTOCOL_VERSION).unwrap(); - quic.load_cert_chain_from_pem_file(&config.cert).unwrap(); - quic.load_priv_key_from_pem_file(&config.key).unwrap(); - quic.set_application_protos(quiche::h3::APPLICATION_PROTOCOL).unwrap(); - quic.set_max_idle_timeout(5000); - quic.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); - quic.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE); - quic.set_initial_max_data(10_000_000); - quic.set_initial_max_stream_data_bidi_local(1_000_000); - quic.set_initial_max_stream_data_bidi_remote(1_000_000); - quic.set_initial_max_stream_data_uni(1_000_000); - quic.set_initial_max_streams_bidi(100); - quic.set_initial_max_streams_uni(100); - quic.set_disable_active_migration(true); - quic.enable_early_data(); - quic.enable_dgram(true, 65536, 65536); - - let conns = Default::default(); - - Ok(Server { - socket, - poll, - events, - - quic, - seed, - - conns, - }) - } - - pub fn run(&mut self) -> anyhow::Result<()> { - log::info!("listening on {}", self.socket.local_addr()?); - - loop { - self.wait()?; - self.receive()?; - self.app()?; - self.send()?; - self.cleanup(); - } - } - - pub fn wait(&mut self) -> anyhow::Result<()> { - // Find the shorter timeout from all the active connections. - // - // TODO: use event loop that properly supports timers - let timeout = self - .conns - .values() - .filter_map(|c| { - let timeout = c.quiche.timeout(); - let expires = c.app.timeout(); - - match (timeout, expires) { - (Some(a), Some(b)) => Some(a.min(b)), - (Some(a), None) => Some(a), - (None, Some(b)) => Some(b), - (None, None) => None, - } - }) - .min(); - - self.poll.poll(&mut self.events, timeout).unwrap(); - - // If the event loop reported no events, it means that the timeout - // has expired, so handle it without attempting to read packets. We - // will then proceed with the send loop. - if self.events.is_empty() { - for conn in self.conns.values_mut() { - conn.quiche.on_timeout(); - } - } - - Ok(()) - } - - // Reads packets from the socket, updating any internal connection state. - fn receive(&mut self) -> anyhow::Result<()> { - let mut src = [0; MAX_DATAGRAM_SIZE]; - - // Try reading any data currently available on the socket. - loop { - let (len, from) = match self.socket.recv_from(&mut src) { - Ok(v) => v, - Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => return Ok(()), - Err(e) => return Err(e.into()), - }; - - let src = &mut src[..len]; - - let info = quiche::RecvInfo { - to: self.socket.local_addr().unwrap(), - from, - }; - - // Parse the QUIC packet's header. - let hdr = quiche::Header::from_slice(src, quiche::MAX_CONN_ID_LEN).unwrap(); - - let conn_id = ring::hmac::sign(&self.seed, &hdr.dcid); - let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN]; - let conn_id = conn_id.to_vec().into(); - - // Check if it's an existing connection. - if let Some(conn) = self.conns.get_mut(&hdr.dcid) { - conn.quiche.recv(src, info)?; - - if conn.session.is_none() && conn.quiche.is_established() { - conn.session = Some(webtransport::ServerSession::with_transport(&mut conn.quiche)?) - } - - continue; - } else if let Some(conn) = self.conns.get_mut(&conn_id) { - conn.quiche.recv(src, info)?; - - // TODO is this needed here? - if conn.session.is_none() && conn.quiche.is_established() { - conn.session = Some(webtransport::ServerSession::with_transport(&mut conn.quiche)?) - } - - continue; - } - - if hdr.ty != quiche::Type::Initial { - log::warn!("unknown connection ID"); - continue; - } - - let mut dst = [0; MAX_DATAGRAM_SIZE]; - - if !quiche::version_is_supported(hdr.version) { - let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut dst).unwrap(); - let dst = &dst[..len]; - - self.socket.send_to(dst, from).unwrap(); - continue; - } - - let mut scid = [0; quiche::MAX_CONN_ID_LEN]; - scid.copy_from_slice(&conn_id); - - let scid = quiche::ConnectionId::from_ref(&scid); - - // Token is always present in Initial packets. - let token = hdr.token.as_ref().unwrap(); - - // Do stateless retry if the client didn't send a token. - if token.is_empty() { - let new_token = mint_token(&hdr, &from); - - let len = quiche::retry(&hdr.scid, &hdr.dcid, &scid, &new_token, hdr.version, &mut dst).unwrap(); - - let dst = &dst[..len]; - - self.socket.send_to(dst, from).unwrap(); - continue; - } - - let odcid = validate_token(&from, token); - - // The token was not valid, meaning the retry failed, so - // drop the packet. - if odcid.is_none() { - log::warn!("invalid token"); - continue; - } - - if scid.len() != hdr.dcid.len() { - log::warn!("invalid connection ID"); - continue; - } - - // Reuse the source connection ID we sent in the Retry packet, - // instead of changing it again. - let conn_id = hdr.dcid.clone(); - let local_addr = self.socket.local_addr().unwrap(); - - log::debug!("new connection: dcid={:?} scid={:?}", hdr.dcid, scid); - - let mut conn = quiche::accept(&conn_id, odcid.as_ref(), local_addr, from, &mut self.quic)?; - - // Log each session with QLOG if the ENV var is set. - if let Some(dir) = std::env::var_os("QLOGDIR") { - let id = format!("{:?}", &scid); - - let mut path = std::path::PathBuf::from(dir); - let filename = format!("server-{id}.sqlog"); - path.push(filename); - - let writer = match std::fs::File::create(&path) { - Ok(f) => std::io::BufWriter::new(f), - - Err(e) => panic!("Error creating qlog file attempted path was {:?}: {}", path, e), - }; - - conn.set_qlog( - std::boxed::Box::new(writer), - "warp-server qlog".to_string(), - format!("{} id={}", "warp-server qlog", id), - ); - } - - // Process potentially coalesced packets. - conn.recv(src, info)?; - - let user = connection::Connection { - quiche: conn, - session: None, - app: T::default(), - }; - - self.conns.insert(conn_id, user); - } - } - - pub fn app(&mut self) -> anyhow::Result<()> { - for conn in self.conns.values_mut() { - if conn.quiche.is_closed() { - continue; - } - - if let Some(session) = &mut conn.session { - if let Err(e) = conn.app.poll(&mut conn.quiche, session) { - log::debug!("app error: {:?}", e); - - // Close the connection on any application error - let reason = format!("app error: {:?}", e); - conn.quiche.close(true, 0xff, reason.as_bytes()).ok(); - } - } - } - - Ok(()) - } - - // Generate outgoing QUIC packets for all active connections and send - // them on the UDP socket, until quiche reports that there are no more - // packets to be sent. - pub fn send(&mut self) -> anyhow::Result<()> { - for conn in self.conns.values_mut() { - let conn = &mut conn.quiche; - - if let Err(e) = send_conn(&self.socket, conn) { - log::error!("{} send failed: {:?}", conn.trace_id(), e); - conn.close(false, 0x1, b"fail").ok(); - } - } - - Ok(()) - } - - pub fn cleanup(&mut self) { - // Garbage collect closed connections. - self.conns.retain(|_, ref mut c| !c.quiche.is_closed()); - } -} - -// Send any pending packets for the connection over the socket. -fn send_conn(socket: &mio::net::UdpSocket, conn: &mut quiche::Connection) -> anyhow::Result<()> { - let mut pkt = [0; MAX_DATAGRAM_SIZE]; - - loop { - let (size, info) = match conn.send(&mut pkt) { - Ok(v) => v, - Err(quiche::Error::Done) => return Ok(()), - Err(e) => return Err(e.into()), - }; - - let pkt = &pkt[..size]; - - match socket.send_to(pkt, info.to) { - Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(()), - Err(e) => return Err(e.into()), - Ok(_) => (), - } - } -} - -/// Generate a stateless retry token. -/// -/// The token includes the static string `"quiche"` followed by the IP address -/// of the client and by the original destination connection ID generated by the -/// client. -/// -/// Note that this function is only an example and doesn't do any cryptographic -/// authenticate of the token. *It should not be used in production system*. -fn mint_token(hdr: &quiche::Header, src: &std::net::SocketAddr) -> Vec { - let mut token = Vec::new(); - - token.extend_from_slice(b"quiche"); - - let addr = match src.ip() { - std::net::IpAddr::V4(a) => a.octets().to_vec(), - std::net::IpAddr::V6(a) => a.octets().to_vec(), - }; - - token.extend_from_slice(&addr); - token.extend_from_slice(&hdr.dcid); - - token -} - -/// Validates a stateless retry token. -/// -/// This checks that the ticket includes the `"quiche"` static string, and that -/// the client IP address matches the address stored in the ticket. -/// -/// Note that this function is only an example and doesn't do any cryptographic -/// authenticate of the token. *It should not be used in production system*. -fn validate_token<'a>(src: &std::net::SocketAddr, token: &'a [u8]) -> Option> { - if token.len() < 6 { - return None; - } - - if &token[..6] != b"quiche" { - return None; - } - - let token = &token[6..]; - - let addr = match src.ip() { - std::net::IpAddr::V4(a) => a.octets().to_vec(), - std::net::IpAddr::V6(a) => a.octets().to_vec(), - }; - - if token.len() < addr.len() || &token[..addr.len()] != addr.as_slice() { - return None; - } - - Some(quiche::ConnectionId::from_ref(&token[addr.len()..])) -} diff --git a/src/transport/streams.rs b/src/transport/streams.rs deleted file mode 100644 index 862da0a..0000000 --- a/src/transport/streams.rs +++ /dev/null @@ -1,136 +0,0 @@ -use std::collections::VecDeque; - -use anyhow; -use quiche; - -#[derive(Default)] -pub struct Streams { - ordered: Vec, -} - -struct Stream { - id: u64, - order: u64, - - buffer: VecDeque, - fin: bool, -} - -impl Streams { - // Write the data to the given stream, buffering it if needed. - pub fn send(&mut self, conn: &mut quiche::Connection, id: u64, buf: &[u8], fin: bool) -> anyhow::Result<()> { - if buf.is_empty() && !fin { - return Ok(()); - } - - // Get the index of the stream, or add it to the list of streams. - let pos = self.ordered.iter().position(|s| s.id == id).unwrap_or_else(|| { - // Create a new stream - let stream = Stream { - id, - buffer: VecDeque::new(), - fin: false, - order: 0, // Default to highest priority until send_order is called. - }; - - self.insert(conn, stream) - }); - - let stream = &mut self.ordered[pos]; - - // Check if we've already closed the stream, just in case. - if stream.fin && !buf.is_empty() { - anyhow::bail!("stream is already finished"); - } - - // If there's no data buffered, try to write it immediately. - let size = if stream.buffer.is_empty() { - match conn.stream_send(id, buf, fin) { - Ok(size) => size, - Err(quiche::Error::Done) => 0, - Err(e) => anyhow::bail!(e), - } - } else { - 0 - }; - - if size < buf.len() { - // Short write, save the rest for later. - stream.buffer.extend(&buf[size..]); - } - - stream.fin |= fin; - - Ok(()) - } - - // Flush any pending stream data. - pub fn poll(&mut self, conn: &mut quiche::Connection) { - self.ordered.retain_mut(|s| s.poll(conn).is_ok()); - } - - // Set the send order of the stream. - pub fn send_order(&mut self, conn: &mut quiche::Connection, id: u64, order: u64) { - let mut stream = match self.ordered.iter().position(|s| s.id == id) { - // Remove the stream from the existing list. - Some(pos) => self.ordered.remove(pos), - - // This is a new stream, insert it into the list. - None => Stream { - id, - buffer: VecDeque::new(), - fin: false, - order, - }, - }; - - stream.order = order; - - self.insert(conn, stream); - } - - fn insert(&mut self, conn: &mut quiche::Connection, stream: Stream) -> usize { - // Look for the position to insert the stream. - let pos = match self.ordered.binary_search_by_key(&stream.order, |s| s.order) { - Ok(pos) | Err(pos) => pos, - }; - - self.ordered.insert(pos, stream); - - // Reprioritize all later streams. - // TODO we can avoid this if stream_priorty takes a u64 - for (i, stream) in self.ordered[pos..].iter().enumerate() { - _ = conn.stream_priority(stream.id, (pos + i) as u8, true); - } - - pos - } -} - -impl Stream { - fn poll(&mut self, conn: &mut quiche::Connection) -> quiche::Result<()> { - // Keep reading from the buffer until it's empty. - while !self.buffer.is_empty() { - // VecDeque is a ring buffer, so we can't write the whole thing at once. - let parts = self.buffer.as_slices(); - - let size = conn.stream_send(self.id, parts.0, false)?; - if size == 0 { - // No more space available for this stream. - return Ok(()); - } - - // Remove the bytes that were written. - self.buffer.drain(..size); - } - - if self.fin { - // Write the stream done signal. - conn.stream_send(self.id, &[], true)?; - - Err(quiche::Error::Done) - } else { - Ok(()) - } - } -}