httpserver: Save Twitter user to store
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
6be489c44f
commit
9767ad2331
|
@ -0,0 +1,49 @@
|
|||
// Code generated by mockery v2.12.2. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
testing "testing"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
twitter "git.netflux.io/rob/elon-eats-my-tweets/twitter"
|
||||
)
|
||||
|
||||
// TwitterAPIClient is an autogenerated mock type for the APIClient type
|
||||
type TwitterAPIClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// GetMe provides a mock function with given fields:
|
||||
func (_m *TwitterAPIClient) GetMe() (*twitter.User, error) {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 *twitter.User
|
||||
if rf, ok := ret.Get(0).(func() *twitter.User); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*twitter.User)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = rf()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NewTwitterAPIClient creates a new instance of TwitterAPIClient. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewTwitterAPIClient(t testing.TB) *TwitterAPIClient {
|
||||
mock := &TwitterAPIClient{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
type User struct {
|
||||
ID uuid.UUID
|
||||
TwitterID int32
|
||||
TwitterID string
|
||||
Username string
|
||||
Name string
|
||||
AccessToken string
|
||||
|
|
|
@ -11,16 +11,20 @@ import (
|
|||
)
|
||||
|
||||
const createUser = `-- name: CreateUser :one
|
||||
INSERT INTO users (twitter_id, username, name, access_token, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
INSERT INTO users (twitter_id, username, name, access_token, refresh_token, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (twitter_id)
|
||||
DO
|
||||
UPDATE SET access_token = EXCLUDED.access_token, refresh_token = EXCLUDED.refresh_token, username = EXCLUDED.username, name = EXCLUDED.name, updated_at = EXCLUDED.updated_at
|
||||
RETURNING id, twitter_id, username, name, access_token, refresh_token, delete_tweets_enabled, delete_tweets_num_per_iteration, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateUserParams struct {
|
||||
TwitterID int32
|
||||
TwitterID string
|
||||
Username string
|
||||
Name string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
@ -31,6 +35,7 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e
|
|||
arg.Username,
|
||||
arg.Name,
|
||||
arg.AccessToken,
|
||||
arg.RefreshToken,
|
||||
arg.CreatedAt,
|
||||
arg.UpdatedAt,
|
||||
)
|
||||
|
@ -49,3 +54,25 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e
|
|||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserByTwitterID = `-- name: GetUserByTwitterID :one
|
||||
SELECT id, twitter_id, username, name, access_token, refresh_token, delete_tweets_enabled, delete_tweets_num_per_iteration, created_at, updated_at FROM users WHERE twitter_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByTwitterID(ctx context.Context, twitterID string) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByTwitterID, twitterID)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TwitterID,
|
||||
&i.Username,
|
||||
&i.Name,
|
||||
&i.AccessToken,
|
||||
&i.RefreshToken,
|
||||
&i.DeleteTweetsEnabled,
|
||||
&i.DeleteTweetsNumPerIteration,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
|
|
@ -5,10 +5,11 @@ package httpserver
|
|||
import (
|
||||
"context"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/config"
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/generated/store"
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/twitter"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
|
@ -25,16 +26,21 @@ const (
|
|||
pkceVerifierLen = 64
|
||||
)
|
||||
|
||||
type TwitterAPIClient interface {
|
||||
GetMe() (*twitter.User, error)
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
templates *template.Template
|
||||
store twitter.Store
|
||||
twitterAPIClientFunc func(c *http.Client) TwitterAPIClient
|
||||
oauth2Config *oauth2.Config
|
||||
sessionStore sessions.Store
|
||||
tokenGenerator TokenGenerator
|
||||
logger *zap.SugaredLogger
|
||||
}
|
||||
|
||||
func NewHandler(cfg config.Config, templates *template.Template, store twitter.Store, sessionStore sessions.Store, tokenGenerator TokenGenerator, logger *zap.Logger) http.Handler {
|
||||
func NewHandler(cfg config.Config, templates *template.Template, store twitter.Store, twitterAPIClientFunc func(c *http.Client) TwitterAPIClient, sessionStore sessions.Store, tokenGenerator TokenGenerator, logger *zap.Logger) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
|
@ -44,6 +50,7 @@ func NewHandler(cfg config.Config, templates *template.Template, store twitter.S
|
|||
h := handler{
|
||||
templates: templates,
|
||||
store: store,
|
||||
twitterAPIClientFunc: twitterAPIClientFunc,
|
||||
oauth2Config: &oauth2.Config{
|
||||
ClientID: cfg.Twitter.ClientID,
|
||||
ClientSecret: cfg.Twitter.ClientSecret,
|
||||
|
@ -139,20 +146,28 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
client := h.oauth2Config.Client(context.Background(), token)
|
||||
resp, err := client.Get("https://api.twitter.com/2/users/me")
|
||||
twitterClient := h.twitterAPIClientFunc(h.oauth2Config.Client(context.Background(), token))
|
||||
twitterUser, err := twitterClient.GetMe()
|
||||
if err != nil {
|
||||
h.logger.With("err", err).Error("error fetching user")
|
||||
http.Error(w, "error fetching user", http.StatusInternalServerError)
|
||||
h.logger.With("err", err).Error("error fetching user from twitter")
|
||||
http.Error(w, "error validating user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.store.CreateUser(r.Context(), store.CreateUserParams{
|
||||
TwitterID: twitterUser.ID,
|
||||
Username: twitterUser.Username,
|
||||
Name: twitterUser.Name,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
h.logger.With("err", err).Error("error upserting user")
|
||||
http.Error(w, "error saving user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
// TODO: do something sensible
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
w.Header().Set("content-type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err = w.Write([]byte(body)); err != nil {
|
||||
h.logger.With("err", err).Error("error writing response")
|
||||
}
|
||||
w.Write([]byte("ok"))
|
||||
}
|
||||
|
|
|
@ -11,7 +11,9 @@ import (
|
|||
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/config"
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/generated/mocks"
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/generated/store"
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/httpserver"
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/twitter"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
@ -44,6 +46,7 @@ func TestGetIndex(t *testing.T) {
|
|||
config.Config{},
|
||||
templates,
|
||||
&mocks.Store{},
|
||||
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
|
||||
&mocks.SessionStore{},
|
||||
&mockTokenGenerator{},
|
||||
zap.NewNop(),
|
||||
|
@ -99,6 +102,7 @@ func TestLogin(t *testing.T) {
|
|||
},
|
||||
templates,
|
||||
&mocks.Store{},
|
||||
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
|
||||
&mockSessionStore,
|
||||
&mockTokenGenerator{tokens: []string{"state", "pkceVerifier"}},
|
||||
zap.NewNop(),
|
||||
|
@ -133,6 +137,8 @@ func TestCallback(t *testing.T) {
|
|||
code string
|
||||
sessionReadError error
|
||||
oauth2StatusCode int
|
||||
getTwitterUserError error
|
||||
createUserError error
|
||||
wantStatusCode int
|
||||
wantError string
|
||||
}{
|
||||
|
@ -183,6 +189,28 @@ func TestCallback(t *testing.T) {
|
|||
wantStatusCode: http.StatusForbidden,
|
||||
wantError: "error exchanging code",
|
||||
},
|
||||
{
|
||||
name: "error fetching user from twitter",
|
||||
state: "mystate",
|
||||
sessionState: "mystate",
|
||||
code: "mycode",
|
||||
sessionPkceVerifier: "mypkceverifier",
|
||||
oauth2StatusCode: http.StatusOK,
|
||||
getTwitterUserError: errors.New("nothing to see here"),
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantError: "error validating user",
|
||||
},
|
||||
{
|
||||
name: "error storing user",
|
||||
state: "mystate",
|
||||
sessionState: "mystate",
|
||||
code: "mycode",
|
||||
sessionPkceVerifier: "mypkceverifier",
|
||||
oauth2StatusCode: http.StatusOK,
|
||||
createUserError: errors.New("oh no"),
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantError: "error saving user",
|
||||
},
|
||||
{
|
||||
name: "successful exchange",
|
||||
state: "mystate",
|
||||
|
@ -196,14 +224,20 @@ func TestCallback(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var mockStore mocks.Store
|
||||
|
||||
var mockSessionStore mocks.SessionStore
|
||||
sess := sessions.NewSession(&mockSessionStore, "elon_session")
|
||||
sess.Values["state"] = tc.sessionState
|
||||
sess.Values["pkce_verifier"] = tc.sessionPkceVerifier
|
||||
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError)
|
||||
|
||||
var mockTwitterClient mocks.TwitterAPIClient
|
||||
mockTwitterClient.On("GetMe").Return(&twitter.User{ID: "1", Name: "foo", Username: "Foo Bar"}, tc.getTwitterUserError)
|
||||
|
||||
var mockStore mocks.Store
|
||||
mockStore.On("CreateUser", mock.Anything, mock.MatchedBy(func(params store.CreateUserParams) bool {
|
||||
return params.TwitterID == "1" && params.Name == "foo" && params.Username == "Foo Bar"
|
||||
})).Return(store.User{}, tc.createUserError)
|
||||
|
||||
const callbackURL = "https://www.example.com/callback"
|
||||
oauthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
|
@ -234,6 +268,7 @@ func TestCallback(t *testing.T) {
|
|||
},
|
||||
templates,
|
||||
&mockStore,
|
||||
func(*http.Client) httpserver.TwitterAPIClient { return &mockTwitterClient },
|
||||
&mockSessionStore,
|
||||
nil,
|
||||
zap.NewNop(),
|
||||
|
|
|
@ -17,6 +17,7 @@ func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler {
|
|||
defer func() {
|
||||
l.Info("HTTP",
|
||||
zap.String("proto", r.Proto),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.Duration("dur", time.Since(t1)),
|
||||
zap.Int("status", ww.Status()),
|
||||
|
|
2
main.go
2
main.go
|
@ -12,6 +12,7 @@ import (
|
|||
"git.netflux.io/rob/elon-eats-my-tweets/config"
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/generated/store"
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/httpserver"
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/twitter"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"go.uber.org/zap"
|
||||
|
@ -45,6 +46,7 @@ func main() {
|
|||
cfg,
|
||||
templates,
|
||||
store,
|
||||
func(c *http.Client) httpserver.TwitterAPIClient { return twitter.NewAPIClient(c) },
|
||||
sessions.NewCookieStore([]byte(cfg.SessionKey)),
|
||||
httpserver.RandomTokenGenerator{},
|
||||
logger,
|
||||
|
|
|
@ -2,13 +2,15 @@ CREATE EXTENSION IF NOT EXISTS pgcrypto;
|
|||
|
||||
CREATE TABLE users (
|
||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
twitter_id int NOT NULL,
|
||||
twitter_id CHARACTER VARYING(256) NOT NULL,
|
||||
username CHARACTER VARYING(255) NOT NULL,
|
||||
name CHARACTER VARYING(255) NOT NULL,
|
||||
access_token CHARACTER VARYING(512) NOT NULL,
|
||||
refresh_token CHARACTER VARYING(512) NOT NULL,
|
||||
delete_tweets_enabled boolean NOT NULL DEFAULT false,
|
||||
delete_tweets_num_per_iteration int NOT NULL DEFAULT 0,
|
||||
delete_tweets_num_per_iteration int NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
)
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX index_users_on_twitter_id ON users (twitter_id);
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
-- name: CreateUser :one
|
||||
INSERT INTO users (twitter_id, username, name, access_token, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
INSERT INTO users (twitter_id, username, name, access_token, refresh_token, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (twitter_id)
|
||||
DO
|
||||
UPDATE SET access_token = EXCLUDED.access_token, refresh_token = EXCLUDED.refresh_token, username = EXCLUDED.username, name = EXCLUDED.name, updated_at = EXCLUDED.updated_at
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetUserByTwitterID :one
|
||||
SELECT * FROM users WHERE twitter_id = $1;
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
package twitter
|
||||
|
||||
//go:generate mockery --recursive --name APIClient --structname TwitterAPIClient --filename TwitterAPIClient.go --output ../generated/mocks
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
Name string
|
||||
Username string
|
||||
}
|
||||
|
||||
func NewAPIClient(httpclient *http.Client) *APIClient {
|
||||
return &APIClient{httpclient}
|
||||
}
|
||||
|
||||
type APIClient struct {
|
||||
*http.Client
|
||||
}
|
||||
|
||||
func (c *APIClient) GetMe() (*User, error) {
|
||||
type oauthResponse struct {
|
||||
Data struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
resp, err := c.Get("https://api.twitter.com/2/users/me")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetching resource: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("error fetching resource: status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var oauthResp oauthResponse
|
||||
if err = json.NewDecoder(resp.Body).Decode(&oauthResp); err != nil {
|
||||
return nil, fmt.Errorf("error decoding resource: %v", err)
|
||||
}
|
||||
|
||||
return &User{
|
||||
ID: oauthResp.Data.ID,
|
||||
Name: oauthResp.Data.Name,
|
||||
Username: oauthResp.Data.Username,
|
||||
}, nil
|
||||
}
|
Reference in New Issue