diff --git a/client.go b/client.go index 5662362..8de12fd 100644 --- a/client.go +++ b/client.go @@ -13,15 +13,14 @@ import ( // Client is the client type Client struct { - c *server.BgpServer - ips *[]IPNet - ipv6Plen int - community string - wg *sync.WaitGroup + c *server.BgpServer + ips *[]IPNet + ipv6Plen int + wg *sync.WaitGroup } // NewClient instantiates a new client connection -func NewClient(c string, ips *[]IPNet) (*Client, error) { +func NewClient(ips *[]IPNet) (*Client, error) { maxSize := 256 << 20 grpcOpts := []grpc.ServerOption{grpc.MaxRecvMsgSize(maxSize), grpc.MaxSendMsgSize(maxSize)} @@ -51,11 +50,10 @@ func NewClient(c string, ips *[]IPNet) (*Client, error) { } return &Client{ - c: cl, - ips: ips, - ipv6Plen: 64, - community: c, - wg: wg, + c: cl, + ips: ips, + ipv6Plen: 64, + wg: wg, }, nil } @@ -141,10 +139,10 @@ func (c *Client) AddStaticRoute(nh string, p IPNet, cm string) error { // AddRoutes adds a static route for all IPs monitored func (c *Client) AddRoutes() error { for _, ip := range *c.ips { - if err := c.AddStaticRoute("", ip, c.community); err != nil { + if err := c.AddStaticRoute("", ip, ip.community); err != nil { return err } - log.WithFields(log.Fields{"Topic": "Route", "Route": ip, "Community": c.community}).Info("added route") + log.WithFields(log.Fields{"Topic": "Route", "Route": ip, "Community": ip.community}).Info("added route") } return nil } diff --git a/helpers.go b/helpers.go index 7092c33..8ba39ca 100644 --- a/helpers.go +++ b/helpers.go @@ -13,16 +13,19 @@ import ( ) // IPNet is a extension of net.IPNet with some addons -type IPNet net.IPNet +type IPNet struct { + ip net.IPNet + community string +} func (i IPNet) String() string { - j, _ := i.Mask.Size() - return fmt.Sprintf("%v/%v", i.IP, j) + j, _ := i.ip.Mask.Size() + return fmt.Sprintf("%v/%v", i.ip.IP, j) } // Plen returns the prefix len as uint func (i IPNet) Plen() uint32 { - j, _ := i.Mask.Size() + j, _ := i.ip.Mask.Size() return uint32(j) } @@ -34,8 +37,7 @@ func IPNetFromAddr(a net.Addr) (*IPNet, error) { } return &IPNet{ - IP: p.IP, - Mask: p.Mask, + ip: *p, }, nil } @@ -70,12 +72,12 @@ func getPath(p IPNet, nh string, myCom string) (*api.Path, error) { } nlri, _ := ptypes.MarshalAny(&api.IPAddressPrefix{ - Prefix: p.IP.String(), + Prefix: p.ip.IP.String(), PrefixLen: p.Plen(), }) var family *api.Family - if p.IP.To4() == nil { + if p.ip.IP.To4() == nil { family = &api.Family{ Afi: api.Family_AFI_IP6, Safi: api.Family_SAFI_UNICAST, @@ -111,7 +113,7 @@ func getPath(p IPNet, nh string, myCom string) (*api.Path, error) { }, nil } -//get all local IPs elegible to be elastic IP +// get all local IPs elegible to be elastic IP func getIPs(v6Mask int, allIfs bool) (*[]IPNet, error) { var addrs []net.Addr var err error @@ -135,54 +137,56 @@ func getIPs(v6Mask int, allIfs bool) (*[]IPNet, error) { ips := make(map[string]*IPNet) for _, addr := range addrs { - ip, err := IPNetFromAddr(addr) + p, err := IPNetFromAddr(addr) if err != nil { log.WithFields(log.Fields{"Topic": "Helper", "Route": addr, "Error": "invalid IP"}).Warn("invalid IP") continue } // ignore loopback IPs - if ip.IP.IsLoopback() { - log.WithFields(log.Fields{"Topic": "Helper", "Route": ip, "Warn": "not acceptable elastic IP"}). + if p.ip.IP.IsLoopback() { + log.WithFields(log.Fields{"Topic": "Helper", "Route": p, "Warn": "not acceptable elastic IP"}). Trace("ignoring loopback IPs") continue } // ignore link local IPs - if ip.IP.IsLinkLocalUnicast() { - log.WithFields(log.Fields{"Topic": "Helper", "Route": ip, "Warn": "not acceptable elastic IP"}). + if p.ip.IP.IsLinkLocalUnicast() { + log.WithFields(log.Fields{"Topic": "Helper", "Route": p, "Warn": "not acceptable elastic IP"}). Trace("ignoring linklocal IPs") continue } // for ipv4 only a /32 is acceptable - if ip.IP.To4() != nil && ip.Plen() != 32 { - log.WithFields(log.Fields{"Topic": "Helper", "Route": ip, "Warn": "not accepted prefix length"}). + if p.ip.IP.To4() != nil && p.Plen() != 32 { + log.WithFields(log.Fields{"Topic": "Helper", "Route": p, "Warn": "not accepted prefix length"}). Warn("not accepted prefix length") continue } // for ipv6 lets find the greater subnet we're part of, make it a /64 (or if asked a /56) and advertise that - if ip.IP.To4() == nil { - if ip.Plen() != 64 && ip.Plen() != 56 { - log.WithFields(log.Fields{"Topic": "Helper", "Route": ip, "Warn": "fixing prefix length"}). + if p.ip.IP.To4() == nil { + if p.Plen() != 64 && p.Plen() != 56 { + log.WithFields(log.Fields{"Topic": "Helper", "Route": p, "Warn": "fixing prefix length"}). Warnf("fixing prefix lenth length to /%d", v6Mask) - ip.Mask = sendMask + p.ip.Mask = sendMask } - _, ipNew, err := net.ParseCIDR(ip.String()) + _, ipNew, err := net.ParseCIDR(p.String()) if err != nil { - log.WithFields(log.Fields{"Topic": "Helper", "Route": ip, "Error": "invalid IP"}). + log.WithFields(log.Fields{"Topic": "Helper", "Route": p, "Error": "invalid IP"}). Warnf("unable to supernet") continue } - ip = &IPNet{ - IP: ipNew.IP, - Mask: ipNew.Mask, + p = &IPNet{ + ip: net.IPNet{ + IP: ipNew.IP, + Mask: ipNew.Mask, + }, } } - ips[ip.String()] = ip - log.WithFields(log.Fields{"Topic": "Helper", "Route": ip}).Debug("handling prefix") + ips[p.String()] = p + log.WithFields(log.Fields{"Topic": "Helper", "Route": p}).Debug("handling prefix") } var uniqIPs []IPNet diff --git a/main.go b/main.go index 6097305..0b0ab75 100644 --- a/main.go +++ b/main.go @@ -3,8 +3,8 @@ package main import ( "flag" "fmt" - log "github.com/sirupsen/logrus" + "net" ) const ( @@ -16,9 +16,23 @@ const ( communitySecondary = "65000:2" ) +type arrayFlags []string + +func (i *arrayFlags) String() string { + return fmt.Sprintf("%v", []string(*i)) +} + +func (i *arrayFlags) Set(value string) error { + *i = append(*i, value) + return nil +} + +var primaryIps arrayFlags +var secondaryIps arrayFlags + func main() { - primary := flag.Bool("primary", false, "advertise as primary") - secondary := flag.Bool("secondary", false, "advertise as secondary") + flag.Var(&primaryIps, "primary", "Advertise as primary for a specific IP. Must contain CIDR notation") + flag.Var(&secondaryIps, "secondary", "Advertise as secondary for a specific IP. Must contain CIDR notation") loglevel := flag.String("loglevel", "info", "set log level: trace, debug, info or warn") logjson := flag.Bool("logjson", false, "set log format to json") dcid := flag.Int("dcid", 0, "dcid for your DC") @@ -46,9 +60,21 @@ func main() { log.WithFields(log.Fields{"Topic": "Main"}).Fatal("dcid not provided, I need this info") } - if !*primary && !*secondary { + if primaryIps == nil && secondaryIps == nil { flag.Usage() - log.WithFields(log.Fields{"Topic": "Main"}).Fatal("use either primary or secondary flag") + log.WithFields(log.Fields{"Topic": "Main"}).Fatal("primary and/or secondary must be provided") + } + + for _, ip := range primaryIps { + if _, _, err := net.ParseCIDR(ip); err != nil { + log.WithFields(log.Fields{"Topic": "Main"}).Fatalf("invalid primary ip: %s. Must be in CIDR notation", ip) + } + } + + for _, ip := range secondaryIps { + if _, _, err := net.ParseCIDR(ip); err != nil { + log.WithFields(log.Fields{"Topic": "Main"}).Fatalf("invalid secondary ip: %s. Must be in CIDR notation", ip) + } } switch *loglevel { @@ -66,29 +92,60 @@ func main() { log.SetLevel(log.InfoLevel) } - var myCommunity string + var communityMap map[string]string + communityMap = make(map[string]string) - switch { - case *primary: - myCommunity = communityPrimary - case *secondary: - myCommunity = communitySecondary - default: - log.WithFields(log.Fields{"Topic": "Main"}).Fatal("use either primary or secondary flag") + if len(primaryIps) > 0 { + for _, ip := range primaryIps { + communityMap[ip] = communityPrimary + } + } + + if len(secondaryIps) > 0 { + for _, ip := range secondaryIps { + communityMap[ip] = communitySecondary + } } - //ips v6Mask := 64 if *send56 { v6Mask = 56 } - ips, err := getIPs(v6Mask, *allIfs) + allIps, err := getIPs(v6Mask, *allIfs) if err != nil { log.WithFields(log.Fields{"Topic": "Main"}).Fatalf("unable to detect IPs: %v", err) } - c, err := NewClient(myCommunity, ips) + var ips []IPNet + + // Filter the list of IPs to only include the ones we want to advertise + if len(communityMap) > 0 { + for _, ipData := range *allIps { + ipString := ipData.String() + if _, ok := communityMap[ipString]; ok { + log.WithFields(log.Fields{"Topic": "Main"}).Infof("advertising IP %s with community %s", ipString, communityMap[ipString]) + ipData.community = communityMap[ipString] + ips = append(ips, ipData) + } else { + log.WithFields(log.Fields{"Topic": "Main", "IP": ipString}).Warnf("not advetising IP %s as it was not specified", ipString) + } + } + } else { + log.WithFields(log.Fields{"Topic": "Main"}).Info("no IPs specified, advertising all IPs") + ips = *allIps + } + + if len(ips) != len(communityMap) { + log.WithFields(log.Fields{ + "Topic": "Main", + "Detected": ips, + "Requested Primaries": primaryIps, + "Requested Secondaries": secondaryIps, + }).Fatal("Unable to detect all IPs specified. Check the IP addresses assigned to 'lo' or alternatively try the -allifs flag") + } + + c, err := NewClient(&ips) if err != nil { log.WithFields(log.Fields{"Topic": "Main"}).Fatal("failed to initiate the client: ", err) }