From 6678489f692da95aadaa132ecc95378b4e924df6 Mon Sep 17 00:00:00 2001 From: Rob Watson Date: Sat, 1 Feb 2025 03:32:28 +0100 Subject: [PATCH] refactor(container): async start --- container/container.go | 114 ++++++++++++++++++++++-------------- container/container_test.go | 40 ++++++++----- domain/types.go | 16 ++--- main.go | 7 +-- mediaserver/actor.go | 52 ++++++++-------- mediaserver/actor_test.go | 6 +- terminal/actor.go | 23 ++++---- testhelpers/channel.go | 35 +++++++++++ testhelpers/logging.go | 10 +--- 9 files changed, 186 insertions(+), 117 deletions(-) create mode 100644 testhelpers/channel.go diff --git a/container/container.go b/container/container.go index da64f4c..ffd10ff 100644 --- a/container/container.go +++ b/container/container.go @@ -1,6 +1,7 @@ package container import ( + "cmp" "context" "encoding/json" "errors" @@ -21,8 +22,13 @@ import ( "github.com/google/uuid" ) -// stopTimeout is the timeout for stopping a container. -var stopTimeout = 10 * time.Second +const ( + // stopTimeout is the timeout for stopping a container. + stopTimeout = 3 * time.Second + + // defaultChanSize is the default size of asynchronous non-error channels. + defaultChanSize = 64 +) // Client provides a thin wrapper around the Docker API client, and provides // additional functionality such as exposing container stats. @@ -154,80 +160,98 @@ func (a *Client) getEvents(containerID string) <-chan events.Message { // RunContainerParams are the parameters for running a container. type RunContainerParams struct { Name string + ChanSize int ContainerConfig *container.Config HostConfig *container.HostConfig } // RunContainer runs a container with the given parameters. // -// The returned channel will receive the current state of the container, and -// will be closed after the container has stopped. -func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (string, <-chan domain.ContainerState, error) { - pullReader, err := a.apiClient.ImagePull(ctx, params.ContainerConfig.Image, image.PullOptions{}) - if err != nil { - return "", nil, fmt.Errorf("image pull: %w", err) - } - _, _ = io.Copy(io.Discard, pullReader) - _ = pullReader.Close() - - params.ContainerConfig.Labels["app"] = "termstream" - params.ContainerConfig.Labels["app-id"] = a.id.String() - - var name string - if params.Name != "" { - name = "termstream-" + a.id.String() + "-" + params.Name +// The returned state channel will receive the state of the container and will +// never be closed. The error channel will receive an error if the container +// fails to start, and will be closed when the container exits, possibly after +// receiving an error. +func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (<-chan domain.Container, <-chan error) { + now := time.Now() + containerStateC := make(chan domain.Container, cmp.Or(params.ChanSize, defaultChanSize)) + errC := make(chan error, 1) + closeWithError := func(err error) { + errC <- err + close(errC) } - createResp, err := a.apiClient.ContainerCreate( - ctx, - params.ContainerConfig, - params.HostConfig, - nil, - nil, - name, - ) - if err != nil { - return "", nil, fmt.Errorf("container create: %w", err) - } - - if err = a.apiClient.ContainerStart(ctx, createResp.ID, container.StartOptions{}); err != nil { - return "", nil, fmt.Errorf("container start: %w", err) - } - a.logger.Info("Started container", "id", shortID(createResp.ID)) - - containerStateC := make(chan domain.ContainerState, 1) a.wg.Add(1) go func() { defer a.wg.Done() - defer close(containerStateC) + defer close(errC) - a.runContainerLoop(ctx, createResp.ID, containerStateC) + containerStateC <- domain.Container{State: "pulling"} + + pullReader, err := a.apiClient.ImagePull(ctx, params.ContainerConfig.Image, image.PullOptions{}) + if err != nil { + closeWithError(fmt.Errorf("image pull: %w", err)) + return + } + _, _ = io.Copy(io.Discard, pullReader) + _ = pullReader.Close() + + params.ContainerConfig.Labels["app"] = "termstream" + params.ContainerConfig.Labels["app-id"] = a.id.String() + + var name string + if params.Name != "" { + name = "termstream-" + a.id.String() + "-" + params.Name + } + + createResp, err := a.apiClient.ContainerCreate( + ctx, + params.ContainerConfig, + params.HostConfig, + nil, + nil, + name, + ) + if err != nil { + closeWithError(fmt.Errorf("container create: %w", err)) + return + } + containerStateC <- domain.Container{ID: createResp.ID, State: "created"} + + if err = a.apiClient.ContainerStart(ctx, createResp.ID, container.StartOptions{}); err != nil { + closeWithError(fmt.Errorf("container start: %w", err)) + return + } + a.logger.Info("Started container", "id", shortID(createResp.ID), "duration", time.Since(now)) + containerStateC <- domain.Container{ID: createResp.ID, State: "running"} + + a.runContainerLoop(ctx, createResp.ID, containerStateC, errC) }() - return createResp.ID, containerStateC, nil + return containerStateC, errC } // runContainerLoop is the control loop for a single container. It returns only // when the container exits. -func (a *Client) runContainerLoop(ctx context.Context, containerID string, stateCh chan<- domain.ContainerState) { +func (a *Client) runContainerLoop(ctx context.Context, containerID string, stateC chan<- domain.Container, errC chan<- error) { statsC := a.getStats(containerID) eventsC := a.getEvents(containerID) - respC, errC := a.apiClient.ContainerWait(ctx, containerID, container.WaitConditionNotRunning) + containerRespC, containerErrC := a.apiClient.ContainerWait(ctx, containerID, container.WaitConditionNotRunning) - state := &domain.ContainerState{ID: containerID} - sendState := func() { stateCh <- *state } + state := &domain.Container{ID: containerID, State: "running"} + sendState := func() { stateC <- *state } sendState() for { select { - case resp := <-respC: + case resp := <-containerRespC: a.logger.Info("Container entered non-running state", "exit_code", resp.StatusCode, "id", shortID(containerID)) return - case err := <-errC: + case err := <-containerErrC: // TODO: error handling? if err != context.Canceled { a.logger.Error("Error setting container wait", "err", err, "id", shortID(containerID)) } + errC <- err return case evt := <-eventsC: if strings.Contains(string(evt.Action), "health_status") { diff --git a/container/container_test.go b/container/container_test.go index 2abbb31..7399aa9 100644 --- a/container/container_test.go +++ b/container/container_test.go @@ -28,8 +28,9 @@ func TestClientStartStop(t *testing.T) { require.NoError(t, err) assert.False(t, running) - containerID, containerStateC, err := client.RunContainer(ctx, container.RunContainerParams{ - Name: containerName, + containerStateC, errC := client.RunContainer(ctx, container.RunContainerParams{ + Name: containerName, + ChanSize: 1, ContainerConfig: &typescontainer.Config{ Image: "netfluxio/mediamtx-alpine:latest", Labels: map[string]string{"component": component}, @@ -38,10 +39,8 @@ func TestClientStartStop(t *testing.T) { NetworkMode: "default", }, }) - require.NoError(t, err) - testhelpers.DiscardChannel(containerStateC) - - assert.NotEmpty(t, containerID) + testhelpers.ChanDiscard(containerStateC) + testhelpers.ChanRequireNoError(t, errC) require.Eventually( t, @@ -49,12 +48,13 @@ func TestClientStartStop(t *testing.T) { running, err = client.ContainerRunning(ctx, map[string]string{"component": component}) return err == nil && running }, - 2*time.Second, + 5*time.Second, 100*time.Millisecond, "container not in RUNNING state", ) client.Close() + require.NoError(t, <-errC) running, err = client.ContainerRunning(ctx, map[string]string{"component": component}) require.NoError(t, err) @@ -72,7 +72,8 @@ func TestClientRemoveContainers(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { client.Close() }) - _, stateC, err := client.RunContainer(ctx, container.RunContainerParams{ + stateC, err1C := client.RunContainer(ctx, container.RunContainerParams{ + ChanSize: 1, ContainerConfig: &typescontainer.Config{ Image: "netfluxio/mediamtx-alpine:latest", Labels: map[string]string{"component": component, "group": "test1"}, @@ -80,9 +81,10 @@ func TestClientRemoveContainers(t *testing.T) { HostConfig: &typescontainer.HostConfig{NetworkMode: "default"}, }) require.NoError(t, err) - testhelpers.DiscardChannel(stateC) + testhelpers.ChanDiscard(stateC) - _, stateC, err = client.RunContainer(ctx, container.RunContainerParams{ + stateC, err2C := client.RunContainer(ctx, container.RunContainerParams{ + ChanSize: 1, ContainerConfig: &typescontainer.Config{ Image: "netfluxio/mediamtx-alpine:latest", Labels: map[string]string{"component": component, "group": "test1"}, @@ -90,9 +92,10 @@ func TestClientRemoveContainers(t *testing.T) { HostConfig: &typescontainer.HostConfig{NetworkMode: "default"}, }) require.NoError(t, err) - testhelpers.DiscardChannel(stateC) + testhelpers.ChanDiscard(stateC) - _, stateC, err = client.RunContainer(ctx, container.RunContainerParams{ + stateC, err3C := client.RunContainer(ctx, container.RunContainerParams{ + ChanSize: 1, ContainerConfig: &typescontainer.Config{ Image: "netfluxio/mediamtx-alpine:latest", Labels: map[string]string{"component": component, "group": "test2"}, @@ -100,7 +103,7 @@ func TestClientRemoveContainers(t *testing.T) { HostConfig: &typescontainer.HostConfig{NetworkMode: "default"}, }) require.NoError(t, err) - testhelpers.DiscardChannel(stateC) + testhelpers.ChanDiscard(stateC) // check all containers in group 1 are running require.Eventually( @@ -109,7 +112,7 @@ func TestClientRemoveContainers(t *testing.T) { running, _ := client.ContainerRunning(ctx, map[string]string{"group": "test1"}) return running }, - 2*time.Second, + 5*time.Second, 100*time.Millisecond, "container group 1 not in RUNNING state", ) @@ -145,5 +148,12 @@ func TestClientRemoveContainers(t *testing.T) { // check group 2 is still running running, err := client.ContainerRunning(ctx, map[string]string{"group": "test2"}) require.NoError(t, err) - require.True(t, running) + assert.True(t, running) + + assert.NoError(t, <-err1C) + assert.NoError(t, <-err2C) + + client.Close() + + assert.NoError(t, <-err3C) } diff --git a/domain/types.go b/domain/types.go index 0a9da6d..ece5eb7 100644 --- a/domain/types.go +++ b/domain/types.go @@ -8,23 +8,25 @@ type AppState struct { // Source represents the source, currently always the mediaserver. type Source struct { - ContainerState ContainerState - Live bool - URL string + Container Container + Live bool + URL string } // Destination is a single destination. type Destination struct { - ContainerState ContainerState - URL string + Container Container + Live bool + URL string } -// ContainerState represents the current state of an individual container. +// Container represents the current state of an individual container. // // The source of truth is always the Docker daemon, this struct is used only // for passing asynchronous state. -type ContainerState struct { +type Container struct { ID string + State string HealthState string CPUPercent float64 MemoryUsageBytes uint64 diff --git a/main.go b/main.go index 1cb7138..6145ebf 100644 --- a/main.go +++ b/main.go @@ -48,14 +48,10 @@ func run(ctx context.Context, cfgReader io.Reader) error { } defer containerClient.Close() - srv, err := mediaserver.StartActor(ctx, mediaserver.StartActorParams{ + srv := mediaserver.StartActor(ctx, mediaserver.StartActorParams{ ContainerClient: containerClient, Logger: logger.With("component", "mediaserver"), }) - if err != nil { - return fmt.Errorf("start media server: %w", err) - } - applyServerState(srv.State(), state) ui, err := terminal.StartActor(ctx, terminal.StartActorParams{Logger: logger.With("component", "ui")}) if err != nil { @@ -73,6 +69,7 @@ func run(ctx context.Context, cfgReader io.Reader) error { select { case cmd, ok := <-ui.C(): if !ok { + // TODO: keep UI open until all containers have closed logger.Info("UI closed") return nil } diff --git a/mediaserver/actor.go b/mediaserver/actor.go index ebfe6ed..a1311b7 100644 --- a/mediaserver/actor.go +++ b/mediaserver/actor.go @@ -51,7 +51,9 @@ const ( ) // StartActor starts a new media server actor. -func StartActor(ctx context.Context, params StartActorParams) (*Actor, error) { +// +// Callers must consume the state channel exposed via [C]. +func StartActor(ctx context.Context, params StartActorParams) *Actor { chanSize := cmp.Or(params.ChanSize, defaultChanSize) actor := &Actor{ @@ -63,10 +65,11 @@ func StartActor(ctx context.Context, params StartActorParams) (*Actor, error) { httpClient: &http.Client{Timeout: httpClientTimeout}, } - containerID, containerStateC, err := params.ContainerClient.RunContainer( + containerStateC, errC := params.ContainerClient.RunContainer( ctx, container.RunContainerParams{ - Name: "server", + Name: "server", + ChanSize: chanSize, ContainerConfig: &typescontainer.Config{ Image: imageNameMediaMTX, Env: []string{ @@ -89,16 +92,12 @@ func StartActor(ctx context.Context, params StartActorParams) (*Actor, error) { }, }, ) - if err != nil { - return nil, fmt.Errorf("run container: %w", err) - } - actor.state.ContainerState.ID = containerID actor.state.URL = "rtmp://localhost:1935/" + rtmpPath - go actor.actorLoop(containerStateC) + go actor.actorLoop(containerStateC, errC) - return actor, nil + return actor } // C returns a channel that will receive the current state of the media server. @@ -128,7 +127,7 @@ func (s *Actor) Close() error { // actorLoop is the main loop of the media server actor. It only exits when the // actor is closed. -func (s *Actor) actorLoop(containerStateC <-chan domain.ContainerState) { +func (s *Actor) actorLoop(containerStateC <-chan domain.Container, errC <-chan error) { ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -136,22 +135,22 @@ func (s *Actor) actorLoop(containerStateC <-chan domain.ContainerState) { for { select { - case containerState, ok := <-containerStateC: - if !ok { - ticker.Stop() - - if s.state.Live { - s.state.Live = false - sendState() - } - - continue - } - - s.state.ContainerState = containerState + case containerState := <-containerStateC: + s.state.Container = containerState sendState() continue + case err := <-errC: + if err != nil { + s.logger.Error("Error from container client", "error", err, "id", shortID(s.state.Container.ID)) + } + + ticker.Stop() + + if s.state.Live { + s.state.Live = false + sendState() + } case <-ticker.C: ingressLive, err := s.fetchIngressStateFromServer() if err != nil { @@ -218,3 +217,10 @@ func (s *Actor) fetchIngressStateFromServer() (bool, error) { return false, nil } + +func shortID(id string) string { + if len(id) < 12 { + return id + } + return id[:12] +} diff --git a/mediaserver/actor_test.go b/mediaserver/actor_test.go index 990a9c0..f268979 100644 --- a/mediaserver/actor_test.go +++ b/mediaserver/actor_test.go @@ -29,13 +29,13 @@ func TestMediaServerStartStop(t *testing.T) { require.NoError(t, err) assert.False(t, running) - mediaServer, err := mediaserver.StartActor(ctx, mediaserver.StartActorParams{ + mediaServer := mediaserver.StartActor(ctx, mediaserver.StartActorParams{ ChanSize: 1, ContainerClient: containerClient, Logger: logger, }) require.NoError(t, err) - testhelpers.DiscardChannel(mediaServer.C()) + testhelpers.ChanDiscard(mediaServer.C()) require.Eventually( t, @@ -57,7 +57,7 @@ func TestMediaServerStartStop(t *testing.T) { t, func() bool { currState := mediaServer.State() - return currState.Live && currState.ContainerState.HealthState == "healthy" + return currState.Live && currState.Container.HealthState == "healthy" }, 5*time.Second, 250*time.Millisecond, diff --git a/terminal/actor.go b/terminal/actor.go index 9d70400..98e39d8 100644 --- a/terminal/actor.go +++ b/terminal/actor.go @@ -134,11 +134,12 @@ func (a *Actor) SetState(state domain.AppState) { func (a *Actor) redrawFromState(state domain.AppState) { setHeaderRow := func(tableView *tview.Table) { tableView.SetCell(0, 0, tview.NewTableCell("[grey]URL").SetAlign(tview.AlignLeft).SetExpansion(7).SetSelectable(false)) - tableView.SetCell(0, 1, tview.NewTableCell("[grey]Status").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false)) - tableView.SetCell(0, 2, tview.NewTableCell("[grey]Health").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false)) - tableView.SetCell(0, 3, tview.NewTableCell("[grey]CPU %").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false)) - tableView.SetCell(0, 4, tview.NewTableCell("[grey]Mem used (MB)").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false)) - tableView.SetCell(0, 5, tview.NewTableCell("[grey]Actions").SetAlign(tview.AlignLeft).SetExpansion(2).SetSelectable(false)) + tableView.SetCell(0, 1, tview.NewTableCell("[grey]Stream").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false)) + tableView.SetCell(0, 2, tview.NewTableCell("[grey]Container").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false)) + tableView.SetCell(0, 3, tview.NewTableCell("[grey]Health").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false)) + tableView.SetCell(0, 4, tview.NewTableCell("[grey]CPU %").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false)) + tableView.SetCell(0, 5, tview.NewTableCell("[grey]Mem used (MB)").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false)) + tableView.SetCell(0, 6, tview.NewTableCell("[grey]Actions").SetAlign(tview.AlignLeft).SetExpansion(2).SetSelectable(false)) } a.sourceView.Clear() @@ -150,10 +151,11 @@ func (a *Actor) redrawFromState(state domain.AppState) { } else { a.sourceView.SetCell(1, 1, tview.NewTableCell("[yellow]off-air")) } - a.sourceView.SetCell(1, 2, tview.NewTableCell("[white]"+cmp.Or(state.Source.ContainerState.HealthState, "starting"))) - a.sourceView.SetCell(1, 3, tview.NewTableCell("[white]"+fmt.Sprintf("%.1f", state.Source.ContainerState.CPUPercent))) - a.sourceView.SetCell(1, 4, tview.NewTableCell("[white]"+fmt.Sprintf("%.1f", float64(state.Source.ContainerState.MemoryUsageBytes)/1024/1024))) - a.sourceView.SetCell(1, 5, tview.NewTableCell("")) + a.sourceView.SetCell(1, 2, tview.NewTableCell("[white]"+state.Source.Container.State)) + a.sourceView.SetCell(1, 3, tview.NewTableCell("[white]"+cmp.Or(state.Source.Container.HealthState, "starting"))) + a.sourceView.SetCell(1, 4, tview.NewTableCell("[white]"+fmt.Sprintf("%.1f", state.Source.Container.CPUPercent))) + a.sourceView.SetCell(1, 5, tview.NewTableCell("[white]"+fmt.Sprintf("%.1f", float64(state.Source.Container.MemoryUsageBytes)/1024/1024))) + a.sourceView.SetCell(1, 6, tview.NewTableCell("")) a.destView.Clear() setHeaderRow(a.destView) @@ -164,7 +166,8 @@ func (a *Actor) redrawFromState(state domain.AppState) { a.destView.SetCell(i+1, 2, tview.NewTableCell("[white]-")) a.destView.SetCell(i+1, 3, tview.NewTableCell("[white]-")) a.destView.SetCell(i+1, 4, tview.NewTableCell("[white]-")) - a.destView.SetCell(i+1, 5, tview.NewTableCell("[green]Tab to go live")) + a.destView.SetCell(i+1, 5, tview.NewTableCell("[white]-")) + a.destView.SetCell(i+1, 6, tview.NewTableCell("[green]Tab to go live")) } a.app.Draw() diff --git a/testhelpers/channel.go b/testhelpers/channel.go new file mode 100644 index 0000000..9a9ec96 --- /dev/null +++ b/testhelpers/channel.go @@ -0,0 +1,35 @@ +package testhelpers + +import ( + "log/slog" + "testing" + + "github.com/stretchr/testify/require" +) + +// ChanDiscard consumes a channel and discards all values. +func ChanDiscard[T any](ch <-chan T) { + go func() { + for range ch { + // no-op + } + }() +} + +// ChanRequireNoError consumes a channel and asserts that no error is received. +func ChanRequireNoError(t testing.TB, ch <-chan error) { + t.Helper() + + go func() { + require.NoError(t, <-ch) + }() +} + +// ChanLog logs a channel's values. +func ChanLog[T any](ch <-chan T, logger *slog.Logger) { + go func() { + for v := range ch { + logger.Info("Channel", "value", v) + } + }() +} diff --git a/testhelpers/logging.go b/testhelpers/logging.go index 0012c8d..c964506 100644 --- a/testhelpers/logging.go +++ b/testhelpers/logging.go @@ -13,15 +13,7 @@ func NewNopLogger() *slog.Logger { return slog.New(slog.NewJSONHandler(io.Discard, nil)) } +// NewTestLogger returns a logger that writes to stderr. func NewTestLogger() *slog.Logger { return slog.New(slog.NewTextHandler(os.Stderr, nil)) } - -// NoopChannel consumes a channel and discards all values. -func DiscardChannel[T any](ch <-chan T) { - go func() { - for range ch { - // no-op - } - }() -}