Refactor uploader, remove Writer implementation
This commit is contained in:
parent
50e7b59442
commit
79be8b7936
|
@ -343,16 +343,7 @@ func (s *MediaSetService) getAudioFromYoutube(ctx context.Context, mediaSet stor
|
||||||
}
|
}
|
||||||
|
|
||||||
s3Key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID)
|
s3Key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID)
|
||||||
uploader, err := newMultipartUploadWriter(
|
uploader := newMultipartUploader(s.s3)
|
||||||
ctx,
|
|
||||||
s.s3,
|
|
||||||
s3Bucket,
|
|
||||||
s3Key,
|
|
||||||
rawAudioMimeType,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating uploader: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fetchAudioProgressReader := newGetAudioProgressReader(
|
fetchAudioProgressReader := newGetAudioProgressReader(
|
||||||
binary.LittleEndian,
|
binary.LittleEndian,
|
||||||
|
@ -378,45 +369,14 @@ type getAudioFromYoutubeState struct {
|
||||||
*fetchAudioProgressReader
|
*fetchAudioProgressReader
|
||||||
|
|
||||||
ffmpegReader *ffmpegReader
|
ffmpegReader *ffmpegReader
|
||||||
uploader *multipartUploadWriter
|
uploader *multipartUploader
|
||||||
s3Bucket, s3Key string
|
s3Bucket, s3Key string
|
||||||
store Store
|
store Store
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *getAudioFromYoutubeState) run(ctx context.Context, mediaSetID uuid.UUID) {
|
func (s *getAudioFromYoutubeState) run(ctx context.Context, mediaSetID uuid.UUID) {
|
||||||
mw := io.MultiWriter(s, s.uploader)
|
teeReader := io.TeeReader(s.ffmpegReader, s)
|
||||||
done := make(chan error)
|
bytesUploaded, err := s.uploader.Upload(ctx, teeReader, s.s3Bucket, s.s3Key, rawAudioMimeType)
|
||||||
var err error
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, copyErr := io.Copy(mw, s.ffmpegReader)
|
|
||||||
|
|
||||||
// At this point, there is no more data to send to the uploader.
|
|
||||||
// We can close it safely, it always returns nil.
|
|
||||||
_ = s.uploader.Close()
|
|
||||||
|
|
||||||
done <- copyErr
|
|
||||||
}()
|
|
||||||
|
|
||||||
outer:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
err = ctx.Err()
|
|
||||||
break outer
|
|
||||||
case err = <-done:
|
|
||||||
break outer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var framesUploaded int64
|
|
||||||
if err == nil {
|
|
||||||
if bytesUploaded, uploaderErr := s.uploader.Complete(); uploaderErr != nil {
|
|
||||||
err = uploaderErr
|
|
||||||
} else {
|
|
||||||
framesUploaded = bytesUploaded / int64(s.channels) / SizeOfInt16
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there was an error returned, the underlying ffmpegReader process may
|
// If there was an error returned, the underlying ffmpegReader process may
|
||||||
// still be active. Kill it.
|
// still be active. Kill it.
|
||||||
|
@ -426,7 +386,8 @@ outer:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Either way, we need to wait for the ffmpegReader process to exit.
|
// Either way, we need to wait for the ffmpegReader process to exit,
|
||||||
|
// and ensure there is no error.
|
||||||
if readerErr := s.ffmpegReader.Close(); readerErr != nil {
|
if readerErr := s.ffmpegReader.Close(); readerErr != nil {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = readerErr
|
err = readerErr
|
||||||
|
@ -438,7 +399,7 @@ outer:
|
||||||
ID: mediaSetID,
|
ID: mediaSetID,
|
||||||
AudioS3Bucket: sqlString(s.s3Bucket),
|
AudioS3Bucket: sqlString(s.s3Bucket),
|
||||||
AudioS3Key: sqlString(s.s3Key),
|
AudioS3Key: sqlString(s.s3Key),
|
||||||
AudioFrames: sqlInt64(framesUploaded),
|
AudioFrames: sqlInt64(bytesUploaded),
|
||||||
})
|
})
|
||||||
|
|
||||||
if updateErr != nil {
|
if updateErr != nil {
|
||||||
|
@ -449,12 +410,6 @@ outer:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error uploading asynchronously: %v", err)
|
log.Printf("error uploading asynchronously: %v", err)
|
||||||
|
|
||||||
newCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if abortUploadErr := s.uploader.Abort(newCtx); abortUploadErr != nil {
|
|
||||||
log.Printf("error aborting uploader: %v", abortUploadErr)
|
|
||||||
}
|
|
||||||
s.Abort(err)
|
s.Abort(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
@ -13,202 +14,181 @@ import (
|
||||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// multipartUploadWriter is a Writer that uploads transparently to an S3 bucket
|
type multipartUploader struct {
|
||||||
// in 5MB parts. It buffers data internally until a part is ready to send over
|
s3 S3Client
|
||||||
// the network. Parts are sent as soon as they exceed the minimum part size of
|
|
||||||
// 5MB.
|
|
||||||
//
|
|
||||||
// The caller must call either Complete() or Abort() after finishing writing.
|
|
||||||
// Failure to do so will leave S3 in an inconsistent state.
|
|
||||||
type multipartUploadWriter struct {
|
|
||||||
ctx context.Context
|
|
||||||
wg sync.WaitGroup
|
|
||||||
s3 S3Client
|
|
||||||
buf *bytes.Buffer
|
|
||||||
bucket, key, contentType string
|
|
||||||
partNum int32
|
|
||||||
uploadResults chan uploadResult
|
|
||||||
uploadID string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type uploadResult struct {
|
type uploadResult struct {
|
||||||
completedPart types.CompletedPart
|
completedPart types.CompletedPart
|
||||||
size int64
|
size int64
|
||||||
err error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const targetPartSizeBytes = 5 * 1024 * 1024 // 5MB
|
type readResult struct {
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
targetPartSizeBytes = 5 * 1024 * 1024 // 5MB
|
||||||
|
bufferOverflowSize = 16_384 // 16Kb
|
||||||
|
)
|
||||||
|
|
||||||
|
func newMultipartUploader(s3Client S3Client) *multipartUploader {
|
||||||
|
return &multipartUploader{s3: s3Client}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload uploads to an S3 bucket in 5MB parts. It buffers data internally
|
||||||
|
// until a part is ready to send over the network. Parts are sent as soon as
|
||||||
|
// they exceed the minimum part size of 5MB.
|
||||||
|
func (u *multipartUploader) Upload(ctx context.Context, r io.Reader, bucket, key, contentType string) (int64, error) {
|
||||||
|
var uploaded bool
|
||||||
|
|
||||||
// newMultipartUploadWriter creates a new multipart upload writer, including
|
|
||||||
// creating the upload on S3. Either Complete or Abort must be called after
|
|
||||||
// calling this function.
|
|
||||||
func newMultipartUploadWriter(ctx context.Context, s3Client S3Client, bucket, key, contentType string) (*multipartUploadWriter, error) {
|
|
||||||
input := s3.CreateMultipartUploadInput{
|
input := s3.CreateMultipartUploadInput{
|
||||||
Bucket: aws.String(bucket),
|
Bucket: aws.String(bucket),
|
||||||
Key: aws.String(key),
|
Key: aws.String(key),
|
||||||
ContentType: aws.String(contentType),
|
ContentType: aws.String(contentType),
|
||||||
}
|
}
|
||||||
|
output, err := u.s3.CreateMultipartUpload(ctx, &input)
|
||||||
output, err := s3Client.CreateMultipartUpload(ctx, &input)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating multipart upload: %v", err)
|
return 0, fmt.Errorf("error creating multipart upload: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
const bufferOverflowSize = 16_384
|
// abort the upload if possible, logging any errors, on exit.
|
||||||
b := make([]byte, 0, targetPartSizeBytes+bufferOverflowSize)
|
defer func() {
|
||||||
|
if uploaded {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
input := s3.AbortMultipartUploadInput{
|
||||||
|
Bucket: aws.String(bucket),
|
||||||
|
Key: aws.String(key),
|
||||||
|
UploadId: output.UploadId,
|
||||||
|
}
|
||||||
|
|
||||||
return &multipartUploadWriter{
|
_, abortErr := u.s3.AbortMultipartUpload(ctx, &input)
|
||||||
ctx: ctx,
|
if abortErr != nil {
|
||||||
s3: s3Client,
|
log.Printf("error aborting upload: %v", abortErr)
|
||||||
buf: bytes.NewBuffer(b),
|
} else {
|
||||||
bucket: bucket,
|
log.Printf("aborted upload, key = %s", key)
|
||||||
key: key,
|
}
|
||||||
contentType: contentType,
|
|
||||||
partNum: 1,
|
|
||||||
uploadResults: make(chan uploadResult),
|
|
||||||
uploadID: *output.UploadId,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *multipartUploadWriter) Write(p []byte) (int, error) {
|
|
||||||
n, err := u.buf.Write(p)
|
|
||||||
if err != nil {
|
|
||||||
return n, fmt.Errorf("error writing to buffer: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.buf.Len() >= targetPartSizeBytes {
|
|
||||||
buf := make([]byte, u.buf.Len())
|
|
||||||
copy(buf, u.buf.Bytes())
|
|
||||||
u.buf.Truncate(0)
|
|
||||||
|
|
||||||
u.wg.Add(1)
|
|
||||||
go u.uploadPart(buf, u.partNum)
|
|
||||||
|
|
||||||
u.partNum++
|
|
||||||
}
|
|
||||||
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *multipartUploadWriter) uploadPart(buf []byte, partNum int32) {
|
|
||||||
defer u.wg.Done()
|
|
||||||
|
|
||||||
partLen := len(buf)
|
|
||||||
log.Printf("uploading part num = %d, len = %d", partNum, partLen)
|
|
||||||
|
|
||||||
input := s3.UploadPartInput{
|
|
||||||
Body: bytes.NewReader(buf),
|
|
||||||
Bucket: aws.String(u.bucket),
|
|
||||||
Key: aws.String(u.key),
|
|
||||||
PartNumber: partNum,
|
|
||||||
UploadId: aws.String(u.uploadID),
|
|
||||||
ContentLength: int64(partLen),
|
|
||||||
}
|
|
||||||
|
|
||||||
output, uploadErr := u.s3.UploadPart(u.ctx, &input)
|
|
||||||
if uploadErr != nil {
|
|
||||||
// TODO: retry on failure
|
|
||||||
u.uploadResults <- uploadResult{err: fmt.Errorf("error uploading part: %v", uploadErr)}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("uploaded part num = %d, etag = %s, bytes = %d", partNum, *output.ETag, partLen)
|
|
||||||
|
|
||||||
u.uploadResults <- uploadResult{
|
|
||||||
completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum},
|
|
||||||
size: int64(partLen),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close signals that no further data will be written to the writer.
|
|
||||||
// Always returns nil.
|
|
||||||
func (u *multipartUploadWriter) Close() error {
|
|
||||||
// TODO: trigger Complete() here too?
|
|
||||||
close(u.uploadResults)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Complete waits for all currently uploading parts to be uploaded, and
|
|
||||||
// finalizes the object in S3.
|
|
||||||
//
|
|
||||||
// Close() must have been been called first.
|
|
||||||
func (u *multipartUploadWriter) Complete() (int64, error) {
|
|
||||||
completedParts := make([]types.CompletedPart, 0, 64)
|
|
||||||
var uploadedBytes int64
|
|
||||||
|
|
||||||
// Write() launches multiple goroutines to upload the parts asynchronously.
|
|
||||||
// We need a waitgroup to ensure that all parts are complete, and the channel
|
|
||||||
// has been closed, before we continue.
|
|
||||||
uploadDone := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
u.wg.Wait()
|
|
||||||
close(u.uploadResults)
|
|
||||||
uploadDone <- struct{}{}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
uploadResultChan := make(chan uploadResult)
|
||||||
|
uploadErrorChan := make(chan error, 1)
|
||||||
|
|
||||||
|
// uploadPart uploads an individual part.
|
||||||
|
uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
partLen := int64(len(buf))
|
||||||
|
log.Printf("uploading part num = %d, len = %d", partNum, partLen)
|
||||||
|
|
||||||
|
input := s3.UploadPartInput{
|
||||||
|
Body: bytes.NewReader(buf),
|
||||||
|
Bucket: aws.String(bucket),
|
||||||
|
Key: aws.String(key),
|
||||||
|
PartNumber: partNum,
|
||||||
|
UploadId: output.UploadId,
|
||||||
|
ContentLength: partLen,
|
||||||
|
}
|
||||||
|
|
||||||
|
output, uploadErr := u.s3.UploadPart(ctx, &input)
|
||||||
|
if uploadErr != nil {
|
||||||
|
// TODO: retry on failure
|
||||||
|
uploadErrorChan <- uploadErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("uploaded part num = %d, etag = %s, bytes = %d", partNum, *output.ETag, partLen)
|
||||||
|
|
||||||
|
uploadResultChan <- uploadResult{
|
||||||
|
completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum},
|
||||||
|
size: partLen,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wgDone := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1) // done when the reader returns EOF
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
wgDone <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
readChan := make(chan readResult)
|
||||||
|
buf := make([]byte, 32_768)
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
var rr readResult
|
||||||
|
rr.n, rr.err = r.Read(buf)
|
||||||
|
readChan <- rr
|
||||||
|
|
||||||
|
if rr.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var closing bool
|
||||||
|
currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+bufferOverflowSize))
|
||||||
|
partNum := int32(1)
|
||||||
|
results := make([]uploadResult, 0, 64)
|
||||||
|
|
||||||
outer:
|
outer:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case uploadResult, ok := <-u.uploadResults:
|
case uploadResult := <-uploadResultChan:
|
||||||
if !ok {
|
results = append(results, uploadResult)
|
||||||
break outer
|
case uploadErr := <-uploadErrorChan:
|
||||||
}
|
return 0, fmt.Errorf("error while uploading part: %v", uploadErr)
|
||||||
// if len(completedParts) == 3 {
|
case <-wgDone:
|
||||||
// return 0, errors.New("nope")
|
break outer
|
||||||
// }
|
case <-ctx.Done():
|
||||||
if uploadResult.err != nil {
|
return 0, ctx.Err()
|
||||||
return 0, uploadResult.err
|
case readResult := <-readChan:
|
||||||
|
if readResult.err == io.EOF {
|
||||||
|
wg.Done()
|
||||||
|
closing = true
|
||||||
|
} else if readResult.err != nil {
|
||||||
|
return 0, fmt.Errorf("reader error: %v", readResult.err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("APPENDING PART, len now", len(completedParts))
|
_, _ = currPart.Write(buf[:readResult.n])
|
||||||
completedParts = append(completedParts, uploadResult.completedPart)
|
if closing || currPart.Len() >= targetPartSizeBytes {
|
||||||
uploadedBytes += uploadResult.size
|
part := make([]byte, currPart.Len())
|
||||||
case <-uploadDone:
|
copy(part, currPart.Bytes())
|
||||||
break outer
|
currPart.Truncate(0)
|
||||||
case <-u.ctx.Done():
|
|
||||||
return 0, u.ctx.Err()
|
wg.Add(1)
|
||||||
|
go uploadPart(&wg, part, partNum)
|
||||||
|
partNum++
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(completedParts) == 0 {
|
if len(results) == 0 {
|
||||||
return 0, errors.New("no parts available to upload")
|
return 0, errors.New("no parts available to upload")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("parts - %+v, bucket - %s, key - %s, id - %s", completedParts, u.bucket, u.key, u.uploadID)
|
completedParts := make([]types.CompletedPart, 0, 64)
|
||||||
log.Printf("len(parts) = %d, cap(parts) = %d", len(completedParts), cap(completedParts))
|
var uploadedBytes int64
|
||||||
|
for _, result := range results {
|
||||||
|
completedParts = append(completedParts, result.completedPart)
|
||||||
|
uploadedBytes += result.size
|
||||||
|
}
|
||||||
|
|
||||||
input := s3.CompleteMultipartUploadInput{
|
completeInput := s3.CompleteMultipartUploadInput{
|
||||||
Bucket: aws.String(u.bucket),
|
Bucket: aws.String(bucket),
|
||||||
Key: aws.String(u.key),
|
Key: aws.String(key),
|
||||||
UploadId: aws.String(u.uploadID),
|
UploadId: output.UploadId,
|
||||||
MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts},
|
MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := u.s3.CompleteMultipartUpload(u.ctx, &input)
|
if _, err = u.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil {
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("error completing upload: %v", err)
|
return 0, fmt.Errorf("error completing upload: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("completed upload, key = %s, bytesUploaded = %d", u.key, uploadedBytes)
|
log.Printf("completed upload, key = %s, bytesUploaded = %d", key, uploadedBytes)
|
||||||
|
uploaded = true
|
||||||
|
|
||||||
return uploadedBytes, nil
|
return uploadedBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Abort aborts the upload process, cancelling the upload on S3. It accepts a
|
|
||||||
// separate context to the associated writer in case it is called during
|
|
||||||
// cleanup after the original context was killed.
|
|
||||||
func (u *multipartUploadWriter) Abort(ctx context.Context) error {
|
|
||||||
input := s3.AbortMultipartUploadInput{
|
|
||||||
Bucket: aws.String(u.bucket),
|
|
||||||
Key: aws.String(u.key),
|
|
||||||
UploadId: aws.String(u.uploadID),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := u.s3.AbortMultipartUpload(ctx, &input)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error aborting upload: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("aborted upload, key = %s", u.key)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue