general: add unit tests for >80% coverage

Includes a few minor fixes:
- frontend: support setting port for WHOIS server
- proxy: fix handling of very long lines
- proxy: refactor IP allowlist logic, parse allow IP list at startup
This commit is contained in:
Lan Tian 2023-05-06 00:23:28 -07:00
parent ccd14af0c8
commit a0246ccee2
No known key found for this signature in database
GPG Key ID: 04E66B6B25A0862B
24 changed files with 1576 additions and 65 deletions

View File

@ -113,7 +113,7 @@ func apiHandler(w http.ResponseWriter, r *http.Request) {
} else {
handler := apiHandlerMap[request.Type]
if handler == nil {
response = apiErrorHandler(errors.New("Invalid request type"))
response = apiErrorHandler(errors.New("invalid request type"))
} else {
response = handler(request)
}

207
frontend/api_test.go Normal file
View File

@ -0,0 +1,207 @@
package main
import (
"bytes"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/jarcoal/httpmock"
"github.com/magiconair/properties/assert"
)
func TestApiServerListHandler(t *testing.T) {
setting.servers = []string{"alpha", "beta", "gamma"}
response := apiServerListHandler(apiRequest{})
assert.Equal(t, len(response.Result), 3)
assert.Equal(t, response.Result[0].(apiGenericResultPair).Server, "alpha")
assert.Equal(t, response.Result[1].(apiGenericResultPair).Server, "beta")
assert.Equal(t, response.Result[2].(apiGenericResultPair).Server, "gamma")
}
func TestApiGenericHandlerFactory(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, BirdSummaryData)
httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show protocols"), httpResponse)
setting.servers = []string{"alpha"}
setting.domain = ""
setting.proxyPort = 8000
request := apiRequest{
Servers: setting.servers,
Type: "bird",
Args: "show protocols",
}
handler := apiGenericHandlerFactory("bird")
response := handler(request)
assert.Equal(t, response.Error, "")
result := response.Result[0].(*apiGenericResultPair)
assert.Equal(t, result.Server, "alpha")
assert.Equal(t, result.Data, BirdSummaryData)
}
func TestApiSummaryHandler(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, BirdSummaryData)
httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show protocols"), httpResponse)
setting.servers = []string{"alpha"}
setting.domain = ""
setting.proxyPort = 8000
request := apiRequest{
Servers: setting.servers,
Type: "summary",
Args: "",
}
response := apiSummaryHandler(request)
assert.Equal(t, response.Error, "")
summary := response.Result[0].(*apiSummaryResultPair)
assert.Equal(t, summary.Server, "alpha")
// Protocol list will be sorted
assert.Equal(t, summary.Data[1].Name, "device1")
assert.Equal(t, summary.Data[1].Proto, "Device")
assert.Equal(t, summary.Data[1].Table, "---")
assert.Equal(t, summary.Data[1].State, "up")
assert.Equal(t, summary.Data[1].Since, "2021-08-27")
assert.Equal(t, summary.Data[1].Info, "")
}
func TestApiSummaryHandlerError(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, "Mock backend error")
httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show protocols"), httpResponse)
setting.servers = []string{"alpha"}
setting.domain = ""
setting.proxyPort = 8000
request := apiRequest{
Servers: setting.servers,
Type: "summary",
Args: "",
}
response := apiSummaryHandler(request)
assert.Equal(t, response.Error, "Mock backend error")
}
func TestApiWhoisHandler(t *testing.T) {
expectedData := "Mock Data"
server := WhoisServer{
t: t,
expectedQuery: "AS6939",
response: expectedData,
}
server.Listen()
go server.Run()
defer server.Close()
setting.whoisServer = server.server.Addr().String()
request := apiRequest{
Servers: []string{},
Type: "",
Args: "AS6939",
}
response := apiWhoisHandler(request)
assert.Equal(t, response.Error, "")
whoisResult := response.Result[0].(apiGenericResultPair)
assert.Equal(t, whoisResult.Server, "")
assert.Equal(t, whoisResult.Data, expectedData)
}
func TestApiErrorHandler(t *testing.T) {
err := errors.New("Mock Error")
response := apiErrorHandler(err)
assert.Equal(t, response.Error, "Mock Error")
}
func TestApiHandler(t *testing.T) {
setting.servers = []string{"alpha", "beta", "gamma"}
request := apiRequest{
Servers: []string{},
Type: "server_list",
Args: "",
}
requestJson, err := json.Marshal(request)
if err != nil {
t.Error(err)
}
r := httptest.NewRequest(http.MethodGet, "/api", bytes.NewReader(requestJson))
w := httptest.NewRecorder()
apiHandler(w, r)
var response apiResponse
err = json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Error(err)
}
assert.Equal(t, len(response.Result), 3)
// Hard to unmarshal JSON into apiGenericResultPair objects, won't check here
}
func TestApiHandlerBadJSON(t *testing.T) {
setting.servers = []string{"alpha", "beta", "gamma"}
r := httptest.NewRequest(http.MethodGet, "/api", strings.NewReader("{bad json}"))
w := httptest.NewRecorder()
apiHandler(w, r)
var response apiResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Error(err)
}
assert.Equal(t, len(response.Result), 0)
}
func TestApiHandlerInvalidType(t *testing.T) {
setting.servers = []string{"alpha", "beta", "gamma"}
request := apiRequest{
Servers: setting.servers,
Type: "invalid_type",
Args: "",
}
requestJson, err := json.Marshal(request)
if err != nil {
t.Error(err)
}
r := httptest.NewRequest(http.MethodGet, "/api", bytes.NewReader(requestJson))
w := httptest.NewRecorder()
apiHandler(w, r)
var response apiResponse
err = json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Error(err)
}
assert.Equal(t, len(response.Result), 0)
}

View File

