chore: add acl policy as field of tailnet

This commit is contained in:
Johan Siebens
2022-06-10 15:49:07 +02:00
parent a94e0ce9b8
commit 8e8646b757
7 changed files with 54 additions and 45 deletions
+37 -9
View File
@@ -1,7 +1,11 @@
package domain
import (
"database/sql/driver"
"encoding/json"
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"inet.af/netaddr"
"strconv"
"strings"
@@ -20,7 +24,35 @@ type ACL struct {
Dst []string `json:"dst"`
}
func defaultPolicy() ACLPolicy {
func (i *ACLPolicy) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, i)
default:
return fmt.Errorf("unexpected data type %T", destination)
}
}
func (i ACLPolicy) Value() (driver.Value, error) {
bytes, err := json.Marshal(i)
return bytes, err
}
// GormDataType gorm common data type
func (ACLPolicy) GormDataType() string {
return "json"
}
// GormDBDataType gorm db data type
func (ACLPolicy) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
}
return ""
}
func DefaultPolicy() ACLPolicy {
return ACLPolicy{
ACLs: []ACL{
{
@@ -37,17 +69,13 @@ type aclEngine struct {
expandedTags map[string][]string
}
func IsValidPeer(policy *ACLPolicy, src *Machine, dest *Machine) bool {
f := &aclEngine{
policy: policy,
}
func (p ACLPolicy) IsValidPeer(src *Machine, dest *Machine) bool {
f := &aclEngine{policy: &p}
return f.isValidPeer(src, dest)
}
func BuildFilterRules(policy *ACLPolicy, dst *Machine, peers []Machine) []tailcfg.FilterRule {
f := &aclEngine{
policy: policy,
}
func (p ACLPolicy) BuildFilterRules(dst *Machine, peers []Machine) []tailcfg.FilterRule {
f := &aclEngine{policy: &p}
return f.build(dst, peers)
}
+5 -3
View File
@@ -30,9 +30,11 @@ type Repository interface {
GetDNSConfig(ctx context.Context, tailnetID uint64) (*DNSConfig, error)
SetDNSConfig(ctx context.Context, tailnetID uint64, config *DNSConfig) error
DeleteDNSConfig(ctx context.Context, tailnetID uint64) error
GetACLPolicy(ctx context.Context, tailnetID uint64) (*ACLPolicy, error)
SetACLPolicy(ctx context.Context, tailnetID uint64, policy *ACLPolicy) error
DeleteACLPolicy(ctx context.Context, tailnetID uint64) error
/*
GetACLPolicy(ctx context.Context, tailnetID uint64) (*ACLPolicy, error)
SetACLPolicy(ctx context.Context, tailnetID uint64, policy *ACLPolicy) error
DeleteACLPolicy(ctx context.Context, tailnetID uint64) error
*/
SaveApiKey(ctx context.Context, key *ApiKey) error
LoadApiKey(ctx context.Context, key string) (*ApiKey, error)
+5 -1
View File
@@ -11,6 +11,7 @@ type Tailnet struct {
ID uint64 `gorm:"primary_key;autoIncrement:false"`
Name string `gorm:"type:varchar(64);unique_index"`
IAMPolicy IAMPolicy
ACLPolicy ACLPolicy
}
func (r *repository) SaveTailnet(ctx context.Context, tailnet *Tailnet) error {
@@ -27,7 +28,10 @@ func (r *repository) GetOrCreateTailnet(ctx context.Context, name string) (*Tail
tailnet := &Tailnet{}
id := util.NextID()
tx := r.withContext(ctx).Where(Tailnet{Name: name}).Attrs(Tailnet{ID: id}).FirstOrCreate(tailnet)
tx := r.withContext(ctx).
Where(Tailnet{Name: name}).
Attrs(Tailnet{ID: id, ACLPolicy: DefaultPolicy()}).
FirstOrCreate(tailnet)
if tx.Error != nil {
return nil, false, tx.Error
-20
View File
@@ -37,26 +37,6 @@ func (r *repository) DeleteDNSConfig(ctx context.Context, tailnetID uint64) erro
return r.deleteConfig(ctx, "dns_config", tailnetID)
}
func (r *repository) SetACLPolicy(ctx context.Context, tailnetID uint64, policy *ACLPolicy) error {
if err := r.setConfig(ctx, "acl_policy", tailnetID, policy); err != nil {
return err
}
return nil
}
func (r *repository) GetACLPolicy(ctx context.Context, tailnetID uint64) (*ACLPolicy, error) {
var p = defaultPolicy()
err := r.getConfig(ctx, "acl_policy", tailnetID, &p)
if err != nil {
return nil, err
}
return &p, nil
}
func (r *repository) DeleteACLPolicy(ctx context.Context, tailnetID uint64) error {
return r.deleteConfig(ctx, "acl_policy", tailnetID)
}
func (r *repository) getConfig(ctx context.Context, s string, tailnetID uint64, v interface{}) error {
var m TailnetConfig
tx := r.withContext(ctx).Take(&m, "key = ? AND tailnet_id = ?", s, tailnetID)
+4 -3
View File
@@ -224,11 +224,12 @@ func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Bin
return nil, nil, err
}
policies, err := h.repository.GetACLPolicy(ctx, m.TailnetID)
tailnet, err := h.repository.GetTailnet(ctx, m.TailnetID)
if err != nil {
return nil, nil, err
}
policies := tailnet.ACLPolicy
var users = []tailcfg.UserProfile{*user}
var changedPeers []*tailcfg.Node
var removedPeers []tailcfg.NodeID
@@ -245,7 +246,7 @@ func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Bin
if peer.IsExpired() {
continue
}
if domain.IsValidPeer(policies, m, &peer) || domain.IsValidPeer(policies, &peer, m) {
if policies.IsValidPeer(m, &peer) || policies.IsValidPeer(&peer, m) {
n, u, err := mapping.ToNode(&peer, h.brokers(peer.TailnetID).IsConnected(peer.ID))
if err != nil {
return nil, nil, err
@@ -275,7 +276,7 @@ func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Bin
return nil, nil, err
}
rules := domain.BuildFilterRules(policies, m, candidatePeers)
rules := policies.BuildFilterRules(m, candidatePeers)
controlTime := time.Now().UTC()
var mapResponse *tailcfg.MapResponse
+3 -5
View File
@@ -24,10 +24,7 @@ func (s *Service) GetACLPolicy(ctx context.Context, req *connect.Request[api.Get
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
}
policy, err := s.repository.GetACLPolicy(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
policy := tailnet.ACLPolicy
marshal, err := json.Marshal(policy)
if err != nil {
@@ -56,7 +53,8 @@ func (s *Service) SetACLPolicy(ctx context.Context, req *connect.Request[api.Set
return nil, err
}
if err := s.repository.SetACLPolicy(ctx, tailnet.ID, &policy); err != nil {
tailnet.ACLPolicy = policy
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
}
-4
View File
@@ -109,10 +109,6 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.De
return err
}
if err := tx.DeleteACLPolicy(ctx, req.Msg.TailnetId); err != nil {
return err
}
if err := tx.DeleteDNSConfig(ctx, req.Msg.TailnetId); err != nil {
return err
}