fix: log in with different use should create new machine entry

This commit is contained in:
Johan Siebens
2024-02-10 10:04:44 +01:00
parent 46cce89e0e
commit b098562988
7 changed files with 56 additions and 6 deletions
+2 -2
View File
@@ -389,9 +389,9 @@ func (r *repository) GetNextMachineNameIndex(ctx context.Context, tailnetID uint
return m.NameIdx + 1, nil
}
func (r *repository) GetMachineByKey(ctx context.Context, tailnetID uint64, machineKey string) (*Machine, error) {
func (r *repository) GetMachineByKeyAndUser(ctx context.Context, machineKey string, userID uint64) (*Machine, error) {
var m Machine
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").Take(&m, "tailnet_id = ? AND machine_key = ?", tailnetID, machineKey)
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").Take(&m, "machine_key = ? AND user_id = ?", machineKey, userID)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
+1 -1
View File
@@ -58,7 +58,7 @@ type Repository interface {
SaveMachine(ctx context.Context, m *Machine) error
DeleteMachine(ctx context.Context, id uint64) (bool, error)
GetMachine(ctx context.Context, id uint64) (*Machine, error)
GetMachineByKey(ctx context.Context, tailnetID uint64, key string) (*Machine, error)
GetMachineByKeyAndUser(ctx context.Context, key string, userID uint64) (*Machine, error)
GetMachineByKeys(ctx context.Context, machineKey string, nodeKey string) (*Machine, error)
CountMachinesWithIPv4(ctx context.Context, ip string) (int64, error)
GetNextMachineNameIndex(ctx context.Context, tailnetID uint64, name string) (uint64, error)
+1 -1
View File
@@ -446,7 +446,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, form
var m *domain.Machine
m, err := h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
m, err := h.repository.GetMachineByKeyAndUser(ctx, machineKey, user.ID)
if err != nil {
return logError(err)
}
+1 -1
View File
@@ -173,7 +173,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, ma
var m *domain.Machine
m, err = h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
m, err = h.repository.GetMachineByKeyAndUser(ctx, machineKey, user.ID)
if err != nil {
return logError(err)
}
+43
View File
@@ -0,0 +1,43 @@
package tests
import (
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/jsiebens/ionscale/tests/sc"
"github.com/jsiebens/ionscale/tests/tsn"
"github.com/stretchr/testify/require"
"net/http"
"testing"
)
func TestSwitchAccounts(t *testing.T) {
sc.Run(t, func(s *sc.Scenario) {
s.PushOIDCUser("123", "john@localtest.me", "john")
s.PushOIDCUser("124", "jane@localtest.me", "jane")
tailnet := s.CreateTailnet()
s.SetIAMPolicy(tailnet.Id, &api.IAMPolicy{Filters: []string{"domain == localtest.me"}})
node := s.NewTailscaleNode(sc.WithName("switch"))
code, err := node.LoginWithOidc()
require.NoError(t, err)
require.Equal(t, http.StatusOK, code)
require.NoError(t, node.WaitFor(tsn.Connected()))
require.NoError(t, node.Check(tsn.HasUser("john@localtest.me")))
require.NoError(t, node.Check(tsn.HasName("switch")))
code, err = node.LoginWithOidc()
require.NoError(t, err)
require.Equal(t, http.StatusOK, code)
require.NoError(t, node.WaitFor(tsn.Connected()))
require.NoError(t, node.Check(tsn.HasUser("jane@localtest.me")))
require.NoError(t, node.Check(tsn.HasName("switch-1")))
machines := s.ListMachines(tailnet.Id)
require.Equal(t, 2, len(machines))
require.Equal(t, "switch", machines[0].Name)
require.Equal(t, "switch-1", machines[1].Name)
})
}
+7
View File
@@ -2,6 +2,7 @@ package tsn
import (
"slices"
"strings"
"tailscale.com/ipn/ipnstate"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
@@ -27,6 +28,12 @@ func HasTag(tag string) Condition {
}
}
func HasName(name string) Condition {
return func(status *ipnstate.Status) bool {
return status.Self != nil && strings.HasPrefix(status.Self.DNSName, name)
}
}
func NeedsMachineAuth() Condition {
return func(status *ipnstate.Status) bool {
return status.BackendState == "NeedsMachineAuth"
+1 -1
View File
@@ -47,7 +47,7 @@ func (t *TailscaleNode) LoginWithOidc(flags ...UpFlag) (int, error) {
return strings.Contains(stderr, "To authenticate, visit:")
}
cmd := []string{"up", "--login-server", t.loginServer}
cmd := []string{"login", "--login-server", t.loginServer}
for _, f := range flags {
cmd = append(cmd, f...)
}