fix: only expire machines from tailnet of the auth filter

This commit is contained in:
Johan Siebens
2022-06-01 16:16:09 +02:00
parent b9b42d8342
commit 687fcd16d1
3 changed files with 4 additions and 4 deletions
+2 -2
View File
@@ -372,14 +372,14 @@ func (r *repository) SetMachineLastSeen(ctx context.Context, machineID uint64) e
return nil return nil
} }
func (r *repository) ExpireMachineByAuthMethod(ctx context.Context, authMethodID uint64) (int64, error) { func (r *repository) ExpireMachineByAuthMethod(ctx context.Context, tailnetID, authMethodID uint64) (int64, error) {
now := time.Now().UTC() now := time.Now().UTC()
subQuery := r.withContext(ctx). subQuery := r.withContext(ctx).
Select("machines.id"). Select("machines.id").
Table("machines"). Table("machines").
Joins("JOIN users u on u.id = machines.user_id JOIN accounts a on a.id = u.account_id"). Joins("JOIN users u on u.id = machines.user_id JOIN accounts a on a.id = u.account_id").
Where("a.auth_method_id = ?", authMethodID) Where("machines.tailnet_id = ? AND a.auth_method_id = ?", tailnetID, authMethodID)
tx := r.withContext(ctx). tx := r.withContext(ctx).
Table("machines"). Table("machines").
+1 -1
View File
@@ -65,7 +65,7 @@ type Repository interface {
ListMachinePeers(ctx context.Context, tailnetID uint64, key string) (Machines, error) ListMachinePeers(ctx context.Context, tailnetID uint64, key string) (Machines, error)
ListInactiveEphemeralMachines(ctx context.Context, checkpoint time.Time) (Machines, error) ListInactiveEphemeralMachines(ctx context.Context, checkpoint time.Time) (Machines, error)
SetMachineLastSeen(ctx context.Context, machineID uint64) error SetMachineLastSeen(ctx context.Context, machineID uint64) error
ExpireMachineByAuthMethod(ctx context.Context, authMethodID uint64) (int64, error) ExpireMachineByAuthMethod(ctx context.Context, tailnetID, authMethodID uint64) (int64, error)
SaveRegistrationRequest(ctx context.Context, request *RegistrationRequest) error SaveRegistrationRequest(ctx context.Context, request *RegistrationRequest) error
GetRegistrationRequestByKey(ctx context.Context, key string) (*RegistrationRequest, error) GetRegistrationRequestByKey(ctx context.Context, key string) (*RegistrationRequest, error)
+1 -1
View File
@@ -96,7 +96,7 @@ func (s *Service) DeleteAuthFilter(ctx context.Context, req *api.DeleteAuthFilte
return status.Error(codes.NotFound, "auth filter not found") return status.Error(codes.NotFound, "auth filter not found")
} }
c, err := rp.ExpireMachineByAuthMethod(ctx, filter.AuthMethodID) c, err := rp.ExpireMachineByAuthMethod(ctx, *filter.TailnetID, filter.AuthMethodID)
if err != nil { if err != nil {
return err return err
} }