feat: remove auth-filter in favor of a new IAM Policy setup

This commit is contained in:
Johan Siebens
2022-06-10 15:32:36 +02:00
parent eefa150738
commit a94e0ce9b8
22 changed files with 1005 additions and 812 deletions
+2 -2
View File
@@ -12,7 +12,7 @@ import (
"io/ioutil"
)
func getACLConfig() *coral.Command {
func getACLConfigCommand() *coral.Command {
command := &coral.Command{
Use: "get-acl",
Short: "Get the ACL policy",
@@ -76,7 +76,7 @@ func getACLConfig() *coral.Command {
return command
}
func setACLConfig() *coral.Command {
func setACLConfigCommand() *coral.Command {
command := &coral.Command{
Use: "set-acl",
Short: "Set ACL policy",
-186
View File
@@ -1,186 +0,0 @@
package cmd
import (
"context"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/hashicorp/go-bexpr"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/muesli/coral"
"github.com/rodaine/table"
)
func authFilterCommand() *coral.Command {
command := &coral.Command{
Use: "auth-filters",
Short: "Manage ionscale auth filters",
}
command.AddCommand(createAuthFilterCommand())
command.AddCommand(listAuthFilterCommand())
command.AddCommand(deleteAuthFilterCommand())
return command
}
func listAuthFilterCommand() *coral.Command {
command := &coral.Command{
Use: "list",
SilenceUsage: true,
}
var authMethodID uint64
var authMethodName string
var target = Target{}
target.prepareCommand(command)
command.Flags().StringVar(&authMethodName, "auth-method", "", "Auth Method name. Mutually exclusive with --auth-method-id")
command.Flags().Uint64Var(&authMethodID, "auth-method-id", 0, "Auth Method ID. Mutually exclusive with --auth-method")
command.PreRunE = checkOptionalAuthMethodAndAuthMethodIdFlags
command.RunE = func(cmd *coral.Command, args []string) error {
client, err := target.createGRPCClient()
if err != nil {
return err
}
req := &api.ListAuthFiltersRequest{}
if cmd.Flags().Changed("auth-method") || cmd.Flags().Changed("auth-method-id") {
method, err := findAuthMethod(client, authMethodName, authMethodID)
if err != nil {
return err
}
req.AuthMethodId = &method.Id
}
resp, err := client.ListAuthFilters(context.Background(), connect.NewRequest(req))
if err != nil {
return err
}
tbl := table.New("ID", "AUTH_METHOD", "TAILNET", "EXPR")
for _, filter := range resp.Msg.AuthFilters {
if filter.Tailnet != nil {
tbl.AddRow(filter.Id, filter.AuthMethod.Name, filter.Tailnet.Name, filter.Expr)
} else {
tbl.AddRow(filter.Id, filter.AuthMethod.Name, "", filter.Expr)
}
}
tbl.Print()
return nil
}
return command
}
func createAuthFilterCommand() *coral.Command {
command := &coral.Command{
Use: "create",
SilenceUsage: true,
}
var expr string
var tailnetID uint64
var tailnetName string
var authMethodID uint64
var authMethodName string
var target = Target{}
target.prepareCommand(command)
command.Flags().StringVar(&expr, "expr", "*", "")
command.Flags().StringVar(&tailnetName, "tailnet", "", "Tailnet name. Mutually exclusive with --tailnet-id.")
command.Flags().Uint64Var(&tailnetID, "tailnet-id", 0, "Tailnet ID. Mutually exclusive with --tailnet.")
command.Flags().StringVar(&authMethodName, "auth-method", "", "Auth Method name. Mutually exclusive with --auth-method-id")
command.Flags().Uint64Var(&authMethodID, "auth-method-id", 0, "Auth Method ID. Mutually exclusive with --auth-method")
command.PreRunE = checkAll(
checkRequiredTailnetAndTailnetIdFlags,
checkRequiredAuthMethodAndAuthMethodIdFlags,
)
command.RunE = func(command *coral.Command, args []string) error {
if expr != "*" {
if _, err := bexpr.CreateEvaluator(expr); err != nil {
return fmt.Errorf("invalid expression: %v", err)
}
}
client, err := target.createGRPCClient()
if err != nil {
return err
}
tailnet, err := findTailnet(client, tailnetName, tailnetID)
if err != nil {
return err
}
authMethod, err := findAuthMethod(client, authMethodName, authMethodID)
if err != nil {
return err
}
req := &api.CreateAuthFilterRequest{
AuthMethodId: authMethod.Id,
TailnetId: tailnet.Id,
Expr: expr,
}
resp, err := client.CreateAuthFilter(context.Background(), connect.NewRequest(req))
if err != nil {
return err
}
tbl := table.New("ID", "AUTH_METHOD", "TAILNET", "EXPR")
if resp.Msg.AuthFilter.Tailnet != nil {
tbl.AddRow(resp.Msg.AuthFilter.Id, resp.Msg.AuthFilter.AuthMethod.Name, resp.Msg.AuthFilter.Tailnet.Name, resp.Msg.AuthFilter.Expr)
} else {
tbl.AddRow(resp.Msg.AuthFilter.Id, resp.Msg.AuthFilter.AuthMethod.Name, "", resp.Msg.AuthFilter.Expr)
}
tbl.Print()
return nil
}
return command
}
func deleteAuthFilterCommand() *coral.Command {
command := &coral.Command{
Use: "delete",
SilenceUsage: true,
}
var authFilterID uint64
var target = Target{}
target.prepareCommand(command)
command.Flags().Uint64Var(&authFilterID, "auth-filter-id", 0, "")
command.RunE = func(command *coral.Command, args []string) error {
client, err := target.createGRPCClient()
if err != nil {
return err
}
req := &api.DeleteAuthFilterRequest{
AuthFilterId: authFilterID,
}
_, err = client.DeleteAuthFilter(context.Background(), connect.NewRequest(req))
if err != nil {
return err
}
return nil
}
return command
}
+2 -2
View File
@@ -9,7 +9,7 @@ import (
"strings"
)
func getDNSConfig() *coral.Command {
func getDNSConfigCommand() *coral.Command {
command := &coral.Command{
Use: "get-dns",
Short: "Get DNS configuration",
@@ -62,7 +62,7 @@ func getDNSConfig() *coral.Command {
return command
}
func setDNSConfig() *coral.Command {
func setDNSConfigCommand() *coral.Command {
command := &coral.Command{
Use: "set-dns",
Short: "Set DNS config",
+109
View File
@@ -0,0 +1,109 @@
package cmd
import (
"context"
"encoding/json"
"fmt"
"github.com/bufbuild/connect-go"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/muesli/coral"
"io/ioutil"
)
func getIAMPolicyCommand() *coral.Command {
command := &coral.Command{
Use: "get-iam-policy",
Short: "Get the IAM policy",
SilenceUsage: true,
}
var tailnetID uint64
var tailnetName string
var target = Target{}
target.prepareCommand(command)
command.Flags().StringVar(&tailnetName, "tailnet", "", "Tailnet name. Mutually exclusive with --tailnet-id.")
command.Flags().Uint64Var(&tailnetID, "tailnet-id", 0, "Tailnet ID. Mutually exclusive with --tailnet.")
command.PreRunE = checkRequiredTailnetAndTailnetIdFlags
command.RunE = func(cmd *coral.Command, args []string) error {
client, err := target.createGRPCClient()
if err != nil {
return err
}
tailnet, err := findTailnet(client, tailnetName, tailnetID)
if err != nil {
return err
}
resp, err := client.GetIAMPolicy(context.Background(), connect.NewRequest(&api.GetIAMPolicyRequest{TailnetId: tailnet.Id}))
if err != nil {
return err
}
marshal, err := json.MarshalIndent(resp.Msg.Policy, "", " ")
if err != nil {
return err
}
fmt.Println(string(marshal))
return nil
}
return command
}
func setIAMPolicyCommand() *coral.Command {
command := &coral.Command{
Use: "set-iam-policy",
Short: "Set IAM policy",
SilenceUsage: true,
}
var tailnetID uint64
var tailnetName string
var file string
var target = Target{}
target.prepareCommand(command)
command.Flags().StringVar(&tailnetName, "tailnet", "", "Tailnet name. Mutually exclusive with --tailnet-id.")
command.Flags().Uint64Var(&tailnetID, "tailnet-id", 0, "Tailnet ID. Mutually exclusive with --tailnet.")
command.Flags().StringVar(&file, "file", "", "Path to json file with the acl configuration")
command.PreRunE = checkRequiredTailnetAndTailnetIdFlags
command.RunE = func(cmd *coral.Command, args []string) error {
rawJson, err := ioutil.ReadFile(file)
if err != nil {
return err
}
var policy = &api.IAMPolicy{}
if err := json.Unmarshal(rawJson, policy); err != nil {
return err
}
client, err := target.createGRPCClient()
if err != nil {
return err
}
tailnet, err := findTailnet(client, tailnetName, tailnetID)
if err != nil {
return err
}
_, err = client.SetIAMPolicy(context.Background(), connect.NewRequest(&api.SetIAMPolicyRequest{TailnetId: tailnet.Id, Policy: policy}))
if err != nil {
return err
}
fmt.Println()
fmt.Println("IAM policy updated successfully")
return nil
}
return command
}
-1
View File
@@ -12,7 +12,6 @@ func Command() *coral.Command {
rootCmd.AddCommand(serverCommand())
rootCmd.AddCommand(versionCommand())
rootCmd.AddCommand(authMethodsCommand())
rootCmd.AddCommand(authFilterCommand())
rootCmd.AddCommand(tailnetCommand())
rootCmd.AddCommand(authkeysCommand())
rootCmd.AddCommand(machineCommands())
+6 -4
View File
@@ -18,10 +18,12 @@ func tailnetCommand() *coral.Command {
command.AddCommand(listTailnetsCommand())
command.AddCommand(createTailnetsCommand())
command.AddCommand(deleteTailnetCommand())
command.AddCommand(getDNSConfig())
command.AddCommand(setDNSConfig())
command.AddCommand(getACLConfig())
command.AddCommand(setACLConfig())
command.AddCommand(getDNSConfigCommand())
command.AddCommand(setDNSConfigCommand())
command.AddCommand(getACLConfigCommand())
command.AddCommand(setACLConfigCommand())
command.AddCommand(getIAMPolicyCommand())
command.AddCommand(setIAMPolicyCommand())
return command
}
-1
View File
@@ -46,7 +46,6 @@ func migrate(db *gorm.DB, repository domain.Repository) error {
&domain.Tailnet{},
&domain.TailnetConfig{},
&domain.AuthMethod{},
&domain.AuthFilter{},
&domain.Account{},
&domain.User{},
&domain.ApiKey{},
-123
View File
@@ -1,123 +0,0 @@
package domain
import (
"context"
"errors"
"github.com/hashicorp/go-bexpr"
"github.com/mitchellh/pointerstructure"
"gorm.io/gorm"
)
type AuthFilter struct {
ID uint64 `gorm:"primary_key;autoIncrement:false"`
Expr string
AuthMethodID uint64
AuthMethod AuthMethod
TailnetID *uint64
Tailnet *Tailnet
}
type AuthFilters []AuthFilter
func (f *AuthFilter) Evaluate(v interface{}) (bool, error) {
if f.Expr == "*" {
return true, nil
}
eval, err := bexpr.CreateEvaluator(f.Expr)
if err != nil {
return false, err
}
result, err := eval.Evaluate(v)
if err != nil && !errors.Is(err, pointerstructure.ErrNotFound) {
return false, err
}
return result, err
}
func (fs AuthFilters) Evaluate(v interface{}) []Tailnet {
var tailnetIDMap = make(map[uint64]bool)
var tailnets []Tailnet
for _, f := range fs {
approved, err := f.Evaluate(v)
if err == nil && approved {
if f.TailnetID != nil {
_, alreadyApproved := tailnetIDMap[*f.TailnetID]
if !alreadyApproved {
tailnetIDMap[*f.TailnetID] = true
tailnets = append(tailnets, *f.Tailnet)
}
}
}
}
return tailnets
}
func (r *repository) GetAuthFilter(ctx context.Context, id uint64) (*AuthFilter, error) {
var t AuthFilter
tx := r.withContext(ctx).Take(&t, "id = ?", id)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
if tx.Error != nil {
return nil, tx.Error
}
return &t, nil
}
func (r *repository) SaveAuthFilter(ctx context.Context, m *AuthFilter) error {
tx := r.withContext(ctx).Save(m)
if tx.Error != nil {
return tx.Error
}
return nil
}
func (r *repository) ListAuthFilters(ctx context.Context) (AuthFilters, error) {
var filters = []AuthFilter{}
tx := r.withContext(ctx).
Preload("Tailnet").
Preload("AuthMethod").
Find(&filters)
if tx.Error != nil {
return nil, tx.Error
}
return filters, nil
}
func (r *repository) ListAuthFiltersByAuthMethod(ctx context.Context, authMethodID uint64) (AuthFilters, error) {
var filters = []AuthFilter{}
tx := r.withContext(ctx).
Preload("Tailnet").
Preload("AuthMethod").
Where("auth_method_id = ?", authMethodID).Find(&filters)
if tx.Error != nil {
return nil, tx.Error
}
return filters, nil
}
func (r *repository) DeleteAuthFilter(ctx context.Context, id uint64) error {
tx := r.withContext(ctx).Delete(&AuthFilter{ID: id})
return tx.Error
}
func (r *repository) DeleteAuthFiltersByTailnet(ctx context.Context, tailnetID uint64) error {
tx := r.withContext(ctx).Where("tailnet_id = ?", tailnetID).Delete(&AuthFilter{})
return tx.Error
}
+89
View File
@@ -0,0 +1,89 @@
package domain
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"github.com/hashicorp/go-bexpr"
"github.com/mitchellh/pointerstructure"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)
type Identity struct {
UserID string
Username string
Email string
Attr map[string]interface{}
}
type IAMPolicy struct {
Subs []string `json:"groups,omitempty"`
Emails []string `json:"emails,omitempty"`
Filters []string `json:"filters,omitempty"`
}
func (i *IAMPolicy) EvaluatePolicy(identity *Identity) (bool, error) {
for _, sub := range i.Subs {
if identity.UserID == sub {
return true, nil
}
}
for _, email := range i.Emails {
if identity.Email == email {
return true, nil
}
}
for _, f := range i.Filters {
if f == "*" {
return true, nil
}
evaluator, err := bexpr.CreateEvaluator(f)
if err != nil {
return false, err
}
result, err := evaluator.Evaluate(identity.Attr)
if err != nil && !errors.Is(err, pointerstructure.ErrNotFound) {
return false, err
}
if result {
return true, nil
}
}
return false, nil
}
func (i *IAMPolicy) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, i)
default:
return fmt.Errorf("unexpected data type %T", destination)
}
}
func (i IAMPolicy) Value() (driver.Value, error) {
bytes, err := json.Marshal(i)
return bytes, err
}
// GormDataType gorm common data type
func (IAMPolicy) GormDataType() string {
return "json"
}
// GormDBDataType gorm db data type
func (IAMPolicy) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
}
return ""
}
+1 -7
View File
@@ -18,16 +18,10 @@ type Repository interface {
ListAuthMethods(ctx context.Context) ([]AuthMethod, error)
GetAuthMethod(ctx context.Context, id uint64) (*AuthMethod, error)
GetAuthFilter(ctx context.Context, id uint64) (*AuthFilter, error)
SaveAuthFilter(ctx context.Context, m *AuthFilter) error
ListAuthFilters(ctx context.Context) (AuthFilters, error)
ListAuthFiltersByAuthMethod(ctx context.Context, authMethodID uint64) (AuthFilters, error)
DeleteAuthFilter(ctx context.Context, id uint64) error
DeleteAuthFiltersByTailnet(ctx context.Context, tailnetID uint64) error
GetAccount(ctx context.Context, accountID uint64) (*Account, error)
GetOrCreateAccount(ctx context.Context, authMethodID uint64, externalID, loginName string) (*Account, bool, error)
SaveTailnet(ctx context.Context, tailnet *Tailnet) error
GetOrCreateTailnet(ctx context.Context, name string) (*Tailnet, bool, error)
GetTailnet(ctx context.Context, id uint64) (*Tailnet, error)
ListTailnets(ctx context.Context) ([]Tailnet, error)
+13 -2
View File
@@ -8,8 +8,19 @@ import (
)
type Tailnet struct {
ID uint64 `gorm:"primary_key;autoIncrement:false"`
Name string `gorm:"type:varchar(64);unique_index"`
ID uint64 `gorm:"primary_key;autoIncrement:false"`
Name string `gorm:"type:varchar(64);unique_index"`
IAMPolicy IAMPolicy
}
func (r *repository) SaveTailnet(ctx context.Context, tailnet *Tailnet) error {
tx := r.withContext(ctx).Save(tailnet)
if tx.Error != nil {
return tx.Error
}
return nil
}
func (r *repository) GetOrCreateTailnet(ctx context.Context, name string) (*Tailnet, bool, error) {
+11 -13
View File
@@ -16,21 +16,19 @@ func (s SystemRole) IsAdmin() bool {
return s == SystemRoleAdmin
}
type TailnetRole string
type UserType string
const (
TailnetRoleService TailnetRole = "service"
TailnetRoleMember TailnetRole = "member"
UserTypeService UserType = "service"
UserTypePerson UserType = "person"
)
type User struct {
ID uint64 `gorm:"primary_key;autoIncrement:false"`
Name string
TailnetRole TailnetRole
TailnetID uint64
Tailnet Tailnet
ID uint64 `gorm:"primary_key;autoIncrement:false"`
Name string
UserType UserType
TailnetID uint64
Tailnet Tailnet
AccountID *uint64
Account *Account
}
@@ -41,8 +39,8 @@ func (r *repository) GetOrCreateServiceUser(ctx context.Context, tailnet *Tailne
user := &User{}
id := util.NextID()
query := User{Name: tailnet.Name, TailnetID: tailnet.ID, TailnetRole: TailnetRoleService}
attrs := User{ID: id, Name: tailnet.Name, TailnetID: tailnet.ID, TailnetRole: TailnetRoleService}
query := User{Name: tailnet.Name, TailnetID: tailnet.ID, UserType: UserTypeService}
attrs := User{ID: id, Name: tailnet.Name, TailnetID: tailnet.ID, UserType: UserTypeService}
tx := r.withContext(ctx).Where(query).Attrs(attrs).FirstOrCreate(user)
@@ -75,7 +73,7 @@ func (r *repository) GetOrCreateUserWithAccount(ctx context.Context, tailnet *Ta
id := util.NextID()
query := User{AccountID: &account.ID, TailnetID: tailnet.ID}
attrs := User{ID: id, Name: account.LoginName, TailnetID: tailnet.ID, AccountID: &account.ID, TailnetRole: TailnetRoleMember}
attrs := User{ID: id, Name: account.LoginName, TailnetID: tailnet.ID, AccountID: &account.ID, UserType: UserTypePerson}
tx := r.withContext(ctx).Where(query).Attrs(attrs).FirstOrCreate(user)
+41 -12
View File
@@ -71,8 +71,8 @@ func (h *AuthenticationHandlers) ProcessCliAuth(c echo.Context) error {
key := c.Param("key")
authMethodId := c.FormValue("s")
session, err := h.repository.GetAuthenticationRequest(ctx, key)
if err != nil || session == nil {
req, err := h.repository.GetAuthenticationRequest(ctx, key)
if err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
}
@@ -180,14 +180,25 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
return err
}
filters, err := h.repository.ListAuthFiltersByAuthMethod(ctx, state.AuthMethod)
tailnets, err := h.listAvailableTailnets(ctx, user)
if err != nil {
return err
}
tailnets := filters.Evaluate(user.Attr)
if len(tailnets) == 0 {
if state.Flow == "r" {
req, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
if err == nil && req != nil {
req.Error = "unauthorized"
_ = h.repository.SaveRegistrationRequest(ctx, req)
}
} else {
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
if err == nil && req != nil {
req.Error = "unauthorized"
_ = h.repository.SaveAuthenticationRequest(ctx, req)
}
}
return c.Redirect(http.StatusFound, "/a/error?e=ua")
}
@@ -201,6 +212,24 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
return c.Render(http.StatusOK, "tailnets.html", &TailnetSelectionData{Tailnets: tailnets})
}
func (h *AuthenticationHandlers) listAvailableTailnets(ctx context.Context, u *provider.User) ([]domain.Tailnet, error) {
var result = []domain.Tailnet{}
tailnets, err := h.repository.ListTailnets(ctx)
if err != nil {
return nil, err
}
for _, t := range tailnets {
approved, err := t.IAMPolicy.EvaluatePolicy(&domain.Identity{UserID: u.ID, Email: u.Name, Attr: u.Attr})
if err != nil {
return nil, err
}
if approved {
result = append(result, t)
}
}
return result, nil
}
func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
ctx := c.Request().Context()
@@ -218,12 +247,12 @@ func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
return h.endMachineRegistrationFlow(c, req, state)
}
session, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
if err != nil || session == nil {
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
if err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
}
return h.endCliAuthenticationFlow(c, session, state)
return h.endCliAuthenticationFlow(c, req, state)
}
func (h *AuthenticationHandlers) Success(c echo.Context) error {
@@ -241,7 +270,7 @@ func (h *AuthenticationHandlers) Error(c echo.Context) error {
return c.Render(http.StatusOK, "error.html", nil)
}
func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, session *domain.AuthenticationRequest, state *oauthState) error {
func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *domain.AuthenticationRequest, state *oauthState) error {
ctx := c.Request().Context()
tailnetIDParam := c.FormValue("s")
@@ -269,13 +298,13 @@ func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, sessio
expiresAt := time.Now().Add(24 * time.Hour)
token, apiKey := domain.CreateApiKey(tailnet, user, &expiresAt)
session.Token = token
req.Token = token
err = h.repository.Transaction(func(rp domain.Repository) error {
if err := rp.SaveApiKey(ctx, apiKey); err != nil {
return err
}
if err := rp.SaveAuthenticationRequest(ctx, session); err != nil {
if err := rp.SaveAuthenticationRequest(ctx, req); err != nil {
return err
}
return nil
@@ -481,7 +510,7 @@ func (h *AuthenticationHandlers) exchangeUser(ctx context.Context, code string,
}
func (h *AuthenticationHandlers) createState(flow string, key string, authMethodId uint64) (string, error) {
stateMap := oauthState{Key: key, AuthMethod: authMethodId}
stateMap := oauthState{Key: key, AuthMethod: authMethodId, Flow: flow}
marshal, err := json.Marshal(&stateMap)
if err != nil {
return "", err
-159
View File
@@ -1,159 +0,0 @@
package service
import (
"context"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/hashicorp/go-bexpr"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/util"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
func (s *Service) ListAuthFilters(ctx context.Context, req *connect.Request[api.ListAuthFiltersRequest]) (*connect.Response[api.ListAuthFiltersResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
}
response := &api.ListAuthFiltersResponse{AuthFilters: []*api.AuthFilter{}}
if req.Msg.AuthMethodId == nil {
filters, err := s.repository.ListAuthFilters(ctx)
if err != nil {
return nil, err
}
for _, filter := range filters {
response.AuthFilters = append(response.AuthFilters, s.mapToApi(&filter.AuthMethod, filter))
}
} else {
authMethod, err := s.repository.GetAuthMethod(ctx, *req.Msg.AuthMethodId)
if err != nil {
return nil, err
}
if authMethod == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("invalid auth method id"))
}
filters, err := s.repository.ListAuthFiltersByAuthMethod(ctx, authMethod.ID)
if err != nil {
return nil, err
}
for _, filter := range filters {
response.AuthFilters = append(response.AuthFilters, s.mapToApi(&filter.AuthMethod, filter))
}
}
return connect.NewResponse[api.ListAuthFiltersResponse](response), nil
}
func (s *Service) CreateAuthFilter(ctx context.Context, req *connect.Request[api.CreateAuthFilterRequest]) (*connect.Response[api.CreateAuthFilterResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
}
authMethod, err := s.repository.GetAuthMethod(ctx, req.Msg.AuthMethodId)
if err != nil {
return nil, err
}
if authMethod == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("invalid auth method id"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("invalid tailnet id"))
}
if req.Msg.Expr != "*" {
if _, err := bexpr.CreateEvaluator(req.Msg.Expr); err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid expression: %v", err))
}
}
authFilter := &domain.AuthFilter{
ID: util.NextID(),
Expr: req.Msg.Expr,
AuthMethod: *authMethod,
Tailnet: tailnet,
}
if err := s.repository.SaveAuthFilter(ctx, authFilter); err != nil {
return nil, err
}
response := api.CreateAuthFilterResponse{AuthFilter: s.mapToApi(authMethod, *authFilter)}
return connect.NewResponse[api.CreateAuthFilterResponse](&response), nil
}
func (s *Service) DeleteAuthFilter(ctx context.Context, req *connect.Request[api.DeleteAuthFilterRequest]) (*connect.Response[api.DeleteAuthFilterResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
}
err := s.repository.Transaction(func(rp domain.Repository) error {
filter, err := rp.GetAuthFilter(ctx, req.Msg.AuthFilterId)
if err != nil {
return err
}
if filter == nil {
return connect.NewError(connect.CodeNotFound, fmt.Errorf("auth filter not found"))
}
c, err := rp.ExpireMachineByAuthMethod(ctx, *filter.TailnetID, filter.AuthMethodID)
if err != nil {
return err
}
if err := rp.DeleteAuthFilter(ctx, filter.ID); err != nil {
return err
}
if c != 0 {
s.brokers(*filter.TailnetID).SignalUpdate()
}
return nil
})
if err != nil {
return nil, err
}
response := api.DeleteAuthFilterResponse{}
return connect.NewResponse[api.DeleteAuthFilterResponse](&response), nil
}
func (s *Service) mapToApi(authMethod *domain.AuthMethod, filter domain.AuthFilter) *api.AuthFilter {
result := api.AuthFilter{
Id: filter.ID,
Expr: filter.Expr,
AuthMethod: &api.Ref{
Id: authMethod.ID,
Name: authMethod.Name,
},
}
if filter.Tailnet != nil {
id := filter.Tailnet.ID
name := filter.Tailnet.Name
result.Tailnet = &api.Ref{
Id: id,
Name: name,
}
}
return &result
}
+60
View File
@@ -0,0 +1,60 @@
package service
import (
"context"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/domain"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
func (s *Service) GetIAMPolicy(ctx context.Context, req *connect.Request[api.GetIAMPolicyRequest]) (*connect.Response[api.GetIAMPolicyResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.TailnetMatches(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
}
policy := &api.IAMPolicy{
Subs: tailnet.IAMPolicy.Subs,
Emails: tailnet.IAMPolicy.Emails,
Filters: tailnet.IAMPolicy.Filters,
}
return connect.NewResponse(&api.GetIAMPolicyResponse{Policy: policy}), nil
}
func (s *Service) SetIAMPolicy(ctx context.Context, req *connect.Request[api.SetIAMPolicyRequest]) (*connect.Response[api.SetIAMPolicyResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
}
tailnet.IAMPolicy = domain.IAMPolicy{
Subs: req.Msg.Policy.Subs,
Emails: req.Msg.Policy.Emails,
Filters: req.Msg.Policy.Filters,
}
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
}
return connect.NewResponse(&api.SetIAMPolicyResponse{}), nil
}
-4
View File
@@ -109,10 +109,6 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.De
return err
}
if err := tx.DeleteAuthFiltersByTailnet(ctx, req.Msg.TailnetId); err != nil {
return err
}
if err := tx.DeleteACLPolicy(ctx, req.Msg.TailnetId); err != nil {
return err
}