feat: configure ACL policies based on tags and hosts

This commit is contained in:
Johan Siebens
2022-05-12 20:32:01 +02:00
parent 22cccceca9
commit e5c7a118a8
21 changed files with 1249 additions and 89 deletions
+6
View File
@@ -13,6 +13,7 @@ type BrokerPool struct {
type Signal struct {
PeerUpdated *uint64
PeersRemoved []uint64
ACLUpdated bool
}
type Broker interface {
@@ -21,6 +22,7 @@ type Broker interface {
SignalPeerUpdated(id uint64)
SignalPeersRemoved([]uint64)
SignalACLUpdated()
IsConnected(uint64) bool
}
@@ -86,6 +88,10 @@ func (h *broker) SignalPeersRemoved(ids []uint64) {
h.signalChannel <- &Signal{PeersRemoved: ids}
}
func (h *broker) SignalACLUpdated() {
h.signalChannel <- &Signal{ACLUpdated: true}
}
func (h *broker) listen() {
for {
select {
+123
View File
@@ -0,0 +1,123 @@
package cmd
import (
"context"
"encoding/json"
"fmt"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/muesli/coral"
"gopkg.in/yaml.v2"
"io/ioutil"
)
func getACLConfig() *coral.Command {
command := &coral.Command{
Use: "get-acl",
Short: "Get the ACL policy",
SilenceUsage: true,
}
var asJson bool
var tailnetID uint64
var tailnetName string
var target = Target{}
target.prepareCommand(command)
command.Flags().StringVar(&tailnetName, "tailnet", "", "")
command.Flags().Uint64Var(&tailnetID, "tailnet-id", 0, "")
command.Flags().BoolVar(&asJson, "json", false, "")
command.RunE = func(command *coral.Command, args []string) error {
client, c, err := target.createGRPCClient()
if err != nil {
return err
}
defer safeClose(c)
tailnet, err := findTailnet(client, tailnetName, tailnetID)
if err != nil {
return err
}
resp, err := client.GetACLPolicy(context.Background(), &api.GetACLPolicyRequest{TailnetId: tailnet.Id})
if err != nil {
return err
}
if asJson {
marshal, err := json.MarshalIndent(resp.Policy, "", " ")
if err != nil {
return err
}
fmt.Println()
fmt.Println(string(marshal))
} else {
marshal, err := yaml.Marshal(resp.Policy)
if err != nil {
return err
}
fmt.Println()
fmt.Println(string(marshal))
}
return nil
}
return command
}
func setACLConfig() *coral.Command {
command := &coral.Command{
Use: "set-acl",
Short: "Set ACL policy",
SilenceUsage: true,
}
var tailnetID uint64
var tailnetName string
var file string
var target = Target{}
target.prepareCommand(command)
command.Flags().StringVar(&tailnetName, "tailnet", "", "")
command.Flags().Uint64Var(&tailnetID, "tailnet-id", 0, "")
command.Flags().StringVar(&file, "file", "", "")
command.RunE = func(command *coral.Command, args []string) error {
rawJson, err := ioutil.ReadFile(file)
if err != nil {
return err
}
var policy api.Policy
if err := json.Unmarshal(rawJson, &policy); err != nil {
return err
}
client, c, err := target.createGRPCClient()
if err != nil {
return err
}
defer safeClose(c)
tailnet, err := findTailnet(client, tailnetName, tailnetID)
if err != nil {
return err
}
_, err = client.SetACLPolicy(context.Background(), &api.SetACLPolicyRequest{TailnetId: tailnet.Id, Policy: &policy})
if err != nil {
return err
}
fmt.Println()
fmt.Println("ACL policy updated successfully")
return nil
}
return command
}
+2
View File
@@ -16,6 +16,8 @@ func tailnetCommand() *coral.Command {
command.AddCommand(listTailnetsCommand())
command.AddCommand(createTailnetsCommand())
command.AddCommand(getACLConfig())
command.AddCommand(setACLConfig())
return command
}
+1
View File
@@ -43,6 +43,7 @@ func migrate(db *gorm.DB, repository domain.Repository) error {
err := db.AutoMigrate(
&domain.ServerConfig{},
&domain.Tailnet{},
&domain.TailnetConfig{},
&domain.User{},
&domain.AuthKey{},
&domain.Machine{},
+211
View File
@@ -0,0 +1,211 @@
package domain
import (
"fmt"
"inet.af/netaddr"
"strconv"
"strings"
"tailscale.com/tailcfg"
)
type ACLPolicy struct {
Hosts map[string]string `json:"hosts,omitempty"`
ACLs []ACL `json:"acls"`
}
type ACL struct {
Action string `json:"action"`
Src []string `json:"src"`
Dst []string `json:"dst"`
}
func defaultPolicy() ACLPolicy {
return ACLPolicy{
ACLs: []ACL{
{
Action: "accept",
Src: []string{"*"},
Dst: []string{"*:*"},
},
},
}
}
type aclEngine struct {
policy *ACLPolicy
expandedTags map[string][]string
}
func IsValidPeer(policy *ACLPolicy, src *Machine, dest *Machine) bool {
f := &aclEngine{
policy: policy,
}
return f.isValidPeer(src, dest)
}
func BuildFilterRules(policy *ACLPolicy, dst *Machine, peers []Machine) []tailcfg.FilterRule {
f := &aclEngine{
policy: policy,
}
return f.build(dst, peers)
}
func (a *aclEngine) isValidPeer(src *Machine, dest *Machine) bool {
for _, acl := range a.policy.ACLs {
allDestPorts := a.expandMachineToDstPorts(dest, acl.Dst)
if len(allDestPorts) == 0 {
continue
}
for _, alias := range acl.Src {
if len(a.expandMachineAlias(src, alias)) != 0 {
return true
}
}
}
return false
}
func (a *aclEngine) build(dst *Machine, peers []Machine) []tailcfg.FilterRule {
var rules []tailcfg.FilterRule
for _, acl := range a.policy.ACLs {
allDestPorts := a.expandMachineToDstPorts(dst, acl.Dst)
if len(allDestPorts) == 0 {
continue
}
var allSrcIPs []string
for _, src := range acl.Src {
for _, peer := range peers {
srcIPs := a.expandMachineAlias(&peer, src)
allSrcIPs = append(allSrcIPs, srcIPs...)
}
}
if len(allSrcIPs) == 0 {
allSrcIPs = nil
}
rule := tailcfg.FilterRule{
SrcIPs: allSrcIPs,
DstPorts: allDestPorts,
}
rules = append(rules, rule)
}
if len(rules) == 0 {
return []tailcfg.FilterRule{{}}
}
return rules
}
func (a *aclEngine) expandMachineToDstPorts(m *Machine, ports []string) []tailcfg.NetPortRange {
allDestRanges := []tailcfg.NetPortRange{}
for _, d := range ports {
ranges := a.expandMachineDestToNetPortRanges(m, d)
allDestRanges = append(allDestRanges, ranges...)
}
return allDestRanges
}
func (a *aclEngine) expandMachineDestToNetPortRanges(m *Machine, dest string) []tailcfg.NetPortRange {
tokens := strings.Split(dest, ":")
if len(tokens) < 2 || len(tokens) > 3 {
return nil
}
var alias string
if len(tokens) == 2 {
alias = tokens[0]
} else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
}
ports, err := a.expandValuePortToPortRange(tokens[len(tokens)-1])
if err != nil {
return nil
}
ips := a.expandMachineAlias(m, alias)
if len(ips) == 0 {
return nil
}
dests := []tailcfg.NetPortRange{}
for _, d := range ips {
for _, p := range ports {
pr := tailcfg.NetPortRange{
IP: d,
Ports: p,
}
dests = append(dests, pr)
}
}
return dests
}
func (a *aclEngine) expandMachineAlias(m *Machine, src string) []string {
if src == "*" {
if src == "*" {
return []string{"*"}
}
}
machineIPs := []string{m.IPv4.String(), m.IPv6.String()}
if strings.HasPrefix(src, "tag:") && m.HasTag(src[4:]) {
return machineIPs
}
if h, ok := a.policy.Hosts[src]; ok {
src = h
}
ip, err := netaddr.ParseIP(src)
if err == nil && m.HasIP(ip) {
return machineIPs
}
return []string{}
}
func (a *aclEngine) expandValuePortToPortRange(s string) ([]tailcfg.PortRange, error) {
if s == "*" {
return []tailcfg.PortRange{{First: 0, Last: 65535}}, nil
}
ports := []tailcfg.PortRange{}
for _, p := range strings.Split(s, ",") {
rang := strings.Split(p, "-")
if len(rang) == 1 {
pi, err := strconv.ParseUint(rang[0], 10, 16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(pi),
Last: uint16(pi),
})
} else if len(rang) == 2 {
start, err := strconv.ParseUint(rang[0], 10, 16)
if err != nil {
return nil, err
}
last, err := strconv.ParseUint(rang[1], 10, 16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(start),
Last: uint16(last),
})
} else {
return nil, fmt.Errorf("invalid format")
}
}
return ports, nil
}
+41 -2
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"time"
)
@@ -26,8 +27,8 @@ type Machine struct {
HostInfo HostInfo
Endpoints Endpoints
IPv4 string
IPv6 string
IPv4 IP
IPv6 IP
CreatedAt time.Time
ExpiresAt *time.Time
@@ -42,6 +43,44 @@ type Machine struct {
type Machines []Machine
func (m *Machine) HasIP(v netaddr.IP) bool {
return v.Compare(*m.IPv4.IP) == 0 || v.Compare(*m.IPv6.IP) == 0
}
func (m *Machine) HasTag(tag string) bool {
for _, t := range m.Tags {
if t == tag {
return true
}
}
return false
}
type IP struct {
*netaddr.IP
}
func (i *IP) Scan(destination interface{}) error {
switch value := destination.(type) {
case string:
ip, err := netaddr.ParseIP(value)
if err != nil {
return err
}
*i = IP{&ip}
return nil
default:
return fmt.Errorf("unexpected data type %T", destination)
}
}
func (i IP) Value() (driver.Value, error) {
if i.IP == nil {
return nil, nil
}
return i.String(), nil
}
type HostInfo tailcfg.Hostinfo
func (hi *HostInfo) Scan(destination interface{}) error {
+3
View File
@@ -15,6 +15,9 @@ type Repository interface {
GetTailnet(ctx context.Context, id uint64) (*Tailnet, error)
ListTailnets(ctx context.Context) ([]Tailnet, error)
GetACLPolicy(ctx context.Context, tailnetID uint64) (*ACLPolicy, error)
SetACLPolicy(ctx context.Context, tailnetID uint64, policy *ACLPolicy) error
SaveAuthKey(ctx context.Context, key *AuthKey) error
DeleteAuthKey(ctx context.Context, id uint64) (bool, error)
ListAuthKeys(ctx context.Context, tailnetID uint64) ([]AuthKey, error)
+65
View File
@@ -0,0 +1,65 @@
package domain
import (
"context"
"encoding/json"
"errors"
"gorm.io/gorm"
)
type TailnetConfig struct {
Key string `gorm:"primary_key"`
TailnetID uint64 `gorm:"primary_key;autoIncrement:false"`
Value []byte
}
func (r *repository) SetACLPolicy(ctx context.Context, tailnetID uint64, policy *ACLPolicy) error {
if err := r.setConfig(ctx, "acl_policy", tailnetID, policy); err != nil {
return err
}
return nil
}
func (r *repository) GetACLPolicy(ctx context.Context, tailnetID uint64) (*ACLPolicy, error) {
var p = defaultPolicy()
err := r.getConfig(ctx, "acl_policy", tailnetID, &p)
if err != nil {
return nil, err
}
return &p, nil
}
func (r *repository) getConfig(ctx context.Context, s string, tailnetID uint64, v interface{}) error {
var m TailnetConfig
tx := r.withContext(ctx).Take(&m, "key = ? AND tailnet_id = ?", s, tailnetID)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil
}
if tx.Error != nil {
return tx.Error
}
err := json.Unmarshal(m.Value, v)
if err != nil {
return err
}
return nil
}
func (r *repository) setConfig(ctx context.Context, s string, tailnetID uint64, v interface{}) error {
marshal, err := json.Marshal(v)
if err != nil {
return err
}
c := &TailnetConfig{
Key: s,
Value: marshal,
TailnetID: tailnetID,
}
tx := r.withContext(ctx).Save(c)
return tx.Error
}
+2 -2
View File
@@ -128,8 +128,8 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
if err != nil {
return err
}
m.IPv4 = ipv4.String()
m.IPv6 = ipv6.String()
m.IPv4 = domain.IP{IP: ipv4}
m.IPv6 = domain.IP{IP: ipv6}
} else {
registeredTags := authKey.Tags
advertisedTags := domain.SanitizeTags(req.Hostinfo.RequestTags)
+19 -10
View File
@@ -219,12 +219,19 @@ func (h *PollNetMapHandler) createKeepAliveResponse(binder bind.Binder, request
}
func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Binder, request *tailcfg.MapRequest, delta bool, prevSyncedPeerIDs map[uint64]bool) ([]byte, map[uint64]bool, error) {
ctx := context.TODO()
node, err := mapping.ToNode(m, true)
if err != nil {
return nil, nil, err
}
users, err := h.repository.ListUsers(context.TODO(), m.TailnetID)
policies, err := h.repository.GetACLPolicy(ctx, m.TailnetID)
if err != nil {
return nil, nil, err
}
users, err := h.repository.ListUsers(ctx, m.TailnetID)
if err != nil {
return nil, nil, err
}
@@ -232,7 +239,7 @@ func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Bin
var changedPeers []*tailcfg.Node
var removedPeers []tailcfg.NodeID
candidatePeers, err := h.repository.ListMachinePeers(context.TODO(), m.TailnetID, m.MachineKey)
candidatePeers, err := h.repository.ListMachinePeers(ctx, m.TailnetID, m.MachineKey)
if err != nil {
return nil, nil, err
}
@@ -240,25 +247,27 @@ func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Bin
syncedPeerIDs := map[uint64]bool{}
for _, peer := range candidatePeers {
n, err := mapping.ToNode(&peer, h.brokers(peer.TailnetID).IsConnected(peer.ID))
if err != nil {
return nil, nil, err
if domain.IsValidPeer(policies, m, &peer) || domain.IsValidPeer(policies, &peer, m) {
n, err := mapping.ToNode(&peer, h.brokers(peer.TailnetID).IsConnected(peer.ID))
if err != nil {
return nil, nil, err
}
changedPeers = append(changedPeers, n)
syncedPeerIDs[peer.ID] = true
delete(prevSyncedPeerIDs, peer.ID)
}
changedPeers = append(changedPeers, n)
syncedPeerIDs[peer.ID] = true
delete(prevSyncedPeerIDs, peer.ID)
}
for p, _ := range prevSyncedPeerIDs {
removedPeers = append(removedPeers, tailcfg.NodeID(p))
}
derpMap, err := h.repository.GetDERPMap(context.TODO())
derpMap, err := h.repository.GetDERPMap(ctx)
if err != nil {
return nil, nil, err
}
rules := tailcfg.FilterAllowAll
rules := domain.BuildFilterRules(policies, m, candidatePeers)
controlTime := time.Now().UTC()
var mapResponse *tailcfg.MapResponse
+2 -2
View File
@@ -187,8 +187,8 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
if err != nil {
return err
}
m.IPv4 = ipv4.String()
m.IPv6 = ipv6.String()
m.IPv4 = domain.IP{IP: ipv4}
m.IPv6 = domain.IP{IP: ipv6}
} else {
registeredTags := authKey.Tags
advertisedTags := domain.SanitizeTags(req.Hostinfo.RequestTags)
+18 -4
View File
@@ -1,6 +1,7 @@
package mapping
import (
"encoding/json"
"fmt"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/util"
@@ -13,6 +14,19 @@ import (
const NetworkMagicDNSSuffix = "ionscale.net"
func CopyViaJson[F any, T any](f F, t T) error {
raw, err := json.Marshal(f)
if err != nil {
return err
}
if err := json.Unmarshal(raw, t); err != nil {
return err
}
return nil
}
func ToNode(m *domain.Machine, connected bool) (*tailcfg.Node, error) {
nKey, err := util.ParseNodePublicKey(m.NodeKey)
if err != nil {
@@ -39,8 +53,8 @@ func ToNode(m *domain.Machine, connected bool) (*tailcfg.Node, error) {
var addrs []netaddr.IPPrefix
var allowedIPs []netaddr.IPPrefix
if m.IPv4 != "" {
ipv4, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPv4))
if !m.IPv4.IsZero() {
ipv4, err := m.IPv4.Prefix(32)
if err != nil {
return nil, err
}
@@ -48,8 +62,8 @@ func ToNode(m *domain.Machine, connected bool) (*tailcfg.Node, error) {
allowedIPs = append(allowedIPs, ipv4)
}
if m.IPv6 != "" {
ipv6, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/128", m.IPv6))
if !m.IPv6.IsZero() {
ipv6, err := m.IPv6.Prefix(128)
if err != nil {
return nil, err
}
+47
View File
@@ -0,0 +1,47 @@
package service
import (
"context"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/mapping"
"github.com/jsiebens/ionscale/pkg/gen/api"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func (s *Service) GetACLPolicy(ctx context.Context, req *api.GetACLPolicyRequest) (*api.GetACLPolicyResponse, error) {
policy, err := s.repository.GetACLPolicy(ctx, req.TailnetId)
if err != nil {
return nil, err
}
var p api.Policy
if err := mapping.CopyViaJson(policy, &p); err != nil {
return nil, err
}
return &api.GetACLPolicyResponse{Policy: &p}, nil
}
func (s *Service) SetACLPolicy(ctx context.Context, req *api.SetACLPolicyRequest) (*api.SetACLPolicyResponse, error) {
tailnet, err := s.repository.GetTailnet(ctx, req.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, status.Error(codes.NotFound, "tailnet does not exist")
}
var policy domain.ACLPolicy
if err := mapping.CopyViaJson(req.Policy, &policy); err != nil {
return nil, err
}
if err := s.repository.SetACLPolicy(ctx, tailnet.ID, &policy); err != nil {
return nil, err
}
s.brokers(tailnet.ID).SignalACLUpdated()
return &api.SetACLPolicyResponse{}, nil
}
+2 -2
View File
@@ -37,8 +37,8 @@ func (s *Service) ListMachines(ctx context.Context, req *api.ListMachinesRequest
response.Machines = append(response.Machines, &api.Machine{
Id: m.ID,
Name: name,
Ipv4: m.IPv4,
Ipv6: m.IPv6,
Ipv4: m.IPv4.String(),
Ipv6: m.IPv6.String(),
Ephemeral: m.Ephemeral,
Tags: m.Tags,
LastSeen: lastSeen,