From fe230b0110bb02202b0da97151ac24a94b736550 Mon Sep 17 00:00:00 2001 From: SalmaElsoly Date: Thu, 27 Nov 2025 13:07:32 +0200 Subject: [PATCH 01/10] feat: add redis distributed locking mechanism --- backend/cmd/root.go | 7 ++ backend/internal/api/app/app_dependencies.go | 9 ++ backend/internal/config/config.go | 18 +++- .../distributed_locks/distributed_locks.go | 12 +++ .../core/distributed_locks/redis_locker.go | 85 +++++++++++++++++ .../distributed_locks/redis_locker_test.go | 93 +++++++++++++++++++ .../core/services/deployment_service.go | 17 ++++ .../internal/core/services/node_service.go | 8 ++ .../internal/core/services/workers_service.go | 49 +++++++++- .../internal/core/workers/locks_releaser.go | 36 +++++++ 10 files changed, 329 insertions(+), 5 deletions(-) create mode 100644 backend/internal/core/distributed_locks/distributed_locks.go create mode 100644 backend/internal/core/distributed_locks/redis_locker.go create mode 100644 backend/internal/core/distributed_locks/redis_locker_test.go create mode 100644 backend/internal/core/workers/locks_releaser.go diff --git a/backend/cmd/root.go b/backend/cmd/root.go index d3fe1dd1..ba1fe789 100644 --- a/backend/cmd/root.go +++ b/backend/cmd/root.go @@ -210,6 +210,13 @@ func addFlags() error { if err := bindStringFlag(rootCmd, "telemetry.otlp_endpoint", "jaeger:4317", "OpenTelemetry gRPC endpoint"); err != nil { return fmt.Errorf("failed to bind telemetry.otlp_endpoint flag: %w", err) } + // === Locks === + if err := bindIntFlag(rootCmd, "redis.lock_timeout_in_hours", 24, "Redis lock timeout (hours)"); err != nil { + return fmt.Errorf("failed to bind redis.lock_timeout_in_hours flag: %w", err) + } + if err := bindIntFlag(rootCmd, "locks_release_interval_in_minutes", 5, "Locks release interval (minutes)"); err != nil { + return fmt.Errorf("failed to bind locks_release_interval_in_minutes flag: %w", err) + } return nil } diff --git a/backend/internal/api/app/app_dependencies.go b/backend/internal/api/app/app_dependencies.go index d42b1497..774a92d4 100644 --- a/backend/internal/api/app/app_dependencies.go +++ b/backend/internal/api/app/app_dependencies.go @@ -7,6 +7,7 @@ import ( "kubecloud/internal/auth" "kubecloud/internal/billing" cfg "kubecloud/internal/config" + distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" corepersistence "kubecloud/internal/core/persistence" "kubecloud/internal/core/queuing" @@ -66,6 +67,7 @@ type appCore struct { db models.DB metrics *metrics.Metrics ewfEngine *ewf.Engine + locker distributedlocks.DistributedLocks tracerProvider *telemetry.TracerProvider } @@ -167,11 +169,14 @@ func createAppCore(ctx context.Context, config cfg.Configuration) (appCore, erro return appCore{}, fmt.Errorf("failed to init workflow engine: %w", err) } + locker := distributedlocks.NewRedisLocker(client, time.Duration(config.Redis.LockTimeoutInHours)*time.Hour) + return appCore{ appCtx: ctx, db: db, metrics: metrics.NewMetrics(), ewfEngine: ewfEngine, + locker: locker, tracerProvider: tp, }, nil } @@ -334,6 +339,7 @@ func (app *App) createHandlers() appHandlers { nodeService := services.NewNodeService( userNodesRepo, userRepo, app.core.appCtx, app.core.ewfEngine, app.infra.gridClient, + app.core.locker, ) invoiceService := services.NewInvoiceService( @@ -342,6 +348,7 @@ func (app *App) createHandlers() appHandlers { deploymentService := services.NewDeploymentService( app.core.appCtx, clusterRepo, userRepo, userNodesRepo, app.core.ewfEngine, + app.core.locker, app.config.Debug, app.security.sshPublicKey, app.config.SSH.PrivateKeyPath, app.config.SystemAccount.Network, ) @@ -394,6 +401,8 @@ func (app *App) createWorkers() workers.Workers { app.config.Invoice, app.config.SystemAccount.Mnemonic, app.config.Currency, app.config.ClusterHealthCheckIntervalInHours, app.config.NodeHealthCheck.ReservedNodeHealthCheckIntervalInHours, app.config.NodeHealthCheck.ReservedNodeHealthCheckTimeoutInMinutes, app.config.NodeHealthCheck.ReservedNodeHealthCheckWorkersNum, app.config.MonitorBalanceIntervalInMinutes, app.config.NotifyAdminsForPendingRecordsInHours, app.config.UsersBalanceCheckIntervalInHours, app.config.CheckUserDebtIntervalInHours, + app.config.LocksReleaseIntervalInMinutes, + app.core.locker, ) return workers.NewWorkers(app.core.appCtx, workersService, app.core.metrics, app.core.db) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 1886e210..b535e4a5 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -34,6 +34,7 @@ type Configuration struct { DevMode bool `json:"dev_mode"` // When true, allows empty SendGridKey and uses FakeMailService MonitorBalanceIntervalInMinutes int `json:"monitor_balance_interval_in_minutes" validate:"required,gt=0"` NotifyAdminsForPendingRecordsInHours int `json:"notify_admins_for_pending_records_in_hours" validate:"required,gt=0"` + LocksReleaseIntervalInMinutes int `json:"locks_release_interval_in_minutes" validate:"required,gt=0" default:"5"` ClusterHealthCheckIntervalInHours int `json:"cluster_health_check_interval_in_hours" validate:"gt=0" default:"1"` UsersBalanceCheckIntervalInHours int `json:"users_balance_check_interval_in_hours" validate:"gt=0" default:"6"` CheckUserDebtIntervalInHours int `json:"check_user_debt_interval_in_hours" validate:"gt=0" default:"48"` @@ -50,10 +51,11 @@ type SSHConfig struct { } type RedisConfig struct { - Hostname string `json:"hostname" validate:"hostname|ip|url"` - Port int `json:"port" validate:"min=1,max=65535"` - Password string `json:"password"` - DB int `json:"db" validate:"min=0"` + Hostname string `json:"hostname" validate:"hostname|ip|url"` + Port int `json:"port" validate:"min=1,max=65535"` + Password string `json:"password"` + DB int `json:"db" validate:"min=0"` + LockTimeoutInHours int `json:"lock_timeout_in_hours" validate:"required,min=1" default:"1"` } // Server struct holds server's information @@ -391,6 +393,14 @@ func applyDefaultValues(config *Configuration) { config.NotifyAdminsForPendingRecordsInHours = 24 } + if config.Redis.LockTimeoutInHours == 0 { + config.Redis.LockTimeoutInHours = 24 + } + + if config.LocksReleaseIntervalInMinutes == 0 { + config.LocksReleaseIntervalInMinutes = 5 + } + if config.Telemetry.OTLPEndpoint == "" { config.Telemetry.OTLPEndpoint = "jaeger:4317" } diff --git a/backend/internal/core/distributed_locks/distributed_locks.go b/backend/internal/core/distributed_locks/distributed_locks.go new file mode 100644 index 00000000..fd190f77 --- /dev/null +++ b/backend/internal/core/distributed_locks/distributed_locks.go @@ -0,0 +1,12 @@ +package distributedlocks + +import ( + "context" +) + +type DistributedLocks interface { + AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) error + AcquireWorkflowLock(ctx context.Context, nodeID uint32, workflowID string) (bool, error) + ReleaseLock(ctx context.Context, nodeID uint32, workflowID string) error + GetAllWorkflowsLocks(ctx context.Context) ([]string, error) +} diff --git a/backend/internal/core/distributed_locks/redis_locker.go b/backend/internal/core/distributed_locks/redis_locker.go new file mode 100644 index 00000000..c41fe1c6 --- /dev/null +++ b/backend/internal/core/distributed_locks/redis_locker.go @@ -0,0 +1,85 @@ +package distributedlocks + +import ( + "context" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +type RedisLocker struct { + client *redis.Client + lockTimeout time.Duration +} + +// NewRedisLocker creates a new RedisLocker instance. +func NewRedisLocker(client *redis.Client, lockTimeout time.Duration) *RedisLocker { + return &RedisLocker{ + client: client, + lockTimeout: lockTimeout, + } +} + +// AcquireNodesLocks acquires locks for the given node IDs. +func (l *RedisLocker) AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) error { + keys := nodeLockKeys(nodeIDs) + locked := make([]string, 0, len(keys)) + + for _, key := range keys { + ok, err := l.client.SetNX(ctx, key, 1, l.lockTimeout).Result() + if err != nil { + err = l.client.Del(ctx, locked...).Err() + if err != nil { + return fmt.Errorf("redis error while rolling back locks: %w", err) + } + return fmt.Errorf("redis error while acquiring lock for key %s: %w", key, err) + } + + if !ok { + err = l.client.Del(ctx, locked...).Err() + if err != nil { + return fmt.Errorf("redis error while rolling back locks: %w", err) + } + return fmt.Errorf("failed to acquire lock for key %s: node already locked", key) + } + + locked = append(locked, key) + } + + return nil +} + +// AcquireWorkflowLock acquires a lock for the given workflow ID. +func (l *RedisLocker) AcquireWorkflowLock(ctx context.Context, nodeID uint32, workflowID string) (bool, error) { + key := workflowLockKey(nodeID, workflowID) + return l.client.SetNX(ctx, key, 1, l.lockTimeout).Result() +} + +// ReleaseLock releases a lock for the given node ID and workflow ID. +func (l *RedisLocker) ReleaseLock(ctx context.Context, nodeID uint32, workflowID string) error { + lockedKey := nodeLockKey(nodeID) + usedKey := workflowLockKey(nodeID, workflowID) + return l.client.Del(ctx, lockedKey, usedKey).Err() +} + +// GetAllWorkflowsLocks gets all workflow locks. +func (l *RedisLocker) GetAllWorkflowsLocks(ctx context.Context) ([]string, error) { + return l.client.Keys(ctx, "used:*").Result() +} + +func nodeLockKey(nodeID uint32) string { + return fmt.Sprintf("locked:%d", nodeID) +} + +func nodeLockKeys(nodeIDs []uint32) []string { + keys := make([]string, len(nodeIDs)) + for i, id := range nodeIDs { + keys[i] = nodeLockKey(id) + } + return keys +} + +func workflowLockKey(nodeID uint32, workflowID string) string { + return fmt.Sprintf("used:%d:%s", nodeID, workflowID) +} diff --git a/backend/internal/core/distributed_locks/redis_locker_test.go b/backend/internal/core/distributed_locks/redis_locker_test.go new file mode 100644 index 00000000..4a05cc62 --- /dev/null +++ b/backend/internal/core/distributed_locks/redis_locker_test.go @@ -0,0 +1,93 @@ +package distributedlocks + +import ( + "context" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func newTestRedisClient(t *testing.T) *redis.Client { + t.Helper() + + client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) + + require.NoError(t, client.Ping(context.Background()).Err()) + require.NoError(t, client.FlushDB(context.Background()).Err()) + + t.Cleanup(func() { + require.NoError(t, client.FlushDB(context.Background()).Err()) + require.NoError(t, client.Close()) + }) + + return client +} + +func TestRedisLocker_AcquireNodesLocks_Success(t *testing.T) { + client := newTestRedisClient(t) + locker := &RedisLocker{client: client, lockTimeout: time.Minute} + + err := locker.AcquireNodesLocks(context.Background(), []uint32{1, 2}) + + require.NoError(t, err) + require.Equal(t, int64(1), client.Exists(context.Background(), "locked:1").Val()) + require.Equal(t, int64(1), client.Exists(context.Background(), "locked:2").Val()) +} + +func TestRedisLocker_AcquireNodesLocks_NodeAlreadyLocked(t *testing.T) { + client := newTestRedisClient(t) + locker := &RedisLocker{client: client, lockTimeout: time.Minute} + + require.NoError(t, client.Set(context.Background(), "locked:2", 1, 0).Err()) + + err := locker.AcquireNodesLocks(context.Background(), []uint32{1, 2}) + + require.Error(t, err) + require.Contains(t, err.Error(), "failed to acquire lock for key locked:2") + require.Equal(t, int64(0), client.Exists(context.Background(), "locked:1").Val(), "previous locks should be rolled back") +} + +func TestRedisLocker_AcquireWorkflowLock(t *testing.T) { + client := newTestRedisClient(t) + locker := &RedisLocker{client: client, lockTimeout: time.Minute} + + ok, err := locker.AcquireWorkflowLock(context.Background(), 1, "wf-1") + require.NoError(t, err) + require.True(t, ok) + + ok, err = locker.AcquireWorkflowLock(context.Background(), 1, "wf-1") + require.NoError(t, err) + require.False(t, ok) +} + +func TestRedisLocker_ReleaseLock(t *testing.T) { + client := newTestRedisClient(t) + locker := &RedisLocker{client: client, lockTimeout: time.Minute} + + require.NoError(t, client.Set(context.Background(), "locked:1", 1, 0).Err()) + require.NoError(t, client.Set(context.Background(), "used:1:wf-1", 1, 0).Err()) + + err := locker.ReleaseLock(context.Background(), 1, "wf-1") + + require.NoError(t, err) + require.Equal(t, int64(0), client.Exists(context.Background(), "locked:1").Val()) + require.Equal(t, int64(0), client.Exists(context.Background(), "used:1:wf-1").Val()) +} + +func TestRedisLocker_GetAllWorkflowsLocks(t *testing.T) { + client := newTestRedisClient(t) + locker := &RedisLocker{client: client, lockTimeout: time.Minute} + + require.NoError(t, client.Set(context.Background(), "used:1:wf-1", 1, 0).Err()) + require.NoError(t, client.Set(context.Background(), "used:2:wf-2", 1, 0).Err()) + require.NoError(t, client.Set(context.Background(), "locked:99", 1, 0).Err()) + + keys, err := locker.GetAllWorkflowsLocks(context.Background()) + + require.NoError(t, err) + require.ElementsMatch(t, []string{"used:1:wf-1", "used:2:wf-2"}, keys) +} diff --git a/backend/internal/core/services/deployment_service.go b/backend/internal/core/services/deployment_service.go index f7d3e90b..d4372bef 100644 --- a/backend/internal/core/services/deployment_service.go +++ b/backend/internal/core/services/deployment_service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" cfg "kubecloud/internal/config" + distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "kubecloud/internal/core/persistence" "kubecloud/internal/core/workflows" @@ -27,6 +28,8 @@ type DeploymentService struct { ewfEngine *ewf.Engine tracer *telemetry.ServiceTracer + locker distributedlocks.DistributedLocks + // configs debug bool sshPublicKey string @@ -37,6 +40,7 @@ type DeploymentService struct { func NewDeploymentService(appCtx context.Context, clusterRepo models.ClusterRepository, userRepo models.UserRepository, userNodesRepo models.UserNodesRepository, ewfEngine *ewf.Engine, + locker distributedlocks.DistributedLocks, debug bool, sshPublicKey, sshPrivateKeyPath, systemNetwork string, ) DeploymentService { return DeploymentService{ @@ -47,6 +51,8 @@ func NewDeploymentService(appCtx context.Context, ewfEngine: ewfEngine, tracer: telemetry.NewServiceTracer("deployment_service"), + locker: locker, + debug: debug, sshPublicKey: sshPublicKey, sshPrivateKeyPath: sshPrivateKeyPath, @@ -242,6 +248,13 @@ func (svc *DeploymentService) handleDeploymentAction(userID int, workflowName st } func (svc *DeploymentService) AsyncDeployCluster(config statemanager.ClientConfig, cluster kubedeployer.Cluster) (string, ewf.WorkflowStatus, error) { + nodeIDs := make([]uint32, 0, len(cluster.Nodes)) + for _, node := range cluster.Nodes { + nodeIDs = append(nodeIDs, node.NodeID) + } + if err := svc.locker.AcquireNodesLocks(svc.appCtx, nodeIDs); err != nil { + return "", "", err + } state := ewf.State{ "config": config, @@ -282,6 +295,10 @@ func (svc *DeploymentService) AsyncDeleteAllClusters(config statemanager.ClientC func (svc *DeploymentService) AsyncAddNode(config statemanager.ClientConfig, cl kubedeployer.Cluster, node kubedeployer.Node) (string, ewf.WorkflowStatus, error) { + if err := svc.locker.AcquireNodesLocks(svc.appCtx, []uint32{node.NodeID}); err != nil { + return "", "", err + } + state := ewf.State{ "config": config, "cluster": cl, diff --git a/backend/internal/core/services/node_service.go b/backend/internal/core/services/node_service.go index 49ee228e..0dd4729e 100644 --- a/backend/internal/core/services/node_service.go +++ b/backend/internal/core/services/node_service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" cfg "kubecloud/internal/config" + distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "kubecloud/internal/core/persistence" "kubecloud/internal/core/workflows" @@ -29,12 +30,14 @@ type NodeService struct { appCtx context.Context ewfEngine *ewf.Engine gridClient gridclient.GridClient + locker distributedlocks.DistributedLocks tracer *telemetry.ServiceTracer } func NewNodeService( userNodesRepo models.UserNodesRepository, userRepo models.UserRepository, appCtx context.Context, ewfEngine *ewf.Engine, gridClient gridclient.GridClient, + locker distributedlocks.DistributedLocks, ) NodeService { return NodeService{ nodesRepo: userNodesRepo, @@ -42,6 +45,7 @@ func NewNodeService( appCtx: appCtx, ewfEngine: ewfEngine, gridClient: gridClient, + locker: locker, tracer: telemetry.NewServiceTracer("node_service"), } } @@ -214,6 +218,10 @@ func (svc *NodeService) GetRentedNodesForUser(ctx context.Context, userID int, h } func (svc *NodeService) AsyncReserveNode(userID int, userMnemonic string, nodeID uint32) (string, error) { + if err := svc.locker.AcquireNodesLocks(svc.appCtx, []uint32{nodeID}); err != nil { + return "", err + } + queueName := fmt.Sprintf("%s:user_%d", cfg.DefaultQueueConfig.Name, userID) displayName := fmt.Sprintf("Reserving node %d", nodeID) metadata := map[string]string{ diff --git a/backend/internal/core/services/workers_service.go b/backend/internal/core/services/workers_service.go index e3ebb2fc..817a4c8e 100644 --- a/backend/internal/core/services/workers_service.go +++ b/backend/internal/core/services/workers_service.go @@ -6,6 +6,7 @@ import ( "fmt" "kubecloud/internal/billing" "kubecloud/internal/config" + distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "kubecloud/internal/core/workflows" "kubecloud/internal/infrastructure/gridclient" @@ -13,6 +14,9 @@ import ( "kubecloud/internal/infrastructure/mailservice" mailsender "kubecloud/internal/infrastructure/mailservice/mail_sender" "kubecloud/internal/infrastructure/notification" + "slices" + "strconv" + "strings" "sync" "time" @@ -49,6 +53,9 @@ type WorkerService struct { reservedNodeHealthCheckWorkersNum int monitorBalanceIntervalInMinutes int notifyAdminsForPendingRecordsInHours int + locksReleaseIntervalInMinutes int + + locker distributedlocks.DistributedLocks usersBalanceCheckIntervalInHours int } @@ -61,7 +68,8 @@ func NewWorkersService( invoiceCompanyData config.InvoiceCompanyData, systemMnemonic, currency string, clusterHealthCheckIntervalInHours, reservedNodeHealthCheckIntervalInHours, reservedNodeHealthCheckTimeoutInMinutes, reservedNodeHealthCheckWorkersNum, - monitorBalanceIntervalInMinutes, notifyAdminsForPendingRecordsInHours int, + monitorBalanceIntervalInMinutes, notifyAdminsForPendingRecordsInHours, locksReleaseIntervalInMinutes int, + locker distributedlocks.DistributedLocks, usersBalanceCheckIntervalInHours int, checkUserDebtIntervalInHours int, ) WorkerService { @@ -80,6 +88,8 @@ func NewWorkersService( firesquidClient: firesquidClient, gridClient: gridClient, + locker: locker, + systemMnemonic: systemMnemonic, invoiceCompanyData: invoiceCompanyData, currency: currency, @@ -91,6 +101,7 @@ func NewWorkersService( reservedNodeHealthCheckWorkersNum: reservedNodeHealthCheckWorkersNum, monitorBalanceIntervalInMinutes: monitorBalanceIntervalInMinutes, notifyAdminsForPendingRecordsInHours: notifyAdminsForPendingRecordsInHours, + locksReleaseIntervalInMinutes: locksReleaseIntervalInMinutes, usersBalanceCheckIntervalInHours: usersBalanceCheckIntervalInHours, } } @@ -136,6 +147,10 @@ func (svc WorkerService) GetNotifyAdminsForPendingRecordsInterval() time.Duratio return time.Duration(svc.notifyAdminsForPendingRecordsInHours) * time.Hour } +func (svc WorkerService) GetLocksReleaseInterval() time.Duration { + return time.Duration(svc.locksReleaseIntervalInMinutes) * time.Minute +} + func (svc WorkerService) GetUsersBalanceCheckInterval() time.Duration { return time.Duration(svc.usersBalanceCheckIntervalInHours) * time.Hour } @@ -472,6 +487,38 @@ func (svc WorkerService) AsyncTrackClusterHealth(cluster models.Cluster) error { return svc.ewfEngine.Run(svc.ctx, wf, ewf.WithAsync()) } +func (svc WorkerService) GetAllWorkflowsLocks() ([]string, error) { + return svc.locker.GetAllWorkflowsLocks(svc.ctx) +} + +func (svc WorkerService) ReleaseLocks(key string) error { + keyParts := strings.Split(key, ":") + if len(keyParts) != 3 { + return fmt.Errorf("invalid lock key format: %s", key) + } + + nodeID, err := strconv.ParseUint(keyParts[1], 10, 32) + if err != nil { + return fmt.Errorf("invalid node ID: %w", err) + } + + workflowID := keyParts[2] + workflow, err := svc.ewfEngine.Store().LoadWorkflowByUUID(svc.ctx, workflowID) + if err != nil { + return fmt.Errorf("failed to load workflow by UUID: %w", err) + } + + if !slices.Contains([]ewf.WorkflowStatus{ewf.StatusCompleted, ewf.StatusFailed}, workflow.Status) { + return nil + } + + if err := svc.locker.ReleaseLock(svc.ctx, uint32(nodeID), workflowID); err != nil { + return fmt.Errorf("failed to release nodes locks: %w", err) + } + + return nil +} + func (svc WorkerService) checkUserDebt(user models.User, contractIDs []uint64) error { totalDebt, err := svc.calculateDebt(user.Mnemonic, contractIDs, svc.GetCheckUserDebtInterval()) if err != nil { diff --git a/backend/internal/core/workers/locks_releaser.go b/backend/internal/core/workers/locks_releaser.go new file mode 100644 index 00000000..8368f3ff --- /dev/null +++ b/backend/internal/core/workers/locks_releaser.go @@ -0,0 +1,36 @@ +package workers + +import ( + "kubecloud/internal/infrastructure/logger" + "time" +) + +// ReleaseWorkflowLocks periodically scans Redis locks and frees those that belong to finished workflows. +func (w Workers) ReleaseWorkflowLocks() { + log := logger.ForOperation("locks_worker", "release_workflow_locks") + ticker := time.NewTicker(w.svc.GetLocksReleaseInterval()) + defer ticker.Stop() + + for { + select { + case <-w.ctx.Done(): + return + case <-ticker.C: + keys, err := w.svc.GetAllWorkflowsLocks() + if err != nil { + log.Error().Err(err).Msg("failed to list workflow locks") + continue + } + + if len(keys) == 0 { + continue + } + + for _, key := range keys { + if err := w.svc.ReleaseLocks(key); err != nil { + log.Error().Err(err).Str("key", key).Msg("failed to release lock") + } + } + } + } +} From 97e5db0c71074a4a57bbb522f38075afcc3a7fcc Mon Sep 17 00:00:00 2001 From: SalmaElsoly Date: Thu, 27 Nov 2025 14:18:07 +0200 Subject: [PATCH 02/10] refactor: redis_locker implementation with rollbacks --- backend/internal/api/app/app_dependencies.go | 4 +- .../api/handlers/deployment_handler.go | 33 ++++++- backend/internal/api/handlers/node_handler.go | 19 +++- .../distributed_locks/distributed_locks.go | 6 +- .../core/distributed_locks/redis_locker.go | 95 ++++++++++++++----- .../distributed_locks/redis_locker_test.go | 8 +- .../core/services/deployment_service.go | 28 +++--- .../internal/core/services/node_service.go | 8 +- 8 files changed, 143 insertions(+), 58 deletions(-) diff --git a/backend/internal/api/app/app_dependencies.go b/backend/internal/api/app/app_dependencies.go index 774a92d4..ec5d5e47 100644 --- a/backend/internal/api/app/app_dependencies.go +++ b/backend/internal/api/app/app_dependencies.go @@ -367,8 +367,8 @@ func (app *App) createHandlers() appHandlers { ) statsHandler := handlers.NewStatsHandler(statsService) notificationHandler := handlers.NewNotificationHandler(notificationAPIService) - nodeHandler := handlers.NewNodeHandler(nodeService) - deploymentHandler := handlers.NewDeploymentHandler(deploymentService) + nodeHandler := handlers.NewNodeHandler(nodeService, app.core.locker) + deploymentHandler := handlers.NewDeploymentHandler(deploymentService, app.core.locker) invoiceHandler := handlers.NewInvoiceHandler(invoiceService) adminHandler := handlers.NewAdminHandler(adminService, app.communication.notificationDispatcher, app.communication.mailService) healthHandler := handlers.NewHealthHandler(app.config.SystemAccount.Network, app.infra.firesquidClient, app.infra.graphql, app.core.db) diff --git a/backend/internal/api/handlers/deployment_handler.go b/backend/internal/api/handlers/deployment_handler.go index c76560ff..f3ec7a10 100644 --- a/backend/internal/api/handlers/deployment_handler.go +++ b/backend/internal/api/handlers/deployment_handler.go @@ -6,18 +6,21 @@ import ( "github.com/gin-gonic/gin" + distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "kubecloud/internal/core/services" "kubecloud/internal/deployment/kubedeployer" ) type DeploymentHandler struct { - svc services.DeploymentService + svc services.DeploymentService + locker distributedlocks.DistributedLocks } -func NewDeploymentHandler(svc services.DeploymentService) DeploymentHandler { +func NewDeploymentHandler(svc services.DeploymentService, locker distributedlocks.DistributedLocks) DeploymentHandler { return DeploymentHandler{ - svc: svc, + svc: svc, + locker: locker, } } @@ -237,6 +240,21 @@ func (h *DeploymentHandler) HandleDeployCluster(c *gin.Context) { return } + nodeIDs := make([]uint32, len(cluster.Nodes)) + for i, node := range cluster.Nodes { + nodeIDs[i] = node.NodeID + } + + if err = h.locker.AcquireNodesLocks(c.Request.Context(), nodeIDs); err != nil { + reqLog.Error().Err(err).Msg("failed to acquire nodes locks") + if errors.Is(err, distributedlocks.ErrNodeLocked) { + Conflict(c, err.Error()) + return + } + InternalServerError(c) + return + } + wfUUID, wfStatus, err := h.svc.AsyncDeployCluster(config, cluster) if err != nil { reqLog.Error().Err(err).Msg("failed to start deployment workflow") @@ -408,6 +426,15 @@ func (h *DeploymentHandler) HandleAddNode(c *gin.Context) { return } } + if err = h.locker.AcquireNodesLocks(c.Request.Context(), []uint32{cluster.Nodes[0].NodeID}); err != nil { + reqLog.Error().Err(err).Msg("failed to acquire nodes locks") + if errors.Is(err, distributedlocks.ErrNodeLocked) { + Conflict(c, err.Error()) + return + } + InternalServerError(c) + return + } wfUUID, wfStatus, err := h.svc.AsyncAddNode(config, cl, cluster.Nodes[0]) if err != nil { diff --git a/backend/internal/api/handlers/node_handler.go b/backend/internal/api/handlers/node_handler.go index 470b1083..897963a5 100644 --- a/backend/internal/api/handlers/node_handler.go +++ b/backend/internal/api/handlers/node_handler.go @@ -3,6 +3,7 @@ package handlers import ( "errors" "fmt" + distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "math/rand/v2" "net/url" @@ -17,12 +18,14 @@ import ( ) type NodeHandler struct { - svc services.NodeService + svc services.NodeService + locker distributedlocks.DistributedLocks } -func NewNodeHandler(svc services.NodeService) NodeHandler { +func NewNodeHandler(svc services.NodeService, locker distributedlocks.DistributedLocks) NodeHandler { return NodeHandler{ - svc: svc, + svc: svc, + locker: locker, } } @@ -290,6 +293,16 @@ func (h *NodeHandler) ReserveNodeHandler(c *gin.Context) { return } + if err = h.locker.AcquireNodesLocks(c.Request.Context(), []uint32{nodeID}); err != nil { + reqLog.Error().Err(err).Msg("failed to acquire nodes locks") + if errors.Is(err, distributedlocks.ErrNodeLocked) { + Conflict(c, err.Error()) + return + } + InternalServerError(c) + return + } + wfUUID, err := h.svc.AsyncReserveNode(userID, user.Mnemonic, nodeID) if err != nil { reqLog.Error().Err(err).Msg("failed to start workflow to reserve node") diff --git a/backend/internal/core/distributed_locks/distributed_locks.go b/backend/internal/core/distributed_locks/distributed_locks.go index fd190f77..9a9898d0 100644 --- a/backend/internal/core/distributed_locks/distributed_locks.go +++ b/backend/internal/core/distributed_locks/distributed_locks.go @@ -2,11 +2,15 @@ package distributedlocks import ( "context" + "errors" ) +var ErrNodeLocked = errors.New("node is currently locked by another request") + +// DistributedLocks is an interface that defines the methods for distributed locks. type DistributedLocks interface { AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) error - AcquireWorkflowLock(ctx context.Context, nodeID uint32, workflowID string) (bool, error) + AcquireWorkflowLock(ctx context.Context, nodeIDs []uint32, workflowID string) error ReleaseLock(ctx context.Context, nodeID uint32, workflowID string) error GetAllWorkflowsLocks(ctx context.Context) ([]string, error) } diff --git a/backend/internal/core/distributed_locks/redis_locker.go b/backend/internal/core/distributed_locks/redis_locker.go index c41fe1c6..fc6b30cd 100644 --- a/backend/internal/core/distributed_locks/redis_locker.go +++ b/backend/internal/core/distributed_locks/redis_locker.go @@ -3,6 +3,7 @@ package distributedlocks import ( "context" "fmt" + "strconv" "time" "github.com/redis/go-redis/v9" @@ -23,25 +24,62 @@ func NewRedisLocker(client *redis.Client, lockTimeout time.Duration) *RedisLocke // AcquireNodesLocks acquires locks for the given node IDs. func (l *RedisLocker) AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) error { - keys := nodeLockKeys(nodeIDs) + if err := l.acquireKeys(ctx, lockKeys(nodeIDs, nodeLockKey)); err != nil { + return err + } + + return nil +} + +// AcquireWorkflowLock acquires a lock for the given workflow ID. +func (l *RedisLocker) AcquireWorkflowLock(ctx context.Context, nodeIDs []uint32, workflowID string) error { + keys := lockKeys(nodeIDs, func(id uint32) string { + return workflowLockKey(id, workflowID) + }) + + if err := l.acquireKeys(ctx, keys); err != nil { + if rollErr := l.rollbackLocks(ctx, keys); rollErr != nil { + return rollErr + } + return err + } + + return nil +} + +func nodeLockKey(nodeID uint32) string { + return fmt.Sprintf("locked:%d", nodeID) +} + +func workflowLockKey(nodeID uint32, workflowID string) string { + return fmt.Sprintf("used:%d:%s", nodeID, workflowID) +} + +func lockKeys(ids []uint32, keyFunc func(uint32) string) []string { + keys := make([]string, len(ids)) + for i, id := range ids { + keys[i] = keyFunc(id) + } + return keys +} + +func (l *RedisLocker) acquireKeys(ctx context.Context, keys []string) error { locked := make([]string, 0, len(keys)) for _, key := range keys { ok, err := l.client.SetNX(ctx, key, 1, l.lockTimeout).Result() if err != nil { - err = l.client.Del(ctx, locked...).Err() - if err != nil { - return fmt.Errorf("redis error while rolling back locks: %w", err) + if rollErr := l.rollbackLocks(ctx, locked); rollErr != nil { + return rollErr } return fmt.Errorf("redis error while acquiring lock for key %s: %w", key, err) } if !ok { - err = l.client.Del(ctx, locked...).Err() - if err != nil { - return fmt.Errorf("redis error while rolling back locks: %w", err) + if rollErr := l.rollbackLocks(ctx, locked); rollErr != nil { + return rollErr } - return fmt.Errorf("failed to acquire lock for key %s: node already locked", key) + return fmt.Errorf("%w: %s", ErrNodeLocked, key) } locked = append(locked, key) @@ -50,13 +88,18 @@ func (l *RedisLocker) AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) e return nil } -// AcquireWorkflowLock acquires a lock for the given workflow ID. -func (l *RedisLocker) AcquireWorkflowLock(ctx context.Context, nodeID uint32, workflowID string) (bool, error) { - key := workflowLockKey(nodeID, workflowID) - return l.client.SetNX(ctx, key, 1, l.lockTimeout).Result() +func (l *RedisLocker) rollbackLocks(ctx context.Context, keys []string) error { + if len(keys) == 0 { + return nil + } + + if err := l.client.Del(ctx, keys...).Err(); err != nil { + return fmt.Errorf("redis error while rolling back locks: %w", err) + } + + return nil } -// ReleaseLock releases a lock for the given node ID and workflow ID. func (l *RedisLocker) ReleaseLock(ctx context.Context, nodeID uint32, workflowID string) error { lockedKey := nodeLockKey(nodeID) usedKey := workflowLockKey(nodeID, workflowID) @@ -68,18 +111,18 @@ func (l *RedisLocker) GetAllWorkflowsLocks(ctx context.Context) ([]string, error return l.client.Keys(ctx, "used:*").Result() } -func nodeLockKey(nodeID uint32) string { - return fmt.Sprintf("locked:%d", nodeID) -} - -func nodeLockKeys(nodeIDs []uint32) []string { - keys := make([]string, len(nodeIDs)) - for i, id := range nodeIDs { - keys[i] = nodeLockKey(id) +func (l *RedisLocker) GetLockedNodes(ctx context.Context) ([]uint32, error) { + keys, err := l.client.Keys(ctx, "locked:*").Result() + if err != nil { + return nil, err } - return keys -} - -func workflowLockKey(nodeID uint32, workflowID string) string { - return fmt.Sprintf("used:%d:%s", nodeID, workflowID) + nodes := make([]uint32, len(keys)) + for i, key := range keys { + value, parseErr := strconv.ParseUint(key[len("locked:"):], 10, 32) + if parseErr != nil { + return nil, fmt.Errorf("failed to parse locked node id from %s: %w", key, parseErr) + } + nodes[i] = uint32(value) + } + return nodes, nil } diff --git a/backend/internal/core/distributed_locks/redis_locker_test.go b/backend/internal/core/distributed_locks/redis_locker_test.go index 4a05cc62..aed9f3c5 100644 --- a/backend/internal/core/distributed_locks/redis_locker_test.go +++ b/backend/internal/core/distributed_locks/redis_locker_test.go @@ -55,13 +55,11 @@ func TestRedisLocker_AcquireWorkflowLock(t *testing.T) { client := newTestRedisClient(t) locker := &RedisLocker{client: client, lockTimeout: time.Minute} - ok, err := locker.AcquireWorkflowLock(context.Background(), 1, "wf-1") + err := locker.AcquireWorkflowLock(context.Background(), []uint32{1}, "wf-1") require.NoError(t, err) - require.True(t, ok) - ok, err = locker.AcquireWorkflowLock(context.Background(), 1, "wf-1") - require.NoError(t, err) - require.False(t, ok) + err = locker.AcquireWorkflowLock(context.Background(), []uint32{1}, "wf-1") + require.Error(t, err) } func TestRedisLocker_ReleaseLock(t *testing.T) { diff --git a/backend/internal/core/services/deployment_service.go b/backend/internal/core/services/deployment_service.go index d4372bef..8a88305e 100644 --- a/backend/internal/core/services/deployment_service.go +++ b/backend/internal/core/services/deployment_service.go @@ -210,7 +210,7 @@ func (svc *DeploymentService) runWithQueue(queueName string, wf *ewf.Workflow) e return svc.ewfEngine.Run(svc.appCtx, *wf) } -func (svc *DeploymentService) handleDeploymentAction(userID int, workflowName string, state ewf.State, displayName string, metadata map[string]string) (string, ewf.WorkflowStatus, error) { +func (svc *DeploymentService) handleDeploymentAction(userID int, workflowName string, state ewf.State, displayName string, metadata map[string]string, nodeIDs []uint32) (workflowID string, status ewf.WorkflowStatus, err error) { _, span := svc.tracer.StartSpan(context.Background(), "handleDeploymentAction") defer span.End() @@ -235,6 +235,14 @@ func (svc *DeploymentService) handleDeploymentAction(userID int, workflowName st return "", "", err } + if len(nodeIDs) > 0 { + + if err = svc.locker.AcquireWorkflowLock(svc.appCtx, nodeIDs, wf.UUID); err != nil { + return "", "", err + } + + } + if err = svc.runWithQueue(queueName, &wf); err != nil { telemetry.RecordError(span, err) return "", "", err @@ -252,10 +260,6 @@ func (svc *DeploymentService) AsyncDeployCluster(config statemanager.ClientConfi for _, node := range cluster.Nodes { nodeIDs = append(nodeIDs, node.NodeID) } - if err := svc.locker.AcquireNodesLocks(svc.appCtx, nodeIDs); err != nil { - return "", "", err - } - state := ewf.State{ "config": config, "cluster": cluster, @@ -266,7 +270,7 @@ func (svc *DeploymentService) AsyncDeployCluster(config statemanager.ClientConfi "cluster_name": cluster.Name, "node_count": strconv.Itoa(len(cluster.Nodes)), } - return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeployCluster, state, displayName, metadata) + return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeployCluster, state, displayName, metadata, nodeIDs) } func (svc *DeploymentService) AsyncDeleteCluster(config statemanager.ClientConfig, projectName string) (string, ewf.WorkflowStatus, error) { @@ -280,7 +284,7 @@ func (svc *DeploymentService) AsyncDeleteCluster(config statemanager.ClientConfi metadata := map[string]string{ "project_name": projectName, } - return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeleteCluster, state, displayName, metadata) + return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeleteCluster, state, displayName, metadata, nil) } func (svc *DeploymentService) AsyncDeleteAllClusters(config statemanager.ClientConfig) (string, ewf.WorkflowStatus, error) { @@ -290,15 +294,11 @@ func (svc *DeploymentService) AsyncDeleteAllClusters(config statemanager.ClientC } displayName := "Deleting all user clusters" - return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeleteAllClusters, state, displayName, nil) + return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeleteAllClusters, state, displayName, nil, nil) } func (svc *DeploymentService) AsyncAddNode(config statemanager.ClientConfig, cl kubedeployer.Cluster, node kubedeployer.Node) (string, ewf.WorkflowStatus, error) { - if err := svc.locker.AcquireNodesLocks(svc.appCtx, []uint32{node.NodeID}); err != nil { - return "", "", err - } - state := ewf.State{ "config": config, "cluster": cl, @@ -309,7 +309,7 @@ func (svc *DeploymentService) AsyncAddNode(config statemanager.ClientConfig, cl "cluster_name": cl.Name, "node_name": node.Name, } - return svc.handleDeploymentAction(config.UserID, workflows.WorkflowAddNode, state, displayName, metadata) + return svc.handleDeploymentAction(config.UserID, workflows.WorkflowAddNode, state, displayName, metadata, []uint32{node.NodeID}) } func (svc *DeploymentService) AsyncRemoveNode(config statemanager.ClientConfig, cl kubedeployer.Cluster, nodeName string) (string, ewf.WorkflowStatus, error) { @@ -325,5 +325,5 @@ func (svc *DeploymentService) AsyncRemoveNode(config statemanager.ClientConfig, "cluster_name": cl.Name, "node_name": nodeName, } - return svc.handleDeploymentAction(config.UserID, workflows.WorkflowRemoveNode, state, displayName, metadata) + return svc.handleDeploymentAction(config.UserID, workflows.WorkflowRemoveNode, state, displayName, metadata, nil) } diff --git a/backend/internal/core/services/node_service.go b/backend/internal/core/services/node_service.go index 0dd4729e..cb0c0e7e 100644 --- a/backend/internal/core/services/node_service.go +++ b/backend/internal/core/services/node_service.go @@ -218,10 +218,6 @@ func (svc *NodeService) GetRentedNodesForUser(ctx context.Context, userID int, h } func (svc *NodeService) AsyncReserveNode(userID int, userMnemonic string, nodeID uint32) (string, error) { - if err := svc.locker.AcquireNodesLocks(svc.appCtx, []uint32{nodeID}); err != nil { - return "", err - } - queueName := fmt.Sprintf("%s:user_%d", cfg.DefaultQueueConfig.Name, userID) displayName := fmt.Sprintf("Reserving node %d", nodeID) metadata := map[string]string{ @@ -246,6 +242,10 @@ func (svc *NodeService) AsyncReserveNode(userID int, userMnemonic string, nodeID return "", err } + if err = svc.locker.AcquireWorkflowLock(svc.appCtx, []uint32{nodeID}, wf.UUID); err != nil { + return "", err + } + if err = svc.runWithQueue(queueName, &wf); err != nil { return "", err } From 7d64b075fbe7b82ca3485f3b2613c794613214fb Mon Sep 17 00:00:00 2001 From: SalmaElsoly Date: Thu, 27 Nov 2025 16:39:13 +0200 Subject: [PATCH 03/10] feat: implement filtering of locked nodes in NodeService and and node handlers --- backend/internal/api/handlers/node_handler.go | 31 +++++-- .../distributed_locks/distributed_locks.go | 1 + .../core/distributed_locks/redis_locker.go | 4 +- .../internal/core/services/node_service.go | 19 ++++ .../core/services/node_service_test.go | 92 +++++++++++++++++++ 5 files changed, 136 insertions(+), 11 deletions(-) diff --git a/backend/internal/api/handlers/node_handler.go b/backend/internal/api/handlers/node_handler.go index 897963a5..9db99878 100644 --- a/backend/internal/api/handlers/node_handler.go +++ b/backend/internal/api/handlers/node_handler.go @@ -136,7 +136,7 @@ func (h *NodeHandler) ListNodesHandler(c *gin.Context) { userID := c.GetInt("user_id") reqLog := requestLogger(c, "ListNodesHandler") - rentedNodes, rentedNodesCount, err := h.svc.GetRentedNodesForUser(c.Request.Context(), userID, true) + rentedNodes, _, err := h.svc.GetRentedNodesForUser(c.Request.Context(), userID, true) if err != nil { reqLog.Error().Err(err).Msg("failed to retrieve rented nodes") InternalServerError(c) @@ -175,7 +175,7 @@ func (h *NodeHandler) ListNodesHandler(c *gin.Context) { filter.Healthy = &healthy filter.AvailableFor = &twinID - availableNodes, availableNodesCount, err := h.svc.GetZos3Nodes(c.Request.Context(), filter, limit) + availableNodes, _, err := h.svc.GetZos3Nodes(c.Request.Context(), filter, limit) if err != nil { reqLog.Error().Err(err).Msg("failed to retrieve available nodes") InternalServerError(c) @@ -188,7 +188,6 @@ func (h *NodeHandler) ListNodesHandler(c *gin.Context) { // Combine all nodes without duplicates var allNodes []proxyTypes.Node - duplicatesCount := 0 seen := make(map[int]bool) for _, node := range rentedNodes { @@ -202,14 +201,19 @@ func (h *NodeHandler) ListNodesHandler(c *gin.Context) { if !seen[node.NodeID] { seen[node.NodeID] = true allNodes = append(allNodes, node) - } else { - duplicatesCount++ } } + unlockedNodes, err := h.svc.FilterLockedNodes(c.Request.Context(), allNodes) + if err != nil { + reqLog.Error().Err(err).Msg("failed to filter locked nodes") + InternalServerError(c) + return + } + OK(c, "Nodes retrieved successfully", ListNodesResponse{ - Total: rentedNodesCount + availableNodesCount - duplicatesCount, - Nodes: allNodes, + Total: len(unlockedNodes), + Nodes: unlockedNodes, }) } @@ -338,15 +342,22 @@ func (h *NodeHandler) ListRentableNodesHandler(c *gin.Context) { limit := proxyTypes.DefaultLimit() limit.Randomize = true - nodes, count, err := h.svc.GetZos3Nodes(c.Request.Context(), filter, limit) + nodes, _, err := h.svc.GetZos3Nodes(c.Request.Context(), filter, limit) if err != nil { reqLog.Error().Err(err).Msg("failed to retrieve nodes") InternalServerError(c) return } + unlockedNodes, err := h.svc.FilterLockedNodes(c.Request.Context(), nodes) + if err != nil { + reqLog.Error().Err(err).Msg("failed to filter locked nodes") + InternalServerError(c) + return + } + var nodesWithDiscount []NodesWithDiscount - for _, node := range nodes { + for _, node := range unlockedNodes { nodesWithDiscount = append(nodesWithDiscount, NodesWithDiscount{ Node: node, DiscountPrice: node.PriceUsd * 0.5, @@ -354,7 +365,7 @@ func (h *NodeHandler) ListRentableNodesHandler(c *gin.Context) { } OK(c, "Nodes are retrieved successfully", ListNodesWithDiscountResponse{ - Total: count, + Total: len(nodesWithDiscount), Nodes: nodesWithDiscount, }) } diff --git a/backend/internal/core/distributed_locks/distributed_locks.go b/backend/internal/core/distributed_locks/distributed_locks.go index 9a9898d0..f0023456 100644 --- a/backend/internal/core/distributed_locks/distributed_locks.go +++ b/backend/internal/core/distributed_locks/distributed_locks.go @@ -13,4 +13,5 @@ type DistributedLocks interface { AcquireWorkflowLock(ctx context.Context, nodeIDs []uint32, workflowID string) error ReleaseLock(ctx context.Context, nodeID uint32, workflowID string) error GetAllWorkflowsLocks(ctx context.Context) ([]string, error) + GetLockedNodes(ctx context.Context) ([]uint32, error) } diff --git a/backend/internal/core/distributed_locks/redis_locker.go b/backend/internal/core/distributed_locks/redis_locker.go index fc6b30cd..7e13ad94 100644 --- a/backend/internal/core/distributed_locks/redis_locker.go +++ b/backend/internal/core/distributed_locks/redis_locker.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strconv" + "strings" "time" "github.com/redis/go-redis/v9" @@ -118,7 +119,8 @@ func (l *RedisLocker) GetLockedNodes(ctx context.Context) ([]uint32, error) { } nodes := make([]uint32, len(keys)) for i, key := range keys { - value, parseErr := strconv.ParseUint(key[len("locked:"):], 10, 32) + nodeID := strings.Split(key, ":")[1] + value, parseErr := strconv.ParseUint(nodeID, 10, 32) if parseErr != nil { return nil, fmt.Errorf("failed to parse locked node id from %s: %w", key, parseErr) } diff --git a/backend/internal/core/services/node_service.go b/backend/internal/core/services/node_service.go index cb0c0e7e..1e12814b 100644 --- a/backend/internal/core/services/node_service.go +++ b/backend/internal/core/services/node_service.go @@ -303,3 +303,22 @@ func (svc *NodeService) runWithQueue(queueName string, wf *ewf.Workflow) error { return svc.ewfEngine.Run(svc.appCtx, *wf) } + +func (svc *NodeService) FilterLockedNodes(ctx context.Context, nodes []proxyTypes.Node) ([]proxyTypes.Node, error) { + lockedNodes, err := svc.locker.GetLockedNodes(ctx) + if err != nil { + return nil, err + } + lockedSet := make(map[uint32]bool, len(lockedNodes)) + for _, id := range lockedNodes { + lockedSet[id] = true + } + unlockedNodes := make([]proxyTypes.Node, 0, len(nodes)) + for _, node := range nodes { + if lockedSet[uint32(node.NodeID)] { + continue + } + unlockedNodes = append(unlockedNodes, node) + } + return unlockedNodes, nil +} diff --git a/backend/internal/core/services/node_service_test.go b/backend/internal/core/services/node_service_test.go index 7eecc344..50df0946 100644 --- a/backend/internal/core/services/node_service_test.go +++ b/backend/internal/core/services/node_service_test.go @@ -1,15 +1,54 @@ package services import ( + "context" "fmt" "testing" "kubecloud/internal/core/models" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + proxyTypes "github.com/threefoldtech/tfgrid-sdk-go/grid-proxy/pkg/types" ) +type MockDistributedLocks struct { + mock.Mock +} + +func (m *MockDistributedLocks) GetLockedNodes(ctx context.Context) ([]uint32, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]uint32), args.Error(1) +} + +func (m *MockDistributedLocks) AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) error { + + args := m.Called(ctx, nodeIDs) + return args.Error(0) +} + +func (m *MockDistributedLocks) AcquireWorkflowLock(ctx context.Context, nodeIDs []uint32, workflowID string) error { + args := m.Called(ctx, nodeIDs, workflowID) + return args.Error(0) +} + +func (m *MockDistributedLocks) ReleaseLock(ctx context.Context, nodeID uint32, workflowID string) error { + args := m.Called(ctx, nodeID, workflowID) + return args.Error(0) +} + +func (m *MockDistributedLocks) GetAllWorkflowsLocks(ctx context.Context) ([]string, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]string), args.Error(1) +} + // Test 1: NodeService - GetUserNodeByNodeID SUCCESS func TestNodeService_GetUserNodeByNodeID_Success(t *testing.T) { mockNodesRepo := new(mockUserNodesRepo) @@ -137,3 +176,56 @@ func TestNodeService_GetUserByID_NotFound(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "user not found") } + +// Test 7: NodeService - FilterLockedNodes SUCCESS +func TestNodeService_FilterLockedNodes_Success(t *testing.T) { + mockLocker := new(MockDistributedLocks) + mockNodesRepo := new(mockUserNodesRepo) + mockUserRepo := new(mockUserRepo) + + nodes := []proxyTypes.Node{ + {NodeID: 100}, + {NodeID: 101}, + {NodeID: 102}, + } + + mockLocker.On("GetLockedNodes", context.Background()).Return([]uint32{101}, nil) + + service := NodeService{ + locker: mockLocker, + nodesRepo: mockNodesRepo, + userRepo: mockUserRepo, + } + + unlockedNodes, err := service.FilterLockedNodes(context.Background(), nodes) + + require.NoError(t, err) + assert.Equal(t, []proxyTypes.Node{nodes[0], nodes[2]}, unlockedNodes) +} + +// Test 8: NodeService - FilterLockedNodes ERROR +func TestNodeService_FilterLockedNodes_Error(t *testing.T) { + mockLocker := new(MockDistributedLocks) + mockNodesRepo := new(mockUserNodesRepo) + mockUserRepo := new(mockUserRepo) + + mockLocker.On("GetLockedNodes", context.Background()).Return(nil, fmt.Errorf("locked nodes error")) + + nodes := []proxyTypes.Node{ + {NodeID: 100}, + {NodeID: 101}, + {NodeID: 102}, + } + + service := NodeService{ + locker: mockLocker, + nodesRepo: mockNodesRepo, + userRepo: mockUserRepo, + } + + unlockedNodes, err := service.FilterLockedNodes(context.Background(), nodes) + + require.Error(t, err) + assert.Contains(t, err.Error(), "locked nodes error") + assert.Empty(t, unlockedNodes) +} From 9cd67981082d79180ed67c4439f59b7604cb9933 Mon Sep 17 00:00:00 2001 From: SalmaElsoly Date: Thu, 27 Nov 2025 16:45:57 +0200 Subject: [PATCH 04/10] refactor: update configuration structure for locks and adjust related bindings --- backend/cmd/root.go | 8 +++---- backend/internal/api/app/app_dependencies.go | 8 +++---- backend/internal/config/config.go | 25 ++++++++++---------- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/backend/cmd/root.go b/backend/cmd/root.go index ba1fe789..bd27fcb8 100644 --- a/backend/cmd/root.go +++ b/backend/cmd/root.go @@ -211,11 +211,11 @@ func addFlags() error { return fmt.Errorf("failed to bind telemetry.otlp_endpoint flag: %w", err) } // === Locks === - if err := bindIntFlag(rootCmd, "redis.lock_timeout_in_hours", 24, "Redis lock timeout (hours)"); err != nil { - return fmt.Errorf("failed to bind redis.lock_timeout_in_hours flag: %w", err) + if err := bindIntFlag(rootCmd, "locks.lock_timeout_in_hours", 24, "Redis lock timeout (hours)"); err != nil { + return fmt.Errorf("failed to bind locks.lock_timeout_in_hours flag: %w", err) } - if err := bindIntFlag(rootCmd, "locks_release_interval_in_minutes", 5, "Locks release interval (minutes)"); err != nil { - return fmt.Errorf("failed to bind locks_release_interval_in_minutes flag: %w", err) + if err := bindIntFlag(rootCmd, "locks.locks_release_interval_in_minutes", 5, "Locks release interval (minutes)"); err != nil { + return fmt.Errorf("failed to bind locks.locks_release_interval_in_minutes flag: %w", err) } return nil diff --git a/backend/internal/api/app/app_dependencies.go b/backend/internal/api/app/app_dependencies.go index ec5d5e47..a85bdb6a 100644 --- a/backend/internal/api/app/app_dependencies.go +++ b/backend/internal/api/app/app_dependencies.go @@ -67,7 +67,7 @@ type appCore struct { db models.DB metrics *metrics.Metrics ewfEngine *ewf.Engine - locker distributedlocks.DistributedLocks + locker distributedlocks.DistributedLocks tracerProvider *telemetry.TracerProvider } @@ -169,14 +169,14 @@ func createAppCore(ctx context.Context, config cfg.Configuration) (appCore, erro return appCore{}, fmt.Errorf("failed to init workflow engine: %w", err) } - locker := distributedlocks.NewRedisLocker(client, time.Duration(config.Redis.LockTimeoutInHours)*time.Hour) + locker := distributedlocks.NewRedisLocker(client, time.Duration(config.Locks.LockTimeoutInHours)*time.Hour) return appCore{ appCtx: ctx, db: db, metrics: metrics.NewMetrics(), ewfEngine: ewfEngine, - locker: locker, + locker: locker, tracerProvider: tp, }, nil } @@ -401,7 +401,7 @@ func (app *App) createWorkers() workers.Workers { app.config.Invoice, app.config.SystemAccount.Mnemonic, app.config.Currency, app.config.ClusterHealthCheckIntervalInHours, app.config.NodeHealthCheck.ReservedNodeHealthCheckIntervalInHours, app.config.NodeHealthCheck.ReservedNodeHealthCheckTimeoutInMinutes, app.config.NodeHealthCheck.ReservedNodeHealthCheckWorkersNum, app.config.MonitorBalanceIntervalInMinutes, app.config.NotifyAdminsForPendingRecordsInHours, app.config.UsersBalanceCheckIntervalInHours, app.config.CheckUserDebtIntervalInHours, - app.config.LocksReleaseIntervalInMinutes, + app.config.Locks.LocksReleaseIntervalInMinutes, app.core.locker, ) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index b535e4a5..427c28ee 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -34,7 +34,7 @@ type Configuration struct { DevMode bool `json:"dev_mode"` // When true, allows empty SendGridKey and uses FakeMailService MonitorBalanceIntervalInMinutes int `json:"monitor_balance_interval_in_minutes" validate:"required,gt=0"` NotifyAdminsForPendingRecordsInHours int `json:"notify_admins_for_pending_records_in_hours" validate:"required,gt=0"` - LocksReleaseIntervalInMinutes int `json:"locks_release_interval_in_minutes" validate:"required,gt=0" default:"5"` + Locks LocksConfig `json:"locks" validate:"required,dive"` ClusterHealthCheckIntervalInHours int `json:"cluster_health_check_interval_in_hours" validate:"gt=0" default:"1"` UsersBalanceCheckIntervalInHours int `json:"users_balance_check_interval_in_hours" validate:"gt=0" default:"6"` CheckUserDebtIntervalInHours int `json:"check_user_debt_interval_in_hours" validate:"gt=0" default:"48"` @@ -51,11 +51,10 @@ type SSHConfig struct { } type RedisConfig struct { - Hostname string `json:"hostname" validate:"hostname|ip|url"` - Port int `json:"port" validate:"min=1,max=65535"` - Password string `json:"password"` - DB int `json:"db" validate:"min=0"` - LockTimeoutInHours int `json:"lock_timeout_in_hours" validate:"required,min=1" default:"1"` + Hostname string `json:"hostname" validate:"hostname|ip|url"` + Port int `json:"port" validate:"min=1,max=65535"` + Password string `json:"password"` + DB int `json:"db" validate:"min=0"` } // Server struct holds server's information @@ -134,6 +133,10 @@ type TelemetryConfig struct { OTLPEndpoint string `json:"otlp_endpoint" default:"jaeger:4317"` // gRPC endpoint for OTLP exporter } +type LocksConfig struct { + LockTimeoutInHours int `json:"lock_timeout_in_hours" validate:"required,gt=0" default:"24"` + LocksReleaseIntervalInMinutes int `json:"locks_release_interval_in_minutes" validate:"required,gt=0" default:"5"` +} type ReservedNodeHealthCheckConfig struct { ReservedNodeHealthCheckIntervalInHours int `json:"reserved_node_health_check_interval_in_hours" validate:"required,gt=0" default:"1"` ReservedNodeHealthCheckTimeoutInMinutes int `json:"reserved_node_health_check_timeout_in_minutes" validate:"required,gt=0" default:"1"` @@ -392,13 +395,11 @@ func applyDefaultValues(config *Configuration) { if config.NotifyAdminsForPendingRecordsInHours == 0 { config.NotifyAdminsForPendingRecordsInHours = 24 } - - if config.Redis.LockTimeoutInHours == 0 { - config.Redis.LockTimeoutInHours = 24 + if config.Locks.LockTimeoutInHours == 0 { + config.Locks.LockTimeoutInHours = 24 } - - if config.LocksReleaseIntervalInMinutes == 0 { - config.LocksReleaseIntervalInMinutes = 5 + if config.Locks.LocksReleaseIntervalInMinutes == 0 { + config.Locks.LocksReleaseIntervalInMinutes = 5 } if config.Telemetry.OTLPEndpoint == "" { From f273ad2926b4f225cf1914cb9254566726b8528e Mon Sep 17 00:00:00 2001 From: SalmaElsoly Date: Thu, 27 Nov 2025 17:44:32 +0200 Subject: [PATCH 05/10] refactor: update lock release worker to handle all nodes of workflow once --- backend/cmd/root.go | 2 +- backend/internal/api/app/app.go | 1 + .../api/handlers/deployment_handler.go | 8 +++ backend/internal/api/handlers/node_handler.go | 4 ++ backend/internal/config/config.go | 2 +- .../distributed_locks/distributed_locks.go | 2 +- .../core/distributed_locks/redis_locker.go | 11 ++-- .../distributed_locks/redis_locker_test.go | 2 +- .../core/services/node_service_test.go | 4 +- .../internal/core/services/workers_service.go | 51 +++++++++++-------- .../internal/core/workers/locks_releaser.go | 7 +-- 11 files changed, 57 insertions(+), 37 deletions(-) diff --git a/backend/cmd/root.go b/backend/cmd/root.go index bd27fcb8..2b070685 100644 --- a/backend/cmd/root.go +++ b/backend/cmd/root.go @@ -214,7 +214,7 @@ func addFlags() error { if err := bindIntFlag(rootCmd, "locks.lock_timeout_in_hours", 24, "Redis lock timeout (hours)"); err != nil { return fmt.Errorf("failed to bind locks.lock_timeout_in_hours flag: %w", err) } - if err := bindIntFlag(rootCmd, "locks.locks_release_interval_in_minutes", 5, "Locks release interval (minutes)"); err != nil { + if err := bindIntFlag(rootCmd, "locks.locks_release_interval_in_minutes", 2, "Locks release interval (minutes)"); err != nil { return fmt.Errorf("failed to bind locks.locks_release_interval_in_minutes flag: %w", err) } diff --git a/backend/internal/api/app/app.go b/backend/internal/api/app/app.go index 56bdc34c..5b05d37e 100644 --- a/backend/internal/api/app/app.go +++ b/backend/internal/api/app/app.go @@ -228,6 +228,7 @@ func (app *App) StartBackgroundWorkers() { go app.workers.TrackReservedNodeHealth() go app.workers.CollectGORMMetrics() go app.workers.CollectGoRuntimeMetrics() + go app.workers.ReleaseWorkflowLocks() } // Run starts the server diff --git a/backend/internal/api/handlers/deployment_handler.go b/backend/internal/api/handlers/deployment_handler.go index f3ec7a10..60999188 100644 --- a/backend/internal/api/handlers/deployment_handler.go +++ b/backend/internal/api/handlers/deployment_handler.go @@ -258,6 +258,10 @@ func (h *DeploymentHandler) HandleDeployCluster(c *gin.Context) { wfUUID, wfStatus, err := h.svc.AsyncDeployCluster(config, cluster) if err != nil { reqLog.Error().Err(err).Msg("failed to start deployment workflow") + err = h.locker.ReleaseLock(c.Request.Context(), nodeIDs, wfUUID) + if err != nil { + reqLog.Error().Err(err).Msg("failed to release nodes locks") + } InternalServerError(c) return } @@ -439,6 +443,10 @@ func (h *DeploymentHandler) HandleAddNode(c *gin.Context) { wfUUID, wfStatus, err := h.svc.AsyncAddNode(config, cl, cluster.Nodes[0]) if err != nil { reqLog.Error().Err(err).Msg("failed to start add node workflow") + err = h.locker.ReleaseLock(c.Request.Context(), []uint32{cluster.Nodes[0].NodeID}, wfUUID) + if err != nil { + reqLog.Error().Err(err).Msg("failed to release nodes locks") + } InternalServerError(c) return } diff --git a/backend/internal/api/handlers/node_handler.go b/backend/internal/api/handlers/node_handler.go index 9db99878..e53f9acc 100644 --- a/backend/internal/api/handlers/node_handler.go +++ b/backend/internal/api/handlers/node_handler.go @@ -310,6 +310,10 @@ func (h *NodeHandler) ReserveNodeHandler(c *gin.Context) { wfUUID, err := h.svc.AsyncReserveNode(userID, user.Mnemonic, nodeID) if err != nil { reqLog.Error().Err(err).Msg("failed to start workflow to reserve node") + err = h.locker.ReleaseLock(c.Request.Context(), []uint32{nodeID}, wfUUID) + if err != nil { + reqLog.Error().Err(err).Msg("failed to release nodes locks") + } InternalServerError(c) return } diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 427c28ee..7bdc76c5 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -399,7 +399,7 @@ func applyDefaultValues(config *Configuration) { config.Locks.LockTimeoutInHours = 24 } if config.Locks.LocksReleaseIntervalInMinutes == 0 { - config.Locks.LocksReleaseIntervalInMinutes = 5 + config.Locks.LocksReleaseIntervalInMinutes = 2 } if config.Telemetry.OTLPEndpoint == "" { diff --git a/backend/internal/core/distributed_locks/distributed_locks.go b/backend/internal/core/distributed_locks/distributed_locks.go index f0023456..fa2b47b5 100644 --- a/backend/internal/core/distributed_locks/distributed_locks.go +++ b/backend/internal/core/distributed_locks/distributed_locks.go @@ -11,7 +11,7 @@ var ErrNodeLocked = errors.New("node is currently locked by another request") type DistributedLocks interface { AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) error AcquireWorkflowLock(ctx context.Context, nodeIDs []uint32, workflowID string) error - ReleaseLock(ctx context.Context, nodeID uint32, workflowID string) error + ReleaseLock(ctx context.Context, nodeIDs []uint32, workflowID string) error GetAllWorkflowsLocks(ctx context.Context) ([]string, error) GetLockedNodes(ctx context.Context) ([]uint32, error) } diff --git a/backend/internal/core/distributed_locks/redis_locker.go b/backend/internal/core/distributed_locks/redis_locker.go index 7e13ad94..8a41ba09 100644 --- a/backend/internal/core/distributed_locks/redis_locker.go +++ b/backend/internal/core/distributed_locks/redis_locker.go @@ -101,10 +101,13 @@ func (l *RedisLocker) rollbackLocks(ctx context.Context, keys []string) error { return nil } -func (l *RedisLocker) ReleaseLock(ctx context.Context, nodeID uint32, workflowID string) error { - lockedKey := nodeLockKey(nodeID) - usedKey := workflowLockKey(nodeID, workflowID) - return l.client.Del(ctx, lockedKey, usedKey).Err() +func (l *RedisLocker) ReleaseLock(ctx context.Context, nodeIDs []uint32, workflowID string) error { + lockedKeys := lockKeys(nodeIDs, nodeLockKey) + usedKeys := lockKeys(nodeIDs, func(id uint32) string { + return workflowLockKey(id, workflowID) + }) + allWorkflowsLocks := append(lockedKeys, usedKeys...) + return l.client.Del(ctx, allWorkflowsLocks...).Err() } // GetAllWorkflowsLocks gets all workflow locks. diff --git a/backend/internal/core/distributed_locks/redis_locker_test.go b/backend/internal/core/distributed_locks/redis_locker_test.go index aed9f3c5..58944144 100644 --- a/backend/internal/core/distributed_locks/redis_locker_test.go +++ b/backend/internal/core/distributed_locks/redis_locker_test.go @@ -69,7 +69,7 @@ func TestRedisLocker_ReleaseLock(t *testing.T) { require.NoError(t, client.Set(context.Background(), "locked:1", 1, 0).Err()) require.NoError(t, client.Set(context.Background(), "used:1:wf-1", 1, 0).Err()) - err := locker.ReleaseLock(context.Background(), 1, "wf-1") + err := locker.ReleaseLock(context.Background(), []uint32{1}, "wf-1") require.NoError(t, err) require.Equal(t, int64(0), client.Exists(context.Background(), "locked:1").Val()) diff --git a/backend/internal/core/services/node_service_test.go b/backend/internal/core/services/node_service_test.go index 50df0946..f55c6fa2 100644 --- a/backend/internal/core/services/node_service_test.go +++ b/backend/internal/core/services/node_service_test.go @@ -36,8 +36,8 @@ func (m *MockDistributedLocks) AcquireWorkflowLock(ctx context.Context, nodeIDs return args.Error(0) } -func (m *MockDistributedLocks) ReleaseLock(ctx context.Context, nodeID uint32, workflowID string) error { - args := m.Called(ctx, nodeID, workflowID) +func (m *MockDistributedLocks) ReleaseLock(ctx context.Context, nodeIDs []uint32, workflowID string) error { + args := m.Called(ctx, nodeIDs, workflowID) return args.Error(0) } diff --git a/backend/internal/core/services/workers_service.go b/backend/internal/core/services/workers_service.go index 817a4c8e..64c4e8e0 100644 --- a/backend/internal/core/services/workers_service.go +++ b/backend/internal/core/services/workers_service.go @@ -491,32 +491,41 @@ func (svc WorkerService) GetAllWorkflowsLocks() ([]string, error) { return svc.locker.GetAllWorkflowsLocks(svc.ctx) } -func (svc WorkerService) ReleaseLocks(key string) error { - keyParts := strings.Split(key, ":") - if len(keyParts) != 3 { - return fmt.Errorf("invalid lock key format: %s", key) - } - - nodeID, err := strconv.ParseUint(keyParts[1], 10, 32) - if err != nil { - return fmt.Errorf("invalid node ID: %w", err) - } - - workflowID := keyParts[2] - workflow, err := svc.ewfEngine.Store().LoadWorkflowByUUID(svc.ctx, workflowID) - if err != nil { - return fmt.Errorf("failed to load workflow by UUID: %w", err) - } +func (svc WorkerService) ReleaseLocks(keys []string) { + log := logger.ForOperation("locks_worker", "release_locks") + workflowsNodes := map[string][]uint32{} + for _, key := range keys { + parts := strings.Split(key, ":") + if len(parts) != 3 { + log.Error().Str("key", key).Msg("invalid lock key format") + continue + } - if !slices.Contains([]ewf.WorkflowStatus{ewf.StatusCompleted, ewf.StatusFailed}, workflow.Status) { - return nil + workflowID := parts[2] + nodeID, err := strconv.ParseUint(parts[1], 10, 32) + if err != nil { + log.Error().Str("key", key).Msg("invalid node ID") + continue + } + workflowsNodes[workflowID] = append(workflowsNodes[workflowID], uint32(nodeID)) } - if err := svc.locker.ReleaseLock(svc.ctx, uint32(nodeID), workflowID); err != nil { - return fmt.Errorf("failed to release nodes locks: %w", err) + for workflowID := range workflowsNodes { + workflow, err := svc.ewfEngine.Store().LoadWorkflowByUUID(svc.ctx, workflowID) + if err != nil { + log.Error().Str("workflow_id", workflowID).Msg("failed to load workflow") + continue + } + if !slices.Contains([]ewf.WorkflowStatus{ewf.StatusCompleted, ewf.StatusFailed}, workflow.Status) { + continue + } + nodeIDs := workflowsNodes[workflowID] + if err := svc.locker.ReleaseLock(svc.ctx, nodeIDs, workflowID); err != nil { + log.Error().Str("workflow_id", workflow.UUID).Msg("failed to release locks") + continue + } } - return nil } func (svc WorkerService) checkUserDebt(user models.User, contractIDs []uint64) error { diff --git a/backend/internal/core/workers/locks_releaser.go b/backend/internal/core/workers/locks_releaser.go index 8368f3ff..8d0d6392 100644 --- a/backend/internal/core/workers/locks_releaser.go +++ b/backend/internal/core/workers/locks_releaser.go @@ -25,12 +25,7 @@ func (w Workers) ReleaseWorkflowLocks() { if len(keys) == 0 { continue } - - for _, key := range keys { - if err := w.svc.ReleaseLocks(key); err != nil { - log.Error().Err(err).Str("key", key).Msg("failed to release lock") - } - } + w.svc.ReleaseLocks(keys) } } } From 4e93b60c990dcb2408684a9c98b7a2e868ca47b7 Mon Sep 17 00:00:00 2001 From: SalmaElsoly Date: Thu, 27 Nov 2025 18:08:34 +0200 Subject: [PATCH 06/10] fix: test + late release lock in handler --- backend/internal/api/handlers/deployment_handler.go | 8 -------- backend/internal/api/handlers/node_handler.go | 4 ---- backend/internal/core/distributed_locks/redis_locker.go | 4 +++- .../internal/core/distributed_locks/redis_locker_test.go | 3 ++- backend/internal/core/services/deployment_service.go | 5 +++++ backend/internal/core/services/node_service.go | 3 +++ 6 files changed, 13 insertions(+), 14 deletions(-) diff --git a/backend/internal/api/handlers/deployment_handler.go b/backend/internal/api/handlers/deployment_handler.go index 60999188..f3ec7a10 100644 --- a/backend/internal/api/handlers/deployment_handler.go +++ b/backend/internal/api/handlers/deployment_handler.go @@ -258,10 +258,6 @@ func (h *DeploymentHandler) HandleDeployCluster(c *gin.Context) { wfUUID, wfStatus, err := h.svc.AsyncDeployCluster(config, cluster) if err != nil { reqLog.Error().Err(err).Msg("failed to start deployment workflow") - err = h.locker.ReleaseLock(c.Request.Context(), nodeIDs, wfUUID) - if err != nil { - reqLog.Error().Err(err).Msg("failed to release nodes locks") - } InternalServerError(c) return } @@ -443,10 +439,6 @@ func (h *DeploymentHandler) HandleAddNode(c *gin.Context) { wfUUID, wfStatus, err := h.svc.AsyncAddNode(config, cl, cluster.Nodes[0]) if err != nil { reqLog.Error().Err(err).Msg("failed to start add node workflow") - err = h.locker.ReleaseLock(c.Request.Context(), []uint32{cluster.Nodes[0].NodeID}, wfUUID) - if err != nil { - reqLog.Error().Err(err).Msg("failed to release nodes locks") - } InternalServerError(c) return } diff --git a/backend/internal/api/handlers/node_handler.go b/backend/internal/api/handlers/node_handler.go index e53f9acc..9db99878 100644 --- a/backend/internal/api/handlers/node_handler.go +++ b/backend/internal/api/handlers/node_handler.go @@ -310,10 +310,6 @@ func (h *NodeHandler) ReserveNodeHandler(c *gin.Context) { wfUUID, err := h.svc.AsyncReserveNode(userID, user.Mnemonic, nodeID) if err != nil { reqLog.Error().Err(err).Msg("failed to start workflow to reserve node") - err = h.locker.ReleaseLock(c.Request.Context(), []uint32{nodeID}, wfUUID) - if err != nil { - reqLog.Error().Err(err).Msg("failed to release nodes locks") - } InternalServerError(c) return } diff --git a/backend/internal/core/distributed_locks/redis_locker.go b/backend/internal/core/distributed_locks/redis_locker.go index 8a41ba09..4838dc7a 100644 --- a/backend/internal/core/distributed_locks/redis_locker.go +++ b/backend/internal/core/distributed_locks/redis_locker.go @@ -39,7 +39,9 @@ func (l *RedisLocker) AcquireWorkflowLock(ctx context.Context, nodeIDs []uint32, }) if err := l.acquireKeys(ctx, keys); err != nil { - if rollErr := l.rollbackLocks(ctx, keys); rollErr != nil { + //rollback nodes locks + nodeLockKeys := lockKeys(nodeIDs, nodeLockKey) + if rollErr := l.rollbackLocks(ctx, nodeLockKeys); rollErr != nil { return rollErr } return err diff --git a/backend/internal/core/distributed_locks/redis_locker_test.go b/backend/internal/core/distributed_locks/redis_locker_test.go index 58944144..8b6be69a 100644 --- a/backend/internal/core/distributed_locks/redis_locker_test.go +++ b/backend/internal/core/distributed_locks/redis_locker_test.go @@ -47,7 +47,8 @@ func TestRedisLocker_AcquireNodesLocks_NodeAlreadyLocked(t *testing.T) { err := locker.AcquireNodesLocks(context.Background(), []uint32{1, 2}) require.Error(t, err) - require.Contains(t, err.Error(), "failed to acquire lock for key locked:2") + require.ErrorIs(t, err, ErrNodeLocked) + require.Contains(t, err.Error(), "locked:2") require.Equal(t, int64(0), client.Exists(context.Background(), "locked:1").Val(), "previous locks should be rolled back") } diff --git a/backend/internal/core/services/deployment_service.go b/backend/internal/core/services/deployment_service.go index 8a88305e..98d9852a 100644 --- a/backend/internal/core/services/deployment_service.go +++ b/backend/internal/core/services/deployment_service.go @@ -244,6 +244,11 @@ func (svc *DeploymentService) handleDeploymentAction(userID int, workflowName st } if err = svc.runWithQueue(queueName, &wf); err != nil { + if len(nodeIDs) > 0 { + if releaseErr := svc.locker.ReleaseLock(svc.appCtx, nodeIDs, wf.UUID); releaseErr != nil { + err = fmt.Errorf("%w: failed to release workflow lock: %v", err, releaseErr) + } + } telemetry.RecordError(span, err) return "", "", err } diff --git a/backend/internal/core/services/node_service.go b/backend/internal/core/services/node_service.go index 1e12814b..6d25cef2 100644 --- a/backend/internal/core/services/node_service.go +++ b/backend/internal/core/services/node_service.go @@ -247,6 +247,9 @@ func (svc *NodeService) AsyncReserveNode(userID int, userMnemonic string, nodeID } if err = svc.runWithQueue(queueName, &wf); err != nil { + if releaseErr := svc.locker.ReleaseLock(svc.appCtx, []uint32{nodeID}, wf.UUID); releaseErr != nil { + err = fmt.Errorf("%w: failed to release workflow lock: %v", err, releaseErr) + } return "", err } From d0761463d958c6cfcacb9ebdeaa91f4cd62ddfb8 Mon Sep 17 00:00:00 2001 From: SalmaElsoly Date: Sun, 14 Dec 2025 17:02:18 +0200 Subject: [PATCH 07/10] refactor: removing redundant lock parameters and updating lock acquisition methods --- backend/cmd/root.go | 10 +- backend/internal/api/app/app.go | 1 - backend/internal/api/app/app_dependencies.go | 11 +- .../api/handlers/deployment_handler.go | 33 +---- backend/internal/api/handlers/node_handler.go | 19 +-- backend/internal/config/config.go | 13 +- .../distributed_locks/distributed_locks.go | 6 +- .../core/distributed_locks/redis_locker.go | 116 +++++++++--------- .../distributed_locks/redis_locker_test.go | 89 +++++++++++--- .../core/services/deployment_service.go | 14 ++- .../internal/core/services/node_service.go | 11 +- .../core/services/node_service_test.go | 25 ++-- .../internal/core/services/workers_service.go | 59 +-------- .../internal/core/workers/locks_releaser.go | 31 ----- 14 files changed, 177 insertions(+), 261 deletions(-) delete mode 100644 backend/internal/core/workers/locks_releaser.go diff --git a/backend/cmd/root.go b/backend/cmd/root.go index 2b070685..277fe5fb 100644 --- a/backend/cmd/root.go +++ b/backend/cmd/root.go @@ -210,14 +210,10 @@ func addFlags() error { if err := bindStringFlag(rootCmd, "telemetry.otlp_endpoint", "jaeger:4317", "OpenTelemetry gRPC endpoint"); err != nil { return fmt.Errorf("failed to bind telemetry.otlp_endpoint flag: %w", err) } - // === Locks === - if err := bindIntFlag(rootCmd, "locks.lock_timeout_in_hours", 24, "Redis lock timeout (hours)"); err != nil { - return fmt.Errorf("failed to bind locks.lock_timeout_in_hours flag: %w", err) + // === Lock Timeout In Hours === + if err := bindIntFlag(rootCmd, "lock_timeout_in_hours", 1, "Redis lock timeout (hours)"); err != nil { + return fmt.Errorf("failed to bind lock_timeout_in_hours flag: %w", err) } - if err := bindIntFlag(rootCmd, "locks.locks_release_interval_in_minutes", 2, "Locks release interval (minutes)"); err != nil { - return fmt.Errorf("failed to bind locks.locks_release_interval_in_minutes flag: %w", err) - } - return nil } diff --git a/backend/internal/api/app/app.go b/backend/internal/api/app/app.go index 5b05d37e..56bdc34c 100644 --- a/backend/internal/api/app/app.go +++ b/backend/internal/api/app/app.go @@ -228,7 +228,6 @@ func (app *App) StartBackgroundWorkers() { go app.workers.TrackReservedNodeHealth() go app.workers.CollectGORMMetrics() go app.workers.CollectGoRuntimeMetrics() - go app.workers.ReleaseWorkflowLocks() } // Run starts the server diff --git a/backend/internal/api/app/app_dependencies.go b/backend/internal/api/app/app_dependencies.go index a85bdb6a..1d3a0eaa 100644 --- a/backend/internal/api/app/app_dependencies.go +++ b/backend/internal/api/app/app_dependencies.go @@ -169,7 +169,7 @@ func createAppCore(ctx context.Context, config cfg.Configuration) (appCore, erro return appCore{}, fmt.Errorf("failed to init workflow engine: %w", err) } - locker := distributedlocks.NewRedisLocker(client, time.Duration(config.Locks.LockTimeoutInHours)*time.Hour) + locker := distributedlocks.NewRedisLocker(client, time.Duration(config.LockTimeoutInHours)*time.Hour) return appCore{ appCtx: ctx, @@ -367,8 +367,8 @@ func (app *App) createHandlers() appHandlers { ) statsHandler := handlers.NewStatsHandler(statsService) notificationHandler := handlers.NewNotificationHandler(notificationAPIService) - nodeHandler := handlers.NewNodeHandler(nodeService, app.core.locker) - deploymentHandler := handlers.NewDeploymentHandler(deploymentService, app.core.locker) + nodeHandler := handlers.NewNodeHandler(nodeService) + deploymentHandler := handlers.NewDeploymentHandler(deploymentService) invoiceHandler := handlers.NewInvoiceHandler(invoiceService) adminHandler := handlers.NewAdminHandler(adminService, app.communication.notificationDispatcher, app.communication.mailService) healthHandler := handlers.NewHealthHandler(app.config.SystemAccount.Network, app.infra.firesquidClient, app.infra.graphql, app.core.db) @@ -400,9 +400,8 @@ func (app *App) createWorkers() workers.Workers { app.communication.notificationDispatcher, app.infra.graphql, app.infra.firesquidClient, app.config.Invoice, app.config.SystemAccount.Mnemonic, app.config.Currency, app.config.ClusterHealthCheckIntervalInHours, - app.config.NodeHealthCheck.ReservedNodeHealthCheckIntervalInHours, app.config.NodeHealthCheck.ReservedNodeHealthCheckTimeoutInMinutes, app.config.NodeHealthCheck.ReservedNodeHealthCheckWorkersNum, app.config.MonitorBalanceIntervalInMinutes, app.config.NotifyAdminsForPendingRecordsInHours, app.config.UsersBalanceCheckIntervalInHours, app.config.CheckUserDebtIntervalInHours, - app.config.Locks.LocksReleaseIntervalInMinutes, - app.core.locker, + app.config.NodeHealthCheck.ReservedNodeHealthCheckIntervalInHours, app.config.NodeHealthCheck.ReservedNodeHealthCheckTimeoutInMinutes, app.config.NodeHealthCheck.ReservedNodeHealthCheckWorkersNum, app.config.MonitorBalanceIntervalInMinutes, app.config.NotifyAdminsForPendingRecordsInHours, app.config.UsersBalanceCheckIntervalInHours, + app.config.CheckUserDebtIntervalInHours, ) return workers.NewWorkers(app.core.appCtx, workersService, app.core.metrics, app.core.db) diff --git a/backend/internal/api/handlers/deployment_handler.go b/backend/internal/api/handlers/deployment_handler.go index f3ec7a10..c76560ff 100644 --- a/backend/internal/api/handlers/deployment_handler.go +++ b/backend/internal/api/handlers/deployment_handler.go @@ -6,21 +6,18 @@ import ( "github.com/gin-gonic/gin" - distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "kubecloud/internal/core/services" "kubecloud/internal/deployment/kubedeployer" ) type DeploymentHandler struct { - svc services.DeploymentService - locker distributedlocks.DistributedLocks + svc services.DeploymentService } -func NewDeploymentHandler(svc services.DeploymentService, locker distributedlocks.DistributedLocks) DeploymentHandler { +func NewDeploymentHandler(svc services.DeploymentService) DeploymentHandler { return DeploymentHandler{ - svc: svc, - locker: locker, + svc: svc, } } @@ -240,21 +237,6 @@ func (h *DeploymentHandler) HandleDeployCluster(c *gin.Context) { return } - nodeIDs := make([]uint32, len(cluster.Nodes)) - for i, node := range cluster.Nodes { - nodeIDs[i] = node.NodeID - } - - if err = h.locker.AcquireNodesLocks(c.Request.Context(), nodeIDs); err != nil { - reqLog.Error().Err(err).Msg("failed to acquire nodes locks") - if errors.Is(err, distributedlocks.ErrNodeLocked) { - Conflict(c, err.Error()) - return - } - InternalServerError(c) - return - } - wfUUID, wfStatus, err := h.svc.AsyncDeployCluster(config, cluster) if err != nil { reqLog.Error().Err(err).Msg("failed to start deployment workflow") @@ -426,15 +408,6 @@ func (h *DeploymentHandler) HandleAddNode(c *gin.Context) { return } } - if err = h.locker.AcquireNodesLocks(c.Request.Context(), []uint32{cluster.Nodes[0].NodeID}); err != nil { - reqLog.Error().Err(err).Msg("failed to acquire nodes locks") - if errors.Is(err, distributedlocks.ErrNodeLocked) { - Conflict(c, err.Error()) - return - } - InternalServerError(c) - return - } wfUUID, wfStatus, err := h.svc.AsyncAddNode(config, cl, cluster.Nodes[0]) if err != nil { diff --git a/backend/internal/api/handlers/node_handler.go b/backend/internal/api/handlers/node_handler.go index 9db99878..4586dd62 100644 --- a/backend/internal/api/handlers/node_handler.go +++ b/backend/internal/api/handlers/node_handler.go @@ -3,7 +3,6 @@ package handlers import ( "errors" "fmt" - distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "math/rand/v2" "net/url" @@ -18,14 +17,12 @@ import ( ) type NodeHandler struct { - svc services.NodeService - locker distributedlocks.DistributedLocks + svc services.NodeService } -func NewNodeHandler(svc services.NodeService, locker distributedlocks.DistributedLocks) NodeHandler { +func NewNodeHandler(svc services.NodeService) NodeHandler { return NodeHandler{ - svc: svc, - locker: locker, + svc: svc, } } @@ -297,16 +294,6 @@ func (h *NodeHandler) ReserveNodeHandler(c *gin.Context) { return } - if err = h.locker.AcquireNodesLocks(c.Request.Context(), []uint32{nodeID}); err != nil { - reqLog.Error().Err(err).Msg("failed to acquire nodes locks") - if errors.Is(err, distributedlocks.ErrNodeLocked) { - Conflict(c, err.Error()) - return - } - InternalServerError(c) - return - } - wfUUID, err := h.svc.AsyncReserveNode(userID, user.Mnemonic, nodeID) if err != nil { reqLog.Error().Err(err).Msg("failed to start workflow to reserve node") diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 7bdc76c5..c6800f38 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -34,11 +34,11 @@ type Configuration struct { DevMode bool `json:"dev_mode"` // When true, allows empty SendGridKey and uses FakeMailService MonitorBalanceIntervalInMinutes int `json:"monitor_balance_interval_in_minutes" validate:"required,gt=0"` NotifyAdminsForPendingRecordsInHours int `json:"notify_admins_for_pending_records_in_hours" validate:"required,gt=0"` - Locks LocksConfig `json:"locks" validate:"required,dive"` ClusterHealthCheckIntervalInHours int `json:"cluster_health_check_interval_in_hours" validate:"gt=0" default:"1"` UsersBalanceCheckIntervalInHours int `json:"users_balance_check_interval_in_hours" validate:"gt=0" default:"6"` CheckUserDebtIntervalInHours int `json:"check_user_debt_interval_in_hours" validate:"gt=0" default:"48"` NodeHealthCheck ReservedNodeHealthCheckConfig `json:"node_health_check" validate:"required,dive"` + LockTimeoutInHours int `json:"lock_timeout_in_hours" validate:"required,gt=0" default:"1"` Logger LoggerConfig `json:"logger"` Loki LokiConfig `json:"loki"` @@ -133,10 +133,6 @@ type TelemetryConfig struct { OTLPEndpoint string `json:"otlp_endpoint" default:"jaeger:4317"` // gRPC endpoint for OTLP exporter } -type LocksConfig struct { - LockTimeoutInHours int `json:"lock_timeout_in_hours" validate:"required,gt=0" default:"24"` - LocksReleaseIntervalInMinutes int `json:"locks_release_interval_in_minutes" validate:"required,gt=0" default:"5"` -} type ReservedNodeHealthCheckConfig struct { ReservedNodeHealthCheckIntervalInHours int `json:"reserved_node_health_check_interval_in_hours" validate:"required,gt=0" default:"1"` ReservedNodeHealthCheckTimeoutInMinutes int `json:"reserved_node_health_check_timeout_in_minutes" validate:"required,gt=0" default:"1"` @@ -395,11 +391,8 @@ func applyDefaultValues(config *Configuration) { if config.NotifyAdminsForPendingRecordsInHours == 0 { config.NotifyAdminsForPendingRecordsInHours = 24 } - if config.Locks.LockTimeoutInHours == 0 { - config.Locks.LockTimeoutInHours = 24 - } - if config.Locks.LocksReleaseIntervalInMinutes == 0 { - config.Locks.LocksReleaseIntervalInMinutes = 2 + if config.LockTimeoutInHours == 0 { + config.LockTimeoutInHours = 24 } if config.Telemetry.OTLPEndpoint == "" { diff --git a/backend/internal/core/distributed_locks/distributed_locks.go b/backend/internal/core/distributed_locks/distributed_locks.go index fa2b47b5..4de6771a 100644 --- a/backend/internal/core/distributed_locks/distributed_locks.go +++ b/backend/internal/core/distributed_locks/distributed_locks.go @@ -9,9 +9,7 @@ var ErrNodeLocked = errors.New("node is currently locked by another request") // DistributedLocks is an interface that defines the methods for distributed locks. type DistributedLocks interface { - AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) error - AcquireWorkflowLock(ctx context.Context, nodeIDs []uint32, workflowID string) error - ReleaseLock(ctx context.Context, nodeIDs []uint32, workflowID string) error - GetAllWorkflowsLocks(ctx context.Context) ([]string, error) + AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) (map[string]string, error) + ReleaseLock(ctx context.Context, lockedKeys map[string]string) error GetLockedNodes(ctx context.Context) ([]uint32, error) } diff --git a/backend/internal/core/distributed_locks/redis_locker.go b/backend/internal/core/distributed_locks/redis_locker.go index 4838dc7a..613e9323 100644 --- a/backend/internal/core/distributed_locks/redis_locker.go +++ b/backend/internal/core/distributed_locks/redis_locker.go @@ -7,9 +7,14 @@ import ( "strings" "time" + "github.com/google/uuid" "github.com/redis/go-redis/v9" ) +const ( + nodeLockKey = "locked" +) + type RedisLocker struct { client *redis.Client lockTimeout time.Duration @@ -24,77 +29,56 @@ func NewRedisLocker(client *redis.Client, lockTimeout time.Duration) *RedisLocke } // AcquireNodesLocks acquires locks for the given node IDs. -func (l *RedisLocker) AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) error { - if err := l.acquireKeys(ctx, lockKeys(nodeIDs, nodeLockKey)); err != nil { - return err - } - - return nil -} - -// AcquireWorkflowLock acquires a lock for the given workflow ID. -func (l *RedisLocker) AcquireWorkflowLock(ctx context.Context, nodeIDs []uint32, workflowID string) error { - keys := lockKeys(nodeIDs, func(id uint32) string { - return workflowLockKey(id, workflowID) - }) - - if err := l.acquireKeys(ctx, keys); err != nil { - //rollback nodes locks - nodeLockKeys := lockKeys(nodeIDs, nodeLockKey) - if rollErr := l.rollbackLocks(ctx, nodeLockKeys); rollErr != nil { - return rollErr - } - return err +func (l *RedisLocker) AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) (map[string]string, error) { + lockedKeys, err := l.acquireKeys(ctx, nodeLockKeys(nodeIDs)) + if err != nil { + return nil, err } - - return nil -} - -func nodeLockKey(nodeID uint32) string { - return fmt.Sprintf("locked:%d", nodeID) -} - -func workflowLockKey(nodeID uint32, workflowID string) string { - return fmt.Sprintf("used:%d:%s", nodeID, workflowID) + return lockedKeys, nil } -func lockKeys(ids []uint32, keyFunc func(uint32) string) []string { - keys := make([]string, len(ids)) - for i, id := range ids { - keys[i] = keyFunc(id) +func nodeLockKeys(nodeIDs []uint32) []string { + keys := make([]string, len(nodeIDs)) + for i, id := range nodeIDs { + keys[i] = fmt.Sprintf("%s:%d", nodeLockKey, id) } return keys } -func (l *RedisLocker) acquireKeys(ctx context.Context, keys []string) error { - locked := make([]string, 0, len(keys)) +func (l *RedisLocker) acquireKeys(ctx context.Context, keys []string) (map[string]string, error) { + locked := make(map[string]string, len(keys)) for _, key := range keys { - ok, err := l.client.SetNX(ctx, key, 1, l.lockTimeout).Result() + keyValue := uuid.New().String() + ok, err := l.client.SetNX(ctx, key, keyValue, l.lockTimeout).Result() if err != nil { if rollErr := l.rollbackLocks(ctx, locked); rollErr != nil { - return rollErr + return nil, rollErr } - return fmt.Errorf("redis error while acquiring lock for key %s: %w", key, err) + return nil, fmt.Errorf("redis error while acquiring lock for key %s: %w", key, err) } if !ok { if rollErr := l.rollbackLocks(ctx, locked); rollErr != nil { - return rollErr + return nil, rollErr } - return fmt.Errorf("%w: %s", ErrNodeLocked, key) + return nil, fmt.Errorf("%w: %s", ErrNodeLocked, key) } - locked = append(locked, key) + locked[key] = keyValue } - return nil + return locked, nil } -func (l *RedisLocker) rollbackLocks(ctx context.Context, keys []string) error { - if len(keys) == 0 { +func (l *RedisLocker) rollbackLocks(ctx context.Context, locked map[string]string) error { + if len(locked) == 0 { return nil } + keys := make([]string, 0, len(locked)) + for k := range locked { + keys = append(keys, k) + } if err := l.client.Del(ctx, keys...).Err(); err != nil { return fmt.Errorf("redis error while rolling back locks: %w", err) @@ -103,18 +87,36 @@ func (l *RedisLocker) rollbackLocks(ctx context.Context, keys []string) error { return nil } -func (l *RedisLocker) ReleaseLock(ctx context.Context, nodeIDs []uint32, workflowID string) error { - lockedKeys := lockKeys(nodeIDs, nodeLockKey) - usedKeys := lockKeys(nodeIDs, func(id uint32) string { - return workflowLockKey(id, workflowID) - }) - allWorkflowsLocks := append(lockedKeys, usedKeys...) - return l.client.Del(ctx, allWorkflowsLocks...).Err() -} +func (l *RedisLocker) ReleaseLock(ctx context.Context, lockedKeys map[string]string) error { + if len(lockedKeys) == 0 { + return nil + } + + var failedKeys []string + for key, expectedValue := range lockedKeys { + storedValue, err := l.client.Get(ctx, key).Result() + if err == redis.Nil { + continue + } + if err != nil { + return fmt.Errorf("failed to get lock value for key %s: %w", key, err) + } -// GetAllWorkflowsLocks gets all workflow locks. -func (l *RedisLocker) GetAllWorkflowsLocks(ctx context.Context) ([]string, error) { - return l.client.Keys(ctx, "used:*").Result() + if storedValue != expectedValue { + failedKeys = append(failedKeys, key) + continue + } + + if err := l.client.Del(ctx, key).Err(); err != nil { + return fmt.Errorf("failed to delete lock for key %s: %w", key, err) + } + } + + if len(failedKeys) > 0 { + return fmt.Errorf("lock value mismatch for keys: %v", failedKeys) + } + + return nil } func (l *RedisLocker) GetLockedNodes(ctx context.Context) ([]uint32, error) { diff --git a/backend/internal/core/distributed_locks/redis_locker_test.go b/backend/internal/core/distributed_locks/redis_locker_test.go index 8b6be69a..7eca0583 100644 --- a/backend/internal/core/distributed_locks/redis_locker_test.go +++ b/backend/internal/core/distributed_locks/redis_locker_test.go @@ -31,62 +31,113 @@ func TestRedisLocker_AcquireNodesLocks_Success(t *testing.T) { client := newTestRedisClient(t) locker := &RedisLocker{client: client, lockTimeout: time.Minute} - err := locker.AcquireNodesLocks(context.Background(), []uint32{1, 2}) + lockedKeys, err := locker.AcquireNodesLocks(context.Background(), []uint32{1, 2}) require.NoError(t, err) + require.Len(t, lockedKeys, 2) + require.Contains(t, lockedKeys, "locked:1") + require.Contains(t, lockedKeys, "locked:2") + // Verify UUID values are stored + require.NotEmpty(t, lockedKeys["locked:1"]) + require.NotEmpty(t, lockedKeys["locked:2"]) require.Equal(t, int64(1), client.Exists(context.Background(), "locked:1").Val()) require.Equal(t, int64(1), client.Exists(context.Background(), "locked:2").Val()) + // Verify the stored values match + val1, _ := client.Get(context.Background(), "locked:1").Result() + val2, _ := client.Get(context.Background(), "locked:2").Result() + require.Equal(t, lockedKeys["locked:1"], val1) + require.Equal(t, lockedKeys["locked:2"], val2) } func TestRedisLocker_AcquireNodesLocks_NodeAlreadyLocked(t *testing.T) { client := newTestRedisClient(t) locker := &RedisLocker{client: client, lockTimeout: time.Minute} - require.NoError(t, client.Set(context.Background(), "locked:2", 1, 0).Err()) + // Set an existing lock with a UUID value + existingValue := "existing-uuid-value" + require.NoError(t, client.Set(context.Background(), "locked:2", existingValue, 0).Err()) - err := locker.AcquireNodesLocks(context.Background(), []uint32{1, 2}) + lockedKeys, err := locker.AcquireNodesLocks(context.Background(), []uint32{1, 2}) require.Error(t, err) require.ErrorIs(t, err, ErrNodeLocked) require.Contains(t, err.Error(), "locked:2") + require.Nil(t, lockedKeys) require.Equal(t, int64(0), client.Exists(context.Background(), "locked:1").Val(), "previous locks should be rolled back") + // Verify the existing lock is still there + val, _ := client.Get(context.Background(), "locked:2").Result() + require.Equal(t, existingValue, val) } -func TestRedisLocker_AcquireWorkflowLock(t *testing.T) { +func TestRedisLocker_ReleaseLock_Success(t *testing.T) { client := newTestRedisClient(t) locker := &RedisLocker{client: client, lockTimeout: time.Minute} - err := locker.AcquireWorkflowLock(context.Background(), []uint32{1}, "wf-1") + // Set locks with specific UUID values + lockValue1 := "uuid-value-1" + lockValue2 := "uuid-value-2" + require.NoError(t, client.Set(context.Background(), "locked:1", lockValue1, 0).Err()) + require.NoError(t, client.Set(context.Background(), "locked:2", lockValue2, 0).Err()) + + // Release locks with matching values + lockedKeys := map[string]string{ + "locked:1": lockValue1, + "locked:2": lockValue2, + } + err := locker.ReleaseLock(context.Background(), lockedKeys) + require.NoError(t, err) + require.Equal(t, int64(0), client.Exists(context.Background(), "locked:1").Val()) + require.Equal(t, int64(0), client.Exists(context.Background(), "locked:2").Val()) +} + +func TestRedisLocker_ReleaseLock_ValueMismatch(t *testing.T) { + client := newTestRedisClient(t) + locker := &RedisLocker{client: client, lockTimeout: time.Minute} + + // Set lock with a specific value + storedValue := "stored-uuid-value" + require.NoError(t, client.Set(context.Background(), "locked:1", storedValue, 0).Err()) + + // Try to release with wrong value + lockedKeys := map[string]string{ + "locked:1": "wrong-uuid-value", + } + err := locker.ReleaseLock(context.Background(), lockedKeys) - err = locker.AcquireWorkflowLock(context.Background(), []uint32{1}, "wf-1") require.Error(t, err) + require.Contains(t, err.Error(), "lock value mismatch") + // Verify the lock is still there + require.Equal(t, int64(1), client.Exists(context.Background(), "locked:1").Val()) + val, _ := client.Get(context.Background(), "locked:1").Result() + require.Equal(t, storedValue, val) } -func TestRedisLocker_ReleaseLock(t *testing.T) { +func TestRedisLocker_ReleaseLock_KeyNotExists(t *testing.T) { client := newTestRedisClient(t) locker := &RedisLocker{client: client, lockTimeout: time.Minute} - require.NoError(t, client.Set(context.Background(), "locked:1", 1, 0).Err()) - require.NoError(t, client.Set(context.Background(), "used:1:wf-1", 1, 0).Err()) - - err := locker.ReleaseLock(context.Background(), []uint32{1}, "wf-1") + // Try to release a non-existent lock + lockedKeys := map[string]string{ + "locked:999": "some-uuid-value", + } + err := locker.ReleaseLock(context.Background(), lockedKeys) + // Should not error if key doesn't exist (just skip it) require.NoError(t, err) - require.Equal(t, int64(0), client.Exists(context.Background(), "locked:1").Val()) - require.Equal(t, int64(0), client.Exists(context.Background(), "used:1:wf-1").Val()) } -func TestRedisLocker_GetAllWorkflowsLocks(t *testing.T) { +func TestRedisLocker_GetLockedNodes(t *testing.T) { client := newTestRedisClient(t) locker := &RedisLocker{client: client, lockTimeout: time.Minute} - require.NoError(t, client.Set(context.Background(), "used:1:wf-1", 1, 0).Err()) - require.NoError(t, client.Set(context.Background(), "used:2:wf-2", 1, 0).Err()) - require.NoError(t, client.Set(context.Background(), "locked:99", 1, 0).Err()) + // Set some locks + require.NoError(t, client.Set(context.Background(), "locked:1", "uuid-1", 0).Err()) + require.NoError(t, client.Set(context.Background(), "locked:2", "uuid-2", 0).Err()) + require.NoError(t, client.Set(context.Background(), "locked:99", "uuid-99", 0).Err()) - keys, err := locker.GetAllWorkflowsLocks(context.Background()) + nodes, err := locker.GetLockedNodes(context.Background()) require.NoError(t, err) - require.ElementsMatch(t, []string{"used:1:wf-1", "used:2:wf-2"}, keys) + require.ElementsMatch(t, []uint32{1, 2, 99}, nodes) } diff --git a/backend/internal/core/services/deployment_service.go b/backend/internal/core/services/deployment_service.go index 98d9852a..a6c2a023 100644 --- a/backend/internal/core/services/deployment_service.go +++ b/backend/internal/core/services/deployment_service.go @@ -2,6 +2,7 @@ package services import ( "context" + "encoding/json" "errors" "fmt" cfg "kubecloud/internal/config" @@ -235,17 +236,24 @@ func (svc *DeploymentService) handleDeploymentAction(userID int, workflowName st return "", "", err } + var lockedKeys map[string]string if len(nodeIDs) > 0 { - - if err = svc.locker.AcquireWorkflowLock(svc.appCtx, nodeIDs, wf.UUID); err != nil { + lockedKeys, err = svc.locker.AcquireNodesLocks(svc.appCtx, nodeIDs) + if err != nil { return "", "", err } + lockedKeysJSON, err := json.Marshal(lockedKeys) + if err != nil { + return "", "", fmt.Errorf("failed to marshal locked keys: %w", err) + } + wf.Metadata["locked_keys"] = string(lockedKeysJSON) + } if err = svc.runWithQueue(queueName, &wf); err != nil { if len(nodeIDs) > 0 { - if releaseErr := svc.locker.ReleaseLock(svc.appCtx, nodeIDs, wf.UUID); releaseErr != nil { + if releaseErr := svc.locker.ReleaseLock(svc.appCtx, lockedKeys); releaseErr != nil { err = fmt.Errorf("%w: failed to release workflow lock: %v", err, releaseErr) } } diff --git a/backend/internal/core/services/node_service.go b/backend/internal/core/services/node_service.go index 6d25cef2..170a1dc9 100644 --- a/backend/internal/core/services/node_service.go +++ b/backend/internal/core/services/node_service.go @@ -2,6 +2,7 @@ package services import ( "context" + "encoding/json" "errors" "fmt" cfg "kubecloud/internal/config" @@ -242,12 +243,18 @@ func (svc *NodeService) AsyncReserveNode(userID int, userMnemonic string, nodeID return "", err } - if err = svc.locker.AcquireWorkflowLock(svc.appCtx, []uint32{nodeID}, wf.UUID); err != nil { + lockedKeys, err := svc.locker.AcquireNodesLocks(svc.appCtx, []uint32{nodeID}) + if err != nil { return "", err } + lockedKeysJSON, err := json.Marshal(lockedKeys) + if err != nil { + return "", fmt.Errorf("failed to marshal locked keys: %w", err) + } + wf.Metadata["locked_keys"] = string(lockedKeysJSON) if err = svc.runWithQueue(queueName, &wf); err != nil { - if releaseErr := svc.locker.ReleaseLock(svc.appCtx, []uint32{nodeID}, wf.UUID); releaseErr != nil { + if releaseErr := svc.locker.ReleaseLock(svc.appCtx, lockedKeys); releaseErr != nil { err = fmt.Errorf("%w: failed to release workflow lock: %v", err, releaseErr) } return "", err diff --git a/backend/internal/core/services/node_service_test.go b/backend/internal/core/services/node_service_test.go index f55c6fa2..a267af71 100644 --- a/backend/internal/core/services/node_service_test.go +++ b/backend/internal/core/services/node_service_test.go @@ -25,28 +25,17 @@ func (m *MockDistributedLocks) GetLockedNodes(ctx context.Context) ([]uint32, er return args.Get(0).([]uint32), args.Error(1) } -func (m *MockDistributedLocks) AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) error { - +func (m *MockDistributedLocks) AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) (map[string]string, error) { args := m.Called(ctx, nodeIDs) - return args.Error(0) -} - -func (m *MockDistributedLocks) AcquireWorkflowLock(ctx context.Context, nodeIDs []uint32, workflowID string) error { - args := m.Called(ctx, nodeIDs, workflowID) - return args.Error(0) -} - -func (m *MockDistributedLocks) ReleaseLock(ctx context.Context, nodeIDs []uint32, workflowID string) error { - args := m.Called(ctx, nodeIDs, workflowID) - return args.Error(0) -} - -func (m *MockDistributedLocks) GetAllWorkflowsLocks(ctx context.Context) ([]string, error) { - args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) } - return args.Get(0).([]string), args.Error(1) + return args.Get(0).(map[string]string), args.Error(1) +} + +func (m *MockDistributedLocks) ReleaseLock(ctx context.Context, lockedKeys map[string]string) error { + args := m.Called(ctx, lockedKeys) + return args.Error(0) } // Test 1: NodeService - GetUserNodeByNodeID SUCCESS diff --git a/backend/internal/core/services/workers_service.go b/backend/internal/core/services/workers_service.go index 64c4e8e0..6fd9e783 100644 --- a/backend/internal/core/services/workers_service.go +++ b/backend/internal/core/services/workers_service.go @@ -6,7 +6,6 @@ import ( "fmt" "kubecloud/internal/billing" "kubecloud/internal/config" - distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "kubecloud/internal/core/workflows" "kubecloud/internal/infrastructure/gridclient" @@ -14,9 +13,6 @@ import ( "kubecloud/internal/infrastructure/mailservice" mailsender "kubecloud/internal/infrastructure/mailservice/mail_sender" "kubecloud/internal/infrastructure/notification" - "slices" - "strconv" - "strings" "sync" "time" @@ -53,10 +49,8 @@ type WorkerService struct { reservedNodeHealthCheckWorkersNum int monitorBalanceIntervalInMinutes int notifyAdminsForPendingRecordsInHours int - locksReleaseIntervalInMinutes int - locker distributedlocks.DistributedLocks - usersBalanceCheckIntervalInHours int + usersBalanceCheckIntervalInHours int } func NewWorkersService( @@ -68,8 +62,7 @@ func NewWorkersService( invoiceCompanyData config.InvoiceCompanyData, systemMnemonic, currency string, clusterHealthCheckIntervalInHours, reservedNodeHealthCheckIntervalInHours, reservedNodeHealthCheckTimeoutInMinutes, reservedNodeHealthCheckWorkersNum, - monitorBalanceIntervalInMinutes, notifyAdminsForPendingRecordsInHours, locksReleaseIntervalInMinutes int, - locker distributedlocks.DistributedLocks, + monitorBalanceIntervalInMinutes, notifyAdminsForPendingRecordsInHours int, usersBalanceCheckIntervalInHours int, checkUserDebtIntervalInHours int, ) WorkerService { @@ -88,8 +81,6 @@ func NewWorkersService( firesquidClient: firesquidClient, gridClient: gridClient, - locker: locker, - systemMnemonic: systemMnemonic, invoiceCompanyData: invoiceCompanyData, currency: currency, @@ -101,7 +92,6 @@ func NewWorkersService( reservedNodeHealthCheckWorkersNum: reservedNodeHealthCheckWorkersNum, monitorBalanceIntervalInMinutes: monitorBalanceIntervalInMinutes, notifyAdminsForPendingRecordsInHours: notifyAdminsForPendingRecordsInHours, - locksReleaseIntervalInMinutes: locksReleaseIntervalInMinutes, usersBalanceCheckIntervalInHours: usersBalanceCheckIntervalInHours, } } @@ -147,10 +137,6 @@ func (svc WorkerService) GetNotifyAdminsForPendingRecordsInterval() time.Duratio return time.Duration(svc.notifyAdminsForPendingRecordsInHours) * time.Hour } -func (svc WorkerService) GetLocksReleaseInterval() time.Duration { - return time.Duration(svc.locksReleaseIntervalInMinutes) * time.Minute -} - func (svc WorkerService) GetUsersBalanceCheckInterval() time.Duration { return time.Duration(svc.usersBalanceCheckIntervalInHours) * time.Hour } @@ -487,47 +473,6 @@ func (svc WorkerService) AsyncTrackClusterHealth(cluster models.Cluster) error { return svc.ewfEngine.Run(svc.ctx, wf, ewf.WithAsync()) } -func (svc WorkerService) GetAllWorkflowsLocks() ([]string, error) { - return svc.locker.GetAllWorkflowsLocks(svc.ctx) -} - -func (svc WorkerService) ReleaseLocks(keys []string) { - log := logger.ForOperation("locks_worker", "release_locks") - workflowsNodes := map[string][]uint32{} - for _, key := range keys { - parts := strings.Split(key, ":") - if len(parts) != 3 { - log.Error().Str("key", key).Msg("invalid lock key format") - continue - } - - workflowID := parts[2] - nodeID, err := strconv.ParseUint(parts[1], 10, 32) - if err != nil { - log.Error().Str("key", key).Msg("invalid node ID") - continue - } - workflowsNodes[workflowID] = append(workflowsNodes[workflowID], uint32(nodeID)) - } - - for workflowID := range workflowsNodes { - workflow, err := svc.ewfEngine.Store().LoadWorkflowByUUID(svc.ctx, workflowID) - if err != nil { - log.Error().Str("workflow_id", workflowID).Msg("failed to load workflow") - continue - } - if !slices.Contains([]ewf.WorkflowStatus{ewf.StatusCompleted, ewf.StatusFailed}, workflow.Status) { - continue - } - nodeIDs := workflowsNodes[workflowID] - if err := svc.locker.ReleaseLock(svc.ctx, nodeIDs, workflowID); err != nil { - log.Error().Str("workflow_id", workflow.UUID).Msg("failed to release locks") - continue - } - } - -} - func (svc WorkerService) checkUserDebt(user models.User, contractIDs []uint64) error { totalDebt, err := svc.calculateDebt(user.Mnemonic, contractIDs, svc.GetCheckUserDebtInterval()) if err != nil { diff --git a/backend/internal/core/workers/locks_releaser.go b/backend/internal/core/workers/locks_releaser.go deleted file mode 100644 index 8d0d6392..00000000 --- a/backend/internal/core/workers/locks_releaser.go +++ /dev/null @@ -1,31 +0,0 @@ -package workers - -import ( - "kubecloud/internal/infrastructure/logger" - "time" -) - -// ReleaseWorkflowLocks periodically scans Redis locks and frees those that belong to finished workflows. -func (w Workers) ReleaseWorkflowLocks() { - log := logger.ForOperation("locks_worker", "release_workflow_locks") - ticker := time.NewTicker(w.svc.GetLocksReleaseInterval()) - defer ticker.Stop() - - for { - select { - case <-w.ctx.Done(): - return - case <-ticker.C: - keys, err := w.svc.GetAllWorkflowsLocks() - if err != nil { - log.Error().Err(err).Msg("failed to list workflow locks") - continue - } - - if len(keys) == 0 { - continue - } - w.svc.ReleaseLocks(keys) - } - } -} From 636a386d465f1ce811546b53ce03d95f0c70d947 Mon Sep 17 00:00:00 2001 From: SalmaElsoly Date: Mon, 15 Dec 2025 08:31:00 +0200 Subject: [PATCH 08/10] refactor: add lock release in after workflow hook --- backend/internal/api/app/app.go | 1 + .../api/handlers/deployment_handler.go | 9 ++++++++ backend/internal/api/handlers/node_handler.go | 5 +++++ .../core/workflows/deployer_activities.go | 5 ++++- backend/internal/core/workflows/hooks.go | 22 +++++++++++++++++++ backend/internal/core/workflows/workflow.go | 5 ++++- 6 files changed, 45 insertions(+), 2 deletions(-) diff --git a/backend/internal/api/app/app.go b/backend/internal/api/app/app.go index 56bdc34c..ddc89921 100644 --- a/backend/internal/api/app/app.go +++ b/backend/internal/api/app/app.go @@ -95,6 +95,7 @@ func (app *App) registerEWFWorkflows() { app.core.metrics, app.communication.notificationDispatcher, stripeClient, + app.core.locker, ) } diff --git a/backend/internal/api/handlers/deployment_handler.go b/backend/internal/api/handlers/deployment_handler.go index c76560ff..eace528d 100644 --- a/backend/internal/api/handlers/deployment_handler.go +++ b/backend/internal/api/handlers/deployment_handler.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" + distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "kubecloud/internal/core/services" "kubecloud/internal/deployment/kubedeployer" @@ -240,6 +241,10 @@ func (h *DeploymentHandler) HandleDeployCluster(c *gin.Context) { wfUUID, wfStatus, err := h.svc.AsyncDeployCluster(config, cluster) if err != nil { reqLog.Error().Err(err).Msg("failed to start deployment workflow") + if errors.Is(err, distributedlocks.ErrNodeLocked) { + Conflict(c, "Node is busy serving another request") + return + } InternalServerError(c) return } @@ -412,6 +417,10 @@ func (h *DeploymentHandler) HandleAddNode(c *gin.Context) { wfUUID, wfStatus, err := h.svc.AsyncAddNode(config, cl, cluster.Nodes[0]) if err != nil { reqLog.Error().Err(err).Msg("failed to start add node workflow") + if errors.Is(err, distributedlocks.ErrNodeLocked) { + Conflict(c, "Node is busy serving another request") + return + } InternalServerError(c) return } diff --git a/backend/internal/api/handlers/node_handler.go b/backend/internal/api/handlers/node_handler.go index 4586dd62..f315e7c5 100644 --- a/backend/internal/api/handlers/node_handler.go +++ b/backend/internal/api/handlers/node_handler.go @@ -3,6 +3,7 @@ package handlers import ( "errors" "fmt" + distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "math/rand/v2" "net/url" @@ -297,6 +298,10 @@ func (h *NodeHandler) ReserveNodeHandler(c *gin.Context) { wfUUID, err := h.svc.AsyncReserveNode(userID, user.Mnemonic, nodeID) if err != nil { reqLog.Error().Err(err).Msg("failed to start workflow to reserve node") + if errors.Is(err, distributedlocks.ErrNodeLocked) { + Conflict(c, "Node is busy serving another request") + return + } InternalServerError(c) return } diff --git a/backend/internal/core/workflows/deployer_activities.go b/backend/internal/core/workflows/deployer_activities.go index 0e8c57ad..7deadf5f 100644 --- a/backend/internal/core/workflows/deployer_activities.go +++ b/backend/internal/core/workflows/deployer_activities.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" cfg "kubecloud/internal/config" + distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "kubecloud/internal/deployment/kubedeployer" "kubecloud/internal/deployment/statemanager" @@ -641,7 +642,7 @@ func createAddNodeWorkflowTemplate(notificationDispatcher *notification.Notifica return template } -func registerDeploymentActivities(engine *ewf.Engine, metrics *metricsLib.Metrics, clusterRepo models.ClusterRepository, notificationDispatcher *notification.NotificationDispatcher, config cfg.Configuration) { +func registerDeploymentActivities(engine *ewf.Engine, metrics *metricsLib.Metrics, clusterRepo models.ClusterRepository, notificationDispatcher *notification.NotificationDispatcher, config cfg.Configuration, locker distributedlocks.DistributedLocks) { engine.Register(StepDeployNetwork, DeployNetworkStep()) engine.Register(StepDeployLeaderNode, DeployLeaderNodeStep()) engine.Register(StepBatchDeployAllNodes, BatchDeployAllNodesStep(metrics)) @@ -670,6 +671,7 @@ func registerDeploymentActivities(engine *ewf.Engine, metrics *metricsLib.Metric deployWFTemplate.AfterStepHooks = []ewf.AfterStepHook{ notifyStepHook(notificationDispatcher), } + deployWFTemplate.AfterWorkflowHooks = append(deployWFTemplate.AfterWorkflowHooks, releaseLocksHook(locker)) engine.RegisterTemplate(WorkflowDeployCluster, &deployWFTemplate) deleteWFTemplate := createDeployerWorkflowTemplate(notificationDispatcher, engine, metrics) @@ -695,6 +697,7 @@ func registerDeploymentActivities(engine *ewf.Engine, metrics *metricsLib.Metric {Name: StepVerifyNewNodes, RetryPolicy: longExponentialRetryPolicy}, {Name: StepStoreDeployment, RetryPolicy: standardRetryPolicy}, } + addNodeWFTemplate.AfterWorkflowHooks = append(addNodeWFTemplate.AfterWorkflowHooks, releaseLocksHook(locker)) engine.RegisterTemplate(WorkflowAddNode, &addNodeWFTemplate) removeNodeWFTemplate := createDeployerWorkflowTemplate(notificationDispatcher, engine, metrics) diff --git a/backend/internal/core/workflows/hooks.go b/backend/internal/core/workflows/hooks.go index 79a547ad..cf22beeb 100644 --- a/backend/internal/core/workflows/hooks.go +++ b/backend/internal/core/workflows/hooks.go @@ -2,11 +2,13 @@ package workflows import ( "context" + "encoding/json" "errors" "fmt" "time" + distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/deployment/kubedeployer" "kubecloud/internal/deployment/statemanager" "kubecloud/internal/infrastructure/logger" @@ -241,3 +243,23 @@ func metricsFailureHook(metrics *metricsLib.Metrics) ewf.AfterWorkflowHook { } } } + +func releaseLocksHook(locker distributedlocks.DistributedLocks) ewf.AfterWorkflowHook { + return func(ctx context.Context, wf *ewf.Workflow, _ error) { + log := logger.ForOperation("workflow", "release_locks").With().Str("workflow_name", wf.Name).Logger() + lockedKeys, ok := wf.Metadata["locked_keys"] + if !ok { + return + } + var lockedKeysJSON map[string]string + err := json.Unmarshal([]byte(lockedKeys), &lockedKeysJSON) + if err != nil { + log.Error().Err(err).Msg("failed to unmarshal locked keys") + return + } + if err := locker.ReleaseLock(ctx, lockedKeysJSON); err != nil { + log.Error().Err(err).Msg("failed to release locks") + return + } + } +} diff --git a/backend/internal/core/workflows/workflow.go b/backend/internal/core/workflows/workflow.go index 09af2e42..4cffb882 100644 --- a/backend/internal/core/workflows/workflow.go +++ b/backend/internal/core/workflows/workflow.go @@ -3,6 +3,7 @@ package workflows import ( "kubecloud/internal/billing" cfg "kubecloud/internal/config" + distributedlocks "kubecloud/internal/core/distributed_locks" "kubecloud/internal/core/models" "kubecloud/internal/core/persistence" "kubecloud/internal/infrastructure/gridclient" @@ -29,6 +30,7 @@ func RegisterEWFWorkflows( metrics *metrics.Metrics, notificationDispatcher *notification.NotificationDispatcher, stripeClient billing.StripeClient, + locker distributedlocks.DistributedLocks, ) { userRepo := persistence.NewGormUserRepository(db) clusterRepo := persistence.NewGormClusterRepository(db) @@ -125,6 +127,7 @@ func RegisterEWFWorkflows( {Name: StepReserveNode, RetryPolicy: &ewf.RetryPolicy{MaxAttempts: 2, BackOff: ewf.ConstantBackoff(2 * time.Second)}}, {Name: StepVerifyNodeState, RetryPolicy: &ewf.RetryPolicy{MaxAttempts: 5, BackOff: ewf.ExponentialBackoff(10*time.Second, 2*time.Minute, 2.0)}}, } + reserveNodeTemplate.AfterWorkflowHooks = append(reserveNodeTemplate.AfterWorkflowHooks, releaseLocksHook(locker)) engine.RegisterTemplate(WorkflowReserveNode, &reserveNodeTemplate) unreserveNodeTemplate := newKubecloudWorkflowTemplate(notificationDispatcher) @@ -144,7 +147,7 @@ func RegisterEWFWorkflows( // trackClusterHealthWFTemplate.BeforeWorkflowHooks = []ewf.BeforeWorkflowHook{hookNotificationWorkflowStarted} engine.RegisterTemplate(WorkflowTrackClusterHealth, &trackClusterHealthWFTemplate) - registerDeploymentActivities(engine, metrics, clusterRepo, notificationDispatcher, config) + registerDeploymentActivities(engine, metrics, clusterRepo, notificationDispatcher, config, locker) // Email-only workflow for guaranteed email delivery with retries emailNotificationTemplate := ewf.WorkflowTemplate{ From 07b539deea21a55ce62e269c59218a8df84eb673 Mon Sep 17 00:00:00 2001 From: SalmaElsoly Date: Mon, 15 Dec 2025 09:00:09 +0200 Subject: [PATCH 09/10] refactor: use scan instead of keys --- .../core/distributed_locks/redis_locker.go | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/backend/internal/core/distributed_locks/redis_locker.go b/backend/internal/core/distributed_locks/redis_locker.go index 613e9323..a3a4142e 100644 --- a/backend/internal/core/distributed_locks/redis_locker.go +++ b/backend/internal/core/distributed_locks/redis_locker.go @@ -87,6 +87,7 @@ func (l *RedisLocker) rollbackLocks(ctx context.Context, locked map[string]strin return nil } +// ReleaseLock releases the locks for the given keys. func (l *RedisLocker) ReleaseLock(ctx context.Context, lockedKeys map[string]string) error { if len(lockedKeys) == 0 { return nil @@ -119,19 +120,24 @@ func (l *RedisLocker) ReleaseLock(ctx context.Context, lockedKeys map[string]str return nil } +// GetLockedNodes returns the list of locked nodes. func (l *RedisLocker) GetLockedNodes(ctx context.Context) ([]uint32, error) { - keys, err := l.client.Keys(ctx, "locked:*").Result() - if err != nil { - return nil, err - } - nodes := make([]uint32, len(keys)) - for i, key := range keys { + iter := l.client.Scan(ctx, 0, "locked:*", 0).Iterator() + + nodes := make([]uint32, 0) + for iter.Next(ctx) { + key := iter.Val() nodeID := strings.Split(key, ":")[1] value, parseErr := strconv.ParseUint(nodeID, 10, 32) if parseErr != nil { return nil, fmt.Errorf("failed to parse locked node id from %s: %w", key, parseErr) } - nodes[i] = uint32(value) + nodes = append(nodes, uint32(value)) } + + if err := iter.Err(); err != nil { + return nil, err + } + return nodes, nil } From 6c8bbc84ae42b2adc06bdc385256ed2055ec9006 Mon Sep 17 00:00:00 2001 From: SalmaElsoly Date: Sun, 21 Dec 2025 15:13:24 +0200 Subject: [PATCH 10/10] refactor: make interface more resource type independent --- .../api/handlers/deployment_handler.go | 4 +- backend/internal/api/handlers/node_handler.go | 2 +- .../distributed_locks/distributed_locks.go | 13 +- .../core/distributed_locks/redis_locker.go | 178 +++++++++--------- .../distributed_locks/redis_locker_test.go | 87 +++++---- .../core/services/deployment_service.go | 65 ++++--- .../internal/core/services/node_service.go | 43 +++-- .../core/services/node_service_test.go | 16 +- backend/internal/core/workflows/hooks.go | 14 +- 9 files changed, 224 insertions(+), 198 deletions(-) diff --git a/backend/internal/api/handlers/deployment_handler.go b/backend/internal/api/handlers/deployment_handler.go index eace528d..9c59eda8 100644 --- a/backend/internal/api/handlers/deployment_handler.go +++ b/backend/internal/api/handlers/deployment_handler.go @@ -241,7 +241,7 @@ func (h *DeploymentHandler) HandleDeployCluster(c *gin.Context) { wfUUID, wfStatus, err := h.svc.AsyncDeployCluster(config, cluster) if err != nil { reqLog.Error().Err(err).Msg("failed to start deployment workflow") - if errors.Is(err, distributedlocks.ErrNodeLocked) { + if errors.Is(err, distributedlocks.ErrResourceLocked) { Conflict(c, "Node is busy serving another request") return } @@ -417,7 +417,7 @@ func (h *DeploymentHandler) HandleAddNode(c *gin.Context) { wfUUID, wfStatus, err := h.svc.AsyncAddNode(config, cl, cluster.Nodes[0]) if err != nil { reqLog.Error().Err(err).Msg("failed to start add node workflow") - if errors.Is(err, distributedlocks.ErrNodeLocked) { + if errors.Is(err, distributedlocks.ErrResourceLocked) { Conflict(c, "Node is busy serving another request") return } diff --git a/backend/internal/api/handlers/node_handler.go b/backend/internal/api/handlers/node_handler.go index f315e7c5..9c313625 100644 --- a/backend/internal/api/handlers/node_handler.go +++ b/backend/internal/api/handlers/node_handler.go @@ -298,7 +298,7 @@ func (h *NodeHandler) ReserveNodeHandler(c *gin.Context) { wfUUID, err := h.svc.AsyncReserveNode(userID, user.Mnemonic, nodeID) if err != nil { reqLog.Error().Err(err).Msg("failed to start workflow to reserve node") - if errors.Is(err, distributedlocks.ErrNodeLocked) { + if errors.Is(err, distributedlocks.ErrResourceLocked) { Conflict(c, "Node is busy serving another request") return } diff --git a/backend/internal/core/distributed_locks/distributed_locks.go b/backend/internal/core/distributed_locks/distributed_locks.go index 4de6771a..53c52330 100644 --- a/backend/internal/core/distributed_locks/distributed_locks.go +++ b/backend/internal/core/distributed_locks/distributed_locks.go @@ -5,11 +5,14 @@ import ( "errors" ) -var ErrNodeLocked = errors.New("node is currently locked by another request") +var ErrResourceLocked = errors.New("resource is currently locked by another request") + +const ( + NodeLockPrefix = "node:" +) -// DistributedLocks is an interface that defines the methods for distributed locks. type DistributedLocks interface { - AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) (map[string]string, error) - ReleaseLock(ctx context.Context, lockedKeys map[string]string) error - GetLockedNodes(ctx context.Context) ([]uint32, error) + AcquireLocks(ctx context.Context, resourceKeys []string) (map[string]string, error) + ReleaseLocks(ctx context.Context, lockedKeys map[string]string) error + GetLockedResources(ctx context.Context, keyPattern string) ([]string, error) } diff --git a/backend/internal/core/distributed_locks/redis_locker.go b/backend/internal/core/distributed_locks/redis_locker.go index a3a4142e..123809c2 100644 --- a/backend/internal/core/distributed_locks/redis_locker.go +++ b/backend/internal/core/distributed_locks/redis_locker.go @@ -3,18 +3,12 @@ package distributedlocks import ( "context" "fmt" - "strconv" - "strings" "time" "github.com/google/uuid" "github.com/redis/go-redis/v9" ) -const ( - nodeLockKey = "locked" -) - type RedisLocker struct { client *redis.Client lockTimeout time.Duration @@ -28,116 +22,128 @@ func NewRedisLocker(client *redis.Client, lockTimeout time.Duration) *RedisLocke } } -// AcquireNodesLocks acquires locks for the given node IDs. -func (l *RedisLocker) AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) (map[string]string, error) { - lockedKeys, err := l.acquireKeys(ctx, nodeLockKeys(nodeIDs)) - if err != nil { - return nil, err +func (l *RedisLocker) AcquireLocks(ctx context.Context, resourceKeys []string) (map[string]string, error) { + if len(resourceKeys) == 0 { + return nil, fmt.Errorf("no resource keys provided") } - return lockedKeys, nil -} -func nodeLockKeys(nodeIDs []uint32) []string { - keys := make([]string, len(nodeIDs)) - for i, id := range nodeIDs { - keys[i] = fmt.Sprintf("%s:%d", nodeLockKey, id) - } - return keys -} + expiry := int64(l.lockTimeout / time.Millisecond) -func (l *RedisLocker) acquireKeys(ctx context.Context, keys []string) (map[string]string, error) { - locked := make(map[string]string, len(keys)) - - for _, key := range keys { - keyValue := uuid.New().String() - ok, err := l.client.SetNX(ctx, key, keyValue, l.lockTimeout).Result() - if err != nil { - if rollErr := l.rollbackLocks(ctx, locked); rollErr != nil { - return nil, rollErr - } - return nil, fmt.Errorf("redis error while acquiring lock for key %s: %w", key, err) - } + values := make([]string, len(resourceKeys)) + argv := make([]interface{}, 0, len(resourceKeys)+1) + //expiry of locks + argv = append(argv, expiry) - if !ok { - if rollErr := l.rollbackLocks(ctx, locked); rollErr != nil { - return nil, rollErr - } - return nil, fmt.Errorf("%w: %s", ErrNodeLocked, key) - } - - locked[key] = keyValue + // uuid values for each key + for i := range resourceKeys { + val := uuid.New().String() + values[i] = val + argv = append(argv, val) } - return locked, nil -} + lua := redis.NewScript(` +local expiry = tonumber(ARGV[1]) +local locked = {} + +for i = 1, #KEYS do + local ok = redis.call("SET", KEYS[i], ARGV[i+1], "PX", expiry, "NX") + if not ok then + for j = 1, #locked do + redis.call("DEL", KEYS[j]) + end + return {"LOCKED", KEYS[i]} + end + table.insert(locked, KEYS[i]) +end + +return {"OK"} +`) + + res, err := lua.Run(ctx, l.client, resourceKeys, argv...).Result() + if err != nil { + return nil, err + } -func (l *RedisLocker) rollbackLocks(ctx context.Context, locked map[string]string) error { - if len(locked) == 0 { - return nil + out, ok := res.([]interface{}) + if !ok || len(out) == 0 { + return nil, fmt.Errorf("unexpected script output: %v", res) } - keys := make([]string, 0, len(locked)) - for k := range locked { - keys = append(keys, k) + + status, _ := out[0].(string) + if status == "LOCKED" { + conflict := out[1].(string) + return nil, fmt.Errorf("%w: %s", ErrResourceLocked, conflict) } - if err := l.client.Del(ctx, keys...).Err(); err != nil { - return fmt.Errorf("redis error while rolling back locks: %w", err) + locked := map[string]string{} + for i, k := range resourceKeys { + locked[k] = values[i] } - return nil + return locked, nil } -// ReleaseLock releases the locks for the given keys. -func (l *RedisLocker) ReleaseLock(ctx context.Context, lockedKeys map[string]string) error { +// ReleaseLocks releases the locks for the given keys. +func (l *RedisLocker) ReleaseLocks(ctx context.Context, lockedKeys map[string]string) error { if len(lockedKeys) == 0 { return nil } + keys := make([]string, 0, len(lockedKeys)) + values := make([]interface{}, 0, len(lockedKeys)) - var failedKeys []string - for key, expectedValue := range lockedKeys { - storedValue, err := l.client.Get(ctx, key).Result() - if err == redis.Nil { - continue - } - if err != nil { - return fmt.Errorf("failed to get lock value for key %s: %w", key, err) - } - - if storedValue != expectedValue { - failedKeys = append(failedKeys, key) - continue - } + for k, v := range lockedKeys { + keys = append(keys, k) + values = append(values, v) + } - if err := l.client.Del(ctx, key).Err(); err != nil { - return fmt.Errorf("failed to delete lock for key %s: %w", key, err) - } + luaScript := redis.NewScript(` +local failed = {} +for i = 1, #KEYS do + local key = KEYS[i] + local expected = ARGV[i] + local actual = redis.call("GET", key) + + if actual ~= false then + if actual ~= expected then + table.insert(failed, key) + else + redis.call("DEL", key) + end + end +end +return failed +`) + + // Run the script + res, err := luaScript.Run(ctx, l.client, keys, values...).Result() + if err != nil { + return err } + failedKeys, _ := res.([]interface{}) if len(failedKeys) > 0 { - return fmt.Errorf("lock value mismatch for keys: %v", failedKeys) + mismatches := make([]string, len(failedKeys)) + for i, v := range failedKeys { + mismatches[i] = v.(string) + } + return fmt.Errorf("lock value mismatch for keys: %v", mismatches) } return nil } -// GetLockedNodes returns the list of locked nodes. -func (l *RedisLocker) GetLockedNodes(ctx context.Context) ([]uint32, error) { - iter := l.client.Scan(ctx, 0, "locked:*", 0).Iterator() - - nodes := make([]uint32, 0) +// GetLockedResources returns all currently locked resource keys matching the given pattern. +func (l *RedisLocker) GetLockedResources(ctx context.Context, keyPattern string) ([]string, error) { + if keyPattern == "" { + keyPattern = "*" + } + iter := l.client.Scan(ctx, 0, keyPattern, 0).Iterator() + resources := make([]string, 0) for iter.Next(ctx) { - key := iter.Val() - nodeID := strings.Split(key, ":")[1] - value, parseErr := strconv.ParseUint(nodeID, 10, 32) - if parseErr != nil { - return nil, fmt.Errorf("failed to parse locked node id from %s: %w", key, parseErr) - } - nodes = append(nodes, uint32(value)) + resources = append(resources, iter.Val()) } - if err := iter.Err(); err != nil { return nil, err } - - return nodes, nil + return resources, nil } diff --git a/backend/internal/core/distributed_locks/redis_locker_test.go b/backend/internal/core/distributed_locks/redis_locker_test.go index 7eca0583..e2c47119 100644 --- a/backend/internal/core/distributed_locks/redis_locker_test.go +++ b/backend/internal/core/distributed_locks/redis_locker_test.go @@ -27,45 +27,47 @@ func newTestRedisClient(t *testing.T) *redis.Client { return client } -func TestRedisLocker_AcquireNodesLocks_Success(t *testing.T) { +func TestRedisLocker_AcquireLocks_Success(t *testing.T) { client := newTestRedisClient(t) locker := &RedisLocker{client: client, lockTimeout: time.Minute} - lockedKeys, err := locker.AcquireNodesLocks(context.Background(), []uint32{1, 2}) + resourceKeys := []string{"node:1", "node:2"} + lockedKeys, err := locker.AcquireLocks(context.Background(), resourceKeys) require.NoError(t, err) require.Len(t, lockedKeys, 2) - require.Contains(t, lockedKeys, "locked:1") - require.Contains(t, lockedKeys, "locked:2") + require.Contains(t, lockedKeys, "node:1") + require.Contains(t, lockedKeys, "node:2") // Verify UUID values are stored - require.NotEmpty(t, lockedKeys["locked:1"]) - require.NotEmpty(t, lockedKeys["locked:2"]) - require.Equal(t, int64(1), client.Exists(context.Background(), "locked:1").Val()) - require.Equal(t, int64(1), client.Exists(context.Background(), "locked:2").Val()) + require.NotEmpty(t, lockedKeys["node:1"]) + require.NotEmpty(t, lockedKeys["node:2"]) + require.Equal(t, int64(1), client.Exists(context.Background(), "node:1").Val()) + require.Equal(t, int64(1), client.Exists(context.Background(), "node:2").Val()) // Verify the stored values match - val1, _ := client.Get(context.Background(), "locked:1").Result() - val2, _ := client.Get(context.Background(), "locked:2").Result() - require.Equal(t, lockedKeys["locked:1"], val1) - require.Equal(t, lockedKeys["locked:2"], val2) + val1, _ := client.Get(context.Background(), "node:1").Result() + val2, _ := client.Get(context.Background(), "node:2").Result() + require.Equal(t, lockedKeys["node:1"], val1) + require.Equal(t, lockedKeys["node:2"], val2) } -func TestRedisLocker_AcquireNodesLocks_NodeAlreadyLocked(t *testing.T) { +func TestRedisLocker_AcquireLocks_ResourceAlreadyLocked(t *testing.T) { client := newTestRedisClient(t) locker := &RedisLocker{client: client, lockTimeout: time.Minute} // Set an existing lock with a UUID value existingValue := "existing-uuid-value" - require.NoError(t, client.Set(context.Background(), "locked:2", existingValue, 0).Err()) + require.NoError(t, client.Set(context.Background(), "node:2", existingValue, 0).Err()) - lockedKeys, err := locker.AcquireNodesLocks(context.Background(), []uint32{1, 2}) + resourceKeys := []string{"node:1", "node:2"} + lockedKeys, err := locker.AcquireLocks(context.Background(), resourceKeys) require.Error(t, err) - require.ErrorIs(t, err, ErrNodeLocked) - require.Contains(t, err.Error(), "locked:2") + require.ErrorIs(t, err, ErrResourceLocked) + require.Contains(t, err.Error(), "node:2") require.Nil(t, lockedKeys) - require.Equal(t, int64(0), client.Exists(context.Background(), "locked:1").Val(), "previous locks should be rolled back") + require.Equal(t, int64(0), client.Exists(context.Background(), "node:1").Val(), "previous locks should be rolled back") // Verify the existing lock is still there - val, _ := client.Get(context.Background(), "locked:2").Result() + val, _ := client.Get(context.Background(), "node:2").Result() require.Equal(t, existingValue, val) } @@ -76,19 +78,19 @@ func TestRedisLocker_ReleaseLock_Success(t *testing.T) { // Set locks with specific UUID values lockValue1 := "uuid-value-1" lockValue2 := "uuid-value-2" - require.NoError(t, client.Set(context.Background(), "locked:1", lockValue1, 0).Err()) - require.NoError(t, client.Set(context.Background(), "locked:2", lockValue2, 0).Err()) + require.NoError(t, client.Set(context.Background(), "node:1", lockValue1, 0).Err()) + require.NoError(t, client.Set(context.Background(), "node:2", lockValue2, 0).Err()) // Release locks with matching values lockedKeys := map[string]string{ - "locked:1": lockValue1, - "locked:2": lockValue2, + "node:1": lockValue1, + "node:2": lockValue2, } - err := locker.ReleaseLock(context.Background(), lockedKeys) + err := locker.ReleaseLocks(context.Background(), lockedKeys) require.NoError(t, err) - require.Equal(t, int64(0), client.Exists(context.Background(), "locked:1").Val()) - require.Equal(t, int64(0), client.Exists(context.Background(), "locked:2").Val()) + require.Equal(t, int64(0), client.Exists(context.Background(), "node:1").Val()) + require.Equal(t, int64(0), client.Exists(context.Background(), "node:2").Val()) } func TestRedisLocker_ReleaseLock_ValueMismatch(t *testing.T) { @@ -97,19 +99,19 @@ func TestRedisLocker_ReleaseLock_ValueMismatch(t *testing.T) { // Set lock with a specific value storedValue := "stored-uuid-value" - require.NoError(t, client.Set(context.Background(), "locked:1", storedValue, 0).Err()) + require.NoError(t, client.Set(context.Background(), "node:1", storedValue, 0).Err()) // Try to release with wrong value lockedKeys := map[string]string{ - "locked:1": "wrong-uuid-value", + "node:1": "wrong-uuid-value", } - err := locker.ReleaseLock(context.Background(), lockedKeys) + err := locker.ReleaseLocks(context.Background(), lockedKeys) require.Error(t, err) require.Contains(t, err.Error(), "lock value mismatch") // Verify the lock is still there - require.Equal(t, int64(1), client.Exists(context.Background(), "locked:1").Val()) - val, _ := client.Get(context.Background(), "locked:1").Result() + require.Equal(t, int64(1), client.Exists(context.Background(), "node:1").Val()) + val, _ := client.Get(context.Background(), "node:1").Result() require.Equal(t, storedValue, val) } @@ -119,25 +121,32 @@ func TestRedisLocker_ReleaseLock_KeyNotExists(t *testing.T) { // Try to release a non-existent lock lockedKeys := map[string]string{ - "locked:999": "some-uuid-value", + "node:999": "some-uuid-value", } - err := locker.ReleaseLock(context.Background(), lockedKeys) + err := locker.ReleaseLocks(context.Background(), lockedKeys) // Should not error if key doesn't exist (just skip it) require.NoError(t, err) } -func TestRedisLocker_GetLockedNodes(t *testing.T) { +func TestRedisLocker_GetLockedResources(t *testing.T) { client := newTestRedisClient(t) locker := &RedisLocker{client: client, lockTimeout: time.Minute} // Set some locks - require.NoError(t, client.Set(context.Background(), "locked:1", "uuid-1", 0).Err()) - require.NoError(t, client.Set(context.Background(), "locked:2", "uuid-2", 0).Err()) - require.NoError(t, client.Set(context.Background(), "locked:99", "uuid-99", 0).Err()) + require.NoError(t, client.Set(context.Background(), "node:1", "uuid-1", 0).Err()) + require.NoError(t, client.Set(context.Background(), "node:2", "uuid-2", 0).Err()) + require.NoError(t, client.Set(context.Background(), "node:99", "uuid-99", 0).Err()) + require.NoError(t, client.Set(context.Background(), "cluster:5", "uuid-5", 0).Err()) - nodes, err := locker.GetLockedNodes(context.Background()) + // Get all node resources + resources, err := locker.GetLockedResources(context.Background(), "node:*") require.NoError(t, err) - require.ElementsMatch(t, []uint32{1, 2, 99}, nodes) + require.ElementsMatch(t, []string{"node:1", "node:2", "node:99"}, resources) + + // Get all resources + allResources, err := locker.GetLockedResources(context.Background(), "") + require.NoError(t, err) + require.Len(t, allResources, 4) } diff --git a/backend/internal/core/services/deployment_service.go b/backend/internal/core/services/deployment_service.go index a6c2a023..ad39c6c2 100644 --- a/backend/internal/core/services/deployment_service.go +++ b/backend/internal/core/services/deployment_service.go @@ -2,7 +2,6 @@ package services import ( "context" - "encoding/json" "errors" "fmt" cfg "kubecloud/internal/config" @@ -211,7 +210,7 @@ func (svc *DeploymentService) runWithQueue(queueName string, wf *ewf.Workflow) e return svc.ewfEngine.Run(svc.appCtx, *wf) } -func (svc *DeploymentService) handleDeploymentAction(userID int, workflowName string, state ewf.State, displayName string, metadata map[string]string, nodeIDs []uint32) (workflowID string, status ewf.WorkflowStatus, err error) { +func (svc *DeploymentService) handleDeploymentAction(userID int, workflowName string, state ewf.State, displayName string, metadata map[string]string) (workflowID string, status ewf.WorkflowStatus, err error) { _, span := svc.tracer.StartSpan(context.Background(), "handleDeploymentAction") defer span.End() @@ -236,26 +235,19 @@ func (svc *DeploymentService) handleDeploymentAction(userID int, workflowName st return "", "", err } - var lockedKeys map[string]string - if len(nodeIDs) > 0 { - lockedKeys, err = svc.locker.AcquireNodesLocks(svc.appCtx, nodeIDs) - if err != nil { + if err = svc.runWithQueue(queueName, &wf); err != nil { + lockedKeysValue, ok := wf.State["locked_keys"] + if !ok { + telemetry.RecordError(span, err) return "", "", err } - - lockedKeysJSON, err := json.Marshal(lockedKeys) - if err != nil { - return "", "", fmt.Errorf("failed to marshal locked keys: %w", err) + lockedKeys, ok := lockedKeysValue.(map[string]string) + if !ok || len(lockedKeys) == 0 { + telemetry.RecordError(span, err) + return "", "", err } - wf.Metadata["locked_keys"] = string(lockedKeysJSON) - - } - - if err = svc.runWithQueue(queueName, &wf); err != nil { - if len(nodeIDs) > 0 { - if releaseErr := svc.locker.ReleaseLock(svc.appCtx, lockedKeys); releaseErr != nil { - err = fmt.Errorf("%w: failed to release workflow lock: %v", err, releaseErr) - } + if releaseErr := svc.locker.ReleaseLocks(svc.appCtx, lockedKeys); releaseErr != nil { + err = fmt.Errorf("%w: failed to release workflow lock: %v", err, releaseErr) } telemetry.RecordError(span, err) return "", "", err @@ -269,13 +261,18 @@ func (svc *DeploymentService) handleDeploymentAction(userID int, workflowName st } func (svc *DeploymentService) AsyncDeployCluster(config statemanager.ClientConfig, cluster kubedeployer.Cluster) (string, ewf.WorkflowStatus, error) { - nodeIDs := make([]uint32, 0, len(cluster.Nodes)) + resourceKeys := make([]string, 0, len(cluster.Nodes)) for _, node := range cluster.Nodes { - nodeIDs = append(nodeIDs, node.NodeID) + resourceKeys = append(resourceKeys, fmt.Sprintf("%s%d", distributedlocks.NodeLockPrefix, node.NodeID)) + } + lockedKeys, err := svc.locker.AcquireLocks(svc.appCtx, resourceKeys) + if err != nil { + return "", "", err } state := ewf.State{ - "config": config, - "cluster": cluster, + "config": config, + "cluster": cluster, + "locked_keys": lockedKeys, } displayName := fmt.Sprintf("Deploying cluster %s", cluster.Name) @@ -283,7 +280,7 @@ func (svc *DeploymentService) AsyncDeployCluster(config statemanager.ClientConfi "cluster_name": cluster.Name, "node_count": strconv.Itoa(len(cluster.Nodes)), } - return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeployCluster, state, displayName, metadata, nodeIDs) + return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeployCluster, state, displayName, metadata) } func (svc *DeploymentService) AsyncDeleteCluster(config statemanager.ClientConfig, projectName string) (string, ewf.WorkflowStatus, error) { @@ -297,7 +294,7 @@ func (svc *DeploymentService) AsyncDeleteCluster(config statemanager.ClientConfi metadata := map[string]string{ "project_name": projectName, } - return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeleteCluster, state, displayName, metadata, nil) + return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeleteCluster, state, displayName, metadata) } func (svc *DeploymentService) AsyncDeleteAllClusters(config statemanager.ClientConfig) (string, ewf.WorkflowStatus, error) { @@ -307,22 +304,28 @@ func (svc *DeploymentService) AsyncDeleteAllClusters(config statemanager.ClientC } displayName := "Deleting all user clusters" - return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeleteAllClusters, state, displayName, nil, nil) + return svc.handleDeploymentAction(config.UserID, workflows.WorkflowDeleteAllClusters, state, displayName, nil) } func (svc *DeploymentService) AsyncAddNode(config statemanager.ClientConfig, cl kubedeployer.Cluster, node kubedeployer.Node) (string, ewf.WorkflowStatus, error) { + resourceKeys := []string{fmt.Sprintf("%s%d", distributedlocks.NodeLockPrefix, node.NodeID)} + lockedKeys, err := svc.locker.AcquireLocks(svc.appCtx, resourceKeys) + if err != nil { + return "", "", err + } state := ewf.State{ - "config": config, - "cluster": cl, - "node": node, + "config": config, + "cluster": cl, + "node": node, + "locked_keys": lockedKeys, } displayName := fmt.Sprintf("Adding node %s to cluster %s", node.Name, cl.Name) metadata := map[string]string{ "cluster_name": cl.Name, "node_name": node.Name, } - return svc.handleDeploymentAction(config.UserID, workflows.WorkflowAddNode, state, displayName, metadata, []uint32{node.NodeID}) + return svc.handleDeploymentAction(config.UserID, workflows.WorkflowAddNode, state, displayName, metadata) } func (svc *DeploymentService) AsyncRemoveNode(config statemanager.ClientConfig, cl kubedeployer.Cluster, nodeName string) (string, ewf.WorkflowStatus, error) { @@ -338,5 +341,5 @@ func (svc *DeploymentService) AsyncRemoveNode(config statemanager.ClientConfig, "cluster_name": cl.Name, "node_name": nodeName, } - return svc.handleDeploymentAction(config.UserID, workflows.WorkflowRemoveNode, state, displayName, metadata, nil) + return svc.handleDeploymentAction(config.UserID, workflows.WorkflowRemoveNode, state, displayName, metadata) } diff --git a/backend/internal/core/services/node_service.go b/backend/internal/core/services/node_service.go index 170a1dc9..e8f8606a 100644 --- a/backend/internal/core/services/node_service.go +++ b/backend/internal/core/services/node_service.go @@ -2,7 +2,6 @@ package services import ( "context" - "encoding/json" "errors" "fmt" cfg "kubecloud/internal/config" @@ -13,6 +12,7 @@ import ( "kubecloud/internal/infrastructure/gridclient" "kubecloud/internal/infrastructure/telemetry" "strconv" + "strings" proxyTypes "github.com/threefoldtech/tfgrid-sdk-go/grid-proxy/pkg/types" "github.com/xmonader/ewf" @@ -31,7 +31,7 @@ type NodeService struct { appCtx context.Context ewfEngine *ewf.Engine gridClient gridclient.GridClient - locker distributedlocks.DistributedLocks + locker distributedlocks.DistributedLocks tracer *telemetry.ServiceTracer } @@ -46,7 +46,7 @@ func NewNodeService( appCtx: appCtx, ewfEngine: ewfEngine, gridClient: gridClient, - locker: locker, + locker: locker, tracer: telemetry.NewServiceTracer("node_service"), } } @@ -229,6 +229,11 @@ func (svc *NodeService) AsyncReserveNode(userID int, userMnemonic string, nodeID if err != nil { return "", err } + resourceKeys := []string{fmt.Sprintf("%s%d", distributedlocks.NodeLockPrefix, nodeID)} + lockedKeys, err := svc.locker.AcquireLocks(svc.appCtx, resourceKeys) + if err != nil { + return "", err + } wf.State = map[string]interface{}{ "node_id": nodeID, @@ -237,24 +242,15 @@ func (svc *NodeService) AsyncReserveNode(userID int, userMnemonic string, nodeID "user_id": userID, "mnemonic": userMnemonic, }, + "locked_keys": lockedKeys, } if err = persistence.SetStateUserID(&wf, userID); err != nil { return "", err } - lockedKeys, err := svc.locker.AcquireNodesLocks(svc.appCtx, []uint32{nodeID}) - if err != nil { - return "", err - } - lockedKeysJSON, err := json.Marshal(lockedKeys) - if err != nil { - return "", fmt.Errorf("failed to marshal locked keys: %w", err) - } - wf.Metadata["locked_keys"] = string(lockedKeysJSON) - if err = svc.runWithQueue(queueName, &wf); err != nil { - if releaseErr := svc.locker.ReleaseLock(svc.appCtx, lockedKeys); releaseErr != nil { + if releaseErr := svc.locker.ReleaseLocks(svc.appCtx, lockedKeys); releaseErr != nil { err = fmt.Errorf("%w: failed to release workflow lock: %v", err, releaseErr) } return "", err @@ -315,14 +311,25 @@ func (svc *NodeService) runWithQueue(queueName string, wf *ewf.Workflow) error { } func (svc *NodeService) FilterLockedNodes(ctx context.Context, nodes []proxyTypes.Node) ([]proxyTypes.Node, error) { - lockedNodes, err := svc.locker.GetLockedNodes(ctx) + lockedResources, err := svc.locker.GetLockedResources(ctx, fmt.Sprintf("%s*", distributedlocks.NodeLockPrefix)) if err != nil { return nil, err } - lockedSet := make(map[uint32]bool, len(lockedNodes)) - for _, id := range lockedNodes { - lockedSet[id] = true + + // Convert locked resources to node IDs + lockedSet := make(map[uint32]bool) + for _, resource := range lockedResources { + nodeIDStr := strings.TrimPrefix(resource, distributedlocks.NodeLockPrefix) + if nodeIDStr == resource { + continue + } + nodeID, err := strconv.ParseUint(nodeIDStr, 10, 32) + if err != nil { + continue + } + lockedSet[uint32(nodeID)] = true } + unlockedNodes := make([]proxyTypes.Node, 0, len(nodes)) for _, node := range nodes { if lockedSet[uint32(node.NodeID)] { diff --git a/backend/internal/core/services/node_service_test.go b/backend/internal/core/services/node_service_test.go index a267af71..71d063bb 100644 --- a/backend/internal/core/services/node_service_test.go +++ b/backend/internal/core/services/node_service_test.go @@ -17,23 +17,23 @@ type MockDistributedLocks struct { mock.Mock } -func (m *MockDistributedLocks) GetLockedNodes(ctx context.Context) ([]uint32, error) { - args := m.Called(ctx) +func (m *MockDistributedLocks) GetLockedResources(ctx context.Context, keyPattern string) ([]string, error) { + args := m.Called(ctx, keyPattern) if args.Get(0) == nil { return nil, args.Error(1) } - return args.Get(0).([]uint32), args.Error(1) + return args.Get(0).([]string), args.Error(1) } -func (m *MockDistributedLocks) AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) (map[string]string, error) { - args := m.Called(ctx, nodeIDs) +func (m *MockDistributedLocks) AcquireLocks(ctx context.Context, resourceKeys []string) (map[string]string, error) { + args := m.Called(ctx, resourceKeys) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(map[string]string), args.Error(1) } -func (m *MockDistributedLocks) ReleaseLock(ctx context.Context, lockedKeys map[string]string) error { +func (m *MockDistributedLocks) ReleaseLocks(ctx context.Context, lockedKeys map[string]string) error { args := m.Called(ctx, lockedKeys) return args.Error(0) } @@ -178,7 +178,7 @@ func TestNodeService_FilterLockedNodes_Success(t *testing.T) { {NodeID: 102}, } - mockLocker.On("GetLockedNodes", context.Background()).Return([]uint32{101}, nil) + mockLocker.On("GetLockedResources", context.Background(), "node:*").Return([]string{"node:101"}, nil) service := NodeService{ locker: mockLocker, @@ -198,7 +198,7 @@ func TestNodeService_FilterLockedNodes_Error(t *testing.T) { mockNodesRepo := new(mockUserNodesRepo) mockUserRepo := new(mockUserRepo) - mockLocker.On("GetLockedNodes", context.Background()).Return(nil, fmt.Errorf("locked nodes error")) + mockLocker.On("GetLockedResources", context.Background(), "node:*").Return(nil, fmt.Errorf("locked nodes error")) nodes := []proxyTypes.Node{ {NodeID: 100}, diff --git a/backend/internal/core/workflows/hooks.go b/backend/internal/core/workflows/hooks.go index cf22beeb..3d80b5d6 100644 --- a/backend/internal/core/workflows/hooks.go +++ b/backend/internal/core/workflows/hooks.go @@ -2,7 +2,6 @@ package workflows import ( "context" - "encoding/json" "errors" "fmt" @@ -247,19 +246,18 @@ func metricsFailureHook(metrics *metricsLib.Metrics) ewf.AfterWorkflowHook { func releaseLocksHook(locker distributedlocks.DistributedLocks) ewf.AfterWorkflowHook { return func(ctx context.Context, wf *ewf.Workflow, _ error) { log := logger.ForOperation("workflow", "release_locks").With().Str("workflow_name", wf.Name).Logger() - lockedKeys, ok := wf.Metadata["locked_keys"] - if !ok { + lockedKeys, err := getFromState[map[string]string](wf.State, "locked_keys") + if err != nil { + log.Error().Err(err).Msg("failed to get locks from state") return } - var lockedKeysJSON map[string]string - err := json.Unmarshal([]byte(lockedKeys), &lockedKeysJSON) - if err != nil { - log.Error().Err(err).Msg("failed to unmarshal locked keys") + if len(lockedKeys) == 0 { return } - if err := locker.ReleaseLock(ctx, lockedKeysJSON); err != nil { + if err := locker.ReleaseLocks(ctx, lockedKeys); err != nil { log.Error().Err(err).Msg("failed to release locks") return } + log.Info().Int("lock_count", len(lockedKeys)).Msg("successfully released locks") } }