mirror of
https://github.com/jsiebens/ionscale.git
synced 2026-03-31 15:07:49 +01:00
chore: add acl policy as field of tailnet
This commit is contained in:
+37
-9
@@ -1,7 +1,11 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
"inet.af/netaddr"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -20,7 +24,35 @@ type ACL struct {
|
||||
Dst []string `json:"dst"`
|
||||
}
|
||||
|
||||
func defaultPolicy() ACLPolicy {
|
||||
func (i *ACLPolicy) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(value, i)
|
||||
default:
|
||||
return fmt.Errorf("unexpected data type %T", destination)
|
||||
}
|
||||
}
|
||||
|
||||
func (i ACLPolicy) Value() (driver.Value, error) {
|
||||
bytes, err := json.Marshal(i)
|
||||
return bytes, err
|
||||
}
|
||||
|
||||
// GormDataType gorm common data type
|
||||
func (ACLPolicy) GormDataType() string {
|
||||
return "json"
|
||||
}
|
||||
|
||||
// GormDBDataType gorm db data type
|
||||
func (ACLPolicy) GormDBDataType(db *gorm.DB, field *schema.Field) string {
|
||||
switch db.Dialector.Name() {
|
||||
case "sqlite":
|
||||
return "JSON"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func DefaultPolicy() ACLPolicy {
|
||||
return ACLPolicy{
|
||||
ACLs: []ACL{
|
||||
{
|
||||
@@ -37,17 +69,13 @@ type aclEngine struct {
|
||||
expandedTags map[string][]string
|
||||
}
|
||||
|
||||
func IsValidPeer(policy *ACLPolicy, src *Machine, dest *Machine) bool {
|
||||
f := &aclEngine{
|
||||
policy: policy,
|
||||
}
|
||||
func (p ACLPolicy) IsValidPeer(src *Machine, dest *Machine) bool {
|
||||
f := &aclEngine{policy: &p}
|
||||
return f.isValidPeer(src, dest)
|
||||
}
|
||||
|
||||
func BuildFilterRules(policy *ACLPolicy, dst *Machine, peers []Machine) []tailcfg.FilterRule {
|
||||
f := &aclEngine{
|
||||
policy: policy,
|
||||
}
|
||||
func (p ACLPolicy) BuildFilterRules(dst *Machine, peers []Machine) []tailcfg.FilterRule {
|
||||
f := &aclEngine{policy: &p}
|
||||
return f.build(dst, peers)
|
||||
}
|
||||
|
||||
|
||||
@@ -30,9 +30,11 @@ type Repository interface {
|
||||
GetDNSConfig(ctx context.Context, tailnetID uint64) (*DNSConfig, error)
|
||||
SetDNSConfig(ctx context.Context, tailnetID uint64, config *DNSConfig) error
|
||||
DeleteDNSConfig(ctx context.Context, tailnetID uint64) error
|
||||
GetACLPolicy(ctx context.Context, tailnetID uint64) (*ACLPolicy, error)
|
||||
SetACLPolicy(ctx context.Context, tailnetID uint64, policy *ACLPolicy) error
|
||||
DeleteACLPolicy(ctx context.Context, tailnetID uint64) error
|
||||
/*
|
||||
GetACLPolicy(ctx context.Context, tailnetID uint64) (*ACLPolicy, error)
|
||||
SetACLPolicy(ctx context.Context, tailnetID uint64, policy *ACLPolicy) error
|
||||
DeleteACLPolicy(ctx context.Context, tailnetID uint64) error
|
||||
*/
|
||||
|
||||
SaveApiKey(ctx context.Context, key *ApiKey) error
|
||||
LoadApiKey(ctx context.Context, key string) (*ApiKey, error)
|
||||
|
||||
@@ -11,6 +11,7 @@ type Tailnet struct {
|
||||
ID uint64 `gorm:"primary_key;autoIncrement:false"`
|
||||
Name string `gorm:"type:varchar(64);unique_index"`
|
||||
IAMPolicy IAMPolicy
|
||||
ACLPolicy ACLPolicy
|
||||
}
|
||||
|
||||
func (r *repository) SaveTailnet(ctx context.Context, tailnet *Tailnet) error {
|
||||
@@ -27,7 +28,10 @@ func (r *repository) GetOrCreateTailnet(ctx context.Context, name string) (*Tail
|
||||
tailnet := &Tailnet{}
|
||||
id := util.NextID()
|
||||
|
||||
tx := r.withContext(ctx).Where(Tailnet{Name: name}).Attrs(Tailnet{ID: id}).FirstOrCreate(tailnet)
|
||||
tx := r.withContext(ctx).
|
||||
Where(Tailnet{Name: name}).
|
||||
Attrs(Tailnet{ID: id, ACLPolicy: DefaultPolicy()}).
|
||||
FirstOrCreate(tailnet)
|
||||
|
||||
if tx.Error != nil {
|
||||
return nil, false, tx.Error
|
||||
|
||||
@@ -37,26 +37,6 @@ func (r *repository) DeleteDNSConfig(ctx context.Context, tailnetID uint64) erro
|
||||
return r.deleteConfig(ctx, "dns_config", tailnetID)
|
||||
}
|
||||
|
||||
func (r *repository) SetACLPolicy(ctx context.Context, tailnetID uint64, policy *ACLPolicy) error {
|
||||
if err := r.setConfig(ctx, "acl_policy", tailnetID, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) GetACLPolicy(ctx context.Context, tailnetID uint64) (*ACLPolicy, error) {
|
||||
var p = defaultPolicy()
|
||||
err := r.getConfig(ctx, "acl_policy", tailnetID, &p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func (r *repository) DeleteACLPolicy(ctx context.Context, tailnetID uint64) error {
|
||||
return r.deleteConfig(ctx, "acl_policy", tailnetID)
|
||||
}
|
||||
|
||||
func (r *repository) getConfig(ctx context.Context, s string, tailnetID uint64, v interface{}) error {
|
||||
var m TailnetConfig
|
||||
tx := r.withContext(ctx).Take(&m, "key = ? AND tailnet_id = ?", s, tailnetID)
|
||||
|
||||
@@ -224,11 +224,12 @@ func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Bin
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
policies, err := h.repository.GetACLPolicy(ctx, m.TailnetID)
|
||||
tailnet, err := h.repository.GetTailnet(ctx, m.TailnetID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
policies := tailnet.ACLPolicy
|
||||
var users = []tailcfg.UserProfile{*user}
|
||||
var changedPeers []*tailcfg.Node
|
||||
var removedPeers []tailcfg.NodeID
|
||||
@@ -245,7 +246,7 @@ func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Bin
|
||||
if peer.IsExpired() {
|
||||
continue
|
||||
}
|
||||
if domain.IsValidPeer(policies, m, &peer) || domain.IsValidPeer(policies, &peer, m) {
|
||||
if policies.IsValidPeer(m, &peer) || policies.IsValidPeer(&peer, m) {
|
||||
n, u, err := mapping.ToNode(&peer, h.brokers(peer.TailnetID).IsConnected(peer.ID))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -275,7 +276,7 @@ func (h *PollNetMapHandler) createMapResponse(m *domain.Machine, binder bind.Bin
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rules := domain.BuildFilterRules(policies, m, candidatePeers)
|
||||
rules := policies.BuildFilterRules(m, candidatePeers)
|
||||
|
||||
controlTime := time.Now().UTC()
|
||||
var mapResponse *tailcfg.MapResponse
|
||||
|
||||
@@ -24,10 +24,7 @@ func (s *Service) GetACLPolicy(ctx context.Context, req *connect.Request[api.Get
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
|
||||
}
|
||||
|
||||
policy, err := s.repository.GetACLPolicy(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
policy := tailnet.ACLPolicy
|
||||
|
||||
marshal, err := json.Marshal(policy)
|
||||
if err != nil {
|
||||
@@ -56,7 +53,8 @@ func (s *Service) SetACLPolicy(ctx context.Context, req *connect.Request[api.Set
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.repository.SetACLPolicy(ctx, tailnet.ID, &policy); err != nil {
|
||||
tailnet.ACLPolicy = policy
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -109,10 +109,6 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.De
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.DeleteACLPolicy(ctx, req.Msg.TailnetId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.DeleteDNSConfig(ctx, req.Msg.TailnetId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user