Files
ionscale/internal/server/server.go
T
2025-05-24 11:01:05 +02:00

391 lines
11 KiB
Go

package server
import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/caddyserver/certmagic"
"github.com/hashicorp/go-plugin"
"github.com/jsiebens/ionscale/internal/auth"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/core"
"github.com/jsiebens/ionscale/internal/database"
"github.com/jsiebens/ionscale/internal/derp"
"github.com/jsiebens/ionscale/internal/dns"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/handlers"
"github.com/jsiebens/ionscale/internal/service"
"github.com/jsiebens/ionscale/internal/stunserver"
"github.com/jsiebens/ionscale/internal/templates"
"github.com/jsiebens/ionscale/internal/util"
"github.com/labstack/echo-contrib/echoprometheus"
"github.com/labstack/echo-contrib/pprof"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
certmagicsql "github.com/travisjeffery/certmagic-sqlstorage"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"golang.org/x/sync/errgroup"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"tailscale.com/types/key"
"time"
)
func Start(ctx context.Context, c *config.Config) error {
ctx = contextWithSigterm(ctx)
logger, err := setupLogging(c.Logging)
if err != nil {
return err
}
logger.Info("Starting ionscale server")
logError := func(err error) error {
if err != nil {
zap.L().WithOptions(zap.AddCallerSkip(1)).Error("Unable to start server", zap.Error(err))
}
return err
}
util.EnsureIDProvider()
derpMap, err := derp.LoadDERPSources(c)
if err != nil {
logger.Warn("not all derp sources are read successfully", zap.Error(err))
}
domain.SetDefaultDERPMap(derpMap)
httpLogger := logger.Named("http")
dbLogger := logger.Named("db")
db, repository, err := database.OpenDB(&c.Database, dbLogger)
if err != nil {
return logError(err)
}
sessionManager := core.NewPollMapSessionManager()
defaultControlKeys, err := repository.GetControlKeys(ctx)
if err != nil {
return logError(err)
}
serverKey, err := c.ReadServerKeys(defaultControlKeys)
if err != nil {
return logError(err)
}
core.StartWorker(repository, sessionManager)
// prepare CertMagic
if c.Tls.AcmeEnabled {
storage, err := certmagicsql.NewStorage(ctx, db, certmagicsql.Options{})
if err != nil {
return err
}
certmagicLogger := logger.Named("certmagic")
certmagic.DefaultACME.Agreed = true
certmagic.DefaultACME.Email = c.Tls.AcmeEmail
certmagic.DefaultACME.CA = c.Tls.AcmeCA
certmagic.DefaultACME.Logger = certmagicLogger
certmagic.Default.Logger = certmagicLogger
certmagic.Default.Storage = storage
cfg := certmagic.NewDefault()
if err := cfg.ManageAsync(ctx, []string{c.PublicUrl.Hostname()}); err != nil {
return logError(err)
}
}
authProvider, systemIAMPolicy, err := setupAuthProvider(c.Auth)
if err != nil {
return logError(fmt.Errorf("error configuring OIDC provider: %v", err))
}
dnsProvider, err := dns.NewProvider(c.DNS)
if err != nil {
return logError(err)
}
promMiddleware := echoprometheus.NewMiddleware("http")
createPeerHandler := func(machinePublicKey key.MachinePublic) http.Handler {
registrationHandlers := handlers.NewRegistrationHandlers(machinePublicKey, c, sessionManager, repository)
pollNetMapHandler := handlers.NewPollNetMapHandler(machinePublicKey, sessionManager, repository)
dnsHandlers := handlers.NewDNSHandlers(machinePublicKey, dnsProvider)
idTokenHandlers := handlers.NewIDTokenHandlers(machinePublicKey, c, repository)
sshActionHandlers := handlers.NewSSHActionHandlers(machinePublicKey, c, repository)
queryFeatureHandlers := handlers.NewQueryFeatureHandlers(machinePublicKey, dnsProvider, repository)
updateHealthHandlers := handlers.NewUpdateHealthHandlers(machinePublicKey, repository)
e := echo.New()
e.Binder = handlers.JsonBinder{}
e.Use(promMiddleware, EchoLogger(httpLogger), EchoErrorHandler(), EchoRecover())
e.POST("/machine/register", registrationHandlers.Register)
e.POST("/machine/map", pollNetMapHandler.PollNetMap)
e.POST("/machine/set-dns", dnsHandlers.SetDNS)
e.POST("/machine/id-token", idTokenHandlers.FetchToken)
e.GET("/machine/ssh/action/:src_machine_id/to/:dst_machine_id", sshActionHandlers.StartAuth)
e.GET("/machine/ssh/action/:src_machine_id/to/:dst_machine_id/:check_period", sshActionHandlers.StartAuth)
e.GET("/machine/ssh/action/check/:key", sshActionHandlers.CheckAuth)
e.POST("/machine/feature/query", queryFeatureHandlers.QueryFeature)
e.POST("/machine/update-health", updateHealthHandlers.UpdateHealth)
return e
}
noiseHandlers := handlers.NewNoiseHandlers(serverKey.ControlKey, createPeerHandler)
oidcConfigHandlers := handlers.NewOIDCConfigHandlers(c, repository)
authenticationHandlers := handlers.NewAuthenticationHandlers(
c,
authProvider,
systemIAMPolicy,
repository,
)
rpcService := service.NewService(c, authProvider, dnsProvider, repository, sessionManager)
rpcPath, rpcHandler := NewRpcHandler(serverKey.SystemAdminKey, repository, rpcService)
metricsMux := echo.New()
metricsMux.GET("/metrics", echoprometheus.NewHandler())
pprof.Register(metricsMux)
webMux := echo.New()
webMux.Renderer = &templates.Renderer{}
webMux.Pre(handlers.HttpsRedirect(c.Tls))
webMux.Use(promMiddleware, EchoLogger(httpLogger), EchoErrorHandler(), EchoRecover())
webMux.Any("/*", handlers.IndexHandler(http.StatusNotFound))
webMux.Any("/", handlers.IndexHandler(http.StatusOK))
webMux.POST(rpcPath+"*", echo.WrapHandler(rpcHandler))
webMux.GET("/version", handlers.Version)
webMux.GET("/key", handlers.KeyHandler(serverKey))
webMux.POST("/ts2021", noiseHandlers.Upgrade)
webMux.GET("/.well-known/jwks", oidcConfigHandlers.Jwks)
webMux.GET("/.well-known/openid-configuration", oidcConfigHandlers.OpenIDConfig)
csrf := middleware.CSRFWithConfig(middleware.CSRFConfig{TokenLookup: "form:_csrf"})
webMux.GET("/a/:flow/:key", authenticationHandlers.StartAuth, csrf)
webMux.POST("/a/:flow/:key", authenticationHandlers.ProcessAuth, csrf)
webMux.GET("/a/callback", authenticationHandlers.Callback, csrf)
webMux.POST("/a/callback", authenticationHandlers.EndAuth, csrf)
webMux.GET("/a/success", authenticationHandlers.Success, csrf)
webMux.GET("/a/error", authenticationHandlers.Error, csrf)
if !c.DERP.Server.Disabled {
derpHandlers := handlers.NewDERPHandler()
metricsMux.GET("/debug/derp/traffic", derpHandlers.DebugTraffic)
metricsMux.GET("/debug/derp/check", derpHandlers.DebugCheck)
webMux.GET("/derp", derpHandlers.Handler)
webMux.GET("/derp/latency-check", derpHandlers.LatencyCheck)
}
webL, err := webListener(c)
if err != nil {
return logError(err)
}
metricsL, err := metricsListener(c)
if err != nil {
return logError(err)
}
stunL, err := stunListener(c)
if err != nil {
return logError(err)
}
errorLog, err := zap.NewStdLogAt(logger, zap.DebugLevel)
if err != nil {
return logError(err)
}
webServer := &http.Server{ErrorLog: errorLog, Handler: h2c.NewHandler(webMux, &http2.Server{})}
metricsServer := &http.Server{ErrorLog: errorLog, Handler: metricsMux}
stunServer := stunserver.New(stunL)
g, gCtx := errgroup.WithContext(ctx)
go func() {
<-gCtx.Done()
logger.Sugar().Infow("Shutting down ionscale server")
plugin.CleanupClients()
shutdownHttpServer(metricsServer)
shutdownHttpServer(webServer)
_ = stunServer.Shutdown()
}()
g.Go(func() error { return serveHttp(webServer, webL) })
g.Go(func() error { return serveHttp(metricsServer, metricsL) })
g.Go(func() error { return stunServer.Serve() })
fields := []zap.Field{
zap.String("url", c.PublicUrl.String()),
zap.String("addr", c.ListenAddr),
zap.String("metrics_addr", c.MetricsListenAddr),
}
if !c.DERP.Server.Disabled {
fields = append(fields, zap.String("stun_addr", c.StunListenAddr))
} else {
logger.Warn("Embedded DERP is disabled")
}
if c.Tls.AcmeEnabled {
logger.Info("TLS is enabled with ACME", zap.String("domain", c.PublicUrl.Hostname()))
logger.Info("Server is running", fields...)
} else if !c.Tls.Disable {
logger.Info("TLS is enabled", zap.String("cert", c.Tls.CertFile))
logger.Info("Server is running", fields...)
} else {
logger.Warn("TLS is disabled")
logger.Info("Server is running", fields...)
}
return g.Wait()
}
func serveHttp(s *http.Server, l net.Listener) error {
if l == nil || s == nil {
return nil
}
if err := s.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}
func shutdownHttpServer(s *http.Server) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.Shutdown(ctx)
}
func setupAuthProvider(config config.Auth) (auth.Provider, *domain.IAMPolicy, error) {
if len(config.Provider.Issuer) == 0 {
return nil, &domain.IAMPolicy{}, nil
}
authProvider, err := auth.NewOIDCProvider(&config.Provider)
if err != nil {
return nil, nil, err
}
return authProvider, &domain.IAMPolicy{
Subs: config.SystemAdminPolicy.Subs,
Emails: config.SystemAdminPolicy.Emails,
Filters: config.SystemAdminPolicy.Filters,
}, nil
}
func webListener(config *config.Config) (net.Listener, error) {
if config.Tls.Disable {
return net.Listen("tcp", config.ListenAddr)
}
if config.Tls.AcmeEnabled {
cfg := certmagic.NewDefault()
tlsConfig := cfg.TLSConfig()
tlsConfig.NextProtos = append([]string{"h2", "http/1.1"}, tlsConfig.NextProtos...)
return tls.Listen("tcp", config.ListenAddr, tlsConfig)
}
certPEMBlock, err := os.ReadFile(config.Tls.CertFile)
if err != nil {
return nil, fmt.Errorf("error reading cert file: %v", err)
}
keyPEMBlock, err := os.ReadFile(config.Tls.KeyFile)
if err != nil {
return nil, fmt.Errorf("error reading key file: %v", err)
}
cer, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return nil, fmt.Errorf("error reading cert and key file: %v", err)
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cer}}
return tls.Listen("tcp", config.ListenAddr, tlsConfig)
}
func metricsListener(config *config.Config) (net.Listener, error) {
return net.Listen("tcp", config.MetricsListenAddr)
}
func stunListener(config *config.Config) (*net.UDPConn, error) {
if config.DERP.Server.Disabled {
return nil, nil
}
addr, err := net.ResolveUDPAddr("udp", config.StunListenAddr)
if err != nil {
return nil, err
}
return net.ListenUDP("udp", addr)
}
func setupLogging(config config.Logging) (*zap.Logger, error) {
level, err := zap.ParseAtomicLevel(config.Level)
if err != nil {
return nil, err
}
pc := zap.NewProductionConfig()
pc.Level = level
pc.DisableStacktrace = true
pc.OutputPaths = []string{"stdout"}
pc.Encoding = "console"
pc.EncoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
pc.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
if config.File != "" {
pc.OutputPaths = []string{config.File}
}
if config.Format == "json" {
pc.Encoding = "json"
}
logger, err := pc.Build()
if err != nil {
return nil, err
}
globalLogger := logger.Named("ionscale")
zap.ReplaceGlobals(globalLogger)
return globalLogger, nil
}
func contextWithSigterm(ctx context.Context) context.Context {
ctxWithCancel, cancel := context.WithCancel(ctx)
go func() {
defer cancel()
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
select {
case <-signalCh:
case <-ctx.Done():
}
}()
return ctxWithCancel
}