mirror of
https://github.com/jsiebens/ionscale.git
synced 2026-03-31 15:07:49 +01:00
feat: add id token handler
This commit is contained in:
@@ -10,6 +10,7 @@ require (
|
||||
github.com/coreos/go-oidc/v3 v3.3.0
|
||||
github.com/glebarez/sqlite v1.4.6
|
||||
github.com/go-gormigrate/gormigrate/v2 v2.0.2
|
||||
github.com/golang-jwt/jwt/v4 v4.4.2
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/hashicorp/go-bexpr v0.1.11
|
||||
github.com/hashicorp/go-hclog v1.3.0
|
||||
@@ -40,6 +41,7 @@ require (
|
||||
golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094
|
||||
golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde
|
||||
google.golang.org/protobuf v1.28.1
|
||||
gopkg.in/square/go-jose.v2 v2.6.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/postgres v1.3.9
|
||||
@@ -140,7 +142,6 @@ require (
|
||||
google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90 // indirect
|
||||
google.golang.org/grpc v1.48.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
modernc.org/libc v1.18.0 // indirect
|
||||
modernc.org/mathutil v1.5.0 // indirect
|
||||
modernc.org/memory v1.3.0 // indirect
|
||||
|
||||
@@ -244,6 +244,8 @@ github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx
|
||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs=
|
||||
github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
|
||||
github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 h1:+eHOFJl1BaXrQxKX+T06f78590z4qA2ZzBTqahsKSE4=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/jsiebens/ionscale/internal/broker"
|
||||
"github.com/jsiebens/ionscale/internal/database/migration"
|
||||
"github.com/jsiebens/ionscale/internal/util"
|
||||
"tailscale.com/types/key"
|
||||
"time"
|
||||
|
||||
@@ -85,6 +86,10 @@ func migrate(db *gorm.DB) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := createJSONWebKeySet(ctx, repository); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -108,6 +113,29 @@ func createServerKey(ctx context.Context, repository domain.Repository) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func createJSONWebKeySet(ctx context.Context, repository domain.Repository) error {
|
||||
jwks, err := repository.GetJSONWebKeySet(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if jwks != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
privateKey, id, err := util.NewPrivateKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jsonWebKey := domain.JSONWebKey{Id: id, PrivateKey: *privateKey}
|
||||
|
||||
if err := repository.SetJSONWebKeySet(ctx, &domain.JSONWebKeys{Key: jsonWebKey}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type GormLoggerAdapter struct {
|
||||
logger hclog.Logger
|
||||
}
|
||||
|
||||
@@ -14,6 +14,9 @@ type Repository interface {
|
||||
GetControlKeys(ctx context.Context) (*ControlKeys, error)
|
||||
SetControlKeys(ctx context.Context, keys *ControlKeys) error
|
||||
|
||||
GetJSONWebKeySet(ctx context.Context) (*JSONWebKeys, error)
|
||||
SetJSONWebKeySet(ctx context.Context, keys *JSONWebKeys) error
|
||||
|
||||
GetDERPMap(ctx context.Context) (*tailcfg.DERPMap, error)
|
||||
SetDERPMap(ctx context.Context, v *tailcfg.DERPMap) error
|
||||
|
||||
|
||||
@@ -2,11 +2,14 @@ package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
tkey "tailscale.com/types/key"
|
||||
"time"
|
||||
)
|
||||
|
||||
type configKey string
|
||||
@@ -14,8 +17,23 @@ type configKey string
|
||||
const (
|
||||
derpMapConfigKey configKey = "derp_map"
|
||||
controlKeysConfigKey configKey = "control_keys"
|
||||
jwksConfigKey configKey = "jwks"
|
||||
)
|
||||
|
||||
type JSONWebKeys struct {
|
||||
Key JSONWebKey
|
||||
}
|
||||
|
||||
type JSONWebKey struct {
|
||||
Id string
|
||||
PrivateKey rsa.PrivateKey
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
func (j JSONWebKey) Public() crypto.PublicKey {
|
||||
return j.PrivateKey.Public()
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Key configKey `gorm:"primary_key"`
|
||||
Value []byte
|
||||
@@ -45,6 +63,25 @@ func (r *repository) SetControlKeys(ctx context.Context, v *ControlKeys) error {
|
||||
return r.setServerConfig(ctx, controlKeysConfigKey, v)
|
||||
}
|
||||
|
||||
func (r *repository) GetJSONWebKeySet(ctx context.Context) (*JSONWebKeys, error) {
|
||||
var m JSONWebKeys
|
||||
err := r.getServerConfig(ctx, jwksConfigKey, &m)
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (r *repository) SetJSONWebKeySet(ctx context.Context, v *JSONWebKeys) error {
|
||||
return r.setServerConfig(ctx, jwksConfigKey, v)
|
||||
}
|
||||
|
||||
func (r *repository) GetDERPMap(ctx context.Context) (*tailcfg.DERPMap, error) {
|
||||
var m tailcfg.DERPMap
|
||||
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/jsiebens/ionscale/internal/bind"
|
||||
"github.com/jsiebens/ionscale/internal/config"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/util"
|
||||
"github.com/labstack/echo/v4"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"net/http"
|
||||
"tailscale.com/tailcfg"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewIDTokenHandlers(createBinder bind.Factory, config *config.Config, repository domain.Repository) *IDTokenHandlers {
|
||||
return &IDTokenHandlers{
|
||||
issuer: config.ServerUrl,
|
||||
jwksUri: config.CreateUrl("/.well-known/jwks"),
|
||||
createBinder: createBinder,
|
||||
repository: repository,
|
||||
}
|
||||
}
|
||||
|
||||
type IDTokenHandlers struct {
|
||||
issuer string
|
||||
jwksUri string
|
||||
createBinder bind.Factory
|
||||
repository domain.Repository
|
||||
}
|
||||
|
||||
func (h *IDTokenHandlers) OpenIDConfig(c echo.Context) error {
|
||||
v := map[string]interface{}{}
|
||||
|
||||
v["issuer"] = h.issuer
|
||||
v["jwks_uri"] = h.jwksUri
|
||||
v["subject_types_supported"] = []string{"public"}
|
||||
v["response_types_supported"] = []string{"id_token"}
|
||||
v["scopes_supported"] = []string{"openid"}
|
||||
v["id_token_signing_alg_values_supported"] = []string{"RS256"}
|
||||
v["claims_supported"] = []string{
|
||||
"sub",
|
||||
"aud",
|
||||
"exp",
|
||||
"iat",
|
||||
"iss",
|
||||
"jti",
|
||||
"nbf",
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, v)
|
||||
}
|
||||
|
||||
func (h *IDTokenHandlers) Jwks(c echo.Context) error {
|
||||
keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pub := jose.JSONWebKey{Key: keySet.Key.Public(), KeyID: keySet.Key.Id, Algorithm: "RS256", Use: "sig"}
|
||||
set := jose.JSONWebKeySet{Keys: []jose.JSONWebKey{pub}}
|
||||
return c.JSON(http.StatusOK, set)
|
||||
}
|
||||
|
||||
func (h *IDTokenHandlers) FetchToken(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
|
||||
keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
binder, err := h.createBinder(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &tailcfg.TokenRequest{}
|
||||
if err := binder.BindRequest(c, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
machineKey := binder.Peer().String()
|
||||
nodeKey := req.NodeKey.String()
|
||||
|
||||
var m *domain.Machine
|
||||
m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
_, tailnetDomain, sub := h.names(m)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"jit": fmt.Sprintf("%d", util.NextID()),
|
||||
"iss": h.issuer,
|
||||
"sub": sub,
|
||||
"aud": []string{req.Audience},
|
||||
"exp": jwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
"nbf": jwt.NewNumericDate(now),
|
||||
"iat": jwt.NewNumericDate(now),
|
||||
|
||||
"key": m.NodeKey,
|
||||
"addresses": []string{m.IPv4.String(), m.IPv6.String()},
|
||||
"nid": m.ID,
|
||||
"node": sub,
|
||||
"domain": tailnetDomain,
|
||||
}
|
||||
|
||||
if m.HasTags() {
|
||||
tags := []string{}
|
||||
for _, t := range m.Tags {
|
||||
tags = append(tags, fmt.Sprintf("%s:%s", tailnetDomain, t))
|
||||
}
|
||||
claims["tags"] = tags
|
||||
} else {
|
||||
claims["user"] = fmt.Sprintf("%s:%s", tailnetDomain, m.User.Name)
|
||||
claims["uid"] = m.UserID
|
||||
}
|
||||
|
||||
unsignedToken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
unsignedToken.Header["kid"] = keySet.Key.Id
|
||||
|
||||
jwtB64, err := unsignedToken.SignedString(&keySet.Key.PrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := tailcfg.TokenResponse{IDToken: jwtB64}
|
||||
return binder.WriteResponse(c, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (h *IDTokenHandlers) names(m *domain.Machine) (string, string, string) {
|
||||
var name = m.Name
|
||||
if m.NameIdx != 0 {
|
||||
name = fmt.Sprintf("%s-%d", m.Name, m.NameIdx)
|
||||
}
|
||||
|
||||
sanitizedTailnetName := domain.SanitizeTailnetName(m.Tailnet.Name)
|
||||
return name, sanitizedTailnetName, fmt.Sprintf("%s.%s", name, sanitizedTailnetName)
|
||||
}
|
||||
@@ -96,6 +96,7 @@ func Start(c *config.Config) error {
|
||||
registrationHandlers := handlers.NewRegistrationHandlers(bind.DefaultBinder(p), c, brokers, repository)
|
||||
pollNetMapHandler := handlers.NewPollNetMapHandler(bind.DefaultBinder(p), brokers, repository, offlineTimers)
|
||||
dnsHandlers := handlers.NewDNSHandlers(bind.DefaultBinder(p), dnsProvider)
|
||||
idTokenHandlers := handlers.NewIDTokenHandlers(bind.DefaultBinder(p), c, repository)
|
||||
|
||||
e := echo.New()
|
||||
e.Use(EchoLogger(logger))
|
||||
@@ -103,6 +104,7 @@ func Start(c *config.Config) error {
|
||||
e.POST("/machine/register", registrationHandlers.Register)
|
||||
e.POST("/machine/map", pollNetMapHandler.PollNetMap)
|
||||
e.POST("/machine/set-dns", dnsHandlers.SetDNS)
|
||||
e.POST("/machine/id-token", idTokenHandlers.FetchToken)
|
||||
|
||||
return e
|
||||
}
|
||||
@@ -111,6 +113,7 @@ func Start(c *config.Config) error {
|
||||
registrationHandlers := handlers.NewRegistrationHandlers(bind.BoxBinder(serverKey.LegacyControlKey), c, brokers, repository)
|
||||
pollNetMapHandler := handlers.NewPollNetMapHandler(bind.BoxBinder(serverKey.LegacyControlKey), brokers, repository, offlineTimers)
|
||||
dnsHandlers := handlers.NewDNSHandlers(bind.BoxBinder(serverKey.LegacyControlKey), dnsProvider)
|
||||
idTokenHandlers := handlers.NewIDTokenHandlers(bind.BoxBinder(serverKey.LegacyControlKey), c, repository)
|
||||
authenticationHandlers := handlers.NewAuthenticationHandlers(
|
||||
c,
|
||||
authProvider,
|
||||
@@ -149,6 +152,8 @@ func Start(c *config.Config) error {
|
||||
tlsAppHandler.POST("/machine/:id", registrationHandlers.Register)
|
||||
tlsAppHandler.POST("/machine/:id/map", pollNetMapHandler.PollNetMap)
|
||||
tlsAppHandler.POST("/machine/:id/set-dns", dnsHandlers.SetDNS)
|
||||
tlsAppHandler.GET("/.well-known/jwks", idTokenHandlers.Jwks)
|
||||
tlsAppHandler.GET("/.well-known/openid-configuration", idTokenHandlers.OpenIDConfig)
|
||||
|
||||
auth := tlsAppHandler.Group("/a")
|
||||
auth.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
@@ -34,3 +35,14 @@ func RandomBytes(size int) ([]byte, error) {
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func NewPrivateKey() (*rsa.PrivateKey, string, error) {
|
||||
id := RandStringBytes(22)
|
||||
|
||||
privateKey, err := rsa.GenerateKey(entropy, 2048)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return privateKey, id, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user