diff --git a/handler/handler.go b/handler/handler.go index 5dd78a9..e682f12 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -8,31 +8,47 @@ import ( //go:embed static/index.html var indexPage string -func New(matrixHostname, matrixBaseURL string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet && r.Method != http.MethodHead { - w.WriteHeader(http.StatusMethodNotAllowed) - w.Write([]byte("405 method not allowed\n")) - return - } +type Params struct { + MatrixHostname string + MatrixBaseURL string + RootPath string +} - switch r.URL.Path { - case "/.well-known/matrix/server": - w.Header().Add("content-type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"m.server": "` + matrixHostname + `"}`)) - case "/.well-known/matrix/client": - w.Header().Add("content-type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"m.homeserver": {"base_url": "` + matrixBaseURL + `"}}`)) - case "/": - w.Header().Add("content-type", "text/html") - w.WriteHeader(http.StatusOK) - if r.Method == http.MethodGet { - w.Write([]byte(indexPage)) - } - default: - http.NotFound(w, r) - } +type handler struct { + *http.ServeMux + params Params +} + +func New(params Params) http.Handler { + h := &handler{ + ServeMux: http.NewServeMux(), + params: params, + } + + h.Handle("GET /.well-known/matrix/server", http.HandlerFunc(h.getMatrixServer)) + h.Handle("GET /.well-known/matrix/client", http.HandlerFunc(h.getMatrixClient)) + h.Handle("GET /{$}", http.HandlerFunc(h.getHomepage)) + h.Handle("GET /", http.FileServer(http.Dir(params.RootPath))) + + return h +} + +func (h *handler) getMatrixServer(w http.ResponseWriter, r *http.Request) { + w.Header().Add("content-type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"m.server": "` + h.params.MatrixHostname + `"}`)) +} + +func (h *handler) getMatrixClient(w http.ResponseWriter, r *http.Request) { + w.Header().Add("content-type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"m.homeserver": {"base_url": "` + h.params.MatrixBaseURL + `"}}`)) +} + +func (h *handler) getHomepage(w http.ResponseWriter, r *http.Request) { + w.Header().Add("content-type", "text/html") + w.WriteHeader(http.StatusOK) + if r.Method == http.MethodGet { + w.Write([]byte(indexPage)) } } diff --git a/handler/handler_test.go b/handler/handler_test.go index e539b00..28aa533 100644 --- a/handler/handler_test.go +++ b/handler/handler_test.go @@ -1,7 +1,7 @@ package handler_test import ( - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -41,6 +41,14 @@ func TestHandler(t *testing.T) { wantStatusCode: http.StatusOK, wantBody: `{"m.homeserver": {"base_url": "https://foo.example.com"}}`, }, + { + name: "GET /test.html", + method: http.MethodGet, + path: "/test.html", + wantContentType: "text/html; charset=utf-8", + wantStatusCode: http.StatusOK, + wantBody: "\n \n Test\n \n\n", + }, { name: "GET /", method: http.MethodGet, @@ -69,7 +77,7 @@ func TestHandler(t *testing.T) { method: http.MethodPost, path: "/", wantStatusCode: http.StatusMethodNotAllowed, - wantBody: "405 method not allowed", + wantBody: "Method Not Allowed\n", }, } @@ -78,7 +86,13 @@ func TestHandler(t *testing.T) { req := httptest.NewRequest(tc.method, tc.path, nil) rec := httptest.NewRecorder() - h := handler.New(matrixHostname, matrixBaseURL) + h := handler.New( + handler.Params{ + MatrixHostname: matrixHostname, + MatrixBaseURL: matrixBaseURL, + RootPath: "testdata/static/", + }, + ) h.ServeHTTP(rec, req) resp := rec.Result() defer resp.Body.Close() @@ -89,7 +103,7 @@ func TestHandler(t *testing.T) { assert.Equal(t, tc.wantContentType, resp.Header.Get("content-type")) } - respBody, err := ioutil.ReadAll(resp.Body) + respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Contains(t, string(respBody), tc.wantBody) }) diff --git a/handler/testdata/static/test.html b/handler/testdata/static/test.html new file mode 100644 index 0000000..8122974 --- /dev/null +++ b/handler/testdata/static/test.html @@ -0,0 +1,5 @@ + + + Test + + diff --git a/main.go b/main.go index 22fddfd..0f63327 100644 --- a/main.go +++ b/main.go @@ -18,8 +18,9 @@ const ( func main() { matrixHostname := os.Getenv("NETFLUX_MATRIX_HOSTNAME") matrixBaseURL := os.Getenv("NETFLUX_MATRIX_BASE_URL") - if matrixHostname == "" || matrixBaseURL == "" { - log.Fatal("NETFLUX_MATRIX_HOSTNAME and NETFLUX_MATRIX_BASE_URL are both required") + rootPath := os.Getenv("NETFLUX_ROOT_PATH") + if matrixHostname == "" || matrixBaseURL == "" || rootPath == "" { + log.Fatal("NETFLUX_MATRIX_HOSTNAME and NETFLUX_MATRIX_BASE_URL and NETFLUX_ROOT_PATH are all required") } listenAddr := os.Getenv("NETFLUX_LISTEN_ADDR") @@ -28,8 +29,12 @@ func main() { } server := http.Server{ - Addr: listenAddr, - Handler: handler.New(matrixHostname, matrixBaseURL), + Addr: listenAddr, + Handler: handler.New(handler.Params{ + MatrixHostname: matrixHostname, + MatrixBaseURL: matrixBaseURL, + RootPath: rootPath, + }), ReadTimeout: readTimeout, WriteTimeout: writeTimeout, }