265 lines
5.7 KiB
Go
265 lines
5.7 KiB
Go
package warp
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/kixelated/invoker"
|
|
)
|
|
|
|
// Perform network simulation in-process to make a simpler demo.
|
|
// You should not use this in production; there are much better ways to throttle a network.
|
|
type Socket struct {
|
|
inner *net.UDPConn
|
|
|
|
writeRate int // bytes per second
|
|
writeErr error // return this error on all future writes
|
|
|
|
writeQueue []packet // packets ready to be sent
|
|
writeQueueSize int // number of bytes in the queue
|
|
writeQueueMax int // number of bytes allowed in the queue
|
|
|
|
writeLastTime time.Time
|
|
writeLastSize int
|
|
|
|
writeLoss float64 // packet loss percentage
|
|
|
|
writeNotify chan struct{} // closed when rate or queue is changed
|
|
writeMutex sync.Mutex
|
|
}
|
|
|
|
type packet struct {
|
|
Addr net.Addr
|
|
Data []byte
|
|
}
|
|
|
|
func NewSocket(addr string) (s *Socket, err error) {
|
|
s = new(Socket)
|
|
|
|
uaddr, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to resolve addr: %w", err)
|
|
}
|
|
|
|
s.inner, err = net.ListenUDP("udp", uaddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to listen: %w", err)
|
|
}
|
|
|
|
s.writeNotify = make(chan struct{})
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Socket) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
// TODO throttle reads?
|
|
return s.inner.ReadFrom(p)
|
|
}
|
|
|
|
// Queue up packets to be sent
|
|
func (s *Socket) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
s.writeMutex.Lock()
|
|
defer s.writeMutex.Unlock()
|
|
|
|
if s.writeErr != nil {
|
|
return 0, s.writeErr
|
|
}
|
|
|
|
if s.writeQueueMax > 0 && s.writeQueueSize+len(p) > s.writeQueueMax {
|
|
// Gotta drop the packet
|
|
return len(p), nil
|
|
}
|
|
|
|
if len(s.writeQueue) == 0 && s.writeRate == 0 {
|
|
// If there's no queue and no throttling, write directly
|
|
if s.writeLoss == 0 || rand.Float64() >= s.writeLoss {
|
|
_, err = s.inner.WriteTo(p, addr)
|
|
if err != nil {
|
|
s.writeErr = err
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
return len(p), nil
|
|
}
|
|
|
|
// Make a copy of the packet
|
|
pc := packet{
|
|
Addr: addr,
|
|
Data: append([]byte{}, p...),
|
|
}
|
|
|
|
if len(s.writeQueue) == 0 {
|
|
// Wakeup the writer goroutine.
|
|
close(s.writeNotify)
|
|
s.writeNotify = make(chan struct{})
|
|
}
|
|
|
|
s.writeQueue = append(s.writeQueue, pc)
|
|
s.writeQueueSize += len(p)
|
|
|
|
return len(p), nil
|
|
}
|
|
|
|
// Perform the writing in another goroutine.
|
|
func (s *Socket) runWrite(ctx context.Context) (err error) {
|
|
timer := time.NewTimer(time.Second)
|
|
timer.Stop()
|
|
|
|
s.writeMutex.Lock()
|
|
defer s.writeMutex.Unlock()
|
|
|
|
for {
|
|
// Lock is held at the start of the loop
|
|
|
|
lastTime := s.writeLastTime
|
|
lastSize := s.writeLastSize
|
|
rate := s.writeRate
|
|
notify := s.writeNotify
|
|
ready := len(s.writeQueue) > 0
|
|
|
|
if !ready {
|
|
// Unlock while we wait for changes.
|
|
s.writeMutex.Unlock()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
s.writeMutex.Lock() // gotta lock again just for the defer...
|
|
return ctx.Err()
|
|
case <-notify:
|
|
// Something changed, try again
|
|
s.writeMutex.Lock()
|
|
continue
|
|
}
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
if lastSize > 0 && rate > 0 {
|
|
// Compute the amount of time it should take to send lastSize bytes
|
|
delay := time.Second * time.Duration(lastSize) / time.Duration(rate)
|
|
next := lastTime.Add(delay)
|
|
|
|
delay = next.Sub(now)
|
|
if delay > 0 {
|
|
// Unlock while we sleep.
|
|
s.writeMutex.Unlock()
|
|
|
|
// Reuse the timer instance
|
|
// No need to drain the timer beforehand
|
|
timer.Reset(delay)
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
s.writeMutex.Lock() // gotta lock again just for the defer...
|
|
return ctx.Err()
|
|
case <-timer.C:
|
|
now = next
|
|
s.writeMutex.Lock()
|
|
case <-notify:
|
|
// Something changed, try again
|
|
if !timer.Stop() {
|
|
// Drain the timer
|
|
<-timer.C
|
|
}
|
|
|
|
s.writeMutex.Lock()
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
// Send the first packet in the queue
|
|
p := s.writeQueue[0]
|
|
s.writeQueue = s.writeQueue[1:]
|
|
s.writeQueueSize -= len(p.Data)
|
|
s.writeLastTime = now
|
|
s.writeLastSize = len(p.Data)
|
|
|
|
loss := s.writeLoss
|
|
|
|
if loss > 0 || rand.Float64() >= loss {
|
|
_, err = s.inner.WriteTo(p.Data, p.Addr)
|
|
if err != nil {
|
|
s.writeErr = err
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Set the number of *bytes* that can be written within a second, or -1 for unlimited.
|
|
// Defaults to unlimited.
|
|
func (s *Socket) SetWriteRate(rate int) {
|
|
s.writeMutex.Lock()
|
|
defer s.writeMutex.Unlock()
|
|
|
|
s.writeRate = rate
|
|
close(s.writeNotify)
|
|
s.writeNotify = make(chan struct{})
|
|
}
|
|
|
|
// Set the maximum number of bytes to queue before we drop packets.
|
|
// Defaults to unlimited (!)
|
|
func (s *Socket) SetWriteBuffer(size int) {
|
|
s.writeMutex.Lock()
|
|
defer s.writeMutex.Unlock()
|
|
|
|
s.writeQueueMax = size
|
|
if s.writeQueueMax > 0 {
|
|
// Remove from the queue until the limit has been met
|
|
for s.writeQueueSize > s.writeQueueMax {
|
|
last := s.writeQueue[len(s.writeQueue)-1]
|
|
s.writeQueue = s.writeQueue[:len(s.writeQueue)-1]
|
|
s.writeQueueSize -= len(last.Data)
|
|
}
|
|
}
|
|
|
|
close(s.writeNotify)
|
|
s.writeNotify = make(chan struct{})
|
|
}
|
|
|
|
func (s *Socket) SetWriteLoss(percent float64) {
|
|
s.writeMutex.Lock()
|
|
defer s.writeMutex.Unlock()
|
|
|
|
s.writeLoss = percent
|
|
}
|
|
|
|
func (s *Socket) Close() (err error) {
|
|
return s.inner.Close()
|
|
}
|
|
|
|
func (s *Socket) LocalAddr() net.Addr {
|
|
return s.inner.LocalAddr()
|
|
}
|
|
|
|
func (s *Socket) SetDeadline(t time.Time) error {
|
|
return s.inner.SetDeadline(t)
|
|
}
|
|
|
|
func (s *Socket) SetReadDeadline(t time.Time) error {
|
|
return s.inner.SetReadDeadline(t)
|
|
}
|
|
|
|
func (s *Socket) SetReadBuffer(size int) error {
|
|
return s.inner.SetReadBuffer(size)
|
|
}
|
|
|
|
func (s *Socket) SetWriteDeadline(t time.Time) error {
|
|
return s.inner.SetWriteDeadline(t)
|
|
}
|
|
|
|
func (s *Socket) SyscallConn() (syscall.RawConn, error) {
|
|
return s.inner.SyscallConn()
|
|
}
|
|
|
|
func (s *Socket) Run(ctx context.Context) (err error) {
|
|
return invoker.Run(ctx /*s.runRead, */, s.runWrite)
|
|
}
|