Implement login/logout flow
continuous-integration/drone/push Build is passing Details

This commit is contained in:
Rob Watson 2022-05-21 22:51:47 +02:00
parent 55ee291390
commit 1816d0be4b
7 changed files with 229 additions and 42 deletions

View File

@ -5,6 +5,8 @@ package httpserver
import ( import (
"context" "context"
"encoding/gob"
"html/template"
"net/http" "net/http"
"time" "time"
@ -14,11 +16,27 @@ import (
"git.netflux.io/rob/elon-eats-my-tweets/twitter" "git.netflux.io/rob/elon-eats-my-tweets/twitter"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/chi/middleware" "github.com/go-chi/chi/middleware"
"github.com/google/uuid"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
func init() {
gob.Register(CurrentUser{})
}
type CurrentUser struct {
ID uuid.UUID
TwitterName, TwitterUsername string
}
func (cu CurrentUser) IsNil() bool {
return cu.ID == uuid.Nil
}
type contextKey string
const ( const (
sessionName = "elon_session" sessionName = "elon_session"
sessionKeyState = "state" sessionKeyState = "state"
@ -27,6 +45,19 @@ const (
pkceVerifierLen = 256 pkceVerifierLen = 256
) )
const (
sessionKeyCurrentUser = "current_user"
contextKeyCurrentUser contextKey = sessionKeyCurrentUser
)
func getCurrentUser(r *http.Request, sessionStore sessions.Store) (CurrentUser, bool) {
session, _ := sessionStore.Get(r, sessionName)
if cu, ok := session.Values[sessionKeyCurrentUser]; ok {
return cu.(CurrentUser), true
}
return CurrentUser{}, false
}
type TwitterAPIClient interface { type TwitterAPIClient interface {
GetMe() (*twitter.User, error) GetMe() (*twitter.User, error)
} }
@ -35,6 +66,7 @@ type handler struct {
store twitter.Store store twitter.Store
twitterAPIClientFunc func(c *http.Client) TwitterAPIClient twitterAPIClientFunc func(c *http.Client) TwitterAPIClient
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
renderer *templates.Renderer
sessionStore sessions.Store sessionStore sessions.Store
tokenGenerator TokenGenerator tokenGenerator TokenGenerator
logger *zap.SugaredLogger logger *zap.SugaredLogger
@ -44,7 +76,7 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
r := chi.NewRouter() r := chi.NewRouter()
r.Use(middleware.RequestID) r.Use(middleware.RequestID)
r.Use(middleware.RealIP) r.Use(middleware.RealIP)
r.Use(loggerMiddleware(logger)) r.Use(loggerMiddleware(logger, sessionStore))
r.Use(middleware.Recoverer) r.Use(middleware.Recoverer)
h := handler{ h := handler{
@ -61,6 +93,7 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
AuthStyle: oauth2.AuthStyleInHeader, AuthStyle: oauth2.AuthStyleInHeader,
}, },
}, },
renderer: templates.NewRenderer(logger.Sugar()),
sessionStore: sessionStore, sessionStore: sessionStore,
tokenGenerator: tokenGenerator, tokenGenerator: tokenGenerator,
logger: logger.Sugar(), logger: logger.Sugar(),
@ -68,17 +101,26 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
r.Get("/", h.getIndex) r.Get("/", h.getIndex)
r.Get("/login", h.getLogin) r.Get("/login", h.getLogin)
r.Post("/logout", h.postLogout)
r.Get("/callback", h.getCallback) r.Get("/callback", h.getCallback)
return r return r
} }
func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) { func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) {
if err := templates.Execute(w, "index.html", nil); err != nil { cu, _ := getCurrentUser(r, h.sessionStore)
h.logger.With("err", err).Error("error rendering template") var template *template.Template
http.Error(w, "error rendering template", http.StatusInternalServerError) if cu.IsNil() {
return template = templates.Index
} else {
template = templates.Dashboard
} }
type data struct {
CurrentUser CurrentUser
}
h.renderer.Render(w, template, data{CurrentUser: cu})
} }
func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) { func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) {
@ -153,7 +195,7 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) {
return return
} }
if _, err := h.store.CreateUser(r.Context(), store.CreateUserParams{ user, err := h.store.CreateUser(r.Context(), store.CreateUserParams{
TwitterID: twitterUser.ID, TwitterID: twitterUser.ID,
Username: twitterUser.Username, Username: twitterUser.Username,
Name: twitterUser.Name, Name: twitterUser.Name,
@ -161,12 +203,44 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) {
RefreshToken: token.RefreshToken, RefreshToken: token.RefreshToken,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
}); err != nil { })
if err != nil {
h.logger.With("err", err).Error("error upserting user") h.logger.With("err", err).Error("error upserting user")
http.Error(w, "error saving user", http.StatusInternalServerError) http.Error(w, "error saving user", http.StatusInternalServerError)
return return
} }
w.WriteHeader(http.StatusOK) // finally, save in the session. To avoid hitting the database on most
w.Write([]byte("ok")) // requests we'll store the Twitter name and username in the session.
session.Values[sessionKeyCurrentUser] = CurrentUser{
ID: user.ID,
TwitterName: twitterUser.Name,
TwitterUsername: twitterUser.Username,
}
if err := session.Save(r, w); err != nil {
h.logger.With("err", err).Error("error saving session")
http.Error(w, "error saving user", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/", http.StatusFound)
}
func (h *handler) postLogout(w http.ResponseWriter, r *http.Request) {
session, err := h.sessionStore.Get(r, sessionName)
if err != nil {
h.logger.With("err", err).Error("error reading session")
http.Error(w, "error reading session", http.StatusBadRequest)
return
}
delete(session.Values, sessionKeyCurrentUser)
if err := session.Save(r, w); err != nil {
h.logger.With("err", err).Error("error saving session")
http.Error(w, "error saving user", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/", http.StatusFound)
} }

View File

@ -12,6 +12,7 @@ import (
"git.netflux.io/rob/elon-eats-my-tweets/generated/store" "git.netflux.io/rob/elon-eats-my-tweets/generated/store"
"git.netflux.io/rob/elon-eats-my-tweets/httpserver" "git.netflux.io/rob/elon-eats-my-tweets/httpserver"
"git.netflux.io/rob/elon-eats-my-tweets/twitter" "git.netflux.io/rob/elon-eats-my-tweets/twitter"
"github.com/google/uuid"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
@ -35,29 +36,56 @@ func (g *mockTokenGenerator) GenerateToken(_ int) string {
} }
func TestGetIndex(t *testing.T) { func TestGetIndex(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil) testCases := []struct {
rec := httptest.NewRecorder() name string
currentUser httpserver.CurrentUser
wantBody string
}{
{
name: "logged out",
currentUser: httpserver.CurrentUser{},
wantBody: "Sign in with Twitter",
},
{
name: "logged out",
currentUser: httpserver.CurrentUser{ID: uuid.New(), TwitterName: "foo", TwitterUsername: "Foo Bar"},
wantBody: "Welcome to your dashboard",
},
}
handler := httpserver.NewHandler( for _, tc := range testCases {
config.Config{}, t.Run(tc.name, func(t *testing.T) {
&mocks.Store{}, req := httptest.NewRequest(http.MethodGet, "/", nil)
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} }, rec := httptest.NewRecorder()
&mocks.SessionStore{},
&mockTokenGenerator{},
zap.NewNop(),
)
handler.ServeHTTP(rec, req) var mockSessionStore mocks.SessionStore
res := rec.Result() sess := sessions.NewSession(&mockSessionStore, "elon_session")
defer res.Body.Close() sess.Values["current_user"] = tc.currentUser
body, err := ioutil.ReadAll(res.Body) mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil)
require.NoError(t, err) mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(nil)
assert.Equal(t, http.StatusOK, res.StatusCode) handler := httpserver.NewHandler(
assert.Contains(t, string(body), "Sign in with Twitter") config.Config{},
&mocks.Store{},
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
&mockSessionStore,
&mockTokenGenerator{},
zap.NewNop(),
)
handler.ServeHTTP(rec, req)
res := rec.Result()
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, string(body), tc.wantBody)
})
}
} }
func TestLogin(t *testing.T) { func TestGetLogin(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
sessionSaveError error sessionSaveError error
@ -122,7 +150,7 @@ func TestLogin(t *testing.T) {
} }
} }
func TestCallback(t *testing.T) { func TestGetCallback(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
state string state string
@ -133,7 +161,9 @@ func TestCallback(t *testing.T) {
oauth2StatusCode int oauth2StatusCode int
getTwitterUserError error getTwitterUserError error
createUserError error createUserError error
sessionSaveError error
wantStatusCode int wantStatusCode int
wantLocation string
wantError string wantError string
}{ }{
{ {
@ -205,6 +235,17 @@ func TestCallback(t *testing.T) {
wantStatusCode: http.StatusInternalServerError, wantStatusCode: http.StatusInternalServerError,
wantError: "error saving user", wantError: "error saving user",
}, },
{
name: "error saving session",
state: "mystate",
sessionState: "mystate",
code: "mycode",
sessionPkceVerifier: "mypkceverifier",
oauth2StatusCode: http.StatusOK,
sessionSaveError: errors.New("foo"),
wantStatusCode: http.StatusInternalServerError,
wantError: "error saving user",
},
{ {
name: "successful exchange", name: "successful exchange",
state: "mystate", state: "mystate",
@ -212,7 +253,8 @@ func TestCallback(t *testing.T) {
code: "mycode", code: "mycode",
sessionPkceVerifier: "mypkceverifier", sessionPkceVerifier: "mypkceverifier",
oauth2StatusCode: http.StatusOK, oauth2StatusCode: http.StatusOK,
wantStatusCode: http.StatusOK, wantStatusCode: http.StatusFound,
wantLocation: "/",
}, },
} }
@ -223,6 +265,7 @@ func TestCallback(t *testing.T) {
sess.Values["state"] = tc.sessionState sess.Values["state"] = tc.sessionState
sess.Values["pkce_verifier"] = tc.sessionPkceVerifier sess.Values["pkce_verifier"] = tc.sessionPkceVerifier
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError) mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError)
mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(tc.sessionSaveError)
var mockTwitterClient mocks.TwitterAPIClient var mockTwitterClient mocks.TwitterAPIClient
mockTwitterClient.On("GetMe").Return(&twitter.User{ID: "1", Name: "foo", Username: "Foo Bar"}, tc.getTwitterUserError) mockTwitterClient.On("GetMe").Return(&twitter.User{ID: "1", Name: "foo", Username: "Foo Bar"}, tc.getTwitterUserError)
@ -283,3 +326,31 @@ func TestCallback(t *testing.T) {
}) })
} }
} }
func TestPostLogout(t *testing.T) {
var mockSessionStore mocks.SessionStore
sess := sessions.NewSession(&mockSessionStore, "elon_session")
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil)
mockSessionStore.On("Save", mock.Anything, mock.Anything, mock.MatchedBy(func(s *sessions.Session) bool {
_, ok := s.Values["current_user"]
return !ok
})).Return(nil)
handler := httpserver.NewHandler(
config.Config{},
&mocks.Store{},
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
&mockSessionStore,
nil,
zap.NewNop(),
)
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
res := rec.Result()
defer res.Body.Close()
assert.Equal(t, http.StatusFound, res.StatusCode)
assert.Equal(t, "/", res.Header.Get("Location"))
}

