Refactor uploader, remove Writer implementation

This commit is contained in:
Rob Watson 2021-11-12 08:20:34 +01:00
parent 50e7b59442
commit 79be8b7936
2 changed files with 147 additions and 212 deletions

View File

@ -343,16 +343,7 @@ func (s *MediaSetService) getAudioFromYoutube(ctx context.Context, mediaSet stor
} }
s3Key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID) s3Key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID)
uploader, err := newMultipartUploadWriter( uploader := newMultipartUploader(s.s3)
ctx,
s.s3,
s3Bucket,
s3Key,
rawAudioMimeType,
)
if err != nil {
return nil, fmt.Errorf("error creating uploader: %v", err)
}
fetchAudioProgressReader := newGetAudioProgressReader( fetchAudioProgressReader := newGetAudioProgressReader(
binary.LittleEndian, binary.LittleEndian,
@ -378,45 +369,14 @@ type getAudioFromYoutubeState struct {
*fetchAudioProgressReader *fetchAudioProgressReader
ffmpegReader *ffmpegReader ffmpegReader *ffmpegReader
uploader *multipartUploadWriter uploader *multipartUploader
s3Bucket, s3Key string s3Bucket, s3Key string
store Store store Store
} }
func (s *getAudioFromYoutubeState) run(ctx context.Context, mediaSetID uuid.UUID) { func (s *getAudioFromYoutubeState) run(ctx context.Context, mediaSetID uuid.UUID) {
mw := io.MultiWriter(s, s.uploader) teeReader := io.TeeReader(s.ffmpegReader, s)
done := make(chan error) bytesUploaded, err := s.uploader.Upload(ctx, teeReader, s.s3Bucket, s.s3Key, rawAudioMimeType)
var err error
go func() {
_, copyErr := io.Copy(mw, s.ffmpegReader)
// At this point, there is no more data to send to the uploader.
// We can close it safely, it always returns nil.
_ = s.uploader.Close()
done <- copyErr
}()
outer:
for {
select {
case <-ctx.Done():
err = ctx.Err()
break outer
case err = <-done:
break outer
}
}
var framesUploaded int64
if err == nil {
if bytesUploaded, uploaderErr := s.uploader.Complete(); uploaderErr != nil {
err = uploaderErr
} else {
framesUploaded = bytesUploaded / int64(s.channels) / SizeOfInt16
}
}
// If there was an error returned, the underlying ffmpegReader process may // If there was an error returned, the underlying ffmpegReader process may
// still be active. Kill it. // still be active. Kill it.
@ -426,7 +386,8 @@ outer:
} }
} }
// Either way, we need to wait for the ffmpegReader process to exit. // Either way, we need to wait for the ffmpegReader process to exit,
// and ensure there is no error.
if readerErr := s.ffmpegReader.Close(); readerErr != nil { if readerErr := s.ffmpegReader.Close(); readerErr != nil {
if err == nil { if err == nil {
err = readerErr err = readerErr
@ -438,7 +399,7 @@ outer:
ID: mediaSetID, ID: mediaSetID,
AudioS3Bucket: sqlString(s.s3Bucket), AudioS3Bucket: sqlString(s.s3Bucket),
AudioS3Key: sqlString(s.s3Key), AudioS3Key: sqlString(s.s3Key),
AudioFrames: sqlInt64(framesUploaded), AudioFrames: sqlInt64(bytesUploaded),
}) })
if updateErr != nil { if updateErr != nil {
@ -449,12 +410,6 @@ outer:
if err != nil { if err != nil {
log.Printf("error uploading asynchronously: %v", err) log.Printf("error uploading asynchronously: %v", err)
newCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
if abortUploadErr := s.uploader.Abort(newCtx); abortUploadErr != nil {
log.Printf("error aborting uploader: %v", abortUploadErr)
}
s.Abort(err) s.Abort(err)
return return
} }

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"log" "log"
"sync" "sync"
@ -13,202 +14,181 @@ import (
"github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go-v2/service/s3/types"
) )
// multipartUploadWriter is a Writer that uploads transparently to an S3 bucket type multipartUploader struct {
// in 5MB parts. It buffers data internally until a part is ready to send over
// the network. Parts are sent as soon as they exceed the minimum part size of
// 5MB.
//
// The caller must call either Complete() or Abort() after finishing writing.
// Failure to do so will leave S3 in an inconsistent state.
type multipartUploadWriter struct {
ctx context.Context
wg sync.WaitGroup
s3 S3Client s3 S3Client
buf *bytes.Buffer
bucket, key, contentType string
partNum int32
uploadResults chan uploadResult
uploadID string
} }
type uploadResult struct { type uploadResult struct {
completedPart types.CompletedPart completedPart types.CompletedPart
size int64 size int64
}
type readResult struct {
n int
err error err error
} }
const targetPartSizeBytes = 5 * 1024 * 1024 // 5MB const (
targetPartSizeBytes = 5 * 1024 * 1024 // 5MB
bufferOverflowSize = 16_384 // 16Kb
)
func newMultipartUploader(s3Client S3Client) *multipartUploader {
return &multipartUploader{s3: s3Client}
}
// Upload uploads to an S3 bucket in 5MB parts. It buffers data internally
// until a part is ready to send over the network. Parts are sent as soon as
// they exceed the minimum part size of 5MB.
func (u *multipartUploader) Upload(ctx context.Context, r io.Reader, bucket, key, contentType string) (int64, error) {
var uploaded bool
// newMultipartUploadWriter creates a new multipart upload writer, including
// creating the upload on S3. Either Complete or Abort must be called after
// calling this function.
func newMultipartUploadWriter(ctx context.Context, s3Client S3Client, bucket, key, contentType string) (*multipartUploadWriter, error) {
input := s3.CreateMultipartUploadInput{ input := s3.CreateMultipartUploadInput{
Bucket: aws.String(bucket), Bucket: aws.String(bucket),
Key: aws.String(key), Key: aws.String(key),
ContentType: aws.String(contentType), ContentType: aws.String(contentType),
} }
output, err := u.s3.CreateMultipartUpload(ctx, &input)
output, err := s3Client.CreateMultipartUpload(ctx, &input)
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating multipart upload: %v", err) return 0, fmt.Errorf("error creating multipart upload: %v", err)
} }
const bufferOverflowSize = 16_384 // abort the upload if possible, logging any errors, on exit.
b := make([]byte, 0, targetPartSizeBytes+bufferOverflowSize) defer func() {
if uploaded {
return &multipartUploadWriter{ return
ctx: ctx, }
s3: s3Client, input := s3.AbortMultipartUploadInput{
buf: bytes.NewBuffer(b), Bucket: aws.String(bucket),
bucket: bucket, Key: aws.String(key),
key: key, UploadId: output.UploadId,
contentType: contentType,
partNum: 1,
uploadResults: make(chan uploadResult),
uploadID: *output.UploadId,
}, nil
} }
func (u *multipartUploadWriter) Write(p []byte) (int, error) { _, abortErr := u.s3.AbortMultipartUpload(ctx, &input)
n, err := u.buf.Write(p) if abortErr != nil {
if err != nil { log.Printf("error aborting upload: %v", abortErr)
return n, fmt.Errorf("error writing to buffer: %v", err) } else {
log.Printf("aborted upload, key = %s", key)
} }
}()
if u.buf.Len() >= targetPartSizeBytes { uploadResultChan := make(chan uploadResult)
buf := make([]byte, u.buf.Len()) uploadErrorChan := make(chan error, 1)
copy(buf, u.buf.Bytes())
u.buf.Truncate(0)
u.wg.Add(1) // uploadPart uploads an individual part.
go u.uploadPart(buf, u.partNum) uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) {
defer wg.Done()
u.partNum++ partLen := int64(len(buf))
}
return n, err
}
func (u *multipartUploadWriter) uploadPart(buf []byte, partNum int32) {
defer u.wg.Done()
partLen := len(buf)
log.Printf("uploading part num = %d, len = %d", partNum, partLen) log.Printf("uploading part num = %d, len = %d", partNum, partLen)
input := s3.UploadPartInput{ input := s3.UploadPartInput{
Body: bytes.NewReader(buf), Body: bytes.NewReader(buf),
Bucket: aws.String(u.bucket), Bucket: aws.String(bucket),
Key: aws.String(u.key), Key: aws.String(key),
PartNumber: partNum, PartNumber: partNum,
UploadId: aws.String(u.uploadID), UploadId: output.UploadId,
ContentLength: int64(partLen), ContentLength: partLen,
} }
output, uploadErr := u.s3.UploadPart(u.ctx, &input) output, uploadErr := u.s3.UploadPart(ctx, &input)
if uploadErr != nil { if uploadErr != nil {
// TODO: retry on failure // TODO: retry on failure
u.uploadResults <- uploadResult{err: fmt.Errorf("error uploading part: %v", uploadErr)} uploadErrorChan <- uploadErr
return return
} }
log.Printf("uploaded part num = %d, etag = %s, bytes = %d", partNum, *output.ETag, partLen) log.Printf("uploaded part num = %d, etag = %s, bytes = %d", partNum, *output.ETag, partLen)
u.uploadResults <- uploadResult{ uploadResultChan <- uploadResult{
completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum}, completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum},
size: int64(partLen), size: partLen,
} }
} }
// Close signals that no further data will be written to the writer. wgDone := make(chan struct{})
// Always returns nil. var wg sync.WaitGroup
func (u *multipartUploadWriter) Close() error { wg.Add(1) // done when the reader returns EOF
// TODO: trigger Complete() here too?
close(u.uploadResults)
return nil
}
// Complete waits for all currently uploading parts to be uploaded, and
// finalizes the object in S3.
//
// Close() must have been been called first.
func (u *multipartUploadWriter) Complete() (int64, error) {
completedParts := make([]types.CompletedPart, 0, 64)
var uploadedBytes int64
// Write() launches multiple goroutines to upload the parts asynchronously.
// We need a waitgroup to ensure that all parts are complete, and the channel
// has been closed, before we continue.
uploadDone := make(chan struct{})
go func() { go func() {
u.wg.Wait() wg.Wait()
close(u.uploadResults) wgDone <- struct{}{}
uploadDone <- struct{}{}
}() }()
readChan := make(chan readResult)
buf := make([]byte, 32_768)
go func() {
for {
var rr readResult
rr.n, rr.err = r.Read(buf)
readChan <- rr
if rr.err != nil {
return
}
}
}()
var closing bool
currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+bufferOverflowSize))
partNum := int32(1)
results := make([]uploadResult, 0, 64)
outer: outer:
for { for {
select { select {
case uploadResult, ok := <-u.uploadResults: case uploadResult := <-uploadResultChan:
if !ok { results = append(results, uploadResult)
case uploadErr := <-uploadErrorChan:
return 0, fmt.Errorf("error while uploading part: %v", uploadErr)
case <-wgDone:
break outer break outer
} case <-ctx.Done():
// if len(completedParts) == 3 { return 0, ctx.Err()
// return 0, errors.New("nope") case readResult := <-readChan:
// } if readResult.err == io.EOF {
if uploadResult.err != nil { wg.Done()
return 0, uploadResult.err closing = true
} else if readResult.err != nil {
return 0, fmt.Errorf("reader error: %v", readResult.err)
} }
log.Println("APPENDING PART, len now", len(completedParts)) _, _ = currPart.Write(buf[:readResult.n])
completedParts = append(completedParts, uploadResult.completedPart) if closing || currPart.Len() >= targetPartSizeBytes {
uploadedBytes += uploadResult.size part := make([]byte, currPart.Len())
case <-uploadDone: copy(part, currPart.Bytes())
break outer currPart.Truncate(0)
case <-u.ctx.Done():
return 0, u.ctx.Err() wg.Add(1)
go uploadPart(&wg, part, partNum)
partNum++
}
} }
} }
if len(completedParts) == 0 { if len(results) == 0 {
return 0, errors.New("no parts available to upload") return 0, errors.New("no parts available to upload")
} }
log.Printf("parts - %+v, bucket - %s, key - %s, id - %s", completedParts, u.bucket, u.key, u.uploadID) completedParts := make([]types.CompletedPart, 0, 64)
log.Printf("len(parts) = %d, cap(parts) = %d", len(completedParts), cap(completedParts)) var uploadedBytes int64
for _, result := range results {
completedParts = append(completedParts, result.completedPart)
uploadedBytes += result.size
}
input := s3.CompleteMultipartUploadInput{ completeInput := s3.CompleteMultipartUploadInput{
Bucket: aws.String(u.bucket), Bucket: aws.String(bucket),
Key: aws.String(u.key), Key: aws.String(key),
UploadId: aws.String(u.uploadID), UploadId: output.UploadId,
MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts}, MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts},
} }
_, err := u.s3.CompleteMultipartUpload(u.ctx, &input) if _, err = u.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil {
if err != nil {
return 0, fmt.Errorf("error completing upload: %v", err) return 0, fmt.Errorf("error completing upload: %v", err)
} }
log.Printf("completed upload, key = %s, bytesUploaded = %d", u.key, uploadedBytes) log.Printf("completed upload, key = %s, bytesUploaded = %d", key, uploadedBytes)
uploaded = true
return uploadedBytes, nil return uploadedBytes, nil
} }
// Abort aborts the upload process, cancelling the upload on S3. It accepts a
// separate context to the associated writer in case it is called during
// cleanup after the original context was killed.
func (u *multipartUploadWriter) Abort(ctx context.Context) error {
input := s3.AbortMultipartUploadInput{
Bucket: aws.String(u.bucket),
Key: aws.String(u.key),
UploadId: aws.String(u.uploadID),
}
_, err := u.s3.AbortMultipartUpload(ctx, &input)
if err != nil {
return fmt.Errorf("error aborting upload: %v", err)
}
log.Printf("aborted upload, key = %s", u.key)
return nil
}