Implement login/logout flow
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
55ee291390
commit
1816d0be4b
|
@ -5,6 +5,8 @@ package httpserver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -14,11 +16,27 @@ import (
|
|||
"git.netflux.io/rob/elon-eats-my-tweets/twitter"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(CurrentUser{})
|
||||
}
|
||||
|
||||
type CurrentUser struct {
|
||||
ID uuid.UUID
|
||||
TwitterName, TwitterUsername string
|
||||
}
|
||||
|
||||
func (cu CurrentUser) IsNil() bool {
|
||||
return cu.ID == uuid.Nil
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
sessionName = "elon_session"
|
||||
sessionKeyState = "state"
|
||||
|
@ -27,6 +45,19 @@ const (
|
|||
pkceVerifierLen = 256
|
||||
)
|
||||
|
||||
const (
|
||||
sessionKeyCurrentUser = "current_user"
|
||||
contextKeyCurrentUser contextKey = sessionKeyCurrentUser
|
||||
)
|
||||
|
||||
func getCurrentUser(r *http.Request, sessionStore sessions.Store) (CurrentUser, bool) {
|
||||
session, _ := sessionStore.Get(r, sessionName)
|
||||
if cu, ok := session.Values[sessionKeyCurrentUser]; ok {
|
||||
return cu.(CurrentUser), true
|
||||
}
|
||||
return CurrentUser{}, false
|
||||
}
|
||||
|
||||
type TwitterAPIClient interface {
|
||||
GetMe() (*twitter.User, error)
|
||||
}
|
||||
|
@ -35,6 +66,7 @@ type handler struct {
|
|||
store twitter.Store
|
||||
twitterAPIClientFunc func(c *http.Client) TwitterAPIClient
|
||||
oauth2Config *oauth2.Config
|
||||
renderer *templates.Renderer
|
||||
sessionStore sessions.Store
|
||||
tokenGenerator TokenGenerator
|
||||
logger *zap.SugaredLogger
|
||||
|
@ -44,7 +76,7 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
|
|||
r := chi.NewRouter()
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(loggerMiddleware(logger))
|
||||
r.Use(loggerMiddleware(logger, sessionStore))
|
||||
r.Use(middleware.Recoverer)
|
||||
|
||||
h := handler{
|
||||
|
@ -61,6 +93,7 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
|
|||
AuthStyle: oauth2.AuthStyleInHeader,
|
||||
},
|
||||
},
|
||||
renderer: templates.NewRenderer(logger.Sugar()),
|
||||
sessionStore: sessionStore,
|
||||
tokenGenerator: tokenGenerator,
|
||||
logger: logger.Sugar(),
|
||||
|
@ -68,17 +101,26 @@ func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc fun
|
|||
|
||||
r.Get("/", h.getIndex)
|
||||
r.Get("/login", h.getLogin)
|
||||
r.Post("/logout", h.postLogout)
|
||||
r.Get("/callback", h.getCallback)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) {
|
||||
if err := templates.Execute(w, "index.html", nil); err != nil {
|
||||
h.logger.With("err", err).Error("error rendering template")
|
||||
http.Error(w, "error rendering template", http.StatusInternalServerError)
|
||||
return
|
||||
cu, _ := getCurrentUser(r, h.sessionStore)
|
||||
var template *template.Template
|
||||
if cu.IsNil() {
|
||||
template = templates.Index
|
||||
} else {
|
||||
template = templates.Dashboard
|
||||
}
|
||||
|
||||
type data struct {
|
||||
CurrentUser CurrentUser
|
||||
}
|
||||
|
||||
h.renderer.Render(w, template, data{CurrentUser: cu})
|
||||
}
|
||||
|
||||
func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -153,7 +195,7 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if _, err := h.store.CreateUser(r.Context(), store.CreateUserParams{
|
||||
user, err := h.store.CreateUser(r.Context(), store.CreateUserParams{
|
||||
TwitterID: twitterUser.ID,
|
||||
Username: twitterUser.Username,
|
||||
Name: twitterUser.Name,
|
||||
|
@ -161,12 +203,44 @@ func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) {
|
|||
RefreshToken: token.RefreshToken,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.With("err", err).Error("error upserting user")
|
||||
http.Error(w, "error saving user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
// finally, save in the session. To avoid hitting the database on most
|
||||
// requests we'll store the Twitter name and username in the session.
|
||||
session.Values[sessionKeyCurrentUser] = CurrentUser{
|
||||
ID: user.ID,
|
||||
TwitterName: twitterUser.Name,
|
||||
TwitterUsername: twitterUser.Username,
|
||||
}
|
||||
if err := session.Save(r, w); err != nil {
|
||||
h.logger.With("err", err).Error("error saving session")
|
||||
http.Error(w, "error saving user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *handler) postLogout(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := h.sessionStore.Get(r, sessionName)
|
||||
if err != nil {
|
||||
h.logger.With("err", err).Error("error reading session")
|
||||
http.Error(w, "error reading session", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
delete(session.Values, sessionKeyCurrentUser)
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
h.logger.With("err", err).Error("error saving session")
|
||||
http.Error(w, "error saving user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"git.netflux.io/rob/elon-eats-my-tweets/generated/store"
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/httpserver"
|
||||
"git.netflux.io/rob/elon-eats-my-tweets/twitter"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
@ -35,29 +36,56 @@ func (g *mockTokenGenerator) GenerateToken(_ int) string {
|
|||
}
|
||||
|
||||
func TestGetIndex(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
testCases := []struct {
|
||||
name string
|
||||
currentUser httpserver.CurrentUser
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "logged out",
|
||||
currentUser: httpserver.CurrentUser{},
|
||||
wantBody: "Sign in with Twitter",
|
||||
},
|
||||
{
|
||||
name: "logged out",
|
||||
currentUser: httpserver.CurrentUser{ID: uuid.New(), TwitterName: "foo", TwitterUsername: "Foo Bar"},
|
||||
wantBody: "Welcome to your dashboard",
|
||||
},
|
||||
}
|
||||
|
||||
handler := httpserver.NewHandler(
|
||||
config.Config{},
|
||||
&mocks.Store{},
|
||||
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
|
||||
&mocks.SessionStore{},
|
||||
&mockTokenGenerator{},
|
||||
zap.NewNop(),
|
||||
)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
var mockSessionStore mocks.SessionStore
|
||||
sess := sessions.NewSession(&mockSessionStore, "elon_session")
|
||||
sess.Values["current_user"] = tc.currentUser
|
||||
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil)
|
||||
mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(nil)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Contains(t, string(body), "Sign in with Twitter")
|
||||
handler := httpserver.NewHandler(
|
||||
config.Config{},
|
||||
&mocks.Store{},
|
||||
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
|
||||
&mockSessionStore,
|
||||
&mockTokenGenerator{},
|
||||
zap.NewNop(),
|
||||
)
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Contains(t, string(body), tc.wantBody)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
func TestGetLogin(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
sessionSaveError error
|
||||
|
@ -122,7 +150,7 @@ func TestLogin(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCallback(t *testing.T) {
|
||||
func TestGetCallback(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
state string
|
||||
|
@ -133,7 +161,9 @@ func TestCallback(t *testing.T) {
|
|||
oauth2StatusCode int
|
||||
getTwitterUserError error
|
||||
createUserError error
|
||||
sessionSaveError error
|
||||
wantStatusCode int
|
||||
wantLocation string
|
||||
wantError string
|
||||
}{
|
||||
{
|
||||
|
@ -205,6 +235,17 @@ func TestCallback(t *testing.T) {
|
|||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantError: "error saving user",
|
||||
},
|
||||
{
|
||||
name: "error saving session",
|
||||
state: "mystate",
|
||||
sessionState: "mystate",
|
||||
code: "mycode",
|
||||
sessionPkceVerifier: "mypkceverifier",
|
||||
oauth2StatusCode: http.StatusOK,
|
||||
sessionSaveError: errors.New("foo"),
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantError: "error saving user",
|
||||
},
|
||||
{
|
||||
name: "successful exchange",
|
||||
state: "mystate",
|
||||
|
@ -212,7 +253,8 @@ func TestCallback(t *testing.T) {
|
|||
code: "mycode",
|
||||
sessionPkceVerifier: "mypkceverifier",
|
||||
oauth2StatusCode: http.StatusOK,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantStatusCode: http.StatusFound,
|
||||
wantLocation: "/",
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -223,6 +265,7 @@ func TestCallback(t *testing.T) {
|
|||
sess.Values["state"] = tc.sessionState
|
||||
sess.Values["pkce_verifier"] = tc.sessionPkceVerifier
|
||||
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError)
|
||||
mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(tc.sessionSaveError)
|
||||
|
||||
var mockTwitterClient mocks.TwitterAPIClient
|
||||
mockTwitterClient.On("GetMe").Return(&twitter.User{ID: "1", Name: "foo", Username: "Foo Bar"}, tc.getTwitterUserError)
|
||||
|
@ -283,3 +326,31 @@ func TestCallback(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostLogout(t *testing.T) {
|
||||
var mockSessionStore mocks.SessionStore
|
||||
sess := sessions.NewSession(&mockSessionStore, "elon_session")
|
||||
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil)
|
||||
mockSessionStore.On("Save", mock.Anything, mock.Anything, mock.MatchedBy(func(s *sessions.Session) bool {
|
||||
_, ok := s.Values["current_user"]
|
||||
return !ok
|
||||
})).Return(nil)
|
||||
|
||||
handler := httpserver.NewHandler(
|
||||
config.Config{},
|
||||
&mocks.Store{},
|
||||
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
|
||||
&mockSessionStore,
|
||||
nil,
|
||||
zap.NewNop(),
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusFound, res.StatusCode)
|
||||
assert.Equal(t, "/", res.Header.Get("Location"))
|
||||
}
|
||||
|
|
|
@ -5,15 +5,22 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/gorilla/sessions"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler {
|
||||
func loggerMiddleware(l *zap.Logger, sessionStore sessions.Store) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||
|
||||
t1 := time.Now()
|
||||
|
||||
cu, ok := getCurrentUser(r, sessionStore)
|
||||
var userID string
|
||||
if ok {
|
||||
userID = cu.ID.String()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
l.Info("HTTP",
|
||||
zap.String("proto", r.Proto),
|
||||
|
@ -22,7 +29,10 @@ func loggerMiddleware(l *zap.Logger) func(next http.Handler) http.Handler {
|
|||
zap.Duration("dur", time.Since(t1)),
|
||||
zap.Int("status", ww.Status()),
|
||||
zap.Int("size", ww.BytesWritten()),
|
||||
zap.String("reqId", middleware.GetReqID(r.Context())))
|
||||
zap.String("ip", r.RemoteAddr),
|
||||
zap.String("reqId", middleware.GetReqID(r.Context())),
|
||||
zap.String("userID", userID),
|
||||
zap.String("username", cu.TwitterUsername))
|
||||
}()
|
||||
|
||||
next.ServeHTTP(ww, r)
|
|
@ -3,18 +3,36 @@ package templates
|
|||
import (
|
||||
"embed"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
//go:embed views/*.html
|
||||
var tmplFS embed.FS
|
||||
|
||||
var templatesMap *template.Template
|
||||
var (
|
||||
Index = parse("views/index.html")
|
||||
Dashboard = parse("views/dashboard.html")
|
||||
)
|
||||
|
||||
func init() {
|
||||
templatesMap = template.Must(template.ParseFS(tmplFS, "views/*.html"))
|
||||
func parse(file string) *template.Template {
|
||||
return template.Must(template.New("base.html").ParseFS(tmplFS, "views/base.html", file))
|
||||
}
|
||||
|
||||
func Execute(w io.Writer, name string, data any) error {
|
||||
return templatesMap.ExecuteTemplate(w, name, data)
|
||||
// Renderer renders HTML templates, and logs any errors.
|
||||
type Renderer struct {
|
||||
logger *zap.SugaredLogger
|
||||
}
|
||||
|
||||
func NewRenderer(logger *zap.SugaredLogger) *Renderer {
|
||||
return &Renderer{logger: logger}
|
||||
}
|
||||
|
||||
// Render renders an HTML template to an http.ResponseWriter.
|
||||
func (r *Renderer) Render(w http.ResponseWriter, template *template.Template, data any) {
|
||||
if err := template.Execute(w, data); err != nil {
|
||||
r.logger.With("err", err).Error("error rendering template")
|
||||
http.Error(w, "error rendering template", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,18 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Elon Eats My Tweets</title>
|
||||
<title>{{block "title" .}}Elon Eats My Tweets{{end}}</title>
|
||||
</head>
|
||||
<body>
|
||||
{{block "content" .}}{{end}}
|
||||
<h1>Elon eats my tweets</h1>
|
||||
{{if not .CurrentUser.IsNil}}
|
||||
<p>
|
||||
Logged in as @{{.CurrentUser.TwitterUsername}}
|
||||
<form id="logout_form" action="/logout" method="post"></form>
|
||||
<a href="/logout" onclick="document.getElementById('logout_form').submit(); return false;">Log out</a>
|
||||
</p>
|
||||
{{end}}
|
||||
|
||||
{{template "content" .}}
|
||||
</body>
|
||||
</html>
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
{{template "base.html" .}}
|
||||
|
||||
{{define "title"}}Elon Eats My Tweets | Dashboard{{end}}
|
||||
|
||||
{{define "content"}}
|
||||
Welcome to your dashboard.
|
||||
{{end}}
|
|
@ -1,8 +1,6 @@
|
|||
{{template "base.html" .}}
|
||||
|
||||
{{define "content"}}
|
||||
<h1>Elon eats my tweets<h1>
|
||||
|
||||
<p>Sick of Elon? Tired of Twitter? Have Elon Musk delete your tweets for you.</p>
|
||||
|
||||
<a href="/login">Sign in with Twitter</a>
|
||||
|
|
Reference in New Issue