feat(mediaserver): use TLS for API endpoints

This commit is contained in:
Rob Watson 2025-03-27 08:18:22 +01:00
parent bdb77cb6bb
commit 2cde04728a
6 changed files with 202 additions and 25 deletions

View File

@ -53,7 +53,8 @@ func Run(ctx context.Context, params RunParams) error {
updateUI() updateUI()
// TODO: check for unused networks. // TODO: check for unused networks.
if exists, err := containerClient.ContainerRunning(ctx, container.AllContainers()); err != nil { var exists bool
if exists, err = containerClient.ContainerRunning(ctx, container.AllContainers()); err != nil {
return fmt.Errorf("check existing containers: %w", err) return fmt.Errorf("check existing containers: %w", err)
} else if exists { } else if exists {
if ui.ShowStartupCheckModal() { if ui.ShowStartupCheckModal() {
@ -74,11 +75,14 @@ func Run(ctx context.Context, params RunParams) error {
return errors.New("config: sources.rtmp.enabled must be set to true") return errors.New("config: sources.rtmp.enabled must be set to true")
} }
srv := mediaserver.StartActor(ctx, mediaserver.StartActorParams{ srv, err := mediaserver.StartActor(ctx, mediaserver.StartActorParams{
StreamKey: mediaserver.StreamKey(params.Config.Sources.RTMP.StreamKey), StreamKey: mediaserver.StreamKey(params.Config.Sources.RTMP.StreamKey),
ContainerClient: containerClient, ContainerClient: containerClient,
Logger: logger.With("component", "mediaserver"), Logger: logger.With("component", "mediaserver"),
}) })
if err != nil {
return fmt.Errorf("start mediaserver: %w", err)
}
defer srv.Close() defer srv.Close()
mp := multiplexer.NewActor(ctx, multiplexer.NewActorParams{ mp := multiplexer.NewActor(ctx, multiplexer.NewActorParams{

View File

@ -53,7 +53,7 @@ type Actor struct {
fetchIngressStateInterval time.Duration fetchIngressStateInterval time.Duration
pass string // password for the media server pass string // password for the media server
logger *slog.Logger logger *slog.Logger
httpClient *http.Client apiClient *http.Client
// mutable state // mutable state
state *domain.Source state *domain.Source
@ -74,10 +74,26 @@ type StartActorParams struct {
// StartActor starts a new media server actor. // StartActor starts a new media server actor.
// //
// Callers must consume the state channel exposed via [C]. // Callers must consume the state channel exposed via [C].
func StartActor(ctx context.Context, params StartActorParams) *Actor { func StartActor(ctx context.Context, params StartActorParams) (_ *Actor, err error) {
chanSize := cmp.Or(params.ChanSize, defaultChanSize)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer func() {
// if err is nil, the context should not be cancelled.
if err != nil {
cancel()
}
}()
tlsCert, tlsKey, err := generateTLSCert()
if err != nil {
cancel()
return nil, fmt.Errorf("generate TLS cert: %w", err)
}
apiClient, err := buildAPIClient(tlsCert)
if err != nil {
return nil, fmt.Errorf("build API client: %w", err)
}
chanSize := cmp.Or(params.ChanSize, defaultChanSize)
actor := &Actor{ actor := &Actor{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
@ -91,7 +107,7 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
stateC: make(chan domain.Source, chanSize), stateC: make(chan domain.Source, chanSize),
containerClient: params.ContainerClient, containerClient: params.ContainerClient,
logger: params.Logger, logger: params.Logger,
httpClient: &http.Client{Timeout: httpClientTimeout}, apiClient: apiClient,
} }
apiPortSpec := nat.Port(strconv.Itoa(actor.apiPort) + ":9997") apiPortSpec := nat.Port(strconv.Itoa(actor.apiPort) + ":9997")
@ -104,7 +120,6 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
LogDestinations: []string{"stdout"}, LogDestinations: []string{"stdout"},
AuthMethod: "internal", AuthMethod: "internal",
AuthInternalUsers: []User{ AuthInternalUsers: []User{
// TODO: TLS
{ {
User: "any", User: "any",
IPs: []string{}, // any IP IPs: []string{}, // any IP
@ -127,16 +142,17 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
Permissions: []UserPermission{{Action: "api"}}, Permissions: []UserPermission{{Action: "api"}},
}, },
}, },
API: true, API: true,
APIEncryption: true,
APIServerCert: "/etc/tls.crt",
APIServerKey: "/etc/tls.key",
Paths: map[string]Path{ Paths: map[string]Path{
string(actor.streamKey): { string(actor.streamKey): {Source: "publisher"},
Source: "publisher",
},
}, },
}, },
) )
if err != nil { // should never happen if err != nil { // should never happen
panic(fmt.Sprintf("failed to marshal config: %v", err)) return nil, fmt.Errorf("marshal config: %w", err)
} }
containerStateC, errC := params.ContainerClient.RunContainer( containerStateC, errC := params.ContainerClient.RunContainer(
@ -147,15 +163,13 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
ContainerConfig: &typescontainer.Config{ ContainerConfig: &typescontainer.Config{
Image: imageNameMediaMTX, Image: imageNameMediaMTX,
Hostname: "mediaserver", Hostname: "mediaserver",
Env: []string{ Labels: map[string]string{container.LabelComponent: componentName},
"MTX_LOGLEVEL=info", Env: []string{"TLS_CERT=" + string(tlsCert)},
"MTX_API=yes",
},
Labels: map[string]string{
container.LabelComponent: componentName,
},
Healthcheck: &typescontainer.HealthConfig{ Healthcheck: &typescontainer.HealthConfig{
Test: []string{"CMD", "curl", "-f", actor.pathsURL()}, Test: []string{
"CMD-SHELL",
`echo "$TLS_CERT" | curl --fail --silent --cacert /dev/stdin ` + actor.pathsURL() + ` || exit 1`,
},
Interval: time.Second * 10, Interval: time.Second * 10,
StartPeriod: time.Second * 2, StartPeriod: time.Second * 2,
StartInterval: time.Second * 2, StartInterval: time.Second * 2,
@ -174,6 +188,16 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
Payload: bytes.NewReader(cfg), Payload: bytes.NewReader(cfg),
Mode: 0600, Mode: 0600,
}, },
{
Path: "/etc/tls.crt",
Payload: bytes.NewReader(tlsCert),
Mode: 0600,
},
{
Path: "/etc/tls.key",
Payload: bytes.NewReader(tlsKey),
Mode: 0600,
},
}, },
}, },
) )
@ -183,7 +207,7 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
go actor.actorLoop(containerStateC, errC) go actor.actorLoop(containerStateC, errC)
return actor return actor, nil
} }
// C returns a channel that will receive the current state of the media server. // C returns a channel that will receive the current state of the media server.
@ -260,7 +284,7 @@ func (s *Actor) actorLoop(containerStateC <-chan domain.Container, errC <-chan e
sendState() sendState()
case <-fetchStateT.C: case <-fetchStateT.C:
ingressState, err := fetchIngressState(s.rtmpConnsURL(), s.streamKey, s.httpClient) ingressState, err := fetchIngressState(s.rtmpConnsURL(), s.streamKey, s.apiClient)
if err != nil { if err != nil {
s.logger.Error("Error fetching server state", "err", err) s.logger.Error("Error fetching server state", "err", err)
continue continue
@ -285,7 +309,7 @@ func (s *Actor) actorLoop(containerStateC <-chan domain.Container, errC <-chan e
continue continue
} }
if tracks, err := fetchTracks(s.pathsURL(), s.streamKey, s.httpClient); err != nil { if tracks, err := fetchTracks(s.pathsURL(), s.streamKey, s.apiClient); err != nil {
s.logger.Error("Error fetching tracks", "err", err) s.logger.Error("Error fetching tracks", "err", err)
resetFetchTracksT(3 * time.Second) resetFetchTracksT(3 * time.Second)
} else if len(tracks) == 0 { } else if len(tracks) == 0 {
@ -334,12 +358,12 @@ func (s *Actor) rtmpInternalURL() string {
// rtmpConnsURL returns the URL for fetching RTMP connections, accessible from // rtmpConnsURL returns the URL for fetching RTMP connections, accessible from
// the host. // the host.
func (s *Actor) rtmpConnsURL() string { func (s *Actor) rtmpConnsURL() string {
return fmt.Sprintf("http://api:%s@localhost:%d/v3/rtmpconns/list", s.pass, s.apiPort) return fmt.Sprintf("https://api:%s@localhost:%d/v3/rtmpconns/list", s.pass, s.apiPort)
} }
// pathsURL returns the URL for fetching paths, accessible from the host. // pathsURL returns the URL for fetching paths, accessible from the host.
func (s *Actor) pathsURL() string { func (s *Actor) pathsURL() string {
return fmt.Sprintf("http://api:%s@localhost:%d/v3/paths/list", s.pass, s.apiPort) return fmt.Sprintf("https://api:%s@localhost:%d/v3/paths/list", s.pass, s.apiPort)
} }
// shortID returns the first 12 characters of the given container ID. // shortID returns the first 12 characters of the given container ID.

View File

@ -1,7 +1,10 @@
package mediaserver package mediaserver
import ( import (
"crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -12,6 +15,35 @@ type httpClient interface {
Do(*http.Request) (*http.Response, error) Do(*http.Request) (*http.Response, error)
} }
func buildAPIClient(certPEM []byte) (*http.Client, error) {
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(certPEM) {
return nil, errors.New("failed to add certificate to pool")
}
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return fmt.Errorf("parse certificate: %w", err)
}
if _, err := cert.Verify(x509.VerifyOptions{Roots: certPool}); err != nil {
return fmt.Errorf("TLS verification: %w", err)
}
return nil
},
},
},
}, nil
}
const userAgent = "octoplex-client"
type apiResponse[T any] struct { type apiResponse[T any] struct {
Items []T `json:"items"` Items []T `json:"items"`
} }
@ -37,6 +69,7 @@ func fetchIngressState(apiURL string, streamKey StreamKey, httpClient httpClient
if err != nil { if err != nil {
return state, fmt.Errorf("new request: %w", err) return state, fmt.Errorf("new request: %w", err)
} }
req.Header.Set("User-Agent", userAgent)
httpResp, err := httpClient.Do(req) httpResp, err := httpClient.Do(req)
if err != nil { if err != nil {
@ -87,6 +120,7 @@ func fetchTracks(apiURL string, streamKey StreamKey, httpClient httpClient) ([]s
if err != nil { if err != nil {
return nil, fmt.Errorf("new request: %w", err) return nil, fmt.Errorf("new request: %w", err)
} }
req.Header.Set("User-Agent", userAgent)
httpResp, err := httpClient.Do(req) httpResp, err := httpClient.Do(req)
if err != nil { if err != nil {

View File

@ -14,6 +14,9 @@ type Config struct {
MetricsAddress string `yaml:"metricsAddress,omitempty"` MetricsAddress string `yaml:"metricsAddress,omitempty"`
API bool `yaml:"api,omitempty"` API bool `yaml:"api,omitempty"`
APIAddr bool `yaml:"apiAddress,omitempty"` APIAddr bool `yaml:"apiAddress,omitempty"`
APIEncryption bool `yaml:"apiEncryption,omitempty"`
APIServerCert string `yaml:"apiServerCert,omitempty"`
APIServerKey string `yaml:"apiServerKey,omitempty"`
RTMP bool `yaml:"rtmp,omitempty"` RTMP bool `yaml:"rtmp,omitempty"`
RTMPAddress string `yaml:"rtmpAddress,omitempty"` RTMPAddress string `yaml:"rtmpAddress,omitempty"`
HLS bool `yaml:"hls"` HLS bool `yaml:"hls"`

View File

@ -0,0 +1,67 @@
package mediaserver
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"time"
)
type (
tlsCert []byte
tlsKey []byte
)
// generateTLSCert generates a self-signed TLS certificate and private key.
func generateTLSCert() (tlsCert, tlsKey, error) {
privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
return nil, nil, err
}
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, nil, err
}
now := time.Now()
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"octoplex.netflux.io"},
},
NotBefore: now,
NotAfter: now.Add(5 * 365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey)
if err != nil {
return nil, nil, err
}
var certPEM, keyPEM bytes.Buffer
if err = pem.Encode(&certPEM, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
return nil, nil, err
}
privKeyDER, err := x509.MarshalECPrivateKey(privKey)
if err != nil {
return nil, nil, err
}
if err := pem.Encode(&keyPEM, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privKeyDER}); err != nil {
return nil, nil, err
}
return certPEM.Bytes(), keyPEM.Bytes(), nil
}

View File

@ -0,0 +1,45 @@
package mediaserver
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/pem"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGenerateTLSCert(t *testing.T) {
certPEM, keyPEM, err := generateTLSCert()
require.NoError(t, err)
require.NotEmpty(t, certPEM)
require.NotEmpty(t, keyPEM)
block, _ := pem.Decode(certPEM)
require.NotNil(t, block, "failed to decode certificate PEM")
cert, err := x509.ParseCertificate(block.Bytes)
require.NoError(t, err)
assert.Equal(t, "octoplex.netflux.io", cert.Subject.Organization[0])
assert.Greater(t, cert.NotBefore, time.Now().Add(-time.Second), "not before should be in the future")
assert.Greater(t, cert.NotAfter, time.Now().Add(4*365*24*time.Hour), "not after should be a long time in the future")
// BitLen does not count leading zeroes, so the length will not always be 128 bits:
assert.GreaterOrEqual(t, cert.SerialNumber.BitLen(), 100, "serial number should be around 128 bits")
assert.True(t, cert.BasicConstraintsValid, "basic constraints should be valid")
assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageServerAuth)
assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageClientAuth)
block, _ = pem.Decode(keyPEM)
require.NotNil(t, block, "failed to decode private key PEM")
privKey, err := x509.ParseECPrivateKey(block.Bytes)
require.NoError(t, err)
assert.IsType(t, &ecdsa.PrivateKey{}, privKey, "expected ECDSA private key")
assert.True(t, privKey.PublicKey.Equal(cert.PublicKey), "private key should match the certificate's public key")
}