package httpserver //go:generate mockery --recursive --srcpkg github.com/gorilla/sessions --name Store --structname SessionStore --filename SessionStore.go --output ../generated/mocks //go:generate mockery --recursive --name TwitterAPIClient --filename TwitterAPIClient.go --output ../generated/mocks import ( "context" "net/http" "time" "git.netflux.io/rob/elon-eats-my-tweets/config" "git.netflux.io/rob/elon-eats-my-tweets/generated/store" "git.netflux.io/rob/elon-eats-my-tweets/templates" "git.netflux.io/rob/elon-eats-my-tweets/twitter" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" "github.com/gorilla/sessions" "go.uber.org/zap" "golang.org/x/oauth2" ) const ( sessionName = "elon_session" sessionKeyState = "state" sessionKeyPkceVerifier = "pkce_verifier" stateLen = 256 pkceVerifierLen = 256 ) type TwitterAPIClient interface { GetMe() (*twitter.User, error) } type handler struct { store twitter.Store twitterAPIClientFunc func(c *http.Client) TwitterAPIClient oauth2Config *oauth2.Config sessionStore sessions.Store tokenGenerator TokenGenerator logger *zap.SugaredLogger } func NewHandler(cfg config.Config, store twitter.Store, twitterAPIClientFunc func(c *http.Client) TwitterAPIClient, sessionStore sessions.Store, tokenGenerator TokenGenerator, logger *zap.Logger) http.Handler { r := chi.NewRouter() r.Use(middleware.RequestID) r.Use(middleware.RealIP) r.Use(loggerMiddleware(logger)) r.Use(middleware.Recoverer) h := handler{ store: store, twitterAPIClientFunc: twitterAPIClientFunc, oauth2Config: &oauth2.Config{ ClientID: cfg.Twitter.ClientID, ClientSecret: cfg.Twitter.ClientSecret, RedirectURL: cfg.Twitter.CallbackURL, Scopes: []string{"tweet.read", "tweet.write", "users.read", "offline.access"}, Endpoint: oauth2.Endpoint{ AuthURL: cfg.Twitter.AuthorizeURL, TokenURL: cfg.Twitter.TokenURL, AuthStyle: oauth2.AuthStyleInHeader, }, }, sessionStore: sessionStore, tokenGenerator: tokenGenerator, logger: logger.Sugar(), } r.Get("/", h.getIndex) r.Get("/login", h.getLogin) r.Get("/callback", h.getCallback) return r } func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) { if err := templates.Execute(w, "index.html", nil); err != nil { h.logger.With("err", err).Error("error rendering template") http.Error(w, "error rendering template", http.StatusInternalServerError) return } } func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) { state := h.tokenGenerator.GenerateToken(stateLen) pkceVerifier := h.tokenGenerator.GenerateToken(pkceVerifierLen) session, _ := h.sessionStore.Get(r, sessionName) session.Values[sessionKeyState] = state session.Values[sessionKeyPkceVerifier] = pkceVerifier if err := session.Save(r, w); err != nil { h.logger.With("err", err).Error("error saving session") http.Error(w, "unexpected error", http.StatusInternalServerError) return } url := h.oauth2Config.AuthCodeURL( state, oauth2.SetAuthURLParam("code_challenge", encodeSHA256(pkceVerifier)), oauth2.SetAuthURLParam("code_challenge_method", "S256"), ) http.Redirect(w, r, url, http.StatusFound) } func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) { session, err := h.sessionStore.Get(r, sessionName) if err != nil { h.logger.With("err", err).Error("error reading session") http.Error(w, "error reading session", http.StatusBadRequest) return } state, ok := session.Values[sessionKeyState] if !ok || state == "" { h.logger.Error("empty state parameter in oauth2 request") http.Error(w, "error validating request", http.StatusBadRequest) return } if state != r.URL.Query().Get("state") { h.logger.Error("unexpected state in oauth2 request") http.Error(w, "error validating request", http.StatusBadRequest) return } code := r.URL.Query().Get("code") if code == "" { h.logger.Error("empty code in oauth2 request") http.Error(w, "invalid code", http.StatusBadRequest) return } pkceVerifier, ok := session.Values[sessionKeyPkceVerifier] if !ok || pkceVerifier == "" { h.logger.Error("no pkce verifier found in session") http.Error(w, "error reading session", http.StatusBadRequest) return } token, err := h.oauth2Config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", pkceVerifier.(string))) if err != nil { h.logger.With("err", err).Error("error exchanging code for access token") http.Error(w, "error exchanging code", http.StatusForbidden) return } twitterClient := h.twitterAPIClientFunc(h.oauth2Config.Client(context.Background(), token)) twitterUser, err := twitterClient.GetMe() if err != nil { h.logger.With("err", err).Error("error fetching user from twitter") http.Error(w, "error validating user", http.StatusInternalServerError) return } if _, err := h.store.CreateUser(r.Context(), store.CreateUserParams{ TwitterID: twitterUser.ID, Username: twitterUser.Username, Name: twitterUser.Name, AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), }); err != nil { h.logger.With("err", err).Error("error upserting user") http.Error(w, "error saving user", http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) }