diff --git a/internal/container/container.go b/internal/container/container.go index 0e240a7..1053db7 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -152,6 +152,10 @@ type CopyFileConfig struct { // restarting, err can be non-nil. type ShouldRestartFunc func(exitCode int64, restartCount int, runningTime time.Duration) (bool, error) +// defaultRestartInterval is the default interval between restarts. +// TODO: exponential backoff +const defaultRestartInterval = 10 * time.Second + // RunContainerParams are the parameters for running a container. type RunContainerParams struct { Name string @@ -162,6 +166,7 @@ type RunContainerParams struct { NetworkCountConfig NetworkCountConfig CopyFileConfigs []CopyFileConfig ShouldRestart ShouldRestartFunc + RestartInterval time.Duration // defaults to 10 seconds } // RunContainer runs a container with the given parameters. @@ -248,6 +253,7 @@ func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (< params.ContainerConfig.Image, params.NetworkCountConfig, params.ShouldRestart, + cmp.Or(params.RestartInterval, defaultRestartInterval), containerStateC, errC, ) @@ -335,6 +341,7 @@ func (a *Client) runContainerLoop( imageName string, networkCountConfig NetworkCountConfig, shouldRestartFunc ShouldRestartFunc, + restartInterval time.Duration, stateC chan<- domain.Container, errC chan<- error, ) { @@ -359,10 +366,10 @@ func (a *Client) runContainerLoop( // The goroutine exits when a value is received on the error channel, or when // the container exits and is not restarting, or when the context is cancelled. go func() { - const restartDuration = 5 * time.Second - timer := time.NewTimer(restartDuration) + timer := time.NewTimer(restartInterval) defer timer.Stop() timer.Stop() + var restartCount int for { @@ -377,28 +384,39 @@ func (a *Client) runContainerLoop( } if !shouldRestart { a.logger.Info("Container exited", "id", shortID(containerID), "should_restart", "false", "exit_code", resp.StatusCode, "restart_count", restartCount) - containerRespC <- containerWaitResponse{WaitResponse: resp, restarting: false, err: err} + containerRespC <- containerWaitResponse{ + WaitResponse: resp, + restarting: false, + restartCount: restartCount, + err: err, + } return } } - restartCount++ - - // TODO: exponential backoff a.logger.Info("Container exited", "id", shortID(containerID), "should_restart", "true", "exit_code", resp.StatusCode, "restart_count", restartCount) - timer.Reset(restartDuration) + timer.Reset(restartInterval) - containerRespC <- containerWaitResponse{WaitResponse: resp, restarting: true} + containerRespC <- containerWaitResponse{ + WaitResponse: resp, + restarting: true, + restartCount: restartCount, + } case <-timer.C: + fmt.Println("ContainerWait timer expired") a.logger.Info("Container restarting", "id", shortID(containerID), "restart_count", restartCount) + restartCount++ if err := a.apiClient.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil { containerErrC <- fmt.Errorf("container start: %w", err) return } + a.logger.Info("Restarted container", "id", shortID(containerID)) case err := <-errC: + fmt.Println("ContainerWait error", err) containerErrC <- err return case <-ctx.Done(): + fmt.Println("ContainerWait cancelled") // This is probably because the container was stopped. containerRespC <- containerWaitResponse{WaitResponse: container.WaitResponse{}, restarting: false} return diff --git a/internal/container/container_test.go b/internal/container/container_test.go index 7248b2a..a050002 100644 --- a/internal/container/container_test.go +++ b/internal/container/container_test.go @@ -3,6 +3,7 @@ package container_test import ( "bytes" "errors" + "fmt" "io" "testing" "time" @@ -126,6 +127,119 @@ func TestClientRunContainer(t *testing.T) { <-done } +func TestClientRunContainerWithRestart(t *testing.T) { + logger := testhelpers.NewTestLogger(t) + + // channels returned by Docker's ContainerWait: + containerWaitC := make(chan dockercontainer.WaitResponse) + containerErrC := make(chan error) + + // channels returned by Docker's Events: + eventsC := make(chan events.Message) + eventsErrC := make(chan error) + + var dockerClient mocks.DockerClient + defer dockerClient.AssertExpectations(t) + + dockerClient. + EXPECT(). + NetworkCreate(mock.Anything, mock.Anything, mock.MatchedBy(func(opts network.CreateOptions) bool { + return opts.Driver == "bridge" && len(opts.Labels) > 0 + })). + Return(network.CreateResponse{ID: "test-network"}, nil) + dockerClient. + EXPECT(). + ImagePull(mock.Anything, "alpine", image.PullOptions{}). + Return(io.NopCloser(bytes.NewReader(nil)), nil) + dockerClient. + EXPECT(). + ContainerCreate(mock.Anything, mock.Anything, mock.Anything, mock.Anything, (*ocispec.Platform)(nil), mock.Anything). + Return(dockercontainer.CreateResponse{ID: "123"}, nil) + dockerClient. + EXPECT(). + NetworkConnect(mock.Anything, "test-network", "123", (*network.EndpointSettings)(nil)). + Return(nil) + dockerClient. + EXPECT(). + ContainerStart(mock.Anything, "123", dockercontainer.StartOptions{}). + Once(). + Return(nil) + dockerClient. + EXPECT(). + ContainerStats(mock.Anything, "123", true). + Return(dockercontainer.StatsResponseReader{Body: io.NopCloser(bytes.NewReader(nil))}, nil) + dockerClient. + EXPECT(). + ContainerWait(mock.Anything, "123", dockercontainer.WaitConditionNextExit). + Return(containerWaitC, containerErrC) + dockerClient. + EXPECT(). + Events(mock.Anything, events.ListOptions{Filters: filters.NewArgs(filters.Arg("container", "123"), filters.Arg("type", "container"))}). + Return(eventsC, eventsErrC) + dockerClient. + EXPECT(). + ContainerStart(mock.Anything, "123", dockercontainer.StartOptions{}). // restart + Return(nil) + + containerClient, err := container.NewClient(t.Context(), &dockerClient, logger) + require.NoError(t, err) + + containerStateC, errC := containerClient.RunContainer(t.Context(), container.RunContainerParams{ + Name: "test-run-container", + ChanSize: 1, + ContainerConfig: &dockercontainer.Config{Image: "alpine"}, + HostConfig: &dockercontainer.HostConfig{}, + ShouldRestart: func(_ int64, restartCount int, _ time.Duration) (bool, error) { + fmt.Println("ShouldRestart", restartCount) + if restartCount == 0 { + return true, nil + } + + return false, errors.New("max restarts reached") + }, + RestartInterval: 10 * time.Millisecond, + }) + + done := make(chan struct{}) + go func() { + defer close(done) + + require.NoError(t, <-errC) + }() + + assert.Equal(t, "pulling", (<-containerStateC).Status) + assert.Equal(t, "created", (<-containerStateC).Status) + assert.Equal(t, "running", (<-containerStateC).Status) + assert.Equal(t, "running", (<-containerStateC).Status) + + // Enough time for the restart to occur: + time.Sleep(100 * time.Millisecond) + + containerWaitC <- dockercontainer.WaitResponse{StatusCode: 1} + + state := <-containerStateC + assert.Equal(t, "restarting", state.Status) + assert.Equal(t, "unhealthy", state.HealthState) + assert.Nil(t, state.ExitCode) + assert.Equal(t, 0, state.RestartCount) // not incremented until the actual restart + + // During the restart, the "running" status is triggered by Docker events + // only. So we don't expect one in unit tests. (Probably the initial startup + // flow should behave the same.) + + time.Sleep(100 * time.Millisecond) + containerWaitC <- dockercontainer.WaitResponse{StatusCode: 1} + + state = <-containerStateC + assert.Equal(t, "exited", state.Status) + assert.Equal(t, "unhealthy", state.HealthState) + require.NotNil(t, state.ExitCode) + assert.Equal(t, 1, *state.ExitCode) + assert.Equal(t, 1, state.RestartCount) + + <-done +} + func TestClientRunContainerErrorStartingContainer(t *testing.T) { logger := testhelpers.NewTestLogger(t)