diff --git a/.drone.yml b/.drone.yml index e3ffbe6..b576c00 100644 --- a/.drone.yml +++ b/.drone.yml @@ -23,21 +23,14 @@ steps: - go vet - go build - - name: stage - image: alpine - commands: - - mkdir artifacts - - mv frontend/frontend artifacts/ - - mv proxy/proxy artifacts/ - - tar -cvzf bird-lg-go.bdn42.tar.gz -C artifacts . - - name: upload artifacts image: git.burble.dn42/burble.dn42/drone-gitea-pkg-plugin:latest settings: token: from_secret: TOKEN version: RELEASE - artifact: bird-lg-go.bdn42.tar.gz + artifact: proxy + filename: proxy/proxy package: bird-lg-go owner: burble.dn42 diff --git a/README.md b/README.md index 6efb99a..62b5041 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ Configuration is handled by [viper](https://github.com/spf13/viper), any config | Config Key | Parameter | Environment Variable | Description | | ---------- | --------- | -------------------- | ----------- | -| allowed_ips | --allowed | ALLOWED_IPS | IPs allowed to access this proxy, separated by commas. Don't set to allow all IPs. (default "") | +| allowed_ips | --allowed | ALLOWED_IPS | IPs or networks allowed to access this proxy, separated by commas. Don't set to allow all IPs. (default "") | | bird_socket | --bird | BIRD_SOCKET | socket file for bird, set either in parameter or environment variable BIRD_SOCKET (default "/var/run/bird/bird.ctl") | | listen | --listen | BIRDLG_PROXY_PORT | listen address, set either in parameter or environment variable BIRDLG_PROXY_PORT(default "8000") | | traceroute_bin | --traceroute_bin | BIRDLG_TRACEROUTE_BIN | traceroute binary file, set either in parameter or environment variable BIRDLG_TRACEROUTE_BIN | diff --git a/proxy/main.go b/proxy/main.go index 07ddc65..05b8ea1 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -22,8 +22,8 @@ func invalidHandler(httpW http.ResponseWriter, httpR *http.Request) { } func hasAccess(remoteAddr string) bool { - // setting.allowedIPs will always have at least one element because of how it's defined - if len(setting.allowedIPs) == 0 { + // setting.allowedNets will always have at least one element because of how it's defined + if len(setting.allowedNets) == 0 { return true } @@ -40,8 +40,8 @@ func hasAccess(remoteAddr string) bool { return false } - for _, allowedIP := range setting.allowedIPs { - if ipObject.Equal(allowedIP) { + for _, net := range setting.allowedNets { + if net.Contains(ipObject) { return true } } @@ -49,7 +49,7 @@ func hasAccess(remoteAddr string) bool { return false } -// Access handler, check to see if client IP in allowed IPs, continue if it is, send to invalidHandler if not +// Access handler, check to see if client IP in allowed nets, continue if it is, send to invalidHandler if not func accessHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(httpW http.ResponseWriter, httpR *http.Request) { if hasAccess(httpR.RemoteAddr) { @@ -61,12 +61,12 @@ func accessHandler(next http.Handler) http.Handler { } type settingType struct { - birdSocket string - listen string - allowedIPs []net.IP - tr_bin string - tr_flags []string - tr_raw bool + birdSocket string + listen string + allowedNets []*net.IPNet + tr_bin string + tr_flags []string + tr_raw bool } var setting settingType diff --git a/proxy/main_test.go b/proxy/main_test.go index 1f33449..52429bd 100644 --- a/proxy/main_test.go +++ b/proxy/main_test.go @@ -10,42 +10,61 @@ import ( ) func TestHasAccessNotConfigured(t *testing.T) { - setting.allowedIPs = []net.IP{} + setting.allowedNets = []*net.IPNet{} assert.Equal(t, hasAccess("whatever"), true) } func TestHasAccessAllowIPv4(t *testing.T) { - setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")} + _, netip, _ := net.ParseCIDR("1.2.3.4/32") + setting.allowedNets = []*net.IPNet{netip} + assert.Equal(t, hasAccess("1.2.3.4:4321"), true) +} + +func TestHasAccessAllowIPv4Net(t *testing.T) { + _, netip, _ := net.ParseCIDR("1.2.3.0/24") + setting.allowedNets = []*net.IPNet{netip} assert.Equal(t, hasAccess("1.2.3.4:4321"), true) } func TestHasAccessDenyIPv4(t *testing.T) { - setting.allowedIPs = []net.IP{net.ParseIP("4.3.2.1")} + _, netip, _ := net.ParseCIDR("4.3.2.1/32") + setting.allowedNets = []*net.IPNet{netip} assert.Equal(t, hasAccess("1.2.3.4:4321"), false) } func TestHasAccessAllowIPv6(t *testing.T) { - setting.allowedIPs = []net.IP{net.ParseIP("2001:db8::1")} + _, netip, _ := net.ParseCIDR("2001:db8::1/128") + setting.allowedNets = []*net.IPNet{netip} + assert.Equal(t, hasAccess("[2001:db8::1]:4321"), true) +} + +func TestHasAccessAllowIPv6Net(t *testing.T) { + _, netip, _ := net.ParseCIDR("2001:db8::/64") + setting.allowedNets = []*net.IPNet{netip} assert.Equal(t, hasAccess("[2001:db8::1]:4321"), true) } func TestHasAccessAllowIPv6DifferentForm(t *testing.T) { - setting.allowedIPs = []net.IP{net.ParseIP("2001:0db8::1")} + _, netip, _ := net.ParseCIDR("2001:db8::1/128") + setting.allowedNets = []*net.IPNet{netip} assert.Equal(t, hasAccess("[2001:db8::1]:4321"), true) } func TestHasAccessDenyIPv6(t *testing.T) { - setting.allowedIPs = []net.IP{net.ParseIP("2001:db8::2")} + _, netip, _ := net.ParseCIDR("2001:db8::2/128") + setting.allowedNets = []*net.IPNet{netip} assert.Equal(t, hasAccess("[2001:db8::1]:4321"), false) } func TestHasAccessBadClientIP(t *testing.T) { - setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")} + _, netip, _ := net.ParseCIDR("1.2.3.4/32") + setting.allowedNets = []*net.IPNet{netip} assert.Equal(t, hasAccess("not an IP"), false) } func TestHasAccessBadClientIPPort(t *testing.T) { - setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")} + _, netip, _ := net.ParseCIDR("1.2.3.4/32") + setting.allowedNets = []*net.IPNet{netip} assert.Equal(t, hasAccess("not an IP:not a port"), false) } @@ -57,7 +76,8 @@ func TestAccessHandlerAllow(t *testing.T) { r.RemoteAddr = "1.2.3.4:4321" w := httptest.NewRecorder() - setting.allowedIPs = []net.IP{net.ParseIP("1.2.3.4")} + _, netip, _ := net.ParseCIDR("1.2.3.4/32") + setting.allowedNets = []*net.IPNet{netip} wrappedHandler.ServeHTTP(w, r) assert.Equal(t, w.Code, http.StatusNotFound) @@ -71,7 +91,8 @@ func TestAccessHandlerDeny(t *testing.T) { r.RemoteAddr = "1.2.3.4:4321" w := httptest.NewRecorder() - setting.allowedIPs = []net.IP{net.ParseIP("4.3.2.1")} + _, netip, _ := net.ParseCIDR("4.3.2.1/32") + setting.allowedNets = []*net.IPNet{netip} wrappedHandler.ServeHTTP(w, r) assert.Equal(t, w.Code, http.StatusInternalServerError) diff --git a/proxy/settings.go b/proxy/settings.go index 098f4a5..93ac1cb 100644 --- a/proxy/settings.go +++ b/proxy/settings.go @@ -13,7 +13,7 @@ import ( type viperSettingType struct { BirdSocket string `mapstructure:"bird_socket"` Listen string `mapstructure:"listen"` - AllowedIPs string `mapstructure:"allowed_ips"` + AllowedNets string `mapstructure:"allowed_ips"` TracerouteBin string `mapstructure:"traceroute_bin"` TracerouteFlags string `mapstructure:"traceroute_flags"` TracerouteRaw bool `mapstructure:"traceroute_raw"` @@ -40,7 +40,7 @@ func parseSettings() { pflag.String("listen", "8000", "listen address, set either in parameter or environment variable BIRDLG_PROXY_PORT") viper.BindPFlag("listen", pflag.Lookup("listen")) - pflag.String("allowed", "", "IPs allowed to access this proxy, separated by commas. Don't set to allow all IPs.") + pflag.String("allowed", "", "IPs or networks allowed to access this proxy, separated by commas. Don't set to allow all IPs.") viper.BindPFlag("allowed_ips", pflag.Lookup("allowed")) pflag.String("traceroute_bin", "", "traceroute binary file, set either in parameter or environment variable BIRDLG_TRACEROUTE_BIN") @@ -66,18 +66,31 @@ func parseSettings() { setting.birdSocket = viperSettings.BirdSocket setting.listen = viperSettings.Listen - if viperSettings.AllowedIPs != "" { - for _, ip := range strings.Split(viperSettings.AllowedIPs, ",") { - ipObject := net.ParseIP(ip) - if ipObject == nil { - fmt.Printf("Parse IP %s failed\n", ip) - continue + if viperSettings.AllowedNets != "" { + for _, arg := range strings.Split(viperSettings.AllowedNets, ",") { + + // if argument is an IP address, convert to CIDR by adding a suitable mask + if !strings.Contains(arg, "/") { + if strings.Contains(arg, ":") { + // IPv6 address with /128 mask + arg += "/128" + } else { + // IPv4 address with /32 mask + arg += "/32" + } } - setting.allowedIPs = append(setting.allowedIPs, ipObject) + // parse the network + _, netip, err := net.ParseCIDR(arg) + if err != nil { + fmt.Printf("Failed to parse CIDR %s: %s\n", arg, err.Error()) + continue + } + setting.allowedNets = append(setting.allowedNets, netip) + } } else { - setting.allowedIPs = []net.IP{} + setting.allowedNets = []*net.IPNet{} } var err error