package httpserver_test import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "git.netflux.io/rob/elon-eats-my-tweets/config" "git.netflux.io/rob/elon-eats-my-tweets/generated/mocks" "git.netflux.io/rob/elon-eats-my-tweets/generated/store" "git.netflux.io/rob/elon-eats-my-tweets/httpserver" "git.netflux.io/rob/elon-eats-my-tweets/twitter" "github.com/google/uuid" "github.com/gorilla/sessions" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap" ) // mockTokenGenerator implements httpserver.TokenGenerator. type mockTokenGenerator struct { i int tokens []string } func (g *mockTokenGenerator) GenerateToken(_ int) string { i := g.i if len(g.tokens) <= i { return "" } g.i++ return g.tokens[i] } func TestGetIndex(t *testing.T) { testCases := []struct { name string currentUser httpserver.CurrentUser wantBody string }{ { name: "logged out", currentUser: httpserver.CurrentUser{}, wantBody: "Sign in with Twitter", }, { name: "logged out", currentUser: httpserver.CurrentUser{ID: uuid.New(), TwitterName: "foo", TwitterUsername: "Foo Bar"}, wantBody: "Welcome to your dashboard", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() var mockSessionStore mocks.SessionStore sess := sessions.NewSession(&mockSessionStore, "elon_session") sess.Values["current_user"] = tc.currentUser mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil) mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(nil) handler := httpserver.NewHandler( config.Config{}, &mocks.Store{}, func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} }, &mockSessionStore, &mockTokenGenerator{}, zap.NewNop(), ) handler.ServeHTTP(rec, req) res := rec.Result() defer res.Body.Close() body, err := ioutil.ReadAll(res.Body) require.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) assert.Contains(t, string(body), tc.wantBody) }) } } func TestGetLogin(t *testing.T) { testCases := []struct { name string sessionSaveError error wantStatusCode int wantLocation string wantRespBody string }{ { name: "successful login", wantStatusCode: http.StatusFound, wantLocation: "https://www.example.com/oauth/authorize?client_id=foo&code_challenge=RdoE4fOeAO8YelxeEEd70qNDoVgzl4844utzxsozlR4&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fwww.example.com%2Fcallback&response_type=code&scope=tweet.read+tweet.write+users.read+offline.access&state=state", }, { name: "error saving session", wantStatusCode: http.StatusInternalServerError, sessionSaveError: errors.New("boom"), wantRespBody: "unexpected error", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var mockSessionStore mocks.SessionStore sess := sessions.NewSession(&mockSessionStore, "elon_session") mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil) mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(tc.sessionSaveError) handler := httpserver.NewHandler( config.Config{ Twitter: config.TwitterConfig{ ClientID: "foo", ClientSecret: "bar", CallbackURL: "https://www.example.com/callback", AuthorizeURL: "https://www.example.com/oauth/authorize", TokenURL: "https://www.example.com/oauth/token", }, }, &mocks.Store{}, func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} }, &mockSessionStore, &mockTokenGenerator{tokens: []string{"state", "pkceVerifier"}}, zap.NewNop(), ) req := httptest.NewRequest(http.MethodGet, "/login", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) res := rec.Result() defer res.Body.Close() assert.Equal(t, tc.wantStatusCode, res.StatusCode) if tc.wantRespBody == "" { assert.Equal(t, tc.wantLocation, res.Header.Get("Location")) } if tc.wantRespBody != "" { body, err := ioutil.ReadAll(res.Body) require.NoError(t, err) assert.Contains(t, string(body), tc.wantRespBody) } }) } } func TestGetCallback(t *testing.T) { testCases := []struct { name string state string sessionState string sessionPkceVerifier string code string sessionReadError error oauth2StatusCode int getTwitterUserError error createUserError error sessionSaveError error wantStatusCode int wantLocation string wantError string }{ { name: "unable to read session", sessionReadError: errors.New("boom"), wantStatusCode: http.StatusBadRequest, wantError: "error reading session", }, { name: "missing state", state: "mystate", sessionState: "", wantStatusCode: http.StatusBadRequest, wantError: "error validating request", }, { name: "unexpected state", state: "mystate", sessionState: "foostate", wantStatusCode: http.StatusBadRequest, wantError: "error validating request", }, { name: "empty code", state: "mystate", sessionState: "mystate", code: "", wantStatusCode: http.StatusBadRequest, wantError: "invalid code", }, { name: "empty pkce verifier", state: "mystate", sessionState: "mystate", code: "mycode", sessionPkceVerifier: "", wantStatusCode: http.StatusBadRequest, wantError: "error reading session", }, { name: "error exchanging code", state: "mystate", sessionState: "mystate", code: "mycode", sessionPkceVerifier: "mypkceverifier", oauth2StatusCode: http.StatusInternalServerError, wantStatusCode: http.StatusForbidden, wantError: "error exchanging code", }, { name: "error fetching user from twitter", state: "mystate", sessionState: "mystate", code: "mycode", sessionPkceVerifier: "mypkceverifier", oauth2StatusCode: http.StatusOK, getTwitterUserError: errors.New("nothing to see here"), wantStatusCode: http.StatusInternalServerError, wantError: "error validating user", }, { name: "error storing user", state: "mystate", sessionState: "mystate", code: "mycode", sessionPkceVerifier: "mypkceverifier", oauth2StatusCode: http.StatusOK, createUserError: errors.New("oh no"), wantStatusCode: http.StatusInternalServerError, wantError: "error saving user", }, { name: "error saving session", state: "mystate", sessionState: "mystate", code: "mycode", sessionPkceVerifier: "mypkceverifier", oauth2StatusCode: http.StatusOK, sessionSaveError: errors.New("foo"), wantStatusCode: http.StatusInternalServerError, wantError: "error saving user", }, { name: "successful exchange", state: "mystate", sessionState: "mystate", code: "mycode", sessionPkceVerifier: "mypkceverifier", oauth2StatusCode: http.StatusOK, wantStatusCode: http.StatusFound, wantLocation: "/", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var mockSessionStore mocks.SessionStore sess := sessions.NewSession(&mockSessionStore, "elon_session") sess.Values["state"] = tc.sessionState sess.Values["pkce_verifier"] = tc.sessionPkceVerifier mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, tc.sessionReadError) mockSessionStore.On("Save", mock.Anything, mock.Anything, sess).Return(tc.sessionSaveError) var mockTwitterClient mocks.TwitterAPIClient mockTwitterClient.On("GetMe").Return(&twitter.User{ID: "1", Name: "foo", Username: "Foo Bar"}, tc.getTwitterUserError) var mockStore mocks.Store mockStore.On("CreateUser", mock.Anything, mock.MatchedBy(func(params store.CreateUserParams) bool { return params.TwitterID == "1" && params.Name == "foo" && params.Username == "Foo Bar" })).Return(store.User{}, tc.createUserError) const callbackURL = "https://www.example.com/callback" oauthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, "/oauth/token", r.URL.Path) require.NoError(t, r.ParseForm()) assert.Equal(t, tc.code, r.PostFormValue("code")) assert.Equal(t, tc.sessionPkceVerifier, r.PostFormValue("code_verifier")) assert.Equal(t, "authorization_code", r.PostFormValue("grant_type")) assert.Equal(t, callbackURL, r.PostFormValue("redirect_uri")) w.WriteHeader(tc.oauth2StatusCode) if tc.oauth2StatusCode == http.StatusOK { w.Write([]byte(`access_token=foo&expires_in=3600&refresh_token=&token_type=bearer"}`)) } }) srv := httptest.NewServer(oauthHandler) handler := httpserver.NewHandler( config.Config{ Twitter: config.TwitterConfig{ ClientID: "foo", ClientSecret: "bar", CallbackURL: callbackURL, AuthorizeURL: srv.URL + "/oauth/authorize", TokenURL: srv.URL + "/oauth/token", }, }, &mockStore, func(*http.Client) httpserver.TwitterAPIClient { return &mockTwitterClient }, &mockSessionStore, nil, zap.NewNop(), ) req := httptest.NewRequest(http.MethodGet, "/callback?state="+tc.state+"&code="+tc.code, nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) res := rec.Result() defer res.Body.Close() assert.Equal(t, tc.wantStatusCode, res.StatusCode) if tc.wantError != "" { body, err := ioutil.ReadAll(res.Body) require.NoError(t, err) assert.Contains(t, string(body), tc.wantError) } }) } } func TestPostLogout(t *testing.T) { var mockSessionStore mocks.SessionStore sess := sessions.NewSession(&mockSessionStore, "elon_session") mockSessionStore.On("Get", mock.Anything, "elon_session").Return(sess, nil) mockSessionStore.On("Save", mock.Anything, mock.Anything, mock.MatchedBy(func(s *sessions.Session) bool { _, ok := s.Values["current_user"] return !ok })).Return(nil) handler := httpserver.NewHandler( config.Config{}, &mocks.Store{}, func(*http.Client) httpserver.TwitterAPIClient { return &mocks.TwitterAPIClient{} }, &mockSessionStore, nil, zap.NewNop(), ) req := httptest.NewRequest(http.MethodPost, "/logout", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) res := rec.Result() defer res.Body.Close() assert.Equal(t, http.StatusFound, res.StatusCode) assert.Equal(t, "/", res.Header.Get("Location")) }