diff --git a/internal/app/integration_test.go b/internal/app/integration_test.go index a6eaf05..9b47d3f 100644 --- a/internal/app/integration_test.go +++ b/internal/app/integration_test.go @@ -296,6 +296,74 @@ func testIntegration(t *testing.T, mediaServerConfig config.MediaServerSource) { <-done } +func TestIntegrationCustomHost(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute) + defer cancel() + + logger := testhelpers.NewTestLogger(t).With("component", "integration") + dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation()) + require.NoError(t, err) + + configService := setupConfigService(t, config.Config{ + Sources: config.Sources{ + MediaServer: config.MediaServerSource{ + Host: "rtmp.example.com", + RTMP: config.RTMPSource{Enabled: true}, + }, + }, + }) + screen, screenCaptureC, getContents := setupSimulationScreen(t) + + done := make(chan struct{}) + go func() { + defer func() { + done <- struct{}{} + }() + + require.NoError(t, app.Run(ctx, buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger))) + }() + + time.Sleep(time.Second) + sendKey(t, screen, tcell.KeyF1, ' ') + + require.EventuallyWithT( + t, + func(t *assert.CollectT) { + assert.True(t, contentsIncludes(getContents(), "rtmp://rtmp.example.com:1935/live"), "expected to see custom host name") + }, + waitTime, + time.Second, + "expected to see custom host name", + ) + printScreen(t, getContents, "Ater opening the app with a custom host name") + + require.EventuallyWithT( + t, + func(c *assert.CollectT) { + conn, err := tls.Dial("tcp", "localhost:9997", &tls.Config{ + InsecureSkipVerify: true, + }) + require.NoError(c, err) + + require.Nil( + c, + conn. + ConnectionState(). + PeerCertificates[0]. + VerifyHostname("rtmp.example.com"), + "expected to verify custom host name", + ) + }, + waitTime, + time.Second, + "expected to connect to API using self-signed TLS cert with custom host name", + ) + + cancel() + + <-done +} + func TestIntegrationCustomTLSCerts(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute) defer cancel() @@ -335,7 +403,6 @@ func TestIntegrationCustomTLSCerts(t *testing.T) { block, _ := pem.Decode(certPEM) require.NotNil(c, block, "failed to decode PEM block containing certificate") require.True(c, block.Type == "CERTIFICATE", "expected PEM block to be a certificate") - certDERBytes := block.Bytes rootCAs := x509.NewCertPool() require.True(c, rootCAs.AppendCertsFromPEM(certPEM), "failed to append cert to root CA pool") @@ -347,12 +414,10 @@ func TestIntegrationCustomTLSCerts(t *testing.T) { }) require.NoError(c, err) - state := conn.ConnectionState() - peerCert := state.PeerCertificates[0] - - expectedCert, err := x509.ParseCertificate(certDERBytes) + peerCert := conn.ConnectionState().PeerCertificates[0] + wantCert, err := x509.ParseCertificate(block.Bytes) require.NoError(c, err) - require.True(c, peerCert.Equal(expectedCert), "expected peer certificate to match the expected certificate") + require.True(c, peerCert.Equal(wantCert), "expected peer certificate to match the expected certificate") }, waitTime, time.Second, @@ -365,52 +430,6 @@ func TestIntegrationCustomTLSCerts(t *testing.T) { <-done } -func TestIntegrationCustomRTMPURL(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute) - defer cancel() - - logger := testhelpers.NewTestLogger(t).With("component", "integration") - dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation()) - require.NoError(t, err) - - configService := setupConfigService(t, config.Config{ - Sources: config.Sources{ - MediaServer: config.MediaServerSource{ - Host: "rtmp.live.tv", - RTMP: config.RTMPSource{Enabled: true}, - }, - }, - }) - screen, screenCaptureC, getContents := setupSimulationScreen(t) - - done := make(chan struct{}) - go func() { - defer func() { - done <- struct{}{} - }() - - require.NoError(t, app.Run(ctx, buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger))) - }() - - time.Sleep(time.Second) - sendKey(t, screen, tcell.KeyF1, ' ') - - require.EventuallyWithT( - t, - func(t *assert.CollectT) { - assert.True(t, contentsIncludes(getContents(), "rtmp://rtmp.live.tv:1935/live"), "expected to see custom host name") - }, - waitTime, - time.Second, - "expected to see custom host name", - ) - printScreen(t, getContents, "Ater opening the app with a custom host name") - - cancel() - - <-done -} - func TestIntegrationRestartDestination(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute) defer cancel() diff --git a/internal/mediaserver/actor.go b/internal/mediaserver/actor.go index 3e2f3f5..7674298 100644 --- a/internal/mediaserver/actor.go +++ b/internal/mediaserver/actor.go @@ -98,7 +98,12 @@ type OptionalNetAddr struct { // // Callers must consume the state channel exposed via [C]. func NewActor(ctx context.Context, params NewActorParams) (_ *Actor, err error) { - keyPairInternal, err := generateTLSCert() + dnsNames := []string{"localhost"} + if params.Host != "" { + dnsNames = append(dnsNames, params.Host) + } + + keyPairInternal, err := generateTLSCert(dnsNames...) if err != nil { return nil, fmt.Errorf("generate TLS cert: %w", err) } diff --git a/internal/mediaserver/tls.go b/internal/mediaserver/tls.go index 194ce91..a0ad771 100644 --- a/internal/mediaserver/tls.go +++ b/internal/mediaserver/tls.go @@ -15,7 +15,7 @@ import ( ) // generateTLSCert generates a self-signed TLS certificate and private key. -func generateTLSCert() (domain.KeyPair, error) { +func generateTLSCert(dnsNames ...string) (domain.KeyPair, error) { privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) if err != nil { return domain.KeyPair{}, err @@ -37,7 +37,7 @@ func generateTLSCert() (domain.KeyPair, error) { KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, BasicConstraintsValid: true, - DNSNames: []string{"localhost"}, + DNSNames: dnsNames, } certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey) diff --git a/internal/mediaserver/tls_test.go b/internal/mediaserver/tls_test.go index 3291800..0a6bef3 100644 --- a/internal/mediaserver/tls_test.go +++ b/internal/mediaserver/tls_test.go @@ -12,7 +12,7 @@ import ( ) func TestGenerateTLSCert(t *testing.T) { - keyPair, err := generateTLSCert() + keyPair, err := generateTLSCert("localhost", "rtmp.example.com") require.NoError(t, err) require.NotEmpty(t, keyPair.Cert) require.NotEmpty(t, keyPair.Key) @@ -33,6 +33,8 @@ func TestGenerateTLSCert(t *testing.T) { assert.True(t, cert.BasicConstraintsValid, "basic constraints should be valid") assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageServerAuth) assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageClientAuth) + assert.Contains(t, cert.DNSNames, "localhost", "DNS names should include localhost") + assert.Contains(t, cert.DNSNames, "rtmp.example.com", "DNS names should include rtmp.example.com") block, _ = pem.Decode(keyPair.Key) require.NotNil(t, block, "failed to decode private key PEM")