mirror of
https://github.com/jsiebens/ionscale.git
synced 2026-03-31 15:07:49 +01:00
feat: add support for oidc providers and users
This commit is contained in:
+10
-9
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/pkg/gen/api"
|
||||
"github.com/muesli/coral"
|
||||
"gopkg.in/yaml.v2"
|
||||
@@ -44,8 +45,14 @@ func getACLConfig() *coral.Command {
|
||||
return err
|
||||
}
|
||||
|
||||
var p domain.ACLPolicy
|
||||
|
||||
if err := json.Unmarshal(resp.Value, &p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if asJson {
|
||||
marshal, err := json.MarshalIndent(resp.Policy, "", " ")
|
||||
marshal, err := json.MarshalIndent(&p, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -53,7 +60,7 @@ func getACLConfig() *coral.Command {
|
||||
fmt.Println()
|
||||
fmt.Println(string(marshal))
|
||||
} else {
|
||||
marshal, err := yaml.Marshal(resp.Policy)
|
||||
marshal, err := yaml.Marshal(&p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -91,12 +98,6 @@ func setACLConfig() *coral.Command {
|
||||
return err
|
||||
}
|
||||
|
||||
var policy api.Policy
|
||||
|
||||
if err := json.Unmarshal(rawJson, &policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, c, err := target.createGRPCClient()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -108,7 +109,7 @@ func setACLConfig() *coral.Command {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = client.SetACLPolicy(context.Background(), &api.SetACLPolicyRequest{TailnetId: tailnet.Id, Policy: &policy})
|
||||
_, err = client.SetACLPolicy(context.Background(), &api.SetACLPolicyRequest{TailnetId: tailnet.Id, Value: rawJson})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/jsiebens/ionscale/pkg/gen/api"
|
||||
"github.com/muesli/coral"
|
||||
"github.com/rodaine/table"
|
||||
)
|
||||
|
||||
func authMethodsCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "auth-methods",
|
||||
Short: "Manage ionscale auth methods",
|
||||
}
|
||||
|
||||
command.AddCommand(listAuthMethods())
|
||||
command.AddCommand(createAuthMethodCommand())
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func listAuthMethods() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "list",
|
||||
Short: "List auth methods",
|
||||
Long: `List auth methods in this ionscale instance. Example:
|
||||
|
||||
$ ionscale auth-methods list`,
|
||||
}
|
||||
|
||||
var target = Target{}
|
||||
target.prepareCommand(command)
|
||||
|
||||
command.RunE = func(command *coral.Command, args []string) error {
|
||||
|
||||
client, c, err := target.createGRPCClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer safeClose(c)
|
||||
|
||||
resp, err := client.ListAuthMethods(context.Background(), &api.ListAuthMethodsRequest{})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tbl := table.New("ID", "NAME", "TYPE")
|
||||
for _, m := range resp.AuthMethods {
|
||||
tbl.AddRow(m.Id, m.Name, m.Type)
|
||||
}
|
||||
tbl.Print()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func createAuthMethodCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "create",
|
||||
Short: "Create a new auth method",
|
||||
}
|
||||
|
||||
command.AddCommand(createOIDCAuthMethodCommand())
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func createOIDCAuthMethodCommand() *coral.Command {
|
||||
command := &coral.Command{
|
||||
Use: "oidc",
|
||||
Short: "Create a new auth method",
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var methodName string
|
||||
|
||||
var clientId string
|
||||
var clientSecret string
|
||||
var issuer string
|
||||
|
||||
var target = Target{}
|
||||
|
||||
target.prepareCommand(command)
|
||||
command.Flags().StringVarP(&methodName, "name", "n", "", "")
|
||||
command.Flags().StringVar(&clientId, "client-id", "", "")
|
||||
command.Flags().StringVar(&clientSecret, "client-secret", "", "")
|
||||
command.Flags().StringVar(&issuer, "issuer", "", "")
|
||||
|
||||
_ = command.MarkFlagRequired("name")
|
||||
_ = command.MarkFlagRequired("client-id")
|
||||
_ = command.MarkFlagRequired("client-secret")
|
||||
_ = command.MarkFlagRequired("issuer")
|
||||
|
||||
command.RunE = func(command *coral.Command, args []string) error {
|
||||
|
||||
client, c, err := target.createGRPCClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer safeClose(c)
|
||||
|
||||
req := &api.CreateAuthMethodRequest{
|
||||
Type: "oidc",
|
||||
Name: methodName,
|
||||
Issuer: issuer,
|
||||
ClientId: clientId,
|
||||
ClientSecret: clientSecret,
|
||||
}
|
||||
|
||||
resp, err := client.CreateAuthMethod(context.Background(), req)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tbl := table.New("ID", "NAME", "TYPE")
|
||||
tbl.AddRow(resp.AuthMethod.Id, resp.AuthMethod.Name, resp.AuthMethod.Type)
|
||||
tbl.Print()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
@@ -10,6 +10,7 @@ func Command() *coral.Command {
|
||||
rootCmd.AddCommand(derpMapCommand())
|
||||
rootCmd.AddCommand(serverCommand())
|
||||
rootCmd.AddCommand(versionCommand())
|
||||
rootCmd.AddCommand(authMethodsCommand())
|
||||
rootCmd.AddCommand(tailnetCommand())
|
||||
rootCmd.AddCommand(authkeysCommand())
|
||||
rootCmd.AddCommand(machineCommands())
|
||||
|
||||
@@ -45,6 +45,8 @@ func migrate(db *gorm.DB, repository domain.Repository) error {
|
||||
&domain.ServerConfig{},
|
||||
&domain.Tailnet{},
|
||||
&domain.TailnetConfig{},
|
||||
&domain.AuthMethod{},
|
||||
&domain.Account{},
|
||||
&domain.User{},
|
||||
&domain.AuthKey{},
|
||||
&domain.Machine{},
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/jsiebens/ionscale/internal/util"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Account struct {
|
||||
ID uint64 `gorm:"primary_key;autoIncrement:false"`
|
||||
|
||||
ExternalID string
|
||||
LoginName string
|
||||
AuthMethodID uint64
|
||||
}
|
||||
|
||||
func (r *repository) GetOrCreateAccount(ctx context.Context, authMethodID uint64, externalID, loginName string) (*Account, bool, error) {
|
||||
account := &Account{}
|
||||
id := util.NextID()
|
||||
|
||||
tx := r.withContext(ctx).
|
||||
Where(Account{AuthMethodID: authMethodID, ExternalID: externalID}).
|
||||
Attrs(Account{ID: id, LoginName: loginName}).
|
||||
FirstOrCreate(account)
|
||||
|
||||
if tx.Error != nil {
|
||||
return nil, false, tx.Error
|
||||
}
|
||||
|
||||
return account, account.ID == id, nil
|
||||
}
|
||||
|
||||
func (r *repository) GetAccount(ctx context.Context, id uint64) (*Account, error) {
|
||||
var account Account
|
||||
tx := r.withContext(ctx).Take(&account, "id = ?", id)
|
||||
|
||||
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
|
||||
return &account, nil
|
||||
}
|
||||
+23
-2
@@ -9,8 +9,9 @@ import (
|
||||
)
|
||||
|
||||
type ACLPolicy struct {
|
||||
Hosts map[string]string `json:"hosts,omitempty"`
|
||||
ACLs []ACL `json:"acls"`
|
||||
Groups map[string][]string `json:"groups,omitempty"`
|
||||
Hosts map[string]string `json:"hosts,omitempty"`
|
||||
ACLs []ACL `json:"acls"`
|
||||
}
|
||||
|
||||
type ACL struct {
|
||||
@@ -157,6 +158,26 @@ func (a *aclEngine) expandMachineAlias(m *Machine, alias string, src bool) []str
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(alias, "@") && !m.HasTags() && m.HasUser(alias) {
|
||||
return []string{m.IPv4.String(), m.IPv6.String()}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(alias, "group:") && !m.HasTags() {
|
||||
users, ok := a.policy.Groups[alias]
|
||||
|
||||
if !ok {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if m.HasUser(u) {
|
||||
return []string{m.IPv4.String(), m.IPv6.String()}
|
||||
}
|
||||
}
|
||||
|
||||
return []string{}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(alias, "tag:") && m.HasTag(alias[4:]) {
|
||||
return []string{m.IPv4.String(), m.IPv6.String()}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AuthMethod struct {
|
||||
ID uint64 `gorm:"primary_key;autoIncrement:false"`
|
||||
Name string `gorm:"type:varchar(64);unique_index"`
|
||||
Type string
|
||||
Issuer string
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
}
|
||||
|
||||
func (r *repository) SaveAuthMethod(ctx context.Context, m *AuthMethod) error {
|
||||
tx := r.withContext(ctx).Save(m)
|
||||
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) ListAuthMethods(ctx context.Context) ([]AuthMethod, error) {
|
||||
var authMethods = []AuthMethod{}
|
||||
tx := r.withContext(ctx).Find(&authMethods)
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
return authMethods, nil
|
||||
}
|
||||
|
||||
func (r *repository) GetAuthMethod(ctx context.Context, id uint64) (*AuthMethod, error) {
|
||||
var m AuthMethod
|
||||
tx := r.withContext(ctx).First(&m, "id = ?", id)
|
||||
|
||||
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
@@ -61,6 +61,10 @@ func (m *Machine) HasTag(tag string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Machine) HasUser(loginName string) bool {
|
||||
return m.User.Name == loginName
|
||||
}
|
||||
|
||||
func (m *Machine) HasTags() bool {
|
||||
return len(m.Tags) != 0
|
||||
}
|
||||
|
||||
@@ -14,6 +14,13 @@ type Repository interface {
|
||||
GetDERPMap(ctx context.Context) (*tailcfg.DERPMap, error)
|
||||
SetDERPMap(ctx context.Context, v *tailcfg.DERPMap) error
|
||||
|
||||
SaveAuthMethod(ctx context.Context, m *AuthMethod) error
|
||||
ListAuthMethods(ctx context.Context) ([]AuthMethod, error)
|
||||
GetAuthMethod(ctx context.Context, id uint64) (*AuthMethod, error)
|
||||
|
||||
GetAccount(ctx context.Context, accountID uint64) (*Account, error)
|
||||
GetOrCreateAccount(ctx context.Context, authMethodID uint64, externalID, loginName string) (*Account, bool, error)
|
||||
|
||||
GetOrCreateTailnet(ctx context.Context, name string) (*Tailnet, bool, error)
|
||||
GetTailnet(ctx context.Context, id uint64) (*Tailnet, error)
|
||||
ListTailnets(ctx context.Context) ([]Tailnet, error)
|
||||
@@ -35,6 +42,7 @@ type Repository interface {
|
||||
|
||||
GetOrCreateServiceUser(ctx context.Context, tailnet *Tailnet) (*User, bool, error)
|
||||
ListUsers(ctx context.Context, tailnetID uint64) (Users, error)
|
||||
GetOrCreateUserWithAccount(ctx context.Context, tailnet *Tailnet, account *Account) (*User, bool, error)
|
||||
DeleteUsersByTailnet(ctx context.Context, tailnetID uint64) error
|
||||
|
||||
SaveMachine(ctx context.Context, m *Machine) error
|
||||
|
||||
@@ -9,6 +9,7 @@ type TailnetRole string
|
||||
|
||||
const (
|
||||
TailnetRoleService TailnetRole = "service"
|
||||
TailnetRoleMember TailnetRole = "member"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
@@ -18,6 +19,9 @@ type User struct {
|
||||
TailnetRole TailnetRole
|
||||
TailnetID uint64
|
||||
Tailnet Tailnet
|
||||
|
||||
AccountID *uint64
|
||||
Account *Account
|
||||
}
|
||||
|
||||
type Users []User
|
||||
@@ -54,3 +58,19 @@ func (r *repository) DeleteUsersByTailnet(ctx context.Context, tailnetID uint64)
|
||||
tx := r.withContext(ctx).Where("tailnet_id = ?", tailnetID).Delete(&User{})
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
func (r *repository) GetOrCreateUserWithAccount(ctx context.Context, tailnet *Tailnet, account *Account) (*User, bool, error) {
|
||||
user := &User{}
|
||||
id := util.NextID()
|
||||
|
||||
query := User{AccountID: &account.ID, TailnetID: tailnet.ID}
|
||||
attrs := User{ID: id, Name: account.LoginName, TailnetID: tailnet.ID, AccountID: &account.ID, TailnetRole: TailnetRoleMember}
|
||||
|
||||
tx := r.withContext(ctx).Where(query).Attrs(attrs).FirstOrCreate(user)
|
||||
|
||||
if tx.Error != nil {
|
||||
return nil, false, tx.Error
|
||||
}
|
||||
|
||||
return user, user.ID == id, nil
|
||||
}
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/jsiebens/ionscale/internal/addr"
|
||||
"github.com/jsiebens/ionscale/internal/provider"
|
||||
"github.com/mr-tron/base58"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/jsiebens/ionscale/internal/config"
|
||||
@@ -21,6 +26,7 @@ func NewAuthenticationHandlers(
|
||||
config: config,
|
||||
repository: repository,
|
||||
pendingMachineRegistrationRequests: pendingMachineRegistrationRequests,
|
||||
pendingOAuthUsers: cache.New(5*time.Minute, 10*time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,21 +34,118 @@ type AuthenticationHandlers struct {
|
||||
repository domain.Repository
|
||||
config *config.Config
|
||||
pendingMachineRegistrationRequests *cache.Cache
|
||||
pendingOAuthUsers *cache.Cache
|
||||
}
|
||||
|
||||
type AuthFormData struct {
|
||||
AuthMethods []domain.AuthMethod
|
||||
}
|
||||
|
||||
type TailnetSelectionData struct {
|
||||
Tailnets []domain.Tailnet
|
||||
}
|
||||
|
||||
type oauthState struct {
|
||||
Key string
|
||||
AuthMethod uint64
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
key := c.Param("key")
|
||||
|
||||
if _, ok := h.pendingMachineRegistrationRequests.Get(key); !ok {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
}
|
||||
|
||||
methods, err := h.repository.ListAuthMethods(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Render(http.StatusOK, "auth.html", &AuthFormData{AuthMethods: methods})
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
|
||||
key := c.Param("key")
|
||||
authKey := c.FormValue("ak")
|
||||
authMethodId := c.FormValue("s")
|
||||
|
||||
if _, ok := h.pendingMachineRegistrationRequests.Get(key); !ok {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
}
|
||||
|
||||
if authKey != "" {
|
||||
return h.endMachineRegistrationFlow(c, key, authKey)
|
||||
return h.endMachineRegistrationFlow(c, &oauthState{Key: key})
|
||||
}
|
||||
|
||||
return c.Render(http.StatusOK, "auth.html", nil)
|
||||
if authMethodId != "" {
|
||||
id, err := strconv.ParseUint(authMethodId, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
method, err := h.repository.GetAuthMethod(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state, err := h.createState(key, method.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authProvider, err := provider.NewProvider(method)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
redirectUrl := authProvider.GetLoginURL(h.config.CreateUrl("/a/callback"), state)
|
||||
|
||||
return c.Redirect(http.StatusFound, redirectUrl)
|
||||
}
|
||||
|
||||
return c.Redirect(http.StatusFound, "/a/"+key)
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
|
||||
code := c.QueryParam("code")
|
||||
state, err := h.readState(c.QueryParam("state"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := h.exchangeUser(ctx, code, state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tailnets, err := h.repository.ListTailnets(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account, _, err := h.repository.GetOrCreateAccount(ctx, state.AuthMethod, user.ID, user.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.pendingOAuthUsers.Set(state.Key, account, cache.DefaultExpiration)
|
||||
|
||||
return c.Render(http.StatusOK, "tailnets.html", &TailnetSelectionData{Tailnets: tailnets})
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
|
||||
state, err := h.readState(c.QueryParam("state"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.endMachineRegistrationFlow(c, state)
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) Success(c echo.Context) error {
|
||||
@@ -58,36 +161,71 @@ func (h *AuthenticationHandlers) Error(c echo.Context) error {
|
||||
return c.Render(http.StatusOK, "error.html", nil)
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, registrationKey, authKeyParam string) error {
|
||||
func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, state *oauthState) error {
|
||||
ctx := c.Request().Context()
|
||||
|
||||
defer h.pendingMachineRegistrationRequests.Delete(registrationKey)
|
||||
defer h.pendingMachineRegistrationRequests.Delete(state.Key)
|
||||
|
||||
preqItem, preqOK := h.pendingMachineRegistrationRequests.Get(registrationKey)
|
||||
preqItem, preqOK := h.pendingMachineRegistrationRequests.Get(state.Key)
|
||||
if !preqOK {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
}
|
||||
|
||||
authKeyParam := c.FormValue("ak")
|
||||
tailnetIDParam := c.FormValue("s")
|
||||
|
||||
preq := preqItem.(*pendingMachineRegistrationRequest)
|
||||
req := preq.request
|
||||
machineKey := preq.machineKey
|
||||
nodeKey := req.NodeKey.String()
|
||||
|
||||
authKey, err := h.repository.LoadAuthKey(ctx, authKeyParam)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var tailnet *domain.Tailnet
|
||||
var user *domain.User
|
||||
var ephemeral bool
|
||||
var tags = []string{}
|
||||
|
||||
if authKey == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error?e=iak")
|
||||
}
|
||||
if authKeyParam != "" {
|
||||
authKey, err := h.repository.LoadAuthKey(ctx, authKeyParam)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tailnet := authKey.Tailnet
|
||||
user := authKey.User
|
||||
if authKey == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error?e=iak")
|
||||
}
|
||||
|
||||
tailnet = &authKey.Tailnet
|
||||
user = &authKey.User
|
||||
tags = authKey.Tags
|
||||
ephemeral = authKey.Ephemeral
|
||||
} else {
|
||||
parseUint, err := strconv.ParseUint(tailnetIDParam, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tailnet, err = h.repository.GetTailnet(ctx, parseUint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
item, ok := h.pendingOAuthUsers.Get(state.Key)
|
||||
if !ok {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
}
|
||||
|
||||
oa := item.(*domain.Account)
|
||||
|
||||
user, _, err = h.repository.GetOrCreateUserWithAccount(ctx, tailnet, oa)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ephemeral = false
|
||||
}
|
||||
|
||||
var m *domain.Machine
|
||||
|
||||
m, err = h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
|
||||
m, err := h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -95,10 +233,17 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
if m == nil {
|
||||
now := time.Now().UTC()
|
||||
|
||||
registeredTags := authKey.Tags
|
||||
registeredTags := tags
|
||||
advertisedTags := domain.SanitizeTags(req.Hostinfo.RequestTags)
|
||||
tags := append(registeredTags, advertisedTags...)
|
||||
|
||||
if len(tags) != 0 {
|
||||
user, _, err = h.repository.GetOrCreateServiceUser(ctx, tailnet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
|
||||
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
|
||||
if err != nil {
|
||||
@@ -111,13 +256,13 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
NameIdx: nameIdx,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: nodeKey,
|
||||
Ephemeral: authKey.Ephemeral,
|
||||
Ephemeral: ephemeral,
|
||||
RegisteredTags: registeredTags,
|
||||
Tags: domain.SanitizeTags(tags),
|
||||
CreatedAt: now,
|
||||
|
||||
User: user,
|
||||
Tailnet: tailnet,
|
||||
User: *user,
|
||||
Tailnet: *tailnet,
|
||||
}
|
||||
|
||||
if !req.Expiry.IsZero() {
|
||||
@@ -131,10 +276,17 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
m.IPv4 = domain.IP{IP: ipv4}
|
||||
m.IPv6 = domain.IP{IP: ipv6}
|
||||
} else {
|
||||
registeredTags := authKey.Tags
|
||||
registeredTags := tags
|
||||
advertisedTags := domain.SanitizeTags(req.Hostinfo.RequestTags)
|
||||
tags := append(registeredTags, advertisedTags...)
|
||||
|
||||
if len(tags) != 0 {
|
||||
user, _, err = h.repository.GetOrCreateServiceUser(ctx, tailnet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
|
||||
if m.Name != sanitizeHostname {
|
||||
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
|
||||
@@ -145,13 +297,13 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
m.NameIdx = nameIdx
|
||||
}
|
||||
m.NodeKey = nodeKey
|
||||
m.Ephemeral = authKey.Ephemeral
|
||||
m.Ephemeral = ephemeral
|
||||
m.RegisteredTags = registeredTags
|
||||
m.Tags = domain.SanitizeTags(tags)
|
||||
m.UserID = user.ID
|
||||
m.User = user
|
||||
m.User = *user
|
||||
m.TailnetID = tailnet.ID
|
||||
m.Tailnet = tailnet
|
||||
m.Tailnet = *tailnet
|
||||
m.ExpiresAt = nil
|
||||
}
|
||||
|
||||
@@ -161,3 +313,49 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
|
||||
return c.Redirect(http.StatusFound, "/a/success")
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) getAuthProvider(ctx context.Context, authMethodId uint64) (provider.AuthProvider, error) {
|
||||
authMethod, err := h.repository.GetAuthMethod(ctx, authMethodId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return provider.NewProvider(authMethod)
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) exchangeUser(ctx context.Context, code string, state *oauthState) (*provider.User, error) {
|
||||
redirectUrl := h.config.CreateUrl("/a/callback")
|
||||
|
||||
authProvider, err := h.getAuthProvider(ctx, state.AuthMethod)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := authProvider.Exchange(redirectUrl, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) createState(key string, authMethodId uint64) (string, error) {
|
||||
stateMap := oauthState{Key: key, AuthMethod: authMethodId}
|
||||
marshal, err := json.Marshal(&stateMap)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base58.FastBase58Encoding(marshal), nil
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) readState(s string) (*oauthState, error) {
|
||||
decodedState, err := base58.FastBase58Decoding(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var state = &oauthState{}
|
||||
if err := json.Unmarshal(decodedState, state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type OIDCProvider struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
provider *oidc.Provider
|
||||
verifier *oidc.IDTokenVerifier
|
||||
}
|
||||
|
||||
func NewOIDCProvider(c *domain.AuthMethod) (*OIDCProvider, error) {
|
||||
provider, err := oidc.NewProvider(context.Background(), c.Issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
verifier := provider.Verifier(&oidc.Config{ClientID: c.ClientId})
|
||||
|
||||
return &OIDCProvider{
|
||||
clientID: c.ClientId,
|
||||
clientSecret: c.ClientSecret,
|
||||
provider: provider,
|
||||
verifier: verifier,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) GetLoginURL(redirectURI, state string) string {
|
||||
oauth2Config := oauth2.Config{
|
||||
ClientID: p.clientID,
|
||||
ClientSecret: p.clientSecret,
|
||||
RedirectURL: redirectURI,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
return oauth2Config.AuthCodeURL(state, oauth2.ApprovalForce)
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) Exchange(redirectURI, code string) (*User, error) {
|
||||
oauth2Config := oauth2.Config{
|
||||
ClientID: p.clientID,
|
||||
ClientSecret: p.clientSecret,
|
||||
RedirectURL: redirectURI,
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
oauth2Token, err := oauth2Config.Exchange(context.Background(), code)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract the ID Token from OAuth2 token.
|
||||
rawIdToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok || strings.TrimSpace(rawIdToken) == "" {
|
||||
return nil, fmt.Errorf("id_token missing")
|
||||
}
|
||||
|
||||
// Parse and verify ID Token payload.
|
||||
idToken, err := p.verifier.Verify(context.Background(), rawIdToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sub, email, tokenClaims, err := p.getTokenClaims(idToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userInfoClaims, err := p.getUserInfoClaims(oauth2Config, oauth2Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &User{
|
||||
ID: sub,
|
||||
Name: email,
|
||||
Attr: map[string]interface{}{
|
||||
"token": tokenClaims,
|
||||
"userinfo": userInfoClaims,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) getTokenClaims(idToken *oidc.IDToken) (string, string, map[string]interface{}, error) {
|
||||
var raw = make(map[string]interface{})
|
||||
var claims struct {
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// Extract default claims.
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return "", "", nil, fmt.Errorf("failed to parse id_token claims: %v", err)
|
||||
}
|
||||
|
||||
// Extract raw claims.
|
||||
if err := idToken.Claims(&raw); err != nil {
|
||||
return "", "", nil, fmt.Errorf("failed to parse id_token claims: %v", err)
|
||||
}
|
||||
|
||||
return claims.Sub, claims.Email, raw, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) getUserInfoClaims(config oauth2.Config, token *oauth2.Token) (map[string]interface{}, error) {
|
||||
var raw = make(map[string]interface{})
|
||||
|
||||
source := config.TokenSource(context.Background(), token)
|
||||
|
||||
info, err := p.provider.UserInfo(context.Background(), source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := info.Claims(&raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user info claims: %v", err)
|
||||
}
|
||||
|
||||
return raw, nil
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
)
|
||||
|
||||
type AuthProvider interface {
|
||||
GetLoginURL(redirectURI, state string) string
|
||||
Exchange(redirectURI, code string) (*User, error)
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
Name string
|
||||
Attr map[string]interface{}
|
||||
}
|
||||
|
||||
func NewProvider(m *domain.AuthMethod) (AuthProvider, error) {
|
||||
switch m.Type {
|
||||
case "oidc":
|
||||
return NewOIDCProvider(m)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown auth method type")
|
||||
}
|
||||
}
|
||||
@@ -130,7 +130,9 @@ func Start(config *config.Config) error {
|
||||
|
||||
auth := tlsAppHandler.Group("/a")
|
||||
auth.GET("/:key", authenticationHandlers.StartAuth)
|
||||
auth.POST("/:key", authenticationHandlers.StartAuth)
|
||||
auth.POST("/:key", authenticationHandlers.ProcessAuth)
|
||||
auth.GET("/callback", authenticationHandlers.Callback)
|
||||
auth.POST("/callback", authenticationHandlers.EndOAuth)
|
||||
auth.GET("/success", authenticationHandlers.Success)
|
||||
auth.GET("/error", authenticationHandlers.Error)
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/mapping"
|
||||
"github.com/jsiebens/ionscale/pkg/gen/api"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
@@ -15,12 +15,12 @@ func (s *Service) GetACLPolicy(ctx context.Context, req *api.GetACLPolicyRequest
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var p api.Policy
|
||||
if err := mapping.CopyViaJson(policy, &p); err != nil {
|
||||
marshal, err := json.Marshal(policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.GetACLPolicyResponse{Policy: &p}, nil
|
||||
return &api.GetACLPolicyResponse{Value: marshal}, nil
|
||||
}
|
||||
|
||||
func (s *Service) SetACLPolicy(ctx context.Context, req *api.SetACLPolicyRequest) (*api.SetACLPolicyResponse, error) {
|
||||
@@ -33,7 +33,7 @@ func (s *Service) SetACLPolicy(ctx context.Context, req *api.SetACLPolicyRequest
|
||||
}
|
||||
|
||||
var policy domain.ACLPolicy
|
||||
if err := mapping.CopyViaJson(req.Policy, &policy); err != nil {
|
||||
if err := json.Unmarshal(req.Value, &policy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/util"
|
||||
"github.com/jsiebens/ionscale/pkg/gen/api"
|
||||
)
|
||||
|
||||
func (s *Service) CreateAuthMethod(ctx context.Context, req *api.CreateAuthMethodRequest) (*api.CreateAuthMethodResponse, error) {
|
||||
|
||||
authMethod := &domain.AuthMethod{
|
||||
ID: util.NextID(),
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Issuer: req.Issuer,
|
||||
ClientId: req.ClientId,
|
||||
ClientSecret: req.ClientSecret,
|
||||
}
|
||||
|
||||
if err := s.repository.SaveAuthMethod(ctx, authMethod); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.CreateAuthMethodResponse{AuthMethod: &api.AuthMethod{
|
||||
Id: authMethod.ID,
|
||||
Type: authMethod.Type,
|
||||
Name: authMethod.Name,
|
||||
Issuer: authMethod.Issuer,
|
||||
ClientId: authMethod.ClientId,
|
||||
}}, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *Service) ListAuthMethods(ctx context.Context, _ *api.ListAuthMethodsRequest) (*api.ListAuthMethodsResponse, error) {
|
||||
methods, err := s.repository.ListAuthMethods(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := &api.ListAuthMethodsResponse{AuthMethods: []*api.AuthMethod{}}
|
||||
for _, m := range methods {
|
||||
response.AuthMethods = append(response.AuthMethods, &api.AuthMethod{
|
||||
Id: m.ID,
|
||||
Name: m.Name,
|
||||
Type: m.Type,
|
||||
})
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
@@ -74,10 +74,28 @@
|
||||
</head>
|
||||
<body>
|
||||
<div class="wrapper">
|
||||
{{if .AuthMethods}}
|
||||
<div style="text-align: left; padding-bottom: 10px">
|
||||
<p><b>Authentication required</b></p>
|
||||
<small>Login with:</small>
|
||||
</div>
|
||||
<form method="post">
|
||||
<ul class="selectionList">
|
||||
{{range .AuthMethods}}
|
||||
<li><button type="submit" name="s" value="{{.ID}}">{{.Name}}</button></li>
|
||||
{{end}}
|
||||
</ul>
|
||||
</form>
|
||||
<div style="text-align: left; padding-bottom: 10px; padding-top: 20px">
|
||||
<small>Or enter an <label for="ak">auth key</label> here:</small>
|
||||
</div>
|
||||
{{end}}
|
||||
{{if not .AuthMethods}}
|
||||
<div style="text-align: left; padding-bottom: 10px">
|
||||
<p><b>Authentication required</b></p>
|
||||
<small>Enter an <label for="ak">auth key</label> here:</small>
|
||||
</div>
|
||||
{{end}}
|
||||
<form method="post" style="text-align: right">
|
||||
<p><input id="ak" name="ak" type="text"/></p>
|
||||
<div style="padding-top: 10px">
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@200;300;400;500;600;700&display=swap');
|
||||
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
font-family: 'Poppins', sans-serif;
|
||||
}
|
||||
|
||||
body {
|
||||
width: 100%;
|
||||
height: 100vh;
|
||||
padding: 10px;
|
||||
background: #379683;
|
||||
}
|
||||
|
||||
.wrapper {
|
||||
background: #fff;
|
||||
max-width: 400px;
|
||||
width: 100%;
|
||||
margin: 120px auto;
|
||||
padding: 25px;
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 10px 15px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.selectionList li {
|
||||
position: relative;
|
||||
list-style: none;
|
||||
height: 45px;
|
||||
line-height: 45px;
|
||||
margin-bottom: 8px;
|
||||
background: #f2f2f2;
|
||||
border-radius: 3px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 2px 2px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.selectionList {
|
||||
padding-top: 5px
|
||||
}
|
||||
|
||||
.selectionList li button {
|
||||
margin: 0;
|
||||
display: block;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
border: none;
|
||||
}
|
||||
|
||||
input {
|
||||
display: block;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
button {
|
||||
padding-top: 10px;
|
||||
padding-bottom: 10px;
|
||||
padding-left: 20px;
|
||||
padding-right: 20px;
|
||||
height: 45px;
|
||||
border: none;
|
||||
}
|
||||
</style>
|
||||
<title>ionscale</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="wrapper">
|
||||
<div style="text-align: left; padding-bottom: 10px">
|
||||
<p><b>Tailnets</b></p>
|
||||
<small>Select your tailnet:</small>
|
||||
</div>
|
||||
<form method="post">
|
||||
<ul class="selectionList">
|
||||
{{range .Tailnets}}
|
||||
<li><button type="submit" name="s" value="{{.ID}}">{{.Name}}</button></li>
|
||||
{{end}}
|
||||
</ul>
|
||||
</form>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user