diff --git a/frontend/api.go b/frontend/api.go index 9d2380e..3cd5e9e 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -113,7 +113,7 @@ func apiHandler(w http.ResponseWriter, r *http.Request) { } else { handler := apiHandlerMap[request.Type] if handler == nil { - response = apiErrorHandler(errors.New("Invalid request type")) + response = apiErrorHandler(errors.New("invalid request type")) } else { response = handler(request) } diff --git a/frontend/api_test.go b/frontend/api_test.go new file mode 100644 index 0000000..d985226 --- /dev/null +++ b/frontend/api_test.go @@ -0,0 +1,207 @@ +package main + +import ( + "bytes" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/jarcoal/httpmock" + "github.com/magiconair/properties/assert" +) + +func TestApiServerListHandler(t *testing.T) { + setting.servers = []string{"alpha", "beta", "gamma"} + response := apiServerListHandler(apiRequest{}) + + assert.Equal(t, len(response.Result), 3) + assert.Equal(t, response.Result[0].(apiGenericResultPair).Server, "alpha") + assert.Equal(t, response.Result[1].(apiGenericResultPair).Server, "beta") + assert.Equal(t, response.Result[2].(apiGenericResultPair).Server, "gamma") +} + +func TestApiGenericHandlerFactory(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, BirdSummaryData) + httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show protocols"), httpResponse) + + setting.servers = []string{"alpha"} + setting.domain = "" + setting.proxyPort = 8000 + + request := apiRequest{ + Servers: setting.servers, + Type: "bird", + Args: "show protocols", + } + + handler := apiGenericHandlerFactory("bird") + response := handler(request) + + assert.Equal(t, response.Error, "") + + result := response.Result[0].(*apiGenericResultPair) + assert.Equal(t, result.Server, "alpha") + assert.Equal(t, result.Data, BirdSummaryData) +} + +func TestApiSummaryHandler(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, BirdSummaryData) + httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show protocols"), httpResponse) + + setting.servers = []string{"alpha"} + setting.domain = "" + setting.proxyPort = 8000 + + request := apiRequest{ + Servers: setting.servers, + Type: "summary", + Args: "", + } + response := apiSummaryHandler(request) + + assert.Equal(t, response.Error, "") + + summary := response.Result[0].(*apiSummaryResultPair) + assert.Equal(t, summary.Server, "alpha") + // Protocol list will be sorted + assert.Equal(t, summary.Data[1].Name, "device1") + assert.Equal(t, summary.Data[1].Proto, "Device") + assert.Equal(t, summary.Data[1].Table, "---") + assert.Equal(t, summary.Data[1].State, "up") + assert.Equal(t, summary.Data[1].Since, "2021-08-27") + assert.Equal(t, summary.Data[1].Info, "") +} + +func TestApiSummaryHandlerError(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, "Mock backend error") + httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show protocols"), httpResponse) + + setting.servers = []string{"alpha"} + setting.domain = "" + setting.proxyPort = 8000 + + request := apiRequest{ + Servers: setting.servers, + Type: "summary", + Args: "", + } + response := apiSummaryHandler(request) + + assert.Equal(t, response.Error, "Mock backend error") +} + +func TestApiWhoisHandler(t *testing.T) { + expectedData := "Mock Data" + server := WhoisServer{ + t: t, + expectedQuery: "AS6939", + response: expectedData, + } + + server.Listen() + go server.Run() + defer server.Close() + + setting.whoisServer = server.server.Addr().String() + + request := apiRequest{ + Servers: []string{}, + Type: "", + Args: "AS6939", + } + response := apiWhoisHandler(request) + + assert.Equal(t, response.Error, "") + + whoisResult := response.Result[0].(apiGenericResultPair) + assert.Equal(t, whoisResult.Server, "") + assert.Equal(t, whoisResult.Data, expectedData) +} + +func TestApiErrorHandler(t *testing.T) { + err := errors.New("Mock Error") + response := apiErrorHandler(err) + assert.Equal(t, response.Error, "Mock Error") +} + +func TestApiHandler(t *testing.T) { + setting.servers = []string{"alpha", "beta", "gamma"} + + request := apiRequest{ + Servers: []string{}, + Type: "server_list", + Args: "", + } + requestJson, err := json.Marshal(request) + if err != nil { + t.Error(err) + } + + r := httptest.NewRequest(http.MethodGet, "/api", bytes.NewReader(requestJson)) + w := httptest.NewRecorder() + apiHandler(w, r) + + var response apiResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + if err != nil { + t.Error(err) + } + + assert.Equal(t, len(response.Result), 3) + // Hard to unmarshal JSON into apiGenericResultPair objects, won't check here +} + +func TestApiHandlerBadJSON(t *testing.T) { + setting.servers = []string{"alpha", "beta", "gamma"} + + r := httptest.NewRequest(http.MethodGet, "/api", strings.NewReader("{bad json}")) + w := httptest.NewRecorder() + apiHandler(w, r) + + var response apiResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + if err != nil { + t.Error(err) + } + + assert.Equal(t, len(response.Result), 0) +} + +func TestApiHandlerInvalidType(t *testing.T) { + setting.servers = []string{"alpha", "beta", "gamma"} + + request := apiRequest{ + Servers: setting.servers, + Type: "invalid_type", + Args: "", + } + requestJson, err := json.Marshal(request) + if err != nil { + t.Error(err) + } + + r := httptest.NewRequest(http.MethodGet, "/api", bytes.NewReader(requestJson)) + w := httptest.NewRecorder() + apiHandler(w, r) + + var response apiResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + if err != nil { + t.Error(err) + } + + assert.Equal(t, len(response.Result), 0) +} diff --git a/frontend/asn_cache_test.go b/frontend/asn_cache_test.go index 26af449..2ba914d 100644 --- a/frontend/asn_cache_test.go +++ b/frontend/asn_cache_test.go @@ -3,6 +3,8 @@ package main import ( "strings" "testing" + + "github.com/magiconair/properties/assert" ) func TestGetASNRepresentationDNS(t *testing.T) { @@ -34,7 +36,5 @@ func TestGetASNRepresentationFallback(t *testing.T) { setting.whoisServer = "" cache := make(ASNCache) result := cache.Lookup("6939") - if result != "AS6939" { - t.Errorf("Lookup AS6939 failed, got %s", result) - } + assert.Equal(t, result, "AS6939") } diff --git a/frontend/bgpmap_test.go b/frontend/bgpmap_test.go index 81db0ff..3450410 100644 --- a/frontend/bgpmap_test.go +++ b/frontend/bgpmap_test.go @@ -8,14 +8,14 @@ import ( "testing" ) -func readDataFile(filename string) string { +func readDataFile(t *testing.T, filename string) string { _, sourceName, _, _ := runtime.Caller(0) projectRoot := path.Join(path.Dir(sourceName), "..") dir := path.Join(projectRoot, filename) data, err := ioutil.ReadFile(dir) if err != nil { - panic(err) + t.Fatal(err) } return string(data) } @@ -41,7 +41,7 @@ func TestBirdRouteToGraphvizXSS(t *testing.T) { func TestBirdRouteToGraph(t *testing.T) { setting.dnsInterface = "" - input := readDataFile("frontend/test_data/bgpmap_case1.txt") + input := readDataFile(t, "frontend/test_data/bgpmap_case1.txt") result := birdRouteToGraph([]string{"node"}, []string{input}, "target") // Source node must exist @@ -71,3 +71,14 @@ func TestBirdRouteToGraph(t *testing.T) { t.Error("Result doesn't contain edge from 4242423914 to target") } } + +func TestBirdRouteToGraphviz(t *testing.T) { + setting.dnsInterface = "" + + input := readDataFile(t, "frontend/test_data/bgpmap_case1.txt") + result := birdRouteToGraphviz([]string{"node"}, []string{input}, "target") + + if !strings.Contains(result, "digraph {") { + t.Error("Response is not Graphviz data") + } +} diff --git a/frontend/dn42.go b/frontend/dn42.go index d5aeaba..5995b94 100644 --- a/frontend/dn42.go +++ b/frontend/dn42.go @@ -65,7 +65,7 @@ func shortenWhoisFilter(whois string) string { shouldSkip := false shouldSkip = shouldSkip || len(s) == 0 shouldSkip = shouldSkip || len(s) > 0 && s[0] == '#' - shouldSkip = shouldSkip || strings.Contains(strings.ToUpper(s), "REDACTED FOR PRIVACY") + shouldSkip = shouldSkip || strings.Contains(strings.ToUpper(s), "REDACTED") if shouldSkip { skippedLinesLonger++ diff --git a/frontend/dn42_test.go b/frontend/dn42_test.go index c0b253e..a258538 100644 --- a/frontend/dn42_test.go +++ b/frontend/dn42_test.go @@ -28,3 +28,79 @@ func TestDN42WhoisFilterUnneeded(t *testing.T) { t.Errorf("Output doesn't match expected: %s", result) } } + +func TestShortenWhoisFilterShorterMode(t *testing.T) { + input := ` +Information line that will be removed + +# Comment that will be removed +Name: Redacted for privacy +Descr: This is a vvvvvvvvvvvvvvvvvvvvvvveeeeeeeeeeeeeeeeeeeerrrrrrrrrrrrrrrrrrrrrrrryyyyyyyyyyyyyyyyyyy long line that will be skipped. +Looooooooooooooooooooooong key: this line will be skipped. + +Preserved1: this line isn't removed. +Preserved2: this line isn't removed. +Preserved3: this line isn't removed. +Preserved4: this line isn't removed. +Preserved5: this line isn't removed. + +` + + result := shortenWhoisFilter(input) + + expectedResult := `Preserved1: this line isn't removed. +Preserved2: this line isn't removed. +Preserved3: this line isn't removed. +Preserved4: this line isn't removed. +Preserved5: this line isn't removed. + +3 line(s) skipped. +` + + if result != expectedResult { + t.Errorf("Output doesn't match expected: %s", result) + } +} + +func TestShortenWhoisFilterLongerMode(t *testing.T) { + input := ` +Information line that will be removed + +# Comment that will be removed +Name: Redacted for privacy +Descr: This is a vvvvvvvvvvvvvvvvvvvvvvveeeeeeeeeeeeeeeeeeeerrrrrrrrrrrrrrrrrrrrrrrryyyyyyyyyyyyyyyyyyy long line that will be skipped. +Looooooooooooooooooooooong key: this line will be skipped. + +Preserved1: this line isn't removed. + +` + + result := shortenWhoisFilter(input) + + expectedResult := `Information line that will be removed +Descr: This is a vvvvvvvvvvvvvvvvvvvvvvveeeeeeeeeeeeeeeeeeeerrrrrrrrrrrrrrrrrrrrrrrryyyyyyyyyyyyyyyyyyy long line that will be skipped. +Looooooooooooooooooooooong key: this line will be skipped. +Preserved1: this line isn't removed. + +7 line(s) skipped. +` + + if result != expectedResult { + t.Errorf("Output doesn't match expected: %s", result) + } +} + +func TestShortenWhoisFilterSkipNothing(t *testing.T) { + input := `Preserved1: this line isn't removed. +Preserved2: this line isn't removed. +Preserved3: this line isn't removed. +Preserved4: this line isn't removed. +Preserved5: this line isn't removed. +` + + result := shortenWhoisFilter(input) + + if result != input { + t.Errorf("Output doesn't match expected: %s", result) + } +} diff --git a/frontend/go.mod b/frontend/go.mod index 860db2d..8686a63 100644 --- a/frontend/go.mod +++ b/frontend/go.mod @@ -11,7 +11,9 @@ require ( require ( github.com/felixge/httpsnoop v1.0.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/hashicorp/hcl v1.0.0 // indirect + github.com/jarcoal/httpmock v1.3.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.0.6 // indirect diff --git a/frontend/go.sum b/frontend/go.sum index ce71e67..9dfd31c 100644 --- a/frontend/go.sum +++ b/frontend/go.sum @@ -533,6 +533,8 @@ github.com/google/pprof v0.0.0-20210601050228-01bbb1931b22/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.0.0-20220520183353-fd19c99a87aa/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8= @@ -588,6 +590,8 @@ github.com/hashicorp/memberlist v0.5.0/go.mod h1:yvyXLpo0QaGE59Y7hDTsTzDD25JYBZ4 github.com/hashicorp/serf v0.10.1/go.mod h1:yL2t6BqATOLGc5HF7qbFkTfXoPIY0WZdWHfEvMqbG+4= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/jarcoal/httpmock v1.3.0 h1:2RJ8GP0IIaWwcC9Fp2BmVi8Kog3v2Hn7VXM3fTd+nuc= +github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= diff --git a/frontend/lgproxy.go b/frontend/lgproxy.go index e973f85..e7c08e7 100644 --- a/frontend/lgproxy.go +++ b/frontend/lgproxy.go @@ -56,8 +56,8 @@ func batchRequest(servers []string, endpoint string, command string) []string { buf := make([]byte, 65536) n, err := io.ReadFull(response.Body, buf) - if err != nil && err != io.ErrUnexpectedEOF { - ch <- channelData{i, err.Error()} + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + ch <- channelData{i, "request failed: " + err.Error()} } else { ch <- channelData{i, string(buf[:n])} } diff --git a/frontend/lgproxy_test.go b/frontend/lgproxy_test.go new file mode 100644 index 0000000..7a2e098 --- /dev/null +++ b/frontend/lgproxy_test.go @@ -0,0 +1,163 @@ +package main + +import ( + "errors" + "strings" + "testing" + + "github.com/jarcoal/httpmock" +) + +func TestBatchRequestIPv4(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, "Mock Result") + httpmock.RegisterResponder("GET", "http://1.1.1.1:8000/mock?q=cmd", httpResponse) + httpmock.RegisterResponder("GET", "http://2.2.2.2:8000/mock?q=cmd", httpResponse) + httpmock.RegisterResponder("GET", "http://3.3.3.3:8000/mock?q=cmd", httpResponse) + + setting.servers = []string{ + "1.1.1.1", + "2.2.2.2", + "3.3.3.3", + } + setting.domain = "" + setting.proxyPort = 8000 + response := batchRequest(setting.servers, "mock", "cmd") + + if len(response) != 3 { + t.Error("Did not get response of all three mock servers") + } + for i := 0; i < len(response); i++ { + if response[i] != "Mock Result" { + t.Error("HTTP response mismatch") + } + } +} + +func TestBatchRequestIPv6(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, "Mock Result") + httpmock.RegisterResponder("GET", "http://[2001:db8::1]:8000/mock?q=cmd", httpResponse) + httpmock.RegisterResponder("GET", "http://[2001:db8::2]:8000/mock?q=cmd", httpResponse) + httpmock.RegisterResponder("GET", "http://[2001:db8::3]:8000/mock?q=cmd", httpResponse) + + setting.servers = []string{ + "2001:db8::1", + "2001:db8::2", + "2001:db8::3", + } + setting.domain = "" + setting.proxyPort = 8000 + response := batchRequest(setting.servers, "mock", "cmd") + + if len(response) != 3 { + t.Error("Did not get response of all three mock servers") + } + for i := 0; i < len(response); i++ { + if response[i] != "Mock Result" { + t.Error("HTTP response mismatch") + } + } +} + +func TestBatchRequestEmptyResponse(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, "") + httpmock.RegisterResponder("GET", "http://alpha:8000/mock?q=cmd", httpResponse) + httpmock.RegisterResponder("GET", "http://beta:8000/mock?q=cmd", httpResponse) + httpmock.RegisterResponder("GET", "http://gamma:8000/mock?q=cmd", httpResponse) + + setting.servers = []string{ + "alpha", + "beta", + "gamma", + } + setting.domain = "" + setting.proxyPort = 8000 + response := batchRequest(setting.servers, "mock", "cmd") + + if len(response) != 3 { + t.Error("Did not get response of all three mock servers") + } + for i := 0; i < len(response); i++ { + if !strings.Contains(response[i], "node returned empty response") { + t.Error("Did not produce error for empty response") + } + } +} + +func TestBatchRequestDomainSuffix(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, "Mock Result") + httpmock.RegisterResponder("GET", "http://alpha.suffix:8000/mock?q=cmd", httpResponse) + httpmock.RegisterResponder("GET", "http://beta.suffix:8000/mock?q=cmd", httpResponse) + httpmock.RegisterResponder("GET", "http://gamma.suffix:8000/mock?q=cmd", httpResponse) + + setting.servers = []string{ + "alpha", + "beta", + "gamma", + } + setting.domain = "suffix" + setting.proxyPort = 8000 + response := batchRequest(setting.servers, "mock", "cmd") + + if len(response) != 3 { + t.Error("Did not get response of all three mock servers") + } + for i := 0; i < len(response); i++ { + if response[i] != "Mock Result" { + t.Error("HTTP response mismatch") + } + } +} + +func TestBatchRequestHTTPError(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpError := httpmock.NewErrorResponder(errors.New("Oops!")) + httpmock.RegisterResponder("GET", "http://alpha:8000/mock?q=cmd", httpError) + httpmock.RegisterResponder("GET", "http://beta:8000/mock?q=cmd", httpError) + httpmock.RegisterResponder("GET", "http://gamma:8000/mock?q=cmd", httpError) + + setting.servers = []string{ + "alpha", + "beta", + "gamma", + } + setting.domain = "" + setting.proxyPort = 8000 + response := batchRequest(setting.servers, "mock", "cmd") + + if len(response) != 3 { + t.Error("Did not get response of all three mock servers") + } + for i := 0; i < len(response); i++ { + if !strings.Contains(response[i], "request failed") { + t.Error("Did not produce HTTP error") + } + } +} + +func TestBatchRequestInvalidServer(t *testing.T) { + setting.servers = []string{} + setting.domain = "" + setting.proxyPort = 8000 + response := batchRequest([]string{"invalid"}, "mock", "cmd") + + if len(response) != 1 { + t.Error("Did not get response of all mock servers") + } + if !strings.Contains(response[0], "invalid server") { + t.Error("Did not produce invalid server error") + } +} diff --git a/frontend/render_test.go b/frontend/render_test.go index e3a8cc7..a923963 100644 --- a/frontend/render_test.go +++ b/frontend/render_test.go @@ -8,6 +8,17 @@ import ( "testing" ) +const BirdSummaryData = `BIRD 2.0.8 ready. +Name Proto Table State Since Info +static1 Static master4 up 2021-08-27 +static2 Static master6 up 2021-08-27 +device1 Device --- up 2021-08-27 +kernel1 Kernel master6 up 2021-08-27 +kernel2 Kernel master4 up 2021-08-27 +direct1 Direct --- up 2021-08-27 +int_babel Babel --- up 2021-08-27 +` + func initSettings() { setting.servers = []string{"alpha"} setting.serversDisplay = []string{"alpha"} @@ -101,17 +112,8 @@ func TestSummaryTableXSS(t *testing.T) { func TestSummaryTableProtocolFilter(t *testing.T) { initSettings() setting.protocolFilter = []string{"Static", "Direct", "Babel"} - data := `BIRD 2.0.8 ready. -Name Proto Table State Since Info -static1 Static master4 up 2021-08-27 -static2 Static master6 up 2021-08-27 -device1 Device --- up 2021-08-27 -kernel1 Kernel master6 up 2021-08-27 -kernel2 Kernel master4 up 2021-08-27 -direct1 Direct --- up 2021-08-27 -int_babel Babel --- up 2021-08-27 ` - result := string(summaryTable(data, "testserver")) + result := string(summaryTable(BirdSummaryData, "testserver")) expectedInclude := []string{"static1", "static2", "int_babel", "direct1"} expectedExclude := []string{"device1", "kernel1", "kernel2"} @@ -134,17 +136,8 @@ int_babel Babel --- up 2021-08-27 ` func TestSummaryTableNameFilter(t *testing.T) { initSettings() setting.nameFilter = "^static" - data := `BIRD 2.0.8 ready. -Name Proto Table State Since Info -static1 Static master4 up 2021-08-27 -static2 Static master6 up 2021-08-27 -device1 Device --- up 2021-08-27 -kernel1 Kernel master6 up 2021-08-27 -kernel2 Kernel master4 up 2021-08-27 -direct1 Direct --- up 2021-08-27 -int_babel Babel --- up 2021-08-27 ` - result := string(summaryTable(data, "testserver")) + result := string(summaryTable(BirdSummaryData, "testserver")) expectedInclude := []string{"device1", "kernel1", "kernel2", "direct1", "int_babel"} expectedExclude := []string{"static1", "static2"} diff --git a/frontend/settings_test.go b/frontend/settings_test.go new file mode 100644 index 0000000..e42b6a4 --- /dev/null +++ b/frontend/settings_test.go @@ -0,0 +1,8 @@ +package main + +import "testing" + +func TestParseSettings(t *testing.T) { + parseSettings() + // Good as long as it doesn't panic +} diff --git a/frontend/telegram_bot_test.go b/frontend/telegram_bot_test.go index 2612435..31fec35 100644 --- a/frontend/telegram_bot_test.go +++ b/frontend/telegram_bot_test.go @@ -1,12 +1,59 @@ package main import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" "testing" + + "github.com/jarcoal/httpmock" + "github.com/magiconair/properties/assert" ) func doTestTelegramIsCommand(t *testing.T, message string, command string, expected bool) { - if telegramIsCommand(message, command) != expected { - t.Errorf("telegramIsCommand(\"%s\", \"%s\") unexpected result", message, command) + result := telegramIsCommand(message, command) + assert.Equal(t, result, expected) +} + +func mockTelegramCall(t *testing.T, msg string, raw bool) string { + return mockTelegramEndpointCall(t, "/telegram/", msg, raw) +} + +func mockTelegramEndpointCall(t *testing.T, endpoint string, msg string, raw bool) string { + request := tgWebhookRequest{ + Message: tgMessage{ + MessageID: 123, + Chat: tgChat{ + ID: 456, + }, + Text: msg, + }, + } + requestJson, err := json.Marshal(request) + if err != nil { + t.Fatal(err) + } + + requestBody := bytes.NewReader(requestJson) + + r := httptest.NewRequest(http.MethodGet, endpoint, requestBody) + w := httptest.NewRecorder() + webHandlerTelegramBot(w, r) + + if raw { + return w.Body.String() + } else { + var response tgWebhookResponse + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Error(err) + } + + assert.Equal(t, response.ChatID, request.Message.Chat.ID) + assert.Equal(t, response.ReplyToMessageID, request.Message.MessageID) + return response.Text } } @@ -41,3 +88,280 @@ func TestTelegramIsCommand(t *testing.T) { doTestTelegramIsCommand(t, "/trace@test_bot_123 google.com", "trace", false) doTestTelegramIsCommand(t, "/trace@test google.com", "trace", false) } + +func TestTelegramBatchRequestFormatSingleServer(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, "Mock") + httpmock.RegisterResponder("GET", "http://alpha:8000/mock?q=cmd", httpResponse) + + setting.servers = []string{"alpha"} + setting.domain = "" + setting.proxyPort = 8000 + + result := telegramBatchRequestFormat(setting.servers, "mock", "cmd", telegramDefaultPostProcess) + expected := "Mock\n\n" + assert.Equal(t, result, expected) +} + +func TestTelegramBatchRequestFormatMultipleServers(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, "Mock") + httpmock.RegisterResponder("GET", "http://alpha:8000/mock?q=cmd", httpResponse) + httpmock.RegisterResponder("GET", "http://beta:8000/mock?q=cmd", httpResponse) + httpmock.RegisterResponder("GET", "http://gamma:8000/mock?q=cmd", httpResponse) + + setting.servers = []string{ + "alpha", + "beta", + "gamma", + } + setting.domain = "" + setting.proxyPort = 8000 + + result := telegramBatchRequestFormat(setting.servers, "mock", "cmd", telegramDefaultPostProcess) + expected := "alpha\nMock\n\nbeta\nMock\n\ngamma\nMock\n\n" + assert.Equal(t, result, expected) +} + +func TestWebHandlerTelegramBotBadJSON(t *testing.T) { + requestBody := strings.NewReader("{bad json}") + + r := httptest.NewRequest(http.MethodGet, "/telegram/", requestBody) + w := httptest.NewRecorder() + webHandlerTelegramBot(w, r) + + response := w.Body.String() + assert.Equal(t, response, "") +} + +func TestWebHandlerTelegramBotTrace(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, "Mock Response") + httpmock.RegisterResponder("GET", "http://alpha:8000/traceroute?q=1.1.1.1", httpResponse) + + setting.servers = []string{"alpha"} + setting.domain = "" + setting.proxyPort = 8000 + + response := mockTelegramCall(t, "/trace 1.1.1.1", false) + assert.Equal(t, response, "```\nMock Response\n```") +} + +func TestWebHandlerTelegramBotTraceWithServerList(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, "Mock Response") + httpmock.RegisterResponder("GET", "http://alpha:8000/traceroute?q=1.1.1.1", httpResponse) + + setting.servers = []string{"alpha"} + setting.domain = "" + setting.proxyPort = 8000 + + response := mockTelegramEndpointCall(t, "/telegram/alpha", "/trace 1.1.1.1", false) + assert.Equal(t, response, "```\nMock Response\n```") +} + +func TestWebHandlerTelegramBotRoute(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, "Mock Response") + httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show route for 1.1.1.1 primary"), httpResponse) + + setting.servers = []string{"alpha"} + setting.domain = "" + setting.proxyPort = 8000 + + response := mockTelegramCall(t, "/route 1.1.1.1", false) + assert.Equal(t, response, "```\nMock Response\n```") +} + +func TestWebHandlerTelegramBotPath(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, ` +BGP.as_path: 123 456 +`) + httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show route for 1.1.1.1 all primary"), httpResponse) + + setting.servers = []string{"alpha"} + setting.domain = "" + setting.proxyPort = 8000 + + response := mockTelegramCall(t, "/path 1.1.1.1", false) + assert.Equal(t, response, "```\n123 456\n```") +} + +func TestWebHandlerTelegramBotPathMissing(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpResponse := httpmock.NewStringResponder(200, "No path in this response") + httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show route for 1.1.1.1 all primary"), httpResponse) + + setting.servers = []string{"alpha"} + setting.domain = "" + setting.proxyPort = 8000 + + response := mockTelegramCall(t, "/path 1.1.1.1", false) + assert.Equal(t, response, "```\nempty result\n```") +} + +func TestWebHandlerTelegramBotWhois(t *testing.T) { + server := WhoisServer{ + t: t, + expectedQuery: "AS6939", + response: AS6939Response, + } + + server.Listen() + go server.Run() + defer server.Close() + + setting.netSpecificMode = "" + setting.whoisServer = server.server.Addr().String() + + response := mockTelegramCall(t, "/whois AS6939", false) + assert.Equal(t, response, "```"+server.response+"```") +} + +func TestWebHandlerTelegramBotWhoisDN42Mode(t *testing.T) { + server := WhoisServer{ + t: t, + expectedQuery: "AS4242422547", + response: ` +Query for AS4242422547 +`, + } + + server.Listen() + go server.Run() + defer server.Close() + + setting.netSpecificMode = "dn42" + setting.whoisServer = server.server.Addr().String() + + response := mockTelegramCall(t, "/whois 2547", false) + assert.Equal(t, response, "```"+server.response+"```") +} + +func TestWebHandlerTelegramBotWhoisDN42ModeFullASN(t *testing.T) { + server := WhoisServer{ + t: t, + expectedQuery: "AS4242422547", + response: ` +Query for AS4242422547 +`, + } + + server.Listen() + go server.Run() + defer server.Close() + + setting.netSpecificMode = "dn42" + setting.whoisServer = server.server.Addr().String() + + response := mockTelegramCall(t, "/whois 4242422547", false) + assert.Equal(t, response, "```"+server.response+"```") +} + +func TestWebHandlerTelegramBotWhoisShortenMode(t *testing.T) { + server := WhoisServer{ + t: t, + expectedQuery: "AS6939", + response: ` +Information line that will be removed + +# Comment that will be removed +Name: Redacted for privacy +Descr: This is a vvvvvvvvvvvvvvvvvvvvvvveeeeeeeeeeeeeeeeeeeerrrrrrrrrrrrrrrrrrrrrrrryyyyyyyyyyyyyyyyyyy long line that will be skipped. +Looooooooooooooooooooooong key: this line will be skipped. + +Preserved1: this line isn't removed. +Preserved2: this line isn't removed. +Preserved3: this line isn't removed. +Preserved4: this line isn't removed. +Preserved5: this line isn't removed. + +`, + } + + expectedResult := `Preserved1: this line isn't removed. +Preserved2: this line isn't removed. +Preserved3: this line isn't removed. +Preserved4: this line isn't removed. +Preserved5: this line isn't removed. + +3 line(s) skipped.` + + server.Listen() + go server.Run() + defer server.Close() + + setting.netSpecificMode = "shorten" + setting.whoisServer = server.server.Addr().String() + + response := mockTelegramCall(t, "/whois AS6939", false) + assert.Equal(t, response, "```\n"+expectedResult+"\n```") +} + +func TestWebHandlerTelegramBotHelp(t *testing.T) { + response := mockTelegramCall(t, "/help", false) + if !strings.Contains(response, "/trace") { + t.Error("Did not get help message") + } +} + +func TestWebHandlerTelegramBotUnknownCommand(t *testing.T) { + response := mockTelegramCall(t, "/nonexistent", true) + assert.Equal(t, response, "") +} + +func TestWebHandlerTelegramBotNotCommand(t *testing.T) { + response := mockTelegramCall(t, "random chat message", true) + assert.Equal(t, response, "") +} + +func TestWebHandlerTelegramBotEmptyResponse(t *testing.T) { + server := WhoisServer{ + t: t, + expectedQuery: "AS6939", + response: "", + } + + server.Listen() + go server.Run() + defer server.Close() + + setting.netSpecificMode = "" + setting.whoisServer = server.server.Addr().String() + + response := mockTelegramCall(t, "/whois AS6939", false) + assert.Equal(t, response, "```\nempty result\n```") +} + +func TestWebHandlerTelegramBotTruncateLongResponse(t *testing.T) { + server := WhoisServer{ + t: t, + expectedQuery: "AS6939", + response: strings.Repeat("A", 65536), + } + + server.Listen() + go server.Run() + defer server.Close() + + setting.netSpecificMode = "" + setting.whoisServer = server.server.Addr().String() + + response := mockTelegramCall(t, "/whois AS6939", false) + assert.Equal(t, response, "```\n"+strings.Repeat("A", 4096)+"\n```") +} diff --git a/frontend/template_test.go b/frontend/template_test.go new file mode 100644 index 0000000..87f9725 --- /dev/null +++ b/frontend/template_test.go @@ -0,0 +1,25 @@ +package main + +import ( + "testing" + + "github.com/magiconair/properties/assert" +) + +func TestSummaryRowDataNameHasPrefix(t *testing.T) { + data := SummaryRowData{ + Name: "mock", + } + + assert.Equal(t, data.NameHasPrefix("m"), true) + assert.Equal(t, data.NameHasPrefix("n"), false) +} + +func TestSummaryRowDataNameContains(t *testing.T) { + data := SummaryRowData{ + Name: "mock", + } + + assert.Equal(t, data.NameContains("oc"), true) + assert.Equal(t, data.NameContains("no"), false) +} diff --git a/frontend/webserver_test.go b/frontend/webserver_test.go new file mode 100644 index 0000000..1e9862a --- /dev/null +++ b/frontend/webserver_test.go @@ -0,0 +1,89 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/jarcoal/httpmock" + "github.com/magiconair/properties/assert" +) + +func TestServerError(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/error", nil) + w := httptest.NewRecorder() + serverError(w, r) + assert.Equal(t, w.Code, http.StatusInternalServerError) +} + +func TestWebHandlerWhois(t *testing.T) { + server := WhoisServer{ + t: t, + expectedQuery: "AS6939", + response: AS6939Response, + } + + server.Listen() + go server.Run() + defer server.Close() + + setting.netSpecificMode = "" + setting.whoisServer = server.server.Addr().String() + + r := httptest.NewRequest(http.MethodGet, "/whois/AS6939", nil) + w := httptest.NewRecorder() + webHandlerWhois(w, r) + + assert.Equal(t, w.Code, http.StatusOK) + if !strings.Contains(w.Body.String(), "HURRICANE") { + t.Error("Body does not contain whois result") + } +} + +func TestWebBackendCommunicator(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + input := readDataFile(t, "frontend/test_data/bgpmap_case1.txt") + httpResponse := httpmock.NewStringResponder(200, input) + httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show route for 1.1.1.1 all"), httpResponse) + + setting.servers = []string{"alpha"} + setting.domain = "" + setting.proxyPort = 8000 + setting.dnsInterface = "" + setting.whoisServer = "" + + r := httptest.NewRequest(http.MethodGet, "/route_bgpmap/alpha/1.1.1.1", nil) + w := httptest.NewRecorder() + + handler := webBackendCommunicator("bird", "route_all") + handler(w, r) + + assert.Equal(t, w.Code, http.StatusOK) +} + +func TestWebHandlerBGPMap(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + input := readDataFile(t, "frontend/test_data/bgpmap_case1.txt") + httpResponse := httpmock.NewStringResponder(200, input) + httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show route for 1.1.1.1 all"), httpResponse) + + setting.servers = []string{"alpha"} + setting.domain = "" + setting.proxyPort = 8000 + setting.dnsInterface = "" + setting.whoisServer = "" + + r := httptest.NewRequest(http.MethodGet, "/route_bgpmap/alpha/1.1.1.1", nil) + w := httptest.NewRecorder() + + handler := webHandlerBGPMap("bird", "route_bgpmap") + handler(w, r) + + assert.Equal(t, w.Code, http.StatusOK) +} diff --git a/frontend/whois.go b/frontend/whois.go index 76735a6..1bce5b7 100644 --- a/frontend/whois.go +++ b/frontend/whois.go @@ -6,6 +6,8 @@ import ( "os/exec" "strings" "time" + + "github.com/google/shlex" ) // Send a whois request @@ -15,7 +17,13 @@ func whois(s string) string { } if strings.HasPrefix(setting.whoisServer, "/") { - cmd := exec.Command(setting.whoisServer, s) + args, err := shlex.Split(setting.whoisServer) + if err != nil { + return err.Error() + } + args = append(args, s) + + cmd := exec.Command(args[0], args[1:]...) output, err := cmd.CombinedOutput() if err != nil { return err.Error() @@ -26,7 +34,13 @@ func whois(s string) string { return string(output) } else { buf := make([]byte, 65536) - conn, err := net.DialTimeout("tcp", setting.whoisServer+":43", 5*time.Second) + + whoisServer := setting.whoisServer + if !strings.Contains(whoisServer, ":") { + whoisServer = whoisServer + ":43" + } + + conn, err := net.DialTimeout("tcp", whoisServer, 5*time.Second) if err != nil { return err.Error() } @@ -35,7 +49,7 @@ func whois(s string) string { conn.Write([]byte(s + "\r\n")) n, err := io.ReadFull(conn, buf) - if err != nil && err != io.ErrUnexpectedEOF { + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { return err.Error() } return string(buf[:n]) diff --git a/frontend/whois_test.go b/frontend/whois_test.go index 3a85a6c..d6c0f88 100644 --- a/frontend/whois_test.go +++ b/frontend/whois_test.go @@ -1,14 +1,78 @@ package main import ( + "bufio" + "net" "strings" "testing" ) -func TestWhois(t *testing.T) { - checkNetwork(t) +type WhoisServer struct { + t *testing.T + expectedQuery string + response string + server net.Listener +} - setting.whoisServer = "whois.arin.net" +const AS6939Response = ` +ASNumber: 6939 +ASName: HURRICANE +ASHandle: AS6939 +RegDate: 1996-06-28 +Updated: 2003-11-04 +Ref: https://rdap.arin.net/registry/autnum/6939 +` + +func (s *WhoisServer) Listen() { + var err error + s.server, err = net.Listen("tcp", "127.0.0.1:0") + if err != nil { + s.t.Error(err) + } +} + +func (s *WhoisServer) Run() { + for { + conn, err := s.server.Accept() + if err != nil { + break + } + if conn == nil { + break + } + + reader := bufio.NewReader(conn) + query, err := reader.ReadBytes('\n') + if err != nil { + break + } + if strings.TrimSpace(string(query)) != s.expectedQuery { + s.t.Errorf("Query %s doesn't match expectation %s", string(query), s.expectedQuery) + } + conn.Write([]byte(s.response)) + conn.Close() + } +} + +func (s *WhoisServer) Close() { + if s.server == nil { + return + } + s.server.Close() +} + +func TestWhois(t *testing.T) { + server := WhoisServer{ + t: t, + expectedQuery: "AS6939", + response: AS6939Response, + } + + server.Listen() + go server.Run() + defer server.Close() + + setting.whoisServer = server.server.Addr().String() result := whois("AS6939") if !strings.Contains(result, "HURRICANE") { t.Errorf("Whois AS6939 failed, got %s", result) @@ -22,3 +86,43 @@ func TestWhoisWithoutServer(t *testing.T) { t.Errorf("Whois AS6939 without server produced output, got %s", result) } } + +func TestWhoisConnectionError(t *testing.T) { + setting.whoisServer = "127.0.0.1:0" + result := whois("AS6939") + if !strings.Contains(result, "connect: connection refused") { + t.Errorf("Whois AS6939 without server produced output, got %s", result) + } +} + +func TestWhoisHostProcess(t *testing.T) { + setting.whoisServer = "/bin/sh -c \"echo Mock Result\"" + result := whois("AS6939") + if result != "Mock Result\n" { + t.Errorf("Whois didn't produce expected result, got %s", result) + } +} + +func TestWhoisHostProcessMalformedCommand(t *testing.T) { + setting.whoisServer = "/bin/sh -c \"mock" + result := whois("AS6939") + if result != "EOF found when expecting closing quote" { + t.Errorf("Whois didn't produce expected result, got %s", result) + } +} + +func TestWhoisHostProcessError(t *testing.T) { + setting.whoisServer = "/nonexistent" + result := whois("AS6939") + if !strings.Contains(result, "no such file or directory") { + t.Errorf("Whois didn't produce expected result, got %s", result) + } +} + +func TestWhoisHostProcessVeryLong(t *testing.T) { + setting.whoisServer = "/bin/sh -c \"for i in $(seq 1 131072); do printf 'A'; done\"" + result := whois("AS6939") + if len(result) != 65535 { + t.Errorf("Whois result incorrectly truncated, actual len %d", len(result)) + } +} diff --git a/proxy/bird.go b/proxy/bird.go index 70a16d7..6276fa9 100644 --- a/proxy/bird.go +++ b/proxy/bird.go @@ -8,19 +8,23 @@ import ( "strings" ) +const MAX_LINE_SIZE = 1024 + // Read a line from bird socket, removing preceding status number, output it. // Returns if there are more lines. func birdReadln(bird io.Reader, w io.Writer) bool { // Read from socket byte by byte, until reaching newline character - c := make([]byte, 1024, 1024) + c := make([]byte, MAX_LINE_SIZE) pos := 0 for { - if pos >= 1024 { + // Leave one byte for newline character + if pos >= MAX_LINE_SIZE-1 { break } _, err := bird.Read(c[pos : pos+1]) if err != nil { - panic(err) + w.Write([]byte(err.Error())) + return false } if c[pos] == byte('\n') { break @@ -29,6 +33,7 @@ func birdReadln(bird io.Reader, w io.Writer) bool { } c = c[:pos+1] + c[pos] = '\n' // print(string(c[:])) // Remove preceding status number, different situations diff --git a/proxy/bird_test.go b/proxy/bird_test.go new file mode 100644 index 0000000..9dfeaeb --- /dev/null +++ b/proxy/bird_test.go @@ -0,0 +1,213 @@ +package main + +import ( + "bufio" + "bytes" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + "path" + "strings" + "testing" + + "github.com/magiconair/properties/assert" +) + +type BirdServer struct { + t *testing.T + expectedQuery string + response string + server net.Listener + socket string + injectError string +} + +func (s *BirdServer) initSocket() { + tmpDir, err := ioutil.TempDir("", "bird-lgproxy-go-mock") + if err != nil { + s.t.Fatal(err) + } + s.socket = path.Join(tmpDir, "mock.socket") +} + +func (s *BirdServer) Listen() { + s.initSocket() + + var err error + s.server, err = net.Listen("unix", s.socket) + if err != nil { + s.t.Error(err) + } +} + +func (s *BirdServer) Run() { + for { + conn, err := s.server.Accept() + if err != nil { + break + } + if conn == nil { + break + } + + reader := bufio.NewReader(conn) + + conn.Write([]byte("1234 Hello from mock bird\n")) + + query, err := reader.ReadBytes('\n') + if err != nil { + break + } + if strings.TrimSpace(string(query)) != "restrict" { + s.t.Errorf("Did not restrict bird permissions") + } + if s.injectError == "restriction" { + conn.Write([]byte("1234 Restriction is disabled!\n")) + } else { + conn.Write([]byte("1234 Access restricted\n")) + } + + query, err = reader.ReadBytes('\n') + if err != nil { + break + } + if strings.TrimSpace(string(query)) != s.expectedQuery { + s.t.Errorf("Query %s doesn't match expectation %s", string(query), s.expectedQuery) + } + + responseList := strings.Split(s.response, "\n") + for i := range responseList { + if i == len(responseList)-1 { + if s.injectError == "eof" { + conn.Write([]byte("0000 " + responseList[i])) + } else { + conn.Write([]byte("0000 " + responseList[i] + "\n")) + } + } else { + conn.Write([]byte("1234 " + responseList[i] + "\n")) + } + } + + conn.Close() + } +} + +func (s *BirdServer) Close() { + if s.server == nil { + return + } + s.server.Close() +} + +func TestBirdReadln(t *testing.T) { + input := strings.NewReader("1234 Bird Message\n") + var output bytes.Buffer + birdReadln(input, &output) + + assert.Equal(t, output.String(), "Bird Message\n") +} + +func TestBirdReadlnNoPrefix(t *testing.T) { + input := strings.NewReader(" Message without prefix\n") + var output bytes.Buffer + birdReadln(input, &output) + + assert.Equal(t, output.String(), "Message without prefix\n") +} + +func TestBirdReadlnVeryLongLine(t *testing.T) { + input := strings.NewReader(strings.Repeat("A", 4096)) + var output bytes.Buffer + birdReadln(input, &output) + + assert.Equal(t, output.String(), strings.Repeat("A", 1022)+"\n") +} + +func TestBirdWriteln(t *testing.T) { + var output bytes.Buffer + birdWriteln(&output, "Test command") + assert.Equal(t, output.String(), "Test command\n") +} + +func TestBirdHandlerWithoutQuery(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/bird", nil) + w := httptest.NewRecorder() + birdHandler(w, r) +} + +func TestBirdHandlerWithQuery(t *testing.T) { + server := BirdServer{ + t: t, + expectedQuery: "show protocols", + response: "Mock Response\nSecond Line", + injectError: "", + } + + server.Listen() + go server.Run() + defer server.Close() + + setting.birdSocket = server.socket + + r := httptest.NewRequest(http.MethodGet, "/bird?q="+url.QueryEscape(server.expectedQuery), nil) + w := httptest.NewRecorder() + birdHandler(w, r) + + assert.Equal(t, w.Code, http.StatusOK) + assert.Equal(t, w.Body.String(), server.response+"\n") +} + +func TestBirdHandlerWithBadSocket(t *testing.T) { + setting.birdSocket = "/nonexistent.sock" + + r := httptest.NewRequest(http.MethodGet, "/bird?q="+url.QueryEscape("mock"), nil) + w := httptest.NewRecorder() + birdHandler(w, r) + + assert.Equal(t, w.Code, http.StatusInternalServerError) +} + +func TestBirdHandlerWithoutRestriction(t *testing.T) { + server := BirdServer{ + t: t, + expectedQuery: "show protocols", + response: "Mock Response", + injectError: "restriction", + } + + server.Listen() + go server.Run() + defer server.Close() + + setting.birdSocket = server.socket + + r := httptest.NewRequest(http.MethodGet, "/bird?q="+url.QueryEscape("mock"), nil) + w := httptest.NewRecorder() + birdHandler(w, r) + + assert.Equal(t, w.Code, http.StatusInternalServerError) +} + +func TestBirdHandlerEOF(t *testing.T) { + server := BirdServer{ + t: t, + expectedQuery: "show protocols", + response: "Mock Response\nSecond Line", + injectError: "eof", + } + + server.Listen() + go server.Run() + defer server.Close() + + setting.birdSocket = server.socket + + r := httptest.NewRequest(http.MethodGet, "/bird?q="+url.QueryEscape("show protocols"), nil) + w := httptest.NewRecorder() + birdHandler(w, r) + + assert.Equal(t, w.Code, http.StatusOK) + assert.Equal(t, w.Body.String(), "Mock Response\nEOF") +} diff --git a/proxy/main.go b/proxy/main.go index 24a8b96..07ddc65 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -21,39 +21,49 @@ func invalidHandler(httpW http.ResponseWriter, httpR *http.Request) { httpW.Write([]byte("Invalid Request\n")) } +func hasAccess(remoteAddr string) bool { + // setting.allowedIPs will always have at least one element because of how it's defined + if len(setting.allowedIPs) == 0 { + return true + } + + if !strings.Contains(remoteAddr, ":") { + return false + } + + // Remove port from IP and remove brackets that are around IPv6 addresses + remoteAddr = remoteAddr[0:strings.LastIndex(remoteAddr, ":")] + remoteAddr = strings.Trim(remoteAddr, "[]") + + ipObject := net.ParseIP(remoteAddr) + if ipObject == nil { + return false + } + + for _, allowedIP := range setting.allowedIPs { + if ipObject.Equal(allowedIP) { + return true + } + } + + return false +} + // Access handler, check to see if client IP in allowed IPs, continue if it is, send to invalidHandler if not func accessHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(httpW http.ResponseWriter, httpR *http.Request) { - - // setting.allowedIPs will always have at least one element because of how it's defined - if setting.allowedIPs[0] == "" { + if hasAccess(httpR.RemoteAddr) { next.ServeHTTP(httpW, httpR) - return + } else { + invalidHandler(httpW, httpR) } - - IPPort := httpR.RemoteAddr - - // Remove port from IP and remove brackets that are around IPv6 addresses - requestIp := IPPort[0:strings.LastIndex(IPPort, ":")] - requestIp = strings.Replace(requestIp, "[", "", -1) - requestIp = strings.Replace(requestIp, "]", "", -1) - - for _, allowedIP := range setting.allowedIPs { - if requestIp == allowedIP { - next.ServeHTTP(httpW, httpR) - return - } - } - - invalidHandler(httpW, httpR) - return }) } type settingType struct { birdSocket string listen string - allowedIPs []string + allowedIPs []net.IP tr_bin string tr_flags []string tr_raw bool diff --git a/proxy/main_test.go b/proxy/main_test.go new file mode 100644 index 0000000..1f33449 --- /dev/null +++ b/proxy/main_test.go @@ -0,0 +1,78 @@ +package main + +import ( + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/magiconair/properties/assert" +) + +func TestHasAccessNotConfigured(t *testing.T) { + setting.allowedIPs = []net.IP{} + assert.Equal(t, hasAccess("whatever"), true) +} + +func TestHasAccessAllowIPv4(t *testing.T) { + setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")} + assert.Equal(t, hasAccess("1.2.3.4:4321"), true) +} + +func TestHasAccessDenyIPv4(t *testing.T) { + setting.allowedIPs = []net.IP{net.ParseIP("4.3.2.1")} + assert.Equal(t, hasAccess("1.2.3.4:4321"), false) +} + +func TestHasAccessAllowIPv6(t *testing.T) { + setting.allowedIPs = []net.IP{net.ParseIP("2001:db8::1")} + assert.Equal(t, hasAccess("[2001:db8::1]:4321"), true) +} + +func TestHasAccessAllowIPv6DifferentForm(t *testing.T) { + setting.allowedIPs = []net.IP{net.ParseIP("2001:0db8::1")} + assert.Equal(t, hasAccess("[2001:db8::1]:4321"), true) +} + +func TestHasAccessDenyIPv6(t *testing.T) { + setting.allowedIPs = []net.IP{net.ParseIP("2001:db8::2")} + assert.Equal(t, hasAccess("[2001:db8::1]:4321"), false) +} + +func TestHasAccessBadClientIP(t *testing.T) { + setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")} + assert.Equal(t, hasAccess("not an IP"), false) +} + +func TestHasAccessBadClientIPPort(t *testing.T) { + setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")} + assert.Equal(t, hasAccess("not an IP:not a port"), false) +} + +func TestAccessHandlerAllow(t *testing.T) { + baseHandler := http.NotFoundHandler() + wrappedHandler := accessHandler(baseHandler) + + r := httptest.NewRequest(http.MethodGet, "/mock", nil) + r.RemoteAddr = "1.2.3.4:4321" + w := httptest.NewRecorder() + + setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")} + + wrappedHandler.ServeHTTP(w, r) + assert.Equal(t, w.Code, http.StatusNotFound) +} + +func TestAccessHandlerDeny(t *testing.T) { + baseHandler := http.NotFoundHandler() + wrappedHandler := accessHandler(baseHandler) + + r := httptest.NewRequest(http.MethodGet, "/mock", nil) + r.RemoteAddr = "1.2.3.4:4321" + w := httptest.NewRecorder() + + setting.allowedIPs = []net.IP{net.ParseIP("4.3.2.1")} + + wrappedHandler.ServeHTTP(w, r) + assert.Equal(t, w.Code, http.StatusInternalServerError) +} diff --git a/proxy/settings.go b/proxy/settings.go index 0129cc0..098f4a5 100644 --- a/proxy/settings.go +++ b/proxy/settings.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "net" "strings" "github.com/google/shlex" @@ -66,9 +67,17 @@ func parseSettings() { setting.listen = viperSettings.Listen if viperSettings.AllowedIPs != "" { - setting.allowedIPs = strings.Split(viperSettings.AllowedIPs, ",") + for _, ip := range strings.Split(viperSettings.AllowedIPs, ",") { + ipObject := net.ParseIP(ip) + if ipObject == nil { + fmt.Printf("Parse IP %s failed\n", ip) + continue + } + + setting.allowedIPs = append(setting.allowedIPs, ipObject) + } } else { - setting.allowedIPs = []string{""} + setting.allowedIPs = []net.IP{} } var err error diff --git a/proxy/settings_test.go b/proxy/settings_test.go new file mode 100644 index 0000000..e42b6a4 --- /dev/null +++ b/proxy/settings_test.go @@ -0,0 +1,8 @@ +package main + +import "testing" + +func TestParseSettings(t *testing.T) { + parseSettings() + // Good as long as it doesn't panic +} diff --git a/proxy/traceroute_test.go b/proxy/traceroute_test.go new file mode 100644 index 0000000..d245248 --- /dev/null +++ b/proxy/traceroute_test.go @@ -0,0 +1,168 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + + "github.com/magiconair/properties/assert" +) + +func TestTracerouteArgsToString(t *testing.T) { + result := tracerouteArgsToString("traceroute", []string{ + "-a", + "-b", + "-c", + }, []string{ + "google.com", + }) + + assert.Equal(t, result, "traceroute -a -b -c google.com") +} + +func TestTracerouteTryExecuteSuccess(t *testing.T) { + _, err := tracerouteTryExecute("sh", []string{ + "-c", + }, []string{ + "true", + }) + + if err != nil { + t.Error(err) + } +} + +func TestTracerouteTryExecuteFail(t *testing.T) { + _, err := tracerouteTryExecute("sh", []string{ + "-c", + }, []string{ + "false", + }) + + if err == nil { + t.Error("Should trigger error, not triggered") + } +} + +func TestTracerouteDetectSuccess(t *testing.T) { + result := tracerouteDetect("sh", []string{ + "-c", + "true", + }) + + assert.Equal(t, result, true) +} + +func TestTracerouteDetectFail(t *testing.T) { + result := tracerouteDetect("sh", []string{ + "-c", + "false", + }) + + assert.Equal(t, result, false) +} + +func TestTracerouteAutodetect(t *testing.T) { + pathBackup := os.Getenv("PATH") + os.Setenv("PATH", "") + defer os.Setenv("PATH", pathBackup) + + setting.tr_bin = "" + setting.tr_flags = []string{} + tracerouteAutodetect() + // Should not panic +} + +func TestTracerouteAutodetectExisting(t *testing.T) { + setting.tr_bin = "mock" + setting.tr_flags = []string{"mock"} + tracerouteAutodetect() + assert.Equal(t, setting.tr_bin, "mock") + assert.Equal(t, setting.tr_flags, []string{"mock"}) +} + +func TestTracerouteAutodetectFlagsOnly(t *testing.T) { + pathBackup := os.Getenv("PATH") + os.Setenv("PATH", "") + defer os.Setenv("PATH", pathBackup) + + setting.tr_bin = "mock" + setting.tr_flags = nil + tracerouteAutodetect() + + // Should not panic +} + +func TestTracerouteHandlerWithoutQuery(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/traceroute", nil) + w := httptest.NewRecorder() + tracerouteHandler(w, r) + assert.Equal(t, w.Code, http.StatusInternalServerError) + if !strings.Contains(w.Body.String(), "Invalid Request") { + t.Error("Did not get invalid request") + } +} + +func TestTracerouteHandlerShlexError(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("\"1.1.1.1"), nil) + w := httptest.NewRecorder() + tracerouteHandler(w, r) + assert.Equal(t, w.Code, http.StatusInternalServerError) + if !strings.Contains(w.Body.String(), "parse") { + t.Error("Did not get parsing error message") + } +} + +func TestTracerouteHandlerNoTracerouteFound(t *testing.T) { + setting.tr_bin = "" + setting.tr_flags = nil + + r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("1.1.1.1"), nil) + w := httptest.NewRecorder() + tracerouteHandler(w, r) + assert.Equal(t, w.Code, http.StatusInternalServerError) + if !strings.Contains(w.Body.String(), "not supported") { + t.Error("Did not get not supported error message") + } +} + +func TestTracerouteHandlerExecuteError(t *testing.T) { + setting.tr_bin = "sh" + setting.tr_flags = []string{"-c", "false"} + setting.tr_raw = true + + r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("1.1.1.1"), nil) + w := httptest.NewRecorder() + tracerouteHandler(w, r) + assert.Equal(t, w.Code, http.StatusInternalServerError) + if !strings.Contains(w.Body.String(), "Error executing traceroute") { + t.Error("Did not get not execute error message") + } +} + +func TestTracerouteHandlerRaw(t *testing.T) { + setting.tr_bin = "sh" + setting.tr_flags = []string{"-c", "echo Mock"} + setting.tr_raw = true + + r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("1.1.1.1"), nil) + w := httptest.NewRecorder() + tracerouteHandler(w, r) + assert.Equal(t, w.Code, http.StatusOK) + assert.Equal(t, w.Body.String(), "Mock\n") +} + +func TestTracerouteHandlerPostprocess(t *testing.T) { + setting.tr_bin = "sh" + setting.tr_flags = []string{"-c", "echo \"first line\n 2 *\nthird line\""} + setting.tr_raw = false + + r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("1.1.1.1"), nil) + w := httptest.NewRecorder() + tracerouteHandler(w, r) + assert.Equal(t, w.Code, http.StatusOK) + assert.Equal(t, w.Body.String(), "first line\nthird line\n\n1 hops not responding.") +}