From 1816d0be4b273f995c1c10fca14fab80089f0586 Mon Sep 17 00:00:00 2001
From: Rob Watson
Date: Sat, 21 May 2022 22:51:47 +0200
Subject: [PATCH] Implement login/logout flow
---
httpserver/handler.go | 92 ++++++++++++++++++--
httpserver/handler_test.go | 111 +++++++++++++++++++-----
httpserver/{logger.go => middleware.go} | 16 +++-
templates/templates.go | 30 +++++--
templates/views/base.html | 13 ++-
templates/views/dashboard.html | 7 ++
templates/views/index.html | 2 -
7 files changed, 229 insertions(+), 42 deletions(-)
rename httpserver/{logger.go => middleware.go} (59%)
create mode 100644 templates/views/dashboard.html
diff --git a/httpserver/handler.go b/httpserver/handler.go
index 1803b24..aa33d92 100644
--- a/httpserver/handler.go
+++ b/httpserver/handler.go
@@ -5,6 +5,8 @@ package httpserver
import (
"context"
+ "encoding/gob"
+ "html/template"
"net/http"
"time"
@@ -14,11 +16,27 @@ import (
"git.netflux.io/rob/elon-eats-my-tweets/twitter"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
+ "github.com/google/uuid"
"github.com/gorilla/sessions"
"go.uber.org/zap"
"golang.org/x/oauth2"
)
+func init() {
+ gob.Register(CurrentUser{})
+}
+
+type CurrentUser struct {
+ ID uuid.UUID
+ TwitterName, TwitterUsername string
+}
+
+func (cu CurrentUser) IsNil() bool {
+ return cu.ID == uuid.Nil
+}
+
+type contextKey string
+
const (
sessionName = "elon_session"
sessionKeyState = "state"
@@ -27,6 +45,19 @@ const (
pkceVerifierLen = 256
)
+const (
+ sessionKeyCurrentUser = "current_user"
+ contextKeyCurrentUser contextKey = sessionKeyCurrentUser
+)
+
+func getCurrentUser(r *http.Request, sessionStore sessions.Store) (CurrentUser, bool) {
+ session, _ := sessionStore.Get(r, sessionName)
+ if cu, ok := session.Values[sessionKeyCurrentUser]; ok {
+ return cu.(CurrentUser), true
+ }
+ return CurrentUser{}, false
+}
+
type TwitterAPIClient interface {
GetMe() (*twitter.User, error)
}
@@ -35,6 +66,7 @@ type handler struct {
store twitter.Store
twitterAPIClientFunc func(c *http.Client) TwitterAPIClient
oauth2Config *oauth2.Config
+ renderer *templates.Renderer
sessionStore sessions.Store
tokenGenerator TokenGenerator
logger *zap.SugaredLogger
@@ -44,7 +76,7 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
- r.Use(loggerMiddleware(logger))
+ r.Use(loggerMiddleware(logger, sessionStore))
r.Use(middleware.Recoverer)
h := handler{
@@ -61,6 +93,7 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
AuthStyle: oauth2.AuthStyleInHeader,
},
},
+ renderer: templates.NewRenderer(logger.Sugar()),
sessionStore: sessionStore,
tokenGenerator: tokenGenerator,
logger: logger.Sugar(),
@@ -68,17 +101,26 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
r.Get("/", h.getIndex)
r.Get("/login", h.getLogin)
+ r.Post("/logout", h.postLogout)
r.Get("/callback", h.getCallback)
return r
}
func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) {
- if err := templates.Execute(w, "index.html", nil); err != nil {
- h.logger.With("err", err).Error("error rendering template")
- http.Error(w, "error rendering template", http.StatusInternalServerError)
- return
+ cu, _ := getCurrentUser(r, h.sessionStore)
+ var template *template.Template
+ if cu.IsNil() {
+ template = templates.Index
+ } else {
+ template = templates.Dashboard
}
+
+ type data struct {
+ CurrentUser CurrentUser
+ }
+
+ h.renderer.Render(w, template, data{CurrentUser: cu})
}
func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) {
@@ -153,7 +195,7 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) {
return
}
- if _, err := h.store.CreateUser(r.Context(), store.CreateUserParams{
+ user, err := h.store.CreateUser(r.Context(), store.CreateUserParams{
TwitterID: twitterUser.ID,
Username: twitterUser.Username,
Name: twitterUser.Name,
@@ -161,12 +203,44 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) {
RefreshToken: token.RefreshToken,
CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
- }); err != nil {
+ })
+ if err != nil {
h.logger.With("err", err).Error("error upserting user")
http.Error(w, "error saving user", http.StatusInternalServerError)
return
}
- w.WriteHeader(http.StatusOK)
- w.Write([]byte("ok"))
+ // finally, save in the session. To avoid hitting the database on most
+ // requests we'll store the Twitter name and username in the session.
+ session.Values[sessionKeyCurrentUser] = CurrentUser{
+ ID: user.ID,
+ TwitterName: twitterUser.Name,
+ TwitterUsername: twitterUser.Username,
+ }
+ if err := session.Save(r, w); err != nil {
+ h.logger.With("err", err).Error("error saving session")
+ http.Error(w, "error saving user", http.StatusInternalServerError)
+ return
+ }
+
+ http.Redirect(w, r, "/", http.StatusFound)
+}
+
+func (h *handler) postLogout(w http.ResponseWriter, r *http.Request) {
+ session, err := h.sessionStore.Get(r, sessionName)
+ if err != nil {
+ h.logger.With("err", err).Error("error reading session")
+ http.Error(w, "error reading session", http.StatusBadRequest)
+ return
+ }
+
+ delete(session.Values, sessionKeyCurrentUser)
+
+ if err := session.Save(r, w); err != nil {
+ h.logger.With("err", err).Error("error saving session")
+ http.Error(w, "error saving user", http.StatusInternalServerError)
+ return
+ }
+
+ http.Redirect(w, r, "/", http.StatusFound)
}
diff --git a/httpserver/handler_test.go b/httpserver/handler_test.go
index fb59b9d..3335b55 100644
--- a/httpserver/handler_test.go
+++ b/httpserver/handler_test.go
@@ -12,6 +12,7 @@ import (
"git.netflux.io/rob/elon-eats-my-tweets/generated/store"
"git.netflux.io/rob/elon-eats-my-tweets/httpserver"
"git.netflux.io/rob/elon-eats-my-tweets/twitter"
+ "github.com/google/uuid"
"github.com/gorilla/sessions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@@ -35,29 +36,56 @@ func (g *mockTokenGenerator) GenerateToken(_ int) string {
}
func TestGetIndex(t *testing.T) {
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
+ testCases := []struct {
+ name string
+ currentUser httpserver.CurrentUser
+ wantBody string
+ }{
+ {
+ name: "logged out",
+ currentUser: httpserver.CurrentUser{},
+ wantBody: "Sign in with Twitter",
+ },
+ {
+ name: "logged out",
+ currentUser: httpserver.CurrentUser{ID: uuid.New(), TwitterName: "foo", TwitterUsername: "Foo Bar"},
+ wantBody: "Welcome to your dashboard",
+ },
+ }
- handler := httpserver.NewHandler(
- config.Config{},
- &mocks.Store{},
- func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
- &mocks.SessionStore{},
- &mockTokenGenerator{},
- zap.NewNop(),
- )
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
- handler.ServeHTTP(rec, req)
- res := rec.Result()
- defer res.Body.Close()
- body, err := ioutil.ReadAll(res.Body)
- require.NoError(t, err)
+ var mockSessionStore mocks.SessionStore
+ sess := sessions.NewSession(&mockSessionStore, "elon_session")
+ sess.Values["current_user"] = tc.currentUser
+ mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil)
+ mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(nil)
- assert.Equal(t, http.StatusOK, res.StatusCode)
- assert.Contains(t, string(body), "Sign in with Twitter")
+ handler := httpserver.NewHandler(
+ config.Config{},
+ &mocks.Store{},
+ func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
+ &mockSessionStore,
+ &mockTokenGenerator{},
+ zap.NewNop(),
+ )
+
+ handler.ServeHTTP(rec, req)
+ res := rec.Result()
+ defer res.Body.Close()
+ body, err := ioutil.ReadAll(res.Body)
+ require.NoError(t, err)
+
+ assert.Equal(t, http.StatusOK, res.StatusCode)
+ assert.Contains(t, string(body), tc.wantBody)
+ })
+ }
}
-func TestLogin(t *testing.T) {
+func TestGetLogin(t *testing.T) {
testCases := []struct {
name string
sessionSaveError error
@@ -122,7 +150,7 @@ func TestLogin(t *testing.T) {
}
}
-func TestCallback(t *testing.T) {
+func TestGetCallback(t *testing.T) {
testCases := []struct {
name string
state string
@@ -133,7 +161,9 @@ func TestCallback(t *testing.T) {
oauth2StatusCode int
getTwitterUserError error
createUserError error
+ sessionSaveError error
wantStatusCode int
+ wantLocation string
wantError string
}{
{
@@ -205,6 +235,17 @@ func TestCallback(t *testing.T) {
wantStatusCode: http.StatusInternalServerError,
wantError: "error saving user",
},
+ {
+ name: "error saving session",
+ state: "mystate",
+ sessionState: "mystate",
+ code: "mycode",
+ sessionPkceVerifier: "mypkceverifier",
+ oauth2StatusCode: http.StatusOK,
+ sessionSaveError: errors.New("foo"),
+ wantStatusCode: http.StatusInternalServerError,
+ wantError: "error saving user",
+ },
{
name: "successful exchange",
state: "mystate",
@@ -212,7 +253,8 @@ func TestCallback(t *testing.T) {
code: "mycode",
sessionPkceVerifier: "mypkceverifier",
oauth2StatusCode: http.StatusOK,
- wantStatusCode: http.StatusOK,
+ wantStatusCode: http.StatusFound,
+ wantLocation: "/",
},
}
@@ -223,6 +265,7 @@ func TestCallback(t *testing.T) {
sess.Values["state"] = tc.sessionState
sess.Values["pkce_verifier"] = tc.sessionPkceVerifier
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError)
+ mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(tc.sessionSaveError)
var mockTwitterClient mocks.TwitterAPIClient
mockTwitterClient.On("GetMe").Return(&twitter.User{ID: "1", Name: "foo", Username: "Foo Bar"}, tc.getTwitterUserError)
@@ -283,3 +326,31 @@ func TestCallback(t *testing.T) {
})
}
}
+
+func TestPostLogout(t *testing.T) {
+ var mockSessionStore mocks.SessionStore
+ sess := sessions.NewSession(&mockSessionStore, "elon_session")
+ mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil)
+ mockSessionStore.On("Save", mock.Anything, mock.Anything, mock.MatchedBy(func(s *sessions.Session) bool {
+ _, ok := s.Values["current_user"]
+ return !ok
+ })).Return(nil)
+
+ handler := httpserver.NewHandler(
+ config.Config{},
+ &mocks.Store{},
+ func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
+ &mockSessionStore,
+ nil,
+ zap.NewNop(),
+ )
+
+ req := httptest.NewRequest(http.MethodPost, "/logout", nil)
+ rec := httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+ res := rec.Result()
+ defer res.Body.Close()
+
+ assert.Equal(t, http.StatusFound, res.StatusCode)
+ assert.Equal(t, "/", res.Header.Get("Location"))
+}
diff --git a/httpserver/logger.go b/httpserver/middleware.go
similarity index 59%
rename from httpserver/logger.go
rename to httpserver/middleware.go
index 4aed649..217cac0 100644
--- a/httpserver/logger.go
+++ b/httpserver/middleware.go
@@ -5,15 +5,22 @@ import (
"time"
"github.com/go-chi/chi/middleware"
+ "github.com/gorilla/sessions"
"go.uber.org/zap"
)
-func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler {
+func loggerMiddleware(l *zap.Logger, sessionStore sessions.Store) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
-
t1 := time.Now()
+
+ cu, ok := getCurrentUser(r, sessionStore)
+ var userID string
+ if ok {
+ userID = cu.ID.String()
+ }
+
defer func() {
l.Info("HTTP",
zap.String("proto", r.Proto),
@@ -22,7 +29,10 @@ func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler {
zap.Duration("dur", time.Since(t1)),
zap.Int("status", ww.Status()),
zap.Int("size", ww.BytesWritten()),
- zap.String("reqId", middleware.GetReqID(r.Context())))
+ zap.String("ip", r.RemoteAddr),
+ zap.String("reqId", middleware.GetReqID(r.Context())),
+ zap.String("userID", userID),
+ zap.String("username", cu.TwitterUsername))
}()
next.ServeHTTP(ww, r)
diff --git a/templates/templates.go b/templates/templates.go
index 97be628..43da288 100644
--- a/templates/templates.go
+++ b/templates/templates.go
@@ -3,18 +3,36 @@ package templates
import (
"embed"
"html/template"
- "io"
+ "net/http"
+
+ "go.uber.org/zap"
)
//go:embed views/*.html
var tmplFS embed.FS
-var templatesMap *template.Template
+var (
+ Index = parse("views/index.html")
+ Dashboard = parse("views/dashboard.html")
+)
-func init() {
- templatesMap = template.Must(template.ParseFS(tmplFS, "views/*.html"))
+func parse(file string) *template.Template {
+ return template.Must(template.New("base.html").ParseFS(tmplFS, "views/base.html", file))
}
-func Execute(w io.Writer, name string, data any) error {
- return templatesMap.ExecuteTemplate(w, name, data)
+// Renderer renders HTML templates, and logs any errors.
+type Renderer struct {
+ logger *zap.SugaredLogger
+}
+
+func NewRenderer(logger *zap.SugaredLogger) *Renderer {
+ return &Renderer{logger: logger}
+}
+
+// Render renders an HTML template to an http.ResponseWriter.
+func (r *Renderer) Render(w http.ResponseWriter, template *template.Template, data any) {
+ if err := template.Execute(w, data); err != nil {
+ r.logger.With("err", err).Error("error rendering template")
+ http.Error(w, "error rendering template", http.StatusInternalServerError)
+ }
}
diff --git a/templates/views/base.html b/templates/views/base.html
index dc019a5..c9ed717 100644
--- a/templates/views/base.html
+++ b/templates/views/base.html
@@ -1,9 +1,18 @@
- Elon Eats My Tweets
+ {{block "title" .}}Elon Eats My Tweets{{end}}
- {{block "content" .}}{{end}}
+ Elon eats my tweets
+ {{if not .CurrentUser.IsNil}}
+
+ Logged in as @{{.CurrentUser.TwitterUsername}}
+
+ Log out
+
+ {{end}}
+
+ {{template "content" .}}
diff --git a/templates/views/dashboard.html b/templates/views/dashboard.html
new file mode 100644
index 0000000..3d8c8f0
--- /dev/null
+++ b/templates/views/dashboard.html
@@ -0,0 +1,7 @@
+{{template "base.html" .}}
+
+{{define "title"}}Elon Eats My Tweets | Dashboard{{end}}
+
+{{define "content"}}
+Welcome to your dashboard.
+{{end}}
diff --git a/templates/views/index.html b/templates/views/index.html
index e247300..5965ab7 100644
--- a/templates/views/index.html
+++ b/templates/views/index.html
@@ -1,8 +1,6 @@
{{template "base.html" .}}
{{define "content"}}
-Elon eats my tweets
-
Sick of Elon? Tired of Twitter? Have Elon Musk delete your tweets for you.
Sign in with Twitter