fix(mediaserver): handle custom hostname with self-signed certs

This commit is contained in:
Rob Watson 2025-04-20 18:59:27 +02:00
parent 4a863a3212
commit 7afa84505e
4 changed files with 82 additions and 56 deletions

View File

@ -296,6 +296,74 @@ func testIntegration(t *testing.T, mediaServerConfig config.MediaServerSource) {
<-done <-done
} }
func TestIntegrationCustomHost(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute)
defer cancel()
logger := testhelpers.NewTestLogger(t).With("component", "integration")
dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation())
require.NoError(t, err)
configService := setupConfigService(t, config.Config{
Sources: config.Sources{
MediaServer: config.MediaServerSource{
Host: "rtmp.example.com",
RTMP: config.RTMPSource{Enabled: true},
},
},
})
screen, screenCaptureC, getContents := setupSimulationScreen(t)
done := make(chan struct{})
go func() {
defer func() {
done <- struct{}{}
}()
require.NoError(t, app.Run(ctx, buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)))
}()
time.Sleep(time.Second)
sendKey(t, screen, tcell.KeyF1, ' ')
require.EventuallyWithT(
t,
func(t *assert.CollectT) {
assert.True(t, contentsIncludes(getContents(), "rtmp://rtmp.example.com:1935/live"), "expected to see custom host name")
},
waitTime,
time.Second,
"expected to see custom host name",
)
printScreen(t, getContents, "Ater opening the app with a custom host name")
require.EventuallyWithT(
t,
func(c *assert.CollectT) {
conn, err := tls.Dial("tcp", "localhost:9997", &tls.Config{
InsecureSkipVerify: true,
})
require.NoError(c, err)
require.Nil(
c,
conn.
ConnectionState().
PeerCertificates[0].
VerifyHostname("rtmp.example.com"),
"expected to verify custom host name",
)
},
waitTime,
time.Second,
"expected to connect to API using self-signed TLS cert with custom host name",
)
cancel()
<-done
}
func TestIntegrationCustomTLSCerts(t *testing.T) { func TestIntegrationCustomTLSCerts(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute) ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute)
defer cancel() defer cancel()
@ -335,7 +403,6 @@ func TestIntegrationCustomTLSCerts(t *testing.T) {
block, _ := pem.Decode(certPEM) block, _ := pem.Decode(certPEM)
require.NotNil(c, block, "failed to decode PEM block containing certificate") require.NotNil(c, block, "failed to decode PEM block containing certificate")
require.True(c, block.Type == "CERTIFICATE", "expected PEM block to be a certificate") require.True(c, block.Type == "CERTIFICATE", "expected PEM block to be a certificate")
certDERBytes := block.Bytes
rootCAs := x509.NewCertPool() rootCAs := x509.NewCertPool()
require.True(c, rootCAs.AppendCertsFromPEM(certPEM), "failed to append cert to root CA pool") require.True(c, rootCAs.AppendCertsFromPEM(certPEM), "failed to append cert to root CA pool")
@ -347,12 +414,10 @@ func TestIntegrationCustomTLSCerts(t *testing.T) {
}) })
require.NoError(c, err) require.NoError(c, err)
state := conn.ConnectionState() peerCert := conn.ConnectionState().PeerCertificates[0]
peerCert := state.PeerCertificates[0] wantCert, err := x509.ParseCertificate(block.Bytes)
expectedCert, err := x509.ParseCertificate(certDERBytes)
require.NoError(c, err) require.NoError(c, err)
require.True(c, peerCert.Equal(expectedCert), "expected peer certificate to match the expected certificate") require.True(c, peerCert.Equal(wantCert), "expected peer certificate to match the expected certificate")
}, },
waitTime, waitTime,
time.Second, time.Second,
@ -365,52 +430,6 @@ func TestIntegrationCustomTLSCerts(t *testing.T) {
<-done <-done
} }
func TestIntegrationCustomRTMPURL(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute)
defer cancel()
logger := testhelpers.NewTestLogger(t).With("component", "integration")
dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation())
require.NoError(t, err)
configService := setupConfigService(t, config.Config{
Sources: config.Sources{
MediaServer: config.MediaServerSource{
Host: "rtmp.live.tv",
RTMP: config.RTMPSource{Enabled: true},
},
},
})
screen, screenCaptureC, getContents := setupSimulationScreen(t)
done := make(chan struct{})
go func() {
defer func() {
done <- struct{}{}
}()
require.NoError(t, app.Run(ctx, buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)))
}()
time.Sleep(time.Second)
sendKey(t, screen, tcell.KeyF1, ' ')
require.EventuallyWithT(
t,
func(t *assert.CollectT) {
assert.True(t, contentsIncludes(getContents(), "rtmp://rtmp.live.tv:1935/live"), "expected to see custom host name")
},
waitTime,
time.Second,
"expected to see custom host name",
)
printScreen(t, getContents, "Ater opening the app with a custom host name")
cancel()
<-done
}
func TestIntegrationRestartDestination(t *testing.T) { func TestIntegrationRestartDestination(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute) ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute)
defer cancel() defer cancel()

