Extract S3 code to S3FileStore

Re: #5
This commit is contained in:
Rob Watson 2021-12-07 20:58:11 +01:00
parent c849b8d2e6
commit f2d7af0860
13 changed files with 569 additions and 661 deletions

View File

@ -57,7 +57,7 @@ sqlc generate
Mocks require [mockery](https://github.com/vektra/mockery) to be installed, and can be regenerated with: Mocks require [mockery](https://github.com/vektra/mockery) to be installed, and can be regenerated with:
``` ```
go generate go generate ./...
``` ```
### Migrations ### Migrations

View File

@ -6,18 +6,20 @@ import (
"time" "time"
"git.netflux.io/rob/clipper/config" "git.netflux.io/rob/clipper/config"
"git.netflux.io/rob/clipper/filestore"
"git.netflux.io/rob/clipper/generated/store" "git.netflux.io/rob/clipper/generated/store"
"git.netflux.io/rob/clipper/media"
"git.netflux.io/rob/clipper/server" "git.netflux.io/rob/clipper/server"
awsconfig "github.com/aws/aws-sdk-go-v2/config" awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/pgxpool"
"github.com/kkdai/youtube/v2" "github.com/kkdai/youtube/v2"
"go.uber.org/zap"
) )
const ( const (
DefaultTimeout = 600 * time.Second defaultTimeout = 600 * time.Second
defaultURLExpiry = time.Hour
) )
func main() { func main() {
@ -35,6 +37,9 @@ func main() {
} }
store := store.New(dbConn) store := store.New(dbConn)
// Create a Youtube client
var youtubeClient youtube.Client
// Create an Amazon S3 service s3Client // Create an Amazon S3 service s3Client
cfg, err := awsconfig.LoadDefaultConfig( cfg, err := awsconfig.LoadDefaultConfig(
ctx, ctx,
@ -47,23 +52,39 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
s3Client := s3.NewFromConfig(cfg) s3Client := s3.NewFromConfig(cfg)
// Create an Amazon S3 presign client
s3PresignClient := s3.NewPresignClient(s3Client) s3PresignClient := s3.NewPresignClient(s3Client)
// Create a Youtube client // Create a logger
var youtubeClient youtube.Client logger, err := buildLogger(config)
if err != nil {
log.Fatal(err)
}
defer logger.Sync()
serverOptions := server.Options{ // Create a file store
Config: config, fileStore := filestore.NewS3FileStore(
Timeout: DefaultTimeout, filestore.S3API{
Store: store,
YoutubeClient: &youtubeClient,
S3API: media.S3API{
S3Client: s3Client, S3Client: s3Client,
S3PresignClient: s3PresignClient, S3PresignClient: s3PresignClient,
}, },
} config.S3Bucket,
defaultURLExpiry,
logger.Sugar().Named("filestore"),
)
log.Fatal(server.Start(serverOptions)) log.Fatal(server.Start(server.Options{
Config: config,
Timeout: defaultTimeout,
Store: store,
YoutubeClient: &youtubeClient,
FileStore: fileStore,
Logger: logger,
}))
}
func buildLogger(c config.Config) (*zap.Logger, error) {
if c.Environment == config.EnvDevelopment {
return zap.NewDevelopment()
}
return zap.NewProduction()
} }

280
backend/filestore/s3.go Normal file
View File

@ -0,0 +1,280 @@
package filestore
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"sort"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"go.uber.org/zap"
)
// S3API provides an API to AWS S3.
type S3API struct {
S3Client
S3PresignClient
}
// S3Client wraps the AWS S3 service client.
type S3Client interface {
GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)
CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error)
UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error)
AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error)
CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error)
}
// S3PresignClient wraps the AWS S3 Presign client.
type S3PresignClient interface {
PresignGetObject(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) (*signerv4.PresignedHTTPRequest, error)
}
// S3FileStore stores files on Amazon S3.
type S3FileStore struct {
s3 S3API
bucket string
urlExpiry time.Duration
logger *zap.SugaredLogger
}
// NewS3FileStore builds a new S3FileStore using the provided configuration.
func NewS3FileStore(s3API S3API, bucket string, urlExpiry time.Duration, logger *zap.SugaredLogger) *S3FileStore {
return &S3FileStore{
s3: s3API,
bucket: bucket,
urlExpiry: urlExpiry,
logger: logger,
}
}
// GetObject returns an io.Reader that returns an object associated with the
// provided key.
func (s *S3FileStore) GetObject(ctx context.Context, key string) (io.ReadCloser, error) {
input := s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
}
output, err := s.s3.GetObject(ctx, &input)
if err != nil {
return nil, fmt.Errorf("error getting object from s3: %v", err)
}
return output.Body, nil
}
// GetObjectWithRange returns an io.Reader that returns a partial object
// associated with the provided key.
func (s *S3FileStore) GetObjectWithRange(ctx context.Context, key string, start, end int64) (io.ReadCloser, error) {
byteRange := fmt.Sprintf("bytes=%d-%d", start, end)
input := s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
Range: aws.String(byteRange),
}
output, err := s.s3.GetObject(ctx, &input)
if err != nil {
return nil, fmt.Errorf("error getting object from s3: %v", err)
}
return output.Body, nil
}
// GetURL returns a presigned URL pointing to the object associated with the
// provided key.
func (s *S3FileStore) GetURL(ctx context.Context, key string) (string, error) {
input := s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
}
request, err := s.s3.PresignGetObject(ctx, &input, s3.WithPresignExpires(s.urlExpiry))
if err != nil {
return "", fmt.Errorf("error generating presigned URL: %v", err)
}
return request.URL, nil
}
// PutObject uploads an object using multipart upload, returning the number of
// bytes uploaded and any error.
func (s *S3FileStore) PutObject(ctx context.Context, key string, r io.Reader, contentType string) (int64, error) {
const (
targetPartSizeBytes = 5 * 1024 * 1024 // 5MB
readBufferSizeBytes = 32_768 // 32Kb
)
var uploaded bool
input := s3.CreateMultipartUploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
ContentType: aws.String(contentType),
}
output, err := s.s3.CreateMultipartUpload(ctx, &input)
if err != nil {
return 0, fmt.Errorf("error creating multipart upload: %v", err)
}
// abort the upload if possible, logging any errors, on exit.
defer func() {
if uploaded {
return
}
input := s3.AbortMultipartUploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
UploadId: output.UploadId,
}
// if the context was cancelled, just use the background context.
ctxToUse := ctx
if ctxToUse.Err() != nil {
ctxToUse = context.Background()
}
_, deferErr := s.s3.AbortMultipartUpload(ctxToUse, &input)
if deferErr != nil {
s.logger.Errorf("uploader: error aborting upload: %v", deferErr)
} else {
s.logger.Infof("aborted upload, key = %s", key)
}
}()
type uploadedPart struct {
part types.CompletedPart
size int64
}
uploadResultChan := make(chan uploadedPart)
uploadErrorChan := make(chan error, 1)
// uploadPart uploads an individual part.
uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) {
defer wg.Done()
partLen := int64(len(buf))
s.logger.With("key", key, "partNum", partNum, "partLen", partLen).Debug("uploading part")
input := s3.UploadPartInput{
Body: bytes.NewReader(buf),
Bucket: aws.String(s.bucket),
Key: aws.String(key),
PartNumber: partNum,
UploadId: output.UploadId,
ContentLength: partLen,
}
output, uploadErr := s.s3.UploadPart(ctx, &input)
if uploadErr != nil {
// TODO: retry on failure
uploadErrorChan <- uploadErr
return
}
s.logger.With("key", key, "partNum", partNum, "partLen", partLen, "etag", *output.ETag).Debug("uploaded part")
uploadResultChan <- uploadedPart{
part: types.CompletedPart{ETag: output.ETag, PartNumber: partNum},
size: partLen,
}
}
wgDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1) // done when the reader goroutine returns
go func() {
wg.Wait()
wgDone <- struct{}{}
}()
readChan := make(chan error, 1)
go func() {
defer wg.Done()
var closing bool
currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+readBufferSizeBytes))
partNum := int32(1)
buf := make([]byte, readBufferSizeBytes)
for {
n, readErr := r.Read(buf)
if readErr == io.EOF {
closing = true
} else if readErr != nil {
readChan <- readErr
return
}
_, _ = currPart.Write(buf[:n])
if closing || currPart.Len() >= targetPartSizeBytes {
part := make([]byte, currPart.Len())
copy(part, currPart.Bytes())
currPart.Truncate(0)
wg.Add(1)
go uploadPart(&wg, part, partNum)
partNum++
}
if closing {
return
}
}
}()
results := make([]uploadedPart, 0, 64)
outer:
for {
select {
case readErr := <-readChan:
if readErr != io.EOF {
return 0, fmt.Errorf("reader error: %v", readErr)
}
case uploadResult := <-uploadResultChan:
results = append(results, uploadResult)
case uploadErr := <-uploadErrorChan:
return 0, fmt.Errorf("error while uploading part: %v", uploadErr)
case <-ctx.Done():
return 0, ctx.Err()
case <-wgDone:
break outer
}
}
if len(results) == 0 {
return 0, errors.New("no parts available to upload")
}
completedParts := make([]types.CompletedPart, 0, 64)
var uploadedBytes int64
for _, result := range results {
completedParts = append(completedParts, result.part)
uploadedBytes += result.size
}
// the parts may be out of order, especially with slow network conditions:
sort.Slice(completedParts, func(i, j int) bool {
return completedParts[i].PartNumber < completedParts[j].PartNumber
})
completeInput := s3.CompleteMultipartUploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
UploadId: output.UploadId,
MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts},
}
if _, err = s.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil {
return 0, fmt.Errorf("error completing upload: %v", err)
}
s.logger.With("key", key, "numParts", len(completedParts), "len", uploadedBytes).Debug("completed upload")
uploaded = true
return uploadedBytes, nil
}

