fix: improve session management and update channels, avoiding potential deadlocks

This commit is contained in:
Johan Siebens
2025-03-06 16:14:05 +01:00
parent ed3e1eb54a
commit c1c708269c
4 changed files with 84 additions and 50 deletions
+72 -43
View File
@@ -1,6 +1,7 @@
package core
import (
"github.com/puzpuzpuz/xsync/v3"
"slices"
"sync"
"time"
@@ -9,35 +10,81 @@ import (
type Ping struct{}
type PollMapSessionManager interface {
Register(tailnetID uint64, machineID uint64, ch chan *Ping)
Deregister(tailnetID uint64, machineID uint64)
Register(tailnetID uint64, machineID uint64, ch chan<- *Ping)
Deregister(tailnetID uint64, machineID uint64, ch chan<- *Ping)
HasSession(tailnetID uint64, machineID uint64) bool
NotifyAll(tailnetID uint64, ignoreMachineIDs ...uint64)
}
func NewPollMapSessionManager() PollMapSessionManager {
return &pollMapSessionManager{
data: map[uint64]map[uint64]chan *Ping{},
timers: map[uint64]*time.Timer{},
tailnets: xsync.NewMapOf[uint64, *tailnetSessionManager](),
}
}
type pollMapSessionManager struct {
sync.RWMutex
data map[uint64]map[uint64]chan *Ping
timers map[uint64]*time.Timer
tailnets *xsync.MapOf[uint64, *tailnetSessionManager]
}
func (n *pollMapSessionManager) Register(tailnetID uint64, machineID uint64, ch chan *Ping) {
func (n *pollMapSessionManager) load(tailnetID uint64) *tailnetSessionManager {
m, _ := n.tailnets.LoadOrCompute(tailnetID, func() *tailnetSessionManager {
return &tailnetSessionManager{
targets: make(map[uint64]chan<- *Ping),
timers: make(map[uint64]*time.Timer),
sessions: xsync.NewMapOf[uint64, bool](),
}
})
return m
}
func (n *pollMapSessionManager) Register(tailnetID uint64, machineID uint64, ch chan<- *Ping) {
n.load(tailnetID).Register(machineID, ch)
}
func (n *pollMapSessionManager) Deregister(tailnetID uint64, machineID uint64, ch chan<- *Ping) {
n.load(tailnetID).Deregister(machineID, ch)
}
func (n *pollMapSessionManager) HasSession(tailnetID uint64, machineID uint64) bool {
return n.load(tailnetID).HasSession(machineID)
}
func (n *pollMapSessionManager) NotifyAll(tailnetID uint64, ignoreMachineIDs ...uint64) {
n.load(tailnetID).NotifyAll(ignoreMachineIDs...)
}
type tailnetSessionManager struct {
sync.RWMutex
targets map[uint64]chan<- *Ping
timers map[uint64]*time.Timer
sessions *xsync.MapOf[uint64, bool]
}
func (n *tailnetSessionManager) NotifyAll(ignoreMachineIDs ...uint64) {
n.RLock()
defer n.RUnlock()
for i, p := range n.targets {
if !slices.Contains(ignoreMachineIDs, i) {
select {
case p <- &Ping{}:
default: // ignore, channel has a small buffer, failing to insert means there is already a ping pending
}
}
}
}
func (n *tailnetSessionManager) Register(machineID uint64, ch chan<- *Ping) {
n.Lock()
defer n.Unlock()
if ss := n.data[tailnetID]; ss == nil {
n.data[tailnetID] = map[uint64]chan *Ping{machineID: ch}
} else {
ss[machineID] = ch
if curr, ok := n.targets[machineID]; ok {
close(curr)
}
n.targets[machineID] = ch
n.sessions.Store(machineID, true)
t, ok := n.timers[machineID]
if ok {
t.Stop()
@@ -47,22 +94,25 @@ func (n *pollMapSessionManager) Register(tailnetID uint64, machineID uint64, ch
timer := time.NewTimer(5 * time.Second)
go func() {
<-timer.C
if n.HasSession(tailnetID, machineID) {
n.NotifyAll(tailnetID, machineID)
if n.HasSession(machineID) {
n.NotifyAll(machineID)
}
}()
n.timers[machineID] = timer
}
func (n *pollMapSessionManager) Deregister(tailnetID uint64, machineID uint64) {
func (n *tailnetSessionManager) Deregister(machineID uint64, ch chan<- *Ping) {
n.Lock()
defer n.Unlock()
if ss := n.data[tailnetID]; ss != nil {
delete(ss, machineID)
if curr, ok := n.targets[machineID]; ok && curr != ch {
return
}
delete(n.targets, machineID)
n.sessions.Store(machineID, false)
t, ok := n.timers[machineID]
if ok {
t.Stop()
@@ -72,36 +122,15 @@ func (n *pollMapSessionManager) Deregister(tailnetID uint64, machineID uint64) {
timer := time.NewTimer(10 * time.Second)
go func() {
<-timer.C
if !n.HasSession(tailnetID, machineID) {
n.NotifyAll(tailnetID)
if !n.HasSession(machineID) {
n.NotifyAll()
}
}()
n.timers[machineID] = timer
}
func (n *pollMapSessionManager) HasSession(tailnetID uint64, machineID uint64) bool {
n.RLock()
defer n.RUnlock()
if ss := n.data[tailnetID]; ss != nil {
if _, ok := ss[machineID]; ok {
return true
}
}
return false
}
func (n *pollMapSessionManager) NotifyAll(tailnetID uint64, ignoreMachineIDs ...uint64) {
n.RLock()
defer n.RUnlock()
if ss := n.data[tailnetID]; ss != nil {
for i, p := range ss {
if !slices.Contains(ignoreMachineIDs, i) {
p <- &Ping{}
}
}
}
func (n *tailnetSessionManager) HasSession(machineID uint64) bool {
v, ok := n.sessions.Load(machineID)
return ok && v
}