mirror of
https://github.com/jsiebens/ionscale.git
synced 2026-03-31 15:07:49 +01:00
fix: improve session management and update channels, avoiding potential deadlocks
This commit is contained in:
@@ -191,6 +191,7 @@ require (
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.55.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/safchain/ethtool v0.3.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
|
||||
@@ -804,6 +804,8 @@ github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB
|
||||
github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -9,35 +10,81 @@ import (
|
||||
type Ping struct{}
|
||||
|
||||
type PollMapSessionManager interface {
|
||||
Register(tailnetID uint64, machineID uint64, ch chan *Ping)
|
||||
Deregister(tailnetID uint64, machineID uint64)
|
||||
Register(tailnetID uint64, machineID uint64, ch chan<- *Ping)
|
||||
Deregister(tailnetID uint64, machineID uint64, ch chan<- *Ping)
|
||||
HasSession(tailnetID uint64, machineID uint64) bool
|
||||
NotifyAll(tailnetID uint64, ignoreMachineIDs ...uint64)
|
||||
}
|
||||
|
||||
func NewPollMapSessionManager() PollMapSessionManager {
|
||||
return &pollMapSessionManager{
|
||||
data: map[uint64]map[uint64]chan *Ping{},
|
||||
timers: map[uint64]*time.Timer{},
|
||||
tailnets: xsync.NewMapOf[uint64, *tailnetSessionManager](),
|
||||
}
|
||||
}
|
||||
|
||||
type pollMapSessionManager struct {
|
||||
sync.RWMutex
|
||||
data map[uint64]map[uint64]chan *Ping
|
||||
timers map[uint64]*time.Timer
|
||||
tailnets *xsync.MapOf[uint64, *tailnetSessionManager]
|
||||
}
|
||||
|
||||
func (n *pollMapSessionManager) Register(tailnetID uint64, machineID uint64, ch chan *Ping) {
|
||||
func (n *pollMapSessionManager) load(tailnetID uint64) *tailnetSessionManager {
|
||||
m, _ := n.tailnets.LoadOrCompute(tailnetID, func() *tailnetSessionManager {
|
||||
return &tailnetSessionManager{
|
||||
targets: make(map[uint64]chan<- *Ping),
|
||||
timers: make(map[uint64]*time.Timer),
|
||||
sessions: xsync.NewMapOf[uint64, bool](),
|
||||
}
|
||||
})
|
||||
return m
|
||||
}
|
||||
|
||||
func (n *pollMapSessionManager) Register(tailnetID uint64, machineID uint64, ch chan<- *Ping) {
|
||||
n.load(tailnetID).Register(machineID, ch)
|
||||
}
|
||||
|
||||
func (n *pollMapSessionManager) Deregister(tailnetID uint64, machineID uint64, ch chan<- *Ping) {
|
||||
n.load(tailnetID).Deregister(machineID, ch)
|
||||
}
|
||||
|
||||
func (n *pollMapSessionManager) HasSession(tailnetID uint64, machineID uint64) bool {
|
||||
return n.load(tailnetID).HasSession(machineID)
|
||||
}
|
||||
|
||||
func (n *pollMapSessionManager) NotifyAll(tailnetID uint64, ignoreMachineIDs ...uint64) {
|
||||
n.load(tailnetID).NotifyAll(ignoreMachineIDs...)
|
||||
}
|
||||
|
||||
type tailnetSessionManager struct {
|
||||
sync.RWMutex
|
||||
targets map[uint64]chan<- *Ping
|
||||
timers map[uint64]*time.Timer
|
||||
sessions *xsync.MapOf[uint64, bool]
|
||||
}
|
||||
|
||||
func (n *tailnetSessionManager) NotifyAll(ignoreMachineIDs ...uint64) {
|
||||
n.RLock()
|
||||
defer n.RUnlock()
|
||||
|
||||
for i, p := range n.targets {
|
||||
if !slices.Contains(ignoreMachineIDs, i) {
|
||||
select {
|
||||
case p <- &Ping{}:
|
||||
default: // ignore, channel has a small buffer, failing to insert means there is already a ping pending
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *tailnetSessionManager) Register(machineID uint64, ch chan<- *Ping) {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
if ss := n.data[tailnetID]; ss == nil {
|
||||
n.data[tailnetID] = map[uint64]chan *Ping{machineID: ch}
|
||||
} else {
|
||||
ss[machineID] = ch
|
||||
if curr, ok := n.targets[machineID]; ok {
|
||||
close(curr)
|
||||
}
|
||||
|
||||
n.targets[machineID] = ch
|
||||
n.sessions.Store(machineID, true)
|
||||
|
||||
t, ok := n.timers[machineID]
|
||||
if ok {
|
||||
t.Stop()
|
||||
@@ -47,22 +94,25 @@ func (n *pollMapSessionManager) Register(tailnetID uint64, machineID uint64, ch
|
||||
timer := time.NewTimer(5 * time.Second)
|
||||
go func() {
|
||||
<-timer.C
|
||||
if n.HasSession(tailnetID, machineID) {
|
||||
n.NotifyAll(tailnetID, machineID)
|
||||
if n.HasSession(machineID) {
|
||||
n.NotifyAll(machineID)
|
||||
}
|
||||
}()
|
||||
|
||||
n.timers[machineID] = timer
|
||||
}
|
||||
|
||||
func (n *pollMapSessionManager) Deregister(tailnetID uint64, machineID uint64) {
|
||||
func (n *tailnetSessionManager) Deregister(machineID uint64, ch chan<- *Ping) {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
if ss := n.data[tailnetID]; ss != nil {
|
||||
delete(ss, machineID)
|
||||
if curr, ok := n.targets[machineID]; ok && curr != ch {
|
||||
return
|
||||
}
|
||||
|
||||
delete(n.targets, machineID)
|
||||
n.sessions.Store(machineID, false)
|
||||
|
||||
t, ok := n.timers[machineID]
|
||||
if ok {
|
||||
t.Stop()
|
||||
@@ -72,36 +122,15 @@ func (n *pollMapSessionManager) Deregister(tailnetID uint64, machineID uint64) {
|
||||
timer := time.NewTimer(10 * time.Second)
|
||||
go func() {
|
||||
<-timer.C
|
||||
if !n.HasSession(tailnetID, machineID) {
|
||||
n.NotifyAll(tailnetID)
|
||||
if !n.HasSession(machineID) {
|
||||
n.NotifyAll()
|
||||
}
|
||||
}()
|
||||
|
||||
n.timers[machineID] = timer
|
||||
}
|
||||
|
||||
func (n *pollMapSessionManager) HasSession(tailnetID uint64, machineID uint64) bool {
|
||||
n.RLock()
|
||||
defer n.RUnlock()
|
||||
|
||||
if ss := n.data[tailnetID]; ss != nil {
|
||||
if _, ok := ss[machineID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *pollMapSessionManager) NotifyAll(tailnetID uint64, ignoreMachineIDs ...uint64) {
|
||||
n.RLock()
|
||||
defer n.RUnlock()
|
||||
|
||||
if ss := n.data[tailnetID]; ss != nil {
|
||||
for i, p := range ss {
|
||||
if !slices.Contains(ignoreMachineIDs, i) {
|
||||
p <- &Ping{}
|
||||
}
|
||||
}
|
||||
}
|
||||
func (n *tailnetSessionManager) HasSession(machineID uint64) bool {
|
||||
v, ok := n.sessions.Load(machineID)
|
||||
return ok && v
|
||||
}
|
||||
|
||||
@@ -131,19 +131,21 @@ func (h *PollNetMapHandler) handlePollNetMap(c echo.Context, m *domain.Machine,
|
||||
|
||||
defer func() {
|
||||
connectedDevices.WithLabelValues(m.Tailnet.Name).Dec()
|
||||
h.sessionManager.Deregister(m.TailnetID, m.ID)
|
||||
h.sessionManager.Deregister(m.TailnetID, m.ID, updateChan)
|
||||
keepAliveTicker.Stop()
|
||||
syncTicker.Stop()
|
||||
_ = h.repository.SetMachineLastSeen(ctx, machineID)
|
||||
}()
|
||||
|
||||
var latestSync = time.Now()
|
||||
var latestUpdate = latestSync
|
||||
var shouldUpdate bool = false
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-updateChan:
|
||||
latestUpdate = time.Now()
|
||||
case _, ok := <-updateChan:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
shouldUpdate = true
|
||||
case <-keepAliveTicker.C:
|
||||
if mapRequest.KeepAlive {
|
||||
if _, err := c.Response().Write(keepAliveResponse); err != nil {
|
||||
@@ -153,7 +155,7 @@ func (h *PollNetMapHandler) handlePollNetMap(c echo.Context, m *domain.Machine,
|
||||
c.Response().Flush()
|
||||
}
|
||||
case <-syncTicker.C:
|
||||
if latestSync.Before(latestUpdate) {
|
||||
if shouldUpdate {
|
||||
machine, err := h.repository.GetMachine(ctx, machineID)
|
||||
if err != nil {
|
||||
return logError(err)
|
||||
@@ -176,7 +178,7 @@ func (h *PollNetMapHandler) handlePollNetMap(c echo.Context, m *domain.Machine,
|
||||
}
|
||||
c.Response().Flush()
|
||||
|
||||
latestSync = latestUpdate
|
||||
shouldUpdate = false
|
||||
}
|
||||
case <-notify:
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user