diff --git a/backend/filestore/fs.go b/backend/filestore/fs.go index 337988f..63b38b5 100644 --- a/backend/filestore/fs.go +++ b/backend/filestore/fs.go @@ -74,8 +74,9 @@ func (s *FileSystemStore) GetURL(ctx context.Context, key string) (string, error return url.String(), nil } -// PutObject writes an object to the local filesystem. -func (s *FileSystemStore) PutObject(ctx context.Context, key string, r io.Reader, _ string) (int64, error) { +// PutObject writes an object to the local filesystem. It will close r after +// consuming it. +func (s *FileSystemStore) PutObject(ctx context.Context, key string, r io.ReadCloser, _ string) (int64, error) { path := filepath.Join(s.rootPath, key) if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { return 0, fmt.Errorf("error creating directories: %v", err) @@ -89,5 +90,10 @@ func (s *FileSystemStore) PutObject(ctx context.Context, key string, r io.Reader if err != nil { return n, fmt.Errorf("error writing file: %v", err) } + + if err := r.Close(); err != nil { + return n, fmt.Errorf("error closing reader: %v", err) + } + return n, nil } diff --git a/backend/filestore/fs_test.go b/backend/filestore/fs_test.go index 14f248c..f73ed2d 100644 --- a/backend/filestore/fs_test.go +++ b/backend/filestore/fs_test.go @@ -2,6 +2,7 @@ package filestore_test import ( "context" + "io" "io/ioutil" "os" "path" @@ -152,7 +153,7 @@ func TestFileStorePutObject(t *testing.T) { store, err := filestore.NewFileSystemStore(rootPath, "/") require.NoError(t, err) - n, err := store.PutObject(context.Background(), tc.key, strings.NewReader(tc.content), "text/plain") + n, err := store.PutObject(context.Background(), tc.key, io.NopCloser(strings.NewReader(tc.content)), "text/plain") require.NoError(t, err) content, err := os.ReadFile(path.Join(rootPath, tc.key)) diff --git a/backend/filestore/s3.go b/backend/filestore/s3.go index bac5ddd..bca3772 100644 --- a/backend/filestore/s3.go +++ b/backend/filestore/s3.go @@ -104,7 +104,7 @@ func (s *S3FileStore) GetURL(ctx context.Context, key string) (string, error) { // PutObject uploads an object using multipart upload, returning the number of // bytes uploaded and any error. -func (s *S3FileStore) PutObject(ctx context.Context, key string, r io.Reader, contentType string) (int64, error) { +func (s *S3FileStore) PutObject(ctx context.Context, key string, r io.ReadCloser, contentType string) (int64, error) { const ( targetPartSizeBytes = 5 * 1024 * 1024 // 5MB readBufferSizeBytes = 32_768 // 32Kb @@ -249,6 +249,10 @@ outer: } } + if err = r.Close(); err != nil { + return 0, fmt.Errorf("error closing reader: %v", err) + } + if len(results) == 0 { return 0, errors.New("no parts available to upload") } diff --git a/backend/filestore/s3_test.go b/backend/filestore/s3_test.go index 3434c87..b2ab3e5 100644 --- a/backend/filestore/s3_test.go +++ b/backend/filestore/s3_test.go @@ -90,6 +90,7 @@ func TestS3GetURL(t *testing.T) { type testReader struct { count, exp int + closed bool } func (r *testReader) Read(p []byte) (int, error) { @@ -102,6 +103,11 @@ func (r *testReader) Read(p []byte) (int, error) { return len(p), nil } +func (r *testReader) Close() error { + r.closed = true + return nil +} + func TestS3PutObject(t *testing.T) { const ( bucket = "some-bucket" @@ -135,10 +141,12 @@ func TestS3PutObject(t *testing.T) { store := filestore.NewS3FileStore(filestore.S3API{S3Client: s3Client}, bucket, time.Hour, zap.NewNop().Sugar()) - n, err := store.PutObject(context.Background(), key, &testReader{exp: contentLength}, contentType) + reader := &testReader{exp: contentLength} + n, err := store.PutObject(context.Background(), key, reader, contentType) require.NoError(t, err) assert.Equal(t, int64(contentLength), n) assert.ElementsMatch(t, []int64{5_242_880, 5_242_880, 5_242_880, 4_271_360}, partLengths) + assert.True(t, reader.closed) }) t.Run("NOK,UploadPartFailure", func(t *testing.T) { diff --git a/backend/generated/mocks/FileStore.go b/backend/generated/mocks/FileStore.go index 3c6c2c3..bc4ca90 100644 --- a/backend/generated/mocks/FileStore.go +++ b/backend/generated/mocks/FileStore.go @@ -82,18 +82,18 @@ func (_m *FileStore) GetURL(ctx context.Context, key string) (string, error) { } // PutObject provides a mock function with given fields: ctx, key, reader, contentType -func (_m *FileStore) PutObject(ctx context.Context, key string, reader io.Reader, contentType string) (int64, error) { +func (_m *FileStore) PutObject(ctx context.Context, key string, reader io.ReadCloser, contentType string) (int64, error) { ret := _m.Called(ctx, key, reader, contentType) var r0 int64 - if rf, ok := ret.Get(0).(func(context.Context, string, io.Reader, string) int64); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, io.ReadCloser, string) int64); ok { r0 = rf(ctx, key, reader, contentType) } else { r0 = ret.Get(0).(int64) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, io.Reader, string) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, string, io.ReadCloser, string) error); ok { r1 = rf(ctx, key, reader, contentType) } else { r1 = ret.Error(1) diff --git a/backend/media/get_audio.go b/backend/media/get_audio.go index fb0a43b..d14a9bd 100644 --- a/backend/media/get_audio.go +++ b/backend/media/get_audio.go @@ -145,8 +145,8 @@ func (s *audioGetterState) getAudio(ctx context.Context, r io.ReadCloser, mediaS // TODO: use mediaSet func to fetch key key := fmt.Sprintf("media_sets/%s/audio.raw", mediaSet.ID) - teeReader := io.TeeReader(stdout, s) - bytesUploaded, rawErr := s.fileStore.PutObject(ctx, key, teeReader, rawAudioMimeType) + bytesUploaded, rawErr := s.fileStore.PutObject(ctx, key, readCloser{io.TeeReader(stdout, s), stdout}, rawAudioMimeType) + if rawErr != nil { s.CloseWithError(fmt.Errorf("error uploading raw audio: %v", rawErr)) return diff --git a/backend/media/get_video.go b/backend/media/get_video.go index f3ef609..9ad86a0 100644 --- a/backend/media/get_video.go +++ b/backend/media/get_video.go @@ -30,7 +30,7 @@ type videoGetter struct { type videoGetterState struct { *videoGetter - r io.Reader + r io.ReadCloser count, exp int64 mediaSetID uuid.UUID key, contentType string @@ -47,10 +47,10 @@ func newVideoGetter(store Store, fileStore FileStore, logger *zap.SugaredLogger) // specified key and content type. The returned reader must have its Next() // method called until error = io.EOF, otherwise a deadlock or other resource // leakage is likely. -func (g *videoGetter) GetVideo(ctx context.Context, r io.Reader, exp int64, mediaSetID uuid.UUID, key, contentType string) (GetVideoProgressReader, error) { +func (g *videoGetter) GetVideo(ctx context.Context, r io.ReadCloser, exp int64, mediaSetID uuid.UUID, key, contentType string) (GetVideoProgressReader, error) { s := &videoGetterState{ videoGetter: g, - r: newLogProgressReader(r, "video", exp, g.logger), + r: r, exp: exp, mediaSetID: mediaSetID, key: key, @@ -75,8 +75,8 @@ func (s *videoGetterState) Write(p []byte) (int, error) { } func (s *videoGetterState) getVideo(ctx context.Context) { - teeReader := io.TeeReader(s.r, s) - + progressReader := newLogProgressReader(s.r, "video", s.exp, s.logger) + teeReader := readCloser{io.TeeReader(progressReader, s), s.r} _, err := s.fileStore.PutObject(ctx, s.key, teeReader, s.contentType) if err != nil { s.errorChan <- fmt.Errorf("error uploading to file store: %v", err) diff --git a/backend/media/thumbnail.go b/backend/media/thumbnail.go index 021946b..c586657 100644 --- a/backend/media/thumbnail.go +++ b/backend/media/thumbnail.go @@ -78,7 +78,7 @@ func (s *MediaSetService) getThumbnailFromYoutube(ctx context.Context, mediaSet thumbnailKey := fmt.Sprintf("media_sets/%s/thumbnail.jpg", mediaSet.ID) const mimeType = "application/jpeg" - _, err = s.fileStore.PutObject(ctx, thumbnailKey, bytes.NewReader(imageData), mimeType) + _, err = s.fileStore.PutObject(ctx, thumbnailKey, io.NopCloser(bytes.NewReader(imageData)), mimeType) if err != nil { return VideoThumbnail{}, fmt.Errorf("error uploading thumbnail: %v", err) } diff --git a/backend/media/types.go b/backend/media/types.go index 429e25b..be0a08e 100644 --- a/backend/media/types.go +++ b/backend/media/types.go @@ -66,7 +66,7 @@ type FileStore interface { GetObject(ctx context.Context, key string) (io.ReadCloser, error) GetObjectWithRange(ctx context.Context, key string, startFrame, endFrame int64) (io.ReadCloser, error) GetURL(ctx context.Context, key string) (string, error) - PutObject(ctx context.Context, key string, reader io.Reader, contentType string) (int64, error) + PutObject(ctx context.Context, key string, reader io.ReadCloser, contentType string) (int64, error) } // YoutubeClient wraps the youtube.Client client. @@ -74,3 +74,8 @@ type YoutubeClient interface { GetVideoContext(ctx context.Context, id string) (*youtubev2.Video, error) GetStreamContext(ctx context.Context, video *youtubev2.Video, format *youtubev2.Format) (io.ReadCloser, int64, error) } + +type readCloser struct { + io.Reader + io.Closer +}