Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions universalClient/chains/chains.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ func (c *Chains) fetchAndUpdate(parent context.Context) error {

switch action {
case chainActionSkip:
// Disabled or no change - skip
continue

case chainActionAdd:
Expand All @@ -191,6 +190,11 @@ func (c *Chains) fetchAndUpdate(parent context.Context) error {
if err := c.addChain(parent, cfg); err != nil {
c.logger.Error().Err(err).Str("chain", chainID).Msg("failed to add updated chain")
}

case chainActionRemove:
if err := c.removeChain(chainID); err != nil {
c.logger.Error().Err(err).Str("chain", chainID).Msg("failed to remove disabled chain")
}
}
}

Expand Down Expand Up @@ -228,17 +232,29 @@ const (
func (c *Chains) determineChainAction(cfg *uregistrytypes.ChainConfig) chainAction {
chainID := cfg.Chain

// Check if chain exists
// Check if chain is fully disabled (both flags off)
bothDisabled := cfg.Enabled == nil ||
(!cfg.Enabled.IsInboundEnabled && !cfg.Enabled.IsOutboundEnabled)

c.chainsMu.RLock()
_, exists := c.chains[chainID]
existingConfig := c.chainConfigs[chainID]
c.chainsMu.RUnlock()

if bothDisabled {
if exists {
c.logger.Info().Str("chain", chainID).Msg("chain fully disabled (inbound+outbound off), removing")
return chainActionRemove
}
c.logger.Debug().Str("chain", chainID).Msg("chain fully disabled, skipping")
return chainActionSkip
}

if !exists {
return chainActionAdd
}

// Check if config changed
// Check if config changed (includes enabled flag changes)
if existingConfig != nil && !configsEqual(existingConfig, cfg) {
return chainActionUpdate
}
Expand Down Expand Up @@ -364,6 +380,22 @@ func (c *Chains) IsEVMChain(chainID string) bool {
return cfg != nil && cfg.VmType == uregistrytypes.VmType_EVM
}

// IsChainInboundEnabled returns whether inbound is enabled for the given chain
func (c *Chains) IsChainInboundEnabled(chainID string) bool {
c.chainsMu.RLock()
cfg := c.chainConfigs[chainID]
c.chainsMu.RUnlock()
return cfg != nil && cfg.Enabled != nil && cfg.Enabled.IsInboundEnabled
}

// IsChainOutboundEnabled returns whether outbound is enabled for the given chain
func (c *Chains) IsChainOutboundEnabled(chainID string) bool {
c.chainsMu.RLock()
cfg := c.chainConfigs[chainID]
c.chainsMu.RUnlock()
return cfg != nil && cfg.Enabled != nil && cfg.Enabled.IsOutboundEnabled
}

// GetStandardConfirmations returns the chain's standard block confirmations from registry config (BlockConfirmation.StandardInbound). Used for outbound tx completion. Returns 12 if not set.
func (c *Chains) GetStandardConfirmations(chainID string) uint64 {
c.chainsMu.RLock()
Expand Down Expand Up @@ -506,6 +538,11 @@ func configsEqual(a, b *uregistrytypes.ChainConfig) bool {
return false
}

// Compare enabled flags
if !chainEnabledEqual(a.Enabled, b.Enabled) {
return false
}

return true
}

Expand Down Expand Up @@ -539,6 +576,17 @@ func vaultMethodsEqual(a, b []*uregistrytypes.VaultMethods) bool {
return true
}

func chainEnabledEqual(a, b *uregistrytypes.ChainEnabled) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return a.IsInboundEnabled == b.IsInboundEnabled &&
a.IsOutboundEnabled == b.IsOutboundEnabled
}

func blockConfirmationEqual(a, b *uregistrytypes.BlockConfirmation) bool {
if a == nil && b == nil {
return true
Expand Down
76 changes: 74 additions & 2 deletions universalClient/chains/chains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,21 +259,64 @@ func TestDetermineChainAction(t *testing.T) {
}
chains := NewChains(nil, nil, cfg, logger)

enabled := &uregistrytypes.ChainEnabled{IsInboundEnabled: true, IsOutboundEnabled: true}

t.Run("new chain returns add", func(t *testing.T) {
chainCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:1",
VmType: uregistrytypes.VmType_EVM,
Chain: "eip155:1",
VmType: uregistrytypes.VmType_EVM,
Enabled: enabled,
}

action := chains.determineChainAction(chainCfg)
assert.Equal(t, chainActionAdd, action)
})

t.Run("both flags off returns skip for new chain", func(t *testing.T) {
chainCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:99",
VmType: uregistrytypes.VmType_EVM,
Enabled: &uregistrytypes.ChainEnabled{IsInboundEnabled: false, IsOutboundEnabled: false},
}
action := chains.determineChainAction(chainCfg)
assert.Equal(t, chainActionSkip, action)
})

t.Run("both flags off returns remove for existing chain", func(t *testing.T) {
chains.chainsMu.Lock()
chains.chains["eip155:99"] = nil
chains.chainConfigs["eip155:99"] = &uregistrytypes.ChainConfig{Chain: "eip155:99", Enabled: enabled}
chains.chainsMu.Unlock()

chainCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:99",
VmType: uregistrytypes.VmType_EVM,
Enabled: &uregistrytypes.ChainEnabled{IsInboundEnabled: false, IsOutboundEnabled: false},
}
action := chains.determineChainAction(chainCfg)
assert.Equal(t, chainActionRemove, action)

chains.chainsMu.Lock()
delete(chains.chains, "eip155:99")
delete(chains.chainConfigs, "eip155:99")
chains.chainsMu.Unlock()
})

t.Run("nil enabled returns skip for new chain", func(t *testing.T) {
chainCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:98",
VmType: uregistrytypes.VmType_EVM,
}
action := chains.determineChainAction(chainCfg)
assert.Equal(t, chainActionSkip, action)
})

