Refactor code

This commit is contained in:
Kioubit 2021-12-22 07:01:30 -05:00
parent 7d876fb290
commit 2b18c35dba
8 changed files with 171 additions and 142 deletions

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"log" "log"
"os" "os"
"pndpd/pndp"
"strings" "strings"
) )
@ -33,7 +34,7 @@ func readConfig(dest string) {
} }
if strings.HasPrefix(line, "debug") { if strings.HasPrefix(line, "debug") {
if strings.Contains(line, "off") { if strings.Contains(line, "off") {
GlobalDebug = false pndp.GlobalDebug = false
} }
continue continue
} }
@ -54,7 +55,7 @@ func readConfig(dest string) {
break break
} }
} }
go simpleRespond(obj.Iface, parseFilter(obj.Filter)) pndp.SimpleRespond(obj.Iface, pndp.ParseFilter(obj.Filter))
} }
if strings.HasPrefix(line, "proxy") { if strings.HasPrefix(line, "proxy") {
obj := configProxy{} obj := configProxy{}
@ -71,7 +72,7 @@ func readConfig(dest string) {
break break
} }
} }
go proxy(obj.Iface1, obj.Iface2) pndp.Proxy(obj.Iface1, obj.Iface2)
} }
} }

View File

@ -3,6 +3,7 @@ package main
import ( import (
"fmt" "fmt"
"os" "os"
"pndpd/pndp"
) )
func main() { func main() {
@ -16,19 +17,19 @@ func main() {
switch os.Args[1] { switch os.Args[1] {
case "respond": case "respond":
if len(os.Args) == 4 { if len(os.Args) == 4 {
go simpleRespond(os.Args[2], parseFilter(os.Args[3])) pndp.SimpleRespond(os.Args[2], pndp.ParseFilter(os.Args[3]))
} else { } else {
go simpleRespond(os.Args[2], nil) pndp.SimpleRespond(os.Args[2], nil)
} }
case "proxy": case "proxy":
go proxy(os.Args[2], os.Args[3]) pndp.Proxy(os.Args[2], os.Args[3])
case "readconfig": case "readconfig":
readConfig(os.Args[2]) readConfig(os.Args[2])
default: default:
printUsage() printUsage()
return return
} }
waitForSignal() pndp.WaitForSignal()
} }
func printUsage() { func printUsage() {

View File

@ -1,14 +1,14 @@
package main package pndp
type NDPType int type ndpType int
const ( const (
NDP_ADV NDPType = 0 ndp_ADV ndpType = 0
NDP_SOL NDPType = 1 ndp_SOL ndpType = 1
) )
type NDRequest struct { type ndpRequest struct {
requestType NDPType requestType ndpType
srcIP []byte srcIP []byte
answeringForIP []byte answeringForIP []byte
dstIP []byte dstIP []byte

View File

@ -1,4 +1,4 @@
package main package pndp
import ( import (
"encoding/binary" "encoding/binary"
@ -9,11 +9,11 @@ import (
var emptyIpv6 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} var emptyIpv6 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
type Payload interface { type payload interface {
constructPacket() ([]byte, int) constructPacket() ([]byte, int)
} }
type IPv6Header struct { type ipv6Header struct {
protocol byte protocol byte
srcIP []byte srcIP []byte
dstIP []byte dstIP []byte
@ -21,14 +21,14 @@ type IPv6Header struct {
payload []byte payload []byte
} }
func newIpv6Header(srcIp []byte, dstIp []byte) (*IPv6Header, error) { func newIpv6Header(srcIp []byte, dstIp []byte) (*ipv6Header, error) {
if len(dstIp) != 16 || len(srcIp) != 16 { if len(dstIp) != 16 || len(srcIp) != 16 {
return nil, errors.New("malformed IP") return nil, errors.New("malformed IP")
} }
return &IPv6Header{dstIP: dstIp, srcIP: srcIp, protocol: 0x3a}, nil return &ipv6Header{dstIP: dstIp, srcIP: srcIp, protocol: 0x3a}, nil
} }
func (h *IPv6Header) addPayload(payload Payload) { func (h *ipv6Header) addPayload(payload payload) {
bPayload, checksumPos := payload.constructPacket() bPayload, checksumPos := payload.constructPacket()
bPayloadLen := make([]byte, 2) bPayloadLen := make([]byte, 2)
binary.BigEndian.PutUint16(bPayloadLen, uint16(len(bPayload))) binary.BigEndian.PutUint16(bPayloadLen, uint16(len(bPayload)))
@ -44,14 +44,14 @@ func (h *IPv6Header) addPayload(payload Payload) {
h.payload = bPayload h.payload = bPayload
} }
func (h *IPv6Header) constructPacket() []byte { func (h *ipv6Header) constructPacket() []byte {
header := []byte{ header := []byte{
0x60, // v6 0x60, // v6
0, // qos 0, // qos
0, // qos 0, // qos
0, // qos 0, // qos
h.payloadLen[0], // Payload Length h.payloadLen[0], // payload Length
h.payloadLen[1], // Payload Length h.payloadLen[1], // payload Length
h.protocol, // Protocol next header h.protocol, // Protocol next header
0xff, // Hop limit 0xff, // Hop limit
} }
@ -61,28 +61,28 @@ func (h *IPv6Header) constructPacket() []byte {
return final return final
} }
type NdpPayload struct { type ndpPayload struct {
packetType NDPType packetType ndpType
answeringForIP []byte answeringForIP []byte
mac []byte mac []byte
} }
func newNdpPacket(answeringForIP []byte, mac []byte, packetType NDPType) (*NdpPayload, error) { func newNdpPacket(answeringForIP []byte, mac []byte, packetType ndpType) (*ndpPayload, error) {
if len(answeringForIP) != 16 || len(mac) != 6 { if len(answeringForIP) != 16 || len(mac) != 6 {
return nil, errors.New("malformed IP") return nil, errors.New("malformed IP")
} }
return &NdpPayload{ return &ndpPayload{
packetType: packetType, packetType: packetType,
answeringForIP: answeringForIP, answeringForIP: answeringForIP,
mac: mac, mac: mac,
}, nil }, nil
} }
func (p *NdpPayload) constructPacket() ([]byte, int) { func (p *ndpPayload) constructPacket() ([]byte, int) {
var protocol byte var protocol byte
var flags byte var flags byte
var linkType byte var linkType byte
if p.packetType == NDP_SOL { if p.packetType == ndp_SOL {
protocol = 0x87 protocol = 0x87
flags = 0x0 flags = 0x0
linkType = 0x01 linkType = 0x01
@ -92,7 +92,7 @@ func (p *NdpPayload) constructPacket() ([]byte, int) {
linkType = 0x02 linkType = 0x02
} }
header := []byte{ header := []byte{
protocol, // Type: NDPType protocol, // Type: ndpType
0x0, // Code 0x0, // Code
0x0, // Checksum filled in later 0x0, // Checksum filled in later
0x0, // Checksum filled in later 0x0, // Checksum filled in later
@ -113,7 +113,7 @@ func (p *NdpPayload) constructPacket() ([]byte, int) {
return final, 2 return final, 2
} }
func calculateChecksum(h *IPv6Header, payload []byte) uint16 { func calculateChecksum(h *ipv6Header, payload []byte) uint16 {
sumPseudoHeader := checksumAddition(h.srcIP) + checksumAddition(h.dstIP) + checksumAddition([]byte{0x00, h.protocol}) + checksumAddition(h.payloadLen) sumPseudoHeader := checksumAddition(h.srcIP) + checksumAddition(h.dstIP) + checksumAddition([]byte{0x00, h.protocol}) + checksumAddition(h.payloadLen)
sumPayload := checksumAddition(payload) sumPayload := checksumAddition(payload)
sumTotal := sumPayload + sumPseudoHeader sumTotal := sumPayload + sumPseudoHeader
@ -133,7 +133,7 @@ func checksumAddition(b []byte) uint32 {
return sum return sum
} }
func IsIPv6(ip string) bool { func isIpv6(ip string) bool {
rip := net.ParseIP(ip) rip := net.ParseIP(ip)
return rip != nil && strings.Contains(ip, ":") return rip != nil && strings.Contains(ip, ":")
} }

125
pndp/process.go Normal file
View File

@ -0,0 +1,125 @@
package pndp
import (
"fmt"
"net"
"os"
"os/signal"
"runtime/pprof"
"strings"
"sync"
"syscall"
"time"
)
var GlobalDebug = false
// Items needed for graceful shutdown
var stop = make(chan struct{})
var stopWg sync.WaitGroup
var sigCh = make(chan os.Signal)
// WaitForSignal Waits (blocking) for the program to be interrupted by the OS and then gracefully shuts down releasing all resources
func WaitForSignal() {
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
<-sigCh
Shutdown()
}
// Shutdown Exits the program gracefully and releases all resources
//
//Do not use with WaitForSignal
func Shutdown() {
fmt.Println("Shutting down...")
close(stop)
if wgWaitTimout(&stopWg, 10*time.Second) {
fmt.Println("Done")
} else {
fmt.Println("Aborting shutdown, since it is taking too long")
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
}
os.Exit(0)
}
func wgWaitTimout(wg *sync.WaitGroup, timeout time.Duration) bool {
t := make(chan struct{})
go func() {
defer close(t)
wg.Wait()
}()
select {
case <-t:
return true
case <-time.After(timeout):
return false
}
}
// SimpleRespond
//
// iface - The interface to listen to and respond from
//
// filter - Optional (can be nil) list of CIDRs to whitelist
//
// Non blocking
func SimpleRespond(iface string, filter []*net.IPNet) {
go simpleRespond(iface, filter)
}
func simpleRespond(iface string, filter []*net.IPNet) {
defer stopWg.Done()
stopWg.Add(3) // This function, 2x goroutines
requests := make(chan *ndpRequest, 100)
defer close(requests)
go respond(iface, requests, ndp_ADV, filter)
go listen(iface, requests, ndp_SOL)
<-stop
}
// Proxy NDP between interfaces iface1 and iface2
//
// Non blocking
func Proxy(iface1, iface2 string) {
go proxy(iface1, iface2)
}
func proxy(iface1, iface2 string) {
defer stopWg.Done()
stopWg.Add(9) // This function, 8x goroutines
req_iface1_sol_iface2 := make(chan *ndpRequest, 100)
defer close(req_iface1_sol_iface2)
go listen(iface1, req_iface1_sol_iface2, ndp_SOL)
go respond(iface2, req_iface1_sol_iface2, ndp_SOL, nil)
req_iface2_sol_iface1 := make(chan *ndpRequest, 100)
defer close(req_iface2_sol_iface1)
go listen(iface2, req_iface2_sol_iface1, ndp_SOL)
go respond(iface1, req_iface2_sol_iface1, ndp_SOL, nil)
req_iface1_adv_iface2 := make(chan *ndpRequest, 100)
defer close(req_iface1_adv_iface2)
go listen(iface1, req_iface1_adv_iface2, ndp_ADV)
go respond(iface2, req_iface1_adv_iface2, ndp_ADV, nil)
req_iface2_adv_iface1 := make(chan *ndpRequest, 100)
defer close(req_iface2_adv_iface1)
go listen(iface2, req_iface2_adv_iface1, ndp_ADV)
go respond(iface1, req_iface2_adv_iface1, ndp_ADV, nil)
<-stop
}
// ParseFilter Helper Function to Parse a string of CIDRs separated by a semicolon as a Whitelist for SimpleRespond
func ParseFilter(f string) []*net.IPNet {
s := strings.Split(f, ";")
result := make([]*net.IPNet, len(s))
for i, n := range s {
_, cidr, err := net.ParseCIDR(n)
if err != nil {
panic(err)
}
result[i] = cidr
}
return result
}

View File

@ -1,4 +1,4 @@
package main package pndp
import ( import (
"fmt" "fmt"
@ -9,11 +9,11 @@ import (
"unsafe" "unsafe"
) )
// Filter represents a classic BPF filter program that can be applied to a socket // bpfFilter represents a classic BPF filter program that can be applied to a socket
type Filter []bpf.Instruction type bpfFilter []bpf.Instruction
// ApplyTo applies the current filter onto the provided file descriptor // ApplyTo applies the current filter onto the provided file descriptor
func (filter Filter) ApplyTo(fd int) (err error) { func (filter bpfFilter) ApplyTo(fd int) (err error) {
var assembled []bpf.RawInstruction var assembled []bpf.RawInstruction
if assembled, err = bpf.Assemble(filter); err != nil { if assembled, err = bpf.Assemble(filter); err != nil {
return err return err
@ -40,7 +40,7 @@ func htons(v uint16) int {
} }
func htons16(v uint16) uint16 { return v<<8 | v>>8 } func htons16(v uint16) uint16 { return v<<8 | v>>8 }
func listen(iface string, responder chan *NDRequest, requestType NDPType) { func listen(iface string, responder chan *ndpRequest, requestType ndpType) {
niface, err := net.InterfaceByName(iface) niface, err := net.InterfaceByName(iface)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
@ -71,7 +71,7 @@ func listen(iface string, responder chan *NDRequest, requestType NDPType) {
} }
var protocolNo uint32 var protocolNo uint32
if requestType == NDP_SOL { if requestType == ndp_SOL {
//Neighbor Solicitation //Neighbor Solicitation
protocolNo = 0x87 protocolNo = 0x87
} else { } else {
@ -79,7 +79,7 @@ func listen(iface string, responder chan *NDRequest, requestType NDPType) {
protocolNo = 0x88 protocolNo = 0x88
} }
var f Filter = []bpf.Instruction{ var f bpfFilter = []bpf.Instruction{
// Load "EtherType" field from the ethernet header. // Load "EtherType" field from the ethernet header.
bpf.LoadAbsolute{Off: 12, Size: 2}, bpf.LoadAbsolute{Off: 12, Size: 2},
// Jump to the drop packet instruction if EtherType is not IPv6. // Jump to the drop packet instruction if EtherType is not IPv6.
@ -120,7 +120,7 @@ func listen(iface string, responder chan *NDRequest, requestType NDPType) {
fmt.Printf("% X\n", buf[:numRead][80:86]) fmt.Printf("% X\n", buf[:numRead][80:86])
fmt.Println() fmt.Println()
} }
responder <- &NDRequest{ responder <- &ndpRequest{
requestType: requestType, requestType: requestType,
srcIP: buf[:numRead][22:38], srcIP: buf[:numRead][22:38],
dstIP: buf[:numRead][38:54], dstIP: buf[:numRead][38:54],

View File

@ -1,4 +1,4 @@
package main package pndp
import ( import (
"bytes" "bytes"
@ -9,7 +9,7 @@ import (
var globalFd int var globalFd int
func respond(iface string, requests chan *NDRequest, respondType NDPType, filter []*net.IPNet) { func respond(iface string, requests chan *ndpRequest, respondType ndpType, filter []*net.IPNet) {
defer stopWg.Done() defer stopWg.Done()
fd, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_RAW, syscall.IPPROTO_RAW) fd, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil { if err != nil {
@ -35,7 +35,7 @@ func respond(iface string, requests chan *NDRequest, respondType NDPType, filter
if err != nil { if err != nil {
break break
} }
if IsIPv6(tip.String()) { if isIpv6(tip.String()) {
if tip.IsGlobalUnicast() { if tip.IsGlobalUnicast() {
result = tip result = tip
_, tnet, _ := net.ParseCIDR("fc00::/7") _, tnet, _ := net.ParseCIDR("fc00::/7")
@ -47,7 +47,7 @@ func respond(iface string, requests chan *NDRequest, respondType NDPType, filter
} }
for { for {
var n *NDRequest var n *ndpRequest
select { select {
case <-stop: case <-stop:
return return
@ -78,7 +78,7 @@ func respond(iface string, requests chan *NDRequest, respondType NDPType, filter
} }
} }
func pkt(ownIP []byte, dstIP []byte, tgtip []byte, mac []byte, respondType NDPType) { func pkt(ownIP []byte, dstIP []byte, tgtip []byte, mac []byte, respondType ndpType) {
v6, err := newIpv6Header(ownIP, dstIP) v6, err := newIpv6Header(ownIP, dstIP)
if err != nil { if err != nil {
return return

View File

@ -1,98 +0,0 @@
package main
import (
"fmt"
"net"
"os"
"os/signal"
"runtime/pprof"
"strings"
"sync"
"syscall"
"time"
)
var GlobalDebug = false
// Items needed for graceful shutdown
var stop = make(chan struct{})
var stopWg sync.WaitGroup
var sigCh = make(chan os.Signal)
func waitForSignal() {
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
<-sigCh
fmt.Println("Shutting down...")
close(stop)
if wgWaitTimout(&stopWg, 10*time.Second) {
fmt.Println("Done")
} else {
fmt.Println("Aborting shutdown, since it is taking too long")
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
}
os.Exit(0)
}
func wgWaitTimout(wg *sync.WaitGroup, timeout time.Duration) bool {
t := make(chan struct{})
go func() {
defer close(t)
wg.Wait()
}()
select {
case <-t:
return true
case <-time.After(timeout):
return false
}
}
func simpleRespond(iface string, filter []*net.IPNet) {
defer stopWg.Done()
stopWg.Add(3) // This function, 2x goroutines
requests := make(chan *NDRequest, 100)
defer close(requests)
go respond(iface, requests, NDP_ADV, filter)
go listen(iface, requests, NDP_SOL)
<-stop
}
func proxy(iface1, iface2 string) {
defer stopWg.Done()
stopWg.Add(9) // This function, 8x goroutines
req_iface1_sol_iface2 := make(chan *NDRequest, 100)
defer close(req_iface1_sol_iface2)
go listen(iface1, req_iface1_sol_iface2, NDP_SOL)
go respond(iface2, req_iface1_sol_iface2, NDP_SOL, nil)
req_iface2_sol_iface1 := make(chan *NDRequest, 100)
defer close(req_iface2_sol_iface1)
go listen(iface2, req_iface2_sol_iface1, NDP_SOL)
go respond(iface1, req_iface2_sol_iface1, NDP_SOL, nil)
req_iface1_adv_iface2 := make(chan *NDRequest, 100)
defer close(req_iface1_adv_iface2)
go listen(iface1, req_iface1_adv_iface2, NDP_ADV)
go respond(iface2, req_iface1_adv_iface2, NDP_ADV, nil)
req_iface2_adv_iface1 := make(chan *NDRequest, 100)
defer close(req_iface2_adv_iface1)
go listen(iface2, req_iface2_adv_iface1, NDP_ADV)
go respond(iface1, req_iface2_adv_iface1, NDP_ADV, nil)
<-stop
}
func parseFilter(f string) []*net.IPNet {
s := strings.Split(f, ";")
result := make([]*net.IPNet, len(s))
for i, n := range s {
_, cidr, err := net.ParseCIDR(n)
if err != nil {
panic(err)
}
result[i] = cidr
}
return result
}