diff --git a/internal/container/integration_test.go b/internal/container/integration_test.go index f03374d..e812e3c 100644 --- a/internal/container/integration_test.go +++ b/internal/container/integration_test.go @@ -4,7 +4,7 @@ package container_test import ( "context" - "errors" + "fmt" "testing" "time" @@ -167,7 +167,9 @@ func TestIntegrationClientRemoveContainers(t *testing.T) { assert.NoError(t, <-err3C) } -func TestContainerRestart(t *testing.T) { +func TestIntegrationContainerRestart(t *testing.T) { + const wantRestartCount = 3 + ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) @@ -185,41 +187,45 @@ func TestContainerRestart(t *testing.T) { Name: containerName, ChanSize: 1, ContainerConfig: &typescontainer.Config{ - Image: "alpine:latest", - Cmd: []string{"sleep", "1"}, - Labels: map[string]string{container.LabelComponent: component}, + Image: "alpine:3.18", + Entrypoint: []string{"sleep", "1"}, + Cmd: []string{"1"}, // 1 second + Labels: map[string]string{container.LabelComponent: component}, }, HostConfig: &typescontainer.HostConfig{ - NetworkMode: "default", - RestartPolicy: typescontainer.RestartPolicy{Name: "always"}, + NetworkMode: "default", }, + ShouldRestart: func(_ int64, restartCount int, _ [][]byte, _ time.Duration) (bool, error) { + return restartCount < wantRestartCount, nil + }, + RestartInterval: 1 * time.Second, }) testhelpers.ChanRequireNoError(t, errC) - containerState := <-containerStateC - assert.Equal(t, "pulling", containerState.Status) - containerState = <-containerStateC - assert.Equal(t, "created", containerState.Status) - containerState = <-containerStateC - assert.Equal(t, "running", containerState.Status) +outer: + for { + select { + case containerState := <-containerStateC: + if containerState.Status == "running" { + break outer + } + case <-time.After(5 * time.Second): + require.Fail(t, "timeout waiting for container") + } + } err = nil // reset error done := make(chan struct{}) go func() { defer close(done) - var count int for { - containerState = <-containerStateC - if containerState.Status == domain.ContainerStatusRestarting { + containerState := <-containerStateC + if containerState.Status == domain.ContainerStatusExited { + if containerState.RestartCount != wantRestartCount { + err = fmt.Errorf("expected %d restarts, got %d", wantRestartCount, containerState.RestartCount) + } break - } else if containerState.Status == domain.ContainerStatusExited { - err = errors.New("container exited unexpectedly") - } else if count >= 5 { - err = errors.New("container did not enter restarting state") - } else { - // wait for a few state changes - count++ } } }()