feat: delete auth filters

This commit is contained in:
Johan Siebens
2022-05-26 16:43:38 +02:00
parent 198b6795b1
commit 2b5439bd60
13 changed files with 488 additions and 186 deletions
+2 -2
View File
@@ -23,7 +23,7 @@ type Broker interface {
AddClient(*Client)
RemoveClient(uint64)
SignalTailnedDeleted()
SignalUpdate()
SignalPeerUpdated(id uint64)
SignalPeersRemoved([]uint64)
SignalDNSUpdated()
@@ -94,7 +94,7 @@ func (h *broker) RemoveClient(id uint64) {
h.closingClients <- id
}
func (h *broker) SignalTailnedDeleted() {
func (h *broker) SignalUpdate() {
h.signalChannel <- &Signal{}
}
+37
View File
@@ -20,6 +20,7 @@ func authFilterCommand() *coral.Command {
command.AddCommand(createAuthFilterCommand())
command.AddCommand(listAuthFilterCommand())
command.AddCommand(deleteAuthFilterCommand())
return command
}
@@ -141,3 +142,39 @@ func createAuthFilterCommand() *coral.Command {
return command
}
func deleteAuthFilterCommand() *coral.Command {
command := &coral.Command{
Use: "delete",
SilenceUsage: true,
}
var authFilterID uint64
var target = Target{}
target.prepareCommand(command)
command.Flags().Uint64Var(&authFilterID, "auth-filter-id", 0, "")
command.RunE = func(command *coral.Command, args []string) error {
client, c, err := target.createGRPCClient()
if err != nil {
return err
}
defer safeClose(c)
req := &api.DeleteAuthFilterRequest{
AuthFilterId: authFilterID,
}
_, err = client.DeleteAuthFilter(context.Background(), req)
if err != nil {
return err
}
return nil
}
return command
}
+26
View File
@@ -5,6 +5,7 @@ import (
"errors"
"github.com/hashicorp/go-bexpr"
"github.com/mitchellh/pointerstructure"
"gorm.io/gorm"
)
type AuthFilter struct {
@@ -56,6 +57,21 @@ func (fs AuthFilters) Evaluate(v interface{}) []Tailnet {
return tailnets
}
func (r *repository) GetAuthFilter(ctx context.Context, id uint64) (*AuthFilter, error) {
var t AuthFilter
tx := r.withContext(ctx).Take(&t, "id = ?", id)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
if tx.Error != nil {
return nil, tx.Error
}
return &t, nil
}
func (r *repository) SaveAuthFilter(ctx context.Context, m *AuthFilter) error {
tx := r.withContext(ctx).Save(m)
@@ -95,3 +111,13 @@ func (r *repository) ListAuthFiltersByAuthMethod(ctx context.Context, authMethod
return filters, nil
}
func (r *repository) DeleteAuthFilter(ctx context.Context, id uint64) error {
tx := r.withContext(ctx).Delete(&AuthFilter{ID: id})
return tx.Error
}
func (r *repository) DeleteAuthFiltersByTailnet(ctx context.Context, tailnetID uint64) error {
tx := r.withContext(ctx).Where("tailnet_id = ?", tailnetID).Delete(&AuthFilter{})
return tx.Error
}
+21
View File
@@ -371,3 +371,24 @@ func (r *repository) SetMachineLastSeen(ctx context.Context, machineID uint64) e
return nil
}
func (r *repository) ExpireMachineByAuthMethod(ctx context.Context, authMethodID uint64) (int64, error) {
now := time.Now().UTC()
subQuery := r.withContext(ctx).
Select("machines.id").
Table("machines").
Joins("JOIN users u on u.id = machines.user_id JOIN accounts a on a.id = u.account_id").
Where("a.auth_method_id = ?", authMethodID)
tx := r.withContext(ctx).
Table("machines").
Where("tags = '' AND (expires_at is null or expires_at > ?) AND id in (?)", &now, subQuery).
Updates(map[string]interface{}{"expires_at": &now})
if tx.Error != nil {
return 0, tx.Error
}
return tx.RowsAffected, nil
}
+4
View File
@@ -18,9 +18,12 @@ type Repository interface {
ListAuthMethods(ctx context.Context) ([]AuthMethod, error)
GetAuthMethod(ctx context.Context, id uint64) (*AuthMethod, error)
GetAuthFilter(ctx context.Context, id uint64) (*AuthFilter, error)
SaveAuthFilter(ctx context.Context, m *AuthFilter) error
ListAuthFilters(ctx context.Context) (AuthFilters, error)
ListAuthFiltersByAuthMethod(ctx context.Context, authMethodID uint64) (AuthFilters, error)
DeleteAuthFilter(ctx context.Context, id uint64) error
DeleteAuthFiltersByTailnet(ctx context.Context, tailnetID uint64) error
GetAccount(ctx context.Context, accountID uint64) (*Account, error)
GetOrCreateAccount(ctx context.Context, authMethodID uint64, externalID, loginName string) (*Account, bool, error)
@@ -62,6 +65,7 @@ type Repository interface {
ListMachinePeers(ctx context.Context, tailnetID uint64, key string) (Machines, error)
ListInactiveEphemeralMachines(ctx context.Context, checkpoint time.Time) (Machines, error)
SetMachineLastSeen(ctx context.Context, machineID uint64) error
ExpireMachineByAuthMethod(ctx context.Context, authMethodID uint64) (int64, error)
Transaction(func(rp Repository) error) error
}
+3
View File
@@ -24,6 +24,9 @@ func (i *Tags) Scan(destination interface{}) error {
}
func (i Tags) Value() (driver.Value, error) {
if len(i) == 0 {
return "", nil
}
v := "|" + strings.Join(i, "|") + "|"
return v, nil
}
+38
View File
@@ -83,6 +83,44 @@ func (s *Service) CreateAuthFilter(ctx context.Context, req *api.CreateAuthFilte
return &response, nil
}
func (s *Service) DeleteAuthFilter(ctx context.Context, req *api.DeleteAuthFilterRequest) (*api.DeleteAuthFilterResponse, error) {
err := s.repository.Transaction(func(rp domain.Repository) error {
filter, err := rp.GetAuthFilter(ctx, req.AuthFilterId)
if err != nil {
return err
}
if filter == nil {
return status.Error(codes.NotFound, "auth filter not found")
}
c, err := rp.ExpireMachineByAuthMethod(ctx, filter.AuthMethodID)
if err != nil {
return err
}
if err := rp.DeleteAuthFilter(ctx, filter.ID); err != nil {
return err
}
if c != 0 {
s.brokers(*filter.TailnetID).SignalUpdate()
}
return nil
})
if err != nil {
return nil, err
}
response := api.DeleteAuthFilterResponse{}
return &response, nil
}
func (s *Service) mapToApi(authMethod *domain.AuthMethod, filter domain.AuthFilter) *api.AuthFilter {
result := api.AuthFilter{
Id: filter.ID,
+5 -1
View File
@@ -81,6 +81,10 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *api.DeleteTailnetReque
return err
}
if err := tx.DeleteAuthFiltersByTailnet(ctx, req.TailnetId); err != nil {
return err
}
if err := tx.DeleteACLPolicy(ctx, req.TailnetId); err != nil {
return err
}
@@ -100,7 +104,7 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *api.DeleteTailnetReque
return nil, err
}
s.brokers(req.TailnetId).SignalTailnedDeleted()
s.brokers(req.TailnetId).SignalUpdate()
return &api.DeleteTailnetResponse{}, nil
}