feat: add support for ssh check periods

This commit is contained in:
Johan Siebens
2023-12-27 14:44:44 +01:00
parent d5ca503318
commit e31ce67f84
13 changed files with 123 additions and 29 deletions
@@ -0,0 +1,23 @@
package migration
import (
"github.com/go-gormigrate/gormigrate/v2"
"gorm.io/gorm"
"time"
)
func m202312271200_account_last_authenticated() *gormigrate.Migration {
return &gormigrate.Migration{
ID: "202312271200",
Migrate: func(db *gorm.DB) error {
type Account struct {
LastAuthenticated *time.Time
}
return db.AutoMigrate(
&Account{},
)
},
Rollback: nil,
}
}
@@ -16,6 +16,7 @@ func Migrations() []*gormigrate.Migration {
m202211031100_add_authorized_column(),
m202212201300_add_user_id_column(),
m202212270800_machine_indeces(),
m202312271200_account_last_authenticated(),
}
return migrations
}
+19 -3
View File
@@ -5,12 +5,14 @@ import (
"errors"
"github.com/jsiebens/ionscale/internal/util"
"gorm.io/gorm"
"time"
)
type Account struct {
ID uint64 `gorm:"primary_key"`
ExternalID string
LoginName string
ID uint64 `gorm:"primary_key"`
ExternalID string
LoginName string
LastAuthenticated *time.Time
}
func (r *repository) GetOrCreateAccount(ctx context.Context, externalID, loginName string) (*Account, bool, error) {
@@ -43,3 +45,17 @@ func (r *repository) GetAccount(ctx context.Context, id uint64) (*Account, error
return &account, nil
}
func (r *repository) SetAccountLastAuthenticated(ctx context.Context, accountID uint64) error {
now := time.Now().UTC()
tx := r.withContext(ctx).
Model(Account{}).
Where("id = ?", accountID).
Updates(map[string]interface{}{"last_authenticated": &now})
if tx.Error != nil {
return tx.Error
}
return nil
}
+5 -4
View File
@@ -41,10 +41,11 @@ type ACL struct {
}
type SSHRule struct {
Action string `json:"action"`
Src []string `json:"src"`
Dst []string `json:"dst"`
Users []string `json:"users"`
Action string `json:"action"`
Src []string `json:"src"`
Dst []string `json:"dst"`
Users []string `json:"users"`
CheckPeriod string `json:"checkPeriod,omitempty"`
}
func DefaultACLPolicy() ACLPolicy {
+7 -1
View File
@@ -39,12 +39,18 @@ func (a ACLPolicy) BuildSSHPolicy(srcs []Machine, dst *Machine) *tailcfg.SSHPoli
AllowLocalPortForwarding: true,
}
if rule.Action == "check" {
if rule.Action == "check" && rule.CheckPeriod == "" {
action = &tailcfg.SSHAction{
HoldAndDelegate: "https://unused/machine/ssh/action/$SRC_NODE_ID/to/$DST_NODE_ID",
}
}
if rule.Action == "check" && rule.CheckPeriod != "" {
action = &tailcfg.SSHAction{
HoldAndDelegate: "https://unused/machine/ssh/action/$SRC_NODE_ID/to/$DST_NODE_ID/" + rule.CheckPeriod,
}
}
selfUsers, otherUsers := a.expandSSHDstToSSHUsers(dst, rule)
if len(selfUsers) != 0 {
+9 -7
View File
@@ -357,7 +357,7 @@ func (r *repository) DeleteMachine(ctx context.Context, id uint64) (bool, error)
func (r *repository) GetMachine(ctx context.Context, machineID uint64) (*Machine, error) {
var m Machine
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").Take(&m, machineID)
tx := r.withContext(ctx).Preload("Tailnet").Preload("User").Preload("User.Account").Take(&m, machineID)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
@@ -458,9 +458,10 @@ func (r *repository) ListMachineByTailnet(ctx context.Context, tailnetID uint64)
tx := r.withContext(ctx).
Preload("Tailnet").
Preload("User").
Where("tailnet_id = ?", tailnetID).
Order("name asc, name_idx asc").
Joins("User").
Joins("User.Account").
Where("machines.tailnet_id = ?", tailnetID).
Order("machines.name asc, machines.name_idx asc").
Find(&machines)
if tx.Error != nil {
@@ -475,9 +476,10 @@ func (r *repository) ListMachinePeers(ctx context.Context, tailnetID uint64, key
tx := r.withContext(ctx).
Preload("Tailnet").
Preload("User").
Where("tailnet_id = ? AND machine_key <> ?", tailnetID, key).
Order("id asc").
Joins("User").
Joins("User.Account").
Where("machines.tailnet_id = ? AND machines.machine_key <> ?", tailnetID, key).
Order("machines.id asc").
Find(&machines)
if tx.Error != nil {
+1
View File
@@ -23,6 +23,7 @@ type Repository interface {
GetAccount(ctx context.Context, accountID uint64) (*Account, error)
GetOrCreateAccount(ctx context.Context, externalID, loginName string) (*Account, bool, error)
SetAccountLastAuthenticated(ctx context.Context, accountID uint64) error
SaveTailnet(ctx context.Context, tailnet *Tailnet) error
GetTailnet(ctx context.Context, id uint64) (*Tailnet, error)
+4
View File
@@ -149,6 +149,10 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
return logError(err)
}
if err := h.repository.SetAccountLastAuthenticated(ctx, account.ID); err != nil {
return logError(err)
}
if state.Flow == "s" {
sshActionReq, err := h.repository.GetSSHActionRequest(ctx, state.Key)
if err != nil || sshActionReq == nil {
+27
View File
@@ -29,6 +29,7 @@ type SSHActionHandlers struct {
type sshActionRequestData struct {
SrcMachineID uint64 `param:"src_machine_id"`
DstMachineID uint64 `param:"dst_machine_id"`
CheckPeriod string `param:"check_period"`
}
func (h *SSHActionHandlers) StartAuth(c echo.Context) error {
@@ -44,6 +45,32 @@ func (h *SSHActionHandlers) StartAuth(c echo.Context) error {
return logError(err)
}
if data.CheckPeriod != "" {
checkPeriod, err := time.ParseDuration(data.CheckPeriod)
if err != nil {
return logError(err)
}
machine, err := h.repository.GetMachine(ctx, data.SrcMachineID)
if err != nil {
return logError(err)
}
if machine.User.Account != nil && machine.User.Account.LastAuthenticated != nil {
sinceLastAuthentication := time.Since(*machine.User.Account.LastAuthenticated)
if sinceLastAuthentication < checkPeriod {
resp := &tailcfg.SSHAction{
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
}
return binder.WriteResponse(c, http.StatusOK, resp)
}
}
}
key := util.RandStringBytes(8)
request := &domain.SSHActionRequest{
Key: key,
+4 -3
View File
@@ -164,9 +164,10 @@ func ToNode(m *domain.Machine, tailnet *domain.Tailnet, taggedDevicesUser *domai
sanitizedTailnetName := domain.SanitizeTailnetName(m.Tailnet.Name)
hostInfo := tailcfg.Hostinfo{
OS: hostinfo.OS,
Hostname: hostinfo.Hostname,
Services: filterServices(hostinfo.Services),
OS: hostinfo.OS,
Hostname: hostinfo.Hostname,
Services: filterServices(hostinfo.Services),
SSH_HostKeys: hostinfo.SSH_HostKeys,
}
n := tailcfg.Node{
+1
View File
@@ -122,6 +122,7 @@ func Start(c *config.Config) error {
e.POST("/machine/set-dns", dnsHandlers.SetDNS)
e.POST("/machine/id-token", idTokenHandlers.FetchToken)
e.GET("/machine/ssh/action/:src_machine_id/to/:dst_machine_id", sshActionHandlers.StartAuth)
e.GET("/machine/ssh/action/:src_machine_id/to/:dst_machine_id/:check_period", sshActionHandlers.StartAuth)
e.GET("/machine/ssh/action/check/:key", sshActionHandlers.CheckAuth)
return e