moq-rs/server/internal/warp/session.go

277 lines
6.2 KiB
Go

package warp
import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"time"
"github.com/adriancable/webtransport-go"
"github.com/kixelated/invoker"
)
// A single WebTransport session
type Session struct {
inner *webtransport.Session
media *Media
socket *Socket
audio *MediaStream
video *MediaStream
streams invoker.Tasks
}
func NewSession(session *webtransport.Session, media *Media, socket *Socket) (s *Session, err error) {
s = new(Session)
s.inner = session
s.media = media
s.socket = socket
return s, nil
}
func (s *Session) Run(ctx context.Context) (err error) {
// TODO validate the session before accepting it
s.inner.AcceptSession()
defer s.inner.CloseSession()
s.audio, s.video, err = s.media.Start()
if err != nil {
return fmt.Errorf("failed to start media: %w", err)
}
// Once we've validated the session, now we can start accessing the streams
return invoker.Run(ctx, s.runAccept, s.runAcceptUni, s.runAudio, s.runVideo, s.streams.Repeat)
}
func (s *Session) runAccept(ctx context.Context) (err error) {
for {
// TODO context support :(
stream, err := s.inner.AcceptStream()
if err != nil {
return fmt.Errorf("failed to accept bidirectional stream: %w", err)
}
// Warp doesn't utilize bidirectional streams so just close them immediately.
// We might use them in the future so don't close the connection with an error.
stream.CancelRead(1)
}
}
func (s *Session) runAcceptUni(ctx context.Context) (err error) {
for {
stream, err := s.inner.AcceptUniStream(ctx)
if err != nil {
return fmt.Errorf("failed to accept unidirectional stream: %w", err)
}
s.streams.Add(func(ctx context.Context) (err error) {
return s.handleStream(ctx, &stream)
})
}
}
func (s *Session) handleStream(ctx context.Context, stream *webtransport.ReceiveStream) (err error) {
defer func() {
if err != nil {
stream.CancelRead(1)
}
}()
var header [8]byte
for {
_, err = io.ReadFull(stream, header[:])
if errors.Is(io.EOF, err) {
return nil
} else if err != nil {
return fmt.Errorf("failed to read atom header: %w", err)
}
size := binary.BigEndian.Uint32(header[0:4])
name := string(header[4:8])
if size < 8 {
return fmt.Errorf("atom size is too small")
} else if size > 42069 { // arbitrary limit
return fmt.Errorf("atom size is too large")
} else if name != "warp" {
return fmt.Errorf("only warp atoms are supported")
}
payload := make([]byte, size-8)
_, err = io.ReadFull(stream, payload)
if err != nil {
return fmt.Errorf("failed to read atom payload: %w", err)
}
msg := Message{}
err = json.Unmarshal(payload, &msg)
if err != nil {
return fmt.Errorf("failed to decode json payload: %w", err)
}
if msg.Throttle != nil {
s.setThrottle(msg.Throttle)
}
}
}
func (s *Session) runAudio(ctx context.Context) (err error) {
init, err := s.audio.Init(ctx)
if err != nil {
return fmt.Errorf("failed to fetch init segment: %w", err)
}
// NOTE: Assumes a single init segment
err = s.writeInit(ctx, init, 1)
if err != nil {
return fmt.Errorf("failed to write init stream: %w", err)
}
for {
segment, err := s.audio.Segment(ctx)
if err != nil {
return fmt.Errorf("failed to get next segment: %w", err)
}
if segment == nil {
return nil
}
err = s.writeSegment(ctx, segment, 1)
if err != nil {
return fmt.Errorf("failed to write segment stream: %w", err)
}
}
}
func (s *Session) runVideo(ctx context.Context) (err error) {
init, err := s.video.Init(ctx)
if err != nil {
return fmt.Errorf("failed to fetch init segment: %w", err)
}
// NOTE: Assumes a single init segment
err = s.writeInit(ctx, init, 2)
if err != nil {
return fmt.Errorf("failed to write init stream: %w", err)
}
for {
segment, err := s.video.Segment(ctx)
if err != nil {
return fmt.Errorf("failed to get next segment: %w", err)
}
if segment == nil {
return nil
}
err = s.writeSegment(ctx, segment, 2)
if err != nil {
return fmt.Errorf("failed to write segment stream: %w", err)
}
}
}
// Create a stream for an INIT segment and write the container.
func (s *Session) writeInit(ctx context.Context, init *MediaInit, id int) (err error) {
temp, err := s.inner.OpenUniStreamSync(ctx)
if err != nil {
return fmt.Errorf("failed to create stream: %w", err)
}
// Wrap the stream in an object that buffers writes instead of blocking.
stream := NewStream(temp)
s.streams.Add(stream.Run)
defer func() {
if err != nil {
stream.WriteCancel(1)
}
}()
stream.SetPriority(math.MaxInt)
err = stream.WriteMessage(Message{
Init: &MessageInit{Id: id},
})
if err != nil {
return fmt.Errorf("failed to write init header: %w", err)
}
_, err = stream.Write(init.Raw)
if err != nil {
return fmt.Errorf("failed to write init data: %w", err)
}
return nil
}
// Create a stream for a segment and write the contents, chunk by chunk.
func (s *Session) writeSegment(ctx context.Context, segment *MediaSegment, init int) (err error) {
temp, err := s.inner.OpenUniStreamSync(ctx)
if err != nil {
return fmt.Errorf("failed to create stream: %w", err)
}
// Wrap the stream in an object that buffers writes instead of blocking.
stream := NewStream(temp)
s.streams.Add(stream.Run)
defer func() {
if err != nil {
stream.WriteCancel(1)
}
}()
ms := int(segment.timestamp / time.Millisecond)
// newer segments take priority
stream.SetPriority(ms)
err = stream.WriteMessage(Message{
Segment: &MessageSegment{
Init: init,
Timestamp: ms,
},
})
if err != nil {
return fmt.Errorf("failed to write segment header: %w", err)
}
for {
// Get the next fragment
buf, err := segment.Read(ctx)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return fmt.Errorf("failed to read segment data: %w", err)
}
// NOTE: This won't block because of our wrapper
_, err = stream.Write(buf)
if err != nil {
return fmt.Errorf("failed to write segment data: %w", err)
}
}
err = stream.Close()
if err != nil {
return fmt.Errorf("failed to close segemnt stream: %w", err)
}
return nil
}
func (s *Session) setThrottle(msg *MessageThrottle) {
s.socket.SetWriteRate(msg.Rate)
s.socket.SetWriteBuffer(msg.Buffer)
s.socket.SetWriteLoss(msg.Loss)
}