Refactor code
This commit is contained in:
parent
7d876fb290
commit
2b18c35dba
@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"log"
|
||||
"os"
|
||||
"pndpd/pndp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -33,7 +34,7 @@ func readConfig(dest string) {
|
||||
}
|
||||
if strings.HasPrefix(line, "debug") {
|
||||
if strings.Contains(line, "off") {
|
||||
GlobalDebug = false
|
||||
pndp.GlobalDebug = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
@ -54,7 +55,7 @@ func readConfig(dest string) {
|
||||
break
|
||||
}
|
||||
}
|
||||
go simpleRespond(obj.Iface, parseFilter(obj.Filter))
|
||||
pndp.SimpleRespond(obj.Iface, pndp.ParseFilter(obj.Filter))
|
||||
}
|
||||
if strings.HasPrefix(line, "proxy") {
|
||||
obj := configProxy{}
|
||||
@ -71,7 +72,7 @@ func readConfig(dest string) {
|
||||
break
|
||||
}
|
||||
}
|
||||
go proxy(obj.Iface1, obj.Iface2)
|
||||
pndp.Proxy(obj.Iface1, obj.Iface2)
|
||||
}
|
||||
}
|
||||
|
||||
|
9
main.go
9
main.go
@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"pndpd/pndp"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@ -16,19 +17,19 @@ func main() {
|
||||
switch os.Args[1] {
|
||||
case "respond":
|
||||
if len(os.Args) == 4 {
|
||||
go simpleRespond(os.Args[2], parseFilter(os.Args[3]))
|
||||
pndp.SimpleRespond(os.Args[2], pndp.ParseFilter(os.Args[3]))
|
||||
} else {
|
||||
go simpleRespond(os.Args[2], nil)
|
||||
pndp.SimpleRespond(os.Args[2], nil)
|
||||
}
|
||||
case "proxy":
|
||||
go proxy(os.Args[2], os.Args[3])
|
||||
pndp.Proxy(os.Args[2], os.Args[3])
|
||||
case "readconfig":
|
||||
readConfig(os.Args[2])
|
||||
default:
|
||||
printUsage()
|
||||
return
|
||||
}
|
||||
waitForSignal()
|
||||
pndp.WaitForSignal()
|
||||
}
|
||||
|
||||
func printUsage() {
|
||||
|
@ -1,14 +1,14 @@
|
||||
package main
|
||||
package pndp
|
||||
|
||||
type NDPType int
|
||||
type ndpType int
|
||||
|
||||
const (
|
||||
NDP_ADV NDPType = 0
|
||||
NDP_SOL NDPType = 1
|
||||
ndp_ADV ndpType = 0
|
||||
ndp_SOL ndpType = 1
|
||||
)
|
||||
|
||||
type NDRequest struct {
|
||||
requestType NDPType
|
||||
type ndpRequest struct {
|
||||
requestType ndpType
|
||||
srcIP []byte
|
||||
answeringForIP []byte
|
||||
dstIP []byte
|
@ -1,4 +1,4 @@
|
||||
package main
|
||||
package pndp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
@ -9,11 +9,11 @@ import (
|
||||
|
||||
var emptyIpv6 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
|
||||
type Payload interface {
|
||||
type payload interface {
|
||||
constructPacket() ([]byte, int)
|
||||
}
|
||||
|
||||
type IPv6Header struct {
|
||||
type ipv6Header struct {
|
||||
protocol byte
|
||||
srcIP []byte
|
||||
dstIP []byte
|
||||
@ -21,14 +21,14 @@ type IPv6Header struct {
|
||||
payload []byte
|
||||
}
|
||||
|
||||
func newIpv6Header(srcIp []byte, dstIp []byte) (*IPv6Header, error) {
|
||||
func newIpv6Header(srcIp []byte, dstIp []byte) (*ipv6Header, error) {
|
||||
if len(dstIp) != 16 || len(srcIp) != 16 {
|
||||
return nil, errors.New("malformed IP")
|
||||
}
|
||||
return &IPv6Header{dstIP: dstIp, srcIP: srcIp, protocol: 0x3a}, nil
|
||||
return &ipv6Header{dstIP: dstIp, srcIP: srcIp, protocol: 0x3a}, nil
|
||||
}
|
||||
|
||||
func (h *IPv6Header) addPayload(payload Payload) {
|
||||
func (h *ipv6Header) addPayload(payload payload) {
|
||||
bPayload, checksumPos := payload.constructPacket()
|
||||
bPayloadLen := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(bPayloadLen, uint16(len(bPayload)))
|
||||
@ -44,14 +44,14 @@ func (h *IPv6Header) addPayload(payload Payload) {
|
||||
h.payload = bPayload
|
||||
}
|
||||
|
||||
func (h *IPv6Header) constructPacket() []byte {
|
||||
func (h *ipv6Header) constructPacket() []byte {
|
||||
header := []byte{
|
||||
0x60, // v6
|
||||
0, // qos
|
||||
0, // qos
|
||||
0, // qos
|
||||
h.payloadLen[0], // Payload Length
|
||||
h.payloadLen[1], // Payload Length
|
||||
h.payloadLen[0], // payload Length
|
||||
h.payloadLen[1], // payload Length
|
||||
h.protocol, // Protocol next header
|
||||
0xff, // Hop limit
|
||||
}
|
||||
@ -61,28 +61,28 @@ func (h *IPv6Header) constructPacket() []byte {
|
||||
return final
|
||||
}
|
||||
|
||||
type NdpPayload struct {
|
||||
packetType NDPType
|
||||
type ndpPayload struct {
|
||||
packetType ndpType
|
||||
answeringForIP []byte
|
||||
mac []byte
|
||||
}
|
||||
|
||||
func newNdpPacket(answeringForIP []byte, mac []byte, packetType NDPType) (*NdpPayload, error) {
|
||||
func newNdpPacket(answeringForIP []byte, mac []byte, packetType ndpType) (*ndpPayload, error) {
|
||||
if len(answeringForIP) != 16 || len(mac) != 6 {
|
||||
return nil, errors.New("malformed IP")
|
||||
}
|
||||
return &NdpPayload{
|
||||
return &ndpPayload{
|
||||
packetType: packetType,
|
||||
answeringForIP: answeringForIP,
|
||||
mac: mac,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *NdpPayload) constructPacket() ([]byte, int) {
|
||||
func (p *ndpPayload) constructPacket() ([]byte, int) {
|
||||
var protocol byte
|
||||
var flags byte
|
||||
var linkType byte
|
||||
if p.packetType == NDP_SOL {
|
||||
if p.packetType == ndp_SOL {
|
||||
protocol = 0x87
|
||||
flags = 0x0
|
||||
linkType = 0x01
|
||||
@ -92,7 +92,7 @@ func (p *NdpPayload) constructPacket() ([]byte, int) {
|
||||
linkType = 0x02
|
||||
}
|
||||
header := []byte{
|
||||
protocol, // Type: NDPType
|
||||
protocol, // Type: ndpType
|
||||
0x0, // Code
|
||||
0x0, // Checksum filled in later
|
||||
0x0, // Checksum filled in later
|
||||
@ -113,7 +113,7 @@ func (p *NdpPayload) constructPacket() ([]byte, int) {
|
||||
return final, 2
|
||||
}
|
||||
|
||||
func calculateChecksum(h *IPv6Header, payload []byte) uint16 {
|
||||
func calculateChecksum(h *ipv6Header, payload []byte) uint16 {
|
||||
sumPseudoHeader := checksumAddition(h.srcIP) + checksumAddition(h.dstIP) + checksumAddition([]byte{0x00, h.protocol}) + checksumAddition(h.payloadLen)
|
||||
sumPayload := checksumAddition(payload)
|
||||
sumTotal := sumPayload + sumPseudoHeader
|
||||
@ -133,7 +133,7 @@ func checksumAddition(b []byte) uint32 {
|
||||
return sum
|
||||
}
|
||||
|
||||
func IsIPv6(ip string) bool {
|
||||
func isIpv6(ip string) bool {
|
||||
rip := net.ParseIP(ip)
|
||||
return rip != nil && strings.Contains(ip, ":")
|
||||
}
|
125
pndp/process.go
Normal file
125
pndp/process.go
Normal file
@ -0,0 +1,125 @@
|
||||
package pndp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var GlobalDebug = false
|
||||
|
||||
// Items needed for graceful shutdown
|
||||
var stop = make(chan struct{})
|
||||
var stopWg sync.WaitGroup
|
||||
var sigCh = make(chan os.Signal)
|
||||
|
||||
// WaitForSignal Waits (blocking) for the program to be interrupted by the OS and then gracefully shuts down releasing all resources
|
||||
func WaitForSignal() {
|
||||
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
Shutdown()
|
||||
}
|
||||
|
||||
// Shutdown Exits the program gracefully and releases all resources
|
||||
//
|
||||
//Do not use with WaitForSignal
|
||||
func Shutdown() {
|
||||
fmt.Println("Shutting down...")
|
||||
close(stop)
|
||||
if wgWaitTimout(&stopWg, 10*time.Second) {
|
||||
fmt.Println("Done")
|
||||
} else {
|
||||
fmt.Println("Aborting shutdown, since it is taking too long")
|
||||
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func wgWaitTimout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
||||
t := make(chan struct{})
|
||||
go func() {
|
||||
defer close(t)
|
||||
wg.Wait()
|
||||
}()
|
||||
select {
|
||||
case <-t:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SimpleRespond
|
||||
//
|
||||
// iface - The interface to listen to and respond from
|
||||
//
|
||||
// filter - Optional (can be nil) list of CIDRs to whitelist
|
||||
//
|
||||
// Non blocking
|
||||
func SimpleRespond(iface string, filter []*net.IPNet) {
|
||||
go simpleRespond(iface, filter)
|
||||
}
|
||||
|
||||
func simpleRespond(iface string, filter []*net.IPNet) {
|
||||
defer stopWg.Done()
|
||||
stopWg.Add(3) // This function, 2x goroutines
|
||||
requests := make(chan *ndpRequest, 100)
|
||||
defer close(requests)
|
||||
go respond(iface, requests, ndp_ADV, filter)
|
||||
go listen(iface, requests, ndp_SOL)
|
||||
<-stop
|
||||
}
|
||||
|
||||
// Proxy NDP between interfaces iface1 and iface2
|
||||
//
|
||||
// Non blocking
|
||||
func Proxy(iface1, iface2 string) {
|
||||
go proxy(iface1, iface2)
|
||||
}
|
||||
|
||||
func proxy(iface1, iface2 string) {
|
||||
defer stopWg.Done()
|
||||
stopWg.Add(9) // This function, 8x goroutines
|
||||
|
||||
req_iface1_sol_iface2 := make(chan *ndpRequest, 100)
|
||||
defer close(req_iface1_sol_iface2)
|
||||
go listen(iface1, req_iface1_sol_iface2, ndp_SOL)
|
||||
go respond(iface2, req_iface1_sol_iface2, ndp_SOL, nil)
|
||||
|
||||
req_iface2_sol_iface1 := make(chan *ndpRequest, 100)
|
||||
defer close(req_iface2_sol_iface1)
|
||||
go listen(iface2, req_iface2_sol_iface1, ndp_SOL)
|
||||
go respond(iface1, req_iface2_sol_iface1, ndp_SOL, nil)
|
||||
|
||||
req_iface1_adv_iface2 := make(chan *ndpRequest, 100)
|
||||
defer close(req_iface1_adv_iface2)
|
||||
go listen(iface1, req_iface1_adv_iface2, ndp_ADV)
|
||||
go respond(iface2, req_iface1_adv_iface2, ndp_ADV, nil)
|
||||
|
||||
req_iface2_adv_iface1 := make(chan *ndpRequest, 100)
|
||||
defer close(req_iface2_adv_iface1)
|
||||
go listen(iface2, req_iface2_adv_iface1, ndp_ADV)
|
||||
go respond(iface1, req_iface2_adv_iface1, ndp_ADV, nil)
|
||||
<-stop
|
||||
}
|
||||
|
||||
// ParseFilter Helper Function to Parse a string of CIDRs separated by a semicolon as a Whitelist for SimpleRespond
|
||||
func ParseFilter(f string) []*net.IPNet {
|
||||
s := strings.Split(f, ";")
|
||||
result := make([]*net.IPNet, len(s))
|
||||
for i, n := range s {
|
||||
_, cidr, err := net.ParseCIDR(n)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
result[i] = cidr
|
||||
}
|
||||
return result
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package main
|
||||
package pndp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -9,11 +9,11 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Filter represents a classic BPF filter program that can be applied to a socket
|
||||
type Filter []bpf.Instruction
|
||||
// bpfFilter represents a classic BPF filter program that can be applied to a socket
|
||||
type bpfFilter []bpf.Instruction
|
||||
|
||||
// ApplyTo applies the current filter onto the provided file descriptor
|
||||
func (filter Filter) ApplyTo(fd int) (err error) {
|
||||
func (filter bpfFilter) ApplyTo(fd int) (err error) {
|
||||
var assembled []bpf.RawInstruction
|
||||
if assembled, err = bpf.Assemble(filter); err != nil {
|
||||
return err
|
||||
@ -40,7 +40,7 @@ func htons(v uint16) int {
|
||||
}
|
||||
func htons16(v uint16) uint16 { return v<<8 | v>>8 }
|
||||
|
||||
func listen(iface string, responder chan *NDRequest, requestType NDPType) {
|
||||
func listen(iface string, responder chan *ndpRequest, requestType ndpType) {
|
||||
niface, err := net.InterfaceByName(iface)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
@ -71,7 +71,7 @@ func listen(iface string, responder chan *NDRequest, requestType NDPType) {
|
||||
}
|
||||
|
||||
var protocolNo uint32
|
||||
if requestType == NDP_SOL {
|
||||
if requestType == ndp_SOL {
|
||||
//Neighbor Solicitation
|
||||
protocolNo = 0x87
|
||||
} else {
|
||||
@ -79,7 +79,7 @@ func listen(iface string, responder chan *NDRequest, requestType NDPType) {
|
||||
protocolNo = 0x88
|
||||
}
|
||||
|
||||
var f Filter = []bpf.Instruction{
|
||||
var f bpfFilter = []bpf.Instruction{
|
||||
// Load "EtherType" field from the ethernet header.
|
||||
bpf.LoadAbsolute{Off: 12, Size: 2},
|
||||
// Jump to the drop packet instruction if EtherType is not IPv6.
|
||||
@ -120,7 +120,7 @@ func listen(iface string, responder chan *NDRequest, requestType NDPType) {
|
||||
fmt.Printf("% X\n", buf[:numRead][80:86])
|
||||
fmt.Println()
|
||||
}
|
||||
responder <- &NDRequest{
|
||||
responder <- &ndpRequest{
|
||||
requestType: requestType,
|
||||
srcIP: buf[:numRead][22:38],
|
||||
dstIP: buf[:numRead][38:54],
|
@ -1,4 +1,4 @@
|
||||
package main
|
||||
package pndp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -9,7 +9,7 @@ import (
|
||||
|
||||
var globalFd int
|
||||
|
||||
func respond(iface string, requests chan *NDRequest, respondType NDPType, filter []*net.IPNet) {
|
||||
func respond(iface string, requests chan *ndpRequest, respondType ndpType, filter []*net.IPNet) {
|
||||
defer stopWg.Done()
|
||||
fd, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
@ -35,7 +35,7 @@ func respond(iface string, requests chan *NDRequest, respondType NDPType, filter
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if IsIPv6(tip.String()) {
|
||||
if isIpv6(tip.String()) {
|
||||
if tip.IsGlobalUnicast() {
|
||||
result = tip
|
||||
_, tnet, _ := net.ParseCIDR("fc00::/7")
|
||||
@ -47,7 +47,7 @@ func respond(iface string, requests chan *NDRequest, respondType NDPType, filter
|
||||
}
|
||||
|
||||
for {
|
||||
var n *NDRequest
|
||||
var n *ndpRequest
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
@ -78,7 +78,7 @@ func respond(iface string, requests chan *NDRequest, respondType NDPType, filter
|
||||
}
|
||||
}
|
||||
|
||||
func pkt(ownIP []byte, dstIP []byte, tgtip []byte, mac []byte, respondType NDPType) {
|
||||
func pkt(ownIP []byte, dstIP []byte, tgtip []byte, mac []byte, respondType ndpType) {
|
||||
v6, err := newIpv6Header(ownIP, dstIP)
|
||||
if err != nil {
|
||||
return
|
98
process.go
98
process.go
@ -1,98 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var GlobalDebug = false
|
||||
|
||||
// Items needed for graceful shutdown
|
||||
var stop = make(chan struct{})
|
||||
var stopWg sync.WaitGroup
|
||||
var sigCh = make(chan os.Signal)
|
||||
|
||||
func waitForSignal() {
|
||||
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
fmt.Println("Shutting down...")
|
||||
close(stop)
|
||||
if wgWaitTimout(&stopWg, 10*time.Second) {
|
||||
fmt.Println("Done")
|
||||
} else {
|
||||
fmt.Println("Aborting shutdown, since it is taking too long")
|
||||
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func wgWaitTimout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
||||
t := make(chan struct{})
|
||||
go func() {
|
||||
defer close(t)
|
||||
wg.Wait()
|
||||
}()
|
||||
select {
|
||||
case <-t:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func simpleRespond(iface string, filter []*net.IPNet) {
|
||||
defer stopWg.Done()
|
||||
stopWg.Add(3) // This function, 2x goroutines
|
||||
requests := make(chan *NDRequest, 100)
|
||||
defer close(requests)
|
||||
go respond(iface, requests, NDP_ADV, filter)
|
||||
go listen(iface, requests, NDP_SOL)
|
||||
<-stop
|
||||
}
|
||||
|
||||
func proxy(iface1, iface2 string) {
|
||||
defer stopWg.Done()
|
||||
stopWg.Add(9) // This function, 8x goroutines
|
||||
|
||||
req_iface1_sol_iface2 := make(chan *NDRequest, 100)
|
||||
defer close(req_iface1_sol_iface2)
|
||||
go listen(iface1, req_iface1_sol_iface2, NDP_SOL)
|
||||
go respond(iface2, req_iface1_sol_iface2, NDP_SOL, nil)
|
||||
|
||||
req_iface2_sol_iface1 := make(chan *NDRequest, 100)
|
||||
defer close(req_iface2_sol_iface1)
|
||||
go listen(iface2, req_iface2_sol_iface1, NDP_SOL)
|
||||
go respond(iface1, req_iface2_sol_iface1, NDP_SOL, nil)
|
||||
|
||||
req_iface1_adv_iface2 := make(chan *NDRequest, 100)
|
||||
defer close(req_iface1_adv_iface2)
|
||||
go listen(iface1, req_iface1_adv_iface2, NDP_ADV)
|
||||
go respond(iface2, req_iface1_adv_iface2, NDP_ADV, nil)
|
||||
|
||||
req_iface2_adv_iface1 := make(chan *NDRequest, 100)
|
||||
defer close(req_iface2_adv_iface1)
|
||||
go listen(iface2, req_iface2_adv_iface1, NDP_ADV)
|
||||
go respond(iface1, req_iface2_adv_iface1, NDP_ADV, nil)
|
||||
<-stop
|
||||
}
|
||||
|
||||
func parseFilter(f string) []*net.IPNet {
|
||||
s := strings.Split(f, ";")
|
||||
result := make([]*net.IPNet, len(s))
|
||||
for i, n := range s {
|
||||
_, cidr, err := net.ParseCIDR(n)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
result[i] = cidr
|
||||
}
|
||||
return result
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user