diff --git a/config/config.go b/config/config.go index 2e9de73..b789880 100644 --- a/config/config.go +++ b/config/config.go @@ -1,28 +1,37 @@ package config -import "os" +import ( + "errors" + "os" +) type TwitterCredentials struct { ClientID, ClientSecret, CallbackURL string } type Config struct { PublicPath string + SessionKey string ListenAddr string Twitter TwitterCredentials } -func NewFromEnv() Config { +func NewFromEnv() (Config, error) { listenAddr := os.Getenv("ELON_LISTEN_ADDR") if listenAddr == "" { listenAddr = ":8000" } + sessionKey := os.Getenv("ELON_SESSION_KEY") + if sessionKey == "" { + return Config{}, errors.New("missing ELON_SESSION_KEY") + } return Config{ PublicPath: os.Getenv("ELON_PUBLIC_PATH"), + SessionKey: sessionKey, ListenAddr: listenAddr, Twitter: TwitterCredentials{ ClientID: os.Getenv("ELON_TWITTER_CLIENT_ID"), ClientSecret: os.Getenv("ELON_TWITTER_CLIENT_SECRET"), CallbackURL: os.Getenv("ELON_TWITTER_CALLBACK_URL"), }, - } + }, nil } diff --git a/go.mod b/go.mod index 316ce52..6cf77a2 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,8 @@ require ( require ( github.com/golang/protobuf v1.4.2 // indirect github.com/google/go-cmp v0.5.7 // indirect + github.com/gorilla/securecookie v1.1.1 // indirect + github.com/gorilla/sessions v1.2.1 // indirect golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect google.golang.org/appengine v1.6.6 // indirect google.golang.org/protobuf v1.25.0 // indirect diff --git a/go.sum b/go.sum index fc8e702..aee8733 100644 --- a/go.sum +++ b/go.sum @@ -97,6 +97,10 @@ github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hf github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= +github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= diff --git a/httpserver/httpserver.go b/httpserver/httpserver.go index 49efa0c..6ca2d5f 100644 --- a/httpserver/httpserver.go +++ b/httpserver/httpserver.go @@ -2,8 +2,12 @@ package httpserver import ( "context" + "crypto/sha256" + "encoding/base64" "fmt" + "io" "log" + "math/rand" "net/http" "path/filepath" "text/template" @@ -11,25 +15,37 @@ import ( "git.netflux.io/rob/elon-eats-my-tweets/config" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" + "github.com/gorilla/sessions" "golang.org/x/oauth2" ) +const ( + sessionName = "elon_session" + sessionKeyState = "state" + sessionKeyPkceVerifier = "pkce_verifier" + stateLen = 64 + pkceVerifierLen = 64 +) + type handler struct { templates *template.Template oauth2Config *oauth2.Config + sessionStore *sessions.CookieStore } func newHandler(cfg config.Config, templates *template.Template) *handler { return &handler{ - templates: templates, + templates: templates, + sessionStore: sessions.NewCookieStore([]byte(cfg.SessionKey)), oauth2Config: &oauth2.Config{ ClientID: cfg.Twitter.ClientID, ClientSecret: cfg.Twitter.ClientSecret, RedirectURL: cfg.Twitter.CallbackURL, Scopes: []string{"tweet.read", "tweet.write", "users.read", "offline.access"}, Endpoint: oauth2.Endpoint{ - AuthURL: "https://twitter.com/i/oauth2/authorize", - TokenURL: "https://api.twitter.com/2/oauth2/token", + AuthURL: "https://twitter.com/i/oauth2/authorize", + TokenURL: "https://api.twitter.com/2/oauth2/token", + AuthStyle: oauth2.AuthStyleInHeader, }, }, } @@ -44,32 +60,81 @@ func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) { } func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) { + state := randSeq(stateLen) + pkceVerifier := randSeq(pkceVerifierLen) + + session, _ := h.sessionStore.Get(r, sessionName) + session.Values[sessionKeyState] = state + session.Values[sessionKeyPkceVerifier] = pkceVerifier + if err := session.Save(r, w); err != nil { + log.Printf("error saving session: %v", err) + http.Error(w, "error saving session", http.StatusBadRequest) + return + } + url := h.oauth2Config.AuthCodeURL( - // TODO: implement state and code_challenge tokens - "state", - oauth2.SetAuthURLParam("code_challenge", "challenge"), - oauth2.SetAuthURLParam("code_challenge_method", "plain"), + state, + oauth2.SetAuthURLParam("code_challenge", encodeSHA256(pkceVerifier)), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), ) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) } func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) { + session, err := h.sessionStore.Get(r, sessionName) + if err != nil { + log.Printf("error reading session: %v", err) + http.Error(w, "error reading session", http.StatusBadRequest) + return + } + + state, ok := session.Values[sessionKeyState] + if !ok { + log.Println("empty state", err) + http.Error(w, "error reading session", http.StatusBadRequest) + return + } + + if state != r.URL.Query().Get("state") { + http.Error(w, "error validating request", http.StatusBadRequest) + return + } + code := r.URL.Query().Get("code") if code == "" { http.Error(w, "empty code", http.StatusBadRequest) return } - _, err := h.oauth2Config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", "challenge")) - if err != nil { - log.Printf("error exchanging code: %v", err) - http.Error(w, "error exchanging code", http.StatusInternalServerError) + pkceVerifier, ok := session.Values[sessionKeyPkceVerifier] + if !ok { + log.Println("empty code challenge", err) + http.Error(w, "error reading session", http.StatusBadRequest) return } + token, err := h.oauth2Config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", pkceVerifier.(string))) + if err != nil { + log.Printf("error exchanging code: %v", err) + http.Error(w, "error exchanging code", http.StatusForbidden) + return + } + + client := h.oauth2Config.Client(context.Background(), token) + resp, err := client.Get("https://api.twitter.com/2/users/me") + if err != nil { + log.Printf("error fetching users/me: %v", err) + http.Error(w, "error fetching user", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + // TODO: do something sensible + body, _ := io.ReadAll(resp.Body) + + w.Header().Set("content-type", "application/json") w.WriteHeader(http.StatusOK) - w.Header().Set("content-type", "text/html") - if _, err = w.Write([]byte("ok")); err != nil { + if _, err = w.Write([]byte(body)); err != nil { log.Printf("error writing response: %v", err) } } @@ -93,3 +158,19 @@ func Start(cfg config.Config) error { return http.ListenAndServe(":8000", r) } + +// https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randSeq(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} + +func encodeSHA256(s string) string { + enc := sha256.Sum256([]byte(s)) + return base64.RawURLEncoding.EncodeToString(enc[:]) +} diff --git a/main.go b/main.go index 381fa7f..abccb02 100644 --- a/main.go +++ b/main.go @@ -2,13 +2,20 @@ package main import ( "log" + "math/rand" + "time" "git.netflux.io/rob/elon-eats-my-tweets/config" "git.netflux.io/rob/elon-eats-my-tweets/httpserver" ) func main() { - c := config.NewFromEnv() + rand.Seed(time.Now().UnixNano()) + + c, err := config.NewFromEnv() + if err != nil { + log.Fatal(err) + } log.Fatal(httpserver.Start(c)) }