diff --git a/internal/gateway/server.go b/internal/gateway/server.go index ae67fd7..0f437bc 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -137,7 +137,7 @@ func (s *Server) setupJSONRPCHandlers() { // Add middleware s.rpcServer.AddMiddleware(jsonrpc.RequestIDMiddleware()) s.rpcServer.AddMiddleware(jsonrpc.LoggingMiddleware(s.logger)) - s.rpcServer.AddMiddleware(jsonrpc.TimeoutMiddleware(30 * time.Second)) + s.rpcServer.AddMiddleware(jsonrpc.TimeoutMiddleware(30*time.Second, s.logger)) // Register handlers based on mode if s.config.Sharding.Mode.IsParent() { diff --git a/internal/service/service.go b/internal/service/service.go index 142e0e9..f8edc82 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -234,7 +234,16 @@ func (as *AggregatorService) GetInclusionProof(ctx context.Context, req *api.Get if err != nil { return nil, fmt.Errorf("failed to get path for request ID %s: %w", req.RequestID, err) } - merkleTreePath, err := as.roundManager.GetSMT().GetPath(path) + + smtInstance := as.roundManager.GetSMT() + if smtInstance == nil { + return nil, fmt.Errorf("merkle tree not initialized") + } + if keyLen := smtInstance.GetKeyLength(); path.BitLen()-1 != keyLen { + return nil, fmt.Errorf("request path length %d does not match SMT key length %d", path.BitLen()-1, keyLen) + } + + merkleTreePath, err := smtInstance.GetPath(path) if err != nil { return nil, fmt.Errorf("failed to get inclusion proof for request ID %s: %w", req.RequestID, err) } diff --git a/internal/service/service_test.go b/internal/service/service_test.go index ba1d7e9..d166c27 100644 --- a/internal/service/service_test.go +++ b/internal/service/service_test.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "strconv" + "strings" "testing" "time" @@ -242,7 +243,7 @@ func validateInclusionProof(t *testing.T, proof *api.InclusionProof, requestID a assert.NotNil(t, proof.MerkleTreePath, "Should have merkle tree path") // Validate unicity certificate field - if proof.UnicityCertificate != nil && len(proof.UnicityCertificate) > 0 { + if len(proof.UnicityCertificate) > 0 { assert.NotEmpty(t, proof.UnicityCertificate, "Unicity certificate should not be empty") // Verify it's valid hex-encoded data _, err := hex.DecodeString(string(proof.UnicityCertificate)) @@ -326,6 +327,68 @@ func (suite *AggregatorTestSuite) TestInclusionProofMissingRecord() { suite.NotEmpty(inclusionProof.InclusionProof.UnicityCertificate, "UnicityCertificate should not be empty") } +func TestGetInclusionProofShardMismatch(t *testing.T) { + shardingCfg := config.ShardingConfig{ + Mode: config.ShardingModeChild, + Child: config.ChildConfig{ + ShardID: 4, + }, + } + tree := smt.NewChildSparseMerkleTree(api.SHA256, 16+256, shardingCfg.Child.ShardID) + service := newAggregatorServiceForTest(t, shardingCfg, tree) + + invalidShardID := api.RequestID(strings.Repeat("00", 33) + "01") + _, err := service.GetInclusionProof(context.Background(), &api.GetInclusionProofRequest{RequestID: invalidShardID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "request ID validation failed") +} + +func TestGetInclusionProofInvalidRequestFormat(t *testing.T) { + shardingCfg := config.ShardingConfig{ + Mode: config.ShardingModeChild, + Child: config.ChildConfig{ + ShardID: 4, + }, + } + tree := smt.NewChildSparseMerkleTree(api.SHA256, 16+256, shardingCfg.Child.ShardID) + service := newAggregatorServiceForTest(t, shardingCfg, tree) + + _, err := service.GetInclusionProof(context.Background(), &api.GetInclusionProofRequest{RequestID: api.RequestID("zz")}) + require.Error(t, err) + assert.Contains(t, err.Error(), "request ID validation failed") +} + +func TestGetInclusionProofSMTUnavailable(t *testing.T) { + shardingCfg := config.ShardingConfig{ + Mode: config.ShardingModeChild, + Child: config.ChildConfig{ + ShardID: 4, + }, + } + service := newAggregatorServiceForTest(t, shardingCfg, nil) + + validID := api.RequestID(strings.Repeat("00", 34)) + _, err := service.GetInclusionProof(context.Background(), &api.GetInclusionProofRequest{RequestID: validID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "merkle tree not initialized") +} + +func TestInclusionProofInvalidPathLength(t *testing.T) { + shardingCfg := config.ShardingConfig{ + Mode: config.ShardingModeStandalone, + } + tree := smt.NewSparseMerkleTree(api.SHA256, 16+256) + service := newAggregatorServiceForTest(t, shardingCfg, tree) + + validID := createTestCommitments(t, 1)[0].RequestID.String() + require.Greater(t, len(validID), 2) + badID := api.RequestID(validID[2:]) + + _, err := service.GetInclusionProof(context.Background(), &api.GetInclusionProofRequest{RequestID: badID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "path length") +} + // TestInclusionProof tests the complete inclusion proof workflow func (suite *AggregatorTestSuite) TestInclusionProof() { @@ -445,3 +508,34 @@ func createTestCommitments(t *testing.T, count int) []*api.SubmitCommitmentReque return commitments } + +type stubRoundManager struct { + smt *smt.ThreadSafeSMT +} + +func (s *stubRoundManager) Start(context.Context) error { return nil } +func (s *stubRoundManager) Stop(context.Context) error { return nil } +func (s *stubRoundManager) Activate(context.Context) error { return nil } +func (s *stubRoundManager) Deactivate(context.Context) error { + return nil +} +func (s *stubRoundManager) GetSMT() *smt.ThreadSafeSMT { return s.smt } +func (s *stubRoundManager) CheckParentHealth(context.Context) error { return nil } + +func newAggregatorServiceForTest(t *testing.T, shardingCfg config.ShardingConfig, baseTree *smt.SparseMerkleTree) *AggregatorService { + t.Helper() + log, err := logger.New("error", "text", "stdout", false) + require.NoError(t, err) + + var tsmt *smt.ThreadSafeSMT + if baseTree != nil { + tsmt = smt.NewThreadSafeSMT(baseTree) + } + + return &AggregatorService{ + config: &config.Config{Sharding: shardingCfg}, + logger: log, + roundManager: &stubRoundManager{smt: tsmt}, + commitmentValidator: signing.NewCommitmentValidator(shardingCfg), + } +} diff --git a/internal/smt/thread_safe_smt.go b/internal/smt/thread_safe_smt.go index 6efee2a..fa5da1f 100644 --- a/internal/smt/thread_safe_smt.go +++ b/internal/smt/thread_safe_smt.go @@ -82,6 +82,13 @@ func (ts *ThreadSafeSMT) GetPath(path *big.Int) (*api.MerkleTreePath, error) { return ts.smt.GetPath(path) } +// GetKeyLength exposes the configured SMT key length. +func (ts *ThreadSafeSMT) GetKeyLength() int { + ts.rwMux.RLock() + defer ts.rwMux.RUnlock() + return ts.smt.keyLength +} + // GetStats returns statistics about the SMT // This is a read operation that can be performed concurrently func (ts *ThreadSafeSMT) GetStats() map[string]interface{} { diff --git a/pkg/jsonrpc/handler.go b/pkg/jsonrpc/handler.go index 291bbc5..00bebcb 100644 --- a/pkg/jsonrpc/handler.go +++ b/pkg/jsonrpc/handler.go @@ -207,7 +207,7 @@ func LoggingMiddleware(logger *logger.Logger) MiddlewareFunc { } // TimeoutMiddleware adds timeout to requests -func TimeoutMiddleware(timeout time.Duration) MiddlewareFunc { +func TimeoutMiddleware(timeout time.Duration, log *logger.Logger) MiddlewareFunc { return func(ctx context.Context, req *Request, next func(context.Context, *Request) *Response) *Response { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -215,7 +215,22 @@ func TimeoutMiddleware(timeout time.Duration) MiddlewareFunc { done := make(chan *Response, 1) go func() { - done <- next(ctx, req) + defer func() { + if r := recover(); r != nil { + log.WithContext(ctx).Error("panic in JSON-RPC handler", "panic", r, "method", req.Method) + errResp := NewErrorResponse(NewError(InternalErrorCode, "Internal server error", nil), req.ID) + select { + case done <- errResp: + default: + } + } + }() + + resp := next(ctx, req) + select { + case done <- resp: + default: + } }() select { diff --git a/pkg/jsonrpc/handler_test.go b/pkg/jsonrpc/handler_test.go new file mode 100644 index 0000000..fc51df1 --- /dev/null +++ b/pkg/jsonrpc/handler_test.go @@ -0,0 +1,26 @@ +package jsonrpc + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/unicitynetwork/aggregator-go/internal/logger" +) + +func TestTimeoutMiddlewareRecoversFromPanic(t *testing.T) { + testLogger, _ := logger.New("error", "text", "stdout", false) + mw := TimeoutMiddleware(50*time.Millisecond, testLogger) + + req := &Request{ID: 1} + + resp := mw(context.Background(), req, func(ctx context.Context, r *Request) *Response { + panic("boom") + }) + + require.NotNil(t, resp, "middleware should return a response even on panic") + require.NotNil(t, resp.Error, "response should contain error information") + assert.Equal(t, InternalErrorCode, resp.Error.Code) +}