From 8e8646b7571582c5e2fedb90d5093de8e39ac1bc Mon Sep 17 00:00:00 2001 From: Johan Siebens Date: Fri, 10 Jun 2022 15:49:07 +0200 Subject: [PATCH] chore: add acl policy as field of tailnet --- internal/domain/acl.go | 46 +++++++++++++++++++++++++------ internal/domain/repository.go | 8 ++++-- internal/domain/tailnet.go | 6 +++- internal/domain/tailnet_config.go | 20 -------------- internal/handlers/poll_net_map.go | 7 +++-- internal/service/acl.go | 8 ++---- internal/service/tailnet.go | 4 --- 7 files changed, 54 insertions(+), 45 deletions(-) diff --git a/internal/domain/acl.go b/internal/domain/acl.go index 3368ec8..aaa9663 100644 --- a/internal/domain/acl.go +++ b/internal/domain/acl.go @@ -1,7 +1,11 @@ package domain import ( + "database/sql/driver" + "encoding/json" "fmt" + "gorm.io/gorm" + "gorm.io/gorm/schema" "inet.af/netaddr" "strconv" "strings" @@ -20,7 +24,35 @@ type ACL struct { Dst []string `json:"dst"` } -func defaultPolicy() ACLPolicy { +func (i *ACLPolicy) Scan(destination interface{}) error { + switch value := destination.(type) { + case []byte: + return json.Unmarshal(value, i) + default: + return fmt.Errorf("unexpected data type %T", destination) + } +} + +func (i ACLPolicy) Value() (driver.Value, error) { + bytes, err := json.Marshal(i) + return bytes, err +} + +// GormDataType gorm common data type +func (ACLPolicy) GormDataType() string { + return "json" +} + +// GormDBDataType gorm db data type +func (ACLPolicy) GormDBDataType(db *gorm.DB, field *schema.Field) string { + switch db.Dialector.Name() { + case "sqlite": + return "JSON" + } + return "" +} + +func DefaultPolicy() ACLPolicy { return ACLPolicy{ ACLs: []ACL{ { @@ -37,17 +69,13 @@ type aclEngine struct { expandedTags map[string][]string } -func IsValidPeer(policy *ACLPolicy, src *Machine, dest *Machine) bool { - f := &aclEngine{ - policy: policy, - } +func (p ACLPolicy) IsValidPeer(src *Machine, dest *Machine) bool { + f := &aclEngine{policy: &p} return f.isValidPeer(src, dest) } -func BuildFilterRules(policy *ACLPolicy, dst *Machine, peers []Machine) []tailcfg.FilterRule { - f := &aclEngine{ - policy: policy, - } +func (p ACLPolicy) BuildFilterRules(dst *Machine, peers []Machine) []tailcfg.FilterRule { + f := &aclEngine{policy: &p} return f.build(dst, peers) } diff --git a/internal/domain/repository.go b/internal/domain/repository.go index 165d773..71bc38c 100644 --- a/internal/domain/repository.go +++ b/internal/domain/repository.go @@ -30,9 +30,11 @@ type Repository interface { GetDNSConfig(ctx context.Context, tailnetID uint64) (*DNSConfig, error) SetDNSConfig(ctx context.Context, tailnetID uint64, config *DNSConfig) error DeleteDNSConfig(ctx context.Context, tailnetID uint64) error - GetACLPolicy(ctx context.Context, tailnetID uint64) (*ACLPolicy, error) - SetACLPolicy(ctx context.Context, tailnetID uint64, policy *ACLPolicy) error - DeleteACLPolicy(ctx context.Context, tailnetID uint64) error + /* + GetACLPolicy(ctx context.Context, tailnetID uint64) (*ACLPolicy, error) + SetACLPolicy(ctx context.Context, tailnetID uint64, policy *ACLPolicy) error + DeleteACLPolicy(ctx context.Context, tailnetID uint64) error + */ SaveApiKey(ctx context.Context, key *ApiKey) error LoadApiKey(ctx context.Context, key string) (*ApiKey, error) diff --git a/internal/domain/tailnet.go b/internal/domain/tailnet.go index 8e5fe50..22800ff 100644 --- a/internal/domain/tailnet.go +++ b/internal/domain/tailnet.go @@ -11,6 +11,7 @@ type Tailnet struct { ID uint64 `gorm:"primary_key;autoIncrement:false"` Name string `gorm:"type:varchar(64);unique_index"` IAMPolicy IAMPolicy + ACLPolicy ACLPolicy } func (r *repository) SaveTailnet(ctx context.Context, tailnet *Tailnet) error { @@ -27,7 +28,10 @@ func (r *repository) GetOrCreateTailnet(ctx context.Context, name string) (*Tail tailnet := &Tailnet{} id := util.NextID() - tx := r.withContext(ctx).Where(Tailnet{Name: name}).Attrs(Tailnet{ID: id}).FirstOrCreate(tailnet) + tx := r.withContext(ctx). + Where(Tailnet{Name: name}). + Attrs(Tailnet{ID: id, ACLPolicy: DefaultPolicy()}). + FirstOrCreate(tailnet) if tx.Error != nil { return nil, false, tx.Error diff --git a/internal/domain/tailnet_config.go b/internal/domain/tailnet_config.go index c72ef02..8f50d33 100644 --- a/internal/domain/tailnet_config.go +++ b/internal/domain/tailnet_config.go @@ -37,26 +37,6 @@ func (r *repository) DeleteDNSConfig(ctx context.Context, tailnetID uint64) erro return r.deleteConfig(ctx, "dns_config", tailnetID) } -func (r *repository) SetACLPolicy(ctx context.Context, tailnetID uint64, policy *ACLPolicy) error { - if err := r.setConfig(ctx, "acl_policy", tailnetID, policy); err != nil { - return err - } - return nil -} - -func (r *repository) GetACLPolicy(ctx context.Context, tailnetID uint64) (*ACLPolicy, error) { - var p = defaultPolicy() - err := r.getConfig(ctx, "acl_policy", tailnetID, &p) - if err != nil { - return nil, err - } - return &p, nil -} - -func (r *repository) DeleteACLPolicy(ctx context.Context, tailnetID uint64) error { - return r.deleteConfig(ctx, "acl_policy", tailnetID) -} - func (r *repository) getConfig(ctx context.Context, s string, tailnetID uint64, v interface{}) error { var m TailnetConfig tx := r.withContext(ctx).Take(&m, "key = ? AND tailnet_id = ?", s, tailnetID) diff --git a/internal/handlers/poll_net_map.go b/internal/handlers/poll_net_map.go index e00bf87..18bbaba 100644 --- a/internal/handlers/poll_net_map.go +++ b/internal/handlers/poll_net_map.go @@ -224,11 +224,12 @@ func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Bin return nil, nil, err } - policies, err := h.repository.GetACLPolicy(ctx, m.TailnetID) + tailnet, err := h.repository.GetTailnet(ctx, m.TailnetID) if err != nil { return nil, nil, err } + policies := tailnet.ACLPolicy var users = []tailcfg.UserProfile{*user} var changedPeers []*tailcfg.Node var removedPeers []tailcfg.NodeID @@ -245,7 +246,7 @@ func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Bin if peer.IsExpired() { continue } - if domain.IsValidPeer(policies, m, &peer) || domain.IsValidPeer(policies, &peer, m) { + if policies.IsValidPeer(m, &peer) || policies.IsValidPeer(&peer, m) { n, u, err := mapping.ToNode(&peer, h.brokers(peer.TailnetID).IsConnected(peer.ID)) if err != nil { return nil, nil, err @@ -275,7 +276,7 @@ func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Bin return nil, nil, err } - rules := domain.BuildFilterRules(policies, m, candidatePeers) + rules := policies.BuildFilterRules(m, candidatePeers) controlTime := time.Now().UTC() var mapResponse *tailcfg.MapResponse diff --git a/internal/service/acl.go b/internal/service/acl.go index 58978f5..690c539 100644 --- a/internal/service/acl.go +++ b/internal/service/acl.go @@ -24,10 +24,7 @@ func (s *Service) GetACLPolicy(ctx context.Context, req *connect.Request[api.Get return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist")) } - policy, err := s.repository.GetACLPolicy(ctx, req.Msg.TailnetId) - if err != nil { - return nil, err - } + policy := tailnet.ACLPolicy marshal, err := json.Marshal(policy) if err != nil { @@ -56,7 +53,8 @@ func (s *Service) SetACLPolicy(ctx context.Context, req *connect.Request[api.Set return nil, err } - if err := s.repository.SetACLPolicy(ctx, tailnet.ID, &policy); err != nil { + tailnet.ACLPolicy = policy + if err := s.repository.SaveTailnet(ctx, tailnet); err != nil { return nil, err } diff --git a/internal/service/tailnet.go b/internal/service/tailnet.go index 34e2177..7d14b23 100644 --- a/internal/service/tailnet.go +++ b/internal/service/tailnet.go @@ -109,10 +109,6 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.De return err } - if err := tx.DeleteACLPolicy(ctx, req.Msg.TailnetId); err != nil { - return err - } - if err := tx.DeleteDNSConfig(ctx, req.Msg.TailnetId); err != nil { return err }