oauth2: Complete implementation

This commit is contained in:
Rob Watson 2022-05-20 02:26:42 +02:00
parent 27af1333ad
commit 8142349db4
5 changed files with 120 additions and 17 deletions

View File

@ -1,28 +1,37 @@
package config package config
import "os" import (
"errors"
"os"
)
type TwitterCredentials struct { type TwitterCredentials struct {
ClientID, ClientSecret, CallbackURL string ClientID, ClientSecret, CallbackURL string
} }
type Config struct { type Config struct {
PublicPath string PublicPath string
SessionKey string
ListenAddr string ListenAddr string
Twitter TwitterCredentials Twitter TwitterCredentials
} }
func NewFromEnv() Config { func NewFromEnv() (Config, error) {
listenAddr := os.Getenv("ELON_LISTEN_ADDR") listenAddr := os.Getenv("ELON_LISTEN_ADDR")
if listenAddr == "" { if listenAddr == "" {
listenAddr = ":8000" listenAddr = ":8000"
} }
sessionKey := os.Getenv("ELON_SESSION_KEY")
if sessionKey == "" {
return Config{}, errors.New("missing ELON_SESSION_KEY")
}
return Config{ return Config{
PublicPath: os.Getenv("ELON_PUBLIC_PATH"), PublicPath: os.Getenv("ELON_PUBLIC_PATH"),
SessionKey: sessionKey,
ListenAddr: listenAddr, ListenAddr: listenAddr,
Twitter: TwitterCredentials{ Twitter: TwitterCredentials{
ClientID: os.Getenv("ELON_TWITTER_CLIENT_ID"), ClientID: os.Getenv("ELON_TWITTER_CLIENT_ID"),
ClientSecret: os.Getenv("ELON_TWITTER_CLIENT_SECRET"), ClientSecret: os.Getenv("ELON_TWITTER_CLIENT_SECRET"),
CallbackURL: os.Getenv("ELON_TWITTER_CALLBACK_URL"), CallbackURL: os.Getenv("ELON_TWITTER_CALLBACK_URL"),
}, },
} }, nil
} }

2
go.mod
View File

