feat: add id token handler

This commit is contained in:
Johan Siebens
2022-09-27 16:13:17 +02:00
parent 7cadcc9085
commit 2e57338b54
8 changed files with 237 additions and 1 deletions
+2 -1
View File
@@ -10,6 +10,7 @@ require (
github.com/coreos/go-oidc/v3 v3.3.0 github.com/coreos/go-oidc/v3 v3.3.0
github.com/glebarez/sqlite v1.4.6 github.com/glebarez/sqlite v1.4.6
github.com/go-gormigrate/gormigrate/v2 v2.0.2 github.com/go-gormigrate/gormigrate/v2 v2.0.2
github.com/golang-jwt/jwt/v4 v4.4.2
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/hashicorp/go-bexpr v0.1.11 github.com/hashicorp/go-bexpr v0.1.11
github.com/hashicorp/go-hclog v1.3.0 github.com/hashicorp/go-hclog v1.3.0
@@ -40,6 +41,7 @@ require (
golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094 golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094
golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde
google.golang.org/protobuf v1.28.1 google.golang.org/protobuf v1.28.1
gopkg.in/square/go-jose.v2 v2.6.0
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.3.9 gorm.io/driver/postgres v1.3.9
@@ -140,7 +142,6 @@ require (
google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90 // indirect google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90 // indirect
google.golang.org/grpc v1.48.0 // indirect google.golang.org/grpc v1.48.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
modernc.org/libc v1.18.0 // indirect modernc.org/libc v1.18.0 // indirect
modernc.org/mathutil v1.5.0 // indirect modernc.org/mathutil v1.5.0 // indirect
modernc.org/memory v1.3.0 // indirect modernc.org/memory v1.3.0 // indirect
+2
View File
@@ -244,6 +244,8 @@ github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs=
github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 h1:+eHOFJl1BaXrQxKX+T06f78590z4qA2ZzBTqahsKSE4= github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 h1:+eHOFJl1BaXrQxKX+T06f78590z4qA2ZzBTqahsKSE4=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
+28
View File
@@ -8,6 +8,7 @@ import (
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/jsiebens/ionscale/internal/broker" "github.com/jsiebens/ionscale/internal/broker"
"github.com/jsiebens/ionscale/internal/database/migration" "github.com/jsiebens/ionscale/internal/database/migration"
"github.com/jsiebens/ionscale/internal/util"
"tailscale.com/types/key" "tailscale.com/types/key"
"time" "time"
@@ -85,6 +86,10 @@ func migrate(db *gorm.DB) error {
return err return err
} }
if err := createJSONWebKeySet(ctx, repository); err != nil {
return err
}
return nil return nil
} }
@@ -108,6 +113,29 @@ func createServerKey(ctx context.Context, repository domain.Repository) error {
return nil return nil
} }
func createJSONWebKeySet(ctx context.Context, repository domain.Repository) error {
jwks, err := repository.GetJSONWebKeySet(ctx)
if err != nil {
return err
}
if jwks != nil {
return nil
}
privateKey, id, err := util.NewPrivateKey()
if err != nil {
return err
}
jsonWebKey := domain.JSONWebKey{Id: id, PrivateKey: *privateKey}
if err := repository.SetJSONWebKeySet(ctx, &domain.JSONWebKeys{Key: jsonWebKey}); err != nil {
return err
}
return nil
}
type GormLoggerAdapter struct { type GormLoggerAdapter struct {
logger hclog.Logger logger hclog.Logger
} }
+3
View File
@@ -14,6 +14,9 @@ type Repository interface {
GetControlKeys(ctx context.Context) (*ControlKeys, error) GetControlKeys(ctx context.Context) (*ControlKeys, error)
SetControlKeys(ctx context.Context, keys *ControlKeys) error SetControlKeys(ctx context.Context, keys *ControlKeys) error
GetJSONWebKeySet(ctx context.Context) (*JSONWebKeys, error)
SetJSONWebKeySet(ctx context.Context, keys *JSONWebKeys) error
GetDERPMap(ctx context.Context) (*tailcfg.DERPMap, error) GetDERPMap(ctx context.Context) (*tailcfg.DERPMap, error)
SetDERPMap(ctx context.Context, v *tailcfg.DERPMap) error SetDERPMap(ctx context.Context, v *tailcfg.DERPMap) error
+37
View File
@@ -2,11 +2,14 @@ package domain
import ( import (
"context" "context"
"crypto"
"crypto/rsa"
"encoding/json" "encoding/json"
"errors" "errors"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
tkey "tailscale.com/types/key" tkey "tailscale.com/types/key"
"time"
) )
type configKey string type configKey string
@@ -14,8 +17,23 @@ type configKey string
const ( const (
derpMapConfigKey configKey = "derp_map" derpMapConfigKey configKey = "derp_map"
controlKeysConfigKey configKey = "control_keys" controlKeysConfigKey configKey = "control_keys"
jwksConfigKey configKey = "jwks"
) )
type JSONWebKeys struct {
Key JSONWebKey
}
type JSONWebKey struct {
Id string
PrivateKey rsa.PrivateKey
CreatedAt time.Time
}
func (j JSONWebKey) Public() crypto.PublicKey {
return j.PrivateKey.Public()
}
type ServerConfig struct { type ServerConfig struct {
Key configKey `gorm:"primary_key"` Key configKey `gorm:"primary_key"`
Value []byte Value []byte
@@ -45,6 +63,25 @@ func (r *repository) SetControlKeys(ctx context.Context, v *ControlKeys) error {
return r.setServerConfig(ctx, controlKeysConfigKey, v) return r.setServerConfig(ctx, controlKeysConfigKey, v)
} }
func (r *repository) GetJSONWebKeySet(ctx context.Context) (*JSONWebKeys, error) {
var m JSONWebKeys
err := r.getServerConfig(ctx, jwksConfigKey, &m)
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &m, nil
}
func (r *repository) SetJSONWebKeySet(ctx context.Context, v *JSONWebKeys) error {
return r.setServerConfig(ctx, jwksConfigKey, v)
}
func (r *repository) GetDERPMap(ctx context.Context) (*tailcfg.DERPMap, error) { func (r *repository) GetDERPMap(ctx context.Context) (*tailcfg.DERPMap, error) {
var m tailcfg.DERPMap var m tailcfg.DERPMap
+148
View File
@@ -0,0 +1,148 @@
package handlers
import (
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/jsiebens/ionscale/internal/bind"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/util"
"github.com/labstack/echo/v4"
"gopkg.in/square/go-jose.v2"
"net/http"
"tailscale.com/tailcfg"
"time"
)
func NewIDTokenHandlers(createBinder bind.Factory, config *config.Config, repository domain.Repository) *IDTokenHandlers {
return &IDTokenHandlers{
issuer: config.ServerUrl,
jwksUri: config.CreateUrl("/.well-known/jwks"),
createBinder: createBinder,
repository: repository,
}
}
type IDTokenHandlers struct {
issuer string
jwksUri string
createBinder bind.Factory
repository domain.Repository
}
func (h *IDTokenHandlers) OpenIDConfig(c echo.Context) error {
v := map[string]interface{}{}
v["issuer"] = h.issuer
v["jwks_uri"] = h.jwksUri
v["subject_types_supported"] = []string{"public"}
v["response_types_supported"] = []string{"id_token"}
v["scopes_supported"] = []string{"openid"}
v["id_token_signing_alg_values_supported"] = []string{"RS256"}
v["claims_supported"] = []string{
"sub",
"aud",
"exp",
"iat",
"iss",
"jti",
"nbf",
}
return c.JSON(http.StatusOK, v)
}
func (h *IDTokenHandlers) Jwks(c echo.Context) error {
keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context())
if err != nil {
return err
}
pub := jose.JSONWebKey{Key: keySet.Key.Public(), KeyID: keySet.Key.Id, Algorithm: "RS256", Use: "sig"}
set := jose.JSONWebKeySet{Keys: []jose.JSONWebKey{pub}}
return c.JSON(http.StatusOK, set)
}
func (h *IDTokenHandlers) FetchToken(c echo.Context) error {
ctx := c.Request().Context()
keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context())
if err != nil {
return err
}
binder, err := h.createBinder(c)
if err != nil {
return err
}
req := &tailcfg.TokenRequest{}
if err := binder.BindRequest(c, req); err != nil {
return err
}
machineKey := binder.Peer().String()
nodeKey := req.NodeKey.String()
var m *domain.Machine
m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
if err != nil {
return err
}
if m == nil {
return echo.NewHTTPError(http.StatusBadRequest)
}
_, tailnetDomain, sub := h.names(m)
now := time.Now()
claims := jwt.MapClaims{
"jit": fmt.Sprintf("%d", util.NextID()),
"iss": h.issuer,
"sub": sub,
"aud": []string{req.Audience},
"exp": jwt.NewNumericDate(now.Add(5 * time.Minute)),
"nbf": jwt.NewNumericDate(now),
"iat": jwt.NewNumericDate(now),
"key": m.NodeKey,
"addresses": []string{m.IPv4.String(), m.IPv6.String()},
"nid": m.ID,
"node": sub,
"domain": tailnetDomain,
}
if m.HasTags() {
tags := []string{}
for _, t := range m.Tags {
tags = append(tags, fmt.Sprintf("%s:%s", tailnetDomain, t))
}
claims["tags"] = tags
} else {
claims["user"] = fmt.Sprintf("%s:%s", tailnetDomain, m.User.Name)
claims["uid"] = m.UserID
}
unsignedToken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
unsignedToken.Header["kid"] = keySet.Key.Id
jwtB64, err := unsignedToken.SignedString(&keySet.Key.PrivateKey)
if err != nil {
return err
}
resp := tailcfg.TokenResponse{IDToken: jwtB64}
return binder.WriteResponse(c, http.StatusOK, resp)
}
func (h *IDTokenHandlers) names(m *domain.Machine) (string, string, string) {
var name = m.Name
if m.NameIdx != 0 {
name = fmt.Sprintf("%s-%d", m.Name, m.NameIdx)
}
sanitizedTailnetName := domain.SanitizeTailnetName(m.Tailnet.Name)
return name, sanitizedTailnetName, fmt.Sprintf("%s.%s", name, sanitizedTailnetName)
}
+5
View File
@@ -96,6 +96,7 @@ func Start(c *config.Config) error {
registrationHandlers := handlers.NewRegistrationHandlers(bind.DefaultBinder(p), c, brokers, repository) registrationHandlers := handlers.NewRegistrationHandlers(bind.DefaultBinder(p), c, brokers, repository)
pollNetMapHandler := handlers.NewPollNetMapHandler(bind.DefaultBinder(p), brokers, repository, offlineTimers) pollNetMapHandler := handlers.NewPollNetMapHandler(bind.DefaultBinder(p), brokers, repository, offlineTimers)
dnsHandlers := handlers.NewDNSHandlers(bind.DefaultBinder(p), dnsProvider) dnsHandlers := handlers.NewDNSHandlers(bind.DefaultBinder(p), dnsProvider)
idTokenHandlers := handlers.NewIDTokenHandlers(bind.DefaultBinder(p), c, repository)
e := echo.New() e := echo.New()
e.Use(EchoLogger(logger)) e.Use(EchoLogger(logger))
@@ -103,6 +104,7 @@ func Start(c *config.Config) error {
e.POST("/machine/register", registrationHandlers.Register) e.POST("/machine/register", registrationHandlers.Register)
e.POST("/machine/map", pollNetMapHandler.PollNetMap) e.POST("/machine/map", pollNetMapHandler.PollNetMap)
e.POST("/machine/set-dns", dnsHandlers.SetDNS) e.POST("/machine/set-dns", dnsHandlers.SetDNS)
e.POST("/machine/id-token", idTokenHandlers.FetchToken)
return e return e
} }
@@ -111,6 +113,7 @@ func Start(c *config.Config) error {
registrationHandlers := handlers.NewRegistrationHandlers(bind.BoxBinder(serverKey.LegacyControlKey), c, brokers, repository) registrationHandlers := handlers.NewRegistrationHandlers(bind.BoxBinder(serverKey.LegacyControlKey), c, brokers, repository)
pollNetMapHandler := handlers.NewPollNetMapHandler(bind.BoxBinder(serverKey.LegacyControlKey), brokers, repository, offlineTimers) pollNetMapHandler := handlers.NewPollNetMapHandler(bind.BoxBinder(serverKey.LegacyControlKey), brokers, repository, offlineTimers)
dnsHandlers := handlers.NewDNSHandlers(bind.BoxBinder(serverKey.LegacyControlKey), dnsProvider) dnsHandlers := handlers.NewDNSHandlers(bind.BoxBinder(serverKey.LegacyControlKey), dnsProvider)
idTokenHandlers := handlers.NewIDTokenHandlers(bind.BoxBinder(serverKey.LegacyControlKey), c, repository)
authenticationHandlers := handlers.NewAuthenticationHandlers( authenticationHandlers := handlers.NewAuthenticationHandlers(
c, c,
authProvider, authProvider,
@@ -149,6 +152,8 @@ func Start(c *config.Config) error {
tlsAppHandler.POST("/machine/:id", registrationHandlers.Register) tlsAppHandler.POST("/machine/:id", registrationHandlers.Register)
tlsAppHandler.POST("/machine/:id/map", pollNetMapHandler.PollNetMap) tlsAppHandler.POST("/machine/:id/map", pollNetMapHandler.PollNetMap)
tlsAppHandler.POST("/machine/:id/set-dns", dnsHandlers.SetDNS) tlsAppHandler.POST("/machine/:id/set-dns", dnsHandlers.SetDNS)
tlsAppHandler.GET("/.well-known/jwks", idTokenHandlers.Jwks)
tlsAppHandler.GET("/.well-known/openid-configuration", idTokenHandlers.OpenIDConfig)
auth := tlsAppHandler.Group("/a") auth := tlsAppHandler.Group("/a")
auth.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ auth.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
+12
View File
@@ -1,6 +1,7 @@
package util package util
import ( import (
"crypto/rsa"
"math/rand" "math/rand"
"time" "time"
) )
@@ -34,3 +35,14 @@ func RandomBytes(size int) ([]byte, error) {
} }
return buf, nil return buf, nil
} }
func NewPrivateKey() (*rsa.PrivateKey, string, error) {
id := RandStringBytes(22)
privateKey, err := rsa.GenerateKey(entropy, 2048)
if err != nil {
return nil, "", err
}
return privateKey, id, nil
}