package media import ( "bytes" "context" "errors" "fmt" "log" "sync" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" ) // multipartUploadWriter is a Writer that uploads transparently to an S3 bucket // in 5MB parts. It buffers data internally until a part is ready to send over // the network. Parts are sent as soon as they exceed the minimum part size of // 5MB. // // The caller must call either Complete() or Abort() after finishing writing. // Failure to do so will leave S3 in an inconsistent state. type multipartUploadWriter struct { ctx context.Context wg sync.WaitGroup s3 S3Client buf *bytes.Buffer bucket, key, contentType string partNum int32 uploadResults chan uploadResult uploadID string } type uploadResult struct { completedPart types.CompletedPart size int64 err error } const targetPartSizeBytes = 5 * 1024 * 1024 // 5MB // newMultipartUploadWriter creates a new multipart upload writer, including // creating the upload on S3. Either Complete or Abort must be called after // calling this function. func newMultipartUploadWriter(ctx context.Context, s3Client S3Client, bucket, key, contentType string) (*multipartUploadWriter, error) { input := s3.CreateMultipartUploadInput{ Bucket: aws.String(bucket), Key: aws.String(key), ContentType: aws.String(contentType), } output, err := s3Client.CreateMultipartUpload(ctx, &input) if err != nil { return nil, fmt.Errorf("error creating multipart upload: %v", err) } const bufferOverflowSize = 16_384 b := make([]byte, 0, targetPartSizeBytes+bufferOverflowSize) return &multipartUploadWriter{ ctx: ctx, s3: s3Client, buf: bytes.NewBuffer(b), bucket: bucket, key: key, contentType: contentType, partNum: 1, uploadResults: make(chan uploadResult), uploadID: *output.UploadId, }, nil } func (u *multipartUploadWriter) Write(p []byte) (int, error) { n, err := u.buf.Write(p) if err != nil { return n, fmt.Errorf("error writing to buffer: %v", err) } if u.buf.Len() >= targetPartSizeBytes { buf := make([]byte, u.buf.Len()) copy(buf, u.buf.Bytes()) u.buf.Truncate(0) u.wg.Add(1) go u.uploadPart(buf, u.partNum) u.partNum++ } return n, err } func (u *multipartUploadWriter) uploadPart(buf []byte, partNum int32) { defer u.wg.Done() partLen := len(buf) log.Printf("uploading part num = %d, len = %d", partNum, partLen) input := s3.UploadPartInput{ Body: bytes.NewReader(buf), Bucket: aws.String(u.bucket), Key: aws.String(u.key), PartNumber: partNum, UploadId: aws.String(u.uploadID), ContentLength: int64(partLen), } output, uploadErr := u.s3.UploadPart(u.ctx, &input) if uploadErr != nil { // TODO: retry on failure u.uploadResults <- uploadResult{err: fmt.Errorf("error uploading part: %v", uploadErr)} return } log.Printf("uploaded part num = %d, etag = %s, bytes = %d", partNum, *output.ETag, partLen) u.uploadResults <- uploadResult{ completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum}, size: int64(partLen), } } // Complete waits for all currently uploading parts to be uploaded, and // finalizes the object in S3. Close() must be called first. func (u *multipartUploadWriter) Complete() (int64, error) { var completedParts []types.CompletedPart var uploadedBytes int64 // we wait for all parts to be completed before collecting the results: wgDone := make(chan struct{}) go func() { u.wg.Wait() close(u.uploadResults) wgDone <- struct{}{} }() outer: for { select { case uploadResult := <-u.uploadResults: if uploadResult.err != nil { return 0, uploadResult.err } completedParts = append(completedParts, uploadResult.completedPart) uploadedBytes += uploadResult.size case <-wgDone: break outer case <-u.ctx.Done(): return 0, u.ctx.Err() } } if len(completedParts) == 0 { return 0, errors.New("no parts available to upload") } input := s3.CompleteMultipartUploadInput{ Bucket: aws.String(u.bucket), Key: aws.String(u.key), UploadId: aws.String(u.uploadID), MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts}, } _, err := u.s3.CompleteMultipartUpload(u.ctx, &input) if err != nil { return 0, fmt.Errorf("error completing upload: %v", err) } log.Printf("completed upload, key = %s, bytesUploaded = %d", u.key, uploadedBytes) return uploadedBytes, nil } // Abort aborts the upload process, cancelling the upload on S3. It accepts a // separate context to the associated writer in case it is called during // cleanup after the original context was killed. func (u *multipartUploadWriter) Abort(ctx context.Context) error { input := s3.AbortMultipartUploadInput{ Bucket: aws.String(u.bucket), Key: aws.String(u.key), UploadId: aws.String(u.uploadID), } _, err := u.s3.AbortMultipartUpload(ctx, &input) if err != nil { return fmt.Errorf("error aborting upload: %v", err) } log.Printf("aborted upload, key = %s", u.key) return nil }