Migrate to quinn and async Rust (#21)

I miss quiche, but it was a pain to do anything asynchronous. MoQ is a
pub/sub protocol so it's very important to support subscribers
joining/leaving/stalling. The API is also just significantly better
since quinn doesn't restrict itself to C bindings, which I'm sure will
come back to haunt me when we want OBS support.
This commit is contained in:
kixelated 2023-06-08 00:01:34 -07:00 committed by GitHub
parent 7cfa5faca2
commit c88f0b045a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 826 additions and 1053 deletions

392
Cargo.lock generated
View File

@ -11,15 +11,6 @@ dependencies = [
"memchr",
]
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "anstream"
version = "0.3.2"
@ -149,19 +140,6 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e3c5919066adf22df73762e50cffcde3a758f2a848b113b586d1f86728b673b"
dependencies = [
"iana-time-zone",
"num-integer",
"num-traits",
"serde",
"winapi",
]
[[package]]
name = "clap"
version = "4.3.0"
@ -204,27 +182,12 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b"
[[package]]
name = "cmake"
version = "0.1.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130"
dependencies = [
"cc",
]
[[package]]
name = "colorchoice"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "core-foundation-sys"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
name = "cpufeatures"
version = "0.2.7"
@ -244,41 +207,6 @@ dependencies = [
"typenum",
]
[[package]]
name = "darling"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0558d22a7b463ed0241e993f76f09f30b126687447751a8638587b864e4b3944"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab8bfa2e259f8ee1ce5e97824a3c55ec4404a0d772ca7fa96bf19f0752a046eb"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn",
]
[[package]]
name = "darling_macro"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29a358ff9f12ec09c3e61fef9b5a9902623a695a46a917b07f269bff1445611a"
dependencies = [
"darling_core",
"quote",
"syn",
]
[[package]]
name = "digest"
version = "0.10.7"
@ -332,6 +260,15 @@ dependencies = [
"libc",
]
[[package]]
name = "fastrand"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be"
dependencies = [
"instant",
]
[[package]]
name = "fnv"
version = "1.0.7"
@ -347,6 +284,21 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.28"
@ -363,6 +315,34 @@ version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c"
[[package]]
name = "futures-executor"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
[[package]]
name = "futures-macro"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.28"
@ -381,9 +361,13 @@ version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
@ -429,6 +413,48 @@ dependencies = [
"tracing",
]
[[package]]
name = "h3"
version = "0.0.2"
source = "git+https://github.com/security-union/h3?branch=add-webtransport#db5c723f653911a476bfd8ffcfebf0f8f2eb980d"
dependencies = [
"bytes",
"fastrand",
"futures-util",
"http",
"pin-project-lite",
"tokio",
"tracing",
]
[[package]]
name = "h3-quinn"
version = "0.0.2"
source = "git+https://github.com/security-union/h3?branch=add-webtransport#db5c723f653911a476bfd8ffcfebf0f8f2eb980d"
dependencies = [
"bytes",
"futures",
"h3",
"quinn",
"quinn-proto",
"tokio",
"tokio-util",
]
[[package]]
name = "h3-webtransport"
version = "0.1.0"
source = "git+https://github.com/security-union/h3?branch=add-webtransport#db5c723f653911a476bfd8ffcfebf0f8f2eb980d"
dependencies = [
"bytes",
"futures-util",
"h3",
"http",
"pin-project-lite",
"tokio",
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
@ -553,42 +579,13 @@ dependencies = [
"httpdate",
"itoa",
"pin-project-lite",
"socket2",
"socket2 0.4.9",
"tokio",
"tower-service",
"tracing",
"want",
]
[[package]]
name = "iana-time-zone"
version = "0.1.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0722cd7114b7de04316e7ea5456a0bbb20e4adb46fd27a3697adb812cff0f37c"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "0.3.0"
@ -607,7 +604,15 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [
"autocfg",
"hashbrown",
"serde",
]
[[package]]
name = "instant"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c"
dependencies = [
"cfg-if",
]
[[package]]
@ -648,24 +653,12 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.144"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b00cc1c228a6782d0f076e7b232802e0c5689d41bb5df366f2a6b6621cfdfe1"
[[package]]
name = "libm"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4"
[[package]]
name = "linux-raw-sys"
version = "0.3.8"
@ -730,14 +723,18 @@ name = "moq"
version = "0.1.0"
dependencies = [
"anyhow",
"bytes",
"clap",
"env_logger",
"futures",
"h3",
"h3-quinn",
"h3-webtransport",
"hex",
"http",
"log",
"mio",
"mp4",
"quiche",
"quinn",
"ring",
"rustls 0.21.1",
"rustls-pemfile",
@ -832,11 +829,6 @@ dependencies = [
"libc",
]
[[package]]
name = "octets"
version = "0.2.0"
source = "git+https://github.com/kixelated/quiche-webtransport.git?branch=master#007a25b35b9509d673466fed8ddc73fd8d9b4184"
[[package]]
name = "once_cell"
version = "1.17.1"
@ -920,33 +912,51 @@ dependencies = [
]
[[package]]
name = "qlog"
version = "0.9.0"
source = "git+https://github.com/kixelated/quiche-webtransport.git?branch=master#007a25b35b9509d673466fed8ddc73fd8d9b4184"
name = "quinn"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21252f1c0fc131f1b69182db8f34837e8a69737b8251dff75636a9be0518c324"
dependencies = [
"serde",
"serde_derive",
"serde_json",
"serde_with",
"smallvec",
"bytes",
"futures-io",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls 0.21.1",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "quiche"
version = "0.17.1"
source = "git+https://github.com/kixelated/quiche-webtransport.git?branch=master#007a25b35b9509d673466fed8ddc73fd8d9b4184"
name = "quinn-proto"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85af4ed6ee5a89f26a26086e9089a6643650544c025158449a3626ebf72884b3"
dependencies = [
"cmake",
"lazy_static",
"libc",
"libm",
"log",
"octets",
"qlog",
"bytes",
"rand",
"ring",
"rustc-hash",
"rustls 0.21.1",
"slab",
"smallvec",
"winapi",
"thiserror",
"tinyvec",
"tracing",
]
[[package]]
name = "quinn-udp"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6df19e284d93757a9fb91d63672f7741b129246a669db09d1c0063071debc0c0"
dependencies = [
"bytes",
"libc",
"socket2 0.5.3",
"tracing",
"windows-sys 0.48.0",
]
[[package]]
@ -1029,6 +1039,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustix"
version = "0.37.19"
@ -1140,7 +1156,6 @@ version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
dependencies = [
"indexmap",
"itoa",
"ryu",
"serde",
@ -1158,34 +1173,6 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_with"
version = "2.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07ff71d2c147a7b57362cead5e22f772cd52f6ab31cfcd9edcd7f6aeb2a0afbe"
dependencies = [
"base64 0.13.1",
"chrono",
"hex",
"indexmap",
"serde",
"serde_json",
"serde_with_macros",
"time",
]
[[package]]
name = "serde_with_macros"
version = "2.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "881b6f881b17d13214e5d494c939ebab463d01264ce1811e9d4ac3a882e7695f"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "sha1"
version = "0.10.5"
@ -1220,9 +1207,6 @@ name = "smallvec"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
dependencies = [
"serde",
]
[[package]]
name = "socket2"
@ -1234,6 +1218,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "socket2"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877"
dependencies = [
"libc",
"windows-sys 0.48.0",
]
[[package]]
name = "spin"
version = "0.5.2"
@ -1292,33 +1286,6 @@ dependencies = [
"syn",
]
[[package]]
name = "time"
version = "0.3.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f3403384eaacbca9923fa06940178ac13e4edb725486d70e8e15881d0c836cc"
dependencies = [
"itoa",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
[[package]]
name = "time-macros"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b"
dependencies = [
"time-core",
]
[[package]]
name = "tinyvec"
version = "1.6.0"
@ -1348,7 +1315,7 @@ dependencies = [
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"socket2",
"socket2 0.4.9",
"tokio-macros",
"windows-sys 0.48.0",
]
@ -1427,9 +1394,21 @@ dependencies = [
"cfg-if",
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tracing-core"
version = "0.1.31"
@ -1688,15 +1667,6 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f"
dependencies = [
"windows-targets 0.48.0",
]
[[package]]
name = "windows-sys"
version = "0.45.0"

View File

@ -6,22 +6,36 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
# Fork of quiche until they add WebTransport support.
quiche = { git = "https://github.com/kixelated/quiche-webtransport.git", branch = "master", features = [ "qlog" ] }
clap = { version = "4.0", features = [ "derive" ] }
log = { version = "0.4", features = ["std"] }
mio = { version = "0.8", features = ["net", "os-poll"] }
env_logger = "0.9.3"
# Fork of h3 with WebTransport support
h3 = { git = "https://github.com/security-union/h3", branch = "add-webtransport" }
h3-quinn = { git = "https://github.com/security-union/h3", branch = "add-webtransport" }
h3-webtransport = { git = "https://github.com/security-union/h3", branch = "add-webtransport" }
quinn = { version = "0.10", default-features = false, features = ["runtime-tokio", "tls-rustls", "ring"] }
# Crypto dependencies
ring = "0.16"
anyhow = "1.0.70"
rustls = { version = "0.21", features = ["dangerous_configuration"] }
rustls-pemfile = "1.0.2"
# Async stuff
tokio = { version = "1.27", features = ["full"] }
futures = "0.3"
# Media
mp4 = "0.13.0"
# Encoding
bytes = "1"
serde = "1.0.160"
serde_json = "1.0"
# Required to serve the fingerprint over HTTPS
tokio = { version = "1", features = ["full"] }
# Web server to serve the fingerprint
http = "0.2"
warp = { version = "0.3.3", features = ["tls"] }
rustls = "0.21"
rustls-pemfile = "1.0.2"
hex = "0.4.3"
http = "0.2.9"
# Logging and utility
clap = { version = "4.0", features = [ "derive" ] }
log = { version = "0.4", features = ["std"] }
env_logger = "0.9.3"
anyhow = "1.0.70"

BIN
src/.DS_Store vendored

Binary file not shown.

View File

@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Default)]
pub struct Message {
pub init: Option<Init>,
pub segment: Option<Segment>,
@ -16,10 +16,7 @@ pub struct Segment {
impl Message {
pub fn new() -> Self {
Message {
init: None,
segment: None,
}
Default::default()
}
pub fn serialize(&self) -> anyhow::Result<Vec<u8>> {

8
src/app/mod.rs Normal file
View File

@ -0,0 +1,8 @@
mod message;
mod server;
mod session;
pub use server::{Server, ServerConfig};
// Reduce the amount of typing
type WebTransportSession = h3_webtransport::server::WebTransportSession<h3_quinn::Connection, bytes::Bytes>;

121
src/app/server.rs Normal file
View File

@ -0,0 +1,121 @@
use super::session::Session;
use crate::media;
use std::{fs, io, net, path, sync, time};
use super::WebTransportSession;
use anyhow::Context;
pub struct Server {
// The QUIC server, yielding new connections and sessions.
server: quinn::Endpoint,
// The media source
broadcast: media::Broadcast,
}
pub struct ServerConfig {
pub addr: net::SocketAddr,
pub cert: path::PathBuf,
pub key: path::PathBuf,
pub broadcast: media::Broadcast,
}
impl Server {
// Create a new server
pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
// Read the PEM certificate chain
let certs = fs::File::open(config.cert).context("failed to open cert file")?;
let mut certs = io::BufReader::new(certs);
let certs = rustls_pemfile::certs(&mut certs)?
.into_iter()
.map(rustls::Certificate)
.collect();
// Read the PEM private key
let keys = fs::File::open(config.key).context("failed to open key file")?;
let mut keys = io::BufReader::new(keys);
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut keys)?;
anyhow::ensure!(keys.len() == 1, "expected a single key");
let key = rustls::PrivateKey(keys.remove(0));
let mut tls_config = rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap()
.with_no_client_auth()
.with_single_cert(certs, key)?;
tls_config.max_early_data_size = u32::MAX;
let alpn: Vec<Vec<u8>> = vec![
b"h3".to_vec(),
b"h3-32".to_vec(),
b"h3-31".to_vec(),
b"h3-30".to_vec(),
b"h3-29".to_vec(),
];
tls_config.alpn_protocols = alpn;
let mut server_config = quinn::ServerConfig::with_crypto(sync::Arc::new(tls_config));
let mut transport_config = quinn::TransportConfig::default();
transport_config.keep_alive_interval(Some(time::Duration::from_secs(2)));
server_config.transport = sync::Arc::new(transport_config);
let server = quinn::Endpoint::server(server_config, config.addr)?;
let broadcast = config.broadcast;
Ok(Self { server, broadcast })
}
pub async fn run(&mut self) -> anyhow::Result<()> {
loop {
let conn = self.server.accept().await.context("failed to accept connection")?;
let broadcast = self.broadcast.clone();
tokio::spawn(async move {
let session = Self::accept_session(conn).await.context("failed to accept session")?;
// Use a wrapper run the session.
let session = Session::new(session);
session.serve_broadcast(broadcast).await
});
}
}
async fn accept_session(conn: quinn::Connecting) -> anyhow::Result<WebTransportSession> {
let conn = conn.await.context("failed to accept h3 connection")?;
let mut conn = h3::server::builder()
.enable_webtransport(true)
.enable_connect(true)
.enable_datagram(true)
.max_webtransport_sessions(1)
.send_grease(true)
.build(h3_quinn::Connection::new(conn))
.await
.context("failed to create h3 server")?;
let (req, stream) = conn
.accept()
.await
.context("failed to accept h3 session")?
.context("failed to accept h3 request")?;
let ext = req.extensions();
anyhow::ensure!(req.method() == http::Method::CONNECT, "expected CONNECT request");
anyhow::ensure!(
ext.get::<h3::ext::Protocol>() == Some(&h3::ext::Protocol::WEB_TRANSPORT),
"expected WebTransport CONNECT"
);
let session = WebTransportSession::accept(req, stream, conn)
.await
.context("failed to accept WebTransport session")?;
Ok(session)
}
}

115
src/app/session.rs Normal file
View File

@ -0,0 +1,115 @@
use crate::media;
use anyhow::Context;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::task::JoinSet;
use super::WebTransportSession;
use super::message;
#[derive(Clone)]
pub struct Session {
// The underlying transport session
transport: Arc<WebTransportSession>,
}
impl Session {
pub fn new(transport: WebTransportSession) -> Self {
let transport = Arc::new(transport);
Self { transport }
}
pub async fn serve_broadcast(&self, mut broadcast: media::Broadcast) -> anyhow::Result<()> {
let mut tasks = JoinSet::new();
let mut done = false;
loop {
tokio::select! {
// Accept new tracks added to the broadcast.
track = broadcast.tracks.next(), if !done => {
match track {
Some(track) => {
let session = self.clone();
tasks.spawn(async move {
session.serve_track(track).await
});
},
None => done = true,
}
},
// Poll any pending tracks until they exit.
res = tasks.join_next(), if !tasks.is_empty() => {
let res = res.context("no tracks running")?;
let res = res.context("failed to run track")?;
res.context("failed to serve track")?;
},
else => return Ok(()),
}
}
}
pub async fn serve_track(&self, mut track: media::Track) -> anyhow::Result<()> {
let mut tasks = JoinSet::new();
let mut done = false;
loop {
tokio::select! {
// Accept new tracks added to the broadcast.
segment = track.segments.next(), if !done => {
match segment {
Some(segment) => {
let track = track.clone();
let session = self.clone();
tasks.spawn(async move {
session.serve_segment(track, segment).await
});
},
None => done = true,
}
},
// Poll any pending segments until they exit.
res = tasks.join_next(), if !tasks.is_empty() => {
let res = res.context("no tasks running")?;
let res = res.context("failed to run segment")?;
res.context("failed serve segment")?
},
else => return Ok(()),
}
}
}
pub async fn serve_segment(&self, track: media::Track, mut segment: media::Segment) -> anyhow::Result<()> {
let mut stream = self.transport.open_uni(self.transport.session_id()).await?;
// TODO support prioirty
// stream.set_priority(0);
// Encode a JSON header indicating this is a new segment.
let mut message: message::Message = message::Message::new();
// TODO combine init and segment messages into one.
if track.id == 0xff {
message.init = Some(message::Init {});
} else {
message.segment = Some(message::Segment { track_id: track.id });
}
// Write the JSON header.
let data = message.serialize()?;
stream.write_all(data.as_slice()).await?;
// Write each fragment as they are available.
while let Some(fragment) = segment.fragments.next().await {
stream.write_all(fragment.as_slice()).await?;
}
// NOTE: stream is automatically closed when dropped
Ok(())
}
}

View File

@ -1,3 +1,2 @@
pub mod app;
pub mod media;
pub mod session;
pub mod transport;

View File

@ -1,9 +1,7 @@
use std::io::BufReader;
use std::net::SocketAddr;
use std::{fs::File, sync::Arc};
use moq::{session, transport};
use moq::{app, media};
use std::{fs, io, net, path, sync};
use anyhow::Context;
use clap::Parser;
use ring::digest::{digest, SHA256};
use warp::Filter;
@ -13,54 +11,57 @@ use warp::Filter;
struct Cli {
/// Listen on this address
#[arg(short, long, default_value = "[::]:4443")]
addr: String,
addr: net::SocketAddr,
/// Use the certificate file at this path
#[arg(short, long, default_value = "cert/localhost.crt")]
cert: String,
cert: path::PathBuf,
/// Use the private key at this path
#[arg(short, long, default_value = "cert/localhost.key")]
key: String,
key: path::PathBuf,
/// Use the media file at this path
#[arg(short, long, default_value = "media/fragmented.mp4")]
media: String,
media: path::PathBuf,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::init();
let moq_args = Cli::parse();
let http_args = moq_args.clone();
let args = Cli::parse();
// TODO return result instead of panicing
tokio::task::spawn(async move { run_transport(moq_args).unwrap() });
// Create a web server to serve the fingerprint
let serve = serve_http(args.clone());
run_http(http_args).await
}
// Create a fake media source from disk.
let mut media = media::Source::new(args.media).context("failed to open fragmented.mp4")?;
// Run the WebTransport server using quiche.
fn run_transport(args: Cli) -> anyhow::Result<()> {
let server_config = transport::Config {
// Create a server to actually serve the media
let config = app::ServerConfig {
addr: args.addr,
cert: args.cert,
key: args.key,
broadcast: media.broadcast(),
};
let mut server = transport::Server::<session::Session>::new(server_config).unwrap();
server.run()
let mut server = app::Server::new(config).context("failed to create server")?;
// Run all of the above
tokio::select! {
res = server.run() => res.context("failed to run server"),
res = media.run() => res.context("failed to run media source"),
res = serve => res.context("failed to run HTTP server"),
}
}
// Run a HTTP server using Warp
// TODO remove this when Chrome adds support for self-signed certificates using WebTransport
async fn run_http(args: Cli) -> anyhow::Result<()> {
let addr: SocketAddr = args.addr.parse()?;
async fn serve_http(args: Cli) -> anyhow::Result<()> {
// Read the PEM certificate file
let crt = File::open(&args.cert)?;
let mut crt = BufReader::new(crt);
let crt = fs::File::open(&args.cert)?;
let mut crt = io::BufReader::new(crt);
// Parse the DER certificate
let certs = rustls_pemfile::certs(&mut crt)?;
@ -69,7 +70,7 @@ async fn run_http(args: Cli) -> anyhow::Result<()> {
// Compute the SHA-256 digest
let fingerprint = digest(&SHA256, cert.as_ref());
let fingerprint = hex::encode(fingerprint.as_ref());
let fingerprint = Arc::new(fingerprint);
let fingerprint = sync::Arc::new(fingerprint);
let cors = warp::cors().allow_any_origin();
@ -83,7 +84,7 @@ async fn run_http(args: Cli) -> anyhow::Result<()> {
.tls()
.cert_path(args.cert)
.key_path(args.key)
.run(addr)
.run(args.addr)
.await;
Ok(())

View File

@ -1,3 +1,8 @@
mod source;
pub use source::Source;
pub use source::{Fragment, Source};
mod model;
pub use model::*;
mod watch;
use watch::{Producer, Subscriber};

28
src/media/model.rs Normal file
View File

@ -0,0 +1,28 @@
use super::Subscriber;
use std::{sync, time};
#[derive(Clone)]
pub struct Broadcast {
pub tracks: Subscriber<Track>,
}
#[derive(Clone)]
pub struct Track {
// The track ID as stored in the MP4
pub id: u32,
// A list of segments, which are independently decodable.
pub segments: Subscriber<Segment>,
}
#[derive(Clone)]
pub struct Segment {
// The timestamp of the segment.
pub timestamp: time::Duration,
// A list of fragments that make up the segment.
pub fragments: Subscriber<Fragment>,
}
// Use Arc to avoid cloning the entire MP4 data for each subscriber.
pub type Fragment = sync::Arc<Vec<u8>>;

View File

@ -1,51 +1,32 @@
use std::collections::VecDeque;
use std::io::Read;
use std::{fs, io, time};
use std::{fs, io, path, time};
use anyhow;
use mp4;
use mp4::ReadBox;
use anyhow::Context;
use std::collections::HashMap;
use super::{Broadcast, Fragment, Producer, Segment, Track};
pub struct Source {
// We read the file once, in order, and don't seek backwards.
reader: io::BufReader<fs::File>,
// The timestamp when the broadcast "started", so we can sleep to simulate a live stream.
start: time::Instant,
// The tracks we're producing
broadcast: Broadcast,
// The initialization payload; ftyp + moov boxes.
pub init: Vec<u8>,
// The parsed moov box.
moov: mp4::MoovBox,
// Any fragments parsed and ready to be returned by next().
fragments: VecDeque<Fragment>,
}
pub struct Fragment {
// The track ID for the fragment.
pub track_id: u32,
// The data of the fragment.
pub data: Vec<u8>,
// Whether this fragment is a keyframe.
pub keyframe: bool,
// The number of samples that make up a second (ex. ms = 1000)
pub timescale: u64,
// The timestamp of the fragment, in timescale units, to simulate a live stream.
pub timestamp: u64,
// The tracks we're producing.
tracks: HashMap<u32, SourceTrack>,
}
impl Source {
pub fn new(path: &str) -> anyhow::Result<Self> {
pub fn new(path: path::PathBuf) -> anyhow::Result<Self> {
let f = fs::File::open(path)?;
let mut reader = io::BufReader::new(f);
let start = time::Instant::now();
let ftyp = read_atom(&mut reader)?;
anyhow::ensure!(&ftyp[4..8] == b"ftyp", "expected ftyp atom");
@ -64,28 +45,72 @@ impl Source {
// Parse the moov box so we can detect the timescales for each track.
let moov = mp4::MoovBox::read_box(&mut moov_reader, moov_header.size)?;
// Create a producer to populate the tracks.
let mut tracks = Producer::<Track>::new();
let broadcast = Broadcast {
tracks: tracks.subscribe(),
};
// Create the init track
let init_track = Self::create_init_track(init);
tracks.push(init_track);
// Create a map with the current segment for each track.
// NOTE: We don't add the init track to this, since it's not part of the MP4.
let mut lookup = HashMap::new();
for trak in &moov.traks {
let track_id = trak.tkhd.track_id;
anyhow::ensure!(track_id != 0xff, "track ID 0xff is reserved");
let timescale = track_timescale(&moov, track_id);
let segments = Producer::<Segment>::new();
tracks.push(Track {
id: track_id,
segments: segments.subscribe(),
});
// Store the track publisher in a map so we can update it later.
let track = SourceTrack::new(segments, timescale);
lookup.insert(track_id, track);
}
Ok(Self {
reader,
start,
init,
moov,
fragments: VecDeque::new(),
broadcast,
tracks: lookup,
})
}
pub fn fragment(&mut self) -> anyhow::Result<Option<Fragment>> {
if self.fragments.is_empty() {
self.parse()?;
};
// Create an init track
fn create_init_track(raw: Vec<u8>) -> Track {
// TODO support static producers
let mut fragments = Producer::<Fragment>::new();
let mut segments = Producer::<Segment>::new();
if self.timeout().is_some() {
return Ok(None);
fragments.push(raw.into());
segments.push(Segment {
fragments: fragments.subscribe(),
timestamp: time::Duration::ZERO,
});
Track {
id: 0xff,
segments: segments.subscribe(),
}
Ok(self.fragments.pop_front())
}
fn parse(&mut self) -> anyhow::Result<()> {
pub async fn run(&mut self) -> anyhow::Result<()> {
// The timestamp when the broadcast "started", so we can sleep to simulate a live stream.
let start = tokio::time::Instant::now();
// The ID of the last moof header.
let mut track_id = None;
loop {
let atom = read_atom(&mut self.reader)?;
@ -93,46 +118,33 @@ impl Source {
let header = mp4::BoxHeader::read(&mut reader)?;
match header.name {
mp4::BoxType::FtypBox | mp4::BoxType::MoovBox => {
anyhow::bail!("must call init first")
}
mp4::BoxType::MoofBox => {
let moof = mp4::MoofBox::read_box(&mut reader, header.size)?;
let moof = mp4::MoofBox::read_box(&mut reader, header.size).context("failed to read MP4")?;
if moof.trafs.len() != 1 {
// We can't split the mdat atom, so this is impossible to support
anyhow::bail!("multiple tracks per moof atom")
}
// Process the moof.
let fragment = SourceFragment::new(moof)?;
let track_id = moof.trafs[0].tfhd.track_id;
let timestamp = sample_timestamp(&moof).expect("couldn't find timestamp");
// Get the track for this moof.
let track = self.tracks.get_mut(&fragment.track).context("failed to find track")?;
// Detect if this is a keyframe.
let keyframe = sample_keyframe(&moof);
// Sleep until we should publish this sample.
let timestamp = time::Duration::from_millis(1000 * fragment.timestamp / track.timescale);
tokio::time::sleep_until(start + timestamp).await;
let timescale = track_timescale(&self.moov, track_id);
// Save the track ID for the next iteration, which must be a mdat.
anyhow::ensure!(track_id.is_none(), "multiple moof atoms");
track_id.replace(fragment.track);
self.fragments.push_back(Fragment {
track_id,
data: atom,
keyframe,
timescale,
timestamp,
})
// Publish the moof header, creating a new segment if it's a keyframe.
track.header(atom, fragment).context("failed to publish moof")?;
}
mp4::BoxType::MdatBox => {
let moof = self.fragments.back().expect("no atom before mdat");
// Get the track ID from the previous moof.
let track_id = track_id.take().context("missing moof")?;
let track = self.tracks.get_mut(&track_id).context("failed to find track")?;
self.fragments.push_back(Fragment {
track_id: moof.track_id,
data: atom,
keyframe: false,
timescale: moof.timescale,
timestamp: moof.timestamp,
});
// We have some media data, return so we can start sending it.
return Ok(());
// Publish the mdat atom.
track.data(atom).context("failed to publish mdat")?;
}
_ => {
// Skip unknown atoms
@ -141,19 +153,108 @@ impl Source {
}
}
// Simulate a live stream by sleeping until the next timestamp in the media.
pub fn timeout(&self) -> Option<time::Duration> {
let next = self.fragments.front()?;
pub fn broadcast(&self) -> Broadcast {
self.broadcast.clone()
}
}
let delay = time::Duration::from_millis(1000 * next.timestamp / next.timescale);
let elapsed = self.start.elapsed();
struct SourceTrack {
// The track we're producing
segments: Producer<Segment>,
delay.checked_sub(elapsed)
// The current segment's fragments
fragments: Option<Producer<Fragment>>,
// The number of units per second.
timescale: u64,
}
impl SourceTrack {
fn new(segments: Producer<Segment>, timescale: u64) -> Self {
Self {
segments,
fragments: None,
timescale,
}
}
pub fn header(&mut self, raw: Vec<u8>, fragment: SourceFragment) -> anyhow::Result<()> {
// Close the current segment if we have a new keyframe.
if fragment.keyframe {
self.fragments.take();
}
// Get or create the current segment.
let fragments = self.fragments.get_or_insert_with(|| {
// Compute the timestamp in seconds.
let timestamp = fragment.timestamp(self.timescale);
// Create a new segment, and save the fragments producer so we can push to it.
let fragments = Producer::<Fragment>::new();
self.segments.push(Segment {
timestamp,
fragments: fragments.subscribe(),
});
// Remove any segments older than 10s.
let expires = timestamp.saturating_sub(time::Duration::from_secs(10));
self.segments.drain(|segment| segment.timestamp < expires);
fragments
});
// Insert the raw atom into the segment.
fragments.push(raw.into());
Ok(())
}
pub fn data(&mut self, raw: Vec<u8>) -> anyhow::Result<()> {
let fragments = self.fragments.as_mut().context("missing keyframe")?;
fragments.push(raw.into());
Ok(())
}
}
struct SourceFragment {
// The track for this fragment.
track: u32,
// The timestamp of the first sample in this fragment, in timescale units.
timestamp: u64,
// True if this fragment is a keyframe.
keyframe: bool,
}
impl SourceFragment {
fn new(moof: mp4::MoofBox) -> anyhow::Result<Self> {
// We can't split the mdat atom, so this is impossible to support
anyhow::ensure!(moof.trafs.len() == 1, "multiple tracks per moof atom");
let track = moof.trafs[0].tfhd.track_id;
// Parse the moof to get some timing information to sleep.
let timestamp = sample_timestamp(&moof).expect("couldn't find timestamp");
// Detect if we should start a new segment.
let keyframe = sample_keyframe(&moof);
Ok(Self {
track,
timestamp,
keyframe,
})
}
// Convert from timescale units to a duration.
fn timestamp(&self, timescale: u64) -> time::Duration {
time::Duration::from_millis(1000 * self.timestamp / timescale)
}
}
// Read a full MP4 atom into a vector.
pub fn read_atom<R: Read>(reader: &mut R) -> anyhow::Result<Vec<u8>> {
fn read_atom<R: Read>(reader: &mut R) -> anyhow::Result<Vec<u8>> {
// Read the 8 bytes for the size + type
let mut buf = [0u8; 8];
reader.read_exact(&mut buf)?;

127
src/media/watch.rs Normal file
View File

@ -0,0 +1,127 @@
use std::collections::VecDeque;
use tokio::sync::watch;
#[derive(Default)]
struct State<T> {
queue: VecDeque<T>,
drained: usize,
closed: bool,
}
impl<T> State<T> {
pub fn new() -> Self {
Self {
queue: VecDeque::new(),
drained: 0,
closed: false,
}
}
// Add a new element to the end of the queue.
fn push(&mut self, t: T) {
self.queue.push_back(t)
}
// Remove elements from the head of the queue if they match the conditional.
fn drain<F>(&mut self, f: F) -> usize
where
F: Fn(&T) -> bool,
{
let prior = self.drained;
while let Some(first) = self.queue.front() {
if !f(first) {
break;
}
self.queue.pop_front();
self.drained += 1;
}
self.drained - prior
}
}
pub struct Producer<T: Clone> {
sender: watch::Sender<State<T>>,
}
impl<T: Clone> Producer<T> {
pub fn new() -> Self {
let state = State::new();
let (sender, _) = watch::channel(state);
Self { sender }
}
// Push a new element to the end of the queue.
pub fn push(&mut self, value: T) {
self.sender.send_modify(|state| state.push(value));
}
// Remove any elements from the front of the queue that match the condition.
pub fn drain<F>(&mut self, f: F)
where
F: Fn(&T) -> bool,
{
// Use send_if_modified to never notify with the updated state.
self.sender.send_if_modified(|state| {
state.drain(f);
false
});
}
pub fn subscribe(&self) -> Subscriber<T> {
Subscriber::new(self.sender.subscribe())
}
}
impl<T: Clone> Default for Producer<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Clone> Drop for Producer<T> {
fn drop(&mut self) {
self.sender.send_modify(|state| state.closed = true);
}
}
#[derive(Clone)]
pub struct Subscriber<T: Clone> {
state: watch::Receiver<State<T>>,
index: usize,
}
impl<T: Clone> Subscriber<T> {
fn new(state: watch::Receiver<State<T>>) -> Self {
Self { state, index: 0 }
}
pub async fn next(&mut self) -> Option<T> {
// Wait until the queue has a new element or if it's closed.
let state = self
.state
.wait_for(|state| state.closed || self.index < state.drained + state.queue.len())
.await
.expect("publisher dropped without close");
// If our index is smaller than drained, skip past those elements we missed.
let index = self.index.saturating_sub(state.drained);
if index < state.queue.len() {
// Clone the next element in the queue.
let element = state.queue[index].clone();
// Increment our index, relative to drained so we can skip ahead if needed.
self.index = index + state.drained + 1;
Some(element)
} else if state.closed {
// Return None if we've consumed all entries and the queue is closed.
None
} else {
panic!("impossible subscriber state")
}
}
}

View File

@ -1,168 +0,0 @@
mod message;
use std::collections::hash_map as hmap;
use std::time;
use quiche;
use quiche::h3::webtransport;
use crate::{media, transport};
#[derive(Default)]
pub struct Session {
// The media source, configured on CONNECT.
media: Option<media::Source>,
// A helper for automatically buffering stream data.
streams: transport::Streams,
// Map from track_id to the the Track state.
tracks: hmap::HashMap<u32, Track>,
}
pub struct Track {
// Current stream_id
stream_id: Option<u64>,
// The timescale used for this track.
timescale: u64,
// The timestamp of the last keyframe.
keyframe: u64,
}
impl transport::App for Session {
// Process any updates to a session.
fn poll(&mut self, conn: &mut quiche::Connection, session: &mut webtransport::ServerSession) -> anyhow::Result<()> {
loop {
let event = match session.poll(conn) {
Err(webtransport::Error::Done) => break,
Err(e) => return Err(e.into()),
Ok(e) => e,
};
log::debug!("webtransport event {:?}", event);
match event {
webtransport::ServerEvent::ConnectRequest(_req) => {
// you can handle request with
// req.authority()
// req.path()
// and you can validate this request with req.origin()
session.accept_connect_request(conn, None)?;
// TODO
let media = media::Source::new("media/fragmented.mp4").expect("failed to open fragmented.mp4");
let init = &media.init;
// Create a JSON header.
let mut message = message::Message::new();
message.init = Some(message::Init {});
let data = message.serialize()?;
// Create a new stream and write the header.
let stream_id = session.open_stream(conn, false)?;
self.streams.send(conn, stream_id, data.as_slice(), false)?;
self.streams.send(conn, stream_id, init.as_slice(), true)?;
self.media = Some(media);
}
webtransport::ServerEvent::StreamData(stream_id) => {
let mut buf = vec![0; 10000];
while let Ok(len) = session.recv_stream_data(conn, stream_id, &mut buf) {
let _stream_data = &buf[0..len];
}
}
_ => {}
}
}
// Send any pending stream data.
// NOTE: This doesn't return an error because it's async, and would be confusing.
self.streams.poll(conn);
// Fetch the next media fragment, possibly queuing up stream data.
self.poll_source(conn, session)?;
Ok(())
}
fn timeout(&self) -> Option<time::Duration> {
self.media.as_ref().and_then(|m| m.timeout())
}
}
impl Session {
fn poll_source(
&mut self,
conn: &mut quiche::Connection,
session: &mut webtransport::ServerSession,
) -> anyhow::Result<()> {
// Get the media source once the connection is established.
let media = match &mut self.media {
Some(m) => m,
None => return Ok(()),
};
// Get the next media fragment.
let fragment = match media.fragment()? {
Some(f) => f,
None => return Ok(()),
};
// Get the track state or insert a new entry.
let track = self.tracks.entry(fragment.track_id).or_insert_with(|| Track {
stream_id: None,
timescale: fragment.timescale,
keyframe: 0,
});
if let Some(stream_id) = track.stream_id {
// Existing stream, check if we should close it.
if fragment.keyframe && fragment.timestamp >= track.keyframe + track.timescale {
// Close the existing stream
self.streams.send(conn, stream_id, &[], true)?;
// Unset the stream id so we create a new one.
track.stream_id = None;
track.keyframe = fragment.timestamp;
}
}
let stream_id = match track.stream_id {
Some(stream_id) => stream_id,
None => {
// Create a new unidirectional stream.
let stream_id = session.open_stream(conn, false)?;
// Set the stream priority to be equal to the timestamp.
// We subtract from u64::MAX so newer media is sent important.
// TODO prioritize audio
let order = u64::MAX - fragment.timestamp;
self.streams.send_order(conn, stream_id, order);
// Encode a JSON header indicating this is a new track.
let mut message: message::Message = message::Message::new();
message.segment = Some(message::Segment {
track_id: fragment.track_id,
});
// Write the header.
let data = message.serialize()?;
self.streams.send(conn, stream_id, &data, false)?;
stream_id
}
};
// Write the current fragment.
let data = fragment.data.as_slice();
self.streams.send(conn, stream_id, data, false)?;
// Save the stream_id for the next fragment.
track.stream_id = Some(stream_id);
Ok(())
}
}

View File

@ -1,8 +0,0 @@
use std::time;
use quiche::h3::webtransport;
pub trait App: Default {
fn poll(&mut self, conn: &mut quiche::Connection, session: &mut webtransport::ServerSession) -> anyhow::Result<()>;
fn timeout(&self) -> Option<time::Duration>;
}

View File

@ -1,15 +0,0 @@
use quiche;
use quiche::h3::webtransport;
use std::collections::hash_map as hmap;
pub type Id = quiche::ConnectionId<'static>;
use super::app;
pub type Map<T> = hmap::HashMap<Id, Connection<T>>;
pub struct Connection<T: app::App> {
pub quiche: quiche::Connection,
pub session: Option<webtransport::ServerSession>,
pub app: T,
}

View File

@ -1,8 +0,0 @@
mod app;
mod connection;
mod server;
mod streams;
pub use app::App;
pub use server::{Config, Server};
pub use streams::Streams;

View File

@ -1,378 +0,0 @@
use std::io;
use quiche::h3::webtransport;
use super::app;
use super::connection;
const MAX_DATAGRAM_SIZE: usize = 1350;
pub struct Server<T: app::App> {
// IO stuff
socket: mio::net::UdpSocket,
poll: mio::Poll,
events: mio::Events,
// QUIC stuff
quic: quiche::Config,
seed: ring::hmac::Key, // connection ID seed
conns: connection::Map<T>,
}
pub struct Config {
pub addr: String,
pub cert: String,
pub key: String,
}
impl<T: app::App> Server<T> {
pub fn new(config: Config) -> io::Result<Self> {
// Listen on the provided socket address
let addr = config.addr.parse().unwrap();
let mut socket = mio::net::UdpSocket::bind(addr).unwrap();
// Setup the event loop.
let poll = mio::Poll::new().unwrap();
let events = mio::Events::with_capacity(1024);
poll.registry()
.register(&mut socket, mio::Token(0), mio::Interest::READABLE)
.unwrap();
// Generate random values for connection IDs.
let rng = ring::rand::SystemRandom::new();
let seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap();
// Create the configuration for the QUIC conns.
let mut quic = quiche::Config::new(quiche::PROTOCOL_VERSION).unwrap();
quic.load_cert_chain_from_pem_file(&config.cert).unwrap();
quic.load_priv_key_from_pem_file(&config.key).unwrap();
quic.set_application_protos(quiche::h3::APPLICATION_PROTOCOL).unwrap();
quic.set_max_idle_timeout(5000);
quic.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE);
quic.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE);
quic.set_initial_max_data(10_000_000);
quic.set_initial_max_stream_data_bidi_local(1_000_000);
quic.set_initial_max_stream_data_bidi_remote(1_000_000);
quic.set_initial_max_stream_data_uni(1_000_000);
quic.set_initial_max_streams_bidi(100);
quic.set_initial_max_streams_uni(100);
quic.set_disable_active_migration(true);
quic.enable_early_data();
quic.enable_dgram(true, 65536, 65536);
let conns = Default::default();
Ok(Server {
socket,
poll,
events,
quic,
seed,
conns,
})
}
pub fn run(&mut self) -> anyhow::Result<()> {
log::info!("listening on {}", self.socket.local_addr()?);
loop {
self.wait()?;
self.receive()?;
self.app()?;
self.send()?;
self.cleanup();
}
}
pub fn wait(&mut self) -> anyhow::Result<()> {
// Find the shorter timeout from all the active connections.
//
// TODO: use event loop that properly supports timers
let timeout = self
.conns
.values()
.filter_map(|c| {
let timeout = c.quiche.timeout();
let expires = c.app.timeout();
match (timeout, expires) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
}
})
.min();
self.poll.poll(&mut self.events, timeout).unwrap();
// If the event loop reported no events, it means that the timeout
// has expired, so handle it without attempting to read packets. We
// will then proceed with the send loop.
if self.events.is_empty() {
for conn in self.conns.values_mut() {
conn.quiche.on_timeout();
}
}
Ok(())
}
// Reads packets from the socket, updating any internal connection state.
fn receive(&mut self) -> anyhow::Result<()> {
let mut src = [0; MAX_DATAGRAM_SIZE];
// Try reading any data currently available on the socket.
loop {
let (len, from) = match self.socket.recv_from(&mut src) {
Ok(v) => v,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e.into()),
};
let src = &mut src[..len];
let info = quiche::RecvInfo {
to: self.socket.local_addr().unwrap(),
from,
};
// Parse the QUIC packet's header.
let hdr = quiche::Header::from_slice(src, quiche::MAX_CONN_ID_LEN).unwrap();
let conn_id = ring::hmac::sign(&self.seed, &hdr.dcid);
let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN];
let conn_id = conn_id.to_vec().into();
// Check if it's an existing connection.
if let Some(conn) = self.conns.get_mut(&hdr.dcid) {
conn.quiche.recv(src, info)?;
if conn.session.is_none() && conn.quiche.is_established() {
conn.session = Some(webtransport::ServerSession::with_transport(&mut conn.quiche)?)
}
continue;
} else if let Some(conn) = self.conns.get_mut(&conn_id) {
conn.quiche.recv(src, info)?;
// TODO is this needed here?
if conn.session.is_none() && conn.quiche.is_established() {
conn.session = Some(webtransport::ServerSession::with_transport(&mut conn.quiche)?)
}
continue;
}
if hdr.ty != quiche::Type::Initial {
log::warn!("unknown connection ID");
continue;
}
let mut dst = [0; MAX_DATAGRAM_SIZE];
if !quiche::version_is_supported(hdr.version) {
let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut dst).unwrap();
let dst = &dst[..len];
self.socket.send_to(dst, from).unwrap();
continue;
}
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
scid.copy_from_slice(&conn_id);
let scid = quiche::ConnectionId::from_ref(&scid);
// Token is always present in Initial packets.
let token = hdr.token.as_ref().unwrap();
// Do stateless retry if the client didn't send a token.
if token.is_empty() {
let new_token = mint_token(&hdr, &from);
let len = quiche::retry(&hdr.scid, &hdr.dcid, &scid, &new_token, hdr.version, &mut dst).unwrap();
let dst = &dst[..len];
self.socket.send_to(dst, from).unwrap();
continue;
}
let odcid = validate_token(&from, token);
// The token was not valid, meaning the retry failed, so
// drop the packet.
if odcid.is_none() {
log::warn!("invalid token");
continue;
}
if scid.len() != hdr.dcid.len() {
log::warn!("invalid connection ID");
continue;
}
// Reuse the source connection ID we sent in the Retry packet,
// instead of changing it again.
let conn_id = hdr.dcid.clone();
let local_addr = self.socket.local_addr().unwrap();
log::debug!("new connection: dcid={:?} scid={:?}", hdr.dcid, scid);
let mut conn = quiche::accept(&conn_id, odcid.as_ref(), local_addr, from, &mut self.quic)?;
// Log each session with QLOG if the ENV var is set.
if let Some(dir) = std::env::var_os("QLOGDIR") {
let id = format!("{:?}", &scid);
let mut path = std::path::PathBuf::from(dir);
let filename = format!("server-{id}.sqlog");
path.push(filename);
let writer = match std::fs::File::create(&path) {
Ok(f) => std::io::BufWriter::new(f),
Err(e) => panic!("Error creating qlog file attempted path was {:?}: {}", path, e),
};
conn.set_qlog(
std::boxed::Box::new(writer),
"warp-server qlog".to_string(),
format!("{} id={}", "warp-server qlog", id),
);
}
// Process potentially coalesced packets.
conn.recv(src, info)?;
let user = connection::Connection {
quiche: conn,
session: None,
app: T::default(),
};
self.conns.insert(conn_id, user);
}
}
pub fn app(&mut self) -> anyhow::Result<()> {
for conn in self.conns.values_mut() {
if conn.quiche.is_closed() {
continue;
}
if let Some(session) = &mut conn.session {
if let Err(e) = conn.app.poll(&mut conn.quiche, session) {
log::debug!("app error: {:?}", e);
// Close the connection on any application error
let reason = format!("app error: {:?}", e);
conn.quiche.close(true, 0xff, reason.as_bytes()).ok();
}
}
}
Ok(())
}
// Generate outgoing QUIC packets for all active connections and send
// them on the UDP socket, until quiche reports that there are no more
// packets to be sent.
pub fn send(&mut self) -> anyhow::Result<()> {
for conn in self.conns.values_mut() {
let conn = &mut conn.quiche;
if let Err(e) = send_conn(&self.socket, conn) {
log::error!("{} send failed: {:?}", conn.trace_id(), e);
conn.close(false, 0x1, b"fail").ok();
}
}
Ok(())
}
pub fn cleanup(&mut self) {
// Garbage collect closed connections.
self.conns.retain(|_, ref mut c| !c.quiche.is_closed());
}
}
// Send any pending packets for the connection over the socket.
fn send_conn(socket: &mio::net::UdpSocket, conn: &mut quiche::Connection) -> anyhow::Result<()> {
let mut pkt = [0; MAX_DATAGRAM_SIZE];
loop {
let (size, info) = match conn.send(&mut pkt) {
Ok(v) => v,
Err(quiche::Error::Done) => return Ok(()),
Err(e) => return Err(e.into()),
};
let pkt = &pkt[..size];
match socket.send_to(pkt, info.to) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e.into()),
Ok(_) => (),
}
}
}
/// Generate a stateless retry token.
///
/// The token includes the static string `"quiche"` followed by the IP address
/// of the client and by the original destination connection ID generated by the
/// client.
///
/// Note that this function is only an example and doesn't do any cryptographic
/// authenticate of the token. *It should not be used in production system*.
fn mint_token(hdr: &quiche::Header, src: &std::net::SocketAddr) -> Vec<u8> {
let mut token = Vec::new();
token.extend_from_slice(b"quiche");
let addr = match src.ip() {
std::net::IpAddr::V4(a) => a.octets().to_vec(),
std::net::IpAddr::V6(a) => a.octets().to_vec(),
};
token.extend_from_slice(&addr);
token.extend_from_slice(&hdr.dcid);
token
}
/// Validates a stateless retry token.
///
/// This checks that the ticket includes the `"quiche"` static string, and that
/// the client IP address matches the address stored in the ticket.
///
/// Note that this function is only an example and doesn't do any cryptographic
/// authenticate of the token. *It should not be used in production system*.
fn validate_token<'a>(src: &std::net::SocketAddr, token: &'a [u8]) -> Option<quiche::ConnectionId<'a>> {
if token.len() < 6 {
return None;
}
if &token[..6] != b"quiche" {
return None;
}
let token = &token[6..];
let addr = match src.ip() {
std::net::IpAddr::V4(a) => a.octets().to_vec(),
std::net::IpAddr::V6(a) => a.octets().to_vec(),
};
if token.len() < addr.len() || &token[..addr.len()] != addr.as_slice() {
return None;
}
Some(quiche::ConnectionId::from_ref(&token[addr.len()..]))
}

View File

@ -1,136 +0,0 @@
use std::collections::VecDeque;
use anyhow;
use quiche;
#[derive(Default)]
pub struct Streams {
ordered: Vec<Stream>,
}
struct Stream {
id: u64,
order: u64,
buffer: VecDeque<u8>,
fin: bool,
}
impl Streams {
// Write the data to the given stream, buffering it if needed.
pub fn send(&mut self, conn: &mut quiche::Connection, id: u64, buf: &[u8], fin: bool) -> anyhow::Result<()> {
if buf.is_empty() && !fin {
return Ok(());
}
// Get the index of the stream, or add it to the list of streams.
let pos = self.ordered.iter().position(|s| s.id == id).unwrap_or_else(|| {
// Create a new stream
let stream = Stream {
id,
buffer: VecDeque::new(),
fin: false,
order: 0, // Default to highest priority until send_order is called.
};
self.insert(conn, stream)
});
let stream = &mut self.ordered[pos];
// Check if we've already closed the stream, just in case.
if stream.fin && !buf.is_empty() {
anyhow::bail!("stream is already finished");
}
// If there's no data buffered, try to write it immediately.
let size = if stream.buffer.is_empty() {
match conn.stream_send(id, buf, fin) {
Ok(size) => size,
Err(quiche::Error::Done) => 0,
Err(e) => anyhow::bail!(e),
}
} else {
0
};
if size < buf.len() {
// Short write, save the rest for later.
stream.buffer.extend(&buf[size..]);
}
stream.fin |= fin;
Ok(())
}
// Flush any pending stream data.
pub fn poll(&mut self, conn: &mut quiche::Connection) {
self.ordered.retain_mut(|s| s.poll(conn).is_ok());
}
// Set the send order of the stream.
pub fn send_order(&mut self, conn: &mut quiche::Connection, id: u64, order: u64) {
let mut stream = match self.ordered.iter().position(|s| s.id == id) {
// Remove the stream from the existing list.
Some(pos) => self.ordered.remove(pos),
// This is a new stream, insert it into the list.
None => Stream {
id,
buffer: VecDeque::new(),
fin: false,
order,
},
};
stream.order = order;
self.insert(conn, stream);
}
fn insert(&mut self, conn: &mut quiche::Connection, stream: Stream) -> usize {
// Look for the position to insert the stream.
let pos = match self.ordered.binary_search_by_key(&stream.order, |s| s.order) {
Ok(pos) | Err(pos) => pos,
};
self.ordered.insert(pos, stream);
// Reprioritize all later streams.
// TODO we can avoid this if stream_priorty takes a u64
for (i, stream) in self.ordered[pos..].iter().enumerate() {
_ = conn.stream_priority(stream.id, (pos + i) as u8, true);
}
pos
}
}
impl Stream {
fn poll(&mut self, conn: &mut quiche::Connection) -> quiche::Result<()> {
// Keep reading from the buffer until it's empty.
while !self.buffer.is_empty() {
// VecDeque is a ring buffer, so we can't write the whole thing at once.
let parts = self.buffer.as_slices();
let size = conn.stream_send(self.id, parts.0, false)?;
if size == 0 {
// No more space available for this stream.
return Ok(());
}
// Remove the bytes that were written.
self.buffer.drain(..size);
}
if self.fin {
// Write the stream done signal.
conn.stream_send(self.id, &[], true)?;
Err(quiche::Error::Done)
} else {
Ok(())
}
}
}