package filestore //go:generate mockery --recursive --name S3Client --output ../generated/mocks //go:generate mockery --recursive --name S3PresignClient --output ../generated/mocks import ( "bytes" "context" "errors" "fmt" "io" "sort" "sync" "time" "github.com/aws/aws-sdk-go-v2/aws" signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" "go.uber.org/zap" ) // S3API provides an API to AWS S3. type S3API struct { S3Client S3PresignClient } // S3Client wraps the AWS S3 service client. type S3Client interface { GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error) CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error) AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) } // S3PresignClient wraps the AWS S3 Presign client. type S3PresignClient interface { PresignGetObject(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) (*signerv4.PresignedHTTPRequest, error) } // S3FileStore stores files on Amazon S3. type S3FileStore struct { s3 S3API bucket string urlExpiry time.Duration logger *zap.SugaredLogger } // NewS3FileStore builds a new S3FileStore using the provided configuration. func NewS3FileStore(s3API S3API, bucket string, urlExpiry time.Duration, logger *zap.SugaredLogger) *S3FileStore { return &S3FileStore{ s3: s3API, bucket: bucket, urlExpiry: urlExpiry, logger: logger, } } // GetObject returns an io.Reader that returns an object associated with the // provided key. func (s *S3FileStore) GetObject(ctx context.Context, key string) (io.ReadCloser, error) { input := s3.GetObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(key), } output, err := s.s3.GetObject(ctx, &input) if err != nil { return nil, fmt.Errorf("error getting object from s3: %v", err) } return output.Body, nil } // GetObjectWithRange returns an io.Reader that returns a partial object // associated with the provided key. func (s *S3FileStore) GetObjectWithRange(ctx context.Context, key string, start, end int64) (io.ReadCloser, error) { byteRange := fmt.Sprintf("bytes=%d-%d", start, end) input := s3.GetObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(key), Range: aws.String(byteRange), } output, err := s.s3.GetObject(ctx, &input) if err != nil { return nil, fmt.Errorf("error getting object from s3: %v", err) } return output.Body, nil } // GetURL returns a presigned URL pointing to the object associated with the // provided key. func (s *S3FileStore) GetURL(ctx context.Context, key string) (string, error) { input := s3.GetObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(key), } request, err := s.s3.PresignGetObject(ctx, &input, s3.WithPresignExpires(s.urlExpiry)) if err != nil { return "", fmt.Errorf("error generating presigned URL: %v", err) } return request.URL, nil } // PutObject uploads an object using multipart upload, returning the number of // bytes uploaded and any error. func (s *S3FileStore) PutObject(ctx context.Context, key string, r io.Reader, contentType string) (int64, error) { const ( targetPartSizeBytes = 5 * 1024 * 1024 // 5MB readBufferSizeBytes = 32_768 // 32Kb ) var uploaded bool input := s3.CreateMultipartUploadInput{ Bucket: aws.String(s.bucket), Key: aws.String(key), ContentType: aws.String(contentType), } output, err := s.s3.CreateMultipartUpload(ctx, &input) if err != nil { return 0, fmt.Errorf("error creating multipart upload: %v", err) } // abort the upload if possible, logging any errors, on exit. defer func() { if uploaded { return } input := s3.AbortMultipartUploadInput{ Bucket: aws.String(s.bucket), Key: aws.String(key), UploadId: output.UploadId, } // if the context was cancelled, just use the background context. ctxToUse := ctx if ctxToUse.Err() != nil { ctxToUse = context.Background() } _, deferErr := s.s3.AbortMultipartUpload(ctxToUse, &input) if deferErr != nil { s.logger.Errorf("uploader: error aborting upload: %v", deferErr) } else { s.logger.Infof("aborted upload, key = %s", key) } }() type uploadedPart struct { part types.CompletedPart size int64 } uploadResultChan := make(chan uploadedPart) uploadErrorChan := make(chan error, 1) // uploadPart uploads an individual part. uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) { defer wg.Done() partLen := int64(len(buf)) s.logger.With("key", key, "partNum", partNum, "partLen", partLen).Debug("uploading part") input := s3.UploadPartInput{ Body: bytes.NewReader(buf), Bucket: aws.String(s.bucket), Key: aws.String(key), PartNumber: partNum, UploadId: output.UploadId, ContentLength: partLen, } output, uploadErr := s.s3.UploadPart(ctx, &input) if uploadErr != nil { // TODO: retry on failure uploadErrorChan <- uploadErr return } s.logger.With("key", key, "partNum", partNum, "partLen", partLen, "etag", *output.ETag).Debug("uploaded part") uploadResultChan <- uploadedPart{ part: types.CompletedPart{ETag: output.ETag, PartNumber: partNum}, size: partLen, } } wgDone := make(chan struct{}) var wg sync.WaitGroup wg.Add(1) // done when the reader goroutine returns go func() { wg.Wait() wgDone <- struct{}{} }() readChan := make(chan error, 1) go func() { defer wg.Done() var closing bool currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+readBufferSizeBytes)) partNum := int32(1) buf := make([]byte, readBufferSizeBytes) for { n, readErr := r.Read(buf) if readErr == io.EOF { closing = true } else if readErr != nil { readChan <- readErr return } _, _ = currPart.Write(buf[:n]) if closing || currPart.Len() >= targetPartSizeBytes { part := make([]byte, currPart.Len()) copy(part, currPart.Bytes()) currPart.Truncate(0) wg.Add(1) go uploadPart(&wg, part, partNum) partNum++ } if closing { return } } }() results := make([]uploadedPart, 0, 64) outer: for { select { case readErr := <-readChan: if readErr != io.EOF { return 0, fmt.Errorf("reader error: %v", readErr) } case uploadResult := <-uploadResultChan: results = append(results, uploadResult) case uploadErr := <-uploadErrorChan: return 0, fmt.Errorf("error while uploading part: %v", uploadErr) case <-ctx.Done(): return 0, ctx.Err() case <-wgDone: break outer } } if len(results) == 0 { return 0, errors.New("no parts available to upload") } completedParts := make([]types.CompletedPart, 0, 64) var uploadedBytes int64 for _, result := range results { completedParts = append(completedParts, result.part) uploadedBytes += result.size } // the parts may be out of order, especially with slow network conditions: sort.Slice(completedParts, func(i, j int) bool { return completedParts[i].PartNumber < completedParts[j].PartNumber }) completeInput := s3.CompleteMultipartUploadInput{ Bucket: aws.String(s.bucket), Key: aws.String(key), UploadId: output.UploadId, MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts}, } if _, err = s.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil { return 0, fmt.Errorf("error completing upload: %v", err) } s.logger.With("key", key, "numParts", len(completedParts), "len", uploadedBytes).Debug("completed upload") uploaded = true return uploadedBytes, nil }