Files

248 lines
5.8 KiB
Go

package handlers
import (
"context"
"encoding/binary"
"encoding/json"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/core"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/mapping"
"github.com/klauspost/compress/zstd"
"github.com/labstack/echo/v4"
"net/http"
"slices"
"sync"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/util/dnsname"
"time"
)
func NewPollNetMapHandler(
machineKey key.MachinePublic,
sessionManager core.PollMapSessionManager,
repository domain.Repository) *PollNetMapHandler {
handler := &PollNetMapHandler{
machineKey: machineKey,
sessionManager: sessionManager,
repository: repository,
}
return handler
}
type PollNetMapHandler struct {
machineKey key.MachinePublic
repository domain.Repository
sessionManager core.PollMapSessionManager
}
func (h *PollNetMapHandler) PollNetMap(c echo.Context) error {
ctx := c.Request().Context()
req := &tailcfg.MapRequest{}
if err := c.Bind(req); err != nil {
return logError(err)
}
if req.Version < SupportedCapabilityVersion {
return UnsupportedClientVersionError
}
machineKey := h.machineKey.String()
nodeKey := req.NodeKey.String()
var m *domain.Machine
m, err := h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
if err != nil {
return logError(err)
}
if m == nil {
return echo.NewHTTPError(http.StatusNotFound)
}
return h.handlePollNetMap(c, m, req)
}
func (h *PollNetMapHandler) handlePollNetMap(c echo.Context, m *domain.Machine, mapRequest *tailcfg.MapRequest) error {
ctx := c.Request().Context()
now := time.Now().UTC()
tailnetID := m.TailnetID
machineID := m.ID
mapper := mapping.NewPollNetMapper(mapRequest, m.ID, h.repository, h.sessionManager)
response, err := h.createMapResponse(mapper, false, mapRequest.Compress)
if err != nil {
return logError(err)
}
if !mapRequest.Stream {
if !slices.Equal(m.HostInfo.RoutableIPs, mapRequest.Hostinfo.RoutableIPs) {
m.AutoAllowIPs = m.Tailnet.ACLPolicy.Get().FindAutoApprovedIPs(mapRequest.Hostinfo.RoutableIPs, m.Tags, &m.User)
}
m.HostInfo = domain.HostInfo(*mapRequest.Hostinfo)
m.DiscoKey = mapRequest.DiscoKey.String()
m.Endpoints = mapRequest.Endpoints
m.LastSeen = &now
sanitizeHostname := dnsname.SanitizeHostname(m.HostInfo.Hostname)
if m.UseOSHostname && m.Name != sanitizeHostname {
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, m.TailnetID, sanitizeHostname)
if err != nil {
return logError(err)
}
m.Name = sanitizeHostname
m.NameIdx = nameIdx
}
if err := h.repository.SaveMachine(ctx, m); err != nil {
return logError(err)
}
h.sessionManager.NotifyAll(tailnetID)
return c.JSONBlob(http.StatusOK, response)
}
updateChan := make(chan *core.Ping, 20)
h.sessionManager.Register(m.TailnetID, m.ID, updateChan)
// Listen to connection close
notify := ctx.Done()
keepAliveResponse, err := h.createKeepAliveResponse(mapRequest)
if err != nil {
return logError(err)
}
c.Response().WriteHeader(http.StatusOK)
if _, err := c.Response().Write(response); err != nil {
return logError(err)
}
c.Response().Flush()
connectedDevices.WithLabelValues(m.Tailnet.Name).Inc()
keepAliveTicker := time.NewTicker(config.KeepAliveInterval())
syncTicker := time.NewTicker(5 * time.Second)
defer func() {
connectedDevices.WithLabelValues(m.Tailnet.Name).Dec()
h.sessionManager.Deregister(m.TailnetID, m.ID, updateChan)
keepAliveTicker.Stop()
syncTicker.Stop()
_ = h.repository.SetMachineLastSeen(ctx, machineID)
}()
var shouldUpdate bool = false
for {
select {
case _, ok := <-updateChan:
if !ok {
return nil
}
shouldUpdate = true
case <-keepAliveTicker.C:
if mapRequest.KeepAlive {
if _, err := c.Response().Write(keepAliveResponse); err != nil {
return logError(err)
}
_ = h.repository.SetMachineLastSeen(ctx, machineID)
c.Response().Flush()
}
case <-syncTicker.C:
if shouldUpdate {
machine, err := h.repository.GetMachine(ctx, machineID)
if err != nil {
return logError(err)
}
if machine == nil {
return nil
}
var payload []byte
var payloadErr error
payload, payloadErr = h.createMapResponse(mapper, true, mapRequest.Compress)
if payloadErr != nil {
return payloadErr
}
if _, err := c.Response().Write(payload); err != nil {
return logError(err)
}
c.Response().Flush()
shouldUpdate = false
}
case <-notify:
return nil
}
}
}
func (h *PollNetMapHandler) createKeepAliveResponse(request *tailcfg.MapRequest) ([]byte, error) {
mapResponse := &tailcfg.MapResponse{
KeepAlive: true,
}
return h.marshalResponse(request.Compress, mapResponse)
}
func (h *PollNetMapHandler) createMapResponse(m *mapping.PollNetMapper, delta bool, compress string) ([]byte, error) {
response, err := m.CreateMapResponse(context.Background(), delta)
if err != nil {
return nil, err
}
return h.marshalResponse(compress, response)
}
func (h *PollNetMapHandler) marshalResponse(compress string, v interface{}) ([]byte, error) {
var payload []byte
marshalled, err := json.Marshal(v)
if err != nil {
return nil, err
}
if compress == "zstd" {
payload = zstdEncode(marshalled)
} else {
payload = marshalled
}
data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, uint32(len(payload)))
data = append(data, payload...)
return data, nil
}
func zstdEncode(in []byte) []byte {
encoder := zstdEncoderPool.Get().(*zstd.Encoder)
out := encoder.EncodeAll(in, nil)
_ = encoder.Close()
zstdEncoderPool.Put(encoder)
return out
}
var zstdEncoderPool = &sync.Pool{
New: func() any {
encoder, err := smallzstd.NewEncoder(nil, zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
panic(err)
}
return encoder
},
}