Implement login/logout flow
continuous-integration/drone/push Build is passing Details

This commit is contained in:
Rob Watson 2022-05-21 22:51:47 +02:00
parent 55ee291390
commit 1816d0be4b
7 changed files with 229 additions and 42 deletions

View File

@ -5,6 +5,8 @@ package httpserver
import (
"context"
"encoding/gob"
"html/template"
"net/http"
"time"
@ -14,11 +16,27 @@ import (
"git.netflux.io/rob/elon-eats-my-tweets/twitter"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"go.uber.org/zap"
"golang.org/x/oauth2"
)
func init() {
gob.Register(CurrentUser{})
}
type CurrentUser struct {
ID uuid.UUID
TwitterName, TwitterUsername string
}
func (cu CurrentUser) IsNil() bool {
return cu.ID == uuid.Nil
}
type contextKey string
const (
sessionName = "elon_session"
sessionKeyState = "state"
@ -27,6 +45,19 @@ const (
pkceVerifierLen = 256
)
const (
sessionKeyCurrentUser = "current_user"
contextKeyCurrentUser contextKey = sessionKeyCurrentUser
)
func getCurrentUser(r *http.Request, sessionStore sessions.Store) (CurrentUser, bool) {
session, _ := sessionStore.Get(r, sessionName)
if cu, ok := session.Values[sessionKeyCurrentUser]; ok {
return cu.(CurrentUser), true
}
return CurrentUser{}, false
}
type TwitterAPIClient interface {
GetMe() (*twitter.User, error)
}
@ -35,6 +66,7 @@ type handler struct {
store twitter.Store
twitterAPIClientFunc func(c *http.Client) TwitterAPIClient
oauth2Config *oauth2.Config
renderer *templates.Renderer
sessionStore sessions.Store
tokenGenerator TokenGenerator
logger *zap.SugaredLogger
@ -44,7 +76,7 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(loggerMiddleware(logger))
r.Use(loggerMiddleware(logger, sessionStore))
r.Use(middleware.Recoverer)
h := handler{
@ -61,6 +93,7 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
AuthStyle: oauth2.AuthStyleInHeader,
},
},
renderer: templates.NewRenderer(logger.Sugar()),
sessionStore: sessionStore,
tokenGenerator: tokenGenerator,
logger: logger.Sugar(),
@ -68,17 +101,26 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
r.Get("/", h.getIndex)
r.Get("/login", h.getLogin)
r.Post("/logout", h.postLogout)
r.Get("/callback", h.getCallback)
return r
}
func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) {
if err := templates.Execute(w, "index.html", nil); err != nil {
h.logger.With("err", err).Error("error rendering template")
http.Error(w, "error rendering template", http.StatusInternalServerError)
return
cu, _ := getCurrentUser(r, h.sessionStore)
var template *template.Template
if cu.IsNil() {
template = templates.Index
} else {
template = templates.Dashboard
}
type data struct {
CurrentUser CurrentUser
}
h.renderer.Render(w, template, data{CurrentUser: cu})
}
func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) {
@ -153,7 +195,7 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) {
return
}
if _, err := h.store.CreateUser(r.Context(), store.CreateUserParams{
user, err := h.store.CreateUser(r.Context(), store.CreateUserParams{
TwitterID: twitterUser.ID,
Username: twitterUser.Username,
Name: twitterUser.Name,
@ -161,12 +203,44 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) {
RefreshToken: token.RefreshToken,
CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}); err != nil {
})
if err != nil {
h.logger.With("err", err).Error("error upserting user")
http.Error(w, "error saving user", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
// finally, save in the session. To avoid hitting the database on most
// requests we'll store the Twitter name and username in the session.
session.Values[sessionKeyCurrentUser] = CurrentUser{
ID: user.ID,
TwitterName: twitterUser.Name,
TwitterUsername: twitterUser.Username,
}
if err := session.Save(r, w); err != nil {
h.logger.With("err", err).Error("error saving session")
http.Error(w, "error saving user", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/", http.StatusFound)
}
func (h *handler) postLogout(w http.ResponseWriter, r *http.Request) {
session, err := h.sessionStore.Get(r, sessionName)
if err != nil {
h.logger.With("err", err).Error("error reading session")
http.Error(w, "error reading session", http.StatusBadRequest)
return
}
delete(session.Values, sessionKeyCurrentUser)
if err := session.Save(r, w); err != nil {
h.logger.With("err", err).Error("error saving session")
http.Error(w, "error saving user", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/", http.StatusFound)
}

View File

@ -12,6 +12,7 @@ import (
"git.netflux.io/rob/elon-eats-my-tweets/generated/store"
"git.netflux.io/rob/elon-eats-my-tweets/httpserver"
"git.netflux.io/rob/elon-eats-my-tweets/twitter"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@ -35,29 +36,56 @@ func (g *mockTokenGenerator) GenerateToken(_ int) string {
}
func TestGetIndex(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
testCases := []struct {
name string
currentUser httpserver.CurrentUser
wantBody string
}{
{
name: "logged out",
currentUser: httpserver.CurrentUser{},
wantBody: "Sign in with Twitter",
},
{
name: "logged out",
currentUser: httpserver.CurrentUser{ID: uuid.New(), TwitterName: "foo", TwitterUsername: "Foo Bar"},
wantBody: "Welcome to your dashboard",
},
}
handler := httpserver.NewHandler(
config.Config{},
&mocks.Store{},
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
&mocks.SessionStore{},
&mockTokenGenerator{},
zap.NewNop(),
)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
res := rec.Result()
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
var mockSessionStore mocks.SessionStore
sess := sessions.NewSession(&mockSessionStore, "elon_session")
sess.Values["current_user"] = tc.currentUser
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil)
mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(nil)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, string(body), "Sign in with Twitter")
handler := httpserver.NewHandler(
config.Config{},
&mocks.Store{},
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
&mockSessionStore,
&mockTokenGenerator{},
zap.NewNop(),
)
handler.ServeHTTP(rec, req)
res := rec.Result()
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, string(body), tc.wantBody)
})
}
}
func TestLogin(t *testing.T) {
func TestGetLogin(t *testing.T) {
testCases := []struct {
name string
sessionSaveError error
@ -122,7 +150,7 @@ func TestLogin(t *testing.T) {
}
}
func TestCallback(t *testing.T) {
func TestGetCallback(t *testing.T) {
testCases := []struct {
name string
state string
@ -133,7 +161,9 @@ func TestCallback(t *testing.T) {
oauth2StatusCode int
getTwitterUserError error
createUserError error
sessionSaveError error
wantStatusCode int
wantLocation string
wantError string
}{
{
@ -205,6 +235,17 @@ func TestCallback(t *testing.T) {
wantStatusCode: http.StatusInternalServerError,
wantError: "error saving user",
},
{
name: "error saving session",
state: "mystate",
sessionState: "mystate",
code: "mycode",
sessionPkceVerifier: "mypkceverifier",
oauth2StatusCode: http.StatusOK,
sessionSaveError: errors.New("foo"),
wantStatusCode: http.StatusInternalServerError,
wantError: "error saving user",
},
{
name: "successful exchange",
state: "mystate",
@ -212,7 +253,8 @@ func TestCallback(t *testing.T) {
code: "mycode",
sessionPkceVerifier: "mypkceverifier",
oauth2StatusCode: http.StatusOK,
wantStatusCode: http.StatusOK,
wantStatusCode: http.StatusFound,
wantLocation: "/",
},
}
@ -223,6 +265,7 @@ func TestCallback(t *testing.T) {
sess.Values["state"] = tc.sessionState
sess.Values["pkce_verifier"] = tc.sessionPkceVerifier
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError)
mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(tc.sessionSaveError)
var mockTwitterClient mocks.TwitterAPIClient
mockTwitterClient.On("GetMe").Return(&twitter.User{ID: "1", Name: "foo", Username: "Foo Bar"}, tc.getTwitterUserError)
@ -283,3 +326,31 @@ func TestCallback(t *testing.T) {
})
}
}
func TestPostLogout(t *testing.T) {
var mockSessionStore mocks.SessionStore
sess := sessions.NewSession(&mockSessionStore, "elon_session")
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil)
mockSessionStore.On("Save", mock.Anything, mock.Anything, mock.MatchedBy(func(s *sessions.Session) bool {
_, ok := s.Values["current_user"]
return !ok
})).Return(nil)
handler := httpserver.NewHandler(
config.Config{},
&mocks.Store{},
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
&mockSessionStore,
nil,
zap.NewNop(),
)
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
res := rec.Result()
defer res.Body.Close()
assert.Equal(t, http.StatusFound, res.StatusCode)
assert.Equal(t, "/", res.Header.Get("Location"))
}

View File

@ -5,15 +5,22 @@ import (
"time"
"github.com/go-chi/chi/middleware"
"github.com/gorilla/sessions"
"go.uber.org/zap"
)
func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler {
func loggerMiddleware(l *zap.Logger, sessionStore sessions.Store) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now()
cu, ok := getCurrentUser(r, sessionStore)
var userID string
if ok {
userID = cu.ID.String()
}
defer func() {
l.Info("HTTP",
zap.String("proto", r.Proto),
@ -22,7 +29,10 @@ func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler {
zap.Duration("dur", time.Since(t1)),
zap.Int("status", ww.Status()),
zap.Int("size", ww.BytesWritten()),
zap.String("reqId", middleware.GetReqID(r.Context())))
zap.String("ip", r.RemoteAddr),
zap.String("reqId", middleware.GetReqID(r.Context())),
zap.String("userID", userID),
zap.String("username", cu.TwitterUsername))
}()
next.ServeHTTP(ww, r)

View File

@ -3,18 +3,36 @@ package templates
import (
"embed"
"html/template"
"io"
"net/http"
"go.uber.org/zap"
)
//go:embed views/*.html
var tmplFS embed.FS
var templatesMap *template.Template
var (
Index = parse("views/index.html")
Dashboard = parse("views/dashboard.html")
)
func init() {
templatesMap = template.Must(template.ParseFS(tmplFS, "views/*.html"))
func parse(file string) *template.Template {
return template.Must(template.New("base.html").ParseFS(tmplFS, "views/base.html", file))
}
func Execute(w io.Writer, name string, data any) error {
return templatesMap.ExecuteTemplate(w, name, data)
// Renderer renders HTML templates, and logs any errors.
type Renderer struct {
logger *zap.SugaredLogger
}
func NewRenderer(logger *zap.SugaredLogger) *Renderer {
return &Renderer{logger: logger}
}
// Render renders an HTML template to an http.ResponseWriter.
func (r *Renderer) Render(w http.ResponseWriter, template *template.Template, data any) {
if err := template.Execute(w, data); err != nil {
r.logger.With("err", err).Error("error rendering template")
http.Error(w, "error rendering template", http.StatusInternalServerError)
}
}

View File

@ -1,9 +1,18 @@
<!DOCTYPE html>
<html>
<head>
<title>Elon Eats My Tweets</title>
<title>{{block "title" .}}Elon Eats My Tweets{{end}}</title>
</head>
<body>
{{block "content" .}}{{end}}
<h1>Elon eats my tweets</h1>
{{if not .CurrentUser.IsNil}}
<p>
Logged in as @{{.CurrentUser.TwitterUsername}}
<form id="logout_form" action="/logout" method="post"></form>
<a href="/logout" onclick="document.getElementById('logout_form').submit(); return false;">Log out</a>
</p>
{{end}}
{{template "content" .}}
</body>
</html>

View File

@ -0,0 +1,7 @@
{{template "base.html" .}}
{{define "title"}}Elon Eats My Tweets | Dashboard{{end}}
{{define "content"}}
Welcome to your dashboard.
{{end}}

View File

@ -1,8 +1,6 @@
{{template "base.html" .}}
{{define "content"}}
<h1>Elon eats my tweets<h1>
<p>Sick of Elon? Tired of Twitter? Have Elon Musk delete your tweets for you.</p>
<a href="/login">Sign in with Twitter</a>