feat: improve error handling/logging a little bit

This commit is contained in:
Johan Siebens
2022-11-23 11:03:42 +01:00
parent c8b040fcd6
commit 2345f0b1de
23 changed files with 460 additions and 299 deletions
+17 -15
View File
@@ -157,21 +157,23 @@ func (g *GormLoggerAdapter) Error(ctx context.Context, s string, i ...interface{
}
func (g *GormLoggerAdapter) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
elapsed := time.Since(begin)
switch {
case err != nil && !errors.Is(err, gorm.ErrRecordNotFound):
sql, rows := fc()
if rows == -1 {
g.logger.Error("Error executing query", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "err", err)
} else {
g.logger.Error("Error executing query", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "rows", rows, "err", err)
}
case g.logger.IsTrace():
sql, rows := fc()
if rows == -1 {
g.logger.Trace("Statement executed", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed)
} else {
g.logger.Trace("Statement executed", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "rows", rows)
if g.logger.IsTrace() {
elapsed := time.Since(begin)
switch {
case err != nil && !errors.Is(err, gorm.ErrRecordNotFound):
sql, rows := fc()
if rows == -1 {
g.logger.Trace("Error executing query", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "err", err)
} else {
g.logger.Trace("Error executing query", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "rows", rows, "err", err)
}
default:
sql, rows := fc()
if rows == -1 {
g.logger.Trace("Statement executed", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed)
} else {
g.logger.Trace("Statement executed", "sql", sql, "start_time", begin.Format(time.RFC3339), "duration", elapsed, "rows", rows)
}
}
}
}
+42
View File
@@ -0,0 +1,42 @@
package errors
import (
"fmt"
"runtime"
)
type Error struct {
Cause error
Location string
}
func Wrap(err error, skip int) error {
if err == nil {
return nil
}
c := &Error{
Cause: err,
Location: getLocation(skip),
}
return c
}
func (w *Error) Error() string {
return w.Cause.Error()
}
func (f *Error) Unwrap() error {
return f.Cause
}
func (f *Error) Format(s fmt.State, verb rune) {
fmt.Fprintf(s, "%s\n", f.Cause.Error())
fmt.Fprintf(s, "\t%s\n", f.Location)
}
func getLocation(skip int) string {
_, file, line, _ := runtime.Caller(2 + skip)
return fmt.Sprintf("%s:%d", file, line)
}
+41 -39
View File
@@ -3,8 +3,10 @@ package handlers
import (
"context"
"encoding/json"
"fmt"
"github.com/jsiebens/ionscale/internal/addr"
"github.com/jsiebens/ionscale/internal/auth"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/labstack/echo/v4/middleware"
"github.com/mr-tron/base58"
"net/http"
@@ -64,7 +66,7 @@ func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
// machine registration auth flow
if flow == "r" || flow == "" {
if req, err := h.repository.GetRegistrationRequestByKey(ctx, key); err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
csrf := c.Get(middleware.DefaultCSRFConfig.ContextKey).(string)
@@ -74,24 +76,24 @@ func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
// cli auth flow
if flow == "c" {
if s, err := h.repository.GetAuthenticationRequest(ctx, key); err != nil || s == nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
}
// ssh check auth flow
if flow == "s" {
if s, err := h.repository.GetSSHActionRequest(ctx, key); err != nil || s == nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
}
if h.authProvider == nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(fmt.Errorf("unable to start auth flow as no auth provider is configured"), 0)
}
state, err := h.createState(flow, key)
if err != nil {
return err
return errors.Wrap(err, 0)
}
redirectUrl := h.authProvider.GetLoginURL(h.config.CreateUrl("/a/callback"), state)
@@ -108,7 +110,7 @@ func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error {
req, err := h.repository.GetRegistrationRequestByKey(ctx, key)
if err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
if authKey != "" {
@@ -118,7 +120,7 @@ func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error {
if interactive != "" {
state, err := h.createState("r", key)
if err != nil {
return err
return errors.Wrap(err, 0)
}
redirectUrl := h.authProvider.GetLoginURL(h.config.CreateUrl("/a/callback"), state)
@@ -135,17 +137,17 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
code := c.QueryParam("code")
state, err := h.readState(c.QueryParam("state"))
if err != nil {
return err
return echo.NewHTTPError(http.StatusBadRequest, "Invalid state parameter")
}
user, err := h.exchangeUser(code)
if err != nil {
return err
return errors.Wrap(err, 0)
}
account, _, err := h.repository.GetOrCreateAccount(ctx, user.ID, user.Name)
if err != nil {
return err
return errors.Wrap(err, 0)
}
if state.Flow == "s" {
@@ -156,27 +158,27 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
machine, err := h.repository.GetMachine(ctx, sshActionReq.SrcMachineID)
if err != nil || sshActionReq == nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
if !machine.HasTags() && machine.User.AccountID != nil && *machine.User.AccountID == account.ID {
sshActionReq.Action = "accept"
if err := h.repository.SaveSSHActionRequest(ctx, sshActionReq); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
return c.Redirect(http.StatusFound, "/a/success")
}
sshActionReq.Action = "reject"
if err := h.repository.SaveSSHActionRequest(ctx, sshActionReq); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
return c.Redirect(http.StatusFound, "/a/error?e=nmo")
}
tailnets, err := h.listAvailableTailnets(ctx, user)
if err != nil {
return err
return errors.Wrap(err, 0)
}
csrf := c.Get(middleware.DefaultCSRFConfig.ContextKey).(string)
@@ -201,7 +203,7 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
if state.Flow == "c" {
isSystemAdmin, err := h.isSystemAdmin(ctx, user)
if err != nil {
return err
return errors.Wrap(err, 0)
}
if !isSystemAdmin && len(tailnets) == 0 {
@@ -220,7 +222,7 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
})
}
return c.Redirect(http.StatusFound, "/a/error")
return echo.NewHTTPError(http.StatusNotFound)
}
func (h *AuthenticationHandlers) isSystemAdmin(ctx context.Context, u *auth.User) (bool, error) {
@@ -250,13 +252,13 @@ func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
state, err := h.readState(c.QueryParam("state"))
if err != nil {
return c.Redirect(http.StatusFound, "/a/error")
return echo.NewHTTPError(http.StatusBadRequest, "Invalid state parameter")
}
if state.Flow == "r" {
req, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
if err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
return h.endMachineRegistrationFlow(c, req, state)
@@ -264,7 +266,7 @@ func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
req, err := h.repository.GetAuthenticationRequest(ctx, state.Key)
if err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
return h.endCliAuthenticationFlow(c, req, state)
@@ -306,12 +308,12 @@ func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *d
var form TailnetSelectionForm
if err := c.Bind(&form); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
account, err := h.repository.GetAccount(ctx, form.AccountID)
if err != nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
// continue as system admin?
@@ -322,27 +324,27 @@ func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *d
err := h.repository.Transaction(func(rp domain.Repository) error {
if err := rp.SaveSystemApiKey(ctx, apiKey); err != nil {
return err
return errors.Wrap(err, 0)
}
if err := rp.SaveAuthenticationRequest(ctx, req); err != nil {
return err
return errors.Wrap(err, 0)
}
return nil
})
if err != nil {
return err
return errors.Wrap(err, 0)
}
return c.Redirect(http.StatusFound, "/a/success")
}
tailnet, err := h.repository.GetTailnet(ctx, form.TailnetID)
if err != nil {
return err
return errors.Wrap(err, 0)
}
user, _, err := h.repository.GetOrCreateUserWithAccount(ctx, tailnet, account)
if err != nil {
return err
return errors.Wrap(err, 0)
}
expiresAt := time.Now().Add(24 * time.Hour)
@@ -360,7 +362,7 @@ func (h *AuthenticationHandlers) endCliAuthenticationFlow(c echo.Context, req *d
return nil
})
if err != nil {
return err
return errors.Wrap(err, 0)
}
return c.Redirect(http.StatusFound, "/a/success")
@@ -371,7 +373,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
var form TailnetSelectionForm
if err := c.Bind(&form); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
req := tailcfg.RegisterRequest(registrationRequest.Data)
@@ -387,7 +389,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
if form.AuthKey != "" {
authKey, err := h.repository.LoadAuthKey(ctx, form.AuthKey)
if err != nil {
return err
return errors.Wrap(err, 0)
}
if authKey == nil {
@@ -396,7 +398,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
registrationRequest.Error = "invalid auth key"
if err := h.repository.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
return c.Redirect(http.StatusFound, "/a/error?e=iak")
@@ -410,17 +412,17 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
} else {
selectedTailnet, err := h.repository.GetTailnet(ctx, form.TailnetID)
if err != nil {
return err
return errors.Wrap(err, 0)
}
account, err := h.repository.GetAccount(ctx, form.AccountID)
if err != nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
selectedUser, _, err := h.repository.GetOrCreateUserWithAccount(ctx, selectedTailnet, account)
if err != nil {
return err
return errors.Wrap(err, 0)
}
user = selectedUser
@@ -432,7 +434,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
registrationRequest.Authenticated = false
registrationRequest.Error = err.Error()
if err := h.repository.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
return errors.Wrap(err, 0)
}
return c.Redirect(http.StatusFound, "/a/error?e=nto")
}
@@ -443,7 +445,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
m, err := h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
if err != nil {
return err
return errors.Wrap(err, 0)
}
now := time.Now().UTC()
@@ -456,7 +458,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
if err != nil {
return err
return errors.Wrap(err, 0)
}
m = &domain.Machine{
@@ -480,7 +482,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
ipv4, ipv6, err := addr.SelectIP(checkIP(ctx, h.repository.CountMachinesWithIPv4))
if err != nil {
return err
return errors.Wrap(err, 0)
}
m.IPv4 = domain.IP{Addr: ipv4}
m.IPv6 = domain.IP{Addr: ipv6}
@@ -493,7 +495,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
if m.Name != sanitizeHostname {
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
if err != nil {
return err
return errors.Wrap(err, 0)
}
m.Name = sanitizeHostname
m.NameIdx = nameIdx
@@ -526,7 +528,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, regi
})
if err != nil {
return err
return errors.Wrap(err, 0)
}
if m.Authorized {
+4 -3
View File
@@ -3,6 +3,7 @@ package handlers
import (
"github.com/jsiebens/ionscale/internal/bind"
"github.com/jsiebens/ionscale/internal/dns"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/labstack/echo/v4"
"net"
"net/http"
@@ -28,12 +29,12 @@ func (h *DNSHandlers) SetDNS(c echo.Context) error {
binder, err := h.createBinder(c)
if err != nil {
return err
return errors.Wrap(err, 0)
}
req := &tailcfg.SetDNSRequest{}
if err := binder.BindRequest(c, req); err != nil {
return err
return errors.Wrap(err, 0)
}
if h.provider == nil {
@@ -41,7 +42,7 @@ func (h *DNSHandlers) SetDNS(c echo.Context) error {
}
if err := h.provider.SetRecord(ctx, req.Type, req.Name, req.Value); err != nil {
return err
return errors.Wrap(err, 0)
}
if strings.HasPrefix(req.Name, "_acme-challenge") && req.Type == "TXT" {
+7 -6
View File
@@ -6,6 +6,7 @@ import (
"github.com/jsiebens/ionscale/internal/bind"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/jsiebens/ionscale/internal/util"
"github.com/labstack/echo/v4"
"gopkg.in/square/go-jose.v2"
@@ -55,7 +56,7 @@ func (h *IDTokenHandlers) OpenIDConfig(c echo.Context) error {
func (h *IDTokenHandlers) Jwks(c echo.Context) error {
keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context())
if err != nil {
return err
return errors.Wrap(err, 0)
}
pub := jose.JSONWebKey{Key: keySet.Key.Public(), KeyID: keySet.Key.Id, Algorithm: "RS256", Use: "sig"}
@@ -68,17 +69,17 @@ func (h *IDTokenHandlers) FetchToken(c echo.Context) error {
keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context())
if err != nil {
return err
return errors.Wrap(err, 0)
}
binder, err := h.createBinder(c)
if err != nil {
return err
return errors.Wrap(err, 0)
}
req := &tailcfg.TokenRequest{}
if err := binder.BindRequest(c, req); err != nil {
return err
return errors.Wrap(err, 0)
}
machineKey := binder.Peer().String()
@@ -87,7 +88,7 @@ func (h *IDTokenHandlers) FetchToken(c echo.Context) error {
var m *domain.Machine
m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
if err != nil {
return err
return errors.Wrap(err, 0)
}
if m == nil {
@@ -130,7 +131,7 @@ func (h *IDTokenHandlers) FetchToken(c echo.Context) error {
jwtB64, err := unsignedToken.SignedString(&keySet.Key.PrivateKey)
if err != nil {
return err
return errors.Wrap(err, 0)
}
resp := tailcfg.TokenResponse{IDToken: jwtB64}
+1 -1
View File
@@ -22,7 +22,7 @@ func KeyHandler(keys *config.ServerKeys) echo.HandlerFunc {
if v != "" {
clientCapabilityVersion, err := strconv.Atoi(v)
if err != nil {
return c.String(http.StatusBadRequest, "Invalid version")
return echo.NewHTTPError(http.StatusBadRequest, "Invalid version")
}
if clientCapabilityVersion >= NoiseCapabilityVersion {
+2 -1
View File
@@ -1,6 +1,7 @@
package handlers
import (
"github.com/jsiebens/ionscale/internal/errors"
"github.com/labstack/echo/v4"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
@@ -27,7 +28,7 @@ func NewNoiseHandlers(controlKey key.MachinePrivate, createPeerHandler CreatePee
func (h *NoiseHandlers) Upgrade(c echo.Context) error {
conn, err := controlhttp.AcceptHTTP(c.Request().Context(), c.Response(), c.Request(), h.controlKey)
if err != nil {
return err
return errors.Wrap(err, 0)
}
handler := h.createPeerHandler(conn.Peer())
+15 -14
View File
@@ -6,6 +6,7 @@ import (
"github.com/jsiebens/ionscale/internal/broker"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/jsiebens/ionscale/internal/mapping"
"github.com/labstack/echo/v4"
"net/http"
@@ -41,12 +42,12 @@ func (h *PollNetMapHandler) PollNetMap(c echo.Context) error {
ctx := c.Request().Context()
binder, err := h.createBinder(c)
if err != nil {
return err
return errors.Wrap(err, 0)
}
req := &tailcfg.MapRequest{}
if err := binder.BindRequest(c, req); err != nil {
return err
return errors.Wrap(err, 0)
}
machineKey := binder.Peer().String()
@@ -55,7 +56,7 @@ func (h *PollNetMapHandler) PollNetMap(c echo.Context) error {
var m *domain.Machine
m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
if err != nil {
return err
return errors.Wrap(err, 0)
}
if m == nil {
@@ -80,7 +81,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
m.LastSeen = &now
if err := h.repository.SaveMachine(ctx, m); err != nil {
return err
return errors.Wrap(err, 0)
}
tailnetID := m.TailnetID
@@ -97,14 +98,14 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
response, syncedPeers, derpMapChecksum, err := h.createMapResponse(m, binder, mapRequest, false, make(map[uint64]bool), derpMapChecksum)
if err != nil {
return err
return errors.Wrap(err, 0)
}
updateChan := make(chan *broker.Signal, 20)
unsubscribe, err := h.brokers.Subscribe(tailnetID, updateChan)
if err != nil {
return err
return errors.Wrap(err, 0)
}
h.cancelOfflineMessage(machineID)
@@ -113,7 +114,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
keepAliveResponse, err := h.createKeepAliveResponse(binder, mapRequest)
if err != nil {
return err
return errors.Wrap(err, 0)
}
keepAliveTicker := time.NewTicker(config.KeepAliveInterval())
syncTicker := time.NewTicker(5 * time.Second)
@@ -121,7 +122,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
c.Response().WriteHeader(http.StatusOK)
if _, err := c.Response().Write(response); err != nil {
return err
return errors.Wrap(err, 0)
}
c.Response().Flush()
@@ -146,7 +147,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
case <-keepAliveTicker.C:
if mapRequest.KeepAlive {
if _, err := c.Response().Write(keepAliveResponse); err != nil {
return err
return errors.Wrap(err, 0)
}
_ = h.repository.SetMachineLastSeen(ctx, machineID)
c.Response().Flush()
@@ -155,7 +156,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
if latestSync.Before(latestUpdate) {
machine, err := h.repository.GetMachine(ctx, machineID)
if err != nil {
return err
return errors.Wrap(err, 0)
}
if machine == nil {
return nil
@@ -171,7 +172,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
}
if _, err := c.Response().Write(payload); err != nil {
return err
return errors.Wrap(err, 0)
}
c.Response().Flush()
@@ -190,16 +191,16 @@ func (h *PollNetMapHandler) handleReadOnly(c echo.Context, binder bind.Binder, m
m.DiscoKey = request.DiscoKey.String()
if err := h.repository.SaveMachine(ctx, m); err != nil {
return err
return errors.Wrap(err, 0)
}
response, _, _, err := h.createMapResponse(m, binder, request, false, map[uint64]bool{}, "")
if err != nil {
return err
return errors.Wrap(err, 0)
}
_, err = c.Response().Write(response)
return err
return errors.Wrap(err, 0)
}
func (h *PollNetMapHandler) scheduleOfflineMessage(tailnetID, machineID uint64) {
+14 -13
View File
@@ -7,6 +7,7 @@ import (
"github.com/jsiebens/ionscale/internal/broker"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/jsiebens/ionscale/internal/util"
"github.com/labstack/echo/v4"
"net/http"
@@ -41,12 +42,12 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
binder, err := h.createBinder(c)
if err != nil {
return err
return errors.Wrap(err, 0)
}
req := &tailcfg.RegisterRequest{}
if err := binder.BindRequest(c, req); err != nil {
return err
return errors.Wrap(err, 0)
}
machineKey := binder.Peer().String()
@@ -56,7 +57,7 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
if err != nil {
return err
return errors.Wrap(err, 0)
}
if m != nil {
@@ -70,12 +71,12 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
if m.Ephemeral {
if _, err := h.repository.DeleteMachine(ctx, m.ID); err != nil {
return err
return errors.Wrap(err, 0)
}
h.pubsub.Publish(m.TailnetID, &broker.Signal{PeersRemoved: []uint64{m.ID}})
} else {
if err := h.repository.SaveMachine(ctx, m); err != nil {
return err
return errors.Wrap(err, 0)
}
h.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
}
@@ -88,7 +89,7 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
if m.Name != sanitizeHostname {
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, m.TailnetID, sanitizeHostname)
if err != nil {
return err
return errors.Wrap(err, 0)
}
m.Name = sanitizeHostname
m.NameIdx = nameIdx
@@ -99,7 +100,7 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
m.Tags = append(m.RegisteredTags, advertisedTags...)
if err := h.repository.SaveMachine(ctx, m); err != nil {
return err
return errors.Wrap(err, 0)
}
response := tailcfg.RegisterResponse{MachineAuthorized: true}
@@ -146,7 +147,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
authKey, err := h.repository.LoadAuthKey(ctx, req.Auth.AuthKey)
if err != nil {
return err
return errors.Wrap(err, 0)
}
if authKey == nil {
@@ -172,7 +173,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
m, err = h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey)
if err != nil {
return err
return errors.Wrap(err, 0)
}
now := time.Now().UTC()
@@ -181,7 +182,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
if err != nil {
return err
return errors.Wrap(err, 0)
}
m = &domain.Machine{
@@ -209,7 +210,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
ipv4, ipv6, err := addr.SelectIP(checkIP(ctx, h.repository.CountMachinesWithIPv4))
if err != nil {
return err
return errors.Wrap(err, 0)
}
m.IPv4 = domain.IP{Addr: ipv4}
m.IPv6 = domain.IP{Addr: ipv6}
@@ -218,7 +219,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
if m.Name != sanitizeHostname {
nameIdx, err := h.repository.GetNextMachineNameIndex(ctx, tailnet.ID, sanitizeHostname)
if err != nil {
return err
return errors.Wrap(err, 0)
}
m.Name = sanitizeHostname
m.NameIdx = nameIdx
@@ -236,7 +237,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
}
if err := h.repository.SaveMachine(ctx, m); err != nil {
return err
return errors.Wrap(err, 0)
}
response := tailcfg.RegisterResponse{MachineAuthorized: true}
+5 -4
View File
@@ -5,6 +5,7 @@ import (
"github.com/jsiebens/ionscale/internal/bind"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/jsiebens/ionscale/internal/util"
"github.com/labstack/echo/v4"
"net/http"
@@ -36,12 +37,12 @@ func (h *SSHActionHandlers) StartAuth(c echo.Context) error {
binder, err := h.createBinder(c)
if err != nil {
return err
return errors.Wrap(err, 0)
}
data := new(sshActionRequestData)
if err = c.Bind(data); err != nil {
return c.String(http.StatusBadRequest, "bad request")
return errors.Wrap(err, 0)
}
key := util.RandStringBytes(8)
@@ -55,7 +56,7 @@ func (h *SSHActionHandlers) StartAuth(c echo.Context) error {
authUrl := h.config.CreateUrl("/a/s/%s", key)
if err := h.repository.SaveSSHActionRequest(ctx, request); err != nil {
return err
return errors.Wrap(err, 0)
}
resp := &tailcfg.SSHAction{
@@ -73,7 +74,7 @@ func (h *SSHActionHandlers) CheckAuth(c echo.Context) error {
binder, err := h.createBinder(c)
if err != nil {
return err
return errors.Wrap(err, 0)
}
tick := time.NewTicker(2 * time.Second)
+46
View File
@@ -3,11 +3,48 @@ package server
import (
"fmt"
"github.com/hashicorp/go-hclog"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/labstack/echo-contrib/prometheus"
"github.com/labstack/echo/v4"
"net/http"
"strings"
"time"
)
func EchoErrorHandler(logger hclog.Logger) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
request := c.Request()
if err := next(c); err != nil {
switch t := err.(type) {
case *echo.HTTPError:
return err
case *errors.Error:
logger.Error("error processing request",
"err", t.Cause,
"location", t.Location,
"http.method", request.Method,
"http.uri", request.RequestURI,
)
default:
logger.Error("error processing request",
"err", err,
"http.method", request.Method,
"http.uri", request.RequestURI,
)
}
if strings.HasPrefix(request.RequestURI, "/a/") {
return c.Render(http.StatusInternalServerError, "error.html", nil)
}
}
return nil
}
}
}
func EchoLogger(logger hclog.Logger) echo.MiddlewareFunc {
httpLogger := logger.Named("http")
return func(next echo.HandlerFunc) echo.HandlerFunc {
@@ -55,6 +92,15 @@ func EchoRecover(logger hclog.Logger) echo.MiddlewareFunc {
}
}
func ErrorRedirect() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set("redirect_on_error", true)
return next(c)
}
}
}
func EchoMetrics(p *prometheus.Prometheus) echo.MiddlewareFunc {
return p.HandlerFunc
}
+3 -2
View File
@@ -2,6 +2,7 @@ package server
import (
"github.com/bufbuild/connect-go"
"github.com/hashicorp/go-hclog"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/key"
"github.com/jsiebens/ionscale/internal/service"
@@ -9,7 +10,7 @@ import (
"net/http"
)
func NewRpcHandler(systemAdminKey *key.ServerPrivate, repository domain.Repository, handler apiconnect.IonscaleServiceHandler) (string, http.Handler) {
interceptors := connect.WithInterceptors(service.AuthenticationInterceptor(systemAdminKey, repository))
func NewRpcHandler(systemAdminKey *key.ServerPrivate, repository domain.Repository, logger hclog.Logger, handler apiconnect.IonscaleServiceHandler) (string, http.Handler) {
interceptors := connect.WithInterceptors(service.NewErrorInterceptor(logger), service.AuthenticationInterceptor(systemAdminKey, repository))
return apiconnect.NewIonscaleServiceHandler(handler, interceptors)
}
+17 -16
View File
@@ -92,16 +92,22 @@ func Start(c *config.Config) error {
return err
}
createPeerHandler := func(p key.MachinePublic) http.Handler {
registrationHandlers := handlers.NewRegistrationHandlers(bind.DefaultBinder(p), c, brokers, repository)
pollNetMapHandler := handlers.NewPollNetMapHandler(bind.DefaultBinder(p), brokers, repository, offlineTimers)
dnsHandlers := handlers.NewDNSHandlers(bind.DefaultBinder(p), dnsProvider)
idTokenHandlers := handlers.NewIDTokenHandlers(bind.DefaultBinder(p), c, repository)
sshActionHandlers := handlers.NewSSHActionHandlers(bind.DefaultBinder(p), c, repository)
p := echo_prometheus.NewPrometheus("http", nil)
metricsHandler := echo.New()
p.SetMetricsPath(metricsHandler)
createPeerHandler := func(machinePublicKey key.MachinePublic) http.Handler {
binder := bind.DefaultBinder(machinePublicKey)
registrationHandlers := handlers.NewRegistrationHandlers(binder, c, brokers, repository)
pollNetMapHandler := handlers.NewPollNetMapHandler(binder, brokers, repository, offlineTimers)
dnsHandlers := handlers.NewDNSHandlers(binder, dnsProvider)
idTokenHandlers := handlers.NewIDTokenHandlers(binder, c, repository)
sshActionHandlers := handlers.NewSSHActionHandlers(binder, c, repository)
e := echo.New()
e.Use(EchoLogger(logger))
e.Use(EchoRecover(logger))
e.Use(EchoMetrics(p), EchoLogger(logger), EchoErrorHandler(logger), EchoRecover(logger))
e.POST("/machine/register", registrationHandlers.Register)
e.POST("/machine/map", pollNetMapHandler.PollNetMap)
e.POST("/machine/set-dns", dnsHandlers.SetDNS)
@@ -125,22 +131,17 @@ func Start(c *config.Config) error {
)
rpcService := service.NewService(c, authProvider, repository, brokers)
rpcPath, rpcHandler := NewRpcHandler(serverKey.SystemAdminKey, repository, rpcService)
p := echo_prometheus.NewPrometheus("http", nil)
metricsHandler := echo.New()
p.SetMetricsPath(metricsHandler)
rpcPath, rpcHandler := NewRpcHandler(serverKey.SystemAdminKey, repository, logger, rpcService)
nonTlsAppHandler := echo.New()
nonTlsAppHandler.Use(EchoMetrics(p), EchoLogger(logger), EchoRecover(logger))
nonTlsAppHandler.Use(EchoMetrics(p), EchoLogger(logger), EchoErrorHandler(logger), EchoRecover(logger))
nonTlsAppHandler.POST("/ts2021", noiseHandlers.Upgrade)
nonTlsAppHandler.Any("/*", handlers.HttpRedirectHandler(c.Tls))
tlsAppHandler := echo.New()
tlsAppHandler.Renderer = templates.NewTemplates()
tlsAppHandler.Pre(handlers.HttpsRedirect(c.Tls))
tlsAppHandler.Use(EchoMetrics(p), EchoLogger(logger), EchoRecover(logger))
tlsAppHandler.Use(EchoMetrics(p), EchoLogger(logger), EchoErrorHandler(logger), EchoRecover(logger))
tlsAppHandler.Any("/*", handlers.IndexHandler(http.StatusNotFound))
tlsAppHandler.Any("/", handlers.IndexHandler(http.StatusOK))
+8 -8
View File
@@ -2,11 +2,11 @@ package service
import (
"context"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/broker"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/jsiebens/ionscale/internal/mapping"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
@@ -14,12 +14,12 @@ import (
func (s *Service) GetACLPolicy(ctx context.Context, req *connect.Request[api.GetACLPolicyRequest]) (*connect.Response[api.GetACLPolicyResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
@@ -27,7 +27,7 @@ func (s *Service) GetACLPolicy(ctx context.Context, req *connect.Request[api.Get
var policy api.ACLPolicy
if err := mapping.CopyViaJson(&tailnet.ACLPolicy, &policy); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
return connect.NewResponse(&api.GetACLPolicyResponse{Policy: &policy}), nil
@@ -36,12 +36,12 @@ func (s *Service) GetACLPolicy(ctx context.Context, req *connect.Request[api.Get
func (s *Service) SetACLPolicy(ctx context.Context, req *connect.Request[api.SetACLPolicyRequest]) (*connect.Response[api.SetACLPolicyResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
@@ -49,12 +49,12 @@ func (s *Service) SetACLPolicy(ctx context.Context, req *connect.Request[api.Set
var policy domain.ACLPolicy
if err := mapping.CopyViaJson(req.Msg.Policy, &policy); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
tailnet.ACLPolicy = policy
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{ACLUpdated: true})
+13 -9
View File
@@ -2,9 +2,10 @@ package service
import (
"context"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/jsiebens/ionscale/internal/util"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"time"
@@ -12,7 +13,7 @@ import (
func (s *Service) Authenticate(ctx context.Context, req *connect.Request[api.AuthenticateRequest], stream *connect.ServerStream[api.AuthenticateResponse]) error {
if s.authProvider == nil {
return connect.NewError(connect.CodeFailedPrecondition, errors.New("no authentication method available, contact your ionscale administrator for more information"))
return connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("no authentication method available, contact your ionscale administrator for more information"))
}
key := util.RandStringBytes(8)
@@ -24,11 +25,11 @@ func (s *Service) Authenticate(ctx context.Context, req *connect.Request[api.Aut
}
if err := s.repository.SaveAuthenticationRequest(ctx, session); err != nil {
return err
return errors.Wrap(err, 0)
}
if err := stream.Send(&api.AuthenticateResponse{AuthUrl: authUrl}); err != nil {
return err
return errors.Wrap(err, 0)
}
notify := ctx.Done()
@@ -43,24 +44,27 @@ func (s *Service) Authenticate(ctx context.Context, req *connect.Request[api.Aut
select {
case <-tick.C:
m, err := s.repository.GetAuthenticationRequest(ctx, key)
if err != nil {
return errors.Wrap(err, 0)
}
if err != nil || m == nil {
return connect.NewError(connect.CodeInternal, errors.New("something went wrong"))
if m == nil {
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid authentication request"))
}
if len(m.Token) != 0 {
if err := stream.Send(&api.AuthenticateResponse{Token: m.Token, TailnetId: m.TailnetID}); err != nil {
return err
return errors.Wrap(err, 0)
}
return nil
}
if len(m.Error) != 0 {
return connect.NewError(connect.CodePermissionDenied, errors.New(m.Error))
return connect.NewError(connect.CodePermissionDenied, fmt.Errorf(m.Error))
}
if err := stream.Send(&api.AuthenticateResponse{AuthUrl: authUrl}); err != nil {
return err
return errors.Wrap(err, 0)
}
case <-notify:
+20 -19
View File
@@ -2,9 +2,10 @@ package service
import (
"context"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"google.golang.org/protobuf/types/known/timestamppb"
"time"
@@ -15,15 +16,15 @@ func (s *Service) GetAuthKey(ctx context.Context, req *connect.Request[api.GetAu
key, err := s.repository.GetAuthKey(ctx, req.Msg.AuthKeyId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if key == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("auth key not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("auth key not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(key.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
var expiresAt *timestamppb.Timestamp
@@ -74,16 +75,16 @@ func mapAuthKeysToApi(authKeys []domain.AuthKey) []*api.AuthKey {
func (s *Service) ListAuthKeys(ctx context.Context, req *connect.Request[api.ListAuthKeysRequest]) (*connect.Response[api.ListAuthKeysResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
response := api.ListAuthKeysResponse{}
@@ -91,7 +92,7 @@ func (s *Service) ListAuthKeys(ctx context.Context, req *connect.Request[api.Lis
if principal.IsSystemAdmin() {
authKeys, err := s.repository.ListAuthKeys(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
response.AuthKeys = mapAuthKeysToApi(authKeys)
@@ -101,7 +102,7 @@ func (s *Service) ListAuthKeys(ctx context.Context, req *connect.Request[api.Lis
if principal.User != nil {
authKeys, err := s.repository.ListAuthKeysByTailnetAndUser(ctx, req.Msg.TailnetId, principal.User.ID)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
response.AuthKeys = mapAuthKeysToApi(authKeys)
@@ -114,11 +115,11 @@ func (s *Service) ListAuthKeys(ctx context.Context, req *connect.Request[api.Lis
func (s *Service) CreateAuthKey(ctx context.Context, req *connect.Request[api.CreateAuthKeyRequest]) (*connect.Response[api.CreateAuthKeyResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
if principal.User == nil && len(req.Msg.Tags) == 0 {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("at least one tag is required when creating an auth key"))
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("at least one tag is required when creating an auth key"))
}
if err := domain.CheckTags(req.Msg.Tags); err != nil {
@@ -127,11 +128,11 @@ func (s *Service) CreateAuthKey(ctx context.Context, req *connect.Request[api.Cr
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
if !principal.IsSystemAdmin() {
@@ -154,7 +155,7 @@ func (s *Service) CreateAuthKey(ctx context.Context, req *connect.Request[api.Cr
if user == nil {
u, _, err := s.repository.GetOrCreateServiceUser(ctx, tailnet)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
user = u
}
@@ -164,7 +165,7 @@ func (s *Service) CreateAuthKey(ctx context.Context, req *connect.Request[api.Cr
v, authKey := domain.CreateAuthKey(tailnet, user, req.Msg.Ephemeral, req.Msg.PreAuthorized, tags, expiresAt)
if err := s.repository.SaveAuthKey(ctx, authKey); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
response := api.CreateAuthKeyResponse{
@@ -190,19 +191,19 @@ func (s *Service) DeleteAuthKey(ctx context.Context, req *connect.Request[api.De
key, err := s.repository.GetAuthKey(ctx, req.Msg.AuthKeyId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if key == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("auth key not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("auth key not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(key.UserID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
if _, err := s.repository.DeleteAuthKey(ctx, req.Msg.AuthKeyId); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
return connect.NewResponse(&api.DeleteAuthKeyResponse{}), nil
}
+12 -11
View File
@@ -3,10 +3,11 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/broker"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/jsiebens/ionscale/internal/util"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"tailscale.com/tailcfg"
@@ -15,17 +16,17 @@ import (
func (s *Service) GetDefaultDERPMap(ctx context.Context, _ *connect.Request[api.GetDefaultDERPMapRequest]) (*connect.Response[api.GetDefaultDERPMapResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
dm, err := s.repository.GetDERPMap(ctx)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
raw, err := json.Marshal(dm.DERPMap)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
return connect.NewResponse(&api.GetDefaultDERPMapResponse{Value: raw}), nil
@@ -34,12 +35,12 @@ func (s *Service) GetDefaultDERPMap(ctx context.Context, _ *connect.Request[api.
func (s *Service) SetDefaultDERPMap(ctx context.Context, req *connect.Request[api.SetDefaultDERPMapRequest]) (*connect.Response[api.SetDefaultDERPMapResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
var derpMap tailcfg.DERPMap
if err := json.Unmarshal(req.Msg.Value, &derpMap); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
dp := domain.DERPMap{
@@ -48,12 +49,12 @@ func (s *Service) SetDefaultDERPMap(ctx context.Context, req *connect.Request[ap
}
if err := s.repository.SetDERPMap(ctx, &dp); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
tailnets, err := s.repository.ListTailnets(ctx)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
for _, t := range tailnets {
@@ -66,18 +67,18 @@ func (s *Service) SetDefaultDERPMap(ctx context.Context, req *connect.Request[ap
func (s *Service) ResetDefaultDERPMap(ctx context.Context, req *connect.Request[api.ResetDefaultDERPMapRequest]) (*connect.Response[api.ResetDefaultDERPMapResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
dp := domain.DERPMap{}
if err := s.repository.SetDERPMap(ctx, &dp); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
tailnets, err := s.repository.ListTailnets(ctx)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
for _, t := range tailnets {
+17 -17
View File
@@ -2,27 +2,27 @@ package service
import (
"context"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/broker"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
func (s *Service) GetDNSConfig(ctx context.Context, req *connect.Request[api.GetDNSConfigRequest]) (*connect.Response[api.GetDNSConfigResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
dnsConfig := tailnet.DNSConfig
@@ -44,17 +44,17 @@ func (s *Service) GetDNSConfig(ctx context.Context, req *connect.Request[api.Get
func (s *Service) SetDNSConfig(ctx context.Context, req *connect.Request[api.SetDNSConfigRequest]) (*connect.Response[api.SetDNSConfigResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
dnsConfig := req.Msg.Config
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
tailnet.DNSConfig = domain.DNSConfig{
@@ -65,7 +65,7 @@ func (s *Service) SetDNSConfig(ctx context.Context, req *connect.Request[api.Set
}
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{DNSUpdated: true})
@@ -78,24 +78,24 @@ func (s *Service) SetDNSConfig(ctx context.Context, req *connect.Request[api.Set
func (s *Service) EnableHttpsCertificates(ctx context.Context, req *connect.Request[api.EnableHttpsCertificatesRequest]) (*connect.Response[api.EnableHttpsCertificatesResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
if !tailnet.DNSConfig.MagicDNS {
return nil, connect.NewError(connect.CodeFailedPrecondition, errors.New("MagicDNS must be enabled for this tailnet"))
return nil, connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("MagicDNS must be enabled for this tailnet"))
}
tailnet.DNSConfig.HttpsCertsEnabled = true
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{DNSUpdated: true})
@@ -106,21 +106,21 @@ func (s *Service) EnableHttpsCertificates(ctx context.Context, req *connect.Requ
func (s *Service) DisableHttpsCertificates(ctx context.Context, req *connect.Request[api.DisableHttpsCertificatesRequest]) (*connect.Response[api.DisableHttpsCertificatesResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
tailnet.DNSConfig.HttpsCertsEnabled = false
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{DNSUpdated: true})
+6 -6
View File
@@ -2,22 +2,22 @@ package service
import (
"context"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
func (s *Service) GetIAMPolicy(ctx context.Context, req *connect.Request[api.GetIAMPolicyRequest]) (*connect.Response[api.GetIAMPolicyResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
@@ -36,12 +36,12 @@ func (s *Service) GetIAMPolicy(ctx context.Context, req *connect.Request[api.Get
func (s *Service) SetIAMPolicy(ctx context.Context, req *connect.Request[api.SetIAMPolicyRequest]) (*connect.Response[api.SetIAMPolicyResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
@@ -55,7 +55,7 @@ func (s *Service) SetIAMPolicy(ctx context.Context, req *connect.Request[api.Set
}
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
return connect.NewResponse(&api.SetIAMPolicyResponse{}), nil
+54
View File
@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/hashicorp/go-hclog"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/jsiebens/ionscale/internal/key"
"github.com/jsiebens/ionscale/internal/token"
"strings"
@@ -75,3 +77,55 @@ func exchangeToken(ctx context.Context, systemAdminKey *key.ServerPrivate, repos
return nil
}
func NewErrorInterceptor(logger hclog.Logger) *ErrorInterceptor {
return &ErrorInterceptor{
logger: logger,
}
}
type ErrorInterceptor struct {
logger hclog.Logger
}
func (e *ErrorInterceptor) handleError(err error) error {
if err == nil {
return err
}
switch t := err.(type) {
case *connect.Error:
return err
case *errors.Error:
e.logger.Error("error processing grpc request",
"err", t.Cause,
"location", t.Location,
)
return connect.NewError(connect.CodeInternal, fmt.Errorf("internal server error"))
default:
e.logger.Error("error processing grpc request",
"err", err,
)
return connect.NewError(connect.CodeInternal, fmt.Errorf("internal server error"))
}
}
func (e *ErrorInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
response, err := next(ctx, request)
return response, e.handleError(err)
}
}
func (e *ErrorInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
return next(ctx, spec)
}
}
func (e *ErrorInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
err := next(ctx, conn)
return e.handleError(err)
}
}
+47 -47
View File
@@ -2,12 +2,12 @@ package service
import (
"context"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/broker"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"google.golang.org/protobuf/types/known/timestamppb"
"net/netip"
@@ -64,20 +64,20 @@ func (s *Service) machineToApi(m *domain.Machine) *api.Machine {
func (s *Service) ListMachines(ctx context.Context, req *connect.Request[api.ListMachinesRequest]) (*connect.Response[api.ListMachinesResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
machines, err := s.repository.ListMachineByTailnet(ctx, tailnet.ID)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
response := &api.ListMachinesResponse{}
@@ -93,15 +93,15 @@ func (s *Service) GetMachine(ctx context.Context, req *connect.Request[api.GetMa
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if m == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
return connect.NewResponse(&api.GetMachineResponse{Machine: s.machineToApi(m)}), nil
@@ -112,19 +112,19 @@ func (s *Service) DeleteMachine(ctx context.Context, req *connect.Request[api.De
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if m == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
if _, err := s.repository.DeleteMachine(ctx, req.Msg.MachineId); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeersRemoved: []uint64{m.ID}})
@@ -137,15 +137,15 @@ func (s *Service) ExpireMachine(ctx context.Context, req *connect.Request[api.Ex
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if m == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
timestamp := time.Unix(123, 0)
@@ -153,7 +153,7 @@ func (s *Service) ExpireMachine(ctx context.Context, req *connect.Request[api.Ex
m.KeyExpiryDisabled = false
if err := s.repository.SaveMachine(ctx, m); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
@@ -166,21 +166,21 @@ func (s *Service) AuthorizeMachine(ctx context.Context, req *connect.Request[api
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if m == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
if !m.Authorized {
m.Authorized = true
if err := s.repository.SaveMachine(ctx, m); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
}
@@ -194,15 +194,15 @@ func (s *Service) GetMachineRoutes(ctx context.Context, req *connect.Request[api
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if m == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
response := api.GetMachineRoutesResponse{
@@ -223,15 +223,15 @@ func (s *Service) EnableMachineRoutes(ctx context.Context, req *connect.Request[
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if m == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
var allowIPs = domain.NewAllowIPsSet(m.AllowIPs)
@@ -245,7 +245,7 @@ func (s *Service) EnableMachineRoutes(ctx context.Context, req *connect.Request[
for _, r := range req.Msg.Routes {
prefix, err := netip.ParsePrefix(r)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
allowIPs.Add(prefix)
}
@@ -253,7 +253,7 @@ func (s *Service) EnableMachineRoutes(ctx context.Context, req *connect.Request[
m.AllowIPs = allowIPs.Items()
m.AutoAllowIPs = autoAllowIPs.Items()
if err := s.repository.SaveMachine(ctx, m); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
@@ -276,15 +276,15 @@ func (s *Service) DisableMachineRoutes(ctx context.Context, req *connect.Request
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if m == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
allowIPs := domain.NewAllowIPsSet(m.AllowIPs)
@@ -293,7 +293,7 @@ func (s *Service) DisableMachineRoutes(ctx context.Context, req *connect.Request
for _, r := range req.Msg.Routes {
prefix, err := netip.ParsePrefix(r)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
allowIPs.Remove(prefix)
autoAllowIPs.Remove(prefix)
@@ -302,7 +302,7 @@ func (s *Service) DisableMachineRoutes(ctx context.Context, req *connect.Request
m.AllowIPs = allowIPs.Items()
m.AutoAllowIPs = autoAllowIPs.Items()
if err := s.repository.SaveMachine(ctx, m); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
@@ -325,19 +325,19 @@ func (s *Service) EnableExitNode(ctx context.Context, req *connect.Request[api.E
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if m == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
if !m.IsAdvertisedExitNode() {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("machine is not a valid exit node"))
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("machine is not a valid exit node"))
}
prefix4 := netip.MustParsePrefix("0.0.0.0/0")
@@ -349,7 +349,7 @@ func (s *Service) EnableExitNode(ctx context.Context, req *connect.Request[api.E
m.AllowIPs = allowIPs.Items()
if err := s.repository.SaveMachine(ctx, m); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
@@ -372,19 +372,19 @@ func (s *Service) DisableExitNode(ctx context.Context, req *connect.Request[api.
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if m == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
if !m.IsAdvertisedExitNode() {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("machine is not a valid exit node"))
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("machine is not a valid exit node"))
}
prefix4 := netip.MustParsePrefix("0.0.0.0/0")
@@ -400,7 +400,7 @@ func (s *Service) DisableExitNode(ctx context.Context, req *connect.Request[api.
m.AutoAllowIPs = autoAllowIPs.Items()
if err := s.repository.SaveMachine(ctx, m); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
@@ -423,21 +423,21 @@ func (s *Service) SetMachineKeyExpiry(ctx context.Context, req *connect.Request[
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if m == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("machine not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(m.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
m.KeyExpiryDisabled = req.Msg.Disabled
if err := s.repository.SaveMachine(ctx, m); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(m.TailnetID, &broker.Signal{PeerUpdated: &m.ID})
+58 -58
View File
@@ -3,11 +3,11 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/broker"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
"github.com/jsiebens/ionscale/internal/util"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"tailscale.com/tailcfg"
@@ -16,7 +16,7 @@ import (
func (s *Service) CreateTailnet(ctx context.Context, req *connect.Request[api.CreateTailnetRequest]) (*connect.Response[api.CreateTailnetResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
name := req.Msg.Name
@@ -38,7 +38,7 @@ func (s *Service) CreateTailnet(ctx context.Context, req *connect.Request[api.Cr
}
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
resp := &api.CreateTailnetResponse{Tailnet: &api.Tailnet{
@@ -52,16 +52,16 @@ func (s *Service) CreateTailnet(ctx context.Context, req *connect.Request[api.Cr
func (s *Service) GetTailnet(ctx context.Context, req *connect.Request[api.GetTailnetRequest]) (*connect.Response[api.GetTailnetResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.Id) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.Id)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
return connect.NewResponse(&api.GetTailnetResponse{Tailnet: &api.Tailnet{
@@ -78,7 +78,7 @@ func (s *Service) ListTailnets(ctx context.Context, req *connect.Request[api.Lis
if principal.IsSystemAdmin() {
tailnets, err := s.repository.ListTailnets(ctx)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
for _, t := range tailnets {
gt := api.Tailnet{Id: t.ID, Name: t.Name}
@@ -89,7 +89,7 @@ func (s *Service) ListTailnets(ctx context.Context, req *connect.Request[api.Lis
if principal.User != nil {
tailnet, err := s.repository.GetTailnet(ctx, principal.User.TailnetID)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
gt := api.Tailnet{Id: tailnet.ID, Name: tailnet.Name}
resp.Tailnet = append(resp.Tailnet, &gt)
@@ -101,12 +101,12 @@ func (s *Service) ListTailnets(ctx context.Context, req *connect.Request[api.Lis
func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.DeleteTailnetRequest]) (*connect.Response[api.DeleteTailnetResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
count, err := s.repository.CountMachineByTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if !req.Msg.Force && count > 0 {
@@ -138,7 +138,7 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.De
})
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(req.Msg.TailnetId, &broker.Signal{})
@@ -149,20 +149,20 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.De
func (s *Service) SetDERPMap(ctx context.Context, req *connect.Request[api.SetDERPMapRequest]) (*connect.Response[api.SetDERPMapResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
derpMap := tailcfg.DERPMap{}
if err := json.Unmarshal(req.Msg.Value, &derpMap); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
tailnet.DERPMap = domain.DERPMap{
@@ -171,14 +171,14 @@ func (s *Service) SetDERPMap(ctx context.Context, req *connect.Request[api.SetDE
}
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
raw, err := json.Marshal(derpMap)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
return connect.NewResponse(&api.SetDERPMapResponse{Value: raw}), nil
@@ -187,21 +187,21 @@ func (s *Service) SetDERPMap(ctx context.Context, req *connect.Request[api.SetDE
func (s *Service) ResetDERPMap(ctx context.Context, req *connect.Request[api.ResetDERPMapRequest]) (*connect.Response[api.ResetDERPMapResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
tailnet.DERPMap = domain.DERPMap{}
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
@@ -212,25 +212,25 @@ func (s *Service) ResetDERPMap(ctx context.Context, req *connect.Request[api.Res
func (s *Service) GetDERPMap(ctx context.Context, req *connect.Request[api.GetDERPMapRequest]) (*connect.Response[api.GetDERPMapResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
derpMap, err := tailnet.GetDERPMap(ctx, s.repository)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
raw, err := json.Marshal(derpMap.DERPMap)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
return connect.NewResponse(&api.GetDERPMapResponse{Value: raw}), nil
@@ -239,21 +239,21 @@ func (s *Service) GetDERPMap(ctx context.Context, req *connect.Request[api.GetDE
func (s *Service) EnableFileSharing(ctx context.Context, req *connect.Request[api.EnableFileSharingRequest]) (*connect.Response[api.EnableFileSharingResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
if !tailnet.FileSharingEnabled {
tailnet.FileSharingEnabled = true
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
@@ -265,21 +265,21 @@ func (s *Service) EnableFileSharing(ctx context.Context, req *connect.Request[ap
func (s *Service) DisableFileSharing(ctx context.Context, req *connect.Request[api.DisableFileSharingRequest]) (*connect.Response[api.DisableFileSharingResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
if tailnet.FileSharingEnabled {
tailnet.FileSharingEnabled = false
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
@@ -291,21 +291,21 @@ func (s *Service) DisableFileSharing(ctx context.Context, req *connect.Request[a
func (s *Service) EnableServiceCollection(ctx context.Context, req *connect.Request[api.EnableServiceCollectionRequest]) (*connect.Response[api.EnableServiceCollectionResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
if !tailnet.ServiceCollectionEnabled {
tailnet.ServiceCollectionEnabled = true
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
@@ -317,21 +317,21 @@ func (s *Service) EnableServiceCollection(ctx context.Context, req *connect.Requ
func (s *Service) DisableServiceCollection(ctx context.Context, req *connect.Request[api.DisableServiceCollectionRequest]) (*connect.Response[api.DisableServiceCollectionResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
if tailnet.ServiceCollectionEnabled {
tailnet.ServiceCollectionEnabled = false
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
@@ -343,21 +343,21 @@ func (s *Service) DisableServiceCollection(ctx context.Context, req *connect.Req
func (s *Service) EnableSSH(ctx context.Context, req *connect.Request[api.EnableSSHRequest]) (*connect.Response[api.EnableSSHResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
if !tailnet.SSHEnabled {
tailnet.SSHEnabled = true
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
@@ -369,21 +369,21 @@ func (s *Service) EnableSSH(ctx context.Context, req *connect.Request[api.Enable
func (s *Service) DisableSSH(ctx context.Context, req *connect.Request[api.DisableSSHRequest]) (*connect.Response[api.DisableSSHResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
if tailnet.SSHEnabled {
tailnet.SSHEnabled = false
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(tailnet.ID, &broker.Signal{})
@@ -395,21 +395,21 @@ func (s *Service) DisableSSH(ctx context.Context, req *connect.Request[api.Disab
func (s *Service) EnableMachineAuthorization(ctx context.Context, req *connect.Request[api.EnableMachineAuthorizationRequest]) (*connect.Response[api.EnableMachineAuthorizationResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
if !tailnet.MachineAuthorizationEnabled {
tailnet.MachineAuthorizationEnabled = true
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
}
@@ -419,21 +419,21 @@ func (s *Service) EnableMachineAuthorization(ctx context.Context, req *connect.R
func (s *Service) DisableMachineAuthorization(ctx context.Context, req *connect.Request[api.DisableMachineAuthorizationRequest]) (*connect.Response[api.DisableMachineAuthorizationResponse], error) {
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(req.Msg.TailnetId) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
if tailnet.MachineAuthorizationEnabled {
tailnet.MachineAuthorizationEnabled = false
if err := s.repository.SaveTailnet(ctx, tailnet); err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
}
+11 -10
View File
@@ -2,10 +2,11 @@ package service
import (
"context"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/broker"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/errors"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
@@ -14,20 +15,20 @@ func (s *Service) ListUsers(ctx context.Context, req *connect.Request[api.ListUs
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if tailnet == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(tailnet.ID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
users, err := s.repository.ListUsers(ctx, tailnet.ID)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
resp := &api.ListUsersResponse{}
@@ -46,20 +47,20 @@ func (s *Service) DeleteUser(ctx context.Context, req *connect.Request[api.Delet
principal := CurrentPrincipal(ctx)
if !principal.IsSystemAdmin() && principal.UserMatches(req.Msg.UserId) {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("unable delete yourself"))
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("unable delete yourself"))
}
user, err := s.repository.GetUser(ctx, req.Msg.UserId)
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
if user == nil {
return nil, connect.NewError(connect.CodeNotFound, errors.New("user not found"))
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("user not found"))
}
if !principal.IsSystemAdmin() && !principal.IsTailnetAdmin(user.TailnetID) {
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("permission denied"))
}
err = s.repository.Transaction(func(tx domain.Repository) error {
@@ -83,7 +84,7 @@ func (s *Service) DeleteUser(ctx context.Context, req *connect.Request[api.Delet
})
if err != nil {
return nil, err
return nil, errors.Wrap(err, 0)
}
s.pubsub.Publish(user.TailnetID, &broker.Signal{})