feat: add auth filters

This commit is contained in:
Johan Siebens
2022-05-26 09:07:54 +02:00
parent 84a57ea409
commit 198b6795b1
16 changed files with 1227 additions and 187 deletions
+143
View File
@@ -0,0 +1,143 @@
package cmd
import (
"context"
"fmt"
"github.com/hashicorp/go-bexpr"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/muesli/coral"
"github.com/rodaine/table"
)
func authFilterCommand() *coral.Command {
command := &coral.Command{
Use: "auth-filters",
Short: "Manage ionscale auth filters",
Long: `This command allows operations on ionscale auth filter resources. Example:
$ ionscale auth-filter create`,
}
command.AddCommand(createAuthFilterCommand())
command.AddCommand(listAuthFilterCommand())
return command
}
func listAuthFilterCommand() *coral.Command {
command := &coral.Command{
Use: "list",
SilenceUsage: true,
}
var authMethodID uint64
var target = Target{}
target.prepareCommand(command)
command.Flags().Uint64Var(&authMethodID, "auth-method-id", 0, "")
command.RunE = func(command *coral.Command, args []string) error {
client, c, err := target.createGRPCClient()
if err != nil {
return err
}
defer safeClose(c)
req := &api.ListAuthFiltersRequest{}
if authMethodID != 0 {
req.AuthMethodId = &authMethodID
}
resp, err := client.ListAuthFilters(context.Background(), req)
if err != nil {
return err
}
tbl := table.New("ID", "AUTH_METHOD", "TAILNET", "EXPR")
for _, filter := range resp.AuthFilters {
if filter.Tailnet != nil {
tbl.AddRow(filter.Id, filter.AuthMethod.Name, filter.Tailnet.Name, filter.Expr)
} else {
tbl.AddRow(filter.Id, filter.AuthMethod.Name, "", filter.Expr)
}
}
tbl.Print()
return nil
}
return command
}
func createAuthFilterCommand() *coral.Command {
command := &coral.Command{
Use: "create",
SilenceUsage: true,
}
var expr string
var tailnetID uint64
var tailnetName string
var authMethodID uint64
var authMethodName string
var target = Target{}
target.prepareCommand(command)
command.Flags().StringVar(&expr, "expr", "*", "")
command.Flags().StringVar(&tailnetName, "tailnet", "", "")
command.Flags().Uint64Var(&tailnetID, "tailnet-id", 0, "")
command.Flags().StringVar(&authMethodName, "auth-method", "", "")
command.Flags().Uint64Var(&authMethodID, "auth-method-id", 0, "")
command.RunE = func(command *coral.Command, args []string) error {
if expr != "*" {
if _, err := bexpr.CreateEvaluator(expr); err != nil {
return fmt.Errorf("invalid expression: %v", err)
}
}
client, c, err := target.createGRPCClient()
if err != nil {
return err
}
defer safeClose(c)
tailnet, err := findTailnet(client, tailnetName, tailnetID)
if err != nil {
return err
}
authMethod, err := findAuthMethod(client, authMethodName, authMethodID)
if err != nil {
return err
}
req := &api.CreateAuthFilterRequest{
AuthMethodId: authMethod.Id,
TailnetId: tailnet.Id,
Expr: expr,
}
resp, err := client.CreateAuthFilter(context.Background(), req)
if err != nil {
return err
}
tbl := table.New("ID", "AUTH_METHOD", "TAILNET", "EXPR")
if resp.AuthFilter.Tailnet != nil {
tbl.AddRow(resp.AuthFilter.Id, resp.AuthFilter.AuthMethod.Name, resp.AuthFilter.Tailnet.Name, resp.AuthFilter.Expr)
} else {
tbl.AddRow(resp.AuthFilter.Id, resp.AuthFilter.AuthMethod.Name, "", resp.AuthFilter.Expr)
}
tbl.Print()
return nil
}
return command
}
+19
View File
@@ -26,6 +26,25 @@ func findTailnet(client api.IonscaleClient, tailnet string, tailnetID uint64) (*
return nil, fmt.Errorf("requested tailnet not found or you are not authorized for this tailnet")
}
func findAuthMethod(client api.IonscaleClient, authMethod string, authMethodID uint64) (*api.AuthMethod, error) {
if authMethodID == 0 && authMethod == "" {
return nil, fmt.Errorf("requested auth method not found or you are not authorized for this tailnet")
}
resp, err := client.ListAuthMethods(context.Background(), &api.ListAuthMethodsRequest{})
if err != nil {
return nil, err
}
for _, t := range resp.AuthMethods {
if t.Id == authMethodID || t.Name == authMethod {
return t, nil
}
}
return nil, fmt.Errorf("requested auth method not found or you are not authorized for this tailnet")
}
func safeClose(c io.Closer) {
if c != nil {
_ = c.Close()
+1
View File
@@ -11,6 +11,7 @@ func Command() *coral.Command {
rootCmd.AddCommand(serverCommand())
rootCmd.AddCommand(versionCommand())
rootCmd.AddCommand(authMethodsCommand())
rootCmd.AddCommand(authFilterCommand())
rootCmd.AddCommand(tailnetCommand())
rootCmd.AddCommand(authkeysCommand())
rootCmd.AddCommand(machineCommands())
+1
View File
@@ -46,6 +46,7 @@ func migrate(db *gorm.DB, repository domain.Repository) error {
&domain.Tailnet{},
&domain.TailnetConfig{},
&domain.AuthMethod{},
&domain.AuthFilter{},
&domain.Account{},
&domain.User{},
&domain.AuthKey{},
+97
View File
@@ -0,0 +1,97 @@
package domain
import (
"context"
"errors"
"github.com/hashicorp/go-bexpr"
"github.com/mitchellh/pointerstructure"
)
type AuthFilter struct {
ID uint64 `gorm:"primary_key;autoIncrement:false"`
Expr string
AuthMethodID uint64
AuthMethod AuthMethod
TailnetID *uint64
Tailnet *Tailnet
}
type AuthFilters []AuthFilter
func (f *AuthFilter) Evaluate(v interface{}) (bool, error) {
if f.Expr == "*" {
return true, nil
}
eval, err := bexpr.CreateEvaluator(f.Expr)
if err != nil {
return false, err
}
result, err := eval.Evaluate(v)
if err != nil && !errors.Is(err, pointerstructure.ErrNotFound) {
return false, err
}
return result, err
}
func (fs AuthFilters) Evaluate(v interface{}) []Tailnet {
var tailnetIDMap = make(map[uint64]bool)
var tailnets []Tailnet
for _, f := range fs {
approved, err := f.Evaluate(v)
if err == nil && approved {
if f.TailnetID != nil {
_, alreadyApproved := tailnetIDMap[*f.TailnetID]
if !alreadyApproved {
tailnetIDMap[*f.TailnetID] = true
tailnets = append(tailnets, *f.Tailnet)
}
}
}
}
return tailnets
}
func (r *repository) SaveAuthFilter(ctx context.Context, m *AuthFilter) error {
tx := r.withContext(ctx).Save(m)
if tx.Error != nil {
return tx.Error
}
return nil
}
func (r *repository) ListAuthFilters(ctx context.Context) (AuthFilters, error) {
var filters = []AuthFilter{}
tx := r.withContext(ctx).
Preload("Tailnet").
Preload("AuthMethod").
Find(&filters)
if tx.Error != nil {
return nil, tx.Error
}
return filters, nil
}
func (r *repository) ListAuthFiltersByAuthMethod(ctx context.Context, authMethodID uint64) (AuthFilters, error) {
var filters = []AuthFilter{}
tx := r.withContext(ctx).
Preload("Tailnet").
Preload("AuthMethod").
Where("auth_method_id = ?", authMethodID).Find(&filters)
if tx.Error != nil {
return nil, tx.Error
}
return filters, nil
}
+4
View File
@@ -18,6 +18,10 @@ type Repository interface {
ListAuthMethods(ctx context.Context) ([]AuthMethod, error)
GetAuthMethod(ctx context.Context, id uint64) (*AuthMethod, error)
SaveAuthFilter(ctx context.Context, m *AuthFilter) error
ListAuthFilters(ctx context.Context) (AuthFilters, error)
ListAuthFiltersByAuthMethod(ctx context.Context, authMethodID uint64) (AuthFilters, error)
GetAccount(ctx context.Context, accountID uint64) (*Account, error)
GetOrCreateAccount(ctx context.Context, authMethodID uint64, externalID, loginName string) (*Account, bool, error)
+9 -1
View File
@@ -124,11 +124,17 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
return err
}
tailnets, err := h.repository.ListTailnets(ctx)
filters, err := h.repository.ListAuthFiltersByAuthMethod(ctx, state.AuthMethod)
if err != nil {
return err
}
tailnets := filters.Evaluate(user.Attr)
if len(tailnets) == 0 {
return c.Redirect(http.StatusFound, "/a/error?e=ua")
}
account, _, err := h.repository.GetOrCreateAccount(ctx, state.AuthMethod, user.ID, user.Name)
if err != nil {
return err
@@ -157,6 +163,8 @@ func (h *AuthenticationHandlers) Error(c echo.Context) error {
switch e {
case "iak":
return c.Render(http.StatusForbidden, "invalidauthkey.html", nil)
case "ua":
return c.Render(http.StatusForbidden, "unauthorized.html", nil)
}
return c.Render(http.StatusOK, "error.html", nil)
}
+107
View File
@@ -0,0 +1,107 @@
package service
import (
"context"
"fmt"
"github.com/hashicorp/go-bexpr"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/util"
"github.com/jsiebens/ionscale/pkg/gen/api"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func (s *Service) ListAuthFilters(ctx context.Context, req *api.ListAuthFiltersRequest) (*api.ListAuthFiltersResponse, error) {
response := &api.ListAuthFiltersResponse{AuthFilters: []*api.AuthFilter{}}
if req.AuthMethodId == nil {
filters, err := s.repository.ListAuthFilters(ctx)
if err != nil {
return nil, err
}
for _, filter := range filters {
response.AuthFilters = append(response.AuthFilters, s.mapToApi(&filter.AuthMethod, filter))
}
} else {
authMethod, err := s.repository.GetAuthMethod(ctx, *req.AuthMethodId)
if err != nil {
return nil, err
}
if authMethod == nil {
return nil, status.Error(codes.NotFound, "invalid auth method id")
}
filters, err := s.repository.ListAuthFiltersByAuthMethod(ctx, authMethod.ID)
if err != nil {
return nil, err
}
for _, filter := range filters {
response.AuthFilters = append(response.AuthFilters, s.mapToApi(&filter.AuthMethod, filter))
}
}
return response, nil
}
func (s *Service) CreateAuthFilter(ctx context.Context, req *api.CreateAuthFilterRequest) (*api.CreateAuthFilterResponse, error) {
authMethod, err := s.repository.GetAuthMethod(ctx, req.AuthMethodId)
if err != nil {
return nil, err
}
if authMethod == nil {
return nil, status.Error(codes.NotFound, "invalid auth method id")
}
tailnet, err := s.repository.GetTailnet(ctx, req.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, status.Error(codes.NotFound, "invalid tailnet id")
}
if req.Expr != "*" {
if _, err := bexpr.CreateEvaluator(req.Expr); err != nil {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid expression: %v", err))
}
}
authFilter := &domain.AuthFilter{
ID: util.NextID(),
Expr: req.Expr,
AuthMethod: *authMethod,
Tailnet: tailnet,
}
if err := s.repository.SaveAuthFilter(ctx, authFilter); err != nil {
return nil, err
}
response := api.CreateAuthFilterResponse{AuthFilter: s.mapToApi(authMethod, *authFilter)}
return &response, nil
}
func (s *Service) mapToApi(authMethod *domain.AuthMethod, filter domain.AuthFilter) *api.AuthFilter {
result := api.AuthFilter{
Id: filter.ID,
Expr: filter.Expr,
AuthMethod: &api.Ref{
Id: authMethod.ID,
Name: authMethod.Name,
},
}
if filter.Tailnet != nil {
id := filter.Tailnet.ID
name := filter.Tailnet.Name
result.Tailnet = &api.Ref{
Id: id,
Name: name,
}
}
return &result
}
+63
View File
@@ -0,0 +1,63 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@200;300;400;500;600;700&display=swap');
* {
margin: 0;
padding: 0;
box-sizing: border-box;
font-family: 'Poppins', sans-serif;
}
body {
width: 100%;
height: 100vh;
padding: 10px;
background: #379683;
}
.wrapper {
background: #fff;
max-width: 400px;
width: 100%;
margin: 120px auto;
padding: 25px;
border-radius: 5px;
box-shadow: 0 10px 15px rgba(0, 0, 0, 0.1);
}
.selectionList li {
position: relative;
list-style: none;
height: 45px;
line-height: 45px;
margin-bottom: 8px;
background: #f2f2f2;
border-radius: 3px;
overflow: hidden;
box-shadow: 0 2px 2px rgba(0, 0, 0, 0.1);
}
.selectionList li button {
margin: 0;
display: block;
width: 100%;
height: 100%;
border: none;
}
</style>
<title>ionscale</title>
</head>
<body>
<div class="wrapper">
<div style="text-align: center">
<p><b>Authentication successful</b></p>
<small>but you're <b style="color: red">not</b> authorized to use any network</small>
</div>
</div>
</body>
</html>