Extract S3 code to S3FileStore

Re: #5
This commit is contained in:
Rob Watson 2021-12-07 20:58:11 +01:00
parent c849b8d2e6
commit f2d7af0860
13 changed files with 569 additions and 661 deletions

View File

@ -57,7 +57,7 @@ sqlc generate
Mocks require [mockery](https://github.com/vektra/mockery) to be installed, and can be regenerated with:
```
go generate
go generate ./...
```
### Migrations

View File

@ -6,18 +6,20 @@ import (
"time"
"git.netflux.io/rob/clipper/config"
"git.netflux.io/rob/clipper/filestore"
"git.netflux.io/rob/clipper/generated/store"
"git.netflux.io/rob/clipper/media"
"git.netflux.io/rob/clipper/server"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/kkdai/youtube/v2"
"go.uber.org/zap"
)
const (
DefaultTimeout = 600 * time.Second
defaultTimeout = 600 * time.Second
defaultURLExpiry = time.Hour
)
func main() {
@ -35,6 +37,9 @@ func main() {
}
store := store.New(dbConn)
// Create a Youtube client
var youtubeClient youtube.Client
// Create an Amazon S3 service s3Client
cfg, err := awsconfig.LoadDefaultConfig(
ctx,
@ -47,23 +52,39 @@ func main() {
log.Fatal(err)
}
s3Client := s3.NewFromConfig(cfg)
// Create an Amazon S3 presign client
s3PresignClient := s3.NewPresignClient(s3Client)
// Create a Youtube client
var youtubeClient youtube.Client
// Create a logger
logger, err := buildLogger(config)
if err != nil {
log.Fatal(err)
}
defer logger.Sync()
serverOptions := server.Options{
Config: config,
Timeout: DefaultTimeout,
Store: store,
YoutubeClient: &youtubeClient,
S3API: media.S3API{
// Create a file store
fileStore := filestore.NewS3FileStore(
filestore.S3API{
S3Client: s3Client,
S3PresignClient: s3PresignClient,
},
}
config.S3Bucket,
defaultURLExpiry,
logger.Sugar().Named("filestore"),
)
log.Fatal(server.Start(serverOptions))
log.Fatal(server.Start(server.Options{
Config: config,
Timeout: defaultTimeout,
Store: store,
YoutubeClient: &youtubeClient,
FileStore: fileStore,
Logger: logger,
}))
}
func buildLogger(c config.Config) (*zap.Logger, error) {
if c.Environment == config.EnvDevelopment {
return zap.NewDevelopment()
}
return zap.NewProduction()
}

280
backend/filestore/s3.go Normal file
View File

@ -0,0 +1,280 @@
package filestore
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"sort"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"go.uber.org/zap"
)
// S3API provides an API to AWS S3.
type S3API struct {
S3Client
S3PresignClient
}
// S3Client wraps the AWS S3 service client.
type S3Client interface {
GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)
CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error)
UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error)
AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error)
CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error)
}
// S3PresignClient wraps the AWS S3 Presign client.
type S3PresignClient interface {
PresignGetObject(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) (*signerv4.PresignedHTTPRequest, error)
}
// S3FileStore stores files on Amazon S3.
type S3FileStore struct {
s3 S3API
bucket string
urlExpiry time.Duration
logger *zap.SugaredLogger
}
// NewS3FileStore builds a new S3FileStore using the provided configuration.
func NewS3FileStore(s3API S3API, bucket string, urlExpiry time.Duration, logger *zap.SugaredLogger) *S3FileStore {
return &S3FileStore{
s3: s3API,
bucket: bucket,
urlExpiry: urlExpiry,
logger: logger,
}
}
// GetObject returns an io.Reader that returns an object associated with the
// provided key.
func (s *S3FileStore) GetObject(ctx context.Context, key string) (io.ReadCloser, error) {
input := s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
}
output, err := s.s3.GetObject(ctx, &input)
if err != nil {
return nil, fmt.Errorf("error getting object from s3: %v", err)
}
return output.Body, nil
}
// GetObjectWithRange returns an io.Reader that returns a partial object
// associated with the provided key.
func (s *S3FileStore) GetObjectWithRange(ctx context.Context, key string, start, end int64) (io.ReadCloser, error) {
byteRange := fmt.Sprintf("bytes=%d-%d", start, end)
input := s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
Range: aws.String(byteRange),
}
output, err := s.s3.GetObject(ctx, &input)
if err != nil {
return nil, fmt.Errorf("error getting object from s3: %v", err)
}
return output.Body, nil
}
// GetURL returns a presigned URL pointing to the object associated with the
// provided key.
func (s *S3FileStore) GetURL(ctx context.Context, key string) (string, error) {
input := s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
}
request, err := s.s3.PresignGetObject(ctx, &input, s3.WithPresignExpires(s.urlExpiry))
if err != nil {
return "", fmt.Errorf("error generating presigned URL: %v", err)
}
return request.URL, nil
}
// PutObject uploads an object using multipart upload, returning the number of
// bytes uploaded and any error.
func (s *S3FileStore) PutObject(ctx context.Context, key string, r io.Reader, contentType string) (int64, error) {
const (
targetPartSizeBytes = 5 * 1024 * 1024 // 5MB
readBufferSizeBytes = 32_768 // 32Kb
)
var uploaded bool
input := s3.CreateMultipartUploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
ContentType: aws.String(contentType),
}
output, err := s.s3.CreateMultipartUpload(ctx, &input)
if err != nil {
return 0, fmt.Errorf("error creating multipart upload: %v", err)
}
// abort the upload if possible, logging any errors, on exit.
defer func() {
if uploaded {
return
}
input := s3.AbortMultipartUploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
UploadId: output.UploadId,
}
// if the context was cancelled, just use the background context.
ctxToUse := ctx
if ctxToUse.Err() != nil {
ctxToUse = context.Background()
}
_, deferErr := s.s3.AbortMultipartUpload(ctxToUse, &input)
if deferErr != nil {
s.logger.Errorf("uploader: error aborting upload: %v", deferErr)
} else {
s.logger.Infof("aborted upload, key = %s", key)
}
}()
type uploadedPart struct {
part types.CompletedPart
size int64
}
uploadResultChan := make(chan uploadedPart)
uploadErrorChan := make(chan error, 1)
// uploadPart uploads an individual part.
uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) {
defer wg.Done()
partLen := int64(len(buf))
s.logger.With("key", key, "partNum", partNum, "partLen", partLen).Debug("uploading part")
input := s3.UploadPartInput{
Body: bytes.NewReader(buf),
Bucket: aws.String(s.bucket),
Key: aws.String(key),
PartNumber: partNum,
UploadId: output.UploadId,
ContentLength: partLen,
}
output, uploadErr := s.s3.UploadPart(ctx, &input)
if uploadErr != nil {
// TODO: retry on failure
uploadErrorChan <- uploadErr
return
}
s.logger.With("key", key, "partNum", partNum, "partLen", partLen, "etag", *output.ETag).Debug("uploaded part")
uploadResultChan <- uploadedPart{
part: types.CompletedPart{ETag: output.ETag, PartNumber: partNum},
size: partLen,
}
}
wgDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1) // done when the reader goroutine returns
go func() {
wg.Wait()
wgDone <- struct{}{}
}()
readChan := make(chan error, 1)
go func() {
defer wg.Done()
var closing bool
currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+readBufferSizeBytes))
partNum := int32(1)
buf := make([]byte, readBufferSizeBytes)
for {
n, readErr := r.Read(buf)
if readErr == io.EOF {
closing = true
} else if readErr != nil {
readChan <- readErr
return
}
_, _ = currPart.Write(buf[:n])
if closing || currPart.Len() >= targetPartSizeBytes {
part := make([]byte, currPart.Len())
copy(part, currPart.Bytes())
currPart.Truncate(0)
wg.Add(1)
go uploadPart(&wg, part, partNum)
partNum++
}
if closing {
return
}
}
}()
results := make([]uploadedPart, 0, 64)
outer:
for {
select {
case readErr := <-readChan:
if readErr != io.EOF {
return 0, fmt.Errorf("reader error: %v", readErr)
}
case uploadResult := <-uploadResultChan:
results = append(results, uploadResult)
case uploadErr := <-uploadErrorChan:
return 0, fmt.Errorf("error while uploading part: %v", uploadErr)
case <-ctx.Done():
return 0, ctx.Err()
case <-wgDone:
break outer
}
}
if len(results) == 0 {
return 0, errors.New("no parts available to upload")
}
completedParts := make([]types.CompletedPart, 0, 64)
var uploadedBytes int64
for _, result := range results {
completedParts = append(completedParts, result.part)
uploadedBytes += result.size
}
// the parts may be out of order, especially with slow network conditions:
sort.Slice(completedParts, func(i, j int) bool {
return completedParts[i].PartNumber < completedParts[j].PartNumber
})
completeInput := s3.CompleteMultipartUploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
UploadId: output.UploadId,
MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts},
}
if _, err = s.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil {
return 0, fmt.Errorf("error completing upload: %v", err)
}
s.logger.With("key", key, "numParts", len(completedParts), "len", uploadedBytes).Debug("completed upload")
uploaded = true
return uploadedBytes, nil
}

