package httpserver_test
import (
"errors"
"html/template"
"io/ioutil"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"git.netflux.io/rob/elon-eats-my-tweets/config"
"git.netflux.io/rob/elon-eats-my-tweets/generated/mocks"
"git.netflux.io/rob/elon-eats-my-tweets/httpserver"
"github.com/gorilla/sessions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
var templates = template.Must(template.ParseGlob(filepath.Join("..", "public", "views", "*.html")))
// mockTokenGenerator implements httpserver.TokenGenerator.
type mockTokenGenerator struct {
i int
tokens []string
}
func (g *mockTokenGenerator) GenerateToken(_ int) string {
i := g.i
if len(g.tokens) <= i {
return ""
}
g.i++
return g.tokens[i]
}
func TestGetIndex(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
handler := httpserver.NewHandler(
config.Config{},
templates,
&mocks.Store{},
&mocks.SessionStore{},
&mockTokenGenerator{},
zap.NewNop(),
)
handler.ServeHTTP(rec, req)
res := rec.Result()
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Contains(t, string(body), "Sign in with Twitter")
}
func TestLogin(t *testing.T) {
testCases := []struct {
name string
sessionSaveError error
wantStatusCode int
wantLocation string
wantRespBody string
}{
{
name: "successful login",
wantStatusCode: http.StatusFound,
wantLocation: "https://www.example.com/oauth/authorize?client_id=foo&code_challenge=RdoE4fOeAO8YelxeEEd70qNDoVgzl4844utzxsozlR4&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fwww.example.com%2Fcallback&response_type=code&scope=tweet.read+tweet.write+users.read+offline.access&state=state",
},
{
name: "error saving session",
wantStatusCode: http.StatusInternalServerError,
sessionSaveError: errors.New("boom"),
wantRespBody: "unexpected error",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var mockSessionStore mocks.SessionStore
sess := sessions.NewSession(&mockSessionStore, "elon_session")
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil)
mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(tc.sessionSaveError)
handler := httpserver.NewHandler(
config.Config{
Twitter: config.TwitterConfig{
ClientID: "foo",
ClientSecret: "bar",
CallbackURL: "https://www.example.com/callback",
AuthorizeURL: "https://www.example.com/oauth/authorize",
TokenURL: "https://www.example.com/oauth/token",
},
},
templates,
&mocks.Store{},
&mockSessionStore,
&mockTokenGenerator{tokens: []string{"state", "pkceVerifier"}},
zap.NewNop(),
)
req := httptest.NewRequest(http.MethodGet, "/login", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
res := rec.Result()
defer res.Body.Close()
assert.Equal(t, tc.wantStatusCode, res.StatusCode)
if tc.wantRespBody == "" {
assert.Equal(t, tc.wantLocation, res.Header.Get("Location"))
}
if tc.wantRespBody != "" {
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Contains(t, string(body), tc.wantRespBody)
}
})
}
}
func TestCallback(t *testing.T) {
testCases := []struct {
name string
state string
sessionState string
sessionPkceVerifier string
code string
sessionReadError error
oauth2StatusCode int
wantStatusCode int
wantError string
}{
{
name: "unable to read session",
sessionReadError: errors.New("boom"),
wantStatusCode: http.StatusBadRequest,
wantError: "error reading session",
},
{
name: "missing state",
state: "mystate",
sessionState: "",
wantStatusCode: http.StatusBadRequest,
wantError: "error validating request",
},
{
name: "unexpected state",
state: "mystate",
sessionState: "foostate",
wantStatusCode: http.StatusBadRequest,
wantError: "error validating request",
},
{
name: "empty code",
state: "mystate",
sessionState: "mystate",
code: "",
wantStatusCode: http.StatusBadRequest,
wantError: "invalid code",
},
{
name: "empty pkce verifier",
state: "mystate",
sessionState: "mystate",
code: "mycode",
sessionPkceVerifier: "",
wantStatusCode: http.StatusBadRequest,
wantError: "error reading session",
},
{
name: "error exchanging code",
state: "mystate",
sessionState: "mystate",
code: "mycode",
sessionPkceVerifier: "mypkceverifier",
oauth2StatusCode: http.StatusInternalServerError,
wantStatusCode: http.StatusForbidden,
wantError: "error exchanging code",
},
{
name: "successful exchange",
state: "mystate",
sessionState: "mystate",
code: "mycode",
sessionPkceVerifier: "mypkceverifier",
oauth2StatusCode: http.StatusOK,
wantStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var mockStore mocks.Store
var mockSessionStore mocks.SessionStore
sess := sessions.NewSession(&mockSessionStore, "elon_session")
sess.Values["state"] = tc.sessionState
sess.Values["pkce_verifier"] = tc.sessionPkceVerifier
mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError)
const callbackURL = "https://www.example.com/callback"
oauthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/oauth/token", r.URL.Path)
require.NoError(t, r.ParseForm())
assert.Equal(t, tc.code, r.PostFormValue("code"))
assert.Equal(t, tc.sessionPkceVerifier, r.PostFormValue("code_verifier"))
assert.Equal(t, "authorization_code", r.PostFormValue("grant_type"))
assert.Equal(t, callbackURL, r.PostFormValue("redirect_uri"))
w.WriteHeader(tc.oauth2StatusCode)
if tc.oauth2StatusCode == http.StatusOK {
w.Write([]byte(`access_token=foo&expires_in=3600&refresh_token=&token_type=bearer"}`))
}
})
srv := httptest.NewServer(oauthHandler)
handler := httpserver.NewHandler(
config.Config{
Twitter: config.TwitterConfig{
ClientID: "foo",
ClientSecret: "bar",
CallbackURL: callbackURL,
AuthorizeURL: srv.URL + "/oauth/authorize",
TokenURL: srv.URL + "/oauth/token",
},
},
templates,
&mockStore,
&mockSessionStore,
nil,
zap.NewNop(),
)
req := httptest.NewRequest(http.MethodGet, "/callback?state="+tc.state+"&code="+tc.code, nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
res := rec.Result()
defer res.Body.Close()
assert.Equal(t, tc.wantStatusCode, res.StatusCode)
if tc.wantError != "" {
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Contains(t, string(body), tc.wantError)
}
})
}
}