View File

@ -98,7 +98,12 @@ type OptionalNetAddr struct {
// //
// Callers must consume the state channel exposed via [C]. // Callers must consume the state channel exposed via [C].
func NewActor(ctx context.Context, params NewActorParams) (_ *Actor, err error) { func NewActor(ctx context.Context, params NewActorParams) (_ *Actor, err error) {
keyPairInternal, err := generateTLSCert() dnsNames := []string{"localhost"}
if params.Host != "" {
dnsNames = append(dnsNames, params.Host)
}
keyPairInternal, err := generateTLSCert(dnsNames...)
if err != nil { if err != nil {
return nil, fmt.Errorf("generate TLS cert: %w", err) return nil, fmt.Errorf("generate TLS cert: %w", err)
} }

View File

@ -15,7 +15,7 @@ import (
) )
// generateTLSCert generates a self-signed TLS certificate and private key. // generateTLSCert generates a self-signed TLS certificate and private key.
func generateTLSCert() (domain.KeyPair, error) { func generateTLSCert(dnsNames ...string) (domain.KeyPair, error) {
privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil { if err != nil {
return domain.KeyPair{}, err return domain.KeyPair{}, err
@ -37,7 +37,7 @@ func generateTLSCert() (domain.KeyPair, error) {
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true, BasicConstraintsValid: true,
DNSNames: []string{"localhost"}, DNSNames: dnsNames,
} }
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey) certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey)

View File

@ -12,7 +12,7 @@ import (
) )
func TestGenerateTLSCert(t *testing.T) { func TestGenerateTLSCert(t *testing.T) {
keyPair, err := generateTLSCert() keyPair, err := generateTLSCert("localhost", "rtmp.example.com")
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, keyPair.Cert) require.NotEmpty(t, keyPair.Cert)
require.NotEmpty(t, keyPair.Key) require.NotEmpty(t, keyPair.Key)
@ -33,6 +33,8 @@ func TestGenerateTLSCert(t *testing.T) {
assert.True(t, cert.BasicConstraintsValid, "basic constraints should be valid") assert.True(t, cert.BasicConstraintsValid, "basic constraints should be valid")
assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageServerAuth) assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageServerAuth)
assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageClientAuth) assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageClientAuth)
assert.Contains(t, cert.DNSNames, "localhost", "DNS names should include localhost")
assert.Contains(t, cert.DNSNames, "rtmp.example.com", "DNS names should include rtmp.example.com")
block, _ = pem.Decode(keyPair.Key) block, _ = pem.Decode(keyPair.Key)
require.NotNil(t, block, "failed to decode private key PEM") require.NotNil(t, block, "failed to decode private key PEM")