You've already forked ionscale
mirror of
https://github.com/jsiebens/ionscale.git
synced 2026-04-05 12:32:58 +01:00
feat: improve error handling/logging a little bit
This commit is contained in:
@@ -157,21 +157,23 @@ func (g *GormLoggerAdapter) Error(ctx context.Context, s string, i ...interface{
|
||||
}
|
||||
|
||||
func (g *GormLoggerAdapter) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
elapsed := time.Since(begin)
|
||||
switch {
|
||||
case err != nil && !errors.Is(err, gorm.ErrRecordNotFound):
|
||||
sql, rows := fc()
|
||||
if rows == -1 {
|
||||
g.logger.Error("Error executing query", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "err", err)
|
||||
} else {
|
||||
g.logger.Error("Error executing query", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "rows", rows, "err", err)
|
||||
}
|
||||
case g.logger.IsTrace():
|
||||
sql, rows := fc()
|
||||
if rows == -1 {
|
||||
g.logger.Trace("Statement executed", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed)
|
||||
} else {
|
||||
g.logger.Trace("Statement executed", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "rows", rows)
|
||||
if g.logger.IsTrace() {
|
||||
elapsed := time.Since(begin)
|
||||
switch {
|
||||
case err != nil && !errors.Is(err, gorm.ErrRecordNotFound):
|
||||
sql, rows := fc()
|
||||
if rows == -1 {
|
||||
g.logger.Trace("Error executing query", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "err", err)
|
||||
} else {
|
||||
g.logger.Trace("Error executing query", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "rows", rows, "err", err)
|
||||
}
|
||||
default:
|
||||
sql, rows := fc()
|
||||
if rows == -1 {
|
||||
g.logger.Trace("Statement executed", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed)
|
||||
} else {
|
||||
g.logger.Trace("Statement executed", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "rows", rows)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
Cause error
|
||||
Location string
|
||||
}
|
||||
|
||||
func Wrap(err error, skip int) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := &Error{
|
||||
Cause: err,
|
||||
Location: getLocation(skip),
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (w *Error) Error() string {
|
||||
return w.Cause.Error()
|
||||
}
|
||||
|
||||
func (f *Error) Unwrap() error {
|
||||
return f.Cause
|
||||
}
|
||||
|
||||
func (f *Error) Format(s fmt.State, verb rune) {
|
||||
fmt.Fprintf(s, "%s\n", f.Cause.Error())
|
||||
fmt.Fprintf(s, "\t%s\n", f.Location)
|
||||
}
|
||||
|
||||
func getLocation(skip int) string {
|
||||
_, file, line, _ := runtime.Caller(2 + skip)
|
||||
return fmt.Sprintf("%s:%d", file, line)
|
||||
}
|
||||
@@ -3,8 +3,10 @@ package handlers
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/jsiebens/ionscale/internal/addr"
|
||||
"github.com/jsiebens/ionscale/internal/auth"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
"github.com/mr-tron/base58"
|
||||
"net/http"
|
||||
@@ -64,7 +66,7 @@ func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
|
||||
// machine registration auth flow
|
||||
if flow == "r" || flow == "" {
|
||||
if req, err := h.repository.GetRegistrationRequestByKey(ctx, key); err != nil || req == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
csrf := c.Get(middleware.DefaultCSRFConfig.ContextKey).(string)
|
||||
@@ -74,24 +76,24 @@ func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
|
||||
// cli auth flow
|
||||
if flow == "c" {
|
||||
if s, err := h.repository.GetAuthenticationRequest(ctx, key); err != nil || s == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// ssh check auth flow
|
||||
if flow == "s" {
|
||||
if s, err := h.repository.GetSSHActionRequest(ctx, key); err != nil || s == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if h.authProvider == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(fmt.Errorf("unable to start auth flow as no auth provider is configured"), 0)
|
||||
}
|
||||
|
||||
state, err := h.createState(flow, key)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
redirectUrl := h.authProvider.GetLoginURL(h.config.CreateUrl("/a/callback"), state)
|
||||
@@ -108,7 +110,7 @@ func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error {
|
||||
|
||||
req, err := h.repository.GetRegistrationRequestByKey(ctx, key)
|
||||
if err != nil || req == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if authKey != "" {
|
||||
@@ -118,7 +120,7 @@ func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error {
|
||||
if interactive != "" {
|
||||
state, err := h.createState("r", key)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
redirectUrl := h.authProvider.GetLoginURL(h.config.CreateUrl("/a/callback"), state)
|
||||
@@ -135,17 +137,17 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
code := c.QueryParam("code")
|
||||
state, err := h.readState(c.QueryParam("state"))
|
||||
if err != nil {
|
||||
return err
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Invalid state parameter")
|
||||
}
|
||||
|
||||
user, err := h.exchangeUser(code)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
account, _, err := h.repository.GetOrCreateAccount(ctx, user.ID, user.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if state.Flow == "s" {
|
||||
@@ -156,27 +158,27 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
|
||||
machine, err := h.repository.GetMachine(ctx, sshActionReq.SrcMachineID)
|
||||
if err != nil || sshActionReq == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if !machine.HasTags() && machine.User.AccountID != nil && *machine.User.AccountID == account.ID {
|
||||
sshActionReq.Action = "accept"
|
||||
if err := h.repository.SaveSSHActionRequest(ctx, sshActionReq); err != nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
return c.Redirect(http.StatusFound, "/a/success")
|
||||
}
|
||||
|
||||
sshActionReq.Action = "reject"
|
||||
if err := h.repository.SaveSSHActionRequest(ctx, sshActionReq); err != nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
return c.Redirect(http.StatusFound, "/a/error?e=nmo")
|
||||
}
|
||||
|
||||
tailnets, err := h.listAvailableTailnets(ctx, user)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
csrf := c.Get(middleware.DefaultCSRFConfig.ContextKey).(string)
|
||||
@@ -201,7 +203,7 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
if state.Flow == "c" {
|
||||
isSystemAdmin, err := h.isSystemAdmin(ctx, user)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if !isSystemAdmin && len(tailnets) == 0 {
|
||||
@@ -220,7 +222,7 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
|
||||
})
|
||||
}
|
||||
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return echo.NewHTTPError(http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (h *AuthenticationHandlers) isSystemAdmin(ctx context.Context, u *auth.User) (bool, error) {
|
||||
@@ -250,13 +252,13 @@ func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
|
||||
|
||||
state, err := h.readState(c.QueryParam("state"))
|
||||
if err != nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Invalid state parameter")
|
||||
}
|
||||
|
||||
if state.Flow == "r" {
|
||||
req, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
|
||||
if err != nil || req == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
return h.endMachineRegistrationFlow(c, req, state)
|
||||
@@ -264,7 +266,7 @@ func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
|
||||
|
||||
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
|
||||
if err != nil || req == nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
return h.endCliAuthenticationFlow(c, req, state)
|
||||
@@ -306,12 +308,12 @@ func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *d
|
||||
|
||||
var form TailnetSelectionForm
|
||||
if err := c.Bind(&form); err != nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
account, err := h.repository.GetAccount(ctx, form.AccountID)
|
||||
if err != nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
// continue as system admin?
|
||||
@@ -322,27 +324,27 @@ func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *d
|
||||
|
||||
err := h.repository.Transaction(func(rp domain.Repository) error {
|
||||
if err := rp.SaveSystemApiKey(ctx, apiKey); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
if err := rp.SaveAuthenticationRequest(ctx, req); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
return c.Redirect(http.StatusFound, "/a/success")
|
||||
}
|
||||
|
||||
tailnet, err := h.repository.GetTailnet(ctx, form.TailnetID)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
user, _, err := h.repository.GetOrCreateUserWithAccount(ctx, tailnet, account)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
@@ -360,7 +362,7 @@ func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *d
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
return c.Redirect(http.StatusFound, "/a/success")
|
||||
@@ -371,7 +373,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
|
||||
var form TailnetSelectionForm
|
||||
if err := c.Bind(&form); err != nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
req := tailcfg.RegisterRequest(registrationRequest.Data)
|
||||
@@ -387,7 +389,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
if form.AuthKey != "" {
|
||||
authKey, err := h.repository.LoadAuthKey(ctx, form.AuthKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if authKey == nil {
|
||||
@@ -396,7 +398,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
registrationRequest.Error = "invalid auth key"
|
||||
|
||||
if err := h.repository.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
return c.Redirect(http.StatusFound, "/a/error?e=iak")
|
||||
@@ -410,17 +412,17 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
} else {
|
||||
selectedTailnet, err := h.repository.GetTailnet(ctx, form.TailnetID)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
account, err := h.repository.GetAccount(ctx, form.AccountID)
|
||||
if err != nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
selectedUser, _, err := h.repository.GetOrCreateUserWithAccount(ctx, selectedTailnet, account)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
user = selectedUser
|
||||
@@ -432,7 +434,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
registrationRequest.Authenticated = false
|
||||
registrationRequest.Error = err.Error()
|
||||
if err := h.repository.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
|
||||
return c.Redirect(http.StatusFound, "/a/error")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
return c.Redirect(http.StatusFound, "/a/error?e=nto")
|
||||
}
|
||||
@@ -443,7 +445,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
|
||||
m, err := h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
@@ -456,7 +458,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
|
||||
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
m = &domain.Machine{
|
||||
@@ -480,7 +482,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
|
||||
ipv4, ipv6, err := addr.SelectIP(checkIP(ctx, h.repository.CountMachinesWithIPv4))
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
m.IPv4 = domain.IP{Addr: ipv4}
|
||||
m.IPv6 = domain.IP{Addr: ipv6}
|
||||
@@ -493,7 +495,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
if m.Name != sanitizeHostname {
|
||||
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
m.Name = sanitizeHostname
|
||||
m.NameIdx = nameIdx
|
||||
@@ -526,7 +528,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m.Authorized {
|
||||
|
||||
@@ -3,6 +3,7 @@ package handlers
|
||||
import (
|
||||
"github.com/jsiebens/ionscale/internal/bind"
|
||||
"github.com/jsiebens/ionscale/internal/dns"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/labstack/echo/v4"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -28,12 +29,12 @@ func (h *DNSHandlers) SetDNS(c echo.Context) error {
|
||||
|
||||
binder, err := h.createBinder(c)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
req := &tailcfg.SetDNSRequest{}
|
||||
if err := binder.BindRequest(c, req); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if h.provider == nil {
|
||||
@@ -41,7 +42,7 @@ func (h *DNSHandlers) SetDNS(c echo.Context) error {
|
||||
}
|
||||
|
||||
if err := h.provider.SetRecord(ctx, req.Type, req.Name, req.Value); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(req.Name, "_acme-challenge") && req.Type == "TXT" {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/jsiebens/ionscale/internal/bind"
|
||||
"github.com/jsiebens/ionscale/internal/config"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/jsiebens/ionscale/internal/util"
|
||||
"github.com/labstack/echo/v4"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
@@ -55,7 +56,7 @@ func (h *IDTokenHandlers) OpenIDConfig(c echo.Context) error {
|
||||
func (h *IDTokenHandlers) Jwks(c echo.Context) error {
|
||||
keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context())
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
pub := jose.JSONWebKey{Key: keySet.Key.Public(), KeyID: keySet.Key.Id, Algorithm: "RS256", Use: "sig"}
|
||||
@@ -68,17 +69,17 @@ func (h *IDTokenHandlers) FetchToken(c echo.Context) error {
|
||||
|
||||
keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context())
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
binder, err := h.createBinder(c)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
req := &tailcfg.TokenRequest{}
|
||||
if err := binder.BindRequest(c, req); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
machineKey := binder.Peer().String()
|
||||
@@ -87,7 +88,7 @@ func (h *IDTokenHandlers) FetchToken(c echo.Context) error {
|
||||
var m *domain.Machine
|
||||
m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
@@ -130,7 +131,7 @@ func (h *IDTokenHandlers) FetchToken(c echo.Context) error {
|
||||
|
||||
jwtB64, err := unsignedToken.SignedString(&keySet.Key.PrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
resp := tailcfg.TokenResponse{IDToken: jwtB64}
|
||||
|
||||
@@ -22,7 +22,7 @@ func KeyHandler(keys *config.ServerKeys) echo.HandlerFunc {
|
||||
if v != "" {
|
||||
clientCapabilityVersion, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return c.String(http.StatusBadRequest, "Invalid version")
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Invalid version")
|
||||
}
|
||||
|
||||
if clientCapabilityVersion >= NoiseCapabilityVersion {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/labstack/echo/v4"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
@@ -27,7 +28,7 @@ func NewNoiseHandlers(controlKey key.MachinePrivate, createPeerHandler CreatePee
|
||||
func (h *NoiseHandlers) Upgrade(c echo.Context) error {
|
||||
conn, err := controlhttp.AcceptHTTP(c.Request().Context(), c.Response(), c.Request(), h.controlKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
handler := h.createPeerHandler(conn.Peer())
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/jsiebens/ionscale/internal/broker"
|
||||
"github.com/jsiebens/ionscale/internal/config"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/jsiebens/ionscale/internal/mapping"
|
||||
"github.com/labstack/echo/v4"
|
||||
"net/http"
|
||||
@@ -41,12 +42,12 @@ func (h *PollNetMapHandler) PollNetMap(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
binder, err := h.createBinder(c)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
req := &tailcfg.MapRequest{}
|
||||
if err := binder.BindRequest(c, req); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
machineKey := binder.Peer().String()
|
||||
@@ -55,7 +56,7 @@ func (h *PollNetMapHandler) PollNetMap(c echo.Context) error {
|
||||
var m *domain.Machine
|
||||
m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
@@ -80,7 +81,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
|
||||
m.LastSeen = &now
|
||||
|
||||
if err := h.repository.SaveMachine(ctx, m); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
tailnetID := m.TailnetID
|
||||
@@ -97,14 +98,14 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
|
||||
|
||||
response, syncedPeers, derpMapChecksum, err := h.createMapResponse(m, binder, mapRequest, false, make(map[uint64]bool), derpMapChecksum)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
updateChan := make(chan *broker.Signal, 20)
|
||||
|
||||
unsubscribe, err := h.brokers.Subscribe(tailnetID, updateChan)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
h.cancelOfflineMessage(machineID)
|
||||
|
||||
@@ -113,7 +114,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
|
||||
|
||||
keepAliveResponse, err := h.createKeepAliveResponse(binder, mapRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
keepAliveTicker := time.NewTicker(config.KeepAliveInterval())
|
||||
syncTicker := time.NewTicker(5 * time.Second)
|
||||
@@ -121,7 +122,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
|
||||
if _, err := c.Response().Write(response); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
c.Response().Flush()
|
||||
|
||||
@@ -146,7 +147,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
|
||||
case <-keepAliveTicker.C:
|
||||
if mapRequest.KeepAlive {
|
||||
if _, err := c.Response().Write(keepAliveResponse); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
_ = h.repository.SetMachineLastSeen(ctx, machineID)
|
||||
c.Response().Flush()
|
||||
@@ -155,7 +156,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
|
||||
if latestSync.Before(latestUpdate) {
|
||||
machine, err := h.repository.GetMachine(ctx, machineID)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
if machine == nil {
|
||||
return nil
|
||||
@@ -171,7 +172,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
|
||||
}
|
||||
|
||||
if _, err := c.Response().Write(payload); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
c.Response().Flush()
|
||||
|
||||
@@ -190,16 +191,16 @@ func (h *PollNetMapHandler) handleReadOnly(c echo.Context, binder bind.Binder, m
|
||||
m.DiscoKey = request.DiscoKey.String()
|
||||
|
||||
if err := h.repository.SaveMachine(ctx, m); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
response, _, _, err := h.createMapResponse(m, binder, request, false, map[uint64]bool{}, "")
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
_, err = c.Response().Write(response)
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
func (h *PollNetMapHandler) scheduleOfflineMessage(tailnetID, machineID uint64) {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/jsiebens/ionscale/internal/broker"
|
||||
"github.com/jsiebens/ionscale/internal/config"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/jsiebens/ionscale/internal/util"
|
||||
"github.com/labstack/echo/v4"
|
||||
"net/http"
|
||||
@@ -41,12 +42,12 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
|
||||
|
||||
binder, err := h.createBinder(c)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
req := &tailcfg.RegisterRequest{}
|
||||
if err := binder.BindRequest(c, req); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
machineKey := binder.Peer().String()
|
||||
@@ -56,7 +57,7 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
|
||||
m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m != nil {
|
||||
@@ -70,12 +71,12 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
|
||||
|
||||
if m.Ephemeral {
|
||||
if _, err := h.repository.DeleteMachine(ctx, m.ID); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
h.pubsub.Publish(m.TailnetID, &broker.Signal{PeersRemoved: []uint64{m.ID}})
|
||||
} else {
|
||||
if err := h.repository.SaveMachine(ctx, m); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
h.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
|
||||
}
|
||||
@@ -88,7 +89,7 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
|
||||
if m.Name != sanitizeHostname {
|
||||
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, m.TailnetID, sanitizeHostname)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
m.Name = sanitizeHostname
|
||||
m.NameIdx = nameIdx
|
||||
@@ -99,7 +100,7 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
|
||||
m.Tags = append(m.RegisteredTags, advertisedTags...)
|
||||
|
||||
if err := h.repository.SaveMachine(ctx, m); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
response := tailcfg.RegisterResponse{MachineAuthorized: true}
|
||||
@@ -146,7 +147,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
|
||||
|
||||
authKey, err := h.repository.LoadAuthKey(ctx, req.Auth.AuthKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if authKey == nil {
|
||||
@@ -172,7 +173,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
|
||||
|
||||
m, err = h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
@@ -181,7 +182,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
|
||||
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
|
||||
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
m = &domain.Machine{
|
||||
@@ -209,7 +210,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
|
||||
|
||||
ipv4, ipv6, err := addr.SelectIP(checkIP(ctx, h.repository.CountMachinesWithIPv4))
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
m.IPv4 = domain.IP{Addr: ipv4}
|
||||
m.IPv6 = domain.IP{Addr: ipv6}
|
||||
@@ -218,7 +219,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
|
||||
if m.Name != sanitizeHostname {
|
||||
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
m.Name = sanitizeHostname
|
||||
m.NameIdx = nameIdx
|
||||
@@ -236,7 +237,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
|
||||
}
|
||||
|
||||
if err := h.repository.SaveMachine(ctx, m); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
response := tailcfg.RegisterResponse{MachineAuthorized: true}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/jsiebens/ionscale/internal/bind"
|
||||
"github.com/jsiebens/ionscale/internal/config"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/jsiebens/ionscale/internal/util"
|
||||
"github.com/labstack/echo/v4"
|
||||
"net/http"
|
||||
@@ -36,12 +37,12 @@ func (h *SSHActionHandlers) StartAuth(c echo.Context) error {
|
||||
|
||||
binder, err := h.createBinder(c)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
data := new(sshActionRequestData)
|
||||
if err = c.Bind(data); err != nil {
|
||||
return c.String(http.StatusBadRequest, "bad request")
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
key := util.RandStringBytes(8)
|
||||
@@ -55,7 +56,7 @@ func (h *SSHActionHandlers) StartAuth(c echo.Context) error {
|
||||
authUrl := h.config.CreateUrl("/a/s/%s", key)
|
||||
|
||||
if err := h.repository.SaveSSHActionRequest(ctx, request); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
resp := &tailcfg.SSHAction{
|
||||
@@ -73,7 +74,7 @@ func (h *SSHActionHandlers) CheckAuth(c echo.Context) error {
|
||||
|
||||
binder, err := h.createBinder(c)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
tick := time.NewTicker(2 * time.Second)
|
||||
|
||||
@@ -3,11 +3,48 @@ package server
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/labstack/echo-contrib/prometheus"
|
||||
"github.com/labstack/echo/v4"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func EchoErrorHandler(logger hclog.Logger) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
request := c.Request()
|
||||
|
||||
if err := next(c); err != nil {
|
||||
switch t := err.(type) {
|
||||
case *echo.HTTPError:
|
||||
return err
|
||||
case *errors.Error:
|
||||
logger.Error("error processing request",
|
||||
"err", t.Cause,
|
||||
"location", t.Location,
|
||||
"http.method", request.Method,
|
||||
"http.uri", request.RequestURI,
|
||||
)
|
||||
default:
|
||||
logger.Error("error processing request",
|
||||
"err", err,
|
||||
"http.method", request.Method,
|
||||
"http.uri", request.RequestURI,
|
||||
)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(request.RequestURI, "/a/") {
|
||||
return c.Render(http.StatusInternalServerError, "error.html", nil)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func EchoLogger(logger hclog.Logger) echo.MiddlewareFunc {
|
||||
httpLogger := logger.Named("http")
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@@ -55,6 +92,15 @@ func EchoRecover(logger hclog.Logger) echo.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorRedirect() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("redirect_on_error", true)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func EchoMetrics(p *prometheus.Prometheus) echo.MiddlewareFunc {
|
||||
return p.HandlerFunc
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/key"
|
||||
"github.com/jsiebens/ionscale/internal/service"
|
||||
@@ -9,7 +10,7 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func NewRpcHandler(systemAdminKey *key.ServerPrivate, repository domain.Repository, handler apiconnect.IonscaleServiceHandler) (string, http.Handler) {
|
||||
interceptors := connect.WithInterceptors(service.AuthenticationInterceptor(systemAdminKey, repository))
|
||||
func NewRpcHandler(systemAdminKey *key.ServerPrivate, repository domain.Repository, logger hclog.Logger, handler apiconnect.IonscaleServiceHandler) (string, http.Handler) {
|
||||
interceptors := connect.WithInterceptors(service.NewErrorInterceptor(logger), service.AuthenticationInterceptor(systemAdminKey, repository))
|
||||
return apiconnect.NewIonscaleServiceHandler(handler, interceptors)
|
||||
}
|
||||
|
||||
+17
-16
@@ -92,16 +92,22 @@ func Start(c *config.Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
createPeerHandler := func(p key.MachinePublic) http.Handler {
|
||||
registrationHandlers := handlers.NewRegistrationHandlers(bind.DefaultBinder(p), c, brokers, repository)
|
||||
pollNetMapHandler := handlers.NewPollNetMapHandler(bind.DefaultBinder(p), brokers, repository, offlineTimers)
|
||||
dnsHandlers := handlers.NewDNSHandlers(bind.DefaultBinder(p), dnsProvider)
|
||||
idTokenHandlers := handlers.NewIDTokenHandlers(bind.DefaultBinder(p), c, repository)
|
||||
sshActionHandlers := handlers.NewSSHActionHandlers(bind.DefaultBinder(p), c, repository)
|
||||
p := echo_prometheus.NewPrometheus("http", nil)
|
||||
|
||||
metricsHandler := echo.New()
|
||||
p.SetMetricsPath(metricsHandler)
|
||||
|
||||
createPeerHandler := func(machinePublicKey key.MachinePublic) http.Handler {
|
||||
binder := bind.DefaultBinder(machinePublicKey)
|
||||
|
||||
registrationHandlers := handlers.NewRegistrationHandlers(binder, c, brokers, repository)
|
||||
pollNetMapHandler := handlers.NewPollNetMapHandler(binder, brokers, repository, offlineTimers)
|
||||
dnsHandlers := handlers.NewDNSHandlers(binder, dnsProvider)
|
||||
idTokenHandlers := handlers.NewIDTokenHandlers(binder, c, repository)
|
||||
sshActionHandlers := handlers.NewSSHActionHandlers(binder, c, repository)
|
||||
|
||||
e := echo.New()
|
||||
e.Use(EchoLogger(logger))
|
||||
e.Use(EchoRecover(logger))
|
||||
e.Use(EchoMetrics(p), EchoLogger(logger), EchoErrorHandler(logger), EchoRecover(logger))
|
||||
e.POST("/machine/register", registrationHandlers.Register)
|
||||
e.POST("/machine/map", pollNetMapHandler.PollNetMap)
|
||||
e.POST("/machine/set-dns", dnsHandlers.SetDNS)
|
||||
@@ -125,22 +131,17 @@ func Start(c *config.Config) error {
|
||||
)
|
||||
|
||||
rpcService := service.NewService(c, authProvider, repository, brokers)
|
||||
rpcPath, rpcHandler := NewRpcHandler(serverKey.SystemAdminKey, repository, rpcService)
|
||||
|
||||
p := echo_prometheus.NewPrometheus("http", nil)
|
||||
|
||||
metricsHandler := echo.New()
|
||||
p.SetMetricsPath(metricsHandler)
|
||||
rpcPath, rpcHandler := NewRpcHandler(serverKey.SystemAdminKey, repository, logger, rpcService)
|
||||
|
||||
nonTlsAppHandler := echo.New()
|
||||
nonTlsAppHandler.Use(EchoMetrics(p), EchoLogger(logger), EchoRecover(logger))
|
||||
nonTlsAppHandler.Use(EchoMetrics(p), EchoLogger(logger), EchoErrorHandler(logger), EchoRecover(logger))
|
||||
nonTlsAppHandler.POST("/ts2021", noiseHandlers.Upgrade)
|
||||
nonTlsAppHandler.Any("/*", handlers.HttpRedirectHandler(c.Tls))
|
||||
|
||||
tlsAppHandler := echo.New()
|
||||
tlsAppHandler.Renderer = templates.NewTemplates()
|
||||
tlsAppHandler.Pre(handlers.HttpsRedirect(c.Tls))
|
||||
tlsAppHandler.Use(EchoMetrics(p), EchoLogger(logger), EchoRecover(logger))
|
||||
tlsAppHandler.Use(EchoMetrics(p), EchoLogger(logger), EchoErrorHandler(logger), EchoRecover(logger))
|
||||
|
||||
tlsAppHandler.Any("/*", handlers.IndexHandler(http.StatusNotFound))
|
||||
tlsAppHandler.Any("/", handlers.IndexHandler(http.StatusOK))
|
||||
|
||||
@@ -2,11 +2,11 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/jsiebens/ionscale/internal/broker"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/jsiebens/ionscale/internal/mapping"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
)
|
||||
@@ -14,12 +14,12 @@ import (
|
||||
func (s *Service) GetACLPolicy(ctx context.Context, req *connect.Request[api.GetACLPolicyRequest]) (*connect.Response[api.GetACLPolicyResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
|
||||
@@ -27,7 +27,7 @@ func (s *Service) GetACLPolicy(ctx context.Context, req *connect.Request[api.Get
|
||||
|
||||
var policy api.ACLPolicy
|
||||
if err := mapping.CopyViaJson(&tailnet.ACLPolicy, &policy); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&api.GetACLPolicyResponse{Policy: &policy}), nil
|
||||
@@ -36,12 +36,12 @@ func (s *Service) GetACLPolicy(ctx context.Context, req *connect.Request[api.Get
|
||||
func (s *Service) SetACLPolicy(ctx context.Context, req *connect.Request[api.SetACLPolicyRequest]) (*connect.Response[api.SetACLPolicyResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
|
||||
@@ -49,12 +49,12 @@ func (s *Service) SetACLPolicy(ctx context.Context, req *connect.Request[api.Set
|
||||
|
||||
var policy domain.ACLPolicy
|
||||
if err := mapping.CopyViaJson(req.Msg.Policy, &policy); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
tailnet.ACLPolicy = policy
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{ACLUpdated: true})
|
||||
|
||||
@@ -2,9 +2,10 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/jsiebens/ionscale/internal/util"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
"time"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
|
||||
func (s *Service) Authenticate(ctx context.Context, req *connect.Request[api.AuthenticateRequest], stream *connect.ServerStream[api.AuthenticateResponse]) error {
|
||||
if s.authProvider == nil {
|
||||
return connect.NewError(connect.CodeFailedPrecondition, errors.New("no authentication method available, contact your ionscale administrator for more information"))
|
||||
return connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("no authentication method available, contact your ionscale administrator for more information"))
|
||||
}
|
||||
|
||||
key := util.RandStringBytes(8)
|
||||
@@ -24,11 +25,11 @@ func (s *Service) Authenticate(ctx context.Context, req *connect.Request[api.Aut
|
||||
}
|
||||
|
||||
if err := s.repository.SaveAuthenticationRequest(ctx, session); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if err := stream.Send(&api.AuthenticateResponse{AuthUrl: authUrl}); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
notify := ctx.Done()
|
||||
@@ -43,24 +44,27 @@ func (s *Service) Authenticate(ctx context.Context, req *connect.Request[api.Aut
|
||||
select {
|
||||
case <-tick.C:
|
||||
m, err := s.repository.GetAuthenticationRequest(ctx, key)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if err != nil || m == nil {
|
||||
return connect.NewError(connect.CodeInternal, errors.New("something went wrong"))
|
||||
if m == nil {
|
||||
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid authentication request"))
|
||||
}
|
||||
|
||||
if len(m.Token) != 0 {
|
||||
if err := stream.Send(&api.AuthenticateResponse{Token: m.Token, TailnetId: m.TailnetID}); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(m.Error) != 0 {
|
||||
return connect.NewError(connect.CodePermissionDenied, errors.New(m.Error))
|
||||
return connect.NewError(connect.CodePermissionDenied, fmt.Errorf(m.Error))
|
||||
}
|
||||
|
||||
if err := stream.Send(&api.AuthenticateResponse{AuthUrl: authUrl}); err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
case <-notify:
|
||||
|
||||
@@ -2,9 +2,10 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"time"
|
||||
@@ -15,15 +16,15 @@ func (s *Service) GetAuthKey(ctx context.Context, req *connect.Request[api.GetAu
|
||||
|
||||
key, err := s.repository.GetAuthKey(ctx, req.Msg.AuthKeyId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if key == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("auth key not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("auth key not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(key.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
var expiresAt *timestamppb.Timestamp
|
||||
@@ -74,16 +75,16 @@ func mapAuthKeysToApi(authKeys []domain.AuthKey) []*api.AuthKey {
|
||||
func (s *Service) ListAuthKeys(ctx context.Context, req *connect.Request[api.ListAuthKeysRequest]) (*connect.Response[api.ListAuthKeysResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
response := api.ListAuthKeysResponse{}
|
||||
@@ -91,7 +92,7 @@ func (s *Service) ListAuthKeys(ctx context.Context, req *connect.Request[api.Lis
|
||||
if principal.IsSystemAdmin() {
|
||||
authKeys, err := s.repository.ListAuthKeys(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
response.AuthKeys = mapAuthKeysToApi(authKeys)
|
||||
@@ -101,7 +102,7 @@ func (s *Service) ListAuthKeys(ctx context.Context, req *connect.Request[api.Lis
|
||||
if principal.User != nil {
|
||||
authKeys, err := s.repository.ListAuthKeysByTailnetAndUser(ctx, req.Msg.TailnetId, principal.User.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
response.AuthKeys = mapAuthKeysToApi(authKeys)
|
||||
@@ -114,11 +115,11 @@ func (s *Service) ListAuthKeys(ctx context.Context, req *connect.Request[api.Lis
|
||||
func (s *Service) CreateAuthKey(ctx context.Context, req *connect.Request[api.CreateAuthKeyRequest]) (*connect.Response[api.CreateAuthKeyResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
if principal.User == nil && len(req.Msg.Tags) == 0 {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("at least one tag is required when creating an auth key"))
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("at least one tag is required when creating an auth key"))
|
||||
}
|
||||
|
||||
if err := domain.CheckTags(req.Msg.Tags); err != nil {
|
||||
@@ -127,11 +128,11 @@ func (s *Service) CreateAuthKey(ctx context.Context, req *connect.Request[api.Cr
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() {
|
||||
@@ -154,7 +155,7 @@ func (s *Service) CreateAuthKey(ctx context.Context, req *connect.Request[api.Cr
|
||||
if user == nil {
|
||||
u, _, err := s.repository.GetOrCreateServiceUser(ctx, tailnet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
user = u
|
||||
}
|
||||
@@ -164,7 +165,7 @@ func (s *Service) CreateAuthKey(ctx context.Context, req *connect.Request[api.Cr
|
||||
v, authKey := domain.CreateAuthKey(tailnet, user, req.Msg.Ephemeral, req.Msg.PreAuthorized, tags, expiresAt)
|
||||
|
||||
if err := s.repository.SaveAuthKey(ctx, authKey); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
response := api.CreateAuthKeyResponse{
|
||||
@@ -190,19 +191,19 @@ func (s *Service) DeleteAuthKey(ctx context.Context, req *connect.Request[api.De
|
||||
|
||||
key, err := s.repository.GetAuthKey(ctx, req.Msg.AuthKeyId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if key == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("auth key not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("auth key not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(key.UserID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
if _, err := s.repository.DeleteAuthKey(ctx, req.Msg.AuthKeyId); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
return connect.NewResponse(&api.DeleteAuthKeyResponse{}), nil
|
||||
}
|
||||
|
||||
@@ -3,10 +3,11 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/jsiebens/ionscale/internal/broker"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/jsiebens/ionscale/internal/util"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -15,17 +16,17 @@ import (
|
||||
func (s *Service) GetDefaultDERPMap(ctx context.Context, _ *connect.Request[api.GetDefaultDERPMapRequest]) (*connect.Response[api.GetDefaultDERPMapResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
dm, err := s.repository.GetDERPMap(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(dm.DERPMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&api.GetDefaultDERPMapResponse{Value: raw}), nil
|
||||
@@ -34,12 +35,12 @@ func (s *Service) GetDefaultDERPMap(ctx context.Context, _ *connect.Request[api.
|
||||
func (s *Service) SetDefaultDERPMap(ctx context.Context, req *connect.Request[api.SetDefaultDERPMapRequest]) (*connect.Response[api.SetDefaultDERPMapResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
var derpMap tailcfg.DERPMap
|
||||
if err := json.Unmarshal(req.Msg.Value, &derpMap); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
dp := domain.DERPMap{
|
||||
@@ -48,12 +49,12 @@ func (s *Service) SetDefaultDERPMap(ctx context.Context, req *connect.Request[ap
|
||||
}
|
||||
|
||||
if err := s.repository.SetDERPMap(ctx, &dp); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
tailnets, err := s.repository.ListTailnets(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
for _, t := range tailnets {
|
||||
@@ -66,18 +67,18 @@ func (s *Service) SetDefaultDERPMap(ctx context.Context, req *connect.Request[ap
|
||||
func (s *Service) ResetDefaultDERPMap(ctx context.Context, req *connect.Request[api.ResetDefaultDERPMapRequest]) (*connect.Response[api.ResetDefaultDERPMapResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
dp := domain.DERPMap{}
|
||||
|
||||
if err := s.repository.SetDERPMap(ctx, &dp); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
tailnets, err := s.repository.ListTailnets(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
for _, t := range tailnets {
|
||||
|
||||
+17
-17
@@ -2,27 +2,27 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/jsiebens/ionscale/internal/broker"
|
||||
"github.com/jsiebens/ionscale/internal/config"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
)
|
||||
|
||||
func (s *Service) GetDNSConfig(ctx context.Context, req *connect.Request[api.GetDNSConfigRequest]) (*connect.Response[api.GetDNSConfigResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
dnsConfig := tailnet.DNSConfig
|
||||
@@ -44,17 +44,17 @@ func (s *Service) GetDNSConfig(ctx context.Context, req *connect.Request[api.Get
|
||||
func (s *Service) SetDNSConfig(ctx context.Context, req *connect.Request[api.SetDNSConfigRequest]) (*connect.Response[api.SetDNSConfigResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
dnsConfig := req.Msg.Config
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
tailnet.DNSConfig = domain.DNSConfig{
|
||||
@@ -65,7 +65,7 @@ func (s *Service) SetDNSConfig(ctx context.Context, req *connect.Request[api.Set
|
||||
}
|
||||
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{DNSUpdated: true})
|
||||
@@ -78,24 +78,24 @@ func (s *Service) SetDNSConfig(ctx context.Context, req *connect.Request[api.Set
|
||||
func (s *Service) EnableHttpsCertificates(ctx context.Context, req *connect.Request[api.EnableHttpsCertificatesRequest]) (*connect.Response[api.EnableHttpsCertificatesResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
if !tailnet.DNSConfig.MagicDNS {
|
||||
return nil, connect.NewError(connect.CodeFailedPrecondition, errors.New("MagicDNS must be enabled for this tailnet"))
|
||||
return nil, connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("MagicDNS must be enabled for this tailnet"))
|
||||
}
|
||||
|
||||
tailnet.DNSConfig.HttpsCertsEnabled = true
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{DNSUpdated: true})
|
||||
@@ -106,21 +106,21 @@ func (s *Service) EnableHttpsCertificates(ctx context.Context, req *connect.Requ
|
||||
func (s *Service) DisableHttpsCertificates(ctx context.Context, req *connect.Request[api.DisableHttpsCertificatesRequest]) (*connect.Response[api.DisableHttpsCertificatesResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
tailnet.DNSConfig.HttpsCertsEnabled = false
|
||||
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{DNSUpdated: true})
|
||||
|
||||
@@ -2,22 +2,22 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
)
|
||||
|
||||
func (s *Service) GetIAMPolicy(ctx context.Context, req *connect.Request[api.GetIAMPolicyRequest]) (*connect.Response[api.GetIAMPolicyResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
|
||||
@@ -36,12 +36,12 @@ func (s *Service) GetIAMPolicy(ctx context.Context, req *connect.Request[api.Get
|
||||
func (s *Service) SetIAMPolicy(ctx context.Context, req *connect.Request[api.SetIAMPolicyRequest]) (*connect.Response[api.SetIAMPolicyResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
|
||||
@@ -55,7 +55,7 @@ func (s *Service) SetIAMPolicy(ctx context.Context, req *connect.Request[api.Set
|
||||
}
|
||||
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&api.SetIAMPolicyResponse{}), nil
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/jsiebens/ionscale/internal/key"
|
||||
"github.com/jsiebens/ionscale/internal/token"
|
||||
"strings"
|
||||
@@ -75,3 +77,55 @@ func exchangeToken(ctx context.Context, systemAdminKey *key.ServerPrivate, repos
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewErrorInterceptor(logger hclog.Logger) *ErrorInterceptor {
|
||||
return &ErrorInterceptor{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
type ErrorInterceptor struct {
|
||||
logger hclog.Logger
|
||||
}
|
||||
|
||||
func (e *ErrorInterceptor) handleError(err error) error {
|
||||
if err == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t := err.(type) {
|
||||
case *connect.Error:
|
||||
return err
|
||||
case *errors.Error:
|
||||
e.logger.Error("error processing grpc request",
|
||||
"err", t.Cause,
|
||||
"location", t.Location,
|
||||
)
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("internal server error"))
|
||||
default:
|
||||
e.logger.Error("error processing grpc request",
|
||||
"err", err,
|
||||
)
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("internal server error"))
|
||||
}
|
||||
|
||||
}
|
||||
func (e *ErrorInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
response, err := next(ctx, request)
|
||||
return response, e.handleError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ErrorInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
|
||||
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
|
||||
return next(ctx, spec)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ErrorInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
|
||||
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
|
||||
err := next(ctx, conn)
|
||||
return e.handleError(err)
|
||||
}
|
||||
}
|
||||
|
||||
+47
-47
@@ -2,12 +2,12 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/jsiebens/ionscale/internal/broker"
|
||||
"github.com/jsiebens/ionscale/internal/config"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"net/netip"
|
||||
@@ -64,20 +64,20 @@ func (s *Service) machineToApi(m *domain.Machine) *api.Machine {
|
||||
func (s *Service) ListMachines(ctx context.Context, req *connect.Request[api.ListMachinesRequest]) (*connect.Response[api.ListMachinesResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
machines, err := s.repository.ListMachineByTailnet(ctx, tailnet.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
response := &api.ListMachinesResponse{}
|
||||
@@ -93,15 +93,15 @@ func (s *Service) GetMachine(ctx context.Context, req *connect.Request[api.GetMa
|
||||
|
||||
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&api.GetMachineResponse{Machine: s.machineToApi(m)}), nil
|
||||
@@ -112,19 +112,19 @@ func (s *Service) DeleteMachine(ctx context.Context, req *connect.Request[api.De
|
||||
|
||||
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
if _, err := s.repository.DeleteMachine(ctx, req.Msg.MachineId); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeersRemoved: []uint64{m.ID}})
|
||||
@@ -137,15 +137,15 @@ func (s *Service) ExpireMachine(ctx context.Context, req *connect.Request[api.Ex
|
||||
|
||||
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
timestamp := time.Unix(123, 0)
|
||||
@@ -153,7 +153,7 @@ func (s *Service) ExpireMachine(ctx context.Context, req *connect.Request[api.Ex
|
||||
m.KeyExpiryDisabled = false
|
||||
|
||||
if err := s.repository.SaveMachine(ctx, m); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
|
||||
@@ -166,21 +166,21 @@ func (s *Service) AuthorizeMachine(ctx context.Context, req *connect.Request[api
|
||||
|
||||
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
if !m.Authorized {
|
||||
m.Authorized = true
|
||||
if err := s.repository.SaveMachine(ctx, m); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,15 +194,15 @@ func (s *Service) GetMachineRoutes(ctx context.Context, req *connect.Request[api
|
||||
|
||||
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
response := api.GetMachineRoutesResponse{
|
||||
@@ -223,15 +223,15 @@ func (s *Service) EnableMachineRoutes(ctx context.Context, req *connect.Request[
|
||||
|
||||
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
var allowIPs = domain.NewAllowIPsSet(m.AllowIPs)
|
||||
@@ -245,7 +245,7 @@ func (s *Service) EnableMachineRoutes(ctx context.Context, req *connect.Request[
|
||||
for _, r := range req.Msg.Routes {
|
||||
prefix, err := netip.ParsePrefix(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
allowIPs.Add(prefix)
|
||||
}
|
||||
@@ -253,7 +253,7 @@ func (s *Service) EnableMachineRoutes(ctx context.Context, req *connect.Request[
|
||||
m.AllowIPs = allowIPs.Items()
|
||||
m.AutoAllowIPs = autoAllowIPs.Items()
|
||||
if err := s.repository.SaveMachine(ctx, m); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
|
||||
@@ -276,15 +276,15 @@ func (s *Service) DisableMachineRoutes(ctx context.Context, req *connect.Request
|
||||
|
||||
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
allowIPs := domain.NewAllowIPsSet(m.AllowIPs)
|
||||
@@ -293,7 +293,7 @@ func (s *Service) DisableMachineRoutes(ctx context.Context, req *connect.Request
|
||||
for _, r := range req.Msg.Routes {
|
||||
prefix, err := netip.ParsePrefix(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
allowIPs.Remove(prefix)
|
||||
autoAllowIPs.Remove(prefix)
|
||||
@@ -302,7 +302,7 @@ func (s *Service) DisableMachineRoutes(ctx context.Context, req *connect.Request
|
||||
m.AllowIPs = allowIPs.Items()
|
||||
m.AutoAllowIPs = autoAllowIPs.Items()
|
||||
if err := s.repository.SaveMachine(ctx, m); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
|
||||
@@ -325,19 +325,19 @@ func (s *Service) EnableExitNode(ctx context.Context, req *connect.Request[api.E
|
||||
|
||||
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
if !m.IsAdvertisedExitNode() {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("machine is not a valid exit node"))
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("machine is not a valid exit node"))
|
||||
}
|
||||
|
||||
prefix4 := netip.MustParsePrefix("0.0.0.0/0")
|
||||
@@ -349,7 +349,7 @@ func (s *Service) EnableExitNode(ctx context.Context, req *connect.Request[api.E
|
||||
m.AllowIPs = allowIPs.Items()
|
||||
|
||||
if err := s.repository.SaveMachine(ctx, m); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
|
||||
@@ -372,19 +372,19 @@ func (s *Service) DisableExitNode(ctx context.Context, req *connect.Request[api.
|
||||
|
||||
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
if !m.IsAdvertisedExitNode() {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("machine is not a valid exit node"))
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("machine is not a valid exit node"))
|
||||
}
|
||||
|
||||
prefix4 := netip.MustParsePrefix("0.0.0.0/0")
|
||||
@@ -400,7 +400,7 @@ func (s *Service) DisableExitNode(ctx context.Context, req *connect.Request[api.
|
||||
m.AutoAllowIPs = autoAllowIPs.Items()
|
||||
|
||||
if err := s.repository.SaveMachine(ctx, m); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
|
||||
@@ -423,21 +423,21 @@ func (s *Service) SetMachineKeyExpiry(ctx context.Context, req *connect.Request[
|
||||
|
||||
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
m.KeyExpiryDisabled = req.Msg.Disabled
|
||||
|
||||
if err := s.repository.SaveMachine(ctx, m); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
|
||||
|
||||
+58
-58
@@ -3,11 +3,11 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/jsiebens/ionscale/internal/broker"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
"github.com/jsiebens/ionscale/internal/util"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
func (s *Service) CreateTailnet(ctx context.Context, req *connect.Request[api.CreateTailnetRequest]) (*connect.Response[api.CreateTailnetResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
name := req.Msg.Name
|
||||
@@ -38,7 +38,7 @@ func (s *Service) CreateTailnet(ctx context.Context, req *connect.Request[api.Cr
|
||||
}
|
||||
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
resp := &api.CreateTailnetResponse{Tailnet: &api.Tailnet{
|
||||
@@ -52,16 +52,16 @@ func (s *Service) CreateTailnet(ctx context.Context, req *connect.Request[api.Cr
|
||||
func (s *Service) GetTailnet(ctx context.Context, req *connect.Request[api.GetTailnetRequest]) (*connect.Response[api.GetTailnetResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.Id) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&api.GetTailnetResponse{Tailnet: &api.Tailnet{
|
||||
@@ -78,7 +78,7 @@ func (s *Service) ListTailnets(ctx context.Context, req *connect.Request[api.Lis
|
||||
if principal.IsSystemAdmin() {
|
||||
tailnets, err := s.repository.ListTailnets(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
for _, t := range tailnets {
|
||||
gt := api.Tailnet{Id: t.ID, Name: t.Name}
|
||||
@@ -89,7 +89,7 @@ func (s *Service) ListTailnets(ctx context.Context, req *connect.Request[api.Lis
|
||||
if principal.User != nil {
|
||||
tailnet, err := s.repository.GetTailnet(ctx, principal.User.TailnetID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
gt := api.Tailnet{Id: tailnet.ID, Name: tailnet.Name}
|
||||
resp.Tailnet = append(resp.Tailnet, >)
|
||||
@@ -101,12 +101,12 @@ func (s *Service) ListTailnets(ctx context.Context, req *connect.Request[api.Lis
|
||||
func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.DeleteTailnetRequest]) (*connect.Response[api.DeleteTailnetResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
count, err := s.repository.CountMachineByTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if !req.Msg.Force && count > 0 {
|
||||
@@ -138,7 +138,7 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.De
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(req.Msg.TailnetId, &broker.Signal{})
|
||||
@@ -149,20 +149,20 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.De
|
||||
func (s *Service) SetDERPMap(ctx context.Context, req *connect.Request[api.SetDERPMapRequest]) (*connect.Response[api.SetDERPMapResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
derpMap := tailcfg.DERPMap{}
|
||||
if err := json.Unmarshal(req.Msg.Value, &derpMap); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
tailnet.DERPMap = domain.DERPMap{
|
||||
@@ -171,14 +171,14 @@ func (s *Service) SetDERPMap(ctx context.Context, req *connect.Request[api.SetDE
|
||||
}
|
||||
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
|
||||
|
||||
raw, err := json.Marshal(derpMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&api.SetDERPMapResponse{Value: raw}), nil
|
||||
@@ -187,21 +187,21 @@ func (s *Service) SetDERPMap(ctx context.Context, req *connect.Request[api.SetDE
|
||||
func (s *Service) ResetDERPMap(ctx context.Context, req *connect.Request[api.ResetDERPMapRequest]) (*connect.Response[api.ResetDERPMapResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
tailnet.DERPMap = domain.DERPMap{}
|
||||
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
|
||||
@@ -212,25 +212,25 @@ func (s *Service) ResetDERPMap(ctx context.Context, req *connect.Request[api.Res
|
||||
func (s *Service) GetDERPMap(ctx context.Context, req *connect.Request[api.GetDERPMapRequest]) (*connect.Response[api.GetDERPMapResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
derpMap, err := tailnet.GetDERPMap(ctx, s.repository)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(derpMap.DERPMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&api.GetDERPMapResponse{Value: raw}), nil
|
||||
@@ -239,21 +239,21 @@ func (s *Service) GetDERPMap(ctx context.Context, req *connect.Request[api.GetDE
|
||||
func (s *Service) EnableFileSharing(ctx context.Context, req *connect.Request[api.EnableFileSharingRequest]) (*connect.Response[api.EnableFileSharingResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
if !tailnet.FileSharingEnabled {
|
||||
tailnet.FileSharingEnabled = true
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
|
||||
@@ -265,21 +265,21 @@ func (s *Service) EnableFileSharing(ctx context.Context, req *connect.Request[ap
|
||||
func (s *Service) DisableFileSharing(ctx context.Context, req *connect.Request[api.DisableFileSharingRequest]) (*connect.Response[api.DisableFileSharingResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
if tailnet.FileSharingEnabled {
|
||||
tailnet.FileSharingEnabled = false
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
|
||||
@@ -291,21 +291,21 @@ func (s *Service) DisableFileSharing(ctx context.Context, req *connect.Request[a
|
||||
func (s *Service) EnableServiceCollection(ctx context.Context, req *connect.Request[api.EnableServiceCollectionRequest]) (*connect.Response[api.EnableServiceCollectionResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
if !tailnet.ServiceCollectionEnabled {
|
||||
tailnet.ServiceCollectionEnabled = true
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
|
||||
@@ -317,21 +317,21 @@ func (s *Service) EnableServiceCollection(ctx context.Context, req *connect.Requ
|
||||
func (s *Service) DisableServiceCollection(ctx context.Context, req *connect.Request[api.DisableServiceCollectionRequest]) (*connect.Response[api.DisableServiceCollectionResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
if tailnet.ServiceCollectionEnabled {
|
||||
tailnet.ServiceCollectionEnabled = false
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
|
||||
@@ -343,21 +343,21 @@ func (s *Service) DisableServiceCollection(ctx context.Context, req *connect.Req
|
||||
func (s *Service) EnableSSH(ctx context.Context, req *connect.Request[api.EnableSSHRequest]) (*connect.Response[api.EnableSSHResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
if !tailnet.SSHEnabled {
|
||||
tailnet.SSHEnabled = true
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
|
||||
@@ -369,21 +369,21 @@ func (s *Service) EnableSSH(ctx context.Context, req *connect.Request[api.Enable
|
||||
func (s *Service) DisableSSH(ctx context.Context, req *connect.Request[api.DisableSSHRequest]) (*connect.Response[api.DisableSSHResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
if tailnet.SSHEnabled {
|
||||
tailnet.SSHEnabled = false
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
|
||||
@@ -395,21 +395,21 @@ func (s *Service) DisableSSH(ctx context.Context, req *connect.Request[api.Disab
|
||||
func (s *Service) EnableMachineAuthorization(ctx context.Context, req *connect.Request[api.EnableMachineAuthorizationRequest]) (*connect.Response[api.EnableMachineAuthorizationResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
if !tailnet.MachineAuthorizationEnabled {
|
||||
tailnet.MachineAuthorizationEnabled = true
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -419,21 +419,21 @@ func (s *Service) EnableMachineAuthorization(ctx context.Context, req *connect.R
|
||||
func (s *Service) DisableMachineAuthorization(ctx context.Context, req *connect.Request[api.DisableMachineAuthorizationRequest]) (*connect.Response[api.DisableMachineAuthorizationResponse], error) {
|
||||
principal := CurrentPrincipal(ctx)
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
if tailnet.MachineAuthorizationEnabled {
|
||||
tailnet.MachineAuthorizationEnabled = false
|
||||
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+11
-10
@@ -2,10 +2,11 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bufbuild/connect-go"
|
||||
"github.com/jsiebens/ionscale/internal/broker"
|
||||
"github.com/jsiebens/ionscale/internal/domain"
|
||||
"github.com/jsiebens/ionscale/internal/errors"
|
||||
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
|
||||
)
|
||||
|
||||
@@ -14,20 +15,20 @@ func (s *Service) ListUsers(ctx context.Context, req *connect.Request[api.ListUs
|
||||
|
||||
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if tailnet == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(tailnet.ID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
users, err := s.repository.ListUsers(ctx, tailnet.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
resp := &api.ListUsersResponse{}
|
||||
@@ -46,20 +47,20 @@ func (s *Service) DeleteUser(ctx context.Context, req *connect.Request[api.Delet
|
||||
principal := CurrentPrincipal(ctx)
|
||||
|
||||
if !principal.IsSystemAdmin() && principal.UserMatches(req.Msg.UserId) {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("unable delete yourself"))
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("unable delete yourself"))
|
||||
}
|
||||
|
||||
user, err := s.repository.GetUser(ctx, req.Msg.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("user not found"))
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("user not found"))
|
||||
}
|
||||
|
||||
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(user.TailnetID) {
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
|
||||
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
err = s.repository.Transaction(func(tx domain.Repository) error {
|
||||
@@ -83,7 +84,7 @@ func (s *Service) DeleteUser(ctx context.Context, req *connect.Request[api.Delet
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
|
||||
s.pubsub.Publish(user.TailnetID, &broker.Signal{})
|
||||
|
||||
Reference in New Issue
Block a user