refactor(container): async start

This commit is contained in:
Rob Watson 2025-02-01 03:32:28 +01:00
parent 9c7989018b
commit 6678489f69
9 changed files with 186 additions and 117 deletions

View File

@ -1,6 +1,7 @@
package container
import (
"cmp"
"context"
"encoding/json"
"errors"
@ -21,8 +22,13 @@ import (
"github.com/google/uuid"
)
const (
// stopTimeout is the timeout for stopping a container.
var stopTimeout = 10 * time.Second
stopTimeout = 3 * time.Second
// defaultChanSize is the default size of asynchronous non-error channels.
defaultChanSize = 64
)
// Client provides a thin wrapper around the Docker API client, and provides
// additional functionality such as exposing container stats.
@ -154,18 +160,37 @@ func (a *Client) getEvents(containerID string) <-chan events.Message {
// RunContainerParams are the parameters for running a container.
type RunContainerParams struct {
Name string
ChanSize int
ContainerConfig *container.Config
HostConfig *container.HostConfig
}
// RunContainer runs a container with the given parameters.
//
// The returned channel will receive the current state of the container, and
// will be closed after the container has stopped.
func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (string, <-chan domain.ContainerState, error) {
// The returned state channel will receive the state of the container and will
// never be closed. The error channel will receive an error if the container
// fails to start, and will be closed when the container exits, possibly after
// receiving an error.
func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (<-chan domain.Container, <-chan error) {
now := time.Now()
containerStateC := make(chan domain.Container, cmp.Or(params.ChanSize, defaultChanSize))
errC := make(chan error, 1)
closeWithError := func(err error) {
errC <- err
close(errC)
}
a.wg.Add(1)
go func() {
defer a.wg.Done()
defer close(errC)
containerStateC <- domain.Container{State: "pulling"}
pullReader, err := a.apiClient.ImagePull(ctx, params.ContainerConfig.Image, image.PullOptions{})
if err != nil {
return "", nil, fmt.Errorf("image pull: %w", err)
closeWithError(fmt.Errorf("image pull: %w", err))
return
}
_, _ = io.Copy(io.Discard, pullReader)
_ = pullReader.Close()
@ -187,47 +212,46 @@ func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (s
name,
)
if err != nil {
return "", nil, fmt.Errorf("container create: %w", err)
closeWithError(fmt.Errorf("container create: %w", err))
return
}
containerStateC <- domain.Container{ID: createResp.ID, State: "created"}
if err = a.apiClient.ContainerStart(ctx, createResp.ID, container.StartOptions{}); err != nil {
return "", nil, fmt.Errorf("container start: %w", err)
closeWithError(fmt.Errorf("container start: %w", err))
return
}
a.logger.Info("Started container", "id", shortID(createResp.ID))
a.logger.Info("Started container", "id", shortID(createResp.ID), "duration", time.Since(now))
containerStateC <- domain.Container{ID: createResp.ID, State: "running"}
containerStateC := make(chan domain.ContainerState, 1)
a.wg.Add(1)
go func() {
defer a.wg.Done()
defer close(containerStateC)
a.runContainerLoop(ctx, createResp.ID, containerStateC)
a.runContainerLoop(ctx, createResp.ID, containerStateC, errC)
}()
return createResp.ID, containerStateC, nil
return containerStateC, errC
}
// runContainerLoop is the control loop for a single container. It returns only
// when the container exits.
func (a *Client) runContainerLoop(ctx context.Context, containerID string, stateCh chan<- domain.ContainerState) {
func (a *Client) runContainerLoop(ctx context.Context, containerID string, stateC chan<- domain.Container, errC chan<- error) {
statsC := a.getStats(containerID)
eventsC := a.getEvents(containerID)
respC, errC := a.apiClient.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
containerRespC, containerErrC := a.apiClient.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
state := &domain.ContainerState{ID: containerID}
sendState := func() { stateCh <- *state }
state := &domain.Container{ID: containerID, State: "running"}
sendState := func() { stateC <- *state }
sendState()
for {
select {
case resp := <-respC:
case resp := <-containerRespC:
a.logger.Info("Container entered non-running state", "exit_code", resp.StatusCode, "id", shortID(containerID))
return
case err := <-errC:
case err := <-containerErrC:
// TODO: error handling?
if err != context.Canceled {
a.logger.Error("Error setting container wait", "err", err, "id", shortID(containerID))
}
errC <- err
return
case evt := <-eventsC:
if strings.Contains(string(evt.Action), "health_status") {

View File

@ -28,8 +28,9 @@ func TestClientStartStop(t *testing.T) {
require.NoError(t, err)
assert.False(t, running)
containerID, containerStateC, err := client.RunContainer(ctx, container.RunContainerParams{
containerStateC, errC := client.RunContainer(ctx, container.RunContainerParams{
Name: containerName,
ChanSize: 1,
ContainerConfig: &typescontainer.Config{
Image: "netfluxio/mediamtx-alpine:latest",
Labels: map[string]string{"component": component},
@ -38,10 +39,8 @@ func TestClientStartStop(t *testing.T) {
NetworkMode: "default",
},
})
require.NoError(t, err)
testhelpers.DiscardChannel(containerStateC)
assert.NotEmpty(t, containerID)
testhelpers.ChanDiscard(containerStateC)
testhelpers.ChanRequireNoError(t, errC)
require.Eventually(
t,
@ -49,12 +48,13 @@ func TestClientStartStop(t *testing.T) {
running, err = client.ContainerRunning(ctx, map[string]string{"component": component})
return err == nil && running
},
2*time.Second,
5*time.Second,
100*time.Millisecond,
"container not in RUNNING state",
)
client.Close()
require.NoError(t, <-errC)
running, err = client.ContainerRunning(ctx, map[string]string{"component": component})
require.NoError(t, err)
@ -72,7 +72,8 @@ func TestClientRemoveContainers(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { client.Close() })
_, stateC, err := client.RunContainer(ctx, container.RunContainerParams{
stateC, err1C := client.RunContainer(ctx, container.RunContainerParams{
ChanSize: 1,
ContainerConfig: &typescontainer.Config{
Image: "netfluxio/mediamtx-alpine:latest",
Labels: map[string]string{"component": component, "group": "test1"},
@ -80,9 +81,10 @@ func TestClientRemoveContainers(t *testing.T) {
HostConfig: &typescontainer.HostConfig{NetworkMode: "default"},
})
require.NoError(t, err)
testhelpers.DiscardChannel(stateC)
testhelpers.ChanDiscard(stateC)
_, stateC, err = client.RunContainer(ctx, container.RunContainerParams{
stateC, err2C := client.RunContainer(ctx, container.RunContainerParams{
ChanSize: 1,
ContainerConfig: &typescontainer.Config{
Image: "netfluxio/mediamtx-alpine:latest",
Labels: map[string]string{"component": component, "group": "test1"},
@ -90,9 +92,10 @@ func TestClientRemoveContainers(t *testing.T) {
HostConfig: &typescontainer.HostConfig{NetworkMode: "default"},
})
require.NoError(t, err)
testhelpers.DiscardChannel(stateC)
testhelpers.ChanDiscard(stateC)
_, stateC, err = client.RunContainer(ctx, container.RunContainerParams{
stateC, err3C := client.RunContainer(ctx, container.RunContainerParams{
ChanSize: 1,
ContainerConfig: &typescontainer.Config{
Image: "netfluxio/mediamtx-alpine:latest",
Labels: map[string]string{"component": component, "group": "test2"},
@ -100,7 +103,7 @@ func TestClientRemoveContainers(t *testing.T) {
HostConfig: &typescontainer.HostConfig{NetworkMode: "default"},
})
require.NoError(t, err)
testhelpers.DiscardChannel(stateC)
testhelpers.ChanDiscard(stateC)
// check all containers in group 1 are running
require.Eventually(
@ -109,7 +112,7 @@ func TestClientRemoveContainers(t *testing.T) {
running, _ := client.ContainerRunning(ctx, map[string]string{"group": "test1"})
return running
},
2*time.Second,
5*time.Second,
100*time.Millisecond,
"container group 1 not in RUNNING state",
)
@ -145,5 +148,12 @@ func TestClientRemoveContainers(t *testing.T) {
// check group 2 is still running
running, err := client.ContainerRunning(ctx, map[string]string{"group": "test2"})
require.NoError(t, err)
require.True(t, running)
assert.True(t, running)
assert.NoError(t, <-err1C)
assert.NoError(t, <-err2C)
client.Close()
assert.NoError(t, <-err3C)
}

View File

@ -8,23 +8,25 @@ type AppState struct {
// Source represents the source, currently always the mediaserver.
type Source struct {
ContainerState ContainerState
Container Container
Live bool
URL string
}
// Destination is a single destination.
type Destination struct {
ContainerState ContainerState
Container Container
Live bool
URL string
}
// ContainerState represents the current state of an individual container.
// Container represents the current state of an individual container.
//
// The source of truth is always the Docker daemon, this struct is used only
// for passing asynchronous state.
type ContainerState struct {
type Container struct {
ID string
State string
HealthState string
CPUPercent float64
MemoryUsageBytes uint64

View File

@ -48,14 +48,10 @@ func run(ctx context.Context, cfgReader io.Reader) error {
}
defer containerClient.Close()
srv, err := mediaserver.StartActor(ctx, mediaserver.StartActorParams{
srv := mediaserver.StartActor(ctx, mediaserver.StartActorParams{
ContainerClient: containerClient,
Logger: logger.With("component", "mediaserver"),
})
if err != nil {
return fmt.Errorf("start media server: %w", err)
}
applyServerState(srv.State(), state)
ui, err := terminal.StartActor(ctx, terminal.StartActorParams{Logger: logger.With("component", "ui")})
if err != nil {
@ -73,6 +69,7 @@ func run(ctx context.Context, cfgReader io.Reader) error {
select {
case cmd, ok := <-ui.C():
if !ok {
// TODO: keep UI open until all containers have closed
logger.Info("UI closed")
return nil
}

View File

@ -51,7 +51,9 @@ const (
)
// StartActor starts a new media server actor.
func StartActor(ctx context.Context, params StartActorParams) (*Actor, error) {
//
// Callers must consume the state channel exposed via [C].
func StartActor(ctx context.Context, params StartActorParams) *Actor {
chanSize := cmp.Or(params.ChanSize, defaultChanSize)
actor := &Actor{
@ -63,10 +65,11 @@ func StartActor(ctx context.Context, params StartActorParams) (*Actor, error) {
httpClient: &http.Client{Timeout: httpClientTimeout},
}
containerID, containerStateC, err := params.ContainerClient.RunContainer(
containerStateC, errC := params.ContainerClient.RunContainer(
ctx,
container.RunContainerParams{
Name: "server",
ChanSize: chanSize,
ContainerConfig: &typescontainer.Config{
Image: imageNameMediaMTX,
Env: []string{
@ -89,16 +92,12 @@ func StartActor(ctx context.Context, params StartActorParams) (*Actor, error) {
},
},
)
if err != nil {
return nil, fmt.Errorf("run container: %w", err)
}
actor.state.ContainerState.ID = containerID
actor.state.URL = "rtmp://localhost:1935/" + rtmpPath
go actor.actorLoop(containerStateC)
go actor.actorLoop(containerStateC, errC)
return actor, nil
return actor
}
// C returns a channel that will receive the current state of the media server.
@ -128,7 +127,7 @@ func (s *Actor) Close() error {
// actorLoop is the main loop of the media server actor. It only exits when the
// actor is closed.
func (s *Actor) actorLoop(containerStateC <-chan domain.ContainerState) {
func (s *Actor) actorLoop(containerStateC <-chan domain.Container, errC <-chan error) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
@ -136,22 +135,22 @@ func (s *Actor) actorLoop(containerStateC <-chan domain.ContainerState) {
for {
select {
case containerState, ok := <-containerStateC:
if !ok {
case containerState := <-containerStateC:
s.state.Container = containerState
sendState()
continue
case err := <-errC:
if err != nil {
s.logger.Error("Error from container client", "error", err, "id", shortID(s.state.Container.ID))
}
ticker.Stop()
if s.state.Live {
s.state.Live = false
sendState()
}
continue
}
s.state.ContainerState = containerState
sendState()
continue
case <-ticker.C:
ingressLive, err := s.fetchIngressStateFromServer()
if err != nil {
@ -218,3 +217,10 @@ func (s *Actor) fetchIngressStateFromServer() (bool, error) {
return false, nil
}
func shortID(id string) string {
if len(id) < 12 {
return id
}
return id[:12]
}

View File

@ -29,13 +29,13 @@ func TestMediaServerStartStop(t *testing.T) {
require.NoError(t, err)
assert.False(t, running)
mediaServer, err := mediaserver.StartActor(ctx, mediaserver.StartActorParams{
mediaServer := mediaserver.StartActor(ctx, mediaserver.StartActorParams{
ChanSize: 1,
ContainerClient: containerClient,
Logger: logger,
})
require.NoError(t, err)
testhelpers.DiscardChannel(mediaServer.C())
testhelpers.ChanDiscard(mediaServer.C())
require.Eventually(
t,
@ -57,7 +57,7 @@ func TestMediaServerStartStop(t *testing.T) {
t,
func() bool {
currState := mediaServer.State()
return currState.Live && currState.ContainerState.HealthState == "healthy"
return currState.Live && currState.Container.HealthState == "healthy"
},
5*time.Second,
250*time.Millisecond,

View File

@ -134,11 +134,12 @@ func (a *Actor) SetState(state domain.AppState) {
func (a *Actor) redrawFromState(state domain.AppState) {
setHeaderRow := func(tableView *tview.Table) {
tableView.SetCell(0, 0, tview.NewTableCell("[grey]URL").SetAlign(tview.AlignLeft).SetExpansion(7).SetSelectable(false))
tableView.SetCell(0, 1, tview.NewTableCell("[grey]Status").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false))
tableView.SetCell(0, 2, tview.NewTableCell("[grey]Health").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false))
tableView.SetCell(0, 3, tview.NewTableCell("[grey]CPU %").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false))
tableView.SetCell(0, 4, tview.NewTableCell("[grey]Mem used (MB)").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false))
tableView.SetCell(0, 5, tview.NewTableCell("[grey]Actions").SetAlign(tview.AlignLeft).SetExpansion(2).SetSelectable(false))
tableView.SetCell(0, 1, tview.NewTableCell("[grey]Stream").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false))
tableView.SetCell(0, 2, tview.NewTableCell("[grey]Container").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false))
tableView.SetCell(0, 3, tview.NewTableCell("[grey]Health").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false))
tableView.SetCell(0, 4, tview.NewTableCell("[grey]CPU %").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false))
tableView.SetCell(0, 5, tview.NewTableCell("[grey]Mem used (MB)").SetAlign(tview.AlignLeft).SetExpansion(1).SetSelectable(false))
tableView.SetCell(0, 6, tview.NewTableCell("[grey]Actions").SetAlign(tview.AlignLeft).SetExpansion(2).SetSelectable(false))
}
a.sourceView.Clear()
@ -150,10 +151,11 @@ func (a *Actor) redrawFromState(state domain.AppState) {
} else {
a.sourceView.SetCell(1, 1, tview.NewTableCell("[yellow]off-air"))
}
a.sourceView.SetCell(1, 2, tview.NewTableCell("[white]"+cmp.Or(state.Source.ContainerState.HealthState, "starting")))
a.sourceView.SetCell(1, 3, tview.NewTableCell("[white]"+fmt.Sprintf("%.1f", state.Source.ContainerState.CPUPercent)))
a.sourceView.SetCell(1, 4, tview.NewTableCell("[white]"+fmt.Sprintf("%.1f", float64(state.Source.ContainerState.MemoryUsageBytes)/1024/1024)))
a.sourceView.SetCell(1, 5, tview.NewTableCell(""))
a.sourceView.SetCell(1, 2, tview.NewTableCell("[white]"+state.Source.Container.State))
a.sourceView.SetCell(1, 3, tview.NewTableCell("[white]"+cmp.Or(state.Source.Container.HealthState, "starting")))
a.sourceView.SetCell(1, 4, tview.NewTableCell("[white]"+fmt.Sprintf("%.1f", state.Source.Container.CPUPercent)))
a.sourceView.SetCell(1, 5, tview.NewTableCell("[white]"+fmt.Sprintf("%.1f", float64(state.Source.Container.MemoryUsageBytes)/1024/1024)))
a.sourceView.SetCell(1, 6, tview.NewTableCell(""))
a.destView.Clear()
setHeaderRow(a.destView)
@ -164,7 +166,8 @@ func (a *Actor) redrawFromState(state domain.AppState) {
a.destView.SetCell(i+1, 2, tview.NewTableCell("[white]-"))
a.destView.SetCell(i+1, 3, tview.NewTableCell("[white]-"))
a.destView.SetCell(i+1, 4, tview.NewTableCell("[white]-"))
a.destView.SetCell(i+1, 5, tview.NewTableCell("[green]Tab to go live"))
a.destView.SetCell(i+1, 5, tview.NewTableCell("[white]-"))
a.destView.SetCell(i+1, 6, tview.NewTableCell("[green]Tab to go live"))
}
a.app.Draw()

35
testhelpers/channel.go Normal file
View File

@ -0,0 +1,35 @@
package testhelpers
import (
"log/slog"
"testing"
"github.com/stretchr/testify/require"
)
// ChanDiscard consumes a channel and discards all values.
func ChanDiscard[T any](ch <-chan T) {
go func() {
for range ch {
// no-op
}
}()
}
// ChanRequireNoError consumes a channel and asserts that no error is received.
func ChanRequireNoError(t testing.TB, ch <-chan error) {
t.Helper()
go func() {
require.NoError(t, <-ch)
}()
}
// ChanLog logs a channel's values.
func ChanLog[T any](ch <-chan T, logger *slog.Logger) {
go func() {
for v := range ch {
logger.Info("Channel", "value", v)
}
}()
}

View File

@ -13,15 +13,7 @@ func NewNopLogger() *slog.Logger {
return slog.New(slog.NewJSONHandler(io.Discard, nil))
}
// NewTestLogger returns a logger that writes to stderr.
func NewTestLogger() *slog.Logger {
return slog.New(slog.NewTextHandler(os.Stderr, nil))
}
// NoopChannel consumes a channel and discards all values.
func DiscardChannel[T any](ch <-chan T) {
go func() {
for range ch {
// no-op
}
}()
}