277 lines
6.2 KiB
Go
277 lines
6.2 KiB
Go
package warp
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"time"
|
|
|
|
"github.com/adriancable/webtransport-go"
|
|
"github.com/kixelated/invoker"
|
|
)
|
|
|
|
// A single WebTransport session
|
|
type Session struct {
|
|
inner *webtransport.Session
|
|
media *Media
|
|
socket *Socket
|
|
audio *MediaStream
|
|
video *MediaStream
|
|
|
|
streams invoker.Tasks
|
|
}
|
|
|
|
func NewSession(session *webtransport.Session, media *Media, socket *Socket) (s *Session, err error) {
|
|
s = new(Session)
|
|
s.inner = session
|
|
s.media = media
|
|
s.socket = socket
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Session) Run(ctx context.Context) (err error) {
|
|
// TODO validate the session before accepting it
|
|
s.inner.AcceptSession()
|
|
defer s.inner.CloseSession()
|
|
|
|
s.audio, s.video, err = s.media.Start()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start media: %w", err)
|
|
}
|
|
|
|
// Once we've validated the session, now we can start accessing the streams
|
|
return invoker.Run(ctx, s.runAccept, s.runAcceptUni, s.runAudio, s.runVideo, s.streams.Repeat)
|
|
}
|
|
|
|
func (s *Session) runAccept(ctx context.Context) (err error) {
|
|
for {
|
|
// TODO context support :(
|
|
stream, err := s.inner.AcceptStream()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to accept bidirectional stream: %w", err)
|
|
}
|
|
|
|
// Warp doesn't utilize bidirectional streams so just close them immediately.
|
|
// We might use them in the future so don't close the connection with an error.
|
|
stream.CancelRead(1)
|
|
}
|
|
}
|
|
|
|
func (s *Session) runAcceptUni(ctx context.Context) (err error) {
|
|
for {
|
|
stream, err := s.inner.AcceptUniStream(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to accept unidirectional stream: %w", err)
|
|
}
|
|
|
|
s.streams.Add(func(ctx context.Context) (err error) {
|
|
return s.handleStream(ctx, &stream)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s *Session) handleStream(ctx context.Context, stream *webtransport.ReceiveStream) (err error) {
|
|
defer func() {
|
|
if err != nil {
|
|
stream.CancelRead(1)
|
|
}
|
|
}()
|
|
|
|
var header [8]byte
|
|
for {
|
|
_, err = io.ReadFull(stream, header[:])
|
|
if errors.Is(io.EOF, err) {
|
|
return nil
|
|
} else if err != nil {
|
|
return fmt.Errorf("failed to read atom header: %w", err)
|
|
}
|
|
|
|
size := binary.BigEndian.Uint32(header[0:4])
|
|
name := string(header[4:8])
|
|
|
|
if size < 8 {
|
|
return fmt.Errorf("atom size is too small")
|
|
} else if size > 42069 { // arbitrary limit
|
|
return fmt.Errorf("atom size is too large")
|
|
} else if name != "warp" {
|
|
return fmt.Errorf("only warp atoms are supported")
|
|
}
|
|
|
|
payload := make([]byte, size-8)
|
|
|
|
_, err = io.ReadFull(stream, payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read atom payload: %w", err)
|
|
}
|
|
|
|
msg := Message{}
|
|
|
|
err = json.Unmarshal(payload, &msg)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode json payload: %w", err)
|
|
}
|
|
|
|
if msg.Throttle != nil {
|
|
//s.setThrottle(msg.Throttle)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Session) runAudio(ctx context.Context) (err error) {
|
|
init, err := s.audio.Init(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fetch init segment: %w", err)
|
|
}
|
|
|
|
// NOTE: Assumes a single init segment
|
|
err = s.writeInit(ctx, init, 1)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write init stream: %w", err)
|
|
}
|
|
|
|
for {
|
|
segment, err := s.audio.Segment(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get next segment: %w", err)
|
|
}
|
|
|
|
if segment == nil {
|
|
return nil
|
|
}
|
|
|
|
err = s.writeSegment(ctx, segment, 1)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write segment stream: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Session) runVideo(ctx context.Context) (err error) {
|
|
init, err := s.video.Init(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fetch init segment: %w", err)
|
|
}
|
|
|
|
// NOTE: Assumes a single init segment
|
|
err = s.writeInit(ctx, init, 2)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write init stream: %w", err)
|
|
}
|
|
|
|
for {
|
|
segment, err := s.video.Segment(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get next segment: %w", err)
|
|
}
|
|
|
|
if segment == nil {
|
|
return nil
|
|
}
|
|
|
|
err = s.writeSegment(ctx, segment, 2)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write segment stream: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create a stream for an INIT segment and write the container.
|
|
func (s *Session) writeInit(ctx context.Context, init *MediaInit, id int) (err error) {
|
|
temp, err := s.inner.OpenUniStreamSync(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create stream: %w", err)
|
|
}
|
|
|
|
// Wrap the stream in an object that buffers writes instead of blocking.
|
|
stream := NewStream(temp)
|
|
s.streams.Add(stream.Run)
|
|
|
|
defer func() {
|
|
if err != nil {
|
|
stream.WriteCancel(1)
|
|
}
|
|
}()
|
|
|
|
stream.SetPriority(math.MaxInt)
|
|
|
|
err = stream.WriteMessage(Message{
|
|
Init: &MessageInit{Id: id},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write init header: %w", err)
|
|
}
|
|
|
|
_, err = stream.Write(init.Raw)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write init data: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Create a stream for a segment and write the contents, chunk by chunk.
|
|
func (s *Session) writeSegment(ctx context.Context, segment *MediaSegment, init int) (err error) {
|
|
temp, err := s.inner.OpenUniStreamSync(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create stream: %w", err)
|
|
}
|
|
|
|
// Wrap the stream in an object that buffers writes instead of blocking.
|
|
stream := NewStream(temp)
|
|
s.streams.Add(stream.Run)
|
|
|
|
defer func() {
|
|
if err != nil {
|
|
stream.WriteCancel(1)
|
|
}
|
|
}()
|
|
|
|
ms := int(segment.timestamp / time.Millisecond)
|
|
|
|
// newer segments take priority
|
|
stream.SetPriority(ms)
|
|
|
|
err = stream.WriteMessage(Message{
|
|
Segment: &MessageSegment{
|
|
Init: init,
|
|
Timestamp: ms,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write segment header: %w", err)
|
|
}
|
|
|
|
for {
|
|
// Get the next fragment
|
|
buf, err := segment.Read(ctx)
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
} else if err != nil {
|
|
return fmt.Errorf("failed to read segment data: %w", err)
|
|
}
|
|
|
|
// NOTE: This won't block because of our wrapper
|
|
_, err = stream.Write(buf)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write segment data: %w", err)
|
|
}
|
|
}
|
|
|
|
err = stream.Close()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to close segemnt stream: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) setThrottle(msg *MessageThrottle) {
|
|
s.socket.SetWriteRate(msg.Rate)
|
|
s.socket.SetWriteBuffer(msg.Buffer)
|
|
s.socket.SetWriteLoss(msg.Loss)
|
|
}
|