diff --git a/internal/handlers/poll_net_map.go b/internal/handlers/poll_net_map.go index 207d1ac..4d338e0 100644 --- a/internal/handlers/poll_net_map.go +++ b/internal/handlers/poll_net_map.go @@ -11,6 +11,7 @@ import ( "github.com/klauspost/compress/zstd" "github.com/labstack/echo/v4" "net/http" + "slices" "sync" "tailscale.com/smallzstd" "tailscale.com/tailcfg" @@ -82,6 +83,10 @@ func (h *PollNetMapHandler) handlePollNetMap(c echo.Context, m *domain.Machine, } if !mapRequest.Stream { + if !slices.Equal(m.HostInfo.RoutableIPs, mapRequest.Hostinfo.RoutableIPs) { + m.AutoAllowIPs = m.Tailnet.ACLPolicy.Get().FindAutoApprovedIPs(mapRequest.Hostinfo.RoutableIPs, m.Tags, &m.User) + } + m.HostInfo = domain.HostInfo(*mapRequest.Hostinfo) m.DiscoKey = mapRequest.DiscoKey.String() m.Endpoints = mapRequest.Endpoints diff --git a/tests/auto_approvers_test.go b/tests/auto_approvers_test.go new file mode 100644 index 0000000..9c3ba4e --- /dev/null +++ b/tests/auto_approvers_test.go @@ -0,0 +1,136 @@ +package tests + +import ( + "github.com/jsiebens/ionscale/pkg/client/ionscale" + "github.com/jsiebens/ionscale/pkg/defaults" + "github.com/jsiebens/ionscale/tests/sc" + "github.com/jsiebens/ionscale/tests/tsn" + "github.com/stretchr/testify/require" + "net/netip" + "testing" +) + +func TestAdvertiseRoutesAutoApprovedOnNewNode(t *testing.T) { + route1 := netip.MustParsePrefix("10.1.0.0/24") + route2 := netip.MustParsePrefix("10.2.0.0/24") + + sc.Run(t, func(s *sc.Scenario) { + aclPolicy := defaults.DefaultACLPolicy() + aclPolicy.AutoApprovers = &ionscale.ACLAutoApprovers{ + Routes: map[string][]string{ + route1.String(): {"tag:test-route"}, + }, + } + + tailnet := s.CreateTailnet() + s.SetACLPolicy(tailnet.Id, aclPolicy) + + testNode := s.NewTailscaleNode() + require.NoError(t, testNode.Up( + s.CreateAuthKey(tailnet.Id, true, "tag:test-route"), + tsn.WithAdvertiseTags("tag:test-route"), + tsn.WithAdvertiseRoutes([]string{ + route1.String(), + route2.String()}, + ), + )) + + require.NoError(t, testNode.WaitFor(tsn.HasTailnet(tailnet.Name))) + + mid, err := s.FindMachine(tailnet.Id, testNode.Hostname()) + require.NoError(t, err) + + machineRoutes := s.GetMachineRoutes(mid) + require.NoError(t, err) + + require.Equal(t, []string{route1.String(), route2.String()}, machineRoutes.AdvertisedRoutes) + require.Equal(t, []string{route1.String()}, machineRoutes.EnabledRoutes) + + require.NoError(t, testNode.Check(tsn.HasAllowedIP(route1))) + require.NoError(t, testNode.Check(tsn.IsMissingAllowedIP(route2))) + }) +} + +func TestAdvertiseRoutesAutoApprovedOnExistingNode(t *testing.T) { + route1 := netip.MustParsePrefix("10.1.0.0/24") + route2 := netip.MustParsePrefix("10.2.0.0/24") + route3 := netip.MustParsePrefix("10.3.0.0/24") + + sc.Run(t, func(s *sc.Scenario) { + aclPolicy := defaults.DefaultACLPolicy() + aclPolicy.AutoApprovers = &ionscale.ACLAutoApprovers{ + Routes: map[string][]string{ + route1.String(): {"tag:test-route"}, + route3.String(): {"tag:test-route"}, + }, + } + + tailnet := s.CreateTailnet() + s.SetACLPolicy(tailnet.Id, aclPolicy) + + testNode := s.NewTailscaleNode() + require.NoError(t, testNode.Up( + s.CreateAuthKey(tailnet.Id, true, "tag:test-route"), + tsn.WithAdvertiseTags("tag:test-route"), + )) + + require.NoError(t, testNode.Check(tsn.HasTailnet(tailnet.Name))) + + testNode.Set(tsn.WithAdvertiseRoutes([]string{ + route3.String(), + route1.String(), + route2.String()}, + )) + + require.NoError(t, testNode.WaitFor(tsn.HasAllowedIP(route1))) + require.NoError(t, testNode.WaitFor(tsn.HasAllowedIP(route3))) + require.NoError(t, testNode.WaitFor(tsn.IsMissingAllowedIP(route2))) + + mid, err := s.FindMachine(tailnet.Id, testNode.Hostname()) + require.NoError(t, err) + + machineRoutes := s.GetMachineRoutes(mid) + require.NoError(t, err) + + require.Equal(t, []string{route1.String(), route2.String(), route3.String()}, machineRoutes.AdvertisedRoutes) + require.Equal(t, []string{route1.String(), route3.String()}, machineRoutes.EnabledRoutes) + }) +} + +func TestAdvertiseRemoveRoutesAutoApprovedOnExistingNode(t *testing.T) { + route1 := netip.MustParsePrefix("10.1.0.0/24") + route2 := netip.MustParsePrefix("10.2.0.0/24") + + sc.Run(t, func(s *sc.Scenario) { + aclPolicy := defaults.DefaultACLPolicy() + aclPolicy.AutoApprovers = &ionscale.ACLAutoApprovers{ + Routes: map[string][]string{ + route1.String(): {"tag:test-route"}, + route2.String(): {"tag:test-route"}, + }, + } + + tailnet := s.CreateTailnet() + s.SetACLPolicy(tailnet.Id, aclPolicy) + + testNode := s.NewTailscaleNode() + require.NoError(t, testNode.Up( + s.CreateAuthKey(tailnet.Id, true, "tag:test-route"), + tsn.WithAdvertiseTags("tag:test-route"), + tsn.WithAdvertiseRoutes([]string{ + route1.String(), + route2.String()}, + ), + )) + + require.NoError(t, testNode.WaitFor(tsn.HasTailnet(tailnet.Name))) + require.NoError(t, testNode.Check(tsn.HasAllowedIP(route1))) + require.NoError(t, testNode.Check(tsn.HasAllowedIP(route2))) + + testNode.Set(tsn.WithAdvertiseRoutes([]string{ + route1.String(), + })) + + require.NoError(t, testNode.WaitFor(tsn.IsMissingAllowedIP(route2))) + }) +} diff --git a/tests/sc/scenario.go b/tests/sc/scenario.go index 5269408..7d54d56 100644 --- a/tests/sc/scenario.go +++ b/tests/sc/scenario.go @@ -117,6 +117,12 @@ func (s *Scenario) EnableMachineAutorization(tailnetID uint64) { require.NoError(s.t, err) } +func (s *Scenario) GetMachineRoutes(machineID uint64) *api.MachineRoutes { + routes, err := s.ionscaleClient.GetMachineRoutes(context.Background(), connect.NewRequest(&api.GetMachineRoutesRequest{MachineId: machineID})) + require.NoError(s.t, err) + return routes.Msg.Routes +} + func (s *Scenario) PushOIDCUser(sub, email, preferredUsername string) { _, err := s.mockoidcClient.PushUser(context.Background(), connect.NewRequest(&mockoidcv1.PushUserRequest{Subject: sub, Email: email, PreferredUsername: preferredUsername})) require.NoError(s.t, err) diff --git a/tests/tsn/conditions.go b/tests/tsn/conditions.go index 0c69a02..0ce5f17 100644 --- a/tests/tsn/conditions.go +++ b/tests/tsn/conditions.go @@ -1,6 +1,7 @@ package tsn import ( + "net/netip" "slices" "strings" "tailscale.com/ipn/ipnstate" @@ -60,6 +61,24 @@ func HasUser(email string) Condition { } } +func HasAllowedIP(route netip.Prefix) Condition { + return func(status *ipnstate.Status) bool { + if status.Self == nil || status.Self.AllowedIPs.Len() == 0 { + return false + } + return slices.Contains(status.Self.AllowedIPs.AsSlice(), route) + } +} + +func IsMissingAllowedIP(route netip.Prefix) Condition { + return func(status *ipnstate.Status) bool { + if status.Self == nil || status.Self.AllowedIPs.Len() == 0 { + return true + } + return !slices.Contains(status.Self.AllowedIPs.AsSlice(), route) + } +} + func PeerCount(expected int) Condition { return func(status *ipnstate.Status) bool { return len(status.Peers()) == expected diff --git a/tests/tsn/node.go b/tests/tsn/node.go index ff77539..ed3339a 100644 --- a/tests/tsn/node.go +++ b/tests/tsn/node.go @@ -38,11 +38,25 @@ func (t *TailscaleNode) Hostname() string { return t.hostname } -func (t *TailscaleNode) Up(authkey string) error { - t.mustExecTailscaleCmd("up", "--login-server", t.loginServer, "--authkey", authkey) +func (t *TailscaleNode) Up(authkey string, flags ...UpFlag) error { + cmd := []string{"up", "--login-server", t.loginServer, "--authkey", authkey} + for _, f := range flags { + cmd = append(cmd, f...) + } + + t.mustExecTailscaleCmd(cmd...) return t.WaitFor(Connected()) } +func (t *TailscaleNode) Set(flags ...UpFlag) string { + cmd := []string{"set"} + for _, f := range flags { + cmd = append(cmd, f...) + } + + return t.mustExecTailscaleCmd(cmd...) +} + func (t *TailscaleNode) LoginWithOidc(flags ...UpFlag) (int, error) { check := func(stdout, stderr string) bool { return strings.Contains(stderr, "To authenticate, visit:") diff --git a/tests/tsn/opts.go b/tests/tsn/opts.go index 29da94b..81f0dbb 100644 --- a/tests/tsn/opts.go +++ b/tests/tsn/opts.go @@ -1,7 +1,13 @@ package tsn +import "strings" + type UpFlag = []string func WithAdvertiseTags(tags string) UpFlag { return []string{"--advertise-tags", tags} } + +func WithAdvertiseRoutes(routes []string) UpFlag { + return []string{"--advertise-routes", strings.Join(routes, ",")} +}