Files
ionscale/internal/domain/machine.go
T
2025-02-15 15:17:37 +01:00

537 lines
11 KiB
Go

package domain
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"net/netip"
"tailscale.com/tailcfg"
"time"
)
type MachineRepository interface {
SaveMachine(ctx context.Context, m *Machine) error
DeleteMachine(ctx context.Context, id uint64) (bool, error)
GetMachine(ctx context.Context, id uint64) (*Machine, error)
GetMachineByKeyAndUser(ctx context.Context, key string, userID uint64) (*Machine, error)
GetMachineByKeys(ctx context.Context, machineKey string, nodeKey string) (*Machine, error)
CountMachinesWithIPv4(ctx context.Context, ip string) (int64, error)
GetNextMachineNameIndex(ctx context.Context, tailnetID uint64, name string) (uint64, error)
ListMachineByTailnet(ctx context.Context, tailnetID uint64) (Machines, error)
CountMachineByTailnet(ctx context.Context, tailnetID uint64) (int64, error)
DeleteMachineByTailnet(ctx context.Context, tailnetID uint64) error
DeleteMachineByUser(ctx context.Context, userID uint64) error
ListMachinePeers(ctx context.Context, tailnetID uint64, machineID uint64) (Machines, error)
ListInactiveEphemeralMachines(ctx context.Context, checkpoint time.Time) (Machines, error)
SetMachineLastSeen(ctx context.Context, machineID uint64) error
}
type Machine struct {
ID uint64 `gorm:"primary_key"`
Name string
NameIdx uint64
MachineKey string
NodeKey string
DiscoKey string
Ephemeral bool
RegisteredTags Tags
Tags Tags
KeyExpiryDisabled bool
Authorized bool
UseOSHostname bool `gorm:"default:true"`
HostInfo HostInfo
Endpoints Endpoints
AllowIPs AllowIPs
AutoAllowIPs AllowIPs
IPv4 IP
IPv6 IP
CreatedAt time.Time
ExpiresAt time.Time
LastSeen *time.Time
UserID uint64
User User
TailnetID uint64
Tailnet Tailnet
}
type Machines []Machine
func (m *Machine) CompleteName() string {
if m.NameIdx != 0 {
return fmt.Sprintf("%s-%d", m.Name, m.NameIdx)
}
return m.Name
}
func (m *Machine) IPs() []string {
return []string{m.IPv4.String(), m.IPv6.String()}
}
func (m *Machine) IsExpired() bool {
return !m.KeyExpiryDisabled && !m.ExpiresAt.IsZero() && m.ExpiresAt.Before(time.Now())
}
func (m *Machine) HasIP(v netip.Addr) bool {
return v.Compare(*m.IPv4.Addr) == 0 || v.Compare(*m.IPv6.Addr) == 0
}
func (m *Machine) HasTag(tag string) bool {
for _, t := range m.Tags {
if t == tag {
return true
}
}
return false
}
func (m *Machine) HasUser(loginName string) bool {
return m.User.Name == loginName
}
func (m *Machine) HasTags() bool {
return len(m.Tags) != 0
}
func (m *Machine) IsAdvertisedExitNode() bool {
for _, r := range m.HostInfo.RoutableIPs {
if r.Bits() == 0 {
return true
}
}
return false
}
func (m *Machine) IsAllowedExitNode() bool {
for _, r := range m.AllowIPs {
if r.Bits() == 0 {
return true
}
}
for _, r := range m.AutoAllowIPs {
if r.Bits() == 0 {
return true
}
}
return false
}
func (m *Machine) AdvertisedPrefixes() []string {
var result []string
for _, r := range m.HostInfo.RoutableIPs {
if r.Bits() != 0 {
result = append(result, r.String())
}
}
return result
}
func (m *Machine) AllowedPrefixes() []string {
result := StringSet{}
for _, r := range m.AllowIPs {
if r.Bits() != 0 {
result.Add(r.String())
}
}
for _, r := range m.AutoAllowIPs {
if r.Bits() != 0 {
result.Add(r.String())
}
}
return result.Items()
}
func (m *Machine) IsAllowedIP(i netip.Addr) bool {
if m.HasIP(i) {
return true
}
for _, t := range m.AllowIPs {
if t.Contains(i) {
return true
}
}
for _, t := range m.AutoAllowIPs {
if t.Contains(i) {
return true
}
}
return false
}
func (m *Machine) IsAllowedIPPrefix(i netip.Prefix) bool {
for _, t := range m.AllowIPs {
if t.Overlaps(i) {
return true
}
}
for _, t := range m.AutoAllowIPs {
if t.Overlaps(i) {
return true
}
}
return false
}
func (m *Machine) IsExitNode() bool {
for _, t := range m.AllowIPs {
if t.Bits() == 0 {
return true
}
}
for _, t := range m.AutoAllowIPs {
if t.Bits() == 0 {
return true
}
}
return false
}
type IP struct {
*netip.Addr
}
func (i *IP) Scan(destination interface{}) error {
switch value := destination.(type) {
case string:
ip, err := netip.ParseAddr(value)
if err != nil {
return err
}
*i = IP{&ip}
return nil
default:
return fmt.Errorf("unexpected data type %T", destination)
}
}
func (i IP) Value() (driver.Value, error) {
if i.Addr == nil {
return nil, nil
}
return i.String(), nil
}
func (IP) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "postgres":
return "TEXT"
}
return ""
}
type AllowIPs []netip.Prefix
type AllowIPsSet struct {
items map[netip.Prefix]bool
}
func NewAllowIPsSet(t AllowIPs) *AllowIPsSet {
s := &AllowIPsSet{}
return s.Add(t...)
}
func (s *AllowIPsSet) Add(t ...netip.Prefix) *AllowIPsSet {
if s.items == nil {
s.items = make(map[netip.Prefix]bool)
}
for _, v := range t {
s.items[v] = true
}
return s
}
func (s *AllowIPsSet) Remove(t ...netip.Prefix) *AllowIPsSet {
if s.items == nil {
return s
}
for _, v := range t {
delete(s.items, v)
}
return s
}
func (s *AllowIPsSet) Items() []netip.Prefix {
items := []netip.Prefix{}
for i := range s.items {
items = append(items, i)
}
return items
}
func (hi *AllowIPs) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, hi)
default:
return fmt.Errorf("unexpected data type %T", destination)
}
}
func (hi AllowIPs) Value() (driver.Value, error) {
bytes, err := json.Marshal(hi)
return bytes, err
}
// GormDataType gorm common data type
func (AllowIPs) GormDataType() string {
return "json"
}
// GormDBDataType gorm db data type
func (AllowIPs) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
}
return ""
}
type HostInfo tailcfg.Hostinfo
func (hi *HostInfo) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, hi)
default:
return fmt.Errorf("unexpected data type %T", destination)
}
}
func (hi HostInfo) Value() (driver.Value, error) {
bytes, err := json.Marshal(hi)
return bytes, err
}
// GormDataType gorm common data type
func (HostInfo) GormDataType() string {
return "json"
}
// GormDBDataType gorm db data type
func (HostInfo) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
}
return ""
}
type Endpoints []netip.AddrPort
func (hi *Endpoints) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, hi)
default:
return fmt.Errorf("unexpected data type %T", destination)
}
}
func (hi Endpoints) Value() (driver.Value, error) {
bytes, err := json.Marshal(hi)
return bytes, err
}
// GormDataType gorm common data type
func (Endpoints) GormDataType() string {
return "json"
}
// GormDBDataType gorm db data type
func (Endpoints) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
}
return ""
}
func (r *repository) SaveMachine(ctx context.Context, machine *Machine) error {
tx := r.withContext(ctx).Save(machine)
if tx.Error != nil {
return tx.Error
}
return nil
}
func (r *repository) DeleteMachine(ctx context.Context, id uint64) (bool, error) {
tx := r.withContext(ctx).Delete(&Machine{}, id)
return tx.RowsAffected == 1, tx.Error
}
func (r *repository) GetMachine(ctx context.Context, machineID uint64) (*Machine, error) {
var m Machine
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").Preload("User.Account").Take(&m, machineID)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
if tx.Error != nil {
return nil, tx.Error
}
return &m, nil
}
func (r *repository) GetNextMachineNameIndex(ctx context.Context, tailnetID uint64, name string) (uint64, error) {
var m Machine
tx := r.withContext(ctx).
Where("name = ? AND tailnet_id = ?", name, tailnetID).
Order("name_idx desc").
Take(&m)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return 0, nil
}
if tx.Error != nil {
return 0, tx.Error
}
return m.NameIdx + 1, nil
}
func (r *repository) GetMachineByKeyAndUser(ctx context.Context, machineKey string, userID uint64) (*Machine, error) {
var m Machine
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").Take(&m, "machine_key = ? AND user_id = ?", machineKey, userID)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
if tx.Error != nil {
return nil, tx.Error
}
return &m, nil
}
func (r *repository) GetMachineByKeys(ctx context.Context, machineKey string, nodeKey string) (*Machine, error) {
var m Machine
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").Take(&m, "machine_key = ? AND node_key = ?", machineKey, nodeKey)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
if tx.Error != nil {
return nil, tx.Error
}
return &m, nil
}
func (r *repository) CountMachinesWithIPv4(ctx context.Context, ip string) (int64, error) {
var count int64
tx := r.withContext(ctx).Model(&Machine{}).Where("ipv4 = ?", ip).Count(&count)
if tx.Error != nil {
return 0, tx.Error
}
return count, nil
}
func (r *repository) CountMachineByTailnet(ctx context.Context, tailnetID uint64) (int64, error) {
var count int64
tx := r.withContext(ctx).Model(&Machine{}).Where("tailnet_id = ?", tailnetID).Count(&count)
if tx.Error != nil {
return 0, tx.Error
}
return count, nil
}
func (r *repository) DeleteMachineByTailnet(ctx context.Context, tailnetID uint64) error {
tx := r.withContext(ctx).Model(&Machine{}).Where("tailnet_id = ?", tailnetID).Delete(&Machine{})
return tx.Error
}
func (r *repository) DeleteMachineByUser(ctx context.Context, userID uint64) error {
tx := r.withContext(ctx).Model(&Machine{}).Where("user_id = ?", userID).Delete(&Machine{})
return tx.Error
}
func (r *repository) ListMachineByTailnet(ctx context.Context, tailnetID uint64) (Machines, error) {
var machines = []Machine{}
tx := r.withContext(ctx).
Preload("Tailnet").
Joins("User").
Joins("User.Account").
Where("machines.tailnet_id = ?", tailnetID).
Order("machines.name asc, machines.name_idx asc").
Find(&machines)
if tx.Error != nil {
return nil, tx.Error
}
return machines, nil
}
func (r *repository) ListMachinePeers(ctx context.Context, tailnetID uint64, machineID uint64) (Machines, error) {
var machines []Machine
tx := r.withContext(ctx).
Preload("Tailnet").
Joins("User").
Joins("User.Account").
Where("machines.tailnet_id = ? AND machines.id <> ?", tailnetID, machineID).
Order("machines.id asc").
Find(&machines)
if tx.Error != nil {
return nil, tx.Error
}
return machines, nil
}
func (r *repository) ListInactiveEphemeralMachines(ctx context.Context, t time.Time) (Machines, error) {
var machines = []Machine{}
tx := r.withContext(ctx).
Where("ephemeral = ? AND last_seen < ?", true, t.UTC()).
Find(&machines)
if tx.Error != nil {
return nil, tx.Error
}
return machines, nil
}
func (r *repository) SetMachineLastSeen(ctx context.Context, machineID uint64) error {
now := time.Now().UTC()
tx := r.withContext(ctx).
Model(Machine{}).
Where("id = ?", machineID).
Updates(map[string]interface{}{"last_seen": &now})
if tx.Error != nil {
return tx.Error
}
return nil
}