feat: add support for oidc providers and users

This commit is contained in:
Johan Siebens
2022-05-25 15:44:21 +02:00
parent 37e94ac915
commit 84a57ea409
29 changed files with 1658 additions and 419 deletions
+10 -9
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/muesli/coral"
"gopkg.in/yaml.v2"
@@ -44,8 +45,14 @@ func getACLConfig() *coral.Command {
return err
}
var p domain.ACLPolicy
if err := json.Unmarshal(resp.Value, &p); err != nil {
return err
}
if asJson {
marshal, err := json.MarshalIndent(resp.Policy, "", " ")
marshal, err := json.MarshalIndent(&p, "", " ")
if err != nil {
return err
}
@@ -53,7 +60,7 @@ func getACLConfig() *coral.Command {
fmt.Println()
fmt.Println(string(marshal))
} else {
marshal, err := yaml.Marshal(resp.Policy)
marshal, err := yaml.Marshal(&p)
if err != nil {
return err
}
@@ -91,12 +98,6 @@ func setACLConfig() *coral.Command {
return err
}
var policy api.Policy
if err := json.Unmarshal(rawJson, &policy); err != nil {
return err
}
client, c, err := target.createGRPCClient()
if err != nil {
return err
@@ -108,7 +109,7 @@ func setACLConfig() *coral.Command {
return err
}
_, err = client.SetACLPolicy(context.Background(), &api.SetACLPolicyRequest{TailnetId: tailnet.Id, Policy: &policy})
_, err = client.SetACLPolicy(context.Background(), &api.SetACLPolicyRequest{TailnetId: tailnet.Id, Value: rawJson})
if err != nil {
return err
}
+127
View File
@@ -0,0 +1,127 @@
package cmd
import (
"context"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/muesli/coral"
"github.com/rodaine/table"
)
func authMethodsCommand() *coral.Command {
command := &coral.Command{
Use: "auth-methods",
Short: "Manage ionscale auth methods",
}
command.AddCommand(listAuthMethods())
command.AddCommand(createAuthMethodCommand())
return command
}
func listAuthMethods() *coral.Command {
command := &coral.Command{
Use: "list",
Short: "List auth methods",
Long: `List auth methods in this ionscale instance. Example:
$ ionscale auth-methods list`,
}
var target = Target{}
target.prepareCommand(command)
command.RunE = func(command *coral.Command, args []string) error {
client, c, err := target.createGRPCClient()
if err != nil {
return err
}
defer safeClose(c)
resp, err := client.ListAuthMethods(context.Background(), &api.ListAuthMethodsRequest{})
if err != nil {
return err
}
tbl := table.New("ID", "NAME", "TYPE")
for _, m := range resp.AuthMethods {
tbl.AddRow(m.Id, m.Name, m.Type)
}
tbl.Print()
return nil
}
return command
}
func createAuthMethodCommand() *coral.Command {
command := &coral.Command{
Use: "create",
Short: "Create a new auth method",
}
command.AddCommand(createOIDCAuthMethodCommand())
return command
}
func createOIDCAuthMethodCommand() *coral.Command {
command := &coral.Command{
Use: "oidc",
Short: "Create a new auth method",
SilenceUsage: true,
}
var methodName string
var clientId string
var clientSecret string
var issuer string
var target = Target{}
target.prepareCommand(command)
command.Flags().StringVarP(&methodName, "name", "n", "", "")
command.Flags().StringVar(&clientId, "client-id", "", "")
command.Flags().StringVar(&clientSecret, "client-secret", "", "")
command.Flags().StringVar(&issuer, "issuer", "", "")
_ = command.MarkFlagRequired("name")
_ = command.MarkFlagRequired("client-id")
_ = command.MarkFlagRequired("client-secret")
_ = command.MarkFlagRequired("issuer")
command.RunE = func(command *coral.Command, args []string) error {
client, c, err := target.createGRPCClient()
if err != nil {
return err
}
defer safeClose(c)
req := &api.CreateAuthMethodRequest{
Type: "oidc",
Name: methodName,
Issuer: issuer,
ClientId: clientId,
ClientSecret: clientSecret,
}
resp, err := client.CreateAuthMethod(context.Background(), req)
if err != nil {
return err
}
tbl := table.New("ID", "NAME", "TYPE")
tbl.AddRow(resp.AuthMethod.Id, resp.AuthMethod.Name, resp.AuthMethod.Type)
tbl.Print()
return nil
}
return command
}
+1
View File
@@ -10,6 +10,7 @@ func Command() *coral.Command {
rootCmd.AddCommand(derpMapCommand())
rootCmd.AddCommand(serverCommand())
rootCmd.AddCommand(versionCommand())
rootCmd.AddCommand(authMethodsCommand())
rootCmd.AddCommand(tailnetCommand())
rootCmd.AddCommand(authkeysCommand())
rootCmd.AddCommand(machineCommands())
+2
View File
@@ -45,6 +45,8 @@ func migrate(db *gorm.DB, repository domain.Repository) error {
&domain.ServerConfig{},
&domain.Tailnet{},
&domain.TailnetConfig{},
&domain.AuthMethod{},
&domain.Account{},
&domain.User{},
&domain.AuthKey{},
&domain.Machine{},
+47
View File
@@ -0,0 +1,47 @@
package domain
import (
"context"
"errors"
"github.com/jsiebens/ionscale/internal/util"
"gorm.io/gorm"
)
type Account struct {
ID uint64 `gorm:"primary_key;autoIncrement:false"`
ExternalID string
LoginName string
AuthMethodID uint64
}
func (r *repository) GetOrCreateAccount(ctx context.Context, authMethodID uint64, externalID, loginName string) (*Account, bool, error) {
account := &Account{}
id := util.NextID()
tx := r.withContext(ctx).
Where(Account{AuthMethodID: authMethodID, ExternalID: externalID}).
Attrs(Account{ID: id, LoginName: loginName}).
FirstOrCreate(account)
if tx.Error != nil {
return nil, false, tx.Error
}
return account, account.ID == id, nil
}
func (r *repository) GetAccount(ctx context.Context, id uint64) (*Account, error) {
var account Account
tx := r.withContext(ctx).Take(&account, "id = ?", id)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
if tx.Error != nil {
return nil, tx.Error
}
return &account, nil
}
+23 -2
View File
@@ -9,8 +9,9 @@ import (
)
type ACLPolicy struct {
Hosts map[string]string `json:"hosts,omitempty"`
ACLs []ACL `json:"acls"`
Groups map[string][]string `json:"groups,omitempty"`
Hosts map[string]string `json:"hosts,omitempty"`
ACLs []ACL `json:"acls"`
}
type ACL struct {
@@ -157,6 +158,26 @@ func (a *aclEngine) expandMachineAlias(m *Machine, alias string, src bool) []str
}
}
if strings.Contains(alias, "@") && !m.HasTags() && m.HasUser(alias) {
return []string{m.IPv4.String(), m.IPv6.String()}
}
if strings.HasPrefix(alias, "group:") && !m.HasTags() {
users, ok := a.policy.Groups[alias]
if !ok {
return []string{}
}
for _, u := range users {
if m.HasUser(u) {
return []string{m.IPv4.String(), m.IPv6.String()}
}
}
return []string{}
}
if strings.HasPrefix(alias, "tag:") && m.HasTag(alias[4:]) {
return []string{m.IPv4.String(), m.IPv6.String()}
}
+50
View File
@@ -0,0 +1,50 @@
package domain
import (
"context"
"errors"
"gorm.io/gorm"
)
type AuthMethod struct {
ID uint64 `gorm:"primary_key;autoIncrement:false"`
Name string `gorm:"type:varchar(64);unique_index"`
Type string
Issuer string
ClientId string
ClientSecret string
}
func (r *repository) SaveAuthMethod(ctx context.Context, m *AuthMethod) error {
tx := r.withContext(ctx).Save(m)
if tx.Error != nil {
return tx.Error
}
return nil
}
func (r *repository) ListAuthMethods(ctx context.Context) ([]AuthMethod, error) {
var authMethods = []AuthMethod{}
tx := r.withContext(ctx).Find(&authMethods)
if tx.Error != nil {
return nil, tx.Error
}
return authMethods, nil
}
func (r *repository) GetAuthMethod(ctx context.Context, id uint64) (*AuthMethod, error) {
var m AuthMethod
tx := r.withContext(ctx).First(&m, "id = ?", id)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
if tx.Error != nil {
return nil, tx.Error
}
return &m, nil
}
+4
View File
@@ -61,6 +61,10 @@ func (m *Machine) HasTag(tag string) bool {
return false
}
func (m *Machine) HasUser(loginName string) bool {
return m.User.Name == loginName
}
func (m *Machine) HasTags() bool {
return len(m.Tags) != 0
}
+8
View File
@@ -14,6 +14,13 @@ type Repository interface {
GetDERPMap(ctx context.Context) (*tailcfg.DERPMap, error)
SetDERPMap(ctx context.Context, v *tailcfg.DERPMap) error
SaveAuthMethod(ctx context.Context, m *AuthMethod) error
ListAuthMethods(ctx context.Context) ([]AuthMethod, error)
GetAuthMethod(ctx context.Context, id uint64) (*AuthMethod, error)
GetAccount(ctx context.Context, accountID uint64) (*Account, error)
GetOrCreateAccount(ctx context.Context, authMethodID uint64, externalID, loginName string) (*Account, bool, error)
GetOrCreateTailnet(ctx context.Context, name string) (*Tailnet, bool, error)
GetTailnet(ctx context.Context, id uint64) (*Tailnet, error)
ListTailnets(ctx context.Context) ([]Tailnet, error)
@@ -35,6 +42,7 @@ type Repository interface {
GetOrCreateServiceUser(ctx context.Context, tailnet *Tailnet) (*User, bool, error)
ListUsers(ctx context.Context, tailnetID uint64) (Users, error)
GetOrCreateUserWithAccount(ctx context.Context, tailnet *Tailnet, account *Account) (*User, bool, error)
DeleteUsersByTailnet(ctx context.Context, tailnetID uint64) error
SaveMachine(ctx context.Context, m *Machine) error
+20
View File
@@ -9,6 +9,7 @@ type TailnetRole string
const (
TailnetRoleService TailnetRole = "service"
TailnetRoleMember TailnetRole = "member"
)
type User struct {
@@ -18,6 +19,9 @@ type User struct {
TailnetRole TailnetRole
TailnetID uint64
Tailnet Tailnet
AccountID *uint64
Account *Account
}
type Users []User
@@ -54,3 +58,19 @@ func (r *repository) DeleteUsersByTailnet(ctx context.Context, tailnetID uint64)
tx := r.withContext(ctx).Where("tailnet_id = ?", tailnetID).Delete(&User{})
return tx.Error
}
func (r *repository) GetOrCreateUserWithAccount(ctx context.Context, tailnet *Tailnet, account *Account) (*User, bool, error) {
user := &User{}
id := util.NextID()
query := User{AccountID: &account.ID, TailnetID: tailnet.ID}
attrs := User{ID: id, Name: account.LoginName, TailnetID: tailnet.ID, AccountID: &account.ID, TailnetRole: TailnetRoleMember}
tx := r.withContext(ctx).Where(query).Attrs(attrs).FirstOrCreate(user)
if tx.Error != nil {
return nil, false, tx.Error
}
return user, user.ID == id, nil
}
+221 -23
View File
@@ -1,8 +1,13 @@
package handlers
import (
"context"
"encoding/json"
"github.com/jsiebens/ionscale/internal/addr"
"github.com/jsiebens/ionscale/internal/provider"
"github.com/mr-tron/base58"
"net/http"
"strconv"
"time"
"github.com/jsiebens/ionscale/internal/config"
@@ -21,6 +26,7 @@ func NewAuthenticationHandlers(
config: config,
repository: repository,
pendingMachineRegistrationRequests: pendingMachineRegistrationRequests,
pendingOAuthUsers: cache.New(5*time.Minute, 10*time.Minute),
}
}
@@ -28,21 +34,118 @@ type AuthenticationHandlers struct {
repository domain.Repository
config *config.Config
pendingMachineRegistrationRequests *cache.Cache
pendingOAuthUsers *cache.Cache
}
type AuthFormData struct {
AuthMethods []domain.AuthMethod
}
type TailnetSelectionData struct {
Tailnets []domain.Tailnet
}
type oauthState struct {
Key string
AuthMethod uint64
}
func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
ctx := c.Request().Context()
key := c.Param("key")
if _, ok := h.pendingMachineRegistrationRequests.Get(key); !ok {
return c.Redirect(http.StatusFound, "/a/error")
}
methods, err := h.repository.ListAuthMethods(ctx)
if err != nil {
return err
}
return c.Render(http.StatusOK, "auth.html", &AuthFormData{AuthMethods: methods})
}
func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error {
ctx := c.Request().Context()
key := c.Param("key")
authKey := c.FormValue("ak")
authMethodId := c.FormValue("s")
if _, ok := h.pendingMachineRegistrationRequests.Get(key); !ok {
return c.Redirect(http.StatusFound, "/a/error")
}
if authKey != "" {
return h.endMachineRegistrationFlow(c, key, authKey)
return h.endMachineRegistrationFlow(c, &oauthState{Key: key})
}
return c.Render(http.StatusOK, "auth.html", nil)
if authMethodId != "" {
id, err := strconv.ParseUint(authMethodId, 10, 64)
if err != nil {
return err
}
method, err := h.repository.GetAuthMethod(ctx, id)
if err != nil {
return err
}
state, err := h.createState(key, method.ID)
if err != nil {
return err
}
authProvider, err := provider.NewProvider(method)
if err != nil {
return err
}
redirectUrl := authProvider.GetLoginURL(h.config.CreateUrl("/a/callback"), state)
return c.Redirect(http.StatusFound, redirectUrl)
}
return c.Redirect(http.StatusFound, "/a/"+key)
}
func (h *AuthenticationHandlers) Callback(c echo.Context) error {
ctx := c.Request().Context()
code := c.QueryParam("code")
state, err := h.readState(c.QueryParam("state"))
if err != nil {
return err
}
user, err := h.exchangeUser(ctx, code, state)
if err != nil {
return err
}
tailnets, err := h.repository.ListTailnets(ctx)
if err != nil {
return err
}
account, _, err := h.repository.GetOrCreateAccount(ctx, state.AuthMethod, user.ID, user.Name)
if err != nil {
return err
}
h.pendingOAuthUsers.Set(state.Key, account, cache.DefaultExpiration)
return c.Render(http.StatusOK, "tailnets.html", &TailnetSelectionData{Tailnets: tailnets})
}
func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
state, err := h.readState(c.QueryParam("state"))
if err != nil {
return err
}
return h.endMachineRegistrationFlow(c, state)
}
func (h *AuthenticationHandlers) Success(c echo.Context) error {
@@ -58,36 +161,71 @@ func (h *AuthenticationHandlers) Error(c echo.Context) error {
return c.Render(http.StatusOK, "error.html", nil)
}
func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, registrationKey, authKeyParam string) error {
func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, state *oauthState) error {
ctx := c.Request().Context()
defer h.pendingMachineRegistrationRequests.Delete(registrationKey)
defer h.pendingMachineRegistrationRequests.Delete(state.Key)
preqItem, preqOK := h.pendingMachineRegistrationRequests.Get(registrationKey)
preqItem, preqOK := h.pendingMachineRegistrationRequests.Get(state.Key)
if !preqOK {
return c.Redirect(http.StatusFound, "/a/error")
}
authKeyParam := c.FormValue("ak")
tailnetIDParam := c.FormValue("s")
preq := preqItem.(*pendingMachineRegistrationRequest)
req := preq.request
machineKey := preq.machineKey
nodeKey := req.NodeKey.String()
authKey, err := h.repository.LoadAuthKey(ctx, authKeyParam)
if err != nil {
return err
}
var tailnet *domain.Tailnet
var user *domain.User
var ephemeral bool
var tags = []string{}
if authKey == nil {
return c.Redirect(http.StatusFound, "/a/error?e=iak")
}
if authKeyParam != "" {
authKey, err := h.repository.LoadAuthKey(ctx, authKeyParam)
if err != nil {
return err
}
tailnet := authKey.Tailnet
user := authKey.User
if authKey == nil {
return c.Redirect(http.StatusFound, "/a/error?e=iak")
}
tailnet = &authKey.Tailnet
user = &authKey.User
tags = authKey.Tags
ephemeral = authKey.Ephemeral
} else {
parseUint, err := strconv.ParseUint(tailnetIDParam, 10, 64)
if err != nil {
return err
}
tailnet, err = h.repository.GetTailnet(ctx, parseUint)
if err != nil {
return err
}
item, ok := h.pendingOAuthUsers.Get(state.Key)
if !ok {
return c.Redirect(http.StatusFound, "/a/error")
}
oa := item.(*domain.Account)
user, _, err = h.repository.GetOrCreateUserWithAccount(ctx, tailnet, oa)
if err != nil {
return err
}
ephemeral = false
}
var m *domain.Machine
m, err = h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
m, err := h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
if err != nil {
return err
}
@@ -95,10 +233,17 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
if m == nil {
now := time.Now().UTC()
registeredTags := authKey.Tags
registeredTags := tags
advertisedTags := domain.SanitizeTags(req.Hostinfo.RequestTags)
tags := append(registeredTags, advertisedTags...)
if len(tags) != 0 {
user, _, err = h.repository.GetOrCreateServiceUser(ctx, tailnet)
if err != nil {
return err
}
}
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
if err != nil {
@@ -111,13 +256,13 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
NameIdx: nameIdx,
MachineKey: machineKey,
NodeKey: nodeKey,
Ephemeral: authKey.Ephemeral,
Ephemeral: ephemeral,
RegisteredTags: registeredTags,
Tags: domain.SanitizeTags(tags),
CreatedAt: now,
User: user,
Tailnet: tailnet,
User: *user,
Tailnet: *tailnet,
}
if !req.Expiry.IsZero() {
@@ -131,10 +276,17 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
m.IPv4 = domain.IP{IP: ipv4}
m.IPv6 = domain.IP{IP: ipv6}
} else {
registeredTags := authKey.Tags
registeredTags := tags
advertisedTags := domain.SanitizeTags(req.Hostinfo.RequestTags)
tags := append(registeredTags, advertisedTags...)
if len(tags) != 0 {
user, _, err = h.repository.GetOrCreateServiceUser(ctx, tailnet)
if err != nil {
return err
}
}
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
if m.Name != sanitizeHostname {
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
@@ -145,13 +297,13 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
m.NameIdx = nameIdx
}
m.NodeKey = nodeKey
m.Ephemeral = authKey.Ephemeral
m.Ephemeral = ephemeral
m.RegisteredTags = registeredTags
m.Tags = domain.SanitizeTags(tags)
m.UserID = user.ID
m.User = user
m.User = *user
m.TailnetID = tailnet.ID
m.Tailnet = tailnet
m.Tailnet = *tailnet
m.ExpiresAt = nil
}
@@ -161,3 +313,49 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
return c.Redirect(http.StatusFound, "/a/success")
}
func (h *AuthenticationHandlers) getAuthProvider(ctx context.Context, authMethodId uint64) (provider.AuthProvider, error) {
authMethod, err := h.repository.GetAuthMethod(ctx, authMethodId)
if err != nil {
return nil, err
}
return provider.NewProvider(authMethod)
}
func (h *AuthenticationHandlers) exchangeUser(ctx context.Context, code string, state *oauthState) (*provider.User, error) {
redirectUrl := h.config.CreateUrl("/a/callback")
authProvider, err := h.getAuthProvider(ctx, state.AuthMethod)
if err != nil {
return nil, err
}
user, err := authProvider.Exchange(redirectUrl, code)
if err != nil {
return nil, err
}
return user, nil
}
func (h *AuthenticationHandlers) createState(key string, authMethodId uint64) (string, error) {
stateMap := oauthState{Key: key, AuthMethod: authMethodId}
marshal, err := json.Marshal(&stateMap)
if err != nil {
return "", err
}
return base58.FastBase58Encoding(marshal), nil
}
func (h *AuthenticationHandlers) readState(s string) (*oauthState, error) {
decodedState, err := base58.FastBase58Decoding(s)
if err != nil {
return nil, err
}
var state = &oauthState{}
if err := json.Unmarshal(decodedState, state); err != nil {
return nil, err
}
return state, nil
}
+130
View File
@@ -0,0 +1,130 @@
package provider
import (
"context"
"fmt"
"strings"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/jsiebens/ionscale/internal/domain"
"golang.org/x/oauth2"
)
type OIDCProvider struct {
clientID string
clientSecret string
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
}
func NewOIDCProvider(c *domain.AuthMethod) (*OIDCProvider, error) {
provider, err := oidc.NewProvider(context.Background(), c.Issuer)
if err != nil {
return nil, err
}
verifier := provider.Verifier(&oidc.Config{ClientID: c.ClientId})
return &OIDCProvider{
clientID: c.ClientId,
clientSecret: c.ClientSecret,
provider: provider,
verifier: verifier,
}, nil
}
func (p *OIDCProvider) GetLoginURL(redirectURI, state string) string {
oauth2Config := oauth2.Config{
ClientID: p.clientID,
ClientSecret: p.clientSecret,
RedirectURL: redirectURI,
Endpoint: p.provider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
return oauth2Config.AuthCodeURL(state, oauth2.ApprovalForce)
}
func (p *OIDCProvider) Exchange(redirectURI, code string) (*User, error) {
oauth2Config := oauth2.Config{
ClientID: p.clientID,
ClientSecret: p.clientSecret,
RedirectURL: redirectURI,
Endpoint: p.provider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
oauth2Token, err := oauth2Config.Exchange(context.Background(), code)
if err != nil {
return nil, err
}
// Extract the ID Token from OAuth2 token.
rawIdToken, ok := oauth2Token.Extra("id_token").(string)
if !ok || strings.TrimSpace(rawIdToken) == "" {
return nil, fmt.Errorf("id_token missing")
}
// Parse and verify ID Token payload.
idToken, err := p.verifier.Verify(context.Background(), rawIdToken)
if err != nil {
return nil, err
}
sub, email, tokenClaims, err := p.getTokenClaims(idToken)
if err != nil {
return nil, err
}
userInfoClaims, err := p.getUserInfoClaims(oauth2Config, oauth2Token)
if err != nil {
return nil, err
}
return &User{
ID: sub,
Name: email,
Attr: map[string]interface{}{
"token": tokenClaims,
"userinfo": userInfoClaims,
},
}, nil
}
func (p *OIDCProvider) getTokenClaims(idToken *oidc.IDToken) (string, string, map[string]interface{}, error) {
var raw = make(map[string]interface{})
var claims struct {
Sub string `json:"sub"`
Email string `json:"email"`
}
// Extract default claims.
if err := idToken.Claims(&claims); err != nil {
return "", "", nil, fmt.Errorf("failed to parse id_token claims: %v", err)
}
// Extract raw claims.
if err := idToken.Claims(&raw); err != nil {
return "", "", nil, fmt.Errorf("failed to parse id_token claims: %v", err)
}
return claims.Sub, claims.Email, raw, nil
}
func (p *OIDCProvider) getUserInfoClaims(config oauth2.Config, token *oauth2.Token) (map[string]interface{}, error) {
var raw = make(map[string]interface{})
source := config.TokenSource(context.Background(), token)
info, err := p.provider.UserInfo(context.Background(), source)
if err != nil {
return nil, err
}
if err := info.Claims(&raw); err != nil {
return nil, fmt.Errorf("failed to parse user info claims: %v", err)
}
return raw, nil
}
+27
View File
@@ -0,0 +1,27 @@
package provider
import (
"fmt"
"github.com/jsiebens/ionscale/internal/domain"
)
type AuthProvider interface {
GetLoginURL(redirectURI, state string) string
Exchange(redirectURI, code string) (*User, error)
}
type User struct {
ID string
Name string
Attr map[string]interface{}
}
func NewProvider(m *domain.AuthMethod) (AuthProvider, error) {
switch m.Type {
case "oidc":
return NewOIDCProvider(m)
default:
return nil, fmt.Errorf("unknown auth method type")
}
}
+3 -1
View File
@@ -130,7 +130,9 @@ func Start(config *config.Config) error {
auth := tlsAppHandler.Group("/a")
auth.GET("/:key", authenticationHandlers.StartAuth)
auth.POST("/:key", authenticationHandlers.StartAuth)
auth.POST("/:key", authenticationHandlers.ProcessAuth)
auth.GET("/callback", authenticationHandlers.Callback)
auth.POST("/callback", authenticationHandlers.EndOAuth)
auth.GET("/success", authenticationHandlers.Success)
auth.GET("/error", authenticationHandlers.Error)
+5 -5
View File
@@ -2,8 +2,8 @@ package service
import (
"context"
"encoding/json"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/mapping"
"github.com/jsiebens/ionscale/pkg/gen/api"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@@ -15,12 +15,12 @@ func (s *Service) GetACLPolicy(ctx context.Context, req *api.GetACLPolicyRequest
return nil, err
}
var p api.Policy
if err := mapping.CopyViaJson(policy, &p); err != nil {
marshal, err := json.Marshal(policy)
if err != nil {
return nil, err
}
return &api.GetACLPolicyResponse{Policy: &p}, nil
return &api.GetACLPolicyResponse{Value: marshal}, nil
}
func (s *Service) SetACLPolicy(ctx context.Context, req *api.SetACLPolicyRequest) (*api.SetACLPolicyResponse, error) {
@@ -33,7 +33,7 @@ func (s *Service) SetACLPolicy(ctx context.Context, req *api.SetACLPolicyRequest
}
var policy domain.ACLPolicy
if err := mapping.CopyViaJson(req.Policy, &policy); err != nil {
if err := json.Unmarshal(req.Value, &policy); err != nil {
return nil, err
}
+51
View File
@@ -0,0 +1,51 @@
package service
import (
"context"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/util"
"github.com/jsiebens/ionscale/pkg/gen/api"
)
func (s *Service) CreateAuthMethod(ctx context.Context, req *api.CreateAuthMethodRequest) (*api.CreateAuthMethodResponse, error) {
authMethod := &domain.AuthMethod{
ID: util.NextID(),
Name: req.Name,
Type: req.Type,
Issuer: req.Issuer,
ClientId: req.ClientId,
ClientSecret: req.ClientSecret,
}
if err := s.repository.SaveAuthMethod(ctx, authMethod); err != nil {
return nil, err
}
return &api.CreateAuthMethodResponse{AuthMethod: &api.AuthMethod{
Id: authMethod.ID,
Type: authMethod.Type,
Name: authMethod.Name,
Issuer: authMethod.Issuer,
ClientId: authMethod.ClientId,
}}, nil
}
func (s *Service) ListAuthMethods(ctx context.Context, _ *api.ListAuthMethodsRequest) (*api.ListAuthMethodsResponse, error) {
methods, err := s.repository.ListAuthMethods(ctx)
if err != nil {
return nil, err
}
response := &api.ListAuthMethodsResponse{AuthMethods: []*api.AuthMethod{}}
for _, m := range methods {
response.AuthMethods = append(response.AuthMethods, &api.AuthMethod{
Id: m.ID,
Name: m.Name,
Type: m.Type,
})
}
return response, nil
}
+18
View File
@@ -74,10 +74,28 @@
</head>
<body>
<div class="wrapper">
{{if .AuthMethods}}
<div style="text-align: left; padding-bottom: 10px">
<p><b>Authentication required</b></p>
<small>Login with:</small>
</div>
<form method="post">
<ul class="selectionList">
{{range .AuthMethods}}
<li><button type="submit" name="s" value="{{.ID}}">{{.Name}}</button></li>
{{end}}
</ul>
</form>
<div style="text-align: left; padding-bottom: 10px; padding-top: 20px">
<small>Or enter an <label for="ak">auth key</label> here:</small>
</div>
{{end}}
{{if not .AuthMethods}}
<div style="text-align: left; padding-bottom: 10px">
<p><b>Authentication required</b></p>
<small>Enter an <label for="ak">auth key</label> here:</small>
</div>
{{end}}
<form method="post" style="text-align: right">
<p><input id="ak" name="ak" type="text"/></p>
<div style="padding-top: 10px">
+90
View File
@@ -0,0 +1,90 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@200;300;400;500;600;700&display=swap');
* {
margin: 0;
padding: 0;
box-sizing: border-box;
font-family: 'Poppins', sans-serif;
}
body {
width: 100%;
height: 100vh;
padding: 10px;
background: #379683;
}
.wrapper {
background: #fff;
max-width: 400px;
width: 100%;
margin: 120px auto;
padding: 25px;
border-radius: 5px;
box-shadow: 0 10px 15px rgba(0, 0, 0, 0.1);
}
.selectionList li {
position: relative;
list-style: none;
height: 45px;
line-height: 45px;
margin-bottom: 8px;
background: #f2f2f2;
border-radius: 3px;
overflow: hidden;
box-shadow: 0 2px 2px rgba(0, 0, 0, 0.1);
}
.selectionList {
padding-top: 5px
}
.selectionList li button {
margin: 0;
display: block;
width: 100%;
height: 100%;
border: none;
}
input {
display: block;
width: 100%;
height: 100%;
padding: 10px;
}
button {
padding-top: 10px;
padding-bottom: 10px;
padding-left: 20px;
padding-right: 20px;
height: 45px;
border: none;
}
</style>
<title>ionscale</title>
</head>
<body>
<div class="wrapper">
<div style="text-align: left; padding-bottom: 10px">
<p><b>Tailnets</b></p>
<small>Select your tailnet:</small>
</div>
<form method="post">
<ul class="selectionList">
{{range .Tailnets}}
<li><button type="submit" name="s" value="{{.ID}}">{{.Name}}</button></li>
{{end}}
</ul>
</form>
</div>
</body>
</html>