@ -10,6 +10,8 @@ require (
require ( require (
github.com/golang/protobuf v1.4.2 // indirect github.com/golang/protobuf v1.4.2 // indirect
github.com/google/go-cmp v0.5.7 // indirect github.com/google/go-cmp v0.5.7 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect
google.golang.org/appengine v1.6.6 // indirect google.golang.org/appengine v1.6.6 // indirect
google.golang.org/protobuf v1.25.0 // indirect google.golang.org/protobuf v1.25.0 // indirect

4
go.sum
View File

@ -97,6 +97,10 @@ github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hf
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=

View File

@ -2,8 +2,12 @@ package httpserver
import ( import (
"context" "context"
"crypto/sha256"
"encoding/base64"
"fmt" "fmt"
"io"
"log" "log"
"math/rand"
"net/http" "net/http"
"path/filepath" "path/filepath"
"text/template" "text/template"
@ -11,17 +15,28 @@ import (
"git.netflux.io/rob/elon-eats-my-tweets/config" "git.netflux.io/rob/elon-eats-my-tweets/config"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/chi/middleware" "github.com/go-chi/chi/middleware"
"github.com/gorilla/sessions"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
const (
sessionName = "elon_session"
sessionKeyState = "state"
sessionKeyPkceVerifier = "pkce_verifier"
stateLen = 64
pkceVerifierLen = 64
)
type handler struct { type handler struct {
templates *template.Template templates *template.Template
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
sessionStore *sessions.CookieStore
} }
func newHandler(cfg config.Config, templates *template.Template) *handler { func newHandler(cfg config.Config, templates *template.Template) *handler {
return &handler{ return &handler{
templates: templates, templates: templates,
sessionStore: sessions.NewCookieStore([]byte(cfg.SessionKey)),
oauth2Config: &oauth2.Config{ oauth2Config: &oauth2.Config{
ClientID: cfg.Twitter.ClientID, ClientID: cfg.Twitter.ClientID,
ClientSecret: cfg.Twitter.ClientSecret, ClientSecret: cfg.Twitter.ClientSecret,
@ -30,6 +45,7 @@ func newHandler(cfg config.Config, templates *template.Template) *handler {
Endpoint: oauth2.Endpoint{ Endpoint: oauth2.Endpoint{
AuthURL: "https://twitter.com/i/oauth2/authorize", AuthURL: "https://twitter.com/i/oauth2/authorize",
TokenURL: "https://api.twitter.com/2/oauth2/token", TokenURL: "https://api.twitter.com/2/oauth2/token",
AuthStyle: oauth2.AuthStyleInHeader,
}, },
}, },
} }
@ -44,32 +60,81 @@ func (h *handler) getIndex(w http.ResponseWriter, r *http.Request) {
} }
func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) { func (h *handler) getLogin(w http.ResponseWriter, r *http.Request) {
state := randSeq(stateLen)
pkceVerifier := randSeq(pkceVerifierLen)
session, _ := h.sessionStore.Get(r, sessionName)
session.Values[sessionKeyState] = state
session.Values[sessionKeyPkceVerifier] = pkceVerifier
if err := session.Save(r, w); err != nil {
log.Printf("error saving session: %v", err)
http.Error(w, "error saving session", http.StatusBadRequest)
return
}
url := h.oauth2Config.AuthCodeURL( url := h.oauth2Config.AuthCodeURL(
// TODO: implement state and code_challenge tokens state,
"state", oauth2.SetAuthURLParam("code_challenge", encodeSHA256(pkceVerifier)),
oauth2.SetAuthURLParam("code_challenge", "challenge"), oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge_method", "plain"),
) )
http.Redirect(w, r, url, http.StatusTemporaryRedirect) http.Redirect(w, r, url, http.StatusTemporaryRedirect)
} }
func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) { func (h *handler) getCallback(w http.ResponseWriter, r *http.Request) {
session, err := h.sessionStore.Get(r, sessionName)
if err != nil {
log.Printf("error reading session: %v", err)
http.Error(w, "error reading session", http.StatusBadRequest)
return
}
state, ok := session.Values[sessionKeyState]
if !ok {
log.Println("empty state", err)
http.Error(w, "error reading session", http.StatusBadRequest)
return
}
if state != r.URL.Query().Get("state") {
http.Error(w, "error validating request", http.StatusBadRequest)
return
}
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
if code == "" { if code == "" {
http.Error(w, "empty code", http.StatusBadRequest) http.Error(w, "empty code", http.StatusBadRequest)
return return
} }
_, err := h.oauth2Config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", "challenge")) pkceVerifier, ok := session.Values[sessionKeyPkceVerifier]
if err != nil { if !ok {
log.Printf("error exchanging code: %v", err) log.Println("empty code challenge", err)
http.Error(w, "error exchanging code", http.StatusInternalServerError) http.Error(w, "error reading session", http.StatusBadRequest)
return return
} }
token, err := h.oauth2Config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", pkceVerifier.(string)))
if err != nil {
log.Printf("error exchanging code: %v", err)
http.Error(w, "error exchanging code", http.StatusForbidden)
return
}
client := h.oauth2Config.Client(context.Background(), token)
resp, err := client.Get("https://api.twitter.com/2/users/me")
if err != nil {
log.Printf("error fetching users/me: %v", err)
http.Error(w, "error fetching user", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// TODO: do something sensible
body, _ := io.ReadAll(resp.Body)
w.Header().Set("content-type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Header().Set("content-type", "text/html") if _, err = w.Write([]byte(body)); err != nil {
if _, err = w.Write([]byte("ok")); err != nil {
log.Printf("error writing response: %v", err) log.Printf("error writing response: %v", err)
} }
} }
@ -93,3 +158,19 @@ func Start(cfg config.Config) error {
return http.ListenAndServe(":8000", r) return http.ListenAndServe(":8000", r)
} }
// https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func randSeq(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}
func encodeSHA256(s string) string {
enc := sha256.Sum256([]byte(s))
return base64.RawURLEncoding.EncodeToString(enc[:])
}

View File

@ -2,13 +2,20 @@ package main
import ( import (
"log" "log"
"math/rand"
"time"
"git.netflux.io/rob/elon-eats-my-tweets/config" "git.netflux.io/rob/elon-eats-my-tweets/config"
"git.netflux.io/rob/elon-eats-my-tweets/httpserver" "git.netflux.io/rob/elon-eats-my-tweets/httpserver"
) )
func main() { func main() {
c := config.NewFromEnv() rand.Seed(time.Now().UnixNano())
c, err := config.NewFromEnv()
if err != nil {
log.Fatal(err)
}
log.Fatal(httpserver.Start(c)) log.Fatal(httpserver.Start(c))
} }