mirror of
https://github.com/jsiebens/ionscale.git
synced 2026-03-31 15:07:49 +01:00
571 lines
15 KiB
Go
571 lines
15 KiB
Go
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"github.com/jsiebens/ionscale/internal/addr"
|
|
"github.com/jsiebens/ionscale/internal/auth"
|
|
"github.com/labstack/echo/v4/middleware"
|
|
"github.com/mr-tron/base58"
|
|
"net/http"
|
|
"tailscale.com/tailcfg"
|
|
"time"
|
|
|
|
"github.com/jsiebens/ionscale/internal/config"
|
|
"github.com/jsiebens/ionscale/internal/domain"
|
|
"github.com/jsiebens/ionscale/internal/util"
|
|
"github.com/labstack/echo/v4"
|
|
"tailscale.com/util/dnsname"
|
|
)
|
|
|
|
func NewAuthenticationHandlers(
|
|
config *config.Config,
|
|
authProvider auth.Provider,
|
|
systemIAMPolicy *domain.IAMPolicy,
|
|
repository domain.Repository) *AuthenticationHandlers {
|
|
|
|
return &AuthenticationHandlers{
|
|
config: config,
|
|
authProvider: authProvider,
|
|
repository: repository,
|
|
systemIAMPolicy: systemIAMPolicy,
|
|
}
|
|
}
|
|
|
|
type AuthenticationHandlers struct {
|
|
repository domain.Repository
|
|
authProvider auth.Provider
|
|
config *config.Config
|
|
systemIAMPolicy *domain.IAMPolicy
|
|
}
|
|
|
|
type AuthFormData struct {
|
|
ProviderAvailable bool
|
|
Csrf string
|
|
}
|
|
|
|
type TailnetSelectionData struct {
|
|
AccountID uint64
|
|
Tailnets []domain.Tailnet
|
|
SystemAdmin bool
|
|
Csrf string
|
|
}
|
|
|
|
type oauthState struct {
|
|
Key string
|
|
Flow string
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
flow := c.Param("flow")
|
|
key := c.Param("key")
|
|
|
|
// machine registration auth flow
|
|
if flow == "r" || flow == "" {
|
|
if req, err := h.repository.GetRegistrationRequestByKey(ctx, key); err != nil || req == nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
csrf := c.Get(middleware.DefaultCSRFConfig.ContextKey).(string)
|
|
return c.Render(http.StatusOK, "auth.html", &AuthFormData{ProviderAvailable: h.authProvider != nil, Csrf: csrf})
|
|
}
|
|
|
|
// cli auth flow
|
|
if flow == "c" {
|
|
if s, err := h.repository.GetAuthenticationRequest(ctx, key); err != nil || s == nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
}
|
|
|
|
// ssh check auth flow
|
|
if flow == "s" {
|
|
if s, err := h.repository.GetSSHActionRequest(ctx, key); err != nil || s == nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
}
|
|
|
|
if h.authProvider == nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
state, err := h.createState(flow, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
redirectUrl := h.authProvider.GetLoginURL(h.config.CreateUrl("/a/callback"), state)
|
|
|
|
return c.Redirect(http.StatusFound, redirectUrl)
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
|
|
key := c.Param("key")
|
|
authKey := c.FormValue("ak")
|
|
interactive := c.FormValue("s")
|
|
|
|
req, err := h.repository.GetRegistrationRequestByKey(ctx, key)
|
|
if err != nil || req == nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
if authKey != "" {
|
|
return h.endMachineRegistrationFlow(c, req, &oauthState{Key: key})
|
|
}
|
|
|
|
if interactive != "" {
|
|
state, err := h.createState("r", key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
redirectUrl := h.authProvider.GetLoginURL(h.config.CreateUrl("/a/callback"), state)
|
|
|
|
return c.Redirect(http.StatusFound, redirectUrl)
|
|
}
|
|
|
|
return c.Redirect(http.StatusFound, "/a/"+key)
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
|
|
code := c.QueryParam("code")
|
|
state, err := h.readState(c.QueryParam("state"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
user, err := h.exchangeUser(code)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
account, _, err := h.repository.GetOrCreateAccount(ctx, user.ID, user.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if state.Flow == "s" {
|
|
sshActionReq, err := h.repository.GetSSHActionRequest(ctx, state.Key)
|
|
if err != nil || sshActionReq == nil {
|
|
return c.Redirect(http.StatusFound, "/a/error?e=ua")
|
|
}
|
|
|
|
machine, err := h.repository.GetMachine(ctx, sshActionReq.SrcMachineID)
|
|
if err != nil || sshActionReq == nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
if !machine.HasTags() && machine.User.AccountID != nil && *machine.User.AccountID == account.ID {
|
|
sshActionReq.Action = "accept"
|
|
if err := h.repository.SaveSSHActionRequest(ctx, sshActionReq); err != nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
return c.Redirect(http.StatusFound, "/a/success")
|
|
}
|
|
|
|
sshActionReq.Action = "reject"
|
|
if err := h.repository.SaveSSHActionRequest(ctx, sshActionReq); err != nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
return c.Redirect(http.StatusFound, "/a/error?e=nmo")
|
|
}
|
|
|
|
tailnets, err := h.listAvailableTailnets(ctx, user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
csrf := c.Get(middleware.DefaultCSRFConfig.ContextKey).(string)
|
|
|
|
if state.Flow == "r" {
|
|
if len(tailnets) == 0 {
|
|
registrationRequest, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
|
|
if err == nil && registrationRequest != nil {
|
|
registrationRequest.Error = "unauthorized"
|
|
_ = h.repository.SaveRegistrationRequest(ctx, registrationRequest)
|
|
}
|
|
return c.Redirect(http.StatusFound, "/a/error?e=ua")
|
|
}
|
|
return c.Render(http.StatusOK, "tailnets.html", &TailnetSelectionData{
|
|
Csrf: csrf,
|
|
Tailnets: tailnets,
|
|
SystemAdmin: false,
|
|
AccountID: account.ID,
|
|
})
|
|
}
|
|
|
|
if state.Flow == "c" {
|
|
isSystemAdmin, err := h.isSystemAdmin(ctx, user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !isSystemAdmin && len(tailnets) == 0 {
|
|
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
|
|
if err == nil && req != nil {
|
|
req.Error = "unauthorized"
|
|
_ = h.repository.SaveAuthenticationRequest(ctx, req)
|
|
}
|
|
return c.Redirect(http.StatusFound, "/a/error?e=ua")
|
|
}
|
|
return c.Render(http.StatusOK, "tailnets.html", &TailnetSelectionData{
|
|
Csrf: csrf,
|
|
Tailnets: tailnets,
|
|
SystemAdmin: isSystemAdmin,
|
|
AccountID: account.ID,
|
|
})
|
|
}
|
|
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) isSystemAdmin(ctx context.Context, u *auth.User) (bool, error) {
|
|
return h.systemIAMPolicy.EvaluatePolicy(&domain.Identity{UserID: u.ID, Email: u.Name, Attr: u.Attr})
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) listAvailableTailnets(ctx context.Context, u *auth.User) ([]domain.Tailnet, error) {
|
|
var result = []domain.Tailnet{}
|
|
tailnets, err := h.repository.ListTailnets(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, t := range tailnets {
|
|
approved, err := t.IAMPolicy.EvaluatePolicy(&domain.Identity{UserID: u.ID, Email: u.Name, Attr: u.Attr})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if approved {
|
|
result = append(result, t)
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
|
|
ctx := c.Request().Context()
|
|
|
|
state, err := h.readState(c.QueryParam("state"))
|
|
if err != nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
if state.Flow == "r" {
|
|
req, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
|
|
if err != nil || req == nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
return h.endMachineRegistrationFlow(c, req, state)
|
|
}
|
|
|
|
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
|
|
if err != nil || req == nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
return h.endCliAuthenticationFlow(c, req, state)
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) Success(c echo.Context) error {
|
|
s := c.QueryParam("s")
|
|
switch s {
|
|
case "nma":
|
|
return c.Render(http.StatusOK, "newmachine.html", nil)
|
|
}
|
|
return c.Render(http.StatusOK, "success.html", nil)
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) Error(c echo.Context) error {
|
|
e := c.QueryParam("e")
|
|
switch e {
|
|
case "iak":
|
|
return c.Render(http.StatusForbidden, "invalidauthkey.html", nil)
|
|
case "ua":
|
|
return c.Render(http.StatusForbidden, "unauthorized.html", nil)
|
|
case "nto":
|
|
return c.Render(http.StatusForbidden, "notagowner.html", nil)
|
|
case "nmo":
|
|
return c.Render(http.StatusForbidden, "notmachineowner.html", nil)
|
|
}
|
|
return c.Render(http.StatusOK, "error.html", nil)
|
|
}
|
|
|
|
type TailnetSelectionForm struct {
|
|
AccountID uint64 `form:"aid"`
|
|
TailnetID uint64 `form:"tid"`
|
|
AsSystemAdmin bool `form:"sad"`
|
|
AuthKey string `form:"ak"`
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *domain.AuthenticationRequest, state *oauthState) error {
|
|
ctx := c.Request().Context()
|
|
|
|
var form TailnetSelectionForm
|
|
if err := c.Bind(&form); err != nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
account, err := h.repository.GetAccount(ctx, form.AccountID)
|
|
if err != nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
// continue as system admin?
|
|
if form.AsSystemAdmin {
|
|
expiresAt := time.Now().Add(24 * time.Hour)
|
|
token, apiKey := domain.CreateSystemApiKey(account, &expiresAt)
|
|
req.Token = token
|
|
|
|
err := h.repository.Transaction(func(rp domain.Repository) error {
|
|
if err := rp.SaveSystemApiKey(ctx, apiKey); err != nil {
|
|
return err
|
|
}
|
|
if err := rp.SaveAuthenticationRequest(ctx, req); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.Redirect(http.StatusFound, "/a/success")
|
|
}
|
|
|
|
tailnet, err := h.repository.GetTailnet(ctx, form.TailnetID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
user, _, err := h.repository.GetOrCreateUserWithAccount(ctx, tailnet, account)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
expiresAt := time.Now().Add(24 * time.Hour)
|
|
token, apiKey := domain.CreateApiKey(tailnet, user, &expiresAt)
|
|
req.Token = token
|
|
req.TailnetID = &tailnet.ID
|
|
|
|
err = h.repository.Transaction(func(rp domain.Repository) error {
|
|
if err := rp.SaveApiKey(ctx, apiKey); err != nil {
|
|
return err
|
|
}
|
|
if err := rp.SaveAuthenticationRequest(ctx, req); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return c.Redirect(http.StatusFound, "/a/success")
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, registrationRequest *domain.RegistrationRequest, state *oauthState) error {
|
|
ctx := c.Request().Context()
|
|
|
|
var form TailnetSelectionForm
|
|
if err := c.Bind(&form); err != nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
req := tailcfg.RegisterRequest(registrationRequest.Data)
|
|
machineKey := registrationRequest.MachineKey
|
|
nodeKey := req.NodeKey.String()
|
|
|
|
var tailnet *domain.Tailnet
|
|
var user *domain.User
|
|
var ephemeral bool
|
|
var tags = []string{}
|
|
var authorized = false
|
|
|
|
if form.AuthKey != "" {
|
|
authKey, err := h.repository.LoadAuthKey(ctx, form.AuthKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if authKey == nil {
|
|
|
|
registrationRequest.Authenticated = false
|
|
registrationRequest.Error = "invalid auth key"
|
|
|
|
if err := h.repository.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
return c.Redirect(http.StatusFound, "/a/error?e=iak")
|
|
}
|
|
|
|
tailnet = &authKey.Tailnet
|
|
user = &authKey.User
|
|
tags = authKey.Tags
|
|
ephemeral = authKey.Ephemeral
|
|
authorized = authKey.PreAuthorized
|
|
} else {
|
|
selectedTailnet, err := h.repository.GetTailnet(ctx, form.TailnetID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
account, err := h.repository.GetAccount(ctx, form.AccountID)
|
|
if err != nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
|
|
selectedUser, _, err := h.repository.GetOrCreateUserWithAccount(ctx, selectedTailnet, account)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
user = selectedUser
|
|
tailnet = selectedTailnet
|
|
ephemeral = false
|
|
}
|
|
|
|
if err := tailnet.ACLPolicy.CheckTagOwners(registrationRequest.Data.Hostinfo.RequestTags, user); err != nil {
|
|
registrationRequest.Authenticated = false
|
|
registrationRequest.Error = err.Error()
|
|
if err := h.repository.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
|
|
return c.Redirect(http.StatusFound, "/a/error")
|
|
}
|
|
return c.Redirect(http.StatusFound, "/a/error?e=nto")
|
|
}
|
|
|
|
autoAllowIPs := tailnet.ACLPolicy.FindAutoApprovedIPs(req.Hostinfo.RoutableIPs, tags, user)
|
|
|
|
var m *domain.Machine
|
|
|
|
m, err := h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
|
|
if m == nil {
|
|
registeredTags := tags
|
|
advertisedTags := domain.SanitizeTags(req.Hostinfo.RequestTags)
|
|
tags := append(registeredTags, advertisedTags...)
|
|
|
|
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
|
|
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
m = &domain.Machine{
|
|
ID: util.NextID(),
|
|
Name: sanitizeHostname,
|
|
NameIdx: nameIdx,
|
|
MachineKey: machineKey,
|
|
NodeKey: nodeKey,
|
|
Ephemeral: ephemeral || req.Ephemeral,
|
|
RegisteredTags: registeredTags,
|
|
Tags: domain.SanitizeTags(tags),
|
|
AutoAllowIPs: autoAllowIPs,
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(180 * 24 * time.Hour).UTC(),
|
|
KeyExpiryDisabled: len(tags) != 0,
|
|
Authorized: !tailnet.MachineAuthorizationEnabled || authorized,
|
|
|
|
User: *user,
|
|
Tailnet: *tailnet,
|
|
}
|
|
|
|
ipv4, ipv6, err := addr.SelectIP(checkIP(ctx, h.repository.CountMachinesWithIPv4))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.IPv4 = domain.IP{Addr: ipv4}
|
|
m.IPv6 = domain.IP{Addr: ipv6}
|
|
} else {
|
|
registeredTags := tags
|
|
advertisedTags := domain.SanitizeTags(req.Hostinfo.RequestTags)
|
|
tags := append(registeredTags, advertisedTags...)
|
|
|
|
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
|
|
if m.Name != sanitizeHostname {
|
|
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.Name = sanitizeHostname
|
|
m.NameIdx = nameIdx
|
|
}
|
|
m.NodeKey = nodeKey
|
|
m.Ephemeral = ephemeral || req.Ephemeral
|
|
m.RegisteredTags = registeredTags
|
|
m.Tags = domain.SanitizeTags(tags)
|
|
m.AutoAllowIPs = autoAllowIPs
|
|
m.UserID = user.ID
|
|
m.User = *user
|
|
m.TailnetID = tailnet.ID
|
|
m.Tailnet = *tailnet
|
|
m.ExpiresAt = now.Add(180 * 24 * time.Hour).UTC()
|
|
}
|
|
|
|
err = h.repository.Transaction(func(rp domain.Repository) error {
|
|
registrationRequest.Authenticated = true
|
|
registrationRequest.Error = ""
|
|
|
|
if err := rp.SaveMachine(ctx, m); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := rp.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if m.Authorized {
|
|
return c.Redirect(http.StatusFound, "/a/success")
|
|
} else {
|
|
return c.Redirect(http.StatusFound, "/a/success?s=nma")
|
|
}
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) exchangeUser(code string) (*auth.User, error) {
|
|
redirectUrl := h.config.CreateUrl("/a/callback")
|
|
|
|
user, err := h.authProvider.Exchange(redirectUrl, code)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) createState(flow string, key string) (string, error) {
|
|
stateMap := oauthState{Key: key, Flow: flow}
|
|
marshal, err := json.Marshal(&stateMap)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base58.FastBase58Encoding(marshal), nil
|
|
}
|
|
|
|
func (h *AuthenticationHandlers) readState(s string) (*oauthState, error) {
|
|
decodedState, err := base58.FastBase58Decoding(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var state = &oauthState{}
|
|
if err := json.Unmarshal(decodedState, state); err != nil {
|
|
return nil, err
|
|
}
|
|
return state, nil
|
|
}
|