moq-rs/server/internal/warp/server.go

130 lines
2.5 KiB
Go

package warp
import (
"context"
"crypto/tls"
"encoding/hex"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"github.com/kixelated/invoker"
"github.com/kixelated/quic-go"
"github.com/kixelated/quic-go/http3"
"github.com/kixelated/quic-go/logging"
"github.com/kixelated/quic-go/qlog"
"github.com/kixelated/webtransport-go"
)
type Server struct {
inner *webtransport.Server
media *Media
sessions invoker.Tasks
}
type ServerConfig struct {
Addr string
Cert *tls.Certificate
LogDir string
}
func NewServer(config ServerConfig, media *Media) (s *Server, err error) {
s = new(Server)
quicConfig := &quic.Config{}
if config.LogDir != "" {
quicConfig.Tracer = qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
path := fmt.Sprintf("%s-%s.qlog", p, hex.EncodeToString(connectionID))
f, err := os.Create(filepath.Join(config.LogDir, path))
if err != nil {
// lame
panic(err)
}
return f
})
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{*config.Cert},
}
mux := http.NewServeMux()
s.inner = &webtransport.Server{
H3: http3.Server{
TLSConfig: tlsConfig,
QuicConfig: quicConfig,
Addr: config.Addr,
Handler: mux,
},
CheckOrigin: func(r *http.Request) bool { return true },
}
s.media = media
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
hijacker, ok := w.(http3.Hijacker)
if !ok {
panic("unable to hijack connection: must use kixelated/quic-go")
}
conn := hijacker.Connection()
sess, err := s.inner.Upgrade(w, r)
if err != nil {
http.Error(w, "failed to upgrade session", 500)
return
}
err = s.serve(r.Context(), conn, sess)
if err != nil {
log.Println(err)
}
})
return s, nil
}
func (s *Server) runServe(ctx context.Context) (err error) {
return s.inner.ListenAndServe()
}
func (s *Server) runShutdown(ctx context.Context) (err error) {
<-ctx.Done()
s.inner.Close()
return ctx.Err()
}
func (s *Server) Run(ctx context.Context) (err error) {
return invoker.Run(ctx, s.runServe, s.runShutdown, s.sessions.Repeat)
}
func (s *Server) serve(ctx context.Context, conn quic.Connection, sess *webtransport.Session) (err error) {
defer func() {
if err != nil {
sess.CloseWithError(1, err.Error())
} else {
sess.CloseWithError(0, "end of broadcast")
}
}()
ss, err := NewSession(conn, sess, s.media)
if err != nil {
return fmt.Errorf("failed to create session: %w", err)
}
err = ss.Run(ctx)
if err != nil {
return fmt.Errorf("terminated session: %w", err)
}
return nil
}