Files
ionscale/internal/handlers/authentication.go
T
2022-11-04 14:13:19 +01:00

571 lines
15 KiB
Go

package handlers
import (
"context"
"encoding/json"
"github.com/jsiebens/ionscale/internal/addr"
"github.com/jsiebens/ionscale/internal/auth"
"github.com/labstack/echo/v4/middleware"
"github.com/mr-tron/base58"
"net/http"
"tailscale.com/tailcfg"
"time"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/util"
"github.com/labstack/echo/v4"
"tailscale.com/util/dnsname"
)
func NewAuthenticationHandlers(
config *config.Config,
authProvider auth.Provider,
systemIAMPolicy *domain.IAMPolicy,
repository domain.Repository) *AuthenticationHandlers {
return &AuthenticationHandlers{
config: config,
authProvider: authProvider,
repository: repository,
systemIAMPolicy: systemIAMPolicy,
}
}
type AuthenticationHandlers struct {
repository domain.Repository
authProvider auth.Provider
config *config.Config
systemIAMPolicy *domain.IAMPolicy
}
type AuthFormData struct {
ProviderAvailable bool
Csrf string
}
type TailnetSelectionData struct {
AccountID uint64
Tailnets []domain.Tailnet
SystemAdmin bool
Csrf string
}
type oauthState struct {
Key string
Flow string
}
func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
ctx := c.Request().Context()
flow := c.Param("flow")
key := c.Param("key")
// machine registration auth flow
if flow == "r" || flow == "" {
if req, err := h.repository.GetRegistrationRequestByKey(ctx, key); err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
}
csrf := c.Get(middleware.DefaultCSRFConfig.ContextKey).(string)
return c.Render(http.StatusOK, "auth.html", &AuthFormData{ProviderAvailable: h.authProvider != nil, Csrf: csrf})
}
// cli auth flow
if flow == "c" {
if s, err := h.repository.GetAuthenticationRequest(ctx, key); err != nil || s == nil {
return c.Redirect(http.StatusFound, "/a/error")
}
}
// ssh check auth flow
if flow == "s" {
if s, err := h.repository.GetSSHActionRequest(ctx, key); err != nil || s == nil {
return c.Redirect(http.StatusFound, "/a/error")
}
}
if h.authProvider == nil {
return c.Redirect(http.StatusFound, "/a/error")
}
state, err := h.createState(flow, key)
if err != nil {
return err
}
redirectUrl := h.authProvider.GetLoginURL(h.config.CreateUrl("/a/callback"), state)
return c.Redirect(http.StatusFound, redirectUrl)
}
func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error {
ctx := c.Request().Context()
key := c.Param("key")
authKey := c.FormValue("ak")
interactive := c.FormValue("s")
req, err := h.repository.GetRegistrationRequestByKey(ctx, key)
if err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
}
if authKey != "" {
return h.endMachineRegistrationFlow(c, req, &oauthState{Key: key})
}
if interactive != "" {
state, err := h.createState("r", key)
if err != nil {
return err
}
redirectUrl := h.authProvider.GetLoginURL(h.config.CreateUrl("/a/callback"), state)
return c.Redirect(http.StatusFound, redirectUrl)
}
return c.Redirect(http.StatusFound, "/a/"+key)
}
func (h *AuthenticationHandlers) Callback(c echo.Context) error {
ctx := c.Request().Context()
code := c.QueryParam("code")
state, err := h.readState(c.QueryParam("state"))
if err != nil {
return err
}
user, err := h.exchangeUser(code)
if err != nil {
return err
}
account, _, err := h.repository.GetOrCreateAccount(ctx, user.ID, user.Name)
if err != nil {
return err
}
if state.Flow == "s" {
sshActionReq, err := h.repository.GetSSHActionRequest(ctx, state.Key)
if err != nil || sshActionReq == nil {
return c.Redirect(http.StatusFound, "/a/error?e=ua")
}
machine, err := h.repository.GetMachine(ctx, sshActionReq.SrcMachineID)
if err != nil || sshActionReq == nil {
return c.Redirect(http.StatusFound, "/a/error")
}
if !machine.HasTags() && machine.User.AccountID != nil && *machine.User.AccountID == account.ID {
sshActionReq.Action = "accept"
if err := h.repository.SaveSSHActionRequest(ctx, sshActionReq); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
}
return c.Redirect(http.StatusFound, "/a/success")
}
sshActionReq.Action = "reject"
if err := h.repository.SaveSSHActionRequest(ctx, sshActionReq); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
}
return c.Redirect(http.StatusFound, "/a/error?e=nmo")
}
tailnets, err := h.listAvailableTailnets(ctx, user)
if err != nil {
return err
}
csrf := c.Get(middleware.DefaultCSRFConfig.ContextKey).(string)
if state.Flow == "r" {
if len(tailnets) == 0 {
registrationRequest, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
if err == nil && registrationRequest != nil {
registrationRequest.Error = "unauthorized"
_ = h.repository.SaveRegistrationRequest(ctx, registrationRequest)
}
return c.Redirect(http.StatusFound, "/a/error?e=ua")
}
return c.Render(http.StatusOK, "tailnets.html", &TailnetSelectionData{
Csrf: csrf,
Tailnets: tailnets,
SystemAdmin: false,
AccountID: account.ID,
})
}
if state.Flow == "c" {
isSystemAdmin, err := h.isSystemAdmin(ctx, user)
if err != nil {
return err
}
if !isSystemAdmin && len(tailnets) == 0 {
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
if err == nil && req != nil {
req.Error = "unauthorized"
_ = h.repository.SaveAuthenticationRequest(ctx, req)
}
return c.Redirect(http.StatusFound, "/a/error?e=ua")
}
return c.Render(http.StatusOK, "tailnets.html", &TailnetSelectionData{
Csrf: csrf,
Tailnets: tailnets,
SystemAdmin: isSystemAdmin,
AccountID: account.ID,
})
}
return c.Redirect(http.StatusFound, "/a/error")
}
func (h *AuthenticationHandlers) isSystemAdmin(ctx context.Context, u *auth.User) (bool, error) {
return h.systemIAMPolicy.EvaluatePolicy(&domain.Identity{UserID: u.ID, Email: u.Name, Attr: u.Attr})
}
func (h *AuthenticationHandlers) listAvailableTailnets(ctx context.Context, u *auth.User) ([]domain.Tailnet, error) {
var result = []domain.Tailnet{}
tailnets, err := h.repository.ListTailnets(ctx)
if err != nil {
return nil, err
}
for _, t := range tailnets {
approved, err := t.IAMPolicy.EvaluatePolicy(&domain.Identity{UserID: u.ID, Email: u.Name, Attr: u.Attr})
if err != nil {
return nil, err
}
if approved {
result = append(result, t)
}
}
return result, nil
}
func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
ctx := c.Request().Context()
state, err := h.readState(c.QueryParam("state"))
if err != nil {
return c.Redirect(http.StatusFound, "/a/error")
}
if state.Flow == "r" {
req, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
if err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
}
return h.endMachineRegistrationFlow(c, req, state)
}
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
if err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
}
return h.endCliAuthenticationFlow(c, req, state)
}
func (h *AuthenticationHandlers) Success(c echo.Context) error {
s := c.QueryParam("s")
switch s {
case "nma":
return c.Render(http.StatusOK, "newmachine.html", nil)
}
return c.Render(http.StatusOK, "success.html", nil)
}
func (h *AuthenticationHandlers) Error(c echo.Context) error {
e := c.QueryParam("e")
switch e {
case "iak":
return c.Render(http.StatusForbidden, "invalidauthkey.html", nil)
case "ua":
return c.Render(http.StatusForbidden, "unauthorized.html", nil)
case "nto":
return c.Render(http.StatusForbidden, "notagowner.html", nil)
case "nmo":
return c.Render(http.StatusForbidden, "notmachineowner.html", nil)
}
return c.Render(http.StatusOK, "error.html", nil)
}
type TailnetSelectionForm struct {
AccountID uint64 `form:"aid"`
TailnetID uint64 `form:"tid"`
AsSystemAdmin bool `form:"sad"`
AuthKey string `form:"ak"`
}
func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *domain.AuthenticationRequest, state *oauthState) error {
ctx := c.Request().Context()
var form TailnetSelectionForm
if err := c.Bind(&form); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
}
account, err := h.repository.GetAccount(ctx, form.AccountID)
if err != nil {
return c.Redirect(http.StatusFound, "/a/error")
}
// continue as system admin?
if form.AsSystemAdmin {
expiresAt := time.Now().Add(24 * time.Hour)
token, apiKey := domain.CreateSystemApiKey(account, &expiresAt)
req.Token = token
err := h.repository.Transaction(func(rp domain.Repository) error {
if err := rp.SaveSystemApiKey(ctx, apiKey); err != nil {
return err
}
if err := rp.SaveAuthenticationRequest(ctx, req); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
return c.Redirect(http.StatusFound, "/a/success")
}
tailnet, err := h.repository.GetTailnet(ctx, form.TailnetID)
if err != nil {
return err
}
user, _, err := h.repository.GetOrCreateUserWithAccount(ctx, tailnet, account)
if err != nil {
return err
}
expiresAt := time.Now().Add(24 * time.Hour)
token, apiKey := domain.CreateApiKey(tailnet, user, &expiresAt)
req.Token = token
req.TailnetID = &tailnet.ID
err = h.repository.Transaction(func(rp domain.Repository) error {
if err := rp.SaveApiKey(ctx, apiKey); err != nil {
return err
}
if err := rp.SaveAuthenticationRequest(ctx, req); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
return c.Redirect(http.StatusFound, "/a/success")
}
func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, registrationRequest *domain.RegistrationRequest, state *oauthState) error {
ctx := c.Request().Context()
var form TailnetSelectionForm
if err := c.Bind(&form); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
}
req := tailcfg.RegisterRequest(registrationRequest.Data)
machineKey := registrationRequest.MachineKey
nodeKey := req.NodeKey.String()
var tailnet *domain.Tailnet
var user *domain.User
var ephemeral bool
var tags = []string{}
var authorized = false
if form.AuthKey != "" {
authKey, err := h.repository.LoadAuthKey(ctx, form.AuthKey)
if err != nil {
return err
}
if authKey == nil {
registrationRequest.Authenticated = false
registrationRequest.Error = "invalid auth key"
if err := h.repository.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
}
return c.Redirect(http.StatusFound, "/a/error?e=iak")
}
tailnet = &authKey.Tailnet
user = &authKey.User
tags = authKey.Tags
ephemeral = authKey.Ephemeral
authorized = authKey.PreAuthorized
} else {
selectedTailnet, err := h.repository.GetTailnet(ctx, form.TailnetID)
if err != nil {
return err
}
account, err := h.repository.GetAccount(ctx, form.AccountID)
if err != nil {
return c.Redirect(http.StatusFound, "/a/error")
}
selectedUser, _, err := h.repository.GetOrCreateUserWithAccount(ctx, selectedTailnet, account)
if err != nil {
return err
}
user = selectedUser
tailnet = selectedTailnet
ephemeral = false
}
if err := tailnet.ACLPolicy.CheckTagOwners(registrationRequest.Data.Hostinfo.RequestTags, user); err != nil {
registrationRequest.Authenticated = false
registrationRequest.Error = err.Error()
if err := h.repository.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
}
return c.Redirect(http.StatusFound, "/a/error?e=nto")
}
autoAllowIPs := tailnet.ACLPolicy.FindAutoApprovedIPs(req.Hostinfo.RoutableIPs, tags, user)
var m *domain.Machine
m, err := h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
if err != nil {
return err
}
now := time.Now().UTC()
if m == nil {
registeredTags := tags
advertisedTags := domain.SanitizeTags(req.Hostinfo.RequestTags)
tags := append(registeredTags, advertisedTags...)
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
if err != nil {
return err
}
m = &domain.Machine{
ID: util.NextID(),
Name: sanitizeHostname,
NameIdx: nameIdx,
MachineKey: machineKey,
NodeKey: nodeKey,
Ephemeral: ephemeral || req.Ephemeral,
RegisteredTags: registeredTags,
Tags: domain.SanitizeTags(tags),
AutoAllowIPs: autoAllowIPs,
CreatedAt: now,
ExpiresAt: now.Add(180 * 24 * time.Hour).UTC(),
KeyExpiryDisabled: len(tags) != 0,
Authorized: !tailnet.MachineAuthorizationEnabled || authorized,
User: *user,
Tailnet: *tailnet,
}
ipv4, ipv6, err := addr.SelectIP(checkIP(ctx, h.repository.CountMachinesWithIPv4))
if err != nil {
return err
}
m.IPv4 = domain.IP{Addr: ipv4}
m.IPv6 = domain.IP{Addr: ipv6}
} else {
registeredTags := tags
advertisedTags := domain.SanitizeTags(req.Hostinfo.RequestTags)
tags := append(registeredTags, advertisedTags...)
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
if m.Name != sanitizeHostname {
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
if err != nil {
return err
}
m.Name = sanitizeHostname
m.NameIdx = nameIdx
}
m.NodeKey = nodeKey
m.Ephemeral = ephemeral || req.Ephemeral
m.RegisteredTags = registeredTags
m.Tags = domain.SanitizeTags(tags)
m.AutoAllowIPs = autoAllowIPs
m.UserID = user.ID
m.User = *user
m.TailnetID = tailnet.ID
m.Tailnet = *tailnet
m.ExpiresAt = now.Add(180 * 24 * time.Hour).UTC()
}
err = h.repository.Transaction(func(rp domain.Repository) error {
registrationRequest.Authenticated = true
registrationRequest.Error = ""
if err := rp.SaveMachine(ctx, m); err != nil {
return err
}
if err := rp.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
if m.Authorized {
return c.Redirect(http.StatusFound, "/a/success")
} else {
return c.Redirect(http.StatusFound, "/a/success?s=nma")
}
}
func (h *AuthenticationHandlers) exchangeUser(code string) (*auth.User, error) {
redirectUrl := h.config.CreateUrl("/a/callback")
user, err := h.authProvider.Exchange(redirectUrl, code)
if err != nil {
return nil, err
}
return user, nil
}
func (h *AuthenticationHandlers) createState(flow string, key string) (string, error) {
stateMap := oauthState{Key: key, Flow: flow}
marshal, err := json.Marshal(&stateMap)
if err != nil {
return "", err
}
return base58.FastBase58Encoding(marshal), nil
}
func (h *AuthenticationHandlers) readState(s string) (*oauthState, error) {
decodedState, err := base58.FastBase58Decoding(s)
if err != nil {
return nil, err
}
var state = &oauthState{}
if err := json.Unmarshal(decodedState, state); err != nil {
return nil, err
}
return state, nil
}