package httpserver //go:generate mockery --recursive --srcpkg github.com/gorilla/sessions --name Store --output ../generated/mocks import ( "context" "html/template" "io" "net/http" "git.netflux.io/rob/elon-eats-my-tweets/config" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" "github.com/gorilla/sessions" "go.uber.org/zap" "golang.org/x/oauth2" ) const ( sessionName = "elon_session" sessionKeyState = "state" sessionKeyPkceVerifier = "pkce_verifier" stateLen = 64 pkceVerifierLen = 64 ) type handler struct { templates *template.Template oauth2Config *oauth2.Config sessionStore sessions.Store tokenGenerator TokenGenerator logger *zap.SugaredLogger } func NewHandler(cfg config.Config, templates *template.Template, sessionStore sessions.Store, tokenGenerator TokenGenerator, logger *zap.Logger) http.Handler { r := chi.NewRouter() r.Use(middleware.RequestID) r.Use(middleware.RealIP) r.Use(loggerMiddleware(logger)) r.Use(middleware.Recoverer) h := handler{ templates: templates, oauth2Config: &oauth2.Config{ ClientID: cfg.Twitter.ClientID, ClientSecret: cfg.Twitter.ClientSecret, RedirectURL: cfg.Twitter.CallbackURL, Scopes: []string{"tweet.read", "tweet.write", "users.read", "offline.access"}, Endpoint: oauth2.Endpoint{ AuthURL: cfg.Twitter.AuthorizeURL, TokenURL: cfg.Twitter.TokenURL, AuthStyle: oauth2.AuthStyleInHeader, }, }, sessionStore: sessionStore, tokenGenerator: tokenGenerator, logger: logger.Sugar(), } r.Get("/", h.getIndex) r.Get("/login", h.getLogin) r.Get("/callback", h.getCallback) return r } func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) { if err := h.templates.ExecuteTemplate(w, "index", nil); err != nil { h.logger.With("err", err).Error("error rendering template") http.Error(w, "error rendering template", http.StatusInternalServerError) return } } func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) { state := h.tokenGenerator.GenerateToken(stateLen) pkceVerifier := h.tokenGenerator.GenerateToken(pkceVerifierLen) session, _ := h.sessionStore.Get(r, sessionName) session.Values[sessionKeyState] = state session.Values[sessionKeyPkceVerifier] = pkceVerifier if err := session.Save(r, w); err != nil { h.logger.With("err", err).Error("error saving session") http.Error(w, "unexpected error", http.StatusInternalServerError) return } url := h.oauth2Config.AuthCodeURL( state, oauth2.SetAuthURLParam("code_challenge", encodeSHA256(pkceVerifier)), oauth2.SetAuthURLParam("code_challenge_method", "S256"), ) http.Redirect(w, r, url, http.StatusFound) } func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) { session, err := h.sessionStore.Get(r, sessionName) if err != nil { h.logger.With("err", err).Error("error reading session") http.Error(w, "error reading session", http.StatusBadRequest) return } state, ok := session.Values[sessionKeyState] if !ok || state == "" { h.logger.Error("empty state parameter in oauth2 request") http.Error(w, "error validating request", http.StatusBadRequest) return } if state != r.URL.Query().Get("state") { h.logger.Error("unexpected state in oauth2 request") http.Error(w, "error validating request", http.StatusBadRequest) return } code := r.URL.Query().Get("code") if code == "" { h.logger.Error("empty code in oauth2 request") http.Error(w, "invalid code", http.StatusBadRequest) return } pkceVerifier, ok := session.Values[sessionKeyPkceVerifier] if !ok || pkceVerifier == "" { h.logger.Error("no pkce verifier found in session") http.Error(w, "error reading session", http.StatusBadRequest) return } token, err := h.oauth2Config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", pkceVerifier.(string))) if err != nil { h.logger.With("err", err).Error("error exchanging code for access token") http.Error(w, "error exchanging code", http.StatusForbidden) return } client := h.oauth2Config.Client(context.Background(), token) resp, err := client.Get("https://api.twitter.com/2/users/me") if err != nil { h.logger.With("err", err).Error("error fetching user") http.Error(w, "error fetching user", http.StatusInternalServerError) return } defer resp.Body.Close() // TODO: do something sensible body, _ := io.ReadAll(resp.Body) w.Header().Set("content-type", "application/json") w.WriteHeader(http.StatusOK) if _, err = w.Write([]byte(body)); err != nil { h.logger.With("err", err).Error("error writing response") } }