247 lines
7.2 KiB
Go
247 lines
7.2 KiB
Go
package httpserver
|
|
|
|
//go:generate mockery --recursive --srcpkg github.com/gorilla/sessions --name Store --structname SessionStore --filename SessionStore.go --output ../generated/mocks
|
|
//go:generate mockery --recursive --name TwitterAPIClient --filename TwitterAPIClient.go --output ../generated/mocks
|
|
|
|
import (
|
|
"context"
|
|
"encoding/gob"
|
|
"html/template"
|
|
"net/http"
|
|
"time"
|
|
|
|
"git.netflux.io/rob/elon-eats-my-tweets/config"
|
|
"git.netflux.io/rob/elon-eats-my-tweets/generated/store"
|
|
"git.netflux.io/rob/elon-eats-my-tweets/templates"
|
|
"git.netflux.io/rob/elon-eats-my-tweets/twitter"
|
|
"github.com/go-chi/chi"
|
|
"github.com/go-chi/chi/middleware"
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/sessions"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
func init() {
|
|
gob.Register(CurrentUser{})
|
|
}
|
|
|
|
type CurrentUser struct {
|
|
ID uuid.UUID
|
|
TwitterName, TwitterUsername string
|
|
}
|
|
|
|
func (cu CurrentUser) IsNil() bool {
|
|
return cu.ID == uuid.Nil
|
|
}
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
sessionName = "elon_session"
|
|
sessionKeyState = "state"
|
|
sessionKeyPkceVerifier = "pkce_verifier"
|
|
stateLen = 256
|
|
pkceVerifierLen = 256
|
|
)
|
|
|
|
const (
|
|
sessionKeyCurrentUser = "current_user"
|
|
contextKeyCurrentUser contextKey = sessionKeyCurrentUser
|
|
)
|
|
|
|
func getCurrentUser(r *http.Request, sessionStore sessions.Store) (CurrentUser, bool) {
|
|
session, _ := sessionStore.Get(r, sessionName)
|
|
if cu, ok := session.Values[sessionKeyCurrentUser]; ok {
|
|
return cu.(CurrentUser), true
|
|
}
|
|
return CurrentUser{}, false
|
|
}
|
|
|
|
type TwitterAPIClient interface {
|
|
GetMe() (*twitter.User, error)
|
|
}
|
|
|
|
type handler struct {
|
|
store twitter.Store
|
|
twitterAPIClientFunc func(c *http.Client) TwitterAPIClient
|
|
oauth2Config *oauth2.Config
|
|
renderer *templates.Renderer
|
|
sessionStore sessions.Store
|
|
tokenGenerator TokenGenerator
|
|
logger *zap.SugaredLogger
|
|
}
|
|
|
|
func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc func(c *http.Client) TwitterAPIClient, sessionStore sessions.Store, tokenGenerator TokenGenerator, logger *zap.Logger) http.Handler {
|
|
r := chi.NewRouter()
|
|
r.Use(middleware.RequestID)
|
|
r.Use(middleware.RealIP)
|
|
r.Use(loggerMiddleware(logger, sessionStore))
|
|
r.Use(middleware.Recoverer)
|
|
|
|
h := handler{
|
|
store: store,
|
|
twitterAPIClientFunc: twitterAPIClientFunc,
|
|
oauth2Config: &oauth2.Config{
|
|
ClientID: cfg.Twitter.ClientID,
|
|
ClientSecret: cfg.Twitter.ClientSecret,
|
|
RedirectURL: cfg.Twitter.CallbackURL,
|
|
Scopes: []string{"tweet.read", "tweet.write", "users.read", "offline.access"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: cfg.Twitter.AuthorizeURL,
|
|
TokenURL: cfg.Twitter.TokenURL,
|
|
AuthStyle: oauth2.AuthStyleInHeader,
|
|
},
|
|
},
|
|
renderer: templates.NewRenderer(logger.Sugar()),
|
|
sessionStore: sessionStore,
|
|
tokenGenerator: tokenGenerator,
|
|
logger: logger.Sugar(),
|
|
}
|
|
|
|
r.Get("/", h.getIndex)
|
|
r.Get("/login", h.getLogin)
|
|
r.Post("/logout", h.postLogout)
|
|
r.Get("/callback", h.getCallback)
|
|
|
|
return r
|
|
}
|
|
|
|
func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) {
|
|
cu, _ := getCurrentUser(r, h.sessionStore)
|
|
var template *template.Template
|
|
if cu.IsNil() {
|
|
template = templates.Index
|
|
} else {
|
|
template = templates.Dashboard
|
|
}
|
|
|
|
type data struct {
|
|
CurrentUser CurrentUser
|
|
}
|
|
|
|
h.renderer.Render(w, template, data{CurrentUser: cu})
|
|
}
|
|
|
|
func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) {
|
|
state := h.tokenGenerator.GenerateToken(stateLen)
|
|
pkceVerifier := h.tokenGenerator.GenerateToken(pkceVerifierLen)
|
|
|
|
session, _ := h.sessionStore.Get(r, sessionName)
|
|
session.Values[sessionKeyState] = state
|
|
session.Values[sessionKeyPkceVerifier] = pkceVerifier
|
|
if err := session.Save(r, w); err != nil {
|
|
h.logger.With("err", err).Error("error saving session")
|
|
http.Error(w, "unexpected error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
url := h.oauth2Config.AuthCodeURL(
|
|
state,
|
|
oauth2.SetAuthURLParam("code_challenge", encodeSHA256(pkceVerifier)),
|
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
|
)
|
|
|
|
http.Redirect(w, r, url, http.StatusFound)
|
|
}
|
|
|
|
func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) {
|
|
session, err := h.sessionStore.Get(r, sessionName)
|
|
if err != nil {
|
|
h.logger.With("err", err).Error("error reading session")
|
|
http.Error(w, "error reading session", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
state, ok := session.Values[sessionKeyState]
|
|
if !ok || state == "" {
|
|
h.logger.Error("empty state parameter in oauth2 request")
|
|
http.Error(w, "error validating request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if state != r.URL.Query().Get("state") {
|
|
h.logger.Error("unexpected state in oauth2 request")
|
|
http.Error(w, "error validating request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
code := r.URL.Query().Get("code")
|
|
if code == "" {
|
|
h.logger.Error("empty code in oauth2 request")
|
|
http.Error(w, "invalid code", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
pkceVerifier, ok := session.Values[sessionKeyPkceVerifier]
|
|
if !ok || pkceVerifier == "" {
|
|
h.logger.Error("no pkce verifier found in session")
|
|
http.Error(w, "error reading session", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
token, err := h.oauth2Config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", pkceVerifier.(string)))
|
|
if err != nil {
|
|
h.logger.With("err", err).Error("error exchanging code for access token")
|
|
http.Error(w, "error exchanging code", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
twitterClient := h.twitterAPIClientFunc(h.oauth2Config.Client(context.Background(), token))
|
|
twitterUser, err := twitterClient.GetMe()
|
|
if err != nil {
|
|
h.logger.With("err", err).Error("error fetching user from twitter")
|
|
http.Error(w, "error validating user", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
user, err := h.store.CreateUser(r.Context(), store.CreateUserParams{
|
|
TwitterID: twitterUser.ID,
|
|
Username: twitterUser.Username,
|
|
Name: twitterUser.Name,
|
|
AccessToken: token.AccessToken,
|
|
RefreshToken: token.RefreshToken,
|
|
CreatedAt: time.Now().UTC(),
|
|
UpdatedAt: time.Now().UTC(),
|
|
})
|
|
if err != nil {
|
|
h.logger.With("err", err).Error("error upserting user")
|
|
http.Error(w, "error saving user", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// finally, save in the session. To avoid hitting the database on most
|
|
// requests we'll store the Twitter name and username in the session.
|
|
session.Values[sessionKeyCurrentUser] = CurrentUser{
|
|
ID: user.ID,
|
|
TwitterName: twitterUser.Name,
|
|
TwitterUsername: twitterUser.Username,
|
|
}
|
|
if err := session.Save(r, w); err != nil {
|
|
h.logger.With("err", err).Error("error saving session")
|
|
http.Error(w, "error saving user", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
}
|
|
|
|
func (h *handler) postLogout(w http.ResponseWriter, r *http.Request) {
|
|
session, err := h.sessionStore.Get(r, sessionName)
|
|
if err != nil {
|
|
h.logger.With("err", err).Error("error reading session")
|
|
http.Error(w, "error reading session", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
delete(session.Values, sessionKeyCurrentUser)
|
|
|
|
if err := session.Save(r, w); err != nil {
|
|
h.logger.With("err", err).Error("error saving session")
|
|
http.Error(w, "error saving user", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
}
|