package httpserver //go:generate mockery --recursive --srcpkg github.com/gorilla/sessions --name Store --structname SessionStore --filename SessionStore.go --output ../generated/mocks //go:generate mockery --recursive --name TwitterAPIClient --filename TwitterAPIClient.go --output ../generated/mocks import ( "context" "encoding/gob" "html/template" "net/http" "time" "git.netflux.io/rob/elon-eats-my-tweets/config" "git.netflux.io/rob/elon-eats-my-tweets/generated/store" "git.netflux.io/rob/elon-eats-my-tweets/templates" "git.netflux.io/rob/elon-eats-my-tweets/twitter" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" "github.com/google/uuid" "github.com/gorilla/sessions" "go.uber.org/zap" "golang.org/x/oauth2" ) func init() { gob.Register(CurrentUser{}) } type CurrentUser struct { ID uuid.UUID TwitterName, TwitterUsername string } func (cu CurrentUser) IsNil() bool { return cu.ID == uuid.Nil } type contextKey string const ( sessionName = "elon_session" sessionKeyState = "state" sessionKeyPkceVerifier = "pkce_verifier" stateLen = 256 pkceVerifierLen = 256 ) const ( sessionKeyCurrentUser = "current_user" contextKeyCurrentUser contextKey = sessionKeyCurrentUser ) func getCurrentUser(r *http.Request, sessionStore sessions.Store) (CurrentUser, bool) { session, _ := sessionStore.Get(r, sessionName) if cu, ok := session.Values[sessionKeyCurrentUser]; ok { return cu.(CurrentUser), true } return CurrentUser{}, false } type TwitterAPIClient interface { GetMe() (*twitter.User, error) } type handler struct { store twitter.Store twitterAPIClientFunc func(c *http.Client) TwitterAPIClient oauth2Config *oauth2.Config renderer *templates.Renderer sessionStore sessions.Store tokenGenerator TokenGenerator logger *zap.SugaredLogger } func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc func(c *http.Client) TwitterAPIClient, sessionStore sessions.Store, tokenGenerator TokenGenerator, logger *zap.Logger) http.Handler { r := chi.NewRouter() r.Use(middleware.RequestID) r.Use(middleware.RealIP) r.Use(loggerMiddleware(logger, sessionStore)) r.Use(middleware.Recoverer) h := handler{ store: store, twitterAPIClientFunc: twitterAPIClientFunc, oauth2Config: &oauth2.Config{ ClientID: cfg.Twitter.ClientID, ClientSecret: cfg.Twitter.ClientSecret, RedirectURL: cfg.Twitter.CallbackURL, Scopes: []string{"tweet.read", "tweet.write", "users.read", "offline.access"}, Endpoint: oauth2.Endpoint{ AuthURL: cfg.Twitter.AuthorizeURL, TokenURL: cfg.Twitter.TokenURL, AuthStyle: oauth2.AuthStyleInHeader, }, }, renderer: templates.NewRenderer(logger.Sugar()), sessionStore: sessionStore, tokenGenerator: tokenGenerator, logger: logger.Sugar(), } r.Get("/", h.getIndex) r.Get("/login", h.getLogin) r.Post("/logout", h.postLogout) r.Get("/callback", h.getCallback) return r } func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) { cu, _ := getCurrentUser(r, h.sessionStore) var template *template.Template if cu.IsNil() { template = templates.Index } else { template = templates.Dashboard } type data struct { CurrentUser CurrentUser } h.renderer.Render(w, template, data{CurrentUser: cu}) } func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) { state := h.tokenGenerator.GenerateToken(stateLen) pkceVerifier := h.tokenGenerator.GenerateToken(pkceVerifierLen) session, _ := h.sessionStore.Get(r, sessionName) session.Values[sessionKeyState] = state session.Values[sessionKeyPkceVerifier] = pkceVerifier if err := session.Save(r, w); err != nil { h.logger.With("err", err).Error("error saving session") http.Error(w, "unexpected error", http.StatusInternalServerError) return } url := h.oauth2Config.AuthCodeURL( state, oauth2.SetAuthURLParam("code_challenge", encodeSHA256(pkceVerifier)), oauth2.SetAuthURLParam("code_challenge_method", "S256"), ) http.Redirect(w, r, url, http.StatusFound) } func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) { session, err := h.sessionStore.Get(r, sessionName) if err != nil { h.logger.With("err", err).Error("error reading session") http.Error(w, "error reading session", http.StatusBadRequest) return } state, ok := session.Values[sessionKeyState] if !ok || state == "" { h.logger.Error("empty state parameter in oauth2 request") http.Error(w, "error validating request", http.StatusBadRequest) return } if state != r.URL.Query().Get("state") { h.logger.Error("unexpected state in oauth2 request") http.Error(w, "error validating request", http.StatusBadRequest) return } code := r.URL.Query().Get("code") if code == "" { h.logger.Error("empty code in oauth2 request") http.Error(w, "invalid code", http.StatusBadRequest) return } pkceVerifier, ok := session.Values[sessionKeyPkceVerifier] if !ok || pkceVerifier == "" { h.logger.Error("no pkce verifier found in session") http.Error(w, "error reading session", http.StatusBadRequest) return } token, err := h.oauth2Config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", pkceVerifier.(string))) if err != nil { h.logger.With("err", err).Error("error exchanging code for access token") http.Error(w, "error exchanging code", http.StatusForbidden) return } twitterClient := h.twitterAPIClientFunc(h.oauth2Config.Client(context.Background(), token)) twitterUser, err := twitterClient.GetMe() if err != nil { h.logger.With("err", err).Error("error fetching user from twitter") http.Error(w, "error validating user", http.StatusInternalServerError) return } user, err := h.store.CreateUser(r.Context(), store.CreateUserParams{ TwitterID: twitterUser.ID, Username: twitterUser.Username, Name: twitterUser.Name, AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), }) if err != nil { h.logger.With("err", err).Error("error upserting user") http.Error(w, "error saving user", http.StatusInternalServerError) return } // finally, save in the session. To avoid hitting the database on most // requests we'll store the Twitter name and username in the session. session.Values[sessionKeyCurrentUser] = CurrentUser{ ID: user.ID, TwitterName: twitterUser.Name, TwitterUsername: twitterUser.Username, } if err := session.Save(r, w); err != nil { h.logger.With("err", err).Error("error saving session") http.Error(w, "error saving user", http.StatusInternalServerError) return } http.Redirect(w, r, "/", http.StatusFound) } func (h *handler) postLogout(w http.ResponseWriter, r *http.Request) { session, err := h.sessionStore.Get(r, sessionName) if err != nil { h.logger.With("err", err).Error("error reading session") http.Error(w, "error reading session", http.StatusBadRequest) return } delete(session.Values, sessionKeyCurrentUser) if err := session.Save(r, w); err != nil { h.logger.With("err", err).Error("error saving session") http.Error(w, "error saving user", http.StatusInternalServerError) return } http.Redirect(w, r, "/", http.StatusFound) }