Files
ionscale/internal/domain/machine.go
T
Johan Siebens 5ad89ff02f initial working version
Signed-off-by: Johan Siebens <johan.siebens@gmail.com>
2022-05-09 21:54:06 +02:00

254 lines
5.3 KiB
Go

package domain
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"tailscale.com/tailcfg"
"time"
)
type Machine struct {
ID uint64 `gorm:"primary_key;autoIncrement:false"`
Name string
NameIdx uint64
MachineKey string
NodeKey string
DiscoKey string
Ephemeral bool
RegisteredTags Tags
Tags Tags
HostInfo HostInfo
Endpoints Endpoints
IPv4 string
IPv6 string
CreatedAt time.Time
ExpiresAt *time.Time
LastSeen *time.Time
UserID uint64
User User
TailnetID uint64
Tailnet Tailnet
}
type Machines []Machine
type HostInfo tailcfg.Hostinfo
func (hi *HostInfo) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, hi)
default:
return fmt.Errorf("unexpected data type %T", destination)
}
}
func (hi HostInfo) Value() (driver.Value, error) {
bytes, err := json.Marshal(hi)
return bytes, err
}
// GormDataType gorm common data type
func (HostInfo) GormDataType() string {
return "json"
}
// GormDBDataType gorm db data type
func (HostInfo) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
}
return ""
}
type Endpoints []string
func (hi *Endpoints) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, hi)
default:
return fmt.Errorf("unexpected data type %T", destination)
}
}
func (hi Endpoints) Value() (driver.Value, error) {
bytes, err := json.Marshal(hi)
return bytes, err
}
// GormDataType gorm common data type
func (Endpoints) GormDataType() string {
return "json"
}
// GormDBDataType gorm db data type
func (Endpoints) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
}
return ""
}
func (r *repository) SaveMachine(ctx context.Context, machine *Machine) error {
tx := r.withContext(ctx).Save(machine)
if tx.Error != nil {
return tx.Error
}
return nil
}
func (r *repository) DeleteMachine(ctx context.Context, id uint64) (bool, error) {
tx := r.withContext(ctx).Delete(&Machine{}, id)
return tx.RowsAffected == 1, tx.Error
}
func (r *repository) GetMachine(ctx context.Context, machineID uint64) (*Machine, error) {
var m Machine
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").First(&m, "id = ?", machineID)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
if tx.Error != nil {
return nil, tx.Error
}
return &m, nil
}
func (r *repository) GetNextMachineNameIndex(ctx context.Context, tailnetID uint64, name string) (uint64, error) {
var m Machine
tx := r.withContext(ctx).
Where("name = ? AND tailnet_id = ?", name, tailnetID).
Order("name_idx desc").
First(&m)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return 0, nil
}
if tx.Error != nil {
return 0, tx.Error
}
return m.NameIdx + 1, nil
}
func (r *repository) GetMachineByKey(ctx context.Context, tailnetID uint64, machineKey string) (*Machine, error) {
var m Machine
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").First(&m, "tailnet_id = ? AND machine_key = ?", tailnetID, machineKey)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
if tx.Error != nil {
return nil, tx.Error
}
return &m, nil
}
func (r *repository) GetMachineByKeys(ctx context.Context, machineKey string, nodeKey string) (*Machine, error) {
var m Machine
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").First(&m, "machine_key = ? AND node_key = ?", machineKey, nodeKey)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
if tx.Error != nil {
return nil, tx.Error
}
return &m, nil
}
func (r *repository) CountMachinesWithIPv4(ctx context.Context, ip string) (int64, error) {
var count int64
tx := r.withContext(ctx).Model(&Machine{}).Where("ipv4 = ?", ip).Count(&count)
if tx.Error != nil {
return 0, tx.Error
}
return count, nil
}
func (r *repository) ListMachineByTailnet(ctx context.Context, tailnetID uint64) (Machines, error) {
var machines = []Machine{}
tx := r.withContext(ctx).
Preload("Tailnet").
Preload("User").
Where("tailnet_id = ?", tailnetID).
Order("name asc, name_idx asc").
Find(&machines)
if tx.Error != nil {
return nil, tx.Error
}
return machines, nil
}
func (r *repository) ListMachinePeers(ctx context.Context, tailnetID uint64, key string) (Machines, error) {
var machines = []Machine{}
tx := r.withContext(ctx).
Preload("Tailnet").
Preload("User").
Where("tailnet_id = ? AND machine_key <> ?", tailnetID, key).
Order("id asc").
Find(&machines)
if tx.Error != nil {
return nil, tx.Error
}
return machines, nil
}
func (r *repository) ListInactiveEphemeralMachines(ctx context.Context, t time.Time) (Machines, error) {
var machines = []Machine{}
tx := r.withContext(ctx).
Where("ephemeral = ? AND last_seen < ?", true, t.UTC()).
Find(&machines)
if tx.Error != nil {
return nil, tx.Error
}
return machines, nil
}
func (r *repository) SetMachineLastSeen(ctx context.Context, machineID uint64) error {
now := time.Now().UTC()
tx := r.withContext(ctx).Model(Machine{}).Where("id = ?", machineID).Updates(map[string]interface{}{"last_seen": &now})
if tx.Error != nil {
return tx.Error
}
return nil
}