mirror of
https://github.com/jsiebens/ionscale.git
synced 2026-03-31 15:07:49 +01:00
391 lines
11 KiB
Go
391 lines
11 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/caddyserver/certmagic"
|
|
"github.com/hashicorp/go-plugin"
|
|
"github.com/jsiebens/ionscale/internal/auth"
|
|
"github.com/jsiebens/ionscale/internal/config"
|
|
"github.com/jsiebens/ionscale/internal/core"
|
|
"github.com/jsiebens/ionscale/internal/database"
|
|
"github.com/jsiebens/ionscale/internal/derp"
|
|
"github.com/jsiebens/ionscale/internal/dns"
|
|
"github.com/jsiebens/ionscale/internal/domain"
|
|
"github.com/jsiebens/ionscale/internal/handlers"
|
|
"github.com/jsiebens/ionscale/internal/service"
|
|
"github.com/jsiebens/ionscale/internal/stunserver"
|
|
"github.com/jsiebens/ionscale/internal/templates"
|
|
"github.com/jsiebens/ionscale/internal/util"
|
|
"github.com/labstack/echo-contrib/echoprometheus"
|
|
"github.com/labstack/echo-contrib/pprof"
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/labstack/echo/v4/middleware"
|
|
certmagicsql "github.com/travisjeffery/certmagic-sqlstorage"
|
|
"go.uber.org/zap"
|
|
"go.uber.org/zap/zapcore"
|
|
"golang.org/x/net/http2"
|
|
"golang.org/x/net/http2/h2c"
|
|
"golang.org/x/sync/errgroup"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
"tailscale.com/types/key"
|
|
"time"
|
|
)
|
|
|
|
func Start(ctx context.Context, c *config.Config) error {
|
|
ctx = contextWithSigterm(ctx)
|
|
|
|
logger, err := setupLogging(c.Logging)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
logger.Info("Starting ionscale server")
|
|
|
|
logError := func(err error) error {
|
|
if err != nil {
|
|
zap.L().WithOptions(zap.AddCallerSkip(1)).Error("Unable to start server", zap.Error(err))
|
|
}
|
|
return err
|
|
}
|
|
|
|
util.EnsureIDProvider()
|
|
|
|
derpMap, err := derp.LoadDERPSources(c)
|
|
if err != nil {
|
|
logger.Warn("not all derp sources are read successfully", zap.Error(err))
|
|
}
|
|
|
|
domain.SetDefaultDERPMap(derpMap)
|
|
|
|
httpLogger := logger.Named("http")
|
|
dbLogger := logger.Named("db")
|
|
|
|
db, repository, err := database.OpenDB(&c.Database, dbLogger)
|
|
if err != nil {
|
|
return logError(err)
|
|
}
|
|
|
|
sessionManager := core.NewPollMapSessionManager()
|
|
|
|
defaultControlKeys, err := repository.GetControlKeys(ctx)
|
|
if err != nil {
|
|
return logError(err)
|
|
}
|
|
|
|
serverKey, err := c.ReadServerKeys(defaultControlKeys)
|
|
if err != nil {
|
|
return logError(err)
|
|
}
|
|
|
|
core.StartWorker(repository, sessionManager)
|
|
|
|
// prepare CertMagic
|
|
if c.Tls.AcmeEnabled {
|
|
storage, err := certmagicsql.NewStorage(ctx, db, certmagicsql.Options{})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
certmagicLogger := logger.Named("certmagic")
|
|
certmagic.DefaultACME.Agreed = true
|
|
certmagic.DefaultACME.Email = c.Tls.AcmeEmail
|
|
certmagic.DefaultACME.CA = c.Tls.AcmeCA
|
|
certmagic.DefaultACME.Logger = certmagicLogger
|
|
certmagic.Default.Logger = certmagicLogger
|
|
certmagic.Default.Storage = storage
|
|
|
|
cfg := certmagic.NewDefault()
|
|
if err := cfg.ManageAsync(ctx, []string{c.PublicUrl.Hostname()}); err != nil {
|
|
return logError(err)
|
|
}
|
|
}
|
|
|
|
authProvider, systemIAMPolicy, err := setupAuthProvider(c.Auth)
|
|
if err != nil {
|
|
return logError(fmt.Errorf("error configuring OIDC provider: %v", err))
|
|
}
|
|
|
|
dnsProvider, err := dns.NewProvider(c.DNS)
|
|
if err != nil {
|
|
return logError(err)
|
|
}
|
|
|
|
promMiddleware := echoprometheus.NewMiddleware("http")
|
|
|
|
createPeerHandler := func(machinePublicKey key.MachinePublic) http.Handler {
|
|
registrationHandlers := handlers.NewRegistrationHandlers(machinePublicKey, c, sessionManager, repository)
|
|
pollNetMapHandler := handlers.NewPollNetMapHandler(machinePublicKey, sessionManager, repository)
|
|
dnsHandlers := handlers.NewDNSHandlers(machinePublicKey, dnsProvider)
|
|
idTokenHandlers := handlers.NewIDTokenHandlers(machinePublicKey, c, repository)
|
|
sshActionHandlers := handlers.NewSSHActionHandlers(machinePublicKey, c, repository)
|
|
queryFeatureHandlers := handlers.NewQueryFeatureHandlers(machinePublicKey, dnsProvider, repository)
|
|
updateHealthHandlers := handlers.NewUpdateHealthHandlers(machinePublicKey, repository)
|
|
|
|
e := echo.New()
|
|
e.Binder = handlers.JsonBinder{}
|
|
e.Use(promMiddleware, EchoLogger(httpLogger), EchoErrorHandler(), EchoRecover())
|
|
e.POST("/machine/register", registrationHandlers.Register)
|
|
e.POST("/machine/map", pollNetMapHandler.PollNetMap)
|
|
e.POST("/machine/set-dns", dnsHandlers.SetDNS)
|
|
e.POST("/machine/id-token", idTokenHandlers.FetchToken)
|
|
e.GET("/machine/ssh/action/:src_machine_id/to/:dst_machine_id", sshActionHandlers.StartAuth)
|
|
e.GET("/machine/ssh/action/:src_machine_id/to/:dst_machine_id/:check_period", sshActionHandlers.StartAuth)
|
|
e.GET("/machine/ssh/action/check/:key", sshActionHandlers.CheckAuth)
|
|
e.POST("/machine/feature/query", queryFeatureHandlers.QueryFeature)
|
|
e.POST("/machine/update-health", updateHealthHandlers.UpdateHealth)
|
|
|
|
return e
|
|
}
|
|
|
|
noiseHandlers := handlers.NewNoiseHandlers(serverKey.ControlKey, createPeerHandler)
|
|
oidcConfigHandlers := handlers.NewOIDCConfigHandlers(c, repository)
|
|
|
|
authenticationHandlers := handlers.NewAuthenticationHandlers(
|
|
c,
|
|
authProvider,
|
|
systemIAMPolicy,
|
|
repository,
|
|
)
|
|
|
|
rpcService := service.NewService(c, authProvider, dnsProvider, repository, sessionManager)
|
|
rpcPath, rpcHandler := NewRpcHandler(serverKey.SystemAdminKey, repository, rpcService)
|
|
|
|
metricsMux := echo.New()
|
|
metricsMux.GET("/metrics", echoprometheus.NewHandler())
|
|
pprof.Register(metricsMux)
|
|
|
|
webMux := echo.New()
|
|
webMux.Renderer = &templates.Renderer{}
|
|
webMux.Pre(handlers.HttpsRedirect(c.Tls))
|
|
webMux.Use(promMiddleware, EchoLogger(httpLogger), EchoErrorHandler(), EchoRecover())
|
|
|
|
webMux.Any("/*", handlers.IndexHandler(http.StatusNotFound))
|
|
webMux.Any("/", handlers.IndexHandler(http.StatusOK))
|
|
webMux.POST(rpcPath+"*", echo.WrapHandler(rpcHandler))
|
|
webMux.GET("/version", handlers.Version)
|
|
webMux.GET("/key", handlers.KeyHandler(serverKey))
|
|
webMux.POST("/ts2021", noiseHandlers.Upgrade)
|
|
webMux.GET("/.well-known/jwks", oidcConfigHandlers.Jwks)
|
|
webMux.GET("/.well-known/openid-configuration", oidcConfigHandlers.OpenIDConfig)
|
|
|
|
csrf := middleware.CSRFWithConfig(middleware.CSRFConfig{TokenLookup: "form:_csrf"})
|
|
webMux.GET("/a/:flow/:key", authenticationHandlers.StartAuth, csrf)
|
|
webMux.POST("/a/:flow/:key", authenticationHandlers.ProcessAuth, csrf)
|
|
webMux.GET("/a/callback", authenticationHandlers.Callback, csrf)
|
|
webMux.POST("/a/callback", authenticationHandlers.EndAuth, csrf)
|
|
webMux.GET("/a/success", authenticationHandlers.Success, csrf)
|
|
webMux.GET("/a/error", authenticationHandlers.Error, csrf)
|
|
|
|
if !c.DERP.Server.Disabled {
|
|
derpHandlers := handlers.NewDERPHandler()
|
|
|
|
metricsMux.GET("/debug/derp/traffic", derpHandlers.DebugTraffic)
|
|
metricsMux.GET("/debug/derp/check", derpHandlers.DebugCheck)
|
|
|
|
webMux.GET("/derp", derpHandlers.Handler)
|
|
webMux.GET("/derp/latency-check", derpHandlers.LatencyCheck)
|
|
}
|
|
|
|
webL, err := webListener(c)
|
|
if err != nil {
|
|
return logError(err)
|
|
}
|
|
|
|
metricsL, err := metricsListener(c)
|
|
if err != nil {
|
|
return logError(err)
|
|
}
|
|
|
|
stunL, err := stunListener(c)
|
|
if err != nil {
|
|
return logError(err)
|
|
}
|
|
|
|
errorLog, err := zap.NewStdLogAt(logger, zap.DebugLevel)
|
|
if err != nil {
|
|
return logError(err)
|
|
}
|
|
|
|
webServer := &http.Server{ErrorLog: errorLog, Handler: h2c.NewHandler(webMux, &http2.Server{})}
|
|
metricsServer := &http.Server{ErrorLog: errorLog, Handler: metricsMux}
|
|
stunServer := stunserver.New(stunL)
|
|
|
|
g, gCtx := errgroup.WithContext(ctx)
|
|
|
|
go func() {
|
|
<-gCtx.Done()
|
|
logger.Sugar().Infow("Shutting down ionscale server")
|
|
plugin.CleanupClients()
|
|
shutdownHttpServer(metricsServer)
|
|
shutdownHttpServer(webServer)
|
|
_ = stunServer.Shutdown()
|
|
}()
|
|
|
|
g.Go(func() error { return serveHttp(webServer, webL) })
|
|
g.Go(func() error { return serveHttp(metricsServer, metricsL) })
|
|
g.Go(func() error { return stunServer.Serve() })
|
|
|
|
fields := []zap.Field{
|
|
zap.String("url", c.PublicUrl.String()),
|
|
zap.String("addr", c.ListenAddr),
|
|
zap.String("metrics_addr", c.MetricsListenAddr),
|
|
}
|
|
|
|
if !c.DERP.Server.Disabled {
|
|
fields = append(fields, zap.String("stun_addr", c.StunListenAddr))
|
|
} else {
|
|
logger.Warn("Embedded DERP is disabled")
|
|
}
|
|
|
|
if c.Tls.AcmeEnabled {
|
|
logger.Info("TLS is enabled with ACME", zap.String("domain", c.PublicUrl.Hostname()))
|
|
logger.Info("Server is running", fields...)
|
|
} else if !c.Tls.Disable {
|
|
logger.Info("TLS is enabled", zap.String("cert", c.Tls.CertFile))
|
|
logger.Info("Server is running", fields...)
|
|
} else {
|
|
logger.Warn("TLS is disabled")
|
|
logger.Info("Server is running", fields...)
|
|
}
|
|
|
|
return g.Wait()
|
|
}
|
|
|
|
func serveHttp(s *http.Server, l net.Listener) error {
|
|
if l == nil || s == nil {
|
|
return nil
|
|
}
|
|
if err := s.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func shutdownHttpServer(s *http.Server) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_ = s.Shutdown(ctx)
|
|
}
|
|
|
|
func setupAuthProvider(config config.Auth) (auth.Provider, *domain.IAMPolicy, error) {
|
|
if len(config.Provider.Issuer) == 0 {
|
|
return nil, &domain.IAMPolicy{}, nil
|
|
}
|
|
|
|
authProvider, err := auth.NewOIDCProvider(&config.Provider)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return authProvider, &domain.IAMPolicy{
|
|
Subs: config.SystemAdminPolicy.Subs,
|
|
Emails: config.SystemAdminPolicy.Emails,
|
|
Filters: config.SystemAdminPolicy.Filters,
|
|
}, nil
|
|
}
|
|
|
|
func webListener(config *config.Config) (net.Listener, error) {
|
|
if config.Tls.Disable {
|
|
return net.Listen("tcp", config.ListenAddr)
|
|
}
|
|
|
|
if config.Tls.AcmeEnabled {
|
|
cfg := certmagic.NewDefault()
|
|
tlsConfig := cfg.TLSConfig()
|
|
tlsConfig.NextProtos = append([]string{"h2", "http/1.1"}, tlsConfig.NextProtos...)
|
|
return tls.Listen("tcp", config.ListenAddr, tlsConfig)
|
|
}
|
|
|
|
certPEMBlock, err := os.ReadFile(config.Tls.CertFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading cert file: %v", err)
|
|
}
|
|
keyPEMBlock, err := os.ReadFile(config.Tls.KeyFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading key file: %v", err)
|
|
}
|
|
|
|
cer, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading cert and key file: %v", err)
|
|
}
|
|
|
|
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cer}}
|
|
|
|
return tls.Listen("tcp", config.ListenAddr, tlsConfig)
|
|
}
|
|
|
|
func metricsListener(config *config.Config) (net.Listener, error) {
|
|
return net.Listen("tcp", config.MetricsListenAddr)
|
|
}
|
|
|
|
func stunListener(config *config.Config) (*net.UDPConn, error) {
|
|
if config.DERP.Server.Disabled {
|
|
return nil, nil
|
|
}
|
|
|
|
addr, err := net.ResolveUDPAddr("udp", config.StunListenAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return net.ListenUDP("udp", addr)
|
|
}
|
|
|
|
func setupLogging(config config.Logging) (*zap.Logger, error) {
|
|
level, err := zap.ParseAtomicLevel(config.Level)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pc := zap.NewProductionConfig()
|
|
pc.Level = level
|
|
pc.DisableStacktrace = true
|
|
pc.OutputPaths = []string{"stdout"}
|
|
pc.Encoding = "console"
|
|
pc.EncoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
|
|
pc.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
|
|
|
if config.File != "" {
|
|
pc.OutputPaths = []string{config.File}
|
|
}
|
|
|
|
if config.Format == "json" {
|
|
pc.Encoding = "json"
|
|
}
|
|
|
|
logger, err := pc.Build()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
globalLogger := logger.Named("ionscale")
|
|
zap.ReplaceGlobals(globalLogger)
|
|
|
|
return globalLogger, nil
|
|
}
|
|
|
|
func contextWithSigterm(ctx context.Context) context.Context {
|
|
ctxWithCancel, cancel := context.WithCancel(ctx)
|
|
go func() {
|
|
defer cancel()
|
|
|
|
signalCh := make(chan os.Signal, 1)
|
|
signal.Notify(signalCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
select {
|
|
case <-signalCh:
|
|
case <-ctx.Done():
|
|
}
|
|
}()
|
|
|
|
return ctxWithCancel
|
|
}
|