mirror of
https://github.com/jsiebens/ionscale.git
synced 2026-03-31 15:07:49 +01:00
feat: remove auth-filter in favor of a new IAM Policy setup
This commit is contained in:
+2
-2
@@ -12,7 +12,7 @@ import (
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
func getACLConfig() *coral.Command {
|
||||
func getACLConfigCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "get-acl",
|
||||
Short: "Get the ACL policy",
|
||||
@@ -76,7 +76,7 @@ func getACLConfig() *coral.Command {
|
||||
return command
|
||||
}
|
||||
|
||||
func setACLConfig() *coral.Command {
|
||||
func setACLConfigCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "set-acl",
|
||||
Short: "Set ACL policy",
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/hashicorp/go-bexpr"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
"github.com/muesli/coral"
|
||||
"github.com/rodaine/table"
|
||||
)
|
||||
|
||||
func authFilterCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "auth-filters",
|
||||
Short: "Manage ionscale auth filters",
|
||||
}
|
||||
|
||||
command.AddCommand(createAuthFilterCommand())
|
||||
command.AddCommand(listAuthFilterCommand())
|
||||
command.AddCommand(deleteAuthFilterCommand())
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func listAuthFilterCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "list",
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var authMethodID uint64
|
||||
var authMethodName string
|
||||
|
||||
var target = Target{}
|
||||
target.prepareCommand(command)
|
||||
|
||||
command.Flags().StringVar(&authMethodName, "auth-method", "", "Auth Method name. Mutually exclusive with --auth-method-id")
|
||||
command.Flags().Uint64Var(&authMethodID, "auth-method-id", 0, "Auth Method ID. Mutually exclusive with --auth-method")
|
||||
|
||||
command.PreRunE = checkOptionalAuthMethodAndAuthMethodIdFlags
|
||||
command.RunE = func(cmd *coral.Command, args []string) error {
|
||||
client, err := target.createGRPCClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &api.ListAuthFiltersRequest{}
|
||||
|
||||
if cmd.Flags().Changed("auth-method") || cmd.Flags().Changed("auth-method-id") {
|
||||
method, err := findAuthMethod(client, authMethodName, authMethodID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.AuthMethodId = &method.Id
|
||||
}
|
||||
|
||||
resp, err := client.ListAuthFilters(context.Background(), connect.NewRequest(req))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tbl := table.New("ID", "AUTH_METHOD", "TAILNET", "EXPR")
|
||||
for _, filter := range resp.Msg.AuthFilters {
|
||||
if filter.Tailnet != nil {
|
||||
tbl.AddRow(filter.Id, filter.AuthMethod.Name, filter.Tailnet.Name, filter.Expr)
|
||||
} else {
|
||||
tbl.AddRow(filter.Id, filter.AuthMethod.Name, "", filter.Expr)
|
||||
}
|
||||
}
|
||||
tbl.Print()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func createAuthFilterCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "create",
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var expr string
|
||||
var tailnetID uint64
|
||||
var tailnetName string
|
||||
var authMethodID uint64
|
||||
var authMethodName string
|
||||
|
||||
var target = Target{}
|
||||
target.prepareCommand(command)
|
||||
|
||||
command.Flags().StringVar(&expr, "expr", "*", "")
|
||||
command.Flags().StringVar(&tailnetName, "tailnet", "", "Tailnet name. Mutually exclusive with --tailnet-id.")
|
||||
command.Flags().Uint64Var(&tailnetID, "tailnet-id", 0, "Tailnet ID. Mutually exclusive with --tailnet.")
|
||||
command.Flags().StringVar(&authMethodName, "auth-method", "", "Auth Method name. Mutually exclusive with --auth-method-id")
|
||||
command.Flags().Uint64Var(&authMethodID, "auth-method-id", 0, "Auth Method ID. Mutually exclusive with --auth-method")
|
||||
|
||||
command.PreRunE = checkAll(
|
||||
checkRequiredTailnetAndTailnetIdFlags,
|
||||
checkRequiredAuthMethodAndAuthMethodIdFlags,
|
||||
)
|
||||
command.RunE = func(command *coral.Command, args []string) error {
|
||||
if expr != "*" {
|
||||
if _, err := bexpr.CreateEvaluator(expr); err != nil {
|
||||
return fmt.Errorf("invalid expression: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
client, err := target.createGRPCClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tailnet, err := findTailnet(client, tailnetName, tailnetID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authMethod, err := findAuthMethod(client, authMethodName, authMethodID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &api.CreateAuthFilterRequest{
|
||||
AuthMethodId: authMethod.Id,
|
||||
TailnetId: tailnet.Id,
|
||||
Expr: expr,
|
||||
}
|
||||
|
||||
resp, err := client.CreateAuthFilter(context.Background(), connect.NewRequest(req))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tbl := table.New("ID", "AUTH_METHOD", "TAILNET", "EXPR")
|
||||
if resp.Msg.AuthFilter.Tailnet != nil {
|
||||
tbl.AddRow(resp.Msg.AuthFilter.Id, resp.Msg.AuthFilter.AuthMethod.Name, resp.Msg.AuthFilter.Tailnet.Name, resp.Msg.AuthFilter.Expr)
|
||||
} else {
|
||||
tbl.AddRow(resp.Msg.AuthFilter.Id, resp.Msg.AuthFilter.AuthMethod.Name, "", resp.Msg.AuthFilter.Expr)
|
||||
}
|
||||
tbl.Print()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func deleteAuthFilterCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "delete",
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var authFilterID uint64
|
||||
|
||||
var target = Target{}
|
||||
target.prepareCommand(command)
|
||||
|
||||
command.Flags().Uint64Var(&authFilterID, "auth-filter-id", 0, "")
|
||||
|
||||
command.RunE = func(command *coral.Command, args []string) error {
|
||||
client, err := target.createGRPCClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &api.DeleteAuthFilterRequest{
|
||||
AuthFilterId: authFilterID,
|
||||
}
|
||||
|
||||
_, err = client.DeleteAuthFilter(context.Background(), connect.NewRequest(req))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
+2
-2
@@ -9,7 +9,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getDNSConfig() *coral.Command {
|
||||
func getDNSConfigCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "get-dns",
|
||||
Short: "Get DNS configuration",
|
||||
@@ -62,7 +62,7 @@ func getDNSConfig() *coral.Command {
|
||||
return command
|
||||
}
|
||||
|
||||
func setDNSConfig() *coral.Command {
|
||||
func setDNSConfigCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "set-dns",
|
||||
Short: "Set DNS config",
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
"github.com/muesli/coral"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
func getIAMPolicyCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "get-iam-policy",
|
||||
Short: "Get the IAM policy",
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var tailnetID uint64
|
||||
var tailnetName string
|
||||
var target = Target{}
|
||||
|
||||
target.prepareCommand(command)
|
||||
command.Flags().StringVar(&tailnetName, "tailnet", "", "Tailnet name. Mutually exclusive with --tailnet-id.")
|
||||
command.Flags().Uint64Var(&tailnetID, "tailnet-id", 0, "Tailnet ID. Mutually exclusive with --tailnet.")
|
||||
|
||||
command.PreRunE = checkRequiredTailnetAndTailnetIdFlags
|
||||
command.RunE = func(cmd *coral.Command, args []string) error {
|
||||
client, err := target.createGRPCClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tailnet, err := findTailnet(client, tailnetName, tailnetID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.GetIAMPolicy(context.Background(), connect.NewRequest(&api.GetIAMPolicyRequest{TailnetId: tailnet.Id}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
marshal, err := json.MarshalIndent(resp.Msg.Policy, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(string(marshal))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func setIAMPolicyCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "set-iam-policy",
|
||||
Short: "Set IAM policy",
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var tailnetID uint64
|
||||
var tailnetName string
|
||||
var file string
|
||||
var target = Target{}
|
||||
|
||||
target.prepareCommand(command)
|
||||
command.Flags().StringVar(&tailnetName, "tailnet", "", "Tailnet name. Mutually exclusive with --tailnet-id.")
|
||||
command.Flags().Uint64Var(&tailnetID, "tailnet-id", 0, "Tailnet ID. Mutually exclusive with --tailnet.")
|
||||
command.Flags().StringVar(&file, "file", "", "Path to json file with the acl configuration")
|
||||
|
||||
command.PreRunE = checkRequiredTailnetAndTailnetIdFlags
|
||||
command.RunE = func(cmd *coral.Command, args []string) error {
|
||||
rawJson, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var policy = &api.IAMPolicy{}
|
||||
if err := json.Unmarshal(rawJson, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := target.createGRPCClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tailnet, err := findTailnet(client, tailnetName, tailnetID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = client.SetIAMPolicy(context.Background(), connect.NewRequest(&api.SetIAMPolicyRequest{TailnetId: tailnet.Id, Policy: policy}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("IAM policy updated successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
@@ -12,7 +12,6 @@ func Command() *coral.Command {
|
||||
rootCmd.AddCommand(serverCommand())
|
||||
rootCmd.AddCommand(versionCommand())
|
||||
rootCmd.AddCommand(authMethodsCommand())
|
||||
rootCmd.AddCommand(authFilterCommand())
|
||||
rootCmd.AddCommand(tailnetCommand())
|
||||
rootCmd.AddCommand(authkeysCommand())
|
||||
rootCmd.AddCommand(machineCommands())
|
||||
|
||||
@@ -18,10 +18,12 @@ func tailnetCommand() *coral.Command {
|
||||
command.AddCommand(listTailnetsCommand())
|
||||
command.AddCommand(createTailnetsCommand())
|
||||
command.AddCommand(deleteTailnetCommand())
|
||||
command.AddCommand(getDNSConfig())
|
||||
command.AddCommand(setDNSConfig())
|
||||
command.AddCommand(getACLConfig())
|
||||
command.AddCommand(setACLConfig())
|
||||
command.AddCommand(getDNSConfigCommand())
|
||||
command.AddCommand(setDNSConfigCommand())
|
||||
command.AddCommand(getACLConfigCommand())
|
||||
command.AddCommand(setACLConfigCommand())
|
||||
command.AddCommand(getIAMPolicyCommand())
|
||||
command.AddCommand(setIAMPolicyCommand())
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
@@ -46,7 +46,6 @@ func migrate(db *gorm.DB, repository domain.Repository) error {
|
||||
&domain.Tailnet{},
|
||||
&domain.TailnetConfig{},
|
||||
&domain.AuthMethod{},
|
||||
&domain.AuthFilter{},
|
||||
&domain.Account{},
|
||||
&domain.User{},
|
||||
&domain.ApiKey{},
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/hashicorp/go-bexpr"
|
||||
"github.com/mitchellh/pointerstructure"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AuthFilter struct {
|
||||
ID uint64 `gorm:"primary_key;autoIncrement:false"`
|
||||
Expr string
|
||||
AuthMethodID uint64
|
||||
AuthMethod AuthMethod
|
||||
TailnetID *uint64
|
||||
Tailnet *Tailnet
|
||||
}
|
||||
|
||||
type AuthFilters []AuthFilter
|
||||
|
||||
func (f *AuthFilter) Evaluate(v interface{}) (bool, error) {
|
||||
if f.Expr == "*" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
eval, err := bexpr.CreateEvaluator(f.Expr)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
result, err := eval.Evaluate(v)
|
||||
if err != nil && !errors.Is(err, pointerstructure.ErrNotFound) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (fs AuthFilters) Evaluate(v interface{}) []Tailnet {
|
||||
var tailnetIDMap = make(map[uint64]bool)
|
||||
var tailnets []Tailnet
|
||||
|
||||
for _, f := range fs {
|
||||
approved, err := f.Evaluate(v)
|
||||
if err == nil && approved {
|
||||
if f.TailnetID != nil {
|
||||
_, alreadyApproved := tailnetIDMap[*f.TailnetID]
|
||||
if !alreadyApproved {
|
||||
tailnetIDMap[*f.TailnetID] = true
|
||||
tailnets = append(tailnets, *f.Tailnet)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tailnets
|
||||
}
|
||||
|
||||
func (r *repository) GetAuthFilter(ctx context.Context, id uint64) (*AuthFilter, error) {
|
||||
var t AuthFilter
|
||||
tx := r.withContext(ctx).Take(&t, "id = ?", id)
|
||||
|
||||
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (r *repository) SaveAuthFilter(ctx context.Context, m *AuthFilter) error {
|
||||
tx := r.withContext(ctx).Save(m)
|
||||
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) ListAuthFilters(ctx context.Context) (AuthFilters, error) {
|
||||
var filters = []AuthFilter{}
|
||||
|
||||
tx := r.withContext(ctx).
|
||||
Preload("Tailnet").
|
||||
Preload("AuthMethod").
|
||||
Find(&filters)
|
||||
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
|
||||
return filters, nil
|
||||
}
|
||||
|
||||
func (r *repository) ListAuthFiltersByAuthMethod(ctx context.Context, authMethodID uint64) (AuthFilters, error) {
|
||||
var filters = []AuthFilter{}
|
||||
|
||||
tx := r.withContext(ctx).
|
||||
Preload("Tailnet").
|
||||
Preload("AuthMethod").
|
||||
Where("auth_method_id = ?", authMethodID).Find(&filters)
|
||||
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
|
||||
return filters, nil
|
||||
}
|
||||
|
||||
func (r *repository) DeleteAuthFilter(ctx context.Context, id uint64) error {
|
||||
tx := r.withContext(ctx).Delete(&AuthFilter{ID: id})
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
func (r *repository) DeleteAuthFiltersByTailnet(ctx context.Context, tailnetID uint64) error {
|
||||
tx := r.withContext(ctx).Where("tailnet_id = ?", tailnetID).Delete(&AuthFilter{})
|
||||
return tx.Error
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/hashicorp/go-bexpr"
|
||||
"github.com/mitchellh/pointerstructure"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Identity struct {
|
||||
UserID string
|
||||
Username string
|
||||
Email string
|
||||
Attr map[string]interface{}
|
||||
}
|
||||
|
||||
type IAMPolicy struct {
|
||||
Subs []string `json:"groups,omitempty"`
|
||||
Emails []string `json:"emails,omitempty"`
|
||||
Filters []string `json:"filters,omitempty"`
|
||||
}
|
||||
|
||||
func (i *IAMPolicy) EvaluatePolicy(identity *Identity) (bool, error) {
|
||||
for _, sub := range i.Subs {
|
||||
if identity.UserID == sub {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, email := range i.Emails {
|
||||
if identity.Email == email {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, f := range i.Filters {
|
||||
if f == "*" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
evaluator, err := bexpr.CreateEvaluator(f)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
result, err := evaluator.Evaluate(identity.Attr)
|
||||
if err != nil && !errors.Is(err, pointerstructure.ErrNotFound) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if result {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (i *IAMPolicy) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(value, i)
|
||||
default:
|
||||
return fmt.Errorf("unexpected data type %T", destination)
|
||||
}
|
||||
}
|
||||
|
||||
func (i IAMPolicy) Value() (driver.Value, error) {
|
||||
bytes, err := json.Marshal(i)
|
||||
return bytes, err
|
||||
}
|
||||
|
||||
// GormDataType gorm common data type
|
||||
func (IAMPolicy) GormDataType() string {
|
||||
return "json"
|
||||
}
|
||||
|
||||
// GormDBDataType gorm db data type
|
||||
func (IAMPolicy) GormDBDataType(db *gorm.DB, field *schema.Field) string {
|
||||
switch db.Dialector.Name() {
|
||||
case "sqlite":
|
||||
return "JSON"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -18,16 +18,10 @@ type Repository interface {
|
||||
ListAuthMethods(ctx context.Context) ([]AuthMethod, error)
|
||||
GetAuthMethod(ctx context.Context, id uint64) (*AuthMethod, error)
|
||||
|
||||
GetAuthFilter(ctx context.Context, id uint64) (*AuthFilter, error)
|
||||
SaveAuthFilter(ctx context.Context, m *AuthFilter) error
|
||||
ListAuthFilters(ctx context.Context) (AuthFilters, error)
|
||||
ListAuthFiltersByAuthMethod(ctx context.Context, authMethodID uint64) (AuthFilters, error)
|
||||
DeleteAuthFilter(ctx context.Context, id uint64) error
|
||||
DeleteAuthFiltersByTailnet(ctx context.Context, tailnetID uint64) error
|
||||
|
||||
GetAccount(ctx context.Context, accountID uint64) (*Account, error)
|
||||
GetOrCreateAccount(ctx context.Context, authMethodID uint64, externalID, loginName string) (*Account, bool, error)
|
||||
|
||||
SaveTailnet(ctx context.Context, tailnet *Tailnet) error
|
||||
GetOrCreateTailnet(ctx context.Context, name string) (*Tailnet, bool, error)
|
||||
GetTailnet(ctx context.Context, id uint64) (*Tailnet, error)
|
||||
ListTailnets(ctx context.Context) ([]Tailnet, error)
|
||||
|
||||
@@ -8,8 +8,19 @@ import (
|
||||
)
|
||||
|
||||
type Tailnet struct {
|
||||
ID uint64 `gorm:"primary_key;autoIncrement:false"`
|
||||
Name string `gorm:"type:varchar(64);unique_index"`
|
||||
ID uint64 `gorm:"primary_key;autoIncrement:false"`
|
||||
Name string `gorm:"type:varchar(64);unique_index"`
|
||||
IAMPolicy IAMPolicy
|
||||
}
|
||||
|
||||
func (r *repository) SaveTailnet(ctx context.Context, tailnet *Tailnet) error {
|
||||
tx := r.withContext(ctx).Save(tailnet)
|
||||
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) GetOrCreateTailnet(ctx context.Context, name string) (*Tailnet, bool, error) {
|
||||
|
||||
+11
-13
@@ -16,21 +16,19 @@ func (s SystemRole) IsAdmin() bool {
|
||||
return s == SystemRoleAdmin
|
||||
}
|
||||
|
||||
type TailnetRole string
|
||||
type UserType string
|
||||
|
||||
const (
|
||||
TailnetRoleService TailnetRole = "service"
|
||||
TailnetRoleMember TailnetRole = "member"
|
||||
UserTypeService UserType = "service"
|
||||
UserTypePerson UserType = "person"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uint64 `gorm:"primary_key;autoIncrement:false"`
|
||||
Name string
|
||||
|
||||
TailnetRole TailnetRole
|
||||
TailnetID uint64
|
||||
Tailnet Tailnet
|
||||
|
||||
ID uint64 `gorm:"primary_key;autoIncrement:false"`
|
||||
Name string
|
||||
UserType UserType
|
||||
TailnetID uint64
|
||||
Tailnet Tailnet
|
||||
AccountID *uint64
|
||||
Account *Account
|
||||
}
|
||||
@@ -41,8 +39,8 @@ func (r *repository) GetOrCreateServiceUser(ctx context.Context, tailnet *Tailne
|
||||
user := &User{}
|
||||
id := util.NextID()
|
||||
|
||||
query := User{Name: tailnet.Name, TailnetID: tailnet.ID, TailnetRole: TailnetRoleService}
|
||||
attrs := User{ID: id, Name: tailnet.Name, TailnetID: tailnet.ID, TailnetRole: TailnetRoleService}
|
||||
query := User{Name: tailnet.Name, TailnetID: tailnet.ID, UserType: UserTypeService}
|
||||
attrs := User{ID: id, Name: tailnet.Name, TailnetID: tailnet.ID, UserType: UserTypeService}
|
||||
|
||||
tx := r.withContext(ctx).Where(query).Attrs(attrs).FirstOrCreate(user)
|
||||
|
||||
@@ -75,7 +73,7 @@ func (r *repository) GetOrCreateUserWithAccount(ctx context.Context, tailnet *Ta
|
||||
id := util.NextID()
|
||||
|
||||
query := User{AccountID: &account.ID, TailnetID: tailnet.ID}
|
||||
attrs := User{ID: id, Name: account.LoginName, TailnetID: tailnet.ID, AccountID: &account.ID, TailnetRole: TailnetRoleMember}
|
||||
attrs := User{ID: id, Name: account.LoginName, TailnetID: tailnet.ID, AccountID: &account.ID, UserType: UserTypePerson}
|
||||
|
||||
tx := r.withContext(ctx).Where(query).Attrs(attrs).FirstOrCreate(user)
|
||||
|
||||
|
||||
@@ -71,8 +71,8 @@ func (h *AuthenticationHandlers) ProcessCliAuth(c echo.Context) error {
|
||||
key := c.Param("key")
|
||||
authMethodId := c.FormValue("s")
|
||||
|
||||
session, err := h.repository.GetAuthenticationRequest(ctx, key)
|
||||
if err != nil || session == nil {
|
||||
req, err := h.repository.GetAuthenticationRequest(ctx, key)
|
||||
if err != nil || req == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
}
|
||||
|
||||
@@ -180,14 +180,25 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
filters, err := h.repository.ListAuthFiltersByAuthMethod(ctx, state.AuthMethod)
|
||||
tailnets, err := h.listAvailableTailnets(ctx, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tailnets := filters.Evaluate(user.Attr)
|
||||
|
||||
if len(tailnets) == 0 {
|
||||
if state.Flow == "r" {
|
||||
req, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
|
||||
if err == nil && req != nil {
|
||||
req.Error = "unauthorized"
|
||||
_ = h.repository.SaveRegistrationRequest(ctx, req)
|
||||
}
|
||||
} else {
|
||||
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
|
||||
if err == nil && req != nil {
|
||||
req.Error = "unauthorized"
|
||||
_ = h.repository.SaveAuthenticationRequest(ctx, req)
|
||||
}
|
||||
}
|
||||
return c.Redirect(http.StatusFound, "/a/error?e=ua")
|
||||
}
|
||||
|
||||
@@ -201,6 +212,24 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
return c.Render(http.StatusOK, "tailnets.html", &TailnetSelectionData{Tailnets: tailnets})
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) listAvailableTailnets(ctx context.Context, u *provider.User) ([]domain.Tailnet, error) {
|
||||
var result = []domain.Tailnet{}
|
||||
tailnets, err := h.repository.ListTailnets(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range tailnets {
|
||||
approved, err := t.IAMPolicy.EvaluatePolicy(&domain.Identity{UserID: u.ID, Email: u.Name, Attr: u.Attr})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if approved {
|
||||
result = append(result, t)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
|
||||
@@ -218,12 +247,12 @@ func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
|
||||
return h.endMachineRegistrationFlow(c, req, state)
|
||||
}
|
||||
|
||||
session, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
|
||||
if err != nil || session == nil {
|
||||
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
|
||||
if err != nil || req == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
}
|
||||
|
||||
return h.endCliAuthenticationFlow(c, session, state)
|
||||
return h.endCliAuthenticationFlow(c, req, state)
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) Success(c echo.Context) error {
|
||||
@@ -241,7 +270,7 @@ func (h *AuthenticationHandlers) Error(c echo.Context) error {
|
||||
return c.Render(http.StatusOK, "error.html", nil)
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, session *domain.AuthenticationRequest, state *oauthState) error {
|
||||
func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *domain.AuthenticationRequest, state *oauthState) error {
|
||||
ctx := c.Request().Context()
|
||||
|
||||
tailnetIDParam := c.FormValue("s")
|
||||
@@ -269,13 +298,13 @@ func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, sessio
|
||||
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
token, apiKey := domain.CreateApiKey(tailnet, user, &expiresAt)
|
||||
session.Token = token
|
||||
req.Token = token
|
||||
|
||||
err = h.repository.Transaction(func(rp domain.Repository) error {
|
||||
if err := rp.SaveApiKey(ctx, apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rp.SaveAuthenticationRequest(ctx, session); err != nil {
|
||||
if err := rp.SaveAuthenticationRequest(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -481,7 +510,7 @@ func (h *AuthenticationHandlers) exchangeUser(ctx context.Context, code string,
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) createState(flow string, key string, authMethodId uint64) (string, error) {
|
||||
stateMap := oauthState{Key: key, AuthMethod: authMethodId}
|
||||
stateMap := oauthState{Key: key, AuthMethod: authMethodId, Flow: flow}
|
||||
marshal, err := json.Marshal(&stateMap)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/hashicorp/go-bexpr"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/util"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
)
|
||||
|
||||
func (s *Service) ListAuthFilters(ctx context.Context, req *connect.Request[api.ListAuthFiltersRequest]) (*connect.Response[api.ListAuthFiltersResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
}
|
||||
|
||||
response := &api.ListAuthFiltersResponse{AuthFilters: []*api.AuthFilter{}}
|
||||
|
||||
if req.Msg.AuthMethodId == nil {
|
||||
filters, err := s.repository.ListAuthFilters(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, filter := range filters {
|
||||
response.AuthFilters = append(response.AuthFilters, s.mapToApi(&filter.AuthMethod, filter))
|
||||
}
|
||||
} else {
|
||||
authMethod, err := s.repository.GetAuthMethod(ctx, *req.Msg.AuthMethodId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authMethod == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("invalid auth method id"))
|
||||
}
|
||||
|
||||
filters, err := s.repository.ListAuthFiltersByAuthMethod(ctx, authMethod.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, filter := range filters {
|
||||
response.AuthFilters = append(response.AuthFilters, s.mapToApi(&filter.AuthMethod, filter))
|
||||
}
|
||||
}
|
||||
|
||||
return connect.NewResponse[api.ListAuthFiltersResponse](response), nil
|
||||
}
|
||||
|
||||
func (s *Service) CreateAuthFilter(ctx context.Context, req *connect.Request[api.CreateAuthFilterRequest]) (*connect.Response[api.CreateAuthFilterResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
}
|
||||
|
||||
authMethod, err := s.repository.GetAuthMethod(ctx, req.Msg.AuthMethodId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authMethod == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("invalid auth method id"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("invalid tailnet id"))
|
||||
}
|
||||
|
||||
if req.Msg.Expr != "*" {
|
||||
if _, err := bexpr.CreateEvaluator(req.Msg.Expr); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid expression: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
authFilter := &domain.AuthFilter{
|
||||
ID: util.NextID(),
|
||||
Expr: req.Msg.Expr,
|
||||
AuthMethod: *authMethod,
|
||||
Tailnet: tailnet,
|
||||
}
|
||||
|
||||
if err := s.repository.SaveAuthFilter(ctx, authFilter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := api.CreateAuthFilterResponse{AuthFilter: s.mapToApi(authMethod, *authFilter)}
|
||||
|
||||
return connect.NewResponse[api.CreateAuthFilterResponse](&response), nil
|
||||
}
|
||||
|
||||
func (s *Service) DeleteAuthFilter(ctx context.Context, req *connect.Request[api.DeleteAuthFilterRequest]) (*connect.Response[api.DeleteAuthFilterResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
}
|
||||
|
||||
err := s.repository.Transaction(func(rp domain.Repository) error {
|
||||
|
||||
filter, err := rp.GetAuthFilter(ctx, req.Msg.AuthFilterId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if filter == nil {
|
||||
return connect.NewError(connect.CodeNotFound, fmt.Errorf("auth filter not found"))
|
||||
}
|
||||
|
||||
c, err := rp.ExpireMachineByAuthMethod(ctx, *filter.TailnetID, filter.AuthMethodID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := rp.DeleteAuthFilter(ctx, filter.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c != 0 {
|
||||
s.brokers(*filter.TailnetID).SignalUpdate()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := api.DeleteAuthFilterResponse{}
|
||||
|
||||
return connect.NewResponse[api.DeleteAuthFilterResponse](&response), nil
|
||||
}
|
||||
|
||||
func (s *Service) mapToApi(authMethod *domain.AuthMethod, filter domain.AuthFilter) *api.AuthFilter {
|
||||
result := api.AuthFilter{
|
||||
Id: filter.ID,
|
||||
Expr: filter.Expr,
|
||||
AuthMethod: &api.Ref{
|
||||
Id: authMethod.ID,
|
||||
Name: authMethod.Name,
|
||||
},
|
||||
}
|
||||
|
||||
if filter.Tailnet != nil {
|
||||
id := filter.Tailnet.ID
|
||||
name := filter.Tailnet.Name
|
||||
|
||||
result.Tailnet = &api.Ref{
|
||||
Id: id,
|
||||
Name: name,
|
||||
}
|
||||
}
|
||||
|
||||
return &result
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
)
|
||||
|
||||
func (s *Service) GetIAMPolicy(ctx context.Context, req *connect.Request[api.GetIAMPolicyRequest]) (*connect.Response[api.GetIAMPolicyResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.TailnetMatches(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
|
||||
}
|
||||
|
||||
policy := &api.IAMPolicy{
|
||||
Subs: tailnet.IAMPolicy.Subs,
|
||||
Emails: tailnet.IAMPolicy.Emails,
|
||||
Filters: tailnet.IAMPolicy.Filters,
|
||||
}
|
||||
|
||||
return connect.NewResponse(&api.GetIAMPolicyResponse{Policy: policy}), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetIAMPolicy(ctx context.Context, req *connect.Request[api.SetIAMPolicyRequest]) (*connect.Response[api.SetIAMPolicyResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
|
||||
}
|
||||
|
||||
tailnet.IAMPolicy = domain.IAMPolicy{
|
||||
Subs: req.Msg.Policy.Subs,
|
||||
Emails: req.Msg.Policy.Emails,
|
||||
Filters: req.Msg.Policy.Filters,
|
||||
}
|
||||
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connect.NewResponse(&api.SetIAMPolicyResponse{}), nil
|
||||
}
|
||||
@@ -109,10 +109,6 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.De
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.DeleteAuthFiltersByTailnet(ctx, req.Msg.TailnetId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.DeleteACLPolicy(ctx, req.Msg.TailnetId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user