diff --git a/internal/app/integration_helpers_test.go b/internal/app/integration_helpers_test.go index b93cd8e..e4427e2 100644 --- a/internal/app/integration_helpers_test.go +++ b/internal/app/integration_helpers_test.go @@ -3,7 +3,10 @@ package app_test import ( + "encoding/json" "fmt" + "io" + "net/http" "os" "strings" "sync" @@ -14,6 +17,7 @@ import ( "git.netflux.io/rob/octoplex/internal/terminal" "github.com/gdamore/tcell/v2" "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" ) func setupSimulationScreen(t *testing.T) (tcell.SimulationScreen, chan<- terminal.ScreenCapture, func() []string) { @@ -128,3 +132,33 @@ func sendBackspaces(screen tcell.SimulationScreen, n int) { } time.Sleep(500 * time.Millisecond) } + +// kickFirstRTMPConn kicks the first RTMP connection from the mediaMTX server. +func kickFirstRTMPConn(t *testing.T, srv testcontainers.Container) { + type conn struct { + ID string `json:"id"` + } + + type apiResponse struct { + Items []conn `json:"items"` + } + + port, err := srv.MappedPort(t.Context(), "9997/tcp") + require.NoError(t, err) + + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/v3/rtmpconns/list", port.Int())) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var apiResp apiResponse + require.NoError(t, json.Unmarshal(respBody, &apiResp)) + require.NoError(t, err) + require.True(t, len(apiResp.Items) > 0, "No RTMP connections found") + + resp, err = http.Post(fmt.Sprintf("http://localhost:%d/v3/rtmpconns/kick/%s", port.Int(), apiResp.Items[0].ID), "application/json", nil) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/internal/app/integration_test.go b/internal/app/integration_test.go index ddca547..ba7900e 100644 --- a/internal/app/integration_test.go +++ b/internal/app/integration_test.go @@ -21,7 +21,6 @@ import ( "git.netflux.io/rob/octoplex/internal/domain" "git.netflux.io/rob/octoplex/internal/terminal" "git.netflux.io/rob/octoplex/internal/testhelpers" - typescontainer "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/network" dockerclient "github.com/docker/docker/client" "github.com/docker/docker/errdefs" @@ -43,6 +42,12 @@ func TestIntegration(t *testing.T) { }) } +// hostIP is the IP address of the Docker host from within the container. +// +// This probably only works for Linux. +// https://stackoverflow.com/a/60740997/62871 +const hostIP = "172.17.0.1" + func testIntegration(t *testing.T, streamKey string) { ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute) defer cancel() @@ -68,18 +73,6 @@ func testIntegration(t *testing.T, streamKey string) { dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation()) require.NoError(t, err) - // List existing containers to debug Github Actions environment. - containers, err := dockerClient.ContainerList(ctx, typescontainer.ListOptions{}) - require.NoError(t, err) - - if len(containers) == 0 { - logger.Info("No existing containers found") - } else { - for _, ctr := range containers { - logger.Info("Container", "id", ctr.ID, "name", ctr.Names, "image", ctr.Image, "started", ctr.Created, "labels", ctr.Labels) - } - } - screen, screenCaptureC, getContents := setupSimulationScreen(t) // https://stackoverflow.com/a/60740997/62871 @@ -266,7 +259,6 @@ func testIntegration(t *testing.T, streamKey string) { // TODO: // - Source error - // - Destination error // - Additional features (copy URL, etc.) cancel() @@ -274,6 +266,273 @@ func testIntegration(t *testing.T, streamKey string) { <-done } +func TestIntegrationRestartDestination(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute) + defer cancel() + + destServer, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: testcontainers.ContainerRequest{ + Image: "bluenviron/mediamtx:latest", + Env: map[string]string{"MTX_LOGLEVEL": "debug"}, + ExposedPorts: []string{"1936/tcp", "9997/tcp"}, + WaitingFor: wait.ForListeningPort("1936/tcp"), + }, + Started: false, + }) + testcontainers.CleanupContainer(t, destServer) + require.NoError(t, err) + + require.NoError(t, destServer.CopyFileToContainer(t.Context(), "testdata/mediamtx.yml", "/mediamtx.yml", 0600)) + require.NoError(t, destServer.Start(ctx)) + + destServerRTMPPort, err := destServer.MappedPort(ctx, "1936/tcp") + require.NoError(t, err) + + logger := testhelpers.NewTestLogger(t).With("component", "integration") + logger.Info("Initialised logger", "debug_level", logger.Enabled(ctx, slog.LevelDebug), "runner_debug", os.Getenv("RUNNER_DEBUG")) + dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation()) + require.NoError(t, err) + + screen, screenCaptureC, getContents := setupSimulationScreen(t) + + configService := setupConfigService(t, config.Config{ + Sources: config.Sources{RTMP: config.RTMPSource{Enabled: true}}, + Destinations: []config.Destination{{ + Name: "Local server 1", + URL: fmt.Sprintf("rtmp://%s:%d/live", hostIP, destServerRTMPPort.Int()), + }}, + }) + + done := make(chan struct{}) + go func() { + defer func() { + done <- struct{}{} + }() + + err := app.Run(ctx, app.RunParams{ + ConfigService: configService, + DockerClient: dockerClient, + Screen: &terminal.Screen{ + Screen: screen, + Width: 160, + Height: 25, + CaptureC: screenCaptureC, + }, + ClipboardAvailable: false, + BuildInfo: domain.BuildInfo{Version: "0.0.1", GoVersion: "go1.16.3"}, + Logger: logger, + }) + require.NoError(t, err) + }() + + require.EventuallyWithT( + t, + func(t *assert.CollectT) { + contents := getContents() + require.True(t, len(contents) > 2, "expected at least 3 lines of output") + + assert.Contains(t, contents[2], "Status waiting for stream", "expected mediaserver status to be waiting") + }, + 2*time.Minute, + time.Second, + "expected the mediaserver to start", + ) + printScreen(getContents, "After starting the mediaserver") + + // Start streaming a test video to the app: + testhelpers.StreamFLV(t, "rtmp://localhost:1935/live") + + require.EventuallyWithT( + t, + func(t *assert.CollectT) { + contents := getContents() + require.True(t, len(contents) > 3, "expected at least 3 lines of output") + + assert.Contains(t, contents[2], "Status receiving", "expected mediaserver status to be receiving") + }, + time.Minute, + time.Second, + "expected to receive an ingress stream", + ) + printScreen(getContents, "After receiving the ingress stream") + + // Start destination: + sendKey(screen, tcell.KeyRune, ' ') + + require.EventuallyWithT( + t, + func(t *assert.CollectT) { + contents := getContents() + require.True(t, len(contents) > 4, "expected at least 5 lines of output") + + assert.Contains(t, contents[2], "Status receiving", "expected mediaserver status to be receiving") + + require.Contains(t, contents[2], "Local server 1", "expected local server 1 to be present") + assert.Contains(t, contents[2], "sending", "expected local server 1 to be sending") + assert.Contains(t, contents[2], "healthy", "expected local server 1 to be healthy") + }, + 2*time.Minute, + time.Second, + "expected to start the destination stream", + ) + printScreen(getContents, "After starting the destination stream") + + // Wait for enough time that the container will be restarted. + // Then, kick the connection to force a restart. + time.Sleep(15 * time.Second) + kickFirstRTMPConn(t, destServer) + + require.EventuallyWithT( + t, + func(t *assert.CollectT) { + contents := getContents() + require.True(t, len(contents) > 3, "expected at least 3 lines of output") + + assert.Contains(t, contents[2], "Status receiving", "expected mediaserver status to be receiving") + + require.Contains(t, contents[2], "Local server 1", "expected local server 1 to be present") + assert.Contains(t, contents[2], "off-air", "expected local server 1 to be off-air") + assert.Contains(t, contents[2], "restarting", "expected local server 1 to be restarting") + }, + 20*time.Second, + time.Second, + "expected to begin restarting", + ) + printScreen(getContents, "After stopping the destination server") + + require.EventuallyWithT( + t, + func(t *assert.CollectT) { + contents := getContents() + require.True(t, len(contents) > 4, "expected at least 4 lines of output") + + assert.Contains(t, contents[2], "Status receiving", "expected mediaserver status to be receiving") + + require.Contains(t, contents[2], "Local server 1", "expected local server 1 to be present") + assert.Contains(t, contents[2], "sending", "expected local server 1 to be sending") + assert.Contains(t, contents[2], "healthy", "expected local server 1 to be healthy") + }, + 2*time.Minute, + time.Second, + "expected to restart the destination stream", + ) + printScreen(getContents, "After restarting the destination stream") + + // Stop destination. + sendKey(screen, tcell.KeyRune, ' ') + + require.EventuallyWithT( + t, + func(t *assert.CollectT) { + contents := getContents() + require.True(t, len(contents) > 4, "expected at least 4 lines of output") + + require.Contains(t, contents[2], "Local server 1", "expected local server 1 to be present") + assert.Contains(t, contents[2], "exited", "expected local server 1 to have exited") + + require.NotContains(t, contents[3], "Local server 2", "expected local server 2 to not be present") + }, + time.Minute, + time.Second, + "expected to stop the destination stream", + ) + + printScreen(getContents, "After stopping the destination") + + cancel() + + <-done +} + +func TestIntegrationStartDestinationFailed(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute) + defer cancel() + + logger := testhelpers.NewTestLogger(t).With("component", "integration") + dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation()) + require.NoError(t, err) + + screen, screenCaptureC, getContents := setupSimulationScreen(t) + + configService := setupConfigService(t, config.Config{ + Sources: config.Sources{RTMP: config.RTMPSource{Enabled: true}}, + Destinations: []config.Destination{{Name: "Example server", URL: "rtmp://rtmp.example.com/live"}}, + }) + + done := make(chan struct{}) + go func() { + defer func() { + done <- struct{}{} + }() + + err := app.Run(ctx, app.RunParams{ + ConfigService: configService, + DockerClient: dockerClient, + Screen: &terminal.Screen{ + Screen: screen, + Width: 160, + Height: 25, + CaptureC: screenCaptureC, + }, + ClipboardAvailable: false, + BuildInfo: domain.BuildInfo{Version: "0.0.1", GoVersion: "go1.16.3"}, + Logger: logger, + }) + require.NoError(t, err) + }() + + require.EventuallyWithT( + t, + func(t *assert.CollectT) { + contents := getContents() + require.True(t, len(contents) > 2, "expected at least 3 lines of output") + + assert.Contains(t, contents[2], "Status waiting for stream", "expected mediaserver status to be waiting") + }, + 2*time.Minute, + time.Second, + "expected the mediaserver to start", + ) + printScreen(getContents, "After starting the mediaserver") + + // Start streaming a test video to the app: + testhelpers.StreamFLV(t, "rtmp://localhost:1935/live") + + require.EventuallyWithT( + t, + func(t *assert.CollectT) { + contents := getContents() + require.True(t, len(contents) > 3, "expected at least 3 lines of output") + + assert.Contains(t, contents[2], "Status receiving", "expected mediaserver status to be receiving") + }, + time.Minute, + time.Second, + "expected to receive an ingress stream", + ) + printScreen(getContents, "After receiving the ingress stream") + + // Start destination: + sendKey(screen, tcell.KeyRune, ' ') + + require.EventuallyWithT( + t, + func(t *assert.CollectT) { + contents := getContents() + assert.True(t, contentsIncludes(contents, "Streaming to Example server failed:"), "expected to see destination error") + assert.True(t, contentsIncludes(contents, "container failed to start"), "expected to see destination error") + }, + time.Minute, + time.Second, + "expected to see the destination start error modal", + ) + printScreen(getContents, "After starting the destination stream fails") + + cancel() + + <-done +} + func TestIntegrationDestinationValidations(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), 10*time.Minute) defer cancel() diff --git a/internal/app/testdata/mediamtx.yml b/internal/app/testdata/mediamtx.yml new file mode 100644 index 0000000..f0f2386 --- /dev/null +++ b/internal/app/testdata/mediamtx.yml @@ -0,0 +1,12 @@ +rtmp: true +rtmpAddress: :1936 +api: true +authInternalUsers: +- user: any + ips: [] + permissions: + - action: api + - action: read + - action: publish +paths: + live: diff --git a/internal/container/container.go b/internal/container/container.go index 960a78d..fb2b255 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -21,7 +21,6 @@ import ( "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/network" - "github.com/docker/docker/errdefs" ocispec "github.com/opencontainers/image-spec/specs-go/v1" ) @@ -38,7 +37,6 @@ type DockerClient interface { io.Closer ContainerCreate(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *ocispec.Platform, string) (container.CreateResponse, error) - ContainerInspect(context.Context, string) (container.InspectResponse, error) ContainerList(context.Context, container.ListOptions) ([]container.Summary, error) ContainerRemove(context.Context, string, container.RemoveOptions) error ContainerStart(context.Context, string, container.StartOptions) error @@ -72,6 +70,7 @@ type Client struct { wg sync.WaitGroup apiClient DockerClient networkID string + cancelFuncs map[string]context.CancelFunc pulledImages map[string]struct{} logger *slog.Logger } @@ -99,6 +98,7 @@ func NewClient(ctx context.Context, apiClient DockerClient, logger *slog.Logger) cancel: cancel, apiClient: apiClient, networkID: network.ID, + cancelFuncs: make(map[string]context.CancelFunc), pulledImages: make(map[string]struct{}), logger: logger, } @@ -147,6 +147,15 @@ type CopyFileConfig struct { Mode int64 } +// ShouldRestartFunc is a callback function that is called when a container +// exits. It should return true if the container is to be restarted. If not +// restarting, err may be non-nil. +type ShouldRestartFunc func(exitCode int64, restartCount int, runningTime time.Duration) (bool, error) + +// defaultRestartInterval is the default interval between restarts. +// TODO: exponential backoff +const defaultRestartInterval = 10 * time.Second + // RunContainerParams are the parameters for running a container. type RunContainerParams struct { Name string @@ -156,6 +165,8 @@ type RunContainerParams struct { NetworkingConfig *network.NetworkingConfig NetworkCountConfig NetworkCountConfig CopyFileConfigs []CopyFileConfig + ShouldRestart ShouldRestartFunc + RestartInterval time.Duration // defaults to 10 seconds } // RunContainer runs a container with the given parameters. @@ -164,13 +175,18 @@ type RunContainerParams struct { // never be closed. The error channel will receive an error if the container // fails to start, and will be closed when the container exits, possibly after // receiving an error. +// +// Panics if ShouldRestart is non-nil and the host config defines a restart +// policy of its own. func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (<-chan domain.Container, <-chan error) { + if params.ShouldRestart != nil && !params.HostConfig.RestartPolicy.IsNone() { + panic("shouldRestart and restart policy are mutually exclusive") + } + now := time.Now() containerStateC := make(chan domain.Container, cmp.Or(params.ChanSize, defaultChanSize)) errC := make(chan error, 1) - sendError := func(err error) { - errC <- err - } + sendError := func(err error) { errC <- err } a.wg.Add(1) go func() { @@ -224,11 +240,20 @@ func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (< containerStateC <- domain.Container{ID: createResp.ID, Status: domain.ContainerStatusRunning} + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + a.mu.Lock() + a.cancelFuncs[createResp.ID] = cancel + a.mu.Unlock() + a.runContainerLoop( ctx, + cancel, createResp.ID, params.ContainerConfig.Image, params.NetworkCountConfig, + params.ShouldRestart, + cmp.Or(params.RestartInterval, defaultRestartInterval), containerStateC, errC, ) @@ -311,15 +336,23 @@ func (a *Client) pullImageIfNeeded(ctx context.Context, imageName string, contai // when the container exits. func (a *Client) runContainerLoop( ctx context.Context, + cancel context.CancelFunc, containerID string, imageName string, networkCountConfig NetworkCountConfig, + shouldRestartFunc ShouldRestartFunc, + restartInterval time.Duration, stateC chan<- domain.Container, errC chan<- error, ) { + defer cancel() + type containerWaitResponse struct { container.WaitResponse - restarting bool + + restarting bool + restartCount int + err error } containerRespC := make(chan containerWaitResponse) @@ -333,36 +366,63 @@ func (a *Client) runContainerLoop( // The goroutine exits when a value is received on the error channel, or when // the container exits and is not restarting, or when the context is cancelled. go func() { + timer := time.NewTimer(restartInterval) + defer timer.Stop() + timer.Stop() + + var restartCount int + for { + startedWaitingAt := time.Now() respC, errC := a.apiClient.ContainerWait(ctx, containerID, container.WaitConditionNextExit) select { case resp := <-respC: - var restarting bool - // Check if the container is restarting. If it is not then we don't - // want to wait for it again and can return early. - ctr, err := a.apiClient.ContainerInspect(ctx, containerID) - // Race condition: the container may already have been removed. - if errdefs.IsNotFound(err) { - // ignore error but do not restart - } else if err != nil { - a.logger.Error("Error inspecting container", "err", err, "id", shortID(containerID)) - containerErrC <- err - return - // Race condition: the container may have already restarted. - } else if ctr.State.Status == domain.ContainerStatusRestarting || ctr.State.Status == domain.ContainerStatusRunning { - restarting = true + exit := func(err error) { + a.logger.Info("Container exited", "id", shortID(containerID), "should_restart", "false", "exit_code", resp.StatusCode, "restart_count", restartCount) + containerRespC <- containerWaitResponse{ + WaitResponse: resp, + restarting: false, + restartCount: restartCount, + err: err, + } } - containerRespC <- containerWaitResponse{WaitResponse: resp, restarting: restarting} - if !restarting { + if shouldRestartFunc == nil { + exit(nil) return } + + shouldRestart, err := shouldRestartFunc(resp.StatusCode, restartCount, time.Since(startedWaitingAt)) + if shouldRestart && err != nil { + panic(fmt.Errorf("shouldRestart must return nil error if restarting, but returned: %w", err)) + } + if !shouldRestart { + exit(err) + return + } + + a.logger.Info("Container exited", "id", shortID(containerID), "should_restart", "true", "exit_code", resp.StatusCode, "restart_count", restartCount) + timer.Reset(restartInterval) + + containerRespC <- containerWaitResponse{ + WaitResponse: resp, + restarting: true, + restartCount: restartCount, + } + case <-timer.C: + a.logger.Info("Container restarting", "id", shortID(containerID), "restart_count", restartCount) + restartCount++ + if err := a.apiClient.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil { + containerErrC <- fmt.Errorf("container start: %w", err) + return + } + a.logger.Info("Restarted container", "id", shortID(containerID)) case err := <-errC: - // Otherwise, this is probably unexpected and we need to handle it. containerErrC <- err return case <-ctx.Done(): - containerErrC <- ctx.Err() + // This is probably because the container was stopped. + containerRespC <- containerWaitResponse{WaitResponse: container.WaitResponse{}, restarting: false} return } } @@ -382,20 +442,23 @@ func (a *Client) runContainerLoop( a.logger.Info("Container entered non-running state", "exit_code", resp.StatusCode, "id", shortID(containerID), "restarting", resp.restarting) var containerState string + var containerErr error if resp.restarting { containerState = domain.ContainerStatusRestarting } else { containerState = domain.ContainerStatusExited + containerErr = resp.err } state.Status = containerState + state.Err = containerErr + state.RestartCount = resp.restartCount state.CPUPercent = 0 state.MemoryUsageBytes = 0 state.HealthState = "unhealthy" state.RxRate = 0 state.TxRate = 0 state.RxSince = time.Time{} - state.RestartCount++ if !resp.restarting { exitCode := int(resp.StatusCode) @@ -406,7 +469,7 @@ func (a *Client) runContainerLoop( sendState() case err := <-containerErrC: - // TODO: error handling? + // TODO: verify error handling if err != context.Canceled { a.logger.Error("Error setting container wait", "err", err, "id", shortID(containerID)) } @@ -479,6 +542,24 @@ func (a *Client) Close() error { } func (a *Client) removeContainer(ctx context.Context, id string) error { + a.mu.Lock() + cancel, ok := a.cancelFuncs[id] + if ok { + delete(a.cancelFuncs, id) + } + a.mu.Unlock() + + if ok { + cancel() + } else { + // It is attempted to keep track of cancel functions for each container, + // which allow clean cancellation of container restart logic during + // removal. But there are legitimate occasions where the cancel function + // would not exist (e.g. during startup check) and in general the state of + // the Docker engine is preferred to local state in this package. + a.logger.Debug("removeContainer: cancelFunc not found", "id", shortID(id)) + } + a.logger.Info("Stopping container", "id", shortID(id)) stopTimeout := int(stopTimeout.Seconds()) if err := a.apiClient.ContainerStop(ctx, id, container.StopOptions{Timeout: &stopTimeout}); err != nil { diff --git a/internal/container/container_test.go b/internal/container/container_test.go index c8fe53c..3e54622 100644 --- a/internal/container/container_test.go +++ b/internal/container/container_test.go @@ -44,7 +44,7 @@ func TestClientRunContainer(t *testing.T) { dockerClient. EXPECT(). ImagePull(mock.Anything, "alpine", image.PullOptions{}). - Return(io.NopCloser(bytes.NewReader(nil)), errors.New("error pulling image should not be fatal")) + Return(nil, errors.New("error pulling image should not be fatal")) dockerClient. EXPECT(). ContainerCreate(mock.Anything, mock.Anything, mock.Anything, mock.Anything, (*ocispec.Platform)(nil), mock.Anything). @@ -69,10 +69,6 @@ func TestClientRunContainer(t *testing.T) { EXPECT(). ContainerWait(mock.Anything, "123", dockercontainer.WaitConditionNextExit). Return(containerWaitC, containerErrC) - dockerClient. - EXPECT(). - ContainerInspect(mock.Anything, "123"). - Return(dockercontainer.InspectResponse{ContainerJSONBase: &dockercontainer.ContainerJSONBase{State: &dockercontainer.State{Status: "exited"}}}, nil) dockerClient. EXPECT(). Events(mock.Anything, events.ListOptions{Filters: filters.NewArgs(filters.Arg("container", "123"), filters.Arg("type", "container"))}). @@ -122,7 +118,120 @@ func TestClientRunContainer(t *testing.T) { assert.Equal(t, "unhealthy", state.HealthState) require.NotNil(t, state.ExitCode) assert.Equal(t, 1, *state.ExitCode) + assert.Equal(t, 0, state.RestartCount) + + <-done +} + +func TestClientRunContainerWithRestart(t *testing.T) { + logger := testhelpers.NewTestLogger(t) + + // channels returned by Docker's ContainerWait: + containerWaitC := make(chan dockercontainer.WaitResponse) + containerErrC := make(chan error) + + // channels returned by Docker's Events: + eventsC := make(chan events.Message) + eventsErrC := make(chan error) + + var dockerClient mocks.DockerClient + defer dockerClient.AssertExpectations(t) + + dockerClient. + EXPECT(). + NetworkCreate(mock.Anything, mock.Anything, mock.MatchedBy(func(opts network.CreateOptions) bool { + return opts.Driver == "bridge" && len(opts.Labels) > 0 + })). + Return(network.CreateResponse{ID: "test-network"}, nil) + dockerClient. + EXPECT(). + ImagePull(mock.Anything, "alpine", image.PullOptions{}). + Return(io.NopCloser(bytes.NewReader(nil)), nil) + dockerClient. + EXPECT(). + ContainerCreate(mock.Anything, mock.Anything, mock.Anything, mock.Anything, (*ocispec.Platform)(nil), mock.Anything). + Return(dockercontainer.CreateResponse{ID: "123"}, nil) + dockerClient. + EXPECT(). + NetworkConnect(mock.Anything, "test-network", "123", (*network.EndpointSettings)(nil)). + Return(nil) + dockerClient. + EXPECT(). + ContainerStart(mock.Anything, "123", dockercontainer.StartOptions{}). + Once(). + Return(nil) + dockerClient. + EXPECT(). + ContainerStats(mock.Anything, "123", true). + Return(dockercontainer.StatsResponseReader{Body: io.NopCloser(bytes.NewReader(nil))}, nil) + dockerClient. + EXPECT(). + ContainerWait(mock.Anything, "123", dockercontainer.WaitConditionNextExit). + Return(containerWaitC, containerErrC) + dockerClient. + EXPECT(). + Events(mock.Anything, events.ListOptions{Filters: filters.NewArgs(filters.Arg("container", "123"), filters.Arg("type", "container"))}). + Return(eventsC, eventsErrC) + dockerClient. + EXPECT(). + ContainerStart(mock.Anything, "123", dockercontainer.StartOptions{}). // restart + Return(nil) + + containerClient, err := container.NewClient(t.Context(), &dockerClient, logger) + require.NoError(t, err) + + containerStateC, errC := containerClient.RunContainer(t.Context(), container.RunContainerParams{ + Name: "test-run-container", + ChanSize: 1, + ContainerConfig: &dockercontainer.Config{Image: "alpine"}, + HostConfig: &dockercontainer.HostConfig{}, + ShouldRestart: func(_ int64, restartCount int, _ time.Duration) (bool, error) { + if restartCount == 0 { + return true, nil + } + + return false, errors.New("max restarts reached") + }, + RestartInterval: 10 * time.Millisecond, + }) + + done := make(chan struct{}) + go func() { + defer close(done) + + require.NoError(t, <-errC) + }() + + assert.Equal(t, "pulling", (<-containerStateC).Status) + assert.Equal(t, "created", (<-containerStateC).Status) + assert.Equal(t, "running", (<-containerStateC).Status) + assert.Equal(t, "running", (<-containerStateC).Status) + + // Enough time for the restart to occur: + time.Sleep(100 * time.Millisecond) + + containerWaitC <- dockercontainer.WaitResponse{StatusCode: 1} + + state := <-containerStateC + assert.Equal(t, "restarting", state.Status) + assert.Equal(t, "unhealthy", state.HealthState) + assert.Nil(t, state.ExitCode) + assert.Zero(t, state.RestartCount) // not incremented until the actual restart + + // During the restart, the "running" status is triggered by Docker events + // only. So we don't expect one in unit tests. (Probably the initial startup + // flow should behave the same.) + + time.Sleep(100 * time.Millisecond) + containerWaitC <- dockercontainer.WaitResponse{StatusCode: 1} + + state = <-containerStateC + assert.Equal(t, "exited", state.Status) + assert.Equal(t, "unhealthy", state.HealthState) + require.NotNil(t, state.ExitCode) + assert.Equal(t, 1, *state.ExitCode) assert.Equal(t, 1, state.RestartCount) + assert.Equal(t, "max restarts reached", state.Err.Error()) <-done } diff --git a/internal/container/mocks/dockerclient_mock.go b/internal/container/mocks/dockerclient_mock.go index 454394f..6ecc419 100644 --- a/internal/container/mocks/dockerclient_mock.go +++ b/internal/container/mocks/dockerclient_mock.go @@ -138,63 +138,6 @@ func (_c *DockerClient_ContainerCreate_Call) RunAndReturn(run func(context.Conte return _c } -// ContainerInspect provides a mock function with given fields: _a0, _a1 -func (_m *DockerClient) ContainerInspect(_a0 context.Context, _a1 string) (typescontainer.InspectResponse, error) { - ret := _m.Called(_a0, _a1) - - if len(ret) == 0 { - panic("no return value specified for ContainerInspect") - } - - var r0 typescontainer.InspectResponse - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (typescontainer.InspectResponse, error)); ok { - return rf(_a0, _a1) - } - if rf, ok := ret.Get(0).(func(context.Context, string) typescontainer.InspectResponse); ok { - r0 = rf(_a0, _a1) - } else { - r0 = ret.Get(0).(typescontainer.InspectResponse) - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(_a0, _a1) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// DockerClient_ContainerInspect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ContainerInspect' -type DockerClient_ContainerInspect_Call struct { - *mock.Call -} - -// ContainerInspect is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 string -func (_e *DockerClient_Expecter) ContainerInspect(_a0 interface{}, _a1 interface{}) *DockerClient_ContainerInspect_Call { - return &DockerClient_ContainerInspect_Call{Call: _e.mock.On("ContainerInspect", _a0, _a1)} -} - -func (_c *DockerClient_ContainerInspect_Call) Run(run func(_a0 context.Context, _a1 string)) *DockerClient_ContainerInspect_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string)) - }) - return _c -} - -func (_c *DockerClient_ContainerInspect_Call) Return(_a0 typescontainer.InspectResponse, _a1 error) *DockerClient_ContainerInspect_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *DockerClient_ContainerInspect_Call) RunAndReturn(run func(context.Context, string) (typescontainer.InspectResponse, error)) *DockerClient_ContainerInspect_Call { - _c.Call.Return(run) - return _c -} - // ContainerList provides a mock function with given fields: _a0, _a1 func (_m *DockerClient) ContainerList(_a0 context.Context, _a1 typescontainer.ListOptions) ([]typescontainer.Summary, error) { ret := _m.Called(_a0, _a1) diff --git a/internal/container/pull.json b/internal/container/pull.json deleted file mode 100644 index e69de29..0000000 diff --git a/internal/replicator/replicator.go b/internal/replicator/replicator.go index 4904a92..d3ca087 100644 --- a/internal/replicator/replicator.go +++ b/internal/replicator/replicator.go @@ -3,6 +3,7 @@ package replicator import ( "cmp" "context" + "errors" "fmt" "log/slog" "strconv" @@ -105,11 +106,20 @@ func (a *Actor) StartDestination(url string) { container.LabelURL: url, }, }, - HostConfig: &typescontainer.HostConfig{ - NetworkMode: "default", - RestartPolicy: typescontainer.RestartPolicy{Name: "always"}, - }, + HostConfig: &typescontainer.HostConfig{NetworkMode: "default"}, NetworkCountConfig: container.NetworkCountConfig{Rx: "eth1", Tx: "eth0"}, + ShouldRestart: func(_ int64, restartCount int, runningTime time.Duration) (bool, error) { + // Try to infer if the container failed to start. + // + // TODO: this is a bit hacky, we should check the container logs and + // include some details in the error message. + if restartCount == 0 && runningTime < 10*time.Second { + return false, errors.New("container failed to start") + } + + // Otherwise, always restart, regardless of the exit code. + return true, nil + }, }) a.wg.Add(1) diff --git a/internal/terminal/terminal.go b/internal/terminal/terminal.go index b1abca5..cec9ad3 100644 --- a/internal/terminal/terminal.go +++ b/internal/terminal/terminal.go @@ -366,8 +366,6 @@ func (ui *UI) ShowStartupCheckModal() bool { } func (ui *UI) ShowDestinationErrorModal(name string, err error) { - done := make(chan struct{}) - ui.app.QueueUpdateDraw(func() { ui.showModal( pageNameModalDestinationError, @@ -377,13 +375,9 @@ func (ui *UI) ShowDestinationErrorModal(name string, err error) { err, ), []string{"Ok"}, - func(int, string) { - done <- struct{}{} - }, + nil, ) }) - - <-done } // ShowFatalErrorModal displays the provided error. It sends a CommandQuit to the