diff --git a/httpserver/handler.go b/httpserver/handler.go index 1803b24..aa33d92 100644 --- a/httpserver/handler.go +++ b/httpserver/handler.go @@ -5,6 +5,8 @@ package httpserver import ( "context" + "encoding/gob" + "html/template" "net/http" "time" @@ -14,11 +16,27 @@ import ( "git.netflux.io/rob/elon-eats-my-tweets/twitter" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" + "github.com/google/uuid" "github.com/gorilla/sessions" "go.uber.org/zap" "golang.org/x/oauth2" ) +func init() { + gob.Register(CurrentUser{}) +} + +type CurrentUser struct { + ID uuid.UUID + TwitterName, TwitterUsername string +} + +func (cu CurrentUser) IsNil() bool { + return cu.ID == uuid.Nil +} + +type contextKey string + const ( sessionName = "elon_session" sessionKeyState = "state" @@ -27,6 +45,19 @@ const ( pkceVerifierLen = 256 ) +const ( + sessionKeyCurrentUser = "current_user" + contextKeyCurrentUser contextKey = sessionKeyCurrentUser +) + +func getCurrentUser(r *http.Request, sessionStore sessions.Store) (CurrentUser, bool) { + session, _ := sessionStore.Get(r, sessionName) + if cu, ok := session.Values[sessionKeyCurrentUser]; ok { + return cu.(CurrentUser), true + } + return CurrentUser{}, false +} + type TwitterAPIClient interface { GetMe() (*twitter.User, error) } @@ -35,6 +66,7 @@ type handler struct { store twitter.Store twitterAPIClientFunc func(c *http.Client) TwitterAPIClient oauth2Config *oauth2.Config + renderer *templates.Renderer sessionStore sessions.Store tokenGenerator TokenGenerator logger *zap.SugaredLogger @@ -44,7 +76,7 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun r := chi.NewRouter() r.Use(middleware.RequestID) r.Use(middleware.RealIP) - r.Use(loggerMiddleware(logger)) + r.Use(loggerMiddleware(logger, sessionStore)) r.Use(middleware.Recoverer) h := handler{ @@ -61,6 +93,7 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun AuthStyle: oauth2.AuthStyleInHeader, }, }, + renderer: templates.NewRenderer(logger.Sugar()), sessionStore: sessionStore, tokenGenerator: tokenGenerator, logger: logger.Sugar(), @@ -68,17 +101,26 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun r.Get("/", h.getIndex) r.Get("/login", h.getLogin) + r.Post("/logout", h.postLogout) r.Get("/callback", h.getCallback) return r } func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) { - if err := templates.Execute(w, "index.html", nil); err != nil { - h.logger.With("err", err).Error("error rendering template") - http.Error(w, "error rendering template", http.StatusInternalServerError) - return + cu, _ := getCurrentUser(r, h.sessionStore) + var template *template.Template + if cu.IsNil() { + template = templates.Index + } else { + template = templates.Dashboard } + + type data struct { + CurrentUser CurrentUser + } + + h.renderer.Render(w, template, data{CurrentUser: cu}) } func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) { @@ -153,7 +195,7 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) { return } - if _, err := h.store.CreateUser(r.Context(), store.CreateUserParams{ + user, err := h.store.CreateUser(r.Context(), store.CreateUserParams{ TwitterID: twitterUser.ID, Username: twitterUser.Username, Name: twitterUser.Name, @@ -161,12 +203,44 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) { RefreshToken: token.RefreshToken, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), - }); err != nil { + }) + if err != nil { h.logger.With("err", err).Error("error upserting user") http.Error(w, "error saving user", http.StatusInternalServerError) return } - w.WriteHeader(http.StatusOK) - w.Write([]byte("ok")) + // finally, save in the session. To avoid hitting the database on most + // requests we'll store the Twitter name and username in the session. + session.Values[sessionKeyCurrentUser] = CurrentUser{ + ID: user.ID, + TwitterName: twitterUser.Name, + TwitterUsername: twitterUser.Username, + } + if err := session.Save(r, w); err != nil { + h.logger.With("err", err).Error("error saving session") + http.Error(w, "error saving user", http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/", http.StatusFound) +} + +func (h *handler) postLogout(w http.ResponseWriter, r *http.Request) { + session, err := h.sessionStore.Get(r, sessionName) + if err != nil { + h.logger.With("err", err).Error("error reading session") + http.Error(w, "error reading session", http.StatusBadRequest) + return + } + + delete(session.Values, sessionKeyCurrentUser) + + if err := session.Save(r, w); err != nil { + h.logger.With("err", err).Error("error saving session") + http.Error(w, "error saving user", http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/", http.StatusFound) } diff --git a/httpserver/handler_test.go b/httpserver/handler_test.go index fb59b9d..3335b55 100644 --- a/httpserver/handler_test.go +++ b/httpserver/handler_test.go @@ -12,6 +12,7 @@ import ( "git.netflux.io/rob/elon-eats-my-tweets/generated/store" "git.netflux.io/rob/elon-eats-my-tweets/httpserver" "git.netflux.io/rob/elon-eats-my-tweets/twitter" + "github.com/google/uuid" "github.com/gorilla/sessions" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -35,29 +36,56 @@ func (g *mockTokenGenerator) GenerateToken(_ int) string { } func TestGetIndex(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() + testCases := []struct { + name string + currentUser httpserver.CurrentUser + wantBody string + }{ + { + name: "logged out", + currentUser: httpserver.CurrentUser{}, + wantBody: "Sign in with Twitter", + }, + { + name: "logged out", + currentUser: httpserver.CurrentUser{ID: uuid.New(), TwitterName: "foo", TwitterUsername: "Foo Bar"}, + wantBody: "Welcome to your dashboard", + }, + } - handler := httpserver.NewHandler( - config.Config{}, - &mocks.Store{}, - func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} }, - &mocks.SessionStore{}, - &mockTokenGenerator{}, - zap.NewNop(), - ) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() - handler.ServeHTTP(rec, req) - res := rec.Result() - defer res.Body.Close() - body, err := ioutil.ReadAll(res.Body) - require.NoError(t, err) + var mockSessionStore mocks.SessionStore + sess := sessions.NewSession(&mockSessionStore, "elon_session") + sess.Values["current_user"] = tc.currentUser + mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil) + mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(nil) - assert.Equal(t, http.StatusOK, res.StatusCode) - assert.Contains(t, string(body), "Sign in with Twitter") + handler := httpserver.NewHandler( + config.Config{}, + &mocks.Store{}, + func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} }, + &mockSessionStore, + &mockTokenGenerator{}, + zap.NewNop(), + ) + + handler.ServeHTTP(rec, req) + res := rec.Result() + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Contains(t, string(body), tc.wantBody) + }) + } } -func TestLogin(t *testing.T) { +func TestGetLogin(t *testing.T) { testCases := []struct { name string sessionSaveError error @@ -122,7 +150,7 @@ func TestLogin(t *testing.T) { } } -func TestCallback(t *testing.T) { +func TestGetCallback(t *testing.T) { testCases := []struct { name string state string @@ -133,7 +161,9 @@ func TestCallback(t *testing.T) { oauth2StatusCode int getTwitterUserError error createUserError error + sessionSaveError error wantStatusCode int + wantLocation string wantError string }{ { @@ -205,6 +235,17 @@ func TestCallback(t *testing.T) { wantStatusCode: http.StatusInternalServerError, wantError: "error saving user", }, + { + name: "error saving session", + state: "mystate", + sessionState: "mystate", + code: "mycode", + sessionPkceVerifier: "mypkceverifier", + oauth2StatusCode: http.StatusOK, + sessionSaveError: errors.New("foo"), + wantStatusCode: http.StatusInternalServerError, + wantError: "error saving user", + }, { name: "successful exchange", state: "mystate", @@ -212,7 +253,8 @@ func TestCallback(t *testing.T) { code: "mycode", sessionPkceVerifier: "mypkceverifier", oauth2StatusCode: http.StatusOK, - wantStatusCode: http.StatusOK, + wantStatusCode: http.StatusFound, + wantLocation: "/", }, } @@ -223,6 +265,7 @@ func TestCallback(t *testing.T) { sess.Values["state"] = tc.sessionState sess.Values["pkce_verifier"] = tc.sessionPkceVerifier mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError) + mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(tc.sessionSaveError) var mockTwitterClient mocks.TwitterAPIClient mockTwitterClient.On("GetMe").Return(&twitter.User{ID: "1", Name: "foo", Username: "Foo Bar"}, tc.getTwitterUserError) @@ -283,3 +326,31 @@ func TestCallback(t *testing.T) { }) } } + +func TestPostLogout(t *testing.T) { + var mockSessionStore mocks.SessionStore + sess := sessions.NewSession(&mockSessionStore, "elon_session") + mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil) + mockSessionStore.On("Save", mock.Anything, mock.Anything, mock.MatchedBy(func(s *sessions.Session) bool { + _, ok := s.Values["current_user"] + return !ok + })).Return(nil) + + handler := httpserver.NewHandler( + config.Config{}, + &mocks.Store{}, + func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} }, + &mockSessionStore, + nil, + zap.NewNop(), + ) + + req := httptest.NewRequest(http.MethodPost, "/logout", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + res := rec.Result() + defer res.Body.Close() + + assert.Equal(t, http.StatusFound, res.StatusCode) + assert.Equal(t, "/", res.Header.Get("Location")) +} diff --git a/httpserver/logger.go b/httpserver/middleware.go similarity index 59% rename from httpserver/logger.go rename to httpserver/middleware.go index 4aed649..217cac0 100644 --- a/httpserver/logger.go +++ b/httpserver/middleware.go @@ -5,15 +5,22 @@ import ( "time" "github.com/go-chi/chi/middleware" + "github.com/gorilla/sessions" "go.uber.org/zap" ) -func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler { +func loggerMiddleware(l *zap.Logger, sessionStore sessions.Store) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) - t1 := time.Now() + + cu, ok := getCurrentUser(r, sessionStore) + var userID string + if ok { + userID = cu.ID.String() + } + defer func() { l.Info("HTTP", zap.String("proto", r.Proto), @@ -22,7 +29,10 @@ func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler { zap.Duration("dur", time.Since(t1)), zap.Int("status", ww.Status()), zap.Int("size", ww.BytesWritten()), - zap.String("reqId", middleware.GetReqID(r.Context()))) + zap.String("ip", r.RemoteAddr), + zap.String("reqId", middleware.GetReqID(r.Context())), + zap.String("userID", userID), + zap.String("username", cu.TwitterUsername)) }() next.ServeHTTP(ww, r) diff --git a/templates/templates.go b/templates/templates.go index 97be628..43da288 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -3,18 +3,36 @@ package templates import ( "embed" "html/template" - "io" + "net/http" + + "go.uber.org/zap" ) //go:embed views/*.html var tmplFS embed.FS -var templatesMap *template.Template +var ( + Index = parse("views/index.html") + Dashboard = parse("views/dashboard.html") +) -func init() { - templatesMap = template.Must(template.ParseFS(tmplFS, "views/*.html")) +func parse(file string) *template.Template { + return template.Must(template.New("base.html").ParseFS(tmplFS, "views/base.html", file)) } -func Execute(w io.Writer, name string, data any) error { - return templatesMap.ExecuteTemplate(w, name, data) +// Renderer renders HTML templates, and logs any errors. +type Renderer struct { + logger *zap.SugaredLogger +} + +func NewRenderer(logger *zap.SugaredLogger) *Renderer { + return &Renderer{logger: logger} +} + +// Render renders an HTML template to an http.ResponseWriter. +func (r *Renderer) Render(w http.ResponseWriter, template *template.Template, data any) { + if err := template.Execute(w, data); err != nil { + r.logger.With("err", err).Error("error rendering template") + http.Error(w, "error rendering template", http.StatusInternalServerError) + } } diff --git a/templates/views/base.html b/templates/views/base.html index dc019a5..c9ed717 100644 --- a/templates/views/base.html +++ b/templates/views/base.html @@ -1,9 +1,18 @@ - Elon Eats My Tweets + {{block "title" .}}Elon Eats My Tweets{{end}} - {{block "content" .}}{{end}} +

Elon eats my tweets

+ {{if not .CurrentUser.IsNil}} +

+ Logged in as @{{.CurrentUser.TwitterUsername}} +

+ Log out +

+ {{end}} + + {{template "content" .}} diff --git a/templates/views/dashboard.html b/templates/views/dashboard.html new file mode 100644 index 0000000..3d8c8f0 --- /dev/null +++ b/templates/views/dashboard.html @@ -0,0 +1,7 @@ +{{template "base.html" .}} + +{{define "title"}}Elon Eats My Tweets | Dashboard{{end}} + +{{define "content"}} +Welcome to your dashboard. +{{end}} diff --git a/templates/views/index.html b/templates/views/index.html index e247300..5965ab7 100644 --- a/templates/views/index.html +++ b/templates/views/index.html @@ -1,8 +1,6 @@ {{template "base.html" .}} {{define "content"}} -

Elon eats my tweets

-

Sick of Elon? Tired of Twitter? Have Elon Musk delete your tweets for you.

Sign in with Twitter