package media import ( "bytes" "context" "errors" "fmt" "io" "log" "sync" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" ) type multipartUploader struct { s3 S3Client } type uploadResult struct { completedPart types.CompletedPart size int64 } type readResult struct { n int err error } const ( targetPartSizeBytes = 5 * 1024 * 1024 // 5MB bufferOverflowSize = 16_384 // 16Kb ) func newMultipartUploader(s3Client S3Client) *multipartUploader { return &multipartUploader{s3: s3Client} } // Upload uploads to an S3 bucket in 5MB parts. It buffers data internally // until a part is ready to send over the network. Parts are sent as soon as // they exceed the minimum part size of 5MB. func (u *multipartUploader) Upload(ctx context.Context, r io.Reader, bucket, key, contentType string) (int64, error) { var uploaded bool input := s3.CreateMultipartUploadInput{ Bucket: aws.String(bucket), Key: aws.String(key), ContentType: aws.String(contentType), } output, err := u.s3.CreateMultipartUpload(ctx, &input) if err != nil { return 0, fmt.Errorf("error creating multipart upload: %v", err) } // abort the upload if possible, logging any errors, on exit. defer func() { if uploaded { return } input := s3.AbortMultipartUploadInput{ Bucket: aws.String(bucket), Key: aws.String(key), UploadId: output.UploadId, } _, abortErr := u.s3.AbortMultipartUpload(ctx, &input) if abortErr != nil { log.Printf("error aborting upload: %v", abortErr) } else { log.Printf("aborted upload, key = %s", key) } }() uploadResultChan := make(chan uploadResult) uploadErrorChan := make(chan error, 1) // uploadPart uploads an individual part. uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) { defer wg.Done() partLen := int64(len(buf)) log.Printf("uploading part num = %d, len = %d", partNum, partLen) input := s3.UploadPartInput{ Body: bytes.NewReader(buf), Bucket: aws.String(bucket), Key: aws.String(key), PartNumber: partNum, UploadId: output.UploadId, ContentLength: partLen, } output, uploadErr := u.s3.UploadPart(ctx, &input) if uploadErr != nil { // TODO: retry on failure uploadErrorChan <- uploadErr return } log.Printf("uploaded part num = %d, etag = %s, bytes = %d", partNum, *output.ETag, partLen) uploadResultChan <- uploadResult{ completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum}, size: partLen, } } wgDone := make(chan struct{}) var wg sync.WaitGroup wg.Add(1) // done when the reader returns EOF go func() { wg.Wait() wgDone <- struct{}{} }() readChan := make(chan readResult) buf := make([]byte, 32_768) go func() { for { var rr readResult rr.n, rr.err = r.Read(buf) readChan <- rr if rr.err != nil { return } } }() var closing bool currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+bufferOverflowSize)) partNum := int32(1) results := make([]uploadResult, 0, 64) outer: for { select { case uploadResult := <-uploadResultChan: results = append(results, uploadResult) case uploadErr := <-uploadErrorChan: return 0, fmt.Errorf("error while uploading part: %v", uploadErr) case <-wgDone: break outer case <-ctx.Done(): return 0, ctx.Err() case readResult := <-readChan: if readResult.err == io.EOF { wg.Done() closing = true } else if readResult.err != nil { return 0, fmt.Errorf("reader error: %v", readResult.err) } _, _ = currPart.Write(buf[:readResult.n]) if closing || currPart.Len() >= targetPartSizeBytes { part := make([]byte, currPart.Len()) copy(part, currPart.Bytes()) currPart.Truncate(0) wg.Add(1) go uploadPart(&wg, part, partNum) partNum++ } } } if len(results) == 0 { return 0, errors.New("no parts available to upload") } completedParts := make([]types.CompletedPart, 0, 64) var uploadedBytes int64 for _, result := range results { completedParts = append(completedParts, result.completedPart) uploadedBytes += result.size } completeInput := s3.CompleteMultipartUploadInput{ Bucket: aws.String(bucket), Key: aws.String(key), UploadId: output.UploadId, MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts}, } if _, err = u.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil { return 0, fmt.Errorf("error completing upload: %v", err) } log.Printf("completed upload, key = %s, bytesUploaded = %d", key, uploadedBytes) uploaded = true return uploadedBytes, nil }