View File

@ -0,0 +1,103 @@
// Code generated by mockery v2.9.4. DO NOT EDIT.
package mocks
import (
context "context"
io "io"
mock "github.com/stretchr/testify/mock"
)
// FileStore is an autogenerated mock type for the FileStore type
type FileStore struct {
mock.Mock
}
// GetObject provides a mock function with given fields: _a0, _a1
func (_m *FileStore) GetObject(_a0 context.Context, _a1 string) (io.ReadCloser, error) {
ret := _m.Called(_a0, _a1)
var r0 io.ReadCloser
if rf, ok := ret.Get(0).(func(context.Context, string) io.ReadCloser); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(io.ReadCloser)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetObjectWithRange provides a mock function with given fields: _a0, _a1, _a2, _a3
func (_m *FileStore) GetObjectWithRange(_a0 context.Context, _a1 string, _a2 int64, _a3 int64) (io.ReadCloser, error) {
ret := _m.Called(_a0, _a1, _a2, _a3)
var r0 io.ReadCloser
if rf, ok := ret.Get(0).(func(context.Context, string, int64, int64) io.ReadCloser); ok {
r0 = rf(_a0, _a1, _a2, _a3)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(io.ReadCloser)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, int64, int64) error); ok {
r1 = rf(_a0, _a1, _a2, _a3)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetURL provides a mock function with given fields: _a0, _a1
func (_m *FileStore) GetURL(_a0 context.Context, _a1 string) (string, error) {
ret := _m.Called(_a0, _a1)
var r0 string
if rf, ok := ret.Get(0).(func(context.Context, string) string); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Get(0).(string)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// PutObject provides a mock function with given fields: _a0, _a1, _a2, _a3
func (_m *FileStore) PutObject(_a0 context.Context, _a1 string, _a2 io.Reader, _a3 string) (int64, error) {
ret := _m.Called(_a0, _a1, _a2, _a3)
var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, string, io.Reader, string) int64); ok {
r0 = rf(_a0, _a1, _a2, _a3)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, io.Reader, string) error); ok {
r1 = rf(_a0, _a1, _a2, _a3)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@ -1,166 +0,0 @@
// Code generated by mockery v2.9.4. DO NOT EDIT.
package mocks
import (
context "context"
mock "github.com/stretchr/testify/mock"
s3 "github.com/aws/aws-sdk-go-v2/service/s3"
)
// S3Client is an autogenerated mock type for the S3Client type
type S3Client struct {
mock.Mock
}
// AbortMultipartUpload provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3Client) AbortMultipartUpload(_a0 context.Context, _a1 *s3.AbortMultipartUploadInput, _a2 ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *s3.AbortMultipartUploadOutput
if rf, ok := ret.Get(0).(func(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) *s3.AbortMultipartUploadOutput); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*s3.AbortMultipartUploadOutput)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CompleteMultipartUpload provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3Client) CompleteMultipartUpload(_a0 context.Context, _a1 *s3.CompleteMultipartUploadInput, _a2 ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *s3.CompleteMultipartUploadOutput
if rf, ok := ret.Get(0).(func(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) *s3.CompleteMultipartUploadOutput); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*s3.CompleteMultipartUploadOutput)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CreateMultipartUpload provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3Client) CreateMultipartUpload(_a0 context.Context, _a1 *s3.CreateMultipartUploadInput, _a2 ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *s3.CreateMultipartUploadOutput
if rf, ok := ret.Get(0).(func(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) *s3.CreateMultipartUploadOutput); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*s3.CreateMultipartUploadOutput)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetObject provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3Client) GetObject(_a0 context.Context, _a1 *s3.GetObjectInput, _a2 ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *s3.GetObjectOutput
if rf, ok := ret.Get(0).(func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) *s3.GetObjectOutput); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*s3.GetObjectOutput)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UploadPart provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3Client) UploadPart(_a0 context.Context, _a1 *s3.UploadPartInput, _a2 ...func(*s3.Options)) (*s3.UploadPartOutput, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *s3.UploadPartOutput
if rf, ok := ret.Get(0).(func(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) *s3.UploadPartOutput); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*s3.UploadPartOutput)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@ -1,48 +0,0 @@
// Code generated by mockery v2.9.4. DO NOT EDIT.
package mocks
import (
context "context"
mock "github.com/stretchr/testify/mock"
s3 "github.com/aws/aws-sdk-go-v2/service/s3"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
)
// S3PresignClient is an autogenerated mock type for the S3PresignClient type
type S3PresignClient struct {
mock.Mock
}
// PresignGetObject provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3PresignClient) PresignGetObject(_a0 context.Context, _a1 *s3.GetObjectInput, _a2 ...func(*s3.PresignOptions)) (*v4.PresignedHTTPRequest, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *v4.PresignedHTTPRequest
if rf, ok := ret.Get(0).(func(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) *v4.PresignedHTTPRequest); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v4.PresignedHTTPRequest)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@ -13,8 +13,6 @@ import (
"git.netflux.io/rob/clipper/config"
"git.netflux.io/rob/clipper/generated/store"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"go.uber.org/zap"
)
@ -33,25 +31,25 @@ type GetAudioProgressReader interface {
type audioGetter struct {
store Store
youtube YoutubeClient
s3API S3API
fileStore FileStore
config config.Config
logger *zap.SugaredLogger
}
// newAudioGetter returns a new audioGetter.
func newAudioGetter(store Store, youtube YoutubeClient, s3API S3API, config config.Config, logger *zap.SugaredLogger) *audioGetter {
func newAudioGetter(store Store, youtube YoutubeClient, fileStore FileStore, config config.Config, logger *zap.SugaredLogger) *audioGetter {
return &audioGetter{
store: store,
youtube: youtube,
s3API: s3API,
fileStore: fileStore,
config: config,
logger: logger,
}
}
// GetAudio gets the audio, processes it and uploads it to S3. It returns a
// GetAudioProgressReader that can be used to poll progress reports and audio
// peaks.
// GetAudio gets the audio, processes it and uploads it to a file store. It
// returns a GetAudioProgressReader that can be used to poll progress reports
// and audio peaks.
//
// TODO: accept domain object instead
func (g *audioGetter) GetAudio(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) {
@ -114,33 +112,28 @@ func (s *audioGetterState) getAudio(ctx context.Context, r io.ReadCloser, mediaS
wg.Add(2)
// Upload the encoded audio.
// TODO: fix error shadowing in these two goroutines.
go func() {
defer wg.Done()
// TODO: use mediaSet func to fetch s3Key
s3Key := fmt.Sprintf("media_sets/%s/audio.opus", mediaSet.ID)
// TODO: use mediaSet func to fetch key
key := fmt.Sprintf("media_sets/%s/audio.opus", mediaSet.ID)
uploader := newMultipartUploader(s.s3API, s.logger)
_, encErr := uploader.Upload(ctx, pr, s.config.S3Bucket, s3Key, "audio/opus")
_, encErr := s.fileStore.PutObject(ctx, key, pr, "audio/opus")
if encErr != nil {
s.CloseWithError(fmt.Errorf("error uploading encoded audio: %v", encErr))
return
}
input := s3.GetObjectInput{
Bucket: aws.String(s.config.S3Bucket),
Key: aws.String(s3Key),
}
request, err := s.s3API.PresignGetObject(ctx, &input, s3.WithPresignExpires(getAudioExpiresIn))
presignedAudioURL, err = s.fileStore.GetURL(ctx, key)
if err != nil {
s.CloseWithError(fmt.Errorf("error generating presigned URL: %v", err))
}
presignedAudioURL = request.URL
if _, err = s.store.SetEncodedAudioUploaded(ctx, store.SetEncodedAudioUploadedParams{
ID: mediaSet.ID,
AudioEncodedS3Bucket: sqlString(s.config.S3Bucket),
AudioEncodedS3Key: sqlString(s3Key),
AudioEncodedS3Key: sqlString(key),
}); err != nil {
s.CloseWithError(fmt.Errorf("error setting encoded audio uploaded: %v", err))
}
@ -150,12 +143,11 @@ func (s *audioGetterState) getAudio(ctx context.Context, r io.ReadCloser, mediaS
go func() {
defer wg.Done()
// TODO: use mediaSet func to fetch s3Key
s3Key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID)
// TODO: use mediaSet func to fetch key
key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID)
teeReader := io.TeeReader(stdout, s)
uploader := newMultipartUploader(s.s3API, s.logger)
bytesUploaded, rawErr := uploader.Upload(ctx, teeReader, s.config.S3Bucket, s3Key, rawAudioMimeType)
bytesUploaded, rawErr := s.fileStore.PutObject(ctx, key, teeReader, rawAudioMimeType)
if rawErr != nil {
s.CloseWithError(fmt.Errorf("error uploading raw audio: %v", rawErr))
return
@ -164,7 +156,7 @@ func (s *audioGetterState) getAudio(ctx context.Context, r io.ReadCloser, mediaS
if _, err = s.store.SetRawAudioUploaded(ctx, store.SetRawAudioUploadedParams{
ID: mediaSet.ID,
AudioRawS3Bucket: sqlString(s.config.S3Bucket),
AudioRawS3Key: sqlString(s3Key),
AudioRawS3Key: sqlString(key),
AudioFrames: sqlInt64(bytesUploaded / SizeOfInt16 / int64(mediaSet.AudioChannels)),
}); err != nil {
s.CloseWithError(fmt.Errorf("error setting raw audio uploaded: %v", err))

View File

@ -6,8 +6,6 @@ import (
"io"
"git.netflux.io/rob/clipper/generated/store"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid"
"go.uber.org/zap"
)
@ -24,8 +22,8 @@ type GetVideoProgressReader interface {
}
type videoGetter struct {
s3 S3API
store Store
fileStore FileStore
logger *zap.SugaredLogger
}
@ -41,21 +39,20 @@ type videoGetterState struct {
errorChan chan error
}
func newVideoGetter(s3 S3API, store Store, logger *zap.SugaredLogger) *videoGetter {
return &videoGetter{s3: s3, store: store, logger: logger}
func newVideoGetter(store Store, fileStore FileStore, logger *zap.SugaredLogger) *videoGetter {
return &videoGetter{store: store, fileStore: fileStore, logger: logger}
}
// GetVideo gets video from Youtube and uploads it to S3 using the specified
// bucket, key and content type. The returned reader must have its Next()
// GetVideo gets video from Youtube and uploads it to a filestore using the
// specified key and content type. The returned reader must have its Next()
// method called until error = io.EOF, otherwise a deadlock or other resource
// leakage is likely.
func (g *videoGetter) GetVideo(ctx context.Context, r io.Reader, exp int64, mediaSetID uuid.UUID, bucket, key, contentType string) (GetVideoProgressReader, error) {
func (g *videoGetter) GetVideo(ctx context.Context, r io.Reader, exp int64, mediaSetID uuid.UUID, key, contentType string) (GetVideoProgressReader, error) {
s := &videoGetterState{
videoGetter: g,
r: newProgressReader(r, "video", exp, g.logger),
exp: exp,
mediaSetID: mediaSetID,
bucket: bucket,
key: key,
contentType: contentType,
progressChan: make(chan GetVideoProgress),
@ -69,7 +66,7 @@ func (g *videoGetter) GetVideo(ctx context.Context, r io.Reader, exp int64, medi
}
// Write implements io.Writer. It is copied that same data that is written to
// S3, to implement progress tracking.
// the file store, to implement progress tracking.
func (s *videoGetterState) Write(p []byte) (int, error) {
s.count += int64(len(p))
pc := (float32(s.count) / float32(s.exp)) * 100
@ -78,24 +75,18 @@ func (s *videoGetterState) Write(p []byte) (int, error) {
}
func (s *videoGetterState) getVideo(ctx context.Context) {
uploader := newMultipartUploader(s.s3, s.logger)
teeReader := io.TeeReader(s.r, s)
_, err := uploader.Upload(ctx, teeReader, s.bucket, s.key, s.contentType)
_, err := s.fileStore.PutObject(ctx, s.key, teeReader, s.contentType)
if err != nil {
s.errorChan <- fmt.Errorf("error uploading to S3: %v", err)
s.errorChan <- fmt.Errorf("error uploading to file store: %v", err)
return
}
input := s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s.key),
}
request, err := s.s3.PresignGetObject(ctx, &input, s3.WithPresignExpires(getVideoExpiresIn))
s.url, err = s.fileStore.GetURL(ctx, s.key)
if err != nil {
s.errorChan <- fmt.Errorf("error generating presigned URL: %v", err)
s.errorChan <- fmt.Errorf("error getting object URL: %v", err)
}
s.url = request.URL
storeParams := store.SetVideoUploadedParams{
ID: s.mediaSetID,

View File

@ -14,8 +14,6 @@ import (
"git.netflux.io/rob/clipper/config"
"git.netflux.io/rob/clipper/generated/store"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid"
"github.com/jackc/pgx/v4"
youtubev2 "github.com/kkdai/youtube/v2"
@ -43,25 +41,25 @@ const (
type MediaSetService struct {
store Store
youtube YoutubeClient
s3 S3API
fileStore FileStore
config config.Config
logger *zap.SugaredLogger
}
func NewMediaSetService(store Store, youtubeClient YoutubeClient, s3API S3API, config config.Config, logger *zap.Logger) *MediaSetService {
func NewMediaSetService(store Store, youtubeClient YoutubeClient, fileStore FileStore, config config.Config, logger *zap.SugaredLogger) *MediaSetService {
return &MediaSetService{
store: store,
youtube: youtubeClient,
s3: s3API,
fileStore: fileStore,
config: config,
logger: logger.Sugar(),
logger: logger,
}
}
// Get fetches the metadata for a given MediaSet source. If it does not exist
// in the local DB, it will attempt to create it. After the resource has been
// created, other endpoints (e.g. GetAudio) can be called to fetch media from
// Youtube and store it in S3.
// Youtube and store it in a file store.
func (s *MediaSetService) Get(ctx context.Context, youtubeID string) (*MediaSet, error) {
var (
mediaSet *MediaSet
@ -220,15 +218,11 @@ func (s *MediaSetService) GetVideo(ctx context.Context, id uuid.UUID) (GetVideoP
}
if mediaSet.VideoS3UploadedAt.Valid {
input := s3.GetObjectInput{
Bucket: aws.String(s.config.S3Bucket),
Key: aws.String(mediaSet.VideoS3Key.String),
url, err := s.fileStore.GetURL(ctx, mediaSet.VideoS3Key.String)
if err != nil {
return nil, fmt.Errorf("error generating presigned URL: %v", err)
}
request, signErr := s.s3.PresignGetObject(ctx, &input, s3.WithPresignExpires(getVideoExpiresIn))
if signErr != nil {
return nil, fmt.Errorf("error generating presigned URL: %v", signErr)
}
videoGetter := videoGetterDownloaded(request.URL)
videoGetter := videoGetterDownloaded(url)
return &videoGetter, nil
}
@ -247,17 +241,16 @@ func (s *MediaSetService) GetVideo(ctx context.Context, id uuid.UUID) (GetVideoP
return nil, fmt.Errorf("error fetching stream: %v", err)
}
// TODO: use mediaSet func to fetch s3Key
s3Key := fmt.Sprintf("media_sets/%s/video.mp4", mediaSet.ID)
// TODO: use mediaSet func to fetch videoKey
videoKey := fmt.Sprintf("media_sets/%s/video.mp4", mediaSet.ID)
videoGetter := newVideoGetter(s.s3, s.store, s.logger)
videoGetter := newVideoGetter(s.store, s.fileStore, s.logger)
return videoGetter.GetVideo(
ctx,
stream,
format.ContentLength,
mediaSet.ID,
s.config.S3Bucket,
s3Key,
videoKey,
format.MimeType,
)
}
@ -273,25 +266,21 @@ func (s *MediaSetService) GetAudio(ctx context.Context, id uuid.UUID, numBins in
// Otherwise, we cannot return both peaks and a presigned URL for use by the
// player.
if mediaSet.AudioRawS3UploadedAt.Valid && mediaSet.AudioEncodedS3UploadedAt.Valid {
return s.getAudioFromS3(ctx, mediaSet, numBins)
return s.getAudioFromFileStore(ctx, mediaSet, numBins)
}
return s.getAudioFromYoutube(ctx, mediaSet, numBins)
}
func (s *MediaSetService) getAudioFromYoutube(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) {
audioGetter := newAudioGetter(s.store, s.youtube, s.s3, s.config, s.logger)
audioGetter := newAudioGetter(s.store, s.youtube, s.fileStore, s.config, s.logger)
return audioGetter.GetAudio(ctx, mediaSet, numBins)
}
func (s *MediaSetService) getAudioFromS3(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) {
input := s3.GetObjectInput{
Bucket: aws.String(mediaSet.AudioRawS3Bucket.String),
Key: aws.String(mediaSet.AudioRawS3Key.String),
}
output, err := s.s3.GetObject(ctx, &input)
func (s *MediaSetService) getAudioFromFileStore(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) {
object, err := s.fileStore.GetObject(ctx, mediaSet.AudioRawS3Key.String)
if err != nil {
return nil, fmt.Errorf("error getting object from s3: %v", err)
return nil, fmt.Errorf("error getting object from file store: %v", err)
}
getAudioProgressReader, err := newGetAudioProgressReader(
@ -303,10 +292,10 @@ func (s *MediaSetService) getAudioFromS3(ctx context.Context, mediaSet store.Med
return nil, fmt.Errorf("error creating audio reader: %v", err)
}
state := getAudioFromS3State{
state := getAudioFromFileStoreState{
getAudioProgressReader: getAudioProgressReader,
s3Reader: NewModuloBufReader(output.Body, int(mediaSet.AudioChannels)*SizeOfInt16),
s3API: s.s3,
reader: NewModuloBufReader(object, int(mediaSet.AudioChannels)*SizeOfInt16),
fileStore: s.fileStore,
config: s.config,
logger: s.logger,
}
@ -315,21 +304,21 @@ func (s *MediaSetService) getAudioFromS3(ctx context.Context, mediaSet store.Med
return &state, nil
}
type getAudioFromS3State struct {
type getAudioFromFileStoreState struct {
*getAudioProgressReader
s3Reader io.ReadCloser
s3API S3API
reader io.ReadCloser
fileStore FileStore
config config.Config
logger *zap.SugaredLogger
}
func (s *getAudioFromS3State) run(ctx context.Context, mediaSet store.MediaSet) {
func (s *getAudioFromFileStoreState) run(ctx context.Context, mediaSet store.MediaSet) {
done := make(chan error)
var err error
go func() {
_, copyErr := io.Copy(s, s.s3Reader)
_, copyErr := io.Copy(s, s.reader)
done <- copyErr
}()
@ -344,29 +333,25 @@ outer:
}
}
if readerErr := s.s3Reader.Close(); readerErr != nil {
if readerErr := s.reader.Close(); readerErr != nil {
if err == nil {
err = readerErr
}
}
if err != nil {
s.logger.Errorf("getAudioFromS3State: error closing s3 reader: %v", err)
s.logger.Errorf("getAudioFromFileStoreState: error closing reader: %v", err)
s.CloseWithError(err)
return
}
input := s3.GetObjectInput{
Bucket: aws.String(s.config.S3Bucket),
Key: aws.String(mediaSet.AudioEncodedS3Key.String),
}
request, err := s.s3API.PresignGetObject(ctx, &input, s3.WithPresignExpires(getAudioExpiresIn))
url, err := s.fileStore.GetURL(ctx, mediaSet.AudioEncodedS3Key.String)
if err != nil {
s.CloseWithError(fmt.Errorf("error generating presigned URL: %v", err))
s.CloseWithError(fmt.Errorf("error generating object URL: %v", err))
}
if iterErr := s.Close(request.URL); iterErr != nil {
s.logger.Errorf("getAudioFromS3State: error closing progress iterator: %v", iterErr)
if iterErr := s.Close(url); iterErr != nil {
s.logger.Errorf("getAudioFromFileStoreState: error closing progress iterator: %v", iterErr)
}
}
@ -381,25 +366,20 @@ func (s *MediaSetService) GetAudioSegment(ctx context.Context, id uuid.UUID, sta
return nil, fmt.Errorf("error getting media set: %v", err)
}
byteRange := fmt.Sprintf(
"bytes=%d-%d",
object, err := s.fileStore.GetObjectWithRange(
ctx,
mediaSet.AudioRawS3Key.String,
startFrame*int64(mediaSet.AudioChannels)*SizeOfInt16,
endFrame*int64(mediaSet.AudioChannels)*SizeOfInt16,
)
input := s3.GetObjectInput{
Bucket: aws.String(mediaSet.AudioRawS3Bucket.String),
Key: aws.String(mediaSet.AudioRawS3Key.String),
Range: aws.String(byteRange),
}
output, err := s.s3.GetObject(ctx, &input)
if err != nil {
return nil, fmt.Errorf("error getting object from s3: %v", err)
return nil, fmt.Errorf("error getting object from file store: %v", err)
}
defer output.Body.Close()
defer object.Close()
const readBufSizeBytes = 8_192
channels := int(mediaSet.AudioChannels)
modReader := NewModuloBufReader(output.Body, channels*SizeOfInt16)
modReader := NewModuloBufReader(object, channels*SizeOfInt16)
readBuf := make([]byte, readBufSizeBytes)
peaks := make([]int16, channels*numBins)
totalFrames := endFrame - startFrame
@ -454,7 +434,7 @@ func (s *MediaSetService) GetAudioSegment(ctx context.Context, id uuid.UUID, sta
}
if bytesRead < bytesExpected {
s.logger.With("startFrame", startFrame, "endFrame", endFrame, "got", bytesRead, "want", bytesExpected, "key", mediaSet.AudioRawS3Key.String).Warn("short read from S3")
s.logger.With("startFrame", startFrame, "endFrame", endFrame, "got", bytesRead, "want", bytesExpected, "key", mediaSet.AudioRawS3Key.String).Warn("short read from file store")
}
return peaks, nil
@ -521,26 +501,22 @@ func (s *MediaSetService) GetVideoThumbnail(ctx context.Context, id uuid.UUID) (
}
if mediaSet.VideoThumbnailS3UploadedAt.Valid {
return s.getThumbnailFromS3(ctx, mediaSet)
return s.getThumbnailFromFileStore(ctx, mediaSet)
}
return s.getThumbnailFromYoutube(ctx, mediaSet)
}
func (s *MediaSetService) getThumbnailFromS3(ctx context.Context, mediaSet store.MediaSet) (VideoThumbnail, error) {
input := s3.GetObjectInput{
Bucket: aws.String(mediaSet.VideoThumbnailS3Bucket.String),
Key: aws.String(mediaSet.VideoThumbnailS3Key.String),
}
output, err := s.s3.GetObject(ctx, &input)
func (s *MediaSetService) getThumbnailFromFileStore(ctx context.Context, mediaSet store.MediaSet) (VideoThumbnail, error) {
object, err := s.fileStore.GetObject(ctx, mediaSet.VideoThumbnailS3Key.String)
if err != nil {
return VideoThumbnail{}, fmt.Errorf("error fetching thumbnail from s3: %v", err)
return VideoThumbnail{}, fmt.Errorf("error fetching thumbnail from file store: %v", err)
}
defer output.Body.Close()
defer object.Close()
imageData, err := io.ReadAll(output.Body)
imageData, err := io.ReadAll(object)
if err != nil {
return VideoThumbnail{}, fmt.Errorf("error reading thumbnail from s3: %v", err)
return VideoThumbnail{}, fmt.Errorf("error reading thumbnail from file store: %v", err)
}
return VideoThumbnail{
@ -575,13 +551,11 @@ func (s *MediaSetService) getThumbnailFromYoutube(ctx context.Context, mediaSet
return VideoThumbnail{}, fmt.Errorf("error reading thumbnail: %v", err)
}
// TODO: use mediaSet func to fetch s3Key
s3Key := fmt.Sprintf("media_sets/%s/thumbnail.jpg", mediaSet.ID)
// TODO: use mediaSet func to fetch key
thumbnailKey := fmt.Sprintf("media_sets/%s/thumbnail.jpg", mediaSet.ID)
uploader := newMultipartUploader(s.s3, s.logger)
const mimeType = "application/jpeg"
_, err = uploader.Upload(ctx, bytes.NewReader(imageData), s.config.S3Bucket, s3Key, mimeType)
_, err = s.fileStore.PutObject(ctx, thumbnailKey, bytes.NewReader(imageData), mimeType)
if err != nil {
return VideoThumbnail{}, fmt.Errorf("error uploading thumbnail: %v", err)
}
@ -590,7 +564,7 @@ func (s *MediaSetService) getThumbnailFromYoutube(ctx context.Context, mediaSet
ID: mediaSet.ID,
VideoThumbnailMimeType: sqlString(mimeType),
VideoThumbnailS3Bucket: sqlString(s.config.S3Bucket),
VideoThumbnailS3Key: sqlString(s3Key),
VideoThumbnailS3Key: sqlString(thumbnailKey),
VideoThumbnailWidth: sqlInt32(int32(thumbnail.Width)),
VideoThumbnailHeight: sqlInt32(int32(thumbnail.Height)),
}
@ -614,7 +588,7 @@ func newProgressReader(reader io.Reader, label string, exp int64, logger *zap.Su
return &progressReader{
Reader: reader,
exp: exp,
logger: logger.Named(fmt.Sprintf("ProgressReader %s", label)),
logger: logger.Named(label),
}
}

View File

@ -1,8 +1,9 @@
package media_test
import (
"bytes"
"context"
"fmt"
"database/sql"
"io"
"os"
"testing"
@ -11,7 +12,6 @@ import (
"git.netflux.io/rob/clipper/generated/mocks"
"git.netflux.io/rob/clipper/generated/store"
"git.netflux.io/rob/clipper/media"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@ -84,31 +84,34 @@ func TestGetAudioSegment(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
expectedBytes := (tc.endFrame - tc.startFrame) * int64(tc.channels) * media.SizeOfInt16
startByte := tc.startFrame * int64(tc.channels) * media.SizeOfInt16
endByte := tc.endFrame * int64(tc.channels) * media.SizeOfInt16
expectedBytes := endByte - startByte
audioFile, err := os.Open(tc.fixturePath)
require.NoError(t, err)
defer audioFile.Close()
audioData := io.NopCloser(io.LimitReader(audioFile, int64(expectedBytes)))
mediaSetID := uuid.New()
mediaSet := store.MediaSet{ID: mediaSetID, AudioChannels: tc.channels}
mediaSet := store.MediaSet{
ID: uuid.New(),
AudioChannels: tc.channels,
AudioRawS3Key: sql.NullString{String: "foo", Valid: true},
}
// store is passed the mediaSetID and returns a mediaSet
store := &mocks.Store{}
store.On("GetMediaSet", mock.Anything, mediaSetID).Return(mediaSet, nil)
store.On("GetMediaSet", mock.Anything, mediaSet.ID).Return(mediaSet, nil)
defer store.AssertExpectations(t)
// S3 is passed the expected byte range, and returns an io.Reader
s3Client := &mocks.S3Client{}
s3Client.On("GetObject", mock.Anything, mock.MatchedBy(func(input *s3.GetObjectInput) bool {
return *input.Range == fmt.Sprintf("bytes=0-%d", expectedBytes)
})).Return(&s3.GetObjectOutput{Body: audioData, ContentLength: tc.fixtureLen}, nil)
defer s3Client.AssertExpectations(t)
s3API := media.S3API{S3Client: s3Client, S3PresignClient: &mocks.S3PresignClient{}}
// fileStore is passed the expected byte range, and returns an io.Reader
fileStore := &mocks.FileStore{}
fileStore.
On("GetObjectWithRange", mock.Anything, "foo", startByte, endByte).
Return(audioData, nil)
service := media.NewMediaSetService(store, nil, s3API, config.Config{}, zap.NewNop())
peaks, err := service.GetAudioSegment(context.Background(), mediaSetID, tc.startFrame, tc.endFrame, tc.numBins)
service := media.NewMediaSetService(store, nil, fileStore, config.Config{}, zap.NewNop().Sugar())
peaks, err := service.GetAudioSegment(context.Background(), mediaSet.ID, tc.startFrame, tc.endFrame, tc.numBins)
if tc.wantErr == "" {
assert.NoError(t, err)
@ -130,14 +133,10 @@ func BenchmarkGetAudioSegment(b *testing.B) {
numBins = 2000
)
for n := 0; n < b.N; n++ {
b.StopTimer()
expectedBytes := (endFrame - startFrame) * int64(channels) * media.SizeOfInt16
audioFile, err := os.Open(fixturePath)
require.NoError(b, err)
audioData := io.NopCloser(io.LimitReader(audioFile, int64(expectedBytes)))
audioData, err := io.ReadAll(audioFile)
require.NoError(b, err)
mediaSetID := uuid.New()
mediaSet := store.MediaSet{ID: mediaSetID, AudioChannels: channels}
@ -145,20 +144,20 @@ func BenchmarkGetAudioSegment(b *testing.B) {
store := &mocks.Store{}
store.On("GetMediaSet", mock.Anything, mediaSetID).Return(mediaSet, nil)
s3Client := &mocks.S3Client{}
s3Client.
On("GetObject", mock.Anything, mock.Anything).
Return(&s3.GetObjectOutput{Body: audioData, ContentLength: fixtureLen}, nil)
s3API := media.S3API{S3Client: s3Client, S3PresignClient: &mocks.S3PresignClient{}}
service := media.NewMediaSetService(store, nil, s3API, config.Config{}, zap.NewNop())
b.StartTimer()
_, err = service.GetAudioSegment(context.Background(), mediaSetID, startFrame, endFrame, numBins)
for n := 0; n < b.N; n++ {
// recreate the reader on each iteration
b.StopTimer()
readCloser := io.NopCloser(bytes.NewReader(audioData))
fileStore := &mocks.FileStore{}
fileStore.
On("GetObjectWithRange", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(readCloser, nil)
service := media.NewMediaSetService(store, nil, fileStore, config.Config{}, zap.NewNop().Sugar())
b.StartTimer()
_, err = service.GetAudioSegment(context.Background(), mediaSetID, startFrame, endFrame, numBins)
require.NoError(b, err)
audioFile.Close()
}
}

View File

@ -1,8 +1,8 @@
package media
//go:generate mockery --recursive --name S3* --output ../generated/mocks
//go:generate mockery --recursive --name Store --output ../generated/mocks
//go:generate mockery --recursive --name YoutubeClient --output ../generated/mocks
//go:generate mockery --recursive --name FileStore --output ../generated/mocks
import (
"context"
@ -10,8 +10,6 @@ import (
"time"
"git.netflux.io/rob/clipper/generated/store"
signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid"
youtubev2 "github.com/kkdai/youtube/v2"
)
@ -63,28 +61,16 @@ type Store interface {
SetVideoThumbnailUploaded(context.Context, store.SetVideoThumbnailUploadedParams) (store.MediaSet, error)
}
// S3API provides an API to AWS S3.
type S3API struct {
S3Client
S3PresignClient
}
// S3Client wraps the AWS S3 service client.
type S3Client interface {
GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)
CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error)
UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error)
AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error)
CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error)
}
// S3PresignClient wraps the AWS S3 Presign client.
type S3PresignClient interface {
PresignGetObject(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) (*signerv4.PresignedHTTPRequest, error)
// FileStore wraps a file store.
type FileStore interface {
GetObject(ctx context.Context, key string) (io.ReadCloser, error)
GetObjectWithRange(ctx context.Context, key string, startFrame, endFrame int64) (io.ReadCloser, error)
GetURL(ctx context.Context, key string) (string, error)
PutObject(ctx context.Context, key string, reader io.Reader, contentType string) (int64, error)
}
// YoutubeClient wraps the youtube.Client client.
type YoutubeClient interface {
GetVideoContext(context.Context, string) (*youtubev2.Video, error)
GetStreamContext(context.Context, *youtubev2.Video, *youtubev2.Format) (io.ReadCloser, int64, error)
GetVideoContext(ctx context.Context, id string) (*youtubev2.Video, error)
GetStreamContext(ctx context.Context, video *youtubev2.Video, format *youtubev2.Format) (io.ReadCloser, int64, error)
}

View File

@ -1,212 +0,0 @@
package media
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"sort"
"sync"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"go.uber.org/zap"
)
// multipartUploader uploads a file to S3.
//
// TODO: extract to s3 package
type multipartUploader struct {
s3 S3Client
logger *zap.SugaredLogger
}
type uploadResult struct {
completedPart types.CompletedPart
size int64
}
const (
targetPartSizeBytes = 5 * 1024 * 1024 // 5MB
readBufferSizeBytes = 32_768 // 32Kb
)
func newMultipartUploader(s3Client S3Client, logger *zap.SugaredLogger) *multipartUploader {
return &multipartUploader{s3: s3Client, logger: logger}
}
// Upload uploads to an S3 bucket in 5MB parts. It buffers data internally
// until a part is ready to send over the network. Parts are sent as soon as
// they exceed the minimum part size of 5MB.
//
// TODO: expire after configurable period.
func (u *multipartUploader) Upload(ctx context.Context, r io.Reader, bucket, key, contentType string) (int64, error) {
var uploaded bool
input := s3.CreateMultipartUploadInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
ContentType: aws.String(contentType),
}
output, err := u.s3.CreateMultipartUpload(ctx, &input)
if err != nil {
return 0, fmt.Errorf("error creating multipart upload: %v", err)
}
// abort the upload if possible, logging any errors, on exit.
defer func() {
if uploaded {
return
}
input := s3.AbortMultipartUploadInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
UploadId: output.UploadId,
}
// if the context was cancelled, just use the background context.
ctxToUse := ctx
if ctxToUse.Err() != nil {
ctxToUse = context.Background()
}
_, abortErr := u.s3.AbortMultipartUpload(ctxToUse, &input)
if abortErr != nil {
u.logger.Errorf("uploader: error aborting upload: %v", abortErr)
} else {
u.logger.Infof("aborted upload, key = %s", key)
}
}()
uploadResultChan := make(chan uploadResult)
uploadErrorChan := make(chan error, 1)
// uploadPart uploads an individual part.
uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) {
defer wg.Done()
partLen := int64(len(buf))
u.logger.With("key", key, "partNum", partNum, "partLen", partLen).Debug("uploading part")
input := s3.UploadPartInput{
Body: bytes.NewReader(buf),
Bucket: aws.String(bucket),
Key: aws.String(key),
PartNumber: partNum,
UploadId: output.UploadId,
ContentLength: partLen,
}
output, uploadErr := u.s3.UploadPart(ctx, &input)
if uploadErr != nil {
// TODO: retry on failure
uploadErrorChan <- uploadErr
return
}
u.logger.With("key", key, "partNum", partNum, "partLen", partLen, "etag", *output.ETag).Debug("uploaded part")
uploadResultChan <- uploadResult{
completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum},
size: partLen,
}
}
wgDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1) // done when the reader goroutine returns
go func() {
wg.Wait()
wgDone <- struct{}{}
}()
readChan := make(chan error, 1)
go func() {
defer wg.Done()
var closing bool
currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+readBufferSizeBytes))
partNum := int32(1)
buf := make([]byte, readBufferSizeBytes)
for {
n, readErr := r.Read(buf)
if readErr == io.EOF {
closing = true
} else if readErr != nil {
readChan <- readErr
return
}
_, _ = currPart.Write(buf[:n])
if closing || currPart.Len() >= targetPartSizeBytes {
part := make([]byte, currPart.Len())
copy(part, currPart.Bytes())
currPart.Truncate(0)
wg.Add(1)
go uploadPart(&wg, part, partNum)
partNum++
}
if closing {
return
}
}
}()
results := make([]uploadResult, 0, 64)
outer:
for {
select {
case readErr := <-readChan:
if readErr != io.EOF {
return 0, fmt.Errorf("reader error: %v", readErr)
}
case uploadResult := <-uploadResultChan:
results = append(results, uploadResult)
case uploadErr := <-uploadErrorChan:
return 0, fmt.Errorf("error while uploading part: %v", uploadErr)
case <-ctx.Done():
return 0, ctx.Err()
case <-wgDone:
break outer
}
}
if len(results) == 0 {
return 0, errors.New("no parts available to upload")
}
completedParts := make([]types.CompletedPart, 0, 64)
var uploadedBytes int64
for _, result := range results {
completedParts = append(completedParts, result.completedPart)
uploadedBytes += result.size
}
// the parts may be out of order, especially with slow network conditions:
sort.Slice(completedParts, func(i, j int) bool {
return completedParts[i].PartNumber < completedParts[j].PartNumber
})
completeInput := s3.CompleteMultipartUploadInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
UploadId: output.UploadId,
MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts},
}
if _, err = u.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil {
return 0, fmt.Errorf("error completing upload: %v", err)
}
u.logger.With("key", key, "numParts", len(completedParts), "len", uploadedBytes).Debug("completed upload")
uploaded = true
return uploadedBytes, nil
}

View File

@ -64,7 +64,8 @@ type Options struct {
Timeout time.Duration
Store media.Store
YoutubeClient media.YoutubeClient
S3API media.S3API
FileStore media.FileStore
Logger *zap.Logger
}
// mediaSetServiceController implements gRPC controller for MediaSetService
@ -226,32 +227,26 @@ func (c *mediaSetServiceController) GetVideoThumbnail(ctx context.Context, reque
}
func Start(options Options) error {
logger, err := buildLogger(options.Config)
if err != nil {
return fmt.Errorf("error building logger: %v", err)
}
defer logger.Sync()
fetchMediaSetService := media.NewMediaSetService(
options.Store,
options.YoutubeClient,
options.S3API,
options.FileStore,
options.Config,
logger,
options.Logger.Sugar().Named("mediaSetService"),
)
grpcServer, err := buildGRPCServer(options.Config, logger)
grpcServer, err := buildGRPCServer(options.Config, options.Logger)
if err != nil {
return fmt.Errorf("error building server: %v", err)
}
mediaSetController := &mediaSetServiceController{mediaSetService: fetchMediaSetService, logger: logger.Sugar().Named("controller")}
mediaSetController := &mediaSetServiceController{mediaSetService: fetchMediaSetService, logger: options.Logger.Sugar().Named("controller")}
pbmediaset.RegisterMediaSetServiceServer(grpcServer, mediaSetController)
// TODO: configure CORS
grpcWebServer := grpcweb.WrapServer(grpcServer, grpcweb.WithOriginFunc(func(string) bool { return true }))
log := logger.Sugar()
log := options.Logger.Sugar()
fileHandler := http.NotFoundHandler()
if options.Config.AssetsHTTPBasePath != "" {
log.With("basePath", options.Config.AssetsHTTPBasePath).Info("Configured to serve assets over HTTP")
@ -280,13 +275,6 @@ func Start(options Options) error {
return httpServer.ListenAndServe()
}
func buildLogger(c config.Config) (*zap.Logger, error) {
if c.Environment == config.EnvProduction {
return zap.NewProduction()
}
return zap.NewDevelopment()
}
func buildGRPCServer(c config.Config, logger *zap.Logger) (*grpc.Server, error) {
unaryInterceptors := []grpc.UnaryServerInterceptor{
grpczap.UnaryServerInterceptor(logger),