diff --git a/generated/mocks/TwitterAPIClient.go b/generated/mocks/TwitterAPIClient.go new file mode 100644 index 0000000..f4d0759 --- /dev/null +++ b/generated/mocks/TwitterAPIClient.go @@ -0,0 +1,49 @@ +// Code generated by mockery v2.12.2. DO NOT EDIT. + +package mocks + +import ( + testing "testing" + + mock "github.com/stretchr/testify/mock" + + twitter "git.netflux.io/rob/elon-eats-my-tweets/twitter" +) + +// TwitterAPIClient is an autogenerated mock type for the APIClient type +type TwitterAPIClient struct { + mock.Mock +} + +// GetMe provides a mock function with given fields: +func (_m *TwitterAPIClient) GetMe() (*twitter.User, error) { + ret := _m.Called() + + var r0 *twitter.User + if rf, ok := ret.Get(0).(func() *twitter.User); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*twitter.User) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewTwitterAPIClient creates a new instance of TwitterAPIClient. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations. +func NewTwitterAPIClient(t testing.TB) *TwitterAPIClient { + mock := &TwitterAPIClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/generated/store/models.go b/generated/store/models.go index b5d11e5..84e0275 100644 --- a/generated/store/models.go +++ b/generated/store/models.go @@ -12,7 +12,7 @@ import ( type User struct { ID uuid.UUID - TwitterID int32 + TwitterID string Username string Name string AccessToken string diff --git a/generated/store/queries.sql.go b/generated/store/queries.sql.go index aafa8ac..bc9dc23 100644 --- a/generated/store/queries.sql.go +++ b/generated/store/queries.sql.go @@ -11,18 +11,22 @@ import ( ) const createUser = `-- name: CreateUser :one -INSERT INTO users (twitter_id, username, name, access_token, created_at, updated_at) - VALUES ($1, $2, $3, $4, $5, $6) +INSERT INTO users (twitter_id, username, name, access_token, refresh_token, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (twitter_id) + DO + UPDATE SET access_token = EXCLUDED.access_token, refresh_token = EXCLUDED.refresh_token, username = EXCLUDED.username, name = EXCLUDED.name, updated_at = EXCLUDED.updated_at RETURNING id, twitter_id, username, name, access_token, refresh_token, delete_tweets_enabled, delete_tweets_num_per_iteration, created_at, updated_at ` type CreateUserParams struct { - TwitterID int32 - Username string - Name string - AccessToken string - CreatedAt time.Time - UpdatedAt time.Time + TwitterID string + Username string + Name string + AccessToken string + RefreshToken string + CreatedAt time.Time + UpdatedAt time.Time } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { @@ -31,6 +35,7 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e arg.Username, arg.Name, arg.AccessToken, + arg.RefreshToken, arg.CreatedAt, arg.UpdatedAt, ) @@ -49,3 +54,25 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e ) return i, err } + +const getUserByTwitterID = `-- name: GetUserByTwitterID :one +SELECT id, twitter_id, username, name, access_token, refresh_token, delete_tweets_enabled, delete_tweets_num_per_iteration, created_at, updated_at FROM users WHERE twitter_id = $1 +` + +func (q *Queries) GetUserByTwitterID(ctx context.Context, twitterID string) (User, error) { + row := q.db.QueryRow(ctx, getUserByTwitterID, twitterID) + var i User + err := row.Scan( + &i.ID, + &i.TwitterID, + &i.Username, + &i.Name, + &i.AccessToken, + &i.RefreshToken, + &i.DeleteTweetsEnabled, + &i.DeleteTweetsNumPerIteration, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/httpserver/handler.go b/httpserver/handler.go index 1242b15..0fb5f32 100644 --- a/httpserver/handler.go +++ b/httpserver/handler.go @@ -5,10 +5,11 @@ package httpserver import ( "context" "html/template" - "io" "net/http" + "time" "git.netflux.io/rob/elon-eats-my-tweets/config" + "git.netflux.io/rob/elon-eats-my-tweets/generated/store" "git.netflux.io/rob/elon-eats-my-tweets/twitter" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" @@ -25,16 +26,21 @@ const ( pkceVerifierLen = 64 ) -type handler struct { - templates *template.Template - store twitter.Store - oauth2Config *oauth2.Config - sessionStore sessions.Store - tokenGenerator TokenGenerator - logger *zap.SugaredLogger +type TwitterAPIClient interface { + GetMe() (*twitter.User, error) } -func NewHandler(cfg config.Config, templates *template.Template, store twitter.Store, sessionStore sessions.Store, tokenGenerator TokenGenerator, logger *zap.Logger) http.Handler { +type handler struct { + templates *template.Template + store twitter.Store + twitterAPIClientFunc func(c *http.Client) TwitterAPIClient + oauth2Config *oauth2.Config + sessionStore sessions.Store + tokenGenerator TokenGenerator + logger *zap.SugaredLogger +} + +func NewHandler(cfg config.Config, templates *template.Template, store twitter.Store, twitterAPIClientFunc func(c *http.Client) TwitterAPIClient, sessionStore sessions.Store, tokenGenerator TokenGenerator, logger *zap.Logger) http.Handler { r := chi.NewRouter() r.Use(middleware.RequestID) r.Use(middleware.RealIP) @@ -42,8 +48,9 @@ func NewHandler(cfg config.Config, templates *template.Template, store twitter.S r.Use(middleware.Recoverer) h := handler{ - templates: templates, - store: store, + templates: templates, + store: store, + twitterAPIClientFunc: twitterAPIClientFunc, oauth2Config: &oauth2.Config{ ClientID: cfg.Twitter.ClientID, ClientSecret: cfg.Twitter.ClientSecret, @@ -139,20 +146,28 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) { return } - client := h.oauth2Config.Client(context.Background(), token) - resp, err := client.Get("https://api.twitter.com/2/users/me") + twitterClient := h.twitterAPIClientFunc(h.oauth2Config.Client(context.Background(), token)) + twitterUser, err := twitterClient.GetMe() if err != nil { - h.logger.With("err", err).Error("error fetching user") - http.Error(w, "error fetching user", http.StatusInternalServerError) + h.logger.With("err", err).Error("error fetching user from twitter") + http.Error(w, "error validating user", http.StatusInternalServerError) return } - defer resp.Body.Close() - // TODO: do something sensible - body, _ := io.ReadAll(resp.Body) - w.Header().Set("content-type", "application/json") - w.WriteHeader(http.StatusOK) - if _, err = w.Write([]byte(body)); err != nil { - h.logger.With("err", err).Error("error writing response") + if _, err := h.store.CreateUser(r.Context(), store.CreateUserParams{ + TwitterID: twitterUser.ID, + Username: twitterUser.Username, + Name: twitterUser.Name, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + }); err != nil { + h.logger.With("err", err).Error("error upserting user") + http.Error(w, "error saving user", http.StatusInternalServerError) + return } + + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) } diff --git a/httpserver/handler_test.go b/httpserver/handler_test.go index cf59b20..473c083 100644 --- a/httpserver/handler_test.go +++ b/httpserver/handler_test.go @@ -11,7 +11,9 @@ import ( "git.netflux.io/rob/elon-eats-my-tweets/config" "git.netflux.io/rob/elon-eats-my-tweets/generated/mocks" + "git.netflux.io/rob/elon-eats-my-tweets/generated/store" "git.netflux.io/rob/elon-eats-my-tweets/httpserver" + "git.netflux.io/rob/elon-eats-my-tweets/twitter" "github.com/gorilla/sessions" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -44,6 +46,7 @@ func TestGetIndex(t *testing.T) { config.Config{}, templates, &mocks.Store{}, + func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} }, &mocks.SessionStore{}, &mockTokenGenerator{}, zap.NewNop(), @@ -99,6 +102,7 @@ func TestLogin(t *testing.T) { }, templates, &mocks.Store{}, + func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} }, &mockSessionStore, &mockTokenGenerator{tokens: []string{"state", "pkceVerifier"}}, zap.NewNop(), @@ -133,6 +137,8 @@ func TestCallback(t *testing.T) { code string sessionReadError error oauth2StatusCode int + getTwitterUserError error + createUserError error wantStatusCode int wantError string }{ @@ -183,6 +189,28 @@ func TestCallback(t *testing.T) { wantStatusCode: http.StatusForbidden, wantError: "error exchanging code", }, + { + name: "error fetching user from twitter", + state: "mystate", + sessionState: "mystate", + code: "mycode", + sessionPkceVerifier: "mypkceverifier", + oauth2StatusCode: http.StatusOK, + getTwitterUserError: errors.New("nothing to see here"), + wantStatusCode: http.StatusInternalServerError, + wantError: "error validating user", + }, + { + name: "error storing user", + state: "mystate", + sessionState: "mystate", + code: "mycode", + sessionPkceVerifier: "mypkceverifier", + oauth2StatusCode: http.StatusOK, + createUserError: errors.New("oh no"), + wantStatusCode: http.StatusInternalServerError, + wantError: "error saving user", + }, { name: "successful exchange", state: "mystate", @@ -196,14 +224,20 @@ func TestCallback(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var mockStore mocks.Store - var mockSessionStore mocks.SessionStore sess := sessions.NewSession(&mockSessionStore, "elon_session") sess.Values["state"] = tc.sessionState sess.Values["pkce_verifier"] = tc.sessionPkceVerifier mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError) + var mockTwitterClient mocks.TwitterAPIClient + mockTwitterClient.On("GetMe").Return(&twitter.User{ID: "1", Name: "foo", Username: "Foo Bar"}, tc.getTwitterUserError) + + var mockStore mocks.Store + mockStore.On("CreateUser", mock.Anything, mock.MatchedBy(func(params store.CreateUserParams) bool { + return params.TwitterID == "1" && params.Name == "foo" && params.Username == "Foo Bar" + })).Return(store.User{}, tc.createUserError) + const callbackURL = "https://www.example.com/callback" oauthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodPost, r.Method) @@ -234,6 +268,7 @@ func TestCallback(t *testing.T) { }, templates, &mockStore, + func(*http.Client) httpserver.TwitterAPIClient { return &mockTwitterClient }, &mockSessionStore, nil, zap.NewNop(), diff --git a/httpserver/logger.go b/httpserver/logger.go index f6a52fe..4aed649 100644 --- a/httpserver/logger.go +++ b/httpserver/logger.go @@ -17,6 +17,7 @@ func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler { defer func() { l.Info("HTTP", zap.String("proto", r.Proto), + zap.String("method", r.Method), zap.String("path", r.URL.Path), zap.Duration("dur", time.Since(t1)), zap.Int("status", ww.Status()), diff --git a/main.go b/main.go index 5df681c..45c347b 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( "git.netflux.io/rob/elon-eats-my-tweets/config" "git.netflux.io/rob/elon-eats-my-tweets/generated/store" "git.netflux.io/rob/elon-eats-my-tweets/httpserver" + "git.netflux.io/rob/elon-eats-my-tweets/twitter" "github.com/gorilla/sessions" "github.com/jackc/pgx/v4/pgxpool" "go.uber.org/zap" @@ -45,6 +46,7 @@ func main() { cfg, templates, store, + func(c *http.Client) httpserver.TwitterAPIClient { return twitter.NewAPIClient(c) }, sessions.NewCookieStore([]byte(cfg.SessionKey)), httpserver.RandomTokenGenerator{}, logger, diff --git a/sql/migrations/20220520201518_add_users_table.up.sql b/sql/migrations/20220520201518_add_users_table.up.sql index 2282931..9ecec86 100644 --- a/sql/migrations/20220520201518_add_users_table.up.sql +++ b/sql/migrations/20220520201518_add_users_table.up.sql @@ -2,13 +2,15 @@ CREATE EXTENSION IF NOT EXISTS pgcrypto; CREATE TABLE users ( id uuid PRIMARY KEY DEFAULT gen_random_uuid(), - twitter_id int NOT NULL, + twitter_id CHARACTER VARYING(256) NOT NULL, username CHARACTER VARYING(255) NOT NULL, name CHARACTER VARYING(255) NOT NULL, access_token CHARACTER VARYING(512) NOT NULL, refresh_token CHARACTER VARYING(512) NOT NULL, delete_tweets_enabled boolean NOT NULL DEFAULT false, - delete_tweets_num_per_iteration int NOT NULL DEFAULT 0, + delete_tweets_num_per_iteration int NOT NULL DEFAULT 1, created_at TIMESTAMP WITH TIME ZONE NOT NULL, updated_at TIMESTAMP WITH TIME ZONE NOT NULL -) +); + +CREATE UNIQUE INDEX index_users_on_twitter_id ON users (twitter_id); diff --git a/sql/queries.sql b/sql/queries.sql index 0d53ba9..8ffeacf 100644 --- a/sql/queries.sql +++ b/sql/queries.sql @@ -1,4 +1,10 @@ -- name: CreateUser :one -INSERT INTO users (twitter_id, username, name, access_token, created_at, updated_at) - VALUES ($1, $2, $3, $4, $5, $6) +INSERT INTO users (twitter_id, username, name, access_token, refresh_token, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (twitter_id) + DO + UPDATE SET access_token = EXCLUDED.access_token, refresh_token = EXCLUDED.refresh_token, username = EXCLUDED.username, name = EXCLUDED.name, updated_at = EXCLUDED.updated_at RETURNING *; + +-- name: GetUserByTwitterID :one +SELECT * FROM users WHERE twitter_id = $1; diff --git a/twitter/api.go b/twitter/api.go new file mode 100644 index 0000000..b647f8b --- /dev/null +++ b/twitter/api.go @@ -0,0 +1,53 @@ +package twitter + +//go:generate mockery --recursive --name APIClient --structname TwitterAPIClient --filename TwitterAPIClient.go --output ../generated/mocks + +import ( + "encoding/json" + "fmt" + "net/http" +) + +type User struct { + ID string + Name string + Username string +} + +func NewAPIClient(httpclient *http.Client) *APIClient { + return &APIClient{httpclient} +} + +type APIClient struct { + *http.Client +} + +func (c *APIClient) GetMe() (*User, error) { + type oauthResponse struct { + Data struct { + ID string `json:"id"` + Name string `json:"name"` + Username string `json:"username"` + } `json:"data"` + } + + resp, err := c.Get("https://api.twitter.com/2/users/me") + if err != nil { + return nil, fmt.Errorf("error fetching resource: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("error fetching resource: status code %d", resp.StatusCode) + } + + var oauthResp oauthResponse + if err = json.NewDecoder(resp.Body).Decode(&oauthResp); err != nil { + return nil, fmt.Errorf("error decoding resource: %v", err) + } + + return &User{ + ID: oauthResp.Data.ID, + Name: oauthResp.Data.Name, + Username: oauthResp.Data.Username, + }, nil +}