From 7eff6c60659788079100708c68cf775402a4e9c4 Mon Sep 17 00:00:00 2001 From: Rob Watson Date: Sat, 12 Apr 2025 19:08:17 +0200 Subject: [PATCH] refactor(container): restart handling --- internal/container/container.go | 90 ++++++++++++++++++++++--------- internal/replicator/replicator.go | 18 +++++-- internal/terminal/terminal.go | 8 +-- 3 files changed, 80 insertions(+), 36 deletions(-) diff --git a/internal/container/container.go b/internal/container/container.go index 960a78d..30555f2 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -147,6 +147,11 @@ type CopyFileConfig struct { Mode int64 } +// ShouldRestartFunc is a callback function that is called when a container +// exits. It should return true if the container is to be restarted. If not +// restarting, err can be non-nil. +type ShouldRestartFunc func(exitCode int64, restartCount int, runningTime time.Duration) (bool, error) + // RunContainerParams are the parameters for running a container. type RunContainerParams struct { Name string @@ -156,6 +161,7 @@ type RunContainerParams struct { NetworkingConfig *network.NetworkingConfig NetworkCountConfig NetworkCountConfig CopyFileConfigs []CopyFileConfig + ShouldRestart ShouldRestartFunc } // RunContainer runs a container with the given parameters. @@ -164,13 +170,18 @@ type RunContainerParams struct { // never be closed. The error channel will receive an error if the container // fails to start, and will be closed when the container exits, possibly after // receiving an error. +// +// Panics if ShouldRestart is non-nil and the host config defines a restart +// policy of its own. func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (<-chan domain.Container, <-chan error) { + if params.ShouldRestart != nil && !params.HostConfig.RestartPolicy.IsNone() { + panic("shouldRestart and restart policy are mutually exclusive") + } + now := time.Now() containerStateC := make(chan domain.Container, cmp.Or(params.ChanSize, defaultChanSize)) errC := make(chan error, 1) - sendError := func(err error) { - errC <- err - } + sendError := func(err error) { errC <- err } a.wg.Add(1) go func() { @@ -229,6 +240,7 @@ func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (< createResp.ID, params.ContainerConfig.Image, params.NetworkCountConfig, + params.ShouldRestart, containerStateC, errC, ) @@ -314,12 +326,16 @@ func (a *Client) runContainerLoop( containerID string, imageName string, networkCountConfig NetworkCountConfig, + shouldRestart ShouldRestartFunc, stateC chan<- domain.Container, errC chan<- error, ) { type containerWaitResponse struct { container.WaitResponse - restarting bool + + restarting bool + restartCount int + err error } containerRespC := make(chan containerWaitResponse) @@ -333,33 +349,54 @@ func (a *Client) runContainerLoop( // The goroutine exits when a value is received on the error channel, or when // the container exits and is not restarting, or when the context is cancelled. go func() { + const restartDuration = 5 * time.Second + timer := time.NewTimer(restartDuration) + defer timer.Stop() + timer.Stop() + var restartCount int + for { + startedWaitingAt := time.Now() respC, errC := a.apiClient.ContainerWait(ctx, containerID, container.WaitConditionNextExit) select { case resp := <-respC: - var restarting bool - // Check if the container is restarting. If it is not then we don't - // want to wait for it again and can return early. - ctr, err := a.apiClient.ContainerInspect(ctx, containerID) - // Race condition: the container may already have been removed. - if errdefs.IsNotFound(err) { - // ignore error but do not restart - } else if err != nil { - a.logger.Error("Error inspecting container", "err", err, "id", shortID(containerID)) - containerErrC <- err - return - // Race condition: the container may have already restarted. - } else if ctr.State.Status == domain.ContainerStatusRestarting || ctr.State.Status == domain.ContainerStatusRunning { - restarting = true + if shouldRestart != nil { + shouldRestart, err := shouldRestart(resp.StatusCode, restartCount, time.Since(startedWaitingAt)) + if shouldRestart && err != nil { + panic(fmt.Errorf("shouldRestart must return nil error if restarting, but returned: %w", err)) + } + if !shouldRestart { + a.logger.Info("Container exited", "id", shortID(containerID), "should_restart", "false", "exit_code", resp.StatusCode, "restart_count", restartCount) + containerRespC <- containerWaitResponse{WaitResponse: resp, restarting: false, err: err} + return + } } - containerRespC <- containerWaitResponse{WaitResponse: resp, restarting: restarting} - if !restarting { - return + restartCount++ + + // TODO: exponential backoff + a.logger.Info("Container exited", "id", shortID(containerID), "should_restart", "true", "exit_code", resp.StatusCode, "restart_count", restartCount) + timer.Reset(restartDuration) + + containerRespC <- containerWaitResponse{WaitResponse: resp, restarting: true} + case <-timer.C: + a.logger.Info("Container restarting", "id", shortID(containerID), "restart_count", restartCount) + if err := a.apiClient.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil { + containerErrC <- fmt.Errorf("container start: %w", err) } case err := <-errC: - // Otherwise, this is probably unexpected and we need to handle it. - containerErrC <- err + // If this is a not found error, the container has been removed - + // probably by the user. This is a bit hacky, and should be more + // explicit, possibly by signalling to this package that the container + // has been removed by the user instead of just calling + // ContainerRemove. + // TODO: improve this + if errdefs.IsNotFound(err) { + a.logger.Debug("Container not found when setting ContainerWait, ignoring", "id", shortID(containerID)) + containerRespC <- containerWaitResponse{WaitResponse: container.WaitResponse{}, restarting: false} + } else { + containerErrC <- err + } return case <-ctx.Done(): containerErrC <- ctx.Err() @@ -382,20 +419,23 @@ func (a *Client) runContainerLoop( a.logger.Info("Container entered non-running state", "exit_code", resp.StatusCode, "id", shortID(containerID), "restarting", resp.restarting) var containerState string + var containerErr error if resp.restarting { containerState = domain.ContainerStatusRestarting } else { containerState = domain.ContainerStatusExited + containerErr = resp.err } state.Status = containerState + state.Err = containerErr + state.RestartCount = resp.restartCount state.CPUPercent = 0 state.MemoryUsageBytes = 0 state.HealthState = "unhealthy" state.RxRate = 0 state.TxRate = 0 state.RxSince = time.Time{} - state.RestartCount++ if !resp.restarting { exitCode := int(resp.StatusCode) @@ -406,7 +446,7 @@ func (a *Client) runContainerLoop( sendState() case err := <-containerErrC: - // TODO: error handling? + // TODO: verify error handling if err != context.Canceled { a.logger.Error("Error setting container wait", "err", err, "id", shortID(containerID)) } diff --git a/internal/replicator/replicator.go b/internal/replicator/replicator.go index 4904a92..d3ca087 100644 --- a/internal/replicator/replicator.go +++ b/internal/replicator/replicator.go @@ -3,6 +3,7 @@ package replicator import ( "cmp" "context" + "errors" "fmt" "log/slog" "strconv" @@ -105,11 +106,20 @@ func (a *Actor) StartDestination(url string) { container.LabelURL: url, }, }, - HostConfig: &typescontainer.HostConfig{ - NetworkMode: "default", - RestartPolicy: typescontainer.RestartPolicy{Name: "always"}, - }, + HostConfig: &typescontainer.HostConfig{NetworkMode: "default"}, NetworkCountConfig: container.NetworkCountConfig{Rx: "eth1", Tx: "eth0"}, + ShouldRestart: func(_ int64, restartCount int, runningTime time.Duration) (bool, error) { + // Try to infer if the container failed to start. + // + // TODO: this is a bit hacky, we should check the container logs and + // include some details in the error message. + if restartCount == 0 && runningTime < 10*time.Second { + return false, errors.New("container failed to start") + } + + // Otherwise, always restart, regardless of the exit code. + return true, nil + }, }) a.wg.Add(1) diff --git a/internal/terminal/terminal.go b/internal/terminal/terminal.go index b1abca5..cec9ad3 100644 --- a/internal/terminal/terminal.go +++ b/internal/terminal/terminal.go @@ -366,8 +366,6 @@ func (ui *UI) ShowStartupCheckModal() bool { } func (ui *UI) ShowDestinationErrorModal(name string, err error) { - done := make(chan struct{}) - ui.app.QueueUpdateDraw(func() { ui.showModal( pageNameModalDestinationError, @@ -377,13 +375,9 @@ func (ui *UI) ShowDestinationErrorModal(name string, err error) { err, ), []string{"Ok"}, - func(int, string) { - done <- struct{}{} - }, + nil, ) }) - - <-done } // ShowFatalErrorModal displays the provided error. It sends a CommandQuit to the