From 2cde04728a86ec6fd9c77e577e461bdc59dbcd1e Mon Sep 17 00:00:00 2001 From: Rob Watson Date: Thu, 27 Mar 2025 08:18:22 +0100 Subject: [PATCH] feat(mediaserver): use TLS for API endpoints --- internal/app/app.go | 8 +++- internal/mediaserver/actor.go | 70 +++++++++++++++++++++----------- internal/mediaserver/api.go | 34 ++++++++++++++++ internal/mediaserver/config.go | 3 ++ internal/mediaserver/tls.go | 67 ++++++++++++++++++++++++++++++ internal/mediaserver/tls_test.go | 45 ++++++++++++++++++++ 6 files changed, 202 insertions(+), 25 deletions(-) create mode 100644 internal/mediaserver/tls.go create mode 100644 internal/mediaserver/tls_test.go diff --git a/internal/app/app.go b/internal/app/app.go index 22b250e..e5a64a6 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -53,7 +53,8 @@ func Run(ctx context.Context, params RunParams) error { updateUI() // TODO: check for unused networks. - if exists, err := containerClient.ContainerRunning(ctx, container.AllContainers()); err != nil { + var exists bool + if exists, err = containerClient.ContainerRunning(ctx, container.AllContainers()); err != nil { return fmt.Errorf("check existing containers: %w", err) } else if exists { if ui.ShowStartupCheckModal() { @@ -74,11 +75,14 @@ func Run(ctx context.Context, params RunParams) error { return errors.New("config: sources.rtmp.enabled must be set to true") } - srv := mediaserver.StartActor(ctx, mediaserver.StartActorParams{ + srv, err := mediaserver.StartActor(ctx, mediaserver.StartActorParams{ StreamKey: mediaserver.StreamKey(params.Config.Sources.RTMP.StreamKey), ContainerClient: containerClient, Logger: logger.With("component", "mediaserver"), }) + if err != nil { + return fmt.Errorf("start mediaserver: %w", err) + } defer srv.Close() mp := multiplexer.NewActor(ctx, multiplexer.NewActorParams{ diff --git a/internal/mediaserver/actor.go b/internal/mediaserver/actor.go index cc64e4f..fa11d29 100644 --- a/internal/mediaserver/actor.go +++ b/internal/mediaserver/actor.go @@ -53,7 +53,7 @@ type Actor struct { fetchIngressStateInterval time.Duration pass string // password for the media server logger *slog.Logger - httpClient *http.Client + apiClient *http.Client // mutable state state *domain.Source @@ -74,10 +74,26 @@ type StartActorParams struct { // StartActor starts a new media server actor. // // Callers must consume the state channel exposed via [C]. -func StartActor(ctx context.Context, params StartActorParams) *Actor { - chanSize := cmp.Or(params.ChanSize, defaultChanSize) +func StartActor(ctx context.Context, params StartActorParams) (_ *Actor, err error) { ctx, cancel := context.WithCancel(ctx) + defer func() { + // if err is nil, the context should not be cancelled. + if err != nil { + cancel() + } + }() + tlsCert, tlsKey, err := generateTLSCert() + if err != nil { + cancel() + return nil, fmt.Errorf("generate TLS cert: %w", err) + } + apiClient, err := buildAPIClient(tlsCert) + if err != nil { + return nil, fmt.Errorf("build API client: %w", err) + } + + chanSize := cmp.Or(params.ChanSize, defaultChanSize) actor := &Actor{ ctx: ctx, cancel: cancel, @@ -91,7 +107,7 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor { stateC: make(chan domain.Source, chanSize), containerClient: params.ContainerClient, logger: params.Logger, - httpClient: &http.Client{Timeout: httpClientTimeout}, + apiClient: apiClient, } apiPortSpec := nat.Port(strconv.Itoa(actor.apiPort) + ":9997") @@ -104,7 +120,6 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor { LogDestinations: []string{"stdout"}, AuthMethod: "internal", AuthInternalUsers: []User{ - // TODO: TLS { User: "any", IPs: []string{}, // any IP @@ -127,16 +142,17 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor { Permissions: []UserPermission{{Action: "api"}}, }, }, - API: true, + API: true, + APIEncryption: true, + APIServerCert: "/etc/tls.crt", + APIServerKey: "/etc/tls.key", Paths: map[string]Path{ - string(actor.streamKey): { - Source: "publisher", - }, + string(actor.streamKey): {Source: "publisher"}, }, }, ) if err != nil { // should never happen - panic(fmt.Sprintf("failed to marshal config: %v", err)) + return nil, fmt.Errorf("marshal config: %w", err) } containerStateC, errC := params.ContainerClient.RunContainer( @@ -147,15 +163,13 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor { ContainerConfig: &typescontainer.Config{ Image: imageNameMediaMTX, Hostname: "mediaserver", - Env: []string{ - "MTX_LOGLEVEL=info", - "MTX_API=yes", - }, - Labels: map[string]string{ - container.LabelComponent: componentName, - }, + Labels: map[string]string{container.LabelComponent: componentName}, + Env: []string{"TLS_CERT=" + string(tlsCert)}, Healthcheck: &typescontainer.HealthConfig{ - Test: []string{"CMD", "curl", "-f", actor.pathsURL()}, + Test: []string{ + "CMD-SHELL", + `echo "$TLS_CERT" | curl --fail --silent --cacert /dev/stdin ` + actor.pathsURL() + ` || exit 1`, + }, Interval: time.Second * 10, StartPeriod: time.Second * 2, StartInterval: time.Second * 2, @@ -174,6 +188,16 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor { Payload: bytes.NewReader(cfg), Mode: 0600, }, + { + Path: "/etc/tls.crt", + Payload: bytes.NewReader(tlsCert), + Mode: 0600, + }, + { + Path: "/etc/tls.key", + Payload: bytes.NewReader(tlsKey), + Mode: 0600, + }, }, }, ) @@ -183,7 +207,7 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor { go actor.actorLoop(containerStateC, errC) - return actor + return actor, nil } // C returns a channel that will receive the current state of the media server. @@ -260,7 +284,7 @@ func (s *Actor) actorLoop(containerStateC <-chan domain.Container, errC <-chan e sendState() case <-fetchStateT.C: - ingressState, err := fetchIngressState(s.rtmpConnsURL(), s.streamKey, s.httpClient) + ingressState, err := fetchIngressState(s.rtmpConnsURL(), s.streamKey, s.apiClient) if err != nil { s.logger.Error("Error fetching server state", "err", err) continue @@ -285,7 +309,7 @@ func (s *Actor) actorLoop(containerStateC <-chan domain.Container, errC <-chan e continue } - if tracks, err := fetchTracks(s.pathsURL(), s.streamKey, s.httpClient); err != nil { + if tracks, err := fetchTracks(s.pathsURL(), s.streamKey, s.apiClient); err != nil { s.logger.Error("Error fetching tracks", "err", err) resetFetchTracksT(3 * time.Second) } else if len(tracks) == 0 { @@ -334,12 +358,12 @@ func (s *Actor) rtmpInternalURL() string { // rtmpConnsURL returns the URL for fetching RTMP connections, accessible from // the host. func (s *Actor) rtmpConnsURL() string { - return fmt.Sprintf("http://api:%s@localhost:%d/v3/rtmpconns/list", s.pass, s.apiPort) + return fmt.Sprintf("https://api:%s@localhost:%d/v3/rtmpconns/list", s.pass, s.apiPort) } // pathsURL returns the URL for fetching paths, accessible from the host. func (s *Actor) pathsURL() string { - return fmt.Sprintf("http://api:%s@localhost:%d/v3/paths/list", s.pass, s.apiPort) + return fmt.Sprintf("https://api:%s@localhost:%d/v3/paths/list", s.pass, s.apiPort) } // shortID returns the first 12 characters of the given container ID. diff --git a/internal/mediaserver/api.go b/internal/mediaserver/api.go index e216ea1..df7b30e 100644 --- a/internal/mediaserver/api.go +++ b/internal/mediaserver/api.go @@ -1,7 +1,10 @@ package mediaserver import ( + "crypto/tls" + "crypto/x509" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -12,6 +15,35 @@ type httpClient interface { Do(*http.Request) (*http.Response, error) } +func buildAPIClient(certPEM []byte) (*http.Client, error) { + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(certPEM) { + return nil, errors.New("failed to add certificate to pool") + } + + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return fmt.Errorf("parse certificate: %w", err) + } + + if _, err := cert.Verify(x509.VerifyOptions{Roots: certPool}); err != nil { + return fmt.Errorf("TLS verification: %w", err) + } + + return nil + }, + }, + }, + }, nil +} + +const userAgent = "octoplex-client" + type apiResponse[T any] struct { Items []T `json:"items"` } @@ -37,6 +69,7 @@ func fetchIngressState(apiURL string, streamKey StreamKey, httpClient httpClient if err != nil { return state, fmt.Errorf("new request: %w", err) } + req.Header.Set("User-Agent", userAgent) httpResp, err := httpClient.Do(req) if err != nil { @@ -87,6 +120,7 @@ func fetchTracks(apiURL string, streamKey StreamKey, httpClient httpClient) ([]s if err != nil { return nil, fmt.Errorf("new request: %w", err) } + req.Header.Set("User-Agent", userAgent) httpResp, err := httpClient.Do(req) if err != nil { diff --git a/internal/mediaserver/config.go b/internal/mediaserver/config.go index 5a9beec..45713b6 100644 --- a/internal/mediaserver/config.go +++ b/internal/mediaserver/config.go @@ -14,6 +14,9 @@ type Config struct { MetricsAddress string `yaml:"metricsAddress,omitempty"` API bool `yaml:"api,omitempty"` APIAddr bool `yaml:"apiAddress,omitempty"` + APIEncryption bool `yaml:"apiEncryption,omitempty"` + APIServerCert string `yaml:"apiServerCert,omitempty"` + APIServerKey string `yaml:"apiServerKey,omitempty"` RTMP bool `yaml:"rtmp,omitempty"` RTMPAddress string `yaml:"rtmpAddress,omitempty"` HLS bool `yaml:"hls"` diff --git a/internal/mediaserver/tls.go b/internal/mediaserver/tls.go new file mode 100644 index 0000000..bff6139 --- /dev/null +++ b/internal/mediaserver/tls.go @@ -0,0 +1,67 @@ +package mediaserver + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "time" +) + +type ( + tlsCert []byte + tlsKey []byte +) + +// generateTLSCert generates a self-signed TLS certificate and private key. +func generateTLSCert() (tlsCert, tlsKey, error) { + privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + return nil, nil, err + } + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, nil, err + } + + now := time.Now() + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"octoplex.netflux.io"}, + }, + NotBefore: now, + NotAfter: now.Add(5 * 365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey) + if err != nil { + return nil, nil, err + } + + var certPEM, keyPEM bytes.Buffer + + if err = pem.Encode(&certPEM, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + return nil, nil, err + } + + privKeyDER, err := x509.MarshalECPrivateKey(privKey) + if err != nil { + return nil, nil, err + } + + if err := pem.Encode(&keyPEM, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privKeyDER}); err != nil { + return nil, nil, err + } + + return certPEM.Bytes(), keyPEM.Bytes(), nil +} diff --git a/internal/mediaserver/tls_test.go b/internal/mediaserver/tls_test.go new file mode 100644 index 0000000..94c2113 --- /dev/null +++ b/internal/mediaserver/tls_test.go @@ -0,0 +1,45 @@ +package mediaserver + +import ( + "crypto/ecdsa" + "crypto/x509" + "encoding/pem" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateTLSCert(t *testing.T) { + certPEM, keyPEM, err := generateTLSCert() + require.NoError(t, err) + require.NotEmpty(t, certPEM) + require.NotEmpty(t, keyPEM) + + block, _ := pem.Decode(certPEM) + require.NotNil(t, block, "failed to decode certificate PEM") + + cert, err := x509.ParseCertificate(block.Bytes) + require.NoError(t, err) + + assert.Equal(t, "octoplex.netflux.io", cert.Subject.Organization[0]) + assert.Greater(t, cert.NotBefore, time.Now().Add(-time.Second), "not before should be in the future") + assert.Greater(t, cert.NotAfter, time.Now().Add(4*365*24*time.Hour), "not after should be a long time in the future") + + // BitLen does not count leading zeroes, so the length will not always be 128 bits: + assert.GreaterOrEqual(t, cert.SerialNumber.BitLen(), 100, "serial number should be around 128 bits") + + assert.True(t, cert.BasicConstraintsValid, "basic constraints should be valid") + assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageServerAuth) + assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageClientAuth) + + block, _ = pem.Decode(keyPEM) + require.NotNil(t, block, "failed to decode private key PEM") + + privKey, err := x509.ParseECPrivateKey(block.Bytes) + require.NoError(t, err) + assert.IsType(t, &ecdsa.PrivateKey{}, privKey, "expected ECDSA private key") + + assert.True(t, privKey.PublicKey.Equal(cert.PublicKey), "private key should match the certificate's public key") +}