152 lines
3.5 KiB
Go
152 lines
3.5 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.netflux.io/rob/octoplex/internal/event"
|
|
pb "git.netflux.io/rob/octoplex/internal/generated/grpc"
|
|
"git.netflux.io/rob/octoplex/internal/protocol"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
// APIListenerCountDeltaFunc is a function that is called when the number of
|
|
// API clients increments or decrements.
|
|
type APIClientCountDeltaFunc func(delta int)
|
|
|
|
// Server is the gRPC server that handles incoming commands and outgoing
|
|
// events.
|
|
type Server struct {
|
|
pb.UnimplementedInternalAPIServer
|
|
|
|
dispatcher func(event.Command)
|
|
bus *event.Bus
|
|
logger *slog.Logger
|
|
|
|
mu sync.Mutex
|
|
clientCount int
|
|
clientC chan struct{}
|
|
}
|
|
|
|
// newServer creates a new gRPC server.
|
|
func newServer(
|
|
dispatcher func(event.Command),
|
|
bus *event.Bus,
|
|
logger *slog.Logger,
|
|
) *Server {
|
|
return &Server{
|
|
dispatcher: dispatcher,
|
|
bus: bus,
|
|
clientC: make(chan struct{}, 1),
|
|
logger: logger.With("component", "server"),
|
|
}
|
|
}
|
|
|
|
func (s *Server) Communicate(stream pb.InternalAPI_CommunicateServer) error {
|
|
g, ctx := errgroup.WithContext(stream.Context())
|
|
|
|
// perform handshake:
|
|
if err := stream.Send(&pb.Envelope{Payload: &pb.Envelope_Event{Event: &pb.Event{EventType: &pb.Event_InternalApiReady{}}}}); err != nil {
|
|
return fmt.Errorf("send ready event: %w", err)
|
|
}
|
|
startStreamCmd, err := stream.Recv()
|
|
if err != nil {
|
|
return fmt.Errorf("receive start stream command: %w", err)
|
|
}
|
|
if startStreamCmd.GetCommand() == nil || startStreamCmd.GetCommand().GetStartInternalStream() == nil {
|
|
return fmt.Errorf("expected start stream command but got: %T", startStreamCmd)
|
|
}
|
|
|
|
// Notify that a client has connected and completed the handshake.
|
|
select {
|
|
case s.clientC <- struct{}{}:
|
|
default:
|
|
}
|
|
|
|
g.Go(func() error {
|
|
eventsC := s.bus.Register()
|
|
defer s.bus.Deregister(eventsC)
|
|
|
|
for {
|
|
select {
|
|
case evt := <-eventsC:
|
|
if err := stream.Send(&pb.Envelope{Payload: &pb.Envelope_Event{Event: protocol.EventToProto(evt)}}); err != nil {
|
|
return fmt.Errorf("send event: %w", err)
|
|
}
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
})
|
|
|
|
g.Go(func() error {
|
|
s.mu.Lock()
|
|
s.clientCount++
|
|
s.mu.Unlock()
|
|
|
|
defer func() {
|
|
s.mu.Lock()
|
|
s.clientCount--
|
|
s.mu.Unlock()
|
|
}()
|
|
|
|
for {
|
|
in, err := stream.Recv()
|
|
if err == io.EOF {
|
|
s.logger.Info("Client disconnected")
|
|
return err
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("receive message: %w", err)
|
|
}
|
|
|
|
switch pbCmd := in.Payload.(type) {
|
|
case *pb.Envelope_Command:
|
|
cmd := protocol.CommandFromProto(pbCmd.Command)
|
|
s.logger.Info("Received command", "command", cmd.Name())
|
|
s.dispatcher(cmd)
|
|
default:
|
|
return fmt.Errorf("expected command but got: %T", pbCmd)
|
|
}
|
|
}
|
|
})
|
|
|
|
if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, context.Canceled) {
|
|
s.logger.Error("Client stream closed with error", "err", err)
|
|
return fmt.Errorf("errgroup.Wait: %w", err)
|
|
}
|
|
|
|
s.logger.Info("Client stream closed")
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetClientCount returns the number of connected clients.
|
|
func (s *Server) GetClientCount() int {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
return s.clientCount
|
|
}
|
|
|
|
const waitForClientTimeout = 10 * time.Second
|
|
|
|
// WaitForClient waits for _any_ client to connect and complete the handshake.
|
|
// It times out if no client has connected after 10 seconds.
|
|
func (s *Server) WaitForClient(ctx context.Context) error {
|
|
select {
|
|
case <-s.clientC:
|
|
return nil
|
|
case <-time.After(waitForClientTimeout):
|
|
return errors.New("timeout")
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|