From 91c54b9322fbc151c05ccf4b10e070889551459f Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sat, 21 Mar 2026 10:54:42 -0700 Subject: [PATCH 01/15] clean up concurrency config --- docs/spec/map.md | 4 +- .../lambda/durable/CompletionConfig.java | 32 +++++++- .../amazon/lambda/durable/DurableContext.java | 24 ++++++ .../amazon/lambda/durable/MapConfig.java | 25 ++++-- .../amazon/lambda/durable/MapFunction.java | 27 ------- .../amazon/lambda/durable/ParallelConfig.java | 63 +++++---------- .../durable/context/DurableContextImpl.java | 5 +- .../operation/ConcurrencyOperation.java | 52 +++++++++---- .../durable/operation/MapOperation.java | 76 +++++-------------- .../durable/operation/ParallelOperation.java | 54 +++---------- .../amazon/lambda/durable/MapConfigTest.java | 4 +- .../lambda/durable/MapFunctionTest.java | 8 +- .../lambda/durable/ParallelConfigTest.java | 54 +++---------- .../operation/ConcurrencyOperationTest.java | 68 ++++------------- .../operation/ParallelOperationTest.java | 32 ++++---- 15 files changed, 215 insertions(+), 313 deletions(-) delete mode 100644 sdk/src/main/java/software/amazon/lambda/durable/MapFunction.java diff --git a/docs/spec/map.md b/docs/spec/map.md index 87f549e54..51a502f14 100644 --- a/docs/spec/map.md +++ b/docs/spec/map.md @@ -481,7 +481,7 @@ import java.util.List; import java.util.function.Function; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.MapConfig; -import software.amazon.lambda.durable.MapFunction; +import software.amazon.lambda.durable.DurableContext.MapFunction; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.model.BatchResult; import software.amazon.lambda.durable.model.OperationSubType; @@ -901,7 +901,7 @@ package software.amazon.lambda.durable.operation; import java.util.List; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.MapConfig; -import software.amazon.lambda.durable.MapFunction; +import software.amazon.lambda.durable.DurableContext.MapFunction; import software.amazon.lambda.durable.model.BatchResult; import software.amazon.lambda.durable.model.OperationSubType; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java index fc52cd1a3..de91f27fd 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java @@ -2,6 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 package software.amazon.lambda.durable; +import java.util.Objects; + /** * Controls when a concurrent operation (map or parallel) completes. * @@ -13,7 +15,7 @@ public class CompletionConfig { private final Integer toleratedFailureCount; private final Double toleratedFailurePercentage; - private CompletionConfig(Integer minSuccessful, Integer toleratedFailureCount, Double toleratedFailurePercentage) { + CompletionConfig(Integer minSuccessful, Integer toleratedFailureCount, Double toleratedFailurePercentage) { this.minSuccessful = minSuccessful; this.toleratedFailureCount = toleratedFailureCount; this.toleratedFailurePercentage = toleratedFailurePercentage; @@ -73,4 +75,32 @@ public Integer toleratedFailureCount() { public Double toleratedFailurePercentage() { return toleratedFailurePercentage; } + + @Override + public String toString() { + return "CompletionConfig{" + "minSuccessful=" + + minSuccessful + ", toleratedFailureCount=" + + toleratedFailureCount + ", toleratedFailurePercentage=" + + toleratedFailurePercentage + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + CompletionConfig that = (CompletionConfig) o; + + return Objects.equals(minSuccessful, that.minSuccessful) + && Objects.equals(toleratedFailureCount, that.toleratedFailureCount) + && Objects.equals(toleratedFailurePercentage, that.toleratedFailurePercentage); + } + + @Override + public int hashCode() { + int result = minSuccessful != null ? minSuccessful.hashCode() : 0; + result = 31 * result + (toleratedFailureCount != null ? toleratedFailureCount.hashCode() : 0); + result = 31 * result + (toleratedFailurePercentage != null ? toleratedFailurePercentage.hashCode() : 0); + return result; + } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java index 79ebda106..4d614e229 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java @@ -440,4 +440,28 @@ DurableFuture waitForConditionAsync( BiFunction> checkFunc, T initialState, WaitForConditionConfig config); + + /** + * Function applied to each item in a map operation. + * + *

Each invocation receives its own {@link DurableContext}, allowing the use of durable operations like + * {@code step()} and {@code wait()} within the function body. The index parameter indicates the item's position in + * the input collection. + * + * @param the input item type + * @param the output result type + */ + @FunctionalInterface + interface MapFunction { + + /** + * Applies this function to the given item. + * + * @param item the input item to process + * @param index the zero-based index of the item in the input collection + * @param context the durable context for this item's execution + * @return the result of processing the item + */ + O apply(I item, int index, DurableContext context); + } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/MapConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/MapConfig.java index d4f6b583c..6a68ea98e 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/MapConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/MapConfig.java @@ -15,8 +15,9 @@ public class MapConfig { private final SerDes serDes; private MapConfig(Builder builder) { - this.maxConcurrency = builder.maxConcurrency; - this.completionConfig = builder.completionConfig; + this.maxConcurrency = builder.maxConcurrency == null ? Integer.MAX_VALUE : builder.maxConcurrency; + this.completionConfig = + builder.completionConfig == null ? CompletionConfig.allCompleted() : builder.completionConfig; this.serDes = builder.serDes; } @@ -27,7 +28,7 @@ public Integer maxConcurrency() { /** @return completion criteria, defaults to {@link CompletionConfig#allCompleted()} */ public CompletionConfig completionConfig() { - return completionConfig != null ? completionConfig : CompletionConfig.allCompleted(); + return completionConfig; } /** @return the custom serializer, or null to use the default */ @@ -56,24 +57,36 @@ private Builder(Integer maxConcurrency, CompletionConfig completionConfig, SerDe } public Builder maxConcurrency(Integer maxConcurrency) { + if (maxConcurrency != null && maxConcurrency < 1) { + throw new IllegalArgumentException("maxConcurrency must be at least 1, got: " + maxConcurrency); + } this.maxConcurrency = maxConcurrency; return this; } + /** + * Sets the completion criteria for the map operation. + * + * @param completionConfig the completion configuration (default: {@link CompletionConfig#allCompleted()}) + * @return this builder for method chaining + */ public Builder completionConfig(CompletionConfig completionConfig) { this.completionConfig = completionConfig; return this; } + /** + * Sets the custom serializer to use for serializing map items and results. + * + * @param serDes the serializer to use + * @return this builder for method chaining + */ public Builder serDes(SerDes serDes) { this.serDes = serDes; return this; } public MapConfig build() { - if (maxConcurrency != null && maxConcurrency < 1) { - throw new IllegalArgumentException("maxConcurrency must be at least 1, got: " + maxConcurrency); - } return new MapConfig(this); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/MapFunction.java b/sdk/src/main/java/software/amazon/lambda/durable/MapFunction.java deleted file mode 100644 index 041dccfc9..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/MapFunction.java +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; - -/** - * Function applied to each item in a map operation. - * - *

Each invocation receives its own {@link DurableContext}, allowing the use of durable operations like - * {@code step()} and {@code wait()} within the function body. The index parameter indicates the item's position in the - * input collection. - * - * @param the input item type - * @param the output result type - */ -@FunctionalInterface -public interface MapFunction { - - /** - * Applies this function to the given item. - * - * @param item the input item to process - * @param index the zero-based index of the item in the input collection - * @param context the durable context for this item's execution - * @return the result of processing the item - */ - O apply(I item, int index, DurableContext context); -} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ParallelConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/ParallelConfig.java index 67992801f..a04a93bd4 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/ParallelConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/ParallelConfig.java @@ -10,13 +10,12 @@ */ public class ParallelConfig { private final int maxConcurrency; - private final int minSuccessful; - private final int toleratedFailureCount; + private final CompletionConfig completionConfig; private ParallelConfig(Builder builder) { - this.maxConcurrency = builder.maxConcurrency; - this.minSuccessful = builder.minSuccessful; - this.toleratedFailureCount = builder.toleratedFailureCount; + this.maxConcurrency = builder.maxConcurrency == null ? Integer.MAX_VALUE : builder.maxConcurrency; + this.completionConfig = + builder.completionConfig == null ? CompletionConfig.allCompleted() : builder.completionConfig; } /** @return the maximum number of branches running simultaneously, or -1 for unlimited */ @@ -24,14 +23,8 @@ public int maxConcurrency() { return maxConcurrency; } - /** @return the minimum number of successful branches required, or -1 meaning all must succeed */ - public int minSuccessful() { - return minSuccessful; - } - - /** @return the maximum number of branch failures tolerated before stopping */ - public int toleratedFailureCount() { - return toleratedFailureCount; + public CompletionConfig completionConfig() { + return completionConfig; } /** @@ -45,42 +38,36 @@ public static Builder builder() { /** Builder for creating ParallelConfig instances. */ public static class Builder { - private int maxConcurrency = -1; - private int minSuccessful = -1; - private int toleratedFailureCount = 0; + private Integer maxConcurrency; + private CompletionConfig completionConfig; private Builder() {} /** * Sets the maximum number of branches that can run simultaneously. * - * @param maxConcurrency the concurrency limit, or -1 for unlimited + * @param maxConcurrency the concurrency limit (default: unlimited) * @return this builder for method chaining */ - public Builder maxConcurrency(int maxConcurrency) { + public Builder maxConcurrency(Integer maxConcurrency) { + if (maxConcurrency != null && maxConcurrency < 1) { + throw new IllegalArgumentException("maxConcurrency must be at least 1, got: " + maxConcurrency); + } this.maxConcurrency = maxConcurrency; return this; } /** - * Sets the minimum number of branches that must succeed for the parallel operation to complete successfully. - * - * @param minSuccessful the minimum successful count, or -1 meaning all branches must succeed - * @return this builder for method chaining - */ - public Builder minSuccessful(int minSuccessful) { - this.minSuccessful = minSuccessful; - return this; - } - - /** - * Sets the maximum number of branch failures tolerated before stopping execution. + * Sets the maximum number of branches that can run simultaneously. * - * @param toleratedFailureCount the maximum tolerated failures + * @param completionConfig the completion configuration for the parallel operation * @return this builder for method chaining */ - public Builder toleratedFailureCount(int toleratedFailureCount) { - this.toleratedFailureCount = toleratedFailureCount; + public Builder completionConfig(CompletionConfig completionConfig) { + if (completionConfig != null && completionConfig.toleratedFailurePercentage() != null) { + throw new IllegalArgumentException("ParallelConfig does not support toleratedFailurePercentage"); + } + this.completionConfig = completionConfig; return this; } @@ -91,16 +78,6 @@ public Builder toleratedFailureCount(int toleratedFailureCount) { * @throws IllegalArgumentException if any configuration values are invalid */ public ParallelConfig build() { - if (maxConcurrency != -1 && maxConcurrency <= 0) { - throw new IllegalArgumentException( - "maxConcurrency must be -1 (unlimited) or greater than 0, got: " + maxConcurrency); - } - if (minSuccessful < -1) { - throw new IllegalArgumentException("minSuccessful must be >= -1, got: " + minSuccessful); - } - if (toleratedFailureCount < 0) { - throw new IllegalArgumentException("toleratedFailureCount must be >= 0, got: " + toleratedFailureCount); - } return new ParallelConfig(this); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java index e171276de..445418efa 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java @@ -20,7 +20,6 @@ import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.InvokeConfig; import software.amazon.lambda.durable.MapConfig; -import software.amazon.lambda.durable.MapFunction; import software.amazon.lambda.durable.ParallelConfig; import software.amazon.lambda.durable.ParallelContext; import software.amazon.lambda.durable.StepConfig; @@ -562,9 +561,7 @@ public ParallelContext parallel(String name, ParallelConfig config) { OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL), getDurableConfig().getSerDes(), this, - config.maxConcurrency(), - config.minSuccessful(), - config.toleratedFailureCount()); + config); parallelOp.execute(); diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index f46b890b4..42430d56a 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -45,6 +45,8 @@ public abstract class ConcurrencyOperation extends BaseDurableOperation { private static final Logger logger = LoggerFactory.getLogger(ConcurrencyOperation.class); private final int maxConcurrency; + private final Integer minSuccessful; + private final Integer toleratedFailureCount; private final AtomicInteger succeededCount = new AtomicInteger(0); private final AtomicInteger failedCount = new AtomicInteger(0); private final AtomicInteger runningCount = new AtomicInteger(0); @@ -52,7 +54,7 @@ public abstract class ConcurrencyOperation extends BaseDurableOperation { private final Queue> pendingQueue = new ConcurrentLinkedDeque<>(); private final List> childOperations = Collections.synchronizedList(new ArrayList<>()); private final Set completedOperations = Collections.synchronizedSet(new HashSet()); - private OperationIdGenerator operationIdGenerator; + private final OperationIdGenerator operationIdGenerator; private final DurableContextImpl rootContext; private ConcurrencyCompletionStatus completionStatus; @@ -61,9 +63,13 @@ protected ConcurrencyOperation( TypeToken resultTypeToken, SerDes resultSerDes, DurableContextImpl durableContext, - int maxConcurrency) { + int maxConcurrency, + Integer minSuccessful, + Integer toleratedFailureCount) { super(operationIdentifier, resultTypeToken, resultSerDes, durableContext); this.maxConcurrency = maxConcurrency; + this.minSuccessful = minSuccessful; + this.toleratedFailureCount = toleratedFailureCount; this.operationIdGenerator = new OperationIdGenerator(getOperationId()); this.rootContext = durableContext.createChildContextWithoutSettingThreadContext(getOperationId(), getName()); } @@ -93,9 +99,6 @@ protected abstract ChildContextOperation createItem( /** Called when the concurrency operation succeeds. Subclasses define checkpointing behavior. */ protected abstract void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus); - /** Called when the concurrency operation fails. Subclasses define checkpointing and exception behavior. */ - protected abstract void handleFailure(ConcurrencyCompletionStatus concurrencyCompletionStatus); - // ========== Concurrency control ========== /** @@ -145,7 +148,7 @@ protected void startPendingItems() { while (true) { synchronized (this) { if (isOperationCompleted()) return; - if (maxConcurrency != -1 && runningCount.get() >= maxConcurrency) return; + if (runningCount.get() >= maxConcurrency) return; var next = pendingQueue.poll(); if (next == null) return; runningCount.incrementAndGet(); @@ -162,7 +165,7 @@ protected void startPendingItems() { private void executeNextItemIfAllowed() { synchronized (this) { if (isOperationCompleted()) return; - if (maxConcurrency != -1 && runningCount.get() >= maxConcurrency) return; + if (runningCount.get() >= maxConcurrency) return; var next = pendingQueue.poll(); if (next == null) return; runningCount.incrementAndGet(); @@ -220,25 +223,46 @@ public void onItemComplete(ChildContextOperation child) { * * @throws IllegalArgumentException if the item count cannot satisfy the criteria */ - protected abstract void validateItemCount(); + protected void validateItemCount() { + if (minSuccessful != null && minSuccessful > getTotalItems()) { + throw new IllegalArgumentException("minSuccessful (" + minSuccessful + + ") exceeds the number of registered items (" + getTotalItems() + ")"); + } + } /** * Checks whether the concurrency operation can be considered complete. * * @return the completion status if the operation is complete, or null if it should continue */ - protected abstract ConcurrencyCompletionStatus canComplete(); + protected ConcurrencyCompletionStatus canComplete() { + int succeeded = getSucceededCount(); + int failed = getFailedCount(); + + // If we've met the minimum successful count, we're done + if (minSuccessful != null && succeeded >= minSuccessful) { + return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; + } + + // If we've exceeded the failure tolerance, we're done + if (toleratedFailureCount != null && failed > toleratedFailureCount) { + return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; + } + + // All items finished — complete + if (isAllItemsFinished()) { + return ConcurrencyCompletionStatus.ALL_COMPLETED; + } + + return null; + } private void handleComplete(ConcurrencyCompletionStatus status) { synchronized (this) { if (isOperationCompleted()) { return; } - if (status.isSucceeded()) { - handleSuccess(status); - } else { - handleFailure(status); - } + handleSuccess(status); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java index d9931a475..830f18e1a 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -14,7 +14,6 @@ import software.amazon.lambda.durable.CompletionConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.MapConfig; -import software.amazon.lambda.durable.MapFunction; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; @@ -42,10 +41,9 @@ public class MapOperation extends ConcurrencyOperation> { private static final int LARGE_RESULT_THRESHOLD = 256 * 1024; private final List items; - private final MapFunction function; + private final DurableContext.MapFunction function; private final TypeToken itemResultType; private final SerDes serDes; - private final CompletionConfig completionConfig; private boolean replayFromPayload; private volatile MapResult cachedResult; private ConcurrencyCompletionStatus completionStatus; @@ -53,7 +51,7 @@ public class MapOperation extends ConcurrencyOperation> { public MapOperation( OperationIdentifier operationIdentifier, List items, - MapFunction function, + DurableContext.MapFunction function, TypeToken itemResultType, MapConfig config, DurableContextImpl durableContext) { @@ -62,12 +60,28 @@ public MapOperation( new TypeToken<>() {}, config.serDes(), durableContext, - config.maxConcurrency() != null ? config.maxConcurrency() : -1); + config.maxConcurrency() != null ? config.maxConcurrency() : Integer.MAX_VALUE, + config.completionConfig().minSuccessful(), + getToleratedFailureCount(config.completionConfig(), items.size())); this.items = List.copyOf(items); this.function = function; this.itemResultType = itemResultType; this.serDes = config.serDes(); - this.completionConfig = config.completionConfig(); + } + + private static Integer getToleratedFailureCount(CompletionConfig completionConfig, int totalItems) { + if (completionConfig == null + || completionConfig.toleratedFailureCount() == null + && completionConfig.toleratedFailurePercentage() == null) { + return null; + } + int toleratedFailureCount = completionConfig.toleratedFailureCount() != null + ? completionConfig.toleratedFailureCount() + : Integer.MAX_VALUE; + int toleratedFailureCountFromPercentage = completionConfig.toleratedFailurePercentage() != null + ? (int) Math.ceil(totalItems * completionConfig.toleratedFailurePercentage()) + : Integer.MAX_VALUE; + return Math.min(toleratedFailureCount, toleratedFailureCountFromPercentage); } @Override @@ -139,56 +153,6 @@ protected void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionSt checkpointMapResult(); } - @Override - protected void handleFailure(ConcurrencyCompletionStatus concurrencyCompletionStatus) { - this.completionStatus = concurrencyCompletionStatus; - checkpointMapResult(); - } - - @Override - protected void validateItemCount() { - if (completionConfig.minSuccessful() != null && completionConfig.minSuccessful() > getTotalItems()) { - throw new IllegalArgumentException("minSuccessful (" + completionConfig.minSuccessful() - + ") exceeds the number of items (" + getTotalItems() + ")"); - } - } - - /** - * Overrides the default completion logic from {@link ConcurrencyOperation} to support Map's - * {@link CompletionConfig} semantics. Unlike Parallel (where {@code minSuccessful == -1} means "all must succeed"), - * Map's default {@code allCompleted()} allows failures without early termination. - */ - @Override - protected ConcurrencyCompletionStatus canComplete() { - int succeeded = getSucceededCount(); - int failed = getFailedCount(); - int totalCompleted = succeeded + failed; - - // Check minSuccessful - if (completionConfig.minSuccessful() != null && succeeded >= completionConfig.minSuccessful()) { - return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; - } - - // Check toleratedFailureCount - if (completionConfig.toleratedFailureCount() != null && failed > completionConfig.toleratedFailureCount()) { - return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; - } - - // Check toleratedFailurePercentage - if (completionConfig.toleratedFailurePercentage() != null - && totalCompleted > 0 - && ((double) failed / totalCompleted) > completionConfig.toleratedFailurePercentage()) { - return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; - } - - // All items finished (no pending, no running) — complete with ALL_COMPLETED - if (isAllItemsFinished()) { - return ConcurrencyCompletionStatus.ALL_COMPLETED; - } - - return null; - } - private void checkpointMapResult() { var result = aggregateResults(); this.cachedResult = result; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index 13aaa5d55..d99e4ffe6 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -9,6 +9,7 @@ import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.ParallelConfig; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; @@ -41,20 +42,21 @@ */ public class ParallelOperation extends ConcurrencyOperation { - private final int minSuccessful; - private final int toleratedFailureCount; private boolean skipCheckpoint = false; public ParallelOperation( OperationIdentifier operationIdentifier, SerDes resultSerDes, DurableContextImpl durableContext, - int maxConcurrency, - int minSuccessful, - int toleratedFailureCount) { - super(operationIdentifier, new TypeToken() {}, resultSerDes, durableContext, maxConcurrency); - this.minSuccessful = minSuccessful; - this.toleratedFailureCount = toleratedFailureCount; + ParallelConfig config) { + super( + operationIdentifier, + TypeToken.get(ParallelResult.class), + resultSerDes, + durableContext, + config.maxConcurrency(), + config.completionConfig().minSuccessful(), + config.completionConfig().toleratedFailureCount()); } @Override @@ -87,11 +89,6 @@ protected void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionSt .contextOptions(ContextOptions.builder().replayChildren(true).build())); } - @Override - protected void handleFailure(ConcurrencyCompletionStatus concurrencyCompletionStatus) { - handleSuccess(concurrencyCompletionStatus); - } - @Override protected void start() { sendOperationUpdateAsync(OperationUpdate.builder() @@ -111,35 +108,4 @@ public ParallelResult get() { join(); return new ParallelResult(getTotalItems(), getSucceededCount(), getFailedCount(), getCompletionStatus()); } - - @Override - protected void validateItemCount() { - if (minSuccessful > getTotalItems()) { - throw new IllegalArgumentException("minSuccessful (" + minSuccessful - + ") exceeds the number of registered items (" + getTotalItems() + ")"); - } - } - - @Override - protected ConcurrencyCompletionStatus canComplete() { - int succeeded = getSucceededCount(); - int failed = getFailedCount(); - - // If we've met the minimum successful count, we're done - if (minSuccessful != -1 && succeeded >= minSuccessful) { - return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; - } - - // If we've exceeded the failure tolerance, we're done - if ((minSuccessful == -1 && failed > 0) || failed > toleratedFailureCount) { - return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; - } - - // All items finished — complete - if (isAllItemsFinished()) { - return ConcurrencyCompletionStatus.ALL_COMPLETED; - } - - return null; - } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/MapConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/MapConfigTest.java index 11c567d8f..fd856fed4 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/MapConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/MapConfigTest.java @@ -13,7 +13,7 @@ class MapConfigTest { void defaultBuilder_hasNullMaxConcurrency() { var config = MapConfig.builder().build(); - assertNull(config.maxConcurrency()); + assertEquals(Integer.MAX_VALUE, config.maxConcurrency()); } @Test @@ -121,6 +121,6 @@ void builderWithNegativeMaxConcurrency_shouldThrow() { @Test void builderWithNullMaxConcurrency_shouldPass() { var config = MapConfig.builder().maxConcurrency(null).build(); - assertNull(config.maxConcurrency()); + assertEquals(Integer.MAX_VALUE, config.maxConcurrency()); } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/MapFunctionTest.java b/sdk/src/test/java/software/amazon/lambda/durable/MapFunctionTest.java index 9bc98c1a1..4cf6c01bc 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/MapFunctionTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/MapFunctionTest.java @@ -10,12 +10,12 @@ class MapFunctionTest { @Test void isFunctionalInterface() { - assertTrue(MapFunction.class.isAnnotationPresent(FunctionalInterface.class)); + assertTrue(DurableContext.MapFunction.class.isAnnotationPresent(FunctionalInterface.class)); } @Test void canBeUsedAsLambda() { - MapFunction fn = (item, index, ctx) -> item.toUpperCase(); + DurableContext.MapFunction fn = (item, index, ctx) -> item.toUpperCase(); var result = fn.apply("hello", 0, null); @@ -24,7 +24,7 @@ void canBeUsedAsLambda() { @Test void receivesCorrectIndex() { - MapFunction fn = (item, index, ctx) -> index; + DurableContext.MapFunction fn = (item, index, ctx) -> index; assertEquals(0, fn.apply("a", 0, null)); assertEquals(5, fn.apply("b", 5, null)); @@ -32,7 +32,7 @@ void receivesCorrectIndex() { @Test void canThrowRuntimeException() { - MapFunction fn = (item, index, ctx) -> { + DurableContext.MapFunction fn = (item, index, ctx) -> { throw new IllegalArgumentException("bad input"); }; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/ParallelConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/ParallelConfigTest.java index 6dc1aa0e8..819ce5bc5 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/ParallelConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/ParallelConfigTest.java @@ -13,22 +13,20 @@ class ParallelConfigTest { void defaultValues() { var config = ParallelConfig.builder().build(); - assertEquals(-1, config.maxConcurrency()); - assertEquals(-1, config.minSuccessful()); - assertEquals(0, config.toleratedFailureCount()); + assertEquals(Integer.MAX_VALUE, config.maxConcurrency()); + assertEquals(CompletionConfig.allCompleted(), config.completionConfig()); } @Test void builderRoundTrip() { + CompletionConfig completionConfig = CompletionConfig.allSuccessful(); var config = ParallelConfig.builder() .maxConcurrency(4) - .minSuccessful(2) - .toleratedFailureCount(3) + .completionConfig(completionConfig) .build(); assertEquals(4, config.maxConcurrency()); - assertEquals(2, config.minSuccessful()); - assertEquals(3, config.toleratedFailureCount()); + assertEquals(completionConfig, config.completionConfig()); } @Test @@ -39,40 +37,12 @@ void maxConcurrencyOfOne() { } @Test - void minSuccessfulOfZero() { - var config = ParallelConfig.builder().minSuccessful(0).build(); - - assertEquals(0, config.minSuccessful()); - } - - @Test - void unlimitedConcurrency() { - var config = ParallelConfig.builder().maxConcurrency(-1).build(); - - assertEquals(-1, config.maxConcurrency()); - } - - @Test - void maxConcurrencyZeroThrows() { - var builder = ParallelConfig.builder().maxConcurrency(0); - assertThrows(IllegalArgumentException.class, builder::build); - } - - @Test - void maxConcurrencyNegativeTwoThrows() { - var builder = ParallelConfig.builder().maxConcurrency(-2); - assertThrows(IllegalArgumentException.class, builder::build); - } - - @Test - void minSuccessfulNegativeTwoThrows() { - var builder = ParallelConfig.builder().minSuccessful(-2); - assertThrows(IllegalArgumentException.class, builder::build); - } - - @Test - void toleratedFailureCountNegativeThrows() { - var builder = ParallelConfig.builder().toleratedFailureCount(-1); - assertThrows(IllegalArgumentException.class, builder::build); + void invalidConcurrency() { + assertThrows( + IllegalArgumentException.class, + () -> ParallelConfig.builder().maxConcurrency(-1).build()); + assertThrows( + IllegalArgumentException.class, + () -> ParallelConfig.builder().maxConcurrency(0).build()); } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java index 67f7c2ebe..d0b816150 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java @@ -19,6 +19,7 @@ import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; +import software.amazon.lambda.durable.CompletionConfig; import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; @@ -86,17 +87,15 @@ void setUp() { when(executionManager.sendOperationUpdate(any())).thenReturn(CompletableFuture.completedFuture(null)); } - private TestConcurrencyOperation createOperation(int maxConcurrency, int minSuccessful, int toleratedFailureCount) - throws Exception { + private TestConcurrencyOperation createOperation(CompletionConfig completionConfig) throws Exception { TestConcurrencyOperation testConcurrencyOperation = new TestConcurrencyOperation( OperationIdentifier.of( OPERATION_ID, "test-concurrency", OperationType.CONTEXT, OperationSubType.PARALLEL), RESULT_TYPE, SER_DES, durableContext, - maxConcurrency, - minSuccessful, - toleratedFailureCount); + Integer.MAX_VALUE, + completionConfig); setOperationIdGenerator(testConcurrencyOperation, mockIdGenerator); return testConcurrencyOperation; } @@ -134,7 +133,7 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { .build()); var functionCalled = new AtomicBoolean(false); - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); op.addItem( "branch-1", ctx -> { @@ -175,7 +174,7 @@ void singleChildAlreadySucceeds_fullCycle() throws Exception { .build()); var functionCalled = new AtomicBoolean(false); - var op = createOperation(-1, 1, 0); + var op = createOperation(CompletionConfig.minSuccessful(1)); op.addItem( "only-branch", ctx -> { @@ -195,7 +194,7 @@ void singleChildAlreadySucceeds_fullCycle() throws Exception { @Test void addItem_usesRootChildContextAsParent() throws Exception { - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); @@ -213,8 +212,6 @@ static class TestConcurrencyOperation extends ConcurrencyOperation { private boolean failureHandled = false; private final AtomicInteger executingCount = new AtomicInteger(0); private DurableContextImpl lastParentContext; - private final int minSuccessful; - private final int toleratedFailureCount; TestConcurrencyOperation( OperationIdentifier operationIdentifier, @@ -222,11 +219,15 @@ static class TestConcurrencyOperation extends ConcurrencyOperation { SerDes resultSerDes, DurableContextImpl durableContext, int maxConcurrency, - int minSuccessful, - int toleratedFailureCount) { - super(operationIdentifier, resultTypeToken, resultSerDes, durableContext, maxConcurrency); - this.minSuccessful = minSuccessful; - this.toleratedFailureCount = toleratedFailureCount; + CompletionConfig completionConfig) { + super( + operationIdentifier, + resultTypeToken, + resultSerDes, + durableContext, + maxConcurrency, + completionConfig.minSuccessful(), + completionConfig.toleratedFailureCount()); } @Override @@ -264,49 +265,12 @@ protected void handleSuccess(ConcurrencyCompletionStatus completionStatus) { .build()); } - @Override - protected void handleFailure(ConcurrencyCompletionStatus completionStatus) { - failureHandled = true; - onCheckpointComplete(Operation.builder() - .id(getOperationId()) - .status(OperationStatus.SUCCEEDED) // always success for parallel - .build()); - } - @Override protected void start() {} @Override protected void replay(Operation existing) {} - @Override - protected void validateItemCount() { - if (minSuccessful > getTotalItems() - getFailedCount()) { - throw new IllegalArgumentException("minSuccessful (" + minSuccessful - + ") exceeds the number of registered items (" + getTotalItems() + ")"); - } - } - - @Override - protected ConcurrencyCompletionStatus canComplete() { - int succeeded = getSucceededCount(); - int failed = getFailedCount(); - - if (minSuccessful != -1 && succeeded >= minSuccessful) { - return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; - } - - if ((minSuccessful == -1 && failed > 0) || failed > toleratedFailureCount) { - return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; - } - - if (isAllItemsFinished()) { - return ConcurrencyCompletionStatus.ALL_COMPLETED; - } - - return null; - } - @Override public Void get() { return null; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index a44e53459..9609c6995 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -20,7 +20,9 @@ import software.amazon.awssdk.services.lambda.model.OperationAction; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; +import software.amazon.lambda.durable.CompletionConfig; import software.amazon.lambda.durable.DurableConfig; +import software.amazon.lambda.durable.ParallelConfig; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; @@ -107,14 +109,12 @@ void setUp() { .sendOperationUpdate(any()); } - private ParallelOperation createOperation(int maxConcurrency, int minSuccessful, int toleratedFailureCount) { + private ParallelOperation createOperation(CompletionConfig completionConfig) { return new ParallelOperation( OperationIdentifier.of(OPERATION_ID, "test-parallel", OperationType.CONTEXT, OperationSubType.PARALLEL), SER_DES, durableContext, - maxConcurrency, - minSuccessful, - toleratedFailureCount); + ParallelConfig.builder().completionConfig(completionConfig).build()); } private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGenerator mockGenerator) @@ -128,7 +128,7 @@ private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGene @Test void branchCreation_createsBranchWithParallelBranchSubType() throws Exception { - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); @@ -138,7 +138,7 @@ void branchCreation_createsBranchWithParallelBranchSubType() throws Exception { @Test void branchCreation_multipleBranchesAllCreated() throws Exception { - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); @@ -149,7 +149,7 @@ void branchCreation_multipleBranchesAllCreated() throws Exception { @Test void branchCreation_childOperationHasParentReference() throws Exception { - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); // The child operation should be a ChildContextOperation with this op as parent var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); @@ -184,7 +184,7 @@ void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws E ContextDetails.builder().result("\"r2\"").build()) .build()); - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); setOperationIdGenerator(op, mockIdGenerator); op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); @@ -214,7 +214,7 @@ void minSuccessful_completesWhenThresholdMetAndReturnsResult() throws Exception ContextDetails.builder().result("\"r1\"").build()) .build()); - var op = createOperation(-1, 1, 0); + var op = createOperation(CompletionConfig.minSuccessful(1)); setOperationIdGenerator(op, mockIdGenerator); op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); @@ -234,7 +234,7 @@ void minSuccessful_completesWhenThresholdMetAndReturnsResult() throws Exception void contextHierarchy_branchesUseParallelContextAsParent() throws Exception { // Verify that branches are created with the parallel operation's context (durableContext) // as their parent — not some other context - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); @@ -277,7 +277,7 @@ void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exc ContextDetails.builder().result("\"r2\"").build()) .build()); - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); setOperationIdGenerator(op, mockIdGenerator); op.execute(); op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); @@ -326,7 +326,7 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio ContextDetails.builder().result("\"r2\"").build()) .build()); - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); setOperationIdGenerator(op, mockIdGenerator); op.execute(); op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); @@ -357,7 +357,7 @@ void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() throws Except .status(OperationStatus.FAILED) .build()); - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); setOperationIdGenerator(op, mockIdGenerator); op.addItem( "branch-1", @@ -400,7 +400,7 @@ void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exceptio .build()); // toleratedFailureCount=1 so the operation completes after both branches finish - var op = createOperation(-1, -1, 1); + var op = createOperation(CompletionConfig.toleratedFailureCount(1)); setOperationIdGenerator(op, mockIdGenerator); op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); op.addItem( @@ -417,12 +417,12 @@ void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exceptio assertEquals(2, result.getTotalBranches()); assertEquals(1, result.getSucceededBranches()); assertEquals(1, result.getFailedBranches()); - assertFalse(result.getCompletionStatus().isSucceeded()); + assertTrue(result.getCompletionStatus().isSucceeded()); } @Test void get_zeroBranches_returnsAllZerosAndAllCompletedStatus() throws Exception { - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); var result = op.get(); From 5878a41e7f91fa69ab905ae8747212226ae9bc62 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sat, 21 Mar 2026 19:24:12 -0700 Subject: [PATCH 02/15] rename parallel result fields --- .../examples/parallel/ParallelExample.java | 6 +- .../ParallelFailureToleranceExample.java | 16 ++--- .../parallel/ParallelWithWaitExample.java | 2 +- .../ParallelFailureToleranceExampleTest.java | 4 +- .../lambda/durable/MapIntegrationTest.java | 4 +- .../lambda/durable/CompletionConfig.java | 57 +----------------- .../lambda/durable/model/MapResult.java | 3 +- .../lambda/durable/model/MapResultItem.java | 12 ++-- .../lambda/durable/model/ParallelResult.java | 39 +----------- .../durable/operation/MapOperation.java | 22 ++++--- .../lambda/durable/model/MapResultTest.java | 31 +++++----- .../durable/model/ParallelResultTest.java | 8 +-- .../operation/ParallelOperationTest.java | 60 +++++++++---------- 13 files changed, 90 insertions(+), 174 deletions(-) diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java index a13005c7b..f294fb8e4 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java @@ -54,9 +54,9 @@ public Output handleRequest(Input input, DurableContext context) { ParallelResult parallelResult = parallel.get(); logger.info( "Parallel complete: total={}, succeeded={}, failed={}", - parallelResult.getTotalBranches(), - parallelResult.getSucceededBranches(), - parallelResult.getFailedBranches()); + parallelResult.size(), + parallelResult.succeeded(), + parallelResult.failed()); var results = futures.stream().map(DurableFuture::get).toList(); diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java index b498db8de..7a4cec7aa 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java @@ -4,6 +4,7 @@ import java.util.ArrayList; import java.util.List; +import software.amazon.lambda.durable.CompletionConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; @@ -25,7 +26,7 @@ public class ParallelFailureToleranceExample extends DurableHandler { - public record Input(List services, int toleratedFailures, int minSuccessful) {} + public record Input(List services, Integer toleratedFailures, Integer minSuccessful) {} public record Output(int succeeded, int failed) {} @@ -35,8 +36,7 @@ public Output handleRequest(Input input, DurableContext context) { logger.info("Starting parallel execution with toleratedFailureCount={}", input.toleratedFailures()); var config = ParallelConfig.builder() - .minSuccessful(input.minSuccessful()) - .toleratedFailureCount(input.toleratedFailures()) + .completionConfig(new CompletionConfig(input.minSuccessful, input.toleratedFailures, null)) .build(); var futures = new ArrayList>(input.services().size()); @@ -65,12 +65,12 @@ public Output handleRequest(Input input, DurableContext context) { ParallelResult parallelResult = parallel.get(); logger.info( "Parallel complete: succeeded={}, failed={}, status={}", - parallelResult.getSucceededBranches(), - parallelResult.getFailedBranches(), - parallelResult.getCompletionStatus().isSucceeded() ? "succeeded" : "failed"); + parallelResult.succeeded(), + parallelResult.failed(), + parallelResult.completionStatus().isSucceeded() ? "succeeded" : "failed"); - var succeeded = parallelResult.getSucceededBranches(); - var failed = parallelResult.getFailedBranches(); + var succeeded = parallelResult.succeeded(); + var failed = parallelResult.failed(); logger.info("Completed: {} succeeded, {} failed", succeeded, failed); return new Output(succeeded, failed); diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java index 63bbee21d..66a8be461 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java @@ -68,6 +68,6 @@ public Output handleRequest(Input input, DurableContext context) { logger.info("All {} notifications delivered", deliveries.size()); // Test replay context.wait("wait for finalization", Duration.ofSeconds(5)); - return new Output(deliveries, result.getSucceededBranches(), result.getFailedBranches()); + return new Output(deliveries, result.succeeded(), result.failed()); } } diff --git a/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java b/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java index 7d4dd72d1..f1518e2ac 100644 --- a/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java +++ b/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java @@ -17,7 +17,7 @@ void succeedsWhenFailuresAreWithinTolerance() { var runner = LocalDurableTestRunner.create(ParallelFailureToleranceExample.Input.class, handler); // 2 good services, 1 bad — toleratedFailureCount=1 so the parallel op still succeeds - var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "svc-c"), 1, -1); + var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "svc-c"), 1, null); var result = runner.runUntilComplete(input); assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); @@ -32,7 +32,7 @@ void succeedsWhenAllBranchesSucceed() { var handler = new ParallelFailureToleranceExample(); var runner = LocalDurableTestRunner.create(ParallelFailureToleranceExample.Input.class, handler); - var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "svc-b", "svc-c"), 2, -1); + var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "svc-b", "svc-c"), 2, null); var result = runner.runUntilComplete(input); assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java index 3bbd8c569..a15cffabe 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java @@ -535,8 +535,8 @@ void testMapWithAllSuccessfulCompletionConfig_stopsOnFirstFailure() { assertEquals("OK1", result.getResult(0)); assertNotNull(result.getError(1)); // Items after the failure should be NOT_STARTED - assertEquals(MapResultItem.Status.NOT_STARTED, result.getItem(2).status()); - assertEquals(MapResultItem.Status.NOT_STARTED, result.getItem(3).status()); + assertEquals(MapResultItem.Status.SKIPPED, result.getItem(2).status()); + assertEquals(MapResultItem.Status.SKIPPED, result.getItem(3).status()); return "done"; }); diff --git a/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java index de91f27fd..49d12e9cd 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java @@ -2,24 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 package software.amazon.lambda.durable; -import java.util.Objects; - /** * Controls when a concurrent operation (map or parallel) completes. * *

Provides factory methods for common completion strategies and fine-grained control via {@code minSuccessful}, * {@code toleratedFailureCount}, and {@code toleratedFailurePercentage}. */ -public class CompletionConfig { - private final Integer minSuccessful; - private final Integer toleratedFailureCount; - private final Double toleratedFailurePercentage; - - CompletionConfig(Integer minSuccessful, Integer toleratedFailureCount, Double toleratedFailurePercentage) { - this.minSuccessful = minSuccessful; - this.toleratedFailureCount = toleratedFailureCount; - this.toleratedFailurePercentage = toleratedFailurePercentage; - } +public record CompletionConfig( + Integer minSuccessful, Integer toleratedFailureCount, Double toleratedFailurePercentage) { /** All items must succeed. Zero failures tolerated. */ public static CompletionConfig allSuccessful() { @@ -60,47 +50,4 @@ public static CompletionConfig toleratedFailurePercentage(double percentage) { } return new CompletionConfig(null, null, percentage); } - - /** @return minimum number of successful items required, or null if not set */ - public Integer minSuccessful() { - return minSuccessful; - } - - /** @return maximum number of failures tolerated, or null if unlimited */ - public Integer toleratedFailureCount() { - return toleratedFailureCount; - } - - /** @return maximum percentage of failures tolerated (0.0 to 1.0), or null if not set */ - public Double toleratedFailurePercentage() { - return toleratedFailurePercentage; - } - - @Override - public String toString() { - return "CompletionConfig{" + "minSuccessful=" - + minSuccessful + ", toleratedFailureCount=" - + toleratedFailureCount + ", toleratedFailurePercentage=" - + toleratedFailurePercentage + '}'; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - CompletionConfig that = (CompletionConfig) o; - - return Objects.equals(minSuccessful, that.minSuccessful) - && Objects.equals(toleratedFailureCount, that.toleratedFailureCount) - && Objects.equals(toleratedFailurePercentage, that.toleratedFailurePercentage); - } - - @Override - public int hashCode() { - int result = minSuccessful != null ? minSuccessful.hashCode() : 0; - result = 31 * result + (toleratedFailureCount != null ? toleratedFailureCount.hashCode() : 0); - result = 31 * result + (toleratedFailurePercentage != null ? toleratedFailurePercentage.hashCode() : 0); - return result; - } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java index a307c4fbf..dce28a3d3 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java @@ -59,8 +59,7 @@ public int size() { /** Returns all results as an unmodifiable list (nulls for failed/not-started items). */ public List results() { - return Collections.unmodifiableList( - items.stream().map(MapResultItem::result).toList()); + return items.stream().map(MapResultItem::result).toList(); } /** Returns results from items that succeeded (includes null results from successful items). */ diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResultItem.java b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResultItem.java index 86cb8ce79..414c40124 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResultItem.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResultItem.java @@ -22,21 +22,21 @@ public record MapResultItem(Status status, T result, MapError error) { public enum Status { SUCCEEDED, FAILED, - NOT_STARTED + SKIPPED } /** Creates a successful result item. */ - public static MapResultItem success(T result) { + public static MapResultItem succeeded(T result) { return new MapResultItem<>(Status.SUCCEEDED, result, null); } /** Creates a failed result item. */ - public static MapResultItem failure(MapError error) { + public static MapResultItem failed(MapError error) { return new MapResultItem<>(Status.FAILED, null, error); } - /** Creates a not-started result item. */ - public static MapResultItem notStarted() { - return new MapResultItem<>(Status.NOT_STARTED, null, null); + /** Creates a skipped result item. */ + public static MapResultItem skipped() { + return new MapResultItem<>(Status.SKIPPED, null, null); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/ParallelResult.java b/sdk/src/main/java/software/amazon/lambda/durable/model/ParallelResult.java index f0e0fcf18..11bd9049e 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/ParallelResult.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/model/ParallelResult.java @@ -8,41 +8,4 @@ *

Captures the aggregate outcome of a parallel execution: how many branches were registered, how many succeeded, how * many failed, and why the operation completed. */ -public class ParallelResult { - - private final int totalBranches; - private final int succeededBranches; - private final int failedBranches; - private final ConcurrencyCompletionStatus completionStatus; - - public ParallelResult( - int totalBranches, - int succeededBranches, - int failedBranches, - ConcurrencyCompletionStatus completionStatus) { - this.totalBranches = totalBranches; - this.succeededBranches = succeededBranches; - this.failedBranches = failedBranches; - this.completionStatus = completionStatus; - } - - /** Returns the total number of branches registered before {@code join()} was called. */ - public int getTotalBranches() { - return totalBranches; - } - - /** Returns the number of branches that completed without throwing. */ - public int getSucceededBranches() { - return succeededBranches; - } - - /** Returns the number of branches that threw an exception. */ - public int getFailedBranches() { - return failedBranches; - } - - /** Returns the status indicating why the parallel operation completed. */ - public ConcurrencyCompletionStatus getCompletionStatus() { - return completionStatus; - } -} +public record ParallelResult(int size, int succeeded, int failed, ConcurrencyCompletionStatus completionStatus) {} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java index 830f18e1a..da49ef8c7 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -60,7 +60,7 @@ public MapOperation( new TypeToken<>() {}, config.serDes(), durableContext, - config.maxConcurrency() != null ? config.maxConcurrency() : Integer.MAX_VALUE, + config.maxConcurrency(), config.completionConfig().minSuccessful(), getToleratedFailureCount(config.completionConfig(), items.size())); this.items = List.copyOf(items); @@ -71,16 +71,20 @@ public MapOperation( private static Integer getToleratedFailureCount(CompletionConfig completionConfig, int totalItems) { if (completionConfig == null - || completionConfig.toleratedFailureCount() == null - && completionConfig.toleratedFailurePercentage() == null) { + || (completionConfig.toleratedFailureCount() == null + && completionConfig.toleratedFailurePercentage() == null)) { + // neither toleratedFailureCount nor toleratedFailurePercentage is specified. return null; } int toleratedFailureCount = completionConfig.toleratedFailureCount() != null ? completionConfig.toleratedFailureCount() : Integer.MAX_VALUE; + + // convert percentage to count int toleratedFailureCountFromPercentage = completionConfig.toleratedFailurePercentage() != null - ? (int) Math.ceil(totalItems * completionConfig.toleratedFailurePercentage()) + ? (int) Math.floor(totalItems * completionConfig.toleratedFailurePercentage()) : Integer.MAX_VALUE; + // minimum of two if both count and percentage is specified return Math.min(toleratedFailureCount, toleratedFailureCountFromPercentage); } @@ -202,19 +206,19 @@ private MapResult aggregateResults() { for (int i = 0; i < children.size(); i++) { var branch = (ChildContextOperation) children.get(i); if (!branch.isOperationCompleted()) { - resultItems.set(i, MapResultItem.notStarted()); + resultItems.set(i, MapResultItem.skipped()); continue; } try { - resultItems.set(i, MapResultItem.success(branch.get())); + resultItems.set(i, MapResultItem.succeeded(branch.get())); } catch (Exception e) { - resultItems.set(i, MapResultItem.failure(buildMapError(e))); + resultItems.set(i, MapResultItem.failed(buildMapError(e))); } } - // Fill any remaining null slots (items beyond children size) with notStarted + // Fill any remaining null slots (items beyond children size) with skipped for (int i = children.size(); i < items.size(); i++) { - resultItems.set(i, MapResultItem.notStarted()); + resultItems.set(i, MapResultItem.skipped()); } return new MapResult<>(resultItems, completionStatus); diff --git a/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java b/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java index 09d97e6d2..f144a18b6 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java @@ -28,7 +28,7 @@ void empty_returnsZeroSizeResult() { @Test void allSucceeded_trueWhenNoErrors() { var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.success("b")), + List.of(MapResultItem.succeeded("a"), MapResultItem.succeeded("b")), ConcurrencyCompletionStatus.ALL_COMPLETED); assertTrue(result.allSucceeded()); @@ -43,7 +43,7 @@ void allSucceeded_trueWhenNoErrors() { void allSucceeded_falseWhenAnyError() { var error = testError("fail"); var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.failure(error)), + List.of(MapResultItem.succeeded("a"), MapResultItem.failed(error)), ConcurrencyCompletionStatus.ALL_COMPLETED); assertFalse(result.allSucceeded()); @@ -53,7 +53,7 @@ void allSucceeded_falseWhenAnyError() { void getResult_returnsNullForFailedItem() { var error = testError("fail"); var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.failure(error)), + List.of(MapResultItem.succeeded("a"), MapResultItem.failed(error)), ConcurrencyCompletionStatus.ALL_COMPLETED); assertEquals("a", result.getResult(0)); @@ -64,7 +64,7 @@ void getResult_returnsNullForFailedItem() { void getError_returnsNullForSucceededItem() { var error = testError("fail"); var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.failure(error)), + List.of(MapResultItem.succeeded("a"), MapResultItem.failed(error)), ConcurrencyCompletionStatus.ALL_COMPLETED); assertNull(result.getError(0)); @@ -75,9 +75,9 @@ void getError_returnsNullForSucceededItem() { void succeeded_filtersNullResults() { var result = new MapResult<>( List.of( - MapResultItem.success("a"), - MapResultItem.failure(testError("fail")), - MapResultItem.success("c")), + MapResultItem.succeeded("a"), + MapResultItem.failed(testError("fail")), + MapResultItem.succeeded("c")), ConcurrencyCompletionStatus.ALL_COMPLETED); assertEquals(List.of("a", "c"), result.succeeded()); @@ -87,7 +87,10 @@ void succeeded_filtersNullResults() { void failed_filtersNullErrors() { var error = testError("fail"); var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.failure(error), MapResultItem.success("c")), + List.of( + MapResultItem.succeeded("a"), + MapResultItem.failed(error), + MapResultItem.succeeded("c")), ConcurrencyCompletionStatus.ALL_COMPLETED); var failures = result.failed(); @@ -98,22 +101,22 @@ void failed_filtersNullErrors() { @Test void completionReason_preserved() { var result = new MapResult<>( - List.of(MapResultItem.success("a")), ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED); + List.of(MapResultItem.succeeded("a")), ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED); assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.completionReason()); } @Test void items_returnsUnmodifiableList() { - var result = new MapResult<>(List.of(MapResultItem.success("a")), ConcurrencyCompletionStatus.ALL_COMPLETED); + var result = new MapResult<>(List.of(MapResultItem.succeeded("a")), ConcurrencyCompletionStatus.ALL_COMPLETED); - assertThrows(UnsupportedOperationException.class, () -> result.items().add(MapResultItem.success("b"))); + assertThrows(UnsupportedOperationException.class, () -> result.items().add(MapResultItem.succeeded("b"))); } @Test void getItem_returnsMapResultItem() { var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.failure(testError("fail"))), + List.of(MapResultItem.succeeded("a"), MapResultItem.failed(testError("fail"))), ConcurrencyCompletionStatus.ALL_COMPLETED); assertEquals(MapResultItem.Status.SUCCEEDED, result.getItem(0).status()); @@ -128,10 +131,10 @@ void getItem_returnsMapResultItem() { @Test void notStartedItems_haveNotStartedStatusAndNullResultAndError() { var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.notStarted()), + List.of(MapResultItem.succeeded("a"), MapResultItem.skipped()), ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED); - assertEquals(MapResultItem.Status.NOT_STARTED, result.getItem(1).status()); + assertEquals(MapResultItem.Status.SKIPPED, result.getItem(1).status()); assertNull(result.getResult(1)); assertNull(result.getError(1)); } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/model/ParallelResultTest.java b/sdk/src/test/java/software/amazon/lambda/durable/model/ParallelResultTest.java index c5cfeebc8..328273111 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/model/ParallelResultTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/model/ParallelResultTest.java @@ -12,9 +12,9 @@ class ParallelResultTest { void allBranchesSucceed_countsAreCorrect() { var result = new ParallelResult(3, 3, 0, ConcurrencyCompletionStatus.ALL_COMPLETED); - assertEquals(3, result.getTotalBranches()); - assertEquals(3, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); + assertEquals(3, result.size()); + assertEquals(3, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index 9609c6995..439e8574a 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -192,11 +192,11 @@ void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws E var result = op.get(); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); - assertEquals(2, result.getTotalBranches()); - assertEquals(2, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); - assertTrue(result.getCompletionStatus().isSucceeded()); + assertEquals(2, result.size()); + assertEquals(2, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); + assertTrue(result.completionStatus().isSucceeded()); } // ===== MinSuccessful satisfaction ===== @@ -221,11 +221,11 @@ void minSuccessful_completesWhenThresholdMetAndReturnsResult() throws Exception var result = op.get(); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); - assertEquals(1, result.getTotalBranches()); - assertEquals(1, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.getCompletionStatus()); - assertTrue(result.getCompletionStatus().isSucceeded()); + assertEquals(1, result.size()); + assertEquals(1, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.completionStatus()); + assertTrue(result.completionStatus().isSucceeded()); } // ===== Context hierarchy ===== @@ -289,10 +289,10 @@ void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exc .sendOperationUpdate(argThat(update -> update.action() == OperationAction.START)); verify(executionManager, times(1)) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); - assertEquals(2, result.getTotalBranches()); - assertEquals(2, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); + assertEquals(2, result.size()); + assertEquals(2, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); } @Test @@ -338,10 +338,10 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio .sendOperationUpdate(argThat(update -> update.action() == OperationAction.START)); verify(executionManager, never()) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); - assertEquals(2, result.getTotalBranches()); - assertEquals(2, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); + assertEquals(2, result.size()); + assertEquals(2, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); } // ===== Branch failure sends SUCCEED checkpoint and returns result ===== @@ -372,10 +372,10 @@ void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() throws Except verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); verify(executionManager, never()) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.FAIL)); - assertEquals(1, result.getTotalBranches()); - assertEquals(0, result.getSucceededBranches()); - assertEquals(1, result.getFailedBranches()); - assertFalse(result.getCompletionStatus().isSucceeded()); + assertEquals(1, result.size()); + assertEquals(0, result.succeeded()); + assertEquals(1, result.failed()); + assertFalse(result.completionStatus().isSucceeded()); } @Test @@ -414,10 +414,10 @@ void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exceptio var result = op.get(); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); - assertEquals(2, result.getTotalBranches()); - assertEquals(1, result.getSucceededBranches()); - assertEquals(1, result.getFailedBranches()); - assertTrue(result.getCompletionStatus().isSucceeded()); + assertEquals(2, result.size()); + assertEquals(1, result.succeeded()); + assertEquals(1, result.failed()); + assertTrue(result.completionStatus().isSucceeded()); } @Test @@ -426,10 +426,10 @@ void get_zeroBranches_returnsAllZerosAndAllCompletedStatus() throws Exception { var result = op.get(); - assertEquals(0, result.getTotalBranches()); - assertEquals(0, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); + assertEquals(0, result.size()); + assertEquals(0, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); } } From fbbc85556700c2a4c5d1afa5c70302ae853d7ce0 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sat, 21 Mar 2026 19:40:55 -0700 Subject: [PATCH 03/15] make DurableContext a DurableFuture interface with extra methods --- .../examples/parallel/ParallelExample.java | 5 +- .../amazon/lambda/durable/DurableContext.java | 6 +- .../lambda/durable/ParallelContext.java | 97 ------------------- .../lambda/durable/ParallelDurableFuture.java | 16 +++ .../durable/context/DurableContextImpl.java | 6 +- .../operation/ConcurrencyOperation.java | 2 +- .../durable/operation/ParallelOperation.java | 43 +++++++- 7 files changed, 68 insertions(+), 107 deletions(-) delete mode 100644 sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java create mode 100644 sdk/src/main/java/software/amazon/lambda/durable/ParallelDurableFuture.java diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java index f294fb8e4..1ae644e26 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java @@ -8,6 +8,7 @@ import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; import software.amazon.lambda.durable.ParallelConfig; +import software.amazon.lambda.durable.ParallelDurableFuture; import software.amazon.lambda.durable.model.ParallelResult; /** @@ -21,8 +22,8 @@ *

  • A final step combines the results into a summary * * - *

    The {@link software.amazon.lambda.durable.ParallelContext} implements {@link AutoCloseable}, so try-with-resources - * guarantees {@code join()} is called even if an exception occurs. + *

    The {@link ParallelDurableFuture} implements {@link AutoCloseable}, so try-with-resources guarantees + * {@code join()} is called even if an exception occurs. */ public class ParallelExample extends DurableHandler { diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java index 4d614e229..d689f730c 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java @@ -282,12 +282,12 @@ DurableFuture> mapAsync( String name, Collection items, TypeToken resultType, MapFunction function, MapConfig config); /** - * Creates a {@link ParallelContext} for executing multiple branches concurrently. + * Creates a {@link ParallelDurableFuture} for executing multiple branches concurrently. * * @param config the parallel execution configuration - * @return a new ParallelContext for registering and executing branches + * @return a new ParallelDurableFuture for registering and executing branches */ - ParallelContext parallel(String name, ParallelConfig config); + ParallelDurableFuture parallel(String name, ParallelConfig config); /** * Executes a submitter function and waits for an external callback, blocking until the callback completes. diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java b/sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java deleted file mode 100644 index 6debb9754..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; - -import java.util.Objects; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; -import software.amazon.lambda.durable.model.ParallelResult; -import software.amazon.lambda.durable.operation.ParallelOperation; - -/** User-facing context for managing parallel branch execution within a durable function. */ -public class ParallelContext implements AutoCloseable, DurableFuture { - - private final ParallelOperation parallelOperation; - private final DurableContext durableContext; - private final AtomicBoolean joined = new AtomicBoolean(false); - - /** - * Creates a new ParallelContext. - * - * @param parallelOperation the underlying parallel operation managing concurrency - * @param durableContext the durable context for creating child operations - */ - public ParallelContext(ParallelOperation parallelOperation, DurableContext durableContext) { - this.parallelOperation = Objects.requireNonNull(parallelOperation, "parallelOperation cannot be null"); - this.durableContext = Objects.requireNonNull(durableContext, "durableContext cannot be null"); - } - - /** - * Registers and immediately starts a branch (respects maxConcurrency). - * - * @param name the branch name - * @param resultType the result type class - * @param func the function to execute in the branch's child context - * @param the result type - * @return a {@link DurableFuture} that will contain the branch result - * @throws IllegalStateException if called after {@link #join()} - */ - public DurableFuture branch(String name, Class resultType, Function func) { - return branch(name, TypeToken.get(resultType), func); - } - - /** - * Registers and immediately starts a branch (respects maxConcurrency). - * - * @param name the branch name - * @param resultType the result type token for generic types - * @param func the function to execute in the branch's child context - * @param the result type - * @return a {@link DurableFuture} that will contain the branch result - * @throws IllegalStateException if called after {@link #join()} - */ - public DurableFuture branch(String name, TypeToken resultType, Function func) { - if (joined.get()) { - throw new IllegalStateException("Cannot add branches after join() has been called"); - } - return parallelOperation.addItem( - name, func, resultType, durableContext.getDurableConfig().getSerDes()); - } - - /** - * Waits for completion based on config rules (minSuccessful, toleratedFailureCount). - * - *

    First validates that the number of registered branches is sufficient to satisfy the completion criteria. Then - * blocks until completion criteria are met or failure threshold exceeded. - * - * @throws IllegalArgumentException if branch count cannot satisfy completion criteria - * @throws software.amazon.lambda.durable.exception.ConcurrencyExecutionException if failure threshold exceeded - */ - public void join() { - if (!joined.compareAndSet(false, true)) { - return; - } - parallelOperation.join(); - } - - /** - * Blocks until the parallel operation completes and returns the {@link ParallelResult}. - * - *

    Calling {@code get()} implicitly calls {@code join()} if it has not been called yet. - * - * @return the {@link ParallelResult} summarising branch outcomes - */ - @Override - public ParallelResult get() { - joined.set(true); - return parallelOperation.get(); - } - - /** - * Calls {@link #join()} if not already called. Guarantees that all branches complete before the context is closed. - */ - @Override - public void close() { - join(); - } -} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ParallelDurableFuture.java b/sdk/src/main/java/software/amazon/lambda/durable/ParallelDurableFuture.java new file mode 100644 index 000000000..039cfc895 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/ParallelDurableFuture.java @@ -0,0 +1,16 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable; + +import java.util.function.Function; +import software.amazon.lambda.durable.model.ParallelResult; + +/** User-facing context for managing parallel branch execution within a durable function. */ +public interface ParallelDurableFuture extends AutoCloseable, DurableFuture { + + DurableFuture branch(String name, Class resultType, Function func); + + DurableFuture branch(String name, TypeToken resultType, Function func); + + void close(); +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java index 445418efa..9190fa543 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java @@ -21,7 +21,7 @@ import software.amazon.lambda.durable.InvokeConfig; import software.amazon.lambda.durable.MapConfig; import software.amazon.lambda.durable.ParallelConfig; -import software.amazon.lambda.durable.ParallelContext; +import software.amazon.lambda.durable.ParallelDurableFuture; import software.amazon.lambda.durable.StepConfig; import software.amazon.lambda.durable.StepContext; import software.amazon.lambda.durable.TypeToken; @@ -553,7 +553,7 @@ public DurableFuture> mapAsync( // ========== parallel methods ========== @Override - public ParallelContext parallel(String name, ParallelConfig config) { + public ParallelDurableFuture parallel(String name, ParallelConfig config) { Objects.requireNonNull(config, "config cannot be null"); var operationId = nextOperationId(); @@ -565,7 +565,7 @@ public ParallelContext parallel(String name, ParallelConfig config) { parallelOp.execute(); - return new ParallelContext(parallelOp, this); + return parallelOp; } // ========= waitForCallback methods ============= diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index 42430d56a..0df9b2f80 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -50,7 +50,7 @@ public abstract class ConcurrencyOperation extends BaseDurableOperation { private final AtomicInteger succeededCount = new AtomicInteger(0); private final AtomicInteger failedCount = new AtomicInteger(0); private final AtomicInteger runningCount = new AtomicInteger(0); - private final AtomicBoolean isJoined = new AtomicBoolean(false); + protected final AtomicBoolean isJoined = new AtomicBoolean(false); private final Queue> pendingQueue = new ConcurrentLinkedDeque<>(); private final List> childOperations = Collections.synchronizedList(new ArrayList<>()); private final Set completedOperations = Collections.synchronizedSet(new HashSet()); diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index d99e4ffe6..39490ea40 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -9,7 +9,9 @@ import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.ParallelConfig; +import software.amazon.lambda.durable.ParallelDurableFuture; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; @@ -40,7 +42,7 @@ * └── Branch N context (ChildContextOperation with PARALLEL_BRANCH) * */ -public class ParallelOperation extends ConcurrencyOperation { +public class ParallelOperation extends ConcurrencyOperation implements ParallelDurableFuture { private boolean skipCheckpoint = false; @@ -108,4 +110,43 @@ public ParallelResult get() { join(); return new ParallelResult(getTotalItems(), getSucceededCount(), getFailedCount(), getCompletionStatus()); } + + /** + * Calls {@link #join()} if not already called. Guarantees that all branches complete before the context is closed. + */ + @Override + public void close() { + join(); + } + + /** + * Registers and immediately starts a branch (respects maxConcurrency). + * + * @param name the branch name + * @param resultType the result type class + * @param func the function to execute in the branch's child context + * @param the result type + * @return a {@link DurableFuture} that will contain the branch result + * @throws IllegalStateException if called after {@link #join()} + */ + public DurableFuture branch(String name, Class resultType, Function func) { + return branch(name, TypeToken.get(resultType), func); + } + + /** + * Registers and immediately starts a branch (respects maxConcurrency). + * + * @param name the branch name + * @param resultType the result type token for generic types + * @param func the function to execute in the branch's child context + * @param the result type + * @return a {@link DurableFuture} that will contain the branch result + * @throws IllegalStateException if called after {@link #join()} + */ + public DurableFuture branch(String name, TypeToken resultType, Function func) { + if (isJoined.get()) { + throw new IllegalStateException("Cannot add branches after join() has been called"); + } + return addItem(name, func, resultType, getContext().getDurableConfig().getSerDes()); + } } From 425c49b925cc7d03af3f81f4b636ede75f5c7d5d Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sat, 21 Mar 2026 20:04:31 -0700 Subject: [PATCH 04/15] rename and minor fix --- .../amazon/lambda/durable/model/MapError.java | 8 +++++++- .../durable/operation/BaseDurableOperation.java | 7 +++---- .../durable/operation/ChildContextOperation.java | 6 +++--- .../durable/operation/ConcurrencyOperation.java | 16 ++++++++-------- .../lambda/durable/operation/MapOperation.java | 10 ++-------- .../durable/operation/ParallelOperation.java | 3 ++- 6 files changed, 25 insertions(+), 25 deletions(-) diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/MapError.java b/sdk/src/main/java/software/amazon/lambda/durable/model/MapError.java index 478a48edd..4a49a8516 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/MapError.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/model/MapError.java @@ -3,6 +3,7 @@ package software.amazon.lambda.durable.model; import java.util.List; +import software.amazon.lambda.durable.util.ExceptionHelper; /** * Error details for a failed map item. @@ -14,4 +15,9 @@ * @param errorMessage the error message * @param stackTrace the stack trace frames, or null */ -public record MapError(String errorType, String errorMessage, List stackTrace) {} +public record MapError(String errorType, String errorMessage, List stackTrace) { + public static MapError of(Throwable e) { + return new MapError( + e.getClass().getName(), e.getMessage(), ExceptionHelper.serializeStackTrace(e.getStackTrace())); + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java index bd951b407..f891f5e9b 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java @@ -144,13 +144,12 @@ protected Operation getOperation() { } /** - * Gets the direct child Operations of a give context operation. + * Gets the direct child Operations of this context operation * - * @param operationId the operation id of the context * @return list of the child Operations */ - protected List getChildOperations(String operationId) { - return executionManager.getChildOperations(operationId); + protected List getChildOperations() { + return executionManager.getChildOperations(getOperationId()); } /** diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java index 4fb9e1beb..d12690019 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java @@ -243,7 +243,7 @@ public T get() { // throw a general failed exception if a user exception is not reconstructed return switch (getSubType()) { - case WAIT_FOR_CALLBACK -> handleWaitForCallbackFailure(op); + case WAIT_FOR_CALLBACK -> handleWaitForCallbackFailure(); case MAP -> throw new ChildContextFailedException(op); case MAP_ITERATION -> throw new ChildContextFailedException(op); case PARALLEL -> throw new ChildContextFailedException(op); @@ -254,8 +254,8 @@ public T get() { } } - private T handleWaitForCallbackFailure(Operation op) { - var childrenOps = getChildOperations(op.id()); + private T handleWaitForCallbackFailure() { + var childrenOps = getChildOperations(); var callbackOp = childrenOps.stream() .filter(o -> o.type() == OperationType.CALLBACK) .findFirst() diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index 0df9b2f80..221fd55df 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -52,8 +52,8 @@ public abstract class ConcurrencyOperation extends BaseDurableOperation { private final AtomicInteger runningCount = new AtomicInteger(0); protected final AtomicBoolean isJoined = new AtomicBoolean(false); private final Queue> pendingQueue = new ConcurrentLinkedDeque<>(); - private final List> childOperations = Collections.synchronizedList(new ArrayList<>()); - private final Set completedOperations = Collections.synchronizedSet(new HashSet()); + private final List> branches = Collections.synchronizedList(new ArrayList<>()); + private final Set completedOperations = Collections.synchronizedSet(new HashSet<>()); private final OperationIdGenerator operationIdGenerator; private final DurableContextImpl rootContext; private ConcurrencyCompletionStatus completionStatus; @@ -117,7 +117,7 @@ public ChildContextOperation addItem( if (isOperationCompleted()) throw new IllegalStateException("Cannot add items to a completed operation"); var operationId = this.operationIdGenerator.nextOperationId(); var childOp = createItem(operationId, name, function, resultType, serDes, this.rootContext); - childOperations.add(childOp); + branches.add(childOp); pendingQueue.add(childOp); logger.debug("Item added {}", name); executeNextItemIfAllowed(); @@ -133,7 +133,7 @@ protected ChildContextOperation enqueueItem( String name, Function function, TypeToken resultType, SerDes serDes) { var operationId = this.operationIdGenerator.nextOperationId(); var childOp = createItem(operationId, name, function, resultType, serDes, this.rootContext); - childOperations.add(childOp); + branches.add(childOp); pendingQueue.add(childOp); logger.debug("Item enqueued {}", name); return childOp; @@ -225,7 +225,7 @@ public void onItemComplete(ChildContextOperation child) { */ protected void validateItemCount() { if (minSuccessful != null && minSuccessful > getTotalItems()) { - throw new IllegalArgumentException("minSuccessful (" + minSuccessful + throw new IllegalStateException("minSuccessful (" + minSuccessful + ") exceeds the number of registered items (" + getTotalItems() + ")"); } } @@ -292,15 +292,15 @@ protected int getFailedCount() { } protected int getTotalItems() { - return childOperations.size(); + return branches.size(); } protected ConcurrencyCompletionStatus getCompletionStatus() { return completionStatus; } - protected List> getChildOperations() { - return Collections.unmodifiableList(childOperations); + protected List> getBranches() { + return branches; } /** Returns true if all items have finished (no pending, no running). Used by subclasses to override canComplete. */ diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java index da49ef8c7..4ede1d325 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -23,7 +23,6 @@ import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; import software.amazon.lambda.durable.serde.SerDes; -import software.amazon.lambda.durable.util.ExceptionHelper; /** * Executes a map operation: applies a function to each item in a collection concurrently, with each item running in its @@ -200,7 +199,7 @@ public MapResult get() { */ @SuppressWarnings("unchecked") private MapResult aggregateResults() { - var children = getChildOperations(); + var children = getBranches(); var resultItems = new ArrayList>(Collections.nCopies(items.size(), null)); for (int i = 0; i < children.size(); i++) { @@ -212,7 +211,7 @@ private MapResult aggregateResults() { try { resultItems.set(i, MapResultItem.succeeded(branch.get())); } catch (Exception e) { - resultItems.set(i, MapResultItem.failed(buildMapError(e))); + resultItems.set(i, MapResultItem.failed(MapError.of(e))); } } @@ -223,9 +222,4 @@ private MapResult aggregateResults() { return new MapResult<>(resultItems, completionStatus); } - - private static MapError buildMapError(Exception e) { - return new MapError( - e.getClass().getName(), e.getMessage(), ExceptionHelper.serializeStackTrace(e.getStackTrace())); - } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index 39490ea40..7023cb6f8 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -44,7 +44,8 @@ */ public class ParallelOperation extends ConcurrencyOperation implements ParallelDurableFuture { - private boolean skipCheckpoint = false; + // this field could be written and read in different threads + private volatile boolean skipCheckpoint = false; public ParallelOperation( OperationIdentifier operationIdentifier, From 3f2ab6566941a91de1f68aa62ab47b7edb94185b Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sat, 21 Mar 2026 21:13:31 -0700 Subject: [PATCH 05/15] add config for parallel branch and runInChildContext --- docs/spec/map.md | 6 +- .../durable/examples/CallbackExample.java | 2 +- .../durable/examples/ComplexMapExample.java | 4 +- .../examples/CustomPollingExample.java | 2 +- .../examples/ErrorHandlingExample.java | 4 +- .../examples/GenericInputOutputExample.java | 2 +- .../durable/examples/GenericTypesExample.java | 2 +- .../examples/RetryInProcessExample.java | 2 +- .../durable/examples/SimpleInvokeExample.java | 2 +- .../durable/examples/WaitAtLeastExample.java | 2 +- .../examples/WaitAtLeastInProcessExample.java | 2 +- .../examples/parallel/ParallelExample.java | 2 +- .../ParallelFailureToleranceExample.java | 6 +- .../parallel/ParallelWithWaitExample.java | 2 +- .../durable/CallbackIntegrationTest.java | 1 + .../durable/CustomConfigIntegrationTest.java | 1 + .../durable/CustomSerDesIntegrationTest.java | 1 + .../durable/ExceptionIntegrationTest.java | 2 + .../lambda/durable/MapIntegrationTest.java | 14 +- .../durable/StepSemanticsIntegrationTest.java | 2 + .../WaitForConditionIntegrationTest.java | 1 + .../durable/retry/RetryIntegrationTest.java | 2 +- .../lambda/durable/testing/SkipTimeTest.java | 2 +- .../amazon/lambda/durable/CallbackConfig.java | 107 +------ .../amazon/lambda/durable/DurableConfig.java | 2 +- .../amazon/lambda/durable/DurableContext.java | 37 +++ .../lambda/durable/ParallelDurableFuture.java | 62 +++- .../amazon/lambda/durable/StepConfig.java | 99 +------ .../lambda/durable/WaitForCallbackConfig.java | 72 +---- .../lambda/durable/config/CallbackConfig.java | 110 +++++++ .../{ => config}/CompletionConfig.java | 2 +- .../durable/{ => config}/InvokeConfig.java | 2 +- .../durable/{ => config}/MapConfig.java | 2 +- .../durable/config/ParallelBranchConfig.java | 69 +++++ .../durable/{ => config}/ParallelConfig.java | 2 +- .../config/RunInChildContextConfig.java | 72 +++++ .../lambda/durable/config/StepConfig.java | 112 +++++++ .../durable/{ => config}/StepSemantics.java | 2 +- .../durable/config/WaitForCallbackConfig.java | 75 +++++ .../{ => config}/WaitForConditionConfig.java | 3 +- .../durable/context/DurableContextImpl.java | 278 ++++++++++++++---- .../amazon/lambda/durable/model/MapError.java | 23 -- .../lambda/durable/model/MapResult.java | 57 ++++ .../lambda/durable/model/MapResultItem.java | 42 --- .../durable/operation/CallbackOperation.java | 2 +- .../operation/ChildContextOperation.java | 10 +- .../durable/operation/InvokeOperation.java | 2 +- .../durable/operation/MapOperation.java | 19 +- .../durable/operation/ParallelOperation.java | 40 +-- .../durable/operation/StepOperation.java | 4 +- .../operation/WaitForConditionOperation.java | 2 +- .../lambda/durable/DurableContextTest.java | 5 +- .../DurationValidationIntegrationTest.java | 1 + .../amazon/lambda/durable/TypeTokenTest.java | 6 +- .../{ => config}/CallbackConfigTest.java | 2 +- .../{ => config}/CompletionConfigTest.java | 2 +- .../durable/{ => config}/MapConfigTest.java | 2 +- .../{ => config}/ParallelConfigTest.java | 2 +- .../durable/{ => config}/StepConfigTest.java | 2 +- .../WaitForCallbackConfigTest.java | 2 +- .../WaitForConditionConfigTest.java | 2 +- .../{ => execution}/DurableExecutionTest.java | 5 +- .../DurableExecutionWrapperTest.java | 6 +- .../lambda/durable/model/MapResultTest.java | 44 +-- .../operation/CallbackOperationTest.java | 2 +- .../operation/ChildContextOperationTest.java | 14 +- .../operation/ConcurrencyOperationTest.java | 5 +- .../operation/InvokeOperationTest.java | 2 +- .../operation/ParallelOperationTest.java | 4 +- .../durable/operation/StepOperationTest.java | 2 +- .../WaitForConditionOperationTest.java | 2 +- .../durable/retry/RetryStrategiesTest.java | 2 +- 72 files changed, 963 insertions(+), 526 deletions(-) create mode 100644 sdk/src/main/java/software/amazon/lambda/durable/config/CallbackConfig.java rename sdk/src/main/java/software/amazon/lambda/durable/{ => config}/CompletionConfig.java (97%) rename sdk/src/main/java/software/amazon/lambda/durable/{ => config}/InvokeConfig.java (98%) rename sdk/src/main/java/software/amazon/lambda/durable/{ => config}/MapConfig.java (98%) create mode 100644 sdk/src/main/java/software/amazon/lambda/durable/config/ParallelBranchConfig.java rename sdk/src/main/java/software/amazon/lambda/durable/{ => config}/ParallelConfig.java (98%) create mode 100644 sdk/src/main/java/software/amazon/lambda/durable/config/RunInChildContextConfig.java create mode 100644 sdk/src/main/java/software/amazon/lambda/durable/config/StepConfig.java rename sdk/src/main/java/software/amazon/lambda/durable/{ => config}/StepSemantics.java (94%) create mode 100644 sdk/src/main/java/software/amazon/lambda/durable/config/WaitForCallbackConfig.java rename sdk/src/main/java/software/amazon/lambda/durable/{ => config}/WaitForConditionConfig.java (97%) delete mode 100644 sdk/src/main/java/software/amazon/lambda/durable/model/MapError.java delete mode 100644 sdk/src/main/java/software/amazon/lambda/durable/model/MapResultItem.java rename sdk/src/test/java/software/amazon/lambda/durable/{ => config}/CallbackConfigTest.java (97%) rename sdk/src/test/java/software/amazon/lambda/durable/{ => config}/CompletionConfigTest.java (98%) rename sdk/src/test/java/software/amazon/lambda/durable/{ => config}/MapConfigTest.java (98%) rename sdk/src/test/java/software/amazon/lambda/durable/{ => config}/ParallelConfigTest.java (96%) rename sdk/src/test/java/software/amazon/lambda/durable/{ => config}/StepConfigTest.java (98%) rename sdk/src/test/java/software/amazon/lambda/durable/{ => config}/WaitForCallbackConfigTest.java (99%) rename sdk/src/test/java/software/amazon/lambda/durable/{ => config}/WaitForConditionConfigTest.java (98%) rename sdk/src/test/java/software/amazon/lambda/durable/{ => execution}/DurableExecutionTest.java (98%) rename sdk/src/test/java/software/amazon/lambda/durable/{ => execution}/DurableExecutionWrapperTest.java (96%) diff --git a/docs/spec/map.md b/docs/spec/map.md index 51a502f14..ee63796f7 100644 --- a/docs/spec/map.md +++ b/docs/spec/map.md @@ -290,7 +290,7 @@ import software.amazon.awssdk.services.lambda.model.ContextOptions; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; import software.amazon.awssdk.services.lambda.model.OperationUpdate; -import software.amazon.lambda.durable.CompletionConfig; +import software.amazon.lambda.durable.config.CompletionConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.model.CompletionReason; @@ -480,7 +480,7 @@ import java.util.Collections; import java.util.List; import java.util.function.Function; import software.amazon.lambda.durable.DurableContext; -import software.amazon.lambda.durable.MapConfig; +import software.amazon.lambda.durable.config.MapConfig; import software.amazon.lambda.durable.DurableContext.MapFunction; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.model.BatchResult; @@ -900,7 +900,7 @@ package software.amazon.lambda.durable.operation; import java.util.List; import software.amazon.lambda.durable.DurableContext; -import software.amazon.lambda.durable.MapConfig; +import software.amazon.lambda.durable.config.MapConfig; import software.amazon.lambda.durable.DurableContext.MapFunction; import software.amazon.lambda.durable.model.BatchResult; import software.amazon.lambda.durable.model.OperationSubType; diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/CallbackExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/CallbackExample.java index 31c9b784a..29697f277 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/CallbackExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/CallbackExample.java @@ -3,9 +3,9 @@ package software.amazon.lambda.durable.examples; import java.time.Duration; -import software.amazon.lambda.durable.CallbackConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; +import software.amazon.lambda.durable.config.CallbackConfig; /** * Example demonstrating callback operations for external system integration. diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/ComplexMapExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/ComplexMapExample.java index e7ba2726d..ce010e575 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/ComplexMapExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/ComplexMapExample.java @@ -5,10 +5,10 @@ import java.time.Duration; import java.util.List; import java.util.stream.Collectors; -import software.amazon.lambda.durable.CompletionConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.MapConfig; +import software.amazon.lambda.durable.config.CompletionConfig; +import software.amazon.lambda.durable.config.MapConfig; /** * Example demonstrating advanced map features: wait operations inside branches, error handling, and early termination. diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/CustomPollingExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/CustomPollingExample.java index 9aa0c68ca..aab698e97 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/CustomPollingExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/CustomPollingExample.java @@ -6,7 +6,7 @@ import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.InvokeConfig; +import software.amazon.lambda.durable.config.InvokeConfig; import software.amazon.lambda.durable.retry.JitterStrategy; import software.amazon.lambda.durable.retry.PollingStrategies; diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/ErrorHandlingExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/ErrorHandlingExample.java index e74a60b9b..cc58ef42e 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/ErrorHandlingExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/ErrorHandlingExample.java @@ -6,8 +6,8 @@ import org.slf4j.LoggerFactory; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; -import software.amazon.lambda.durable.StepSemantics; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.StepSemantics; import software.amazon.lambda.durable.exception.StepFailedException; import software.amazon.lambda.durable.exception.StepInterruptedException; import software.amazon.lambda.durable.retry.RetryStrategies; diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/GenericInputOutputExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/GenericInputOutputExample.java index 47bacf9c2..e996b4835 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/GenericInputOutputExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/GenericInputOutputExample.java @@ -9,8 +9,8 @@ import org.slf4j.LoggerFactory; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.retry.RetryStrategies; /** diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/GenericTypesExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/GenericTypesExample.java index b4edd512e..1b10f5e55 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/GenericTypesExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/GenericTypesExample.java @@ -9,8 +9,8 @@ import org.slf4j.LoggerFactory; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.retry.RetryStrategies; /** diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/RetryInProcessExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/RetryInProcessExample.java index 79644bed6..ee5a8e3de 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/RetryInProcessExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/RetryInProcessExample.java @@ -9,7 +9,7 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.retry.JitterStrategy; import software.amazon.lambda.durable.retry.RetryStrategies; diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/SimpleInvokeExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/SimpleInvokeExample.java index 4c976de64..e54f5418b 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/SimpleInvokeExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/SimpleInvokeExample.java @@ -4,7 +4,7 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.InvokeConfig; +import software.amazon.lambda.durable.config.InvokeConfig; /** * Simple example demonstrating basic invoke execution with the Durable Execution SDK. diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastExample.java index 22d2201ae..e08b55f99 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastExample.java @@ -8,7 +8,7 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.retry.RetryStrategies; /** diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastInProcessExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastInProcessExample.java index a8454eaae..f6103c4c6 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastInProcessExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastInProcessExample.java @@ -8,7 +8,7 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.retry.RetryStrategies; /** diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java index 1ae644e26..35c91f37c 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java @@ -7,8 +7,8 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.ParallelConfig; import software.amazon.lambda.durable.ParallelDurableFuture; +import software.amazon.lambda.durable.config.ParallelConfig; import software.amazon.lambda.durable.model.ParallelResult; /** diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java index 7a4cec7aa..8b7c97bba 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java @@ -4,12 +4,12 @@ import java.util.ArrayList; import java.util.List; -import software.amazon.lambda.durable.CompletionConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.ParallelConfig; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.CompletionConfig; +import software.amazon.lambda.durable.config.ParallelConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.model.ParallelResult; import software.amazon.lambda.durable.retry.RetryStrategies; diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java index 66a8be461..a7c0e3ac2 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java @@ -8,7 +8,7 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.ParallelConfig; +import software.amazon.lambda.durable.config.ParallelConfig; import software.amazon.lambda.durable.model.ParallelResult; /** diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CallbackIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CallbackIntegrationTest.java index 8ab42779f..b253425ee 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CallbackIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CallbackIntegrationTest.java @@ -10,6 +10,7 @@ import software.amazon.awssdk.services.lambda.model.ErrorObject; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; +import software.amazon.lambda.durable.config.CallbackConfig; import software.amazon.lambda.durable.exception.CallbackFailedException; import software.amazon.lambda.durable.exception.CallbackTimeoutException; import software.amazon.lambda.durable.model.ExecutionStatus; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomConfigIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomConfigIntegrationTest.java index a6dd8a78a..e6f02ed75 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomConfigIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomConfigIntegrationTest.java @@ -11,6 +11,7 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaClient; import software.amazon.lambda.durable.client.LambdaDurableFunctionsClient; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.serde.JacksonSerDes; import software.amazon.lambda.durable.serde.SerDes; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomSerDesIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomSerDesIntegrationTest.java index 6c966e214..336c81f5c 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomSerDesIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomSerDesIntegrationTest.java @@ -7,6 +7,7 @@ import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.serde.JacksonSerDes; import software.amazon.lambda.durable.serde.SerDes; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ExceptionIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ExceptionIntegrationTest.java index 8dd811dd9..5445a4f3d 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ExceptionIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ExceptionIntegrationTest.java @@ -8,6 +8,8 @@ import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.StepSemantics; import software.amazon.lambda.durable.exception.StepInterruptedException; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.retry.RetryStrategies; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java index a15cffabe..7788cc4a1 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java @@ -9,9 +9,11 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.CompletionConfig; +import software.amazon.lambda.durable.config.MapConfig; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.ExecutionStatus; -import software.amazon.lambda.durable.model.MapResultItem; +import software.amazon.lambda.durable.model.MapResult; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; class MapIntegrationTest { @@ -535,8 +537,10 @@ void testMapWithAllSuccessfulCompletionConfig_stopsOnFirstFailure() { assertEquals("OK1", result.getResult(0)); assertNotNull(result.getError(1)); // Items after the failure should be NOT_STARTED - assertEquals(MapResultItem.Status.SKIPPED, result.getItem(2).status()); - assertEquals(MapResultItem.Status.SKIPPED, result.getItem(3).status()); + assertEquals( + MapResult.MapResultItem.Status.SKIPPED, result.getItem(2).status()); + assertEquals( + MapResult.MapResultItem.Status.SKIPPED, result.getItem(3).status()); return "done"; }); @@ -873,7 +877,9 @@ void testMapWithNullResults() { assertTrue(result.allSucceeded()); assertEquals(3, result.size()); for (int i = 0; i < result.size(); i++) { - assertEquals(MapResultItem.Status.SUCCEEDED, result.getItem(i).status()); + assertEquals( + MapResult.MapResultItem.Status.SUCCEEDED, + result.getItem(i).status()); assertNull(result.getResult(i)); assertNull(result.getError(i)); } diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/StepSemanticsIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/StepSemanticsIntegrationTest.java index 34f9f87be..1b3b67409 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/StepSemanticsIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/StepSemanticsIntegrationTest.java @@ -6,6 +6,8 @@ import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.StepSemantics; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.retry.RetryStrategies; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/WaitForConditionIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/WaitForConditionIntegrationTest.java index 105dd6a88..689c25b9f 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/WaitForConditionIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/WaitForConditionIntegrationTest.java @@ -10,6 +10,7 @@ import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.WaitForConditionConfig; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.model.WaitForConditionResult; import software.amazon.lambda.durable.retry.JitterStrategy; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/retry/RetryIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/retry/RetryIntegrationTest.java index 2debf140f..259d8c653 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/retry/RetryIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/retry/RetryIntegrationTest.java @@ -9,7 +9,7 @@ import org.junit.jupiter.api.Test; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; diff --git a/sdk-testing/src/test/java/software/amazon/lambda/durable/testing/SkipTimeTest.java b/sdk-testing/src/test/java/software/amazon/lambda/durable/testing/SkipTimeTest.java index cdaf87f3c..8fa2f591d 100644 --- a/sdk-testing/src/test/java/software/amazon/lambda/durable/testing/SkipTimeTest.java +++ b/sdk-testing/src/test/java/software/amazon/lambda/durable/testing/SkipTimeTest.java @@ -7,7 +7,7 @@ import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.retry.RetryStrategies; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java index 9cb859aee..158228fd4 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java @@ -2,109 +2,10 @@ // SPDX-License-Identifier: Apache-2.0 package software.amazon.lambda.durable; -import java.time.Duration; -import software.amazon.lambda.durable.serde.SerDes; -import software.amazon.lambda.durable.util.ParameterValidator; - -/** Configuration for callback operations. */ +/** @deprecated use {@link software.amazon.lambda.durable.config.CallbackConfig} instead. */ public class CallbackConfig { - private final Duration timeout; - private final Duration heartbeatTimeout; - private final SerDes serDes; - - private CallbackConfig(Builder builder) { - this.timeout = builder.timeout; - this.heartbeatTimeout = builder.heartbeatTimeout; - this.serDes = builder.serDes; - } - - /** - * Returns the maximum duration to wait for the callback to complete. - * - * @return the timeout duration, or null if not specified - */ - public Duration timeout() { - return timeout; - } - - /** - * Returns the maximum duration between heartbeats before the callback is considered failed. - * - * @return the heartbeat timeout duration, or null if not specified - */ - public Duration heartbeatTimeout() { - return heartbeatTimeout; - } - - /** Returns the custom serializer for this callback, or null if not specified (uses default SerDes). */ - public SerDes serDes() { - return serDes; - } - - /** Creates a new builder with default values. */ - public static Builder builder() { - return new Builder(null, null, null); - } - - /** Creates a new builder pre-populated with this config's values. */ - public Builder toBuilder() { - return new Builder(timeout, heartbeatTimeout, serDes); - } - - /** Builder for {@link CallbackConfig}. */ - public static class Builder { - private Duration timeout; - private Duration heartbeatTimeout; - private SerDes serDes; - - private Builder(Duration timeout, Duration heartbeatTimeout, SerDes serDes) { - this.timeout = timeout; - this.heartbeatTimeout = heartbeatTimeout; - this.serDes = serDes; - } - - /** - * Sets the maximum duration to wait for the callback to complete before timing out. - * - * @param timeout the timeout duration - * @return this builder for method chaining - */ - public Builder timeout(Duration timeout) { - ParameterValidator.validateOptionalDuration(timeout, "Callback timeout"); - this.timeout = timeout; - return this; - } - - /** - * Sets the maximum duration between heartbeats before the callback is considered failed. - * - * @param heartbeatTimeout the heartbeat timeout duration - * @return this builder for method chaining - */ - public Builder heartbeatTimeout(Duration heartbeatTimeout) { - ParameterValidator.validateOptionalDuration(heartbeatTimeout, "Heartbeat timeout"); - this.heartbeatTimeout = heartbeatTimeout; - return this; - } - - /** - * Sets a custom serializer for the callback. - * - *

    If not specified, the callback will use the default SerDes configured for the handler. This allows - * per-callback customization of serialization behavior, useful for callbacks that need special handling (e.g., - * custom date formats, encryption, compression). - * - * @param serDes the custom serializer to use, or null to use the default - * @return this builder for method chaining - */ - public Builder serDes(SerDes serDes) { - this.serDes = serDes; - return this; - } - - /** Builds the {@link CallbackConfig} instance. */ - public CallbackConfig build() { - return new CallbackConfig(this); - } + /** @deprecated use {@link software.amazon.lambda.durable.config.CallbackConfig#builder()} instead. */ + public static software.amazon.lambda.durable.config.CallbackConfig.Builder builder() { + return new software.amazon.lambda.durable.config.CallbackConfig.Builder(null, null, null); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java index e97843360..b83c7598c 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java @@ -267,7 +267,7 @@ public static final class Builder { private PollingStrategy pollingStrategy; private Duration checkpointDelay; - private Builder() {} + public Builder() {} /** * Sets a custom LambdaClient for production use. Use this method to customize the AWS SDK client with specific diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java index d689f730c..4e9d26adc 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java @@ -8,6 +8,14 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; +import software.amazon.lambda.durable.config.CallbackConfig; +import software.amazon.lambda.durable.config.InvokeConfig; +import software.amazon.lambda.durable.config.MapConfig; +import software.amazon.lambda.durable.config.ParallelConfig; +import software.amazon.lambda.durable.config.RunInChildContextConfig; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.WaitForCallbackConfig; +import software.amazon.lambda.durable.config.WaitForConditionConfig; import software.amazon.lambda.durable.context.BaseContext; import software.amazon.lambda.durable.model.MapResult; import software.amazon.lambda.durable.model.WaitForConditionResult; @@ -259,6 +267,35 @@ DurableFuture invokeAsync( /** Asynchronously runs a function in a child context using a {@link TypeToken} for generic result types. */ DurableFuture runInChildContextAsync(String name, TypeToken resultType, Function func); + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param the result type + * @param name the unique operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @return the child context result + */ + T runInChildContext( + String name, Class resultType, Function func, RunInChildContextConfig config); + + /** + * Runs a function in a child context using a {@link TypeToken} for generic result types, blocking until complete. + */ + T runInChildContext( + String name, TypeToken resultType, Function func, RunInChildContextConfig config); + + /** Asynchronously runs a function in a child context, returning a {@link DurableFuture}. */ + DurableFuture runInChildContextAsync( + String name, Class resultType, Function func, RunInChildContextConfig config); + + /** Asynchronously runs a function in a child context using a {@link TypeToken} for generic result types. */ + DurableFuture runInChildContextAsync( + String name, TypeToken resultType, Function func, RunInChildContextConfig config); + MapResult map(String name, Collection items, Class resultType, MapFunction function); MapResult map( diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ParallelDurableFuture.java b/sdk/src/main/java/software/amazon/lambda/durable/ParallelDurableFuture.java index 039cfc895..b71198d7d 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/ParallelDurableFuture.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/ParallelDurableFuture.java @@ -3,14 +3,72 @@ package software.amazon.lambda.durable; import java.util.function.Function; +import software.amazon.lambda.durable.config.ParallelBranchConfig; import software.amazon.lambda.durable.model.ParallelResult; /** User-facing context for managing parallel branch execution within a durable function. */ public interface ParallelDurableFuture extends AutoCloseable, DurableFuture { - DurableFuture branch(String name, Class resultType, Function func); + /** + * Registers and immediately starts a branch (respects maxConcurrency). + * + * @param name the branch name + * @param resultType the result type token for generic types + * @param func the function to execute in the branch's child context + * @param the result type + * @return a {@link DurableFuture} that will contain the branch result + * @throws IllegalStateException if called after {@link #close()} + */ + default DurableFuture branch(String name, Class resultType, Function func) { + return branch( + name, + TypeToken.get(resultType), + func, + ParallelBranchConfig.builder().build()); + } - DurableFuture branch(String name, TypeToken resultType, Function func); + /** + * Registers and immediately starts a branch (respects maxConcurrency). + * + * @param name the branch name + * @param resultType the result type token for generic types + * @param func the function to execute in the branch's child context + * @param the result type + * @return a {@link DurableFuture} that will contain the branch result + * @throws IllegalStateException if called after {@link #close()} + */ + default DurableFuture branch(String name, TypeToken resultType, Function func) { + return branch(name, resultType, func, ParallelBranchConfig.builder().build()); + } + /** + * Registers and immediately starts a branch (respects maxConcurrency). + * + * @param name the branch name + * @param resultType the result type token for generic types + * @param func the function to execute in the branch's child context + * @param the result type + * @return a {@link DurableFuture} that will contain the branch result + * @throws IllegalStateException if called after {@link #close()} + */ + default DurableFuture branch( + String name, Class resultType, Function func, ParallelBranchConfig config) { + return branch(name, TypeToken.get(resultType), func, config); + } + + /** + * Registers and immediately starts a branch (respects maxConcurrency). + * + * @param name the branch name + * @param resultType the result type token for generic types + * @param func the function to execute in the branch's child context + * @param the result type + * @return a {@link DurableFuture} that will contain the branch result + * @throws IllegalStateException if called after {@link #close()} + */ + DurableFuture branch( + String name, TypeToken resultType, Function func, ParallelBranchConfig config); + + /** Calls {@link #get()} if not already called. Guarantees that the context is closed. */ void close(); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java index d3e6b10a9..997ce75dd 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java @@ -2,111 +2,22 @@ // SPDX-License-Identifier: Apache-2.0 package software.amazon.lambda.durable; -import software.amazon.lambda.durable.retry.RetryStrategies; -import software.amazon.lambda.durable.retry.RetryStrategy; -import software.amazon.lambda.durable.serde.SerDes; - /** * Configuration options for step operations in durable executions. * *

    This class provides a builder pattern for configuring various aspects of step execution, including retry behavior * and delivery semantics. + * + * @deprecated use {@link software.amazon.lambda.durable.config.StepConfig} */ public class StepConfig { - private final RetryStrategy retryStrategy; - private final StepSemantics semantics; - private final SerDes serDes; - - private StepConfig(Builder builder) { - this.retryStrategy = builder.retryStrategy; - this.semantics = builder.semantics; - this.serDes = builder.serDes; - } - - /** Returns the retry strategy for this step, or the default strategy if not specified. */ - public RetryStrategy retryStrategy() { - return retryStrategy != null ? retryStrategy : RetryStrategies.Presets.DEFAULT; - } - - /** Returns the delivery semantics for this step, defaults to AT_LEAST_ONCE_PER_RETRY if not specified. */ - public StepSemantics semantics() { - return semantics != null ? semantics : StepSemantics.AT_LEAST_ONCE_PER_RETRY; - } - - /** Returns the custom serializer for this step, or null if not specified (uses default SerDes). */ - public SerDes serDes() { - return serDes; - } - - public Builder toBuilder() { - return new Builder(retryStrategy, semantics, serDes); - } - /** * Creates a new builder for StepConfig. * * @return a new Builder instance + * @deprecated use {@link software.amazon.lambda.durable.config.StepConfig#builder} */ - public static Builder builder() { - return new Builder(null, null, null); - } - - /** Builder for creating StepConfig instances. */ - public static class Builder { - private RetryStrategy retryStrategy; - private StepSemantics semantics; - private SerDes serDes; - - private Builder(RetryStrategy retryStrategy, StepSemantics semantics, SerDes serDes) { - this.retryStrategy = retryStrategy; - this.semantics = semantics; - this.serDes = serDes; - } - - /** - * Sets the retry strategy for the step. - * - * @param retryStrategy the retry strategy to use, or null for default behavior - * @return this builder for method chaining - */ - public Builder retryStrategy(RetryStrategy retryStrategy) { - this.retryStrategy = retryStrategy; - return this; - } - - /** - * Sets the delivery semantics for the step. - * - * @param semantics the delivery semantics to use, defaults to AT_LEAST_ONCE_PER_RETRY if not specified - * @return this builder for method chaining - */ - public Builder semantics(StepSemantics semantics) { - this.semantics = semantics; - return this; - } - - /** - * Sets a custom serializer for the step. - * - *

    If not specified, the step will use the default SerDes configured for the handler. This allows per-step - * customization of serialization behavior, useful for steps that need special handling (e.g., custom date - * formats, encryption, compression). - * - * @param serDes the custom serializer to use, or null to use the default - * @return this builder for method chaining - */ - public Builder serDes(SerDes serDes) { - this.serDes = serDes; - return this; - } - - /** - * Builds the StepConfig instance. - * - * @return a new StepConfig with the configured options - */ - public StepConfig build() { - return new StepConfig(this); - } + public static software.amazon.lambda.durable.config.StepConfig.Builder builder() { + return new software.amazon.lambda.durable.config.StepConfig.Builder(null, null, null); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java index 8c6b7e28a..fd7096171 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java @@ -2,74 +2,10 @@ // SPDX-License-Identifier: Apache-2.0 package software.amazon.lambda.durable; -/** - * Configuration for the {@code waitForCallback} composite operation. - * - *

    Combines a {@link StepConfig} (for the step that produces the callback) and a {@link CallbackConfig} (for the - * callback wait itself). - */ +/** @deprecated use {@link software.amazon.lambda.durable.config.WaitForCallbackConfig} instead. */ public class WaitForCallbackConfig { - private final StepConfig stepConfig; - private final CallbackConfig callbackConfig; - - private WaitForCallbackConfig(Builder builder) { - this.stepConfig = builder.stepConfig == null ? StepConfig.builder().build() : builder.stepConfig; - this.callbackConfig = - builder.callbackConfig == null ? CallbackConfig.builder().build() : builder.callbackConfig; - } - - /** Returns the step configuration for the composite operation. */ - public StepConfig stepConfig() { - return stepConfig; - } - - /** Returns the callback configuration for the composite operation. */ - public CallbackConfig callbackConfig() { - return callbackConfig; - } - - /** Creates a new builder. */ - public static Builder builder() { - return new Builder(); - } - - /** Creates a builder pre-populated with this instance's values. */ - public Builder toBuilder() { - return new Builder().stepConfig(this.stepConfig).callbackConfig(this.callbackConfig); - } - - /** Builder for {@link WaitForCallbackConfig}. */ - public static class Builder { - private StepConfig stepConfig; - private CallbackConfig callbackConfig; - - private Builder() {} - - /** - * Sets the step configuration for the composite operation. - * - * @param stepConfig the step configuration - * @return this builder for method chaining - */ - public Builder stepConfig(StepConfig stepConfig) { - this.stepConfig = stepConfig; - return this; - } - - /** - * Sets the callback configuration for the composite operation. - * - * @param callbackConfig the callback configuration - * @return this builder for method chaining - */ - public Builder callbackConfig(CallbackConfig callbackConfig) { - this.callbackConfig = callbackConfig; - return this; - } - - /** Builds the WaitForCallbackConfig instance. */ - public WaitForCallbackConfig build() { - return new WaitForCallbackConfig(this); - } + /** @deprecated use {@link software.amazon.lambda.durable.config.WaitForCallbackConfig#builder()} instead. */ + public static software.amazon.lambda.durable.config.WaitForCallbackConfig.Builder builder() { + return new software.amazon.lambda.durable.config.WaitForCallbackConfig.Builder(); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/config/CallbackConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/CallbackConfig.java new file mode 100644 index 000000000..a71962aff --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/CallbackConfig.java @@ -0,0 +1,110 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +import java.time.Duration; +import software.amazon.lambda.durable.serde.SerDes; +import software.amazon.lambda.durable.util.ParameterValidator; + +/** Configuration for callback operations. */ +public class CallbackConfig { + private final Duration timeout; + private final Duration heartbeatTimeout; + private final SerDes serDes; + + private CallbackConfig(Builder builder) { + this.timeout = builder.timeout; + this.heartbeatTimeout = builder.heartbeatTimeout; + this.serDes = builder.serDes; + } + + /** + * Returns the maximum duration to wait for the callback to complete. + * + * @return the timeout duration, or null if not specified + */ + public Duration timeout() { + return timeout; + } + + /** + * Returns the maximum duration between heartbeats before the callback is considered failed. + * + * @return the heartbeat timeout duration, or null if not specified + */ + public Duration heartbeatTimeout() { + return heartbeatTimeout; + } + + /** Returns the custom serializer for this callback, or null if not specified (uses default SerDes). */ + public SerDes serDes() { + return serDes; + } + + /** Creates a new builder with default values. */ + public static Builder builder() { + return new Builder(null, null, null); + } + + /** Creates a new builder pre-populated with this config's values. */ + public Builder toBuilder() { + return new Builder(timeout, heartbeatTimeout, serDes); + } + + /** Builder for {@link CallbackConfig}. */ + public static class Builder { + private Duration timeout; + private Duration heartbeatTimeout; + private SerDes serDes; + + public Builder(Duration timeout, Duration heartbeatTimeout, SerDes serDes) { + this.timeout = timeout; + this.heartbeatTimeout = heartbeatTimeout; + this.serDes = serDes; + } + + /** + * Sets the maximum duration to wait for the callback to complete before timing out. + * + * @param timeout the timeout duration + * @return this builder for method chaining + */ + public Builder timeout(Duration timeout) { + ParameterValidator.validateOptionalDuration(timeout, "Callback timeout"); + this.timeout = timeout; + return this; + } + + /** + * Sets the maximum duration between heartbeats before the callback is considered failed. + * + * @param heartbeatTimeout the heartbeat timeout duration + * @return this builder for method chaining + */ + public Builder heartbeatTimeout(Duration heartbeatTimeout) { + ParameterValidator.validateOptionalDuration(heartbeatTimeout, "Heartbeat timeout"); + this.heartbeatTimeout = heartbeatTimeout; + return this; + } + + /** + * Sets a custom serializer for the callback. + * + *

    If not specified, the callback will use the default SerDes configured for the handler. This allows + * per-callback customization of serialization behavior, useful for callbacks that need special handling (e.g., + * custom date formats, encryption, compression). + * + * @param serDes the custom serializer to use, or null to use the default + * @return this builder for method chaining + */ + public Builder serDes(SerDes serDes) { + this.serDes = serDes; + return this; + } + + /** Builds the {@link CallbackConfig} instance. */ + public CallbackConfig build() { + return new CallbackConfig(this); + } + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/CompletionConfig.java similarity index 97% rename from sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java rename to sdk/src/main/java/software/amazon/lambda/durable/config/CompletionConfig.java index 49d12e9cd..bca1c4ecf 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/CompletionConfig.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; /** * Controls when a concurrent operation (map or parallel) completes. diff --git a/sdk/src/main/java/software/amazon/lambda/durable/InvokeConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/InvokeConfig.java similarity index 98% rename from sdk/src/main/java/software/amazon/lambda/durable/InvokeConfig.java rename to sdk/src/main/java/software/amazon/lambda/durable/config/InvokeConfig.java index 0f091e9c7..e9dc7af24 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/InvokeConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/InvokeConfig.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import software.amazon.lambda.durable.serde.SerDes; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/MapConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/MapConfig.java similarity index 98% rename from sdk/src/main/java/software/amazon/lambda/durable/MapConfig.java rename to sdk/src/main/java/software/amazon/lambda/durable/config/MapConfig.java index 6a68ea98e..7ca0e9b27 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/MapConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/MapConfig.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import software.amazon.lambda.durable.serde.SerDes; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/config/ParallelBranchConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/ParallelBranchConfig.java new file mode 100644 index 000000000..689f9aa54 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/ParallelBranchConfig.java @@ -0,0 +1,69 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +import software.amazon.lambda.durable.serde.SerDes; + +/** + * Configuration options for parallel branch in durable executions. + * + *

    This class provides a builder pattern for configuring various aspects of parallel branch execution + */ +public class ParallelBranchConfig { + private final SerDes serDes; + + private ParallelBranchConfig(Builder builder) { + this.serDes = builder.serDes; + } + + /** Returns the custom serializer for this step, or null if not specified (uses default SerDes). */ + public SerDes serDes() { + return serDes; + } + + public Builder toBuilder() { + return new Builder(serDes); + } + + /** + * Creates a new builder for ParallelBranchConfig. + * + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(null); + } + + /** Builder for creating StepConfig instances. */ + public static class Builder { + private SerDes serDes; + + public Builder(SerDes serDes) { + this.serDes = serDes; + } + + /** + * Sets a custom serializer for the step. + * + *

    If not specified, the parallel branch will use the default SerDes configured for the handler. This allows + * per-branch customization of serialization behavior, useful for branches that need special handling (e.g., + * custom date formats, encryption, compression). + * + * @param serDes the custom serializer to use, or null to use the default + * @return this builder for method chaining + */ + public Builder serDes(SerDes serDes) { + this.serDes = serDes; + return this; + } + + /** + * Builds the ParallelBranchConfig instance. + * + * @return a new StepConfig with the configured options + */ + public ParallelBranchConfig build() { + return new ParallelBranchConfig(this); + } + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ParallelConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/ParallelConfig.java similarity index 98% rename from sdk/src/main/java/software/amazon/lambda/durable/ParallelConfig.java rename to sdk/src/main/java/software/amazon/lambda/durable/config/ParallelConfig.java index a04a93bd4..3371be21b 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/ParallelConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/ParallelConfig.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; /** * Configuration options for parallel operations in durable executions. diff --git a/sdk/src/main/java/software/amazon/lambda/durable/config/RunInChildContextConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/RunInChildContextConfig.java new file mode 100644 index 000000000..7eb42cbaf --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/RunInChildContextConfig.java @@ -0,0 +1,72 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +import software.amazon.lambda.durable.serde.SerDes; + +/** + * Configuration options for RunInChildContext operations in durable executions. + * + *

    This class provides a builder pattern for configuring various aspects of RunInChildContext execution. + */ +public class RunInChildContextConfig { + private final SerDes serDes; + + private RunInChildContextConfig(Builder builder) { + this.serDes = builder.serDes; + } + + /** + * Returns the custom serializer for this RunInChildContext operation, or null if not specified (uses default + * SerDes). + */ + public SerDes serDes() { + return serDes; + } + + public Builder toBuilder() { + return new Builder(serDes); + } + + /** + * Creates a new builder for RunInChildContextConfig. + * + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(null); + } + + /** Builder for creating StepConfig instances. */ + public static class Builder { + private SerDes serDes; + + public Builder(SerDes serDes) { + this.serDes = serDes; + } + + /** + * Sets a custom serializer for the step. + * + *

    If not specified, the RunInChildContext operation will use the default SerDes configured for the handler. + * This allows per-operation customization of serialization behavior, useful for operations that need special + * handling (e.g., custom date formats, encryption, compression). + * + * @param serDes the custom serializer to use, or null to use the default + * @return this builder for method chaining + */ + public Builder serDes(SerDes serDes) { + this.serDes = serDes; + return this; + } + + /** + * Builds the RunInChildContextConfig instance. + * + * @return a new StepConfig with the configured options + */ + public RunInChildContextConfig build() { + return new RunInChildContextConfig(this); + } + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/config/StepConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/StepConfig.java new file mode 100644 index 000000000..8eada6faf --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/StepConfig.java @@ -0,0 +1,112 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +import software.amazon.lambda.durable.retry.RetryStrategies; +import software.amazon.lambda.durable.retry.RetryStrategy; +import software.amazon.lambda.durable.serde.SerDes; + +/** + * Configuration options for step operations in durable executions. + * + *

    This class provides a builder pattern for configuring various aspects of step execution, including retry behavior + * and delivery semantics. + */ +public class StepConfig { + private final RetryStrategy retryStrategy; + private final StepSemantics semantics; + private final SerDes serDes; + + private StepConfig(Builder builder) { + this.retryStrategy = builder.retryStrategy; + this.semantics = builder.semantics; + this.serDes = builder.serDes; + } + + /** Returns the retry strategy for this step, or the default strategy if not specified. */ + public RetryStrategy retryStrategy() { + return retryStrategy != null ? retryStrategy : RetryStrategies.Presets.DEFAULT; + } + + /** Returns the delivery semantics for this step, defaults to AT_LEAST_ONCE_PER_RETRY if not specified. */ + public StepSemantics semantics() { + return semantics != null ? semantics : StepSemantics.AT_LEAST_ONCE_PER_RETRY; + } + + /** Returns the custom serializer for this step, or null if not specified (uses default SerDes). */ + public SerDes serDes() { + return serDes; + } + + public Builder toBuilder() { + return new Builder(retryStrategy, semantics, serDes); + } + + /** + * Creates a new builder for StepConfig. + * + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(null, null, null); + } + + /** Builder for creating StepConfig instances. */ + public static class Builder { + private RetryStrategy retryStrategy; + private StepSemantics semantics; + private SerDes serDes; + + public Builder(RetryStrategy retryStrategy, StepSemantics semantics, SerDes serDes) { + this.retryStrategy = retryStrategy; + this.semantics = semantics; + this.serDes = serDes; + } + + /** + * Sets the retry strategy for the step. + * + * @param retryStrategy the retry strategy to use, or null for default behavior + * @return this builder for method chaining + */ + public Builder retryStrategy(RetryStrategy retryStrategy) { + this.retryStrategy = retryStrategy; + return this; + } + + /** + * Sets the delivery semantics for the step. + * + * @param semantics the delivery semantics to use, defaults to AT_LEAST_ONCE_PER_RETRY if not specified + * @return this builder for method chaining + */ + public Builder semantics(StepSemantics semantics) { + this.semantics = semantics; + return this; + } + + /** + * Sets a custom serializer for the step. + * + *

    If not specified, the step will use the default SerDes configured for the handler. This allows per-step + * customization of serialization behavior, useful for steps that need special handling (e.g., custom date + * formats, encryption, compression). + * + * @param serDes the custom serializer to use, or null to use the default + * @return this builder for method chaining + */ + public Builder serDes(SerDes serDes) { + this.serDes = serDes; + return this; + } + + /** + * Builds the StepConfig instance. + * + * @return a new StepConfig with the configured options + */ + public StepConfig build() { + return new StepConfig(this); + } + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/StepSemantics.java b/sdk/src/main/java/software/amazon/lambda/durable/config/StepSemantics.java similarity index 94% rename from sdk/src/main/java/software/amazon/lambda/durable/StepSemantics.java rename to sdk/src/main/java/software/amazon/lambda/durable/config/StepSemantics.java index 2028d30be..fc23ab3c7 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/StepSemantics.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/StepSemantics.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; /** * Delivery semantics for step operations. diff --git a/sdk/src/main/java/software/amazon/lambda/durable/config/WaitForCallbackConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/WaitForCallbackConfig.java new file mode 100644 index 000000000..e3bd57f44 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/WaitForCallbackConfig.java @@ -0,0 +1,75 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +/** + * Configuration for the {@code waitForCallback} composite operation. + * + *

    Combines a {@link StepConfig} (for the step that produces the callback) and a {@link CallbackConfig} (for the + * callback wait itself). + */ +public class WaitForCallbackConfig { + private final StepConfig stepConfig; + private final CallbackConfig callbackConfig; + + private WaitForCallbackConfig(Builder builder) { + this.stepConfig = builder.stepConfig == null ? StepConfig.builder().build() : builder.stepConfig; + this.callbackConfig = + builder.callbackConfig == null ? CallbackConfig.builder().build() : builder.callbackConfig; + } + + /** Returns the step configuration for the composite operation. */ + public StepConfig stepConfig() { + return stepConfig; + } + + /** Returns the callback configuration for the composite operation. */ + public CallbackConfig callbackConfig() { + return callbackConfig; + } + + /** Creates a new builder. */ + public static Builder builder() { + return new Builder(); + } + + /** Creates a builder pre-populated with this instance's values. */ + public Builder toBuilder() { + return new Builder().stepConfig(this.stepConfig).callbackConfig(this.callbackConfig); + } + + /** Builder for {@link WaitForCallbackConfig}. */ + public static class Builder { + private StepConfig stepConfig; + private CallbackConfig callbackConfig; + + public Builder() {} + + /** + * Sets the step configuration for the composite operation. + * + * @param stepConfig the step configuration + * @return this builder for method chaining + */ + public Builder stepConfig(StepConfig stepConfig) { + this.stepConfig = stepConfig; + return this; + } + + /** + * Sets the callback configuration for the composite operation. + * + * @param callbackConfig the callback configuration + * @return this builder for method chaining + */ + public Builder callbackConfig(CallbackConfig callbackConfig) { + this.callbackConfig = callbackConfig; + return this; + } + + /** Builds the WaitForCallbackConfig instance. */ + public WaitForCallbackConfig build() { + return new WaitForCallbackConfig(this); + } + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/WaitForConditionConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/WaitForConditionConfig.java similarity index 97% rename from sdk/src/main/java/software/amazon/lambda/durable/WaitForConditionConfig.java rename to sdk/src/main/java/software/amazon/lambda/durable/config/WaitForConditionConfig.java index 9ec3eb042..6cc54651a 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/WaitForConditionConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/WaitForConditionConfig.java @@ -1,7 +1,8 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; +import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.retry.WaitForConditionWaitStrategy; import software.amazon.lambda.durable.retry.WaitStrategies; import software.amazon.lambda.durable.serde.SerDes; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java index 9190fa543..40ab21dda 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java @@ -13,20 +13,21 @@ import java.util.function.Supplier; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.lambda.model.OperationType; -import software.amazon.lambda.durable.CallbackConfig; import software.amazon.lambda.durable.DurableCallbackFuture; import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; -import software.amazon.lambda.durable.InvokeConfig; -import software.amazon.lambda.durable.MapConfig; -import software.amazon.lambda.durable.ParallelConfig; import software.amazon.lambda.durable.ParallelDurableFuture; -import software.amazon.lambda.durable.StepConfig; import software.amazon.lambda.durable.StepContext; import software.amazon.lambda.durable.TypeToken; -import software.amazon.lambda.durable.WaitForCallbackConfig; -import software.amazon.lambda.durable.WaitForConditionConfig; +import software.amazon.lambda.durable.config.CallbackConfig; +import software.amazon.lambda.durable.config.InvokeConfig; +import software.amazon.lambda.durable.config.MapConfig; +import software.amazon.lambda.durable.config.ParallelConfig; +import software.amazon.lambda.durable.config.RunInChildContextConfig; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.WaitForCallbackConfig; +import software.amazon.lambda.durable.config.WaitForConditionConfig; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.execution.OperationIdGenerator; import software.amazon.lambda.durable.execution.ThreadType; @@ -163,14 +164,14 @@ public T step(String name, Class resultType, Function fun } @Override - public T step(String name, TypeToken typeToken, Function func) { - return step(name, typeToken, func, StepConfig.builder().build()); + public T step(String name, TypeToken resultType, Function func) { + return step(name, resultType, func, StepConfig.builder().build()); } @Override - public T step(String name, TypeToken typeToken, Function func, StepConfig config) { + public T step(String name, TypeToken resultType, Function func, StepConfig config) { // Simply delegate to stepAsync and block on the result - return stepAsync(name, typeToken, func, config).get(); + return stepAsync(name, resultType, func, config).get(); } @Override @@ -186,15 +187,15 @@ public DurableFuture stepAsync( } @Override - public DurableFuture stepAsync(String name, TypeToken typeToken, Function func) { - return stepAsync(name, typeToken, func, StepConfig.builder().build()); + public DurableFuture stepAsync(String name, TypeToken resultType, Function func) { + return stepAsync(name, resultType, func, StepConfig.builder().build()); } @Override public DurableFuture stepAsync( - String name, TypeToken typeToken, Function func, StepConfig config) { + String name, TypeToken resultType, Function func, StepConfig config) { Objects.requireNonNull(config, "config cannot be null"); - Objects.requireNonNull(typeToken, "typeToken cannot be null"); + Objects.requireNonNull(resultType, "resultType cannot be null"); ParameterValidator.validateOperationName(name); if (config.serDes() == null) { @@ -204,7 +205,7 @@ public DurableFuture stepAsync( // Create and start step operation with TypeToken var operation = new StepOperation<>( - OperationIdentifier.of(operationId, name, OperationType.STEP), func, typeToken, config, this); + OperationIdentifier.of(operationId, name, OperationType.STEP), func, resultType, config, this); operation.execute(); // Start the step (returns immediately) @@ -234,16 +235,16 @@ public T step(String name, Class resultType, Supplier func, StepConfig /** @deprecated use the variants accepting StepContext instead */ @Deprecated @Override - public T step(String name, TypeToken typeToken, Supplier func) { - return stepAsync(name, typeToken, func, StepConfig.builder().build()).get(); + public T step(String name, TypeToken resultType, Supplier func) { + return stepAsync(name, resultType, func, StepConfig.builder().build()).get(); } /** @deprecated use the variants accepting StepContext instead */ @Deprecated @Override - public T step(String name, TypeToken typeToken, Supplier func, StepConfig config) { + public T step(String name, TypeToken resultType, Supplier func, StepConfig config) { // Simply delegate to stepAsync and block on the result - return stepAsync(name, typeToken, func, config).get(); + return stepAsync(name, resultType, func, config).get(); } /** @deprecated use the variants accepting StepContext instead */ @@ -264,15 +265,15 @@ public DurableFuture stepAsync(String name, Class resultType, Supplier /** @deprecated use the variants accepting StepContext instead */ @Deprecated @Override - public DurableFuture stepAsync(String name, TypeToken typeToken, Supplier func) { - return stepAsync(name, typeToken, func, StepConfig.builder().build()); + public DurableFuture stepAsync(String name, TypeToken resultType, Supplier func) { + return stepAsync(name, resultType, func, StepConfig.builder().build()); } /** @deprecated use the variants accepting StepContext instead */ @Deprecated @Override - public DurableFuture stepAsync(String name, TypeToken typeToken, Supplier func, StepConfig config) { - return stepAsync(name, typeToken, stepContext -> func.get(), config); + public DurableFuture stepAsync(String name, TypeToken resultType, Supplier func, StepConfig config) { + return stepAsync(name, resultType, stepContext -> func.get(), config); } // ========== wait methods ========== @@ -318,19 +319,19 @@ public T invoke(String name, String functionName, U payload, Class res } @Override - public T invoke(String name, String functionName, U payload, TypeToken typeToken) { + public T invoke(String name, String functionName, U payload, TypeToken resultType) { return invokeAsync( name, functionName, payload, - typeToken, + resultType, InvokeConfig.builder().build()) .get(); } @Override - public T invoke(String name, String functionName, U payload, TypeToken typeToken, InvokeConfig config) { - return invokeAsync(name, functionName, payload, typeToken, config).get(); + public T invoke(String name, String functionName, U payload, TypeToken resultType, InvokeConfig config) { + return invokeAsync(name, functionName, payload, resultType, config).get(); } /** Asynchronously invokes another Lambda function with custom configuration. */ @@ -358,9 +359,9 @@ public DurableFuture invokeAsync(String name, String functionName, U p @Override public DurableFuture invokeAsync( - String name, String functionName, U payload, TypeToken typeToken, InvokeConfig config) { + String name, String functionName, U payload, TypeToken resultType, InvokeConfig config) { Objects.requireNonNull(config, "config cannot be null"); - Objects.requireNonNull(typeToken, "typeToken cannot be null"); + Objects.requireNonNull(resultType, "resultType cannot be null"); ParameterValidator.validateOperationName(name); if (config.serDes() == null) { @@ -378,7 +379,7 @@ public DurableFuture invokeAsync( OperationIdentifier.of(operationId, name, OperationType.CHAINED_INVOKE), functionName, payload, - typeToken, + resultType, config, this); @@ -394,8 +395,8 @@ public DurableCallbackFuture createCallback(String name, Class resultT } @Override - public DurableCallbackFuture createCallback(String name, TypeToken typeToken) { - return createCallback(name, typeToken, CallbackConfig.builder().build()); + public DurableCallbackFuture createCallback(String name, TypeToken resultType) { + return createCallback(name, resultType, CallbackConfig.builder().build()); } @Override @@ -405,7 +406,7 @@ public DurableCallbackFuture createCallback(String name, Class resultT } @Override - public DurableCallbackFuture createCallback(String name, TypeToken typeToken, CallbackConfig config) { + public DurableCallbackFuture createCallback(String name, TypeToken resultType, CallbackConfig config) { ParameterValidator.validateOperationName(name); if (config.serDes() == null) { config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build(); @@ -413,7 +414,7 @@ public DurableCallbackFuture createCallback(String name, TypeToken typ var operationId = nextOperationId(); var operation = new CallbackOperation<>( - OperationIdentifier.of(operationId, name, OperationType.CALLBACK), typeToken, config, this); + OperationIdentifier.of(operationId, name, OperationType.CALLBACK), resultType, config, this); operation.execute(); return operation; @@ -421,39 +422,185 @@ public DurableCallbackFuture createCallback(String name, TypeToken typ // ========== runInChildContext methods ========== + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @return the child context result + */ @Override public T runInChildContext(String name, Class resultType, Function func) { - return runInChildContextAsync(name, TypeToken.get(resultType), func).get(); + return runInChildContextAsync( + name, + TypeToken.get(resultType), + func, + RunInChildContextConfig.builder().build()) + .get(); } + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @return the child context result + */ @Override - public T runInChildContext(String name, TypeToken typeToken, Function func) { - return runInChildContextAsync(name, typeToken, func).get(); + public T runInChildContext(String name, TypeToken resultType, Function func) { + return runInChildContextAsync( + name, + resultType, + func, + RunInChildContextConfig.builder().build()) + .get(); } + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @return the DurableFuture of the child context result + */ @Override public DurableFuture runInChildContextAsync( String name, Class resultType, Function func) { - return runInChildContextAsync(name, TypeToken.get(resultType), func); + return runInChildContextAsync( + name, + TypeToken.get(resultType), + func, + RunInChildContextConfig.builder().build()); + } + + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @return the DurableFuture of the child context result + */ + @Override + public DurableFuture runInChildContextAsync( + String name, TypeToken resultType, Function func) { + return runInChildContextAsync( + name, + resultType, + func, + RunInChildContextConfig.builder().build(), + OperationSubType.RUN_IN_CHILD_CONTEXT); } + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @param config the configuration for the child context + * @return the child context result + */ + @Override + public T runInChildContext( + String name, Class resultType, Function func, RunInChildContextConfig config) { + return runInChildContextAsync(name, TypeToken.get(resultType), func, config) + .get(); + } + + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @param config the configuration for the child context + * @return the child context result + */ + @Override + public T runInChildContext( + String name, TypeToken resultType, Function func, RunInChildContextConfig config) { + return runInChildContextAsync(name, resultType, func, config).get(); + } + + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @param config the configuration for the child context + * @return the DurableFuture wrapping the child context result + */ + @Override + public DurableFuture runInChildContextAsync( + String name, Class resultType, Function func, RunInChildContextConfig config) { + return runInChildContextAsync(name, TypeToken.get(resultType), func, config); + } + + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @param config the configuration for the child context + * @return the DurableFuture wrapping the child context result + */ @Override public DurableFuture runInChildContextAsync( - String name, TypeToken typeToken, Function func) { - return runInChildContextAsync(name, typeToken, func, OperationSubType.RUN_IN_CHILD_CONTEXT); + String name, TypeToken resultType, Function func, RunInChildContextConfig config) { + return runInChildContextAsync(name, resultType, func, config, OperationSubType.RUN_IN_CHILD_CONTEXT); } private DurableFuture runInChildContextAsync( - String name, TypeToken typeToken, Function func, OperationSubType subType) { - Objects.requireNonNull(typeToken, "typeToken cannot be null"); + String name, + TypeToken resultType, + Function func, + RunInChildContextConfig config, + OperationSubType subType) { + Objects.requireNonNull(resultType, "resultType cannot be null"); + Objects.requireNonNull(config, "RunInChildContextConfig cannot be null"); ParameterValidator.validateOperationName(name); + + if (config.serDes() == null) { + config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build(); + } + var operationId = nextOperationId(); var operation = new ChildContextOperation<>( OperationIdentifier.of(operationId, name, OperationType.CONTEXT, subType), func, - typeToken, - getDurableConfig().getSerDes(), + resultType, + config, this); operation.execute(); @@ -581,9 +728,9 @@ public T waitForCallback(String name, Class resultType, BiConsumer T waitForCallback(String name, TypeToken typeToken, BiConsumer func) { + public T waitForCallback(String name, TypeToken resultType, BiConsumer func) { return waitForCallbackAsync( - name, typeToken, func, WaitForCallbackConfig.builder().build()) + name, resultType, func, WaitForCallbackConfig.builder().build()) .get(); } @@ -600,10 +747,10 @@ public T waitForCallback( @Override public T waitForCallback( String name, - TypeToken typeToken, + TypeToken resultType, BiConsumer func, WaitForCallbackConfig waitForCallbackConfig) { - return waitForCallbackAsync(name, typeToken, func, waitForCallbackConfig) + return waitForCallbackAsync(name, resultType, func, waitForCallbackConfig) .get(); } @@ -619,9 +766,9 @@ public DurableFuture waitForCallbackAsync( @Override public DurableFuture waitForCallbackAsync( - String name, TypeToken typeToken, BiConsumer func) { + String name, TypeToken resultType, BiConsumer func) { return waitForCallbackAsync( - name, typeToken, func, WaitForCallbackConfig.builder().build()); + name, resultType, func, WaitForCallbackConfig.builder().build()); } @Override @@ -636,10 +783,10 @@ public DurableFuture waitForCallbackAsync( @Override public DurableFuture waitForCallbackAsync( String name, - TypeToken typeToken, + TypeToken resultType, BiConsumer func, WaitForCallbackConfig waitForCallbackConfig) { - Objects.requireNonNull(typeToken, "typeToken cannot be null"); + Objects.requireNonNull(resultType, "resultType cannot be null"); Objects.requireNonNull(waitForCallbackConfig, "waitForCallbackConfig cannot be null"); // waitForCallback adds a suffix for the callback operation name and the submitter operation name so // the length restriction of waitForCallback name is different from the other operations. @@ -655,11 +802,11 @@ public DurableFuture waitForCallbackAsync( return runInChildContextAsync( name, - typeToken, + resultType, childCtx -> { var callback = childCtx.createCallback( name + WAIT_FOR_CALLBACK_CALLBACK_SUFFIX, - typeToken, + resultType, finalWaitForCallbackConfig.callbackConfig()); childCtx.step( name + WAIT_FOR_CALLBACK_SUBMITTER_SUFFIX, @@ -671,6 +818,9 @@ public DurableFuture waitForCallbackAsync( finalWaitForCallbackConfig.stepConfig()); return callback.get(); }, + RunInChildContextConfig.builder() + .serDes(finalWaitForCallbackConfig.stepConfig().serDes()) + .build(), OperationSubType.WAIT_FOR_CALLBACK); } @@ -704,12 +854,12 @@ public T waitForCondition( @Override public T waitForCondition( String name, - TypeToken typeToken, + TypeToken resultType, BiFunction> checkFunc, T initialState) { return waitForConditionAsync( name, - typeToken, + resultType, checkFunc, initialState, WaitForConditionConfig.builder().build()) @@ -719,11 +869,11 @@ public T waitForCondition( @Override public T waitForCondition( String name, - TypeToken typeToken, + TypeToken resultType, BiFunction> checkFunc, T initialState, WaitForConditionConfig config) { - return waitForConditionAsync(name, typeToken, checkFunc, initialState, config) + return waitForConditionAsync(name, resultType, checkFunc, initialState, config) .get(); } @@ -754,12 +904,12 @@ public DurableFuture waitForConditionAsync( @Override public DurableFuture waitForConditionAsync( String name, - TypeToken typeToken, + TypeToken resultType, BiFunction> checkFunc, T initialState) { return waitForConditionAsync( name, - typeToken, + resultType, checkFunc, initialState, WaitForConditionConfig.builder().build()); @@ -768,12 +918,12 @@ public DurableFuture waitForConditionAsync( @Override public DurableFuture waitForConditionAsync( String name, - TypeToken typeToken, + TypeToken resultType, BiFunction> checkFunc, T initialState, WaitForConditionConfig config) { Objects.requireNonNull(config, "config cannot be null"); - Objects.requireNonNull(typeToken, "typeToken cannot be null"); + Objects.requireNonNull(resultType, "resultType cannot be null"); Objects.requireNonNull(checkFunc, "checkFunc cannot be null"); Objects.requireNonNull(initialState, "initialState cannot be null"); ParameterValidator.validateOperationName(name); @@ -784,7 +934,7 @@ public DurableFuture waitForConditionAsync( var operationId = nextOperationId(); var operation = - new WaitForConditionOperation<>(operationId, name, checkFunc, typeToken, initialState, config, this); + new WaitForConditionOperation<>(operationId, name, checkFunc, resultType, initialState, config, this); operation.execute(); diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/MapError.java b/sdk/src/main/java/software/amazon/lambda/durable/model/MapError.java deleted file mode 100644 index 4a49a8516..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/MapError.java +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable.model; - -import java.util.List; -import software.amazon.lambda.durable.util.ExceptionHelper; - -/** - * Error details for a failed map item. - * - *

    Stores error information as plain strings so that {@link MapResult} can be serialized through the user's SerDes - * without requiring AWS SDK-specific Jackson modules. - * - * @param errorType the fully qualified exception class name - * @param errorMessage the error message - * @param stackTrace the stack trace frames, or null - */ -public record MapError(String errorType, String errorMessage, List stackTrace) { - public static MapError of(Throwable e) { - return new MapError( - e.getClass().getName(), e.getMessage(), ExceptionHelper.serializeStackTrace(e.getStackTrace())); - } -} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java index dce28a3d3..9fbbbcb09 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java @@ -4,6 +4,7 @@ import java.util.Collections; import java.util.List; +import software.amazon.lambda.durable.util.ExceptionHelper; /** * Result container for map operations. @@ -77,4 +78,60 @@ public List failed() { .map(MapResultItem::error) .toList(); } + + /** + * Represents the outcome of a single item in a map operation. + * + *

    Each item either succeeds with a result, fails with an error, or was never started. The status field indicates + * which case applies. + * + *

    Errors are stored as {@link MapError} (plain strings) rather than raw Throwable, so they survive serialization + * across checkpoint-and-replay cycles without requiring AWS SDK-specific Jackson modules. + * + * @param status the status of this item + * @param result the result value, or null if failed/not started + * @param error the error details, or null if succeeded/not started + * @param the result type + */ + public record MapResultItem(Status status, T result, MapError error) { + + /** Status of an individual map item. */ + public enum Status { + SUCCEEDED, + FAILED, + SKIPPED + } + + /** Creates a successful result item. */ + public static MapResultItem succeeded(T result) { + return new MapResultItem<>(Status.SUCCEEDED, result, null); + } + + /** Creates a failed result item. */ + public static MapResultItem failed(MapError error) { + return new MapResultItem<>(Status.FAILED, null, error); + } + + /** Creates a skipped result item. */ + public static MapResultItem skipped() { + return new MapResultItem<>(Status.SKIPPED, null, null); + } + } + + /** + * Error details for a failed map item. + * + *

    Stores error information as plain strings so that {@link MapResult} can be serialized through the user's + * SerDes without requiring AWS SDK-specific Jackson modules. + * + * @param errorType the fully qualified exception class name + * @param errorMessage the error message + * @param stackTrace the stack trace frames, or null + */ + public record MapError(String errorType, String errorMessage, List stackTrace) { + public static MapError of(Throwable e) { + return new MapError( + e.getClass().getName(), e.getMessage(), ExceptionHelper.serializeStackTrace(e.getStackTrace())); + } + } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResultItem.java b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResultItem.java deleted file mode 100644 index 414c40124..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResultItem.java +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable.model; - -/** - * Represents the outcome of a single item in a map operation. - * - *

    Each item either succeeds with a result, fails with an error, or was never started. The status field indicates - * which case applies. - * - *

    Errors are stored as {@link MapError} (plain strings) rather than raw Throwable, so they survive serialization - * across checkpoint-and-replay cycles without requiring AWS SDK-specific Jackson modules. - * - * @param status the status of this item - * @param result the result value, or null if failed/not started - * @param error the error details, or null if succeeded/not started - * @param the result type - */ -public record MapResultItem(Status status, T result, MapError error) { - - /** Status of an individual map item. */ - public enum Status { - SUCCEEDED, - FAILED, - SKIPPED - } - - /** Creates a successful result item. */ - public static MapResultItem succeeded(T result) { - return new MapResultItem<>(Status.SUCCEEDED, result, null); - } - - /** Creates a failed result item. */ - public static MapResultItem failed(MapError error) { - return new MapResultItem<>(Status.FAILED, null, error); - } - - /** Creates a skipped result item. */ - public static MapResultItem skipped() { - return new MapResultItem<>(Status.SKIPPED, null, null); - } -} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java index 9e1f63fc2..a9b8d06a0 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java @@ -6,9 +6,9 @@ import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; import software.amazon.awssdk.services.lambda.model.OperationUpdate; -import software.amazon.lambda.durable.CallbackConfig; import software.amazon.lambda.durable.DurableCallbackFuture; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.CallbackConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.CallbackFailedException; import software.amazon.lambda.durable.exception.CallbackTimeoutException; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java index d12690019..b25173598 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java @@ -17,6 +17,7 @@ import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.CallbackFailedException; import software.amazon.lambda.durable.exception.CallbackSubmitterException; @@ -28,7 +29,6 @@ import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; import software.amazon.lambda.durable.execution.SuspendExecutionException; import software.amazon.lambda.durable.model.OperationIdentifier; -import software.amazon.lambda.durable.serde.SerDes; import software.amazon.lambda.durable.util.ExceptionHelper; /** @@ -55,19 +55,19 @@ public ChildContextOperation( OperationIdentifier operationIdentifier, Function function, TypeToken resultTypeToken, - SerDes resultSerDes, + RunInChildContextConfig config, DurableContextImpl durableContext) { - this(operationIdentifier, function, resultTypeToken, resultSerDes, durableContext, null); + this(operationIdentifier, function, resultTypeToken, config, durableContext, null); } public ChildContextOperation( OperationIdentifier operationIdentifier, Function function, TypeToken resultTypeToken, - SerDes resultSerDes, + RunInChildContextConfig config, DurableContextImpl durableContext, ConcurrencyOperation parentOperation) { - super(operationIdentifier, resultTypeToken, resultSerDes, durableContext); + super(operationIdentifier, resultTypeToken, config.serDes(), durableContext); this.function = function; this.userExecutor = getContext().getDurableConfig().getExecutorService(); this.parentOperation = parentOperation; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java index be94fd0b0..05bedec6f 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java @@ -6,8 +6,8 @@ import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; import software.amazon.awssdk.services.lambda.model.OperationUpdate; -import software.amazon.lambda.durable.InvokeConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.InvokeConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.InvokeException; import software.amazon.lambda.durable.exception.InvokeFailedException; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java index 4ede1d325..146e9695c 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -11,15 +11,14 @@ import software.amazon.awssdk.services.lambda.model.OperationAction; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.OperationUpdate; -import software.amazon.lambda.durable.CompletionConfig; import software.amazon.lambda.durable.DurableContext; -import software.amazon.lambda.durable.MapConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.CompletionConfig; +import software.amazon.lambda.durable.config.MapConfig; +import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; -import software.amazon.lambda.durable.model.MapError; import software.amazon.lambda.durable.model.MapResult; -import software.amazon.lambda.durable.model.MapResultItem; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; import software.amazon.lambda.durable.serde.SerDes; @@ -99,7 +98,7 @@ protected ChildContextOperation createItem( OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.MAP_ITERATION), function, resultType, - serDes, + RunInChildContextConfig.builder().serDes(serDes).build(), parentContext, this); } @@ -200,24 +199,24 @@ public MapResult get() { @SuppressWarnings("unchecked") private MapResult aggregateResults() { var children = getBranches(); - var resultItems = new ArrayList>(Collections.nCopies(items.size(), null)); + var resultItems = new ArrayList>(Collections.nCopies(items.size(), null)); for (int i = 0; i < children.size(); i++) { var branch = (ChildContextOperation) children.get(i); if (!branch.isOperationCompleted()) { - resultItems.set(i, MapResultItem.skipped()); + resultItems.set(i, MapResult.MapResultItem.skipped()); continue; } try { - resultItems.set(i, MapResultItem.succeeded(branch.get())); + resultItems.set(i, MapResult.MapResultItem.succeeded(branch.get())); } catch (Exception e) { - resultItems.set(i, MapResultItem.failed(MapError.of(e))); + resultItems.set(i, MapResult.MapResultItem.failed(MapResult.MapError.of(e))); } } // Fill any remaining null slots (items beyond children size) with skipped for (int i = children.size(); i < items.size(); i++) { - resultItems.set(i, MapResultItem.skipped()); + resultItems.set(i, MapResult.MapResultItem.skipped()); } return new MapResult<>(resultItems, completionStatus); diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index 7023cb6f8..48880138c 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -10,9 +10,11 @@ import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; -import software.amazon.lambda.durable.ParallelConfig; import software.amazon.lambda.durable.ParallelDurableFuture; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.ParallelBranchConfig; +import software.amazon.lambda.durable.config.ParallelConfig; +import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; @@ -74,7 +76,7 @@ protected ChildContextOperation createItem( OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL_BRANCH), function, resultType, - serDes, + RunInChildContextConfig.builder().serDes(serDes).build(), parentContext, this); } @@ -112,42 +114,18 @@ public ParallelResult get() { return new ParallelResult(getTotalItems(), getSucceededCount(), getFailedCount(), getCompletionStatus()); } - /** - * Calls {@link #join()} if not already called. Guarantees that all branches complete before the context is closed. - */ + /** Calls {@link #get()} if not already called. Guarantees that the context is closed. */ @Override public void close() { join(); } - /** - * Registers and immediately starts a branch (respects maxConcurrency). - * - * @param name the branch name - * @param resultType the result type class - * @param func the function to execute in the branch's child context - * @param the result type - * @return a {@link DurableFuture} that will contain the branch result - * @throws IllegalStateException if called after {@link #join()} - */ - public DurableFuture branch(String name, Class resultType, Function func) { - return branch(name, TypeToken.get(resultType), func); - } - - /** - * Registers and immediately starts a branch (respects maxConcurrency). - * - * @param name the branch name - * @param resultType the result type token for generic types - * @param func the function to execute in the branch's child context - * @param the result type - * @return a {@link DurableFuture} that will contain the branch result - * @throws IllegalStateException if called after {@link #join()} - */ - public DurableFuture branch(String name, TypeToken resultType, Function func) { + public DurableFuture branch( + String name, TypeToken resultType, Function func, ParallelBranchConfig config) { if (isJoined.get()) { throw new IllegalStateException("Cannot add branches after join() has been called"); } - return addItem(name, func, resultType, getContext().getDurableConfig().getSerDes()); + var serDes = config.serDes() == null ? getContext().getDurableConfig().getSerDes() : config.serDes(); + return addItem(name, func, resultType, serDes); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java index 4570f8637..901766a32 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java @@ -13,10 +13,10 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.awssdk.services.lambda.model.StepOptions; -import software.amazon.lambda.durable.StepConfig; import software.amazon.lambda.durable.StepContext; -import software.amazon.lambda.durable.StepSemantics; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.StepSemantics; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.DurableOperationException; import software.amazon.lambda.durable.exception.StepFailedException; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java index 700fe5f19..2645ac2e6 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java @@ -14,7 +14,7 @@ import software.amazon.awssdk.services.lambda.model.StepOptions; import software.amazon.lambda.durable.StepContext; import software.amazon.lambda.durable.TypeToken; -import software.amazon.lambda.durable.WaitForConditionConfig; +import software.amazon.lambda.durable.config.WaitForConditionConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.DurableOperationException; import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurableContextTest.java b/sdk/src/test/java/software/amazon/lambda/durable/DurableContextTest.java index 269eacf98..1d50cc1c8 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurableContextTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/DurableContextTest.java @@ -9,6 +9,7 @@ import java.util.List; import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.lambda.model.*; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.execution.SuspendExecutionException; @@ -47,7 +48,9 @@ private DurableContext createTestContext(List initialOperations) { new DurableExecutionInput(EXECUTION_ARN, "test-token", initialExecutionState), DurableConfig.builder().withDurableExecutionClient(client).build()); var root = DurableContextImpl.createRootContext( - executionManager, DurableConfig.builder().build(), null); + executionManager, + software.amazon.lambda.durable.DurableConfig.builder().build(), + null); executionManager.registerActiveThread(null); executionManager.setCurrentThreadContext(new ThreadContext(null, ThreadType.CONTEXT)); return root; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurationValidationIntegrationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/DurationValidationIntegrationTest.java index 55a72a607..2deda7107 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurationValidationIntegrationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/DurationValidationIntegrationTest.java @@ -6,6 +6,7 @@ import java.time.Duration; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.CallbackConfig; class DurationValidationIntegrationTest { diff --git a/sdk/src/test/java/software/amazon/lambda/durable/TypeTokenTest.java b/sdk/src/test/java/software/amazon/lambda/durable/TypeTokenTest.java index fe5ebe98f..c3029ee6a 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/TypeTokenTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/TypeTokenTest.java @@ -17,7 +17,7 @@ void testSimpleGenericType() { var token = new TypeToken>() {}; Type type = token.getType(); - assertTrue(type instanceof ParameterizedType); + assertInstanceOf(ParameterizedType.class, type); ParameterizedType paramType = (ParameterizedType) type; assertEquals(List.class, paramType.getRawType()); assertEquals(String.class, paramType.getActualTypeArguments()[0]); @@ -28,13 +28,13 @@ void testNestedGenericType() { var token = new TypeToken>>() {}; Type type = token.getType(); - assertTrue(type instanceof ParameterizedType); + assertInstanceOf(ParameterizedType.class, type); ParameterizedType paramType = (ParameterizedType) type; assertEquals(Map.class, paramType.getRawType()); assertEquals(String.class, paramType.getActualTypeArguments()[0]); Type valueType = paramType.getActualTypeArguments()[1]; - assertTrue(valueType instanceof ParameterizedType); + assertInstanceOf(ParameterizedType.class, valueType); ParameterizedType valueParamType = (ParameterizedType) valueType; assertEquals(List.class, valueParamType.getRawType()); assertEquals(Integer.class, valueParamType.getActualTypeArguments()[0]); diff --git a/sdk/src/test/java/software/amazon/lambda/durable/CallbackConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/CallbackConfigTest.java similarity index 97% rename from sdk/src/test/java/software/amazon/lambda/durable/CallbackConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/CallbackConfigTest.java index 5f5853f1a..45798109d 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/CallbackConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/CallbackConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/CompletionConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/CompletionConfigTest.java similarity index 98% rename from sdk/src/test/java/software/amazon/lambda/durable/CompletionConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/CompletionConfigTest.java index 710bf9b1f..b9dbe7ecc 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/CompletionConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/CompletionConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.*; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/MapConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/MapConfigTest.java similarity index 98% rename from sdk/src/test/java/software/amazon/lambda/durable/MapConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/MapConfigTest.java index fd856fed4..fc8962de3 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/MapConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/MapConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.*; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/ParallelConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/ParallelConfigTest.java similarity index 96% rename from sdk/src/test/java/software/amazon/lambda/durable/ParallelConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/ParallelConfigTest.java index 819ce5bc5..923a9d221 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/ParallelConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/ParallelConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/StepConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/StepConfigTest.java similarity index 98% rename from sdk/src/test/java/software/amazon/lambda/durable/StepConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/StepConfigTest.java index ec9437c03..033e0b4d5 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/StepConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/StepConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/WaitForCallbackConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/WaitForCallbackConfigTest.java similarity index 99% rename from sdk/src/test/java/software/amazon/lambda/durable/WaitForCallbackConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/WaitForCallbackConfigTest.java index 2ca85e143..6f12f5cd8 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/WaitForCallbackConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/WaitForCallbackConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/WaitForConditionConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/WaitForConditionConfigTest.java similarity index 98% rename from sdk/src/test/java/software/amazon/lambda/durable/WaitForConditionConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/WaitForConditionConfigTest.java index 82a431874..eacb1f704 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/WaitForConditionConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/WaitForConditionConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionTest.java b/sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionTest.java similarity index 98% rename from sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionTest.java index 3b966cdbd..a4ad86d3b 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.execution; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -19,7 +19,8 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.StepDetails; -import software.amazon.lambda.durable.execution.DurableExecutor; +import software.amazon.lambda.durable.DurableConfig; +import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.model.DurableExecutionInput; import software.amazon.lambda.durable.model.ExecutionStatus; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionWrapperTest.java b/sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionWrapperTest.java similarity index 96% rename from sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionWrapperTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionWrapperTest.java index ac9bfdd7a..a64a19232 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionWrapperTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionWrapperTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.execution; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -14,8 +14,10 @@ import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; +import software.amazon.lambda.durable.DurableConfig; +import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.client.DurableExecutionClient; -import software.amazon.lambda.durable.execution.DurableExecutor; import software.amazon.lambda.durable.model.DurableExecutionInput; import software.amazon.lambda.durable.model.DurableExecutionOutput; import software.amazon.lambda.durable.model.ExecutionStatus; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java b/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java index f144a18b6..ea8e49665 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java @@ -9,8 +9,8 @@ class MapResultTest { - private static MapError testError(String message) { - return new MapError("java.lang.RuntimeException", message, null); + private static MapResult.MapError testError(String message) { + return new MapResult.MapError("java.lang.RuntimeException", message, null); } @Test @@ -28,7 +28,7 @@ void empty_returnsZeroSizeResult() { @Test void allSucceeded_trueWhenNoErrors() { var result = new MapResult<>( - List.of(MapResultItem.succeeded("a"), MapResultItem.succeeded("b")), + List.of(MapResult.MapResultItem.succeeded("a"), MapResult.MapResultItem.succeeded("b")), ConcurrencyCompletionStatus.ALL_COMPLETED); assertTrue(result.allSucceeded()); @@ -43,7 +43,7 @@ void allSucceeded_trueWhenNoErrors() { void allSucceeded_falseWhenAnyError() { var error = testError("fail"); var result = new MapResult<>( - List.of(MapResultItem.succeeded("a"), MapResultItem.failed(error)), + List.of(MapResult.MapResultItem.succeeded("a"), MapResult.MapResultItem.failed(error)), ConcurrencyCompletionStatus.ALL_COMPLETED); assertFalse(result.allSucceeded()); @@ -53,7 +53,7 @@ void allSucceeded_falseWhenAnyError() { void getResult_returnsNullForFailedItem() { var error = testError("fail"); var result = new MapResult<>( - List.of(MapResultItem.succeeded("a"), MapResultItem.failed(error)), + List.of(MapResult.MapResultItem.succeeded("a"), MapResult.MapResultItem.failed(error)), ConcurrencyCompletionStatus.ALL_COMPLETED); assertEquals("a", result.getResult(0)); @@ -64,7 +64,7 @@ void getResult_returnsNullForFailedItem() { void getError_returnsNullForSucceededItem() { var error = testError("fail"); var result = new MapResult<>( - List.of(MapResultItem.succeeded("a"), MapResultItem.failed(error)), + List.of(MapResult.MapResultItem.succeeded("a"), MapResult.MapResultItem.failed(error)), ConcurrencyCompletionStatus.ALL_COMPLETED); assertNull(result.getError(0)); @@ -75,9 +75,9 @@ void getError_returnsNullForSucceededItem() { void succeeded_filtersNullResults() { var result = new MapResult<>( List.of( - MapResultItem.succeeded("a"), - MapResultItem.failed(testError("fail")), - MapResultItem.succeeded("c")), + MapResult.MapResultItem.succeeded("a"), + MapResult.MapResultItem.failed(testError("fail")), + MapResult.MapResultItem.succeeded("c")), ConcurrencyCompletionStatus.ALL_COMPLETED); assertEquals(List.of("a", "c"), result.succeeded()); @@ -88,9 +88,9 @@ void failed_filtersNullErrors() { var error = testError("fail"); var result = new MapResult<>( List.of( - MapResultItem.succeeded("a"), - MapResultItem.failed(error), - MapResultItem.succeeded("c")), + MapResult.MapResultItem.succeeded("a"), + MapResult.MapResultItem.failed(error), + MapResult.MapResultItem.succeeded("c")), ConcurrencyCompletionStatus.ALL_COMPLETED); var failures = result.failed(); @@ -101,29 +101,33 @@ void failed_filtersNullErrors() { @Test void completionReason_preserved() { var result = new MapResult<>( - List.of(MapResultItem.succeeded("a")), ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED); + List.of(MapResult.MapResultItem.succeeded("a")), ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED); assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.completionReason()); } @Test void items_returnsUnmodifiableList() { - var result = new MapResult<>(List.of(MapResultItem.succeeded("a")), ConcurrencyCompletionStatus.ALL_COMPLETED); + var result = new MapResult<>( + List.of(MapResult.MapResultItem.succeeded("a")), ConcurrencyCompletionStatus.ALL_COMPLETED); - assertThrows(UnsupportedOperationException.class, () -> result.items().add(MapResultItem.succeeded("b"))); + assertThrows( + UnsupportedOperationException.class, () -> result.items().add(MapResult.MapResultItem.succeeded("b"))); } @Test void getItem_returnsMapResultItem() { var result = new MapResult<>( - List.of(MapResultItem.succeeded("a"), MapResultItem.failed(testError("fail"))), + List.of( + MapResult.MapResultItem.succeeded("a"), + MapResult.MapResultItem.failed(testError("fail"))), ConcurrencyCompletionStatus.ALL_COMPLETED); - assertEquals(MapResultItem.Status.SUCCEEDED, result.getItem(0).status()); + assertEquals(MapResult.MapResultItem.Status.SUCCEEDED, result.getItem(0).status()); assertEquals("a", result.getItem(0).result()); assertNull(result.getItem(0).error()); - assertEquals(MapResultItem.Status.FAILED, result.getItem(1).status()); + assertEquals(MapResult.MapResultItem.Status.FAILED, result.getItem(1).status()); assertNull(result.getItem(1).result()); assertNotNull(result.getItem(1).error()); } @@ -131,10 +135,10 @@ void getItem_returnsMapResultItem() { @Test void notStartedItems_haveNotStartedStatusAndNullResultAndError() { var result = new MapResult<>( - List.of(MapResultItem.succeeded("a"), MapResultItem.skipped()), + List.of(MapResult.MapResultItem.succeeded("a"), MapResult.MapResultItem.skipped()), ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED); - assertEquals(MapResultItem.Status.SKIPPED, result.getItem(1).status()); + assertEquals(MapResult.MapResultItem.Status.SKIPPED, result.getItem(1).status()); assertNull(result.getResult(1)); assertNull(result.getError(1)); } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/CallbackOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/CallbackOperationTest.java index 8f9f8775a..ebe2d7e99 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/CallbackOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/CallbackOperationTest.java @@ -13,10 +13,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.lambda.model.*; -import software.amazon.lambda.durable.CallbackConfig; import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.CallbackConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.CallbackFailedException; import software.amazon.lambda.durable.exception.CallbackTimeoutException; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java index 0e6158ed1..60ae3025f 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java @@ -20,6 +20,7 @@ import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.ChildContextFailedException; import software.amazon.lambda.durable.exception.NonDeterministicExecutionException; @@ -58,13 +59,22 @@ private DurableConfig createConfig() { private ChildContextOperation createOperation(Function func) { return new ChildContextOperation<>( - OPERATION_IDENTIFIER, func, TypeToken.get(String.class), SERDES, durableContext); + OPERATION_IDENTIFIER, + func, + TypeToken.get(String.class), + RunInChildContextConfig.builder().serDes(SERDES).build(), + durableContext); } private ChildContextOperation createOperationWithParent( Function func, ConcurrencyOperation parent) { return new ChildContextOperation<>( - OPERATION_IDENTIFIER, func, TypeToken.get(String.class), SERDES, durableContext, parent); + OPERATION_IDENTIFIER, + func, + TypeToken.get(String.class), + RunInChildContextConfig.builder().serDes(SERDES).build(), + durableContext, + parent); } // ===== SUCCEEDED replay ===== diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java index d0b816150..57770ea60 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java @@ -19,10 +19,11 @@ import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; -import software.amazon.lambda.durable.CompletionConfig; import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.CompletionConfig; +import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.execution.OperationIdGenerator; @@ -243,7 +244,7 @@ protected ChildContextOperation createItem( OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL_BRANCH), function, resultType, - serDes, + RunInChildContextConfig.builder().serDes(serDes).build(), parentContext, this) { @Override diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/InvokeOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/InvokeOperationTest.java index f370805c9..daa75056c 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/InvokeOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/InvokeOperationTest.java @@ -14,8 +14,8 @@ import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; -import software.amazon.lambda.durable.InvokeConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.InvokeConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.InvokeException; import software.amazon.lambda.durable.exception.InvokeFailedException; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index 439e8574a..138e98c31 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -20,10 +20,10 @@ import software.amazon.awssdk.services.lambda.model.OperationAction; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; -import software.amazon.lambda.durable.CompletionConfig; import software.amazon.lambda.durable.DurableConfig; -import software.amazon.lambda.durable.ParallelConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.CompletionConfig; +import software.amazon.lambda.durable.config.ParallelConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.execution.OperationIdGenerator; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/StepOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/StepOperationTest.java index 65cb64a7d..2eae32f0c 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/StepOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/StepOperationTest.java @@ -15,8 +15,8 @@ import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.StepDetails; import software.amazon.lambda.durable.DurableConfig; -import software.amazon.lambda.durable.StepConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.StepFailedException; import software.amazon.lambda.durable.exception.StepInterruptedException; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/WaitForConditionOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/WaitForConditionOperationTest.java index 32f175ec1..2747e6440 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/WaitForConditionOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/WaitForConditionOperationTest.java @@ -18,7 +18,7 @@ import software.amazon.awssdk.services.lambda.model.StepDetails; import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.TypeToken; -import software.amazon.lambda.durable.WaitForConditionConfig; +import software.amazon.lambda.durable.config.WaitForConditionConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.IllegalDurableOperationException; import software.amazon.lambda.durable.exception.NonDeterministicExecutionException; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/retry/RetryStrategiesTest.java b/sdk/src/test/java/software/amazon/lambda/durable/retry/RetryStrategiesTest.java index 74f57966b..0578a8bf8 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/retry/RetryStrategiesTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/retry/RetryStrategiesTest.java @@ -6,7 +6,7 @@ import java.time.Duration; import org.junit.jupiter.api.Test; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; class RetryStrategiesTest { From 195db49274da03c0d0d95f1ccb7b9df3374a6458 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sat, 21 Mar 2026 22:16:15 -0700 Subject: [PATCH 06/15] move proxy methods to DurableFuture --- .../amazon/lambda/durable/DurableContext.java | 425 +++++++++++--- .../durable/context/DurableContextImpl.java | 546 ------------------ 2 files changed, 346 insertions(+), 625 deletions(-) diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java index 4e9d26adc..2f652232f 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java @@ -33,7 +33,9 @@ public interface DurableContext extends BaseContext { * @param func the function to execute, receiving a {@link StepContext} * @return the step result */ - T step(String name, Class resultType, Function func); + default T step(String name, Class resultType, Function func) { + return step(name, TypeToken.get(resultType), func, StepConfig.builder().build()); + } /** * Executes a durable step with the given name and configuration, blocking until it completes. @@ -45,7 +47,9 @@ public interface DurableContext extends BaseContext { * @param config the step configuration (retry strategy, semantics, custom SerDes) * @return the step result */ - T step(String name, Class resultType, Function func, StepConfig config); + default T step(String name, Class resultType, Function func, StepConfig config) { + return stepAsync(name, resultType, func, config).get(); + } /** * Executes a durable step using a {@link TypeToken} for generic result types, blocking until it completes. @@ -56,7 +60,9 @@ public interface DurableContext extends BaseContext { * @param func the function to execute, receiving a {@link StepContext} * @return the step result */ - T step(String name, TypeToken resultType, Function func); + default T step(String name, TypeToken resultType, Function func) { + return step(name, resultType, func, StepConfig.builder().build()); + } /** * Executes a durable step using a {@link TypeToken} and configuration, blocking until it completes. @@ -68,7 +74,9 @@ public interface DurableContext extends BaseContext { * @param config the step configuration (retry strategy, semantics, custom SerDes) * @return the step result */ - T step(String name, TypeToken resultType, Function func, StepConfig config); + default T step(String name, TypeToken resultType, Function func, StepConfig config) { + return stepAsync(name, resultType, func, config).get(); + } /** * Asynchronously executes a durable step, returning a {@link DurableFuture} that can be composed or blocked on. @@ -79,7 +87,10 @@ public interface DurableContext extends BaseContext { * @param func the function to execute, receiving a {@link StepContext} * @return a future representing the step result */ - DurableFuture stepAsync(String name, Class resultType, Function func); + default DurableFuture stepAsync(String name, Class resultType, Function func) { + return stepAsync( + name, TypeToken.get(resultType), func, StepConfig.builder().build()); + } /** * Asynchronously executes a durable step using custom configuration. @@ -92,7 +103,10 @@ public interface DurableContext extends BaseContext { * @param config the step configuration (retry strategy, semantics, custom SerDes) * @return a future representing the step result */ - DurableFuture stepAsync(String name, Class resultType, Function func, StepConfig config); + default DurableFuture stepAsync( + String name, Class resultType, Function func, StepConfig config) { + return stepAsync(name, TypeToken.get(resultType), func, config); + } /** * Asynchronously executes a durable step using a {@link TypeToken} for generic result types. @@ -105,7 +119,9 @@ public interface DurableContext extends BaseContext { * @param func the function to execute, receiving a {@link StepContext} * @return a future representing the step result */ - DurableFuture stepAsync(String name, TypeToken resultType, Function func); + default DurableFuture stepAsync(String name, TypeToken resultType, Function func) { + return stepAsync(name, resultType, func, StepConfig.builder().build()); + } /** * Asynchronously executes a durable step using a {@link TypeToken} and custom configuration. @@ -122,29 +138,60 @@ public interface DurableContext extends BaseContext { DurableFuture stepAsync( String name, TypeToken resultType, Function func, StepConfig config); + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - T step(String name, Class resultType, Supplier func); + default T step(String name, Class resultType, Supplier func) { + return stepAsync( + name, + TypeToken.get(resultType), + func, + StepConfig.builder().build()) + .get(); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - T step(String name, Class resultType, Supplier func, StepConfig config); + default T step(String name, Class resultType, Supplier func, StepConfig config) { + // Simply delegate to stepAsync and block on the result + return stepAsync(name, TypeToken.get(resultType), func, config).get(); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - T step(String name, TypeToken resultType, Supplier func); + default T step(String name, TypeToken resultType, Supplier func) { + return stepAsync(name, resultType, func, StepConfig.builder().build()).get(); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - T step(String name, TypeToken resultType, Supplier func, StepConfig config); + default T step(String name, TypeToken resultType, Supplier func, StepConfig config) { + return stepAsync(name, resultType, func, config).get(); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - DurableFuture stepAsync(String name, Class resultType, Supplier func); + default DurableFuture stepAsync(String name, Class resultType, Supplier func) { + return stepAsync( + name, TypeToken.get(resultType), func, StepConfig.builder().build()); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - DurableFuture stepAsync(String name, Class resultType, Supplier func, StepConfig config); + default DurableFuture stepAsync(String name, Class resultType, Supplier func, StepConfig config) { + return stepAsync(name, TypeToken.get(resultType), func, config); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - DurableFuture stepAsync(String name, TypeToken resultType, Supplier func); + default DurableFuture stepAsync(String name, TypeToken resultType, Supplier func) { + return stepAsync(name, resultType, func, StepConfig.builder().build()); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - DurableFuture stepAsync(String name, TypeToken resultType, Supplier func, StepConfig config); + default DurableFuture stepAsync(String name, TypeToken resultType, Supplier func, StepConfig config) { + return stepAsync(name, resultType, stepContext -> func.get(), config); + } /** * Suspends execution for the specified duration without consuming compute resources. @@ -156,7 +203,9 @@ DurableFuture stepAsync( * @param duration the duration to wait * @return always {@code null} */ - Void wait(String name, Duration duration); + default Void wait(String name, Duration duration) { + return waitAsync(name, duration).get(); + } /** * Asynchronously suspends execution for the specified duration. @@ -181,26 +230,59 @@ DurableFuture stepAsync( * @param resultType the result class for deserialization * @return the invocation result */ - T invoke(String name, String functionName, U payload, Class resultType); + default T invoke(String name, String functionName, U payload, Class resultType) { + return invokeAsync( + name, + functionName, + payload, + TypeToken.get(resultType), + InvokeConfig.builder().build()) + .get(); + } /** Invokes another Lambda function with custom configuration, blocking until the result is available. */ - T invoke(String name, String functionName, U payload, Class resultType, InvokeConfig config); + default T invoke(String name, String functionName, U payload, Class resultType, InvokeConfig config) { + return invokeAsync(name, functionName, payload, TypeToken.get(resultType), config) + .get(); + } /** Invokes another Lambda function using a {@link TypeToken} for generic result types, blocking until complete. */ - T invoke(String name, String functionName, U payload, TypeToken resultType); + default T invoke(String name, String functionName, U payload, TypeToken resultType) { + return invokeAsync( + name, + functionName, + payload, + resultType, + InvokeConfig.builder().build()) + .get(); + } /** Invokes another Lambda function using a {@link TypeToken} and custom configuration, blocking until complete. */ - T invoke(String name, String functionName, U payload, TypeToken resultType, InvokeConfig config); + default T invoke(String name, String functionName, U payload, TypeToken resultType, InvokeConfig config) { + return invokeAsync(name, functionName, payload, resultType, config).get(); + } /** Invokes another Lambda function using a {@link TypeToken} and custom configuration, blocking until complete. */ - DurableFuture invokeAsync( - String name, String functionName, U payload, Class resultType, InvokeConfig config); + default DurableFuture invokeAsync( + String name, String functionName, U payload, Class resultType, InvokeConfig config) { + return invokeAsync(name, functionName, payload, TypeToken.get(resultType), config); + } /** Asynchronously invokes another Lambda function, returning a {@link DurableFuture}. */ - DurableFuture invokeAsync(String name, String functionName, U payload, Class resultType); + default DurableFuture invokeAsync(String name, String functionName, U payload, Class resultType) { + return invokeAsync( + name, + functionName, + payload, + TypeToken.get(resultType), + InvokeConfig.builder().build()); + } /** Asynchronously invokes another Lambda function using a {@link TypeToken} for generic result types. */ - DurableFuture invokeAsync(String name, String functionName, U payload, TypeToken resultType); + default DurableFuture invokeAsync(String name, String functionName, U payload, TypeToken resultType) { + return invokeAsync( + name, functionName, payload, resultType, InvokeConfig.builder().build()); + } /** * Asynchronously invokes another Lambda function using a {@link TypeToken} and custom configuration. @@ -220,13 +302,20 @@ DurableFuture invokeAsync( String name, String functionName, U payload, TypeToken resultType, InvokeConfig config); /** Creates a callback with custom configuration. */ - DurableCallbackFuture createCallback(String name, Class resultType, CallbackConfig config); + default DurableCallbackFuture createCallback(String name, Class resultType, CallbackConfig config) { + return createCallback(name, TypeToken.get(resultType), config); + } /** Creates a callback using a {@link TypeToken} for generic result types. */ - DurableCallbackFuture createCallback(String name, TypeToken resultType); + default DurableCallbackFuture createCallback(String name, TypeToken resultType) { + return createCallback(name, resultType, CallbackConfig.builder().build()); + } /** Creates a callback with default configuration. */ - DurableCallbackFuture createCallback(String name, Class resultType); + default DurableCallbackFuture createCallback(String name, Class resultType) { + return createCallback( + name, TypeToken.get(resultType), CallbackConfig.builder().build()); + } /** * Creates a callback operation that suspends execution until an external system completes it. @@ -254,18 +343,63 @@ DurableFuture invokeAsync( * @param func the function to execute, receiving a child {@link DurableContext} * @return the child context result */ - T runInChildContext(String name, Class resultType, Function func); + default T runInChildContext(String name, Class resultType, Function func) { + return runInChildContextAsync( + name, + TypeToken.get(resultType), + func, + RunInChildContextConfig.builder().build()) + .get(); + } /** * Runs a function in a child context using a {@link TypeToken} for generic result types, blocking until complete. */ - T runInChildContext(String name, TypeToken resultType, Function func); + default T runInChildContext(String name, TypeToken resultType, Function func) { + return runInChildContextAsync( + name, + resultType, + func, + RunInChildContextConfig.builder().build()) + .get(); + } - /** Asynchronously runs a function in a child context, returning a {@link DurableFuture}. */ - DurableFuture runInChildContextAsync(String name, Class resultType, Function func); + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @return the DurableFuture of the child context result + */ + default DurableFuture runInChildContextAsync( + String name, Class resultType, Function func) { + return runInChildContextAsync( + name, + TypeToken.get(resultType), + func, + RunInChildContextConfig.builder().build()); + } - /** Asynchronously runs a function in a child context using a {@link TypeToken} for generic result types. */ - DurableFuture runInChildContextAsync(String name, TypeToken resultType, Function func); + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @return the DurableFuture of the child context result + */ + default DurableFuture runInChildContextAsync( + String name, TypeToken resultType, Function func) { + return runInChildContextAsync( + name, resultType, func, RunInChildContextConfig.builder().build()); + } /** * Runs a function in a child context, blocking until it completes. @@ -279,41 +413,107 @@ DurableFuture invokeAsync( * @param func the function to execute, receiving a child {@link DurableContext} * @return the child context result */ - T runInChildContext( - String name, Class resultType, Function func, RunInChildContextConfig config); + default T runInChildContext( + String name, Class resultType, Function func, RunInChildContextConfig config) { + return runInChildContextAsync(name, TypeToken.get(resultType), func, config) + .get(); + } /** - * Runs a function in a child context using a {@link TypeToken} for generic result types, blocking until complete. + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @param config the configuration for the child context + * @return the child context result */ - T runInChildContext( - String name, TypeToken resultType, Function func, RunInChildContextConfig config); + default T runInChildContext( + String name, TypeToken resultType, Function func, RunInChildContextConfig config) { + return runInChildContextAsync(name, resultType, func, config).get(); + } - /** Asynchronously runs a function in a child context, returning a {@link DurableFuture}. */ - DurableFuture runInChildContextAsync( - String name, Class resultType, Function func, RunInChildContextConfig config); + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @param config the configuration for the child context + * @return the DurableFuture wrapping the child context result + */ + default DurableFuture runInChildContextAsync( + String name, Class resultType, Function func, RunInChildContextConfig config) { + return runInChildContextAsync(name, TypeToken.get(resultType), func, config); + } - /** Asynchronously runs a function in a child context using a {@link TypeToken} for generic result types. */ + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @param config the configuration for the child context + * @return the DurableFuture wrapping the child context result + */ DurableFuture runInChildContextAsync( String name, TypeToken resultType, Function func, RunInChildContextConfig config); - MapResult map(String name, Collection items, Class resultType, MapFunction function); + default MapResult map(String name, Collection items, Class resultType, MapFunction function) { + return mapAsync( + name, + items, + TypeToken.get(resultType), + function, + MapConfig.builder().build()) + .get(); + } - MapResult map( - String name, Collection items, Class resultType, MapFunction function, MapConfig config); + default MapResult map( + String name, Collection items, Class resultType, MapFunction function, MapConfig config) { + return mapAsync(name, items, TypeToken.get(resultType), function, config) + .get(); + } - MapResult map(String name, Collection items, TypeToken resultType, MapFunction function); + default MapResult map( + String name, Collection items, TypeToken resultType, MapFunction function) { + return mapAsync(name, items, resultType, function, MapConfig.builder().build()) + .get(); + } - MapResult map( - String name, Collection items, TypeToken resultType, MapFunction function, MapConfig config); + default MapResult map( + String name, Collection items, TypeToken resultType, MapFunction function, MapConfig config) { + return mapAsync(name, items, resultType, function, config).get(); + } - DurableFuture> mapAsync( - String name, Collection items, Class resultType, MapFunction function); + default DurableFuture> mapAsync( + String name, Collection items, Class resultType, MapFunction function) { + return mapAsync( + name, + items, + TypeToken.get(resultType), + function, + MapConfig.builder().build()); + } - DurableFuture> mapAsync( - String name, Collection items, Class resultType, MapFunction function, MapConfig config); + default DurableFuture> mapAsync( + String name, Collection items, Class resultType, MapFunction function, MapConfig config) { + return mapAsync(name, items, TypeToken.get(resultType), function, config); + } - DurableFuture> mapAsync( - String name, Collection items, TypeToken resultType, MapFunction function); + default DurableFuture> mapAsync( + String name, Collection items, TypeToken resultType, MapFunction function) { + return mapAsync(name, items, resultType, function, MapConfig.builder().build()); + } DurableFuture> mapAsync( String name, Collection items, TypeToken resultType, MapFunction function, MapConfig config); @@ -338,38 +538,67 @@ DurableFuture> mapAsync( * @param func the submitter function, receiving the callback ID and a {@link StepContext} * @return the callback result */ - T waitForCallback(String name, Class resultType, BiConsumer func); + default T waitForCallback(String name, Class resultType, BiConsumer func) { + return waitForCallbackAsync( + name, + TypeToken.get(resultType), + func, + WaitForCallbackConfig.builder().build()) + .get(); + } /** Executes a submitter and waits for an external callback using a {@link TypeToken}, blocking until complete. */ - T waitForCallback(String name, TypeToken resultType, BiConsumer func); + default T waitForCallback(String name, TypeToken resultType, BiConsumer func) { + return waitForCallbackAsync( + name, resultType, func, WaitForCallbackConfig.builder().build()) + .get(); + } /** Executes a submitter and waits for an external callback with custom configuration, blocking until complete. */ - T waitForCallback( + default T waitForCallback( String name, Class resultType, BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig); + WaitForCallbackConfig waitForCallbackConfig) { + return waitForCallbackAsync(name, TypeToken.get(resultType), func, waitForCallbackConfig) + .get(); + } /** Executes a submitter and waits for an external callback using a {@link TypeToken} and custom configuration. */ - T waitForCallback( + default T waitForCallback( String name, TypeToken resultType, BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig); + WaitForCallbackConfig waitForCallbackConfig) { + return waitForCallbackAsync(name, resultType, func, waitForCallbackConfig) + .get(); + } /** Asynchronously executes a submitter and waits for an external callback. */ - DurableFuture waitForCallbackAsync(String name, Class resultType, BiConsumer func); + default DurableFuture waitForCallbackAsync( + String name, Class resultType, BiConsumer func) { + return waitForCallbackAsync( + name, + TypeToken.get(resultType), + func, + WaitForCallbackConfig.builder().build()); + } /** Asynchronously executes a submitter and waits for an external callback using a {@link TypeToken}. */ - DurableFuture waitForCallbackAsync( - String name, TypeToken resultType, BiConsumer func); + default DurableFuture waitForCallbackAsync( + String name, TypeToken resultType, BiConsumer func) { + return waitForCallbackAsync( + name, resultType, func, WaitForCallbackConfig.builder().build()); + } /** Asynchronously executes a submitter and waits for an external callback with custom configuration. */ - DurableFuture waitForCallbackAsync( + default DurableFuture waitForCallbackAsync( String name, Class resultType, BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig); + WaitForCallbackConfig waitForCallbackConfig) { + return waitForCallbackAsync(name, TypeToken.get(resultType), func, waitForCallbackConfig); + } /** * Asynchronously executes a submitter and waits for an external callback using a {@link TypeToken} and custom @@ -402,59 +631,97 @@ DurableFuture waitForCallbackAsync( * @param initialState the initial state passed to the first check invocation * @return the final state value when the condition is met */ - T waitForCondition( + default T waitForCondition( String name, Class resultType, BiFunction> checkFunc, - T initialState); + T initialState) { + return waitForConditionAsync( + name, + TypeToken.get(resultType), + checkFunc, + initialState, + WaitForConditionConfig.builder().build()) + .get(); + } /** Polls a condition function until it signals done, using a custom configuration, blocking until complete. */ - T waitForCondition( + default T waitForCondition( String name, Class resultType, BiFunction> checkFunc, T initialState, - WaitForConditionConfig config); + WaitForConditionConfig config) { + return waitForConditionAsync(name, resultType, checkFunc, initialState, config) + .get(); + } /** Polls a condition function until it signals done, using a {@link TypeToken}, blocking until complete. */ - T waitForCondition( + default T waitForCondition( String name, TypeToken resultType, BiFunction> checkFunc, - T initialState); + T initialState) { + return waitForConditionAsync( + name, + resultType, + checkFunc, + initialState, + WaitForConditionConfig.builder().build()) + .get(); + } /** * Polls a condition function until it signals done, using a {@link TypeToken} and custom configuration, blocking * until complete. */ - T waitForCondition( + default T waitForCondition( String name, TypeToken resultType, BiFunction> checkFunc, T initialState, - WaitForConditionConfig config); + WaitForConditionConfig config) { + return waitForConditionAsync(name, resultType, checkFunc, initialState, config) + .get(); + } /** Asynchronously polls a condition function until it signals done. */ - DurableFuture waitForConditionAsync( + default DurableFuture waitForConditionAsync( String name, Class resultType, BiFunction> checkFunc, - T initialState); + T initialState) { + return waitForConditionAsync( + name, + TypeToken.get(resultType), + checkFunc, + initialState, + WaitForConditionConfig.builder().build()); + } /** Asynchronously polls a condition function until it signals done, using custom configuration. */ - DurableFuture waitForConditionAsync( + default DurableFuture waitForConditionAsync( String name, Class resultType, BiFunction> checkFunc, T initialState, - WaitForConditionConfig config); + WaitForConditionConfig config) { + return waitForConditionAsync(name, TypeToken.get(resultType), checkFunc, initialState, config); + } /** Asynchronously polls a condition function until it signals done, using a {@link TypeToken}. */ - DurableFuture waitForConditionAsync( + default DurableFuture waitForConditionAsync( String name, TypeToken resultType, BiFunction> checkFunc, - T initialState); + T initialState) { + return waitForConditionAsync( + name, + resultType, + checkFunc, + initialState, + WaitForConditionConfig.builder().build()); + } /** * Asynchronously polls a condition function until it signals done, using a {@link TypeToken} and custom diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java index 40ab21dda..3e8e29c5c 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java @@ -10,7 +10,6 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; -import java.util.function.Supplier; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.lambda.durable.DurableCallbackFuture; @@ -150,47 +149,6 @@ public StepContextImpl createStepContext(String stepOperationId, String stepOper attempt); } - // ========== step methods ========== - - @Override - public T step(String name, Class resultType, Function func) { - return step(name, TypeToken.get(resultType), func, StepConfig.builder().build()); - } - - @Override - public T step(String name, Class resultType, Function func, StepConfig config) { - // Simply delegate to stepAsync and block on the result - return stepAsync(name, resultType, func, config).get(); - } - - @Override - public T step(String name, TypeToken resultType, Function func) { - return step(name, resultType, func, StepConfig.builder().build()); - } - - @Override - public T step(String name, TypeToken resultType, Function func, StepConfig config) { - // Simply delegate to stepAsync and block on the result - return stepAsync(name, resultType, func, config).get(); - } - - @Override - public DurableFuture stepAsync(String name, Class resultType, Function func) { - return stepAsync( - name, TypeToken.get(resultType), func, StepConfig.builder().build()); - } - - @Override - public DurableFuture stepAsync( - String name, Class resultType, Function func, StepConfig config) { - return stepAsync(name, TypeToken.get(resultType), func, config); - } - - @Override - public DurableFuture stepAsync(String name, TypeToken resultType, Function func) { - return stepAsync(name, resultType, func, StepConfig.builder().build()); - } - @Override public DurableFuture stepAsync( String name, TypeToken resultType, Function func, StepConfig config) { @@ -212,78 +170,6 @@ public DurableFuture stepAsync( return operation; } - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public T step(String name, Class resultType, Supplier func) { - return stepAsync( - name, - TypeToken.get(resultType), - func, - StepConfig.builder().build()) - .get(); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public T step(String name, Class resultType, Supplier func, StepConfig config) { - // Simply delegate to stepAsync and block on the result - return stepAsync(name, TypeToken.get(resultType), func, config).get(); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public T step(String name, TypeToken resultType, Supplier func) { - return stepAsync(name, resultType, func, StepConfig.builder().build()).get(); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public T step(String name, TypeToken resultType, Supplier func, StepConfig config) { - // Simply delegate to stepAsync and block on the result - return stepAsync(name, resultType, func, config).get(); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public DurableFuture stepAsync(String name, Class resultType, Supplier func) { - return stepAsync( - name, TypeToken.get(resultType), func, StepConfig.builder().build()); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public DurableFuture stepAsync(String name, Class resultType, Supplier func, StepConfig config) { - return stepAsync(name, TypeToken.get(resultType), func, config); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public DurableFuture stepAsync(String name, TypeToken resultType, Supplier func) { - return stepAsync(name, resultType, func, StepConfig.builder().build()); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public DurableFuture stepAsync(String name, TypeToken resultType, Supplier func, StepConfig config) { - return stepAsync(name, resultType, stepContext -> func.get(), config); - } - - // ========== wait methods ========== - - @Override - public Void wait(String name, Duration duration) { - // Block (will throw SuspendExecutionException if there is no active thread) - return waitAsync(name, duration).get(); - } - @Override public DurableFuture waitAsync(String name, Duration duration) { ParameterValidator.validateDuration(duration, "Wait duration"); @@ -299,64 +185,6 @@ public DurableFuture waitAsync(String name, Duration duration) { return operation; } - // ========== chained invoke methods ========== - - @Override - public T invoke(String name, String functionName, U payload, Class resultType) { - return invokeAsync( - name, - functionName, - payload, - TypeToken.get(resultType), - InvokeConfig.builder().build()) - .get(); - } - - @Override - public T invoke(String name, String functionName, U payload, Class resultType, InvokeConfig config) { - return invokeAsync(name, functionName, payload, TypeToken.get(resultType), config) - .get(); - } - - @Override - public T invoke(String name, String functionName, U payload, TypeToken resultType) { - return invokeAsync( - name, - functionName, - payload, - resultType, - InvokeConfig.builder().build()) - .get(); - } - - @Override - public T invoke(String name, String functionName, U payload, TypeToken resultType, InvokeConfig config) { - return invokeAsync(name, functionName, payload, resultType, config).get(); - } - - /** Asynchronously invokes another Lambda function with custom configuration. */ - @Override - public DurableFuture invokeAsync( - String name, String functionName, U payload, Class resultType, InvokeConfig config) { - return invokeAsync(name, functionName, payload, TypeToken.get(resultType), config); - } - - @Override - public DurableFuture invokeAsync(String name, String functionName, U payload, Class resultType) { - return invokeAsync( - name, - functionName, - payload, - TypeToken.get(resultType), - InvokeConfig.builder().build()); - } - - @Override - public DurableFuture invokeAsync(String name, String functionName, U payload, TypeToken resultType) { - return invokeAsync( - name, functionName, payload, resultType, InvokeConfig.builder().build()); - } - @Override public DurableFuture invokeAsync( String name, String functionName, U payload, TypeToken resultType, InvokeConfig config) { @@ -387,24 +215,6 @@ public DurableFuture invokeAsync( return operation; // Block (will throw SuspendExecutionException if needed) } - // ========== createCallback methods ========== - - @Override - public DurableCallbackFuture createCallback(String name, Class resultType, CallbackConfig config) { - return createCallback(name, TypeToken.get(resultType), config); - } - - @Override - public DurableCallbackFuture createCallback(String name, TypeToken resultType) { - return createCallback(name, resultType, CallbackConfig.builder().build()); - } - - @Override - public DurableCallbackFuture createCallback(String name, Class resultType) { - return createCallback( - name, TypeToken.get(resultType), CallbackConfig.builder().build()); - } - @Override public DurableCallbackFuture createCallback(String name, TypeToken resultType, CallbackConfig config) { ParameterValidator.validateOperationName(name); @@ -420,148 +230,6 @@ public DurableCallbackFuture createCallback(String name, TypeToken res return operation; } - // ========== runInChildContext methods ========== - - /** - * Runs a function in a child context, blocking until it completes. - * - *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID - * collisions. On replay, the child context's operations are replayed independently. - * - * @param name the operation name within this context - * @param resultType the result class for deserialization - * @param func the function to execute, receiving a child {@link DurableContext} - * @return the child context result - */ - @Override - public T runInChildContext(String name, Class resultType, Function func) { - return runInChildContextAsync( - name, - TypeToken.get(resultType), - func, - RunInChildContextConfig.builder().build()) - .get(); - } - - /** - * Runs a function in a child context, blocking until it completes. - * - *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID - * collisions. On replay, the child context's operations are replayed independently. - * - * @param name the operation name within this context - * @param resultType the result class for deserialization - * @param func the function to execute, receiving a child {@link DurableContext} - * @return the child context result - */ - @Override - public T runInChildContext(String name, TypeToken resultType, Function func) { - return runInChildContextAsync( - name, - resultType, - func, - RunInChildContextConfig.builder().build()) - .get(); - } - - /** - * Runs a function in a child context, blocking until it completes. - * - *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID - * collisions. On replay, the child context's operations are replayed independently. - * - * @param name the operation name within this context - * @param resultType the result class for deserialization - * @param func the function to execute, receiving a child {@link DurableContext} - * @return the DurableFuture of the child context result - */ - @Override - public DurableFuture runInChildContextAsync( - String name, Class resultType, Function func) { - return runInChildContextAsync( - name, - TypeToken.get(resultType), - func, - RunInChildContextConfig.builder().build()); - } - - /** - * Runs a function in a child context, blocking until it completes. - * - *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID - * collisions. On replay, the child context's operations are replayed independently. - * - * @param name the operation name within this context - * @param resultType the result class for deserialization - * @param func the function to execute, receiving a child {@link DurableContext} - * @return the DurableFuture of the child context result - */ - @Override - public DurableFuture runInChildContextAsync( - String name, TypeToken resultType, Function func) { - return runInChildContextAsync( - name, - resultType, - func, - RunInChildContextConfig.builder().build(), - OperationSubType.RUN_IN_CHILD_CONTEXT); - } - - /** - * Runs a function in a child context, blocking until it completes. - * - *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID - * collisions. On replay, the child context's operations are replayed independently. - * - * @param name the operation name within this context - * @param resultType the result class for deserialization - * @param func the function to execute, receiving a child {@link DurableContext} - * @param config the configuration for the child context - * @return the child context result - */ - @Override - public T runInChildContext( - String name, Class resultType, Function func, RunInChildContextConfig config) { - return runInChildContextAsync(name, TypeToken.get(resultType), func, config) - .get(); - } - - /** - * Runs a function in a child context, blocking until it completes. - * - *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID - * collisions. On replay, the child context's operations are replayed independently. - * - * @param name the operation name within this context - * @param resultType the result class for deserialization - * @param func the function to execute, receiving a child {@link DurableContext} - * @param config the configuration for the child context - * @return the child context result - */ - @Override - public T runInChildContext( - String name, TypeToken resultType, Function func, RunInChildContextConfig config) { - return runInChildContextAsync(name, resultType, func, config).get(); - } - - /** - * Runs a function in a child context, blocking until it completes. - * - *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID - * collisions. On replay, the child context's operations are replayed independently. - * - * @param name the operation name within this context - * @param resultType the result class for deserialization - * @param func the function to execute, receiving a child {@link DurableContext} - * @param config the configuration for the child context - * @return the DurableFuture wrapping the child context result - */ - @Override - public DurableFuture runInChildContextAsync( - String name, Class resultType, Function func, RunInChildContextConfig config) { - return runInChildContextAsync(name, TypeToken.get(resultType), func, config); - } - /** * Runs a function in a child context, blocking until it completes. * @@ -607,62 +275,6 @@ private DurableFuture runInChildContextAsync( return operation; } - // ========== map methods ========== - - @Override - public MapResult map(String name, Collection items, Class resultType, MapFunction function) { - return mapAsync( - name, - items, - TypeToken.get(resultType), - function, - MapConfig.builder().build()) - .get(); - } - - @Override - public MapResult map( - String name, Collection items, Class resultType, MapFunction function, MapConfig config) { - return mapAsync(name, items, TypeToken.get(resultType), function, config) - .get(); - } - - @Override - public MapResult map( - String name, Collection items, TypeToken resultType, MapFunction function) { - return mapAsync(name, items, resultType, function, MapConfig.builder().build()) - .get(); - } - - @Override - public MapResult map( - String name, Collection items, TypeToken resultType, MapFunction function, MapConfig config) { - return mapAsync(name, items, resultType, function, config).get(); - } - - @Override - public DurableFuture> mapAsync( - String name, Collection items, Class resultType, MapFunction function) { - return mapAsync( - name, - items, - TypeToken.get(resultType), - function, - MapConfig.builder().build()); - } - - @Override - public DurableFuture> mapAsync( - String name, Collection items, Class resultType, MapFunction function, MapConfig config) { - return mapAsync(name, items, TypeToken.get(resultType), function, config); - } - - @Override - public DurableFuture> mapAsync( - String name, Collection items, TypeToken resultType, MapFunction function) { - return mapAsync(name, items, resultType, function, MapConfig.builder().build()); - } - @Override public DurableFuture> mapAsync( String name, Collection items, TypeToken resultType, MapFunction function, MapConfig config) { @@ -697,8 +309,6 @@ public DurableFuture> mapAsync( return operation; } - // ========== parallel methods ========== - @Override public ParallelDurableFuture parallel(String name, ParallelConfig config) { Objects.requireNonNull(config, "config cannot be null"); @@ -715,71 +325,6 @@ public ParallelDurableFuture parallel(String name, ParallelConfig config) { return parallelOp; } - // ========= waitForCallback methods ============= - - @Override - public T waitForCallback(String name, Class resultType, BiConsumer func) { - return waitForCallbackAsync( - name, - TypeToken.get(resultType), - func, - WaitForCallbackConfig.builder().build()) - .get(); - } - - @Override - public T waitForCallback(String name, TypeToken resultType, BiConsumer func) { - return waitForCallbackAsync( - name, resultType, func, WaitForCallbackConfig.builder().build()) - .get(); - } - - @Override - public T waitForCallback( - String name, - Class resultType, - BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig) { - return waitForCallbackAsync(name, TypeToken.get(resultType), func, waitForCallbackConfig) - .get(); - } - - @Override - public T waitForCallback( - String name, - TypeToken resultType, - BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig) { - return waitForCallbackAsync(name, resultType, func, waitForCallbackConfig) - .get(); - } - - @Override - public DurableFuture waitForCallbackAsync( - String name, Class resultType, BiConsumer func) { - return waitForCallbackAsync( - name, - TypeToken.get(resultType), - func, - WaitForCallbackConfig.builder().build()); - } - - @Override - public DurableFuture waitForCallbackAsync( - String name, TypeToken resultType, BiConsumer func) { - return waitForCallbackAsync( - name, resultType, func, WaitForCallbackConfig.builder().build()); - } - - @Override - public DurableFuture waitForCallbackAsync( - String name, - Class resultType, - BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig) { - return waitForCallbackAsync(name, TypeToken.get(resultType), func, waitForCallbackConfig); - } - @Override public DurableFuture waitForCallbackAsync( String name, @@ -824,97 +369,6 @@ public DurableFuture waitForCallbackAsync( OperationSubType.WAIT_FOR_CALLBACK); } - // ========== waitForCondition methods ========== - @Override - public T waitForCondition( - String name, - Class resultType, - BiFunction> checkFunc, - T initialState) { - return waitForConditionAsync( - name, - TypeToken.get(resultType), - checkFunc, - initialState, - WaitForConditionConfig.builder().build()) - .get(); - } - - @Override - public T waitForCondition( - String name, - Class resultType, - BiFunction> checkFunc, - T initialState, - WaitForConditionConfig config) { - return waitForConditionAsync(name, resultType, checkFunc, initialState, config) - .get(); - } - - @Override - public T waitForCondition( - String name, - TypeToken resultType, - BiFunction> checkFunc, - T initialState) { - return waitForConditionAsync( - name, - resultType, - checkFunc, - initialState, - WaitForConditionConfig.builder().build()) - .get(); - } - - @Override - public T waitForCondition( - String name, - TypeToken resultType, - BiFunction> checkFunc, - T initialState, - WaitForConditionConfig config) { - return waitForConditionAsync(name, resultType, checkFunc, initialState, config) - .get(); - } - - @Override - public DurableFuture waitForConditionAsync( - String name, - Class resultType, - BiFunction> checkFunc, - T initialState) { - return waitForConditionAsync( - name, - TypeToken.get(resultType), - checkFunc, - initialState, - WaitForConditionConfig.builder().build()); - } - - @Override - public DurableFuture waitForConditionAsync( - String name, - Class resultType, - BiFunction> checkFunc, - T initialState, - WaitForConditionConfig config) { - return waitForConditionAsync(name, TypeToken.get(resultType), checkFunc, initialState, config); - } - - @Override - public DurableFuture waitForConditionAsync( - String name, - TypeToken resultType, - BiFunction> checkFunc, - T initialState) { - return waitForConditionAsync( - name, - resultType, - checkFunc, - initialState, - WaitForConditionConfig.builder().build()); - } - @Override public DurableFuture waitForConditionAsync( String name, From ce479879a4b632764feeeb7c22af543504adff53 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sat, 21 Mar 2026 22:57:26 -0700 Subject: [PATCH 07/15] remove duplicate concurrency code --- .../operation/ConcurrencyOperation.java | 61 +++++++++---------- .../durable/operation/MapOperation.java | 36 ++++------- .../durable/operation/ParallelOperation.java | 21 +------ .../operation/ConcurrencyOperationTest.java | 48 ++------------- .../operation/ParallelOperationTest.java | 36 +++++------ 5 files changed, 64 insertions(+), 138 deletions(-) diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index 221fd55df..f78ffae3e 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -14,12 +14,15 @@ import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.OperationIdGenerator; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.OperationIdentifier; +import software.amazon.lambda.durable.model.OperationSubType; import software.amazon.lambda.durable.serde.SerDes; /** @@ -88,13 +91,22 @@ protected ConcurrencyOperation( * @param the result type of the child operation * @return a new ChildContextOperation */ - protected abstract ChildContextOperation createItem( + protected ChildContextOperation createItem( String operationId, String name, Function function, TypeToken resultType, SerDes serDes, - DurableContextImpl parentContext); + OperationSubType branchSubType, + DurableContextImpl parentContext) { + return new ChildContextOperation<>( + OperationIdentifier.of(operationId, name, OperationType.CONTEXT, branchSubType), + function, + resultType, + RunInChildContextConfig.builder().serDes(serDes).build(), + parentContext, + this); + } /** Called when the concurrency operation succeeds. Subclasses define checkpointing behavior. */ protected abstract void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus); @@ -113,26 +125,25 @@ protected abstract ChildContextOperation createItem( * @return the created ChildContextOperation */ public ChildContextOperation addItem( - String name, Function function, TypeToken resultType, SerDes serDes) { - if (isOperationCompleted()) throw new IllegalStateException("Cannot add items to a completed operation"); - var operationId = this.operationIdGenerator.nextOperationId(); - var childOp = createItem(operationId, name, function, resultType, serDes, this.rootContext); - branches.add(childOp); - pendingQueue.add(childOp); - logger.debug("Item added {}", name); + String name, Function function, TypeToken resultType, SerDes serDes, OperationSubType branchSubType) { + var childOp = enqueueItem(name, function, resultType, serDes, branchSubType); executeNextItemIfAllowed(); return childOp; } /** - * Creates and enqueues an item without starting execution. Use {@link #startPendingItems()} to begin execution - * after all items have been enqueued. This prevents early termination from blocking item creation when all items - * are known upfront (e.g., map operations). + * Creates and enqueues an item without starting execution. Use {@link #executeNextItemIfAllowed()} to begin + * execution after all items have been enqueued. This prevents early termination from blocking item creation when + * all items are known upfront (e.g., map operations). */ protected ChildContextOperation enqueueItem( - String name, Function function, TypeToken resultType, SerDes serDes) { + String name, + Function function, + TypeToken resultType, + SerDes serDes, + OperationSubType branchSubType) { var operationId = this.operationIdGenerator.nextOperationId(); - var childOp = createItem(operationId, name, function, resultType, serDes, this.rootContext); + var childOp = createItem(operationId, name, function, resultType, serDes, branchSubType, this.rootContext); branches.add(childOp); pendingQueue.add(childOp); logger.debug("Item enqueued {}", name); @@ -140,10 +151,10 @@ protected ChildContextOperation enqueueItem( } /** - * Starts executing enqueued items up to maxConcurrency. Called after all items have been enqueued via - * {@link #enqueueItem}. + * Starts the queued items if the running count is below maxConcurrency and the operation hasn't completed yet. Must + * be called within {@code synchronized (pendingQueue)}. */ - protected void startPendingItems() { + protected void executeNextItemIfAllowed() { // Start as many items as concurrency allows while (true) { synchronized (this) { @@ -158,22 +169,6 @@ protected void startPendingItems() { } } - /** - * Starts the next queued item if the running count is below maxConcurrency and the operation hasn't completed yet. - * Must be called within {@code synchronized (pendingQueue)}. - */ - private void executeNextItemIfAllowed() { - synchronized (this) { - if (isOperationCompleted()) return; - if (runningCount.get() >= maxConcurrency) return; - var next = pendingQueue.poll(); - if (next == null) return; - runningCount.incrementAndGet(); - logger.debug("Executing operation {}", next.getName()); - next.execute(); - } - } - /** * Called by a ChildContextOperation BEFORE it closes its child context. Updates counters, checks completion * criteria, and either triggers the next queued item or completes the operation. diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java index 146e9695c..69db319b7 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -5,17 +5,14 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.function.Function; import software.amazon.awssdk.services.lambda.model.ContextOptions; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; -import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.config.CompletionConfig; import software.amazon.lambda.durable.config.MapConfig; -import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.MapResult; @@ -65,6 +62,8 @@ public MapOperation( this.function = function; this.itemResultType = itemResultType; this.serDes = config.serDes(); + + addAllItems(); } private static Integer getToleratedFailureCount(CompletionConfig completionConfig, int totalItems) { @@ -86,29 +85,12 @@ private static Integer getToleratedFailureCount(CompletionConfig completionConfi return Math.min(toleratedFailureCount, toleratedFailureCountFromPercentage); } - @Override - protected ChildContextOperation createItem( - String operationId, - String name, - Function function, - TypeToken resultType, - SerDes serDes, - DurableContextImpl parentContext) { - return new ChildContextOperation<>( - OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.MAP_ITERATION), - function, - resultType, - RunInChildContextConfig.builder().serDes(serDes).build(), - parentContext, - this); - } - @Override protected void start() { sendOperationUpdateAsync(OperationUpdate.builder() .action(OperationAction.START) .subType(getSubType().getValue())); - addAllItems(); + executeNextItemIfAllowed(); } @Override @@ -118,7 +100,7 @@ protected void replay(Operation existing) { if (existing.contextDetails() != null && Boolean.TRUE.equals(existing.contextDetails().replayChildren())) { // Large result: re-execute children to reconstruct MapResult - addAllItems(); + executeNextItemIfAllowed(); } else { // Small result: MapResult is in the payload, skip child replay replayFromPayload = true; @@ -128,7 +110,7 @@ protected void replay(Operation existing) { case STARTED -> { // Map was in progress when interrupted — re-create children without sending // another START (the backend rejects duplicate START for existing operations) - addAllItems(); + executeNextItemIfAllowed(); } default -> terminateExecutionWithIllegalDurableOperationException( @@ -140,13 +122,17 @@ private void addAllItems() { // Enqueue all items first, then start execution. This prevents early termination // criteria (e.g., minSuccessful) from completing the operation mid-loop on replay, // which would cause subsequent enqueue calls to fail with "completed operation". + var branchPrefix = getName() == null ? "map-iteration-" : getName() + "-iteration-"; for (int i = 0; i < items.size(); i++) { var index = i; var item = items.get(i); enqueueItem( - "map-iteration-" + i, childCtx -> function.apply(item, index, childCtx), itemResultType, serDes); + branchPrefix + i, + childCtx -> function.apply(item, index, childCtx), + itemResultType, + serDes, + OperationSubType.MAP_ITERATION); } - startPendingItems(); } @Override diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index 48880138c..d1cb21682 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -6,7 +6,6 @@ import software.amazon.awssdk.services.lambda.model.ContextOptions; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; -import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; @@ -14,7 +13,6 @@ import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.config.ParallelBranchConfig; import software.amazon.lambda.durable.config.ParallelConfig; -import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; @@ -64,23 +62,6 @@ public ParallelOperation( config.completionConfig().toleratedFailureCount()); } - @Override - protected ChildContextOperation createItem( - String operationId, - String name, - Function function, - TypeToken resultType, - SerDes serDes, - DurableContextImpl parentContext) { - return new ChildContextOperation<>( - OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL_BRANCH), - function, - resultType, - RunInChildContextConfig.builder().serDes(serDes).build(), - parentContext, - this); - } - @Override protected void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus) { if (skipCheckpoint) { @@ -126,6 +107,6 @@ public DurableFuture branch( throw new IllegalStateException("Cannot add branches after join() has been called"); } var serDes = config.serDes() == null ? getContext().getDurableConfig().getSerDes() : config.serDes(); - return addItem(name, func, resultType, serDes); + return addItem(name, func, resultType, serDes, OperationSubType.PARALLEL_BRANCH); } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java index 57770ea60..9674f1659 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java @@ -12,7 +12,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.lambda.model.ContextDetails; @@ -20,10 +19,8 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.lambda.durable.DurableConfig; -import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.config.CompletionConfig; -import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.execution.OperationIdGenerator; @@ -142,7 +139,8 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { return "result-1"; }, TypeToken.get(String.class), - SER_DES); + SER_DES, + OperationSubType.PARALLEL_BRANCH); op.addItem( "branch-2", ctx -> { @@ -150,7 +148,8 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { return "result-2"; }, TypeToken.get(String.class), - SER_DES); + SER_DES, + OperationSubType.PARALLEL_BRANCH); op.exposedJoin(); @@ -183,7 +182,8 @@ void singleChildAlreadySucceeds_fullCycle() throws Exception { return "done"; }, TypeToken.get(String.class), - SER_DES); + SER_DES, + OperationSubType.PARALLEL_BRANCH); op.exposedJoin(); @@ -193,18 +193,6 @@ void singleChildAlreadySucceeds_fullCycle() throws Exception { assertFalse(functionCalled.get(), "Function should not be called during SUCCEEDED replay"); } - @Test - void addItem_usesRootChildContextAsParent() throws Exception { - var op = createOperation(CompletionConfig.allSuccessful()); - - op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); - - // rootContext is created via durableContext.createChildContext(...) in the constructor, - // so the parentContext passed to createItem must be that child context, not durableContext itself - assertNotSame(durableContext, op.getLastParentContext()); - assertSame(childContext, op.getLastParentContext()); - } - // ===== Test subclass ===== static class TestConcurrencyOperation extends ConcurrencyOperation { @@ -231,30 +219,6 @@ static class TestConcurrencyOperation extends ConcurrencyOperation { completionConfig.toleratedFailureCount()); } - @Override - protected ChildContextOperation createItem( - String operationId, - String name, - Function function, - TypeToken resultType, - SerDes serDes, - DurableContextImpl parentContext) { - lastParentContext = parentContext; - return new ChildContextOperation( - OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL_BRANCH), - function, - resultType, - RunInChildContextConfig.builder().serDes(serDes).build(), - parentContext, - this) { - @Override - public void execute() { - executingCount.incrementAndGet(); - super.execute(); - } - }; - } - @Override protected void handleSuccess(ConcurrencyCompletionStatus completionStatus) { successHandled = true; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index 138e98c31..0003d83a8 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -127,22 +127,22 @@ private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGene // ===== Branch creation delegates to ConcurrencyOperation ===== @Test - void branchCreation_createsBranchWithParallelBranchSubType() throws Exception { + void branchCreation_createsBranchWithParallelBranchSubType() { var op = createOperation(CompletionConfig.allSuccessful()); - var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); + var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); assertNotNull(childOp); assertEquals(OperationSubType.PARALLEL_BRANCH, childOp.getSubType()); } @Test - void branchCreation_multipleBranchesAllCreated() throws Exception { + void branchCreation_multipleBranchesAllCreated() { var op = createOperation(CompletionConfig.allSuccessful()); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); - op.addItem("branch-3", ctx -> "r3", TypeToken.get(String.class), SER_DES); + op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.addItem("branch-3", ctx -> "r3", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); assertEquals(3, op.getTotalItems()); } @@ -152,7 +152,7 @@ void branchCreation_childOperationHasParentReference() throws Exception { var op = createOperation(CompletionConfig.allSuccessful()); // The child operation should be a ChildContextOperation with this op as parent - var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); + var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); assertNotNull(childOp); // Verify it's a ChildContextOperation (the concrete type returned by createItem) @@ -186,8 +186,8 @@ void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws E var op = createOperation(CompletionConfig.allSuccessful()); setOperationIdGenerator(op, mockIdGenerator); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); + op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); @@ -216,7 +216,7 @@ void minSuccessful_completesWhenThresholdMetAndReturnsResult() throws Exception var op = createOperation(CompletionConfig.minSuccessful(1)); setOperationIdGenerator(op, mockIdGenerator); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); + op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); @@ -236,7 +236,7 @@ void contextHierarchy_branchesUseParallelContextAsParent() throws Exception { // as their parent — not some other context var op = createOperation(CompletionConfig.allSuccessful()); - var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); + var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); // The child operation should be registered in the execution manager // (BaseDurableOperation constructor calls executionManager.registerOperation) @@ -280,8 +280,8 @@ void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exc var op = createOperation(CompletionConfig.allSuccessful()); setOperationIdGenerator(op, mockIdGenerator); op.execute(); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); + op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); @@ -329,8 +329,8 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio var op = createOperation(CompletionConfig.allSuccessful()); setOperationIdGenerator(op, mockIdGenerator); op.execute(); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); + op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); @@ -365,7 +365,7 @@ void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() throws Except throw new RuntimeException("branch failed"); }, TypeToken.get(String.class), - SER_DES); + SER_DES, OperationSubType.PARALLEL_BRANCH); var result = assertDoesNotThrow(() -> op.get()); @@ -402,14 +402,14 @@ void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exceptio // toleratedFailureCount=1 so the operation completes after both branches finish var op = createOperation(CompletionConfig.toleratedFailureCount(1)); setOperationIdGenerator(op, mockIdGenerator); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); + op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); op.addItem( "branch-2", ctx -> { throw new RuntimeException("branch failed"); }, TypeToken.get(String.class), - SER_DES); + SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); From b50318a3ff2c0457f0e724f8eaa0ccd436bb2b65 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sat, 21 Mar 2026 23:43:00 -0700 Subject: [PATCH 08/15] cleanup map operation --- .../operation/ConcurrencyOperation.java | 6 +- .../durable/operation/MapOperation.java | 61 ++++++------------- .../operation/ParallelOperationTest.java | 15 +++-- 3 files changed, 35 insertions(+), 47 deletions(-) diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index f78ffae3e..9b6974684 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -125,7 +125,11 @@ protected ChildContextOperation createItem( * @return the created ChildContextOperation */ public ChildContextOperation addItem( - String name, Function function, TypeToken resultType, SerDes serDes, OperationSubType branchSubType) { + String name, + Function function, + TypeToken resultType, + SerDes serDes, + OperationSubType branchSubType) { var childOp = enqueueItem(name, function, resultType, serDes, branchSubType); executeNextItemIfAllowed(); return childOp; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java index 69db319b7..78460f5f2 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -41,7 +41,6 @@ public class MapOperation extends ConcurrencyOperation> { private final SerDes serDes; private boolean replayFromPayload; private volatile MapResult cachedResult; - private ConcurrencyCompletionStatus completionStatus; public MapOperation( OperationIdentifier operationIdentifier, @@ -90,6 +89,7 @@ protected void start() { sendOperationUpdateAsync(OperationUpdate.builder() .action(OperationAction.START) .subType(getSubType().getValue())); + executeNextItemIfAllowed(); } @@ -135,16 +135,27 @@ private void addAllItems() { } } + @SuppressWarnings("unchecked") @Override protected void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus) { - this.completionStatus = concurrencyCompletionStatus; - checkpointMapResult(); - } + var children = getBranches(); + var resultItems = new ArrayList>(Collections.nCopies(items.size(), null)); - private void checkpointMapResult() { - var result = aggregateResults(); - this.cachedResult = result; - var serialized = serializeResult(result); + for (int i = 0; i < children.size(); i++) { + var branch = (ChildContextOperation) children.get(i); + if (!branch.isOperationCompleted()) { + resultItems.set(i, MapResult.MapResultItem.skipped()); + } else { + try { + resultItems.set(i, MapResult.MapResultItem.succeeded(branch.get())); + } catch (Exception e) { + resultItems.set(i, MapResult.MapResultItem.failed(MapResult.MapError.of(e))); + } + } + } + + this.cachedResult = new MapResult<>(resultItems, concurrencyCompletionStatus); + var serialized = serializeResult(cachedResult); var serializedBytes = serialized.getBytes(java.nio.charset.StandardCharsets.UTF_8); if (serializedBytes.length < LARGE_RESULT_THRESHOLD) { @@ -173,38 +184,6 @@ public MapResult get() { } // First execution or large result replay: wait for children, then aggregate join(); - return cachedResult != null ? cachedResult : aggregateResults(); - } - - /** - * Aggregates results from completed branches into a {@code MapResult}. - * - *

    Called after all branches have completed. At this point every branch's {@code completionFuture} is already - * done, so {@code branch.get()} returns immediately without blocking. - */ - @SuppressWarnings("unchecked") - private MapResult aggregateResults() { - var children = getBranches(); - var resultItems = new ArrayList>(Collections.nCopies(items.size(), null)); - - for (int i = 0; i < children.size(); i++) { - var branch = (ChildContextOperation) children.get(i); - if (!branch.isOperationCompleted()) { - resultItems.set(i, MapResult.MapResultItem.skipped()); - continue; - } - try { - resultItems.set(i, MapResult.MapResultItem.succeeded(branch.get())); - } catch (Exception e) { - resultItems.set(i, MapResult.MapResultItem.failed(MapResult.MapError.of(e))); - } - } - - // Fill any remaining null slots (items beyond children size) with skipped - for (int i = children.size(); i < items.size(); i++) { - resultItems.set(i, MapResult.MapResultItem.skipped()); - } - - return new MapResult<>(resultItems, completionStatus); + return cachedResult; } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index 0003d83a8..7222fd7f9 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -130,7 +130,8 @@ private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGene void branchCreation_createsBranchWithParallelBranchSubType() { var op = createOperation(CompletionConfig.allSuccessful()); - var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + var childOp = op.addItem( + "branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); assertNotNull(childOp); assertEquals(OperationSubType.PARALLEL_BRANCH, childOp.getSubType()); @@ -152,7 +153,8 @@ void branchCreation_childOperationHasParentReference() throws Exception { var op = createOperation(CompletionConfig.allSuccessful()); // The child operation should be a ChildContextOperation with this op as parent - var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + var childOp = op.addItem( + "branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); assertNotNull(childOp); // Verify it's a ChildContextOperation (the concrete type returned by createItem) @@ -236,7 +238,8 @@ void contextHierarchy_branchesUseParallelContextAsParent() throws Exception { // as their parent — not some other context var op = createOperation(CompletionConfig.allSuccessful()); - var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + var childOp = op.addItem( + "branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); // The child operation should be registered in the execution manager // (BaseDurableOperation constructor calls executionManager.registerOperation) @@ -365,7 +368,8 @@ void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() throws Except throw new RuntimeException("branch failed"); }, TypeToken.get(String.class), - SER_DES, OperationSubType.PARALLEL_BRANCH); + SER_DES, + OperationSubType.PARALLEL_BRANCH); var result = assertDoesNotThrow(() -> op.get()); @@ -409,7 +413,8 @@ void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exceptio throw new RuntimeException("branch failed"); }, TypeToken.get(String.class), - SER_DES, OperationSubType.PARALLEL_BRANCH); + SER_DES, + OperationSubType.PARALLEL_BRANCH); var result = op.get(); From 878a71d9443fe0ef439b74540371ff8b45ba00c7 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sun, 22 Mar 2026 08:56:46 -0700 Subject: [PATCH 09/15] minor cleanups --- .../operation/ConcurrencyOperation.java | 41 ++++++------------- .../durable/operation/MapOperation.java | 5 +++ .../durable/operation/ParallelOperation.java | 12 +++++- .../operation/ConcurrencyOperationTest.java | 10 +++-- .../operation/ParallelOperationTest.java | 2 +- 5 files changed, 36 insertions(+), 34 deletions(-) diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index 9b6974684..bef8c3a48 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -59,7 +59,6 @@ public abstract class ConcurrencyOperation extends BaseDurableOperation { private final Set completedOperations = Collections.synchronizedSet(new HashSet<>()); private final OperationIdGenerator operationIdGenerator; private final DurableContextImpl rootContext; - private ConcurrencyCompletionStatus completionStatus; protected ConcurrencyOperation( OperationIdentifier operationIdentifier, @@ -86,7 +85,7 @@ protected ConcurrencyOperation( * @param name the name of this item * @param function the user function to execute * @param resultType the result type token - * @param serDes the serializer/deserializer + * @param branchSubType the sub-type of the branch operation * @param parentContext the parent durable context * @param the result type of the child operation * @return a new ChildContextOperation @@ -206,9 +205,9 @@ public void onItemComplete(ChildContextOperation child) { } runningCount.decrementAndGet(); - this.completionStatus = canComplete(); - if (this.completionStatus != null) { - handleComplete(this.completionStatus); + var completionStatus = canComplete(); + if (completionStatus != null) { + handleComplete(completionStatus); } else { executeNextItemIfAllowed(); } @@ -223,9 +222,9 @@ public void onItemComplete(ChildContextOperation child) { * @throws IllegalArgumentException if the item count cannot satisfy the criteria */ protected void validateItemCount() { - if (minSuccessful != null && minSuccessful > getTotalItems()) { + if (minSuccessful != null && minSuccessful > branches.size()) { throw new IllegalStateException("minSuccessful (" + minSuccessful - + ") exceeds the number of registered items (" + getTotalItems() + ")"); + + ") exceeds the number of registered items (" + branches.size() + ")"); } } @@ -235,8 +234,8 @@ protected void validateItemCount() { * @return the completion status if the operation is complete, or null if it should continue */ protected ConcurrencyCompletionStatus canComplete() { - int succeeded = getSucceededCount(); - int failed = getFailedCount(); + int succeeded = succeededCount.get(); + int failed = failedCount.get(); // If we've met the minimum successful count, we're done if (minSuccessful != null && succeeded >= minSuccessful) { @@ -273,31 +272,17 @@ public void join() { validateItemCount(); isJoined.set(true); synchronized (this) { - this.completionStatus = canComplete(); - if (this.completionStatus != null) { - handleComplete(this.completionStatus); + if (!isOperationCompleted()) { + var completionStatus = canComplete(); + if (completionStatus != null) { + handleComplete(completionStatus); + } } } waitForOperationCompletion(); } - protected int getSucceededCount() { - return succeededCount.get(); - } - - protected int getFailedCount() { - return failedCount.get(); - } - - protected int getTotalItems() { - return branches.size(); - } - - protected ConcurrencyCompletionStatus getCompletionStatus() { - return completionStatus; - } - protected List> getBranches() { return branches; } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java index 78460f5f2..c5188b88d 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -57,6 +57,11 @@ public MapOperation( config.maxConcurrency(), config.completionConfig().minSuccessful(), getToleratedFailureCount(config.completionConfig(), items.size())); + if (config.completionConfig().minSuccessful() != null + && config.completionConfig().minSuccessful() > items.size()) { + throw new IllegalArgumentException("minSuccessful cannot be greater than total items: " + + config.completionConfig().minSuccessful() + " > " + items.size()); + } this.items = List.copyOf(items); this.function = function; this.itemResultType = itemResultType; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index d1cb21682..7b28766ad 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -6,6 +6,7 @@ import software.amazon.awssdk.services.lambda.model.ContextOptions; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; +import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; @@ -46,6 +47,7 @@ public class ParallelOperation extends ConcurrencyOperation impl // this field could be written and read in different threads private volatile boolean skipCheckpoint = false; + private volatile ParallelResult cachedResult; public ParallelOperation( OperationIdentifier operationIdentifier, @@ -64,6 +66,14 @@ public ParallelOperation( @Override protected void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus) { + var items = getBranches(); + int succeededCount = Math.toIntExact(items.stream() + .filter(item -> item.getOperation().status() == OperationStatus.SUCCEEDED) + .count()); + int failedCount = Math.toIntExact(items.stream() + .filter(item -> item.getOperation().status() != OperationStatus.SUCCEEDED) + .count()); + this.cachedResult = new ParallelResult(items.size(), succeededCount, failedCount, concurrencyCompletionStatus); if (skipCheckpoint) { // Do not send checkpoint during replay markAlreadyCompleted(); @@ -92,7 +102,7 @@ protected void replay(Operation existing) { @Override public ParallelResult get() { join(); - return new ParallelResult(getTotalItems(), getSucceededCount(), getFailedCount(), getCompletionStatus()); + return cachedResult; } /** Calls {@link #get()} if not already called. Guarantees that the context is closed. */ diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java index 9674f1659..d06d23bb6 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java @@ -155,8 +155,9 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { assertTrue(op.isSuccessHandled()); assertFalse(op.isFailureHandled()); - assertEquals(2, op.getSucceededCount()); - assertEquals(0, op.getFailedCount()); + var items = op.getBranches(); + assertEquals(2, items.size()); + assertTrue(items.stream().allMatch(b -> b.getOperation().status().equals(OperationStatus.SUCCEEDED))); assertFalse(functionCalled.get(), "Functions should not be called during SUCCEEDED replay"); } @@ -188,8 +189,9 @@ void singleChildAlreadySucceeds_fullCycle() throws Exception { op.exposedJoin(); assertTrue(op.isSuccessHandled()); - assertEquals(1, op.getSucceededCount()); - assertEquals(0, op.getFailedCount()); + var items = op.getBranches(); + assertEquals(1, items.size()); + assertEquals(OperationStatus.SUCCEEDED, items.get(0).getOperation().status()); assertFalse(functionCalled.get(), "Function should not be called during SUCCEEDED replay"); } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index 7222fd7f9..9f52cf4c1 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -145,7 +145,7 @@ void branchCreation_multipleBranchesAllCreated() { op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); op.addItem("branch-3", ctx -> "r3", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); - assertEquals(3, op.getTotalItems()); + assertEquals(3, op.getBranches().size()); } @Test From af82eeab81644c6e687dff06ef76361dd26d4254 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sun, 22 Mar 2026 11:15:17 -0700 Subject: [PATCH 10/15] create a thread for concurrency operation --- .../amazon/lambda/durable/CallbackConfig.java | 2 + .../amazon/lambda/durable/StepConfig.java | 3 +- .../lambda/durable/WaitForCallbackConfig.java | 2 + .../durable/execution/ExecutionManager.java | 5 +- .../operation/BaseDurableOperation.java | 114 ++--------- .../durable/operation/CallbackOperation.java | 5 +- .../operation/ChildContextOperation.java | 11 +- .../operation/ConcurrencyOperation.java | 125 +++++++----- .../durable/operation/InvokeOperation.java | 2 +- .../durable/operation/MapOperation.java | 42 ++-- .../durable/operation/ParallelOperation.java | 5 +- .../SerializableDurableOperation.java | 183 ++++++++++++++++++ .../durable/operation/StepOperation.java | 2 +- .../operation/WaitForConditionOperation.java | 2 +- .../durable/operation/WaitOperation.java | 10 +- .../lambda/durable/DurableFutureTest.java | 8 +- .../operation/ConcurrencyOperationTest.java | 8 +- .../operation/ParallelOperationTest.java | 42 ++-- ... => SerializableDurableOperationTest.java} | 54 +++--- 19 files changed, 378 insertions(+), 247 deletions(-) create mode 100644 sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java rename sdk/src/test/java/software/amazon/lambda/durable/operation/{BaseDurableOperationTest.java => SerializableDurableOperationTest.java} (85%) diff --git a/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java index 158228fd4..0d63a5c29 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java @@ -3,8 +3,10 @@ package software.amazon.lambda.durable; /** @deprecated use {@link software.amazon.lambda.durable.config.CallbackConfig} instead. */ +@Deprecated public class CallbackConfig { /** @deprecated use {@link software.amazon.lambda.durable.config.CallbackConfig#builder()} instead. */ + @Deprecated public static software.amazon.lambda.durable.config.CallbackConfig.Builder builder() { return new software.amazon.lambda.durable.config.CallbackConfig.Builder(null, null, null); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java index 997ce75dd..fff17bf90 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java @@ -10,13 +10,14 @@ * * @deprecated use {@link software.amazon.lambda.durable.config.StepConfig} */ +@Deprecated public class StepConfig { /** * Creates a new builder for StepConfig. * - * @return a new Builder instance * @deprecated use {@link software.amazon.lambda.durable.config.StepConfig#builder} */ + @Deprecated public static software.amazon.lambda.durable.config.StepConfig.Builder builder() { return new software.amazon.lambda.durable.config.StepConfig.Builder(null, null, null); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java index fd7096171..16c9a7da3 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java @@ -3,8 +3,10 @@ package software.amazon.lambda.durable; /** @deprecated use {@link software.amazon.lambda.durable.config.WaitForCallbackConfig} instead. */ +@Deprecated public class WaitForCallbackConfig { /** @deprecated use {@link software.amazon.lambda.durable.config.WaitForCallbackConfig#builder()} instead. */ + @Deprecated public static software.amazon.lambda.durable.config.WaitForCallbackConfig.Builder builder() { return new software.amazon.lambda.durable.config.WaitForCallbackConfig.Builder(); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java b/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java index 774cfa022..97d516ee1 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java @@ -56,8 +56,7 @@ public class ExecutionManager implements AutoCloseable { private final AtomicReference executionMode; // ===== Thread Coordination ===== - private final Map> registeredOperations = - Collections.synchronizedMap(new HashMap<>()); + private final Map registeredOperations = Collections.synchronizedMap(new HashMap<>()); private final Set activeThreads = Collections.synchronizedSet(new HashSet<>()); private static final ThreadLocal currentThreadContext = new ThreadLocal<>(); private final CompletableFuture executionExceptionFuture = new CompletableFuture<>(); @@ -107,7 +106,7 @@ public boolean isReplaying() { } /** Registers an operation so it can receive checkpoint completion notifications. */ - public void registerOperation(BaseDurableOperation operation) { + public void registerOperation(BaseDurableOperation operation) { registeredOperations.put(operation.getOperationId(), operation); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java index f891f5e9b..21ac06349 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java @@ -8,24 +8,18 @@ import java.util.concurrent.CompletableFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.services.lambda.model.ErrorObject; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.OperationUpdate; -import software.amazon.lambda.durable.DurableFuture; -import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.IllegalDurableOperationException; import software.amazon.lambda.durable.exception.NonDeterministicExecutionException; -import software.amazon.lambda.durable.exception.SerDesException; import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.execution.ThreadContext; import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; -import software.amazon.lambda.durable.serde.SerDes; -import software.amazon.lambda.durable.util.ExceptionHelper; /** * Base class for all durable operations (STEP, WAIT, etc.). @@ -45,34 +39,24 @@ *

  • Proper thread coordination via future * */ -public abstract class BaseDurableOperation implements DurableFuture { +public abstract class BaseDurableOperation { private static final Logger logger = LoggerFactory.getLogger(BaseDurableOperation.class); private final OperationIdentifier operationIdentifier; - private final ExecutionManager executionManager; - private final TypeToken resultTypeToken; - private final SerDes resultSerDes; - protected final CompletableFuture completionFuture; + protected final ExecutionManager executionManager; + protected final CompletableFuture completionFuture; private final DurableContextImpl durableContext; /** * Constructs a new durable operation. * * @param operationIdentifier the unique identifier for this operation - * @param resultTypeToken the type token for deserializing the result - * @param resultSerDes the serializer/deserializer for the result * @param durableContext the parent context this operation belongs to */ - protected BaseDurableOperation( - OperationIdentifier operationIdentifier, - TypeToken resultTypeToken, - SerDes resultSerDes, - DurableContextImpl durableContext) { + protected BaseDurableOperation(OperationIdentifier operationIdentifier, DurableContextImpl durableContext) { this.operationIdentifier = operationIdentifier; this.durableContext = durableContext; this.executionManager = durableContext.getExecutionManager(); - this.resultTypeToken = resultTypeToken; - this.resultSerDes = resultSerDes; this.completionFuture = new CompletableFuture<>(); @@ -80,6 +64,10 @@ protected BaseDurableOperation( executionManager.registerOperation(this); } + protected CompletableFuture getCompletionFuture() { + return completionFuture; + } + /** Gets the operation sub-type (e.g. RUN_IN_CHILD_CONTEXT, WAIT_FOR_CALLBACK). */ public OperationSubType getSubType() { return operationIdentifier.subType(); @@ -214,7 +202,7 @@ protected Operation waitForOperationCompletion() { // Get result based on status var op = getOperation(); if (op == null) { - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( String.format("%s operation not found: %s", getType(), getOperationId())); } return op; @@ -238,7 +226,7 @@ public void onCheckpointComplete(Operation operation) { // to the future. In our case, other contexts may have attached a function to reactivate themselves, // so they will definitely have a chance to reactivate before we finish completing and deactivating // whatever operations were just checkpointed. - completionFuture.complete(null); + completionFuture.complete(this); } } } @@ -252,7 +240,7 @@ protected void markAlreadyCompleted() { // It's important that we synchronize access to the future, otherwise the processing could happen // on someone else's thread and cause a race condition. synchronized (completionFuture) { - completionFuture.complete(null); + completionFuture.complete(this); } } @@ -262,7 +250,7 @@ protected void markAlreadyCompleted() { * @param exception the unrecoverable exception * @return never returns normally; always throws */ - protected T terminateExecution(UnrecoverableDurableExecutionException exception) { + protected RuntimeException terminateExecution(UnrecoverableDurableExecutionException exception) { executionManager.terminateExecution(exception); // Exception is already thrown from above. Keep the throw statement below to make tests happy throw exception; @@ -274,7 +262,7 @@ protected T terminateExecution(UnrecoverableDurableExecutionException exception) * @param message the error message * @return never returns normally; always throws */ - protected T terminateExecutionWithIllegalDurableOperationException(String message) { + protected RuntimeException terminateExecutionWithIllegalDurableOperationException(String message) { return terminateExecution(new IllegalDurableOperationException(message)); } @@ -322,82 +310,6 @@ protected CompletableFuture sendOperationUpdateAsync(OperationUpdate.Build return executionManager.sendOperationUpdate(updateBuilder.build()); } - /** - * Deserializes a result string into the operation's result type. - * - * @param result the serialized result string - * @return the deserialized result - * @throws SerDesException if deserialization fails - */ - protected T deserializeResult(String result) { - try { - return resultSerDes.deserialize(result, resultTypeToken); - } catch (SerDesException e) { - logger.warn( - "Failed to deserialize {} result for operation name '{}'. Ensure the result is properly encoded.", - getType(), - getName()); - throw e; - } - } - - /** - * Serializes the result to a string. - * - * @param result the result to serialize - * @return the serialized string - */ - protected String serializeResult(T result) { - return resultSerDes.serialize(result); - } - - /** - * Serializes a throwable into an {@link ErrorObject} for checkpointing. - * - * @param throwable the exception to serialize - * @return the serialized error object - */ - protected ErrorObject serializeException(Throwable throwable) { - return ExceptionHelper.buildErrorObject(throwable, resultSerDes); - } - - /** - * Deserializes an {@link ErrorObject} back into a throwable, reconstructing the original exception type and stack - * trace when possible. Falls back to null if the exception class is not found or deserialization fails. - * - * @param errorObject the serialized error object - * @return the reconstructed throwable, or null if reconstruction is not possible - */ - protected Throwable deserializeException(ErrorObject errorObject) { - Throwable original = null; - if (errorObject == null) { - return original; - } - var errorType = errorObject.errorType(); - var errorData = errorObject.errorData(); - - if (errorType == null) { - return original; - } - try { - - Class exceptionClass = Class.forName(errorType); - if (Throwable.class.isAssignableFrom(exceptionClass)) { - original = - resultSerDes.deserialize(errorData, TypeToken.get(exceptionClass.asSubclass(Throwable.class))); - - if (original != null) { - original.setStackTrace(ExceptionHelper.deserializeStackTrace(errorObject.stackTrace())); - } - } - } catch (ClassNotFoundException e) { - logger.warn("Cannot re-construct original exception type. Falling back to generic StepFailedException."); - } catch (SerDesException e) { - logger.warn("Cannot deserialize original exception data. Falling back to generic StepFailedException.", e); - } - return original; - } - /** Validates that current operation matches checkpointed operation during replay. */ protected void validateReplay(Operation checkpointed) { if (checkpointed == null || checkpointed.type() == null) { diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java index a9b8d06a0..40f921320 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java @@ -15,7 +15,7 @@ import software.amazon.lambda.durable.model.OperationIdentifier; /** Durable operation for creating and waiting on external callbacks. */ -public class CallbackOperation extends BaseDurableOperation implements DurableCallbackFuture { +public class CallbackOperation extends SerializableDurableOperation implements DurableCallbackFuture { private final CallbackConfig config; @@ -80,7 +80,8 @@ public T get() { case FAILED -> throw new CallbackFailedException(op); case TIMED_OUT -> throw new CallbackTimeoutException(op); default -> - terminateExecutionWithIllegalDurableOperationException("Unexpected callback status: " + op.status()); + throw terminateExecutionWithIllegalDurableOperationException( + "Unexpected callback status: " + op.status()); }; } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java index b25173598..e03c3709c 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java @@ -41,7 +41,7 @@ * on completion via {@code onItemComplete()} BEFORE closing its own child context. It also skips checkpointing if the * parent operation has already succeeded. */ -public class ChildContextOperation extends BaseDurableOperation { +public class ChildContextOperation extends SerializableDurableOperation { private static final int LARGE_RESULT_THRESHOLD = 256 * 1024; @@ -98,7 +98,7 @@ protected void replay(Operation existing) { case FAILED -> markAlreadyCompleted(); case STARTED -> executeChildContext(); default -> - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( "Unexpected child context status: " + existing.status()); } } @@ -106,9 +106,6 @@ protected void replay(Operation existing) { @Override protected void markAlreadyCompleted() { super.markAlreadyCompleted(); - if (parentOperation != null) { - parentOperation.onItemComplete(this); - } } private void executeChildContext() { @@ -142,10 +139,6 @@ private void executeChildContext() { handleChildContextSuccess(result); } catch (Throwable e) { handleChildContextFailure(e); - } finally { - if (parentOperation != null) { - parentOperation.onItemComplete(this); - } } } }; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index bef8c3a48..57ce20a67 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -6,8 +6,11 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Queue; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -43,7 +46,7 @@ * * @param the result type of this operation */ -public abstract class ConcurrencyOperation extends BaseDurableOperation { +public abstract class ConcurrencyOperation extends SerializableDurableOperation { private static final Logger logger = LoggerFactory.getLogger(ConcurrencyOperation.class); @@ -52,13 +55,21 @@ public abstract class ConcurrencyOperation extends BaseDurableOperation { private final Integer toleratedFailureCount; private final AtomicInteger succeededCount = new AtomicInteger(0); private final AtomicInteger failedCount = new AtomicInteger(0); - private final AtomicInteger runningCount = new AtomicInteger(0); protected final AtomicBoolean isJoined = new AtomicBoolean(false); private final Queue> pendingQueue = new ConcurrentLinkedDeque<>(); private final List> branches = Collections.synchronizedList(new ArrayList<>()); + private final Map, Boolean> runningChildren = new ConcurrentHashMap<>(); private final Set completedOperations = Collections.synchronizedSet(new HashSet<>()); private final OperationIdGenerator operationIdGenerator; private final DurableContextImpl rootContext; + private volatile CompletableFuture vacancyListener; + + private record NewBranchItem( + String name, + Function function, + TypeToken resultType, + SerDes serDes, + OperationSubType branchSubType) {} protected ConcurrencyOperation( OperationIdentifier operationIdentifier, @@ -113,31 +124,9 @@ protected ChildContextOperation createItem( // ========== Concurrency control ========== /** - * Adds a new item to this concurrency operation. Creates the child operation and either starts it immediately or - * enqueues it if maxConcurrency is reached. - * - * @param name the name of the item - * @param function the user function to execute - * @param resultType the result type token - * @param serDes the serializer/deserializer - * @param the result type of the child operation - * @return the created ChildContextOperation - */ - public ChildContextOperation addItem( - String name, - Function function, - TypeToken resultType, - SerDes serDes, - OperationSubType branchSubType) { - var childOp = enqueueItem(name, function, resultType, serDes, branchSubType); - executeNextItemIfAllowed(); - return childOp; - } - - /** - * Creates and enqueues an item without starting execution. Use {@link #executeNextItemIfAllowed()} to begin - * execution after all items have been enqueued. This prevents early termination from blocking item creation when - * all items are known upfront (e.g., map operations). + * Creates and enqueues an item without starting execution. Use {@link #executeItems()} to begin execution after all + * items have been enqueued. This prevents early termination from blocking item creation when all items are known + * upfront (e.g., map operations). */ protected ChildContextOperation enqueueItem( String name, @@ -150,26 +139,69 @@ protected ChildContextOperation enqueueItem( branches.add(childOp); pendingQueue.add(childOp); logger.debug("Item enqueued {}", name); + if (vacancyListener != null) { + vacancyListener.complete(null); + } return childOp; } - /** - * Starts the queued items if the running count is below maxConcurrency and the operation hasn't completed yet. Must - * be called within {@code synchronized (pendingQueue)}. - */ - protected void executeNextItemIfAllowed() { + protected void executeItems() { // Start as many items as concurrency allows - while (true) { - synchronized (this) { - if (isOperationCompleted()) return; - if (runningCount.get() >= maxConcurrency) return; - var next = pendingQueue.poll(); - if (next == null) return; - runningCount.incrementAndGet(); - logger.debug("Executing operation {}", next.getName()); - next.execute(); + var contextId = getOperationId(); + registerActiveThread(contextId); + + Runnable handler = () -> { + try (var context = getContext().createChildContext(contextId, getName())) { + while (true) { + if (isOperationCompleted()) { + return; + } + while (runningChildren.size() < maxConcurrency) { + if (vacancyListener != null && vacancyListener.isDone()) { + vacancyListener = null; + } + var next = pendingQueue.poll(); + if (next == null) { + break; + } + runningChildren.put(next, true); + logger.debug("Executing operation {}", next.getName()); + next.execute(); + } + var child = waitForChildCompletion(); + if (runningChildren.containsKey(child)) { + onItemComplete((ChildContextOperation) child); + } + } + } + }; + CompletableFuture.runAsync(handler, getContext().getDurableConfig().getExecutorService()); + } + + private BaseDurableOperation waitForChildCompletion() { + var threadContext = getCurrentThreadContext(); + CompletableFuture future; + + synchronized (this) { + ArrayList> futures; + futures = new ArrayList<>(runningChildren.keySet().stream() + .map(BaseDurableOperation::getCompletionFuture) + .toList()); + if (futures.size() < maxConcurrency) { + vacancyListener = new CompletableFuture<>(); + futures.add(vacancyListener); + } + + // future will be completed immediately if any future of the list is already completed + future = CompletableFuture.anyOf(futures.toArray(CompletableFuture[]::new)); + // skip deregistering the current thread if there is more completed future to process + if (!future.isDone()) { + future.thenRun(() -> registerActiveThread(threadContext.threadId())); + // Deregister the current thread to allow suspension + executionManager.deregisterActiveThread(threadContext.threadId()); } } + return future.thenApply(o -> (BaseDurableOperation) o).join(); } /** @@ -180,7 +212,7 @@ protected void executeNextItemIfAllowed() { */ public void onItemComplete(ChildContextOperation child) { if (!completedOperations.add(child.getOperationId())) { - return; + throw new IllegalStateException("Child operation " + child.getOperationId() + " completed twice"); } // Evaluate child result outside the lock — child.get() may block waiting for a checkpoint response. @@ -203,13 +235,14 @@ public void onItemComplete(ChildContextOperation child) { } else { failedCount.incrementAndGet(); } - runningCount.decrementAndGet(); + if (!runningChildren.containsKey(child)) { + throw new IllegalStateException("Child operation " + child.getOperationId() + " completed twice"); + } + runningChildren.remove(child); var completionStatus = canComplete(); if (completionStatus != null) { handleComplete(completionStatus); - } else { - executeNextItemIfAllowed(); } } } @@ -289,6 +322,6 @@ protected List> getBranches() { /** Returns true if all items have finished (no pending, no running). Used by subclasses to override canComplete. */ protected boolean isAllItemsFinished() { - return isJoined.get() && pendingQueue.isEmpty() && runningCount.get() == 0; + return isJoined.get() && pendingQueue.isEmpty() && runningChildren.isEmpty(); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java index 05bedec6f..ac53e0cce 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java @@ -22,7 +22,7 @@ * @param the result type from the invoked function * @param the payload type sent to the invoked function */ -public class InvokeOperation extends BaseDurableOperation { +public class InvokeOperation extends SerializableDurableOperation { private final String functionName; private final I payload; private final InvokeConfig invokeConfig; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java index c5188b88d..474570cae 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -70,6 +70,23 @@ public MapOperation( addAllItems(); } + private void addAllItems() { + // Enqueue all items first, then start execution. This prevents early termination + // criteria (e.g., minSuccessful) from completing the operation mid-loop on replay, + // which would cause subsequent enqueue calls to fail with "completed operation". + var branchPrefix = getName() == null ? "map-iteration-" : getName() + "-iteration-"; + for (int i = 0; i < items.size(); i++) { + var index = i; + var item = items.get(i); + enqueueItem( + branchPrefix + i, + childCtx -> function.apply(item, index, childCtx), + itemResultType, + serDes, + OperationSubType.MAP_ITERATION); + } + } + private static Integer getToleratedFailureCount(CompletionConfig completionConfig, int totalItems) { if (completionConfig == null || (completionConfig.toleratedFailureCount() == null @@ -95,7 +112,7 @@ protected void start() { .action(OperationAction.START) .subType(getSubType().getValue())); - executeNextItemIfAllowed(); + executeItems(); } @Override @@ -105,7 +122,7 @@ protected void replay(Operation existing) { if (existing.contextDetails() != null && Boolean.TRUE.equals(existing.contextDetails().replayChildren())) { // Large result: re-execute children to reconstruct MapResult - executeNextItemIfAllowed(); + executeItems(); } else { // Small result: MapResult is in the payload, skip child replay replayFromPayload = true; @@ -115,31 +132,14 @@ protected void replay(Operation existing) { case STARTED -> { // Map was in progress when interrupted — re-create children without sending // another START (the backend rejects duplicate START for existing operations) - executeNextItemIfAllowed(); + executeItems(); } default -> - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( "Unexpected map operation status: " + existing.status()); } } - private void addAllItems() { - // Enqueue all items first, then start execution. This prevents early termination - // criteria (e.g., minSuccessful) from completing the operation mid-loop on replay, - // which would cause subsequent enqueue calls to fail with "completed operation". - var branchPrefix = getName() == null ? "map-iteration-" : getName() + "-iteration-"; - for (int i = 0; i < items.size(); i++) { - var index = i; - var item = items.get(i); - enqueueItem( - branchPrefix + i, - childCtx -> function.apply(item, index, childCtx), - itemResultType, - serDes, - OperationSubType.MAP_ITERATION); - } - } - @SuppressWarnings("unchecked") @Override protected void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus) { diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index 7b28766ad..c47d5d695 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -90,6 +90,8 @@ protected void start() { sendOperationUpdateAsync(OperationUpdate.builder() .action(OperationAction.START) .subType(getSubType().getValue())); + + executeItems(); } @Override @@ -97,6 +99,7 @@ protected void replay(Operation existing) { // No-op: child branches handle their own replay via ChildContextOperation.replay(). // Set replaying=true so handleSuccess() skips re-checkpointing the already-completed parallel context. skipCheckpoint = ExecutionManager.isTerminalStatus(existing.status()); + executeItems(); } @Override @@ -117,6 +120,6 @@ public DurableFuture branch( throw new IllegalStateException("Cannot add branches after join() has been called"); } var serDes = config.serDes() == null ? getContext().getDurableConfig().getSerDes() : config.serDes(); - return addItem(name, func, resultType, serDes, OperationSubType.PARALLEL_BRANCH); + return enqueueItem(name, func, resultType, serDes, OperationSubType.PARALLEL_BRANCH); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java new file mode 100644 index 000000000..bcb964691 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java @@ -0,0 +1,183 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.operation; + +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.lambda.model.ErrorObject; +import software.amazon.awssdk.services.lambda.model.Operation; +import software.amazon.lambda.durable.DurableFuture; +import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.context.DurableContextImpl; +import software.amazon.lambda.durable.exception.IllegalDurableOperationException; +import software.amazon.lambda.durable.exception.NonDeterministicExecutionException; +import software.amazon.lambda.durable.exception.SerDesException; +import software.amazon.lambda.durable.execution.ThreadType; +import software.amazon.lambda.durable.model.OperationIdentifier; +import software.amazon.lambda.durable.serde.SerDes; +import software.amazon.lambda.durable.util.ExceptionHelper; + +/** + * Base class for all durable operations (STEP, WAIT, etc.). + * + *

    Key methods: + * + *

      + *
    • {@code execute()} starts the operation (returns immediately) + *
    • {@code get()} blocks until complete and returns the result + *
    + * + *

    The separation allows: + * + *

      + *
    • Starting multiple async operations quickly + *
    • Blocking on results later when needed + *
    • Proper thread coordination via future + *
    + */ +public abstract class SerializableDurableOperation extends BaseDurableOperation implements DurableFuture { + private static final Logger logger = LoggerFactory.getLogger(SerializableDurableOperation.class); + + private final TypeToken resultTypeToken; + private final SerDes resultSerDes; + + /** + * Constructs a new durable operation. + * + * @param operationIdentifier the unique identifier for this operation + * @param resultTypeToken the type token for deserializing the result + * @param resultSerDes the serializer/deserializer for the result + * @param durableContext the parent context this operation belongs to + */ + protected SerializableDurableOperation( + OperationIdentifier operationIdentifier, + TypeToken resultTypeToken, + SerDes resultSerDes, + DurableContextImpl durableContext) { + super(operationIdentifier, durableContext); + this.resultTypeToken = resultTypeToken; + this.resultSerDes = resultSerDes; + } + + /** + * Checks if it's called from a Step. + * + * @throws IllegalDurableOperationException if it's in a step + */ + private void validateCurrentThreadType() { + ThreadType current = getCurrentThreadContext().threadType(); + if (current == ThreadType.STEP) { + var message = String.format( + "Nested %s operation is not supported on %s from within a %s execution.", + getType(), getName(), current); + // terminate execution and throw the exception + terminateExecutionWithIllegalDurableOperationException(message); + } + } + + /** + * Deserializes a result string into the operation's result type. + * + * @param result the serialized result string + * @return the deserialized result + * @throws SerDesException if deserialization fails + */ + protected T deserializeResult(String result) { + try { + return resultSerDes.deserialize(result, resultTypeToken); + } catch (SerDesException e) { + logger.warn( + "Failed to deserialize {} result for operation name '{}'. Ensure the result is properly encoded.", + getType(), + getName()); + throw e; + } + } + + /** + * Serializes the result to a string. + * + * @param result the result to serialize + * @return the serialized string + */ + protected String serializeResult(T result) { + return resultSerDes.serialize(result); + } + + /** + * Serializes a throwable into an {@link ErrorObject} for checkpointing. + * + * @param throwable the exception to serialize + * @return the serialized error object + */ + protected ErrorObject serializeException(Throwable throwable) { + return ExceptionHelper.buildErrorObject(throwable, resultSerDes); + } + + /** + * Deserializes an {@link ErrorObject} back into a throwable, reconstructing the original exception type and stack + * trace when possible. Falls back to null if the exception class is not found or deserialization fails. + * + * @param errorObject the serialized error object + * @return the reconstructed throwable, or null if reconstruction is not possible + */ + protected Throwable deserializeException(ErrorObject errorObject) { + Throwable original = null; + if (errorObject == null) { + return original; + } + var errorType = errorObject.errorType(); + var errorData = errorObject.errorData(); + + if (errorType == null) { + return original; + } + try { + + Class exceptionClass = Class.forName(errorType); + if (Throwable.class.isAssignableFrom(exceptionClass)) { + original = + resultSerDes.deserialize(errorData, TypeToken.get(exceptionClass.asSubclass(Throwable.class))); + + if (original != null) { + original.setStackTrace(ExceptionHelper.deserializeStackTrace(errorObject.stackTrace())); + } + } + } catch (ClassNotFoundException e) { + logger.warn("Cannot re-construct original exception type. Falling back to generic StepFailedException."); + } catch (SerDesException e) { + logger.warn("Cannot deserialize original exception data. Falling back to generic StepFailedException.", e); + } + return original; + } + + /** Validates that current operation matches checkpointed operation during replay. */ + protected void validateReplay(Operation checkpointed) { + if (checkpointed == null || checkpointed.type() == null) { + return; // First execution, no validation needed + } + + if (!checkpointed.type().equals(getType())) { + terminateExecution(new NonDeterministicExecutionException(String.format( + "Operation type mismatch for \"%s\". Expected %s, got %s", + getOperationId(), checkpointed.type(), getType()))); + } + + if (!Objects.equals(checkpointed.name(), getName())) { + terminateExecution(new NonDeterministicExecutionException(String.format( + "Operation name mismatch for \"%s\". Expected \"%s\", got \"%s\"", + getOperationId(), checkpointed.name(), getName()))); + } + + if ((getSubType() == null && checkpointed.subType() != null) + || getSubType() != null + && !Objects.equals(checkpointed.subType(), getSubType().getValue())) { + terminateExecution(new NonDeterministicExecutionException(String.format( + "Operation subType mismatch for \"%s\". Expected \"%s\", got \"%s\"", + getOperationId(), checkpointed.subType(), getSubType()))); + } + } + + public abstract T get(); +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java index 901766a32..dbac63256 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java @@ -34,7 +34,7 @@ * * @param the result type of the step function */ -public class StepOperation extends BaseDurableOperation { +public class StepOperation extends SerializableDurableOperation { private static final Integer FIRST_ATTEMPT = 0; private final Function function; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java index 2645ac2e6..0ac30a88c 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java @@ -34,7 +34,7 @@ * * @param the type of state being polled */ -public class WaitForConditionOperation extends BaseDurableOperation { +public class WaitForConditionOperation extends SerializableDurableOperation { private final BiFunction> checkFunc; private final WaitForConditionConfig config; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java index e9493ba31..a51521071 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java @@ -11,11 +11,9 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.awssdk.services.lambda.model.WaitOptions; -import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.model.OperationIdentifier; -import software.amazon.lambda.durable.serde.NoopSerDes; -import software.amazon.lambda.durable.serde.SerDes; /** * Durable operation that suspends execution for a specified duration without consuming compute. @@ -23,16 +21,15 @@ *

    The wait is checkpointed and the Lambda is suspended. On re-invocation after the wait period, execution resumes * from where it left off. */ -public class WaitOperation extends BaseDurableOperation { +public class WaitOperation extends BaseDurableOperation implements DurableFuture { private static final Logger logger = LoggerFactory.getLogger(WaitOperation.class); - private static final SerDes NOOP_SER_DES = new NoopSerDes(); private final Duration duration; public WaitOperation( OperationIdentifier operationIdentifier, Duration duration, DurableContextImpl durableContext) { - super(operationIdentifier, TypeToken.get(Void.class), NOOP_SER_DES, durableContext); + super(operationIdentifier, durableContext); this.duration = duration; } @@ -80,7 +77,6 @@ private void pollForWaitExpiration() { pollForOperationUpdates(remainingWaitTime); } - @Override public Void get() { waitForOperationCompletion(); diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurableFutureTest.java b/sdk/src/test/java/software/amazon/lambda/durable/DurableFutureTest.java index 7292269e4..91338f9b1 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurableFutureTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/DurableFutureTest.java @@ -7,7 +7,7 @@ import java.util.List; import org.junit.jupiter.api.Test; -import software.amazon.lambda.durable.operation.BaseDurableOperation; +import software.amazon.lambda.durable.operation.SerializableDurableOperation; class DurableFutureTest { @@ -63,15 +63,15 @@ void allOfSingleFutureReturnsSingleResult() { void allOfPropagatesException() { var op1 = mockOperation("first"); @SuppressWarnings("unchecked") - BaseDurableOperation op2 = mock(BaseDurableOperation.class); + SerializableDurableOperation op2 = mock(SerializableDurableOperation.class); when(op2.get()).thenThrow(new RuntimeException("Step failed")); assertThrows(RuntimeException.class, () -> DurableFuture.allOf(op1, op2)); } @SuppressWarnings("unchecked") - private BaseDurableOperation mockOperation(T result) { - BaseDurableOperation op = mock(BaseDurableOperation.class); + private SerializableDurableOperation mockOperation(T result) { + SerializableDurableOperation op = mock(SerializableDurableOperation.class); when(op.get()).thenReturn(result); return op; } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java index d06d23bb6..75395e2a9 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java @@ -132,16 +132,16 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { var functionCalled = new AtomicBoolean(false); var op = createOperation(CompletionConfig.allSuccessful()); - op.addItem( + op.enqueueItem( "branch-1", - ctx -> { + ctx1 -> { functionCalled.set(true); return "result-1"; }, TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); - op.addItem( + op.enqueueItem( "branch-2", ctx -> { functionCalled.set(true); @@ -176,7 +176,7 @@ void singleChildAlreadySucceeds_fullCycle() throws Exception { var functionCalled = new AtomicBoolean(false); var op = createOperation(CompletionConfig.minSuccessful(1)); - op.addItem( + op.enqueueItem( "only-branch", ctx -> { functionCalled.set(true); diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index 9f52cf4c1..b30171a1a 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -73,9 +73,9 @@ void setUp() { when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null); // Capture registered operations so we can drive onCheckpointComplete callbacks. - var registeredOps = new ConcurrentHashMap>(); + var registeredOps = new ConcurrentHashMap(); doAnswer(inv -> { - BaseDurableOperation op = inv.getArgument(0); + BaseDurableOperation op = inv.getArgument(0); registeredOps.put(op.getOperationId(), op); return null; }) @@ -130,7 +130,7 @@ private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGene void branchCreation_createsBranchWithParallelBranchSubType() { var op = createOperation(CompletionConfig.allSuccessful()); - var childOp = op.addItem( + var childOp = op.enqueueItem( "branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); assertNotNull(childOp); @@ -141,9 +141,11 @@ void branchCreation_createsBranchWithParallelBranchSubType() { void branchCreation_multipleBranchesAllCreated() { var op = createOperation(CompletionConfig.allSuccessful()); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); - op.addItem("branch-3", ctx -> "r3", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem( + "branch-1", ctx2 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem( + "branch-2", ctx1 -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem("branch-3", ctx -> "r3", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); assertEquals(3, op.getBranches().size()); } @@ -153,7 +155,7 @@ void branchCreation_childOperationHasParentReference() throws Exception { var op = createOperation(CompletionConfig.allSuccessful()); // The child operation should be a ChildContextOperation with this op as parent - var childOp = op.addItem( + var childOp = op.enqueueItem( "branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); assertNotNull(childOp); @@ -188,8 +190,9 @@ void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws E var op = createOperation(CompletionConfig.allSuccessful()); setOperationIdGenerator(op, mockIdGenerator); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem( + "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); @@ -218,7 +221,7 @@ void minSuccessful_completesWhenThresholdMetAndReturnsResult() throws Exception var op = createOperation(CompletionConfig.minSuccessful(1)); setOperationIdGenerator(op, mockIdGenerator); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); @@ -238,7 +241,7 @@ void contextHierarchy_branchesUseParallelContextAsParent() throws Exception { // as their parent — not some other context var op = createOperation(CompletionConfig.allSuccessful()); - var childOp = op.addItem( + var childOp = op.enqueueItem( "branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); // The child operation should be registered in the execution manager @@ -283,8 +286,9 @@ void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exc var op = createOperation(CompletionConfig.allSuccessful()); setOperationIdGenerator(op, mockIdGenerator); op.execute(); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem( + "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); @@ -332,8 +336,9 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio var op = createOperation(CompletionConfig.allSuccessful()); setOperationIdGenerator(op, mockIdGenerator); op.execute(); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem( + "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); @@ -362,7 +367,7 @@ void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() throws Except var op = createOperation(CompletionConfig.allSuccessful()); setOperationIdGenerator(op, mockIdGenerator); - op.addItem( + op.enqueueItem( "branch-1", ctx -> { throw new RuntimeException("branch failed"); @@ -406,8 +411,9 @@ void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exceptio // toleratedFailureCount=1 so the operation completes after both branches finish var op = createOperation(CompletionConfig.toleratedFailureCount(1)); setOperationIdGenerator(op, mockIdGenerator); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); - op.addItem( + op.enqueueItem( + "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem( "branch-2", ctx -> { throw new RuntimeException("branch failed"); diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/BaseDurableOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/SerializableDurableOperationTest.java similarity index 85% rename from sdk/src/test/java/software/amazon/lambda/durable/operation/BaseDurableOperationTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/operation/SerializableDurableOperationTest.java index 8df56e796..6ae8cc890 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/BaseDurableOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/SerializableDurableOperationTest.java @@ -39,7 +39,7 @@ import software.amazon.lambda.durable.serde.JacksonSerDes; import software.amazon.lambda.durable.serde.SerDes; -class BaseDurableOperationTest { +class SerializableDurableOperationTest { private static final String OPERATION_ID = "1"; private static final String CONTEXT_ID = "1-step"; @@ -67,8 +67,8 @@ void setUp() { @Test void getOperation() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} @@ -91,8 +91,8 @@ public String get() { @Test void waitForOperationCompletionThrowsIfOperationMissing() { when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(null); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { markAlreadyCompleted(); @@ -115,8 +115,8 @@ public String get() { @Test void waitForOperationCompletionWhenRunningAndReadyToComplete() throws InterruptedException, ExecutionException, TimeoutException { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} @@ -147,8 +147,8 @@ public String get() { @Test void waitForOperationCompletionWhenAlreadyCompleted() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { markAlreadyCompleted(); @@ -171,8 +171,8 @@ public String get() { @Test void markAlreadyCompleted() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { markAlreadyCompleted(); @@ -198,8 +198,8 @@ void validateReplayThrowsWhenTypeMismatch() { .thenReturn( Operation.builder().type(OperationType.CHAINED_INVOKE).build()); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { validateReplay(getOperation()); @@ -225,8 +225,8 @@ void validateReplayThrowsWhenNameMismatch() { .type(OPERATION_TYPE) .build()); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { validateReplay(getOperation()); @@ -248,8 +248,8 @@ public String get() { void validateReplayDoesNotThrowWhenNoOperation() { when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(null); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { validateReplay(getOperation()); @@ -274,8 +274,8 @@ void validateReplayDoesNotThrowWhenNameAndTypeMatch() { .type(OPERATION_TYPE) .build()); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { validateReplay(getOperation()); @@ -294,8 +294,8 @@ public String get() { @Test void deserializeResult() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} @@ -315,8 +315,8 @@ public String get() { @Test void deserializeException() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} @@ -341,8 +341,8 @@ public String get() { @Test void polling() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} @@ -365,8 +365,8 @@ public String get() { void sendOperationUpdate() { var update = OperationUpdate.builder(); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} From c3316dd03800e434eae62675ce82473e7783d929 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sun, 22 Mar 2026 16:34:04 -0700 Subject: [PATCH 11/15] fix race conditions in concurrency operation --- .../durable/context/BaseContextImpl.java | 48 ----- .../durable/context/DurableContextImpl.java | 36 +--- .../durable/context/StepContextImpl.java | 1 - .../durable/execution/DurableExecutor.java | 1 + .../operation/BaseDurableOperation.java | 55 ++++- .../durable/operation/CallbackOperation.java | 2 +- .../operation/ChildContextOperation.java | 17 +- .../operation/ConcurrencyOperation.java | 193 ++++++++---------- .../durable/operation/InvokeOperation.java | 2 +- .../SerializableDurableOperation.java | 32 +-- .../durable/operation/StepOperation.java | 15 +- .../operation/WaitForConditionOperation.java | 106 +++++----- .../durable/context/BaseContextImplTest.java | 50 ----- .../context/DurableContextImplTest.java | 41 ---- .../operation/ChildContextOperationTest.java | 134 ------------ .../operation/ConcurrencyOperationTest.java | 36 ++-- .../operation/ParallelOperationTest.java | 74 +++---- 17 files changed, 253 insertions(+), 590 deletions(-) diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java index 2cee25d09..9920366f4 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java @@ -5,8 +5,6 @@ import com.amazonaws.services.lambda.runtime.Context; import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.execution.ExecutionManager; -import software.amazon.lambda.durable.execution.SuspendExecutionException; -import software.amazon.lambda.durable.execution.ThreadContext; import software.amazon.lambda.durable.execution.ThreadType; public abstract class BaseContextImpl implements AutoCloseable, BaseContext { @@ -36,28 +34,6 @@ protected BaseContextImpl( String contextId, String contextName, ThreadType threadType) { - this(executionManager, durableConfig, lambdaContext, contextId, contextName, threadType, true); - } - - /** - * Creates a new BaseContext instance. - * - * @param executionManager the execution manager for thread coordination and state management - * @param durableConfig the durable execution configuration - * @param lambdaContext the AWS Lambda runtime context - * @param contextId the context ID, null for root context, set for child contexts - * @param contextName the human-readable name for this context - * @param threadType the type of thread this context runs on - * @param setCurrentThreadContext whether to call setCurrentThreadContext on the execution manager - */ - protected BaseContextImpl( - ExecutionManager executionManager, - DurableConfig durableConfig, - Context lambdaContext, - String contextId, - String contextName, - ThreadType threadType, - boolean setCurrentThreadContext) { this.executionManager = executionManager; this.durableConfig = durableConfig; this.lambdaContext = lambdaContext; @@ -65,11 +41,6 @@ protected BaseContextImpl( this.contextName = contextName; this.isReplaying = executionManager.hasOperationsForContext(contextId); this.threadType = threadType; - - if (setCurrentThreadContext) { - // write the thread id and type to thread local - executionManager.setCurrentThreadContext(new ThreadContext(contextId, threadType)); - } } // =============== accessors ================ @@ -138,23 +109,4 @@ public boolean isReplaying() { public void setExecutionMode() { this.isReplaying = false; } - - @Override - public void close() { - // this is called in the user thread, after the context's user code has completed - if (getContextId() != null) { - // if this is a child context or a step context, we need to - // deregister the context's thread from the execution manager - try { - executionManager.deregisterActiveThread(getContextId()); - } catch (SuspendExecutionException e) { - // Expected when this is the last active thread. Must catch here because: - // 1/ This runs in a worker thread detached from handlerFuture - // 2/ Uncaught exception would prevent stepAsync().get() from resume - // Suspension/Termination is already signaled via - // suspendExecutionFuture/terminateExecutionFuture - // before the throw. - } - } - } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java index 3e8e29c5c..ac2eb6af6 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java @@ -67,24 +67,7 @@ private DurableContextImpl( Context lambdaContext, String contextId, String contextName) { - this(executionManager, durableConfig, lambdaContext, contextId, contextName, true); - } - - private DurableContextImpl( - ExecutionManager executionManager, - DurableConfig durableConfig, - Context lambdaContext, - String contextId, - String contextName, - boolean setCurrentThreadContext) { - super( - executionManager, - durableConfig, - lambdaContext, - contextId, - contextName, - ThreadType.CONTEXT, - setCurrentThreadContext); + super(executionManager, durableConfig, lambdaContext, contextId, contextName, ThreadType.CONTEXT); operationIdGenerator = new OperationIdGenerator(contextId); } @@ -115,22 +98,6 @@ public DurableContextImpl createChildContext(String childContextId, String child getExecutionManager(), getDurableConfig(), getLambdaContext(), childContextId, childContextName); } - /** - * Creates a child context without setting the current thread context. - * - *

    Use this when the child context is being created on a thread that should not have its thread-local context - * overwritten (e.g. when constructing the context ahead of running it on a separate thread). - * - * @param childContextId the child context's ID (the CONTEXT operation's operation ID) - * @param childContextName the name of the child context - * @return a new DurableContext for the child context - */ - public DurableContextImpl createChildContextWithoutSettingThreadContext( - String childContextId, String childContextName) { - return new DurableContextImpl( - getExecutionManager(), getDurableConfig(), getLambdaContext(), childContextId, childContextName, false); - } - /** * Creates a step context for executing step operations. * @@ -418,7 +385,6 @@ public void close() { if (logger != null) { logger.close(); } - super.close(); } /** diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/StepContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/StepContextImpl.java index af5c9222a..dcf5af66b 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/StepContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/StepContextImpl.java @@ -66,6 +66,5 @@ public void close() { if (logger != null) { logger.close(); } - super.close(); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/execution/DurableExecutor.java b/sdk/src/main/java/software/amazon/lambda/durable/execution/DurableExecutor.java index c7d0c490c..28608e8dd 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/execution/DurableExecutor.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/execution/DurableExecutor.java @@ -51,6 +51,7 @@ public static DurableExecutionOutput execute( executionManager.registerActiveThread(null); var handlerFuture = CompletableFuture.supplyAsync( () -> { + executionManager.setCurrentThreadContext(new ThreadContext(null, ThreadType.CONTEXT)); var userInput = extractUserInput( executionManager.getExecutionOperation(), config.getSerDes(), inputType); // use try-with-resources to clear logger properties diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java index 21ac06349..69f81f9e0 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java @@ -6,6 +6,7 @@ import java.util.List; import java.util.Objects; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.lambda.model.Operation; @@ -16,6 +17,7 @@ import software.amazon.lambda.durable.exception.NonDeterministicExecutionException; import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; import software.amazon.lambda.durable.execution.ExecutionManager; +import software.amazon.lambda.durable.execution.SuspendExecutionException; import software.amazon.lambda.durable.execution.ThreadContext; import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; @@ -46,6 +48,7 @@ public abstract class BaseDurableOperation { protected final ExecutionManager executionManager; protected final CompletableFuture completionFuture; private final DurableContextImpl durableContext; + private final AtomicReference> runningUserHandler = new AtomicReference<>(null); /** * Constructs a new durable operation. @@ -152,7 +155,7 @@ private void validateCurrentThreadType() { "Nested %s operation is not supported on %s from within a %s execution.", getType(), getName(), current); // terminate execution and throw the exception - terminateExecutionWithIllegalDurableOperationException(message); + throw terminateExecutionWithIllegalDurableOperationException(message); } } @@ -208,6 +211,50 @@ protected Operation waitForOperationCompletion() { return op; } + protected void runUserHandler(Runnable runnable, String contextId, ThreadType threadType) { + Runnable wrapped = () -> { + executionManager.setCurrentThreadContext(new ThreadContext(contextId, threadType)); + try { + runnable.run(); + } finally { + if (contextId != null) { + try { + // if this is a child context or a step context, we need to + // deregister the context's thread from the execution manager + executionManager.deregisterActiveThread(contextId); + } catch (SuspendExecutionException e) { + // Expected when this is the last active thread. Must catch here because: + // 1/ This runs in a worker thread detached from handlerFuture + // 2/ Uncaught exception would prevent stepAsync().get() from resume + // Suspension/Termination is already signaled via + // suspendExecutionFuture/terminateExecutionFuture + // before the throw. + } + } + } + }; + + // runUserHandler is used to ensure that only one user handler is running at a time + if (runningUserHandler.get() != null) { + throw new IllegalStateException("User handler already running"); + } + + // Thread registration is intentionally split across two threads: + // 1. registerActiveThread on the PARENT thread — ensures the child is tracked before the + // parent can deregister and trigger suspension (race prevention). + // 2. setCurrentContext on the CHILD thread — sets the ThreadLocal so operations inside + // the child context know which context they belong to. + // registerActiveThread is idempotent (no-op if already registered). + registerActiveThread(contextId); + + if (!runningUserHandler.compareAndSet( + null, + CompletableFuture.runAsync( + wrapped, getContext().getDurableConfig().getExecutorService()))) { + throw new IllegalStateException("User handler already running"); + } + } + /** * Receives operation updates from ExecutionManager. Completes the internal future when the operation reaches a * terminal status, unblocking any threads waiting on this operation. @@ -317,13 +364,13 @@ protected void validateReplay(Operation checkpointed) { } if (!checkpointed.type().equals(getType())) { - terminateExecution(new NonDeterministicExecutionException(String.format( + throw terminateExecution(new NonDeterministicExecutionException(String.format( "Operation type mismatch for \"%s\". Expected %s, got %s", getOperationId(), checkpointed.type(), getType()))); } if (!Objects.equals(checkpointed.name(), getName())) { - terminateExecution(new NonDeterministicExecutionException(String.format( + throw terminateExecution(new NonDeterministicExecutionException(String.format( "Operation name mismatch for \"%s\". Expected \"%s\", got \"%s\"", getOperationId(), checkpointed.name(), getName()))); } @@ -331,7 +378,7 @@ protected void validateReplay(Operation checkpointed) { if ((getSubType() == null && checkpointed.subType() != null) || getSubType() != null && !Objects.equals(checkpointed.subType(), getSubType().getValue())) { - terminateExecution(new NonDeterministicExecutionException(String.format( + throw terminateExecution(new NonDeterministicExecutionException(String.format( "Operation subType mismatch for \"%s\". Expected \"%s\", got \"%s\"", getOperationId(), checkpointed.subType(), getSubType()))); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java index 40f921320..9d9481fb9 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java @@ -65,7 +65,7 @@ protected void replay(Operation existing) { // Still waiting - continue to polling } default -> - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( "Unexpected callback status: " + existing.status()); } pollForOperationUpdates(); diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java index e03c3709c..ed93c8515 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java @@ -5,8 +5,6 @@ import static software.amazon.lambda.durable.execution.ExecutionManager.isTerminalStatus; import java.nio.charset.StandardCharsets; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; import java.util.function.Function; import software.amazon.awssdk.services.lambda.model.ContextOptions; import software.amazon.awssdk.services.lambda.model.ErrorObject; @@ -28,6 +26,7 @@ import software.amazon.lambda.durable.exception.StepInterruptedException; import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; import software.amazon.lambda.durable.execution.SuspendExecutionException; +import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.util.ExceptionHelper; @@ -46,7 +45,6 @@ public class ChildContextOperation extends SerializableDurableOperation { private static final int LARGE_RESULT_THRESHOLD = 256 * 1024; private final Function function; - private final ExecutorService userExecutor; private final ConcurrencyOperation parentOperation; private boolean replayChildContext; private T reconstructedResult; @@ -69,7 +67,6 @@ public ChildContextOperation( ConcurrencyOperation parentOperation) { super(operationIdentifier, resultTypeToken, config.serDes(), durableContext); this.function = function; - this.userExecutor = getContext().getDurableConfig().getExecutorService(); this.parentOperation = parentOperation; } @@ -116,14 +113,6 @@ private void executeChildContext() { // third level child context "hash(hash(hash(1)-2)-1)". var contextId = getOperationId(); - // Thread registration is intentionally split across two threads: - // 1. registerActiveThread on the PARENT thread — ensures the child is tracked before the - // parent can deregister and trigger suspension (race prevention). - // 2. setCurrentContext on the CHILD thread — sets the ThreadLocal so operations inside - // the child context know which context they belong to. - // registerActiveThread is idempotent (no-op if already registered). - registerActiveThread(contextId); - Runnable userHandler = () -> { // use a try-with-resources to // - add thread id/type to thread local when the step starts @@ -144,7 +133,7 @@ private void executeChildContext() { }; // Execute user provided child context code in user-configured executor - CompletableFuture.runAsync(userHandler, userExecutor); + runUserHandler(userHandler, contextId, ThreadType.CONTEXT); } private void handleChildContextSuccess(T result) { @@ -192,7 +181,7 @@ private void handleChildContextFailure(Throwable exception) { } if (exception instanceof UnrecoverableDurableExecutionException unrecoverableDurableExecutionException) { // terminate the execution and throw the exception if it's not recoverable - terminateExecution(unrecoverableDurableExecutionException); + throw terminateExecution(unrecoverableDurableExecutionException); } // Skip checkpointing if parent ConcurrencyOperation has already completed — diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index 57ce20a67..d02dfa438 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -6,14 +6,13 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Queue; import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -23,6 +22,7 @@ import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.OperationIdGenerator; +import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; @@ -53,23 +53,20 @@ public abstract class ConcurrencyOperation extends SerializableDurableOperati private final int maxConcurrency; private final Integer minSuccessful; private final Integer toleratedFailureCount; - private final AtomicInteger succeededCount = new AtomicInteger(0); - private final AtomicInteger failedCount = new AtomicInteger(0); - protected final AtomicBoolean isJoined = new AtomicBoolean(false); - private final Queue> pendingQueue = new ConcurrentLinkedDeque<>(); - private final List> branches = Collections.synchronizedList(new ArrayList<>()); - private final Map, Boolean> runningChildren = new ConcurrentHashMap<>(); - private final Set completedOperations = Collections.synchronizedSet(new HashSet<>()); private final OperationIdGenerator operationIdGenerator; private final DurableContextImpl rootContext; - private volatile CompletableFuture vacancyListener; - private record NewBranchItem( - String name, - Function function, - TypeToken resultType, - SerDes serDes, - OperationSubType branchSubType) {} + // access by context thread only + private final List> branches = Collections.synchronizedList(new ArrayList<>()); + + // put only by context thread and consume only by consumer thread + private final Queue> pendingQueue = new ConcurrentLinkedDeque<>(); + + // set by context thread and used by consumer thread + protected final AtomicBoolean isJoined = new AtomicBoolean(false); + + // used to wake up consumer thread for either new items or checking completion condition (isJoined changed) + private final AtomicReference> consumerThreadListener; protected ConcurrencyOperation( OperationIdentifier operationIdentifier, @@ -84,7 +81,8 @@ protected ConcurrencyOperation( this.minSuccessful = minSuccessful; this.toleratedFailureCount = toleratedFailureCount; this.operationIdGenerator = new OperationIdGenerator(getOperationId()); - this.rootContext = durableContext.createChildContextWithoutSettingThreadContext(getOperationId(), getName()); + this.rootContext = durableContext.createChildContext(getOperationId(), getName()); + this.consumerThreadListener = new AtomicReference<>(null); } // ========== Template methods for subclasses ========== @@ -139,57 +137,84 @@ protected ChildContextOperation enqueueItem( branches.add(childOp); pendingQueue.add(childOp); logger.debug("Item enqueued {}", name); - if (vacancyListener != null) { - vacancyListener.complete(null); - } + // notify the consumer thread a new item is available + completeVacancyListenerIfSet(); return childOp; } + private void completeVacancyListenerIfSet() { + synchronized (this) { + if (consumerThreadListener.get() != null) { + consumerThreadListener.get().complete(null); + } + } + } + + /** Starts execution of all enqueued items. */ protected void executeItems() { - // Start as many items as concurrency allows - var contextId = getOperationId(); - registerActiveThread(contextId); + // variables accessed only by the consumer thread. Put them here to avoid accidentally used by other threads + Set runningChildren = new HashSet<>(); + AtomicInteger succeededCount = new AtomicInteger(0); + AtomicInteger failedCount = new AtomicInteger(0); - Runnable handler = () -> { - try (var context = getContext().createChildContext(contextId, getName())) { - while (true) { - if (isOperationCompleted()) { - return; - } - while (runningChildren.size() < maxConcurrency) { - if (vacancyListener != null && vacancyListener.isDone()) { - vacancyListener = null; - } - var next = pendingQueue.poll(); - if (next == null) { - break; - } - runningChildren.put(next, true); - logger.debug("Executing operation {}", next.getName()); - next.execute(); + Runnable consumer = () -> { + while (true) { + if (isOperationCompleted()) { + return; + } + var completionStatus = canComplete(succeededCount, failedCount, runningChildren); + if (completionStatus != null) { + handleComplete(completionStatus); + return; + } + while (runningChildren.size() < maxConcurrency && !pendingQueue.isEmpty()) { + var next = pendingQueue.poll(); + runningChildren.add(next); + logger.debug("Executing operation {}", next.getName()); + next.execute(); + } + var child = waitForChildCompletion(succeededCount, failedCount, runningChildren); + if (child != null) { + if (runningChildren.contains(child)) { + runningChildren.remove(child); + onItemComplete(succeededCount, failedCount, (ChildContextOperation) child); + } else { + throw new IllegalStateException("Unexpected completion: " + child); } - var child = waitForChildCompletion(); - if (runningChildren.containsKey(child)) { - onItemComplete((ChildContextOperation) child); + } + synchronized (this) { + if (consumerThreadListener.get() != null + && consumerThreadListener.get().isDone()) { + consumerThreadListener.set(null); } } } }; - CompletableFuture.runAsync(handler, getContext().getDurableConfig().getExecutorService()); + // run consumer in the user thread pool, although it's not a real user thread + runUserHandler(consumer, getOperationId(), ThreadType.CONTEXT); } - private BaseDurableOperation waitForChildCompletion() { + private BaseDurableOperation waitForChildCompletion( + AtomicInteger succeededCount, AtomicInteger failedCount, Set runningChildren) { var threadContext = getCurrentThreadContext(); CompletableFuture future; synchronized (this) { + // check again in synchronized block to prevent race conditions + if (isOperationCompleted()) { + return null; + } + var completionStatus = canComplete(succeededCount, failedCount, runningChildren); + if (completionStatus != null) { + return null; + } ArrayList> futures; - futures = new ArrayList<>(runningChildren.keySet().stream() + futures = new ArrayList<>(runningChildren.stream() .map(BaseDurableOperation::getCompletionFuture) .toList()); if (futures.size() < maxConcurrency) { - vacancyListener = new CompletableFuture<>(); - futures.add(vacancyListener); + consumerThreadListener.compareAndSet(null, new CompletableFuture<>()); + futures.add(consumerThreadListener.get()); } // future will be completed immediately if any future of the list is already completed @@ -210,63 +235,28 @@ private BaseDurableOperation waitForChildCompletion() { * * @param child the child operation that completed */ - public void onItemComplete(ChildContextOperation child) { - if (!completedOperations.add(child.getOperationId())) { - throw new IllegalStateException("Child operation " + child.getOperationId() + " completed twice"); - } - + private void onItemComplete( + AtomicInteger succeededCount, AtomicInteger failedCount, ChildContextOperation child) { // Evaluate child result outside the lock — child.get() may block waiting for a checkpoint response. logger.debug("OnItemComplete called by {}, Id: {}", child.getName(), child.getOperationId()); - boolean succeeded; try { child.get(); logger.debug("Result succeeded - {}", child.getName()); - succeeded = true; + succeededCount.incrementAndGet(); } catch (Throwable e) { logger.debug("Child operation {} failed: {}", child.getOperationId(), e.getMessage()); - succeeded = false; - } - - // Counter updates, completion check, and next-item dispatch must be atomic to prevent - // the main thread's join() from seeing runningCount==0 with incomplete counters. - synchronized (this) { - if (succeeded) { - succeededCount.incrementAndGet(); - } else { - failedCount.incrementAndGet(); - } - if (!runningChildren.containsKey(child)) { - throw new IllegalStateException("Child operation " + child.getOperationId() + " completed twice"); - } - runningChildren.remove(child); - - var completionStatus = canComplete(); - if (completionStatus != null) { - handleComplete(completionStatus); - } + failedCount.incrementAndGet(); } } // ========== Completion logic ========== - /** - * Validates that the number of registered items is sufficient to satisfy the completion criteria. Called at join() - * time because branches are registered incrementally and the total count is only known once the user calls join(). - * - * @throws IllegalArgumentException if the item count cannot satisfy the criteria - */ - protected void validateItemCount() { - if (minSuccessful != null && minSuccessful > branches.size()) { - throw new IllegalStateException("minSuccessful (" + minSuccessful - + ") exceeds the number of registered items (" + branches.size() + ")"); - } - } - /** * Checks whether the concurrency operation can be considered complete. * * @return the completion status if the operation is complete, or null if it should continue */ - protected ConcurrencyCompletionStatus canComplete() { + private ConcurrencyCompletionStatus canComplete( + AtomicInteger succeededCount, AtomicInteger failedCount, Set runningChildren) { int succeeded = succeededCount.get(); int failed = failedCount.get(); @@ -281,7 +271,7 @@ protected ConcurrencyCompletionStatus canComplete() { } // All items finished — complete - if (isAllItemsFinished()) { + if (isJoined.get() && pendingQueue.isEmpty() && runningChildren.isEmpty()) { return ConcurrencyCompletionStatus.ALL_COMPLETED; } @@ -301,27 +291,18 @@ private void handleComplete(ConcurrencyCompletionStatus status) { * Blocks the calling thread until the concurrency operation reaches a terminal state. Validates item count, handles * zero-branch case, then delegates to {@code waitForOperationCompletion()} from BaseDurableOperation. */ - public void join() { - validateItemCount(); - isJoined.set(true); - synchronized (this) { - if (!isOperationCompleted()) { - var completionStatus = canComplete(); - if (completionStatus != null) { - handleComplete(completionStatus); - } - } + protected void join() { + if (minSuccessful != null && minSuccessful > branches.size()) { + throw new IllegalStateException("minSuccessful (" + minSuccessful + + ") exceeds the number of registered items (" + branches.size() + ")"); } - + isJoined.set(true); + // notify the execution thread this concurrency operation is joined + completeVacancyListenerIfSet(); waitForOperationCompletion(); } protected List> getBranches() { return branches; } - - /** Returns true if all items have finished (no pending, no running). Used by subclasses to override canComplete. */ - protected boolean isAllItemsFinished() { - return isJoined.get() && pendingQueue.isEmpty() && runningChildren.isEmpty(); - } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java index ac53e0cce..9e2c54ace 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java @@ -58,7 +58,7 @@ protected void replay(Operation existing) { case STARTED -> pollForOperationUpdates(); case SUCCEEDED, FAILED, TIMED_OUT, STOPPED -> markAlreadyCompleted(); default -> - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( "Unexpected invoke status: " + existing.statusAsString()); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java index bcb964691..c86d6c263 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java @@ -2,16 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 package software.amazon.lambda.durable.operation; -import java.util.Objects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.lambda.model.ErrorObject; -import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.IllegalDurableOperationException; -import software.amazon.lambda.durable.exception.NonDeterministicExecutionException; import software.amazon.lambda.durable.exception.SerDesException; import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; @@ -72,7 +69,7 @@ private void validateCurrentThreadType() { "Nested %s operation is not supported on %s from within a %s execution.", getType(), getName(), current); // terminate execution and throw the exception - terminateExecutionWithIllegalDurableOperationException(message); + throw terminateExecutionWithIllegalDurableOperationException(message); } } @@ -152,32 +149,5 @@ protected Throwable deserializeException(ErrorObject errorObject) { return original; } - /** Validates that current operation matches checkpointed operation during replay. */ - protected void validateReplay(Operation checkpointed) { - if (checkpointed == null || checkpointed.type() == null) { - return; // First execution, no validation needed - } - - if (!checkpointed.type().equals(getType())) { - terminateExecution(new NonDeterministicExecutionException(String.format( - "Operation type mismatch for \"%s\". Expected %s, got %s", - getOperationId(), checkpointed.type(), getType()))); - } - - if (!Objects.equals(checkpointed.name(), getName())) { - terminateExecution(new NonDeterministicExecutionException(String.format( - "Operation name mismatch for \"%s\". Expected \"%s\", got \"%s\"", - getOperationId(), checkpointed.name(), getName()))); - } - - if ((getSubType() == null && checkpointed.subType() != null) - || getSubType() != null - && !Objects.equals(checkpointed.subType(), getSubType().getValue())) { - terminateExecution(new NonDeterministicExecutionException(String.format( - "Operation subType mismatch for \"%s\". Expected \"%s\", got \"%s\"", - getOperationId(), checkpointed.subType(), getSubType()))); - } - } - public abstract T get(); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java index dbac63256..7ff4b9ae0 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java @@ -5,7 +5,6 @@ import java.time.Duration; import java.time.Instant; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; import java.util.function.Function; import software.amazon.awssdk.services.lambda.model.ErrorObject; import software.amazon.awssdk.services.lambda.model.Operation; @@ -23,6 +22,7 @@ import software.amazon.lambda.durable.exception.StepInterruptedException; import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; import software.amazon.lambda.durable.execution.SuspendExecutionException; +import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.util.ExceptionHelper; @@ -39,7 +39,6 @@ public class StepOperation extends SerializableDurableOperation { private final Function function; private final StepConfig config; - private final ExecutorService userExecutor; public StepOperation( OperationIdentifier operationIdentifier, @@ -51,7 +50,6 @@ public StepOperation( this.function = function; this.config = config; - this.userExecutor = durableContext.getDurableConfig().getExecutorService(); } /** Starts the operation. */ @@ -82,7 +80,8 @@ protected void replay(Operation existing) { // Execute with current attempt case READY -> executeStepLogic(attempt); default -> - terminateExecutionWithIllegalDurableOperationException("Unexpected step status: " + existing.status()); + throw terminateExecutionWithIllegalDurableOperationException( + "Unexpected step status: " + existing.status()); } } @@ -96,10 +95,6 @@ private CompletableFuture pollReadyAndExecuteStepLogic(Operation existing, } private void executeStepLogic(int attempt) { - // Register step thread as active BEFORE executor runs (prevents suspension when handler deregisters). - // The thread local ThreadContext is set inside the executor since that's where the step actually runs - registerActiveThread(getOperationId()); - Runnable userHandler = () -> { // use a try-with-resources to // - add thread id/type to thread local when the step starts @@ -119,7 +114,7 @@ private void executeStepLogic(int attempt) { }; // Execute user provided step code in user-configured executor - CompletableFuture.runAsync(userHandler, userExecutor); + runUserHandler(userHandler, getOperationId(), ThreadType.STEP); } private void checkpointStarted() { @@ -155,7 +150,7 @@ private void handleStepFailure(Throwable exception, int attempt) { } if (exception instanceof UnrecoverableDurableExecutionException unrecoverableDurableExecutionException) { // terminate the execution and throw the exception if it's not recoverable - terminateExecution(unrecoverableDurableExecutionException); + throw terminateExecution(unrecoverableDurableExecutionException); } final ErrorObject errorObject; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java index 0ac30a88c..ac3258077 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java @@ -20,6 +20,7 @@ import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; import software.amazon.lambda.durable.exception.WaitForConditionException; import software.amazon.lambda.durable.execution.SuspendExecutionException; +import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; import software.amazon.lambda.durable.model.WaitForConditionResult; @@ -73,7 +74,7 @@ protected void replay(Operation existing) { case PENDING -> pollReadyAndResumeCheckLoop(existing); // Check if pending retry case STARTED, READY -> resumeCheckLoop(existing); default -> - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( "Unexpected waitForCondition status: " + existing.status()); } } @@ -121,59 +122,56 @@ private CompletableFuture pollReadyAndResumeCheckLoop(Operation existing) } private void executeCheckLogic(T currentState, int attempt) { - // Register thread as active BEFORE executor runs - registerActiveThread(getOperationId()); - - CompletableFuture.runAsync( - () -> { - try (var stepContext = getContext().createStepContext(getOperationId(), getName(), attempt)) { - try { - // Checkpoint START if not already started - var existing = getOperation(); - if (existing == null || existing.status() != OperationStatus.STARTED) { - var startUpdate = OperationUpdate.builder().action(OperationAction.START); - sendOperationUpdateAsync(startUpdate); - } - - // Execute check function in user executor - WaitForConditionResult result = checkFunc.apply(currentState, stepContext); - - // Serialize/deserialize round-trip on the value to ensure state is checkpoint-safe - var serializedState = serializeResult(result.value()); - T deserializedValue = deserializeResult(serializedState); - - if (result.isDone()) { - // Condition met — checkpoint SUCCEED - var successUpdate = OperationUpdate.builder() - .action(OperationAction.SUCCEED) - .payload(serializedState); - sendOperationUpdate(successUpdate); - } else { - // Compute delay from strategy - Duration delay = config.waitStrategy().evaluate(deserializedValue, attempt); - - // Checkpoint RETRY with delay - var retryUpdate = OperationUpdate.builder() - .action(OperationAction.RETRY) - .payload(serializedState) - .stepOptions(StepOptions.builder() - .nextAttemptDelaySeconds(Math.toIntExact(delay.toSeconds())) - .build()); - sendOperationUpdate(retryUpdate); - - // Poll for READY, then continue the loop - pollForOperationUpdates() - .thenCompose(op -> op.status() == OperationStatus.READY - ? CompletableFuture.completedFuture(op) - : pollForOperationUpdates()) - .thenRun(() -> executeCheckLogic(deserializedValue, attempt + 1)); - } - } catch (Throwable e) { - handleCheckFailure(e); - } + Runnable userHandler = () -> { + try (var stepContext = getContext().createStepContext(getOperationId(), getName(), attempt)) { + try { + // Checkpoint START if not already started + var existing = getOperation(); + if (existing == null || existing.status() != OperationStatus.STARTED) { + var startUpdate = OperationUpdate.builder().action(OperationAction.START); + sendOperationUpdateAsync(startUpdate); } - }, - userExecutor); + + // Execute check function in user executor + WaitForConditionResult result = checkFunc.apply(currentState, stepContext); + + // Serialize/deserialize round-trip on the value to ensure state is checkpoint-safe + var serializedState = serializeResult(result.value()); + T deserializedValue = deserializeResult(serializedState); + + if (result.isDone()) { + // Condition met — checkpoint SUCCEED + var successUpdate = OperationUpdate.builder() + .action(OperationAction.SUCCEED) + .payload(serializedState); + sendOperationUpdate(successUpdate); + } else { + // Compute delay from strategy + Duration delay = config.waitStrategy().evaluate(deserializedValue, attempt); + + // Checkpoint RETRY with delay + var retryUpdate = OperationUpdate.builder() + .action(OperationAction.RETRY) + .payload(serializedState) + .stepOptions(StepOptions.builder() + .nextAttemptDelaySeconds(Math.toIntExact(delay.toSeconds())) + .build()); + sendOperationUpdate(retryUpdate); + + // Poll for READY, then continue the loop + pollForOperationUpdates() + .thenCompose(op -> op.status() == OperationStatus.READY + ? CompletableFuture.completedFuture(op) + : pollForOperationUpdates()) + .thenRun(() -> executeCheckLogic(deserializedValue, attempt + 1)); + } + } catch (Throwable e) { + handleCheckFailure(e); + } + } + }; + + runUserHandler(userHandler, getOperationId(), ThreadType.STEP); } private void handleCheckFailure(Throwable exception) { @@ -182,7 +180,7 @@ private void handleCheckFailure(Throwable exception) { throw suspendExecutionException; } if (exception instanceof UnrecoverableDurableExecutionException unrecoverable) { - terminateExecution(unrecoverable); + throw terminateExecution(unrecoverable); } final var errorObject = (exception instanceof DurableOperationException durableOpEx) diff --git a/sdk/src/test/java/software/amazon/lambda/durable/context/BaseContextImplTest.java b/sdk/src/test/java/software/amazon/lambda/durable/context/BaseContextImplTest.java index c0d321ed5..76b59e663 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/context/BaseContextImplTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/context/BaseContextImplTest.java @@ -15,8 +15,6 @@ import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.execution.ExecutionManager; -import software.amazon.lambda.durable.execution.ThreadContext; -import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.DurableExecutionInput; class BaseContextImplTest { @@ -59,11 +57,6 @@ void defaultConstructor_setsCurrentThreadContext() { // Creating a root context with the default constructor should set the thread context DurableContextImpl.createRootContext( executionManager, DurableConfig.builder().build(), null); - - var threadContext = executionManager.getCurrentThreadContext(); - assertNotNull(threadContext); - assertEquals(ThreadType.CONTEXT, threadContext.threadType()); - assertNull(threadContext.threadId()); } @Test @@ -73,51 +66,8 @@ void constructorWithSetCurrentThreadContextTrue_setsCurrentThreadContext() { // createRootContext sets thread context to root (threadId=null) var rootContext = DurableContextImpl.createRootContext( executionManager, DurableConfig.builder().build(), null); - assertEquals( - ThreadType.CONTEXT, executionManager.getCurrentThreadContext().threadType()); - assertNull(executionManager.getCurrentThreadContext().threadId()); // createChildContext (setCurrentThreadContext=true) should overwrite with child's context rootContext.createChildContext("child-id", "child-name"); - - var threadContext = executionManager.getCurrentThreadContext(); - assertNotNull(threadContext); - assertEquals(ThreadType.CONTEXT, threadContext.threadType()); - assertEquals("child-id", threadContext.threadId()); - } - - @Test - void constructorWithSetCurrentThreadContextFalse_doesNotOverwriteThreadContext() { - var executionManager = createExecutionManager(); - - // Create root context first (it will set thread context to null/root) - var rootContext = DurableContextImpl.createRootContext( - executionManager, DurableConfig.builder().build(), null); - - // Now set a sentinel — simulating a caller thread that already has context established - var sentinel = new ThreadContext("original-context", ThreadType.CONTEXT); - executionManager.setCurrentThreadContext(sentinel); - - // createChildContextWithoutSettingThreadContext should NOT overwrite the sentinel - rootContext.createChildContextWithoutSettingThreadContext("child-id", "child-name"); - - // Thread context should still be the sentinel, not the child's context - var threadContext = executionManager.getCurrentThreadContext(); - assertNotNull(threadContext); - assertEquals("original-context", threadContext.threadId()); - } - - @Test - void createChildContextWithoutSettingThreadContext_returnsValidChildContext() { - var executionManager = createExecutionManager(); - executionManager.setCurrentThreadContext(new ThreadContext(null, ThreadType.CONTEXT)); - var rootContext = DurableContextImpl.createRootContext( - executionManager, DurableConfig.builder().build(), null); - - var childContext = rootContext.createChildContextWithoutSettingThreadContext("child-id", "child-name"); - - assertNotNull(childContext); - assertEquals("child-id", childContext.getContextId()); - assertEquals("child-name", childContext.getContextName()); } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/context/DurableContextImplTest.java b/sdk/src/test/java/software/amazon/lambda/durable/context/DurableContextImplTest.java index e3a5f732a..2d6f6c4e3 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/context/DurableContextImplTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/context/DurableContextImplTest.java @@ -7,7 +7,6 @@ import java.util.ArrayList; import java.util.List; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.lambda.model.CheckpointUpdatedExecutionState; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationStatus; @@ -50,44 +49,4 @@ void setUp() { rootContext = DurableContextImpl.createRootContext( executionManager, DurableConfig.builder().build(), null); } - - @Test - void createChildContext_setsThreadContextToChild() { - rootContext.createChildContext("child-1", "my-child"); - - var threadContext = executionManager.getCurrentThreadContext(); - assertNotNull(threadContext); - assertEquals("child-1", threadContext.threadId()); - assertEquals(ThreadType.CONTEXT, threadContext.threadType()); - } - - @Test - void createChildContextWithoutSettingThreadContext_preservesCallerThreadContext() { - var callerContext = new ThreadContext("caller-thread", ThreadType.CONTEXT); - executionManager.setCurrentThreadContext(callerContext); - - rootContext.createChildContextWithoutSettingThreadContext("child-1", "my-child"); - - // Thread context must remain unchanged - var threadContext = executionManager.getCurrentThreadContext(); - assertEquals("caller-thread", threadContext.threadId()); - } - - @Test - void createChildContextWithoutSettingThreadContext_returnsCorrectChildMetadata() { - var child = rootContext.createChildContextWithoutSettingThreadContext("child-42", "child-name"); - - assertEquals("child-42", child.getContextId()); - assertEquals("child-name", child.getContextName()); - } - - @Test - void createChildContextWithoutSettingThreadContext_whenNoThreadContextSet_leavesItNull() { - // Clear any existing thread context - executionManager.setCurrentThreadContext(null); - - rootContext.createChildContextWithoutSettingThreadContext("child-1", "my-child"); - - assertNull(executionManager.getCurrentThreadContext()); - } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java index 60ae3025f..92be2e3ef 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java @@ -269,40 +269,6 @@ void replayWithNameMismatchTerminatesExecution() { // ===== Parent ConcurrencyOperation support ===== - /** Parent's onItemComplete() is called when child succeeds. */ - @Test - void parentOnItemCompleteCalledOnChildSuccess() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")).thenReturn(null); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent(ctx -> "success", parent); - operation.execute(); - Thread.sleep(200); - - verify(parent).onItemComplete(operation); - } - - /** Parent's onItemComplete() is called when child fails. */ - @Test - void parentOnItemCompleteCalledOnChildFailure() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")).thenReturn(null); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent( - ctx -> { - throw new RuntimeException("branch failed"); - }, - parent); - operation.execute(); - Thread.sleep(200); - - verify(parent).onItemComplete(operation); - } - /** Child skips success checkpoint when parent operation has already completed. */ @Test void childSkipsSuccessCheckpointWhenParentAlreadyCompleted() throws Exception { @@ -340,104 +306,4 @@ void childSkipsFailureCheckpointWhenParentAlreadyCompleted() throws Exception { verify(executionManager, never()) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.FAIL)); } - - // ===== onItemComplete called during replay ===== - - /** SUCCEEDED replay (terminal) — onItemComplete is called via markAlreadyCompleted(). */ - @Test - void replaySucceeded_callsParentOnItemComplete() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")) - .thenReturn(Operation.builder() - .id("1") - .name("test-context") - .type(OperationType.CONTEXT) - .subType(OperationSubType.RUN_IN_CHILD_CONTEXT.getValue()) - .status(OperationStatus.SUCCEEDED) - .contextDetails( - ContextDetails.builder().result("\"cached\"").build()) - .build()); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent(ctx -> "unused", parent); - operation.execute(); - - verify(parent).onItemComplete(operation); - } - - /** FAILED replay (terminal) — onItemComplete is called via markAlreadyCompleted(). */ - @Test - void replayFailed_callsParentOnItemComplete() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")) - .thenReturn(Operation.builder() - .id("1") - .name("test-context") - .type(OperationType.CONTEXT) - .subType(OperationSubType.RUN_IN_CHILD_CONTEXT.getValue()) - .status(OperationStatus.FAILED) - .contextDetails(ContextDetails.builder() - .error(ErrorObject.builder() - .errorType("java.lang.RuntimeException") - .errorMessage("original failure") - .build()) - .build()) - .build()); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent(ctx -> "unused", parent); - operation.execute(); - - verify(parent).onItemComplete(operation); - } - - /** STARTED replay — child re-executes and onItemComplete is called from the finally block. */ - @Test - void replayStarted_callsParentOnItemComplete() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")) - .thenReturn(Operation.builder() - .id("1") - .name("test-context") - .type(OperationType.CONTEXT) - .subType(OperationSubType.RUN_IN_CHILD_CONTEXT.getValue()) - .status(OperationStatus.STARTED) - .build()); - when(executionManager.hasOperationsForContext("1")).thenReturn(false); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent(ctx -> "re-executed", parent); - operation.execute(); - Thread.sleep(200); - - verify(parent).onItemComplete(operation); - } - - /** replayChildren=true — child re-executes and onItemComplete is called from the finally block. */ - @Test - void replayChildren_callsParentOnItemComplete() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")) - .thenReturn(Operation.builder() - .id("1") - .name("test-context") - .type(OperationType.CONTEXT) - .subType(OperationSubType.RUN_IN_CHILD_CONTEXT.getValue()) - .status(OperationStatus.SUCCEEDED) - .contextDetails( - ContextDetails.builder().replayChildren(true).build()) - .build()); - when(executionManager.hasOperationsForContext("1")).thenReturn(false); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent(ctx -> "reconstructed", parent); - operation.execute(); - Thread.sleep(200); - - verify(parent, atLeast(1)).onItemComplete(operation); - } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java index 75395e2a9..a453ceae9 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java @@ -19,6 +19,7 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.lambda.durable.DurableConfig; +import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.config.CompletionConfig; import software.amazon.lambda.durable.context.DurableContextImpl; @@ -36,19 +37,18 @@ class ConcurrencyOperationTest { private static final SerDes SER_DES = new JacksonSerDes(); private static final String OPERATION_ID = "op-1"; + private static final String CHILD_OP_1 = TestUtils.hashOperationId(OPERATION_ID + "-1"); + private static final String CHILD_OP_2 = TestUtils.hashOperationId(OPERATION_ID + "-2"); private static final TypeToken RESULT_TYPE = TypeToken.get(Void.class); private DurableContextImpl durableContext; private DurableContextImpl childContext; private ExecutionManager executionManager; - private AtomicInteger operationIdCounter; - private OperationIdGenerator mockIdGenerator; @BeforeEach void setUp() { durableContext = mock(DurableContextImpl.class); executionManager = mock(ExecutionManager.class); - operationIdCounter = new AtomicInteger(0); var childContext = mock(DurableContextImpl.class); this.childContext = childContext; @@ -64,11 +64,7 @@ void setUp() { .withExecutorService(Executors.newCachedThreadPool()) .build()); when(durableContext.createChildContext(anyString(), anyString())).thenReturn(childContext); - when(durableContext.createChildContextWithoutSettingThreadContext(anyString(), anyString())) - .thenReturn(childContext); when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("Root", ThreadType.CONTEXT)); - mockIdGenerator = mock(OperationIdGenerator.class); - when(mockIdGenerator.nextOperationId()).thenAnswer(inv -> "child-" + operationIdCounter.incrementAndGet()); // All child operations are NOT in replay when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null); // Simulate the real backend: the parent concurrency operation is available in storage after completion @@ -86,7 +82,7 @@ void setUp() { } private TestConcurrencyOperation createOperation(CompletionConfig completionConfig) throws Exception { - TestConcurrencyOperation testConcurrencyOperation = new TestConcurrencyOperation( + return new TestConcurrencyOperation( OperationIdentifier.of( OPERATION_ID, "test-concurrency", OperationType.CONTEXT, OperationSubType.PARALLEL), RESULT_TYPE, @@ -94,8 +90,6 @@ private TestConcurrencyOperation createOperation(CompletionConfig completionConf durableContext, Integer.MAX_VALUE, completionConfig); - setOperationIdGenerator(testConcurrencyOperation, mockIdGenerator); - return testConcurrencyOperation; } private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGenerator mockGenerator) @@ -109,9 +103,9 @@ private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGene @Test void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -119,9 +113,9 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { .contextDetails( ContextDetails.builder().result("\"result-1\"").build()) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-2")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2)) .thenReturn(Operation.builder() - .id("child-2") + .id(CHILD_OP_2) .name("branch-2") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -132,6 +126,7 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { var functionCalled = new AtomicBoolean(false); var op = createOperation(CompletionConfig.allSuccessful()); + op.execute(); op.enqueueItem( "branch-1", ctx1 -> { @@ -163,9 +158,9 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { @Test void singleChildAlreadySucceeds_fullCycle() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("only-branch") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -186,6 +181,7 @@ void singleChildAlreadySucceeds_fullCycle() throws Exception { SER_DES, OperationSubType.PARALLEL_BRANCH); + op.execute(); op.exposedJoin(); assertTrue(op.isSuccessHandled()); @@ -233,10 +229,14 @@ protected void handleSuccess(ConcurrencyCompletionStatus completionStatus) { } @Override - protected void start() {} + protected void start() { + executeItems(); + } @Override - protected void replay(Operation existing) {} + protected void replay(Operation existing) { + executeItems(); + } @Override public Void get() { diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index b30171a1a..a7693591b 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -12,7 +12,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.lambda.model.ContextDetails; @@ -21,6 +20,7 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.lambda.durable.DurableConfig; +import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.config.CompletionConfig; import software.amazon.lambda.durable.config.ParallelConfig; @@ -39,17 +39,19 @@ class ParallelOperationTest { private static final SerDes SER_DES = new JacksonSerDes(); private static final String OPERATION_ID = "parallel-op-1"; + private static final String CHILD_OP_1 = TestUtils.hashOperationId(OPERATION_ID + "-1"); + private static final String CHILD_OP_2 = TestUtils.hashOperationId(OPERATION_ID + "-2"); private DurableContextImpl durableContext; private ExecutionManager executionManager; - private AtomicInteger operationIdCounter; - private OperationIdGenerator mockIdGenerator; @BeforeEach void setUp() { durableContext = mock(DurableContextImpl.class); executionManager = mock(ExecutionManager.class); - operationIdCounter = new AtomicInteger(0); + + when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext(null, ThreadType.CONTEXT)); + when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null); var childContext = mock(DurableContextImpl.class); when(childContext.getExecutionManager()).thenReturn(executionManager); @@ -64,13 +66,6 @@ void setUp() { .withExecutorService(Executors.newCachedThreadPool()) .build()); when(durableContext.createChildContext(anyString(), anyString())).thenReturn(childContext); - when(durableContext.createChildContextWithoutSettingThreadContext(anyString(), anyString())) - .thenReturn(childContext); - when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("Root", ThreadType.CONTEXT)); - // Default: no existing operations (fresh execution) - mockIdGenerator = mock(OperationIdGenerator.class); - when(mockIdGenerator.nextOperationId()).thenAnswer(inv -> "child-" + operationIdCounter.incrementAndGet()); - when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null); // Capture registered operations so we can drive onCheckpointComplete callbacks. var registeredOps = new ConcurrentHashMap(); @@ -110,11 +105,14 @@ void setUp() { } private ParallelOperation createOperation(CompletionConfig completionConfig) { - return new ParallelOperation( + var op = new ParallelOperation( OperationIdentifier.of(OPERATION_ID, "test-parallel", OperationType.CONTEXT, OperationSubType.PARALLEL), SER_DES, durableContext, ParallelConfig.builder().completionConfig(completionConfig).build()); + + op.execute(); + return op; } private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGenerator mockGenerator) @@ -167,9 +165,9 @@ void branchCreation_childOperationHasParentReference() throws Exception { @Test void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -177,9 +175,9 @@ void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws E .contextDetails( ContextDetails.builder().result("\"r1\"").build()) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-2")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2)) .thenReturn(Operation.builder() - .id("child-2") + .id(CHILD_OP_2) .name("branch-2") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -189,7 +187,6 @@ void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws E .build()); var op = createOperation(CompletionConfig.allSuccessful()); - setOperationIdGenerator(op, mockIdGenerator); op.enqueueItem( "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); op.enqueueItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); @@ -208,9 +205,9 @@ void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws E @Test void minSuccessful_completesWhenThresholdMetAndReturnsResult() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -220,7 +217,6 @@ void minSuccessful_completesWhenThresholdMetAndReturnsResult() throws Exception .build()); var op = createOperation(CompletionConfig.minSuccessful(1)); - setOperationIdGenerator(op, mockIdGenerator); op.enqueueItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); @@ -262,9 +258,9 @@ void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exc .subType(OperationSubType.PARALLEL.getValue()) .status(OperationStatus.STARTED) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -272,9 +268,9 @@ void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exc .contextDetails( ContextDetails.builder().result("\"r1\"").build()) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-2")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2)) .thenReturn(Operation.builder() - .id("child-2") + .id(CHILD_OP_2) .name("branch-2") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -284,8 +280,6 @@ void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exc .build()); var op = createOperation(CompletionConfig.allSuccessful()); - setOperationIdGenerator(op, mockIdGenerator); - op.execute(); op.enqueueItem( "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); op.enqueueItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); @@ -312,9 +306,9 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio .subType(OperationSubType.PARALLEL.getValue()) .status(OperationStatus.SUCCEEDED) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -322,9 +316,9 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio .contextDetails( ContextDetails.builder().result("\"r1\"").build()) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-2")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2)) .thenReturn(Operation.builder() - .id("child-2") + .id(CHILD_OP_2) .name("branch-2") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -334,8 +328,6 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio .build()); var op = createOperation(CompletionConfig.allSuccessful()); - setOperationIdGenerator(op, mockIdGenerator); - op.execute(); op.enqueueItem( "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); op.enqueueItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); @@ -355,10 +347,10 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio // ===== Branch failure sends SUCCEED checkpoint and returns result ===== @Test - void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() { + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -366,7 +358,6 @@ void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() throws Except .build()); var op = createOperation(CompletionConfig.allSuccessful()); - setOperationIdGenerator(op, mockIdGenerator); op.enqueueItem( "branch-1", ctx -> { @@ -376,7 +367,7 @@ void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() throws Except SER_DES, OperationSubType.PARALLEL_BRANCH); - var result = assertDoesNotThrow(() -> op.get()); + var result = assertDoesNotThrow(op::get); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); verify(executionManager, never()) @@ -389,9 +380,9 @@ void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() throws Except @Test void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -399,9 +390,9 @@ void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exceptio .contextDetails( ContextDetails.builder().result("\"r1\"").build()) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-2")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2)) .thenReturn(Operation.builder() - .id("child-2") + .id(CHILD_OP_2) .name("branch-2") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -410,7 +401,6 @@ void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exceptio // toleratedFailureCount=1 so the operation completes after both branches finish var op = createOperation(CompletionConfig.toleratedFailureCount(1)); - setOperationIdGenerator(op, mockIdGenerator); op.enqueueItem( "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); op.enqueueItem( From aefcd0e13d918a842b254781222da0b48ef39927 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sun, 22 Mar 2026 18:27:22 -0700 Subject: [PATCH 12/15] avoid marking branch done when parent is done --- .../operation/BaseDurableOperation.java | 18 +++++++++--------- .../operation/ChildContextOperation.java | 15 ++++----------- .../operation/ConcurrencyOperation.java | 2 ++ 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java index 69f81f9e0..a99f54fd2 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java @@ -266,15 +266,8 @@ public void onCheckpointComplete(Operation operation) { // This method handles only terminal status updates. Override this method if a DurableOperation needs to // handle other updates. logger.trace("In onCheckpointComplete, completing operation {} ({})", getOperationId(), completionFuture); - // It's important that we synchronize access to the future, otherwise the processing could happen - // on someone else's thread and cause a race condition. - synchronized (completionFuture) { - // Completing the future here will also run any other completion stages that have been attached - // to the future. In our case, other contexts may have attached a function to reactivate themselves, - // so they will definitely have a chance to reactivate before we finish completing and deactivating - // whatever operations were just checkpointed. - completionFuture.complete(this); - } + + markCompletionFutureCompleted(); } } @@ -283,10 +276,17 @@ protected void markAlreadyCompleted() { // When the operation is already completed in a replay, we complete completionFuture immediately // so that the `get` method will be unblocked and the context thread will be registered logger.trace("In markAlreadyCompleted, completing operation: {} ({}).", getOperationId(), completionFuture); + markCompletionFutureCompleted(); + } + private void markCompletionFutureCompleted() { // It's important that we synchronize access to the future, otherwise the processing could happen // on someone else's thread and cause a race condition. synchronized (completionFuture) { + // Completing the future here will also run any other completion stages that have been attached + // to the future. In our case, other contexts may have attached a function to reactivate themselves, + // so they will definitely have a chance to reactivate before we finish completing and deactivating + // whatever operations were just checkpointed. completionFuture.complete(this); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java index ed93c8515..09a287726 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java @@ -5,6 +5,7 @@ import static software.amazon.lambda.durable.execution.ExecutionManager.isTerminalStatus; import java.nio.charset.StandardCharsets; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import software.amazon.awssdk.services.lambda.model.ContextOptions; import software.amazon.awssdk.services.lambda.model.ErrorObject; @@ -46,7 +47,7 @@ public class ChildContextOperation extends SerializableDurableOperation { private final Function function; private final ConcurrencyOperation parentOperation; - private boolean replayChildContext; + private final AtomicBoolean replayChildren = new AtomicBoolean(false); private T reconstructedResult; public ChildContextOperation( @@ -86,7 +87,7 @@ protected void replay(Operation existing) { if (existing.contextDetails() != null && Boolean.TRUE.equals(existing.contextDetails().replayChildren())) { // Large result: re-execute child context to reconstruct result - replayChildContext = true; + replayChildren.set(true); executeChildContext(); } else { markAlreadyCompleted(); @@ -100,11 +101,6 @@ protected void replay(Operation existing) { } } - @Override - protected void markAlreadyCompleted() { - super.markAlreadyCompleted(); - } - private void executeChildContext() { // The operationId is already globally unique (prefixed by parent context path via // DurableContext.nextOperationId), so we use it directly as the contextId. @@ -137,7 +133,7 @@ private void executeChildContext() { } private void handleChildContextSuccess(T result) { - if (replayChildContext) { + if (replayChildren.get()) { // Replaying a SUCCEEDED child with replayChildren=true — skip checkpointing. // Mark the completableFuture completed so get() doesn't block waiting for a checkpoint response. this.reconstructedResult = result; @@ -151,8 +147,6 @@ private void checkpointSuccess(T result) { // Skip checkpointing if parent ConcurrencyOperation has already completed — // prevents race conditions where a child finishes after the parent has already completed. if (parentOperation != null && parentOperation.isOperationCompleted()) { - this.reconstructedResult = result; - markAlreadyCompleted(); return; } @@ -187,7 +181,6 @@ private void handleChildContextFailure(Throwable exception) { // Skip checkpointing if parent ConcurrencyOperation has already completed — // prevents race conditions where a child finishes after the parent has already succeeded. if (parentOperation != null && parentOperation.isOperationCompleted()) { - markAlreadyCompleted(); return; } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index d02dfa438..db6592515 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -174,6 +174,7 @@ protected void executeItems() { next.execute(); } var child = waitForChildCompletion(succeededCount, failedCount, runningChildren); + // child may be null if the consumer thread is woken up due to a new item being added if (child != null) { if (runningChildren.contains(child)) { runningChildren.remove(child); @@ -213,6 +214,7 @@ private BaseDurableOperation waitForChildCompletion( .map(BaseDurableOperation::getCompletionFuture) .toList()); if (futures.size() < maxConcurrency) { + // add a future to listen to the new items if there is a vacancy consumerThreadListener.compareAndSet(null, new CompletableFuture<>()); futures.add(consumerThreadListener.get()); } From a3cfd56b9be8528e063fed5d880fa66e904e6033 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sun, 22 Mar 2026 19:59:12 -0700 Subject: [PATCH 13/15] remove Noop serdes --- .../durable/operation/WaitOperation.java | 1 + .../lambda/durable/serde/NoopSerDes.java | 18 ------------------ 2 files changed, 1 insertion(+), 18 deletions(-) delete mode 100644 sdk/src/main/java/software/amazon/lambda/durable/serde/NoopSerDes.java diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java index a51521071..f4e3f4b57 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java @@ -77,6 +77,7 @@ private void pollForWaitExpiration() { pollForOperationUpdates(remainingWaitTime); } + @Override public Void get() { waitForOperationCompletion(); diff --git a/sdk/src/main/java/software/amazon/lambda/durable/serde/NoopSerDes.java b/sdk/src/main/java/software/amazon/lambda/durable/serde/NoopSerDes.java deleted file mode 100644 index 245fed925..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/serde/NoopSerDes.java +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable.serde; - -import software.amazon.lambda.durable.TypeToken; - -/** A {@link SerDes} implementation that does nothing. Used as a placeholder when no serialization is required. */ -public class NoopSerDes implements SerDes { - @Override - public String serialize(Object value) { - return ""; - } - - @Override - public T deserialize(String data, TypeToken typeToken) { - return null; - } -} From 426dabc2ffe49460eec812d49e2ef4ba30fabfd5 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sun, 22 Mar 2026 20:22:44 -0700 Subject: [PATCH 14/15] add anyOf --- .../lambda/durable/examples/WaitExample.java | 19 ++++++++++++++----- .../amazon/lambda/durable/DurableFuture.java | 15 +++++++++++++++ .../operation/BaseDurableOperation.java | 2 +- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java index 7385fd48a..b4060f29a 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java @@ -4,6 +4,7 @@ import java.time.Duration; import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; /** @@ -31,11 +32,19 @@ public String handleRequest(GreetingRequest input, DurableContext context) { context.wait(null, Duration.ofSeconds(10)); // Step 2: Continue processing - var continued = - context.step("continue-processing", String.class, stepCtx -> started + " - continued after 10s"); - - // Wait 5 seconds - context.wait(null, Duration.ofSeconds(5)); + var continued = context.stepAsync("continue-processing", String.class, stepCtx -> { + try { + Thread.sleep(10000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return started + " - continued after 10s"; + }); + + // Wait at most seconds + var wait5seconds = context.waitAsync(null, Duration.ofSeconds(5)); + + DurableFuture.anyOf(continued, wait5seconds); // Step 3: Complete var result = diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java index 392c123b2..ba70d5fd1 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java @@ -4,6 +4,8 @@ import java.util.Arrays; import java.util.List; +import java.util.concurrent.CompletableFuture; +import software.amazon.lambda.durable.operation.BaseDurableOperation; /** * A future representing the result of an asynchronous durable operation. @@ -52,4 +54,17 @@ static List allOf(DurableFuture... futures) { static List allOf(List> futures) { return futures.stream().map(DurableFuture::get).toList(); } + + /** + * Waits for any of the provided futures to complete and returns its result. + * + * @param futures the futures to wait for + * @return the result of the first future to complete + */ + static Object anyOf(DurableFuture... futures) { + return CompletableFuture.anyOf(Arrays.stream(futures) + .map(f -> ((BaseDurableOperation) f).getCompletionFuture()) + .toArray(CompletableFuture[]::new)) + .join(); + } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java index a99f54fd2..30cb67611 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java @@ -67,7 +67,7 @@ protected BaseDurableOperation(OperationIdentifier operationIdentifier, DurableC executionManager.registerOperation(this); } - protected CompletableFuture getCompletionFuture() { + public CompletableFuture getCompletionFuture() { return completionFuture; } From c3dc73ddc694b4901ae93a0ea9636bf9b660fef7 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sun, 22 Mar 2026 20:58:56 -0700 Subject: [PATCH 15/15] add anyOf tests --- .../amazon/lambda/durable/examples/WaitExample.java | 11 +++++++---- .../durable/examples/CloudBasedIntegrationTest.java | 3 ++- .../software/amazon/lambda/durable/DurableFuture.java | 4 +++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java index b4060f29a..988262f4b 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java @@ -42,13 +42,16 @@ public String handleRequest(GreetingRequest input, DurableContext context) { }); // Wait at most seconds - var wait5seconds = context.waitAsync(null, Duration.ofSeconds(5)); + var wait5seconds = context.runInChildContextAsync("wait-5-seconds", String.class, ctx -> { + ctx.wait("wait-5-seconds", Duration.ofSeconds(5)); - DurableFuture.anyOf(continued, wait5seconds); + return started + " - waited 5 seconds"; + }); + + var step2 = DurableFuture.anyOf(continued, wait5seconds); // Step 3: Complete - var result = - context.step("complete-processing", String.class, stepCtx -> continued + " - completed after 5s more"); + var result = context.step("complete-processing", String.class, stepCtx -> step2 + " - completed after 5s more"); return result; } diff --git a/examples/src/test/java/software/amazon/lambda/durable/examples/CloudBasedIntegrationTest.java b/examples/src/test/java/software/amazon/lambda/durable/examples/CloudBasedIntegrationTest.java index c0e703cc4..639fa23cf 100644 --- a/examples/src/test/java/software/amazon/lambda/durable/examples/CloudBasedIntegrationTest.java +++ b/examples/src/test/java/software/amazon/lambda/durable/examples/CloudBasedIntegrationTest.java @@ -171,7 +171,8 @@ void testWaitExample() { var finalResult = result.getResult(String.class); assertNotNull(finalResult); assertTrue(finalResult.contains("Started processing for TestUser")); - assertTrue(finalResult.contains("continued after 10s")); + assertFalse(finalResult.contains("continued after 10s")); + assertTrue(finalResult.contains("waited 5 seconds")); assertTrue(finalResult.contains("completed after 5s more")); assertNotNull(runner.getOperation("start-processing")); diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java index ba70d5fd1..51e7163ef 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java @@ -65,6 +65,8 @@ static Object anyOf(DurableFuture... futures) { return CompletableFuture.anyOf(Arrays.stream(futures) .map(f -> ((BaseDurableOperation) f).getCompletionFuture()) .toArray(CompletableFuture[]::new)) - .join(); + .thenApply(o -> (DurableFuture) o) + .join() + .get(); } }