diff --git a/command/aa55.go b/command/aa55.go index ea6b7a3..596b775 100644 --- a/command/aa55.go +++ b/command/aa55.go @@ -41,7 +41,7 @@ func aa55Checksum(payload []byte) []byte { func (cmd AA55Command) String() string { return string(cmd.payload) } -func (cmd AA55Command) validateResponse(p []byte) ([]byte, error) { +func (cmd AA55Command) ValidateResponse(p []byte) ([]byte, error) { if len(p) < 8 { return nil, fmt.Errorf("response truncated") } diff --git a/command/command.go b/command/command.go index b2c6adc..c634038 100644 --- a/command/command.go +++ b/command/command.go @@ -4,29 +4,65 @@ import ( "bufio" "fmt" "io" + "log" + "time" ) -type command interface { +type Command interface { String() string - validateResponse([]byte) ([]byte, error) + ValidateResponse([]byte) ([]byte, error) } +type Conn interface { + io.ReadWriter + SetDeadline(time.Time) error +} + +const ( + maxAttempts = 3 + timeout = time.Second * 3 + readBufferSizeBytes = 4_096 +) + // Send writes the command to the provided Writer, and reads and validates the // response. -// -// TODO: accept a context.Context and enforce deadline/timeout. -func Send(cmd command, conn io.ReadWriter) ([]byte, error) { +func Send(cmd Command, conn Conn) ([]byte, error) { + var ( + resp []byte + err error + attempts int + ) + + for { + if resp, err = tryRequest(cmd, conn); err != nil { + attempts++ + log.Printf("error executing command (attempt %d): %s", attempts, err) + if attempts <= 3 { + continue + } + return nil, fmt.Errorf("error executing command: %s", err) + } + + return resp, nil + } +} + +func tryRequest(cmd Command, conn Conn) ([]byte, error) { + if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return nil, fmt.Errorf("error setting deadline: %s", err) + } + + p := make([]byte, readBufferSizeBytes) _, err := fmt.Fprint(conn, cmd.String()) if err != nil { return nil, fmt.Errorf("error writing to socket: %s", err) } - p := make([]byte, 4_096) r := bufio.NewReader(conn) n, err := r.Read(p) if err != nil { return nil, fmt.Errorf("error reading from socket: %s", err) } - return cmd.validateResponse(p[:n]) + return cmd.ValidateResponse(p[:n]) } diff --git a/command/command_test.go b/command/command_test.go new file mode 100644 index 0000000..3a624c4 --- /dev/null +++ b/command/command_test.go @@ -0,0 +1,63 @@ +package command_test + +import ( + "errors" + "testing" + "time" + + "git.netflux.io/rob/solar-toolkit/command" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type readResult struct { + p []byte + err error +} + +type mockConn struct { + readResults []readResult +} + +func (c *mockConn) Read(p []byte) (int, error) { + var result readResult + result, c.readResults = c.readResults[0], c.readResults[1:] + return copy(p, result.p), result.err +} + +func (c *mockConn) Write(p []byte) (int, error) { return 0, nil } +func (c *mockConn) SetDeadline(time.Time) error { return nil } + +type mockCommand struct{} + +func (cmd *mockCommand) String() string { return "baz" } +func (cmd *mockCommand) ValidateResponse(p []byte) ([]byte, error) { return p, nil } + +func TestSendWithOneRetry(t *testing.T) { + var cmd mockCommand + conn := mockConn{ + readResults: []readResult{ + {err: errors.New("i/o timeout")}, + {p: []byte("bar"), err: nil}, + }, + } + + resp, err := command.Send(&cmd, &conn) + require.NoError(t, err) + assert.Equal(t, []byte("bar"), resp) +} + +func TestSendFail(t *testing.T) { + var cmd mockCommand + conn := mockConn{ + readResults: []readResult{ + {err: errors.New("i/o timeout 1")}, + {err: errors.New("i/o timeout 2")}, + {err: errors.New("i/o timeout 3")}, + {err: errors.New("i/o timeout 4")}, + }, + } + + _, err := command.Send(&cmd, &conn) + assert.EqualError(t, err, "error executing command: error reading from socket: i/o timeout 4") +} diff --git a/command/modbus.go b/command/modbus.go index 7ca8c54..d6a4515 100644 --- a/command/modbus.go +++ b/command/modbus.go @@ -58,6 +58,6 @@ func modbusChecksum(b []byte) uint16 { func (cmd ModbusCommand) String() string { return string(cmd.payload) } -func (cmd ModbusCommand) validateResponse(p []byte) ([]byte, error) { +func (cmd ModbusCommand) ValidateResponse(p []byte) ([]byte, error) { return p[5 : len(p)-2], nil } diff --git a/inverter/et.go b/inverter/et.go index 97c0109..d363c63 100644 --- a/inverter/et.go +++ b/inverter/et.go @@ -5,7 +5,6 @@ import ( "context" "encoding/binary" "fmt" - "io" "math" "strings" "time" @@ -266,7 +265,7 @@ func (inv ET) DecodeRuntimeData(p []byte) (*ETRuntimeData, error) { } // DEPRECATED -func (inv ET) DeviceInfo(ctx context.Context, conn io.ReadWriter) (*DeviceInfo, error) { +func (inv ET) DeviceInfo(ctx context.Context, conn command.Conn) (*DeviceInfo, error) { resp, err := command.Send(command.NewModbus(command.ModbusCommandTypeRead, 0x88b8, 0x0021), conn) if err != nil { return nil, fmt.Errorf("error sending command: %s", err) @@ -281,7 +280,7 @@ func (inv ET) DeviceInfo(ctx context.Context, conn io.ReadWriter) (*DeviceInfo, } // DEPRECATED -func (inv ET) RuntimeData(ctx context.Context, conn io.ReadWriter) (*ETRuntimeData, error) { +func (inv ET) RuntimeData(ctx context.Context, conn command.Conn) (*ETRuntimeData, error) { deviceInfo, err := inv.DeviceInfo(ctx, conn) if err != nil { return nil, fmt.Errorf("error fetching device info: %s", err)