feat: replace grpc with buf connect

This commit is contained in:
Johan Siebens
2022-06-03 09:37:34 +02:00
parent 687fcd16d1
commit da71a43990
58 changed files with 2217 additions and 3138 deletions
+5 -4
View File
@@ -4,8 +4,9 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/pkg/gen/api"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/muesli/coral"
"gopkg.in/yaml.v2"
"io/ioutil"
@@ -40,14 +41,14 @@ func getACLConfig() *coral.Command {
return err
}
resp, err := client.GetACLPolicy(context.Background(), &api.GetACLPolicyRequest{TailnetId: tailnet.Id})
resp, err := client.GetACLPolicy(context.Background(), connect.NewRequest(&api.GetACLPolicyRequest{TailnetId: tailnet.Id}))
if err != nil {
return err
}
var p domain.ACLPolicy
if err := json.Unmarshal(resp.Value, &p); err != nil {
if err := json.Unmarshal(resp.Msg.Value, &p); err != nil {
return err
}
@@ -109,7 +110,7 @@ func setACLConfig() *coral.Command {
return err
}
_, err = client.SetACLPolicy(context.Background(), &api.SetACLPolicyRequest{TailnetId: tailnet.Id, Value: rawJson})
_, err = client.SetACLPolicy(context.Background(), connect.NewRequest(&api.SetACLPolicyRequest{TailnetId: tailnet.Id, Value: rawJson}))
if err != nil {
return err
}
+9 -8
View File
@@ -3,8 +3,9 @@ package cmd
import (
"context"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/hashicorp/go-bexpr"
"github.com/jsiebens/ionscale/pkg/gen/api"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/muesli/coral"
"github.com/rodaine/table"
)
@@ -51,14 +52,14 @@ func listAuthFilterCommand() *coral.Command {
req.AuthMethodId = &authMethodID
}
resp, err := client.ListAuthFilters(context.Background(), req)
resp, err := client.ListAuthFilters(context.Background(), connect.NewRequest(req))
if err != nil {
return err
}
tbl := table.New("ID", "AUTH_METHOD", "TAILNET", "EXPR")
for _, filter := range resp.AuthFilters {
for _, filter := range resp.Msg.AuthFilters {
if filter.Tailnet != nil {
tbl.AddRow(filter.Id, filter.AuthMethod.Name, filter.Tailnet.Name, filter.Expr)
} else {
@@ -123,17 +124,17 @@ func createAuthFilterCommand() *coral.Command {
Expr: expr,
}
resp, err := client.CreateAuthFilter(context.Background(), req)
resp, err := client.CreateAuthFilter(context.Background(), connect.NewRequest(req))
if err != nil {
return err
}
tbl := table.New("ID", "AUTH_METHOD", "TAILNET", "EXPR")
if resp.AuthFilter.Tailnet != nil {
tbl.AddRow(resp.AuthFilter.Id, resp.AuthFilter.AuthMethod.Name, resp.AuthFilter.Tailnet.Name, resp.AuthFilter.Expr)
if resp.Msg.AuthFilter.Tailnet != nil {
tbl.AddRow(resp.Msg.AuthFilter.Id, resp.Msg.AuthFilter.AuthMethod.Name, resp.Msg.AuthFilter.Tailnet.Name, resp.Msg.AuthFilter.Expr)
} else {
tbl.AddRow(resp.AuthFilter.Id, resp.AuthFilter.AuthMethod.Name, "", resp.AuthFilter.Expr)
tbl.AddRow(resp.Msg.AuthFilter.Id, resp.Msg.AuthFilter.AuthMethod.Name, "", resp.Msg.AuthFilter.Expr)
}
tbl.Print()
@@ -167,7 +168,7 @@ func deleteAuthFilterCommand() *coral.Command {
AuthFilterId: authFilterID,
}
_, err = client.DeleteAuthFilter(context.Background(), req)
_, err = client.DeleteAuthFilter(context.Background(), connect.NewRequest(req))
if err != nil {
return err
+7 -6
View File
@@ -3,7 +3,8 @@ package cmd
import (
"context"
"fmt"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/bufbuild/connect-go"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/muesli/coral"
"github.com/rodaine/table"
str2dur "github.com/xhit/go-str2duration/v2"
@@ -72,7 +73,7 @@ func createAuthkeysCommand() *coral.Command {
Tags: tags,
Expiry: expiryDur,
}
resp, err := client.CreateAuthKey(context.Background(), req)
resp, err := client.CreateAuthKey(context.Background(), connect.NewRequest(req))
if err != nil {
return err
@@ -82,7 +83,7 @@ func createAuthkeysCommand() *coral.Command {
fmt.Println("Generated new auth key")
fmt.Println("Be sure to copy your new key below. It won't be shown in full again.")
fmt.Println("")
fmt.Printf(" %s\n", resp.Value)
fmt.Printf(" %s\n", resp.Msg.Value)
fmt.Println("")
return nil
@@ -110,7 +111,7 @@ func deleteAuthKeyCommand() *coral.Command {
defer safeClose(c)
req := api.DeleteAuthKeyRequest{AuthKeyId: authKeyId}
if _, err := grpcClient.DeleteAuthKey(context.Background(), &req); err != nil {
if _, err := grpcClient.DeleteAuthKey(context.Background(), connect.NewRequest(&req)); err != nil {
return err
}
@@ -149,13 +150,13 @@ func listAuthkeysCommand() *coral.Command {
}
req := &api.ListAuthKeysRequest{TailnetId: tailnet.Id}
resp, err := client.ListAuthKeys(context.Background(), req)
resp, err := client.ListAuthKeys(context.Background(), connect.NewRequest(req))
if err != nil {
return err
}
printAuthKeyTable(resp.AuthKeys...)
printAuthKeyTable(resp.Msg.AuthKeys...)
return nil
}
+6 -5
View File
@@ -2,7 +2,8 @@ package cmd
import (
"context"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/bufbuild/connect-go"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/muesli/coral"
"github.com/rodaine/table"
)
@@ -39,14 +40,14 @@ func listAuthMethods() *coral.Command {
}
defer safeClose(c)
resp, err := client.ListAuthMethods(context.Background(), &api.ListAuthMethodsRequest{})
resp, err := client.ListAuthMethods(context.Background(), connect.NewRequest(&api.ListAuthMethodsRequest{}))
if err != nil {
return err
}
tbl := table.New("ID", "NAME", "TYPE")
for _, m := range resp.AuthMethods {
for _, m := range resp.Msg.AuthMethods {
tbl.AddRow(m.Id, m.Name, m.Type)
}
tbl.Print()
@@ -110,14 +111,14 @@ func createOIDCAuthMethodCommand() *coral.Command {
ClientSecret: clientSecret,
}
resp, err := client.CreateAuthMethod(context.Background(), req)
resp, err := client.CreateAuthMethod(context.Background(), connect.NewRequest(req))
if err != nil {
return err
}
tbl := table.New("ID", "NAME", "TYPE")
tbl.AddRow(resp.AuthMethod.Id, resp.AuthMethod.Name, resp.AuthMethod.Type)
tbl.AddRow(resp.Msg.AuthMethod.Id, resp.Msg.AuthMethod.Name, resp.Msg.AuthMethod.Type)
tbl.Print()
return nil
+6 -5
View File
@@ -4,7 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/bufbuild/connect-go"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/muesli/coral"
"gopkg.in/yaml.v2"
"io/ioutil"
@@ -43,7 +44,7 @@ func getDERPMap() *coral.Command {
}
defer safeClose(c)
resp, err := client.GetDERPMap(context.Background(), &api.GetDERPMapRequest{})
resp, err := client.GetDERPMap(context.Background(), connect.NewRequest(&api.GetDERPMapRequest{}))
if err != nil {
return err
@@ -53,7 +54,7 @@ func getDERPMap() *coral.Command {
Regions map[int]*tailcfg.DERPRegion
}
if err := json.Unmarshal(resp.Value, &derpMap); err != nil {
if err := json.Unmarshal(resp.Msg.Value, &derpMap); err != nil {
return err
}
@@ -105,13 +106,13 @@ func setDERPMap() *coral.Command {
return err
}
resp, err := grpcClient.SetDERPMap(context.Background(), &api.SetDERPMapRequest{Value: rawJson})
resp, err := grpcClient.SetDERPMap(context.Background(), connect.NewRequest(&api.SetDERPMapRequest{Value: rawJson}))
if err != nil {
return err
}
var derpMap tailcfg.DERPMap
if err := json.Unmarshal(resp.Value, &derpMap); err != nil {
if err := json.Unmarshal(resp.Msg.Value, &derpMap); err != nil {
return err
}
+6 -5
View File
@@ -3,7 +3,8 @@ package cmd
import (
"context"
"fmt"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/bufbuild/connect-go"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/muesli/coral"
"strings"
)
@@ -36,12 +37,12 @@ func getDNSConfig() *coral.Command {
}
req := api.GetDNSConfigRequest{TailnetId: tailnet.Id}
resp, err := client.GetDNSConfig(context.Background(), &req)
resp, err := client.GetDNSConfig(context.Background(), connect.NewRequest(&req))
if err != nil {
return err
}
config := resp.Config
config := resp.Msg.Config
var allNameservers = config.Nameservers
@@ -120,13 +121,13 @@ func setDNSConfig() *coral.Command {
Routes: routes,
},
}
resp, err := client.SetDNSConfig(context.Background(), &req)
resp, err := client.SetDNSConfig(context.Background(), connect.NewRequest(&req))
if err != nil {
return err
}
config := resp.Config
config := resp.Msg.Config
var allNameservers = config.Nameservers
+9 -7
View File
@@ -3,21 +3,23 @@ package cmd
import (
"context"
"fmt"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/bufbuild/connect-go"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
apiconnect "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1/ionscalev1connect"
"io"
)
func findTailnet(client api.IonscaleClient, tailnet string, tailnetID uint64) (*api.Tailnet, error) {
func findTailnet(client apiconnect.IonscaleServiceClient, tailnet string, tailnetID uint64) (*api.Tailnet, error) {
if tailnetID == 0 && tailnet == "" {
return nil, fmt.Errorf("requested tailnet not found or you are not authorized for this tailnet")
}
tailnets, err := client.ListTailnets(context.Background(), &api.ListTailnetRequest{})
tailnets, err := client.ListTailnets(context.Background(), connect.NewRequest(&api.ListTailnetRequest{}))
if err != nil {
return nil, err
}
for _, t := range tailnets.Tailnet {
for _, t := range tailnets.Msg.Tailnet {
if t.Id == tailnetID || t.Name == tailnet {
return t, nil
}
@@ -26,17 +28,17 @@ func findTailnet(client api.IonscaleClient, tailnet string, tailnetID uint64) (*
return nil, fmt.Errorf("requested tailnet not found or you are not authorized for this tailnet")
}
func findAuthMethod(client api.IonscaleClient, authMethod string, authMethodID uint64) (*api.AuthMethod, error) {
func findAuthMethod(client apiconnect.IonscaleServiceClient, authMethod string, authMethodID uint64) (*api.AuthMethod, error) {
if authMethodID == 0 && authMethod == "" {
return nil, fmt.Errorf("requested auth method not found or you are not authorized for this tailnet")
}
resp, err := client.ListAuthMethods(context.Background(), &api.ListAuthMethodsRequest{})
resp, err := client.ListAuthMethods(context.Background(), connect.NewRequest(&api.ListAuthMethodsRequest{}))
if err != nil {
return nil, err
}
for _, t := range resp.AuthMethods {
for _, t := range resp.Msg.AuthMethods {
if t.Id == authMethodID || t.Name == authMethod {
return t, nil
}
+10 -9
View File
@@ -3,7 +3,8 @@ package cmd
import (
"context"
"fmt"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/bufbuild/connect-go"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/muesli/coral"
"github.com/nleeper/goment"
"github.com/rodaine/table"
@@ -47,7 +48,7 @@ func deleteMachineCommand() *coral.Command {
defer safeClose(c)
req := api.DeleteMachineRequest{MachineId: machineID}
if _, err := client.DeleteMachine(context.Background(), &req); err != nil {
if _, err := client.DeleteMachine(context.Background(), connect.NewRequest(&req)); err != nil {
return err
}
@@ -79,7 +80,7 @@ func expireMachineCommand() *coral.Command {
defer safeClose(c)
req := api.ExpireMachineRequest{MachineId: machineID}
if _, err := client.ExpireMachine(context.Background(), &req); err != nil {
if _, err := client.ExpireMachine(context.Background(), connect.NewRequest(&req)); err != nil {
return err
}
@@ -119,14 +120,14 @@ func listMachinesCommand() *coral.Command {
}
req := api.ListMachinesRequest{TailnetId: tailnet.Id}
resp, err := client.ListMachines(context.Background(), &req)
resp, err := client.ListMachines(context.Background(), connect.NewRequest(&req))
if err != nil {
return err
}
tbl := table.New("ID", "TAILNET", "NAME", "IPv4", "IPv6", "EPHEMERAL", "LAST_SEEN", "TAGS")
for _, m := range resp.Machines {
for _, m := range resp.Msg.Machines {
var lastSeen = "N/A"
if m.Connected {
lastSeen = "Connected"
@@ -166,13 +167,13 @@ func getMachineRoutesCommand() *coral.Command {
defer safeClose(c)
req := api.GetMachineRoutesRequest{MachineId: machineID}
resp, err := grpcClient.GetMachineRoutes(context.Background(), &req)
resp, err := grpcClient.GetMachineRoutes(context.Background(), connect.NewRequest(&req))
if err != nil {
return err
}
tbl := table.New("ROUTE", "ALLOWED")
for _, r := range resp.Routes {
for _, r := range resp.Msg.Routes {
tbl.AddRow(r.Advertised, r.Allowed)
}
tbl.Print()
@@ -214,13 +215,13 @@ func setMachineRoutesCommand() *coral.Command {
}
req := api.SetMachineRoutesRequest{MachineId: machineID, AllowedIps: allowedIps}
resp, err := client.SetMachineRoutes(context.Background(), &req)
resp, err := client.SetMachineRoutes(context.Background(), connect.NewRequest(&req))
if err != nil {
return err
}
tbl := table.New("ROUTE", "ALLOWED")
for _, r := range resp.Routes {
for _, r := range resp.Msg.Routes {
tbl.AddRow(r.Advertised, r.Allowed)
}
tbl.Print()
+7 -6
View File
@@ -3,7 +3,8 @@ package cmd
import (
"context"
"fmt"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/bufbuild/connect-go"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/muesli/coral"
"github.com/rodaine/table"
)
@@ -45,14 +46,14 @@ func listTailnetsCommand() *coral.Command {
}
defer safeClose(c)
resp, err := client.ListTailnets(context.Background(), &api.ListTailnetRequest{})
resp, err := client.ListTailnets(context.Background(), connect.NewRequest(&api.ListTailnetRequest{}))
if err != nil {
return err
}
tbl := table.New("ID", "NAME")
for _, tailnet := range resp.Tailnet {
for _, tailnet := range resp.Msg.Tailnet {
tbl.AddRow(tailnet.Id, tailnet.Name)
}
tbl.Print()
@@ -85,14 +86,14 @@ func createTailnetsCommand() *coral.Command {
}
defer safeClose(c)
resp, err := client.CreateTailnet(context.Background(), &api.CreateTailnetRequest{Name: name})
resp, err := client.CreateTailnet(context.Background(), connect.NewRequest(&api.CreateTailnetRequest{Name: name}))
if err != nil {
return err
}
tbl := table.New("ID", "NAME")
tbl.AddRow(resp.Tailnet.Id, resp.Tailnet.Name)
tbl.AddRow(resp.Msg.Tailnet.Id, resp.Msg.Tailnet.Name)
tbl.Print()
return nil
@@ -131,7 +132,7 @@ func deleteTailnetCommand() *coral.Command {
return err
}
_, err = client.DeleteTailnet(context.Background(), &api.DeleteTailnetRequest{TailnetId: tailnet.Id, Force: force})
_, err = client.DeleteTailnet(context.Background(), connect.NewRequest(&api.DeleteTailnetRequest{TailnetId: tailnet.Id, Force: force}))
if err != nil {
return err
+3 -14
View File
@@ -3,7 +3,7 @@ package cmd
import (
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/pkg/client/ionscale"
"github.com/jsiebens/ionscale/pkg/gen/api"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1/ionscalev1connect"
"github.com/muesli/coral"
"io"
)
@@ -12,12 +12,10 @@ const (
ionscaleSystemAdminKey = "IONSCALE_ADMIN_KEY"
ionscaleAddr = "IONSCALE_ADDR"
ionscaleInsecureSkipVerify = "IONSCALE_SKIP_VERIFY"
ionscaleUseGrpcWeb = "IONSCALE_GRPC_WEB"
)
type Target struct {
addr string
useGrpcWeb bool
insecureSkipVerify bool
systemAdminKey string
}
@@ -25,13 +23,11 @@ type Target struct {
func (t *Target) prepareCommand(cmd *coral.Command) {
cmd.Flags().StringVar(&t.addr, "addr", "", "Addr of the ionscale server, as a complete URL")
cmd.Flags().BoolVar(&t.insecureSkipVerify, "tls-skip-verify", false, "Disable verification of TLS certificates")
cmd.Flags().BoolVar(&t.useGrpcWeb, "grpc-web", false, "Enables gRPC-web protocol. Useful if ionscale server is behind proxy which does not support GRPC")
cmd.Flags().StringVar(&t.systemAdminKey, "admin-key", "", "If specified, the given value will be used as the key to generate a Bearer token for the call. This can also be specified via the IONSCALE_ADMIN_KEY environment variable.")
}
func (t *Target) createGRPCClient() (api.IonscaleClient, io.Closer, error) {
func (t *Target) createGRPCClient() (api.IonscaleServiceClient, io.Closer, error) {
addr := t.getAddr()
useGrpcWeb := t.getUseGrpcWeb()
skipVerify := t.getInsecureSkipVerify()
systemAdminKey := t.getSystemAdminKey()
@@ -40,7 +36,7 @@ func (t *Target) createGRPCClient() (api.IonscaleClient, io.Closer, error) {
return nil, nil, err
}
return ionscale.NewClient(auth, addr, skipVerify, useGrpcWeb)
return ionscale.NewClient(auth, addr, skipVerify)
}
func (t *Target) getAddr() string {
@@ -57,13 +53,6 @@ func (t *Target) getInsecureSkipVerify() bool {
return config.GetBool(ionscaleInsecureSkipVerify, false)
}
func (t *Target) getUseGrpcWeb() bool {
if t.useGrpcWeb {
return true
}
return config.GetBool(ionscaleUseGrpcWeb, false)
}
func (t *Target) getSystemAdminKey() string {
if len(t.systemAdminKey) != 0 {
return t.systemAdminKey
+4 -3
View File
@@ -3,8 +3,9 @@ package cmd
import (
"context"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/version"
"github.com/jsiebens/ionscale/pkg/gen/api"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"github.com/muesli/coral"
)
@@ -36,7 +37,7 @@ Server:
}
defer safeClose(c)
resp, err := client.GetVersion(context.Background(), &api.GetVersionRequest{})
resp, err := client.GetVersion(context.Background(), connect.NewRequest(&api.GetVersionRequest{}))
if err != nil {
fmt.Printf(`
Server:
@@ -50,7 +51,7 @@ Server:
Addr: %s
Version: %s
Git Revision: %s
`, target.getAddr(), resp.Version, resp.Revision)
`, target.getAddr(), resp.Msg.Version, resp.Msg.Revision)
}
-54
View File
@@ -1,54 +0,0 @@
package server
import (
"github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery"
"github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/hashicorp/go-hclog"
"github.com/jsiebens/ionscale/internal/key"
"github.com/jsiebens/ionscale/internal/service"
"google.golang.org/grpc"
)
func init() {
grpc_prometheus.EnableHandlingTimeHistogram()
}
func NewGrpcServer(logger hclog.Logger, systemAdminKey key.ServerPrivate) *grpc.Server {
return grpc.NewServer(
middleware.WithUnaryServerChain(
logging.UnaryServerInterceptor(
&grpcLogger{logger.Named("grpc")},
logging.WithDurationField(logging.DurationToDurationField),
),
grpc_prometheus.UnaryServerInterceptor,
recovery.UnaryServerInterceptor(),
service.UnaryServerTokenAuth(systemAdminKey),
),
)
}
type grpcLogger struct {
log hclog.Logger
}
func (l *grpcLogger) Log(lvl logging.Level, msg string) {
switch lvl {
case logging.ERROR:
l.log.Error(msg)
default:
l.log.Debug(msg)
}
}
func (l *grpcLogger) With(fields ...string) logging.Logger {
if len(fields) == 0 {
return l
}
vals := make([]interface{}, 0, len(fields))
for i := 0; i < len(fields); i++ {
vals = append(vals, fields[i])
}
return &grpcLogger{log: l.log.With(vals...)}
}
+58
View File
@@ -0,0 +1,58 @@
package server
import (
"context"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/key"
"github.com/jsiebens/ionscale/internal/token"
apiconnect "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1/ionscalev1connect"
"net/http"
"strings"
)
var (
errInvalidToken = connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("invalid token"))
)
func NewRpcHandler(systemAdminKey key.ServerPrivate, handler apiconnect.IonscaleServiceHandler) (string, http.Handler) {
interceptors := connect.WithInterceptors(authenticationInterceptor(systemAdminKey))
return apiconnect.NewIonscaleServiceHandler(handler, interceptors)
}
func authenticationInterceptor(systemAdminKey key.ServerPrivate) connect.UnaryInterceptorFunc {
return func(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
name := req.Spec().Procedure
if strings.HasSuffix(name, "/GetVersion") {
return next(ctx, req)
}
authorizationHeader := req.Header().Get("Authorization")
valid := validateAuthorizationToken(systemAdminKey, authorizationHeader)
if valid {
return next(ctx, req)
}
return nil, errInvalidToken
}
}
}
func validateAuthorizationToken(systemAdminKey key.ServerPrivate, authorization string) bool {
if len(authorization) == 0 {
return false
}
bearerToken := strings.TrimPrefix(authorization, "Bearer ")
if token.IsSystemAdminToken(bearerToken) {
_, err := token.ParseSystemAdminToken(systemAdminKey, bearerToken)
return err == nil
}
return false
}
+5 -19
View File
@@ -6,7 +6,6 @@ import (
"fmt"
"github.com/caddyserver/certmagic"
"github.com/hashicorp/go-hclog"
"github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/jsiebens/ionscale/internal/bind"
"github.com/jsiebens/ionscale/internal/broker"
"github.com/jsiebens/ionscale/internal/config"
@@ -14,10 +13,8 @@ import (
"github.com/jsiebens/ionscale/internal/handlers"
"github.com/jsiebens/ionscale/internal/service"
"github.com/jsiebens/ionscale/internal/templates"
"github.com/jsiebens/ionscale/pkg/gen/api"
echo_prometheus "github.com/labstack/echo-contrib/prometheus"
"github.com/labstack/echo/v4"
"github.com/soheilhy/cmux"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"golang.org/x/sync/errgroup"
@@ -98,6 +95,9 @@ func Start(config *config.Config) error {
repository,
)
rpcService := service.NewService(repository, brokers)
rpcPath, rpcHandler := NewRpcHandler(serverKey.SystemAdminKey, rpcService)
p := echo_prometheus.NewPrometheus("http", nil)
metricsHandler := echo.New()
@@ -118,6 +118,7 @@ func Start(config *config.Config) error {
tlsAppHandler.Any("/*", handlers.IndexHandler(http.StatusNotFound))
tlsAppHandler.Any("/", handlers.IndexHandler(http.StatusOK))
tlsAppHandler.POST(rpcPath+"*", echo.WrapHandler(rpcHandler))
tlsAppHandler.GET("/version", handlers.Version)
tlsAppHandler.GET("/key", handlers.KeyHandler(controlKeys))
tlsAppHandler.POST("/ts2021", noiseHandlers.Upgrade)
@@ -132,10 +133,6 @@ func Start(config *config.Config) error {
auth.GET("/success", authenticationHandlers.Success)
auth.GET("/error", authenticationHandlers.Error)
grpcService := service.NewService(repository, brokers)
grpcServer := NewGrpcServer(logger, serverKey.SystemAdminKey)
api.RegisterIonscaleServer(grpcServer, grpcService)
tlsL, err := tlsListener(config)
if err != nil {
return err
@@ -151,23 +148,12 @@ func Start(config *config.Config) error {
return err
}
mux := cmux.New(selectListener(tlsL, nonTlsL))
grpcL := mux.MatchWithWriters(
cmux.HTTP2MatchHeaderFieldPrefixSendSettings("content-type", "application/grpc"),
cmux.HTTP2MatchHeaderFieldPrefixSendSettings("content-type", "application/grpc+proto"),
)
grpcWebL := mux.Match(cmux.HTTP1HeaderFieldPrefix("content-type", "application/grpc-web"))
httpL := mux.Match(cmux.Any())
grpcWebHandler := grpcweb.WrapServer(grpcServer)
httpL := selectListener(tlsL, nonTlsL)
http2Server := &http2.Server{}
g := new(errgroup.Group)
g.Go(func() error { return grpcServer.Serve(grpcL) })
g.Go(func() error { return http.Serve(grpcWebL, h2c.NewHandler(grpcWebHandler, http2Server)) })
g.Go(func() error { return http.Serve(httpL, h2c.NewHandler(tlsAppHandler, http2Server)) })
g.Go(func() error { return http.Serve(metricsL, metricsHandler) })
g.Go(func() error { return mux.Serve() })
if tlsL != nil {
g.Go(func() error { return http.Serve(nonTlsL, nonTlsAppHandler) })
+11 -11
View File
@@ -3,14 +3,14 @@ package service
import (
"context"
"encoding/json"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/pkg/gen/api"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
func (s *Service) GetACLPolicy(ctx context.Context, req *api.GetACLPolicyRequest) (*api.GetACLPolicyResponse, error) {
policy, err := s.repository.GetACLPolicy(ctx, req.TailnetId)
func (s *Service) GetACLPolicy(ctx context.Context, req *connect.Request[api.GetACLPolicyRequest]) (*connect.Response[api.GetACLPolicyResponse], error) {
policy, err := s.repository.GetACLPolicy(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
@@ -20,20 +20,20 @@ func (s *Service) GetACLPolicy(ctx context.Context, req *api.GetACLPolicyRequest
return nil, err
}
return &api.GetACLPolicyResponse{Value: marshal}, nil
return connect.NewResponse(&api.GetACLPolicyResponse{Value: marshal}), nil
}
func (s *Service) SetACLPolicy(ctx context.Context, req *api.SetACLPolicyRequest) (*api.SetACLPolicyResponse, error) {
tailnet, err := s.repository.GetTailnet(ctx, req.TailnetId)
func (s *Service) SetACLPolicy(ctx context.Context, req *connect.Request[api.SetACLPolicyRequest]) (*connect.Response[api.SetACLPolicyResponse], error) {
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, status.Error(codes.NotFound, "tailnet does not exist")
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("tailnet does not exist"))
}
var policy domain.ACLPolicy
if err := json.Unmarshal(req.Value, &policy); err != nil {
if err := json.Unmarshal(req.Msg.Value, &policy); err != nil {
return nil, err
}
@@ -43,5 +43,5 @@ func (s *Service) SetACLPolicy(ctx context.Context, req *api.SetACLPolicyRequest
s.brokers(tailnet.ID).SignalACLUpdated()
return &api.SetACLPolicyResponse{}, nil
return connect.NewResponse(&api.SetACLPolicyResponse{}), nil
}
+21 -22
View File
@@ -3,18 +3,17 @@ package service
import (
"context"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/hashicorp/go-bexpr"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/util"
"github.com/jsiebens/ionscale/pkg/gen/api"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
func (s *Service) ListAuthFilters(ctx context.Context, req *api.ListAuthFiltersRequest) (*api.ListAuthFiltersResponse, error) {
func (s *Service) ListAuthFilters(ctx context.Context, req *connect.Request[api.ListAuthFiltersRequest]) (*connect.Response[api.ListAuthFiltersResponse], error) {
response := &api.ListAuthFiltersResponse{AuthFilters: []*api.AuthFilter{}}
if req.AuthMethodId == nil {
if req.Msg.AuthMethodId == nil {
filters, err := s.repository.ListAuthFilters(ctx)
if err != nil {
return nil, err
@@ -23,12 +22,12 @@ func (s *Service) ListAuthFilters(ctx context.Context, req *api.ListAuthFiltersR
response.AuthFilters = append(response.AuthFilters, s.mapToApi(&filter.AuthMethod, filter))
}
} else {
authMethod, err := s.repository.GetAuthMethod(ctx, *req.AuthMethodId)
authMethod, err := s.repository.GetAuthMethod(ctx, *req.Msg.AuthMethodId)
if err != nil {
return nil, err
}
if authMethod == nil {
return nil, status.Error(codes.NotFound, "invalid auth method id")
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("invalid auth method id"))
}
filters, err := s.repository.ListAuthFiltersByAuthMethod(ctx, authMethod.ID)
@@ -40,36 +39,36 @@ func (s *Service) ListAuthFilters(ctx context.Context, req *api.ListAuthFiltersR
}
}
return response, nil
return connect.NewResponse[api.ListAuthFiltersResponse](response), nil
}
func (s *Service) CreateAuthFilter(ctx context.Context, req *api.CreateAuthFilterRequest) (*api.CreateAuthFilterResponse, error) {
authMethod, err := s.repository.GetAuthMethod(ctx, req.AuthMethodId)
func (s *Service) CreateAuthFilter(ctx context.Context, req *connect.Request[api.CreateAuthFilterRequest]) (*connect.Response[api.CreateAuthFilterResponse], error) {
authMethod, err := s.repository.GetAuthMethod(ctx, req.Msg.AuthMethodId)
if err != nil {
return nil, err
}
if authMethod == nil {
return nil, status.Error(codes.NotFound, "invalid auth method id")
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("invalid auth method id"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.TailnetId)
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, status.Error(codes.NotFound, "invalid tailnet id")
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("invalid tailnet id"))
}
if req.Expr != "*" {
if _, err := bexpr.CreateEvaluator(req.Expr); err != nil {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid expression: %v", err))
if req.Msg.Expr != "*" {
if _, err := bexpr.CreateEvaluator(req.Msg.Expr); err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid expression: %v", err))
}
}
authFilter := &domain.AuthFilter{
ID: util.NextID(),
Expr: req.Expr,
Expr: req.Msg.Expr,
AuthMethod: *authMethod,
Tailnet: tailnet,
}
@@ -80,20 +79,20 @@ func (s *Service) CreateAuthFilter(ctx context.Context, req *api.CreateAuthFilte
response := api.CreateAuthFilterResponse{AuthFilter: s.mapToApi(authMethod, *authFilter)}
return &response, nil
return connect.NewResponse[api.CreateAuthFilterResponse](&response), nil
}
func (s *Service) DeleteAuthFilter(ctx context.Context, req *api.DeleteAuthFilterRequest) (*api.DeleteAuthFilterResponse, error) {
func (s *Service) DeleteAuthFilter(ctx context.Context, req *connect.Request[api.DeleteAuthFilterRequest]) (*connect.Response[api.DeleteAuthFilterResponse], error) {
err := s.repository.Transaction(func(rp domain.Repository) error {
filter, err := rp.GetAuthFilter(ctx, req.AuthFilterId)
filter, err := rp.GetAuthFilter(ctx, req.Msg.AuthFilterId)
if err != nil {
return err
}
if filter == nil {
return status.Error(codes.NotFound, "auth filter not found")
return connect.NewError(connect.CodeNotFound, fmt.Errorf("auth filter not found"))
}
c, err := rp.ExpireMachineByAuthMethod(ctx, *filter.TailnetID, filter.AuthMethodID)
@@ -118,7 +117,7 @@ func (s *Service) DeleteAuthFilter(ctx context.Context, req *api.DeleteAuthFilte
response := api.DeleteAuthFilterResponse{}
return &response, nil
return connect.NewResponse[api.DeleteAuthFilterResponse](&response), nil
}
func (s *Service) mapToApi(authMethod *domain.AuthMethod, filter domain.AuthFilter) *api.AuthFilter {
+26 -26
View File
@@ -2,22 +2,22 @@ package service
import (
"context"
"errors"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/pkg/gen/api"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"google.golang.org/protobuf/types/known/timestamppb"
"time"
)
func (s *Service) GetAuthKey(ctx context.Context, req *api.GetAuthKeyRequest) (*api.GetAuthKeyResponse, error) {
key, err := s.repository.GetAuthKey(ctx, req.AuthKeyId)
func (s *Service) GetAuthKey(ctx context.Context, req *connect.Request[api.GetAuthKeyRequest]) (*connect.Response[api.GetAuthKeyResponse], error) {
key, err := s.repository.GetAuthKey(ctx, req.Msg.AuthKeyId)
if err != nil {
return nil, err
}
if key == nil {
return nil, status.Error(codes.NotFound, "")
return nil, connect.NewError(connect.CodeNotFound, errors.New("auth key not found"))
}
var expiresAt *timestamppb.Timestamp
@@ -25,7 +25,7 @@ func (s *Service) GetAuthKey(ctx context.Context, req *api.GetAuthKeyRequest) (*
expiresAt = timestamppb.New(*key.ExpiresAt)
}
return &api.GetAuthKeyResponse{AuthKey: &api.AuthKey{
return connect.NewResponse(&api.GetAuthKeyResponse{AuthKey: &api.AuthKey{
Id: key.ID,
Key: key.Key,
Ephemeral: key.Ephemeral,
@@ -36,20 +36,20 @@ func (s *Service) GetAuthKey(ctx context.Context, req *api.GetAuthKeyRequest) (*
Id: key.Tailnet.ID,
Name: key.Tailnet.Name,
},
}}, nil
}}), nil
}
func (s *Service) ListAuthKeys(ctx context.Context, req *api.ListAuthKeysRequest) (*api.ListAuthKeysResponse, error) {
tailnet, err := s.repository.GetTailnet(ctx, req.TailnetId)
func (s *Service) ListAuthKeys(ctx context.Context, req *connect.Request[api.ListAuthKeysRequest]) (*connect.Response[api.ListAuthKeysResponse], error) {
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, status.Error(codes.NotFound, "")
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
}
authKeys, err := s.repository.ListAuthKeys(ctx, req.TailnetId)
authKeys, err := s.repository.ListAuthKeys(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
@@ -76,28 +76,28 @@ func (s *Service) ListAuthKeys(ctx context.Context, req *api.ListAuthKeysRequest
})
}
return &response, nil
return connect.NewResponse(&response), nil
}
func (s *Service) CreateAuthKey(ctx context.Context, req *api.CreateAuthKeyRequest) (*api.CreateAuthKeyResponse, error) {
if len(req.Tags) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "at least one tag is required when creating an auth key")
func (s *Service) CreateAuthKey(ctx context.Context, req *connect.Request[api.CreateAuthKeyRequest]) (*connect.Response[api.CreateAuthKeyResponse], error) {
if len(req.Msg.Tags) == 0 {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("at least one tag is required when creating an auth key"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.TailnetId)
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, status.Error(codes.NotFound, "")
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
}
var expiresAt *time.Time
var expiresAtPb *timestamppb.Timestamp
if req.Expiry != nil {
duration := req.Expiry.AsDuration()
if req.Msg.Expiry != nil {
duration := req.Msg.Expiry.AsDuration()
e := time.Now().UTC().Add(duration)
expiresAt = &e
expiresAtPb = timestamppb.New(*expiresAt)
@@ -108,9 +108,9 @@ func (s *Service) CreateAuthKey(ctx context.Context, req *api.CreateAuthKeyReque
return nil, err
}
tags := domain.SanitizeTags(req.Tags)
tags := domain.SanitizeTags(req.Msg.Tags)
v, authKey := domain.CreateAuthKey(tailnet, user, req.Ephemeral, tags, expiresAt)
v, authKey := domain.CreateAuthKey(tailnet, user, req.Msg.Ephemeral, tags, expiresAt)
if err := s.repository.SaveAuthKey(ctx, authKey); err != nil {
return nil, err
@@ -131,12 +131,12 @@ func (s *Service) CreateAuthKey(ctx context.Context, req *api.CreateAuthKeyReque
},
}}
return &response, nil
return connect.NewResponse(&response), nil
}
func (s *Service) DeleteAuthKey(ctx context.Context, req *api.DeleteAuthKeyRequest) (*api.DeleteAuthKeyResponse, error) {
if _, err := s.repository.DeleteAuthKey(ctx, req.AuthKeyId); err != nil {
func (s *Service) DeleteAuthKey(ctx context.Context, req *connect.Request[api.DeleteAuthKeyRequest]) (*connect.Response[api.DeleteAuthKeyResponse], error) {
if _, err := s.repository.DeleteAuthKey(ctx, req.Msg.AuthKeyId); err != nil {
return nil, err
}
return &api.DeleteAuthKeyResponse{}, nil
return connect.NewResponse(&api.DeleteAuthKeyResponse{}), nil
}
+12 -11
View File
@@ -2,37 +2,38 @@ package service
import (
"context"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/util"
"github.com/jsiebens/ionscale/pkg/gen/api"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
func (s *Service) CreateAuthMethod(ctx context.Context, req *api.CreateAuthMethodRequest) (*api.CreateAuthMethodResponse, error) {
func (s *Service) CreateAuthMethod(ctx context.Context, req *connect.Request[api.CreateAuthMethodRequest]) (*connect.Response[api.CreateAuthMethodResponse], error) {
authMethod := &domain.AuthMethod{
ID: util.NextID(),
Name: req.Name,
Type: req.Type,
Issuer: req.Issuer,
ClientId: req.ClientId,
ClientSecret: req.ClientSecret,
Name: req.Msg.Name,
Type: req.Msg.Type,
Issuer: req.Msg.Issuer,
ClientId: req.Msg.ClientId,
ClientSecret: req.Msg.ClientSecret,
}
if err := s.repository.SaveAuthMethod(ctx, authMethod); err != nil {
return nil, err
}
return &api.CreateAuthMethodResponse{AuthMethod: &api.AuthMethod{
return connect.NewResponse(&api.CreateAuthMethodResponse{AuthMethod: &api.AuthMethod{
Id: authMethod.ID,
Type: authMethod.Type,
Name: authMethod.Name,
Issuer: authMethod.Issuer,
ClientId: authMethod.ClientId,
}}, nil
}}), nil
}
func (s *Service) ListAuthMethods(ctx context.Context, _ *api.ListAuthMethodsRequest) (*api.ListAuthMethodsResponse, error) {
func (s *Service) ListAuthMethods(ctx context.Context, _ *connect.Request[api.ListAuthMethodsRequest]) (*connect.Response[api.ListAuthMethodsResponse], error) {
methods, err := s.repository.ListAuthMethods(ctx)
if err != nil {
return nil, err
@@ -47,5 +48,5 @@ func (s *Service) ListAuthMethods(ctx context.Context, _ *api.ListAuthMethodsReq
})
}
return response, nil
return connect.NewResponse(response), nil
}
+7 -6
View File
@@ -3,11 +3,12 @@ package service
import (
"context"
"encoding/json"
"github.com/jsiebens/ionscale/pkg/gen/api"
"github.com/bufbuild/connect-go"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"tailscale.com/tailcfg"
)
func (s *Service) GetDERPMap(ctx context.Context, req *api.GetDERPMapRequest) (*api.GetDERPMapResponse, error) {
func (s *Service) GetDERPMap(ctx context.Context, req *connect.Request[api.GetDERPMapRequest]) (*connect.Response[api.GetDERPMapResponse], error) {
derpMap, err := s.repository.GetDERPMap(ctx)
if err != nil {
return nil, err
@@ -18,12 +19,12 @@ func (s *Service) GetDERPMap(ctx context.Context, req *api.GetDERPMapRequest) (*
return nil, err
}
return &api.GetDERPMapResponse{Value: raw}, nil
return connect.NewResponse(&api.GetDERPMapResponse{Value: raw}), nil
}
func (s *Service) SetDERPMap(ctx context.Context, req *api.SetDERPMapRequest) (*api.SetDERPMapResponse, error) {
func (s *Service) SetDERPMap(ctx context.Context, req *connect.Request[api.SetDERPMapRequest]) (*connect.Response[api.SetDERPMapResponse], error) {
var derpMap tailcfg.DERPMap
err := json.Unmarshal(req.Value, &derpMap)
err := json.Unmarshal(req.Msg.Value, &derpMap)
if err != nil {
return nil, err
}
@@ -34,5 +35,5 @@ func (s *Service) SetDERPMap(ctx context.Context, req *api.SetDERPMapRequest) (*
s.brokerPool.SignalDERPMapUpdated(&derpMap)
return &api.SetDERPMapResponse{Value: req.Value}, nil
return connect.NewResponse(&api.SetDERPMapResponse{Value: req.Msg.Value}), nil
}
+13 -13
View File
@@ -2,19 +2,19 @@ package service
import (
"context"
"errors"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/pkg/gen/api"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
func (s *Service) GetDNSConfig(ctx context.Context, req *api.GetDNSConfigRequest) (*api.GetDNSConfigResponse, error) {
tailnet, err := s.repository.GetTailnet(ctx, req.TailnetId)
func (s *Service) GetDNSConfig(ctx context.Context, req *connect.Request[api.GetDNSConfigRequest]) (*connect.Response[api.GetDNSConfigResponse], error) {
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, status.Error(codes.NotFound, "tailnet does not exist")
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
}
config, err := s.repository.GetDNSConfig(ctx, tailnet.ID)
@@ -31,22 +31,22 @@ func (s *Service) GetDNSConfig(ctx context.Context, req *api.GetDNSConfigRequest
},
}
return resp, nil
return connect.NewResponse(resp), nil
}
func (s *Service) SetDNSConfig(ctx context.Context, req *api.SetDNSConfigRequest) (*api.SetDNSConfigResponse, error) {
dnsConfig := req.Config
func (s *Service) SetDNSConfig(ctx context.Context, req *connect.Request[api.SetDNSConfigRequest]) (*connect.Response[api.SetDNSConfigResponse], error) {
dnsConfig := req.Msg.Config
if dnsConfig.MagicDns && len(dnsConfig.Nameservers) == 0 {
return nil, status.Error(codes.InvalidArgument, "at least one global nameserver is required when enabling magic dns")
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("at least one global nameserver is required when enabling magic dns"))
}
tailnet, err := s.repository.GetTailnet(ctx, req.TailnetId)
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, status.Error(codes.NotFound, "tailnet does not exist")
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
}
config := domain.DNSConfig{
@@ -64,7 +64,7 @@ func (s *Service) SetDNSConfig(ctx context.Context, req *api.SetDNSConfigRequest
resp := &api.SetDNSConfigResponse{Config: dnsConfig}
return resp, nil
return connect.NewResponse(resp), nil
}
func domainRoutesToApiRoutes(routes map[string][]string) map[string]*api.Routes {
+25 -33
View File
@@ -2,22 +2,22 @@ package service
import (
"context"
"errors"
"fmt"
"github.com/jsiebens/ionscale/pkg/gen/api"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/bufbuild/connect-go"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
"google.golang.org/protobuf/types/known/timestamppb"
"inet.af/netaddr"
"time"
)
func (s *Service) ListMachines(ctx context.Context, req *api.ListMachinesRequest) (*api.ListMachinesResponse, error) {
tailnet, err := s.repository.GetTailnet(ctx, req.TailnetId)
func (s *Service) ListMachines(ctx context.Context, req *connect.Request[api.ListMachinesRequest]) (*connect.Response[api.ListMachinesResponse], error) {
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, status.Error(codes.NotFound, "tailnet does not exist")
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
}
machines, err := s.repository.ListMachineByTailnet(ctx, tailnet.ID)
@@ -56,36 +56,36 @@ func (s *Service) ListMachines(ctx context.Context, req *api.ListMachinesRequest
})
}
return response, nil
return connect.NewResponse(response), nil
}
func (s *Service) DeleteMachine(ctx context.Context, req *api.DeleteMachineRequest) (*api.DeleteMachineResponse, error) {
m, err := s.repository.GetMachine(ctx, req.MachineId)
func (s *Service) DeleteMachine(ctx context.Context, req *connect.Request[api.DeleteMachineRequest]) (*connect.Response[api.DeleteMachineResponse], error) {
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
}
if m == nil {
return nil, status.Error(codes.NotFound, "machine does not exist")
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
}
if _, err := s.repository.DeleteMachine(ctx, req.MachineId); err != nil {
if _, err := s.repository.DeleteMachine(ctx, req.Msg.MachineId); err != nil {
return nil, err
}
s.brokers(m.TailnetID).SignalPeersRemoved([]uint64{m.ID})
return &api.DeleteMachineResponse{}, nil
return connect.NewResponse(&api.DeleteMachineResponse{}), nil
}
func (s *Service) ExpireMachine(ctx context.Context, req *api.ExpireMachineRequest) (*api.ExpireMachineResponse, error) {
m, err := s.repository.GetMachine(ctx, req.MachineId)
func (s *Service) ExpireMachine(ctx context.Context, req *connect.Request[api.ExpireMachineRequest]) (*connect.Response[api.ExpireMachineResponse], error) {
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
}
if m == nil {
return nil, status.Error(codes.NotFound, "machine does not exist")
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
}
timestamp := time.Unix(123, 0)
@@ -97,18 +97,18 @@ func (s *Service) ExpireMachine(ctx context.Context, req *api.ExpireMachineReque
s.brokers(m.TailnetID).SignalPeerUpdated(m.ID)
return &api.ExpireMachineResponse{}, nil
return connect.NewResponse(&api.ExpireMachineResponse{}), nil
}
func (s *Service) GetMachineRoutes(ctx context.Context, req *api.GetMachineRoutesRequest) (*api.GetMachineRoutesResponse, error) {
func (s *Service) GetMachineRoutes(ctx context.Context, req *connect.Request[api.GetMachineRoutesRequest]) (*connect.Response[api.GetMachineRoutesResponse], error) {
m, err := s.repository.GetMachine(ctx, req.MachineId)
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
}
if m == nil {
return nil, status.Error(codes.NotFound, "machine does not exist")
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
}
var routes []*api.RoutableIP
@@ -123,21 +123,21 @@ func (s *Service) GetMachineRoutes(ctx context.Context, req *api.GetMachineRoute
Routes: routes,
}
return &response, nil
return connect.NewResponse(&response), nil
}
func (s *Service) SetMachineRoutes(ctx context.Context, req *api.SetMachineRoutesRequest) (*api.GetMachineRoutesResponse, error) {
m, err := s.repository.GetMachine(ctx, req.MachineId)
func (s *Service) SetMachineRoutes(ctx context.Context, req *connect.Request[api.SetMachineRoutesRequest]) (*connect.Response[api.GetMachineRoutesResponse], error) {
m, err := s.repository.GetMachine(ctx, req.Msg.MachineId)
if err != nil {
return nil, err
}
if m == nil {
return nil, status.Error(codes.NotFound, "machine does not exist")
return nil, connect.NewError(connect.CodeNotFound, errors.New("machine not found"))
}
var allowedIps []netaddr.IPPrefix
for _, r := range req.AllowedIps {
for _, r := range req.Msg.AllowedIps {
prefix, err := netaddr.ParseIPPrefix(r)
if err != nil {
return nil, err
@@ -164,13 +164,5 @@ func (s *Service) SetMachineRoutes(ctx context.Context, req *api.SetMachineRoute
Routes: routes,
}
return &response, nil
}
func mapIp(ip []netaddr.IPPrefix) []string {
var x = []string{}
for _, i := range ip {
x = append(x, i.String())
}
return x
return connect.NewResponse(&response), nil
}
+5 -55
View File
@@ -2,22 +2,11 @@ package service
import (
"context"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/broker"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/key"
"github.com/jsiebens/ionscale/internal/token"
"github.com/jsiebens/ionscale/internal/version"
"github.com/jsiebens/ionscale/pkg/gen/api"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"strings"
)
var (
errMissingMetadata = status.Error(codes.InvalidArgument, "missing metadata")
errInvalidToken = status.Error(codes.Unauthenticated, "invalid token")
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
func NewService(repository domain.Repository, brokerPool *broker.BrokerPool) *Service {
@@ -36,49 +25,10 @@ func (s *Service) brokers(tailnetID uint64) broker.Broker {
return s.brokerPool.Get(tailnetID)
}
func (s *Service) GetVersion(ctx context.Context, req *api.GetVersionRequest) (*api.GetVersionResponse, error) {
func (s *Service) GetVersion(_ context.Context, _ *connect.Request[api.GetVersionRequest]) (*connect.Response[api.GetVersionResponse], error) {
v, revision := version.GetReleaseInfo()
return &api.GetVersionResponse{
return connect.NewResponse(&api.GetVersionResponse{
Version: v,
Revision: revision,
}, nil
}
func UnaryServerTokenAuth(systemAdminKey key.ServerPrivate) func(context.Context, interface{}, *grpc.UnaryServerInfo, grpc.UnaryHandler) (interface{}, error) {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if strings.HasSuffix(info.FullMethod, "/GetVersion") {
return handler(ctx, req)
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, errMissingMetadata
}
// The keys within metadata.MD are normalized to lowercase.
// See: https://godoc.org/google.golang.org/grpc/metadata#New
valid := validateAuthorizationToken(systemAdminKey, md["authorization"])
if valid {
return handler(ctx, req)
}
return nil, errInvalidToken
}
}
func validateAuthorizationToken(systemAdminKey key.ServerPrivate, authorization []string) bool {
if len(authorization) != 1 {
return false
}
bearerToken := strings.TrimPrefix(authorization[0], "Bearer ")
if token.IsSystemAdminToken(bearerToken) {
_, err := token.ParseSystemAdminToken(systemAdminKey, bearerToken)
return err == nil
}
return false
}), nil
}
+27 -27
View File
@@ -2,21 +2,21 @@ package service
import (
"context"
"errors"
"fmt"
"github.com/bufbuild/connect-go"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/pkg/gen/api"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1"
)
func (s *Service) CreateTailnet(ctx context.Context, req *api.CreateTailnetRequest) (*api.CreateTailnetResponse, error) {
tailnet, created, err := s.repository.GetOrCreateTailnet(ctx, req.Name)
func (s *Service) CreateTailnet(ctx context.Context, req *connect.Request[api.CreateTailnetRequest]) (*connect.Response[api.CreateTailnetResponse], error) {
tailnet, created, err := s.repository.GetOrCreateTailnet(ctx, req.Msg.Name)
if err != nil {
return nil, err
}
if !created {
return nil, fmt.Errorf("tailnet already exists")
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("tailnet already exists"))
}
resp := &api.CreateTailnetResponse{Tailnet: &api.Tailnet{
@@ -24,26 +24,26 @@ func (s *Service) CreateTailnet(ctx context.Context, req *api.CreateTailnetReque
Name: tailnet.Name,
}}
return resp, nil
return connect.NewResponse(resp), nil
}
func (s *Service) GetTailnet(ctx context.Context, req *api.GetTailnetRequest) (*api.GetTailnetResponse, error) {
tailnet, err := s.repository.GetTailnet(ctx, req.Id)
func (s *Service) GetTailnet(ctx context.Context, req *connect.Request[api.GetTailnetRequest]) (*connect.Response[api.GetTailnetResponse], error) {
tailnet, err := s.repository.GetTailnet(ctx, req.Msg.Id)
if err != nil {
return nil, err
}
if tailnet == nil {
return nil, status.Error(codes.NotFound, "")
return nil, connect.NewError(connect.CodeNotFound, errors.New("tailnet not found"))
}
return &api.GetTailnetResponse{Tailnet: &api.Tailnet{
return connect.NewResponse(&api.GetTailnetResponse{Tailnet: &api.Tailnet{
Id: tailnet.ID,
Name: tailnet.Name,
}}, nil
}}), nil
}
func (s *Service) ListTailnets(ctx context.Context, _ *api.ListTailnetRequest) (*api.ListTailnetResponse, error) {
func (s *Service) ListTailnets(ctx context.Context, _ *connect.Request[api.ListTailnetRequest]) (*connect.Response[api.ListTailnetResponse], error) {
resp := &api.ListTailnetResponse{}
tailnets, err := s.repository.ListTailnets(ctx)
@@ -54,46 +54,46 @@ func (s *Service) ListTailnets(ctx context.Context, _ *api.ListTailnetRequest) (
gt := api.Tailnet{Id: t.ID, Name: t.Name}
resp.Tailnet = append(resp.Tailnet, &gt)
}
return resp, nil
return connect.NewResponse(resp), nil
}
func (s *Service) DeleteTailnet(ctx context.Context, req *api.DeleteTailnetRequest) (*api.DeleteTailnetResponse, error) {
func (s *Service) DeleteTailnet(ctx context.Context, req *connect.Request[api.DeleteTailnetRequest]) (*connect.Response[api.DeleteTailnetResponse], error) {
count, err := s.repository.CountMachineByTailnet(ctx, req.TailnetId)
count, err := s.repository.CountMachineByTailnet(ctx, req.Msg.TailnetId)
if err != nil {
return nil, err
}
if !req.Force && count > 0 {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("tailnet is not empty, number of machines: %d", count))
if !req.Msg.Force && count > 0 {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("tailnet is not empty, number of machines: %d", count))
}
err = s.repository.Transaction(func(tx domain.Repository) error {
if err := tx.DeleteMachineByTailnet(ctx, req.TailnetId); err != nil {
if err := tx.DeleteMachineByTailnet(ctx, req.Msg.TailnetId); err != nil {
return err
}
if err := tx.DeleteAuthKeysByTailnet(ctx, req.TailnetId); err != nil {
if err := tx.DeleteAuthKeysByTailnet(ctx, req.Msg.TailnetId); err != nil {
return err
}
if err := tx.DeleteUsersByTailnet(ctx, req.TailnetId); err != nil {
if err := tx.DeleteUsersByTailnet(ctx, req.Msg.TailnetId); err != nil {
return err
}
if err := tx.DeleteAuthFiltersByTailnet(ctx, req.TailnetId); err != nil {
if err := tx.DeleteAuthFiltersByTailnet(ctx, req.Msg.TailnetId); err != nil {
return err
}
if err := tx.DeleteACLPolicy(ctx, req.TailnetId); err != nil {
if err := tx.DeleteACLPolicy(ctx, req.Msg.TailnetId); err != nil {
return err
}
if err := tx.DeleteDNSConfig(ctx, req.TailnetId); err != nil {
if err := tx.DeleteDNSConfig(ctx, req.Msg.TailnetId); err != nil {
return err
}
if err := tx.DeleteTailnet(ctx, req.TailnetId); err != nil {
if err := tx.DeleteTailnet(ctx, req.Msg.TailnetId); err != nil {
return err
}
@@ -104,7 +104,7 @@ func (s *Service) DeleteTailnet(ctx context.Context, req *api.DeleteTailnetReque
return nil, err
}
s.brokers(req.TailnetId).SignalUpdate()
s.brokers(req.Msg.TailnetId).SignalUpdate()
return &api.DeleteTailnetResponse{}, nil
return connect.NewResponse(&api.DeleteTailnetResponse{}), nil
}