View File

@ -5,15 +5,22 @@ import (
"time" "time"
"github.com/go-chi/chi/middleware" "github.com/go-chi/chi/middleware"
"github.com/gorilla/sessions"
"go.uber.org/zap" "go.uber.org/zap"
) )
func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler { func loggerMiddleware(l *zap.Logger, sessionStore sessions.Store) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) { fn := func(w http.ResponseWriter, r *http.Request) {
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now() t1 := time.Now()
cu, ok := getCurrentUser(r, sessionStore)
var userID string
if ok {
userID = cu.ID.String()
}
defer func() { defer func() {
l.Info("HTTP", l.Info("HTTP",
zap.String("proto", r.Proto), zap.String("proto", r.Proto),
@ -22,7 +29,10 @@ func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler {
zap.Duration("dur", time.Since(t1)), zap.Duration("dur", time.Since(t1)),
zap.Int("status", ww.Status()), zap.Int("status", ww.Status()),
zap.Int("size", ww.BytesWritten()), zap.Int("size", ww.BytesWritten()),
zap.String("reqId", middleware.GetReqID(r.Context()))) zap.String("ip", r.RemoteAddr),
zap.String("reqId", middleware.GetReqID(r.Context())),
zap.String("userID", userID),
zap.String("username", cu.TwitterUsername))
}() }()
next.ServeHTTP(ww, r) next.ServeHTTP(ww, r)

View File

@ -3,18 +3,36 @@ package templates
import ( import (
"embed" "embed"
"html/template" "html/template"
"io" "net/http"
"go.uber.org/zap"
) )
//go:embed views/*.html //go:embed views/*.html
var tmplFS embed.FS var tmplFS embed.FS
var templatesMap *template.Template var (
Index = parse("views/index.html")
Dashboard = parse("views/dashboard.html")
)
func init() { func parse(file string) *template.Template {
templatesMap = template.Must(template.ParseFS(tmplFS, "views/*.html")) return template.Must(template.New("base.html").ParseFS(tmplFS, "views/base.html", file))
} }
func Execute(w io.Writer, name string, data any) error { // Renderer renders HTML templates, and logs any errors.
return templatesMap.ExecuteTemplate(w, name, data) type Renderer struct {
logger *zap.SugaredLogger
}
func NewRenderer(logger *zap.SugaredLogger) *Renderer {
return &Renderer{logger: logger}
}
// Render renders an HTML template to an http.ResponseWriter.
func (r *Renderer) Render(w http.ResponseWriter, template *template.Template, data any) {
if err := template.Execute(w, data); err != nil {
r.logger.With("err", err).Error("error rendering template")
http.Error(w, "error rendering template", http.StatusInternalServerError)
}
} }

View File

@ -1,9 +1,18 @@
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
<title>Elon Eats My Tweets</title> <title>{{block "title" .}}Elon Eats My Tweets{{end}}</title>
</head> </head>
<body> <body>
{{block "content" .}}{{end}} <h1>Elon eats my tweets</h1>
{{if not .CurrentUser.IsNil}}
<p>
Logged in as @{{.CurrentUser.TwitterUsername}}
<form id="logout_form" action="/logout" method="post"></form>
<a href="/logout" onclick="document.getElementById('logout_form').submit(); return false;">Log out</a>
</p>
{{end}}
{{template "content" .}}
</body> </body>
</html> </html>

View File

@ -0,0 +1,7 @@
{{template "base.html" .}}
{{define "title"}}Elon Eats My Tweets | Dashboard{{end}}
{{define "content"}}
Welcome to your dashboard.
{{end}}

View File

@ -1,8 +1,6 @@
{{template "base.html" .}} {{template "base.html" .}}
{{define "content"}} {{define "content"}}
<h1>Elon eats my tweets<h1>
<p>Sick of Elon? Tired of Twitter? Have Elon Musk delete your tweets for you.</p> <p>Sick of Elon? Tired of Twitter? Have Elon Musk delete your tweets for you.</p>
<a href="/login">Sign in with Twitter</a> <a href="/login">Sign in with Twitter</a>