diff --git a/backend/media/service.go b/backend/media/service.go index a133ef5..e1fe921 100644 --- a/backend/media/service.go +++ b/backend/media/service.go @@ -343,16 +343,7 @@ func (s *MediaSetService) getAudioFromYoutube(ctx context.Context, mediaSet stor } s3Key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID) - uploader, err := newMultipartUploadWriter( - ctx, - s.s3, - s3Bucket, - s3Key, - rawAudioMimeType, - ) - if err != nil { - return nil, fmt.Errorf("error creating uploader: %v", err) - } + uploader := newMultipartUploader(s.s3) fetchAudioProgressReader := newGetAudioProgressReader( binary.LittleEndian, @@ -378,45 +369,14 @@ type getAudioFromYoutubeState struct { *fetchAudioProgressReader ffmpegReader *ffmpegReader - uploader *multipartUploadWriter + uploader *multipartUploader s3Bucket, s3Key string store Store } func (s *getAudioFromYoutubeState) run(ctx context.Context, mediaSetID uuid.UUID) { - mw := io.MultiWriter(s, s.uploader) - done := make(chan error) - var err error - - go func() { - _, copyErr := io.Copy(mw, s.ffmpegReader) - - // At this point, there is no more data to send to the uploader. - // We can close it safely, it always returns nil. - _ = s.uploader.Close() - - done <- copyErr - }() - -outer: - for { - select { - case <-ctx.Done(): - err = ctx.Err() - break outer - case err = <-done: - break outer - } - } - - var framesUploaded int64 - if err == nil { - if bytesUploaded, uploaderErr := s.uploader.Complete(); uploaderErr != nil { - err = uploaderErr - } else { - framesUploaded = bytesUploaded / int64(s.channels) / SizeOfInt16 - } - } + teeReader := io.TeeReader(s.ffmpegReader, s) + bytesUploaded, err := s.uploader.Upload(ctx, teeReader, s.s3Bucket, s.s3Key, rawAudioMimeType) // If there was an error returned, the underlying ffmpegReader process may // still be active. Kill it. @@ -426,7 +386,8 @@ outer: } } - // Either way, we need to wait for the ffmpegReader process to exit. + // Either way, we need to wait for the ffmpegReader process to exit, + // and ensure there is no error. if readerErr := s.ffmpegReader.Close(); readerErr != nil { if err == nil { err = readerErr @@ -438,7 +399,7 @@ outer: ID: mediaSetID, AudioS3Bucket: sqlString(s.s3Bucket), AudioS3Key: sqlString(s.s3Key), - AudioFrames: sqlInt64(framesUploaded), + AudioFrames: sqlInt64(bytesUploaded), }) if updateErr != nil { @@ -449,12 +410,6 @@ outer: if err != nil { log.Printf("error uploading asynchronously: %v", err) - newCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - if abortUploadErr := s.uploader.Abort(newCtx); abortUploadErr != nil { - log.Printf("error aborting uploader: %v", abortUploadErr) - } s.Abort(err) return } diff --git a/backend/media/uploader.go b/backend/media/uploader.go index ebba02a..5ba2738 100644 --- a/backend/media/uploader.go +++ b/backend/media/uploader.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" "log" "sync" @@ -13,202 +14,181 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3/types" ) -// multipartUploadWriter is a Writer that uploads transparently to an S3 bucket -// in 5MB parts. It buffers data internally until a part is ready to send over -// the network. Parts are sent as soon as they exceed the minimum part size of -// 5MB. -// -// The caller must call either Complete() or Abort() after finishing writing. -// Failure to do so will leave S3 in an inconsistent state. -type multipartUploadWriter struct { - ctx context.Context - wg sync.WaitGroup - s3 S3Client - buf *bytes.Buffer - bucket, key, contentType string - partNum int32 - uploadResults chan uploadResult - uploadID string +type multipartUploader struct { + s3 S3Client } type uploadResult struct { completedPart types.CompletedPart size int64 - err error } -const targetPartSizeBytes = 5 * 1024 * 1024 // 5MB +type readResult struct { + n int + err error +} + +const ( + targetPartSizeBytes = 5 * 1024 * 1024 // 5MB + bufferOverflowSize = 16_384 // 16Kb +) + +func newMultipartUploader(s3Client S3Client) *multipartUploader { + return &multipartUploader{s3: s3Client} +} + +// Upload uploads to an S3 bucket in 5MB parts. It buffers data internally +// until a part is ready to send over the network. Parts are sent as soon as +// they exceed the minimum part size of 5MB. +func (u *multipartUploader) Upload(ctx context.Context, r io.Reader, bucket, key, contentType string) (int64, error) { + var uploaded bool -// newMultipartUploadWriter creates a new multipart upload writer, including -// creating the upload on S3. Either Complete or Abort must be called after -// calling this function. -func newMultipartUploadWriter(ctx context.Context, s3Client S3Client, bucket, key, contentType string) (*multipartUploadWriter, error) { input := s3.CreateMultipartUploadInput{ Bucket: aws.String(bucket), Key: aws.String(key), ContentType: aws.String(contentType), } - - output, err := s3Client.CreateMultipartUpload(ctx, &input) + output, err := u.s3.CreateMultipartUpload(ctx, &input) if err != nil { - return nil, fmt.Errorf("error creating multipart upload: %v", err) + return 0, fmt.Errorf("error creating multipart upload: %v", err) } - const bufferOverflowSize = 16_384 - b := make([]byte, 0, targetPartSizeBytes+bufferOverflowSize) + // abort the upload if possible, logging any errors, on exit. + defer func() { + if uploaded { + return + } + input := s3.AbortMultipartUploadInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + UploadId: output.UploadId, + } - return &multipartUploadWriter{ - ctx: ctx, - s3: s3Client, - buf: bytes.NewBuffer(b), - bucket: bucket, - key: key, - contentType: contentType, - partNum: 1, - uploadResults: make(chan uploadResult), - uploadID: *output.UploadId, - }, nil -} - -func (u *multipartUploadWriter) Write(p []byte) (int, error) { - n, err := u.buf.Write(p) - if err != nil { - return n, fmt.Errorf("error writing to buffer: %v", err) - } - - if u.buf.Len() >= targetPartSizeBytes { - buf := make([]byte, u.buf.Len()) - copy(buf, u.buf.Bytes()) - u.buf.Truncate(0) - - u.wg.Add(1) - go u.uploadPart(buf, u.partNum) - - u.partNum++ - } - - return n, err -} - -func (u *multipartUploadWriter) uploadPart(buf []byte, partNum int32) { - defer u.wg.Done() - - partLen := len(buf) - log.Printf("uploading part num = %d, len = %d", partNum, partLen) - - input := s3.UploadPartInput{ - Body: bytes.NewReader(buf), - Bucket: aws.String(u.bucket), - Key: aws.String(u.key), - PartNumber: partNum, - UploadId: aws.String(u.uploadID), - ContentLength: int64(partLen), - } - - output, uploadErr := u.s3.UploadPart(u.ctx, &input) - if uploadErr != nil { - // TODO: retry on failure - u.uploadResults <- uploadResult{err: fmt.Errorf("error uploading part: %v", uploadErr)} - return - } - - log.Printf("uploaded part num = %d, etag = %s, bytes = %d", partNum, *output.ETag, partLen) - - u.uploadResults <- uploadResult{ - completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum}, - size: int64(partLen), - } -} - -// Close signals that no further data will be written to the writer. -// Always returns nil. -func (u *multipartUploadWriter) Close() error { - // TODO: trigger Complete() here too? - close(u.uploadResults) - return nil -} - -// Complete waits for all currently uploading parts to be uploaded, and -// finalizes the object in S3. -// -// Close() must have been been called first. -func (u *multipartUploadWriter) Complete() (int64, error) { - completedParts := make([]types.CompletedPart, 0, 64) - var uploadedBytes int64 - - // Write() launches multiple goroutines to upload the parts asynchronously. - // We need a waitgroup to ensure that all parts are complete, and the channel - // has been closed, before we continue. - uploadDone := make(chan struct{}) - go func() { - u.wg.Wait() - close(u.uploadResults) - uploadDone <- struct{}{} + _, abortErr := u.s3.AbortMultipartUpload(ctx, &input) + if abortErr != nil { + log.Printf("error aborting upload: %v", abortErr) + } else { + log.Printf("aborted upload, key = %s", key) + } }() + uploadResultChan := make(chan uploadResult) + uploadErrorChan := make(chan error, 1) + + // uploadPart uploads an individual part. + uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) { + defer wg.Done() + + partLen := int64(len(buf)) + log.Printf("uploading part num = %d, len = %d", partNum, partLen) + + input := s3.UploadPartInput{ + Body: bytes.NewReader(buf), + Bucket: aws.String(bucket), + Key: aws.String(key), + PartNumber: partNum, + UploadId: output.UploadId, + ContentLength: partLen, + } + + output, uploadErr := u.s3.UploadPart(ctx, &input) + if uploadErr != nil { + // TODO: retry on failure + uploadErrorChan <- uploadErr + return + } + + log.Printf("uploaded part num = %d, etag = %s, bytes = %d", partNum, *output.ETag, partLen) + + uploadResultChan <- uploadResult{ + completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum}, + size: partLen, + } + } + + wgDone := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) // done when the reader returns EOF + go func() { + wg.Wait() + wgDone <- struct{}{} + }() + + readChan := make(chan readResult) + buf := make([]byte, 32_768) + go func() { + for { + var rr readResult + rr.n, rr.err = r.Read(buf) + readChan <- rr + + if rr.err != nil { + return + } + } + }() + + var closing bool + currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+bufferOverflowSize)) + partNum := int32(1) + results := make([]uploadResult, 0, 64) + outer: for { select { - case uploadResult, ok := <-u.uploadResults: - if !ok { - break outer - } - // if len(completedParts) == 3 { - // return 0, errors.New("nope") - // } - if uploadResult.err != nil { - return 0, uploadResult.err + case uploadResult := <-uploadResultChan: + results = append(results, uploadResult) + case uploadErr := <-uploadErrorChan: + return 0, fmt.Errorf("error while uploading part: %v", uploadErr) + case <-wgDone: + break outer + case <-ctx.Done(): + return 0, ctx.Err() + case readResult := <-readChan: + if readResult.err == io.EOF { + wg.Done() + closing = true + } else if readResult.err != nil { + return 0, fmt.Errorf("reader error: %v", readResult.err) } - log.Println("APPENDING PART, len now", len(completedParts)) - completedParts = append(completedParts, uploadResult.completedPart) - uploadedBytes += uploadResult.size - case <-uploadDone: - break outer - case <-u.ctx.Done(): - return 0, u.ctx.Err() + _, _ = currPart.Write(buf[:readResult.n]) + if closing || currPart.Len() >= targetPartSizeBytes { + part := make([]byte, currPart.Len()) + copy(part, currPart.Bytes()) + currPart.Truncate(0) + + wg.Add(1) + go uploadPart(&wg, part, partNum) + partNum++ + } } } - if len(completedParts) == 0 { + if len(results) == 0 { return 0, errors.New("no parts available to upload") } - log.Printf("parts - %+v, bucket - %s, key - %s, id - %s", completedParts, u.bucket, u.key, u.uploadID) - log.Printf("len(parts) = %d, cap(parts) = %d", len(completedParts), cap(completedParts)) + completedParts := make([]types.CompletedPart, 0, 64) + var uploadedBytes int64 + for _, result := range results { + completedParts = append(completedParts, result.completedPart) + uploadedBytes += result.size + } - input := s3.CompleteMultipartUploadInput{ - Bucket: aws.String(u.bucket), - Key: aws.String(u.key), - UploadId: aws.String(u.uploadID), + completeInput := s3.CompleteMultipartUploadInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + UploadId: output.UploadId, MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts}, } - _, err := u.s3.CompleteMultipartUpload(u.ctx, &input) - if err != nil { + if _, err = u.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil { return 0, fmt.Errorf("error completing upload: %v", err) } - log.Printf("completed upload, key = %s, bytesUploaded = %d", u.key, uploadedBytes) + log.Printf("completed upload, key = %s, bytesUploaded = %d", key, uploadedBytes) + uploaded = true + return uploadedBytes, nil } - -// Abort aborts the upload process, cancelling the upload on S3. It accepts a -// separate context to the associated writer in case it is called during -// cleanup after the original context was killed. -func (u *multipartUploadWriter) Abort(ctx context.Context) error { - input := s3.AbortMultipartUploadInput{ - Bucket: aws.String(u.bucket), - Key: aws.String(u.key), - UploadId: aws.String(u.uploadID), - } - - _, err := u.s3.AbortMultipartUpload(ctx, &input) - if err != nil { - return fmt.Errorf("error aborting upload: %v", err) - } - - log.Printf("aborted upload, key = %s", u.key) - - return nil -}