Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/gateway/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (s *Server) setupJSONRPCHandlers() {
// Add middleware
s.rpcServer.AddMiddleware(jsonrpc.RequestIDMiddleware())
s.rpcServer.AddMiddleware(jsonrpc.LoggingMiddleware(s.logger))
s.rpcServer.AddMiddleware(jsonrpc.TimeoutMiddleware(30 * time.Second))
s.rpcServer.AddMiddleware(jsonrpc.TimeoutMiddleware(30*time.Second, s.logger))

// Register handlers based on mode
if s.config.Sharding.Mode.IsParent() {
Expand Down
11 changes: 10 additions & 1 deletion internal/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,16 @@ func (as *AggregatorService) GetInclusionProof(ctx context.Context, req *api.Get
if err != nil {
return nil, fmt.Errorf("failed to get path for request ID %s: %w", req.RequestID, err)
}
merkleTreePath, err := as.roundManager.GetSMT().GetPath(path)

smtInstance := as.roundManager.GetSMT()
if smtInstance == nil {
return nil, fmt.Errorf("merkle tree not initialized")
}
if keyLen := smtInstance.GetKeyLength(); path.BitLen()-1 != keyLen {
return nil, fmt.Errorf("request path length %d does not match SMT key length %d", path.BitLen()-1, keyLen)
}

merkleTreePath, err := smtInstance.GetPath(path)
if err != nil {
return nil, fmt.Errorf("failed to get inclusion proof for request ID %s: %w", req.RequestID, err)
}
Expand Down
96 changes: 95 additions & 1 deletion internal/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"
"os"
"strconv"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -242,7 +243,7 @@ func validateInclusionProof(t *testing.T, proof *api.InclusionProof, requestID a
assert.NotNil(t, proof.MerkleTreePath, "Should have merkle tree path")

// Validate unicity certificate field
if proof.UnicityCertificate != nil && len(proof.UnicityCertificate) > 0 {
if len(proof.UnicityCertificate) > 0 {
assert.NotEmpty(t, proof.UnicityCertificate, "Unicity certificate should not be empty")
// Verify it's valid hex-encoded data
_, err := hex.DecodeString(string(proof.UnicityCertificate))
Expand Down Expand Up @@ -326,6 +327,68 @@ func (suite *AggregatorTestSuite) TestInclusionProofMissingRecord() {
suite.NotEmpty(inclusionProof.InclusionProof.UnicityCertificate, "UnicityCertificate should not be empty")
}

func TestGetInclusionProofShardMismatch(t *testing.T) {
shardingCfg := config.ShardingConfig{
Mode: config.ShardingModeChild,
Child: config.ChildConfig{
ShardID: 4,
},
}
tree := smt.NewChildSparseMerkleTree(api.SHA256, 16+256, shardingCfg.Child.ShardID)
service := newAggregatorServiceForTest(t, shardingCfg, tree)

invalidShardID := api.RequestID(strings.Repeat("00", 33) + "01")
_, err := service.GetInclusionProof(context.Background(), &api.GetInclusionProofRequest{RequestID: invalidShardID})
require.Error(t, err)
assert.Contains(t, err.Error(), "request ID validation failed")
}

func TestGetInclusionProofInvalidRequestFormat(t *testing.T) {
shardingCfg := config.ShardingConfig{
Mode: config.ShardingModeChild,
Child: config.ChildConfig{
ShardID: 4,
},
}
tree := smt.NewChildSparseMerkleTree(api.SHA256, 16+256, shardingCfg.Child.ShardID)
service := newAggregatorServiceForTest(t, shardingCfg, tree)

_, err := service.GetInclusionProof(context.Background(), &api.GetInclusionProofRequest{RequestID: api.RequestID("zz")})
require.Error(t, err)
assert.Contains(t, err.Error(), "request ID validation failed")
}

func TestGetInclusionProofSMTUnavailable(t *testing.T) {
shardingCfg := config.ShardingConfig{
Mode: config.ShardingModeChild,
Child: config.ChildConfig{
ShardID: 4,
},
}
service := newAggregatorServiceForTest(t, shardingCfg, nil)

validID := api.RequestID(strings.Repeat("00", 34))
_, err := service.GetInclusionProof(context.Background(), &api.GetInclusionProofRequest{RequestID: validID})
require.Error(t, err)
assert.Contains(t, err.Error(), "merkle tree not initialized")
}

func TestInclusionProofInvalidPathLength(t *testing.T) {
shardingCfg := config.ShardingConfig{
Mode: config.ShardingModeStandalone,
}
tree := smt.NewSparseMerkleTree(api.SHA256, 16+256)
service := newAggregatorServiceForTest(t, shardingCfg, tree)

validID := createTestCommitments(t, 1)[0].RequestID.String()
require.Greater(t, len(validID), 2)
badID := api.RequestID(validID[2:])

_, err := service.GetInclusionProof(context.Background(), &api.GetInclusionProofRequest{RequestID: badID})
require.Error(t, err)
assert.Contains(t, err.Error(), "path length")
}

// TestInclusionProof tests the complete inclusion proof workflow
func (suite *AggregatorTestSuite) TestInclusionProof() {

Expand Down Expand Up @@ -445,3 +508,34 @@ func createTestCommitments(t *testing.T, count int) []*api.SubmitCommitmentReque

return commitments
}

type stubRoundManager struct {
smt *smt.ThreadSafeSMT
}

func (s *stubRoundManager) Start(context.Context) error { return nil }
func (s *stubRoundManager) Stop(context.Context) error { return nil }
func (s *stubRoundManager) Activate(context.Context) error { return nil }
func (s *stubRoundManager) Deactivate(context.Context) error {
return nil
}
func (s *stubRoundManager) GetSMT() *smt.ThreadSafeSMT { return s.smt }
func (s *stubRoundManager) CheckParentHealth(context.Context) error { return nil }

func newAggregatorServiceForTest(t *testing.T, shardingCfg config.ShardingConfig, baseTree *smt.SparseMerkleTree) *AggregatorService {
t.Helper()
log, err := logger.New("error", "text", "stdout", false)
require.NoError(t, err)

var tsmt *smt.ThreadSafeSMT
if baseTree != nil {
tsmt = smt.NewThreadSafeSMT(baseTree)
}

return &AggregatorService{
config: &config.Config{Sharding: shardingCfg},
logger: log,
roundManager: &stubRoundManager{smt: tsmt},
commitmentValidator: signing.NewCommitmentValidator(shardingCfg),
}
}
7 changes: 7 additions & 0 deletions internal/smt/thread_safe_smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ func (ts *ThreadSafeSMT) GetPath(path *big.Int) (*api.MerkleTreePath, error) {
return ts.smt.GetPath(path)
}

// GetKeyLength exposes the configured SMT key length.
func (ts *ThreadSafeSMT) GetKeyLength() int {
ts.rwMux.RLock()
defer ts.rwMux.RUnlock()
return ts.smt.keyLength
}

// GetStats returns statistics about the SMT
// This is a read operation that can be performed concurrently
func (ts *ThreadSafeSMT) GetStats() map[string]interface{} {
Expand Down
19 changes: 17 additions & 2 deletions pkg/jsonrpc/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,30 @@ func LoggingMiddleware(logger *logger.Logger) MiddlewareFunc {
}

// TimeoutMiddleware adds timeout to requests
func TimeoutMiddleware(timeout time.Duration) MiddlewareFunc {
func TimeoutMiddleware(timeout time.Duration, log *logger.Logger) MiddlewareFunc {
return func(ctx context.Context, req *Request, next func(context.Context, *Request) *Response) *Response {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

done := make(chan *Response, 1)

go func() {
done <- next(ctx, req)
defer func() {
if r := recover(); r != nil {
log.WithContext(ctx).Error("panic in JSON-RPC handler", "panic", r, "method", req.Method)
errResp := NewErrorResponse(NewError(InternalErrorCode, "Internal server error", nil), req.ID)
select {
case done <- errResp:
default:
}
}
}()

resp := next(ctx, req)
select {
case done <- resp:
default:
}
}()

select {
Expand Down
26 changes: 26 additions & 0 deletions pkg/jsonrpc/handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package jsonrpc

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/unicitynetwork/aggregator-go/internal/logger"
)

func TestTimeoutMiddlewareRecoversFromPanic(t *testing.T) {
testLogger, _ := logger.New("error", "text", "stdout", false)
mw := TimeoutMiddleware(50*time.Millisecond, testLogger)

req := &Request{ID: 1}

resp := mw(context.Background(), req, func(ctx context.Context, r *Request) *Response {
panic("boom")
})

require.NotNil(t, resp, "middleware should return a response even on panic")
require.NotNil(t, resp.Error, "response should contain error information")
assert.Equal(t, InternalErrorCode, resp.Error.Code)
}