@ -3,6 +3,8 @@ package main
import (
"strings"
"testing"
"github.com/magiconair/properties/assert"
)
func TestGetASNRepresentationDNS(t *testing.T) {
@ -34,7 +36,5 @@ func TestGetASNRepresentationFallback(t *testing.T) {
setting.whoisServer = ""
cache := make(ASNCache)
result := cache.Lookup("6939")
if result != "AS6939" {
t.Errorf("Lookup AS6939 failed, got %s", result)
}
assert.Equal(t, result, "AS6939")
}

View File

@ -8,14 +8,14 @@ import (
"testing"
)
func readDataFile(filename string) string {
func readDataFile(t *testing.T, filename string) string {
_, sourceName, _, _ := runtime.Caller(0)
projectRoot := path.Join(path.Dir(sourceName), "..")
dir := path.Join(projectRoot, filename)
data, err := ioutil.ReadFile(dir)
if err != nil {
panic(err)
t.Fatal(err)
}
return string(data)
}
@ -41,7 +41,7 @@ func TestBirdRouteToGraphvizXSS(t *testing.T) {
func TestBirdRouteToGraph(t *testing.T) {
setting.dnsInterface = ""
input := readDataFile("frontend/test_data/bgpmap_case1.txt")
input := readDataFile(t, "frontend/test_data/bgpmap_case1.txt")
result := birdRouteToGraph([]string{"node"}, []string{input}, "target")
// Source node must exist
@ -71,3 +71,14 @@ func TestBirdRouteToGraph(t *testing.T) {
t.Error("Result doesn't contain edge from 4242423914 to target")
}
}
func TestBirdRouteToGraphviz(t *testing.T) {
setting.dnsInterface = ""
input := readDataFile(t, "frontend/test_data/bgpmap_case1.txt")
result := birdRouteToGraphviz([]string{"node"}, []string{input}, "target")
if !strings.Contains(result, "digraph {") {
t.Error("Response is not Graphviz data")
}
}

View File

@ -65,7 +65,7 @@ func shortenWhoisFilter(whois string) string {
shouldSkip := false
shouldSkip = shouldSkip || len(s) == 0
shouldSkip = shouldSkip || len(s) > 0 && s[0] == '#'
shouldSkip = shouldSkip || strings.Contains(strings.ToUpper(s), "REDACTED FOR PRIVACY")
shouldSkip = shouldSkip || strings.Contains(strings.ToUpper(s), "REDACTED")
if shouldSkip {
skippedLinesLonger++

View File

@ -28,3 +28,79 @@ func TestDN42WhoisFilterUnneeded(t *testing.T) {
t.Errorf("Output doesn't match expected: %s", result)
}
}
func TestShortenWhoisFilterShorterMode(t *testing.T) {
input := `
Information line that will be removed
# Comment that will be removed
Name: Redacted for privacy
Descr: This is a vvvvvvvvvvvvvvvvvvvvvvveeeeeeeeeeeeeeeeeeeerrrrrrrrrrrrrrrrrrrrrrrryyyyyyyyyyyyyyyyyyy long line that will be skipped.
Looooooooooooooooooooooong key: this line will be skipped.
Preserved1: this line isn't removed.
Preserved2: this line isn't removed.
Preserved3: this line isn't removed.
Preserved4: this line isn't removed.
Preserved5: this line isn't removed.
`
result := shortenWhoisFilter(input)
expectedResult := `Preserved1: this line isn't removed.
Preserved2: this line isn't removed.
Preserved3: this line isn't removed.
Preserved4: this line isn't removed.
Preserved5: this line isn't removed.
3 line(s) skipped.
`
if result != expectedResult {
t.Errorf("Output doesn't match expected: %s", result)
}
}
func TestShortenWhoisFilterLongerMode(t *testing.T) {
input := `
Information line that will be removed
# Comment that will be removed
Name: Redacted for privacy
Descr: This is a vvvvvvvvvvvvvvvvvvvvvvveeeeeeeeeeeeeeeeeeeerrrrrrrrrrrrrrrrrrrrrrrryyyyyyyyyyyyyyyyyyy long line that will be skipped.
Looooooooooooooooooooooong key: this line will be skipped.
Preserved1: this line isn't removed.
`
result := shortenWhoisFilter(input)
expectedResult := `Information line that will be removed
Descr: This is a vvvvvvvvvvvvvvvvvvvvvvveeeeeeeeeeeeeeeeeeeerrrrrrrrrrrrrrrrrrrrrrrryyyyyyyyyyyyyyyyyyy long line that will be skipped.
Looooooooooooooooooooooong key: this line will be skipped.
Preserved1: this line isn't removed.
7 line(s) skipped.
`
if result != expectedResult {
t.Errorf("Output doesn't match expected: %s", result)
}
}
func TestShortenWhoisFilterSkipNothing(t *testing.T) {
input := `Preserved1: this line isn't removed.
Preserved2: this line isn't removed.
Preserved3: this line isn't removed.
Preserved4: this line isn't removed.
Preserved5: this line isn't removed.
`
result := shortenWhoisFilter(input)
if result != input {
t.Errorf("Output doesn't match expected: %s", result)
}
}

View File

@ -11,7 +11,9 @@ require (
require (
github.com/felixge/httpsnoop v1.0.3 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jarcoal/httpmock v1.3.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/pelletier/go-toml/v2 v2.0.6 // indirect

View File

@ -533,6 +533,8 @@ github.com/google/pprof v0.0.0-20210601050228-01bbb1931b22/go.mod h1:kpwsk12EmLe
github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.0.0-20220520183353-fd19c99a87aa/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8=
@ -588,6 +590,8 @@ github.com/hashicorp/memberlist v0.5.0/go.mod h1:yvyXLpo0QaGE59Y7hDTsTzDD25JYBZ4
github.com/hashicorp/serf v0.10.1/go.mod h1:yL2t6BqATOLGc5HF7qbFkTfXoPIY0WZdWHfEvMqbG+4=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/jarcoal/httpmock v1.3.0 h1:2RJ8GP0IIaWwcC9Fp2BmVi8Kog3v2Hn7VXM3fTd+nuc=
github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=

View File

@ -56,8 +56,8 @@ func batchRequest(servers []string, endpoint string, command string) []string {
buf := make([]byte, 65536)
n, err := io.ReadFull(response.Body, buf)
if err != nil && err != io.ErrUnexpectedEOF {
ch <- channelData{i, err.Error()}
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
ch <- channelData{i, "request failed: " + err.Error()}
} else {
ch <- channelData{i, string(buf[:n])}
}

163
frontend/lgproxy_test.go Normal file
View File

@ -0,0 +1,163 @@
package main
import (
"errors"
"strings"
"testing"
"github.com/jarcoal/httpmock"
)
func TestBatchRequestIPv4(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, "Mock Result")
httpmock.RegisterResponder("GET", "http://1.1.1.1:8000/mock?q=cmd", httpResponse)
httpmock.RegisterResponder("GET", "http://2.2.2.2:8000/mock?q=cmd", httpResponse)
httpmock.RegisterResponder("GET", "http://3.3.3.3:8000/mock?q=cmd", httpResponse)
setting.servers = []string{
"1.1.1.1",
"2.2.2.2",
"3.3.3.3",
}
setting.domain = ""
setting.proxyPort = 8000
response := batchRequest(setting.servers, "mock", "cmd")
if len(response) != 3 {
t.Error("Did not get response of all three mock servers")
}
for i := 0; i < len(response); i++ {
if response[i] != "Mock Result" {
t.Error("HTTP response mismatch")
}
}
}
func TestBatchRequestIPv6(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, "Mock Result")
httpmock.RegisterResponder("GET", "http://[2001:db8::1]:8000/mock?q=cmd", httpResponse)
httpmock.RegisterResponder("GET", "http://[2001:db8::2]:8000/mock?q=cmd", httpResponse)
httpmock.RegisterResponder("GET", "http://[2001:db8::3]:8000/mock?q=cmd", httpResponse)
setting.servers = []string{
"2001:db8::1",
"2001:db8::2",
"2001:db8::3",
}
setting.domain = ""
setting.proxyPort = 8000
response := batchRequest(setting.servers, "mock", "cmd")
if len(response) != 3 {
t.Error("Did not get response of all three mock servers")
}
for i := 0; i < len(response); i++ {
if response[i] != "Mock Result" {
t.Error("HTTP response mismatch")
}
}
}
func TestBatchRequestEmptyResponse(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, "")
httpmock.RegisterResponder("GET", "http://alpha:8000/mock?q=cmd", httpResponse)
httpmock.RegisterResponder("GET", "http://beta:8000/mock?q=cmd", httpResponse)
httpmock.RegisterResponder("GET", "http://gamma:8000/mock?q=cmd", httpResponse)
setting.servers = []string{
"alpha",
"beta",
"gamma",
}
setting.domain = ""
setting.proxyPort = 8000
response := batchRequest(setting.servers, "mock", "cmd")
if len(response) != 3 {
t.Error("Did not get response of all three mock servers")
}
for i := 0; i < len(response); i++ {
if !strings.Contains(response[i], "node returned empty response") {
t.Error("Did not produce error for empty response")
}
}
}
func TestBatchRequestDomainSuffix(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, "Mock Result")
httpmock.RegisterResponder("GET", "http://alpha.suffix:8000/mock?q=cmd", httpResponse)
httpmock.RegisterResponder("GET", "http://beta.suffix:8000/mock?q=cmd", httpResponse)
httpmock.RegisterResponder("GET", "http://gamma.suffix:8000/mock?q=cmd", httpResponse)
setting.servers = []string{
"alpha",
"beta",
"gamma",
}
setting.domain = "suffix"
setting.proxyPort = 8000
response := batchRequest(setting.servers, "mock", "cmd")
if len(response) != 3 {
t.Error("Did not get response of all three mock servers")
}
for i := 0; i < len(response); i++ {
if response[i] != "Mock Result" {
t.Error("HTTP response mismatch")
}
}
}
func TestBatchRequestHTTPError(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpError := httpmock.NewErrorResponder(errors.New("Oops!"))
httpmock.RegisterResponder("GET", "http://alpha:8000/mock?q=cmd", httpError)
httpmock.RegisterResponder("GET", "http://beta:8000/mock?q=cmd", httpError)
httpmock.RegisterResponder("GET", "http://gamma:8000/mock?q=cmd", httpError)
setting.servers = []string{
"alpha",
"beta",
"gamma",
}
setting.domain = ""
setting.proxyPort = 8000
response := batchRequest(setting.servers, "mock", "cmd")
if len(response) != 3 {
t.Error("Did not get response of all three mock servers")
}
for i := 0; i < len(response); i++ {
if !strings.Contains(response[i], "request failed") {
t.Error("Did not produce HTTP error")
}
}
}
func TestBatchRequestInvalidServer(t *testing.T) {
setting.servers = []string{}
setting.domain = ""
setting.proxyPort = 8000
response := batchRequest([]string{"invalid"}, "mock", "cmd")
if len(response) != 1 {
t.Error("Did not get response of all mock servers")
}
if !strings.Contains(response[0], "invalid server") {
t.Error("Did not produce invalid server error")
}
}

View File

@ -8,6 +8,17 @@ import (
"testing"
)
const BirdSummaryData = `BIRD 2.0.8 ready.
Name Proto Table State Since Info
static1 Static master4 up 2021-08-27
static2 Static master6 up 2021-08-27
device1 Device --- up 2021-08-27
kernel1 Kernel master6 up 2021-08-27
kernel2 Kernel master4 up 2021-08-27
direct1 Direct --- up 2021-08-27
int_babel Babel --- up 2021-08-27
`
func initSettings() {
setting.servers = []string{"alpha"}
setting.serversDisplay = []string{"alpha"}
@ -101,17 +112,8 @@ func TestSummaryTableXSS(t *testing.T) {
func TestSummaryTableProtocolFilter(t *testing.T) {
initSettings()
setting.protocolFilter = []string{"Static", "Direct", "Babel"}
data := `BIRD 2.0.8 ready.
Name Proto Table State Since Info
static1 Static master4 up 2021-08-27
static2 Static master6 up 2021-08-27
device1 Device --- up 2021-08-27
kernel1 Kernel master6 up 2021-08-27
kernel2 Kernel master4 up 2021-08-27
direct1 Direct --- up 2021-08-27
int_babel Babel --- up 2021-08-27 `
result := string(summaryTable(data, "testserver"))
result := string(summaryTable(BirdSummaryData, "testserver"))
expectedInclude := []string{"static1", "static2", "int_babel", "direct1"}
expectedExclude := []string{"device1", "kernel1", "kernel2"}
@ -134,17 +136,8 @@ int_babel Babel --- up 2021-08-27 `
func TestSummaryTableNameFilter(t *testing.T) {
initSettings()
setting.nameFilter = "^static"
data := `BIRD 2.0.8 ready.
Name Proto Table State Since Info
static1 Static master4 up 2021-08-27
static2 Static master6 up 2021-08-27
device1 Device --- up 2021-08-27
kernel1 Kernel master6 up 2021-08-27
kernel2 Kernel master4 up 2021-08-27
direct1 Direct --- up 2021-08-27
int_babel Babel --- up 2021-08-27 `
result := string(summaryTable(data, "testserver"))
result := string(summaryTable(BirdSummaryData, "testserver"))
expectedInclude := []string{"device1", "kernel1", "kernel2", "direct1", "int_babel"}
expectedExclude := []string{"static1", "static2"}

View File

@ -0,0 +1,8 @@
package main
import "testing"
func TestParseSettings(t *testing.T) {
parseSettings()
// Good as long as it doesn't panic
}

View File

@ -1,12 +1,59 @@
package main
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/jarcoal/httpmock"
"github.com/magiconair/properties/assert"
)
func doTestTelegramIsCommand(t *testing.T, message string, command string, expected bool) {
if telegramIsCommand(message, command) != expected {
t.Errorf("telegramIsCommand(\"%s\", \"%s\") unexpected result", message, command)
result := telegramIsCommand(message, command)
assert.Equal(t, result, expected)
}
func mockTelegramCall(t *testing.T, msg string, raw bool) string {
return mockTelegramEndpointCall(t, "/telegram/", msg, raw)
}
func mockTelegramEndpointCall(t *testing.T, endpoint string, msg string, raw bool) string {
request := tgWebhookRequest{
Message: tgMessage{
MessageID: 123,
Chat: tgChat{
ID: 456,
},
Text: msg,
},
}
requestJson, err := json.Marshal(request)
if err != nil {
t.Fatal(err)
}
requestBody := bytes.NewReader(requestJson)
r := httptest.NewRequest(http.MethodGet, endpoint, requestBody)
w := httptest.NewRecorder()
webHandlerTelegramBot(w, r)
if raw {
return w.Body.String()
} else {
var response tgWebhookResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Error(err)
}
assert.Equal(t, response.ChatID, request.Message.Chat.ID)
assert.Equal(t, response.ReplyToMessageID, request.Message.MessageID)
return response.Text
}
}
@ -41,3 +88,280 @@ func TestTelegramIsCommand(t *testing.T) {
doTestTelegramIsCommand(t, "/trace@test_bot_123 google.com", "trace", false)
doTestTelegramIsCommand(t, "/trace@test google.com", "trace", false)
}
func TestTelegramBatchRequestFormatSingleServer(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, "Mock")
httpmock.RegisterResponder("GET", "http://alpha:8000/mock?q=cmd", httpResponse)
setting.servers = []string{"alpha"}
setting.domain = ""
setting.proxyPort = 8000
result := telegramBatchRequestFormat(setting.servers, "mock", "cmd", telegramDefaultPostProcess)
expected := "Mock\n\n"
assert.Equal(t, result, expected)
}
func TestTelegramBatchRequestFormatMultipleServers(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, "Mock")
httpmock.RegisterResponder("GET", "http://alpha:8000/mock?q=cmd", httpResponse)
httpmock.RegisterResponder("GET", "http://beta:8000/mock?q=cmd", httpResponse)
httpmock.RegisterResponder("GET", "http://gamma:8000/mock?q=cmd", httpResponse)
setting.servers = []string{
"alpha",
"beta",
"gamma",
}
setting.domain = ""
setting.proxyPort = 8000
result := telegramBatchRequestFormat(setting.servers, "mock", "cmd", telegramDefaultPostProcess)
expected := "alpha\nMock\n\nbeta\nMock\n\ngamma\nMock\n\n"
assert.Equal(t, result, expected)
}
func TestWebHandlerTelegramBotBadJSON(t *testing.T) {
requestBody := strings.NewReader("{bad json}")
r := httptest.NewRequest(http.MethodGet, "/telegram/", requestBody)
w := httptest.NewRecorder()
webHandlerTelegramBot(w, r)
response := w.Body.String()
assert.Equal(t, response, "")
}
func TestWebHandlerTelegramBotTrace(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, "Mock Response")
httpmock.RegisterResponder("GET", "http://alpha:8000/traceroute?q=1.1.1.1", httpResponse)
setting.servers = []string{"alpha"}
setting.domain = ""
setting.proxyPort = 8000
response := mockTelegramCall(t, "/trace 1.1.1.1", false)
assert.Equal(t, response, "```\nMock Response\n```")
}
func TestWebHandlerTelegramBotTraceWithServerList(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, "Mock Response")
httpmock.RegisterResponder("GET", "http://alpha:8000/traceroute?q=1.1.1.1", httpResponse)
setting.servers = []string{"alpha"}
setting.domain = ""
setting.proxyPort = 8000
response := mockTelegramEndpointCall(t, "/telegram/alpha", "/trace 1.1.1.1", false)
assert.Equal(t, response, "```\nMock Response\n```")
}
func TestWebHandlerTelegramBotRoute(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, "Mock Response")
httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show route for 1.1.1.1 primary"), httpResponse)
setting.servers = []string{"alpha"}
setting.domain = ""
setting.proxyPort = 8000
response := mockTelegramCall(t, "/route 1.1.1.1", false)
assert.Equal(t, response, "```\nMock Response\n```")
}
func TestWebHandlerTelegramBotPath(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, `
BGP.as_path: 123 456
`)
httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show route for 1.1.1.1 all primary"), httpResponse)
setting.servers = []string{"alpha"}
setting.domain = ""
setting.proxyPort = 8000
response := mockTelegramCall(t, "/path 1.1.1.1", false)
assert.Equal(t, response, "```\n123 456\n```")
}
func TestWebHandlerTelegramBotPathMissing(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpResponse := httpmock.NewStringResponder(200, "No path in this response")
httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show route for 1.1.1.1 all primary"), httpResponse)
setting.servers = []string{"alpha"}
setting.domain = ""
setting.proxyPort = 8000
response := mockTelegramCall(t, "/path 1.1.1.1", false)
assert.Equal(t, response, "```\nempty result\n```")
}
func TestWebHandlerTelegramBotWhois(t *testing.T) {
server := WhoisServer{
t: t,
expectedQuery: "AS6939",
response: AS6939Response,
}
server.Listen()
go server.Run()
defer server.Close()
setting.netSpecificMode = ""
setting.whoisServer = server.server.Addr().String()
response := mockTelegramCall(t, "/whois AS6939", false)
assert.Equal(t, response, "```"+server.response+"```")
}
func TestWebHandlerTelegramBotWhoisDN42Mode(t *testing.T) {
server := WhoisServer{
t: t,
expectedQuery: "AS4242422547",
response: `
Query for AS4242422547
`,
}
server.Listen()
go server.Run()
defer server.Close()
setting.netSpecificMode = "dn42"
setting.whoisServer = server.server.Addr().String()
response := mockTelegramCall(t, "/whois 2547", false)
assert.Equal(t, response, "```"+server.response+"```")
}
func TestWebHandlerTelegramBotWhoisDN42ModeFullASN(t *testing.T) {
server := WhoisServer{
t: t,
expectedQuery: "AS4242422547",
response: `
Query for AS4242422547
`,
}
server.Listen()
go server.Run()
defer server.Close()
setting.netSpecificMode = "dn42"
setting.whoisServer = server.server.Addr().String()
response := mockTelegramCall(t, "/whois 4242422547", false)
assert.Equal(t, response, "```"+server.response+"```")
}
func TestWebHandlerTelegramBotWhoisShortenMode(t *testing.T) {
server := WhoisServer{
t: t,
expectedQuery: "AS6939",
response: `
Information line that will be removed
# Comment that will be removed
Name: Redacted for privacy
Descr: This is a vvvvvvvvvvvvvvvvvvvvvvveeeeeeeeeeeeeeeeeeeerrrrrrrrrrrrrrrrrrrrrrrryyyyyyyyyyyyyyyyyyy long line that will be skipped.
Looooooooooooooooooooooong key: this line will be skipped.
Preserved1: this line isn't removed.
Preserved2: this line isn't removed.
Preserved3: this line isn't removed.
Preserved4: this line isn't removed.
Preserved5: this line isn't removed.
`,
}
expectedResult := `Preserved1: this line isn't removed.
Preserved2: this line isn't removed.
Preserved3: this line isn't removed.
Preserved4: this line isn't removed.
Preserved5: this line isn't removed.
3 line(s) skipped.`
server.Listen()
go server.Run()
defer server.Close()
setting.netSpecificMode = "shorten"
setting.whoisServer = server.server.Addr().String()
response := mockTelegramCall(t, "/whois AS6939", false)
assert.Equal(t, response, "```\n"+expectedResult+"\n```")
}
func TestWebHandlerTelegramBotHelp(t *testing.T) {
response := mockTelegramCall(t, "/help", false)
if !strings.Contains(response, "/trace") {
t.Error("Did not get help message")
}
}
func TestWebHandlerTelegramBotUnknownCommand(t *testing.T) {
response := mockTelegramCall(t, "/nonexistent", true)
assert.Equal(t, response, "")
}
func TestWebHandlerTelegramBotNotCommand(t *testing.T) {
response := mockTelegramCall(t, "random chat message", true)
assert.Equal(t, response, "")
}
func TestWebHandlerTelegramBotEmptyResponse(t *testing.T) {
server := WhoisServer{
t: t,
expectedQuery: "AS6939",
response: "",
}
server.Listen()
go server.Run()
defer server.Close()
setting.netSpecificMode = ""
setting.whoisServer = server.server.Addr().String()
response := mockTelegramCall(t, "/whois AS6939", false)
assert.Equal(t, response, "```\nempty result\n```")
}
func TestWebHandlerTelegramBotTruncateLongResponse(t *testing.T) {
server := WhoisServer{
t: t,
expectedQuery: "AS6939",
response: strings.Repeat("A", 65536),
}
server.Listen()
go server.Run()
defer server.Close()
setting.netSpecificMode = ""
setting.whoisServer = server.server.Addr().String()
response := mockTelegramCall(t, "/whois AS6939", false)
assert.Equal(t, response, "```\n"+strings.Repeat("A", 4096)+"\n```")
}

25
frontend/template_test.go Normal file
View File

@ -0,0 +1,25 @@
package main
import (
"testing"
"github.com/magiconair/properties/assert"
)
func TestSummaryRowDataNameHasPrefix(t *testing.T) {
data := SummaryRowData{
Name: "mock",
}
assert.Equal(t, data.NameHasPrefix("m"), true)
assert.Equal(t, data.NameHasPrefix("n"), false)
}
func TestSummaryRowDataNameContains(t *testing.T) {
data := SummaryRowData{
Name: "mock",
}
assert.Equal(t, data.NameContains("oc"), true)
assert.Equal(t, data.NameContains("no"), false)
}

View File

@ -0,0 +1,89 @@
package main
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/jarcoal/httpmock"
"github.com/magiconair/properties/assert"
)
func TestServerError(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/error", nil)
w := httptest.NewRecorder()
serverError(w, r)
assert.Equal(t, w.Code, http.StatusInternalServerError)
}
func TestWebHandlerWhois(t *testing.T) {
server := WhoisServer{
t: t,
expectedQuery: "AS6939",
response: AS6939Response,
}
server.Listen()
go server.Run()
defer server.Close()
setting.netSpecificMode = ""
setting.whoisServer = server.server.Addr().String()
r := httptest.NewRequest(http.MethodGet, "/whois/AS6939", nil)
w := httptest.NewRecorder()
webHandlerWhois(w, r)
assert.Equal(t, w.Code, http.StatusOK)
if !strings.Contains(w.Body.String(), "HURRICANE") {
t.Error("Body does not contain whois result")
}
}
func TestWebBackendCommunicator(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
input := readDataFile(t, "frontend/test_data/bgpmap_case1.txt")
httpResponse := httpmock.NewStringResponder(200, input)
httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show route for 1.1.1.1 all"), httpResponse)
setting.servers = []string{"alpha"}
setting.domain = ""
setting.proxyPort = 8000
setting.dnsInterface = ""
setting.whoisServer = ""
r := httptest.NewRequest(http.MethodGet, "/route_bgpmap/alpha/1.1.1.1", nil)
w := httptest.NewRecorder()
handler := webBackendCommunicator("bird", "route_all")
handler(w, r)
assert.Equal(t, w.Code, http.StatusOK)
}
func TestWebHandlerBGPMap(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
input := readDataFile(t, "frontend/test_data/bgpmap_case1.txt")
httpResponse := httpmock.NewStringResponder(200, input)
httpmock.RegisterResponder("GET", "http://alpha:8000/bird?q="+url.QueryEscape("show route for 1.1.1.1 all"), httpResponse)
setting.servers = []string{"alpha"}
setting.domain = ""
setting.proxyPort = 8000
setting.dnsInterface = ""
setting.whoisServer = ""
r := httptest.NewRequest(http.MethodGet, "/route_bgpmap/alpha/1.1.1.1", nil)
w := httptest.NewRecorder()
handler := webHandlerBGPMap("bird", "route_bgpmap")
handler(w, r)
assert.Equal(t, w.Code, http.StatusOK)
}

View File

@ -6,6 +6,8 @@ import (
"os/exec"
"strings"
"time"
"github.com/google/shlex"
)
// Send a whois request
@ -15,7 +17,13 @@ func whois(s string) string {
}
if strings.HasPrefix(setting.whoisServer, "/") {
cmd := exec.Command(setting.whoisServer, s)
args, err := shlex.Split(setting.whoisServer)
if err != nil {
return err.Error()
}
args = append(args, s)
cmd := exec.Command(args[0], args[1:]...)
output, err := cmd.CombinedOutput()
if err != nil {
return err.Error()
@ -26,7 +34,13 @@ func whois(s string) string {
return string(output)
} else {
buf := make([]byte, 65536)
conn, err := net.DialTimeout("tcp", setting.whoisServer+":43", 5*time.Second)
whoisServer := setting.whoisServer
if !strings.Contains(whoisServer, ":") {
whoisServer = whoisServer + ":43"
}
conn, err := net.DialTimeout("tcp", whoisServer, 5*time.Second)
if err != nil {
return err.Error()
}
@ -35,7 +49,7 @@ func whois(s string) string {
conn.Write([]byte(s + "\r\n"))
n, err := io.ReadFull(conn, buf)
if err != nil && err != io.ErrUnexpectedEOF {
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return err.Error()
}
return string(buf[:n])

View File

@ -1,14 +1,78 @@
package main
import (
"bufio"
"net"
"strings"
"testing"
)
func TestWhois(t *testing.T) {
checkNetwork(t)
type WhoisServer struct {
t *testing.T
expectedQuery string
response string
server net.Listener
}
setting.whoisServer = "whois.arin.net"
const AS6939Response = `
ASNumber: 6939
ASName: HURRICANE
ASHandle: AS6939
RegDate: 1996-06-28
Updated: 2003-11-04
Ref: https://rdap.arin.net/registry/autnum/6939
`
func (s *WhoisServer) Listen() {
var err error
s.server, err = net.Listen("tcp", "127.0.0.1:0")
if err != nil {
s.t.Error(err)
}
}
func (s *WhoisServer) Run() {
for {
conn, err := s.server.Accept()
if err != nil {
break
}
if conn == nil {
break
}
reader := bufio.NewReader(conn)
query, err := reader.ReadBytes('\n')
if err != nil {
break
}
if strings.TrimSpace(string(query)) != s.expectedQuery {
s.t.Errorf("Query %s doesn't match expectation %s", string(query), s.expectedQuery)
}
conn.Write([]byte(s.response))
conn.Close()
}
}
func (s *WhoisServer) Close() {
if s.server == nil {
return
}
s.server.Close()
}
func TestWhois(t *testing.T) {
server := WhoisServer{
t: t,
expectedQuery: "AS6939",
response: AS6939Response,
}
server.Listen()
go server.Run()
defer server.Close()
setting.whoisServer = server.server.Addr().String()
result := whois("AS6939")
if !strings.Contains(result, "HURRICANE") {
t.Errorf("Whois AS6939 failed, got %s", result)
@ -22,3 +86,43 @@ func TestWhoisWithoutServer(t *testing.T) {
t.Errorf("Whois AS6939 without server produced output, got %s", result)
}
}
func TestWhoisConnectionError(t *testing.T) {
setting.whoisServer = "127.0.0.1:0"
result := whois("AS6939")
if !strings.Contains(result, "connect: connection refused") {
t.Errorf("Whois AS6939 without server produced output, got %s", result)
}
}
func TestWhoisHostProcess(t *testing.T) {
setting.whoisServer = "/bin/sh -c \"echo Mock Result\""
result := whois("AS6939")
if result != "Mock Result\n" {
t.Errorf("Whois didn't produce expected result, got %s", result)
}
}
func TestWhoisHostProcessMalformedCommand(t *testing.T) {
setting.whoisServer = "/bin/sh -c \"mock"
result := whois("AS6939")
if result != "EOF found when expecting closing quote" {
t.Errorf("Whois didn't produce expected result, got %s", result)
}
}
func TestWhoisHostProcessError(t *testing.T) {
setting.whoisServer = "/nonexistent"
result := whois("AS6939")
if !strings.Contains(result, "no such file or directory") {
t.Errorf("Whois didn't produce expected result, got %s", result)
}
}
func TestWhoisHostProcessVeryLong(t *testing.T) {
setting.whoisServer = "/bin/sh -c \"for i in $(seq 1 131072); do printf 'A'; done\""
result := whois("AS6939")
if len(result) != 65535 {
t.Errorf("Whois result incorrectly truncated, actual len %d", len(result))
}
}

View File

@ -8,19 +8,23 @@ import (
"strings"
)
const MAX_LINE_SIZE = 1024
// Read a line from bird socket, removing preceding status number, output it.
// Returns if there are more lines.
func birdReadln(bird io.Reader, w io.Writer) bool {
// Read from socket byte by byte, until reaching newline character
c := make([]byte, 1024, 1024)
c := make([]byte, MAX_LINE_SIZE)
pos := 0
for {
if pos >= 1024 {
// Leave one byte for newline character
if pos >= MAX_LINE_SIZE-1 {
break
}
_, err := bird.Read(c[pos : pos+1])
if err != nil {
panic(err)
w.Write([]byte(err.Error()))
return false
}
if c[pos] == byte('\n') {
break
@ -29,6 +33,7 @@ func birdReadln(bird io.Reader, w io.Writer) bool {
}
c = c[:pos+1]
c[pos] = '\n'
// print(string(c[:]))
// Remove preceding status number, different situations

213
proxy/bird_test.go Normal file
View File

@ -0,0 +1,213 @@
package main
import (
"bufio"
"bytes"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"path"
"strings"
"testing"
"github.com/magiconair/properties/assert"
)
type BirdServer struct {
t *testing.T
expectedQuery string
response string
server net.Listener
socket string
injectError string
}
func (s *BirdServer) initSocket() {
tmpDir, err := ioutil.TempDir("", "bird-lgproxy-go-mock")
if err != nil {
s.t.Fatal(err)
}
s.socket = path.Join(tmpDir, "mock.socket")
}
func (s *BirdServer) Listen() {
s.initSocket()
var err error
s.server, err = net.Listen("unix", s.socket)
if err != nil {
s.t.Error(err)
}
}
func (s *BirdServer) Run() {
for {
conn, err := s.server.Accept()
if err != nil {
break
}
if conn == nil {
break
}
reader := bufio.NewReader(conn)
conn.Write([]byte("1234 Hello from mock bird\n"))
query, err := reader.ReadBytes('\n')
if err != nil {
break
}
if strings.TrimSpace(string(query)) != "restrict" {
s.t.Errorf("Did not restrict bird permissions")
}
if s.injectError == "restriction" {
conn.Write([]byte("1234 Restriction is disabled!\n"))
} else {
conn.Write([]byte("1234 Access restricted\n"))
}
query, err = reader.ReadBytes('\n')
if err != nil {
break
}
if strings.TrimSpace(string(query)) != s.expectedQuery {
s.t.Errorf("Query %s doesn't match expectation %s", string(query), s.expectedQuery)
}
responseList := strings.Split(s.response, "\n")
for i := range responseList {
if i == len(responseList)-1 {
if s.injectError == "eof" {
conn.Write([]byte("0000 " + responseList[i]))
} else {
conn.Write([]byte("0000 " + responseList[i] + "\n"))
}
} else {
conn.Write([]byte("1234 " + responseList[i] + "\n"))
}
}
conn.Close()
}
}
func (s *BirdServer) Close() {
if s.server == nil {
return
}
s.server.Close()
}
func TestBirdReadln(t *testing.T) {
input := strings.NewReader("1234 Bird Message\n")
var output bytes.Buffer
birdReadln(input, &output)
assert.Equal(t, output.String(), "Bird Message\n")
}
func TestBirdReadlnNoPrefix(t *testing.T) {
input := strings.NewReader(" Message without prefix\n")
var output bytes.Buffer
birdReadln(input, &output)
assert.Equal(t, output.String(), "Message without prefix\n")
}
func TestBirdReadlnVeryLongLine(t *testing.T) {
input := strings.NewReader(strings.Repeat("A", 4096))
var output bytes.Buffer
birdReadln(input, &output)
assert.Equal(t, output.String(), strings.Repeat("A", 1022)+"\n")
}
func TestBirdWriteln(t *testing.T) {
var output bytes.Buffer
birdWriteln(&output, "Test command")
assert.Equal(t, output.String(), "Test command\n")
}
func TestBirdHandlerWithoutQuery(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/bird", nil)
w := httptest.NewRecorder()
birdHandler(w, r)
}
func TestBirdHandlerWithQuery(t *testing.T) {
server := BirdServer{
t: t,
expectedQuery: "show protocols",
response: "Mock Response\nSecond Line",
injectError: "",
}
server.Listen()
go server.Run()
defer server.Close()
setting.birdSocket = server.socket
r := httptest.NewRequest(http.MethodGet, "/bird?q="+url.QueryEscape(server.expectedQuery), nil)
w := httptest.NewRecorder()
birdHandler(w, r)
assert.Equal(t, w.Code, http.StatusOK)
assert.Equal(t, w.Body.String(), server.response+"\n")
}
func TestBirdHandlerWithBadSocket(t *testing.T) {
setting.birdSocket = "/nonexistent.sock"
r := httptest.NewRequest(http.MethodGet, "/bird?q="+url.QueryEscape("mock"), nil)
w := httptest.NewRecorder()
birdHandler(w, r)
assert.Equal(t, w.Code, http.StatusInternalServerError)
}
func TestBirdHandlerWithoutRestriction(t *testing.T) {
server := BirdServer{
t: t,
expectedQuery: "show protocols",
response: "Mock Response",
injectError: "restriction",
}
server.Listen()
go server.Run()
defer server.Close()
setting.birdSocket = server.socket
r := httptest.NewRequest(http.MethodGet, "/bird?q="+url.QueryEscape("mock"), nil)
w := httptest.NewRecorder()
birdHandler(w, r)
assert.Equal(t, w.Code, http.StatusInternalServerError)
}
func TestBirdHandlerEOF(t *testing.T) {
server := BirdServer{
t: t,
expectedQuery: "show protocols",
response: "Mock Response\nSecond Line",
injectError: "eof",
}
server.Listen()
go server.Run()
defer server.Close()
setting.birdSocket = server.socket
r := httptest.NewRequest(http.MethodGet, "/bird?q="+url.QueryEscape("show protocols"), nil)
w := httptest.NewRecorder()
birdHandler(w, r)
assert.Equal(t, w.Code, http.StatusOK)
assert.Equal(t, w.Body.String(), "Mock Response\nEOF")
}

View File

@ -21,39 +21,49 @@ func invalidHandler(httpW http.ResponseWriter, httpR *http.Request) {
httpW.Write([]byte("Invalid Request\n"))
}
func hasAccess(remoteAddr string) bool {
// setting.allowedIPs will always have at least one element because of how it's defined
if len(setting.allowedIPs) == 0 {
return true
}
if !strings.Contains(remoteAddr, ":") {
return false
}
// Remove port from IP and remove brackets that are around IPv6 addresses
remoteAddr = remoteAddr[0:strings.LastIndex(remoteAddr, ":")]
remoteAddr = strings.Trim(remoteAddr, "[]")
ipObject := net.ParseIP(remoteAddr)
if ipObject == nil {
return false
}
for _, allowedIP := range setting.allowedIPs {
if ipObject.Equal(allowedIP) {
return true
}
}
return false
}
// Access handler, check to see if client IP in allowed IPs, continue if it is, send to invalidHandler if not
func accessHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(httpW http.ResponseWriter, httpR *http.Request) {
// setting.allowedIPs will always have at least one element because of how it's defined
if setting.allowedIPs[0] == "" {
if hasAccess(httpR.RemoteAddr) {
next.ServeHTTP(httpW, httpR)
return
}
IPPort := httpR.RemoteAddr
// Remove port from IP and remove brackets that are around IPv6 addresses
requestIp := IPPort[0:strings.LastIndex(IPPort, ":")]
requestIp = strings.Replace(requestIp, "[", "", -1)
requestIp = strings.Replace(requestIp, "]", "", -1)
for _, allowedIP := range setting.allowedIPs {
if requestIp == allowedIP {
next.ServeHTTP(httpW, httpR)
return
}
}
} else {
invalidHandler(httpW, httpR)
return
}
})
}
type settingType struct {
birdSocket string
listen string
allowedIPs []string
allowedIPs []net.IP
tr_bin string
tr_flags []string
tr_raw bool

78
proxy/main_test.go Normal file
View File

@ -0,0 +1,78 @@
package main
import (
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/magiconair/properties/assert"
)
func TestHasAccessNotConfigured(t *testing.T) {
setting.allowedIPs = []net.IP{}
assert.Equal(t, hasAccess("whatever"), true)
}
func TestHasAccessAllowIPv4(t *testing.T) {
setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")}
assert.Equal(t, hasAccess("1.2.3.4:4321"), true)
}
func TestHasAccessDenyIPv4(t *testing.T) {
setting.allowedIPs = []net.IP{net.ParseIP("4.3.2.1")}
assert.Equal(t, hasAccess("1.2.3.4:4321"), false)
}
func TestHasAccessAllowIPv6(t *testing.T) {
setting.allowedIPs = []net.IP{net.ParseIP("2001:db8::1")}
assert.Equal(t, hasAccess("[2001:db8::1]:4321"), true)
}
func TestHasAccessAllowIPv6DifferentForm(t *testing.T) {
setting.allowedIPs = []net.IP{net.ParseIP("2001:0db8::1")}
assert.Equal(t, hasAccess("[2001:db8::1]:4321"), true)
}
func TestHasAccessDenyIPv6(t *testing.T) {
setting.allowedIPs = []net.IP{net.ParseIP("2001:db8::2")}
assert.Equal(t, hasAccess("[2001:db8::1]:4321"), false)
}
func TestHasAccessBadClientIP(t *testing.T) {
setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")}
assert.Equal(t, hasAccess("not an IP"), false)
}
func TestHasAccessBadClientIPPort(t *testing.T) {
setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")}
assert.Equal(t, hasAccess("not an IP:not a port"), false)
}
func TestAccessHandlerAllow(t *testing.T) {
baseHandler := http.NotFoundHandler()
wrappedHandler := accessHandler(baseHandler)
r := httptest.NewRequest(http.MethodGet, "/mock", nil)
r.RemoteAddr = "1.2.3.4:4321"
w := httptest.NewRecorder()
setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")}
wrappedHandler.ServeHTTP(w, r)
assert.Equal(t, w.Code, http.StatusNotFound)
}
func TestAccessHandlerDeny(t *testing.T) {
baseHandler := http.NotFoundHandler()
wrappedHandler := accessHandler(baseHandler)
r := httptest.NewRequest(http.MethodGet, "/mock", nil)
r.RemoteAddr = "1.2.3.4:4321"
w := httptest.NewRecorder()
setting.allowedIPs = []net.IP{net.ParseIP("4.3.2.1")}
wrappedHandler.ServeHTTP(w, r)
assert.Equal(t, w.Code, http.StatusInternalServerError)
}

View File

@ -2,6 +2,7 @@ package main
import (
"fmt"
"net"
"strings"
"github.com/google/shlex"
@ -66,9 +67,17 @@ func parseSettings() {
setting.listen = viperSettings.Listen
if viperSettings.AllowedIPs != "" {
setting.allowedIPs = strings.Split(viperSettings.AllowedIPs, ",")
for _, ip := range strings.Split(viperSettings.AllowedIPs, ",") {
ipObject := net.ParseIP(ip)
if ipObject == nil {
fmt.Printf("Parse IP %s failed\n", ip)
continue
}
setting.allowedIPs = append(setting.allowedIPs, ipObject)
}
} else {
setting.allowedIPs = []string{""}
setting.allowedIPs = []net.IP{}
}
var err error

8
proxy/settings_test.go Normal file
View File

@ -0,0 +1,8 @@
package main
import "testing"
func TestParseSettings(t *testing.T) {
parseSettings()
// Good as long as it doesn't panic
}

168
proxy/traceroute_test.go Normal file
View File

@ -0,0 +1,168 @@
package main
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"github.com/magiconair/properties/assert"
)
func TestTracerouteArgsToString(t *testing.T) {
result := tracerouteArgsToString("traceroute", []string{
"-a",
"-b",
"-c",
}, []string{
"google.com",
})
assert.Equal(t, result, "traceroute -a -b -c google.com")
}
func TestTracerouteTryExecuteSuccess(t *testing.T) {
_, err := tracerouteTryExecute("sh", []string{
"-c",
}, []string{
"true",
})
if err != nil {
t.Error(err)
}
}
func TestTracerouteTryExecuteFail(t *testing.T) {
_, err := tracerouteTryExecute("sh", []string{
"-c",
}, []string{
"false",
})
if err == nil {
t.Error("Should trigger error, not triggered")
}
}
func TestTracerouteDetectSuccess(t *testing.T) {
result := tracerouteDetect("sh", []string{
"-c",
"true",
})
assert.Equal(t, result, true)
}
func TestTracerouteDetectFail(t *testing.T) {
result := tracerouteDetect("sh", []string{
"-c",
"false",
})
assert.Equal(t, result, false)
}
func TestTracerouteAutodetect(t *testing.T) {
pathBackup := os.Getenv("PATH")
os.Setenv("PATH", "")
defer os.Setenv("PATH", pathBackup)
setting.tr_bin = ""
setting.tr_flags = []string{}
tracerouteAutodetect()
// Should not panic
}
func TestTracerouteAutodetectExisting(t *testing.T) {
setting.tr_bin = "mock"
setting.tr_flags = []string{"mock"}
tracerouteAutodetect()
assert.Equal(t, setting.tr_bin, "mock")
assert.Equal(t, setting.tr_flags, []string{"mock"})
}
func TestTracerouteAutodetectFlagsOnly(t *testing.T) {
pathBackup := os.Getenv("PATH")
os.Setenv("PATH", "")
defer os.Setenv("PATH", pathBackup)
setting.tr_bin = "mock"
setting.tr_flags = nil
tracerouteAutodetect()
// Should not panic
}
func TestTracerouteHandlerWithoutQuery(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/traceroute", nil)
w := httptest.NewRecorder()
tracerouteHandler(w, r)
assert.Equal(t, w.Code, http.StatusInternalServerError)
if !strings.Contains(w.Body.String(), "Invalid Request") {
t.Error("Did not get invalid request")
}
}
func TestTracerouteHandlerShlexError(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("\"1.1.1.1"), nil)
w := httptest.NewRecorder()
tracerouteHandler(w, r)
assert.Equal(t, w.Code, http.StatusInternalServerError)
if !strings.Contains(w.Body.String(), "parse") {
t.Error("Did not get parsing error message")
}
}
func TestTracerouteHandlerNoTracerouteFound(t *testing.T) {
setting.tr_bin = ""
setting.tr_flags = nil
r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("1.1.1.1"), nil)
w := httptest.NewRecorder()
tracerouteHandler(w, r)
assert.Equal(t, w.Code, http.StatusInternalServerError)
if !strings.Contains(w.Body.String(), "not supported") {
t.Error("Did not get not supported error message")
}
}
func TestTracerouteHandlerExecuteError(t *testing.T) {
setting.tr_bin = "sh"
setting.tr_flags = []string{"-c", "false"}
setting.tr_raw = true
r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("1.1.1.1"), nil)
w := httptest.NewRecorder()
tracerouteHandler(w, r)
assert.Equal(t, w.Code, http.StatusInternalServerError)
if !strings.Contains(w.Body.String(), "Error executing traceroute") {
t.Error("Did not get not execute error message")
}
}
func TestTracerouteHandlerRaw(t *testing.T) {
setting.tr_bin = "sh"
setting.tr_flags = []string{"-c", "echo Mock"}
setting.tr_raw = true
r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("1.1.1.1"), nil)
w := httptest.NewRecorder()
tracerouteHandler(w, r)
assert.Equal(t, w.Code, http.StatusOK)
assert.Equal(t, w.Body.String(), "Mock\n")
}
func TestTracerouteHandlerPostprocess(t *testing.T) {
setting.tr_bin = "sh"
setting.tr_flags = []string{"-c", "echo \"first line\n 2 *\nthird line\""}
setting.tr_raw = false
r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("1.1.1.1"), nil)
w := httptest.NewRecorder()
tracerouteHandler(w, r)
assert.Equal(t, w.Code, http.StatusOK)
assert.Equal(t, w.Body.String(), "first line\nthird line\n\n1 hops not responding.")
}