fixup! refactor(container): restart handling

This commit is contained in:
Rob Watson 2025-04-13 05:55:53 +02:00
parent b1713b22b2
commit 10ca605db0
2 changed files with 140 additions and 8 deletions

View File

@ -152,6 +152,10 @@ type CopyFileConfig struct {
// restarting, err can be non-nil. // restarting, err can be non-nil.
type ShouldRestartFunc func(exitCode int64, restartCount int, runningTime time.Duration) (bool, error) type ShouldRestartFunc func(exitCode int64, restartCount int, runningTime time.Duration) (bool, error)
// defaultRestartInterval is the default interval between restarts.
// TODO: exponential backoff
const defaultRestartInterval = 10 * time.Second
// RunContainerParams are the parameters for running a container. // RunContainerParams are the parameters for running a container.
type RunContainerParams struct { type RunContainerParams struct {
Name string Name string
@ -162,6 +166,7 @@ type RunContainerParams struct {
NetworkCountConfig NetworkCountConfig NetworkCountConfig NetworkCountConfig
CopyFileConfigs []CopyFileConfig CopyFileConfigs []CopyFileConfig
ShouldRestart ShouldRestartFunc ShouldRestart ShouldRestartFunc
RestartInterval time.Duration // defaults to 10 seconds
} }
// RunContainer runs a container with the given parameters. // RunContainer runs a container with the given parameters.
@ -248,6 +253,7 @@ func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (<
params.ContainerConfig.Image, params.ContainerConfig.Image,
params.NetworkCountConfig, params.NetworkCountConfig,
params.ShouldRestart, params.ShouldRestart,
cmp.Or(params.RestartInterval, defaultRestartInterval),
containerStateC, containerStateC,
errC, errC,
) )
@ -335,6 +341,7 @@ func (a *Client) runContainerLoop(
imageName string, imageName string,
networkCountConfig NetworkCountConfig, networkCountConfig NetworkCountConfig,
shouldRestartFunc ShouldRestartFunc, shouldRestartFunc ShouldRestartFunc,
restartInterval time.Duration,
stateC chan<- domain.Container, stateC chan<- domain.Container,
errC chan<- error, errC chan<- error,
) { ) {
@ -359,10 +366,10 @@ func (a *Client) runContainerLoop(
// The goroutine exits when a value is received on the error channel, or when // The goroutine exits when a value is received on the error channel, or when
// the container exits and is not restarting, or when the context is cancelled. // the container exits and is not restarting, or when the context is cancelled.
go func() { go func() {
const restartDuration = 5 * time.Second timer := time.NewTimer(restartInterval)
timer := time.NewTimer(restartDuration)
defer timer.Stop() defer timer.Stop()
timer.Stop() timer.Stop()
var restartCount int var restartCount int
for { for {
@ -377,28 +384,39 @@ func (a *Client) runContainerLoop(
} }
if !shouldRestart { if !shouldRestart {
a.logger.Info("Container exited", "id", shortID(containerID), "should_restart", "false", "exit_code", resp.StatusCode, "restart_count", restartCount) a.logger.Info("Container exited", "id", shortID(containerID), "should_restart", "false", "exit_code", resp.StatusCode, "restart_count", restartCount)
containerRespC <- containerWaitResponse{WaitResponse: resp, restarting: false, err: err} containerRespC <- containerWaitResponse{
WaitResponse: resp,
restarting: false,
restartCount: restartCount,
err: err,
}
return return
} }
} }
restartCount++
// TODO: exponential backoff
a.logger.Info("Container exited", "id", shortID(containerID), "should_restart", "true", "exit_code", resp.StatusCode, "restart_count", restartCount) a.logger.Info("Container exited", "id", shortID(containerID), "should_restart", "true", "exit_code", resp.StatusCode, "restart_count", restartCount)
timer.Reset(restartDuration) timer.Reset(restartInterval)
containerRespC <- containerWaitResponse{WaitResponse: resp, restarting: true} containerRespC <- containerWaitResponse{
WaitResponse: resp,
restarting: true,
restartCount: restartCount,
}
case <-timer.C: case <-timer.C:
fmt.Println("ContainerWait timer expired")
a.logger.Info("Container restarting", "id", shortID(containerID), "restart_count", restartCount) a.logger.Info("Container restarting", "id", shortID(containerID), "restart_count", restartCount)
restartCount++
if err := a.apiClient.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil { if err := a.apiClient.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil {
containerErrC <- fmt.Errorf("container start: %w", err) containerErrC <- fmt.Errorf("container start: %w", err)
return return
} }
a.logger.Info("Restarted container", "id", shortID(containerID))
case err := <-errC: case err := <-errC:
fmt.Println("ContainerWait error", err)
containerErrC <- err containerErrC <- err
return return
case <-ctx.Done(): case <-ctx.Done():
fmt.Println("ContainerWait cancelled")
// This is probably because the container was stopped. // This is probably because the container was stopped.
containerRespC <- containerWaitResponse{WaitResponse: container.WaitResponse{}, restarting: false} containerRespC <- containerWaitResponse{WaitResponse: container.WaitResponse{}, restarting: false}
return return

View File

@ -3,6 +3,7 @@ package container_test
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"io" "io"
"testing" "testing"
"time" "time"
@ -126,6 +127,119 @@ func TestClientRunContainer(t *testing.T) {
<-done <-done
} }
func TestClientRunContainerWithRestart(t *testing.T) {
logger := testhelpers.NewTestLogger(t)
// channels returned by Docker's ContainerWait:
containerWaitC := make(chan dockercontainer.WaitResponse)
containerErrC := make(chan error)
// channels returned by Docker's Events:
eventsC := make(chan events.Message)
eventsErrC := make(chan error)
var dockerClient mocks.DockerClient
defer dockerClient.AssertExpectations(t)
dockerClient.
EXPECT().
NetworkCreate(mock.Anything, mock.Anything, mock.MatchedBy(func(opts network.CreateOptions) bool {
return opts.Driver == "bridge" && len(opts.Labels) > 0
})).
Return(network.CreateResponse{ID: "test-network"}, nil)
dockerClient.
EXPECT().
ImagePull(mock.Anything, "alpine", image.PullOptions{}).
Return(io.NopCloser(bytes.NewReader(nil)), nil)
dockerClient.
EXPECT().
ContainerCreate(mock.Anything, mock.Anything, mock.Anything, mock.Anything, (*ocispec.Platform)(nil), mock.Anything).
Return(dockercontainer.CreateResponse{ID: "123"}, nil)
dockerClient.
EXPECT().
NetworkConnect(mock.Anything, "test-network", "123", (*network.EndpointSettings)(nil)).
Return(nil)
dockerClient.
EXPECT().
ContainerStart(mock.Anything, "123", dockercontainer.StartOptions{}).
Once().
Return(nil)
dockerClient.
EXPECT().
ContainerStats(mock.Anything, "123", true).
Return(dockercontainer.StatsResponseReader{Body: io.NopCloser(bytes.NewReader(nil))}, nil)
dockerClient.
EXPECT().
ContainerWait(mock.Anything, "123", dockercontainer.WaitConditionNextExit).
Return(containerWaitC, containerErrC)
dockerClient.
EXPECT().
Events(mock.Anything, events.ListOptions{Filters: filters.NewArgs(filters.Arg("container", "123"), filters.Arg("type", "container"))}).
Return(eventsC, eventsErrC)
dockerClient.
EXPECT().
ContainerStart(mock.Anything, "123", dockercontainer.StartOptions{}). // restart
Return(nil)
containerClient, err := container.NewClient(t.Context(), &dockerClient, logger)
require.NoError(t, err)
containerStateC, errC := containerClient.RunContainer(t.Context(), container.RunContainerParams{
Name: "test-run-container",
ChanSize: 1,
ContainerConfig: &dockercontainer.Config{Image: "alpine"},
HostConfig: &dockercontainer.HostConfig{},
ShouldRestart: func(_ int64, restartCount int, _ time.Duration) (bool, error) {
fmt.Println("ShouldRestart", restartCount)
if restartCount == 0 {
return true, nil
}
return false, errors.New("max restarts reached")
},
RestartInterval: 10 * time.Millisecond,
})
done := make(chan struct{})
go func() {
defer close(done)
require.NoError(t, <-errC)
}()
assert.Equal(t, "pulling", (<-containerStateC).Status)
assert.Equal(t, "created", (<-containerStateC).Status)
assert.Equal(t, "running", (<-containerStateC).Status)
assert.Equal(t, "running", (<-containerStateC).Status)
// Enough time for the restart to occur:
time.Sleep(100 * time.Millisecond)
containerWaitC <- dockercontainer.WaitResponse{StatusCode: 1}
state := <-containerStateC
assert.Equal(t, "restarting", state.Status)
assert.Equal(t, "unhealthy", state.HealthState)
assert.Nil(t, state.ExitCode)
assert.Equal(t, 0, state.RestartCount) // not incremented until the actual restart
// During the restart, the "running" status is triggered by Docker events
// only. So we don't expect one in unit tests. (Probably the initial startup
// flow should behave the same.)
time.Sleep(100 * time.Millisecond)
containerWaitC <- dockercontainer.WaitResponse{StatusCode: 1}
state = <-containerStateC
assert.Equal(t, "exited", state.Status)
assert.Equal(t, "unhealthy", state.HealthState)
require.NotNil(t, state.ExitCode)
assert.Equal(t, 1, *state.ExitCode)
assert.Equal(t, 1, state.RestartCount)
<-done
}
func TestClientRunContainerErrorStartingContainer(t *testing.T) { func TestClientRunContainerErrorStartingContainer(t *testing.T) {
logger := testhelpers.NewTestLogger(t) logger := testhelpers.NewTestLogger(t)