281 lines
7.7 KiB
Go
281 lines
7.7 KiB
Go
package filestore
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
|
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// S3API provides an API to AWS S3.
|
|
type S3API struct {
|
|
S3Client
|
|
S3PresignClient
|
|
}
|
|
|
|
// S3Client wraps the AWS S3 service client.
|
|
type S3Client interface {
|
|
GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)
|
|
CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error)
|
|
UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error)
|
|
AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error)
|
|
CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error)
|
|
}
|
|
|
|
// S3PresignClient wraps the AWS S3 Presign client.
|
|
type S3PresignClient interface {
|
|
PresignGetObject(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) (*signerv4.PresignedHTTPRequest, error)
|
|
}
|
|
|
|
// S3FileStore stores files on Amazon S3.
|
|
type S3FileStore struct {
|
|
s3 S3API
|
|
bucket string
|
|
urlExpiry time.Duration
|
|
logger *zap.SugaredLogger
|
|
}
|
|
|
|
// NewS3FileStore builds a new S3FileStore using the provided configuration.
|
|
func NewS3FileStore(s3API S3API, bucket string, urlExpiry time.Duration, logger *zap.SugaredLogger) *S3FileStore {
|
|
return &S3FileStore{
|
|
s3: s3API,
|
|
bucket: bucket,
|
|
urlExpiry: urlExpiry,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// GetObject returns an io.Reader that returns an object associated with the
|
|
// provided key.
|
|
func (s *S3FileStore) GetObject(ctx context.Context, key string) (io.ReadCloser, error) {
|
|
input := s3.GetObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
}
|
|
output, err := s.s3.GetObject(ctx, &input)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting object from s3: %v", err)
|
|
}
|
|
return output.Body, nil
|
|
}
|
|
|
|
// GetObjectWithRange returns an io.Reader that returns a partial object
|
|
// associated with the provided key.
|
|
func (s *S3FileStore) GetObjectWithRange(ctx context.Context, key string, start, end int64) (io.ReadCloser, error) {
|
|
byteRange := fmt.Sprintf("bytes=%d-%d", start, end)
|
|
input := s3.GetObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
Range: aws.String(byteRange),
|
|
}
|
|
output, err := s.s3.GetObject(ctx, &input)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting object from s3: %v", err)
|
|
}
|
|
return output.Body, nil
|
|
}
|
|
|
|
// GetURL returns a presigned URL pointing to the object associated with the
|
|
// provided key.
|
|
func (s *S3FileStore) GetURL(ctx context.Context, key string) (string, error) {
|
|
input := s3.GetObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
}
|
|
request, err := s.s3.PresignGetObject(ctx, &input, s3.WithPresignExpires(s.urlExpiry))
|
|
if err != nil {
|
|
return "", fmt.Errorf("error generating presigned URL: %v", err)
|
|
}
|
|
return request.URL, nil
|
|
}
|
|
|
|
// PutObject uploads an object using multipart upload, returning the number of
|
|
// bytes uploaded and any error.
|
|
func (s *S3FileStore) PutObject(ctx context.Context, key string, r io.Reader, contentType string) (int64, error) {
|
|
const (
|
|
targetPartSizeBytes = 5 * 1024 * 1024 // 5MB
|
|
readBufferSizeBytes = 32_768 // 32Kb
|
|
)
|
|
|
|
var uploaded bool
|
|
|
|
input := s3.CreateMultipartUploadInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
ContentType: aws.String(contentType),
|
|
}
|
|
output, err := s.s3.CreateMultipartUpload(ctx, &input)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("error creating multipart upload: %v", err)
|
|
}
|
|
|
|
// abort the upload if possible, logging any errors, on exit.
|
|
defer func() {
|
|
if uploaded {
|
|
return
|
|
}
|
|
input := s3.AbortMultipartUploadInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
UploadId: output.UploadId,
|
|
}
|
|
|
|
// if the context was cancelled, just use the background context.
|
|
ctxToUse := ctx
|
|
if ctxToUse.Err() != nil {
|
|
ctxToUse = context.Background()
|
|
}
|
|
|
|
_, deferErr := s.s3.AbortMultipartUpload(ctxToUse, &input)
|
|
if deferErr != nil {
|
|
s.logger.Errorf("uploader: error aborting upload: %v", deferErr)
|
|
} else {
|
|
s.logger.Infof("aborted upload, key = %s", key)
|
|
}
|
|
}()
|
|
|
|
type uploadedPart struct {
|
|
part types.CompletedPart
|
|
size int64
|
|
}
|
|
uploadResultChan := make(chan uploadedPart)
|
|
uploadErrorChan := make(chan error, 1)
|
|
|
|
// uploadPart uploads an individual part.
|
|
uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) {
|
|
defer wg.Done()
|
|
|
|
partLen := int64(len(buf))
|
|
s.logger.With("key", key, "partNum", partNum, "partLen", partLen).Debug("uploading part")
|
|
|
|
input := s3.UploadPartInput{
|
|
Body: bytes.NewReader(buf),
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
PartNumber: partNum,
|
|
UploadId: output.UploadId,
|
|
ContentLength: partLen,
|
|
}
|
|
|
|
output, uploadErr := s.s3.UploadPart(ctx, &input)
|
|
if uploadErr != nil {
|
|
// TODO: retry on failure
|
|
uploadErrorChan <- uploadErr
|
|
return
|
|
}
|
|
|
|
s.logger.With("key", key, "partNum", partNum, "partLen", partLen, "etag", *output.ETag).Debug("uploaded part")
|
|
|
|
uploadResultChan <- uploadedPart{
|
|
part: types.CompletedPart{ETag: output.ETag, PartNumber: partNum},
|
|
size: partLen,
|
|
}
|
|
}
|
|
|
|
wgDone := make(chan struct{})
|
|
var wg sync.WaitGroup
|
|
wg.Add(1) // done when the reader goroutine returns
|
|
go func() {
|
|
wg.Wait()
|
|
wgDone <- struct{}{}
|
|
}()
|
|
|
|
readChan := make(chan error, 1)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
var closing bool
|
|
currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+readBufferSizeBytes))
|
|
partNum := int32(1)
|
|
buf := make([]byte, readBufferSizeBytes)
|
|
|
|
for {
|
|
n, readErr := r.Read(buf)
|
|
if readErr == io.EOF {
|
|
closing = true
|
|
} else if readErr != nil {
|
|
readChan <- readErr
|
|
return
|
|
}
|
|
|
|
_, _ = currPart.Write(buf[:n])
|
|
if closing || currPart.Len() >= targetPartSizeBytes {
|
|
part := make([]byte, currPart.Len())
|
|
copy(part, currPart.Bytes())
|
|
currPart.Truncate(0)
|
|
|
|
wg.Add(1)
|
|
go uploadPart(&wg, part, partNum)
|
|
partNum++
|
|
}
|
|
|
|
if closing {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
results := make([]uploadedPart, 0, 64)
|
|
|
|
outer:
|
|
for {
|
|
select {
|
|
case readErr := <-readChan:
|
|
if readErr != io.EOF {
|
|
return 0, fmt.Errorf("reader error: %v", readErr)
|
|
}
|
|
case uploadResult := <-uploadResultChan:
|
|
results = append(results, uploadResult)
|
|
case uploadErr := <-uploadErrorChan:
|
|
return 0, fmt.Errorf("error while uploading part: %v", uploadErr)
|
|
case <-ctx.Done():
|
|
return 0, ctx.Err()
|
|
case <-wgDone:
|
|
break outer
|
|
}
|
|
}
|
|
|
|
if len(results) == 0 {
|
|
return 0, errors.New("no parts available to upload")
|
|
}
|
|
|
|
completedParts := make([]types.CompletedPart, 0, 64)
|
|
var uploadedBytes int64
|
|
for _, result := range results {
|
|
completedParts = append(completedParts, result.part)
|
|
uploadedBytes += result.size
|
|
}
|
|
|
|
// the parts may be out of order, especially with slow network conditions:
|
|
sort.Slice(completedParts, func(i, j int) bool {
|
|
return completedParts[i].PartNumber < completedParts[j].PartNumber
|
|
})
|
|
|
|
completeInput := s3.CompleteMultipartUploadInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
UploadId: output.UploadId,
|
|
MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts},
|
|
}
|
|
|
|
if _, err = s.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil {
|
|
return 0, fmt.Errorf("error completing upload: %v", err)
|
|
}
|
|
|
|
s.logger.With("key", key, "numParts", len(completedParts), "len", uploadedBytes).Debug("completed upload")
|
|
uploaded = true
|
|
|
|
return uploadedBytes, nil
|
|
}
|