293 lines
8.9 KiB
Go
293 lines
8.9 KiB
Go
package httpserver_test
|
|
|
|
import (
|
|
"errors"
|
|
"html/template"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"git.netflux.io/rob/elon-eats-my-tweets/config"
|
|
"git.netflux.io/rob/elon-eats-my-tweets/generated/mocks"
|
|
"git.netflux.io/rob/elon-eats-my-tweets/generated/store"
|
|
"git.netflux.io/rob/elon-eats-my-tweets/httpserver"
|
|
"git.netflux.io/rob/elon-eats-my-tweets/twitter"
|
|
"github.com/gorilla/sessions"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
var templates = template.Must(template.ParseGlob(filepath.Join("..", "public", "views", "*.html")))
|
|
|
|
// mockTokenGenerator implements httpserver.TokenGenerator.
|
|
type mockTokenGenerator struct {
|
|
i int
|
|
tokens []string
|
|
}
|
|
|
|
func (g *mockTokenGenerator) GenerateToken(_ int) string {
|
|
i := g.i
|
|
if len(g.tokens) <= i {
|
|
return ""
|
|
}
|
|
g.i++
|
|
return g.tokens[i]
|
|
}
|
|
|
|
func TestGetIndex(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler := httpserver.NewHandler(
|
|
config.Config{},
|
|
templates,
|
|
&mocks.Store{},
|
|
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
|
|
&mocks.SessionStore{},
|
|
&mockTokenGenerator{},
|
|
zap.NewNop(),
|
|
)
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
res := rec.Result()
|
|
defer res.Body.Close()
|
|
body, err := ioutil.ReadAll(res.Body)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
assert.Contains(t, string(body), "Sign in with Twitter")
|
|
}
|
|
|
|
func TestLogin(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
sessionSaveError error
|
|
wantStatusCode int
|
|
wantLocation string
|
|
wantRespBody string
|
|
}{
|
|
{
|
|
name: "successful login",
|
|
wantStatusCode: http.StatusFound,
|
|
wantLocation: "https://www.example.com/oauth/authorize?client_id=foo&code_challenge=RdoE4fOeAO8YelxeEEd70qNDoVgzl4844utzxsozlR4&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fwww.example.com%2Fcallback&response_type=code&scope=tweet.read+tweet.write+users.read+offline.access&state=state",
|
|
},
|
|
{
|
|
name: "error saving session",
|
|
wantStatusCode: http.StatusInternalServerError,
|
|
sessionSaveError: errors.New("boom"),
|
|
wantRespBody: "unexpected error",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
var mockSessionStore mocks.SessionStore
|
|
sess := sessions.NewSession(&mockSessionStore, "elon_session")
|
|
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil)
|
|
mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(tc.sessionSaveError)
|
|
|
|
handler := httpserver.NewHandler(
|
|
config.Config{
|
|
Twitter: config.TwitterConfig{
|
|
ClientID: "foo",
|
|
ClientSecret: "bar",
|
|
CallbackURL: "https://www.example.com/callback",
|
|
AuthorizeURL: "https://www.example.com/oauth/authorize",
|
|
TokenURL: "https://www.example.com/oauth/token",
|
|
},
|
|
},
|
|
templates,
|
|
&mocks.Store{},
|
|
func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} },
|
|
&mockSessionStore,
|
|
&mockTokenGenerator{tokens: []string{"state", "pkceVerifier"}},
|
|
zap.NewNop(),
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/login", nil)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
res := rec.Result()
|
|
defer res.Body.Close()
|
|
|
|
assert.Equal(t, tc.wantStatusCode, res.StatusCode)
|
|
if tc.wantRespBody == "" {
|
|
assert.Equal(t, tc.wantLocation, res.Header.Get("Location"))
|
|
}
|
|
|
|
if tc.wantRespBody != "" {
|
|
body, err := ioutil.ReadAll(res.Body)
|
|
require.NoError(t, err)
|
|
assert.Contains(t, string(body), tc.wantRespBody)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCallback(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
state string
|
|
sessionState string
|
|
sessionPkceVerifier string
|
|
code string
|
|
sessionReadError error
|
|
oauth2StatusCode int
|
|
getTwitterUserError error
|
|
createUserError error
|
|
wantStatusCode int
|
|
wantError string
|
|
}{
|
|
{
|
|
name: "unable to read session",
|
|
sessionReadError: errors.New("boom"),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
wantError: "error reading session",
|
|
},
|
|
{
|
|
name: "missing state",
|
|
state: "mystate",
|
|
sessionState: "",
|
|
wantStatusCode: http.StatusBadRequest,
|
|
wantError: "error validating request",
|
|
},
|
|
{
|
|
name: "unexpected state",
|
|
state: "mystate",
|
|
sessionState: "foostate",
|
|
wantStatusCode: http.StatusBadRequest,
|
|
wantError: "error validating request",
|
|
},
|
|
{
|
|
name: "empty code",
|
|
state: "mystate",
|
|
sessionState: "mystate",
|
|
code: "",
|
|
wantStatusCode: http.StatusBadRequest,
|
|
wantError: "invalid code",
|
|
},
|
|
{
|
|
name: "empty pkce verifier",
|
|
state: "mystate",
|
|
sessionState: "mystate",
|
|
code: "mycode",
|
|
sessionPkceVerifier: "",
|
|
wantStatusCode: http.StatusBadRequest,
|
|
wantError: "error reading session",
|
|
},
|
|
{
|
|
name: "error exchanging code",
|
|
state: "mystate",
|
|
sessionState: "mystate",
|
|
code: "mycode",
|
|
sessionPkceVerifier: "mypkceverifier",
|
|
oauth2StatusCode: http.StatusInternalServerError,
|
|
wantStatusCode: http.StatusForbidden,
|
|
wantError: "error exchanging code",
|
|
},
|
|
{
|
|
name: "error fetching user from twitter",
|
|
state: "mystate",
|
|
sessionState: "mystate",
|
|
code: "mycode",
|
|
sessionPkceVerifier: "mypkceverifier",
|
|
oauth2StatusCode: http.StatusOK,
|
|
getTwitterUserError: errors.New("nothing to see here"),
|
|
wantStatusCode: http.StatusInternalServerError,
|
|
wantError: "error validating user",
|
|
},
|
|
{
|
|
name: "error storing user",
|
|
state: "mystate",
|
|
sessionState: "mystate",
|
|
code: "mycode",
|
|
sessionPkceVerifier: "mypkceverifier",
|
|
oauth2StatusCode: http.StatusOK,
|
|
createUserError: errors.New("oh no"),
|
|
wantStatusCode: http.StatusInternalServerError,
|
|
wantError: "error saving user",
|
|
},
|
|
{
|
|
name: "successful exchange",
|
|
state: "mystate",
|
|
sessionState: "mystate",
|
|
code: "mycode",
|
|
sessionPkceVerifier: "mypkceverifier",
|
|
oauth2StatusCode: http.StatusOK,
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
var mockSessionStore mocks.SessionStore
|
|
sess := sessions.NewSession(&mockSessionStore, "elon_session")
|
|
sess.Values["state"] = tc.sessionState
|
|
sess.Values["pkce_verifier"] = tc.sessionPkceVerifier
|
|
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError)
|
|
|
|
var mockTwitterClient mocks.TwitterAPIClient
|
|
mockTwitterClient.On("GetMe").Return(&twitter.User{ID: "1", Name: "foo", Username: "Foo Bar"}, tc.getTwitterUserError)
|
|
|
|
var mockStore mocks.Store
|
|
mockStore.On("CreateUser", mock.Anything, mock.MatchedBy(func(params store.CreateUserParams) bool {
|
|
return params.TwitterID == "1" && params.Name == "foo" && params.Username == "Foo Bar"
|
|
})).Return(store.User{}, tc.createUserError)
|
|
|
|
const callbackURL = "https://www.example.com/callback"
|
|
oauthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, http.MethodPost, r.Method)
|
|
assert.Equal(t, "/oauth/token", r.URL.Path)
|
|
|
|
require.NoError(t, r.ParseForm())
|
|
assert.Equal(t, tc.code, r.PostFormValue("code"))
|
|
assert.Equal(t, tc.sessionPkceVerifier, r.PostFormValue("code_verifier"))
|
|
assert.Equal(t, "authorization_code", r.PostFormValue("grant_type"))
|
|
assert.Equal(t, callbackURL, r.PostFormValue("redirect_uri"))
|
|
|
|
w.WriteHeader(tc.oauth2StatusCode)
|
|
if tc.oauth2StatusCode == http.StatusOK {
|
|
w.Write([]byte(`access_token=foo&expires_in=3600&refresh_token=&token_type=bearer"}`))
|
|
}
|
|
})
|
|
srv := httptest.NewServer(oauthHandler)
|
|
|
|
handler := httpserver.NewHandler(
|
|
config.Config{
|
|
Twitter: config.TwitterConfig{
|
|
ClientID: "foo",
|
|
ClientSecret: "bar",
|
|
CallbackURL: callbackURL,
|
|
AuthorizeURL: srv.URL + "/oauth/authorize",
|
|
TokenURL: srv.URL + "/oauth/token",
|
|
},
|
|
},
|
|
templates,
|
|
&mockStore,
|
|
func(*http.Client) httpserver.TwitterAPIClient { return &mockTwitterClient },
|
|
&mockSessionStore,
|
|
nil,
|
|
zap.NewNop(),
|
|
)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/callback?state="+tc.state+"&code="+tc.code, nil)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
res := rec.Result()
|
|
defer res.Body.Close()
|
|
|
|
assert.Equal(t, tc.wantStatusCode, res.StatusCode)
|
|
|
|
if tc.wantError != "" {
|
|
body, err := ioutil.ReadAll(res.Body)
|
|
require.NoError(t, err)
|
|
assert.Contains(t, string(body), tc.wantError)
|
|
}
|
|
})
|
|
}
|
|
}
|