octoplex/container/container.go
2025-01-29 16:33:53 +01:00

345 lines
9.3 KiB
Go

package container
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"strings"
"sync"
"time"
"git.netflux.io/rob/termstream/domain"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/events"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/client"
"github.com/google/uuid"
)
// stopTimeout is the timeout for stopping a container.
var stopTimeout = 10 * time.Second
// Client provides a thin wrapper around the Docker API client, and provides
// additional functionality such as exposing container stats.
type Client struct {
id uuid.UUID
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
apiClient *client.Client
logger *slog.Logger
}
// NewClient creates a new Client.
func NewClient(ctx context.Context, logger *slog.Logger) (*Client, error) {
apiClient, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(ctx)
client := &Client{
id: uuid.New(),
ctx: ctx,
cancel: cancel,
apiClient: apiClient,
logger: logger,
}
return client, nil
}
// stats is a struct to hold container stats.
type stats struct {
cpuPercent float64
memoryUsageBytes uint64
}
// getStats returns a channel that will receive container stats. The channel is
// never closed, but the spawned goroutine will exit when the context is
// cancelled.
func (a *Client) getStats(containerID string) <-chan stats {
ch := make(chan stats)
go func() {
statsReader, err := a.apiClient.ContainerStats(a.ctx, containerID, true)
if err != nil {
// TODO: error handling?
a.logger.Error("Error getting container stats", "err", err, "id", shortID(containerID))
return
}
defer statsReader.Body.Close()
buf := make([]byte, 4_096)
for {
n, err := statsReader.Body.Read(buf)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
break
}
a.logger.Error("Error reading stats", "err", err, "id", shortID(containerID))
break
}
var statsResp container.StatsResponse
if err = json.Unmarshal(buf[:n], &statsResp); err != nil {
a.logger.Error("Error unmarshalling stats", "err", err, "id", shortID(containerID))
break
}
// https://stackoverflow.com/a/30292327/62871
cpuDelta := float64(statsResp.CPUStats.CPUUsage.TotalUsage - statsResp.PreCPUStats.CPUUsage.TotalUsage)
systemDelta := float64(statsResp.CPUStats.SystemUsage - statsResp.PreCPUStats.SystemUsage)
ch <- stats{
cpuPercent: (cpuDelta / systemDelta) * float64(statsResp.CPUStats.OnlineCPUs) * 100,
memoryUsageBytes: statsResp.MemoryStats.Usage,
}
}
}()
return ch
}
// getEvents returns a channel that will receive container events. The channel is
// never closed, but the spawned goroutine will exit when the context is
// cancelled.
func (a *Client) getEvents(containerID string) <-chan events.Message {
sendC := make(chan events.Message)
getEvents := func() (bool, error) {
recvC, errC := a.apiClient.Events(a.ctx, events.ListOptions{
Filters: filters.NewArgs(
filters.Arg("container", containerID),
filters.Arg("type", "container"),
),
})
for {
select {
case <-a.ctx.Done():
return false, a.ctx.Err()
case evt := <-recvC:
sendC <- evt
case err := <-errC:
if a.ctx.Err() != nil || errors.Is(err, io.EOF) {
return false, err
}
return true, err
}
}
}
go func() {
for {
shouldRetry, err := getEvents()
if !shouldRetry {
break
}
a.logger.Warn("Error receiving Docker events", "err", err, "id", shortID(containerID))
time.Sleep(2 * time.Second)
}
}()
return sendC
}
// RunContainerParams are the parameters for running a container.
type RunContainerParams struct {
Name string
ContainerConfig *container.Config
HostConfig *container.HostConfig
}
// RunContainer runs a container with the given parameters.
//
// The returned channel will receive the current state of the container, and
// will be closed after the container has stopped.
func (a *Client) RunContainer(ctx context.Context, params RunContainerParams) (string, <-chan domain.ContainerState, error) {
pullReader, err := a.apiClient.ImagePull(ctx, params.ContainerConfig.Image, image.PullOptions{})
if err != nil {
return "", nil, fmt.Errorf("image pull: %w", err)
}
_, _ = io.Copy(io.Discard, pullReader)
_ = pullReader.Close()
params.ContainerConfig.Labels["app"] = "termstream"
params.ContainerConfig.Labels["app-id"] = a.id.String()
var name string
if params.Name != "" {
name = "termstream-" + a.id.String() + "-" + params.Name
}
createResp, err := a.apiClient.ContainerCreate(
ctx,
params.ContainerConfig,
params.HostConfig,
nil,
nil,
name,
)
if err != nil {
return "", nil, fmt.Errorf("container create: %w", err)
}
if err = a.apiClient.ContainerStart(ctx, createResp.ID, container.StartOptions{}); err != nil {
return "", nil, fmt.Errorf("container start: %w", err)
}
a.logger.Info("Started container", "id", shortID(createResp.ID))
containerStateC := make(chan domain.ContainerState, 1)
a.wg.Add(1)
go func() {
defer a.wg.Done()
defer close(containerStateC)
a.runContainerLoop(ctx, createResp.ID, containerStateC)
}()
return createResp.ID, containerStateC, nil
}
// runContainerLoop is the control loop for a single container. It returns only
// when the container exits.
func (a *Client) runContainerLoop(ctx context.Context, containerID string, stateCh chan<- domain.ContainerState) {
statsC := a.getStats(containerID)
eventsC := a.getEvents(containerID)
respC, errC := a.apiClient.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
state := &domain.ContainerState{ID: containerID}
sendState := func() { stateCh <- *state }
sendState()
for {
select {
case resp := <-respC:
a.logger.Info("Container entered non-running state", "exit_code", resp.StatusCode, "id", shortID(containerID))
return
case err := <-errC:
// TODO: error handling?
if err != context.Canceled {
a.logger.Error("Error setting container wait", "err", err, "id", shortID(containerID))
}
return
case evt := <-eventsC:
if strings.Contains(string(evt.Action), "health_status") {
switch evt.Action {
case events.ActionHealthStatusRunning:
state.HealthState = "running"
case events.ActionHealthStatusHealthy:
state.HealthState = "healthy"
case events.ActionHealthStatusUnhealthy:
state.HealthState = "unhealthy"
default:
a.logger.Warn("Unknown health status", "status", evt.Action)
state.HealthState = "unknown"
}
sendState()
}
case stats := <-statsC:
state.CPUPercent = stats.cpuPercent
state.MemoryUsageBytes = stats.memoryUsageBytes
sendState()
}
}
}
// Close closes the client, stopping and removing all running containers.
func (a *Client) Close() error {
a.cancel()
ctx, cancel := context.WithTimeout(context.Background(), stopTimeout)
defer cancel()
containerList, err := a.containersMatchingLabels(ctx, nil)
if err != nil {
return fmt.Errorf("container list: %w", err)
}
for _, container := range containerList {
if err := a.removeContainer(ctx, container.ID); err != nil {
a.logger.Error("Error removing container:", "err", err, "id", shortID(container.ID))
}
}
a.wg.Wait()
return a.apiClient.Close()
}
func (a *Client) removeContainer(ctx context.Context, id string) error {
a.logger.Info("Stopping container", "id", shortID(id))
stopTimeout := int(stopTimeout.Seconds())
if err := a.apiClient.ContainerStop(ctx, id, container.StopOptions{Timeout: &stopTimeout}); err != nil {
return fmt.Errorf("container stop: %w", err)
}
a.logger.Info("Removing container", "id", shortID(id))
if err := a.apiClient.ContainerRemove(ctx, id, container.RemoveOptions{Force: true}); err != nil {
return fmt.Errorf("container remove: %w", err)
}
return nil
}
// ContainerRunning checks if a container with the given labels is running.
func (a *Client) ContainerRunning(ctx context.Context, labels map[string]string) (bool, error) {
containers, err := a.containersMatchingLabels(ctx, labels)
if err != nil {
return false, fmt.Errorf("container list: %w", err)
}
for _, container := range containers {
if container.State == "running" || container.State == "restarting" {
return true, nil
}
}
return false, nil
}
// RemoveContainers removes all containers with the given labels.
func (a *Client) RemoveContainers(ctx context.Context, labels map[string]string) error {
containers, err := a.containersMatchingLabels(ctx, labels)
if err != nil {
return fmt.Errorf("container list: %w", err)
}
for _, container := range containers {
if err := a.removeContainer(ctx, container.ID); err != nil {
a.logger.Error("Error removing container:", "err", err, "id", shortID(container.ID))
}
}
return nil
}
func (a *Client) containersMatchingLabels(ctx context.Context, labels map[string]string) ([]types.Container, error) {
filterArgs := filters.NewArgs(
filters.Arg("label", "app=termstream"),
filters.Arg("label", "app-id="+a.id.String()),
)
for k, v := range labels {
filterArgs.Add("label", k+"="+v)
}
return a.apiClient.ContainerList(ctx, container.ListOptions{
All: true,
Filters: filterArgs,
})
}
func shortID(id string) string {
if len(id) < 12 {
return id
}
return id[:12]
}