diff --git a/internal/app/app.go b/internal/app/app.go index e5d4649..4545147 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -182,6 +182,8 @@ func (a *App) Run(ctx context.Context) error { for { select { + case <-ctx.Done(): + return ctx.Err() case <-startMediaServerC: if err = srv.Start(ctx); err != nil { return fmt.Errorf("start mediaserver: %w", err) diff --git a/internal/app/integration_test.go b/internal/app/integration_test.go index a54a760..f39f7ca 100644 --- a/internal/app/integration_test.go +++ b/internal/app/integration_test.go @@ -132,7 +132,7 @@ func testIntegration(t *testing.T, mediaServerConfig config.MediaServerSource) { done <- struct{}{} }() - require.NoError(t, app.New(app.Params{ + require.Equal(t, context.Canceled, app.New(app.Params{ ConfigService: configService, DockerClient: dockerClient, Screen: &terminal.Screen{ @@ -319,7 +319,7 @@ func TestIntegrationCustomHost(t *testing.T) { done <- struct{}{} }() - require.NoError(t, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) + require.Equal(t, context.Canceled, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) }() time.Sleep(time.Second) @@ -390,7 +390,7 @@ func TestIntegrationCustomTLSCerts(t *testing.T) { done <- struct{}{} }() - require.NoError(t, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) + require.Equal(t, context.Canceled, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) }() require.EventuallyWithT( @@ -471,7 +471,7 @@ func TestIntegrationRestartDestination(t *testing.T) { done <- struct{}{} }() - require.NoError(t, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) + require.Equal(t, context.Canceled, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) }() require.EventuallyWithT( @@ -608,7 +608,7 @@ func TestIntegrationStartDestinationFailed(t *testing.T) { done <- struct{}{} }() - require.NoError(t, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) + require.Equal(t, context.Canceled, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) }() require.EventuallyWithT( @@ -681,7 +681,7 @@ func TestIntegrationDestinationValidations(t *testing.T) { done <- struct{}{} }() - require.NoError(t, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) + require.Equal(t, context.Canceled, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) }() require.EventuallyWithT( @@ -823,7 +823,7 @@ func TestIntegrationStartupCheck(t *testing.T) { done <- struct{}{} }() - require.NoError(t, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) + require.Equal(t, context.Canceled, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) }() require.EventuallyWithT( @@ -1069,7 +1069,7 @@ func TestIntegrationCopyURLs(t *testing.T) { done <- struct{}{} }() - require.NoError(t, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) + require.Equal(t, context.Canceled, app.New(buildAppParams(t, configService, dockerClient, screen, screenCaptureC, logger)).Run(ctx)) }() time.Sleep(3 * time.Second) diff --git a/main.go b/main.go index e2f9a4e..f611a04 100644 --- a/main.go +++ b/main.go @@ -3,11 +3,14 @@ package main import ( "cmp" "context" + "errors" "flag" "fmt" + "io" "log/slog" "os" "os/exec" + "os/signal" "runtime/debug" "syscall" @@ -27,16 +30,25 @@ var ( date string ) -func main() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +var errShutdown = errors.New("shutdown") - if err := run(ctx); err != nil { +func main() { + var exitStatus int + + if err := run(); errors.Is(err, errShutdown) { + exitStatus = 130 + } else if err != nil { + exitStatus = 1 _, _ = os.Stderr.WriteString("Error: " + err.Error() + "\n") } + + os.Exit(exitStatus) } -func run(ctx context.Context) error { +func run() error { + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(nil) + configService, err := config.NewDefaultService() if err != nil { return fmt.Errorf("build config service: %w", err) @@ -72,11 +84,24 @@ func run(ctx context.Context) error { if err != nil { return fmt.Errorf("read or create config: %w", err) } - logger, err := buildLogger(cfg.LogFile) + + headless := os.Getenv("OCTO_HEADLESS") != "" + logger, err := buildLogger(cfg.LogFile, headless) if err != nil { return fmt.Errorf("build logger: %w", err) } + if headless { + // When running in headless mode tview doesn't handle SIGINT for us. + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-ch + logger.Info("Received interrupt signal, exiting") + cancel(errShutdown) + }() + } + var clipboardAvailable bool if err = clipboard.Init(); err != nil { logger.Warn("Clipboard not available", "err", err) @@ -100,7 +125,7 @@ func run(ctx context.Context) error { app := app.New(app.Params{ ConfigService: configService, DockerClient: dockerClient, - Headless: os.Getenv("OCTO_HEADLESS") != "", + Headless: headless, ClipboardAvailable: clipboardAvailable, ConfigFilePath: configService.Path(), BuildInfo: domain.BuildInfo{ @@ -166,7 +191,20 @@ func printUsage() { } // buildLogger builds the logger, which may be a no-op logger. -func buildLogger(cfg config.LogFile) (*slog.Logger, error) { +func buildLogger(cfg config.LogFile, headless bool) (*slog.Logger, error) { + build := func(w io.Writer) *slog.Logger { + var handlerOpts slog.HandlerOptions + if os.Getenv("OCTO_DEBUG") != "" { + handlerOpts.Level = slog.LevelDebug + } + return slog.New(slog.NewTextHandler(w, &handlerOpts)) + } + + // In headless mode, always log to stderr. + if headless { + return build(os.Stderr), nil + } + if !cfg.Enabled { return slog.New(slog.DiscardHandler), nil } @@ -176,9 +214,5 @@ func buildLogger(cfg config.LogFile) (*slog.Logger, error) { return nil, fmt.Errorf("error opening log file: %w", err) } - var handlerOpts slog.HandlerOptions - if os.Getenv("OCTO_DEBUG") != "" { - handlerOpts.Level = slog.LevelDebug - } - return slog.New(slog.NewTextHandler(fptr, &handlerOpts)), nil + return build(fptr), nil }