diff --git a/cmd/aggregator/main.go b/cmd/aggregator/main.go index eec18da..aad3c4d 100644 --- a/cmd/aggregator/main.go +++ b/cmd/aggregator/main.go @@ -12,14 +12,17 @@ import ( "github.com/unicitynetwork/bft-go-base/types" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/gateway" "github.com/unicitynetwork/aggregator-go/internal/ha" "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/round" "github.com/unicitynetwork/aggregator-go/internal/service" + "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage" "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" + "github.com/unicitynetwork/aggregator-go/pkg/api" ) // gracefulExit flushes async logger and exits with the given code @@ -140,6 +143,8 @@ func main() { // Create the shared state tracker for block sync height stateTracker := state.NewSyncStateTracker() + eventBus := events.NewEventBus(log) + // Load last committed unicity certificate (can be nil for genesis) var luc *types.UnicityCertificate lastBlock, err := storageInstance.BlockStorage().GetLatest(ctx) @@ -154,8 +159,23 @@ func main() { } } + // Create SMT instance based on sharding mode + var smtInstance *smt.SparseMerkleTree + switch cfg.Sharding.Mode { + case config.ShardingModeStandalone: + smtInstance = smt.NewSparseMerkleTree(api.SHA256, 16+256) + case config.ShardingModeChild: + smtInstance = smt.NewChildSparseMerkleTree(api.SHA256, 16+256, cfg.Sharding.Child.ShardID) + case config.ShardingModeParent: + smtInstance = smt.NewParentSparseMerkleTree(api.SHA256, cfg.Sharding.ShardIDLength) + default: + log.WithComponent("main").Error("Unsupported sharding mode", "mode", cfg.Sharding.Mode) + gracefulExit(asyncLogger, 1) + } + threadSafeSmt := smt.NewThreadSafeSMT(smtInstance) + // Create round manager based on sharding mode - roundManager, err := round.NewManager(ctx, cfg, log, commitmentQueue, storageInstance, stateTracker, luc) + roundManager, err := round.NewManager(ctx, cfg, log, commitmentQueue, storageInstance, stateTracker, luc, eventBus, threadSafeSmt) if err != nil { log.WithComponent("main").Error("Failed to create round manager", "error", err.Error()) gracefulExit(asyncLogger, 1) @@ -167,23 +187,30 @@ func main() { gracefulExit(asyncLogger, 1) } - // Initialize leader selector and HA Manager if enabled + // Initialize leader selector and block syncer if enabled var ls leaderSelector - var haManager *ha.HAManager + var bs *ha.BlockSyncer if cfg.HA.Enabled { log.WithComponent("main").Info("High availability mode enabled") - ls = ha.NewLeaderElection(log, cfg.HA, storageInstance.LeadershipStorage()) + ls = ha.NewLeaderElection(log, cfg.HA, storageInstance.LeadershipStorage(), eventBus) ls.Start(ctx) // Disable block syncing for parent aggregator mode // Parent mode uses state-based SMT (current shard roots) rather than history-based (commitment leaves) - disableBlockSync := cfg.Sharding.Mode == config.ShardingModeParent - if disableBlockSync { + if cfg.Sharding.Mode == config.ShardingModeParent { log.WithComponent("main").Info("Block syncing disabled for parent aggregator mode - SMT will be reconstructed on leadership transition") + } else { + log.WithComponent("main").Info("Starting block syncer") + bs = ha.NewBlockSyncer(log, ls, storageInstance, threadSafeSmt, cfg.Sharding.Child.ShardID, cfg.Processing.RoundDuration, stateTracker) + bs.Start(ctx) } - haManager = ha.NewHAManager(log, roundManager, ls, storageInstance, roundManager.GetSMT(), cfg.Sharding.Child.ShardID, stateTracker, cfg.Processing.RoundDuration, disableBlockSync) - haManager.Start(ctx) + // In HA mode, listen for leadership changes to activate/deactivate the round manager + go func() { + if err := startLeaderChangedEventListener(ctx, log, cfg, roundManager, bs, eventBus); err != nil { + log.WithComponent("ha-listener").Error("Fatal error on leader changed event listener", "error", err.Error()) + } + }() } else { log.WithComponent("main").Info("High availability mode is disabled, running as standalone leader") // In non-HA mode, activate the round manager directly @@ -227,9 +254,9 @@ func main() { log.WithComponent("main").Error("Failed to stop server gracefully", "error", err.Error()) } - // Stop HA Manager if it was started - if haManager != nil { - haManager.Stop() + // Stop block syncer if it was started + if bs != nil { + bs.Stop() } // Stop leader selector if it was started @@ -263,6 +290,46 @@ func main() { } } +func startLeaderChangedEventListener(ctx context.Context, log *logger.Logger, cfg *config.Config, roundManager round.Manager, bs *ha.BlockSyncer, eventBus *events.EventBus) error { + log.WithComponent("ha-listener").Info("Subscribing to TopicLeaderChanged") + leaderChangedCh := eventBus.Subscribe(events.TopicLeaderChanged) + defer func() { + if err := eventBus.Unsubscribe(events.TopicLeaderChanged, leaderChangedCh); err != nil { + log.WithComponent("ha-listener").Error("Failed to unsubscribe from TopicLeaderChanged", "error", err) + } + }() + + for { + select { + case <-ctx.Done(): + return nil + case e := <-leaderChangedCh: + evt := e.(*events.LeaderChangedEvent) + log.WithComponent("ha-listener").Info("Received LeaderChangedEvent", "isLeader", evt.IsLeader) + if evt.IsLeader { + // In child and standalone mode, we must sync SMT state before starting to produce blocks + // In parent mode, the Activate call handles SMT reconstruction, no block sync needed. + if cfg.Sharding.Mode != config.ShardingModeParent { + log.WithComponent("ha-listener").Info("Becoming leader, syncing to latest block...") + if err := bs.SyncToLatestBlock(ctx); err != nil { + log.WithComponent("ha-listener").Error("failed to sync to latest block on leadership change", "error", err) + continue + } else { + log.WithComponent("ha-listener").Info("Sync complete.") + } + } + if err := roundManager.Activate(ctx); err != nil { + log.WithComponent("ha-listener").Error("Failed to activate round manager", "error", err) + } + } else { + if err := roundManager.Deactivate(ctx); err != nil { + log.WithComponent("ha-listener").Error("Failed to deactivate round manager", "error", err) + } + } + } + } +} + type leaderSelector interface { IsLeader(ctx context.Context) (bool, error) Start(ctx context.Context) diff --git a/docker-compose.yml b/docker-compose.yml index 50458e6..4bbfec3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -182,7 +182,7 @@ services: BFT_ENABLED: "true" BFT_KEY_CONF_FILE: "/app/bft-config/aggregator/keys.json" BFT_SHARD_CONF_FILE: "/app/bft-config/shard-conf-7_0.json" - BFT_TRUST_BASE_FILE: "/app/bft-config/trust-base.json" + BFT_TRUST_BASE_FILES: "/app/bft-config/trust-base.json" # BFT_BOOTSTRAP_ADDRESSES will be set dynamically by the entrypoint script entrypoint: ["/bin/sh", "-c"] diff --git a/ha-compose.yml b/ha-compose.yml index 016d1d3..352b88b 100644 --- a/ha-compose.yml +++ b/ha-compose.yml @@ -28,7 +28,7 @@ services: ubft trust-base sign --home /genesis/root --trust-base /genesis/trust-base.json fi echo "Starting root node..." && - ubft root-node run --home /genesis/root --address "/ip4/$(hostname -i)/tcp/8000" --trust-base /genesis/trust-base.json --rpc-server-address "$(hostname -i):8002" && + ubft root-node run --home /genesis/root --address "/ip4/$(hostname -i)/tcp/8000" --trust-base /genesis/trust-base.json --rpc-server-address "$(hostname -i):8002" --log-level debug && ls -l /genesis/root echo "Root node started successfully." @@ -100,7 +100,7 @@ services: mongo1: image: mongo:7.0 - container_name: mongo-1 + container_name: mongo1 command: ["--replSet", "rs0", "--bind_ip_all", "--noauth"] user: "${USER_UID:-1001}:${USER_GID:-1001}" networks: @@ -116,7 +116,7 @@ services: mongo2: image: mongo:7.0 - container_name: mongo-2 + container_name: mongo2 command: ["--replSet", "rs0", "--bind_ip_all", "--noauth"] user: "${USER_UID:-1001}:${USER_GID:-1001}" networks: @@ -132,7 +132,7 @@ services: mongo3: image: mongo:7.0 - container_name: mongo-3 + container_name: mongo3 command: ["--replSet", "rs0", "--bind_ip_all", "--noauth"] user: "${USER_UID:-1001}:${USER_GID:-1001}" networks: @@ -225,8 +225,8 @@ services: # Database Configuration MONGODB_URI: "mongodb://mongo1:27017,mongo2:27017,mongo3:27017/aggregator?replicaSet=rs0" MONGODB_DATABASE: "aggregator" - MONGODB_CONNECT_TIMEOUT: "10s" - MONGODB_SERVER_SELECTION_TIMEOUT: "5s" + MONGODB_CONNECT_TIMEOUT: "30s" + MONGODB_SERVER_SELECTION_TIMEOUT: "30s" # Redis Configuration REDIS_HOST: "redis" @@ -261,7 +261,7 @@ services: BFT_ENABLED: "true" BFT_KEY_CONF_FILE: "/app/bft-config/aggregator/keys.json" BFT_SHARD_CONF_FILE: "/app/bft-config/shard-conf-7_0.json" - BFT_TRUST_BASE_FILE: "/app/bft-config/trust-base.json" + BFT_TRUST_BASE_FILES: "/app/bft-config/trust-base.json" # BFT_BOOTSTRAP_ADDRESSES will be set dynamically by the entrypoint script entrypoint: ["/bin/sh", "-c"] command: diff --git a/internal/bft/client.go b/internal/bft/client.go index 5d638a4..296c3f8 100644 --- a/internal/bft/client.go +++ b/internal/bft/client.go @@ -20,6 +20,7 @@ import ( "github.com/unicitynetwork/bft-go-base/types" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/models" "github.com/unicitynetwork/aggregator-go/pkg/api" @@ -43,7 +44,7 @@ const ( normal ) -// BFTRootChainClient handles communication with the BFT root chain via P2P network +// BFTClientImpl handles communication with the BFT root chain via P2P network type ( BFTClientImpl struct { conf *config.BFTConfig @@ -51,6 +52,7 @@ type ( partitionID types.PartitionID shardID types.ShardID logger *logger.Logger + eventBus *events.EventBus // mutex for peer, network, signer TODO: there are readers without mutex mu sync.Mutex @@ -97,7 +99,15 @@ type ( status int ) -func NewBFTClient(conf *config.BFTConfig, roundManager RoundManager, trustBaseStore TrustBaseStore, luc *types.UnicityCertificate, logger *logger.Logger) (*BFTClientImpl, error) { +func NewBFTClient( + ctx context.Context, + conf *config.BFTConfig, + roundManager RoundManager, + trustBaseStore TrustBaseStore, + luc *types.UnicityCertificate, + logger *logger.Logger, + eventBus *events.EventBus, +) (*BFTClientImpl, error) { logger.Info("Creating BFT Client") bftClient := &BFTClientImpl{ logger: logger, @@ -106,6 +116,7 @@ func NewBFTClient(conf *config.BFTConfig, roundManager RoundManager, trustBaseSt roundManager: roundManager, trustBaseStore: trustBaseStore, conf: conf, + eventBus: eventBus, } bftClient.status.Store(idle) bftClient.luc.Store(luc) @@ -136,6 +147,10 @@ func (c *BFTClientImpl) Start(ctx context.Context) error { c.mu.Lock() defer c.mu.Unlock() + if c.status.Load().(status) != idle { + c.logger.WithContext(ctx).Warn("BFT Client is not idle, skipping start") + return nil + } c.status.Store(initializing) peerConf, err := c.conf.PeerConf() @@ -167,6 +182,7 @@ func (c *BFTClientImpl) Start(ctx context.Context) error { msgLoopCtx, cancelFn := context.WithCancel(ctx) c.msgLoopCancelFn = cancelFn go func() { + c.logger.WithContext(ctx).Info("BFT client event loop started") if err := c.loop(msgLoopCtx); err != nil { c.logger.Error("BFT event loop thread exited with error", "error", err.Error()) } else { @@ -182,10 +198,16 @@ func (c *BFTClientImpl) Stop() { c.mu.Lock() defer c.mu.Unlock() + if c.status.Load().(status) == idle { + c.logger.Warn("BFT Client is already idle, skipping stop") + return + } + c.status.Store(idle) if c.msgLoopCancelFn != nil { c.msgLoopCancelFn() + c.msgLoopCancelFn = nil } if c.peer != nil { if err := c.peer.Close(); err != nil { diff --git a/internal/events/event_bus.go b/internal/events/event_bus.go new file mode 100644 index 0000000..3f14a5f --- /dev/null +++ b/internal/events/event_bus.go @@ -0,0 +1,78 @@ +package events + +import ( + "fmt" + "slices" + "sync" + + "github.com/unicitynetwork/aggregator-go/internal/logger" +) + +type ( + Event interface{} + + Topic string + + EventBus struct { + logger *logger.Logger + + mu sync.RWMutex + subscribers map[Topic][]chan Event + } +) + +func NewEventBus(log *logger.Logger) *EventBus { + return &EventBus{ + logger: log, + subscribers: make(map[Topic][]chan Event), + } +} + +// Subscribe creates a channel, adds it to the subscribers list, and returns it to the caller. +func (bus *EventBus) Subscribe(topic Topic) <-chan Event { + bus.mu.Lock() + defer bus.mu.Unlock() + + ch := make(chan Event, 1) + bus.subscribers[topic] = append(bus.subscribers[topic], ch) + return ch +} + +// Publish sends the event to all subscribers. +// If the subscriber is busy then the event is dropped. +func (bus *EventBus) Publish(topic Topic, event Event) { + bus.mu.RLock() + defer bus.mu.RUnlock() + + subscribers, found := bus.subscribers[topic] + if !found { + bus.logger.Warn("Event not published, no subscriber found", "topic", topic) + return + } + for _, sub := range subscribers { + select { + case sub <- event: + default: + bus.logger.Warn("Dropped event for a slow subscriber", "topic", topic, "event", event) + } + } +} + +// Unsubscribe removes the subscriber from the subscribers list, +// returns error if the provided topic does not exist or the subscriber was not found. +func (bus *EventBus) Unsubscribe(topic Topic, sub <-chan Event) error { + bus.mu.Lock() + defer bus.mu.Unlock() + + subs, found := bus.subscribers[topic] + if !found { + return fmt.Errorf("topic not found: %s", topic) + } + for i, s := range subs { + if s == sub { + bus.subscribers[topic] = slices.Delete(subs, i, i+1) + return nil + } + } + return fmt.Errorf("subscriber not found for topic: %s", topic) +} diff --git a/internal/events/event_bus_test.go b/internal/events/event_bus_test.go new file mode 100644 index 0000000..eca44eb --- /dev/null +++ b/internal/events/event_bus_test.go @@ -0,0 +1,114 @@ +package events + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/unicitynetwork/aggregator-go/internal/logger" +) + +func TestEventBusPubSub(t *testing.T) { + bus := NewEventBus(newLogger(t)) + + // Subscribe to the event topic + ch := bus.Subscribe(TopicTest) + + // Publish an event + event := TestEvent{} + bus.Publish(TopicTest, event) + + // Verify the event was received + select { + case e := <-ch: + require.Equal(t, event, e) + case <-time.After(100 * time.Millisecond): + require.Fail(t, "Event was not received within timeout") + } +} + +func TestEventBusPublishToMultipleSubscribers(t *testing.T) { + bus := NewEventBus(newLogger(t)) + + // Subscribe to the event + ch1 := bus.Subscribe(TopicTest) + ch2 := bus.Subscribe(TopicTest) + + // Create and publish an event + event := TestEvent{} + bus.Publish(TopicTest, event) + + // Verify both subscribers received the event + select { + case e := <-ch1: + require.Equal(t, event, e) + case <-time.After(100 * time.Millisecond): + require.Fail(t, "First subscriber did not receive event within timeout") + } + + select { + case e := <-ch2: + require.Equal(t, event, e) + case <-time.After(100 * time.Millisecond): + require.Fail(t, "Second subscriber did not receive event within timeout") + } +} + +func TestEventBusUnsubscribe(t *testing.T) { + bus := NewEventBus(newLogger(t)) + + // Subscribe to the event topic + ch1 := bus.Subscribe(TopicTest) + ch2 := bus.Subscribe(TopicTest) + ch3 := bus.Subscribe(TopicTest) + + // Unsubscribe ch2 + err := bus.Unsubscribe(TopicTest, ch2) + require.NoError(t, err) + + // Try to unsubscribe ch2 again - should return error + err = bus.Unsubscribe(TopicTest, ch2) + require.ErrorContains(t, err, "subscriber not found") + + // Try to unsubscribe ch2 on wrong topic - should return error + err = bus.Unsubscribe("invalid-topic", ch2) + require.ErrorContains(t, err, "topic not found") + + // Publish an event + event := TestEvent{} + bus.Publish(TopicTest, event) + + // Verify ch1 and ch3 receive the event + select { + case e := <-ch1: + require.Equal(t, event, e) + case <-time.After(100 * time.Millisecond): + require.Fail(t, "ch1 should have received the event") + } + + select { + case e := <-ch3: + require.Equal(t, event, e) + case <-time.After(100 * time.Millisecond): + require.Fail(t, "ch3 should have received the event") + } + + // Verify ch2 did not receive the event + select { + case <-ch2: + require.Fail(t, "ch2 should not have received the event") + default: + } +} + +func newLogger(t *testing.T) *logger.Logger { + testLogger, err := logger.New("info", "text", "stdout", false) + require.NoError(t, err) + return testLogger +} + +const TopicTest Topic = "test_event" + +type TestEvent struct { +} diff --git a/internal/events/events.go b/internal/events/events.go new file mode 100644 index 0000000..4777fcc --- /dev/null +++ b/internal/events/events.go @@ -0,0 +1,11 @@ +package events + +const ( + TopicLeaderChanged Topic = "leaderChanged" +) + +// LeaderChangedEvent is published when the node becomes a leader +// or when the node loses leader status and becomes a follower. +type LeaderChangedEvent struct { + IsLeader bool +} diff --git a/internal/ha/block_syncer.go b/internal/ha/block_syncer.go index 228e6fc..21d65a2 100644 --- a/internal/ha/block_syncer.go +++ b/internal/ha/block_syncer.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "math/big" + "sync" + "time" "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" @@ -13,33 +15,99 @@ import ( "github.com/unicitynetwork/aggregator-go/pkg/api" ) -// blockSyncer helper struct to update the SMT with data from commited blocks. -type blockSyncer struct { - logger *logger.Logger - storage interfaces.Storage - smt *smt.ThreadSafeSMT - shardID api.ShardID - stateTracker *state.Tracker +type ( + LeaderSelector interface { + IsLeader(ctx context.Context) (bool, error) + } + + // BlockSyncer updates the node's state tree using the blocks from storage, if in follower mode. + // Needs to be started with the Start method and stopped with the Stop method. + // Should not be started in standalone mode. + BlockSyncer struct { + logger *logger.Logger + leaderSelector LeaderSelector + storage interfaces.Storage + smt *smt.ThreadSafeSMT + shardID api.ShardID + syncInterval time.Duration + stateTracker *state.Tracker + + wg sync.WaitGroup + cancel context.CancelFunc + } +) + +func NewBlockSyncer( + logger *logger.Logger, + leaderSelector LeaderSelector, + storage interfaces.Storage, + smt *smt.ThreadSafeSMT, + shardID api.ShardID, + syncInterval time.Duration, + stateTracker *state.Tracker, +) *BlockSyncer { + return &BlockSyncer{ + logger: logger, + leaderSelector: leaderSelector, + storage: storage, + smt: smt, + shardID: shardID, + syncInterval: syncInterval, + stateTracker: stateTracker, + } +} + +func (bs *BlockSyncer) Start(ctx context.Context) { + ctx, bs.cancel = context.WithCancel(ctx) + bs.wg.Go(func() { + bs.runLoop(ctx) + }) } -func newBlockSyncer(logger *logger.Logger, storage interfaces.Storage, smt *smt.ThreadSafeSMT, shardID api.ShardID, stateTracker *state.Tracker) *blockSyncer { - return &blockSyncer{ - logger: logger, - storage: storage, - smt: smt, - shardID: shardID, - stateTracker: stateTracker, +func (bs *BlockSyncer) Stop() { + if bs.cancel != nil { + bs.cancel() + bs.cancel = nil } + bs.wg.Wait() +} + +func (bs *BlockSyncer) runLoop(ctx context.Context) { + ticker := time.NewTicker(bs.syncInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := bs.onTick(ctx); err != nil { + bs.logger.WithContext(ctx).Error("error on block sync tick", "error", err.Error()) + } + } + } +} + +func (bs *BlockSyncer) onTick(ctx context.Context) error { + isLeader, err := bs.leaderSelector.IsLeader(ctx) + if err != nil { + return fmt.Errorf("failed to query leader status: %w", err) + } + if !isLeader { + if err := bs.SyncToLatestBlock(ctx); err != nil { + return fmt.Errorf("failed to sync smt to latest block: %w", err) + } + } + return nil } -func (bs *blockSyncer) syncToLatestBlock(ctx context.Context) error { +func (bs *BlockSyncer) SyncToLatestBlock(ctx context.Context) error { // fetch last synced smt block number and last stored block number currBlock := bs.stateTracker.GetLastSyncedBlock() endBlock, err := bs.getLastStoredBlockRecordNumber(ctx) if err != nil { return fmt.Errorf("failed to fetch last stored block number: %w", err) } - bs.logger.WithContext(ctx).Debug("block sync", "from", currBlock, "to", endBlock) for currBlock.Cmp(endBlock) < 0 { // fetch the next block record b, err := bs.storage.BlockRecordsStorage().GetNextBlock(ctx, api.NewBigInt(currBlock)) @@ -52,7 +120,7 @@ func (bs *blockSyncer) syncToLatestBlock(ctx context.Context) error { // skip empty blocks if len(b.RequestIDs) == 0 { - bs.logger.WithContext(ctx).Debug("skipping block sync (empty block)", "nextBlock", b.BlockNumber.String()) + bs.logger.WithContext(ctx).Debug("skipping block sync (empty block)", "blockNumber", b.BlockNumber.String()) currBlock = b.BlockNumber.Int bs.stateTracker.SetLastSyncedBlock(currBlock) continue @@ -71,7 +139,7 @@ func (bs *blockSyncer) syncToLatestBlock(ctx context.Context) error { return nil } -func (bs *blockSyncer) verifySMTForBlock(ctx context.Context, smtRootHash string, blockNumber *api.BigInt) error { +func (bs *BlockSyncer) verifySMTForBlock(ctx context.Context, smtRootHash string, blockNumber *api.BigInt) error { block, err := bs.storage.BlockStorage().GetByNumber(ctx, blockNumber) if err != nil { return fmt.Errorf("failed to fetch block: %w", err) @@ -87,7 +155,7 @@ func (bs *blockSyncer) verifySMTForBlock(ctx context.Context, smtRootHash string return nil } -func (bs *blockSyncer) updateSMTForBlock(ctx context.Context, blockRecord *models.BlockRecords) error { +func (bs *BlockSyncer) updateSMTForBlock(ctx context.Context, blockRecord *models.BlockRecords) error { // build leaf ids while filtering duplicate blockRecord.RequestIDs uniqueRequestIds := make(map[string]struct{}, len(blockRecord.RequestIDs)) leafIDs := make([]api.HexBytes, 0, len(blockRecord.RequestIDs)) @@ -132,7 +200,7 @@ func (bs *blockSyncer) updateSMTForBlock(ctx context.Context, blockRecord *model return nil } -func (bs *blockSyncer) getLastStoredBlockRecordNumber(ctx context.Context) (*big.Int, error) { +func (bs *BlockSyncer) getLastStoredBlockRecordNumber(ctx context.Context) (*big.Int, error) { // Use BlockStorage which filters on finalized=true // This ensures we only sync up to the latest finalized block latestNumber, err := bs.storage.BlockStorage().GetLatestNumber(ctx) diff --git a/internal/ha/block_syncer_test.go b/internal/ha/block_syncer_test.go index 80b26cc..4e14be6 100644 --- a/internal/ha/block_syncer_test.go +++ b/internal/ha/block_syncer_test.go @@ -3,6 +3,7 @@ package ha import ( "context" "math/big" + "sync/atomic" "testing" "time" @@ -18,7 +19,16 @@ import ( "github.com/unicitynetwork/aggregator-go/pkg/api" ) -func TestBlockSync(t *testing.T) { +type mockLeaderSelector struct { + isLeader atomic.Bool +} + +func (m *mockLeaderSelector) IsLeader(_ context.Context) (bool, error) { + return m.isLeader.Load(), nil +} + +func TestBlockSyncer(t *testing.T) { + ctx := t.Context() storage := testutil.SetupTestStorage(t, config.Config{ Database: config.DatabaseConfig{ Database: "test_block_sync", @@ -31,30 +41,46 @@ func TestBlockSync(t *testing.T) { }, }) - ctx := context.Background() + cfg := &config.Config{ + Processing: config.ProcessingConfig{RoundDuration: 100 * time.Millisecond}, + HA: config.HAConfig{Enabled: true}, + BFT: config.BFTConfig{Enabled: false}, + } testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - // create block syncer - smtInstance := smt.NewSparseMerkleTree(api.SHA256, 16+256) - threadSafeSMT := smt.NewThreadSafeSMT(smtInstance) + // initialize block syncer with isLeader=false + mockLeader := &mockLeaderSelector{} + smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) stateTracker := state.NewSyncStateTracker() - syncer := newBlockSyncer(testLogger, storage, threadSafeSMT, 0, stateTracker) + syncer := NewBlockSyncer(testLogger, mockLeader, storage, smtInstance, 0, cfg.Processing.RoundDuration, stateTracker) // simulate leader creating a block - rootHash := createBlock(ctx, t, storage) + rootHash := createBlock(t, storage, 1) - // trigger the sync - err = syncer.syncToLatestBlock(ctx) - require.NoError(t, err) + // start the block syncer + syncer.Start(ctx) + defer syncer.Stop() + + // wait for block syncer to start + time.Sleep(2 * cfg.Processing.RoundDuration) // SMT root hash should match persisted block root hash after block sync - require.Equal(t, rootHash.String(), threadSafeSMT.GetRootHash()) + require.Equal(t, rootHash.String(), smtInstance.GetRootHash()) require.Equal(t, big.NewInt(1), stateTracker.GetLastSyncedBlock()) + + // verify the blocks are not synced if node is leader + mockLeader.isLeader.Store(true) + createBlock(t, storage, 2) + time.Sleep(2 * cfg.Processing.RoundDuration) + require.Equal(t, rootHash.String(), smtInstance.GetRootHash()) + require.Equal(t, big.NewInt(1), stateTracker.GetLastSyncedBlock()) + } -func createBlock(ctx context.Context, t *testing.T, storage *mongodb.Storage) api.HexBytes { - blockNumber := api.NewBigInt(big.NewInt(1)) +func createBlock(t *testing.T, storage *mongodb.Storage, blockNum int64) api.HexBytes { + ctx := t.Context() + blockNumber := api.NewBigInt(big.NewInt(blockNum)) testCommitments := []*models.Commitment{ testutil.CreateTestCommitment(t, "request_1"), testutil.CreateTestCommitment(t, "request_2"), diff --git a/internal/ha/ha_manager.go b/internal/ha/ha_manager.go deleted file mode 100644 index 38598f8..0000000 --- a/internal/ha/ha_manager.go +++ /dev/null @@ -1,143 +0,0 @@ -package ha - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/unicitynetwork/aggregator-go/internal/ha/state" - "github.com/unicitynetwork/aggregator-go/internal/logger" - "github.com/unicitynetwork/aggregator-go/internal/smt" - "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" - "github.com/unicitynetwork/aggregator-go/pkg/api" -) - -type ( - // Activatable defines a service that can be started and stopped - // based on HA leadership status. - Activatable interface { - // Activate is called when the node becomes the leader. - Activate(ctx context.Context) error - - // Deactivate is called when the node loses leadership. - Deactivate(ctx context.Context) error - } - - LeaderSelector interface { - IsLeader(ctx context.Context) (bool, error) - } - - // HAManager keeps track of node's leadership status, - // calls the provided Activatable callback when leadership status changes, - // and updates the follower node's SMT state using the BlockSyncer. - // Needs to be started with the Start method and stopped with the Stop method. - HAManager struct { - logger *logger.Logger - leaderSelector LeaderSelector - blockSyncer *blockSyncer // Optional: nil when block syncing is disabled - activatable Activatable - syncInterval time.Duration - - wg sync.WaitGroup - cancel context.CancelFunc - } -) - -func NewHAManager(logger *logger.Logger, - activatable Activatable, - leaderSelector LeaderSelector, - storage interfaces.Storage, - smt *smt.ThreadSafeSMT, - shardID api.ShardID, - stateTracker *state.Tracker, - syncInterval time.Duration, - disableBlockSync bool, // Set true for parent mode where block syncing is not needed -) *HAManager { - var syncer *blockSyncer - if !disableBlockSync { - syncer = newBlockSyncer(logger, storage, smt, shardID, stateTracker) - } - - return &HAManager{ - logger: logger, - leaderSelector: leaderSelector, - blockSyncer: syncer, - activatable: activatable, - syncInterval: syncInterval, - } -} - -func (ham *HAManager) Start(ctx context.Context) { - ctx, ham.cancel = context.WithCancel(ctx) - ham.wg.Add(1) - go func() { - defer ham.wg.Done() - ham.runLoop(ctx) - }() -} - -func (ham *HAManager) Stop() { - if ham.cancel != nil { - ham.cancel() - ham.cancel = nil - } - ham.wg.Wait() -} - -func (ham *HAManager) runLoop(ctx context.Context) { - ticker := time.NewTicker(ham.syncInterval) - defer ticker.Stop() - - var wasLeader bool - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - ham.logger.WithContext(ctx).Debug("on block sync tick") - isLeader, err := ham.onTick(ctx, wasLeader) - if err != nil { - ham.logger.WithContext(ctx).Warn("failed to sync block", "err", err.Error()) - continue - } - ham.logger.WithContext(ctx).Debug("block sync tick finished") - wasLeader = isLeader - } - } -} - -func (ham *HAManager) onTick(ctx context.Context, wasLeader bool) (bool, error) { - isLeader, err := ham.leaderSelector.IsLeader(ctx) - if err != nil { - return wasLeader, fmt.Errorf("error on leader selection: %w", err) - } - // nothing to do if still leader - if isLeader && wasLeader { - ham.logger.WithContext(ctx).Debug("leader is already being synced") - return isLeader, nil - } - - // Only sync blocks if blockSyncer is enabled (regular aggregator mode) - if ham.blockSyncer != nil { - if err := ham.blockSyncer.syncToLatestBlock(ctx); err != nil { - // Log the error but continue, as we might still need to handle a leadership change. - ham.logger.Error("failed to sync smt to latest block", "error", err) - } - } else { - ham.logger.WithContext(ctx).Debug("block syncing disabled (parent mode), skipping SMT sync") - } - - if !wasLeader && isLeader { - ham.logger.Info("Transitioning to LEADER") - if err := ham.activatable.Activate(ctx); err != nil { - return isLeader, fmt.Errorf("failed onActivate transition: %w", err) - } - } else if wasLeader && !isLeader { - ham.logger.Info("Transitioning to FOLLOWER") - if err := ham.activatable.Deactivate(ctx); err != nil { - return isLeader, fmt.Errorf("failed onDeactivate transition: %w", err) - } - } - return isLeader, nil -} diff --git a/internal/ha/ha_manager_test.go b/internal/ha/ha_manager_test.go deleted file mode 100644 index 4573e50..0000000 --- a/internal/ha/ha_manager_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package ha - -import ( - "context" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/unicitynetwork/aggregator-go/internal/config" - "github.com/unicitynetwork/aggregator-go/internal/ha/state" - "github.com/unicitynetwork/aggregator-go/internal/logger" - "github.com/unicitynetwork/aggregator-go/internal/smt" - "github.com/unicitynetwork/aggregator-go/internal/testutil" - "github.com/unicitynetwork/aggregator-go/pkg/api" -) - -type mockLeaderSelector struct { - isLeader atomic.Bool -} - -func (m *mockLeaderSelector) IsLeader(ctx context.Context) (bool, error) { - return m.isLeader.Load(), nil -} - -type mockActivatable struct { - activateCalled atomic.Int32 - deactivateCalled atomic.Int32 -} - -func newMockActivatable() *mockActivatable { - return &mockActivatable{} -} - -func (m *mockActivatable) Activate(_ context.Context) error { - m.activateCalled.Add(1) - return nil -} - -func (m *mockActivatable) Deactivate(_ context.Context) error { - m.deactivateCalled.Add(1) - return nil -} - -func TestHAManager(t *testing.T) { - storage := testutil.SetupTestStorage(t, config.Config{ - Database: config.DatabaseConfig{ - Database: "test_block_sync", - ConnectTimeout: 30 * time.Second, - ServerSelectionTimeout: 5 * time.Second, - SocketTimeout: 30 * time.Second, - MaxPoolSize: 100, - MinPoolSize: 5, - MaxConnIdleTime: 5 * time.Minute, - }, - }) - - ctx := context.Background() - cfg := &config.Config{ - Processing: config.ProcessingConfig{RoundDuration: 100 * time.Millisecond}, - HA: config.HAConfig{Enabled: true}, - BFT: config.BFTConfig{Enabled: false}, - } - testLogger, err := logger.New("info", "text", "stdout", false) - require.NoError(t, err) - - // initialize HA manger with isLeader=false - mockLeader := &mockLeaderSelector{} - mockLeader.isLeader.Store(false) - callback := newMockActivatable() - smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) - stateTracker := state.NewSyncStateTracker() - disableBlockSync := false - ham := NewHAManager(testLogger, callback, mockLeader, storage, smtInstance, 0, stateTracker, cfg.Processing.RoundDuration, disableBlockSync) - - // verify Activate/Deactivate has not been called initially - require.Equal(t, int32(0), callback.activateCalled.Load(), "Activate should not be called initially") - require.Equal(t, int32(0), callback.deactivateCalled.Load(), "Deactivate should not be called initially") - - // start the HA manager - ham.Start(ctx) - defer ham.Stop() - - // wait for HA manager to start - time.Sleep(2 * cfg.Processing.RoundDuration) - require.Equal(t, int32(0), callback.activateCalled.Load(), "Activate should not be called if not leader after initial ticks") - require.Equal(t, int32(0), callback.deactivateCalled.Load(), "Deactivate should not be called if not leader after initial ticks") - - // set IsLeader to true and verify Activate is called - mockLeader.isLeader.Store(true) - require.Eventually(t, func() bool { - return callback.activateCalled.Load() == 1 - }, 5*cfg.Processing.RoundDuration, 100*time.Millisecond, "Activate should be called once when becoming leader") - require.Equal(t, int32(0), callback.deactivateCalled.Load(), "Deactivate count should not change if IsLeader goes false -> true") - - // set IsLeader to false and verify Deactivate is called - mockLeader.isLeader.Store(false) - require.Eventually(t, func() bool { - return callback.deactivateCalled.Load() == 1 - }, 5*cfg.Processing.RoundDuration, 100*time.Millisecond, "Deactivate should be called once when losing leadership") - require.Equal(t, int32(1), callback.activateCalled.Load(), "Activate count should not change if IsLeader goes true -> false") - - // ensure no further unexpected calls - time.Sleep(2 * cfg.Processing.RoundDuration) - require.Equal(t, int32(1), callback.activateCalled.Load(), "Activate should not be called again unexpectedly") - require.Equal(t, int32(1), callback.deactivateCalled.Load(), "Deactivate should not be called again unexpectedly") -} diff --git a/internal/ha/leader_election.go b/internal/ha/leader_election.go index 787a5fd..fd875fa 100644 --- a/internal/ha/leader_election.go +++ b/internal/ha/leader_election.go @@ -4,14 +4,18 @@ import ( "context" "fmt" "sync" + "sync/atomic" "time" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" ) -// LeaderElection manages the HA leader election process +// LeaderElection manages the HA leader election process. +// It polls the db for leadership lock and publishes leadership +// transition events on the provided EventBus. type LeaderElection struct { log *logger.Logger storage interfaces.LeadershipStorage @@ -19,12 +23,15 @@ type LeaderElection struct { serverID string heartbeatInterval time.Duration electionPollingInterval time.Duration + eventBus *events.EventBus wg sync.WaitGroup // election polling thread wg cancel context.CancelFunc // election polling thread cancel signal + + isLeader atomic.Bool // cached leader flag } -func NewLeaderElection(log *logger.Logger, cfg config.HAConfig, storage interfaces.LeadershipStorage) *LeaderElection { +func NewLeaderElection(log *logger.Logger, cfg config.HAConfig, storage interfaces.LeadershipStorage, eventBus *events.EventBus) *LeaderElection { return &LeaderElection{ log: log, storage: storage, @@ -32,14 +39,15 @@ func NewLeaderElection(log *logger.Logger, cfg config.HAConfig, storage interfac serverID: cfg.ServerID, heartbeatInterval: cfg.LeaderHeartbeatInterval, electionPollingInterval: cfg.LeaderElectionPollingInterval, + eventBus: eventBus, } } -func (le *LeaderElection) IsLeader(ctx context.Context) (bool, error) { - return le.storage.IsLeader(ctx, le.lockID, le.serverID) +func (le *LeaderElection) IsLeader(_ context.Context) (bool, error) { + return le.isLeader.Load(), nil } -// Start stars the election polling +// Start starts the election polling. func (le *LeaderElection) Start(ctx context.Context) { ctx, cancel := context.WithCancel(ctx) le.cancel = cancel @@ -105,11 +113,17 @@ func (le *LeaderElection) startElectionPolling(ctx context.Context) { continue // keep trying to acquire the lock } if acquired { - le.log.WithComponent("leader-election").Info("Acquired leadership, starting heartbeat", "serverID", le.serverID) + le.log.WithComponent("leader-election").Info("Acquired leadership, publishing LeaderChangedEvent and starting heartbeat", "serverID", le.serverID) + le.isLeader.Store(true) + le.eventBus.Publish(events.TopicLeaderChanged, &events.LeaderChangedEvent{IsLeader: true}) + if err := le.startHeartbeat(ctx); err != nil { le.log.WithComponent("leader-election").Error("Error during heartbeat attempt", "error", err) } - le.log.WithComponent("leader-election").Info("Lost leadership, returning to polling", "serverID", le.serverID) + + le.log.WithComponent("leader-election").Info("Lost leadership, publishing LeaderChangedEvent and returning to polling", "serverID", le.serverID) + le.isLeader.Store(false) + le.eventBus.Publish(events.TopicLeaderChanged, &events.LeaderChangedEvent{IsLeader: false}) } } } diff --git a/internal/ha/leader_election_test.go b/internal/ha/leader_election_test.go index b904378..4fecc11 100644 --- a/internal/ha/leader_election_test.go +++ b/internal/ha/leader_election_test.go @@ -1,14 +1,13 @@ package ha import ( - "context" "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/testutil" ) @@ -33,33 +32,41 @@ var conf = config.Config{ } func TestLeaderElection_LockContention(t *testing.T) { + ctx := t.Context() storage := testutil.SetupTestStorage(t, conf) + leadershipStorage := storage.LeadershipStorage() log, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) + eventBus := events.NewEventBus(log) + // setup server 1 le1Config := conf.HA le1Config.ServerID = "server-1" - le1 := NewLeaderElection(log, le1Config, storage.LeadershipStorage()) - defer le1.Stop(context.Background()) + + le1 := NewLeaderElection(log, le1Config, leadershipStorage, eventBus) + defer le1.Stop(ctx) // setup server 2 le2Config := conf.HA le2Config.ServerID = "server-2" - le2 := NewLeaderElection(log, le2Config, storage.LeadershipStorage()) - defer le2.Stop(context.Background()) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + le2 := NewLeaderElection(log, le2Config, leadershipStorage, eventBus) + defer le2.Stop(ctx) // start server 1 and wait for startup + leaderChangedCh := eventBus.Subscribe(events.TopicLeaderChanged) le1.Start(ctx) - assert.Eventually(t, func() bool { - isLeader, err := le1.IsLeader(ctx) - require.NoError(t, err) - return isLeader - }, 2*time.Second, 50*time.Millisecond, "server 1 should become leader") + select { + case e := <-leaderChangedCh: + evt := e.(*events.LeaderChangedEvent) + require.True(t, evt.IsLeader) + case <-time.After(time.Second): + require.Fail(t, "LeaderChangedEvent not received") + } + isLeader, err := le1.IsLeader(ctx) + require.NoError(t, err) + require.True(t, isLeader) // start server 2 le2.Start(ctx) @@ -67,61 +74,84 @@ func TestLeaderElection_LockContention(t *testing.T) { // verify server 1 is still the leader isLeader1, err := le1.IsLeader(ctx) - assert.NoError(t, err) - assert.True(t, isLeader1, "server 1 should remain the leader") + require.NoError(t, err) + require.True(t, isLeader1, "server 1 should remain the leader") // verify server 2 is NOT the leader isLeader2, err := le2.IsLeader(ctx) - assert.NoError(t, err) - assert.False(t, isLeader2, "server 2 should not become leader while server 1 is active") + require.NoError(t, err) + require.False(t, isLeader2, "server 2 should not become leader while server 1 is active") } func TestLeaderElection_Failover(t *testing.T) { + ctx := t.Context() storage := testutil.SetupTestStorage(t, conf) + leadershipStorage := storage.LeadershipStorage() log, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) + eventBus := events.NewEventBus(log) + // setup server 1 with slower heartbeat than the TTL le1Config := conf.HA le1Config.ServerID = "server-1" le1Config.LeaderHeartbeatInterval = 2 * time.Second // slower heartbeat that TTL - le1 := NewLeaderElection(log, le1Config, storage.LeadershipStorage()) - defer le1.Stop(context.Background()) + + le1 := NewLeaderElection(log, le1Config, leadershipStorage, eventBus) + defer le1.Stop(ctx) // setup server 2 with normal heartbeat le2Config := conf.HA le2Config.ServerID = "server-2" - le2 := NewLeaderElection(log, le2Config, storage.LeadershipStorage()) - defer le2.Stop(context.Background()) + le2 := NewLeaderElection(log, le2Config, leadershipStorage, eventBus) + defer le2.Stop(ctx) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // start server 1 and wait for it to become leader + // start server 1 + leaderChangedCh := eventBus.Subscribe(events.TopicLeaderChanged) le1.Start(ctx) - assert.Eventually(t, func() bool { - isLeader, err := le1.IsLeader(ctx) - require.NoError(t, err) - return isLeader - }, 2*time.Second, 50*time.Millisecond, "server 1 should become leader") - // start server 2 (initially cannot get the lock) + // wait for server 1 to become leader + select { + case e := <-leaderChangedCh: + evt := e.(*events.LeaderChangedEvent) + require.True(t, evt.IsLeader) + case <-time.After(time.Second): + require.Fail(t, "LeaderChangedEvent not received for server 1 becoming leader") + } + + // verify server 1 cached flag is updated + isLeader, err := le1.IsLeader(ctx) + require.NoError(t, err) + require.True(t, isLeader) + + // start server 2 le2.Start(ctx) - time.Sleep(200 * time.Millisecond) - isLeader2, err := le2.IsLeader(ctx) - assert.NoError(t, err) - assert.False(t, isLeader2, "server 2 should not become leader while server 1 is active") // wait long enough for server 1 lock to expire and server 2 to acquire it - assert.Eventually(t, func() bool { - isLeader2, err := le2.IsLeader(ctx) - require.NoError(t, err) - return isLeader2 - }, 5*time.Second, 100*time.Millisecond, "server 2 should take over leadership after server 1 misses heartbeat") + select { + case e := <-leaderChangedCh: + evt := e.(*events.LeaderChangedEvent) + require.True(t, evt.IsLeader, "server 2 should have become leader") + case <-time.After(2 * time.Second): + require.Fail(t, "did not receive LeaderChangedEvent for server 2 taking over") + } + + isLeader2, err := le2.IsLeader(ctx) + require.NoError(t, err) + require.True(t, isLeader2) + + // wait for server 1 heartbeat to fail and lose leadership (2s in config) + select { + case e := <-leaderChangedCh: + evt := e.(*events.LeaderChangedEvent) + require.False(t, evt.IsLeader, "server 1 should have lost leadership") + case <-time.After(2 * time.Second): + require.Fail(t, "did not receive LeaderChangedEvent for server 1 losing leadership") + } // confirm server 1 is no longer leader isLeader1, err := le1.IsLeader(ctx) - assert.NoError(t, err) - assert.False(t, isLeader1, "server 1 should lose leadership after missing heartbeat") + require.NoError(t, err) + require.False(t, isLeader1, "server 1 should lose leadership after missing heartbeat") } diff --git a/internal/round/factory.go b/internal/round/factory.go index dbd046b..5531dff 100644 --- a/internal/round/factory.go +++ b/internal/round/factory.go @@ -7,12 +7,12 @@ import ( "github.com/unicitynetwork/bft-go-base/types" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/sharding" "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" - "github.com/unicitynetwork/aggregator-go/pkg/api" ) // Manager interface for both standalone and parent round managers @@ -26,17 +26,25 @@ type Manager interface { } // NewManager creates the appropriate round manager based on sharding mode -func NewManager(ctx context.Context, cfg *config.Config, logger *logger.Logger, commitmentQueue interfaces.CommitmentQueue, storage interfaces.Storage, stateTracker *state.Tracker, luc *types.UnicityCertificate) (Manager, error) { +func NewManager( + ctx context.Context, + cfg *config.Config, + logger *logger.Logger, + commitmentQueue interfaces.CommitmentQueue, + storage interfaces.Storage, + stateTracker *state.Tracker, + luc *types.UnicityCertificate, + eventBus *events.EventBus, + threadSafeSmt *smt.ThreadSafeSMT, +) (Manager, error) { switch cfg.Sharding.Mode { case config.ShardingModeStandalone: - smtInstance := smt.NewSparseMerkleTree(api.SHA256, 16+256) - return NewRoundManager(ctx, cfg, logger, smtInstance, commitmentQueue, storage, nil, stateTracker, luc) + return NewRoundManager(ctx, cfg, logger, commitmentQueue, storage, nil, stateTracker, luc, eventBus, threadSafeSmt) case config.ShardingModeParent: - return NewParentRoundManager(ctx, cfg, logger, storage, luc) + return NewParentRoundManager(ctx, cfg, logger, storage, luc, eventBus, threadSafeSmt) case config.ShardingModeChild: - smtInstance := smt.NewChildSparseMerkleTree(api.SHA256, 16+256, cfg.Sharding.Child.ShardID) rootAggregatorClient := sharding.NewRootAggregatorClient(cfg.Sharding.Child.ParentRpcAddr) - return NewRoundManager(ctx, cfg, logger, smtInstance, commitmentQueue, storage, rootAggregatorClient, stateTracker, luc) + return NewRoundManager(ctx, cfg, logger, commitmentQueue, storage, rootAggregatorClient, stateTracker, luc, eventBus, threadSafeSmt) default: return nil, fmt.Errorf("unsupported sharding mode: %s", cfg.Sharding.Mode) } diff --git a/internal/round/finalize_duplicate_test.go b/internal/round/finalize_duplicate_test.go index d1a849b..655ebc7 100644 --- a/internal/round/finalize_duplicate_test.go +++ b/internal/round/finalize_duplicate_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/models" @@ -60,7 +61,10 @@ func (s *FinalizeDuplicateTestSuite) Test1_DuplicateRecovery() { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(ctx, s.cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil) + smtInstance := smt.NewSparseMerkleTree(api.SHA256, 16+256) + threadSafeSMT := smt.NewThreadSafeSMT(smtInstance) + rm, err := NewRoundManager(ctx, s.cfg, testLogger, s.storage.CommitmentQueue(), s.storage, nil, + state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), threadSafeSMT) require.NoError(t, err) // Generate test commitments with unique IDs @@ -141,7 +145,10 @@ func (s *FinalizeDuplicateTestSuite) Test2_NoDuplicates() { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(ctx, s.cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil) + smtInstance := smt.NewSparseMerkleTree(api.SHA256, 16+256) + threadSafeSMT := smt.NewThreadSafeSMT(smtInstance) + rm, err := NewRoundManager(ctx, s.cfg, testLogger, s.storage.CommitmentQueue(), s.storage, nil, + state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), threadSafeSMT) require.NoError(t, err) commitments := testutil.CreateTestCommitments(t, 3, "t2_req") @@ -189,7 +196,10 @@ func (s *FinalizeDuplicateTestSuite) Test3_AllDuplicates() { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(ctx, s.cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil) + smtInstance := smt.NewSparseMerkleTree(api.SHA256, 16+256) + threadSafeSMT := smt.NewThreadSafeSMT(smtInstance) + rm, err := NewRoundManager(ctx, s.cfg, testLogger, s.storage.CommitmentQueue(), s.storage, nil, + state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), threadSafeSMT) require.NoError(t, err) commitments := testutil.CreateTestCommitments(t, 3, "t3_req") @@ -258,7 +268,8 @@ func (s *FinalizeDuplicateTestSuite) Test4_DuplicateBlock() { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(ctx, s.cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil) + threadSafeSMT := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + rm, err := NewRoundManager(ctx, s.cfg, testLogger, s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), threadSafeSMT) require.NoError(t, err) commitments := testutil.CreateTestCommitments(t, 3, "t4_req") @@ -338,7 +349,8 @@ func (s *FinalizeDuplicateTestSuite) Test5_DuplicateBlockAlreadyFinalized() { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(ctx, s.cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil) + threadSafeSMT := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) + rm, err := NewRoundManager(ctx, s.cfg, testLogger, s.storage.CommitmentQueue(), s.storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), threadSafeSMT) require.NoError(t, err) commitments := testutil.CreateTestCommitments(t, 3, "t5_req") diff --git a/internal/round/parent_round_manager.go b/internal/round/parent_round_manager.go index a93d1ca..c2860bf 100644 --- a/internal/round/parent_round_manager.go +++ b/internal/round/parent_round_manager.go @@ -12,6 +12,7 @@ import ( "github.com/unicitynetwork/aggregator-go/internal/bft" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/models" "github.com/unicitynetwork/aggregator-go/internal/smt" @@ -46,6 +47,7 @@ type ParentRoundManager struct { storage interfaces.Storage parentSMT *smt.ThreadSafeSMT bftClient bft.BFTClient + eventBus *events.EventBus // Round management currentRound *ParentRound @@ -62,24 +64,29 @@ type ParentRoundManager struct { const parentRoundRetryDelay = 1 * time.Second // NewParentRoundManager creates a new parent round manager -func NewParentRoundManager(ctx context.Context, cfg *config.Config, logger *logger.Logger, storage interfaces.Storage, luc *types.UnicityCertificate) (*ParentRoundManager, error) { - // Initialize parent SMT in parent mode with support for mutable leaves - smtInstance := smt.NewParentSparseMerkleTree(api.SHA256, cfg.Sharding.ShardIDLength) - parentSMT := smt.NewThreadSafeSMT(smtInstance) - +func NewParentRoundManager( + ctx context.Context, + cfg *config.Config, + logger *logger.Logger, + storage interfaces.Storage, + luc *types.UnicityCertificate, + eventBus *events.EventBus, + threadSafeSmt *smt.ThreadSafeSMT, +) (*ParentRoundManager, error) { prm := &ParentRoundManager{ config: cfg, logger: logger, storage: storage, - parentSMT: parentSMT, + parentSMT: threadSafeSmt, stopChan: make(chan struct{}), roundDuration: cfg.Processing.RoundDuration, + eventBus: eventBus, } // Create BFT client (same logic as regular RoundManager) if cfg.BFT.Enabled { var err error - prm.bftClient, err = bft.NewBFTClient(&cfg.BFT, prm, storage.TrustBaseStorage(), luc, logger) + prm.bftClient, err = bft.NewBFTClient(ctx, &cfg.BFT, prm, storage.TrustBaseStorage(), luc, logger, eventBus) if err != nil { return nil, fmt.Errorf("failed to create BFT client: %w", err) } diff --git a/internal/round/parent_round_manager_test.go b/internal/round/parent_round_manager_test.go index e7b4e72..3c98e08 100644 --- a/internal/round/parent_round_manager_test.go +++ b/internal/round/parent_round_manager_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/models" "github.com/unicitynetwork/aggregator-go/internal/smt" @@ -20,10 +21,11 @@ import ( // ParentRoundManagerTestSuite is the test suite for parent round manager type ParentRoundManagerTestSuite struct { suite.Suite - cfg *config.Config - logger *logger.Logger - storage *mongodb.Storage - cleanup func() + cfg *config.Config + logger *logger.Logger + storage *mongodb.Storage + eventBus *events.EventBus + cleanup func() } // SetupSuite runs once before all tests - creates one MongoDB container for all tests @@ -32,6 +34,8 @@ func (suite *ParentRoundManagerTestSuite) SetupSuite() { suite.logger, err = logger.New("info", "text", "stdout", false) require.NoError(suite.T(), err, "Should create logger") + suite.eventBus = events.NewEventBus(suite.logger) + suite.cfg = &config.Config{ Sharding: config.ShardingConfig{ Mode: config.ShardingModeParent, @@ -84,7 +88,7 @@ func (suite *ParentRoundManagerTestSuite) TestInitialization() { ctx := context.Background() // Create parent round manager (BFT stub will be created automatically when BFT.Enabled = false) - prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil) + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil, suite.eventBus, smt.NewThreadSafeSMT(smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength))) suite.Require().NoError(err, "Should create parent round manager successfully") suite.Require().NotNil(prm, "ParentRoundManager should not be nil") @@ -103,7 +107,7 @@ func (suite *ParentRoundManagerTestSuite) TestInitialization() { func (suite *ParentRoundManagerTestSuite) TestBasicRoundLifecycle() { ctx := context.Background() - prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil) + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil, suite.eventBus, smt.NewThreadSafeSMT(smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength))) suite.Require().NoError(err) defer prm.Stop(ctx) // Stop round manager before cleanup to avoid disconnection errors @@ -151,7 +155,7 @@ func (suite *ParentRoundManagerTestSuite) TestMultiRoundUpdates() { suite.T().Skip("TODO(SMT): enable once sparse Merkle tree supports updating existing leaves") ctx := context.Background() - prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil) + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil, suite.eventBus, smt.NewThreadSafeSMT(smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength))) suite.Require().NoError(err) defer prm.Stop(ctx) // Stop round manager before cleanup to avoid disconnection errors @@ -216,7 +220,7 @@ func (suite *ParentRoundManagerTestSuite) TestMultiRoundUpdates() { func (suite *ParentRoundManagerTestSuite) TestMultipleShards() { ctx := context.Background() - prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil) + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil, suite.eventBus, smt.NewThreadSafeSMT(smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength))) suite.Require().NoError(err) defer prm.Stop(ctx) @@ -266,7 +270,7 @@ func (suite *ParentRoundManagerTestSuite) TestMultipleShards() { func (suite *ParentRoundManagerTestSuite) TestEmptyRound() { ctx := context.Background() - prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil) + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil, suite.eventBus, smt.NewThreadSafeSMT(smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength))) suite.Require().NoError(err) defer prm.Stop(ctx) @@ -301,7 +305,7 @@ func (suite *ParentRoundManagerTestSuite) TestEmptyRound() { func (suite *ParentRoundManagerTestSuite) TestDuplicateShardUpdate() { ctx := context.Background() - prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil) + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil, suite.eventBus, smt.NewThreadSafeSMT(smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength))) suite.Require().NoError(err) defer prm.Stop(ctx) @@ -345,7 +349,7 @@ func (suite *ParentRoundManagerTestSuite) TestSameShardMultipleValues() { suite.T().Skip("TODO(SMT): enable once sparse Merkle tree supports updating existing leaves") ctx := context.Background() - prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil) + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil, suite.eventBus, smt.NewThreadSafeSMT(smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength))) suite.Require().NoError(err) defer prm.Stop(ctx) @@ -401,7 +405,7 @@ func (suite *ParentRoundManagerTestSuite) TestSameShardMultipleValues() { func (suite *ParentRoundManagerTestSuite) TestBlockRootMatchesSMTRoot() { ctx := context.Background() - prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil) + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil, suite.eventBus, smt.NewThreadSafeSMT(smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength))) suite.Require().NoError(err) defer prm.Stop(ctx) diff --git a/internal/round/round_manager.go b/internal/round/round_manager.go index a26e347..e074da0 100644 --- a/internal/round/round_manager.go +++ b/internal/round/round_manager.go @@ -13,6 +13,7 @@ import ( "github.com/unicitynetwork/aggregator-go/internal/bft" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/models" @@ -85,10 +86,9 @@ type Round struct { // - Start() and Stop(): These methods manage the overall lifecycle of the RoundManager instance. // Start() is called once during application initialization to set up core components // and restore state. Stop() is called once during application shutdown for graceful cleanup. -// - Activate() and Deactivate(): These methods are part of the ha.Activatable interface -// and manage the RoundManager's active participation in block creation based on -// High Availability (HA) leadership status. Activate() is called when the node -// becomes the leader, enabling active block processing. Deactivate() is called +// - Activate() and Deactivate(): these methods manage the RoundManager's active participation +// in block creation based on High Availability (HA) leadership status. Activate() is called +// when the node becomes the leader, enabling active block processing. Deactivate() is called // when the node loses leadership, putting it into a passive state. // A RoundManager can be Activated and Deactivated multiple times throughout its // overall Start-Stop lifecycle as leadership changes. @@ -101,6 +101,7 @@ type RoundManager struct { rootClient RootAggregatorClient bftClient bft.BFTClient stateTracker *state.Tracker + eventBus *events.EventBus // Round management currentRound *Round @@ -148,21 +149,23 @@ func NewRoundManager( ctx context.Context, cfg *config.Config, logger *logger.Logger, - smtInstance *smt.SparseMerkleTree, commitmentQueue interfaces.CommitmentQueue, storage interfaces.Storage, rootAggregatorClient RootAggregatorClient, stateTracker *state.Tracker, luc *types.UnicityCertificate, + eventBus *events.EventBus, + threadSafeSmt *smt.ThreadSafeSMT, ) (*RoundManager, error) { rm := &RoundManager{ config: cfg, logger: logger, commitmentQueue: commitmentQueue, storage: storage, - smt: smt.NewThreadSafeSMT(smtInstance), + smt: threadSafeSmt, rootClient: rootAggregatorClient, stateTracker: stateTracker, + eventBus: eventBus, roundDuration: cfg.Processing.RoundDuration, // Configurable round duration (default 1s) commitmentStream: make(chan *models.Commitment, 10000), // Reasonable buffer for streaming avgProcessingRate: 1.0, // Initial estimate: 1 commitment per ms @@ -174,7 +177,7 @@ func NewRoundManager( if cfg.Sharding.Mode == config.ShardingModeStandalone { if cfg.BFT.Enabled { var err error - rm.bftClient, err = bft.NewBFTClient(&cfg.BFT, rm, storage.TrustBaseStorage(), luc, logger) + rm.bftClient, err = bft.NewBFTClient(ctx, &cfg.BFT, rm, storage.TrustBaseStorage(), luc, logger, eventBus) if err != nil { return nil, fmt.Errorf("failed to create BFT client: %w", err) } @@ -816,6 +819,7 @@ func (rm *RoundManager) restoreSmtFromStorage(ctx context.Context) (*api.BigInt, } func (rm *RoundManager) Activate(ctx context.Context) error { + rm.logger.WithContext(ctx).Info("Activating round manager") if rm.config.HA.Enabled { recoveryResult, err := RecoverUnfinalizedBlock(ctx, rm.logger, rm.storage, rm.commitmentQueue) if err != nil { @@ -865,7 +869,8 @@ func (rm *RoundManager) Activate(ctx context.Context) error { return nil } -func (rm *RoundManager) Deactivate(_ context.Context) error { +func (rm *RoundManager) Deactivate(ctx context.Context) error { + rm.logger.WithContext(ctx).Info("Deactivating round manager") rm.stopCommitmentPrefetcher() if rm.bftClient != nil { rm.bftClient.Stop() diff --git a/internal/round/round_manager_test.go b/internal/round/round_manager_test.go index 08087c5..51a2f34 100644 --- a/internal/round/round_manager_test.go +++ b/internal/round/round_manager_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" testsharding "github.com/unicitynetwork/aggregator-go/internal/sharding" @@ -43,7 +44,7 @@ func TestParentShardIntegration_GoodCase(t *testing.T) { rootAggregatorClient := testsharding.NewRootAggregatorClientStub() // create round manager - rm, err := NewRoundManager(ctx, &cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, rootAggregatorClient, state.NewSyncStateTracker(), nil) + rm, err := NewRoundManager(ctx, &cfg, testLogger, storage.CommitmentQueue(), storage, rootAggregatorClient, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err) // start round manager @@ -92,7 +93,8 @@ func TestParentShardIntegration_RoundProcessingError(t *testing.T) { rootAggregatorClient := testsharding.NewRootAggregatorClientStub() rootAggregatorClient.SetSubmissionError(errors.New("some error")) - rm, err := NewRoundManager(ctx, &cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, rootAggregatorClient, state.NewSyncStateTracker(), nil) + // create round manager + rm, err := NewRoundManager(ctx, &cfg, testLogger, storage.CommitmentQueue(), storage, rootAggregatorClient, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err) require.NoError(t, rm.Start(ctx)) diff --git a/internal/round/smt_persistence_integration_test.go b/internal/round/smt_persistence_integration_test.go index 89c4073..d299ad3 100644 --- a/internal/round/smt_persistence_integration_test.go +++ b/internal/round/smt_persistence_integration_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/models" @@ -57,7 +58,7 @@ func TestSmtPersistenceAndRestoration(t *testing.T) { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil) + rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err, "Should create RoundManager") // Test persistence @@ -77,7 +78,7 @@ func TestSmtPersistenceAndRestoration(t *testing.T) { freshHash := freshSmt.GetRootHashHex() // Create RoundManager and call Start() to trigger restoration - restoredRm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil) + restoredRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err, "Should create RoundManager") err = restoredRm.Start(ctx) @@ -109,7 +110,7 @@ func TestLargeSmtRestoration(t *testing.T) { RoundDuration: time.Second, }, } - rm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil) + rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err, "Should create RoundManager") const testNodeCount = 2500 // Ensure multiple chunks (chunkSize = 1000 in round_manager.go) @@ -140,7 +141,7 @@ func TestLargeSmtRestoration(t *testing.T) { require.Equal(t, int64(testNodeCount), count, "Should have stored all nodes") // Create new RoundManager and call Start() to restore from storage (uses multiple chunks) - newRm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil) + newRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err, "Should create new RoundManager") err = newRm.Start(ctx) @@ -171,7 +172,7 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil) + rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err, "Should create RoundManager") rm.currentRound = &Round{ @@ -223,7 +224,7 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { // Simulate service restart with new round manager cfg = &config.Config{Processing: config.ProcessingConfig{RoundDuration: time.Second}} - newRm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil) + newRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err, "NewRoundManager should succeed after restart") // Call Start() to trigger SMT restoration @@ -298,7 +299,7 @@ func TestSmtRestorationWithBlockVerification(t *testing.T) { cfg := &config.Config{ Processing: config.ProcessingConfig{RoundDuration: time.Second}, } - rm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil) + rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err, "Should create RoundManager") // Persist SMT nodes to storage @@ -308,7 +309,7 @@ func TestSmtRestorationWithBlockVerification(t *testing.T) { // Test 1: Successful verification (matching root hash) t.Run("SuccessfulVerification", func(t *testing.T) { - successRm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil) + successRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err, "Should create RoundManager") err = successRm.Start(ctx) @@ -342,7 +343,7 @@ func TestSmtRestorationWithBlockVerification(t *testing.T) { err = storage.BlockStorage().Store(ctx, wrongBlock) require.NoError(t, err, "Should store wrong test block") - failRm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil) + failRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker(), nil, events.NewEventBus(testLogger), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err, "Should create RoundManager") // This should fail because the restored SMT root hash doesn't match the latest block diff --git a/internal/service/parent_service_test.go b/internal/service/parent_service_test.go index 37a12fb..5be8896 100644 --- a/internal/service/parent_service_test.go +++ b/internal/service/parent_service_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/round" "github.com/unicitynetwork/aggregator-go/internal/smt" @@ -21,12 +22,13 @@ import ( // ParentServiceTestSuite is the test suite for parent aggregator service type ParentServiceTestSuite struct { suite.Suite - cfg *config.Config - logger *logger.Logger - storage *mongodb.Storage - cleanup func() - service *ParentAggregatorService - prm *round.ParentRoundManager + cfg *config.Config + logger *logger.Logger + storage *mongodb.Storage + cleanup func() + service *ParentAggregatorService + prm *round.ParentRoundManager + eventBus *events.EventBus } type staticLeaderSelector struct { @@ -47,6 +49,8 @@ func (suite *ParentServiceTestSuite) SetupSuite() { suite.logger, err = logger.New("info", "text", "stdout", false) require.NoError(suite.T(), err, "Should create logger") + suite.eventBus = events.NewEventBus(suite.logger) + suite.cfg = &config.Config{ Sharding: config.ShardingConfig{ Mode: config.ShardingModeParent, @@ -80,7 +84,8 @@ func (suite *ParentServiceTestSuite) SetupTest() { // Create parent round manager var err error - suite.prm, err = round.NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil) + parentSMT := smt.NewThreadSafeSMT(smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength)) + suite.prm, err = round.NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage, nil, suite.eventBus, parentSMT) require.NoError(suite.T(), err, "Should create parent round manager") require.NotNil(suite.T(), suite.prm, "Parent round manager should not be nil") diff --git a/internal/service/service_test.go b/internal/service/service_test.go index d166c27..f68f91c 100644 --- a/internal/service/service_test.go +++ b/internal/service/service_test.go @@ -25,6 +25,7 @@ import ( redisContainer "github.com/testcontainers/testcontainers-go/modules/redis" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/gateway" "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" @@ -136,7 +137,7 @@ func setupMongoDBAndAggregator(t *testing.T, ctx context.Context) (string, func( // Initialize round manager rootAggregatorClient := sharding.NewRootAggregatorClientStub() - roundManager, err := round.NewRoundManager(ctx, cfg, log, smt.NewSparseMerkleTree(api.SHA256, 16+256), commitmentQueue, mongoStorage, rootAggregatorClient, state.NewSyncStateTracker(), nil) + roundManager, err := round.NewRoundManager(ctx, cfg, log, commitmentQueue, mongoStorage, rootAggregatorClient, state.NewSyncStateTracker(), nil, events.NewEventBus(log), smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256))) require.NoError(t, err) // Start the round manager (restores SMT) diff --git a/internal/storage/interfaces/interfaces.go b/internal/storage/interfaces/interfaces.go index fe38af5..51a6bac 100644 --- a/internal/storage/interfaces/interfaces.go +++ b/internal/storage/interfaces/interfaces.go @@ -169,8 +169,8 @@ type BlockRecordsStorage interface { // If blockNumber is nil then returns the very first block. GetNextBlock(ctx context.Context, blockNumber *api.BigInt) (*models.BlockRecords, error) - // GetLatestBlock retrieves the latest block - GetLatestBlock(ctx context.Context) (*models.BlockRecords, error) + // GetLatestBlockNumber retrieves the latest block + GetLatestBlockNumber(ctx context.Context) (*api.BigInt, error) } // LeadershipStorage handles high availability leadership state @@ -197,8 +197,6 @@ var ErrTrustBaseAlreadyExists = errors.New("trust base already exists") type TrustBaseStorage interface { Store(ctx context.Context, trustBase types.RootTrustBase) error GetByEpoch(ctx context.Context, epoch uint64) (types.RootTrustBase, error) - GetByRound(ctx context.Context, round uint64) (types.RootTrustBase, error) - GetAll(ctx context.Context) ([]types.RootTrustBase, error) } // Storage handles persistent data storage diff --git a/internal/storage/mongodb/block_records.go b/internal/storage/mongodb/block_records.go index 821357f..4b7a4a0 100644 --- a/internal/storage/mongodb/block_records.go +++ b/internal/storage/mongodb/block_records.go @@ -127,24 +127,28 @@ func (brs *BlockRecordsStorage) GetNextBlock(ctx context.Context, blockNumber *a return blockRecord, nil } -// GetLatestBlock retrieves the latest block -func (brs *BlockRecordsStorage) GetLatestBlock(ctx context.Context) (*models.BlockRecords, error) { - opts := options.FindOne().SetSort(bson.D{{Key: "blockNumber", Value: -1}}) +// GetLatestBlockNumber retrieves the latest block number +func (brs *BlockRecordsStorage) GetLatestBlockNumber(ctx context.Context) (*api.BigInt, error) { + opts := options.FindOne(). + SetProjection(bson.M{"blockNumber": 1}). + SetSort(bson.D{{Key: "blockNumber", Value: -1}}) - var result models.BlockRecordsBSON - err := brs.collection.FindOne(ctx, bson.M{}, opts).Decode(&result) - if err != nil { + var result struct { + BlockNumber primitive.Decimal128 `bson:"blockNumber"` + } + + if err := brs.collection.FindOne(ctx, bson.M{}, opts).Decode(&result); err != nil { if errors.Is(err, mongo.ErrNoDocuments) { return nil, nil } - return nil, fmt.Errorf("failed to get latest block record: %w", err) + return nil, fmt.Errorf("failed to get latest block record number: %w", err) } - blockRecord, err := result.FromBSON() + blockNumber, _, err := result.BlockNumber.BigInt() if err != nil { - return nil, fmt.Errorf("failed to convert from BSON: %w", err) + return nil, fmt.Errorf("failed to parse blockNumber: %w", err) } - return blockRecord, nil + return api.NewBigInt(blockNumber), nil } // CreateIndexes creates necessary indexes for the block records collection diff --git a/internal/storage/mongodb/block_records_test.go b/internal/storage/mongodb/block_records_test.go index 10e9ec2..17feb31 100644 --- a/internal/storage/mongodb/block_records_test.go +++ b/internal/storage/mongodb/block_records_test.go @@ -374,9 +374,9 @@ func TestBlockRecordsStorage_GetLatestBlock(t *testing.T) { } t.Run("should return nil when no block records exist", func(t *testing.T) { - num, err := storage.GetLatestBlock(ctx) - require.NoError(t, err, "GetLatestBlock should not return an error when empty") - assert.Nil(t, num, "GetLatestBlock should return nil when no records exist") + num, err := storage.GetLatestBlockNumber(ctx) + require.NoError(t, err, "GetLatestBlockNumber should not return an error when empty") + assert.Nil(t, num, "GetLatestBlockNumber should return nil when no records exist") }) t.Run("should return latest block with single record", func(t *testing.T) { @@ -387,7 +387,7 @@ func TestBlockRecordsStorage_GetLatestBlock(t *testing.T) { require.NoError(t, err) // Get latest - latestNum, err := storage.GetLatestBlock(ctx) + latestNum, err := storage.GetLatestBlockNumber(ctx) require.NoError(t, err, "GetLatestNumber should not return an error") require.NotNil(t, latestNum, "Latest number should not be nil") @@ -413,12 +413,12 @@ func TestBlockRecordsStorage_GetLatestBlock(t *testing.T) { } // Get latest - should be block number 130 - latestBlock, err := storage.GetLatestBlock(ctx) + latestBlock, err := storage.GetLatestBlockNumber(ctx) require.NoError(t, err, "GetLatestNumber should not return an error") require.NotNil(t, latestBlock, "Latest number should not be nil") expectedLatest := api.NewBigInt(big.NewInt(130)) - assert.Equal(t, 0, expectedLatest.Cmp(latestBlock.BlockNumber.Int), "Should get latest block number") + assert.Equal(t, 0, expectedLatest.Cmp(latestBlock.Int), "Should get latest block number") }) t.Run("should handle decimal128 sorting correctly for large numbers", func(t *testing.T) { @@ -447,11 +447,11 @@ func TestBlockRecordsStorage_GetLatestBlock(t *testing.T) { } // Get latest - latestBlock, err := storage.GetLatestBlock(ctx) + latestBlock, err := storage.GetLatestBlockNumber(ctx) require.NoError(t, err, "GetLatestNumber should not return an error") require.NotNil(t, latestBlock, "Latest number should not be nil") - assert.Equal(t, 0, expectedLatest.Cmp(latestBlock.BlockNumber.Int), "Should get latest block number") + assert.Equal(t, 0, expectedLatest.Cmp(latestBlock.Int), "Should get latest block number") }) } diff --git a/internal/storage/mongodb/cached_trust_base.go b/internal/storage/mongodb/cached_trust_base.go index 79f4dba..12b6c87 100644 --- a/internal/storage/mongodb/cached_trust_base.go +++ b/internal/storage/mongodb/cached_trust_base.go @@ -3,25 +3,23 @@ package mongodb import ( "context" "fmt" - "sort" "sync" "github.com/unicitynetwork/bft-go-base/types" - - "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" ) // CachedTrustBaseStorage is a cached decorator of TrustBaseStorage. type CachedTrustBaseStorage struct { storage *TrustBaseStorage - sortedTrustBases []types.RootTrustBase + trustBaseByEpoch map[uint64]types.RootTrustBase mu sync.RWMutex } func NewCachedTrustBaseStorage(storage *TrustBaseStorage) *CachedTrustBaseStorage { return &CachedTrustBaseStorage{ - storage: storage, + storage: storage, + trustBaseByEpoch: make(map[uint64]types.RootTrustBase), } } @@ -30,41 +28,37 @@ func (s *CachedTrustBaseStorage) Store(ctx context.Context, trustBase types.Root if err := s.storage.Store(ctx, trustBase); err != nil { return fmt.Errorf("failed to store trust base: %w", err) } - if err := s.UpdateCache(ctx); err != nil { - return fmt.Errorf("failed to reload cache: %w", err) - } + s.updateCache(trustBase) return nil } // GetByEpoch retrieves a trust base by epoch. -func (s *CachedTrustBaseStorage) GetByEpoch(_ context.Context, epoch uint64) (types.RootTrustBase, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - if epoch < uint64(len(s.sortedTrustBases)) { - return s.sortedTrustBases[epoch], nil +func (s *CachedTrustBaseStorage) GetByEpoch(ctx context.Context, epoch uint64) (types.RootTrustBase, error) { + tbFromCache := s.getByEpoch(epoch) + if tbFromCache != nil { + return tbFromCache, nil } - return nil, interfaces.ErrTrustBaseNotFound -} -// GetByRound retrieves a trust base by epoch start round. -func (s *CachedTrustBaseStorage) GetByRound(_ context.Context, epochStart uint64) (types.RootTrustBase, error) { - s.mu.RLock() - defer s.mu.RUnlock() + // in HA mode another node may have updated the trust base, + // so we must check storage + tbFromStorage, err := s.storage.GetByEpoch(ctx, epoch) + if err != nil { + return nil, fmt.Errorf("failed to fetch trust base from storage: %w", err) + } + s.updateCache(tbFromStorage) - return s.getByEpochStartRound(epochStart) + return tbFromStorage, nil } -// GetAll returns all trust bases in sorted order. -func (s *CachedTrustBaseStorage) GetAll(_ context.Context) ([]types.RootTrustBase, error) { +func (s *CachedTrustBaseStorage) getByEpoch(epoch uint64) types.RootTrustBase { s.mu.RLock() defer s.mu.RUnlock() - return s.sortedTrustBases, nil + return s.trustBaseByEpoch[epoch] } -// UpdateCache updates the cache from storage. -func (s *CachedTrustBaseStorage) UpdateCache(ctx context.Context) error { +// ReloadCache reloads the cache from storage. +func (s *CachedTrustBaseStorage) ReloadCache(ctx context.Context) error { trustBases, err := s.storage.GetAll(ctx) if err != nil { return fmt.Errorf("failed to get all trust bases: %w", err) @@ -77,18 +71,17 @@ func (s *CachedTrustBaseStorage) reloadCache(trustBases []types.RootTrustBase) { s.mu.Lock() defer s.mu.Unlock() - // make sure trust bases are in sorted order by epoch - sort.Slice(trustBases, func(i, j int) bool { - return trustBases[i].GetEpoch() < trustBases[j].GetEpoch() - }) - s.sortedTrustBases = trustBases + newCache := make(map[uint64]types.RootTrustBase, len(trustBases)) + for _, tb := range trustBases { + newCache[tb.GetEpoch()] = tb + } + + s.trustBaseByEpoch = newCache } -func (s *CachedTrustBaseStorage) getByEpochStartRound(epochStart uint64) (types.RootTrustBase, error) { - for i := len(s.sortedTrustBases) - 1; i >= 0; i-- { - if s.sortedTrustBases[i].GetEpochStart() <= epochStart { - return s.sortedTrustBases[i], nil - } - } - return nil, interfaces.ErrTrustBaseNotFound +func (s *CachedTrustBaseStorage) updateCache(trustBase types.RootTrustBase) { + s.mu.Lock() + defer s.mu.Unlock() + + s.trustBaseByEpoch[trustBase.GetEpoch()] = trustBase } diff --git a/internal/storage/mongodb/cached_trust_base_test.go b/internal/storage/mongodb/cached_trust_base_test.go index 53f228d..08aa859 100644 --- a/internal/storage/mongodb/cached_trust_base_test.go +++ b/internal/storage/mongodb/cached_trust_base_test.go @@ -34,23 +34,16 @@ func TestCachedTrustBaseStorage(t *testing.T) { t.Run("initially empty", func(t *testing.T) { ctx := t.Context() require.NoError(t, cachedStorage.storage.collection.Drop(ctx)) - require.NoError(t, cachedStorage.UpdateCache(ctx)) + require.NoError(t, cachedStorage.ReloadCache(ctx)) _, err := cachedStorage.GetByEpoch(ctx, 0) require.ErrorIs(t, err, interfaces.ErrTrustBaseNotFound) - - _, err = cachedStorage.GetByRound(ctx, 0) - require.ErrorIs(t, err, interfaces.ErrTrustBaseNotFound) - - all, err := cachedStorage.GetAll(ctx) - require.NoError(t, err) - require.Empty(t, all) }) t.Run("store and retrieve", func(t *testing.T) { ctx := t.Context() require.NoError(t, cachedStorage.storage.collection.Drop(ctx)) - require.NoError(t, cachedStorage.UpdateCache(ctx)) + require.NoError(t, cachedStorage.ReloadCache(ctx)) // store epoch 0 require.NoError(t, cachedStorage.Store(ctx, trustBaseEpoch0)) @@ -60,16 +53,6 @@ func TestCachedTrustBaseStorage(t *testing.T) { require.NoError(t, err) require.Equal(t, trustBaseEpoch0, tb) - // GetByRound (any value) returns epoch 0 - tb, err = cachedStorage.GetByRound(ctx, 9999) - require.NoError(t, err) - require.Equal(t, trustBaseEpoch0, tb) - - // GetAll returns only epoch 0 - all, err := cachedStorage.GetAll(ctx) - require.NoError(t, err) - require.Equal(t, []types.RootTrustBase{trustBaseEpoch0}, all) - // store epoch 1 require.NoError(t, cachedStorage.Store(ctx, trustBaseEpoch1)) @@ -77,20 +60,5 @@ func TestCachedTrustBaseStorage(t *testing.T) { tb, err = cachedStorage.GetByEpoch(ctx, 1) require.NoError(t, err) require.Equal(t, trustBaseEpoch1, tb) - - // GetByRound for round >= 1000 returns epoch 1 - tb, err = cachedStorage.GetByRound(ctx, 1000) - require.NoError(t, err) - require.Equal(t, trustBaseEpoch1, tb) - - // GetByRound for round < 1000 returns epoch 0 - tb, err = cachedStorage.GetByRound(ctx, 999) - require.NoError(t, err) - require.Equal(t, trustBaseEpoch0, tb) - - // GetAll returns both epochs - all, err = cachedStorage.GetAll(ctx) - require.NoError(t, err) - require.Equal(t, []types.RootTrustBase{trustBaseEpoch0, trustBaseEpoch1}, all) }) } diff --git a/internal/storage/mongodb/connection.go b/internal/storage/mongodb/connection.go index ac6aedc..b75accf 100644 --- a/internal/storage/mongodb/connection.go +++ b/internal/storage/mongodb/connection.go @@ -77,7 +77,7 @@ func NewStorage(ctx context.Context, config config.Config) (*Storage, error) { storage.cachedTrustBaseStorage = NewCachedTrustBaseStorage(NewTrustBaseStorage(database)) // init trust base store cache - if err := storage.cachedTrustBaseStorage.UpdateCache(ctx); err != nil { + if err := storage.cachedTrustBaseStorage.ReloadCache(ctx); err != nil { return nil, fmt.Errorf("failed to init cached trust base storage cache: %w", err) } diff --git a/sharding-compose.yml b/sharding-compose.yml index 2f60563..54ab9a9 100644 --- a/sharding-compose.yml +++ b/sharding-compose.yml @@ -212,7 +212,7 @@ services: BFT_ENABLED: "true" BFT_KEY_CONF_FILE: "/app/bft-config/aggregator/keys.json" BFT_SHARD_CONF_FILE: "/app/bft-config/shard-conf-7_0.json" - BFT_TRUST_BASE_FILE: "/app/bft-config/trust-base.json" + BFT_TRUST_BASE_FILES: "/app/bft-config/trust-base.json" # BFT_BOOTSTRAP_ADDRESSES will be set dynamically by the entrypoint script # Redis Configuration diff --git a/sharding-ha-compose.yml b/sharding-ha-compose.yml index 93c6f24..49f9a4d 100644 --- a/sharding-ha-compose.yml +++ b/sharding-ha-compose.yml @@ -315,7 +315,7 @@ services: BFT_ENABLED: "false" BFT_KEY_CONF_FILE: "/app/bft-config/aggregator/keys.json" BFT_SHARD_CONF_FILE: "/app/bft-config/shard-conf-7_0.json" - BFT_TRUST_BASE_FILE: "/app/bft-config/trust-base.json" + BFT_TRUST_BASE_FILES: "/app/bft-config/trust-base.json" REDIS_HOST: "redis" REDIS_PORT: "6379" REDIS_PASSWORD: "" diff --git a/test/integration/sharding_e2e_test.go b/test/integration/sharding_e2e_test.go index 54c9a0b..b99752f 100644 --- a/test/integration/sharding_e2e_test.go +++ b/test/integration/sharding_e2e_test.go @@ -17,11 +17,13 @@ import ( redisContainer "github.com/testcontainers/testcontainers-go/modules/redis" "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/events" "github.com/unicitynetwork/aggregator-go/internal/gateway" "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/round" "github.com/unicitynetwork/aggregator-go/internal/service" + "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage" "github.com/unicitynetwork/aggregator-go/internal/testutil" "github.com/unicitynetwork/aggregator-go/pkg/api" @@ -156,7 +158,19 @@ func startAggregator(t *testing.T, ctx context.Context, name, port, mongoURI, re queue, stor, _ := storage.NewStorage(aggCtx, cfg, log) queue.Initialize(aggCtx) - mgr, _ := round.NewManager(aggCtx, cfg, log, queue, stor, state.NewSyncStateTracker(), nil) + eventBus := events.NewEventBus(log) + + // Create SMT instance based on sharding mode + var smtInstance *smt.SparseMerkleTree + switch cfg.Sharding.Mode { + case config.ShardingModeStandalone, config.ShardingModeChild: + smtInstance = smt.NewSparseMerkleTree(api.SHA256, 16+256) + case config.ShardingModeParent: + smtInstance = smt.NewParentSparseMerkleTree(api.SHA256, cfg.Sharding.ShardIDLength) + } + threadSafeSmt := smt.NewThreadSafeSMT(smtInstance) + + mgr, _ := round.NewManager(aggCtx, cfg, log, queue, stor, state.NewSyncStateTracker(), nil, eventBus, threadSafeSmt) mgr.Start(aggCtx) mgr.Activate(aggCtx)