package media import ( "bytes" "context" "fmt" "log" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" ) // multipartUploadWriter is a Writer that uploads transparently to an S3 bucket // in 5MB parts. It buffers data internally until a part is ready to send over // the network. Parts are sent as soon as they exceed the minimum part size of // 5MB. // // The caller must call either Complete() or Abort() after finishing writing. // Failure to do so will leave S3 in an inconsistent state. type multipartUploadWriter struct { ctx context.Context s3 S3Client buf *bytes.Buffer bucket, key, contentType string uploadID string completedParts []types.CompletedPart bytesUploaded int64 } const targetPartSizeBytes = 5 * 1024 * 1024 // 5MB // newMultipartUploadWriter creates a new multipart upload writer, including // creating the upload on S3. Either Complete or Abort must be called after // calling this function. func newMultipartUploadWriter(ctx context.Context, s3Client S3Client, bucket, key, contentType string) (*multipartUploadWriter, error) { input := s3.CreateMultipartUploadInput{ Bucket: aws.String(bucket), Key: aws.String(key), ContentType: aws.String(contentType), } output, err := s3Client.CreateMultipartUpload(ctx, &input) if err != nil { return nil, fmt.Errorf("error creating multipart upload: %v", err) } const bufferOverflowSize = 16_384 b := make([]byte, 0, targetPartSizeBytes+bufferOverflowSize) return &multipartUploadWriter{ ctx: ctx, s3: s3Client, buf: bytes.NewBuffer(b), bucket: bucket, key: key, contentType: contentType, uploadID: *output.UploadId, }, nil } func (u *multipartUploadWriter) Write(p []byte) (int, error) { n, err := u.buf.Write(p) if err != nil { return n, fmt.Errorf("error writing to buffer: %v", err) } if u.buf.Len() >= targetPartSizeBytes { partLen := u.buf.Len() log.Printf("uploading part num = %d, len = %d", u.partNum(), partLen) input := s3.UploadPartInput{ Body: u.buf, Bucket: aws.String(u.bucket), Key: aws.String(u.key), PartNumber: u.partNum(), UploadId: aws.String(u.uploadID), ContentLength: int64(partLen), } output, uploadErr := u.s3.UploadPart(u.ctx, &input) if uploadErr != nil { // TODO: retry on failure return n, fmt.Errorf("error uploading part: %v", uploadErr) } log.Printf("uploaded part num = %d, etag = %s, bytes = %d", u.partNum(), *output.ETag, partLen) u.completedParts = append(u.completedParts, types.CompletedPart{ETag: output.ETag, PartNumber: u.partNum()}) u.bytesUploaded += int64(partLen) } return n, err } func (u *multipartUploadWriter) partNum() int32 { return int32(len(u.completedParts) + 1) } // Abort aborts the upload process, cancelling the upload on S3. It accepts a // separate context to the associated writer in case it is called during // cleanup after the original context was killed. func (u *multipartUploadWriter) Abort(ctx context.Context) error { input := s3.AbortMultipartUploadInput{ Bucket: aws.String(u.bucket), Key: aws.String(u.key), UploadId: aws.String(u.uploadID), } _, err := u.s3.AbortMultipartUpload(ctx, &input) if err != nil { return fmt.Errorf("error aborting upload: %v", err) } log.Printf("aborted upload, key = %s, bytesUploaded = %d", u.key, u.bytesUploaded) return nil } // Complete completes the upload process, finalizing the upload on S3. // If no parts have been successfully uploaded, then Abort() will be called // transparently. func (u *multipartUploadWriter) Complete() (int64, error) { if len(u.completedParts) == 0 { return 0, u.Abort(u.ctx) } input := s3.CompleteMultipartUploadInput{ Bucket: aws.String(u.bucket), Key: aws.String(u.key), UploadId: aws.String(u.uploadID), MultipartUpload: &types.CompletedMultipartUpload{ Parts: u.completedParts, }, } _, err := u.s3.CompleteMultipartUpload(u.ctx, &input) if err != nil { return 0, fmt.Errorf("error completing upload: %v", err) } log.Printf("completed upload, key = %s, bytesUploaded = %d", u.key, u.bytesUploaded) return u.bytesUploaded, nil }