From d0aa71f8a882ce74c432d81b99d985941768ec07 Mon Sep 17 00:00:00 2001 From: ldornele Date: Sat, 7 Mar 2026 20:30:07 -0300 Subject: [PATCH 1/6] HYPERFLEET-618: add PostgreSQL advisory locks for migration coordination --- cmd/hyperfleet-api/migrate/cmd.go | 4 +- docs/database.md | 40 +++ pkg/db/advisory_locks.go | 107 ++++++++ pkg/db/context.go | 94 +++++++ pkg/db/migrations.go | 21 ++ test/integration/advisory_locks_test.go | 321 ++++++++++++++++++++++++ 6 files changed, 585 insertions(+), 2 deletions(-) create mode 100644 pkg/db/advisory_locks.go create mode 100644 test/integration/advisory_locks_test.go diff --git a/cmd/hyperfleet-api/migrate/cmd.go b/cmd/hyperfleet-api/migrate/cmd.go index 89b8d34..0ad9a9e 100755 --- a/cmd/hyperfleet-api/migrate/cmd.go +++ b/cmd/hyperfleet-api/migrate/cmd.go @@ -53,11 +53,11 @@ func runMigrateWithError() error { } }() - if err := db.Migrate(connection.New(ctx)); err != nil { + // Use MigrateWithLock to prevent concurrent migrations from multiple pods + if err := db.MigrateWithLock(ctx, connection); err != nil { logger.WithError(ctx, err).Error("Migration failed") return err } - logger.Info(ctx, "Migration completed successfully") return nil } diff --git a/docs/database.md b/docs/database.md index e5b3d6a..f6f5e8b 100644 --- a/docs/database.md +++ b/docs/database.md @@ -61,6 +61,46 @@ Uses GORM AutoMigrate: - Additive (creates missing tables, columns, indexes) - Run via `./bin/hyperfleet-api migrate` +### Migration Coordination + +**Problem:** During rolling deployments, multiple pods attempt to run migrations simultaneously, causing race conditions and deployment failures. + +**Solution:** PostgreSQL advisory locks ensure exclusive migration execution. + +#### How It Works + +```go +// Only one pod/process acquires the lock and runs migrations +// Others wait until the lock is released +db.MigrateWithLock(ctx, factory) +``` + +**Implementation:** +1. Pod acquires advisory lock via `pg_advisory_xact_lock(hash("migrations"), hash("Migrations"))` +2. Lock holder runs migrations exclusively +3. Other pods block until lock is released +4. Lock automatically released on transaction commit + +**Key Features:** +- **Zero infrastructure overhead** - Uses native PostgreSQL locks +- **Automatic cleanup** - Locks released on transaction end or pod crash +- **Nested lock support** - Same lock can be acquired in nested contexts without deadlock +- **UUID-based ownership** - Only original acquirer can unlock + +#### Testing Concurrent Migrations + +Integration tests validate concurrent behavior: + +```bash +make test-integration # Runs TestConcurrentMigrations +``` + +**Test coverage:** +- `TestConcurrentMigrations` - Multiple pods running migrations simultaneously +- `TestAdvisoryLocksConcurrently` - Lock serialization under race conditions +- `TestAdvisoryLocksWithTransactions` - Lock + transaction interaction +- `TestAdvisoryLockBlocking` - Lock blocking behavior + ## Database Setup ```bash diff --git a/pkg/db/advisory_locks.go b/pkg/db/advisory_locks.go new file mode 100644 index 0000000..4fd4c8e --- /dev/null +++ b/pkg/db/advisory_locks.go @@ -0,0 +1,107 @@ +package db + +import ( + "context" + "errors" + "hash/fnv" + "time" + + "gorm.io/gorm" +) + +// LockType represents the type of advisory lock +type LockType string + +const ( + // Migrations lock type for database migrations + Migrations LockType = "Migrations" +) + +// AdvisoryLock represents a postgres advisory lock +// +// begin # start a Tx +// select pg_advisory_xact_lock(id, lockType) # obtain the lock (blocking) +// end # end the Tx and release the lock +// +// ownerUUID is a way to own the lock. Only the very first +// service call that owns the lock will have the correct ownerUUID. This is necessary +// to allow functions to call other service functions as part of the same lock (id, lockType). +type AdvisoryLock struct { + g2 *gorm.DB + txid int64 + ownerUUID *string + id *string + lockType *LockType + startTime time.Time +} + +// newAdvisoryLock constructs a new AdvisoryLock object. +func newAdvisoryLock(ctx context.Context, connection SessionFactory, ownerUUID *string, id *string, locktype *LockType) (*AdvisoryLock, error) { + if connection == nil { + return nil, errors.New("AdvisoryLock: connection factory is missing") + } + + // it requires a new DB session to start the advisory lock. + g2 := connection.New(ctx) + + // start a Tx to ensure gorm will obtain/release the lock using a same connection. + tx := g2.Begin() + if tx.Error != nil { + return nil, tx.Error + } + + // current transaction ID set by postgres. these are *not* distinct across time + // and do get reset after postgres performs "vacuuming" to reclaim used IDs. + var txid struct{ ID int64 } + err := tx.Raw("select txid_current() as id").Scan(&txid).Error + + return &AdvisoryLock{ + txid: txid.ID, + ownerUUID: ownerUUID, + id: id, + lockType: locktype, + g2: tx, + startTime: time.Now(), + }, err +} + +// lock calls select pg_advisory_xact_lock(id, lockType) to obtain the lock defined by (id, lockType). +// it is blocked if some other thread currently is holding the same lock (id, lockType). +// if blocked, it can be unblocked or timed out when overloaded. +func (l *AdvisoryLock) lock() error { + if l.g2 == nil { + return errors.New("AdvisoryLock: transaction is missing") + } + if l.id == nil { + return errors.New("AdvisoryLock: id is missing") + } + if l.lockType == nil { + return errors.New("AdvisoryLock: lockType is missing") + } + + idAsInt := hash(*l.id) + typeAsInt := hash(string(*l.lockType)) + err := l.g2.Exec("select pg_advisory_xact_lock(?, ?)", idAsInt, typeAsInt).Error + return err +} + +func (l *AdvisoryLock) unlock() error { + if l.g2 == nil { + return errors.New("AdvisoryLock: transaction is missing") + } + + // it ends the Tx and implicitly releases the lock. + err := l.g2.Commit().Error + l.g2 = nil + return err +} + +// hash string to int32 (postgres integer) +// https://pkg.go.dev/math#pkg-constants +// https://www.postgresql.org/docs/12/datatype-numeric.html +func hash(s string) int32 { + h := fnv.New32a() + h.Write([]byte(s)) + // Sum32() returns uint32. needs conversion. + return int32(h.Sum32()) +} diff --git a/pkg/db/context.go b/pkg/db/context.go index 06c5114..8727839 100755 --- a/pkg/db/context.go +++ b/pkg/db/context.go @@ -3,10 +3,33 @@ package db import ( "context" + "github.com/google/uuid" + dbContext "github.com/openshift-hyperfleet/hyperfleet-api/pkg/db/db_context" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/logger" ) +type advisoryLockKey string + +const ( + advisoryLock advisoryLockKey = "advisoryLock" +) + +type advisoryLockMap map[string]*AdvisoryLock + +func (m advisoryLockMap) key(id string, lockType LockType) string { + return id + ":" + string(lockType) +} + +func (m advisoryLockMap) get(id string, lockType LockType) (*AdvisoryLock, bool) { + lock, ok := m[m.key(id, lockType)] + return lock, ok +} + +func (m advisoryLockMap) set(id string, lockType LockType, lock *AdvisoryLock) { + m[m.key(id, lockType)] = lock +} + // NewContext returns a new context with transaction stored in it. // Upon error, the original context is still returned along with an error func NewContext(ctx context.Context, connection SessionFactory) (context.Context, error) { @@ -53,3 +76,74 @@ func MarkForRollback(ctx context.Context, err error) { transaction.SetRollbackFlag(true) logger.WithError(ctx, err).Info("Marked transaction for rollback") } + +// NewAdvisoryLockContext returns a new context with AdvisoryLock stored in it. +// Upon error, the original context is still returned along with an error +func NewAdvisoryLockContext(ctx context.Context, connection SessionFactory, id string, lockType LockType) (context.Context, string, error) { + // lockOwnerID will be different for every service function that attempts to start a lock. + // only the initial call in the stack must unlock. + // Unlock() will compare UUIDs and ensure only the top level call succeeds. + lockOwnerID := uuid.New().String() + + locks, found := ctx.Value(advisoryLock).(advisoryLockMap) + if found { + if _, ok := locks.get(id, lockType); ok { + return ctx, lockOwnerID, nil + } + } else { + locks = make(advisoryLockMap) + } + + lock, err := newAdvisoryLock(ctx, connection, &lockOwnerID, &id, &lockType) + if err != nil { + logger.WithError(ctx, err).Error("Failed to create advisory lock") + return ctx, lockOwnerID, err + } + + // obtain the advisory lock (blocking) + err = lock.lock() + if err != nil { + logger.WithError(ctx, err).Error("Failed to acquire advisory lock") + return ctx, lockOwnerID, err + } + + locks.set(id, lockType, lock) + + ctx = context.WithValue(ctx, advisoryLock, locks) + logger.With(ctx, "lock_id", id, "lock_type", lockType).Info("Acquired advisory lock") + + return ctx, lockOwnerID, nil +} + +// Unlock searches current locks and unlocks the one matching its owner id. +func Unlock(ctx context.Context, callerUUID string) context.Context { + locks, ok := ctx.Value(advisoryLock).(advisoryLockMap) + if !ok { + logger.Error(ctx, "Could not retrieve locks from context") + return ctx + } + + for k, lock := range locks { + if lock.ownerUUID == nil { + logger.With(ctx, "lock_id", lock.id).Warn("lockOwnerID could not be found in AdvisoryLock") + } else if *lock.ownerUUID == callerUUID { + lockID := "" + lockType := *lock.lockType + if lock.id != nil { + lockID = *lock.id + } + + if err := lock.unlock(); err != nil { + logger.With(ctx, "lock_id", lockID, "lock_type", lockType).WithError(err).Error("Could not unlock lock") + } else { + logger.With(ctx, "lock_id", lockID, "lock_type", lockType).Info("Unlocked lock") + } + delete(locks, k) + } else { + // the resolving UUID belongs to a service call that did *not* initiate the lock. + // it is ignored. + } + } + + return ctx +} diff --git a/pkg/db/migrations.go b/pkg/db/migrations.go index 63fa6e0..008b24c 100755 --- a/pkg/db/migrations.go +++ b/pkg/db/migrations.go @@ -24,6 +24,27 @@ func Migrate(g2 *gorm.DB) error { return nil } +// MigrateWithLock runs migrations with an advisory lock to prevent concurrent migrations +func MigrateWithLock(ctx context.Context, factory SessionFactory) error { + // Acquire advisory lock for migrations + ctx, lockOwnerID, err := NewAdvisoryLockContext(ctx, factory, "migrations", Migrations) + if err != nil { + logger.WithError(ctx, err).Error("Could not lock migrations") + return err + } + defer Unlock(ctx, lockOwnerID) + + // Run migrations with the locked context + g2 := factory.New(ctx) + if err := Migrate(g2); err != nil { + logger.WithError(ctx, err).Error("Could not migrate") + return err + } + + logger.Info(ctx, "Migration completed successfully") + return nil +} + // MigrateTo a specific migration will not seed the database, seeds are up to date with the latest // schema based on the most recent migration // This should be for testing purposes mainly diff --git a/test/integration/advisory_locks_test.go b/test/integration/advisory_locks_test.go new file mode 100644 index 0000000..07e6d64 --- /dev/null +++ b/test/integration/advisory_locks_test.go @@ -0,0 +1,321 @@ +package integration + +import ( + "context" + "math/rand" + "sync" + "testing" + "time" + + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/db" + "github.com/openshift-hyperfleet/hyperfleet-api/test" +) + +// TestAdvisoryLocksConcurrently validates that advisory locks properly serialize +// concurrent access to shared resources. This simulates a race condition where +// multiple threads try to access and modify the same variable. +func TestAdvisoryLocksConcurrently(t *testing.T) { + helper := test.NewHelper(t) + + total := 10 + var waiter sync.WaitGroup + waiter.Add(total) + + // Simulate a race condition where multiple threads are trying to access and modify the "total" var. + // The acquireLock func uses an advisory lock so the accesses to "total" should be properly serialized. + for i := 0; i < total; i++ { + go acquireLock(helper, &total, &waiter) + } + + // Wait for all goroutines to complete + waiter.Wait() + + // All goroutines should have decremented total by 1, resulting in 0 + if total != 0 { + t.Errorf("Expected total to be 0, got %d", total) + } +} + +func acquireLock(helper *test.Helper, total *int, waiter *sync.WaitGroup) { + ctx := context.Background() + + // Acquire advisory lock + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "test-resource", db.Migrations) + if err != nil { + helper.T.Errorf("Failed to acquire lock: %v", err) + waiter.Done() + return + } + defer db.Unlock(ctx, lockOwnerID) + + // Pretend loading "total" from DB + initTotal := *total + + // Some slow work to increase the likelihood of race conditions + time.Sleep(20 * time.Millisecond) + + // Pretend saving "total" to DB + finalTotal := initTotal - 1 + *total = finalTotal + + waiter.Done() +} + +// TestAdvisoryLocksWithTransactions validates that advisory locks work correctly +// when combined with database transactions in various orders +func TestAdvisoryLocksWithTransactions(t *testing.T) { + helper := test.NewHelper(t) + + total := 10 + var waiter sync.WaitGroup + waiter.Add(total) + + for i := 0; i < total; i++ { + go acquireLockWithTransaction(helper, &total, &waiter) + } + + waiter.Wait() + + if total != 0 { + t.Errorf("Expected total to be 0, got %d", total) + } +} + +func acquireLockWithTransaction(helper *test.Helper, total *int, waiter *sync.WaitGroup) { + ctx := context.Background() + + // Lock and Tx can be stored within the same context. They should be independent of each other. + // It doesn't matter if a Tx coexists or not, nor does it matter if it occurs before or after the lock + r := rand.Intn(3) // no Tx if r == 2 + txBeforeLock := r == 0 + txAfterLock := r == 1 + + var dberr error + + // Randomly add Tx before lock to demonstrate it works + if txBeforeLock { + ctx, dberr = db.NewContext(ctx, helper.DBFactory) + if dberr != nil { + helper.T.Errorf("Failed to create transaction context: %v", dberr) + waiter.Done() + return + } + defer db.Resolve(ctx) + } + + // Acquire advisory lock + ctx, lockOwnerID, dberr := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "test-resource-tx", db.Migrations) + if dberr != nil { + helper.T.Errorf("Failed to acquire lock: %v", dberr) + waiter.Done() + return + } + defer db.Unlock(ctx, lockOwnerID) + + // Randomly add Tx after lock to demonstrate it works + if txAfterLock { + ctx, dberr = db.NewContext(ctx, helper.DBFactory) + if dberr != nil { + helper.T.Errorf("Failed to create transaction context: %v", dberr) + waiter.Done() + return + } + defer db.Resolve(ctx) + } + + // Pretend loading "total" from DB + initTotal := *total + + // Some slow work + time.Sleep(20 * time.Millisecond) + + // Pretend saving "total" to DB + finalTotal := initTotal - 1 + *total = finalTotal + + waiter.Done() +} + +// TestLocksAndExpectedWaits validates the behavior of advisory locks: +// - Nested locks with the same (id, lockType) should not create additional locks +// - Different (id, lockType) combinations should create separate locks +// - Unlocking should only affect the lock matching the owner ID +func TestLocksAndExpectedWaits(t *testing.T) { + helper := test.NewHelper(t) + + // Start lock + ctx := context.Background() + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "system", db.Migrations) + if err != nil { + t.Fatalf("Failed to acquire lock: %v", err) + } + + // It should have 1 lock + g2 := helper.DBFactory.New(ctx) + var pgLocks []struct{ Granted bool } + g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + if len(pgLocks) != 1 { + t.Errorf("Expected 1 lock, got %d", len(pgLocks)) + } + + // Successive locking should have no effect (nested lock with same id/type) + // Pretend this runs in a nested func + ctx, lockOwnerID2, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "system", db.Migrations) + if err != nil { + t.Fatalf("Failed to acquire nested lock: %v", err) + } + // It should still have 1 lock + pgLocks = nil + g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + if len(pgLocks) != 1 { + t.Errorf("Expected 1 lock after nested acquire, got %d", len(pgLocks)) + } + + // Unlock should have no effect either (unlocking nested lock) + // Pretend this runs in the nested func + db.Unlock(ctx, lockOwnerID2) + // It should still have 1 lock + pgLocks = nil + g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + if len(pgLocks) != 1 { + t.Errorf("Expected 1 lock after nested unlock, got %d", len(pgLocks)) + } + + // Lock on a different (id, lockType) should work + // Pretend this runs in a nested func + ctx, lockOwnerID3, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "diff_system", db.Migrations) + if err != nil { + t.Fatalf("Failed to acquire different lock: %v", err) + } + // It should have 2 locks + pgLocks = nil + g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + if len(pgLocks) != 2 { + t.Errorf("Expected 2 locks, got %d", len(pgLocks)) + } + + // Pretend it releases the new lock in the nested func + db.Unlock(ctx, lockOwnerID3) + // It should have 1 lock + pgLocks = nil + g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + if len(pgLocks) != 1 { + t.Errorf("Expected 1 lock after releasing different lock, got %d", len(pgLocks)) + } + + // Unlock the topmost lock + // Pretend it returns back to the parent func + db.Unlock(ctx, lockOwnerID) + // The lock should be gone + pgLocks = nil + g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + if len(pgLocks) != 0 { + t.Errorf("Expected 0 locks after final unlock, got %d", len(pgLocks)) + } +} + +// TestConcurrentMigrations validates that the MigrateWithLock function +// properly serializes concurrent migration attempts, ensuring only one +// instance actually runs migrations at a time. +func TestConcurrentMigrations(t *testing.T) { + helper := test.NewHelper(t) + + // First, reset the database to a clean state + if err := helper.ResetDB(); err != nil { + t.Fatalf("Failed to reset database: %v", err) + } + + total := 5 + var waiter sync.WaitGroup + waiter.Add(total) + + // Track which goroutines successfully acquired the lock + var successCount int + var mu sync.Mutex + errors := make([]error, 0) + + // Simulate multiple pods trying to run migrations concurrently + for i := 0; i < total; i++ { + go func(id int) { + defer waiter.Done() + + ctx := context.Background() + err := db.MigrateWithLock(ctx, helper.DBFactory) + + mu.Lock() + defer mu.Unlock() + + if err != nil { + errors = append(errors, err) + } else { + successCount++ + } + }(i) + } + + waiter.Wait() + + // All migrations should succeed (they're idempotent) + if len(errors) > 0 { + t.Errorf("Expected no errors, but got %d: %v", len(errors), errors) + } + + // All goroutines should complete successfully + if successCount != total { + t.Errorf("Expected %d successful migrations, got %d", total, successCount) + } +} + +// TestAdvisoryLockBlocking validates that a second goroutine trying to acquire +// the same lock will block until the first goroutine releases it. +func TestAdvisoryLockBlocking(t *testing.T) { + helper := test.NewHelper(t) + + ctx := context.Background() + + // First goroutine acquires the lock + ctx1, lockOwnerID1, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "blocking-test", db.Migrations) + if err != nil { + t.Fatalf("Failed to acquire first lock: %v", err) + } + + // Track when the second goroutine acquires the lock + acquired := make(chan bool, 1) + released := make(chan bool, 1) + + // Second goroutine tries to acquire the same lock + go func() { + ctx2, lockOwnerID2, err := db.NewAdvisoryLockContext(context.Background(), helper.DBFactory, "blocking-test", db.Migrations) + if err != nil { + t.Errorf("Failed to acquire second lock: %v", err) + return + } + defer db.Unlock(ctx2, lockOwnerID2) + + acquired <- true + <-released // Wait for signal to release + }() + + // Give the second goroutine time to start waiting + time.Sleep(100 * time.Millisecond) + + // The second goroutine should still be blocked + select { + case <-acquired: + t.Error("Second goroutine acquired lock while first still holds it") + default: + // Expected: second goroutine is still blocked + } + + // Release the first lock + db.Unlock(ctx1, lockOwnerID1) + + // Now the second goroutine should acquire the lock + select { + case <-acquired: + // Expected: second goroutine acquired the lock + released <- true + case <-time.After(5 * time.Second): + t.Error("Second goroutine did not acquire lock after first was released") + } +} From eda930a4b0ea6f773422f25a7c81e02cb4751843 Mon Sep 17 00:00:00 2001 From: ldornele Date: Wed, 11 Mar 2026 01:34:45 -0300 Subject: [PATCH 2/6] HYPERFLEET-618: refactoring after code review feedback --- docs/database.md | 10 +- pkg/config/db.go | 31 ++- pkg/db/advisory_locks.go | 57 +++-- pkg/db/context.go | 29 ++- pkg/db/db_session/default.go | 4 + pkg/db/db_session/test.go | 4 + pkg/db/db_session/testcontainer.go | 4 + pkg/db/migrations.go | 2 +- pkg/db/mocks/session_factory.go | 4 + pkg/db/session.go | 1 + pkg/logger/fields.go | 3 + test/integration/advisory_locks_test.go | 279 +++++++++++++++++++++--- 12 files changed, 349 insertions(+), 79 deletions(-) diff --git a/docs/database.md b/docs/database.md index f6f5e8b..fa7f56e 100644 --- a/docs/database.md +++ b/docs/database.md @@ -76,14 +76,16 @@ db.MigrateWithLock(ctx, factory) ``` **Implementation:** -1. Pod acquires advisory lock via `pg_advisory_xact_lock(hash("migrations"), hash("Migrations"))` -2. Lock holder runs migrations exclusively -3. Other pods block until lock is released -4. Lock automatically released on transaction commit +1. Pod sets statement timeout (5 minutes) to prevent indefinite blocking +2. Pod acquires advisory lock via `pg_advisory_xact_lock(hash("migrations"), hash("Migrations"))` +3. Lock holder runs migrations exclusively +4. Other pods block until lock is released or timeout is reached +5. Lock automatically released on transaction commit **Key Features:** - **Zero infrastructure overhead** - Uses native PostgreSQL locks - **Automatic cleanup** - Locks released on transaction end or pod crash +- **Timeout protection** - 5-minute timeout prevents indefinite blocking if a pod hangs - **Nested lock support** - Same lock can be acquired in nested contexts without deadlock - **UUID-based ownership** - Only original acquirer can unlock diff --git a/pkg/config/db.go b/pkg/config/db.go index d6c9de9..0a5749b 100755 --- a/pkg/config/db.go +++ b/pkg/config/db.go @@ -11,10 +11,11 @@ import ( ) type DatabaseConfig struct { - Dialect string `json:"dialect"` - SSLMode string `json:"sslmode"` - Debug bool `json:"debug"` - MaxOpenConnections int `json:"max_connections"` + Dialect string `json:"dialect"` + SSLMode string `json:"sslmode"` + Debug bool `json:"debug"` + MaxOpenConnections int `json:"max_connections"` + AdvisoryLockTimeoutSeconds int `json:"advisory_lock_timeout_seconds"` Host string `json:"host"` Port int `json:"port"` @@ -32,10 +33,11 @@ type DatabaseConfig struct { func NewDatabaseConfig() *DatabaseConfig { return &DatabaseConfig{ - Dialect: "postgres", - SSLMode: "disable", - Debug: false, - MaxOpenConnections: 50, + Dialect: "postgres", + SSLMode: "disable", + Debug: false, + MaxOpenConnections: 50, + AdvisoryLockTimeoutSeconds: 300, // 5 minutes - prevents indefinite blocking on migrations HostFile: "secrets/db.host", PortFile: "secrets/db.port", @@ -59,6 +61,10 @@ func (c *DatabaseConfig) AddFlags(fs *pflag.FlagSet) { &c.MaxOpenConnections, "db-max-open-connections", c.MaxOpenConnections, "Maximum open DB connections for this instance", ) + fs.IntVar( + &c.AdvisoryLockTimeoutSeconds, "db-advisory-lock-timeout", c.AdvisoryLockTimeoutSeconds, + "Advisory lock timeout in seconds (prevents indefinite blocking during migrations)", + ) } // BindEnv reads configuration from environment variables @@ -72,6 +78,15 @@ func (c *DatabaseConfig) BindEnv(fs *pflag.FlagSet) { } } } + + if val := os.Getenv("DB_ADVISORY_LOCK_TIMEOUT"); val != "" { + if fs == nil || !fs.Changed("db-advisory-lock-timeout") { + timeout, err := strconv.Atoi(val) + if err == nil && timeout > 0 { + c.AdvisoryLockTimeoutSeconds = timeout + } + } + } } func (c *DatabaseConfig) ReadFiles() error { diff --git a/pkg/db/advisory_locks.go b/pkg/db/advisory_locks.go index 4fd4c8e..d43262c 100644 --- a/pkg/db/advisory_locks.go +++ b/pkg/db/advisory_locks.go @@ -3,9 +3,11 @@ package db import ( "context" "errors" + "fmt" "hash/fnv" "time" + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/logger" "gorm.io/gorm" ) @@ -15,6 +17,9 @@ type LockType string const ( // Migrations lock type for database migrations Migrations LockType = "Migrations" + + // MigrationsLockID is the advisory lock ID used for migration coordination + MigrationsLockID = "migrations" ) // AdvisoryLock represents a postgres advisory lock @@ -27,12 +32,12 @@ const ( // service call that owns the lock will have the correct ownerUUID. This is necessary // to allow functions to call other service functions as part of the same lock (id, lockType). type AdvisoryLock struct { - g2 *gorm.DB - txid int64 - ownerUUID *string - id *string - lockType *LockType - startTime time.Time + g2 *gorm.DB + ownerUUID *string + id *string + lockType *LockType + timeoutSeconds int + startTime time.Time } // newAdvisoryLock constructs a new AdvisoryLock object. @@ -50,24 +55,19 @@ func newAdvisoryLock(ctx context.Context, connection SessionFactory, ownerUUID * return nil, tx.Error } - // current transaction ID set by postgres. these are *not* distinct across time - // and do get reset after postgres performs "vacuuming" to reclaim used IDs. - var txid struct{ ID int64 } - err := tx.Raw("select txid_current() as id").Scan(&txid).Error - return &AdvisoryLock{ - txid: txid.ID, - ownerUUID: ownerUUID, - id: id, - lockType: locktype, - g2: tx, - startTime: time.Now(), - }, err + ownerUUID: ownerUUID, + id: id, + lockType: locktype, + timeoutSeconds: connection.GetAdvisoryLockTimeout(), + g2: tx, + startTime: time.Now(), + }, nil } // lock calls select pg_advisory_xact_lock(id, lockType) to obtain the lock defined by (id, lockType). -// it is blocked if some other thread currently is holding the same lock (id, lockType). -// if blocked, it can be unblocked or timed out when overloaded. +// It blocks until the lock is acquired or the statement timeout is reached. +// The timeout prevents indefinite blocking if a pod hangs while holding the lock. func (l *AdvisoryLock) lock() error { if l.g2 == nil { return errors.New("AdvisoryLock: transaction is missing") @@ -79,20 +79,35 @@ func (l *AdvisoryLock) lock() error { return errors.New("AdvisoryLock: lockType is missing") } + // Set statement timeout to prevent indefinite blocking. + // This is transaction-scoped (SET LOCAL), so it only affects this lock acquisition. + // Note: We cannot use parameter binding (?) for SET commands in PostgreSQL + timeoutMs := l.timeoutSeconds * 1000 + if err := l.g2.Exec(fmt.Sprintf("SET LOCAL statement_timeout = %d", timeoutMs)).Error; err != nil { + return err + } + idAsInt := hash(*l.id) typeAsInt := hash(string(*l.lockType)) err := l.g2.Exec("select pg_advisory_xact_lock(?, ?)", idAsInt, typeAsInt).Error return err } -func (l *AdvisoryLock) unlock() error { +func (l *AdvisoryLock) unlock(ctx context.Context) error { if l.g2 == nil { return errors.New("AdvisoryLock: transaction is missing") } + duration := time.Since(l.startTime) + // it ends the Tx and implicitly releases the lock. err := l.g2.Commit().Error l.g2 = nil + + if err == nil { + logger.With(ctx, logger.FieldLockDurationMs, duration.Milliseconds()).Info("Released advisory lock") + } + return err } diff --git a/pkg/db/context.go b/pkg/db/context.go index 8727839..2dec22f 100755 --- a/pkg/db/context.go +++ b/pkg/db/context.go @@ -78,7 +78,11 @@ func MarkForRollback(ctx context.Context, err error) { } // NewAdvisoryLockContext returns a new context with AdvisoryLock stored in it. -// Upon error, the original context is still returned along with an error +// Upon error, the original context is still returned along with an error. +// +// CONCURRENCY: The returned context must not be shared across goroutines that call +// NewAdvisoryLockContext or Unlock concurrently, as the internal lock map is not +// protected by a mutex. Each goroutine should derive its own context chain. func NewAdvisoryLockContext(ctx context.Context, connection SessionFactory, id string, lockType LockType) (context.Context, string, error) { // lockOwnerID will be different for every service function that attempts to start a lock. // only the initial call in the stack must unlock. @@ -104,39 +108,44 @@ func NewAdvisoryLockContext(ctx context.Context, connection SessionFactory, id s err = lock.lock() if err != nil { logger.WithError(ctx, err).Error("Failed to acquire advisory lock") + lock.g2.Rollback() // clean up the open transaction return ctx, lockOwnerID, err } locks.set(id, lockType, lock) ctx = context.WithValue(ctx, advisoryLock, locks) - logger.With(ctx, "lock_id", id, "lock_type", lockType).Info("Acquired advisory lock") + logger.With(ctx, logger.FieldLockID, id, logger.FieldLockType, lockType).Info("Acquired advisory lock") return ctx, lockOwnerID, nil } // Unlock searches current locks and unlocks the one matching its owner id. -func Unlock(ctx context.Context, callerUUID string) context.Context { +func Unlock(ctx context.Context, callerUUID string) { locks, ok := ctx.Value(advisoryLock).(advisoryLockMap) if !ok { logger.Error(ctx, "Could not retrieve locks from context") - return ctx + return } for k, lock := range locks { if lock.ownerUUID == nil { - logger.With(ctx, "lock_id", lock.id).Warn("lockOwnerID could not be found in AdvisoryLock") + logger.With(ctx, logger.FieldLockID, lock.id).Warn("lockOwnerID could not be found in AdvisoryLock") } else if *lock.ownerUUID == callerUUID { lockID := "" - lockType := *lock.lockType + lockType := LockType("") + if lock.id != nil { lockID = *lock.id } + if lock.lockType != nil { + lockType = *lock.lockType + } - if err := lock.unlock(); err != nil { - logger.With(ctx, "lock_id", lockID, "lock_type", lockType).WithError(err).Error("Could not unlock lock") + if err := lock.unlock(ctx); err != nil { + logger.With(ctx, logger.FieldLockID, lockID, logger.FieldLockType, lockType).WithError(err).Error("Could not unlock lock") } else { - logger.With(ctx, "lock_id", lockID, "lock_type", lockType).Info("Unlocked lock") + logger.With(ctx, logger.FieldLockID, lockID, logger.FieldLockType, lockType).Info("Unlocked lock") } delete(locks, k) } else { @@ -144,6 +153,4 @@ func Unlock(ctx context.Context, callerUUID string) context.Context { // it is ignored. } } - - return ctx } diff --git a/pkg/db/db_session/default.go b/pkg/db/db_session/default.go index f6b1d41..c00a6ff 100755 --- a/pkg/db/db_session/default.go +++ b/pkg/db/db_session/default.go @@ -188,3 +188,7 @@ func (f *Default) ReconfigureLogger(level gormlogger.LogLevel) { newLogger := logger.NewGormLogger(level, slowQueryThreshold) f.g2.Logger = newLogger } + +func (f *Default) GetAdvisoryLockTimeout() int { + return f.config.AdvisoryLockTimeoutSeconds +} diff --git a/pkg/db/db_session/test.go b/pkg/db/db_session/test.go index 9e17694..a5c9f87 100755 --- a/pkg/db/db_session/test.go +++ b/pkg/db/db_session/test.go @@ -232,3 +232,7 @@ func (f *Test) ReconfigureLogger(level gormlogger.LogLevel) { newLogger := logger.NewGormLogger(level, slowQueryThreshold) f.g2.Logger = newLogger } + +func (f *Test) GetAdvisoryLockTimeout() int { + return f.config.AdvisoryLockTimeoutSeconds +} diff --git a/pkg/db/db_session/testcontainer.go b/pkg/db/db_session/testcontainer.go index 36f4b91..f6ad2cd 100755 --- a/pkg/db/db_session/testcontainer.go +++ b/pkg/db/db_session/testcontainer.go @@ -224,3 +224,7 @@ func (f *Testcontainer) ReconfigureLogger(level gormlogger.LogLevel) { newLogger := logger.NewGormLogger(level, slowQueryThreshold) f.g2.Logger = newLogger } + +func (f *Testcontainer) GetAdvisoryLockTimeout() int { + return f.config.AdvisoryLockTimeoutSeconds +} diff --git a/pkg/db/migrations.go b/pkg/db/migrations.go index 008b24c..5d647a7 100755 --- a/pkg/db/migrations.go +++ b/pkg/db/migrations.go @@ -27,7 +27,7 @@ func Migrate(g2 *gorm.DB) error { // MigrateWithLock runs migrations with an advisory lock to prevent concurrent migrations func MigrateWithLock(ctx context.Context, factory SessionFactory) error { // Acquire advisory lock for migrations - ctx, lockOwnerID, err := NewAdvisoryLockContext(ctx, factory, "migrations", Migrations) + ctx, lockOwnerID, err := NewAdvisoryLockContext(ctx, factory, MigrationsLockID, Migrations) if err != nil { logger.WithError(ctx, err).Error("Could not lock migrations") return err diff --git a/pkg/db/mocks/session_factory.go b/pkg/db/mocks/session_factory.go index 18f8396..7bfddb3 100755 --- a/pkg/db/mocks/session_factory.go +++ b/pkg/db/mocks/session_factory.go @@ -77,3 +77,7 @@ func (m *MockSessionFactory) ResetDB() { func (m *MockSessionFactory) NewListener(ctx context.Context, channel string, callback func(id string)) { // Mock implementation - does nothing } + +func (m *MockSessionFactory) GetAdvisoryLockTimeout() int { + return 300 // 5 minutes default +} diff --git a/pkg/db/session.go b/pkg/db/session.go index a124751..5d9dbbb 100755 --- a/pkg/db/session.go +++ b/pkg/db/session.go @@ -17,4 +17,5 @@ type SessionFactory interface { Close() error ResetDB() NewListener(ctx context.Context, channel string, callback func(id string)) + GetAdvisoryLockTimeout() int } diff --git a/pkg/logger/fields.go b/pkg/logger/fields.go index 8c5e5a0..8c0d45f 100644 --- a/pkg/logger/fields.go +++ b/pkg/logger/fields.go @@ -31,6 +31,9 @@ const ( FieldConnectionString = "connection_string" FieldTable = "table" FieldChannel = "channel" + FieldLockID = "lock_id" + FieldLockType = "lock_type" + FieldLockDurationMs = "lock_duration_ms" // Note: transaction_id is a context field (see context.go) ) diff --git a/test/integration/advisory_locks_test.go b/test/integration/advisory_locks_test.go index 07e6d64..a6bb06e 100644 --- a/test/integration/advisory_locks_test.go +++ b/test/integration/advisory_locks_test.go @@ -2,86 +2,129 @@ package integration import ( "context" + "fmt" "math/rand" "sync" "testing" "time" + "gorm.io/gorm" + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/db" "github.com/openshift-hyperfleet/hyperfleet-api/test" ) // TestAdvisoryLocksConcurrently validates that advisory locks properly serialize -// concurrent access to shared resources. This simulates a race condition where -// multiple threads try to access and modify the same variable. +// concurrent access to shared resources. This test uses actual database operations +// to prove the lock prevents race conditions at the database level. func TestAdvisoryLocksConcurrently(t *testing.T) { helper := test.NewHelper(t) + // Create a counter table and initialize to 0 + g2 := helper.DBFactory.New(context.Background()) + if err := g2.Exec("CREATE TABLE IF NOT EXISTS lock_test_counter (id INTEGER PRIMARY KEY, value INTEGER)").Error; err != nil { + t.Fatalf("Failed to create counter table: %v", err) + } + if err := g2.Exec("INSERT INTO lock_test_counter (id, value) VALUES (1, 0)").Error; err != nil { + t.Fatalf("Failed to initialize counter: %v", err) + } + defer g2.Exec("DROP TABLE IF EXISTS lock_test_counter") + total := 10 var waiter sync.WaitGroup waiter.Add(total) - // Simulate a race condition where multiple threads are trying to access and modify the "total" var. - // The acquireLock func uses an advisory lock so the accesses to "total" should be properly serialized. + // Simulate a race condition where multiple threads are trying to access and modify the counter. + // The acquireLock func uses an advisory lock so the accesses should be properly serialized. for i := 0; i < total; i++ { - go acquireLock(helper, &total, &waiter) + go acquireLock(helper, &waiter) } // Wait for all goroutines to complete waiter.Wait() - // All goroutines should have decremented total by 1, resulting in 0 - if total != 0 { - t.Errorf("Expected total to be 0, got %d", total) + // All goroutines should have incremented the counter by 1, resulting in 10 + var finalValue int + if err := g2.Raw("SELECT value FROM lock_test_counter WHERE id = 1").Scan(&finalValue).Error; err != nil { + t.Fatalf("Failed to read final counter value: %v", err) + } + if finalValue != total { + t.Errorf("Expected counter to be %d, got %d", total, finalValue) } } -func acquireLock(helper *test.Helper, total *int, waiter *sync.WaitGroup) { +func acquireLock(helper *test.Helper, waiter *sync.WaitGroup) { + defer waiter.Done() + ctx := context.Background() // Acquire advisory lock ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "test-resource", db.Migrations) if err != nil { helper.T.Errorf("Failed to acquire lock: %v", err) - waiter.Done() return } defer db.Unlock(ctx, lockOwnerID) - // Pretend loading "total" from DB - initTotal := *total + g2 := helper.DBFactory.New(ctx) + + // Read current value from database + var currentValue int + if err := g2.Raw("SELECT value FROM lock_test_counter WHERE id = 1").Scan(¤tValue).Error; err != nil { + helper.T.Errorf("Failed to read counter: %v", err) + return + } // Some slow work to increase the likelihood of race conditions time.Sleep(20 * time.Millisecond) - // Pretend saving "total" to DB - finalTotal := initTotal - 1 - *total = finalTotal - - waiter.Done() + // Increment and save to database + newValue := currentValue + 1 + if err := g2.Exec("UPDATE lock_test_counter SET value = ? WHERE id = 1", newValue).Error; err != nil { + helper.T.Errorf("Failed to update counter: %v", err) + return + } } // TestAdvisoryLocksWithTransactions validates that advisory locks work correctly -// when combined with database transactions in various orders +// when combined with database transactions in various orders. Uses actual database +// operations to prove serialization. func TestAdvisoryLocksWithTransactions(t *testing.T) { helper := test.NewHelper(t) + // Create a counter table and initialize to 0 + g2 := helper.DBFactory.New(context.Background()) + if err := g2.Exec("CREATE TABLE IF NOT EXISTS lock_test_counter_tx (id INTEGER PRIMARY KEY, value INTEGER)").Error; err != nil { + t.Fatalf("Failed to create counter table: %v", err) + } + if err := g2.Exec("INSERT INTO lock_test_counter_tx (id, value) VALUES (1, 0)").Error; err != nil { + t.Fatalf("Failed to initialize counter: %v", err) + } + defer g2.Exec("DROP TABLE IF EXISTS lock_test_counter_tx") + total := 10 var waiter sync.WaitGroup waiter.Add(total) for i := 0; i < total; i++ { - go acquireLockWithTransaction(helper, &total, &waiter) + go acquireLockWithTransaction(helper, &waiter) } waiter.Wait() - if total != 0 { - t.Errorf("Expected total to be 0, got %d", total) + // All goroutines should have incremented the counter by 1, resulting in 10 + var finalValue int + if err := g2.Raw("SELECT value FROM lock_test_counter_tx WHERE id = 1").Scan(&finalValue).Error; err != nil { + t.Fatalf("Failed to read final counter value: %v", err) + } + if finalValue != total { + t.Errorf("Expected counter to be %d, got %d", total, finalValue) } } -func acquireLockWithTransaction(helper *test.Helper, total *int, waiter *sync.WaitGroup) { +func acquireLockWithTransaction(helper *test.Helper, waiter *sync.WaitGroup) { + defer waiter.Done() + ctx := context.Background() // Lock and Tx can be stored within the same context. They should be independent of each other. @@ -97,7 +140,6 @@ func acquireLockWithTransaction(helper *test.Helper, total *int, waiter *sync.Wa ctx, dberr = db.NewContext(ctx, helper.DBFactory) if dberr != nil { helper.T.Errorf("Failed to create transaction context: %v", dberr) - waiter.Done() return } defer db.Resolve(ctx) @@ -107,7 +149,6 @@ func acquireLockWithTransaction(helper *test.Helper, total *int, waiter *sync.Wa ctx, lockOwnerID, dberr := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "test-resource-tx", db.Migrations) if dberr != nil { helper.T.Errorf("Failed to acquire lock: %v", dberr) - waiter.Done() return } defer db.Unlock(ctx, lockOwnerID) @@ -117,23 +158,29 @@ func acquireLockWithTransaction(helper *test.Helper, total *int, waiter *sync.Wa ctx, dberr = db.NewContext(ctx, helper.DBFactory) if dberr != nil { helper.T.Errorf("Failed to create transaction context: %v", dberr) - waiter.Done() return } defer db.Resolve(ctx) } - // Pretend loading "total" from DB - initTotal := *total + g2 := helper.DBFactory.New(ctx) + + // Read current value from database + var currentValue int + if err := g2.Raw("SELECT value FROM lock_test_counter_tx WHERE id = 1").Scan(¤tValue).Error; err != nil { + helper.T.Errorf("Failed to read counter: %v", err) + return + } // Some slow work time.Sleep(20 * time.Millisecond) - // Pretend saving "total" to DB - finalTotal := initTotal - 1 - *total = finalTotal - - waiter.Done() + // Increment and save to database + newValue := currentValue + 1 + if err := g2.Exec("UPDATE lock_test_counter_tx SET value = ? WHERE id = 1", newValue).Error; err != nil { + helper.T.Errorf("Failed to update counter: %v", err) + return + } } // TestLocksAndExpectedWaits validates the behavior of advisory locks: @@ -296,8 +343,27 @@ func TestAdvisoryLockBlocking(t *testing.T) { <-released // Wait for signal to release }() - // Give the second goroutine time to start waiting - time.Sleep(100 * time.Millisecond) + // Wait for the second goroutine to be actively waiting on the lock + // by polling pg_locks for a non-granted advisory lock. + // This is more reliable than sleep, especially in slow CI environments. + g2 := helper.DBFactory.New(ctx) + waitingForLock := false + for i := 0; i < 50; i++ { // Poll for up to 5 seconds (50 * 100ms) + var waitingLocks []struct{ Granted bool } + if err := g2.Raw("SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false").Scan(&waitingLocks).Error; err != nil { + t.Errorf("Failed to query pg_locks: %v", err) + break + } + if len(waitingLocks) > 0 { + waitingForLock = true + break + } + time.Sleep(100 * time.Millisecond) + } + + if !waitingForLock { + t.Fatal("Second goroutine did not reach the lock waiting state within timeout") + } // The second goroutine should still be blocked select { @@ -319,3 +385,148 @@ func TestAdvisoryLockBlocking(t *testing.T) { t.Error("Second goroutine did not acquire lock after first was released") } } + +// TestAdvisoryLockContextCancellation verifies that context cancellation properly +// terminates a waiting advisory lock acquisition. The context is passed through +// connection.New(ctx) and affects the blocking pg_advisory_xact_lock SQL call. +func TestAdvisoryLockContextCancellation(t *testing.T) { + helper := test.NewHelper(t) + + ctx := context.Background() + + // First goroutine acquires the lock + ctx1, lockOwnerID1, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "cancel-test", db.Migrations) + if err != nil { + t.Fatalf("Failed to acquire first lock: %v", err) + } + defer db.Unlock(ctx1, lockOwnerID1) + + // Track when the second goroutine gets cancelled + gotCancelError := make(chan bool, 1) + + // Create a cancellable context for the second goroutine + ctx2, cancel := context.WithCancel(context.Background()) + + // Second goroutine tries to acquire the same lock with cancellable context + go func() { + _, _, err := db.NewAdvisoryLockContext(ctx2, helper.DBFactory, "cancel-test", db.Migrations) + if err != nil { + // Expected: context cancellation causes "canceling statement due to user request" + t.Logf("Second goroutine got error (expected): %v", err) + gotCancelError <- true + return + } + t.Error("Second goroutine acquired lock despite context cancellation (unexpected)") + }() + + // Wait for the second goroutine to be actively waiting on the lock + g2 := helper.DBFactory.New(ctx) + waitingForLock := false + for i := 0; i < 50; i++ { + var waitingLocks []struct{ Granted bool } + if err := g2.Raw("SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false").Scan(&waitingLocks).Error; err != nil { + t.Errorf("Failed to query pg_locks: %v", err) + break + } + if len(waitingLocks) > 0 { + waitingForLock = true + break + } + time.Sleep(100 * time.Millisecond) + } + + if !waitingForLock { + t.Fatal("Second goroutine did not reach the lock waiting state within timeout") + } + + // Cancel the context while the second goroutine is waiting + cancel() + + // The second goroutine should exit with a cancellation error + select { + case <-gotCancelError: + // Expected: context cancellation terminates the lock acquisition + t.Log("Confirmed: context cancellation properly terminates waiting advisory lock") + case <-time.After(2 * time.Second): + t.Error("Second goroutine did not exit after context cancellation within timeout") + } +} + +// TestMigrationFailureUnderLock validates that when a migration fails while holding +// the advisory lock, the lock is properly released via defer, allowing other waiters +// to proceed. This tests the error path and cleanup behavior. +func TestMigrationFailureUnderLock(t *testing.T) { + helper := test.NewHelper(t) + + // Reset database to clean state + if err := helper.ResetDB(); err != nil { + t.Fatalf("Failed to reset database: %v", err) + } + + // Track results + var mu sync.Mutex + successCount := 0 + failureCount := 0 + var wg sync.WaitGroup + + // Create a failing migration function + failingMigration := func(g2 *gorm.DB) error { + return fmt.Errorf("simulated migration failure") + } + + // First goroutine: acquire lock and fail migration + wg.Add(1) + go func() { + defer wg.Done() + + ctx := context.Background() + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "migration-fail-test", db.Migrations) + if err != nil { + t.Errorf("Failed to acquire lock: %v", err) + return + } + defer db.Unlock(ctx, lockOwnerID) + + // Simulate migration failure + if err := failingMigration(helper.DBFactory.New(ctx)); err != nil { + mu.Lock() + failureCount++ + mu.Unlock() + } + // Lock should be released via defer even though migration failed + }() + + // Give first goroutine time to acquire lock and fail + time.Sleep(100 * time.Millisecond) + + // Second goroutine: should be able to acquire lock after first fails + wg.Add(1) + go func() { + defer wg.Done() + + ctx := context.Background() + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "migration-fail-test", db.Migrations) + if err != nil { + t.Errorf("Failed to acquire lock after failure: %v", err) + return + } + defer db.Unlock(ctx, lockOwnerID) + + // This one succeeds + mu.Lock() + successCount++ + mu.Unlock() + }() + + wg.Wait() + + // Verify both completed + if failureCount != 1 { + t.Errorf("Expected 1 failure, got %d", failureCount) + } + if successCount != 1 { + t.Errorf("Expected 1 success, got %d", successCount) + } + + t.Log("Confirmed: lock properly released after migration failure, allowing subsequent operations") +} From 64c865ae6768e0ef901223e3965f92e207233f95 Mon Sep 17 00:00:00 2001 From: ldornele Date: Wed, 11 Mar 2026 02:39:52 -0300 Subject: [PATCH 3/6] HYPERFLEET-618: refactoring suggested by coderabbitai --- pkg/db/context.go | 4 +- test/integration/advisory_locks_test.go | 281 +++++++++--------------- 2 files changed, 111 insertions(+), 174 deletions(-) diff --git a/pkg/db/context.go b/pkg/db/context.go index 2dec22f..8094a9f 100755 --- a/pkg/db/context.go +++ b/pkg/db/context.go @@ -144,9 +144,9 @@ func Unlock(ctx context.Context, callerUUID string) { if err := lock.unlock(ctx); err != nil { logger.With(ctx, logger.FieldLockID, lockID, logger.FieldLockType, lockType).WithError(err).Error("Could not unlock lock") - } else { - logger.With(ctx, logger.FieldLockID, lockID, logger.FieldLockType, lockType).Info("Unlocked lock") + continue } + logger.With(ctx, logger.FieldLockID, lockID, logger.FieldLockType, lockType).Info("Unlocked lock") delete(locks, k) } else { // the resolving UUID belongs to a service call that did *not* initiate the lock. diff --git a/test/integration/advisory_locks_test.go b/test/integration/advisory_locks_test.go index a6bb06e..fcf4723 100644 --- a/test/integration/advisory_locks_test.go +++ b/test/integration/advisory_locks_test.go @@ -2,12 +2,15 @@ package integration import ( "context" + "errors" "fmt" "math/rand" + "strings" "sync" "testing" "time" + . "github.com/onsi/gomega" "gorm.io/gorm" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/db" @@ -18,16 +21,15 @@ import ( // concurrent access to shared resources. This test uses actual database operations // to prove the lock prevents race conditions at the database level. func TestAdvisoryLocksConcurrently(t *testing.T) { - helper := test.NewHelper(t) + h, _ := test.RegisterIntegration(t) // Create a counter table and initialize to 0 - g2 := helper.DBFactory.New(context.Background()) - if err := g2.Exec("CREATE TABLE IF NOT EXISTS lock_test_counter (id INTEGER PRIMARY KEY, value INTEGER)").Error; err != nil { - t.Fatalf("Failed to create counter table: %v", err) - } - if err := g2.Exec("INSERT INTO lock_test_counter (id, value) VALUES (1, 0)").Error; err != nil { - t.Fatalf("Failed to initialize counter: %v", err) - } + g2 := h.DBFactory.New(context.Background()) + err := g2.Exec("CREATE TABLE IF NOT EXISTS lock_test_counter (id INTEGER PRIMARY KEY, value INTEGER)").Error + Expect(err).NotTo(HaveOccurred(), "Failed to create counter table") + + err = g2.Exec("INSERT INTO lock_test_counter (id, value) VALUES (1, 0)").Error + Expect(err).NotTo(HaveOccurred(), "Failed to initialize counter") defer g2.Exec("DROP TABLE IF EXISTS lock_test_counter") total := 10 @@ -37,7 +39,7 @@ func TestAdvisoryLocksConcurrently(t *testing.T) { // Simulate a race condition where multiple threads are trying to access and modify the counter. // The acquireLock func uses an advisory lock so the accesses should be properly serialized. for i := 0; i < total; i++ { - go acquireLock(helper, &waiter) + go acquireLock(h, &waiter) } // Wait for all goroutines to complete @@ -45,61 +47,50 @@ func TestAdvisoryLocksConcurrently(t *testing.T) { // All goroutines should have incremented the counter by 1, resulting in 10 var finalValue int - if err := g2.Raw("SELECT value FROM lock_test_counter WHERE id = 1").Scan(&finalValue).Error; err != nil { - t.Fatalf("Failed to read final counter value: %v", err) - } - if finalValue != total { - t.Errorf("Expected counter to be %d, got %d", total, finalValue) - } + err = g2.Raw("SELECT value FROM lock_test_counter WHERE id = 1").Scan(&finalValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to read final counter value") + Expect(finalValue).To(Equal(total), "Counter should equal total") } -func acquireLock(helper *test.Helper, waiter *sync.WaitGroup) { +func acquireLock(h *test.Helper, waiter *sync.WaitGroup) { defer waiter.Done() ctx := context.Background() // Acquire advisory lock - ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "test-resource", db.Migrations) - if err != nil { - helper.T.Errorf("Failed to acquire lock: %v", err) - return - } + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "test-resource", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire lock") defer db.Unlock(ctx, lockOwnerID) - g2 := helper.DBFactory.New(ctx) + g2 := h.DBFactory.New(ctx) // Read current value from database var currentValue int - if err := g2.Raw("SELECT value FROM lock_test_counter WHERE id = 1").Scan(¤tValue).Error; err != nil { - helper.T.Errorf("Failed to read counter: %v", err) - return - } + err = g2.Raw("SELECT value FROM lock_test_counter WHERE id = 1").Scan(¤tValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to read counter") // Some slow work to increase the likelihood of race conditions time.Sleep(20 * time.Millisecond) // Increment and save to database newValue := currentValue + 1 - if err := g2.Exec("UPDATE lock_test_counter SET value = ? WHERE id = 1", newValue).Error; err != nil { - helper.T.Errorf("Failed to update counter: %v", err) - return - } + err = g2.Exec("UPDATE lock_test_counter SET value = ? WHERE id = 1", newValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to update counter") } // TestAdvisoryLocksWithTransactions validates that advisory locks work correctly // when combined with database transactions in various orders. Uses actual database // operations to prove serialization. func TestAdvisoryLocksWithTransactions(t *testing.T) { - helper := test.NewHelper(t) + h, _ := test.RegisterIntegration(t) // Create a counter table and initialize to 0 - g2 := helper.DBFactory.New(context.Background()) - if err := g2.Exec("CREATE TABLE IF NOT EXISTS lock_test_counter_tx (id INTEGER PRIMARY KEY, value INTEGER)").Error; err != nil { - t.Fatalf("Failed to create counter table: %v", err) - } - if err := g2.Exec("INSERT INTO lock_test_counter_tx (id, value) VALUES (1, 0)").Error; err != nil { - t.Fatalf("Failed to initialize counter: %v", err) - } + g2 := h.DBFactory.New(context.Background()) + err := g2.Exec("CREATE TABLE IF NOT EXISTS lock_test_counter_tx (id INTEGER PRIMARY KEY, value INTEGER)").Error + Expect(err).NotTo(HaveOccurred(), "Failed to create counter table") + + err = g2.Exec("INSERT INTO lock_test_counter_tx (id, value) VALUES (1, 0)").Error + Expect(err).NotTo(HaveOccurred(), "Failed to initialize counter") defer g2.Exec("DROP TABLE IF EXISTS lock_test_counter_tx") total := 10 @@ -107,22 +98,19 @@ func TestAdvisoryLocksWithTransactions(t *testing.T) { waiter.Add(total) for i := 0; i < total; i++ { - go acquireLockWithTransaction(helper, &waiter) + go acquireLockWithTransaction(h, &waiter) } waiter.Wait() // All goroutines should have incremented the counter by 1, resulting in 10 var finalValue int - if err := g2.Raw("SELECT value FROM lock_test_counter_tx WHERE id = 1").Scan(&finalValue).Error; err != nil { - t.Fatalf("Failed to read final counter value: %v", err) - } - if finalValue != total { - t.Errorf("Expected counter to be %d, got %d", total, finalValue) - } + err = g2.Raw("SELECT value FROM lock_test_counter_tx WHERE id = 1").Scan(&finalValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to read final counter value") + Expect(finalValue).To(Equal(total), "Counter should equal total") } -func acquireLockWithTransaction(helper *test.Helper, waiter *sync.WaitGroup) { +func acquireLockWithTransaction(h *test.Helper, waiter *sync.WaitGroup) { defer waiter.Done() ctx := context.Background() @@ -137,50 +125,37 @@ func acquireLockWithTransaction(helper *test.Helper, waiter *sync.WaitGroup) { // Randomly add Tx before lock to demonstrate it works if txBeforeLock { - ctx, dberr = db.NewContext(ctx, helper.DBFactory) - if dberr != nil { - helper.T.Errorf("Failed to create transaction context: %v", dberr) - return - } + ctx, dberr = db.NewContext(ctx, h.DBFactory) + Expect(dberr).NotTo(HaveOccurred(), "Failed to create transaction context") defer db.Resolve(ctx) } // Acquire advisory lock - ctx, lockOwnerID, dberr := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "test-resource-tx", db.Migrations) - if dberr != nil { - helper.T.Errorf("Failed to acquire lock: %v", dberr) - return - } + ctx, lockOwnerID, dberr := db.NewAdvisoryLockContext(ctx, h.DBFactory, "test-resource-tx", db.Migrations) + Expect(dberr).NotTo(HaveOccurred(), "Failed to acquire lock") defer db.Unlock(ctx, lockOwnerID) // Randomly add Tx after lock to demonstrate it works if txAfterLock { - ctx, dberr = db.NewContext(ctx, helper.DBFactory) - if dberr != nil { - helper.T.Errorf("Failed to create transaction context: %v", dberr) - return - } + ctx, dberr = db.NewContext(ctx, h.DBFactory) + Expect(dberr).NotTo(HaveOccurred(), "Failed to create transaction context") defer db.Resolve(ctx) } - g2 := helper.DBFactory.New(ctx) + g2 := h.DBFactory.New(ctx) // Read current value from database var currentValue int - if err := g2.Raw("SELECT value FROM lock_test_counter_tx WHERE id = 1").Scan(¤tValue).Error; err != nil { - helper.T.Errorf("Failed to read counter: %v", err) - return - } + err := g2.Raw("SELECT value FROM lock_test_counter_tx WHERE id = 1").Scan(¤tValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to read counter") // Some slow work time.Sleep(20 * time.Millisecond) // Increment and save to database newValue := currentValue + 1 - if err := g2.Exec("UPDATE lock_test_counter_tx SET value = ? WHERE id = 1", newValue).Error; err != nil { - helper.T.Errorf("Failed to update counter: %v", err) - return - } + err = g2.Exec("UPDATE lock_test_counter_tx SET value = ? WHERE id = 1", newValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to update counter") } // TestLocksAndExpectedWaits validates the behavior of advisory locks: @@ -188,35 +163,30 @@ func acquireLockWithTransaction(helper *test.Helper, waiter *sync.WaitGroup) { // - Different (id, lockType) combinations should create separate locks // - Unlocking should only affect the lock matching the owner ID func TestLocksAndExpectedWaits(t *testing.T) { - helper := test.NewHelper(t) + h, _ := test.RegisterIntegration(t) // Start lock ctx := context.Background() - ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "system", db.Migrations) - if err != nil { - t.Fatalf("Failed to acquire lock: %v", err) - } + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "system", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire lock") + defer db.Unlock(ctx, lockOwnerID) // Ensure lock is released on test exit // It should have 1 lock - g2 := helper.DBFactory.New(ctx) + g2 := h.DBFactory.New(ctx) var pgLocks []struct{ Granted bool } g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) - if len(pgLocks) != 1 { - t.Errorf("Expected 1 lock, got %d", len(pgLocks)) - } + Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock") // Successive locking should have no effect (nested lock with same id/type) // Pretend this runs in a nested func - ctx, lockOwnerID2, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "system", db.Migrations) - if err != nil { - t.Fatalf("Failed to acquire nested lock: %v", err) - } + ctx, lockOwnerID2, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "system", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire nested lock") + defer db.Unlock(ctx, lockOwnerID2) // Ensure lock is released on test exit + // It should still have 1 lock pgLocks = nil g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) - if len(pgLocks) != 1 { - t.Errorf("Expected 1 lock after nested acquire, got %d", len(pgLocks)) - } + Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after nested acquire") // Unlock should have no effect either (unlocking nested lock) // Pretend this runs in the nested func @@ -224,31 +194,25 @@ func TestLocksAndExpectedWaits(t *testing.T) { // It should still have 1 lock pgLocks = nil g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) - if len(pgLocks) != 1 { - t.Errorf("Expected 1 lock after nested unlock, got %d", len(pgLocks)) - } + Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after nested unlock") // Lock on a different (id, lockType) should work // Pretend this runs in a nested func - ctx, lockOwnerID3, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "diff_system", db.Migrations) - if err != nil { - t.Fatalf("Failed to acquire different lock: %v", err) - } + ctx, lockOwnerID3, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "diff_system", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire different lock") + defer db.Unlock(ctx, lockOwnerID3) // Ensure lock is released on test exit + // It should have 2 locks pgLocks = nil g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) - if len(pgLocks) != 2 { - t.Errorf("Expected 2 locks, got %d", len(pgLocks)) - } + Expect(len(pgLocks)).To(Equal(2), "Expected 2 locks") // Pretend it releases the new lock in the nested func db.Unlock(ctx, lockOwnerID3) // It should have 1 lock pgLocks = nil g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) - if len(pgLocks) != 1 { - t.Errorf("Expected 1 lock after releasing different lock, got %d", len(pgLocks)) - } + Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after releasing different lock") // Unlock the topmost lock // Pretend it returns back to the parent func @@ -256,21 +220,18 @@ func TestLocksAndExpectedWaits(t *testing.T) { // The lock should be gone pgLocks = nil g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) - if len(pgLocks) != 0 { - t.Errorf("Expected 0 locks after final unlock, got %d", len(pgLocks)) - } + Expect(len(pgLocks)).To(Equal(0), "Expected 0 locks after final unlock") } // TestConcurrentMigrations validates that the MigrateWithLock function // properly serializes concurrent migration attempts, ensuring only one // instance actually runs migrations at a time. func TestConcurrentMigrations(t *testing.T) { - helper := test.NewHelper(t) + h, _ := test.RegisterIntegration(t) // First, reset the database to a clean state - if err := helper.ResetDB(); err != nil { - t.Fatalf("Failed to reset database: %v", err) - } + err := h.ResetDB() + Expect(err).NotTo(HaveOccurred(), "Failed to reset database") total := 5 var waiter sync.WaitGroup @@ -287,7 +248,7 @@ func TestConcurrentMigrations(t *testing.T) { defer waiter.Done() ctx := context.Background() - err := db.MigrateWithLock(ctx, helper.DBFactory) + err := db.MigrateWithLock(ctx, h.DBFactory) mu.Lock() defer mu.Unlock() @@ -303,28 +264,23 @@ func TestConcurrentMigrations(t *testing.T) { waiter.Wait() // All migrations should succeed (they're idempotent) - if len(errors) > 0 { - t.Errorf("Expected no errors, but got %d: %v", len(errors), errors) - } + Expect(errors).To(BeEmpty(), "Expected no errors during concurrent migrations") // All goroutines should complete successfully - if successCount != total { - t.Errorf("Expected %d successful migrations, got %d", total, successCount) - } + Expect(successCount).To(Equal(total), "All migrations should succeed") } // TestAdvisoryLockBlocking validates that a second goroutine trying to acquire // the same lock will block until the first goroutine releases it. func TestAdvisoryLockBlocking(t *testing.T) { - helper := test.NewHelper(t) + h, _ := test.RegisterIntegration(t) ctx := context.Background() // First goroutine acquires the lock - ctx1, lockOwnerID1, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "blocking-test", db.Migrations) - if err != nil { - t.Fatalf("Failed to acquire first lock: %v", err) - } + ctx1, lockOwnerID1, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "blocking-test", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire first lock") + defer db.Unlock(ctx1, lockOwnerID1) // Ensure lock is released on test exit // Track when the second goroutine acquires the lock acquired := make(chan bool, 1) @@ -332,11 +288,8 @@ func TestAdvisoryLockBlocking(t *testing.T) { // Second goroutine tries to acquire the same lock go func() { - ctx2, lockOwnerID2, err := db.NewAdvisoryLockContext(context.Background(), helper.DBFactory, "blocking-test", db.Migrations) - if err != nil { - t.Errorf("Failed to acquire second lock: %v", err) - return - } + ctx2, lockOwnerID2, err := db.NewAdvisoryLockContext(context.Background(), h.DBFactory, "blocking-test", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire second lock") defer db.Unlock(ctx2, lockOwnerID2) acquired <- true @@ -346,14 +299,13 @@ func TestAdvisoryLockBlocking(t *testing.T) { // Wait for the second goroutine to be actively waiting on the lock // by polling pg_locks for a non-granted advisory lock. // This is more reliable than sleep, especially in slow CI environments. - g2 := helper.DBFactory.New(ctx) + g2 := h.DBFactory.New(ctx) waitingForLock := false for i := 0; i < 50; i++ { // Poll for up to 5 seconds (50 * 100ms) var waitingLocks []struct{ Granted bool } - if err := g2.Raw("SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false").Scan(&waitingLocks).Error; err != nil { - t.Errorf("Failed to query pg_locks: %v", err) - break - } + err := g2.Raw("SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false").Scan(&waitingLocks).Error + Expect(err).NotTo(HaveOccurred(), "Failed to query pg_locks") + if len(waitingLocks) > 0 { waitingForLock = true break @@ -361,9 +313,7 @@ func TestAdvisoryLockBlocking(t *testing.T) { time.Sleep(100 * time.Millisecond) } - if !waitingForLock { - t.Fatal("Second goroutine did not reach the lock waiting state within timeout") - } + Expect(waitingForLock).To(BeTrue(), "Second goroutine should be waiting for lock") // The second goroutine should still be blocked select { @@ -390,15 +340,13 @@ func TestAdvisoryLockBlocking(t *testing.T) { // terminates a waiting advisory lock acquisition. The context is passed through // connection.New(ctx) and affects the blocking pg_advisory_xact_lock SQL call. func TestAdvisoryLockContextCancellation(t *testing.T) { - helper := test.NewHelper(t) + h, _ := test.RegisterIntegration(t) ctx := context.Background() // First goroutine acquires the lock - ctx1, lockOwnerID1, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "cancel-test", db.Migrations) - if err != nil { - t.Fatalf("Failed to acquire first lock: %v", err) - } + ctx1, lockOwnerID1, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "cancel-test", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire first lock") defer db.Unlock(ctx1, lockOwnerID1) // Track when the second goroutine gets cancelled @@ -409,25 +357,30 @@ func TestAdvisoryLockContextCancellation(t *testing.T) { // Second goroutine tries to acquire the same lock with cancellable context go func() { - _, _, err := db.NewAdvisoryLockContext(ctx2, helper.DBFactory, "cancel-test", db.Migrations) + _, _, err := db.NewAdvisoryLockContext(ctx2, h.DBFactory, "cancel-test", db.Migrations) if err != nil { - // Expected: context cancellation causes "canceling statement due to user request" - t.Logf("Second goroutine got error (expected): %v", err) - gotCancelError <- true + // Check if this is a cancellation-type error + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || + strings.Contains(err.Error(), "canceling statement due to user request") { + // Expected: context cancellation causes proper cancellation error + gotCancelError <- true + return + } + // Unexpected error - fail the test + t.Errorf("Unexpected error from lock acquisition: %v", err) return } t.Error("Second goroutine acquired lock despite context cancellation (unexpected)") }() // Wait for the second goroutine to be actively waiting on the lock - g2 := helper.DBFactory.New(ctx) + g2 := h.DBFactory.New(ctx) waitingForLock := false for i := 0; i < 50; i++ { var waitingLocks []struct{ Granted bool } - if err := g2.Raw("SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false").Scan(&waitingLocks).Error; err != nil { - t.Errorf("Failed to query pg_locks: %v", err) - break - } + err := g2.Raw("SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false").Scan(&waitingLocks).Error + Expect(err).NotTo(HaveOccurred(), "Failed to query pg_locks") + if len(waitingLocks) > 0 { waitingForLock = true break @@ -435,9 +388,7 @@ func TestAdvisoryLockContextCancellation(t *testing.T) { time.Sleep(100 * time.Millisecond) } - if !waitingForLock { - t.Fatal("Second goroutine did not reach the lock waiting state within timeout") - } + Expect(waitingForLock).To(BeTrue(), "Second goroutine should be waiting for lock") // Cancel the context while the second goroutine is waiting cancel() @@ -446,7 +397,6 @@ func TestAdvisoryLockContextCancellation(t *testing.T) { select { case <-gotCancelError: // Expected: context cancellation terminates the lock acquisition - t.Log("Confirmed: context cancellation properly terminates waiting advisory lock") case <-time.After(2 * time.Second): t.Error("Second goroutine did not exit after context cancellation within timeout") } @@ -456,12 +406,11 @@ func TestAdvisoryLockContextCancellation(t *testing.T) { // the advisory lock, the lock is properly released via defer, allowing other waiters // to proceed. This tests the error path and cleanup behavior. func TestMigrationFailureUnderLock(t *testing.T) { - helper := test.NewHelper(t) + h, _ := test.RegisterIntegration(t) // Reset database to clean state - if err := helper.ResetDB(); err != nil { - t.Fatalf("Failed to reset database: %v", err) - } + err := h.ResetDB() + Expect(err).NotTo(HaveOccurred(), "Failed to reset database") // Track results var mu sync.Mutex @@ -480,15 +429,12 @@ func TestMigrationFailureUnderLock(t *testing.T) { defer wg.Done() ctx := context.Background() - ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "migration-fail-test", db.Migrations) - if err != nil { - t.Errorf("Failed to acquire lock: %v", err) - return - } + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "migration-fail-test", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire lock") defer db.Unlock(ctx, lockOwnerID) // Simulate migration failure - if err := failingMigration(helper.DBFactory.New(ctx)); err != nil { + if err := failingMigration(h.DBFactory.New(ctx)); err != nil { mu.Lock() failureCount++ mu.Unlock() @@ -505,11 +451,8 @@ func TestMigrationFailureUnderLock(t *testing.T) { defer wg.Done() ctx := context.Background() - ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, helper.DBFactory, "migration-fail-test", db.Migrations) - if err != nil { - t.Errorf("Failed to acquire lock after failure: %v", err) - return - } + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "migration-fail-test", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire lock after failure") defer db.Unlock(ctx, lockOwnerID) // This one succeeds @@ -521,12 +464,6 @@ func TestMigrationFailureUnderLock(t *testing.T) { wg.Wait() // Verify both completed - if failureCount != 1 { - t.Errorf("Expected 1 failure, got %d", failureCount) - } - if successCount != 1 { - t.Errorf("Expected 1 success, got %d", successCount) - } - - t.Log("Confirmed: lock properly released after migration failure, allowing subsequent operations") + Expect(failureCount).To(Equal(1), "Expected 1 failure") + Expect(successCount).To(Equal(1), "Expected 1 success") } From 601e80a33994f30325a210816a1e17850ce1556c Mon Sep 17 00:00:00 2001 From: Leonardo Dorneles <149205964+ldornele@users.noreply.github.com> Date: Wed, 11 Mar 2026 18:47:36 -0300 Subject: [PATCH 4/6] Update test/integration/advisory_locks_test.go Co-authored-by: Rafael Benevides --- test/integration/advisory_locks_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/test/integration/advisory_locks_test.go b/test/integration/advisory_locks_test.go index fcf4723..2fb4821 100644 --- a/test/integration/advisory_locks_test.go +++ b/test/integration/advisory_locks_test.go @@ -285,6 +285,7 @@ func TestAdvisoryLockBlocking(t *testing.T) { // Track when the second goroutine acquires the lock acquired := make(chan bool, 1) released := make(chan bool, 1) + defer close(released) // ensure goroutine exits even on timeout // Second goroutine tries to acquire the same lock go func() { From a9514543e536718728950000e21f532a6050ad17 Mon Sep 17 00:00:00 2001 From: ldornele Date: Thu, 12 Mar 2026 00:27:06 -0300 Subject: [PATCH 5/6] HYPERFLEET-618 - test: add DatabaseConfig test coverage and improve advisory lock tests --- pkg/config/db_test.go | 433 ++++++++++++++++++++++++ pkg/db/mocks/session_factory.go | 2 +- test/integration/advisory_locks_test.go | 161 ++++++--- 3 files changed, 551 insertions(+), 45 deletions(-) create mode 100644 pkg/config/db_test.go diff --git a/pkg/config/db_test.go b/pkg/config/db_test.go new file mode 100644 index 0000000..d6f0342 --- /dev/null +++ b/pkg/config/db_test.go @@ -0,0 +1,433 @@ +package config + +import ( + "os" + "testing" + + "github.com/spf13/pflag" +) + +const ( + testAdvisoryLockTimeout = 600 +) + +// TestNewDatabaseConfig_Defaults tests default configuration values +func TestNewDatabaseConfig_Defaults(t *testing.T) { + cfg := NewDatabaseConfig() + + tests := []struct { + name string + got interface{} + expected interface{} + }{ + {"Dialect", cfg.Dialect, "postgres"}, + {"SSLMode", cfg.SSLMode, "disable"}, + {"Debug", cfg.Debug, false}, + {"MaxOpenConnections", cfg.MaxOpenConnections, 50}, + {"AdvisoryLockTimeoutSeconds", cfg.AdvisoryLockTimeoutSeconds, 300}, + {"MaxIdleConnections", cfg.MaxIdleConnections, 10}, + {"ConnRetryAttempts", cfg.ConnRetryAttempts, 10}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.got != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, tt.got) + } + }) + } +} + +// TestDatabaseConfig_AddFlags tests CLI flag registration +func TestDatabaseConfig_AddFlags(t *testing.T) { + cfg := NewDatabaseConfig() + fs := pflag.NewFlagSet("test", pflag.ContinueOnError) + + cfg.AddFlags(fs) + + // Verify flags are registered + flags := []string{ + "db-host-file", + "db-port-file", + "db-user-file", + "db-password-file", + "db-name-file", + "db-sslmode", + "enable-db-debug", + "db-max-open-connections", + "db-advisory-lock-timeout", + } + for _, flagName := range flags { + t.Run("flag_"+flagName, func(t *testing.T) { + if fs.Lookup(flagName) == nil { + t.Errorf("expected %s flag to be registered", flagName) + } + }) + } + + // Test flag parsing for advisory lock timeout + tests := []struct { + name string + args []string + expected int + }{ + { + name: "default advisory lock timeout", + args: []string{}, + expected: 300, + }, + { + name: "custom advisory lock timeout", + args: []string{"--db-advisory-lock-timeout=600"}, + expected: 600, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := NewDatabaseConfig() + fs := pflag.NewFlagSet("test", pflag.ContinueOnError) + cfg.AddFlags(fs) + + if err := fs.Parse(tt.args); err != nil { + t.Fatalf("failed to parse flags: %v", err) + } + + if cfg.AdvisoryLockTimeoutSeconds != tt.expected { + t.Errorf("expected AdvisoryLockTimeoutSeconds %d, got %d", tt.expected, cfg.AdvisoryLockTimeoutSeconds) + } + }) + } +} + +// TestDatabaseConfig_BindEnv tests environment variable binding +func TestDatabaseConfig_BindEnv(t *testing.T) { + tests := []struct { + name string + envVars map[string]string + validate func(*testing.T, *DatabaseConfig) + }{ + { + name: "valid advisory lock timeout", + envVars: map[string]string{ + "DB_ADVISORY_LOCK_TIMEOUT": "600", + }, + validate: func(t *testing.T, cfg *DatabaseConfig) { + if cfg.AdvisoryLockTimeoutSeconds != 600 { + t.Errorf("expected AdvisoryLockTimeoutSeconds 600, got %d", cfg.AdvisoryLockTimeoutSeconds) + } + }, + }, + { + name: "valid db debug true", + envVars: map[string]string{ + "DB_DEBUG": "true", + }, + validate: func(t *testing.T, cfg *DatabaseConfig) { + if cfg.Debug != true { + t.Errorf("expected Debug true, got %t", cfg.Debug) + } + }, + }, + { + name: "valid db debug false", + envVars: map[string]string{ + "DB_DEBUG": "false", + }, + validate: func(t *testing.T, cfg *DatabaseConfig) { + if cfg.Debug != false { + t.Errorf("expected Debug false, got %t", cfg.Debug) + } + }, + }, + { + name: "zero timeout keeps default", + envVars: map[string]string{ + "DB_ADVISORY_LOCK_TIMEOUT": "0", + }, + validate: func(t *testing.T, cfg *DatabaseConfig) { + if cfg.AdvisoryLockTimeoutSeconds != 300 { + t.Errorf("expected AdvisoryLockTimeoutSeconds to keep default (300), got %d", cfg.AdvisoryLockTimeoutSeconds) + } + }, + }, + { + name: "negative timeout keeps default", + envVars: map[string]string{ + "DB_ADVISORY_LOCK_TIMEOUT": "-1", + }, + validate: func(t *testing.T, cfg *DatabaseConfig) { + if cfg.AdvisoryLockTimeoutSeconds != 300 { + t.Errorf("expected AdvisoryLockTimeoutSeconds to keep default (300), got %d", cfg.AdvisoryLockTimeoutSeconds) + } + }, + }, + { + name: "invalid timeout string keeps default", + envVars: map[string]string{ + "DB_ADVISORY_LOCK_TIMEOUT": "abc", + }, + validate: func(t *testing.T, cfg *DatabaseConfig) { + if cfg.AdvisoryLockTimeoutSeconds != 300 { + t.Errorf("expected AdvisoryLockTimeoutSeconds to keep default (300), got %d", cfg.AdvisoryLockTimeoutSeconds) + } + }, + }, + { + name: "invalid bool value keeps default", + envVars: map[string]string{ + "DB_DEBUG": "not-a-bool", + }, + validate: func(t *testing.T, cfg *DatabaseConfig) { + if cfg.Debug != false { + t.Errorf("expected Debug to keep default (false), got %t", cfg.Debug) + } + }, + }, + { + name: "empty timeout keeps default", + envVars: map[string]string{ + "DB_ADVISORY_LOCK_TIMEOUT": "", + }, + validate: func(t *testing.T, cfg *DatabaseConfig) { + if cfg.AdvisoryLockTimeoutSeconds != 300 { + t.Errorf("expected AdvisoryLockTimeoutSeconds to keep default (300), got %d", cfg.AdvisoryLockTimeoutSeconds) + } + }, + }, + { + name: "empty debug keeps default", + envVars: map[string]string{ + "DB_DEBUG": "", + }, + validate: func(t *testing.T, cfg *DatabaseConfig) { + if cfg.Debug != false { + t.Errorf("expected Debug to keep default (false), got %t", cfg.Debug) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save and restore env vars + oldEnvs := make(map[string]string) + for key := range tt.envVars { + oldEnvs[key] = os.Getenv(key) + } + defer func() { + for key, val := range oldEnvs { + if val == "" { + _ = os.Unsetenv(key) + } else { + _ = os.Setenv(key, val) + } + } + }() + + // Set env vars + for key, val := range tt.envVars { + if val != "" { + if err := os.Setenv(key, val); err != nil { + t.Fatalf("failed to set env var %s: %v", key, err) + } + } else { + _ = os.Unsetenv(key) + } + } + + cfg := NewDatabaseConfig() + cfg.BindEnv(nil) + + tt.validate(t, cfg) + }) + } +} + +// TestDatabaseConfig_FlagsOverrideEnv tests that CLI flags override environment variables +func TestDatabaseConfig_FlagsOverrideEnv(t *testing.T) { + // Save and restore env var + oldTimeout := os.Getenv("DB_ADVISORY_LOCK_TIMEOUT") + defer func() { + if oldTimeout == "" { + _ = os.Unsetenv("DB_ADVISORY_LOCK_TIMEOUT") + } else { + _ = os.Setenv("DB_ADVISORY_LOCK_TIMEOUT", oldTimeout) + } + }() + + // Set env var to "600" + if err := os.Setenv("DB_ADVISORY_LOCK_TIMEOUT", "600"); err != nil { + t.Fatalf("failed to set DB_ADVISORY_LOCK_TIMEOUT: %v", err) + } + + cfg := NewDatabaseConfig() + fs := pflag.NewFlagSet("test", pflag.ContinueOnError) + cfg.AddFlags(fs) + + // Parse flags with different value + args := []string{"--db-advisory-lock-timeout=120"} + if err := fs.Parse(args); err != nil { + t.Fatalf("failed to parse flags: %v", err) + } + + // Before BindEnv, should have flag value + if cfg.AdvisoryLockTimeoutSeconds != 120 { + t.Errorf("expected AdvisoryLockTimeoutSeconds 120 from flag, got %d", cfg.AdvisoryLockTimeoutSeconds) + } + + // After BindEnv, flag should take priority over env var + cfg.BindEnv(fs) + if cfg.AdvisoryLockTimeoutSeconds != 120 { + t.Errorf("expected AdvisoryLockTimeoutSeconds 120 (flag > env), got %d", cfg.AdvisoryLockTimeoutSeconds) + } +} + +// TestDatabaseConfig_EnvOverridesDefaults tests that env vars override defaults when no flag is set +func TestDatabaseConfig_EnvOverridesDefaults(t *testing.T) { + // Save and restore env var + oldTimeout := os.Getenv("DB_ADVISORY_LOCK_TIMEOUT") + defer func() { + if oldTimeout == "" { + _ = os.Unsetenv("DB_ADVISORY_LOCK_TIMEOUT") + } else { + _ = os.Setenv("DB_ADVISORY_LOCK_TIMEOUT", oldTimeout) + } + }() + + // Set env var + if err := os.Setenv("DB_ADVISORY_LOCK_TIMEOUT", "450"); err != nil { + t.Fatalf("failed to set DB_ADVISORY_LOCK_TIMEOUT: %v", err) + } + + cfg := NewDatabaseConfig() + fs := pflag.NewFlagSet("test", pflag.ContinueOnError) + cfg.AddFlags(fs) + + // Parse empty args (no flags set) + if err := fs.Parse([]string{}); err != nil { + t.Fatalf("failed to parse flags: %v", err) + } + + // Before BindEnv, should have default value + if cfg.AdvisoryLockTimeoutSeconds != 300 { + t.Errorf("expected AdvisoryLockTimeoutSeconds 300 (default), got %d", cfg.AdvisoryLockTimeoutSeconds) + } + + // After BindEnv, env var should override default + cfg.BindEnv(fs) + if cfg.AdvisoryLockTimeoutSeconds != 450 { + t.Errorf("expected AdvisoryLockTimeoutSeconds 450 (env > default), got %d", cfg.AdvisoryLockTimeoutSeconds) + } +} + +// TestDatabaseConfig_PriorityMixed tests priority with multiple fields and mixed sources +func TestDatabaseConfig_PriorityMixed(t *testing.T) { + // Save and restore env vars + envVars := map[string]string{ + "DB_ADVISORY_LOCK_TIMEOUT": os.Getenv("DB_ADVISORY_LOCK_TIMEOUT"), + "DB_DEBUG": os.Getenv("DB_DEBUG"), + } + defer func() { + for key, val := range envVars { + if val == "" { + _ = os.Unsetenv(key) + } else { + _ = os.Setenv(key, val) + } + } + }() + + // Set env vars for both fields + _ = os.Setenv("DB_ADVISORY_LOCK_TIMEOUT", "600") + _ = os.Setenv("DB_DEBUG", "true") + + cfg := NewDatabaseConfig() + fs := pflag.NewFlagSet("test", pflag.ContinueOnError) + cfg.AddFlags(fs) + + // Only set flag for advisory lock timeout + if err := fs.Parse([]string{"--db-advisory-lock-timeout=240"}); err != nil { + t.Fatalf("failed to parse flags: %v", err) + } + + cfg.BindEnv(fs) + + // advisory lock timeout: flag wins over env + if cfg.AdvisoryLockTimeoutSeconds != 240 { + t.Errorf("expected AdvisoryLockTimeoutSeconds 240 (flag > env), got %d", cfg.AdvisoryLockTimeoutSeconds) + } + // db debug: env wins over default + if cfg.Debug != true { + t.Errorf("expected Debug true (env > default), got %t", cfg.Debug) + } +} + +// TestDatabaseConfig_InvalidEnvHandling documents that invalid env values are silently ignored +func TestDatabaseConfig_InvalidEnvHandling(t *testing.T) { + tests := []struct { + name string + envVar string + envValue string + description string + }{ + { + name: "zero timeout", + envVar: "DB_ADVISORY_LOCK_TIMEOUT", + envValue: "0", + description: "Zero timeout is rejected by validation (timeout > 0), keeps default", + }, + { + name: "negative timeout", + envVar: "DB_ADVISORY_LOCK_TIMEOUT", + envValue: "-1", + description: "Negative timeout is rejected by validation (timeout > 0), keeps default", + }, + { + name: "non-numeric timeout", + envVar: "DB_ADVISORY_LOCK_TIMEOUT", + envValue: "abc", + description: "Non-numeric value fails strconv.Atoi, keeps default", + }, + { + name: "invalid bool", + envVar: "DB_DEBUG", + envValue: "not-a-bool", + description: "Invalid bool fails strconv.ParseBool, keeps default", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save and restore env var + oldVal := os.Getenv(tt.envVar) + defer func() { + if oldVal == "" { + _ = os.Unsetenv(tt.envVar) + } else { + _ = os.Setenv(tt.envVar, oldVal) + } + }() + + if err := os.Setenv(tt.envVar, tt.envValue); err != nil { + t.Fatalf("failed to set %s: %v", tt.envVar, err) + } + + cfg := NewDatabaseConfig() + cfg.BindEnv(nil) + + // Document the behavior: invalid values are silently ignored + if tt.envVar == "DB_ADVISORY_LOCK_TIMEOUT" { + if cfg.AdvisoryLockTimeoutSeconds != 300 { + t.Errorf("expected default AdvisoryLockTimeoutSeconds (300) after invalid env, got %d", cfg.AdvisoryLockTimeoutSeconds) + } + t.Logf("INFO: %s - invalid value silently ignored, kept default 300", tt.description) + } else if tt.envVar == "DB_DEBUG" { + if cfg.Debug != false { + t.Errorf("expected default Debug (false) after invalid env, got %t", cfg.Debug) + } + t.Logf("INFO: %s - invalid value silently ignored, kept default false", tt.description) + } + }) + } +} diff --git a/pkg/db/mocks/session_factory.go b/pkg/db/mocks/session_factory.go index 7bfddb3..0657f7c 100755 --- a/pkg/db/mocks/session_factory.go +++ b/pkg/db/mocks/session_factory.go @@ -79,5 +79,5 @@ func (m *MockSessionFactory) NewListener(ctx context.Context, channel string, ca } func (m *MockSessionFactory) GetAdvisoryLockTimeout() int { - return 300 // 5 minutes default + return config.NewDatabaseConfig().AdvisoryLockTimeoutSeconds } diff --git a/test/integration/advisory_locks_test.go b/test/integration/advisory_locks_test.go index 2fb4821..0a24efd 100644 --- a/test/integration/advisory_locks_test.go +++ b/test/integration/advisory_locks_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "math/rand" "strings" "sync" "testing" @@ -93,37 +92,60 @@ func TestAdvisoryLocksWithTransactions(t *testing.T) { Expect(err).NotTo(HaveOccurred(), "Failed to initialize counter") defer g2.Exec("DROP TABLE IF EXISTS lock_test_counter_tx") - total := 10 - var waiter sync.WaitGroup - waiter.Add(total) - - for i := 0; i < total; i++ { - go acquireLockWithTransaction(h, &waiter) + // Test all three transaction ordering scenarios deterministically + testCases := []struct { + name string + txBeforeLock bool + txAfterLock bool + }{ + { + name: "tx_before_lock", + txBeforeLock: true, + txAfterLock: false, + }, + { + name: "tx_after_lock", + txBeforeLock: false, + txAfterLock: true, + }, + { + name: "no_tx", + txBeforeLock: false, + txAfterLock: false, + }, } - waiter.Wait() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Run multiple goroutines for each scenario to test concurrency + goroutines := 3 + var waiter sync.WaitGroup + waiter.Add(goroutines) - // All goroutines should have incremented the counter by 1, resulting in 10 + for i := 0; i < goroutines; i++ { + go acquireLockWithTransaction(h, &waiter, tc.txBeforeLock, tc.txAfterLock) + } + + waiter.Wait() + }) + } + + // All test cases combined should have incremented the counter by 9 (3 scenarios × 3 goroutines) + expectedTotal := 9 var finalValue int err = g2.Raw("SELECT value FROM lock_test_counter_tx WHERE id = 1").Scan(&finalValue).Error Expect(err).NotTo(HaveOccurred(), "Failed to read final counter value") - Expect(finalValue).To(Equal(total), "Counter should equal total") + Expect(finalValue).To(Equal(expectedTotal), "Counter should equal total") } -func acquireLockWithTransaction(h *test.Helper, waiter *sync.WaitGroup) { +func acquireLockWithTransaction(h *test.Helper, waiter *sync.WaitGroup, txBeforeLock bool, txAfterLock bool) { defer waiter.Done() ctx := context.Background() - // Lock and Tx can be stored within the same context. They should be independent of each other. - // It doesn't matter if a Tx coexists or not, nor does it matter if it occurs before or after the lock - r := rand.Intn(3) // no Tx if r == 2 - txBeforeLock := r == 0 - txAfterLock := r == 1 - var dberr error - // Randomly add Tx before lock to demonstrate it works + // Add Tx before lock if requested if txBeforeLock { ctx, dberr = db.NewContext(ctx, h.DBFactory) Expect(dberr).NotTo(HaveOccurred(), "Failed to create transaction context") @@ -135,7 +157,7 @@ func acquireLockWithTransaction(h *test.Helper, waiter *sync.WaitGroup) { Expect(dberr).NotTo(HaveOccurred(), "Failed to acquire lock") defer db.Unlock(ctx, lockOwnerID) - // Randomly add Tx after lock to demonstrate it works + // Add Tx after lock if requested if txAfterLock { ctx, dberr = db.NewContext(ctx, h.DBFactory) Expect(dberr).NotTo(HaveOccurred(), "Failed to create transaction context") @@ -174,7 +196,7 @@ func TestLocksAndExpectedWaits(t *testing.T) { // It should have 1 lock g2 := h.DBFactory.New(ctx) var pgLocks []struct{ Granted bool } - g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock") // Successive locking should have no effect (nested lock with same id/type) @@ -185,7 +207,7 @@ func TestLocksAndExpectedWaits(t *testing.T) { // It should still have 1 lock pgLocks = nil - g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after nested acquire") // Unlock should have no effect either (unlocking nested lock) @@ -193,7 +215,7 @@ func TestLocksAndExpectedWaits(t *testing.T) { db.Unlock(ctx, lockOwnerID2) // It should still have 1 lock pgLocks = nil - g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after nested unlock") // Lock on a different (id, lockType) should work @@ -204,14 +226,14 @@ func TestLocksAndExpectedWaits(t *testing.T) { // It should have 2 locks pgLocks = nil - g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(2), "Expected 2 locks") // Pretend it releases the new lock in the nested func db.Unlock(ctx, lockOwnerID3) // It should have 1 lock pgLocks = nil - g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after releasing different lock") // Unlock the topmost lock @@ -219,7 +241,7 @@ func TestLocksAndExpectedWaits(t *testing.T) { db.Unlock(ctx, lockOwnerID) // The lock should be gone pgLocks = nil - g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks) + err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(0), "Expected 0 locks after final unlock") } @@ -356,8 +378,13 @@ func TestAdvisoryLockContextCancellation(t *testing.T) { // Create a cancellable context for the second goroutine ctx2, cancel := context.WithCancel(context.Background()) + // Use WaitGroup to ensure goroutine exits before test cleanup + var wg sync.WaitGroup + wg.Add(1) + // Second goroutine tries to acquire the same lock with cancellable context go func() { + defer wg.Done() _, _, err := db.NewAdvisoryLockContext(ctx2, h.DBFactory, "cancel-test", db.Migrations) if err != nil { // Check if this is a cancellation-type error @@ -401,11 +428,33 @@ func TestAdvisoryLockContextCancellation(t *testing.T) { case <-time.After(2 * time.Second): t.Error("Second goroutine did not exit after context cancellation within timeout") } + + // Ensure goroutine exits before test cleanup + wg.Wait() +} + +// migrateWithLockAndCustomMigration mimics db.MigrateWithLock but accepts a custom migration function +// This allows testing the lock acquisition/release pattern with controlled success/failure +func migrateWithLockAndCustomMigration(ctx context.Context, factory db.SessionFactory, migrationFunc func(*gorm.DB) error) error { + // Acquire advisory lock for migrations (same pattern as production MigrateWithLock) + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, factory, db.MigrationsLockID, db.Migrations) + if err != nil { + return err + } + defer db.Unlock(ctx, lockOwnerID) + + // Run custom migration with the locked context + g2 := factory.New(ctx) + if err := migrationFunc(g2); err != nil { + return err + } + + return nil } // TestMigrationFailureUnderLock validates that when a migration fails while holding // the advisory lock, the lock is properly released via defer, allowing other waiters -// to proceed. This tests the error path and cleanup behavior. +// to proceed. This tests the error path and cleanup behavior of the MigrateWithLock pattern. func TestMigrationFailureUnderLock(t *testing.T) { h, _ := test.RegisterIntegration(t) @@ -413,58 +462,82 @@ func TestMigrationFailureUnderLock(t *testing.T) { err := h.ResetDB() Expect(err).NotTo(HaveOccurred(), "Failed to reset database") + // Channels to coordinate goroutines + firstLockAcquired := make(chan bool, 1) + firstMigrationFailed := make(chan bool, 1) + secondCanProceed := make(chan bool, 1) + // Track results var mu sync.Mutex successCount := 0 failureCount := 0 var wg sync.WaitGroup - // Create a failing migration function + // Create a failing migration function that signals when it acquires lock and fails failingMigration := func(g2 *gorm.DB) error { + firstLockAcquired <- true + // Wait a bit to ensure second goroutine tries to acquire + time.Sleep(50 * time.Millisecond) return fmt.Errorf("simulated migration failure") } - // First goroutine: acquire lock and fail migration + // Create a successful migration function + successfulMigration := func(g2 *gorm.DB) error { + return nil + } + + // First goroutine: acquire lock and fail migration using production code path wg.Add(1) go func() { defer wg.Done() ctx := context.Background() - ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "migration-fail-test", db.Migrations) - Expect(err).NotTo(HaveOccurred(), "Failed to acquire lock") - defer db.Unlock(ctx, lockOwnerID) + err := migrateWithLockAndCustomMigration(ctx, h.DBFactory, failingMigration) - // Simulate migration failure - if err := failingMigration(h.DBFactory.New(ctx)); err != nil { - mu.Lock() + mu.Lock() + if err != nil { failureCount++ - mu.Unlock() } + mu.Unlock() + + firstMigrationFailed <- true // Lock should be released via defer even though migration failed }() - // Give first goroutine time to acquire lock and fail - time.Sleep(100 * time.Millisecond) + // Wait for first goroutine to acquire lock + <-firstLockAcquired - // Second goroutine: should be able to acquire lock after first fails + // Second goroutine: should block until first releases lock, then succeed wg.Add(1) go func() { defer wg.Done() ctx := context.Background() - ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "migration-fail-test", db.Migrations) - Expect(err).NotTo(HaveOccurred(), "Failed to acquire lock after failure") - defer db.Unlock(ctx, lockOwnerID) + err := migrateWithLockAndCustomMigration(ctx, h.DBFactory, successfulMigration) - // This one succeeds mu.Lock() - successCount++ + if err == nil { + successCount++ + } mu.Unlock() + + secondCanProceed <- true }() + // Wait for first migration to fail and release lock + <-firstMigrationFailed + + // Wait for second migration to complete + select { + case <-secondCanProceed: + // Expected: second goroutine acquired lock after first released it + case <-time.After(3 * time.Second): + t.Error("Second goroutine did not acquire lock after first failed") + } + wg.Wait() - // Verify both completed + // Verify both completed as expected Expect(failureCount).To(Equal(1), "Expected 1 failure") Expect(successCount).To(Equal(1), "Expected 1 success") } From 2d44ef618adefd7c96e65668a28d88745cbf3ae5 Mon Sep 17 00:00:00 2001 From: ldornele Date: Thu, 12 Mar 2026 17:53:44 -0300 Subject: [PATCH 6/6] HYPERFLEET-618 - Fix linter issues: break long lines, use tagged switch, remove unused parameters, mark intentionally ignored errors with underscore --- pkg/config/db_test.go | 13 +++++---- pkg/db/context.go | 8 +++--- test/integration/advisory_locks_test.go | 35 +++++++++++++++---------- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/pkg/config/db_test.go b/pkg/config/db_test.go index d6f0342..097b1d7 100644 --- a/pkg/config/db_test.go +++ b/pkg/config/db_test.go @@ -7,10 +7,6 @@ import ( "github.com/spf13/pflag" ) -const ( - testAdvisoryLockTimeout = 600 -) - // TestNewDatabaseConfig_Defaults tests default configuration values func TestNewDatabaseConfig_Defaults(t *testing.T) { cfg := NewDatabaseConfig() @@ -417,12 +413,15 @@ func TestDatabaseConfig_InvalidEnvHandling(t *testing.T) { cfg.BindEnv(nil) // Document the behavior: invalid values are silently ignored - if tt.envVar == "DB_ADVISORY_LOCK_TIMEOUT" { + switch tt.envVar { + case "DB_ADVISORY_LOCK_TIMEOUT": if cfg.AdvisoryLockTimeoutSeconds != 300 { - t.Errorf("expected default AdvisoryLockTimeoutSeconds (300) after invalid env, got %d", cfg.AdvisoryLockTimeoutSeconds) + t.Errorf( + "expected default AdvisoryLockTimeoutSeconds (300) after invalid env, got %d", + cfg.AdvisoryLockTimeoutSeconds) } t.Logf("INFO: %s - invalid value silently ignored, kept default 300", tt.description) - } else if tt.envVar == "DB_DEBUG" { + case "DB_DEBUG": if cfg.Debug != false { t.Errorf("expected default Debug (false) after invalid env, got %t", cfg.Debug) } diff --git a/pkg/db/context.go b/pkg/db/context.go index 8094a9f..c262bbf 100755 --- a/pkg/db/context.go +++ b/pkg/db/context.go @@ -143,14 +143,14 @@ func Unlock(ctx context.Context, callerUUID string) { } if err := lock.unlock(ctx); err != nil { - logger.With(ctx, logger.FieldLockID, lockID, logger.FieldLockType, lockType).WithError(err).Error("Could not unlock lock") + logger.With(ctx, logger.FieldLockID, lockID, logger.FieldLockType, lockType). + WithError(err).Error("Could not unlock lock") continue } logger.With(ctx, logger.FieldLockID, lockID, logger.FieldLockType, lockType).Info("Unlocked lock") delete(locks, k) - } else { - // the resolving UUID belongs to a service call that did *not* initiate the lock. - // it is ignored. } + // Note: if ownerUUID doesn't match callerUUID, the lock belongs to a different + // service call and is intentionally not unlocked here } } diff --git a/test/integration/advisory_locks_test.go b/test/integration/advisory_locks_test.go index 0a24efd..bf85538 100644 --- a/test/integration/advisory_locks_test.go +++ b/test/integration/advisory_locks_test.go @@ -196,7 +196,7 @@ func TestLocksAndExpectedWaits(t *testing.T) { // It should have 1 lock g2 := h.DBFactory.New(ctx) var pgLocks []struct{ Granted bool } - err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock") // Successive locking should have no effect (nested lock with same id/type) @@ -207,7 +207,7 @@ func TestLocksAndExpectedWaits(t *testing.T) { // It should still have 1 lock pgLocks = nil - err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after nested acquire") // Unlock should have no effect either (unlocking nested lock) @@ -215,7 +215,7 @@ func TestLocksAndExpectedWaits(t *testing.T) { db.Unlock(ctx, lockOwnerID2) // It should still have 1 lock pgLocks = nil - err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after nested unlock") // Lock on a different (id, lockType) should work @@ -226,14 +226,14 @@ func TestLocksAndExpectedWaits(t *testing.T) { // It should have 2 locks pgLocks = nil - err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(2), "Expected 2 locks") // Pretend it releases the new lock in the nested func db.Unlock(ctx, lockOwnerID3) // It should have 1 lock pgLocks = nil - err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after releasing different lock") // Unlock the topmost lock @@ -241,7 +241,7 @@ func TestLocksAndExpectedWaits(t *testing.T) { db.Unlock(ctx, lockOwnerID) // The lock should be gone pgLocks = nil - err = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error Expect(len(pgLocks)).To(Equal(0), "Expected 0 locks after final unlock") } @@ -266,7 +266,7 @@ func TestConcurrentMigrations(t *testing.T) { // Simulate multiple pods trying to run migrations concurrently for i := 0; i < total; i++ { - go func(id int) { + go func() { defer waiter.Done() ctx := context.Background() @@ -280,7 +280,7 @@ func TestConcurrentMigrations(t *testing.T) { } else { successCount++ } - }(i) + }() } waiter.Wait() @@ -311,7 +311,8 @@ func TestAdvisoryLockBlocking(t *testing.T) { // Second goroutine tries to acquire the same lock go func() { - ctx2, lockOwnerID2, err := db.NewAdvisoryLockContext(context.Background(), h.DBFactory, "blocking-test", db.Migrations) + ctx2, lockOwnerID2, err := db.NewAdvisoryLockContext( + context.Background(), h.DBFactory, "blocking-test", db.Migrations) Expect(err).NotTo(HaveOccurred(), "Failed to acquire second lock") defer db.Unlock(ctx2, lockOwnerID2) @@ -326,7 +327,8 @@ func TestAdvisoryLockBlocking(t *testing.T) { waitingForLock := false for i := 0; i < 50; i++ { // Poll for up to 5 seconds (50 * 100ms) var waitingLocks []struct{ Granted bool } - err := g2.Raw("SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false").Scan(&waitingLocks).Error + query := "SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false" + err := g2.Raw(query).Scan(&waitingLocks).Error Expect(err).NotTo(HaveOccurred(), "Failed to query pg_locks") if len(waitingLocks) > 0 { @@ -406,7 +408,8 @@ func TestAdvisoryLockContextCancellation(t *testing.T) { waitingForLock := false for i := 0; i < 50; i++ { var waitingLocks []struct{ Granted bool } - err := g2.Raw("SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false").Scan(&waitingLocks).Error + query := "SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false" + err := g2.Raw(query).Scan(&waitingLocks).Error Expect(err).NotTo(HaveOccurred(), "Failed to query pg_locks") if len(waitingLocks) > 0 { @@ -435,7 +438,11 @@ func TestAdvisoryLockContextCancellation(t *testing.T) { // migrateWithLockAndCustomMigration mimics db.MigrateWithLock but accepts a custom migration function // This allows testing the lock acquisition/release pattern with controlled success/failure -func migrateWithLockAndCustomMigration(ctx context.Context, factory db.SessionFactory, migrationFunc func(*gorm.DB) error) error { +func migrateWithLockAndCustomMigration( + ctx context.Context, + factory db.SessionFactory, + migrationFunc func(*gorm.DB) error, +) error { // Acquire advisory lock for migrations (same pattern as production MigrateWithLock) ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, factory, db.MigrationsLockID, db.Migrations) if err != nil { @@ -474,7 +481,7 @@ func TestMigrationFailureUnderLock(t *testing.T) { var wg sync.WaitGroup // Create a failing migration function that signals when it acquires lock and fails - failingMigration := func(g2 *gorm.DB) error { + failingMigration := func(_ *gorm.DB) error { firstLockAcquired <- true // Wait a bit to ensure second goroutine tries to acquire time.Sleep(50 * time.Millisecond) @@ -482,7 +489,7 @@ func TestMigrationFailureUnderLock(t *testing.T) { } // Create a successful migration function - successfulMigration := func(g2 *gorm.DB) error { + successfulMigration := func(_ *gorm.DB) error { return nil }