package httpserver_test import ( "errors" "html/template" "io/ioutil" "net/http" "net/http/httptest" "path/filepath" "testing" "git.netflux.io/rob/elon-eats-my-tweets/config" "git.netflux.io/rob/elon-eats-my-tweets/generated/mocks" "git.netflux.io/rob/elon-eats-my-tweets/httpserver" "github.com/gorilla/sessions" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap" ) var templates = template.Must(template.ParseGlob(filepath.Join("..", "public", "views", "*.html"))) // mockTokenGenerator implements httpserver.TokenGenerator. type mockTokenGenerator struct { i int tokens []string } func (g *mockTokenGenerator) GenerateToken(_ int) string { i := g.i if len(g.tokens) <= i { return "" } g.i++ return g.tokens[i] } func TestGetIndex(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() handler := httpserver.NewHandler( config.Config{}, templates, &mocks.Store{}, &mockTokenGenerator{}, zap.NewNop(), ) handler.ServeHTTP(rec, req) res := rec.Result() defer res.Body.Close() body, err := ioutil.ReadAll(res.Body) require.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) assert.Contains(t, string(body), "Sign in with Twitter") } func TestLogin(t *testing.T) { testCases := []struct { name string sessionSaveError error wantStatusCode int wantLocation string wantRespBody string }{ { name: "successful login", wantStatusCode: http.StatusFound, wantLocation: "https://www.example.com/oauth/authorize?client_id=foo&code_challenge=RdoE4fOeAO8YelxeEEd70qNDoVgzl4844utzxsozlR4&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fwww.example.com%2Fcallback&response_type=code&scope=tweet.read+tweet.write+users.read+offline.access&state=state", }, { name: "error saving session", wantStatusCode: http.StatusInternalServerError, sessionSaveError: errors.New("boom"), wantRespBody: "unexpected error", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var mockStore mocks.Store sess := sessions.NewSession(&mockStore, "elon_session") mockStore.On("Get", mock.Anything, "elon_session").Return(sess, nil) mockStore.On("Save", mock.Anything, mock.Anything, sess).Return(tc.sessionSaveError) handler := httpserver.NewHandler( config.Config{ Twitter: config.TwitterConfig{ ClientID: "foo", ClientSecret: "bar", CallbackURL: "https://www.example.com/callback", AuthorizeURL: "https://www.example.com/oauth/authorize", TokenURL: "https://www.example.com/oauth/token", }, }, templates, &mockStore, &mockTokenGenerator{tokens: []string{"state", "pkceVerifier"}}, zap.NewNop(), ) req := httptest.NewRequest(http.MethodGet, "/login", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) res := rec.Result() defer res.Body.Close() assert.Equal(t, tc.wantStatusCode, res.StatusCode) if tc.wantRespBody == "" { assert.Equal(t, tc.wantLocation, res.Header.Get("Location")) } if tc.wantRespBody != "" { body, err := ioutil.ReadAll(res.Body) require.NoError(t, err) assert.Contains(t, string(body), tc.wantRespBody) } }) } } func TestCallback(t *testing.T) { testCases := []struct { name string state string sessionState string sessionPkceVerifier string code string sessionReadError error oauth2StatusCode int wantStatusCode int wantError string }{ { name: "unable to read session", sessionReadError: errors.New("boom"), wantStatusCode: http.StatusBadRequest, wantError: "error reading session", }, { name: "missing state", state: "mystate", sessionState: "", wantStatusCode: http.StatusBadRequest, wantError: "error validating request", }, { name: "unexpected state", state: "mystate", sessionState: "foostate", wantStatusCode: http.StatusBadRequest, wantError: "error validating request", }, { name: "empty code", state: "mystate", sessionState: "mystate", code: "", wantStatusCode: http.StatusBadRequest, wantError: "invalid code", }, { name: "empty pkce verifier", state: "mystate", sessionState: "mystate", code: "mycode", sessionPkceVerifier: "", wantStatusCode: http.StatusBadRequest, wantError: "error reading session", }, { name: "error exchanging code", state: "mystate", sessionState: "mystate", code: "mycode", sessionPkceVerifier: "mypkceverifier", oauth2StatusCode: http.StatusInternalServerError, wantStatusCode: http.StatusForbidden, wantError: "error exchanging code", }, { name: "successful exchange", state: "mystate", sessionState: "mystate", code: "mycode", sessionPkceVerifier: "mypkceverifier", oauth2StatusCode: http.StatusOK, wantStatusCode: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var mockStore mocks.Store sess := sessions.NewSession(&mockStore, "elon_session") sess.Values["state"] = tc.sessionState sess.Values["pkce_verifier"] = tc.sessionPkceVerifier mockStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError) const callbackURL = "https://www.example.com/callback" oauthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, "/oauth/token", r.URL.Path) require.NoError(t, r.ParseForm()) assert.Equal(t, tc.code, r.PostFormValue("code")) assert.Equal(t, tc.sessionPkceVerifier, r.PostFormValue("code_verifier")) assert.Equal(t, "authorization_code", r.PostFormValue("grant_type")) assert.Equal(t, callbackURL, r.PostFormValue("redirect_uri")) w.WriteHeader(tc.oauth2StatusCode) if tc.oauth2StatusCode == http.StatusOK { w.Write([]byte(`access_token=foo&expires_in=3600&refresh_token=&token_type=bearer"}`)) } }) srv := httptest.NewServer(oauthHandler) handler := httpserver.NewHandler( config.Config{ Twitter: config.TwitterConfig{ ClientID: "foo", ClientSecret: "bar", CallbackURL: callbackURL, AuthorizeURL: srv.URL + "/oauth/authorize", TokenURL: srv.URL + "/oauth/token", }, }, templates, &mockStore, nil, zap.NewNop(), ) req := httptest.NewRequest(http.MethodGet, "/callback?state="+tc.state+"&code="+tc.code, nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) res := rec.Result() defer res.Body.Close() assert.Equal(t, tc.wantStatusCode, res.StatusCode) if tc.wantError != "" { body, err := ioutil.ReadAll(res.Body) require.NoError(t, err) assert.Contains(t, string(body), tc.wantError) } }) } }