From b019cc6c1fea7156e6f4a1be21ada7ad03c0c97f Mon Sep 17 00:00:00 2001 From: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com> Date: Tue, 14 May 2024 13:57:41 +0300 Subject: [PATCH] feat(auth): add policy caching Cache policies when the first authorization request happens and delete them when removing policies Signed-off-by: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com> --- auth/README.md | 70 ++++---- auth/cache/doc.go | 6 + auth/cache/policies.go | 87 +++++++++ auth/cache/policies_test.go | 345 ++++++++++++++++++++++++++++++++++++ auth/cache/setup_test.go | 75 ++++++++ auth/policies.go | 41 +++++ auth/spicedb/policies.go | 25 ++- cmd/auth/main.go | 21 ++- docker/.env | 2 + docker/docker-compose.yml | 14 ++ 10 files changed, 648 insertions(+), 38 deletions(-) create mode 100644 auth/cache/doc.go create mode 100644 auth/cache/policies.go create mode 100644 auth/cache/policies_test.go create mode 100644 auth/cache/setup_test.go diff --git a/auth/README.md b/auth/README.md index 4a991e0fb16..ee118b11ad1 100644 --- a/auth/README.md +++ b/auth/README.md @@ -59,40 +59,42 @@ Domain consists of the following fields: The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values. -| Variable | Description | Default | -| ------------------------------ | ----------------------------------------------------------------------- | ------------------------------- | -| MG_AUTH_LOG_LEVEL | Log level for the Auth service (debug, info, warn, error) | info | -| MG_AUTH_DB_HOST | Database host address | localhost | -| MG_AUTH_DB_PORT | Database host port | 5432 | -| MG_AUTH_DB_USER | Database user | magistrala | -| MG_AUTH_DB_PASSWORD | Database password | magistrala | -| MG_AUTH_DB_NAME | Name of the database used by the service | auth | -| MG_AUTH_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | -| MG_AUTH_DB_SSL_CERT | Path to the PEM encoded certificate file | "" | -| MG_AUTH_DB_SSL_KEY | Path to the PEM encoded key file | "" | -| MG_AUTH_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" | -| MG_AUTH_HTTP_HOST | Auth service HTTP host | "" | -| MG_AUTH_HTTP_PORT | Auth service HTTP port | 8189 | -| MG_AUTH_HTTP_SERVER_CERT | Path to the PEM encoded HTTP server certificate file | "" | -| MG_AUTH_HTTP_SERVER_KEY | Path to the PEM encoded HTTP server key file | "" | -| MG_AUTH_GRPC_HOST | Auth service gRPC host | "" | -| MG_AUTH_GRPC_PORT | Auth service gRPC port | 8181 | -| MG_AUTH_GRPC_SERVER_CERT | Path to the PEM encoded gRPC server certificate file | "" | -| MG_AUTH_GRPC_SERVER_KEY | Path to the PEM encoded gRPC server key file | "" | -| MG_AUTH_GRPC_SERVER_CA_CERTS | Path to the PEM encoded gRPC server CA certificate file | "" | -| MG_AUTH_GRPC_CLIENT_CA_CERTS | Path to the PEM encoded gRPC client CA certificate file | "" | -| MG_AUTH_SECRET_KEY | String used for signing tokens | secret | -| MG_AUTH_ACCESS_TOKEN_DURATION | The access token expiration period | 1h | -| MG_AUTH_REFRESH_TOKEN_DURATION | The refresh token expiration period | 24h | -| MG_AUTH_INVITATION_DURATION | The invitation token expiration period | 168h | -| MG_SPICEDB_HOST | SpiceDB host address | localhost | -| MG_SPICEDB_PORT | SpiceDB host port | 50051 | -| MG_SPICEDB_PRE_SHARED_KEY | SpiceDB pre-shared key | 12345678 | -| MG_SPICEDB_SCHEMA_FILE | Path to SpiceDB schema file | ./docker/spicedb/schema.zed | +| Variable | Description | Default | +| ------------------------------ | ----------------------------------------------------------------------- | ------------------------------ | +| MG_AUTH_LOG_LEVEL | Log level for the Auth service (debug, info, warn, error) | info | +| MG_AUTH_DB_HOST | Database host address | localhost | +| MG_AUTH_DB_PORT | Database host port | 5432 | +| MG_AUTH_DB_USER | Database user | magistrala | +| MG_AUTH_DB_PASSWORD | Database password | magistrala | +| MG_AUTH_DB_NAME | Name of the database used by the service | auth | +| MG_AUTH_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | +| MG_AUTH_DB_SSL_CERT | Path to the PEM encoded certificate file | "" | +| MG_AUTH_DB_SSL_KEY | Path to the PEM encoded key file | "" | +| MG_AUTH_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" | +| MG_AUTH_HTTP_HOST | Auth service HTTP host | "" | +| MG_AUTH_HTTP_PORT | Auth service HTTP port | 8189 | +| MG_AUTH_HTTP_SERVER_CERT | Path to the PEM encoded HTTP server certificate file | "" | +| MG_AUTH_HTTP_SERVER_KEY | Path to the PEM encoded HTTP server key file | "" | +| MG_AUTH_GRPC_HOST | Auth service gRPC host | "" | +| MG_AUTH_GRPC_PORT | Auth service gRPC port | 8181 | +| MG_AUTH_GRPC_SERVER_CERT | Path to the PEM encoded gRPC server certificate file | "" | +| MG_AUTH_GRPC_SERVER_KEY | Path to the PEM encoded gRPC server key file | "" | +| MG_AUTH_GRPC_SERVER_CA_CERTS | Path to the PEM encoded gRPC server CA certificate file | "" | +| MG_AUTH_GRPC_CLIENT_CA_CERTS | Path to the PEM encoded gRPC client CA certificate file | "" | +| MG_AUTH_SECRET_KEY | String used for signing tokens | secret | +| MG_AUTH_ACCESS_TOKEN_DURATION | The access token expiration period | 1h | +| MG_AUTH_REFRESH_TOKEN_DURATION | The refresh token expiration period | 24h | +| MG_AUTH_INVITATION_DURATION | The invitation token expiration period | 168h | +| MG_SPICEDB_HOST | SpiceDB host address | localhost | +| MG_SPICEDB_PORT | SpiceDB host port | 50051 | +| MG_SPICEDB_PRE_SHARED_KEY | SpiceDB pre-shared key | 12345678 | +| MG_SPICEDB_SCHEMA_FILE | Path to SpiceDB schema file | ./docker/spicedb/schema.zed | +| MG_AUTH_CACHE_URL | Cache server URL | "redis://localhost:6379/0" | +| MG_AUTH_CACHE_KEY_DURATION | Cache key expiration period | "1h" | | MG_JAEGER_URL | Jaeger server URL | | -| MG_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 | -| MG_SEND_TELEMETRY | Send telemetry to magistrala call home server | true | -| MG_AUTH_ADAPTER_INSTANCE_ID | Adapter instance ID | "" | +| MG_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 | +| MG_SEND_TELEMETRY | Send telemetry to magistrala call home server | true | +| MG_AUTH_ADAPTER_INSTANCE_ID | Adapter instance ID | "" | ## Deployment @@ -142,6 +144,8 @@ MG_SPICEDB_HOST=localhost \ MG_SPICEDB_PORT=50051 \ MG_SPICEDB_PRE_SHARED_KEY=12345678 \ MG_SPICEDB_SCHEMA_FILE=./docker/spicedb/schema.zed \ +MG_AUTH_CACHE_URL=redis://localhost:6379/0 \ +MG_AUTH_CACHE_KEY_DURATION=1h \ MG_JAEGER_URL=http://localhost:14268/api/traces \ MG_JAEGER_TRACE_RATIO=1.0 \ MG_SEND_TELEMETRY=true \ diff --git a/auth/cache/doc.go b/auth/cache/doc.go new file mode 100644 index 00000000000..6bf2be2e393 --- /dev/null +++ b/auth/cache/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package cache contains the domain concept definitions needed to +// support Magistrala auth cache service functionality. +package cache diff --git a/auth/cache/policies.go b/auth/cache/policies.go new file mode 100644 index 00000000000..fbf6176f340 --- /dev/null +++ b/auth/cache/policies.go @@ -0,0 +1,87 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "context" + "strings" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/go-redis/redis/v8" +) + +const defLimit = 100 + +var _ auth.Cache = (*policiesCache)(nil) + +type policiesCache struct { + client *redis.Client + keyDuration time.Duration +} + +// NewPoliciesCache returns redis auth cache implementation. +func NewPoliciesCache(client *redis.Client, duration time.Duration) auth.Cache { + return &policiesCache{ + client: client, + keyDuration: duration, + } +} + +func (pc *policiesCache) Save(ctx context.Context, key, value string) error { + if err := pc.client.Set(ctx, key, value, pc.keyDuration).Err(); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (pc *policiesCache) Contains(ctx context.Context, key, value string) bool { + rval, err := pc.client.Get(ctx, key).Result() + if err != nil { + return false + } + if rval == value { + return true + } + + return false +} + +func (pc *policiesCache) Remove(ctx context.Context, key string) error { + if strings.Contains(key, "*") { + return pc.delete(ctx, key) + } + + if err := pc.client.Del(ctx, key).Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + return nil +} + +func (pc *policiesCache) delete(ctx context.Context, key string) error { + keys, cursor, err := pc.client.Scan(ctx, 0, key, defLimit).Result() + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + for cursor != 0 { + var newKeys []string + newKeys, cursor, err = pc.client.Scan(ctx, cursor, key, defLimit).Result() + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + keys = append(keys, newKeys...) + } + + for _, key := range keys { + if err := pc.client.Del(ctx, key).Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + } + + return nil +} diff --git a/auth/cache/policies_test.go b/auth/cache/policies_test.go new file mode 100644 index 00000000000..28529f3a447 --- /dev/null +++ b/auth/cache/policies_test.go @@ -0,0 +1,345 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/auth/cache" + "github.com/absmach/magistrala/internal/testsutil" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/assert" +) + +var policy = auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, +} + +func setupRedisClient(t *testing.T) auth.Cache { + opts, err := redis.ParseURL(redisURL) + assert.Nil(t, err, fmt.Sprintf("got unexpected error on parsing redis URL: %s", err)) + redisClient := redis.NewClient(opts) + return cache.NewPoliciesCache(redisClient, 10*time.Minute) +} + +func TestSave(t *testing.T) { + authCache := setupRedisClient(t) + + cases := []struct { + desc string + policy auth.PolicyReq + err error + }{ + { + desc: "Save policy", + policy: policy, + err: nil, + }, + { + desc: "Save already cached policy", + policy: policy, + err: nil, + }, + { + desc: "Save another policy", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Save another policy with domain", + policy: auth.PolicyReq{ + Domain: testsutil.GenerateUUID(&testing.T{}), + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Save policy with long key", + policy: auth.PolicyReq{ + SubjectType: strings.Repeat("a", 513*1024*1024), + Subject: strings.Repeat("a", 513*1024*1024), + ObjectType: strings.Repeat("a", 513*1024*1024), + Object: strings.Repeat("a", 513*1024*1024), + Permission: auth.ViewPermission, + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "Save policy with long value", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: strings.Repeat("a", 513*1024*1024), + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "Save policy with empty key", + policy: auth.PolicyReq{ + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Save policy with empty subject", + policy: auth.PolicyReq{ + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Save policy with empty object", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Save policy with empty value", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + }, + err: nil, + }, + { + desc: "Save policy with empty key and id", + policy: auth.PolicyReq{}, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + key, val := tc.policy.KV() + err := authCache.Save(context.Background(), key, val) + if err == nil { + ok := authCache.Contains(context.Background(), key, val) + assert.True(t, ok) + } + assert.True(t, errors.Contains(err, tc.err)) + }) + } +} + +func TestContains(t *testing.T) { + authCache := setupRedisClient(t) + + key, val := policy.KV() + err := authCache.Save(context.Background(), key, val) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + + cases := []struct { + desc string + policy auth.PolicyReq + ok bool + }{ + { + desc: "Contains existing policy", + policy: policy, + ok: true, + }, + { + desc: "Contains invalid policy", + policy: auth.PolicyReq{ + SubjectType: policy.SubjectType, + Subject: policy.Subject, + ObjectType: policy.ObjectType, + Object: policy.Object, + Permission: auth.EditPermission, + }, + }, + { + desc: "Contains non existing policy", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + }, + { + desc: "Contains non existing policy with domain", + policy: auth.PolicyReq{ + Domain: testsutil.GenerateUUID(&testing.T{}), + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + }, + { + desc: "Contains policy with empty key", + policy: auth.PolicyReq{ + Permission: auth.ViewPermission, + }, + }, + { + desc: "Contains policy with long key", + policy: auth.PolicyReq{ + SubjectType: strings.Repeat("a", 513*1024*1024), + Subject: strings.Repeat("a", 513*1024*1024), + ObjectType: strings.Repeat("a", 513*1024*1024), + Object: strings.Repeat("a", 513*1024*1024), + Permission: auth.ViewPermission, + }, + }, + { + desc: "Contains policy with empty value", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + key, val := tc.policy.KV() + ok := authCache.Contains(context.Background(), key, val) + assert.Equal(t, tc.ok, ok) + }) + } +} + +func TestRemove(t *testing.T) { + authCache := setupRedisClient(t) + + subject := policy.Subject + object := policy.Object + + num := 200 + var policies []auth.PolicyReq + for i := 0; i < num; i++ { + policy.Subject = fmt.Sprintf("%s-%d", policy.Subject, i) + policy.Object = fmt.Sprintf("%s-%d", policy.Object, i) + key, val := policy.KV() + err := authCache.Save(context.Background(), key, val) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + policies = append(policies, policy) + } + + cases := []struct { + desc string + multiple bool + policy auth.PolicyReq + err error + }{ + { + desc: "Remove an existing policy from cache", + policy: policies[0], + err: nil, + }, + { + desc: "Remove multiple existing policies from cache with subject", + multiple: true, + policy: auth.PolicyReq{ + Subject: subject, + }, + err: nil, + }, + { + desc: "Remove multiple existing policies from cache with object", + multiple: true, + policy: auth.PolicyReq{ + Object: object, + }, + err: nil, + }, + { + desc: "Remove non existing policy from cache", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Remove policy with empty key from cache", + policy: auth.PolicyReq{ + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Remove policy with long key from cache", + policy: auth.PolicyReq{ + SubjectType: strings.Repeat("a", 513*1024*1024), + Subject: strings.Repeat("a", 513*1024*1024), + ObjectType: strings.Repeat("a", 513*1024*1024), + Object: strings.Repeat("a", 513*1024*1024), + Permission: auth.ViewPermission, + }, + err: repoerr.ErrRemoveEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := authCache.Remove(context.Background(), tc.policy.KeyForRemoval()) + assert.True(t, errors.Contains(err, tc.err)) + if err == nil { + key, val := tc.policy.KV() + ok := authCache.Contains(context.Background(), key, val) + assert.False(t, ok) + if tc.multiple { + switch { + case tc.policy.Subject != "": + for _, p := range policies { + if strings.HasPrefix(p.Subject, subject) { + key, val := p.KV() + ok := authCache.Contains(context.Background(), key, val) + assert.False(t, ok) + } + } + case tc.policy.Object != "": + for _, p := range policies { + if strings.HasPrefix(p.Object, object) { + key, val := p.KV() + ok := authCache.Contains(context.Background(), key, val) + assert.False(t, ok) + } + } + } + } + } + }) + } +} diff --git a/auth/cache/setup_test.go b/auth/cache/setup_test.go new file mode 100644 index 00000000000..76f0fb146b1 --- /dev/null +++ b/auth/cache/setup_test.go @@ -0,0 +1,75 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache_test + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "testing" + + "github.com/go-redis/redis/v8" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +var redisURL string + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "redis", + Tag: "7.2.4-alpine", + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + handleInterrupt(pool, container) + + redisURL = fmt.Sprintf("redis://localhost:%s/0", container.GetPort("6379/tcp")) + opts, err := redis.ParseURL(redisURL) + if err != nil { + log.Fatalf("Could not parse redis URL: %s", err) + } + + if err := pool.Retry(func() error { + redisClient := redis.NewClient(opts) + + return redisClient.Ping(context.Background()).Err() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + code := m.Run() + + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} + +func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + go func() { + <-c + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + os.Exit(0) + }() +} diff --git a/auth/policies.go b/auth/policies.go index e2e416aed28..d8d0fd706a4 100644 --- a/auth/policies.go +++ b/auth/policies.go @@ -104,6 +104,34 @@ func (pr PolicyReq) String() string { return string(data) } +// KV returns the key-value pair for the given PolicyReq. +func (pr PolicyReq) KV() (string, string) { + var key, val string + switch pr.Domain { + case "": + key = pr.SubjectType + ":" + pr.Subject + ":" + pr.ObjectType + ":" + pr.Object + default: + key = pr.Domain + ":" + pr.SubjectType + ":" + pr.Subject + ":" + pr.ObjectType + ":" + pr.Object + } + val = pr.Permission + + return key, val +} + +// KeyForRemoval returns the key for the given PolicyReq. It is used +// to remove a key from the cache. +func (pr PolicyReq) KeyForRemoval() string { + switch { + case pr.Subject != "" && pr.Object == "": + return "*" + pr.Subject + "*" + case pr.Object != "" && pr.Subject == "": + return "*" + pr.Object + "*" + default: + key, _ := pr.KV() + return key + } +} + type PolicyRes struct { Namespace string Subject string @@ -221,3 +249,16 @@ type PolicyAgent interface { // (ctx context.Context, pr PolicyReq, filterPermissions []string) ([]PolicyReq, error) RetrievePermissions(ctx context.Context, pr PolicyReq, filterPermission []string) (Permissions, error) } + +// Cache represents a cache repository. It exposes functionalities +// through `auth` to perform caching. +type Cache interface { + // Save saves the key-value pair in the cache. + Save(ctx context.Context, key, value string) error + + // Contains checks if the key-value pair exists in the cache. + Contains(ctx context.Context, key, value string) bool + + // Remove removes the key from the cache. + Remove(ctx context.Context, key string) error +} diff --git a/auth/spicedb/policies.go b/auth/spicedb/policies.go index 7ac2ba4a2c3..cc83b5c639a 100644 --- a/auth/spicedb/policies.go +++ b/auth/spicedb/policies.go @@ -35,17 +35,30 @@ type policyAgent struct { client *authzed.ClientWithExperimental permissionClient v1.PermissionsServiceClient logger *slog.Logger + cache auth.Cache } -func NewPolicyAgent(client *authzed.ClientWithExperimental, logger *slog.Logger) auth.PolicyAgent { +func NewPolicyAgent(client *authzed.ClientWithExperimental, logger *slog.Logger, cache auth.Cache) auth.PolicyAgent { return &policyAgent{ client: client, permissionClient: client.PermissionsServiceClient, logger: logger, + cache: cache, } } -func (pa *policyAgent) CheckPolicy(ctx context.Context, pr auth.PolicyReq) error { +func (pa *policyAgent) CheckPolicy(ctx context.Context, pr auth.PolicyReq) (err error) { + key, val := pr.KV() + if pa.cache.Contains(ctx, key, val) { + return nil + } + defer func() { + if err == nil { + cacheErr := pa.cache.Save(ctx, key, val) + err = errors.Wrap(err, cacheErr) + } + }() + checkReq := v1.CheckPermissionRequest{ // FullyConsistent means little caching will be available, which means performance will suffer. // Only use if a ZedToken is not available or absolutely latest information is required. @@ -134,6 +147,10 @@ func (pa *policyAgent) AddPolicy(ctx context.Context, pr auth.PolicyReq) error { func (pa *policyAgent) DeletePolicies(ctx context.Context, prs []auth.PolicyReq) error { updates := []*v1.RelationshipUpdate{} for _, pr := range prs { + if err := pa.cache.Remove(ctx, pr.KeyForRemoval()); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + updates = append(updates, &v1.RelationshipUpdate{ Operation: v1.RelationshipUpdate_OPERATION_DELETE, Relationship: &v1.Relationship{ @@ -154,6 +171,10 @@ func (pa *policyAgent) DeletePolicies(ctx context.Context, prs []auth.PolicyReq) } func (pa *policyAgent) DeletePolicyFilter(ctx context.Context, pr auth.PolicyReq) error { + if err := pa.cache.Remove(ctx, pr.KeyForRemoval()); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + req := &v1.DeleteRelationshipsRequest{ RelationshipFilter: &v1.RelationshipFilter{ ResourceType: pr.ObjectType, diff --git a/cmd/auth/main.go b/cmd/auth/main.go index a2b11b027e8..b1d4fe1b749 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -18,11 +18,13 @@ import ( api "github.com/absmach/magistrala/auth/api" grpcapi "github.com/absmach/magistrala/auth/api/grpc" httpapi "github.com/absmach/magistrala/auth/api/http" + "github.com/absmach/magistrala/auth/cache" "github.com/absmach/magistrala/auth/events" "github.com/absmach/magistrala/auth/jwt" apostgres "github.com/absmach/magistrala/auth/postgres" "github.com/absmach/magistrala/auth/spicedb" "github.com/absmach/magistrala/auth/tracing" + redisclient "github.com/absmach/magistrala/internal/clients/redis" mglog "github.com/absmach/magistrala/logger" "github.com/absmach/magistrala/pkg/jaeger" "github.com/absmach/magistrala/pkg/postgres" @@ -36,6 +38,7 @@ import ( "github.com/authzed/authzed-go/v1" "github.com/authzed/grpcutil" "github.com/caarlos0/env/v10" + "github.com/go-redis/redis/v8" "github.com/jmoiron/sqlx" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" @@ -67,6 +70,8 @@ type config struct { SpicedbPort string `env:"MG_SPICEDB_PORT" envDefault:"50051"` SpicedbSchemaFile string `env:"MG_SPICEDB_SCHEMA_FILE" envDefault:"./docker/spicedb/schema.zed"` SpicedbPreSharedKey string `env:"MG_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` + CacheURL string `env:"MG_AUTH_CACHE_URL" envDefault:"redis://localhost:6379/0"` + CacheKeyDuration time.Duration `env:"MG_AUTH_CACHE_KEY_DURATION" envDefault:"1h"` TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` ESURL string `env:"MG_ES_URL" envDefault:"nats://localhost:4222"` } @@ -122,6 +127,14 @@ func main() { }() tracer := tp.Tracer(svcName) + cacheclient, err := redisclient.Connect(cfg.CacheURL) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer cacheclient.Close() + spicedbclient, err := initSpiceDB(ctx, cfg) if err != nil { logger.Error(fmt.Sprintf("failed to init spicedb grpc client : %s\n", err.Error())) @@ -129,7 +142,7 @@ func main() { return } - svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient) + svc := newService(ctx, db, tracer, cfg, dbConfig, cacheclient, cfg.CacheKeyDuration, logger, spicedbclient) httpServerConfig := server.Config{Port: defSvcHTTPPort} if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { @@ -203,11 +216,13 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch return nil } -func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental) auth.Service { +func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, cacheClient *redis.Client, keyDuration time.Duration, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental) auth.Service { database := postgres.NewDatabase(db, dbConfig, tracer) keysRepo := apostgres.New(database) domainsRepo := apostgres.NewDomainRepository(database) - pa := spicedb.NewPolicyAgent(spicedbClient, logger) + policiesCache := cache.NewPoliciesCache(cacheClient, keyDuration) + + pa := spicedb.NewPolicyAgent(spicedbClient, logger, policiesCache) idProvider := uuid.New() t := jwt.New([]byte(cfg.SecretKey)) diff --git a/docker/.env b/docker/.env index 29a8588cf9a..6bdff0225cd 100644 --- a/docker/.env +++ b/docker/.env @@ -93,6 +93,8 @@ MG_AUTH_DB_SSL_MODE=disable MG_AUTH_DB_SSL_CERT= MG_AUTH_DB_SSL_KEY= MG_AUTH_DB_SSL_ROOT_CERT= +MG_AUTH_CACHE_URL=redis://auth-redis:${MG_REDIS_TCP_PORT}/0 +MG_AUTH_CACHE_KEY_DURATION="1h" MG_AUTH_SECRET_KEY=HyE2D4RUt9nnKG6v8zKEqAp6g6ka8hhZsqUpzgKvnwpXrNVQSH MG_AUTH_ACCESS_TOKEN_DURATION="1h" MG_AUTH_REFRESH_TOKEN_DURATION="24h" diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index c206e7529c7..32e5d7dfb6a 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -17,6 +17,7 @@ volumes: magistrala-auth-db-volume: magistrala-invitations-db-volume: magistrala-ui-db-volume: + magistrala-auth-redis-volume: services: spicedb: @@ -64,6 +65,15 @@ services: volumes: - magistrala-spicedb-db-volume:/var/lib/postgresql/data + auth-redis: + image: redis:7.2.4-alpine + container_name: magistrala-auth-redis + restart: on-failure + networks: + - magistrala-base-net + volumes: + - magistrala-auth-redis-volume:/data + auth-db: image: postgres:16.2-alpine container_name: magistrala-auth-db @@ -83,6 +93,7 @@ services: image: magistrala/auth:${MG_RELEASE_TAG} container_name: magistrala-auth depends_on: + - auth-redis - auth-db - spicedb expose: @@ -120,6 +131,8 @@ services: MG_AUTH_DB_SSL_CERT: ${MG_AUTH_DB_SSL_CERT} MG_AUTH_DB_SSL_KEY: ${MG_AUTH_DB_SSL_KEY} MG_AUTH_DB_SSL_ROOT_CERT: ${MG_AUTH_DB_SSL_ROOT_CERT} + MG_AUTH_CACHE_URL: ${MG_AUTH_CACHE_URL} + MG_AUTH_CACHE_KEY_DURATION: ${MG_AUTH_CACHE_KEY_DURATION} MG_JAEGER_URL: ${MG_JAEGER_URL} MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} @@ -295,6 +308,7 @@ services: image: magistrala/things:${MG_RELEASE_TAG} container_name: magistrala-things depends_on: + - things-redis - things-db - users - auth