View File

@ -0,0 +1,103 @@
// Code generated by mockery v2.9.4. DO NOT EDIT.
package mocks
import (
context "context"
io "io"
mock "github.com/stretchr/testify/mock"
)
// FileStore is an autogenerated mock type for the FileStore type
type FileStore struct {
mock.Mock
}
// GetObject provides a mock function with given fields: _a0, _a1
func (_m *FileStore) GetObject(_a0 context.Context, _a1 string) (io.ReadCloser, error) {
ret := _m.Called(_a0, _a1)
var r0 io.ReadCloser
if rf, ok := ret.Get(0).(func(context.Context, string) io.ReadCloser); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(io.ReadCloser)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetObjectWithRange provides a mock function with given fields: _a0, _a1, _a2, _a3
func (_m *FileStore) GetObjectWithRange(_a0 context.Context, _a1 string, _a2 int64, _a3 int64) (io.ReadCloser, error) {
ret := _m.Called(_a0, _a1, _a2, _a3)
var r0 io.ReadCloser
if rf, ok := ret.Get(0).(func(context.Context, string, int64, int64) io.ReadCloser); ok {
r0 = rf(_a0, _a1, _a2, _a3)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(io.ReadCloser)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, int64, int64) error); ok {
r1 = rf(_a0, _a1, _a2, _a3)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetURL provides a mock function with given fields: _a0, _a1
func (_m *FileStore) GetURL(_a0 context.Context, _a1 string) (string, error) {
ret := _m.Called(_a0, _a1)
var r0 string
if rf, ok := ret.Get(0).(func(context.Context, string) string); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Get(0).(string)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// PutObject provides a mock function with given fields: _a0, _a1, _a2, _a3
func (_m *FileStore) PutObject(_a0 context.Context, _a1 string, _a2 io.Reader, _a3 string) (int64, error) {
ret := _m.Called(_a0, _a1, _a2, _a3)
var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, string, io.Reader, string) int64); ok {
r0 = rf(_a0, _a1, _a2, _a3)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, io.Reader, string) error); ok {
r1 = rf(_a0, _a1, _a2, _a3)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@ -1,166 +0,0 @@
// Code generated by mockery v2.9.4. DO NOT EDIT.
package mocks
import (
context "context"
mock "github.com/stretchr/testify/mock"
s3 "github.com/aws/aws-sdk-go-v2/service/s3"
)
// S3Client is an autogenerated mock type for the S3Client type
type S3Client struct {
mock.Mock
}
// AbortMultipartUpload provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3Client) AbortMultipartUpload(_a0 context.Context, _a1 *s3.AbortMultipartUploadInput, _a2 ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *s3.AbortMultipartUploadOutput
if rf, ok := ret.Get(0).(func(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) *s3.AbortMultipartUploadOutput); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*s3.AbortMultipartUploadOutput)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CompleteMultipartUpload provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3Client) CompleteMultipartUpload(_a0 context.Context, _a1 *s3.CompleteMultipartUploadInput, _a2 ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *s3.CompleteMultipartUploadOutput
if rf, ok := ret.Get(0).(func(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) *s3.CompleteMultipartUploadOutput); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*s3.CompleteMultipartUploadOutput)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CreateMultipartUpload provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3Client) CreateMultipartUpload(_a0 context.Context, _a1 *s3.CreateMultipartUploadInput, _a2 ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *s3.CreateMultipartUploadOutput
if rf, ok := ret.Get(0).(func(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) *s3.CreateMultipartUploadOutput); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*s3.CreateMultipartUploadOutput)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetObject provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3Client) GetObject(_a0 context.Context, _a1 *s3.GetObjectInput, _a2 ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *s3.GetObjectOutput
if rf, ok := ret.Get(0).(func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) *s3.GetObjectOutput); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*s3.GetObjectOutput)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UploadPart provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3Client) UploadPart(_a0 context.Context, _a1 *s3.UploadPartInput, _a2 ...func(*s3.Options)) (*s3.UploadPartOutput, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *s3.UploadPartOutput
if rf, ok := ret.Get(0).(func(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) *s3.UploadPartOutput); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*s3.UploadPartOutput)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@ -1,48 +0,0 @@
// Code generated by mockery v2.9.4. DO NOT EDIT.
package mocks
import (
context "context"
mock "github.com/stretchr/testify/mock"
s3 "github.com/aws/aws-sdk-go-v2/service/s3"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
)
// S3PresignClient is an autogenerated mock type for the S3PresignClient type
type S3PresignClient struct {
mock.Mock
}
// PresignGetObject provides a mock function with given fields: _a0, _a1, _a2
func (_m *S3PresignClient) PresignGetObject(_a0 context.Context, _a1 *s3.GetObjectInput, _a2 ...func(*s3.PresignOptions)) (*v4.PresignedHTTPRequest, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *v4.PresignedHTTPRequest
if rf, ok := ret.Get(0).(func(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) *v4.PresignedHTTPRequest); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v4.PresignedHTTPRequest)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@ -13,8 +13,6 @@ import (
"git.netflux.io/rob/clipper/config" "git.netflux.io/rob/clipper/config"
"git.netflux.io/rob/clipper/generated/store" "git.netflux.io/rob/clipper/generated/store"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -31,27 +29,27 @@ type GetAudioProgressReader interface {
// audioGetter manages getting and processing audio from Youtube. // audioGetter manages getting and processing audio from Youtube.
type audioGetter struct { type audioGetter struct {
store Store store Store
youtube YoutubeClient youtube YoutubeClient
s3API S3API fileStore FileStore
config config.Config config config.Config
logger *zap.SugaredLogger logger *zap.SugaredLogger
} }
// newAudioGetter returns a new audioGetter. // newAudioGetter returns a new audioGetter.
func newAudioGetter(store Store, youtube YoutubeClient, s3API S3API, config config.Config, logger *zap.SugaredLogger) *audioGetter { func newAudioGetter(store Store, youtube YoutubeClient, fileStore FileStore, config config.Config, logger *zap.SugaredLogger) *audioGetter {
return &audioGetter{ return &audioGetter{
store: store, store: store,
youtube: youtube, youtube: youtube,
s3API: s3API, fileStore: fileStore,
config: config, config: config,
logger: logger, logger: logger,
} }
} }
// GetAudio gets the audio, processes it and uploads it to S3. It returns a // GetAudio gets the audio, processes it and uploads it to a file store. It
// GetAudioProgressReader that can be used to poll progress reports and audio // returns a GetAudioProgressReader that can be used to poll progress reports
// peaks. // and audio peaks.
// //
// TODO: accept domain object instead // TODO: accept domain object instead
func (g *audioGetter) GetAudio(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) { func (g *audioGetter) GetAudio(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) {
@ -114,33 +112,28 @@ func (s *audioGetterState) getAudio(ctx context.Context, r io.ReadCloser, mediaS
wg.Add(2) wg.Add(2)
// Upload the encoded audio. // Upload the encoded audio.
// TODO: fix error shadowing in these two goroutines.
go func() { go func() {
defer wg.Done() defer wg.Done()
// TODO: use mediaSet func to fetch s3Key // TODO: use mediaSet func to fetch key
s3Key := fmt.Sprintf("media_sets/%s/audio.opus", mediaSet.ID) key := fmt.Sprintf("media_sets/%s/audio.opus", mediaSet.ID)
uploader := newMultipartUploader(s.s3API, s.logger) _, encErr := s.fileStore.PutObject(ctx, key, pr, "audio/opus")
_, encErr := uploader.Upload(ctx, pr, s.config.S3Bucket, s3Key, "audio/opus")
if encErr != nil { if encErr != nil {
s.CloseWithError(fmt.Errorf("error uploading encoded audio: %v", encErr)) s.CloseWithError(fmt.Errorf("error uploading encoded audio: %v", encErr))
return return
} }
input := s3.GetObjectInput{ presignedAudioURL, err = s.fileStore.GetURL(ctx, key)
Bucket: aws.String(s.config.S3Bucket),
Key: aws.String(s3Key),
}
request, err := s.s3API.PresignGetObject(ctx, &input, s3.WithPresignExpires(getAudioExpiresIn))
if err != nil { if err != nil {
s.CloseWithError(fmt.Errorf("error generating presigned URL: %v", err)) s.CloseWithError(fmt.Errorf("error generating presigned URL: %v", err))
} }
presignedAudioURL = request.URL
if _, err = s.store.SetEncodedAudioUploaded(ctx, store.SetEncodedAudioUploadedParams{ if _, err = s.store.SetEncodedAudioUploaded(ctx, store.SetEncodedAudioUploadedParams{
ID: mediaSet.ID, ID: mediaSet.ID,
AudioEncodedS3Bucket: sqlString(s.config.S3Bucket), AudioEncodedS3Bucket: sqlString(s.config.S3Bucket),
AudioEncodedS3Key: sqlString(s3Key), AudioEncodedS3Key: sqlString(key),
}); err != nil { }); err != nil {
s.CloseWithError(fmt.Errorf("error setting encoded audio uploaded: %v", err)) s.CloseWithError(fmt.Errorf("error setting encoded audio uploaded: %v", err))
} }
@ -150,12 +143,11 @@ func (s *audioGetterState) getAudio(ctx context.Context, r io.ReadCloser, mediaS
go func() { go func() {
defer wg.Done() defer wg.Done()
// TODO: use mediaSet func to fetch s3Key // TODO: use mediaSet func to fetch key
s3Key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID) key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID)
teeReader := io.TeeReader(stdout, s) teeReader := io.TeeReader(stdout, s)
uploader := newMultipartUploader(s.s3API, s.logger) bytesUploaded, rawErr := s.fileStore.PutObject(ctx, key, teeReader, rawAudioMimeType)
bytesUploaded, rawErr := uploader.Upload(ctx, teeReader, s.config.S3Bucket, s3Key, rawAudioMimeType)
if rawErr != nil { if rawErr != nil {
s.CloseWithError(fmt.Errorf("error uploading raw audio: %v", rawErr)) s.CloseWithError(fmt.Errorf("error uploading raw audio: %v", rawErr))
return return
@ -164,7 +156,7 @@ func (s *audioGetterState) getAudio(ctx context.Context, r io.ReadCloser, mediaS
if _, err = s.store.SetRawAudioUploaded(ctx, store.SetRawAudioUploadedParams{ if _, err = s.store.SetRawAudioUploaded(ctx, store.SetRawAudioUploadedParams{
ID: mediaSet.ID, ID: mediaSet.ID,
AudioRawS3Bucket: sqlString(s.config.S3Bucket), AudioRawS3Bucket: sqlString(s.config.S3Bucket),
AudioRawS3Key: sqlString(s3Key), AudioRawS3Key: sqlString(key),
AudioFrames: sqlInt64(bytesUploaded / SizeOfInt16 / int64(mediaSet.AudioChannels)), AudioFrames: sqlInt64(bytesUploaded / SizeOfInt16 / int64(mediaSet.AudioChannels)),
}); err != nil { }); err != nil {
s.CloseWithError(fmt.Errorf("error setting raw audio uploaded: %v", err)) s.CloseWithError(fmt.Errorf("error setting raw audio uploaded: %v", err))

View File

@ -6,8 +6,6 @@ import (
"io" "io"
"git.netflux.io/rob/clipper/generated/store" "git.netflux.io/rob/clipper/generated/store"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -24,9 +22,9 @@ type GetVideoProgressReader interface {
} }
type videoGetter struct { type videoGetter struct {
s3 S3API store Store
store Store fileStore FileStore
logger *zap.SugaredLogger logger *zap.SugaredLogger
} }
type videoGetterState struct { type videoGetterState struct {
@ -41,21 +39,20 @@ type videoGetterState struct {
errorChan chan error errorChan chan error
} }
func newVideoGetter(s3 S3API, store Store, logger *zap.SugaredLogger) *videoGetter { func newVideoGetter(store Store, fileStore FileStore, logger *zap.SugaredLogger) *videoGetter {
return &videoGetter{s3: s3, store: store, logger: logger} return &videoGetter{store: store, fileStore: fileStore, logger: logger}
} }
// GetVideo gets video from Youtube and uploads it to S3 using the specified // GetVideo gets video from Youtube and uploads it to a filestore using the
// bucket, key and content type. The returned reader must have its Next() // specified key and content type. The returned reader must have its Next()
// method called until error = io.EOF, otherwise a deadlock or other resource // method called until error = io.EOF, otherwise a deadlock or other resource
// leakage is likely. // leakage is likely.
func (g *videoGetter) GetVideo(ctx context.Context, r io.Reader, exp int64, mediaSetID uuid.UUID, bucket, key, contentType string) (GetVideoProgressReader, error) { func (g *videoGetter) GetVideo(ctx context.Context, r io.Reader, exp int64, mediaSetID uuid.UUID, key, contentType string) (GetVideoProgressReader, error) {
s := &videoGetterState{ s := &videoGetterState{
videoGetter: g, videoGetter: g,
r: newProgressReader(r, "video", exp, g.logger), r: newProgressReader(r, "video", exp, g.logger),
exp: exp, exp: exp,
mediaSetID: mediaSetID, mediaSetID: mediaSetID,
bucket: bucket,
key: key, key: key,
contentType: contentType, contentType: contentType,
progressChan: make(chan GetVideoProgress), progressChan: make(chan GetVideoProgress),
@ -69,7 +66,7 @@ func (g *videoGetter) GetVideo(ctx context.Context, r io.Reader, exp int64, medi
} }
// Write implements io.Writer. It is copied that same data that is written to // Write implements io.Writer. It is copied that same data that is written to
// S3, to implement progress tracking. // the file store, to implement progress tracking.
func (s *videoGetterState) Write(p []byte) (int, error) { func (s *videoGetterState) Write(p []byte) (int, error) {
s.count += int64(len(p)) s.count += int64(len(p))
pc := (float32(s.count) / float32(s.exp)) * 100 pc := (float32(s.count) / float32(s.exp)) * 100
@ -78,24 +75,18 @@ func (s *videoGetterState) Write(p []byte) (int, error) {
} }
func (s *videoGetterState) getVideo(ctx context.Context) { func (s *videoGetterState) getVideo(ctx context.Context) {
uploader := newMultipartUploader(s.s3, s.logger)
teeReader := io.TeeReader(s.r, s) teeReader := io.TeeReader(s.r, s)
_, err := uploader.Upload(ctx, teeReader, s.bucket, s.key, s.contentType) _, err := s.fileStore.PutObject(ctx, s.key, teeReader, s.contentType)
if err != nil { if err != nil {
s.errorChan <- fmt.Errorf("error uploading to S3: %v", err) s.errorChan <- fmt.Errorf("error uploading to file store: %v", err)
return return
} }
input := s3.GetObjectInput{ s.url, err = s.fileStore.GetURL(ctx, s.key)
Bucket: aws.String(s.bucket),
Key: aws.String(s.key),
}
request, err := s.s3.PresignGetObject(ctx, &input, s3.WithPresignExpires(getVideoExpiresIn))
if err != nil { if err != nil {
s.errorChan <- fmt.Errorf("error generating presigned URL: %v", err) s.errorChan <- fmt.Errorf("error getting object URL: %v", err)
} }
s.url = request.URL
storeParams := store.SetVideoUploadedParams{ storeParams := store.SetVideoUploadedParams{
ID: s.mediaSetID, ID: s.mediaSetID,

View File

@ -14,8 +14,6 @@ import (
"git.netflux.io/rob/clipper/config" "git.netflux.io/rob/clipper/config"
"git.netflux.io/rob/clipper/generated/store" "git.netflux.io/rob/clipper/generated/store"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
youtubev2 "github.com/kkdai/youtube/v2" youtubev2 "github.com/kkdai/youtube/v2"
@ -41,27 +39,27 @@ const (
// MediaSetService exposes logical flows handling MediaSets. // MediaSetService exposes logical flows handling MediaSets.
type MediaSetService struct { type MediaSetService struct {
store Store store Store
youtube YoutubeClient youtube YoutubeClient
s3 S3API fileStore FileStore
config config.Config config config.Config
logger *zap.SugaredLogger logger *zap.SugaredLogger
} }
func NewMediaSetService(store Store, youtubeClient YoutubeClient, s3API S3API, config config.Config, logger *zap.Logger) *MediaSetService { func NewMediaSetService(store Store, youtubeClient YoutubeClient, fileStore FileStore, config config.Config, logger *zap.SugaredLogger) *MediaSetService {
return &MediaSetService{ return &MediaSetService{
store: store, store: store,
youtube: youtubeClient, youtube: youtubeClient,
s3: s3API, fileStore: fileStore,
config: config, config: config,
logger: logger.Sugar(), logger: logger,
} }
} }
// Get fetches the metadata for a given MediaSet source. If it does not exist // Get fetches the metadata for a given MediaSet source. If it does not exist
// in the local DB, it will attempt to create it. After the resource has been // in the local DB, it will attempt to create it. After the resource has been
// created, other endpoints (e.g. GetAudio) can be called to fetch media from // created, other endpoints (e.g. GetAudio) can be called to fetch media from
// Youtube and store it in S3. // Youtube and store it in a file store.
func (s *MediaSetService) Get(ctx context.Context, youtubeID string) (*MediaSet, error) { func (s *MediaSetService) Get(ctx context.Context, youtubeID string) (*MediaSet, error) {
var ( var (
mediaSet *MediaSet mediaSet *MediaSet
@ -220,15 +218,11 @@ func (s *MediaSetService) GetVideo(ctx context.Context, id uuid.UUID) (GetVideoP
} }
if mediaSet.VideoS3UploadedAt.Valid { if mediaSet.VideoS3UploadedAt.Valid {
input := s3.GetObjectInput{ url, err := s.fileStore.GetURL(ctx, mediaSet.VideoS3Key.String)
Bucket: aws.String(s.config.S3Bucket), if err != nil {
Key: aws.String(mediaSet.VideoS3Key.String), return nil, fmt.Errorf("error generating presigned URL: %v", err)
} }
request, signErr := s.s3.PresignGetObject(ctx, &input, s3.WithPresignExpires(getVideoExpiresIn)) videoGetter := videoGetterDownloaded(url)
if signErr != nil {
return nil, fmt.Errorf("error generating presigned URL: %v", signErr)
}
videoGetter := videoGetterDownloaded(request.URL)
return &videoGetter, nil return &videoGetter, nil
} }
@ -247,17 +241,16 @@ func (s *MediaSetService) GetVideo(ctx context.Context, id uuid.UUID) (GetVideoP
return nil, fmt.Errorf("error fetching stream: %v", err) return nil, fmt.Errorf("error fetching stream: %v", err)
} }
// TODO: use mediaSet func to fetch s3Key // TODO: use mediaSet func to fetch videoKey
s3Key := fmt.Sprintf("media_sets/%s/video.mp4", mediaSet.ID) videoKey := fmt.Sprintf("media_sets/%s/video.mp4", mediaSet.ID)
videoGetter := newVideoGetter(s.s3, s.store, s.logger) videoGetter := newVideoGetter(s.store, s.fileStore, s.logger)
return videoGetter.GetVideo( return videoGetter.GetVideo(
ctx, ctx,
stream, stream,
format.ContentLength, format.ContentLength,
mediaSet.ID, mediaSet.ID,
s.config.S3Bucket, videoKey,
s3Key,
format.MimeType, format.MimeType,
) )
} }
@ -273,25 +266,21 @@ func (s *MediaSetService) GetAudio(ctx context.Context, id uuid.UUID, numBins in
// Otherwise, we cannot return both peaks and a presigned URL for use by the // Otherwise, we cannot return both peaks and a presigned URL for use by the
// player. // player.
if mediaSet.AudioRawS3UploadedAt.Valid && mediaSet.AudioEncodedS3UploadedAt.Valid { if mediaSet.AudioRawS3UploadedAt.Valid && mediaSet.AudioEncodedS3UploadedAt.Valid {
return s.getAudioFromS3(ctx, mediaSet, numBins) return s.getAudioFromFileStore(ctx, mediaSet, numBins)
} }
return s.getAudioFromYoutube(ctx, mediaSet, numBins) return s.getAudioFromYoutube(ctx, mediaSet, numBins)
} }
func (s *MediaSetService) getAudioFromYoutube(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) { func (s *MediaSetService) getAudioFromYoutube(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) {
audioGetter := newAudioGetter(s.store, s.youtube, s.s3, s.config, s.logger) audioGetter := newAudioGetter(s.store, s.youtube, s.fileStore, s.config, s.logger)
return audioGetter.GetAudio(ctx, mediaSet, numBins) return audioGetter.GetAudio(ctx, mediaSet, numBins)
} }
func (s *MediaSetService) getAudioFromS3(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) { func (s *MediaSetService) getAudioFromFileStore(ctx context.Context, mediaSet store.MediaSet, numBins int) (GetAudioProgressReader, error) {
input := s3.GetObjectInput{ object, err := s.fileStore.GetObject(ctx, mediaSet.AudioRawS3Key.String)
Bucket: aws.String(mediaSet.AudioRawS3Bucket.String),
Key: aws.String(mediaSet.AudioRawS3Key.String),
}
output, err := s.s3.GetObject(ctx, &input)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting object from s3: %v", err) return nil, fmt.Errorf("error getting object from file store: %v", err)
} }
getAudioProgressReader, err := newGetAudioProgressReader( getAudioProgressReader, err := newGetAudioProgressReader(
@ -303,10 +292,10 @@ func (s *MediaSetService) getAudioFromS3(ctx context.Context, mediaSet store.Med
return nil, fmt.Errorf("error creating audio reader: %v", err) return nil, fmt.Errorf("error creating audio reader: %v", err)
} }
state := getAudioFromS3State{ state := getAudioFromFileStoreState{
getAudioProgressReader: getAudioProgressReader, getAudioProgressReader: getAudioProgressReader,
s3Reader: NewModuloBufReader(output.Body, int(mediaSet.AudioChannels)*SizeOfInt16), reader: NewModuloBufReader(object, int(mediaSet.AudioChannels)*SizeOfInt16),
s3API: s.s3, fileStore: s.fileStore,
config: s.config, config: s.config,
logger: s.logger, logger: s.logger,
} }
@ -315,21 +304,21 @@ func (s *MediaSetService) getAudioFromS3(ctx context.Context, mediaSet store.Med
return &state, nil return &state, nil
} }
type getAudioFromS3State struct { type getAudioFromFileStoreState struct {
*getAudioProgressReader *getAudioProgressReader
s3Reader io.ReadCloser reader io.ReadCloser
s3API S3API fileStore FileStore
config config.Config config config.Config
logger *zap.SugaredLogger logger *zap.SugaredLogger
} }
func (s *getAudioFromS3State) run(ctx context.Context, mediaSet store.MediaSet) { func (s *getAudioFromFileStoreState) run(ctx context.Context, mediaSet store.MediaSet) {
done := make(chan error) done := make(chan error)
var err error var err error
go func() { go func() {
_, copyErr := io.Copy(s, s.s3Reader) _, copyErr := io.Copy(s, s.reader)
done <- copyErr done <- copyErr
}() }()
@ -344,29 +333,25 @@ outer:
} }
} }
if readerErr := s.s3Reader.Close(); readerErr != nil { if readerErr := s.reader.Close(); readerErr != nil {
if err == nil { if err == nil {
err = readerErr err = readerErr
} }
} }
if err != nil { if err != nil {
s.logger.Errorf("getAudioFromS3State: error closing s3 reader: %v", err) s.logger.Errorf("getAudioFromFileStoreState: error closing reader: %v", err)
s.CloseWithError(err) s.CloseWithError(err)
return return
} }
input := s3.GetObjectInput{ url, err := s.fileStore.GetURL(ctx, mediaSet.AudioEncodedS3Key.String)
Bucket: aws.String(s.config.S3Bucket),
Key: aws.String(mediaSet.AudioEncodedS3Key.String),
}
request, err := s.s3API.PresignGetObject(ctx, &input, s3.WithPresignExpires(getAudioExpiresIn))
if err != nil { if err != nil {
s.CloseWithError(fmt.Errorf("error generating presigned URL: %v", err)) s.CloseWithError(fmt.Errorf("error generating object URL: %v", err))
} }
if iterErr := s.Close(request.URL); iterErr != nil { if iterErr := s.Close(url); iterErr != nil {
s.logger.Errorf("getAudioFromS3State: error closing progress iterator: %v", iterErr) s.logger.Errorf("getAudioFromFileStoreState: error closing progress iterator: %v", iterErr)
} }
} }
@ -381,25 +366,20 @@ func (s *MediaSetService) GetAudioSegment(ctx context.Context, id uuid.UUID, sta
return nil, fmt.Errorf("error getting media set: %v", err) return nil, fmt.Errorf("error getting media set: %v", err)
} }
byteRange := fmt.Sprintf( object, err := s.fileStore.GetObjectWithRange(
"bytes=%d-%d", ctx,
mediaSet.AudioRawS3Key.String,
startFrame*int64(mediaSet.AudioChannels)*SizeOfInt16, startFrame*int64(mediaSet.AudioChannels)*SizeOfInt16,
endFrame*int64(mediaSet.AudioChannels)*SizeOfInt16, endFrame*int64(mediaSet.AudioChannels)*SizeOfInt16,
) )
input := s3.GetObjectInput{
Bucket: aws.String(mediaSet.AudioRawS3Bucket.String),
Key: aws.String(mediaSet.AudioRawS3Key.String),
Range: aws.String(byteRange),
}
output, err := s.s3.GetObject(ctx, &input)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting object from s3: %v", err) return nil, fmt.Errorf("error getting object from file store: %v", err)
} }
defer output.Body.Close() defer object.Close()
const readBufSizeBytes = 8_192 const readBufSizeBytes = 8_192
channels := int(mediaSet.AudioChannels) channels := int(mediaSet.AudioChannels)
modReader := NewModuloBufReader(output.Body, channels*SizeOfInt16) modReader := NewModuloBufReader(object, channels*SizeOfInt16)
readBuf := make([]byte, readBufSizeBytes) readBuf := make([]byte, readBufSizeBytes)
peaks := make([]int16, channels*numBins) peaks := make([]int16, channels*numBins)
totalFrames := endFrame - startFrame totalFrames := endFrame - startFrame
@ -454,7 +434,7 @@ func (s *MediaSetService) GetAudioSegment(ctx context.Context, id uuid.UUID, sta
} }
if bytesRead < bytesExpected { if bytesRead < bytesExpected {
s.logger.With("startFrame", startFrame, "endFrame", endFrame, "got", bytesRead, "want", bytesExpected, "key", mediaSet.AudioRawS3Key.String).Warn("short read from S3") s.logger.With("startFrame", startFrame, "endFrame", endFrame, "got", bytesRead, "want", bytesExpected, "key", mediaSet.AudioRawS3Key.String).Warn("short read from file store")
} }
return peaks, nil return peaks, nil
@ -521,26 +501,22 @@ func (s *MediaSetService) GetVideoThumbnail(ctx context.Context, id uuid.UUID) (
} }
if mediaSet.VideoThumbnailS3UploadedAt.Valid { if mediaSet.VideoThumbnailS3UploadedAt.Valid {
return s.getThumbnailFromS3(ctx, mediaSet) return s.getThumbnailFromFileStore(ctx, mediaSet)
} }
return s.getThumbnailFromYoutube(ctx, mediaSet) return s.getThumbnailFromYoutube(ctx, mediaSet)
} }
func (s *MediaSetService) getThumbnailFromS3(ctx context.Context, mediaSet store.MediaSet) (VideoThumbnail, error) { func (s *MediaSetService) getThumbnailFromFileStore(ctx context.Context, mediaSet store.MediaSet) (VideoThumbnail, error) {
input := s3.GetObjectInput{ object, err := s.fileStore.GetObject(ctx, mediaSet.VideoThumbnailS3Key.String)
Bucket: aws.String(mediaSet.VideoThumbnailS3Bucket.String),
Key: aws.String(mediaSet.VideoThumbnailS3Key.String),
}
output, err := s.s3.GetObject(ctx, &input)
if err != nil { if err != nil {
return VideoThumbnail{}, fmt.Errorf("error fetching thumbnail from s3: %v", err) return VideoThumbnail{}, fmt.Errorf("error fetching thumbnail from file store: %v", err)
} }
defer output.Body.Close() defer object.Close()
imageData, err := io.ReadAll(output.Body) imageData, err := io.ReadAll(object)
if err != nil { if err != nil {
return VideoThumbnail{}, fmt.Errorf("error reading thumbnail from s3: %v", err) return VideoThumbnail{}, fmt.Errorf("error reading thumbnail from file store: %v", err)
} }
return VideoThumbnail{ return VideoThumbnail{
@ -575,13 +551,11 @@ func (s *MediaSetService) getThumbnailFromYoutube(ctx context.Context, mediaSet
return VideoThumbnail{}, fmt.Errorf("error reading thumbnail: %v", err) return VideoThumbnail{}, fmt.Errorf("error reading thumbnail: %v", err)
} }
// TODO: use mediaSet func to fetch s3Key // TODO: use mediaSet func to fetch key
s3Key := fmt.Sprintf("media_sets/%s/thumbnail.jpg", mediaSet.ID) thumbnailKey := fmt.Sprintf("media_sets/%s/thumbnail.jpg", mediaSet.ID)
uploader := newMultipartUploader(s.s3, s.logger)
const mimeType = "application/jpeg" const mimeType = "application/jpeg"
_, err = s.fileStore.PutObject(ctx, thumbnailKey, bytes.NewReader(imageData), mimeType)
_, err = uploader.Upload(ctx, bytes.NewReader(imageData), s.config.S3Bucket, s3Key, mimeType)
if err != nil { if err != nil {
return VideoThumbnail{}, fmt.Errorf("error uploading thumbnail: %v", err) return VideoThumbnail{}, fmt.Errorf("error uploading thumbnail: %v", err)
} }
@ -590,7 +564,7 @@ func (s *MediaSetService) getThumbnailFromYoutube(ctx context.Context, mediaSet
ID: mediaSet.ID, ID: mediaSet.ID,
VideoThumbnailMimeType: sqlString(mimeType), VideoThumbnailMimeType: sqlString(mimeType),
VideoThumbnailS3Bucket: sqlString(s.config.S3Bucket), VideoThumbnailS3Bucket: sqlString(s.config.S3Bucket),
VideoThumbnailS3Key: sqlString(s3Key), VideoThumbnailS3Key: sqlString(thumbnailKey),
VideoThumbnailWidth: sqlInt32(int32(thumbnail.Width)), VideoThumbnailWidth: sqlInt32(int32(thumbnail.Width)),
VideoThumbnailHeight: sqlInt32(int32(thumbnail.Height)), VideoThumbnailHeight: sqlInt32(int32(thumbnail.Height)),
} }
@ -614,7 +588,7 @@ func newProgressReader(reader io.Reader, label string, exp int64, logger *zap.Su
return &progressReader{ return &progressReader{
Reader: reader, Reader: reader,
exp: exp, exp: exp,
logger: logger.Named(fmt.Sprintf("ProgressReader %s", label)), logger: logger.Named(label),
} }
} }

View File

@ -1,8 +1,9 @@
package media_test package media_test
import ( import (
"bytes"
"context" "context"
"fmt" "database/sql"
"io" "io"
"os" "os"
"testing" "testing"
@ -11,7 +12,6 @@ import (
"git.netflux.io/rob/clipper/generated/mocks" "git.netflux.io/rob/clipper/generated/mocks"
"git.netflux.io/rob/clipper/generated/store" "git.netflux.io/rob/clipper/generated/store"
"git.netflux.io/rob/clipper/media" "git.netflux.io/rob/clipper/media"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
@ -84,31 +84,34 @@ func TestGetAudioSegment(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
expectedBytes := (tc.endFrame - tc.startFrame) * int64(tc.channels) * media.SizeOfInt16 startByte := tc.startFrame * int64(tc.channels) * media.SizeOfInt16
endByte := tc.endFrame * int64(tc.channels) * media.SizeOfInt16
expectedBytes := endByte - startByte
audioFile, err := os.Open(tc.fixturePath) audioFile, err := os.Open(tc.fixturePath)
require.NoError(t, err) require.NoError(t, err)
defer audioFile.Close() defer audioFile.Close()
audioData := io.NopCloser(io.LimitReader(audioFile, int64(expectedBytes))) audioData := io.NopCloser(io.LimitReader(audioFile, int64(expectedBytes)))
mediaSetID := uuid.New() mediaSet := store.MediaSet{
mediaSet := store.MediaSet{ID: mediaSetID, AudioChannels: tc.channels} ID: uuid.New(),
AudioChannels: tc.channels,
AudioRawS3Key: sql.NullString{String: "foo", Valid: true},
}
// store is passed the mediaSetID and returns a mediaSet // store is passed the mediaSetID and returns a mediaSet
store := &mocks.Store{} store := &mocks.Store{}
store.On("GetMediaSet", mock.Anything, mediaSetID).Return(mediaSet, nil) store.On("GetMediaSet", mock.Anything, mediaSet.ID).Return(mediaSet, nil)
defer store.AssertExpectations(t) defer store.AssertExpectations(t)
// S3 is passed the expected byte range, and returns an io.Reader // fileStore is passed the expected byte range, and returns an io.Reader
s3Client := &mocks.S3Client{} fileStore := &mocks.FileStore{}
s3Client.On("GetObject", mock.Anything, mock.MatchedBy(func(input *s3.GetObjectInput) bool { fileStore.
return *input.Range == fmt.Sprintf("bytes=0-%d", expectedBytes) On("GetObjectWithRange", mock.Anything, "foo", startByte, endByte).
})).Return(&s3.GetObjectOutput{Body: audioData, ContentLength: tc.fixtureLen}, nil) Return(audioData, nil)
defer s3Client.AssertExpectations(t)
s3API := media.S3API{S3Client: s3Client, S3PresignClient: &mocks.S3PresignClient{}}
service := media.NewMediaSetService(store, nil, s3API, config.Config{}, zap.NewNop()) service := media.NewMediaSetService(store, nil, fileStore, config.Config{}, zap.NewNop().Sugar())
peaks, err := service.GetAudioSegment(context.Background(), mediaSetID, tc.startFrame, tc.endFrame, tc.numBins) peaks, err := service.GetAudioSegment(context.Background(), mediaSet.ID, tc.startFrame, tc.endFrame, tc.numBins)
if tc.wantErr == "" { if tc.wantErr == "" {
assert.NoError(t, err) assert.NoError(t, err)
@ -130,35 +133,31 @@ func BenchmarkGetAudioSegment(b *testing.B) {
numBins = 2000 numBins = 2000
) )
audioFile, err := os.Open(fixturePath)
require.NoError(b, err)
audioData, err := io.ReadAll(audioFile)
require.NoError(b, err)
mediaSetID := uuid.New()
mediaSet := store.MediaSet{ID: mediaSetID, AudioChannels: channels}
store := &mocks.Store{}
store.On("GetMediaSet", mock.Anything, mediaSetID).Return(mediaSet, nil)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
// recreate the reader on each iteration
b.StopTimer() b.StopTimer()
readCloser := io.NopCloser(bytes.NewReader(audioData))
fileStore := &mocks.FileStore{}
fileStore.
On("GetObjectWithRange", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(readCloser, nil)
expectedBytes := (endFrame - startFrame) * int64(channels) * media.SizeOfInt16 service := media.NewMediaSetService(store, nil, fileStore, config.Config{}, zap.NewNop().Sugar())
audioFile, err := os.Open(fixturePath)
require.NoError(b, err)
audioData := io.NopCloser(io.LimitReader(audioFile, int64(expectedBytes)))
mediaSetID := uuid.New()
mediaSet := store.MediaSet{ID: mediaSetID, AudioChannels: channels}
store := &mocks.Store{}
store.On("GetMediaSet", mock.Anything, mediaSetID).Return(mediaSet, nil)
s3Client := &mocks.S3Client{}
s3Client.
On("GetObject", mock.Anything, mock.Anything).
Return(&s3.GetObjectOutput{Body: audioData, ContentLength: fixtureLen}, nil)
s3API := media.S3API{S3Client: s3Client, S3PresignClient: &mocks.S3PresignClient{}}
service := media.NewMediaSetService(store, nil, s3API, config.Config{}, zap.NewNop())
b.StartTimer() b.StartTimer()
_, err = service.GetAudioSegment(context.Background(), mediaSetID, startFrame, endFrame, numBins)
b.StopTimer()
_, err = service.GetAudioSegment(context.Background(), mediaSetID, startFrame, endFrame, numBins)
require.NoError(b, err) require.NoError(b, err)
audioFile.Close()
} }
} }

View File

@ -1,8 +1,8 @@
package media package media
//go:generate mockery --recursive --name S3* --output ../generated/mocks
//go:generate mockery --recursive --name Store --output ../generated/mocks //go:generate mockery --recursive --name Store --output ../generated/mocks
//go:generate mockery --recursive --name YoutubeClient --output ../generated/mocks //go:generate mockery --recursive --name YoutubeClient --output ../generated/mocks
//go:generate mockery --recursive --name FileStore --output ../generated/mocks
import ( import (
"context" "context"
@ -10,8 +10,6 @@ import (
"time" "time"
"git.netflux.io/rob/clipper/generated/store" "git.netflux.io/rob/clipper/generated/store"
signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid" "github.com/google/uuid"
youtubev2 "github.com/kkdai/youtube/v2" youtubev2 "github.com/kkdai/youtube/v2"
) )
@ -63,28 +61,16 @@ type Store interface {
SetVideoThumbnailUploaded(context.Context, store.SetVideoThumbnailUploadedParams) (store.MediaSet, error) SetVideoThumbnailUploaded(context.Context, store.SetVideoThumbnailUploadedParams) (store.MediaSet, error)
} }
// S3API provides an API to AWS S3. // FileStore wraps a file store.
type S3API struct { type FileStore interface {
S3Client GetObject(ctx context.Context, key string) (io.ReadCloser, error)
S3PresignClient GetObjectWithRange(ctx context.Context, key string, startFrame, endFrame int64) (io.ReadCloser, error)
} GetURL(ctx context.Context, key string) (string, error)
PutObject(ctx context.Context, key string, reader io.Reader, contentType string) (int64, error)
// S3Client wraps the AWS S3 service client.
type S3Client interface {
GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)
CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error)
UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error)
AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error)
CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error)
}
// S3PresignClient wraps the AWS S3 Presign client.
type S3PresignClient interface {
PresignGetObject(context.Context, *s3.GetObjectInput, ...func(*s3.PresignOptions)) (*signerv4.PresignedHTTPRequest, error)
} }
// YoutubeClient wraps the youtube.Client client. // YoutubeClient wraps the youtube.Client client.
type YoutubeClient interface { type YoutubeClient interface {
GetVideoContext(context.Context, string) (*youtubev2.Video, error) GetVideoContext(ctx context.Context, id string) (*youtubev2.Video, error)
GetStreamContext(context.Context, *youtubev2.Video, *youtubev2.Format) (io.ReadCloser, int64, error) GetStreamContext(ctx context.Context, video *youtubev2.Video, format *youtubev2.Format) (io.ReadCloser, int64, error)
} }

View File

@ -1,212 +0,0 @@
package media
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"sort"
"sync"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"go.uber.org/zap"
)
// multipartUploader uploads a file to S3.
//
// TODO: extract to s3 package
type multipartUploader struct {
s3 S3Client
logger *zap.SugaredLogger
}
type uploadResult struct {
completedPart types.CompletedPart
size int64
}
const (
targetPartSizeBytes = 5 * 1024 * 1024 // 5MB
readBufferSizeBytes = 32_768 // 32Kb
)
func newMultipartUploader(s3Client S3Client, logger *zap.SugaredLogger) *multipartUploader {
return &multipartUploader{s3: s3Client, logger: logger}
}
// Upload uploads to an S3 bucket in 5MB parts. It buffers data internally
// until a part is ready to send over the network. Parts are sent as soon as
// they exceed the minimum part size of 5MB.
//
// TODO: expire after configurable period.
func (u *multipartUploader) Upload(ctx context.Context, r io.Reader, bucket, key, contentType string) (int64, error) {
var uploaded bool
input := s3.CreateMultipartUploadInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
ContentType: aws.String(contentType),
}
output, err := u.s3.CreateMultipartUpload(ctx, &input)
if err != nil {
return 0, fmt.Errorf("error creating multipart upload: %v", err)
}
// abort the upload if possible, logging any errors, on exit.
defer func() {
if uploaded {
return
}
input := s3.AbortMultipartUploadInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
UploadId: output.UploadId,
}
// if the context was cancelled, just use the background context.
ctxToUse := ctx
if ctxToUse.Err() != nil {
ctxToUse = context.Background()
}
_, abortErr := u.s3.AbortMultipartUpload(ctxToUse, &input)
if abortErr != nil {
u.logger.Errorf("uploader: error aborting upload: %v", abortErr)
} else {
u.logger.Infof("aborted upload, key = %s", key)
}
}()
uploadResultChan := make(chan uploadResult)
uploadErrorChan := make(chan error, 1)
// uploadPart uploads an individual part.
uploadPart := func(wg *sync.WaitGroup, buf []byte, partNum int32) {
defer wg.Done()
partLen := int64(len(buf))
u.logger.With("key", key, "partNum", partNum, "partLen", partLen).Debug("uploading part")
input := s3.UploadPartInput{
Body: bytes.NewReader(buf),
Bucket: aws.String(bucket),
Key: aws.String(key),
PartNumber: partNum,
UploadId: output.UploadId,
ContentLength: partLen,
}
output, uploadErr := u.s3.UploadPart(ctx, &input)
if uploadErr != nil {
// TODO: retry on failure
uploadErrorChan <- uploadErr
return
}
u.logger.With("key", key, "partNum", partNum, "partLen", partLen, "etag", *output.ETag).Debug("uploaded part")
uploadResultChan <- uploadResult{
completedPart: types.CompletedPart{ETag: output.ETag, PartNumber: partNum},
size: partLen,
}
}
wgDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1) // done when the reader goroutine returns
go func() {
wg.Wait()
wgDone <- struct{}{}
}()
readChan := make(chan error, 1)
go func() {
defer wg.Done()
var closing bool
currPart := bytes.NewBuffer(make([]byte, 0, targetPartSizeBytes+readBufferSizeBytes))
partNum := int32(1)
buf := make([]byte, readBufferSizeBytes)
for {
n, readErr := r.Read(buf)
if readErr == io.EOF {
closing = true
} else if readErr != nil {
readChan <- readErr
return
}
_, _ = currPart.Write(buf[:n])
if closing || currPart.Len() >= targetPartSizeBytes {
part := make([]byte, currPart.Len())
copy(part, currPart.Bytes())
currPart.Truncate(0)
wg.Add(1)
go uploadPart(&wg, part, partNum)
partNum++
}
if closing {
return
}
}
}()
results := make([]uploadResult, 0, 64)
outer:
for {
select {
case readErr := <-readChan:
if readErr != io.EOF {
return 0, fmt.Errorf("reader error: %v", readErr)
}
case uploadResult := <-uploadResultChan:
results = append(results, uploadResult)
case uploadErr := <-uploadErrorChan:
return 0, fmt.Errorf("error while uploading part: %v", uploadErr)
case <-ctx.Done():
return 0, ctx.Err()
case <-wgDone:
break outer
}
}
if len(results) == 0 {
return 0, errors.New("no parts available to upload")
}
completedParts := make([]types.CompletedPart, 0, 64)
var uploadedBytes int64
for _, result := range results {
completedParts = append(completedParts, result.completedPart)
uploadedBytes += result.size
}
// the parts may be out of order, especially with slow network conditions:
sort.Slice(completedParts, func(i, j int) bool {
return completedParts[i].PartNumber < completedParts[j].PartNumber
})
completeInput := s3.CompleteMultipartUploadInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
UploadId: output.UploadId,
MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts},
}
if _, err = u.s3.CompleteMultipartUpload(ctx, &completeInput); err != nil {
return 0, fmt.Errorf("error completing upload: %v", err)
}
u.logger.With("key", key, "numParts", len(completedParts), "len", uploadedBytes).Debug("completed upload")
uploaded = true
return uploadedBytes, nil
}

View File

@ -64,7 +64,8 @@ type Options struct {
Timeout time.Duration Timeout time.Duration
Store media.Store Store media.Store
YoutubeClient media.YoutubeClient YoutubeClient media.YoutubeClient
S3API media.S3API FileStore media.FileStore
Logger *zap.Logger
} }
// mediaSetServiceController implements gRPC controller for MediaSetService // mediaSetServiceController implements gRPC controller for MediaSetService
@ -226,32 +227,26 @@ func (c *mediaSetServiceController) GetVideoThumbnail(ctx context.Context, reque
} }
func Start(options Options) error { func Start(options Options) error {
logger, err := buildLogger(options.Config)
if err != nil {
return fmt.Errorf("error building logger: %v", err)
}
defer logger.Sync()
fetchMediaSetService := media.NewMediaSetService( fetchMediaSetService := media.NewMediaSetService(
options.Store, options.Store,
options.YoutubeClient, options.YoutubeClient,
options.S3API, options.FileStore,
options.Config, options.Config,
logger, options.Logger.Sugar().Named("mediaSetService"),
) )
grpcServer, err := buildGRPCServer(options.Config, logger) grpcServer, err := buildGRPCServer(options.Config, options.Logger)
if err != nil { if err != nil {
return fmt.Errorf("error building server: %v", err) return fmt.Errorf("error building server: %v", err)
} }
mediaSetController := &mediaSetServiceController{mediaSetService: fetchMediaSetService, logger: logger.Sugar().Named("controller")} mediaSetController := &mediaSetServiceController{mediaSetService: fetchMediaSetService, logger: options.Logger.Sugar().Named("controller")}
pbmediaset.RegisterMediaSetServiceServer(grpcServer, mediaSetController) pbmediaset.RegisterMediaSetServiceServer(grpcServer, mediaSetController)
// TODO: configure CORS // TODO: configure CORS
grpcWebServer := grpcweb.WrapServer(grpcServer, grpcweb.WithOriginFunc(func(string) bool { return true })) grpcWebServer := grpcweb.WrapServer(grpcServer, grpcweb.WithOriginFunc(func(string) bool { return true }))
log := logger.Sugar() log := options.Logger.Sugar()
fileHandler := http.NotFoundHandler() fileHandler := http.NotFoundHandler()
if options.Config.AssetsHTTPBasePath != "" { if options.Config.AssetsHTTPBasePath != "" {
log.With("basePath", options.Config.AssetsHTTPBasePath).Info("Configured to serve assets over HTTP") log.With("basePath", options.Config.AssetsHTTPBasePath).Info("Configured to serve assets over HTTP")
@ -280,13 +275,6 @@ func Start(options Options) error {
return httpServer.ListenAndServe() return httpServer.ListenAndServe()
} }
func buildLogger(c config.Config) (*zap.Logger, error) {
if c.Environment == config.EnvProduction {
return zap.NewProduction()
}
return zap.NewDevelopment()
}
func buildGRPCServer(c config.Config, logger *zap.Logger) (*grpc.Server, error) { func buildGRPCServer(c config.Config, logger *zap.Logger) (*grpc.Server, error) {
unaryInterceptors := []grpc.UnaryServerInterceptor{ unaryInterceptors := []grpc.UnaryServerInterceptor{
grpczap.UnaryServerInterceptor(logger), grpczap.UnaryServerInterceptor(logger),