package media import ( "bytes" "context" "errors" "fmt" "log" "sync" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" ) // multipartUploadWriter is a Writer that uploads transparently to an S3 bucket // in 5MB parts. It buffers data internally until a part is ready to send over // the network. Parts are sent as soon as they exceed the minimum part size of // 5MB. // // The caller must call either Complete() or Abort() after finishing writing. // Failure to do so will leave S3 in an inconsistent state. type multipartUploadWriter struct { ctx context.Context wg sync.WaitGroup s3 S3Client buf *bytes.Buffer bucket, key, contentType string partNum int32 uploadResults chan uploadResult uploadID string } type uploadResult struct { completedPart types.CompletedPart size int64 err error } const targetPartSizeBytes = 5 * 1024 * 1024 // 5MB // newMultipartUploadWriter creates a new multipart upload writer, including // creating the upload on S3. Either Complete or Abort must be called after // calling this function. func newMultipartUploadWriter(ctx context.Context, s3Client S3Client, bucket, key, contentType string) (*multipartUploadWriter, error) { input := s3.CreateMultipartUploadInput{ Bucket: aws.String(bucket), Key: aws.String(key), ContentType: aws.String(contentType), } output, err := s3Client.CreateMultipartUpload(ctx, &input) if err != nil { return nil, fmt.Errorf("error creating multipart upload: %v", err) } const bufferOverflowSize = 16_384 b := make([]byte, 0, targetPartSizeBytes+bufferOverflowSize) return &multipartUploadWriter{ ctx: ctx, s3: s3Client, buf: bytes.NewBuffer(b), bucket: bucket, key: key, contentType: contentType, partNum: 1, uploadResults: make(chan uploadResult), uploadID: *output.UploadId, }, nil } func (u *multipartUploadWriter) Write(p []byte) (int, error) { n, err := u.buf.Write(p) if err != nil { return n, fmt.Errorf("error writing to buffer: %v", err) } if u.buf.Len() >= targetPartSizeBytes { buf := make([]byte, u.buf.Len()) copy(buf, u.buf.Bytes()) u.buf.Truncate(0) u.wg.Add(1) go u.uploadPart(buf, u.partNum) u.partNum++ } return n, err } func (u *multipartUploadWriter) uploadPart(buf []byte, partNum int32) { defer u.wg.Done() partLen := len(buf) log.Printf("uploading part num = %d, len = %d", partNum, partLen) input := s3.UploadPartInput{ Body: bytes.NewReader(buf), Bucket: aws.String(u.bucket), Key: aws.String(u.key), PartNumber: partNum, UploadId: aws.String(u.uploadID), ContentLength: int64(partLen), } output, uploadErr := u.s3.UploadPart(u.ctx, &input) if uploadErr != nil { // TODO: retry on failure u.uploadResults <- uploadResult{err: fmt.Errorf("error uploading part: %v", uploadErr)} return } log.Printf("uploaded part num = %d, etag = %s, bytes = %d", partNum, *output.ETag, partLen) u.uploadResults <- uploadResult{ completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum}, size: int64(partLen), } } // Close signals that no further data will be written to the writer. // Always returns nil. func (u *multipartUploadWriter) Close() error { // TODO: trigger Complete() here too? close(u.uploadResults) return nil } // Complete waits for all currently uploading parts to be uploaded, and // finalizes the object in S3. // // Close() must have been been called first. func (u *multipartUploadWriter) Complete() (int64, error) { completedParts := make([]types.CompletedPart, 0, 64) var uploadedBytes int64 // Write() launches multiple goroutines to upload the parts asynchronously. // We need a waitgroup to ensure that all parts are complete, and the channel // has been closed, before we continue. uploadDone := make(chan struct{}) go func() { u.wg.Wait() close(u.uploadResults) uploadDone <- struct{}{} }() outer: for { select { case uploadResult, ok := <-u.uploadResults: if !ok { break outer } // if len(completedParts) == 3 { // return 0, errors.New("nope") // } if uploadResult.err != nil { return 0, uploadResult.err } log.Println("APPENDING PART, len now", len(completedParts)) completedParts = append(completedParts, uploadResult.completedPart) uploadedBytes += uploadResult.size case <-uploadDone: break outer case <-u.ctx.Done(): return 0, u.ctx.Err() } } if len(completedParts) == 0 { return 0, errors.New("no parts available to upload") } log.Printf("parts - %+v, bucket - %s, key - %s, id - %s", completedParts, u.bucket, u.key, u.uploadID) log.Printf("len(parts) = %d, cap(parts) = %d", len(completedParts), cap(completedParts)) input := s3.CompleteMultipartUploadInput{ Bucket: aws.String(u.bucket), Key: aws.String(u.key), UploadId: aws.String(u.uploadID), MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts}, } _, err := u.s3.CompleteMultipartUpload(u.ctx, &input) if err != nil { return 0, fmt.Errorf("error completing upload: %v", err) } log.Printf("completed upload, key = %s, bytesUploaded = %d", u.key, uploadedBytes) return uploadedBytes, nil } // Abort aborts the upload process, cancelling the upload on S3. It accepts a // separate context to the associated writer in case it is called during // cleanup after the original context was killed. func (u *multipartUploadWriter) Abort(ctx context.Context) error { input := s3.AbortMultipartUploadInput{ Bucket: aws.String(u.bucket), Key: aws.String(u.key), UploadId: aws.String(u.uploadID), } _, err := u.s3.AbortMultipartUpload(ctx, &input) if err != nil { return fmt.Errorf("error aborting upload: %v", err) } log.Printf("aborted upload, key = %s", u.key) return nil }