t.Run("existing chain with same config returns skip", func(t *testing.T) {
chainCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:1",
VmType: uregistrytypes.VmType_EVM,
GatewayAddress: "0x123",
Enabled: enabled,
}

// Add the chain first
Expand All @@ -297,12 +340,14 @@ func TestDetermineChainAction(t *testing.T) {
Chain: "eip155:1",
VmType: uregistrytypes.VmType_EVM,
GatewayAddress: "0x123",
Enabled: enabled,
}

newCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:1",
VmType: uregistrytypes.VmType_EVM,
GatewayAddress: "0x456", // Different address
Enabled: enabled,
}

// Add the chain first
Expand All @@ -320,6 +365,33 @@ func TestDetermineChainAction(t *testing.T) {
delete(chains.chainConfigs, "eip155:1")
chains.chainsMu.Unlock()
})

t.Run("enabled flag change triggers update", func(t *testing.T) {
oldCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:1",
VmType: uregistrytypes.VmType_EVM,
Enabled: enabled,
}

newCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:1",
VmType: uregistrytypes.VmType_EVM,
Enabled: &uregistrytypes.ChainEnabled{IsInboundEnabled: true, IsOutboundEnabled: false},
}

chains.chainsMu.Lock()
chains.chains["eip155:1"] = nil
chains.chainConfigs["eip155:1"] = oldCfg
chains.chainsMu.Unlock()

action := chains.determineChainAction(newCfg)
assert.Equal(t, chainActionUpdate, action)

chains.chainsMu.Lock()
delete(chains.chains, "eip155:1")
delete(chains.chainConfigs, "eip155:1")
chains.chainsMu.Unlock()
})
}

func TestPerSyncTimeout(t *testing.T) {
Expand Down
38 changes: 26 additions & 12 deletions universalClient/chains/common/event_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,34 @@ import (

// EventProcessor processes events from the chain's database and votes on them
type EventProcessor struct {
signer *pushsigner.Signer
chainStore *ChainStore
logger zerolog.Logger
chainID string
running bool
stopCh chan struct{}
wg sync.WaitGroup
signer *pushsigner.Signer
chainStore *ChainStore
logger zerolog.Logger
chainID string
inboundEnabled bool
outboundEnabled bool
running bool
stopCh chan struct{}
wg sync.WaitGroup
}

// NewEventProcessor creates a new event processor
func NewEventProcessor(
signer *pushsigner.Signer,
database *db.DB,
chainID string,
inboundEnabled bool,
outboundEnabled bool,
logger zerolog.Logger,
) *EventProcessor {
return &EventProcessor{
signer: signer,
chainStore: NewChainStore(database),
chainID: chainID,
logger: logger.With().Str("component", "event_processor").Str("chain", chainID).Logger(),
stopCh: make(chan struct{}),
signer: signer,
chainStore: NewChainStore(database),
chainID: chainID,
inboundEnabled: inboundEnabled,
outboundEnabled: outboundEnabled,
logger: logger.With().Str("component", "event_processor").Str("chain", chainID).Logger(),
stopCh: make(chan struct{}),
}
}

Expand Down Expand Up @@ -114,6 +120,10 @@ func (ep *EventProcessor) processConfirmedEvents(ctx context.Context) error {

for _, event := range events {
if event.Type == EventTypeInbound {
if !ep.inboundEnabled {
ep.logger.Warn().Str("event_id", event.EventID).Msg("inbound disabled, skipping inbound event processing")
continue
}
if err := ep.processInboundEvent(ctx, &event); err != nil {
ep.logger.Error().
Err(err).
Expand All @@ -122,6 +132,10 @@ func (ep *EventProcessor) processConfirmedEvents(ctx context.Context) error {
continue
}
} else if event.Type == EventTypeOutbound {
if !ep.outboundEnabled {
ep.logger.Warn().Str("event_id", event.EventID).Msg("outbound disabled, skipping outbound event processing")
continue
}
if err := ep.processOutboundEvent(ctx, &event); err != nil {
ep.logger.Error().
Err(err).
Expand Down
Loading
Loading