From f2d7af0860edf291c1069f8140d7a51e94a44267 Mon Sep 17 00:00:00 2001 From: Rob Watson Date: Tue, 7 Dec 2021 20:58:11 +0100 Subject: [PATCH] Extract S3 code to S3FileStore Re: #5 --- README.md | 2 +- backend/cmd/clipper/main.go | 49 ++-- backend/filestore/s3.go | 280 +++++++++++++++++++++ backend/generated/mocks/FileStore.go | 103 ++++++++ backend/generated/mocks/S3Client.go | 166 ------------ backend/generated/mocks/S3PresignClient.go | 48 ---- backend/media/get_audio.go | 56 ++--- backend/media/get_video.go | 35 +-- backend/media/service.go | 146 +++++------ backend/media/service_test.go | 75 +++--- backend/media/types.go | 32 +-- backend/media/uploader.go | 212 ---------------- backend/server/server.go | 26 +- 13 files changed, 569 insertions(+), 661 deletions(-) create mode 100644 backend/filestore/s3.go create mode 100644 backend/generated/mocks/FileStore.go delete mode 100644 backend/generated/mocks/S3Client.go delete mode 100644 backend/generated/mocks/S3PresignClient.go delete mode 100644 backend/media/uploader.go diff --git a/README.md b/README.md index fa5e83a..50abc61 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ sqlc generate Mocks require [mockery](https://github.com/vektra/mockery) to be installed, and can be regenerated with: ``` -go generate +go generate ./... ``` ### Migrations diff --git a/backend/cmd/clipper/main.go b/backend/cmd/clipper/main.go index ae8c074..e5cb688 100644 --- a/backend/cmd/clipper/main.go +++ b/backend/cmd/clipper/main.go @@ -6,18 +6,20 @@ import ( "time" "git.netflux.io/rob/clipper/config" + "git.netflux.io/rob/clipper/filestore" "git.netflux.io/rob/clipper/generated/store" - "git.netflux.io/rob/clipper/media" "git.netflux.io/rob/clipper/server" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/jackc/pgx/v4/pgxpool" "github.com/kkdai/youtube/v2" + "go.uber.org/zap" ) const ( - DefaultTimeout = 600 * time.Second + defaultTimeout = 600 * time.Second + defaultURLExpiry = time.Hour ) func main() { @@ -35,6 +37,9 @@ func main() { } store := store.New(dbConn) + // Create a Youtube client + var youtubeClient youtube.Client + // Create an Amazon S3 service s3Client cfg, err := awsconfig.LoadDefaultConfig( ctx, @@ -47,23 +52,39 @@ func main() { log.Fatal(err) } s3Client := s3.NewFromConfig(cfg) - - // Create an Amazon S3 presign client s3PresignClient := s3.NewPresignClient(s3Client) - // Create a Youtube client - var youtubeClient youtube.Client + // Create a logger + logger, err := buildLogger(config) + if err != nil { + log.Fatal(err) + } + defer logger.Sync() - serverOptions := server.Options{ - Config: config, - Timeout: DefaultTimeout, - Store: store, - YoutubeClient: &youtubeClient, - S3API: media.S3API{ + // Create a file store + fileStore := filestore.NewS3FileStore( + filestore.S3API{ S3Client: s3Client, S3PresignClient: s3PresignClient, }, - } + config.S3Bucket, + defaultURLExpiry, + logger.Sugar().Named("filestore"), + ) - log.Fatal(server.Start(serverOptions)) + log.Fatal(server.Start(server.Options{ + Config: config, + Timeout: defaultTimeout, + Store: store, + YoutubeClient: &youtubeClient, + FileStore: fileStore, + Logger: logger, + })) +} + +func buildLogger(c config.Config) (*zap.Logger, error) { + if c.Environment == config.EnvDevelopment { + return zap.NewDevelopment() + } + return zap.NewProduction() } diff --git a/backend/filestore/s3.go b/backend/filestore/s3.go new file mode 100644 index 0000000..6af8eeb --- /dev/null +++ b/backend/filestore/s3.go @@ -0,0 +1,280 @@ +package filestore + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "sort" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "go.uber.org/zap" +) + +// S3API provides an API to AWS S3. +type S3API struct { + S3Client + S3PresignClient +} + +// S3Client wraps the AWS S3 service client. +type S3Client interface { + GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error) + CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) + UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error) + AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) + CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) +} + +// S3PresignClient wraps the AWS S3 Presign client. +type S3PresignClient interface { + PresignGetObject(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) (*signerv4.PresignedHTTPRequest, error) +} + +// S3FileStore stores files on Amazon S3. +type S3FileStore struct { + s3 S3API + bucket string + urlExpiry time.Duration + logger *zap.SugaredLogger +} + +// NewS3FileStore builds a new S3FileStore using the provided configuration. +func NewS3FileStore(s3API S3API, bucket string, urlExpiry time.Duration, logger *zap.SugaredLogger) *S3FileStore { + return &S3FileStore{ + s3: s3API, + bucket: bucket, + urlExpiry: urlExpiry, + logger: logger, + } +} + +// GetObject returns an io.Reader that returns an object associated with the +// provided key. +func (s *S3FileStore) GetObject(ctx context.Context, key string) (io.ReadCloser, error) { + input := s3.GetObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(key), + } + output, err := s.s3.GetObject(ctx, &input) + if err != nil { + return nil, fmt.Errorf("error getting object from s3: %v", err) + } + return output.Body, nil +} + +// GetObjectWithRange returns an io.Reader that returns a partial object +// associated with the provided key. +func (s *S3FileStore) GetObjectWithRange(ctx context.Context, key string, start, end int64) (io.ReadCloser, error) { + byteRange := fmt.Sprintf("bytes=%d-%d", start, end) + input := s3.GetObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(key), + Range: aws.String(byteRange), + } + output, err := s.s3.GetObject(ctx, &input) + if err != nil { + return nil, fmt.Errorf("error getting object from s3: %v", err) + } + return output.Body, nil +} + +// GetURL returns a presigned URL pointing to the object associated with the +// provided key. +func (s *S3FileStore) GetURL(ctx context.Context, key string) (string, error) { + input := s3.GetObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(key), + } + request, err := s.s3.PresignGetObject(ctx, &input, s3.WithPresignExpires(s.urlExpiry)) + if err != nil { + return "", fmt.Errorf("error generating presigned URL: %v", err) + } + return request.URL, nil +} + +// PutObject uploads an object using multipart upload, returning the number of +// bytes uploaded and any error. +func (s *S3FileStore) PutObject(ctx context.Context, key string, r io.Reader, contentType string) (int64, error) { + const ( + targetPartSizeBytes = 5 * 1024 * 1024 // 5MB + readBufferSizeBytes = 32_768 // 32Kb + ) + + var uploaded bool + + input := s3.CreateMultipartUploadInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(key), + ContentType: aws.String(contentType), + } + output, err := s.s3.CreateMultipartUpload(ctx, &input) + if err != nil { + return 0, fmt.Errorf("error creating multipart upload: %v", err) + } + + // abort the upload if possible, logging any errors, on exit. + defer func() { + if uploaded { + return + } + input := s3.AbortMultipartUploadInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(key), + UploadId: output.UploadId, + } + + // if the context was cancelled, just use the background context. + ctxToUse := ctx + if ctxToUse.Err() != nil { + ctxToUse = context.Background() + } + + _, deferErr := s.s3.AbortMultipartUpload(ctxToUse, &input) + if deferErr != nil { + s.logger.Errorf("uploader: error aborting upload: %v", deferErr) + } else { + s.logger.Infof("aborted upload, key = %s", key) + } + }() + + type uploadedPart struct { + part types.CompletedPart + size int64 + } + uploadResultChan := make(chan uploadedPart) + uploadErrorChan := make(chan error, 1) + + // uploadPart uploads an individual part. + uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) { + defer wg.Done() + + partLen := int64(len(buf)) + s.logger.With("key", key, "partNum", partNum, "partLen", partLen).Debug("uploading part") + + input := s3.UploadPartInput{ + Body: bytes.NewReader(buf), + Bucket: aws.String(s.bucket), + Key: aws.String(key), + PartNumber: partNum, + UploadId: output.UploadId, + ContentLength: partLen, + } + + output, uploadErr := s.s3.UploadPart(ctx, &input) + if uploadErr != nil { + // TODO: retry on failure + uploadErrorChan <- uploadErr + return + } + + s.logger.With("key", key, "partNum", partNum, "partLen", partLen, "etag", *output.ETag).Debug("uploaded part") + + uploadResultChan <- uploadedPart{ + part: types.CompletedPart{ETag: output.ETag, PartNumber: partNum}, + size: partLen, + } + } + + wgDone := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) // done when the reader goroutine returns + go func() { + wg.Wait() + wgDone <- struct{}{} + }() + + readChan := make(chan error, 1) + + go func() { + defer wg.Done() + + var closing bool + currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+readBufferSizeBytes)) + partNum := int32(1) + buf := make([]byte, readBufferSizeBytes) + + for { + n, readErr := r.Read(buf) + if readErr == io.EOF { + closing = true + } else if readErr != nil { + readChan <- readErr + return + } + + _, _ = currPart.Write(buf[:n]) + if closing || currPart.Len() >= targetPartSizeBytes { + part := make([]byte, currPart.Len()) + copy(part, currPart.Bytes()) + currPart.Truncate(0) + + wg.Add(1) + go uploadPart(&wg, part, partNum) + partNum++ + } + + if closing { + return + } + } + }() + + results := make([]uploadedPart, 0, 64) + +outer: + for { + select { + case readErr := <-readChan: + if readErr != io.EOF { + return 0, fmt.Errorf("reader error: %v", readErr) + } + case uploadResult := <-uploadResultChan: + results = append(results, uploadResult) + case uploadErr := <-uploadErrorChan: + return 0, fmt.Errorf("error while uploading part: %v", uploadErr) + case <-ctx.Done(): + return 0, ctx.Err() + case <-wgDone: + break outer + } + } + + if len(results) == 0 { + return 0, errors.New("no parts available to upload") + } + + completedParts := make([]types.CompletedPart, 0, 64) + var uploadedBytes int64 + for _, result := range results { + completedParts = append(completedParts, result.part) + uploadedBytes += result.size + } + + // the parts may be out of order, especially with slow network conditions: + sort.Slice(completedParts, func(i, j int) bool { + return completedParts[i].PartNumber < completedParts[j].PartNumber + }) + + completeInput := s3.CompleteMultipartUploadInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(key), + UploadId: output.UploadId, + MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts}, + } + + if _, err = s.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil { + return 0, fmt.Errorf("error completing upload: %v", err) + } + + s.logger.With("key", key, "numParts", len(completedParts), "len", uploadedBytes).Debug("completed upload") + uploaded = true + + return uploadedBytes, nil +} diff --git a/backend/generated/mocks/FileStore.go b/backend/generated/mocks/FileStore.go new file mode 100644 index 0000000..12eb08f --- /dev/null +++ b/backend/generated/mocks/FileStore.go @@ -0,0 +1,103 @@ +// Code generated by mockery v2.9.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + io "io" + + mock "github.com/stretchr/testify/mock" +) + +// FileStore is an autogenerated mock type for the FileStore type +type FileStore struct { + mock.Mock +} + +// GetObject provides a mock function with given fields: _a0, _a1 +func (_m *FileStore) GetObject(_a0 context.Context, _a1 string) (io.ReadCloser, error) { + ret := _m.Called(_a0, _a1) + + var r0 io.ReadCloser + if rf, ok := ret.Get(0).(func(context.Context, string) io.ReadCloser); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.ReadCloser) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetObjectWithRange provides a mock function with given fields: _a0, _a1, _a2, _a3 +func (_m *FileStore) GetObjectWithRange(_a0 context.Context, _a1 string, _a2 int64, _a3 int64) (io.ReadCloser, error) { + ret := _m.Called(_a0, _a1, _a2, _a3) + + var r0 io.ReadCloser + if rf, ok := ret.Get(0).(func(context.Context, string, int64, int64) io.ReadCloser); ok { + r0 = rf(_a0, _a1, _a2, _a3) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.ReadCloser) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, int64, int64) error); ok { + r1 = rf(_a0, _a1, _a2, _a3) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetURL provides a mock function with given fields: _a0, _a1 +func (_m *FileStore) GetURL(_a0 context.Context, _a1 string) (string, error) { + ret := _m.Called(_a0, _a1) + + var r0 string + if rf, ok := ret.Get(0).(func(context.Context, string) string); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Get(0).(string) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PutObject provides a mock function with given fields: _a0, _a1, _a2, _a3 +func (_m *FileStore) PutObject(_a0 context.Context, _a1 string, _a2 io.Reader, _a3 string) (int64, error) { + ret := _m.Called(_a0, _a1, _a2, _a3) + + var r0 int64 + if rf, ok := ret.Get(0).(func(context.Context, string, io.Reader, string) int64); ok { + r0 = rf(_a0, _a1, _a2, _a3) + } else { + r0 = ret.Get(0).(int64) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, io.Reader, string) error); ok { + r1 = rf(_a0, _a1, _a2, _a3) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/backend/generated/mocks/S3Client.go b/backend/generated/mocks/S3Client.go deleted file mode 100644 index 7bcf177..0000000 --- a/backend/generated/mocks/S3Client.go +++ /dev/null @@ -1,166 +0,0 @@ -// Code generated by mockery v2.9.4. DO NOT EDIT. - -package mocks - -import ( - context "context" - - mock "github.com/stretchr/testify/mock" - - s3 "github.com/aws/aws-sdk-go-v2/service/s3" -) - -// S3Client is an autogenerated mock type for the S3Client type -type S3Client struct { - mock.Mock -} - -// AbortMultipartUpload provides a mock function with given fields: _a0, _a1, _a2 -func (_m *S3Client) AbortMultipartUpload(_a0 context.Context, _a1 *s3.AbortMultipartUploadInput, _a2 ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) { - _va := make([]interface{}, len(_a2)) - for _i := range _a2 { - _va[_i] = _a2[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0, _a1) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 *s3.AbortMultipartUploadOutput - if rf, ok := ret.Get(0).(func(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) *s3.AbortMultipartUploadOutput); ok { - r0 = rf(_a0, _a1, _a2...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*s3.AbortMultipartUploadOutput) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) error); ok { - r1 = rf(_a0, _a1, _a2...) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// CompleteMultipartUpload provides a mock function with given fields: _a0, _a1, _a2 -func (_m *S3Client) CompleteMultipartUpload(_a0 context.Context, _a1 *s3.CompleteMultipartUploadInput, _a2 ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) { - _va := make([]interface{}, len(_a2)) - for _i := range _a2 { - _va[_i] = _a2[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0, _a1) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 *s3.CompleteMultipartUploadOutput - if rf, ok := ret.Get(0).(func(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) *s3.CompleteMultipartUploadOutput); ok { - r0 = rf(_a0, _a1, _a2...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*s3.CompleteMultipartUploadOutput) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) error); ok { - r1 = rf(_a0, _a1, _a2...) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// CreateMultipartUpload provides a mock function with given fields: _a0, _a1, _a2 -func (_m *S3Client) CreateMultipartUpload(_a0 context.Context, _a1 *s3.CreateMultipartUploadInput, _a2 ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) { - _va := make([]interface{}, len(_a2)) - for _i := range _a2 { - _va[_i] = _a2[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0, _a1) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 *s3.CreateMultipartUploadOutput - if rf, ok := ret.Get(0).(func(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) *s3.CreateMultipartUploadOutput); ok { - r0 = rf(_a0, _a1, _a2...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*s3.CreateMultipartUploadOutput) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) error); ok { - r1 = rf(_a0, _a1, _a2...) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetObject provides a mock function with given fields: _a0, _a1, _a2 -func (_m *S3Client) GetObject(_a0 context.Context, _a1 *s3.GetObjectInput, _a2 ...func(*s3.Options)) (*s3.GetObjectOutput, error) { - _va := make([]interface{}, len(_a2)) - for _i := range _a2 { - _va[_i] = _a2[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0, _a1) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 *s3.GetObjectOutput - if rf, ok := ret.Get(0).(func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) *s3.GetObjectOutput); ok { - r0 = rf(_a0, _a1, _a2...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*s3.GetObjectOutput) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) error); ok { - r1 = rf(_a0, _a1, _a2...) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// UploadPart provides a mock function with given fields: _a0, _a1, _a2 -func (_m *S3Client) UploadPart(_a0 context.Context, _a1 *s3.UploadPartInput, _a2 ...func(*s3.Options)) (*s3.UploadPartOutput, error) { - _va := make([]interface{}, len(_a2)) - for _i := range _a2 { - _va[_i] = _a2[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0, _a1) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 *s3.UploadPartOutput - if rf, ok := ret.Get(0).(func(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) *s3.UploadPartOutput); ok { - r0 = rf(_a0, _a1, _a2...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*s3.UploadPartOutput) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) error); ok { - r1 = rf(_a0, _a1, _a2...) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} diff --git a/backend/generated/mocks/S3PresignClient.go b/backend/generated/mocks/S3PresignClient.go deleted file mode 100644 index 2a35f49..0000000 --- a/backend/generated/mocks/S3PresignClient.go +++ /dev/null @@ -1,48 +0,0 @@ -// Code generated by mockery v2.9.4. DO NOT EDIT. - -package mocks - -import ( - context "context" - - mock "github.com/stretchr/testify/mock" - - s3 "github.com/aws/aws-sdk-go-v2/service/s3" - - v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" -) - -// S3PresignClient is an autogenerated mock type for the S3PresignClient type -type S3PresignClient struct { - mock.Mock -} - -// PresignGetObject provides a mock function with given fields: _a0, _a1, _a2 -func (_m *S3PresignClient) PresignGetObject(_a0 context.Context, _a1 *s3.GetObjectInput, _a2 ...func(*s3.PresignOptions)) (*v4.PresignedHTTPRequest, error) { - _va := make([]interface{}, len(_a2)) - for _i := range _a2 { - _va[_i] = _a2[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0, _a1) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 *v4.PresignedHTTPRequest - if rf, ok := ret.Get(0).(func(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) *v4.PresignedHTTPRequest); ok { - r0 = rf(_a0, _a1, _a2...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*v4.PresignedHTTPRequest) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) error); ok { - r1 = rf(_a0, _a1, _a2...) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} diff --git a/backend/media/get_audio.go b/backend/media/get_audio.go index d26ea06..7f5ff36 100644 --- a/backend/media/get_audio.go +++ b/backend/media/get_audio.go @@ -13,8 +13,6 @@ import ( "git.netflux.io/rob/clipper/config" "git.netflux.io/rob/clipper/generated/store" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/s3" "go.uber.org/zap" ) @@ -31,27 +29,27 @@ type GetAudioProgressReader interface { // audioGetter manages getting and processing audio from Youtube. type audioGetter struct { - store Store - youtube YoutubeClient - s3API S3API - config config.Config - logger *zap.SugaredLogger + store Store + youtube YoutubeClient + fileStore FileStore + config config.Config + logger *zap.SugaredLogger } // newAudioGetter returns a new audioGetter. -func newAudioGetter(store Store, youtube YoutubeClient, s3API S3API, config config.Config, logger *zap.SugaredLogger) *audioGetter { +func newAudioGetter(store Store, youtube YoutubeClient, fileStore FileStore, config config.Config, logger *zap.SugaredLogger) *audioGetter { return &audioGetter{ - store: store, - youtube: youtube, - s3API: s3API, - config: config, - logger: logger, + store: store, + youtube: youtube, + fileStore: fileStore, + config: config, + logger: logger, } } -// GetAudio gets the audio, processes it and uploads it to S3. It returns a -// GetAudioProgressReader that can be used to poll progress reports and audio -// peaks. +// GetAudio gets the audio, processes it and uploads it to a file store. It +// returns a GetAudioProgressReader that can be used to poll progress reports +// and audio peaks. // // TODO: accept domain object instead func (g *audioGetter) GetAudio(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) { @@ -114,33 +112,28 @@ func (s *audioGetterState) getAudio(ctx context.Context, r io.ReadCloser, mediaS wg.Add(2) // Upload the encoded audio. + // TODO: fix error shadowing in these two goroutines. go func() { defer wg.Done() - // TODO: use mediaSet func to fetch s3Key - s3Key := fmt.Sprintf("media_sets/%s/audio.opus", mediaSet.ID) + // TODO: use mediaSet func to fetch key + key := fmt.Sprintf("media_sets/%s/audio.opus", mediaSet.ID) - uploader := newMultipartUploader(s.s3API, s.logger) - _, encErr := uploader.Upload(ctx, pr, s.config.S3Bucket, s3Key, "audio/opus") + _, encErr := s.fileStore.PutObject(ctx, key, pr, "audio/opus") if encErr != nil { s.CloseWithError(fmt.Errorf("error uploading encoded audio: %v", encErr)) return } - input := s3.GetObjectInput{ - Bucket: aws.String(s.config.S3Bucket), - Key: aws.String(s3Key), - } - request, err := s.s3API.PresignGetObject(ctx, &input, s3.WithPresignExpires(getAudioExpiresIn)) + presignedAudioURL, err = s.fileStore.GetURL(ctx, key) if err != nil { s.CloseWithError(fmt.Errorf("error generating presigned URL: %v", err)) } - presignedAudioURL = request.URL if _, err = s.store.SetEncodedAudioUploaded(ctx, store.SetEncodedAudioUploadedParams{ ID: mediaSet.ID, AudioEncodedS3Bucket: sqlString(s.config.S3Bucket), - AudioEncodedS3Key: sqlString(s3Key), + AudioEncodedS3Key: sqlString(key), }); err != nil { s.CloseWithError(fmt.Errorf("error setting encoded audio uploaded: %v", err)) } @@ -150,12 +143,11 @@ func (s *audioGetterState) getAudio(ctx context.Context, r io.ReadCloser, mediaS go func() { defer wg.Done() - // TODO: use mediaSet func to fetch s3Key - s3Key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID) + // TODO: use mediaSet func to fetch key + key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID) teeReader := io.TeeReader(stdout, s) - uploader := newMultipartUploader(s.s3API, s.logger) - bytesUploaded, rawErr := uploader.Upload(ctx, teeReader, s.config.S3Bucket, s3Key, rawAudioMimeType) + bytesUploaded, rawErr := s.fileStore.PutObject(ctx, key, teeReader, rawAudioMimeType) if rawErr != nil { s.CloseWithError(fmt.Errorf("error uploading raw audio: %v", rawErr)) return @@ -164,7 +156,7 @@ func (s *audioGetterState) getAudio(ctx context.Context, r io.ReadCloser, mediaS if _, err = s.store.SetRawAudioUploaded(ctx, store.SetRawAudioUploadedParams{ ID: mediaSet.ID, AudioRawS3Bucket: sqlString(s.config.S3Bucket), - AudioRawS3Key: sqlString(s3Key), + AudioRawS3Key: sqlString(key), AudioFrames: sqlInt64(bytesUploaded / SizeOfInt16 / int64(mediaSet.AudioChannels)), }); err != nil { s.CloseWithError(fmt.Errorf("error setting raw audio uploaded: %v", err)) diff --git a/backend/media/get_video.go b/backend/media/get_video.go index 7091284..9fb7ff2 100644 --- a/backend/media/get_video.go +++ b/backend/media/get_video.go @@ -6,8 +6,6 @@ import ( "io" "git.netflux.io/rob/clipper/generated/store" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/google/uuid" "go.uber.org/zap" ) @@ -24,9 +22,9 @@ type GetVideoProgressReader interface { } type videoGetter struct { - s3 S3API - store Store - logger *zap.SugaredLogger + store Store + fileStore FileStore + logger *zap.SugaredLogger } type videoGetterState struct { @@ -41,21 +39,20 @@ type videoGetterState struct { errorChan chan error } -func newVideoGetter(s3 S3API, store Store, logger *zap.SugaredLogger) *videoGetter { - return &videoGetter{s3: s3, store: store, logger: logger} +func newVideoGetter(store Store, fileStore FileStore, logger *zap.SugaredLogger) *videoGetter { + return &videoGetter{store: store, fileStore: fileStore, logger: logger} } -// GetVideo gets video from Youtube and uploads it to S3 using the specified -// bucket, key and content type. The returned reader must have its Next() +// GetVideo gets video from Youtube and uploads it to a filestore using the +// specified key and content type. The returned reader must have its Next() // method called until error = io.EOF, otherwise a deadlock or other resource // leakage is likely. -func (g *videoGetter) GetVideo(ctx context.Context, r io.Reader, exp int64, mediaSetID uuid.UUID, bucket, key, contentType string) (GetVideoProgressReader, error) { +func (g *videoGetter) GetVideo(ctx context.Context, r io.Reader, exp int64, mediaSetID uuid.UUID, key, contentType string) (GetVideoProgressReader, error) { s := &videoGetterState{ videoGetter: g, r: newProgressReader(r, "video", exp, g.logger), exp: exp, mediaSetID: mediaSetID, - bucket: bucket, key: key, contentType: contentType, progressChan: make(chan GetVideoProgress), @@ -69,7 +66,7 @@ func (g *videoGetter) GetVideo(ctx context.Context, r io.Reader, exp int64, medi } // Write implements io.Writer. It is copied that same data that is written to -// S3, to implement progress tracking. +// the file store, to implement progress tracking. func (s *videoGetterState) Write(p []byte) (int, error) { s.count += int64(len(p)) pc := (float32(s.count) / float32(s.exp)) * 100 @@ -78,24 +75,18 @@ func (s *videoGetterState) Write(p []byte) (int, error) { } func (s *videoGetterState) getVideo(ctx context.Context) { - uploader := newMultipartUploader(s.s3, s.logger) teeReader := io.TeeReader(s.r, s) - _, err := uploader.Upload(ctx, teeReader, s.bucket, s.key, s.contentType) + _, err := s.fileStore.PutObject(ctx, s.key, teeReader, s.contentType) if err != nil { - s.errorChan <- fmt.Errorf("error uploading to S3: %v", err) + s.errorChan <- fmt.Errorf("error uploading to file store: %v", err) return } - input := s3.GetObjectInput{ - Bucket: aws.String(s.bucket), - Key: aws.String(s.key), - } - request, err := s.s3.PresignGetObject(ctx, &input, s3.WithPresignExpires(getVideoExpiresIn)) + s.url, err = s.fileStore.GetURL(ctx, s.key) if err != nil { - s.errorChan <- fmt.Errorf("error generating presigned URL: %v", err) + s.errorChan <- fmt.Errorf("error getting object URL: %v", err) } - s.url = request.URL storeParams := store.SetVideoUploadedParams{ ID: s.mediaSetID, diff --git a/backend/media/service.go b/backend/media/service.go index 5e6e241..c746399 100644 --- a/backend/media/service.go +++ b/backend/media/service.go @@ -14,8 +14,6 @@ import ( "git.netflux.io/rob/clipper/config" "git.netflux.io/rob/clipper/generated/store" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/google/uuid" "github.com/jackc/pgx/v4" youtubev2 "github.com/kkdai/youtube/v2" @@ -41,27 +39,27 @@ const ( // MediaSetService exposes logical flows handling MediaSets. type MediaSetService struct { - store Store - youtube YoutubeClient - s3 S3API - config config.Config - logger *zap.SugaredLogger + store Store + youtube YoutubeClient + fileStore FileStore + config config.Config + logger *zap.SugaredLogger } -func NewMediaSetService(store Store, youtubeClient YoutubeClient, s3API S3API, config config.Config, logger *zap.Logger) *MediaSetService { +func NewMediaSetService(store Store, youtubeClient YoutubeClient, fileStore FileStore, config config.Config, logger *zap.SugaredLogger) *MediaSetService { return &MediaSetService{ - store: store, - youtube: youtubeClient, - s3: s3API, - config: config, - logger: logger.Sugar(), + store: store, + youtube: youtubeClient, + fileStore: fileStore, + config: config, + logger: logger, } } // Get fetches the metadata for a given MediaSet source. If it does not exist // in the local DB, it will attempt to create it. After the resource has been // created, other endpoints (e.g. GetAudio) can be called to fetch media from -// Youtube and store it in S3. +// Youtube and store it in a file store. func (s *MediaSetService) Get(ctx context.Context, youtubeID string) (*MediaSet, error) { var ( mediaSet *MediaSet @@ -220,15 +218,11 @@ func (s *MediaSetService) GetVideo(ctx context.Context, id uuid.UUID) (GetVideoP } if mediaSet.VideoS3UploadedAt.Valid { - input := s3.GetObjectInput{ - Bucket: aws.String(s.config.S3Bucket), - Key: aws.String(mediaSet.VideoS3Key.String), + url, err := s.fileStore.GetURL(ctx, mediaSet.VideoS3Key.String) + if err != nil { + return nil, fmt.Errorf("error generating presigned URL: %v", err) } - request, signErr := s.s3.PresignGetObject(ctx, &input, s3.WithPresignExpires(getVideoExpiresIn)) - if signErr != nil { - return nil, fmt.Errorf("error generating presigned URL: %v", signErr) - } - videoGetter := videoGetterDownloaded(request.URL) + videoGetter := videoGetterDownloaded(url) return &videoGetter, nil } @@ -247,17 +241,16 @@ func (s *MediaSetService) GetVideo(ctx context.Context, id uuid.UUID) (GetVideoP return nil, fmt.Errorf("error fetching stream: %v", err) } - // TODO: use mediaSet func to fetch s3Key - s3Key := fmt.Sprintf("media_sets/%s/video.mp4", mediaSet.ID) + // TODO: use mediaSet func to fetch videoKey + videoKey := fmt.Sprintf("media_sets/%s/video.mp4", mediaSet.ID) - videoGetter := newVideoGetter(s.s3, s.store, s.logger) + videoGetter := newVideoGetter(s.store, s.fileStore, s.logger) return videoGetter.GetVideo( ctx, stream, format.ContentLength, mediaSet.ID, - s.config.S3Bucket, - s3Key, + videoKey, format.MimeType, ) } @@ -273,25 +266,21 @@ func (s *MediaSetService) GetAudio(ctx context.Context, id uuid.UUID, numBins in // Otherwise, we cannot return both peaks and a presigned URL for use by the // player. if mediaSet.AudioRawS3UploadedAt.Valid && mediaSet.AudioEncodedS3UploadedAt.Valid { - return s.getAudioFromS3(ctx, mediaSet, numBins) + return s.getAudioFromFileStore(ctx, mediaSet, numBins) } return s.getAudioFromYoutube(ctx, mediaSet, numBins) } func (s *MediaSetService) getAudioFromYoutube(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) { - audioGetter := newAudioGetter(s.store, s.youtube, s.s3, s.config, s.logger) + audioGetter := newAudioGetter(s.store, s.youtube, s.fileStore, s.config, s.logger) return audioGetter.GetAudio(ctx, mediaSet, numBins) } -func (s *MediaSetService) getAudioFromS3(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) { - input := s3.GetObjectInput{ - Bucket: aws.String(mediaSet.AudioRawS3Bucket.String), - Key: aws.String(mediaSet.AudioRawS3Key.String), - } - output, err := s.s3.GetObject(ctx, &input) +func (s *MediaSetService) getAudioFromFileStore(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) { + object, err := s.fileStore.GetObject(ctx, mediaSet.AudioRawS3Key.String) if err != nil { - return nil, fmt.Errorf("error getting object from s3: %v", err) + return nil, fmt.Errorf("error getting object from file store: %v", err) } getAudioProgressReader, err := newGetAudioProgressReader( @@ -303,10 +292,10 @@ func (s *MediaSetService) getAudioFromS3(ctx context.Context, mediaSet store.Med return nil, fmt.Errorf("error creating audio reader: %v", err) } - state := getAudioFromS3State{ + state := getAudioFromFileStoreState{ getAudioProgressReader: getAudioProgressReader, - s3Reader: NewModuloBufReader(output.Body, int(mediaSet.AudioChannels)*SizeOfInt16), - s3API: s.s3, + reader: NewModuloBufReader(object, int(mediaSet.AudioChannels)*SizeOfInt16), + fileStore: s.fileStore, config: s.config, logger: s.logger, } @@ -315,21 +304,21 @@ func (s *MediaSetService) getAudioFromS3(ctx context.Context, mediaSet store.Med return &state, nil } -type getAudioFromS3State struct { +type getAudioFromFileStoreState struct { *getAudioProgressReader - s3Reader io.ReadCloser - s3API S3API - config config.Config - logger *zap.SugaredLogger + reader io.ReadCloser + fileStore FileStore + config config.Config + logger *zap.SugaredLogger } -func (s *getAudioFromS3State) run(ctx context.Context, mediaSet store.MediaSet) { +func (s *getAudioFromFileStoreState) run(ctx context.Context, mediaSet store.MediaSet) { done := make(chan error) var err error go func() { - _, copyErr := io.Copy(s, s.s3Reader) + _, copyErr := io.Copy(s, s.reader) done <- copyErr }() @@ -344,29 +333,25 @@ outer: } } - if readerErr := s.s3Reader.Close(); readerErr != nil { + if readerErr := s.reader.Close(); readerErr != nil { if err == nil { err = readerErr } } if err != nil { - s.logger.Errorf("getAudioFromS3State: error closing s3 reader: %v", err) + s.logger.Errorf("getAudioFromFileStoreState: error closing reader: %v", err) s.CloseWithError(err) return } - input := s3.GetObjectInput{ - Bucket: aws.String(s.config.S3Bucket), - Key: aws.String(mediaSet.AudioEncodedS3Key.String), - } - request, err := s.s3API.PresignGetObject(ctx, &input, s3.WithPresignExpires(getAudioExpiresIn)) + url, err := s.fileStore.GetURL(ctx, mediaSet.AudioEncodedS3Key.String) if err != nil { - s.CloseWithError(fmt.Errorf("error generating presigned URL: %v", err)) + s.CloseWithError(fmt.Errorf("error generating object URL: %v", err)) } - if iterErr := s.Close(request.URL); iterErr != nil { - s.logger.Errorf("getAudioFromS3State: error closing progress iterator: %v", iterErr) + if iterErr := s.Close(url); iterErr != nil { + s.logger.Errorf("getAudioFromFileStoreState: error closing progress iterator: %v", iterErr) } } @@ -381,25 +366,20 @@ func (s *MediaSetService) GetAudioSegment(ctx context.Context, id uuid.UUID, sta return nil, fmt.Errorf("error getting media set: %v", err) } - byteRange := fmt.Sprintf( - "bytes=%d-%d", + object, err := s.fileStore.GetObjectWithRange( + ctx, + mediaSet.AudioRawS3Key.String, startFrame*int64(mediaSet.AudioChannels)*SizeOfInt16, endFrame*int64(mediaSet.AudioChannels)*SizeOfInt16, ) - input := s3.GetObjectInput{ - Bucket: aws.String(mediaSet.AudioRawS3Bucket.String), - Key: aws.String(mediaSet.AudioRawS3Key.String), - Range: aws.String(byteRange), - } - output, err := s.s3.GetObject(ctx, &input) if err != nil { - return nil, fmt.Errorf("error getting object from s3: %v", err) + return nil, fmt.Errorf("error getting object from file store: %v", err) } - defer output.Body.Close() + defer object.Close() const readBufSizeBytes = 8_192 channels := int(mediaSet.AudioChannels) - modReader := NewModuloBufReader(output.Body, channels*SizeOfInt16) + modReader := NewModuloBufReader(object, channels*SizeOfInt16) readBuf := make([]byte, readBufSizeBytes) peaks := make([]int16, channels*numBins) totalFrames := endFrame - startFrame @@ -454,7 +434,7 @@ func (s *MediaSetService) GetAudioSegment(ctx context.Context, id uuid.UUID, sta } if bytesRead < bytesExpected { - s.logger.With("startFrame", startFrame, "endFrame", endFrame, "got", bytesRead, "want", bytesExpected, "key", mediaSet.AudioRawS3Key.String).Warn("short read from S3") + s.logger.With("startFrame", startFrame, "endFrame", endFrame, "got", bytesRead, "want", bytesExpected, "key", mediaSet.AudioRawS3Key.String).Warn("short read from file store") } return peaks, nil @@ -521,26 +501,22 @@ func (s *MediaSetService) GetVideoThumbnail(ctx context.Context, id uuid.UUID) ( } if mediaSet.VideoThumbnailS3UploadedAt.Valid { - return s.getThumbnailFromS3(ctx, mediaSet) + return s.getThumbnailFromFileStore(ctx, mediaSet) } return s.getThumbnailFromYoutube(ctx, mediaSet) } -func (s *MediaSetService) getThumbnailFromS3(ctx context.Context, mediaSet store.MediaSet) (VideoThumbnail, error) { - input := s3.GetObjectInput{ - Bucket: aws.String(mediaSet.VideoThumbnailS3Bucket.String), - Key: aws.String(mediaSet.VideoThumbnailS3Key.String), - } - output, err := s.s3.GetObject(ctx, &input) +func (s *MediaSetService) getThumbnailFromFileStore(ctx context.Context, mediaSet store.MediaSet) (VideoThumbnail, error) { + object, err := s.fileStore.GetObject(ctx, mediaSet.VideoThumbnailS3Key.String) if err != nil { - return VideoThumbnail{}, fmt.Errorf("error fetching thumbnail from s3: %v", err) + return VideoThumbnail{}, fmt.Errorf("error fetching thumbnail from file store: %v", err) } - defer output.Body.Close() + defer object.Close() - imageData, err := io.ReadAll(output.Body) + imageData, err := io.ReadAll(object) if err != nil { - return VideoThumbnail{}, fmt.Errorf("error reading thumbnail from s3: %v", err) + return VideoThumbnail{}, fmt.Errorf("error reading thumbnail from file store: %v", err) } return VideoThumbnail{ @@ -575,13 +551,11 @@ func (s *MediaSetService) getThumbnailFromYoutube(ctx context.Context, mediaSet return VideoThumbnail{}, fmt.Errorf("error reading thumbnail: %v", err) } - // TODO: use mediaSet func to fetch s3Key - s3Key := fmt.Sprintf("media_sets/%s/thumbnail.jpg", mediaSet.ID) + // TODO: use mediaSet func to fetch key + thumbnailKey := fmt.Sprintf("media_sets/%s/thumbnail.jpg", mediaSet.ID) - uploader := newMultipartUploader(s.s3, s.logger) const mimeType = "application/jpeg" - - _, err = uploader.Upload(ctx, bytes.NewReader(imageData), s.config.S3Bucket, s3Key, mimeType) + _, err = s.fileStore.PutObject(ctx, thumbnailKey, bytes.NewReader(imageData), mimeType) if err != nil { return VideoThumbnail{}, fmt.Errorf("error uploading thumbnail: %v", err) } @@ -590,7 +564,7 @@ func (s *MediaSetService) getThumbnailFromYoutube(ctx context.Context, mediaSet ID: mediaSet.ID, VideoThumbnailMimeType: sqlString(mimeType), VideoThumbnailS3Bucket: sqlString(s.config.S3Bucket), - VideoThumbnailS3Key: sqlString(s3Key), + VideoThumbnailS3Key: sqlString(thumbnailKey), VideoThumbnailWidth: sqlInt32(int32(thumbnail.Width)), VideoThumbnailHeight: sqlInt32(int32(thumbnail.Height)), } @@ -614,7 +588,7 @@ func newProgressReader(reader io.Reader, label string, exp int64, logger *zap.Su return &progressReader{ Reader: reader, exp: exp, - logger: logger.Named(fmt.Sprintf("ProgressReader %s", label)), + logger: logger.Named(label), } } diff --git a/backend/media/service_test.go b/backend/media/service_test.go index 89b7033..b496181 100644 --- a/backend/media/service_test.go +++ b/backend/media/service_test.go @@ -1,8 +1,9 @@ package media_test import ( + "bytes" "context" - "fmt" + "database/sql" "io" "os" "testing" @@ -11,7 +12,6 @@ import ( "git.netflux.io/rob/clipper/generated/mocks" "git.netflux.io/rob/clipper/generated/store" "git.netflux.io/rob/clipper/media" - "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -84,31 +84,34 @@ func TestGetAudioSegment(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - expectedBytes := (tc.endFrame - tc.startFrame) * int64(tc.channels) * media.SizeOfInt16 + startByte := tc.startFrame * int64(tc.channels) * media.SizeOfInt16 + endByte := tc.endFrame * int64(tc.channels) * media.SizeOfInt16 + expectedBytes := endByte - startByte audioFile, err := os.Open(tc.fixturePath) require.NoError(t, err) defer audioFile.Close() audioData := io.NopCloser(io.LimitReader(audioFile, int64(expectedBytes))) - mediaSetID := uuid.New() - mediaSet := store.MediaSet{ID: mediaSetID, AudioChannels: tc.channels} + mediaSet := store.MediaSet{ + ID: uuid.New(), + AudioChannels: tc.channels, + AudioRawS3Key: sql.NullString{String: "foo", Valid: true}, + } // store is passed the mediaSetID and returns a mediaSet store := &mocks.Store{} - store.On("GetMediaSet", mock.Anything, mediaSetID).Return(mediaSet, nil) + store.On("GetMediaSet", mock.Anything, mediaSet.ID).Return(mediaSet, nil) defer store.AssertExpectations(t) - // S3 is passed the expected byte range, and returns an io.Reader - s3Client := &mocks.S3Client{} - s3Client.On("GetObject", mock.Anything, mock.MatchedBy(func(input *s3.GetObjectInput) bool { - return *input.Range == fmt.Sprintf("bytes=0-%d", expectedBytes) - })).Return(&s3.GetObjectOutput{Body: audioData, ContentLength: tc.fixtureLen}, nil) - defer s3Client.AssertExpectations(t) - s3API := media.S3API{S3Client: s3Client, S3PresignClient: &mocks.S3PresignClient{}} + // fileStore is passed the expected byte range, and returns an io.Reader + fileStore := &mocks.FileStore{} + fileStore. + On("GetObjectWithRange", mock.Anything, "foo", startByte, endByte). + Return(audioData, nil) - service := media.NewMediaSetService(store, nil, s3API, config.Config{}, zap.NewNop()) - peaks, err := service.GetAudioSegment(context.Background(), mediaSetID, tc.startFrame, tc.endFrame, tc.numBins) + service := media.NewMediaSetService(store, nil, fileStore, config.Config{}, zap.NewNop().Sugar()) + peaks, err := service.GetAudioSegment(context.Background(), mediaSet.ID, tc.startFrame, tc.endFrame, tc.numBins) if tc.wantErr == "" { assert.NoError(t, err) @@ -130,35 +133,31 @@ func BenchmarkGetAudioSegment(b *testing.B) { numBins = 2000 ) + audioFile, err := os.Open(fixturePath) + require.NoError(b, err) + audioData, err := io.ReadAll(audioFile) + require.NoError(b, err) + + mediaSetID := uuid.New() + mediaSet := store.MediaSet{ID: mediaSetID, AudioChannels: channels} + + store := &mocks.Store{} + store.On("GetMediaSet", mock.Anything, mediaSetID).Return(mediaSet, nil) + for n := 0; n < b.N; n++ { + // recreate the reader on each iteration b.StopTimer() + readCloser := io.NopCloser(bytes.NewReader(audioData)) + fileStore := &mocks.FileStore{} + fileStore. + On("GetObjectWithRange", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(readCloser, nil) - expectedBytes := (endFrame - startFrame) * int64(channels) * media.SizeOfInt16 - - audioFile, err := os.Open(fixturePath) - require.NoError(b, err) - audioData := io.NopCloser(io.LimitReader(audioFile, int64(expectedBytes))) - - mediaSetID := uuid.New() - mediaSet := store.MediaSet{ID: mediaSetID, AudioChannels: channels} - - store := &mocks.Store{} - store.On("GetMediaSet", mock.Anything, mediaSetID).Return(mediaSet, nil) - - s3Client := &mocks.S3Client{} - s3Client. - On("GetObject", mock.Anything, mock.Anything). - Return(&s3.GetObjectOutput{Body: audioData, ContentLength: fixtureLen}, nil) - s3API := media.S3API{S3Client: s3Client, S3PresignClient: &mocks.S3PresignClient{}} - - service := media.NewMediaSetService(store, nil, s3API, config.Config{}, zap.NewNop()) - + service := media.NewMediaSetService(store, nil, fileStore, config.Config{}, zap.NewNop().Sugar()) b.StartTimer() - _, err = service.GetAudioSegment(context.Background(), mediaSetID, startFrame, endFrame, numBins) - b.StopTimer() + _, err = service.GetAudioSegment(context.Background(), mediaSetID, startFrame, endFrame, numBins) require.NoError(b, err) - audioFile.Close() } } diff --git a/backend/media/types.go b/backend/media/types.go index 9839eca..429e25b 100644 --- a/backend/media/types.go +++ b/backend/media/types.go @@ -1,8 +1,8 @@ package media -//go:generate mockery --recursive --name S3* --output ../generated/mocks //go:generate mockery --recursive --name Store --output ../generated/mocks //go:generate mockery --recursive --name YoutubeClient --output ../generated/mocks +//go:generate mockery --recursive --name FileStore --output ../generated/mocks import ( "context" @@ -10,8 +10,6 @@ import ( "time" "git.netflux.io/rob/clipper/generated/store" - signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" - "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/google/uuid" youtubev2 "github.com/kkdai/youtube/v2" ) @@ -63,28 +61,16 @@ type Store interface { SetVideoThumbnailUploaded(context.Context, store.SetVideoThumbnailUploadedParams) (store.MediaSet, error) } -// S3API provides an API to AWS S3. -type S3API struct { - S3Client - S3PresignClient -} - -// S3Client wraps the AWS S3 service client. -type S3Client interface { - GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error) - CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) - UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error) - AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) - CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) -} - -// S3PresignClient wraps the AWS S3 Presign client. -type S3PresignClient interface { - PresignGetObject(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) (*signerv4.PresignedHTTPRequest, error) +// FileStore wraps a file store. +type FileStore interface { + GetObject(ctx context.Context, key string) (io.ReadCloser, error) + GetObjectWithRange(ctx context.Context, key string, startFrame, endFrame int64) (io.ReadCloser, error) + GetURL(ctx context.Context, key string) (string, error) + PutObject(ctx context.Context, key string, reader io.Reader, contentType string) (int64, error) } // YoutubeClient wraps the youtube.Client client. type YoutubeClient interface { - GetVideoContext(context.Context, string) (*youtubev2.Video, error) - GetStreamContext(context.Context, *youtubev2.Video, *youtubev2.Format) (io.ReadCloser, int64, error) + GetVideoContext(ctx context.Context, id string) (*youtubev2.Video, error) + GetStreamContext(ctx context.Context, video *youtubev2.Video, format *youtubev2.Format) (io.ReadCloser, int64, error) } diff --git a/backend/media/uploader.go b/backend/media/uploader.go deleted file mode 100644 index b80ca5e..0000000 --- a/backend/media/uploader.go +++ /dev/null @@ -1,212 +0,0 @@ -package media - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "sort" - "sync" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/s3" - "github.com/aws/aws-sdk-go-v2/service/s3/types" - "go.uber.org/zap" -) - -// multipartUploader uploads a file to S3. -// -// TODO: extract to s3 package -type multipartUploader struct { - s3 S3Client - logger *zap.SugaredLogger -} - -type uploadResult struct { - completedPart types.CompletedPart - size int64 -} - -const ( - targetPartSizeBytes = 5 * 1024 * 1024 // 5MB - readBufferSizeBytes = 32_768 // 32Kb -) - -func newMultipartUploader(s3Client S3Client, logger *zap.SugaredLogger) *multipartUploader { - return &multipartUploader{s3: s3Client, logger: logger} -} - -// Upload uploads to an S3 bucket in 5MB parts. It buffers data internally -// until a part is ready to send over the network. Parts are sent as soon as -// they exceed the minimum part size of 5MB. -// -// TODO: expire after configurable period. -func (u *multipartUploader) Upload(ctx context.Context, r io.Reader, bucket, key, contentType string) (int64, error) { - var uploaded bool - - input := s3.CreateMultipartUploadInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - ContentType: aws.String(contentType), - } - output, err := u.s3.CreateMultipartUpload(ctx, &input) - if err != nil { - return 0, fmt.Errorf("error creating multipart upload: %v", err) - } - - // abort the upload if possible, logging any errors, on exit. - defer func() { - if uploaded { - return - } - input := s3.AbortMultipartUploadInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - UploadId: output.UploadId, - } - - // if the context was cancelled, just use the background context. - ctxToUse := ctx - if ctxToUse.Err() != nil { - ctxToUse = context.Background() - } - - _, abortErr := u.s3.AbortMultipartUpload(ctxToUse, &input) - if abortErr != nil { - u.logger.Errorf("uploader: error aborting upload: %v", abortErr) - } else { - u.logger.Infof("aborted upload, key = %s", key) - } - }() - - uploadResultChan := make(chan uploadResult) - uploadErrorChan := make(chan error, 1) - - // uploadPart uploads an individual part. - uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) { - defer wg.Done() - - partLen := int64(len(buf)) - u.logger.With("key", key, "partNum", partNum, "partLen", partLen).Debug("uploading part") - - input := s3.UploadPartInput{ - Body: bytes.NewReader(buf), - Bucket: aws.String(bucket), - Key: aws.String(key), - PartNumber: partNum, - UploadId: output.UploadId, - ContentLength: partLen, - } - - output, uploadErr := u.s3.UploadPart(ctx, &input) - if uploadErr != nil { - // TODO: retry on failure - uploadErrorChan <- uploadErr - return - } - - u.logger.With("key", key, "partNum", partNum, "partLen", partLen, "etag", *output.ETag).Debug("uploaded part") - - uploadResultChan <- uploadResult{ - completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum}, - size: partLen, - } - } - - wgDone := make(chan struct{}) - var wg sync.WaitGroup - wg.Add(1) // done when the reader goroutine returns - go func() { - wg.Wait() - wgDone <- struct{}{} - }() - - readChan := make(chan error, 1) - - go func() { - defer wg.Done() - - var closing bool - currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+readBufferSizeBytes)) - partNum := int32(1) - buf := make([]byte, readBufferSizeBytes) - - for { - n, readErr := r.Read(buf) - if readErr == io.EOF { - closing = true - } else if readErr != nil { - readChan <- readErr - return - } - - _, _ = currPart.Write(buf[:n]) - if closing || currPart.Len() >= targetPartSizeBytes { - part := make([]byte, currPart.Len()) - copy(part, currPart.Bytes()) - currPart.Truncate(0) - - wg.Add(1) - go uploadPart(&wg, part, partNum) - partNum++ - } - - if closing { - return - } - } - }() - - results := make([]uploadResult, 0, 64) - -outer: - for { - select { - case readErr := <-readChan: - if readErr != io.EOF { - return 0, fmt.Errorf("reader error: %v", readErr) - } - case uploadResult := <-uploadResultChan: - results = append(results, uploadResult) - case uploadErr := <-uploadErrorChan: - return 0, fmt.Errorf("error while uploading part: %v", uploadErr) - case <-ctx.Done(): - return 0, ctx.Err() - case <-wgDone: - break outer - } - } - - if len(results) == 0 { - return 0, errors.New("no parts available to upload") - } - - completedParts := make([]types.CompletedPart, 0, 64) - var uploadedBytes int64 - for _, result := range results { - completedParts = append(completedParts, result.completedPart) - uploadedBytes += result.size - } - - // the parts may be out of order, especially with slow network conditions: - sort.Slice(completedParts, func(i, j int) bool { - return completedParts[i].PartNumber < completedParts[j].PartNumber - }) - - completeInput := s3.CompleteMultipartUploadInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - UploadId: output.UploadId, - MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts}, - } - - if _, err = u.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil { - return 0, fmt.Errorf("error completing upload: %v", err) - } - - u.logger.With("key", key, "numParts", len(completedParts), "len", uploadedBytes).Debug("completed upload") - uploaded = true - - return uploadedBytes, nil -} diff --git a/backend/server/server.go b/backend/server/server.go index 59eebfd..86c122c 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -64,7 +64,8 @@ type Options struct { Timeout time.Duration Store media.Store YoutubeClient media.YoutubeClient - S3API media.S3API + FileStore media.FileStore + Logger *zap.Logger } // mediaSetServiceController implements gRPC controller for MediaSetService @@ -226,32 +227,26 @@ func (c *mediaSetServiceController) GetVideoThumbnail(ctx context.Context, reque } func Start(options Options) error { - logger, err := buildLogger(options.Config) - if err != nil { - return fmt.Errorf("error building logger: %v", err) - } - defer logger.Sync() - fetchMediaSetService := media.NewMediaSetService( options.Store, options.YoutubeClient, - options.S3API, + options.FileStore, options.Config, - logger, + options.Logger.Sugar().Named("mediaSetService"), ) - grpcServer, err := buildGRPCServer(options.Config, logger) + grpcServer, err := buildGRPCServer(options.Config, options.Logger) if err != nil { return fmt.Errorf("error building server: %v", err) } - mediaSetController := &mediaSetServiceController{mediaSetService: fetchMediaSetService, logger: logger.Sugar().Named("controller")} + mediaSetController := &mediaSetServiceController{mediaSetService: fetchMediaSetService, logger: options.Logger.Sugar().Named("controller")} pbmediaset.RegisterMediaSetServiceServer(grpcServer, mediaSetController) // TODO: configure CORS grpcWebServer := grpcweb.WrapServer(grpcServer, grpcweb.WithOriginFunc(func(string) bool { return true })) - log := logger.Sugar() + log := options.Logger.Sugar() fileHandler := http.NotFoundHandler() if options.Config.AssetsHTTPBasePath != "" { log.With("basePath", options.Config.AssetsHTTPBasePath).Info("Configured to serve assets over HTTP") @@ -280,13 +275,6 @@ func Start(options Options) error { return httpServer.ListenAndServe() } -func buildLogger(c config.Config) (*zap.Logger, error) { - if c.Environment == config.EnvProduction { - return zap.NewProduction() - } - return zap.NewDevelopment() -} - func buildGRPCServer(c config.Config, logger *zap.Logger) (*grpc.Server, error) { unaryInterceptors := []grpc.UnaryServerInterceptor{ grpczap.UnaryServerInterceptor(logger),