Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions coordinator/internal/controller/api/get_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import (

// GetTaskController the get prover task api controller
type GetTaskController struct {
proverTasks map[message.ProofType]provertask.ProverTask
proverTasks map[message.ProofType]provertask.ProverTask
proverTaskManager *provertask.ProverTaskManager

getTaskAccessCounter *prometheus.CounterVec

Expand All @@ -32,12 +33,15 @@ type GetTaskController struct {

// NewGetTaskController create a get prover task controller
func NewGetTaskController(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, verifier *verifier.Verifier, reg prometheus.Registerer) *GetTaskController {
proverTaskManager := provertask.NewProverTaskManager(db)

chunkProverTask := provertask.NewChunkProverTask(cfg, chainCfg, db, verifier.ChunkVk, reg)
batchProverTask := provertask.NewBatchProverTask(cfg, chainCfg, db, verifier.BatchVk, reg)
bundleProverTask := provertask.NewBundleProverTask(cfg, chainCfg, db, verifier.BundleVk, reg)

ptc := &GetTaskController{
proverTasks: make(map[message.ProofType]provertask.ProverTask),
proverTasks: make(map[message.ProofType]provertask.ProverTask),
proverTaskManager: proverTaskManager,
getTaskAccessCounter: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{
Name: "coordinator_get_task_access_count",
Help: "Multi dimensions get task counter.",
Expand Down Expand Up @@ -99,7 +103,19 @@ func (ptc *GetTaskController) GetTasks(ctx *gin.Context) {
}
}

proofType := ptc.proofType(&getTaskParameter)
assigned, err := ptc.proverTaskManager.CheckParameter(ctx)
if err != nil {
nerr := fmt.Errorf("check prover task parameter failed, error:%w", err)
types.RenderFailure(ctx, types.ErrCoordinatorGetTaskFailure, nerr)
return
}

var proofType message.ProofType
if assigned != nil {
proofType = message.ProofType(assigned.TaskType)
} else {
proofType = ptc.proofType(&getTaskParameter)
}
proverTask, isExist := ptc.proverTasks[proofType]
if !isExist {
nerr := fmt.Errorf("parameter wrong proof type:%v", proofType)
Expand Down
23 changes: 11 additions & 12 deletions coordinator/internal/logic/provertask/batch_prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ type BatchProverTask struct {
func NewBatchProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, expectedVk map[string][]byte, reg prometheus.Registerer) *BatchProverTask {
bp := &BatchProverTask{
BaseProverTask: BaseProverTask{
db: db,
cfg: cfg,
chainCfg: chainCfg,
expectedVk: expectedVk,
blockOrm: orm.NewL2Block(db),
chunkOrm: orm.NewChunk(db),
batchOrm: orm.NewBatch(db),
proverTaskOrm: orm.NewProverTask(db),
proverBlockListOrm: orm.NewProverBlockList(db),
db: db,
cfg: cfg,
chainCfg: chainCfg,
expectedVk: expectedVk,
blockOrm: orm.NewL2Block(db),
chunkOrm: orm.NewChunk(db),
batchOrm: orm.NewBatch(db),
proverTaskOrm: orm.NewProverTask(db),
},
batchTaskGetTaskTotal: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{
Name: "coordinator_batch_get_task_total",
Expand All @@ -60,9 +59,9 @@ func NewBatchProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *go

// Assign load and assign batch tasks
func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
taskCtx, err := bp.checkParameter(ctx)
if err != nil || taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter failed, error:%w", err)
taskCtx := bp.checkParameter(ctx)
if taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter missed")
}
Comment on lines +62 to 65
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add logging and improve error message for parameter validation failures.

The current implementation returns a generic error without logging or context when parameter validation fails. This significantly reduces observability and makes debugging difficult in production.

Consider adding a log statement and a more descriptive error message:

 func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
 	taskCtx := bp.checkParameter(ctx)
 	if taskCtx == nil {
-		return nil, fmt.Errorf("check prover task parameter missed")
+		log.Warn("parameter validation failed for batch task assignment", "height", getTaskParameter.ProverHeight)
+		return nil, fmt.Errorf("failed to validate prover task parameters")
 	}

Note: If checkParameter stores error details in the context, consider retrieving and logging them here for better diagnostics.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
taskCtx := bp.checkParameter(ctx)
if taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter missed")
}
taskCtx := bp.checkParameter(ctx)
if taskCtx == nil {
log.Warn("parameter validation failed for batch task assignment", "height", getTaskParameter.ProverHeight)
return nil, fmt.Errorf("failed to validate prover task parameters")
}
🤖 Prompt for AI Agents
In coordinator/internal/logic/provertask/batch_prover_task.go around lines 62 to
65, the code returns a generic error when parameter validation fails; update
this to log a descriptive message and return a more informative error. Add a
process/scoped logger call (e.g., bp.logger or ctx logger) before returning,
include contextual details such as which parameters failed and any error stored
in the context by checkParameter (retrieve from ctx if checkParameter puts an
error there), and change the returned error text to something like "invalid
prover task parameters: <brief detail>" so both logs and the error carry
actionable information.


maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession
Expand Down
25 changes: 12 additions & 13 deletions coordinator/internal/logic/provertask/bundle_prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ type BundleProverTask struct {
func NewBundleProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, expectedVk map[string][]byte, reg prometheus.Registerer) *BundleProverTask {
bp := &BundleProverTask{
BaseProverTask: BaseProverTask{
db: db,
chainCfg: chainCfg,
cfg: cfg,
expectedVk: expectedVk,
blockOrm: orm.NewL2Block(db),
chunkOrm: orm.NewChunk(db),
batchOrm: orm.NewBatch(db),
bundleOrm: orm.NewBundle(db),
proverTaskOrm: orm.NewProverTask(db),
proverBlockListOrm: orm.NewProverBlockList(db),
db: db,
chainCfg: chainCfg,
cfg: cfg,
expectedVk: expectedVk,
blockOrm: orm.NewL2Block(db),
chunkOrm: orm.NewChunk(db),
batchOrm: orm.NewBatch(db),
bundleOrm: orm.NewBundle(db),
proverTaskOrm: orm.NewProverTask(db),
},
bundleTaskGetTaskTotal: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{
Name: "coordinator_bundle_get_task_total",
Expand All @@ -58,9 +57,9 @@ func NewBundleProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *g

// Assign load and assign batch tasks
func (bp *BundleProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
taskCtx, err := bp.checkParameter(ctx)
if err != nil || taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter failed, error:%w", err)
taskCtx := bp.checkParameter(ctx)
if taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter missed")
}

maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession
Expand Down
21 changes: 10 additions & 11 deletions coordinator/internal/logic/provertask/chunk_prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,13 @@ type ChunkProverTask struct {
func NewChunkProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, expectedVk map[string][]byte, reg prometheus.Registerer) *ChunkProverTask {
cp := &ChunkProverTask{
BaseProverTask: BaseProverTask{
db: db,
cfg: cfg,
chainCfg: chainCfg,
expectedVk: expectedVk,
chunkOrm: orm.NewChunk(db),
blockOrm: orm.NewL2Block(db),
proverTaskOrm: orm.NewProverTask(db),
proverBlockListOrm: orm.NewProverBlockList(db),
db: db,
cfg: cfg,
chainCfg: chainCfg,
expectedVk: expectedVk,
chunkOrm: orm.NewChunk(db),
blockOrm: orm.NewL2Block(db),
proverTaskOrm: orm.NewProverTask(db),
},
chunkTaskGetTaskTotal: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{
Name: "coordinator_chunk_get_task_total",
Expand All @@ -56,9 +55,9 @@ func NewChunkProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *go

// Assign the chunk proof which need to prove
func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
taskCtx, err := cp.checkParameter(ctx)
if err != nil || taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter failed, error:%w", err)
taskCtx := cp.checkParameter(ctx)
if taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter missed")
}

maxActiveAttempts := cp.cfg.ProverManager.ProversPerSession
Expand Down
143 changes: 86 additions & 57 deletions coordinator/internal/logic/provertask/prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,94 @@ type ProverTask interface {
Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error)
}

// ProverTaskManager manage task which has been assigned
type ProverTaskManager struct {
proverTaskOrm *orm.ProverTask
proverBlockListOrm *orm.ProverBlockList
}

const proverTaskCtxKey = "prover_task_context_key"

// NewProverTaskManager new a prover task manager
func NewProverTaskManager(db *gorm.DB) *ProverTaskManager {
return &ProverTaskManager{
proverTaskOrm: orm.NewProverTask(db),
proverBlockListOrm: orm.NewProverBlockList(db),
}
}

// checkParameter check the prover task parameter illegal
func (b *ProverTaskManager) CheckParameter(ctx *gin.Context) (*orm.ProverTask, error) {
var ptc proverTaskContext
ptc.HardForkNames = make(map[string]struct{})

publicKey, publicKeyExist := ctx.Get(coordinatorType.PublicKey)
if !publicKeyExist {
return nil, errors.New("get public key from context failed")
}
ptc.PublicKey = publicKey.(string)

proverName, proverNameExist := ctx.Get(coordinatorType.ProverName)
if !proverNameExist {
return nil, errors.New("get prover name from context failed")
}
ptc.ProverName = proverName.(string)

proverVersion, proverVersionExist := ctx.Get(coordinatorType.ProverVersion)
if !proverVersionExist {
return nil, errors.New("get prover version from context failed")
}
ptc.ProverVersion = proverVersion.(string)

ProverProviderType, ProverProviderTypeExist := ctx.Get(coordinatorType.ProverProviderTypeKey)
if !ProverProviderTypeExist {
// for backward compatibility, set ProverProviderType as internal
ProverProviderType = float64(coordinatorType.ProverProviderTypeInternal)
}
ptc.ProverProviderType = uint8(ProverProviderType.(float64))

hardForkNamesStr, hardForkNameExist := ctx.Get(coordinatorType.HardForkName)
if !hardForkNameExist {
return nil, errors.New("get hard fork name from context failed")
}
hardForkNames := strings.Split(hardForkNamesStr.(string), ",")
for _, hardForkName := range hardForkNames {
ptc.HardForkNames[hardForkName] = struct{}{}
}

isBlocked, err := b.proverBlockListOrm.IsPublicKeyBlocked(ctx.Copy(), publicKey.(string))
if err != nil {
return nil, fmt.Errorf("failed to check whether the public key %s is blocked before assigning a chunk task, err: %w, proverName: %s, proverVersion: %s", publicKey, err, proverName, proverVersion)
}
if isBlocked {
return nil, fmt.Errorf("public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", publicKey, proverName, proverVersion)
}

assigned, err := b.proverTaskOrm.IsProverAssigned(ctx.Copy(), publicKey.(string))
if err != nil {
return nil, fmt.Errorf("failed to check if prover %s is assigned a task, err: %w", publicKey.(string), err)
}

ptc.hasAssignedTask = assigned

ctx.Set(proverTaskCtxKey, &ptc)

return assigned, nil
}
Comment on lines +57 to +113
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Address type assertion safety concerns.

The method performs multiple type assertions without checking the underlying type (lines 65, 71, 77, 84, 90), which could panic if the context values have unexpected types. While these values are typically set by middleware, consider either:

  1. Using type assertion with the comma-ok idiom: value, ok := publicKey.(string)
  2. Adding defensive checks to prevent panics in case of misconfiguration

Additionally, line 84 has a compound operation that's particularly risky:

ptc.ProverProviderType = uint8(ProverProviderType.(float64))

If ProverProviderType is not a float64, this will panic.

Apply this pattern for safer type assertions:

-	ptc.PublicKey = publicKey.(string)
+	publicKeyStr, ok := publicKey.(string)
+	if !ok {
+		return nil, fmt.Errorf("public key has invalid type: expected string")
+	}
+	ptc.PublicKey = publicKeyStr

Consider applying similar checks for other type assertions, especially line 84 which converts float64 to uint8.

🤖 Prompt for AI Agents
coordinator/internal/logic/provertask/prover_task.go lines 57-113: several
context value type assertions (publicKey, proverName, proverVersion,
ProverProviderType, hardForkNamesStr) use direct assertions and can panic;
change each to use the comma-ok form and return a clear error if the stored type
is unexpected (e.g., assert string for
publicKey/proverName/proverVersion/hardForkNamesStr), handle ProverProviderType
with a type switch (accept float64 or numeric types and convert safely to uint8,
or fall back to the internal default), and ensure any conversion errors return
descriptive errors rather than causing panics; update ptc fields only after
successful assertions and keep ctx.Set(proverTaskCtxKey, &ptc) as the final
step.


// BaseProverTask a base prover task which contain series functions
type BaseProverTask struct {
cfg *config.Config
chainCfg *params.ChainConfig
db *gorm.DB
expectedVk map[string][]byte

batchOrm *orm.Batch
chunkOrm *orm.Chunk
bundleOrm *orm.Bundle
blockOrm *orm.L2Block
proverTaskOrm *orm.ProverTask
proverBlockListOrm *orm.ProverBlockList
batchOrm *orm.Batch
chunkOrm *orm.Chunk
bundleOrm *orm.Bundle
blockOrm *orm.L2Block

proverTaskOrm *orm.ProverTask
}

type proverTaskContext struct {
Expand Down Expand Up @@ -132,59 +207,13 @@ func (b *BaseProverTask) hardForkSanityCheck(ctx *gin.Context, taskCtx *proverTa
}

// checkParameter check the prover task parameter illegal
func (b *BaseProverTask) checkParameter(ctx *gin.Context) (*proverTaskContext, error) {
var ptc proverTaskContext
ptc.HardForkNames = make(map[string]struct{})

publicKey, publicKeyExist := ctx.Get(coordinatorType.PublicKey)
if !publicKeyExist {
return nil, errors.New("get public key from context failed")
}
ptc.PublicKey = publicKey.(string)

proverName, proverNameExist := ctx.Get(coordinatorType.ProverName)
if !proverNameExist {
return nil, errors.New("get prover name from context failed")
}
ptc.ProverName = proverName.(string)

proverVersion, proverVersionExist := ctx.Get(coordinatorType.ProverVersion)
if !proverVersionExist {
return nil, errors.New("get prover version from context failed")
}
ptc.ProverVersion = proverVersion.(string)

ProverProviderType, ProverProviderTypeExist := ctx.Get(coordinatorType.ProverProviderTypeKey)
if !ProverProviderTypeExist {
// for backward compatibility, set ProverProviderType as internal
ProverProviderType = float64(coordinatorType.ProverProviderTypeInternal)
}
ptc.ProverProviderType = uint8(ProverProviderType.(float64))

hardForkNamesStr, hardForkNameExist := ctx.Get(coordinatorType.HardForkName)
if !hardForkNameExist {
return nil, errors.New("get hard fork name from context failed")
}
hardForkNames := strings.Split(hardForkNamesStr.(string), ",")
for _, hardForkName := range hardForkNames {
ptc.HardForkNames[hardForkName] = struct{}{}
func (b *BaseProverTask) checkParameter(ctx *gin.Context) *proverTaskContext {
pctx, exist := ctx.Get(proverTaskCtxKey)
if !exist {
return nil
}

isBlocked, err := b.proverBlockListOrm.IsPublicKeyBlocked(ctx.Copy(), publicKey.(string))
if err != nil {
return nil, fmt.Errorf("failed to check whether the public key %s is blocked before assigning a chunk task, err: %w, proverName: %s, proverVersion: %s", publicKey, err, proverName, proverVersion)
}
if isBlocked {
return nil, fmt.Errorf("public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", publicKey, proverName, proverVersion)
}

assigned, err := b.proverTaskOrm.IsProverAssigned(ctx.Copy(), publicKey.(string))
if err != nil {
return nil, fmt.Errorf("failed to check if prover %s is assigned a task, err: %w", publicKey.(string), err)
}

ptc.hasAssignedTask = assigned
return &ptc, nil
return pctx.(*proverTaskContext)
}

func (b *BaseProverTask) applyUniversal(schema *coordinatorType.GetTaskSchema) (*coordinatorType.GetTaskSchema, []byte, error) {
Expand Down
4 changes: 2 additions & 2 deletions coordinator/test/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func testGetTaskBlocked(t *testing.T) {
err := proverBlockListOrm.InsertProverPublicKey(context.Background(), chunkProver.proverName, chunkProver.publicKey())
assert.NoError(t, err)

expectedErr := fmt.Errorf("return prover task err:check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", chunkProver.publicKey(), chunkProver.proverName, chunkProver.proverVersion)
expectedErr := fmt.Errorf("check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", chunkProver.publicKey(), chunkProver.proverName, chunkProver.proverVersion)
code, errMsg := chunkProver.tryGetProverTask(t, message.ProofTypeChunk)
assert.Equal(t, types.ErrCoordinatorGetTaskFailure, code)
assert.Equal(t, expectedErr, errors.New(errMsg))
Expand All @@ -255,7 +255,7 @@ func testGetTaskBlocked(t *testing.T) {
assert.Equal(t, types.ErrCoordinatorEmptyProofData, code)
assert.Equal(t, expectedErr, errors.New(errMsg))

expectedErr = fmt.Errorf("return prover task err:check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", batchProver.publicKey(), batchProver.proverName, batchProver.proverVersion)
expectedErr = fmt.Errorf("check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", batchProver.publicKey(), batchProver.proverName, batchProver.proverVersion)
code, errMsg = batchProver.tryGetProverTask(t, message.ProofTypeBatch)
assert.Equal(t, types.ErrCoordinatorGetTaskFailure, code)
assert.Equal(t, expectedErr, errors.New(errMsg))
Expand Down
Loading