mirror of
https://github.com/jsiebens/ionscale.git
synced 2026-03-31 15:07:49 +01:00
chore: restructure test setup and add some initial web login flow tests
This commit is contained in:
@@ -45,6 +45,21 @@ type AuthFormData struct {
|
||||
Csrf string
|
||||
}
|
||||
|
||||
type AuthInput struct {
|
||||
Key string `param:"key"`
|
||||
Flow AuthFlow `param:"flow"`
|
||||
AuthKey string `query:"ak" form:"ak"`
|
||||
Oidc bool `query:"oidc" form:"oidc"`
|
||||
}
|
||||
|
||||
type EndAuthForm struct {
|
||||
AccountID uint64 `form:"aid"`
|
||||
TailnetID uint64 `form:"tid"`
|
||||
AsSystemAdmin bool `form:"sad"`
|
||||
AuthKey string `form:"ak"`
|
||||
State string `form:"state"`
|
||||
}
|
||||
|
||||
type TailnetSelectionData struct {
|
||||
AccountID uint64
|
||||
Tailnets []domain.Tailnet
|
||||
@@ -54,34 +69,54 @@ type TailnetSelectionData struct {
|
||||
|
||||
type oauthState struct {
|
||||
Key string
|
||||
Flow string
|
||||
Flow AuthFlow
|
||||
}
|
||||
|
||||
type AuthFlow string
|
||||
|
||||
const (
|
||||
AuthFlowMachineRegistration = "r"
|
||||
AuthFlowClient = "c"
|
||||
AuthFlowSSHCheckFlow = "s"
|
||||
)
|
||||
|
||||
func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
flow := c.Param("flow")
|
||||
key := c.Param("key")
|
||||
|
||||
var input AuthInput
|
||||
if err := c.Bind(&input); err != nil {
|
||||
return logError(err)
|
||||
}
|
||||
|
||||
// machine registration auth flow
|
||||
if flow == "r" || flow == "" {
|
||||
if req, err := h.repository.GetRegistrationRequestByKey(ctx, key); err != nil || req == nil {
|
||||
if input.Flow == AuthFlowMachineRegistration {
|
||||
req, err := h.repository.GetRegistrationRequestByKey(ctx, input.Key)
|
||||
if err != nil || req == nil {
|
||||
return logError(err)
|
||||
}
|
||||
|
||||
if input.Oidc && h.authProvider != nil {
|
||||
goto startOidc
|
||||
}
|
||||
|
||||
if input.AuthKey != "" {
|
||||
return h.endMachineRegistrationFlow(c, EndAuthForm{AuthKey: input.AuthKey}, req)
|
||||
}
|
||||
|
||||
csrf := c.Get(middleware.DefaultCSRFConfig.ContextKey).(string)
|
||||
return c.Render(http.StatusOK, "auth.html", &AuthFormData{ProviderAvailable: h.authProvider != nil, Csrf: csrf})
|
||||
}
|
||||
|
||||
// cli auth flow
|
||||
if flow == "c" {
|
||||
if s, err := h.repository.GetAuthenticationRequest(ctx, key); err != nil || s == nil {
|
||||
if input.Flow == AuthFlowClient {
|
||||
if s, err := h.repository.GetAuthenticationRequest(ctx, input.Key); err != nil || s == nil {
|
||||
return logError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// ssh check auth flow
|
||||
if flow == "s" {
|
||||
if s, err := h.repository.GetSSHActionRequest(ctx, key); err != nil || s == nil {
|
||||
if input.Flow == AuthFlowSSHCheckFlow {
|
||||
if s, err := h.repository.GetSSHActionRequest(ctx, input.Key); err != nil || s == nil {
|
||||
return logError(err)
|
||||
}
|
||||
}
|
||||
@@ -90,7 +125,9 @@ func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
|
||||
return logError(fmt.Errorf("unable to start auth flow as no auth provider is configured"))
|
||||
}
|
||||
|
||||
state, err := h.createState(flow, key)
|
||||
startOidc:
|
||||
|
||||
state, err := h.createState(input.Flow, input.Key)
|
||||
if err != nil {
|
||||
return logError(err)
|
||||
}
|
||||
@@ -103,21 +140,22 @@ func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
|
||||
func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
|
||||
key := c.Param("key")
|
||||
authKey := c.FormValue("ak")
|
||||
interactive := c.FormValue("s")
|
||||
var input AuthInput
|
||||
if err := c.Bind(&input); err != nil {
|
||||
return logError(err)
|
||||
}
|
||||
|
||||
req, err := h.repository.GetRegistrationRequestByKey(ctx, key)
|
||||
req, err := h.repository.GetRegistrationRequestByKey(ctx, input.Key)
|
||||
if err != nil || req == nil {
|
||||
return logError(err)
|
||||
}
|
||||
|
||||
if authKey != "" {
|
||||
return h.endMachineRegistrationFlow(c, req, &oauthState{Key: key})
|
||||
if input.AuthKey != "" {
|
||||
return h.endMachineRegistrationFlow(c, EndAuthForm{AuthKey: input.AuthKey}, req)
|
||||
}
|
||||
|
||||
if interactive != "" {
|
||||
state, err := h.createState("r", key)
|
||||
if input.Oidc {
|
||||
state, err := h.createState(input.Flow, input.Key)
|
||||
if err != nil {
|
||||
return logError(err)
|
||||
}
|
||||
@@ -127,7 +165,7 @@ func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error {
|
||||
return c.Redirect(http.StatusFound, redirectUrl)
|
||||
}
|
||||
|
||||
return c.Redirect(http.StatusFound, "/a/"+key)
|
||||
return c.Redirect(http.StatusFound, fmt.Sprintf("/a/%s/%s", input.Flow, input.Key))
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
@@ -153,7 +191,7 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
return logError(err)
|
||||
}
|
||||
|
||||
if state.Flow == "s" {
|
||||
if state.Flow == AuthFlowSSHCheckFlow {
|
||||
sshActionReq, err := h.repository.GetSSHActionRequest(ctx, state.Key)
|
||||
if err != nil || sshActionReq == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error?e=ua")
|
||||
@@ -186,7 +224,7 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
|
||||
csrf := c.Get(middleware.DefaultCSRFConfig.ContextKey).(string)
|
||||
|
||||
if state.Flow == "r" {
|
||||
if state.Flow == AuthFlowMachineRegistration {
|
||||
if len(tailnets) == 0 {
|
||||
registrationRequest, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
|
||||
if err == nil && registrationRequest != nil {
|
||||
@@ -195,6 +233,18 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
}
|
||||
return c.Redirect(http.StatusFound, "/a/error?e=ua")
|
||||
}
|
||||
|
||||
if len(tailnets) == 1 {
|
||||
req, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
|
||||
if err != nil {
|
||||
return logError(err)
|
||||
}
|
||||
if req == nil {
|
||||
return logError(fmt.Errorf("invalid registration key"))
|
||||
}
|
||||
return h.endMachineRegistrationFlow(c, EndAuthForm{AccountID: account.ID, TailnetID: tailnets[0].ID}, req)
|
||||
}
|
||||
|
||||
return c.Render(http.StatusOK, "tailnets.html", &TailnetSelectionData{
|
||||
Csrf: csrf,
|
||||
Tailnets: tailnets,
|
||||
@@ -203,8 +253,8 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
})
|
||||
}
|
||||
|
||||
if state.Flow == "c" {
|
||||
isSystemAdmin, err := h.isSystemAdmin(ctx, user)
|
||||
if state.Flow == AuthFlowClient {
|
||||
isSystemAdmin, err := h.isSystemAdmin(user)
|
||||
if err != nil {
|
||||
return logError(err)
|
||||
}
|
||||
@@ -228,51 +278,38 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) isSystemAdmin(ctx context.Context, u *auth.User) (bool, error) {
|
||||
return h.systemIAMPolicy.EvaluatePolicy(&domain.Identity{UserID: u.ID, Email: u.Name, Attr: u.Attr})
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) listAvailableTailnets(ctx context.Context, u *auth.User) ([]domain.Tailnet, error) {
|
||||
var result = []domain.Tailnet{}
|
||||
tailnets, err := h.repository.ListTailnets(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range tailnets {
|
||||
approved, err := t.IAMPolicy.EvaluatePolicy(&domain.Identity{UserID: u.ID, Email: u.Name, Attr: u.Attr})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if approved {
|
||||
result = append(result, t)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
|
||||
func (h *AuthenticationHandlers) EndAuth(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
|
||||
state, err := h.readState(c.QueryParam("state"))
|
||||
var form EndAuthForm
|
||||
if err := c.Bind(&form); err != nil {
|
||||
return logError(err)
|
||||
}
|
||||
|
||||
state, err := h.readState(form.State)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Invalid state parameter")
|
||||
}
|
||||
|
||||
if state.Flow == "r" {
|
||||
if state.Flow == AuthFlowMachineRegistration {
|
||||
req, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
|
||||
if err != nil || req == nil {
|
||||
return logError(err)
|
||||
}
|
||||
|
||||
return h.endMachineRegistrationFlow(c, req, state)
|
||||
return h.endMachineRegistrationFlow(c, form, req)
|
||||
}
|
||||
|
||||
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
|
||||
if err != nil || req == nil {
|
||||
return logError(err)
|
||||
if state.Flow == AuthFlowClient {
|
||||
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
|
||||
if err != nil || req == nil {
|
||||
return logError(err)
|
||||
}
|
||||
|
||||
return h.endCliAuthenticationFlow(c, form, req)
|
||||
}
|
||||
|
||||
return h.endCliAuthenticationFlow(c, req, state)
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Invalid state parameter")
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) Success(c echo.Context) error {
|
||||
@@ -299,21 +336,9 @@ func (h *AuthenticationHandlers) Error(c echo.Context) error {
|
||||
return c.Render(http.StatusOK, "error.html", nil)
|
||||
}
|
||||
|
||||
type TailnetSelectionForm struct {
|
||||
AccountID uint64 `form:"aid"`
|
||||
TailnetID uint64 `form:"tid"`
|
||||
AsSystemAdmin bool `form:"sad"`
|
||||
AuthKey string `form:"ak"`
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *domain.AuthenticationRequest, state *oauthState) error {
|
||||
func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, form EndAuthForm, req *domain.AuthenticationRequest) error {
|
||||
ctx := c.Request().Context()
|
||||
|
||||
var form TailnetSelectionForm
|
||||
if err := c.Bind(&form); err != nil {
|
||||
return logError(err)
|
||||
}
|
||||
|
||||
account, err := h.repository.GetAccount(ctx, form.AccountID)
|
||||
if err != nil {
|
||||
return logError(err)
|
||||
@@ -371,14 +396,9 @@ func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *d
|
||||
return c.Redirect(http.StatusFound, "/a/success")
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, registrationRequest *domain.RegistrationRequest, state *oauthState) error {
|
||||
func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, form EndAuthForm, registrationRequest *domain.RegistrationRequest) error {
|
||||
ctx := c.Request().Context()
|
||||
|
||||
var form TailnetSelectionForm
|
||||
if err := c.Bind(&form); err != nil {
|
||||
return logError(err)
|
||||
}
|
||||
|
||||
req := tailcfg.RegisterRequest(registrationRequest.Data)
|
||||
machineKey := registrationRequest.MachineKey
|
||||
nodeKey := req.NodeKey.String()
|
||||
@@ -542,6 +562,28 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) isSystemAdmin(u *auth.User) (bool, error) {
|
||||
return h.systemIAMPolicy.EvaluatePolicy(&domain.Identity{UserID: u.ID, Email: u.Name, Attr: u.Attr})
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) listAvailableTailnets(ctx context.Context, u *auth.User) ([]domain.Tailnet, error) {
|
||||
var result = []domain.Tailnet{}
|
||||
tailnets, err := h.repository.ListTailnets(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range tailnets {
|
||||
approved, err := t.IAMPolicy.EvaluatePolicy(&domain.Identity{UserID: u.ID, Email: u.Name, Attr: u.Attr})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if approved {
|
||||
result = append(result, t)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) exchangeUser(code string) (*auth.User, error) {
|
||||
redirectUrl := h.config.CreateUrl("/a/callback")
|
||||
|
||||
@@ -553,7 +595,7 @@ func (h *AuthenticationHandlers) exchangeUser(code string) (*auth.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) createState(flow string, key string) (string, error) {
|
||||
func (h *AuthenticationHandlers) createState(flow AuthFlow, key string) (string, error) {
|
||||
stateMap := oauthState{Key: key, Flow: flow}
|
||||
marshal, err := json.Marshal(&stateMap)
|
||||
if err != nil {
|
||||
|
||||
@@ -174,7 +174,7 @@ func Start(c *config.Config) error {
|
||||
auth.GET("/:flow/:key", authenticationHandlers.StartAuth)
|
||||
auth.POST("/:flow/:key", authenticationHandlers.ProcessAuth)
|
||||
auth.GET("/callback", authenticationHandlers.Callback)
|
||||
auth.POST("/callback", authenticationHandlers.EndOAuth)
|
||||
auth.POST("/callback", authenticationHandlers.EndAuth)
|
||||
auth.GET("/success", authenticationHandlers.Success)
|
||||
auth.GET("/error", authenticationHandlers.Error)
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@
|
||||
<form method="post">
|
||||
<input type="hidden" name="_csrf" value="{{.Csrf}}">
|
||||
<ul class="selectionList">
|
||||
<li><button type="submit" name="s" value="true">OpenID</button></li>
|
||||
<li><button type="submit" name="oidc" value="true">OpenID</button></li>
|
||||
</ul>
|
||||
</form>
|
||||
<div style="text-align: left; padding-bottom: 10px; padding-top: 20px">
|
||||
|
||||
Reference in New Issue
Block a user