2022-06-29 16:17:02 +00:00
|
|
|
package warp
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2022-11-07 09:55:04 +00:00
|
|
|
"crypto/tls"
|
2022-06-29 16:17:02 +00:00
|
|
|
"encoding/hex"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"log"
|
|
|
|
"net/http"
|
|
|
|
"os"
|
|
|
|
"path/filepath"
|
|
|
|
|
|
|
|
"github.com/kixelated/invoker"
|
2022-12-06 00:26:35 +00:00
|
|
|
"github.com/kixelated/quic-go"
|
|
|
|
"github.com/kixelated/quic-go/http3"
|
|
|
|
"github.com/kixelated/quic-go/logging"
|
|
|
|
"github.com/kixelated/quic-go/qlog"
|
|
|
|
"github.com/kixelated/webtransport-go"
|
2022-06-29 16:17:02 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
type Server struct {
|
2022-11-07 09:55:04 +00:00
|
|
|
inner *webtransport.Server
|
|
|
|
media *Media
|
2022-06-29 16:17:02 +00:00
|
|
|
|
|
|
|
sessions invoker.Tasks
|
|
|
|
}
|
|
|
|
|
|
|
|
type ServerConfig struct {
|
2022-11-07 09:55:04 +00:00
|
|
|
Addr string
|
|
|
|
Cert *tls.Certificate
|
|
|
|
LogDir string
|
2022-06-29 16:17:02 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func NewServer(config ServerConfig, media *Media) (s *Server, err error) {
|
|
|
|
s = new(Server)
|
|
|
|
|
|
|
|
quicConfig := &quic.Config{}
|
|
|
|
|
|
|
|
if config.LogDir != "" {
|
|
|
|
quicConfig.Tracer = qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
|
|
|
|
path := fmt.Sprintf("%s-%s.qlog", p, hex.EncodeToString(connectionID))
|
|
|
|
|
|
|
|
f, err := os.Create(filepath.Join(config.LogDir, path))
|
|
|
|
if err != nil {
|
|
|
|
// lame
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return f
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2022-11-07 09:55:04 +00:00
|
|
|
tlsConfig := &tls.Config{
|
|
|
|
Certificates: []tls.Certificate{*config.Cert},
|
|
|
|
}
|
|
|
|
|
|
|
|
mux := http.NewServeMux()
|
|
|
|
|
2022-06-29 16:17:02 +00:00
|
|
|
s.inner = &webtransport.Server{
|
2022-11-07 09:55:04 +00:00
|
|
|
H3: http3.Server{
|
|
|
|
TLSConfig: tlsConfig,
|
|
|
|
QuicConfig: quicConfig,
|
|
|
|
Addr: config.Addr,
|
|
|
|
Handler: mux,
|
|
|
|
},
|
|
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
2022-06-29 16:17:02 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
s.media = media
|
|
|
|
|
2022-11-07 09:55:04 +00:00
|
|
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
2022-11-18 23:13:35 +00:00
|
|
|
hijacker, ok := w.(http3.Hijacker)
|
|
|
|
if !ok {
|
2022-12-06 00:26:35 +00:00
|
|
|
panic("unable to hijack connection: must use kixelated/quic-go")
|
2022-11-18 23:13:35 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
conn := hijacker.Connection()
|
|
|
|
|
2022-12-06 00:26:35 +00:00
|
|
|
sess, err := s.inner.Upgrade(w, r)
|
2022-06-29 16:17:02 +00:00
|
|
|
if err != nil {
|
2022-12-06 00:26:35 +00:00
|
|
|
http.Error(w, "failed to upgrade session", 500)
|
2022-06-29 16:17:02 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2022-12-06 00:26:35 +00:00
|
|
|
err = s.serve(r.Context(), conn, sess)
|
2022-11-18 23:13:35 +00:00
|
|
|
if err != nil {
|
2022-12-06 00:26:35 +00:00
|
|
|
log.Println(err)
|
2022-11-18 23:13:35 +00:00
|
|
|
}
|
2022-06-29 16:17:02 +00:00
|
|
|
})
|
|
|
|
|
|
|
|
return s, nil
|
|
|
|
}
|
|
|
|
|
2022-11-07 09:55:04 +00:00
|
|
|
func (s *Server) runServe(ctx context.Context) (err error) {
|
|
|
|
return s.inner.ListenAndServe()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) runShutdown(ctx context.Context) (err error) {
|
|
|
|
<-ctx.Done()
|
|
|
|
s.inner.Close()
|
|
|
|
return ctx.Err()
|
|
|
|
}
|
|
|
|
|
2022-06-29 16:17:02 +00:00
|
|
|
func (s *Server) Run(ctx context.Context) (err error) {
|
2022-11-07 09:55:04 +00:00
|
|
|
return invoker.Run(ctx, s.runServe, s.runShutdown, s.sessions.Repeat)
|
2022-06-29 16:17:02 +00:00
|
|
|
}
|
2022-12-06 00:26:35 +00:00
|
|
|
|
|
|
|
func (s *Server) serve(ctx context.Context, conn quic.Connection, sess *webtransport.Session) (err error) {
|
|
|
|
defer func() {
|
|
|
|
if err != nil {
|
|
|
|
sess.CloseWithError(1, err.Error())
|
|
|
|
} else {
|
|
|
|
sess.CloseWithError(0, "end of broadcast")
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
ss, err := NewSession(conn, sess, s.media)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to create session: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
err = ss.Run(ctx)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("terminated session: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|