feat(mediaserver): use TLS for API endpoints
This commit is contained in:
parent
bdb77cb6bb
commit
2cde04728a
@ -53,7 +53,8 @@ func Run(ctx context.Context, params RunParams) error {
|
||||
updateUI()
|
||||
|
||||
// TODO: check for unused networks.
|
||||
if exists, err := containerClient.ContainerRunning(ctx, container.AllContainers()); err != nil {
|
||||
var exists bool
|
||||
if exists, err = containerClient.ContainerRunning(ctx, container.AllContainers()); err != nil {
|
||||
return fmt.Errorf("check existing containers: %w", err)
|
||||
} else if exists {
|
||||
if ui.ShowStartupCheckModal() {
|
||||
@ -74,11 +75,14 @@ func Run(ctx context.Context, params RunParams) error {
|
||||
return errors.New("config: sources.rtmp.enabled must be set to true")
|
||||
}
|
||||
|
||||
srv := mediaserver.StartActor(ctx, mediaserver.StartActorParams{
|
||||
srv, err := mediaserver.StartActor(ctx, mediaserver.StartActorParams{
|
||||
StreamKey: mediaserver.StreamKey(params.Config.Sources.RTMP.StreamKey),
|
||||
ContainerClient: containerClient,
|
||||
Logger: logger.With("component", "mediaserver"),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("start mediaserver: %w", err)
|
||||
}
|
||||
defer srv.Close()
|
||||
|
||||
mp := multiplexer.NewActor(ctx, multiplexer.NewActorParams{
|
||||
|
@ -53,7 +53,7 @@ type Actor struct {
|
||||
fetchIngressStateInterval time.Duration
|
||||
pass string // password for the media server
|
||||
logger *slog.Logger
|
||||
httpClient *http.Client
|
||||
apiClient *http.Client
|
||||
|
||||
// mutable state
|
||||
state *domain.Source
|
||||
@ -74,10 +74,26 @@ type StartActorParams struct {
|
||||
// StartActor starts a new media server actor.
|
||||
//
|
||||
// Callers must consume the state channel exposed via [C].
|
||||
func StartActor(ctx context.Context, params StartActorParams) *Actor {
|
||||
chanSize := cmp.Or(params.ChanSize, defaultChanSize)
|
||||
func StartActor(ctx context.Context, params StartActorParams) (_ *Actor, err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
// if err is nil, the context should not be cancelled.
|
||||
if err != nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
tlsCert, tlsKey, err := generateTLSCert()
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("generate TLS cert: %w", err)
|
||||
}
|
||||
apiClient, err := buildAPIClient(tlsCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build API client: %w", err)
|
||||
}
|
||||
|
||||
chanSize := cmp.Or(params.ChanSize, defaultChanSize)
|
||||
actor := &Actor{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
@ -91,7 +107,7 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
|
||||
stateC: make(chan domain.Source, chanSize),
|
||||
containerClient: params.ContainerClient,
|
||||
logger: params.Logger,
|
||||
httpClient: &http.Client{Timeout: httpClientTimeout},
|
||||
apiClient: apiClient,
|
||||
}
|
||||
|
||||
apiPortSpec := nat.Port(strconv.Itoa(actor.apiPort) + ":9997")
|
||||
@ -104,7 +120,6 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
|
||||
LogDestinations: []string{"stdout"},
|
||||
AuthMethod: "internal",
|
||||
AuthInternalUsers: []User{
|
||||
// TODO: TLS
|
||||
{
|
||||
User: "any",
|
||||
IPs: []string{}, // any IP
|
||||
@ -127,16 +142,17 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
|
||||
Permissions: []UserPermission{{Action: "api"}},
|
||||
},
|
||||
},
|
||||
API: true,
|
||||
API: true,
|
||||
APIEncryption: true,
|
||||
APIServerCert: "/etc/tls.crt",
|
||||
APIServerKey: "/etc/tls.key",
|
||||
Paths: map[string]Path{
|
||||
string(actor.streamKey): {
|
||||
Source: "publisher",
|
||||
},
|
||||
string(actor.streamKey): {Source: "publisher"},
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil { // should never happen
|
||||
panic(fmt.Sprintf("failed to marshal config: %v", err))
|
||||
return nil, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
containerStateC, errC := params.ContainerClient.RunContainer(
|
||||
@ -147,15 +163,13 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
|
||||
ContainerConfig: &typescontainer.Config{
|
||||
Image: imageNameMediaMTX,
|
||||
Hostname: "mediaserver",
|
||||
Env: []string{
|
||||
"MTX_LOGLEVEL=info",
|
||||
"MTX_API=yes",
|
||||
},
|
||||
Labels: map[string]string{
|
||||
container.LabelComponent: componentName,
|
||||
},
|
||||
Labels: map[string]string{container.LabelComponent: componentName},
|
||||
Env: []string{"TLS_CERT=" + string(tlsCert)},
|
||||
Healthcheck: &typescontainer.HealthConfig{
|
||||
Test: []string{"CMD", "curl", "-f", actor.pathsURL()},
|
||||
Test: []string{
|
||||
"CMD-SHELL",
|
||||
`echo "$TLS_CERT" | curl --fail --silent --cacert /dev/stdin ` + actor.pathsURL() + ` || exit 1`,
|
||||
},
|
||||
Interval: time.Second * 10,
|
||||
StartPeriod: time.Second * 2,
|
||||
StartInterval: time.Second * 2,
|
||||
@ -174,6 +188,16 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
|
||||
Payload: bytes.NewReader(cfg),
|
||||
Mode: 0600,
|
||||
},
|
||||
{
|
||||
Path: "/etc/tls.crt",
|
||||
Payload: bytes.NewReader(tlsCert),
|
||||
Mode: 0600,
|
||||
},
|
||||
{
|
||||
Path: "/etc/tls.key",
|
||||
Payload: bytes.NewReader(tlsKey),
|
||||
Mode: 0600,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
@ -183,7 +207,7 @@ func StartActor(ctx context.Context, params StartActorParams) *Actor {
|
||||
|
||||
go actor.actorLoop(containerStateC, errC)
|
||||
|
||||
return actor
|
||||
return actor, nil
|
||||
}
|
||||
|
||||
// C returns a channel that will receive the current state of the media server.
|
||||
@ -260,7 +284,7 @@ func (s *Actor) actorLoop(containerStateC <-chan domain.Container, errC <-chan e
|
||||
|
||||
sendState()
|
||||
case <-fetchStateT.C:
|
||||
ingressState, err := fetchIngressState(s.rtmpConnsURL(), s.streamKey, s.httpClient)
|
||||
ingressState, err := fetchIngressState(s.rtmpConnsURL(), s.streamKey, s.apiClient)
|
||||
if err != nil {
|
||||
s.logger.Error("Error fetching server state", "err", err)
|
||||
continue
|
||||
@ -285,7 +309,7 @@ func (s *Actor) actorLoop(containerStateC <-chan domain.Container, errC <-chan e
|
||||
continue
|
||||
}
|
||||
|
||||
if tracks, err := fetchTracks(s.pathsURL(), s.streamKey, s.httpClient); err != nil {
|
||||
if tracks, err := fetchTracks(s.pathsURL(), s.streamKey, s.apiClient); err != nil {
|
||||
s.logger.Error("Error fetching tracks", "err", err)
|
||||
resetFetchTracksT(3 * time.Second)
|
||||
} else if len(tracks) == 0 {
|
||||
@ -334,12 +358,12 @@ func (s *Actor) rtmpInternalURL() string {
|
||||
// rtmpConnsURL returns the URL for fetching RTMP connections, accessible from
|
||||
// the host.
|
||||
func (s *Actor) rtmpConnsURL() string {
|
||||
return fmt.Sprintf("http://api:%s@localhost:%d/v3/rtmpconns/list", s.pass, s.apiPort)
|
||||
return fmt.Sprintf("https://api:%s@localhost:%d/v3/rtmpconns/list", s.pass, s.apiPort)
|
||||
}
|
||||
|
||||
// pathsURL returns the URL for fetching paths, accessible from the host.
|
||||
func (s *Actor) pathsURL() string {
|
||||
return fmt.Sprintf("http://api:%s@localhost:%d/v3/paths/list", s.pass, s.apiPort)
|
||||
return fmt.Sprintf("https://api:%s@localhost:%d/v3/paths/list", s.pass, s.apiPort)
|
||||
}
|
||||
|
||||
// shortID returns the first 12 characters of the given container ID.
|
||||
|
@ -1,7 +1,10 @@
|
||||
package mediaserver
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -12,6 +15,35 @@ type httpClient interface {
|
||||
Do(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func buildAPIClient(certPEM []byte) (*http.Client, error) {
|
||||
certPool := x509.NewCertPool()
|
||||
if !certPool.AppendCertsFromPEM(certPEM) {
|
||||
return nil, errors.New("failed to add certificate to pool")
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse certificate: %w", err)
|
||||
}
|
||||
|
||||
if _, err := cert.Verify(x509.VerifyOptions{Roots: certPool}); err != nil {
|
||||
return fmt.Errorf("TLS verification: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
const userAgent = "octoplex-client"
|
||||
|
||||
type apiResponse[T any] struct {
|
||||
Items []T `json:"items"`
|
||||
}
|
||||
@ -37,6 +69,7 @@ func fetchIngressState(apiURL string, streamKey StreamKey, httpClient httpClient
|
||||
if err != nil {
|
||||
return state, fmt.Errorf("new request: %w", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
httpResp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
@ -87,6 +120,7 @@ func fetchTracks(apiURL string, streamKey StreamKey, httpClient httpClient) ([]s
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request: %w", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
httpResp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
|
@ -14,6 +14,9 @@ type Config struct {
|
||||
MetricsAddress string `yaml:"metricsAddress,omitempty"`
|
||||
API bool `yaml:"api,omitempty"`
|
||||
APIAddr bool `yaml:"apiAddress,omitempty"`
|
||||
APIEncryption bool `yaml:"apiEncryption,omitempty"`
|
||||
APIServerCert string `yaml:"apiServerCert,omitempty"`
|
||||
APIServerKey string `yaml:"apiServerKey,omitempty"`
|
||||
RTMP bool `yaml:"rtmp,omitempty"`
|
||||
RTMPAddress string `yaml:"rtmpAddress,omitempty"`
|
||||
HLS bool `yaml:"hls"`
|
||||
|
67
internal/mediaserver/tls.go
Normal file
67
internal/mediaserver/tls.go
Normal file
@ -0,0 +1,67 @@
|
||||
package mediaserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
tlsCert []byte
|
||||
tlsKey []byte
|
||||
)
|
||||
|
||||
// generateTLSCert generates a self-signed TLS certificate and private key.
|
||||
func generateTLSCert() (tlsCert, tlsKey, error) {
|
||||
privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"octoplex.netflux.io"},
|
||||
},
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(5 * 365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var certPEM, keyPEM bytes.Buffer
|
||||
|
||||
if err = pem.Encode(&certPEM, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
privKeyDER, err := x509.MarshalECPrivateKey(privKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := pem.Encode(&keyPEM, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privKeyDER}); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return certPEM.Bytes(), keyPEM.Bytes(), nil
|
||||
}
|
45
internal/mediaserver/tls_test.go
Normal file
45
internal/mediaserver/tls_test.go
Normal file
@ -0,0 +1,45 @@
|
||||
package mediaserver
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateTLSCert(t *testing.T) {
|
||||
certPEM, keyPEM, err := generateTLSCert()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, certPEM)
|
||||
require.NotEmpty(t, keyPEM)
|
||||
|
||||
block, _ := pem.Decode(certPEM)
|
||||
require.NotNil(t, block, "failed to decode certificate PEM")
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "octoplex.netflux.io", cert.Subject.Organization[0])
|
||||
assert.Greater(t, cert.NotBefore, time.Now().Add(-time.Second), "not before should be in the future")
|
||||
assert.Greater(t, cert.NotAfter, time.Now().Add(4*365*24*time.Hour), "not after should be a long time in the future")
|
||||
|
||||
// BitLen does not count leading zeroes, so the length will not always be 128 bits:
|
||||
assert.GreaterOrEqual(t, cert.SerialNumber.BitLen(), 100, "serial number should be around 128 bits")
|
||||
|
||||
assert.True(t, cert.BasicConstraintsValid, "basic constraints should be valid")
|
||||
assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageServerAuth)
|
||||
assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageClientAuth)
|
||||
|
||||
block, _ = pem.Decode(keyPEM)
|
||||
require.NotNil(t, block, "failed to decode private key PEM")
|
||||
|
||||
privKey, err := x509.ParseECPrivateKey(block.Bytes)
|
||||
require.NoError(t, err)
|
||||
assert.IsType(t, &ecdsa.PrivateKey{}, privKey, "expected ECDSA private key")
|
||||
|
||||
assert.True(t, privKey.PublicKey.Equal(cert.PublicKey), "private key should match the certificate's public key")
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user