mirror of
https://github.com/jsiebens/ionscale.git
synced 2026-03-31 15:07:49 +01:00
391 lines
8.1 KiB
Go
391 lines
8.1 KiB
Go
package domain
|
|
|
|
import (
|
|
"context"
|
|
"database/sql/driver"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/schema"
|
|
"net/netip"
|
|
"tailscale.com/tailcfg"
|
|
"time"
|
|
)
|
|
|
|
type Machine struct {
|
|
ID uint64 `gorm:"primary_key;autoIncrement:false"`
|
|
Name string
|
|
NameIdx uint64
|
|
MachineKey string
|
|
NodeKey string
|
|
DiscoKey string
|
|
Ephemeral bool
|
|
RegisteredTags Tags
|
|
Tags Tags
|
|
KeyExpiryDisabled bool
|
|
|
|
HostInfo HostInfo
|
|
Endpoints Endpoints
|
|
AllowIPs AllowIPs
|
|
|
|
IPv4 IP
|
|
IPv6 IP
|
|
|
|
CreatedAt time.Time
|
|
ExpiresAt time.Time
|
|
LastSeen *time.Time
|
|
|
|
UserID uint64
|
|
User User
|
|
|
|
TailnetID uint64
|
|
Tailnet Tailnet
|
|
}
|
|
|
|
type Machines []Machine
|
|
|
|
func (m *Machine) IsExpired() bool {
|
|
return !m.KeyExpiryDisabled && !m.ExpiresAt.IsZero() && m.ExpiresAt.Before(time.Now())
|
|
}
|
|
|
|
func (m *Machine) HasIP(v netip.Addr) bool {
|
|
return v.Compare(*m.IPv4.Addr) == 0 || v.Compare(*m.IPv6.Addr) == 0
|
|
}
|
|
|
|
func (m *Machine) HasTag(tag string) bool {
|
|
for _, t := range m.Tags {
|
|
if t == tag {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *Machine) HasUser(loginName string) bool {
|
|
return m.User.Name == loginName
|
|
}
|
|
|
|
func (m *Machine) HasTags() bool {
|
|
return len(m.Tags) != 0
|
|
}
|
|
|
|
func (m *Machine) IsAllowedIP(i netip.Addr) bool {
|
|
if m.HasIP(i) {
|
|
return true
|
|
}
|
|
for _, t := range m.AllowIPs {
|
|
if t.Contains(i) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *Machine) IsAllowedIPPrefix(i netip.Prefix) bool {
|
|
for _, t := range m.AllowIPs {
|
|
if t.Overlaps(i) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
type IP struct {
|
|
*netip.Addr
|
|
}
|
|
|
|
func (i *IP) Scan(destination interface{}) error {
|
|
switch value := destination.(type) {
|
|
case string:
|
|
ip, err := netip.ParseAddr(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*i = IP{&ip}
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("unexpected data type %T", destination)
|
|
}
|
|
}
|
|
|
|
func (i IP) Value() (driver.Value, error) {
|
|
if i.Addr == nil {
|
|
return nil, nil
|
|
}
|
|
return i.String(), nil
|
|
}
|
|
|
|
func (IP) GormDBDataType(db *gorm.DB, field *schema.Field) string {
|
|
switch db.Dialector.Name() {
|
|
case "postgres":
|
|
return "TEXT"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
type AllowIPs []netip.Prefix
|
|
|
|
func (hi *AllowIPs) Scan(destination interface{}) error {
|
|
switch value := destination.(type) {
|
|
case []byte:
|
|
return json.Unmarshal(value, hi)
|
|
default:
|
|
return fmt.Errorf("unexpected data type %T", destination)
|
|
}
|
|
}
|
|
|
|
func (hi AllowIPs) Value() (driver.Value, error) {
|
|
bytes, err := json.Marshal(hi)
|
|
return bytes, err
|
|
}
|
|
|
|
// GormDataType gorm common data type
|
|
func (AllowIPs) GormDataType() string {
|
|
return "json"
|
|
}
|
|
|
|
// GormDBDataType gorm db data type
|
|
func (AllowIPs) GormDBDataType(db *gorm.DB, field *schema.Field) string {
|
|
switch db.Dialector.Name() {
|
|
case "sqlite":
|
|
return "JSON"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
type HostInfo tailcfg.Hostinfo
|
|
|
|
func (hi *HostInfo) Scan(destination interface{}) error {
|
|
switch value := destination.(type) {
|
|
case []byte:
|
|
return json.Unmarshal(value, hi)
|
|
default:
|
|
return fmt.Errorf("unexpected data type %T", destination)
|
|
}
|
|
}
|
|
|
|
func (hi HostInfo) Value() (driver.Value, error) {
|
|
bytes, err := json.Marshal(hi)
|
|
return bytes, err
|
|
}
|
|
|
|
// GormDataType gorm common data type
|
|
func (HostInfo) GormDataType() string {
|
|
return "json"
|
|
}
|
|
|
|
// GormDBDataType gorm db data type
|
|
func (HostInfo) GormDBDataType(db *gorm.DB, field *schema.Field) string {
|
|
switch db.Dialector.Name() {
|
|
case "sqlite":
|
|
return "JSON"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
type Endpoints []string
|
|
|
|
func (hi *Endpoints) Scan(destination interface{}) error {
|
|
switch value := destination.(type) {
|
|
case []byte:
|
|
return json.Unmarshal(value, hi)
|
|
default:
|
|
return fmt.Errorf("unexpected data type %T", destination)
|
|
}
|
|
}
|
|
|
|
func (hi Endpoints) Value() (driver.Value, error) {
|
|
bytes, err := json.Marshal(hi)
|
|
return bytes, err
|
|
}
|
|
|
|
// GormDataType gorm common data type
|
|
func (Endpoints) GormDataType() string {
|
|
return "json"
|
|
}
|
|
|
|
// GormDBDataType gorm db data type
|
|
func (Endpoints) GormDBDataType(db *gorm.DB, field *schema.Field) string {
|
|
switch db.Dialector.Name() {
|
|
case "sqlite":
|
|
return "JSON"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (r *repository) SaveMachine(ctx context.Context, machine *Machine) error {
|
|
tx := r.withContext(ctx).Save(machine)
|
|
|
|
if tx.Error != nil {
|
|
return tx.Error
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *repository) DeleteMachine(ctx context.Context, id uint64) (bool, error) {
|
|
tx := r.withContext(ctx).Delete(&Machine{}, id)
|
|
return tx.RowsAffected == 1, tx.Error
|
|
}
|
|
|
|
func (r *repository) GetMachine(ctx context.Context, machineID uint64) (*Machine, error) {
|
|
var m Machine
|
|
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").First(&m, "id = ?", machineID)
|
|
|
|
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
|
|
if tx.Error != nil {
|
|
return nil, tx.Error
|
|
}
|
|
|
|
return &m, nil
|
|
}
|
|
|
|
func (r *repository) GetNextMachineNameIndex(ctx context.Context, tailnetID uint64, name string) (uint64, error) {
|
|
var m Machine
|
|
|
|
tx := r.withContext(ctx).
|
|
Where("name = ? AND tailnet_id = ?", name, tailnetID).
|
|
Order("name_idx desc").
|
|
First(&m)
|
|
|
|
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
|
|
return 0, nil
|
|
}
|
|
|
|
if tx.Error != nil {
|
|
return 0, tx.Error
|
|
}
|
|
|
|
return m.NameIdx + 1, nil
|
|
}
|
|
|
|
func (r *repository) GetMachineByKey(ctx context.Context, tailnetID uint64, machineKey string) (*Machine, error) {
|
|
var m Machine
|
|
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").First(&m, "tailnet_id = ? AND machine_key = ?", tailnetID, machineKey)
|
|
|
|
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
|
|
if tx.Error != nil {
|
|
return nil, tx.Error
|
|
}
|
|
|
|
return &m, nil
|
|
}
|
|
|
|
func (r *repository) GetMachineByKeys(ctx context.Context, machineKey string, nodeKey string) (*Machine, error) {
|
|
var m Machine
|
|
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").First(&m, "machine_key = ? AND node_key = ?", machineKey, nodeKey)
|
|
|
|
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
|
|
if tx.Error != nil {
|
|
return nil, tx.Error
|
|
}
|
|
|
|
return &m, nil
|
|
}
|
|
|
|
func (r *repository) CountMachinesWithIPv4(ctx context.Context, ip string) (int64, error) {
|
|
var count int64
|
|
|
|
tx := r.withContext(ctx).Model(&Machine{}).Where("ipv4 = ?", ip).Count(&count)
|
|
|
|
if tx.Error != nil {
|
|
return 0, tx.Error
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func (r *repository) CountMachineByTailnet(ctx context.Context, tailnetID uint64) (int64, error) {
|
|
var count int64
|
|
|
|
tx := r.withContext(ctx).Model(&Machine{}).Where("tailnet_id = ?", tailnetID).Count(&count)
|
|
|
|
if tx.Error != nil {
|
|
return 0, tx.Error
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func (r *repository) DeleteMachineByTailnet(ctx context.Context, tailnetID uint64) error {
|
|
tx := r.withContext(ctx).Model(&Machine{}).Where("tailnet_id = ?", tailnetID).Delete(&Machine{})
|
|
return tx.Error
|
|
}
|
|
|
|
func (r *repository) DeleteMachineByUser(ctx context.Context, userID uint64) error {
|
|
tx := r.withContext(ctx).Model(&Machine{}).Where("user_id = ?", userID).Delete(&Machine{})
|
|
return tx.Error
|
|
}
|
|
|
|
func (r *repository) ListMachineByTailnet(ctx context.Context, tailnetID uint64) (Machines, error) {
|
|
var machines = []Machine{}
|
|
|
|
tx := r.withContext(ctx).
|
|
Preload("Tailnet").
|
|
Preload("User").
|
|
Where("tailnet_id = ?", tailnetID).
|
|
Order("name asc, name_idx asc").
|
|
Find(&machines)
|
|
|
|
if tx.Error != nil {
|
|
return nil, tx.Error
|
|
}
|
|
|
|
return machines, nil
|
|
}
|
|
|
|
func (r *repository) ListMachinePeers(ctx context.Context, tailnetID uint64, key string) (Machines, error) {
|
|
var machines = []Machine{}
|
|
|
|
tx := r.withContext(ctx).
|
|
Preload("Tailnet").
|
|
Preload("User").
|
|
Where("tailnet_id = ? AND machine_key <> ?", tailnetID, key).
|
|
Order("id asc").
|
|
Find(&machines)
|
|
|
|
if tx.Error != nil {
|
|
return nil, tx.Error
|
|
}
|
|
|
|
return machines, nil
|
|
}
|
|
|
|
func (r *repository) ListInactiveEphemeralMachines(ctx context.Context, t time.Time) (Machines, error) {
|
|
var machines = []Machine{}
|
|
|
|
tx := r.withContext(ctx).
|
|
Where("ephemeral = ? AND last_seen < ?", true, t.UTC()).
|
|
Find(&machines)
|
|
|
|
if tx.Error != nil {
|
|
return nil, tx.Error
|
|
}
|
|
|
|
return machines, nil
|
|
}
|
|
|
|
func (r *repository) SetMachineLastSeen(ctx context.Context, machineID uint64) error {
|
|
now := time.Now().UTC()
|
|
tx := r.withContext(ctx).
|
|
Model(Machine{}).
|
|
Where("id = ?", machineID).
|
|
Updates(map[string]interface{}{"last_seen": &now})
|
|
|
|
if tx.Error != nil {
|
|
return tx.Error
|
|
}
|
|
|
|
return nil
|
|
}
|