diff --git a/backend/media/uploader.go b/backend/media/uploader.go index 5ba2738..1ac6e38 100644 --- a/backend/media/uploader.go +++ b/backend/media/uploader.go @@ -30,7 +30,7 @@ type readResult struct { const ( targetPartSizeBytes = 5 * 1024 * 1024 // 5MB - bufferOverflowSize = 16_384 // 16Kb + readBufferSizeBytes = 32_768 // 32Kb ) func newMultipartUploader(s3Client S3Client) *multipartUploader { @@ -64,7 +64,13 @@ func (u *multipartUploader) Upload(ctx context.Context, r io.Reader, bucket, key UploadId: output.UploadId, } - _, abortErr := u.s3.AbortMultipartUpload(ctx, &input) + // if the context was cancelled, just use the background context. + ctxToUse := ctx + if ctxToUse.Err() != nil { + ctxToUse = context.Background() + } + + _, abortErr := u.s3.AbortMultipartUpload(ctxToUse, &input) if abortErr != nil { log.Printf("error aborting upload: %v", abortErr) } else { @@ -108,51 +114,32 @@ func (u *multipartUploader) Upload(ctx context.Context, r io.Reader, bucket, key wgDone := make(chan struct{}) var wg sync.WaitGroup - wg.Add(1) // done when the reader returns EOF + wg.Add(1) // done when the reader goroutine returns go func() { wg.Wait() wgDone <- struct{}{} }() - readChan := make(chan readResult) - buf := make([]byte, 32_768) - go func() { - for { - var rr readResult - rr.n, rr.err = r.Read(buf) - readChan <- rr + readChan := make(chan error, 1) - if rr.err != nil { + go func() { + defer wg.Done() + + var closing bool + currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+readBufferSizeBytes)) + partNum := int32(1) + buf := make([]byte, readBufferSizeBytes) + + for { + n, readErr := r.Read(buf) + if readErr == io.EOF { + closing = true + } else if readErr != nil { + readChan <- readErr return } - } - }() - var closing bool - currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+bufferOverflowSize)) - partNum := int32(1) - results := make([]uploadResult, 0, 64) - -outer: - for { - select { - case uploadResult := <-uploadResultChan: - results = append(results, uploadResult) - case uploadErr := <-uploadErrorChan: - return 0, fmt.Errorf("error while uploading part: %v", uploadErr) - case <-wgDone: - break outer - case <-ctx.Done(): - return 0, ctx.Err() - case readResult := <-readChan: - if readResult.err == io.EOF { - wg.Done() - closing = true - } else if readResult.err != nil { - return 0, fmt.Errorf("reader error: %v", readResult.err) - } - - _, _ = currPart.Write(buf[:readResult.n]) + _, _ = currPart.Write(buf[:n]) if closing || currPart.Len() >= targetPartSizeBytes { part := make([]byte, currPart.Len()) copy(part, currPart.Bytes()) @@ -162,6 +149,30 @@ outer: go uploadPart(&wg, part, partNum) partNum++ } + + if closing { + return + } + } + }() + + results := make([]uploadResult, 0, 64) + +outer: + for { + select { + case readErr := <-readChan: + if readErr != io.EOF { + return 0, fmt.Errorf("reader error: %v", readErr) + } + case uploadResult := <-uploadResultChan: + results = append(results, uploadResult) + case uploadErr := <-uploadErrorChan: + return 0, fmt.Errorf("error while uploading part: %v", uploadErr) + case <-ctx.Done(): + return 0, ctx.Err() + case <-wgDone: + break outer } }