diff --git a/docs/spec/map.md b/docs/spec/map.md index 87f549e54..ee63796f7 100644 --- a/docs/spec/map.md +++ b/docs/spec/map.md @@ -290,7 +290,7 @@ import software.amazon.awssdk.services.lambda.model.ContextOptions; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; import software.amazon.awssdk.services.lambda.model.OperationUpdate; -import software.amazon.lambda.durable.CompletionConfig; +import software.amazon.lambda.durable.config.CompletionConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.model.CompletionReason; @@ -480,8 +480,8 @@ import java.util.Collections; import java.util.List; import java.util.function.Function; import software.amazon.lambda.durable.DurableContext; -import software.amazon.lambda.durable.MapConfig; -import software.amazon.lambda.durable.MapFunction; +import software.amazon.lambda.durable.config.MapConfig; +import software.amazon.lambda.durable.DurableContext.MapFunction; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.model.BatchResult; import software.amazon.lambda.durable.model.OperationSubType; @@ -900,8 +900,8 @@ package software.amazon.lambda.durable.operation; import java.util.List; import software.amazon.lambda.durable.DurableContext; -import software.amazon.lambda.durable.MapConfig; -import software.amazon.lambda.durable.MapFunction; +import software.amazon.lambda.durable.config.MapConfig; +import software.amazon.lambda.durable.DurableContext.MapFunction; import software.amazon.lambda.durable.model.BatchResult; import software.amazon.lambda.durable.model.OperationSubType; diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/CallbackExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/CallbackExample.java index 31c9b784a..29697f277 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/CallbackExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/CallbackExample.java @@ -3,9 +3,9 @@ package software.amazon.lambda.durable.examples; import java.time.Duration; -import software.amazon.lambda.durable.CallbackConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; +import software.amazon.lambda.durable.config.CallbackConfig; /** * Example demonstrating callback operations for external system integration. diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/ComplexMapExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/ComplexMapExample.java index e7ba2726d..ce010e575 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/ComplexMapExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/ComplexMapExample.java @@ -5,10 +5,10 @@ import java.time.Duration; import java.util.List; import java.util.stream.Collectors; -import software.amazon.lambda.durable.CompletionConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.MapConfig; +import software.amazon.lambda.durable.config.CompletionConfig; +import software.amazon.lambda.durable.config.MapConfig; /** * Example demonstrating advanced map features: wait operations inside branches, error handling, and early termination. diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/CustomPollingExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/CustomPollingExample.java index 9aa0c68ca..aab698e97 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/CustomPollingExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/CustomPollingExample.java @@ -6,7 +6,7 @@ import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.InvokeConfig; +import software.amazon.lambda.durable.config.InvokeConfig; import software.amazon.lambda.durable.retry.JitterStrategy; import software.amazon.lambda.durable.retry.PollingStrategies; diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/ErrorHandlingExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/ErrorHandlingExample.java index e74a60b9b..cc58ef42e 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/ErrorHandlingExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/ErrorHandlingExample.java @@ -6,8 +6,8 @@ import org.slf4j.LoggerFactory; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; -import software.amazon.lambda.durable.StepSemantics; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.StepSemantics; import software.amazon.lambda.durable.exception.StepFailedException; import software.amazon.lambda.durable.exception.StepInterruptedException; import software.amazon.lambda.durable.retry.RetryStrategies; diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/GenericInputOutputExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/GenericInputOutputExample.java index 47bacf9c2..e996b4835 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/GenericInputOutputExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/GenericInputOutputExample.java @@ -9,8 +9,8 @@ import org.slf4j.LoggerFactory; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.retry.RetryStrategies; /** diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/GenericTypesExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/GenericTypesExample.java index b4edd512e..1b10f5e55 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/GenericTypesExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/GenericTypesExample.java @@ -9,8 +9,8 @@ import org.slf4j.LoggerFactory; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.retry.RetryStrategies; /** diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/RetryInProcessExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/RetryInProcessExample.java index 79644bed6..ee5a8e3de 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/RetryInProcessExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/RetryInProcessExample.java @@ -9,7 +9,7 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.retry.JitterStrategy; import software.amazon.lambda.durable.retry.RetryStrategies; diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/SimpleInvokeExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/SimpleInvokeExample.java index 4c976de64..e54f5418b 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/SimpleInvokeExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/SimpleInvokeExample.java @@ -4,7 +4,7 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.InvokeConfig; +import software.amazon.lambda.durable.config.InvokeConfig; /** * Simple example demonstrating basic invoke execution with the Durable Execution SDK. diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastExample.java index 22d2201ae..e08b55f99 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastExample.java @@ -8,7 +8,7 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.retry.RetryStrategies; /** diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastInProcessExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastInProcessExample.java index a8454eaae..f6103c4c6 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastInProcessExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitAtLeastInProcessExample.java @@ -8,7 +8,7 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.retry.RetryStrategies; /** diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java index 7385fd48a..988262f4b 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/WaitExample.java @@ -4,6 +4,7 @@ import java.time.Duration; import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; /** @@ -31,15 +32,26 @@ public String handleRequest(GreetingRequest input, DurableContext context) { context.wait(null, Duration.ofSeconds(10)); // Step 2: Continue processing - var continued = - context.step("continue-processing", String.class, stepCtx -> started + " - continued after 10s"); + var continued = context.stepAsync("continue-processing", String.class, stepCtx -> { + try { + Thread.sleep(10000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return started + " - continued after 10s"; + }); - // Wait 5 seconds - context.wait(null, Duration.ofSeconds(5)); + // Wait at most seconds + var wait5seconds = context.runInChildContextAsync("wait-5-seconds", String.class, ctx -> { + ctx.wait("wait-5-seconds", Duration.ofSeconds(5)); + + return started + " - waited 5 seconds"; + }); + + var step2 = DurableFuture.anyOf(continued, wait5seconds); // Step 3: Complete - var result = - context.step("complete-processing", String.class, stepCtx -> continued + " - completed after 5s more"); + var result = context.step("complete-processing", String.class, stepCtx -> step2 + " - completed after 5s more"); return result; } diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java index a13005c7b..35c91f37c 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java @@ -7,7 +7,8 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.ParallelConfig; +import software.amazon.lambda.durable.ParallelDurableFuture; +import software.amazon.lambda.durable.config.ParallelConfig; import software.amazon.lambda.durable.model.ParallelResult; /** @@ -21,8 +22,8 @@ *
  • A final step combines the results into a summary * * - *

    The {@link software.amazon.lambda.durable.ParallelContext} implements {@link AutoCloseable}, so try-with-resources - * guarantees {@code join()} is called even if an exception occurs. + *

    The {@link ParallelDurableFuture} implements {@link AutoCloseable}, so try-with-resources guarantees + * {@code join()} is called even if an exception occurs. */ public class ParallelExample extends DurableHandler { @@ -54,9 +55,9 @@ public Output handleRequest(Input input, DurableContext context) { ParallelResult parallelResult = parallel.get(); logger.info( "Parallel complete: total={}, succeeded={}, failed={}", - parallelResult.getTotalBranches(), - parallelResult.getSucceededBranches(), - parallelResult.getFailedBranches()); + parallelResult.size(), + parallelResult.succeeded(), + parallelResult.failed()); var results = futures.stream().map(DurableFuture::get).toList(); diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java index b498db8de..8b7c97bba 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java @@ -7,8 +7,9 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.ParallelConfig; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.CompletionConfig; +import software.amazon.lambda.durable.config.ParallelConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.model.ParallelResult; import software.amazon.lambda.durable.retry.RetryStrategies; @@ -25,7 +26,7 @@ public class ParallelFailureToleranceExample extends DurableHandler { - public record Input(List services, int toleratedFailures, int minSuccessful) {} + public record Input(List services, Integer toleratedFailures, Integer minSuccessful) {} public record Output(int succeeded, int failed) {} @@ -35,8 +36,7 @@ public Output handleRequest(Input input, DurableContext context) { logger.info("Starting parallel execution with toleratedFailureCount={}", input.toleratedFailures()); var config = ParallelConfig.builder() - .minSuccessful(input.minSuccessful()) - .toleratedFailureCount(input.toleratedFailures()) + .completionConfig(new CompletionConfig(input.minSuccessful, input.toleratedFailures, null)) .build(); var futures = new ArrayList>(input.services().size()); @@ -65,12 +65,12 @@ public Output handleRequest(Input input, DurableContext context) { ParallelResult parallelResult = parallel.get(); logger.info( "Parallel complete: succeeded={}, failed={}, status={}", - parallelResult.getSucceededBranches(), - parallelResult.getFailedBranches(), - parallelResult.getCompletionStatus().isSucceeded() ? "succeeded" : "failed"); + parallelResult.succeeded(), + parallelResult.failed(), + parallelResult.completionStatus().isSucceeded() ? "succeeded" : "failed"); - var succeeded = parallelResult.getSucceededBranches(); - var failed = parallelResult.getFailedBranches(); + var succeeded = parallelResult.succeeded(); + var failed = parallelResult.failed(); logger.info("Completed: {} succeeded, {} failed", succeeded, failed); return new Output(succeeded, failed); diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java index 63bbee21d..a7c0e3ac2 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java @@ -8,7 +8,7 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.ParallelConfig; +import software.amazon.lambda.durable.config.ParallelConfig; import software.amazon.lambda.durable.model.ParallelResult; /** @@ -68,6 +68,6 @@ public Output handleRequest(Input input, DurableContext context) { logger.info("All {} notifications delivered", deliveries.size()); // Test replay context.wait("wait for finalization", Duration.ofSeconds(5)); - return new Output(deliveries, result.getSucceededBranches(), result.getFailedBranches()); + return new Output(deliveries, result.succeeded(), result.failed()); } } diff --git a/examples/src/test/java/software/amazon/lambda/durable/examples/CloudBasedIntegrationTest.java b/examples/src/test/java/software/amazon/lambda/durable/examples/CloudBasedIntegrationTest.java index c0e703cc4..639fa23cf 100644 --- a/examples/src/test/java/software/amazon/lambda/durable/examples/CloudBasedIntegrationTest.java +++ b/examples/src/test/java/software/amazon/lambda/durable/examples/CloudBasedIntegrationTest.java @@ -171,7 +171,8 @@ void testWaitExample() { var finalResult = result.getResult(String.class); assertNotNull(finalResult); assertTrue(finalResult.contains("Started processing for TestUser")); - assertTrue(finalResult.contains("continued after 10s")); + assertFalse(finalResult.contains("continued after 10s")); + assertTrue(finalResult.contains("waited 5 seconds")); assertTrue(finalResult.contains("completed after 5s more")); assertNotNull(runner.getOperation("start-processing")); diff --git a/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java b/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java index 7d4dd72d1..f1518e2ac 100644 --- a/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java +++ b/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java @@ -17,7 +17,7 @@ void succeedsWhenFailuresAreWithinTolerance() { var runner = LocalDurableTestRunner.create(ParallelFailureToleranceExample.Input.class, handler); // 2 good services, 1 bad — toleratedFailureCount=1 so the parallel op still succeeds - var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "svc-c"), 1, -1); + var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "svc-c"), 1, null); var result = runner.runUntilComplete(input); assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); @@ -32,7 +32,7 @@ void succeedsWhenAllBranchesSucceed() { var handler = new ParallelFailureToleranceExample(); var runner = LocalDurableTestRunner.create(ParallelFailureToleranceExample.Input.class, handler); - var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "svc-b", "svc-c"), 2, -1); + var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "svc-b", "svc-c"), 2, null); var result = runner.runUntilComplete(input); assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CallbackIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CallbackIntegrationTest.java index 8ab42779f..b253425ee 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CallbackIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CallbackIntegrationTest.java @@ -10,6 +10,7 @@ import software.amazon.awssdk.services.lambda.model.ErrorObject; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; +import software.amazon.lambda.durable.config.CallbackConfig; import software.amazon.lambda.durable.exception.CallbackFailedException; import software.amazon.lambda.durable.exception.CallbackTimeoutException; import software.amazon.lambda.durable.model.ExecutionStatus; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomConfigIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomConfigIntegrationTest.java index a6dd8a78a..e6f02ed75 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomConfigIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomConfigIntegrationTest.java @@ -11,6 +11,7 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaClient; import software.amazon.lambda.durable.client.LambdaDurableFunctionsClient; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.serde.JacksonSerDes; import software.amazon.lambda.durable.serde.SerDes; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomSerDesIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomSerDesIntegrationTest.java index 6c966e214..336c81f5c 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomSerDesIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/CustomSerDesIntegrationTest.java @@ -7,6 +7,7 @@ import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.serde.JacksonSerDes; import software.amazon.lambda.durable.serde.SerDes; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ExceptionIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ExceptionIntegrationTest.java index 8dd811dd9..5445a4f3d 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ExceptionIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ExceptionIntegrationTest.java @@ -8,6 +8,8 @@ import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.StepSemantics; import software.amazon.lambda.durable.exception.StepInterruptedException; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.retry.RetryStrategies; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java index 3bbd8c569..7788cc4a1 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java @@ -9,9 +9,11 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.CompletionConfig; +import software.amazon.lambda.durable.config.MapConfig; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.ExecutionStatus; -import software.amazon.lambda.durable.model.MapResultItem; +import software.amazon.lambda.durable.model.MapResult; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; class MapIntegrationTest { @@ -535,8 +537,10 @@ void testMapWithAllSuccessfulCompletionConfig_stopsOnFirstFailure() { assertEquals("OK1", result.getResult(0)); assertNotNull(result.getError(1)); // Items after the failure should be NOT_STARTED - assertEquals(MapResultItem.Status.NOT_STARTED, result.getItem(2).status()); - assertEquals(MapResultItem.Status.NOT_STARTED, result.getItem(3).status()); + assertEquals( + MapResult.MapResultItem.Status.SKIPPED, result.getItem(2).status()); + assertEquals( + MapResult.MapResultItem.Status.SKIPPED, result.getItem(3).status()); return "done"; }); @@ -873,7 +877,9 @@ void testMapWithNullResults() { assertTrue(result.allSucceeded()); assertEquals(3, result.size()); for (int i = 0; i < result.size(); i++) { - assertEquals(MapResultItem.Status.SUCCEEDED, result.getItem(i).status()); + assertEquals( + MapResult.MapResultItem.Status.SUCCEEDED, + result.getItem(i).status()); assertNull(result.getResult(i)); assertNull(result.getError(i)); } diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/StepSemanticsIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/StepSemanticsIntegrationTest.java index 34f9f87be..1b3b67409 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/StepSemanticsIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/StepSemanticsIntegrationTest.java @@ -6,6 +6,8 @@ import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.StepSemantics; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.retry.RetryStrategies; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/WaitForConditionIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/WaitForConditionIntegrationTest.java index 105dd6a88..689c25b9f 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/WaitForConditionIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/WaitForConditionIntegrationTest.java @@ -10,6 +10,7 @@ import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.WaitForConditionConfig; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.model.WaitForConditionResult; import software.amazon.lambda.durable.retry.JitterStrategy; diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/retry/RetryIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/retry/RetryIntegrationTest.java index 2debf140f..259d8c653 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/retry/RetryIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/retry/RetryIntegrationTest.java @@ -9,7 +9,7 @@ import org.junit.jupiter.api.Test; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableHandler; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; diff --git a/sdk-testing/src/test/java/software/amazon/lambda/durable/testing/SkipTimeTest.java b/sdk-testing/src/test/java/software/amazon/lambda/durable/testing/SkipTimeTest.java index cdaf87f3c..8fa2f591d 100644 --- a/sdk-testing/src/test/java/software/amazon/lambda/durable/testing/SkipTimeTest.java +++ b/sdk-testing/src/test/java/software/amazon/lambda/durable/testing/SkipTimeTest.java @@ -7,7 +7,7 @@ import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.retry.RetryStrategies; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java index 9cb859aee..0d63a5c29 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/CallbackConfig.java @@ -2,109 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 package software.amazon.lambda.durable; -import java.time.Duration; -import software.amazon.lambda.durable.serde.SerDes; -import software.amazon.lambda.durable.util.ParameterValidator; - -/** Configuration for callback operations. */ +/** @deprecated use {@link software.amazon.lambda.durable.config.CallbackConfig} instead. */ +@Deprecated public class CallbackConfig { - private final Duration timeout; - private final Duration heartbeatTimeout; - private final SerDes serDes; - - private CallbackConfig(Builder builder) { - this.timeout = builder.timeout; - this.heartbeatTimeout = builder.heartbeatTimeout; - this.serDes = builder.serDes; - } - - /** - * Returns the maximum duration to wait for the callback to complete. - * - * @return the timeout duration, or null if not specified - */ - public Duration timeout() { - return timeout; - } - - /** - * Returns the maximum duration between heartbeats before the callback is considered failed. - * - * @return the heartbeat timeout duration, or null if not specified - */ - public Duration heartbeatTimeout() { - return heartbeatTimeout; - } - - /** Returns the custom serializer for this callback, or null if not specified (uses default SerDes). */ - public SerDes serDes() { - return serDes; - } - - /** Creates a new builder with default values. */ - public static Builder builder() { - return new Builder(null, null, null); - } - - /** Creates a new builder pre-populated with this config's values. */ - public Builder toBuilder() { - return new Builder(timeout, heartbeatTimeout, serDes); - } - - /** Builder for {@link CallbackConfig}. */ - public static class Builder { - private Duration timeout; - private Duration heartbeatTimeout; - private SerDes serDes; - - private Builder(Duration timeout, Duration heartbeatTimeout, SerDes serDes) { - this.timeout = timeout; - this.heartbeatTimeout = heartbeatTimeout; - this.serDes = serDes; - } - - /** - * Sets the maximum duration to wait for the callback to complete before timing out. - * - * @param timeout the timeout duration - * @return this builder for method chaining - */ - public Builder timeout(Duration timeout) { - ParameterValidator.validateOptionalDuration(timeout, "Callback timeout"); - this.timeout = timeout; - return this; - } - - /** - * Sets the maximum duration between heartbeats before the callback is considered failed. - * - * @param heartbeatTimeout the heartbeat timeout duration - * @return this builder for method chaining - */ - public Builder heartbeatTimeout(Duration heartbeatTimeout) { - ParameterValidator.validateOptionalDuration(heartbeatTimeout, "Heartbeat timeout"); - this.heartbeatTimeout = heartbeatTimeout; - return this; - } - - /** - * Sets a custom serializer for the callback. - * - *

    If not specified, the callback will use the default SerDes configured for the handler. This allows - * per-callback customization of serialization behavior, useful for callbacks that need special handling (e.g., - * custom date formats, encryption, compression). - * - * @param serDes the custom serializer to use, or null to use the default - * @return this builder for method chaining - */ - public Builder serDes(SerDes serDes) { - this.serDes = serDes; - return this; - } - - /** Builds the {@link CallbackConfig} instance. */ - public CallbackConfig build() { - return new CallbackConfig(this); - } + /** @deprecated use {@link software.amazon.lambda.durable.config.CallbackConfig#builder()} instead. */ + @Deprecated + public static software.amazon.lambda.durable.config.CallbackConfig.Builder builder() { + return new software.amazon.lambda.durable.config.CallbackConfig.Builder(null, null, null); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java index e97843360..b83c7598c 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java @@ -267,7 +267,7 @@ public static final class Builder { private PollingStrategy pollingStrategy; private Duration checkpointDelay; - private Builder() {} + public Builder() {} /** * Sets a custom LambdaClient for production use. Use this method to customize the AWS SDK client with specific diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java index 79ebda106..2f652232f 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java @@ -8,6 +8,14 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; +import software.amazon.lambda.durable.config.CallbackConfig; +import software.amazon.lambda.durable.config.InvokeConfig; +import software.amazon.lambda.durable.config.MapConfig; +import software.amazon.lambda.durable.config.ParallelConfig; +import software.amazon.lambda.durable.config.RunInChildContextConfig; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.WaitForCallbackConfig; +import software.amazon.lambda.durable.config.WaitForConditionConfig; import software.amazon.lambda.durable.context.BaseContext; import software.amazon.lambda.durable.model.MapResult; import software.amazon.lambda.durable.model.WaitForConditionResult; @@ -25,7 +33,9 @@ public interface DurableContext extends BaseContext { * @param func the function to execute, receiving a {@link StepContext} * @return the step result */ - T step(String name, Class resultType, Function func); + default T step(String name, Class resultType, Function func) { + return step(name, TypeToken.get(resultType), func, StepConfig.builder().build()); + } /** * Executes a durable step with the given name and configuration, blocking until it completes. @@ -37,7 +47,9 @@ public interface DurableContext extends BaseContext { * @param config the step configuration (retry strategy, semantics, custom SerDes) * @return the step result */ - T step(String name, Class resultType, Function func, StepConfig config); + default T step(String name, Class resultType, Function func, StepConfig config) { + return stepAsync(name, resultType, func, config).get(); + } /** * Executes a durable step using a {@link TypeToken} for generic result types, blocking until it completes. @@ -48,7 +60,9 @@ public interface DurableContext extends BaseContext { * @param func the function to execute, receiving a {@link StepContext} * @return the step result */ - T step(String name, TypeToken resultType, Function func); + default T step(String name, TypeToken resultType, Function func) { + return step(name, resultType, func, StepConfig.builder().build()); + } /** * Executes a durable step using a {@link TypeToken} and configuration, blocking until it completes. @@ -60,7 +74,9 @@ public interface DurableContext extends BaseContext { * @param config the step configuration (retry strategy, semantics, custom SerDes) * @return the step result */ - T step(String name, TypeToken resultType, Function func, StepConfig config); + default T step(String name, TypeToken resultType, Function func, StepConfig config) { + return stepAsync(name, resultType, func, config).get(); + } /** * Asynchronously executes a durable step, returning a {@link DurableFuture} that can be composed or blocked on. @@ -71,7 +87,10 @@ public interface DurableContext extends BaseContext { * @param func the function to execute, receiving a {@link StepContext} * @return a future representing the step result */ - DurableFuture stepAsync(String name, Class resultType, Function func); + default DurableFuture stepAsync(String name, Class resultType, Function func) { + return stepAsync( + name, TypeToken.get(resultType), func, StepConfig.builder().build()); + } /** * Asynchronously executes a durable step using custom configuration. @@ -84,7 +103,10 @@ public interface DurableContext extends BaseContext { * @param config the step configuration (retry strategy, semantics, custom SerDes) * @return a future representing the step result */ - DurableFuture stepAsync(String name, Class resultType, Function func, StepConfig config); + default DurableFuture stepAsync( + String name, Class resultType, Function func, StepConfig config) { + return stepAsync(name, TypeToken.get(resultType), func, config); + } /** * Asynchronously executes a durable step using a {@link TypeToken} for generic result types. @@ -97,7 +119,9 @@ public interface DurableContext extends BaseContext { * @param func the function to execute, receiving a {@link StepContext} * @return a future representing the step result */ - DurableFuture stepAsync(String name, TypeToken resultType, Function func); + default DurableFuture stepAsync(String name, TypeToken resultType, Function func) { + return stepAsync(name, resultType, func, StepConfig.builder().build()); + } /** * Asynchronously executes a durable step using a {@link TypeToken} and custom configuration. @@ -114,29 +138,60 @@ public interface DurableContext extends BaseContext { DurableFuture stepAsync( String name, TypeToken resultType, Function func, StepConfig config); + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - T step(String name, Class resultType, Supplier func); - + default T step(String name, Class resultType, Supplier func) { + return stepAsync( + name, + TypeToken.get(resultType), + func, + StepConfig.builder().build()) + .get(); + } + + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - T step(String name, Class resultType, Supplier func, StepConfig config); + default T step(String name, Class resultType, Supplier func, StepConfig config) { + // Simply delegate to stepAsync and block on the result + return stepAsync(name, TypeToken.get(resultType), func, config).get(); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - T step(String name, TypeToken resultType, Supplier func); + default T step(String name, TypeToken resultType, Supplier func) { + return stepAsync(name, resultType, func, StepConfig.builder().build()).get(); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - T step(String name, TypeToken resultType, Supplier func, StepConfig config); + default T step(String name, TypeToken resultType, Supplier func, StepConfig config) { + return stepAsync(name, resultType, func, config).get(); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - DurableFuture stepAsync(String name, Class resultType, Supplier func); + default DurableFuture stepAsync(String name, Class resultType, Supplier func) { + return stepAsync( + name, TypeToken.get(resultType), func, StepConfig.builder().build()); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - DurableFuture stepAsync(String name, Class resultType, Supplier func, StepConfig config); + default DurableFuture stepAsync(String name, Class resultType, Supplier func, StepConfig config) { + return stepAsync(name, TypeToken.get(resultType), func, config); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - DurableFuture stepAsync(String name, TypeToken resultType, Supplier func); + default DurableFuture stepAsync(String name, TypeToken resultType, Supplier func) { + return stepAsync(name, resultType, func, StepConfig.builder().build()); + } + /** @deprecated use the variants accepting StepContext instead */ @Deprecated - DurableFuture stepAsync(String name, TypeToken resultType, Supplier func, StepConfig config); + default DurableFuture stepAsync(String name, TypeToken resultType, Supplier func, StepConfig config) { + return stepAsync(name, resultType, stepContext -> func.get(), config); + } /** * Suspends execution for the specified duration without consuming compute resources. @@ -148,7 +203,9 @@ DurableFuture stepAsync( * @param duration the duration to wait * @return always {@code null} */ - Void wait(String name, Duration duration); + default Void wait(String name, Duration duration) { + return waitAsync(name, duration).get(); + } /** * Asynchronously suspends execution for the specified duration. @@ -173,26 +230,59 @@ DurableFuture stepAsync( * @param resultType the result class for deserialization * @return the invocation result */ - T invoke(String name, String functionName, U payload, Class resultType); + default T invoke(String name, String functionName, U payload, Class resultType) { + return invokeAsync( + name, + functionName, + payload, + TypeToken.get(resultType), + InvokeConfig.builder().build()) + .get(); + } /** Invokes another Lambda function with custom configuration, blocking until the result is available. */ - T invoke(String name, String functionName, U payload, Class resultType, InvokeConfig config); + default T invoke(String name, String functionName, U payload, Class resultType, InvokeConfig config) { + return invokeAsync(name, functionName, payload, TypeToken.get(resultType), config) + .get(); + } /** Invokes another Lambda function using a {@link TypeToken} for generic result types, blocking until complete. */ - T invoke(String name, String functionName, U payload, TypeToken resultType); + default T invoke(String name, String functionName, U payload, TypeToken resultType) { + return invokeAsync( + name, + functionName, + payload, + resultType, + InvokeConfig.builder().build()) + .get(); + } /** Invokes another Lambda function using a {@link TypeToken} and custom configuration, blocking until complete. */ - T invoke(String name, String functionName, U payload, TypeToken resultType, InvokeConfig config); + default T invoke(String name, String functionName, U payload, TypeToken resultType, InvokeConfig config) { + return invokeAsync(name, functionName, payload, resultType, config).get(); + } /** Invokes another Lambda function using a {@link TypeToken} and custom configuration, blocking until complete. */ - DurableFuture invokeAsync( - String name, String functionName, U payload, Class resultType, InvokeConfig config); + default DurableFuture invokeAsync( + String name, String functionName, U payload, Class resultType, InvokeConfig config) { + return invokeAsync(name, functionName, payload, TypeToken.get(resultType), config); + } /** Asynchronously invokes another Lambda function, returning a {@link DurableFuture}. */ - DurableFuture invokeAsync(String name, String functionName, U payload, Class resultType); + default DurableFuture invokeAsync(String name, String functionName, U payload, Class resultType) { + return invokeAsync( + name, + functionName, + payload, + TypeToken.get(resultType), + InvokeConfig.builder().build()); + } /** Asynchronously invokes another Lambda function using a {@link TypeToken} for generic result types. */ - DurableFuture invokeAsync(String name, String functionName, U payload, TypeToken resultType); + default DurableFuture invokeAsync(String name, String functionName, U payload, TypeToken resultType) { + return invokeAsync( + name, functionName, payload, resultType, InvokeConfig.builder().build()); + } /** * Asynchronously invokes another Lambda function using a {@link TypeToken} and custom configuration. @@ -212,13 +302,20 @@ DurableFuture invokeAsync( String name, String functionName, U payload, TypeToken resultType, InvokeConfig config); /** Creates a callback with custom configuration. */ - DurableCallbackFuture createCallback(String name, Class resultType, CallbackConfig config); + default DurableCallbackFuture createCallback(String name, Class resultType, CallbackConfig config) { + return createCallback(name, TypeToken.get(resultType), config); + } /** Creates a callback using a {@link TypeToken} for generic result types. */ - DurableCallbackFuture createCallback(String name, TypeToken resultType); + default DurableCallbackFuture createCallback(String name, TypeToken resultType) { + return createCallback(name, resultType, CallbackConfig.builder().build()); + } /** Creates a callback with default configuration. */ - DurableCallbackFuture createCallback(String name, Class resultType); + default DurableCallbackFuture createCallback(String name, Class resultType) { + return createCallback( + name, TypeToken.get(resultType), CallbackConfig.builder().build()); + } /** * Creates a callback operation that suspends execution until an external system completes it. @@ -246,48 +343,188 @@ DurableFuture invokeAsync( * @param func the function to execute, receiving a child {@link DurableContext} * @return the child context result */ - T runInChildContext(String name, Class resultType, Function func); + default T runInChildContext(String name, Class resultType, Function func) { + return runInChildContextAsync( + name, + TypeToken.get(resultType), + func, + RunInChildContextConfig.builder().build()) + .get(); + } /** * Runs a function in a child context using a {@link TypeToken} for generic result types, blocking until complete. */ - T runInChildContext(String name, TypeToken resultType, Function func); - - /** Asynchronously runs a function in a child context, returning a {@link DurableFuture}. */ - DurableFuture runInChildContextAsync(String name, Class resultType, Function func); + default T runInChildContext(String name, TypeToken resultType, Function func) { + return runInChildContextAsync( + name, + resultType, + func, + RunInChildContextConfig.builder().build()) + .get(); + } - /** Asynchronously runs a function in a child context using a {@link TypeToken} for generic result types. */ - DurableFuture runInChildContextAsync(String name, TypeToken resultType, Function func); - - MapResult map(String name, Collection items, Class resultType, MapFunction function); - - MapResult map( - String name, Collection items, Class resultType, MapFunction function, MapConfig config); + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @return the DurableFuture of the child context result + */ + default DurableFuture runInChildContextAsync( + String name, Class resultType, Function func) { + return runInChildContextAsync( + name, + TypeToken.get(resultType), + func, + RunInChildContextConfig.builder().build()); + } - MapResult map(String name, Collection items, TypeToken resultType, MapFunction function); + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @return the DurableFuture of the child context result + */ + default DurableFuture runInChildContextAsync( + String name, TypeToken resultType, Function func) { + return runInChildContextAsync( + name, resultType, func, RunInChildContextConfig.builder().build()); + } - MapResult map( - String name, Collection items, TypeToken resultType, MapFunction function, MapConfig config); + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param the result type + * @param name the unique operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @return the child context result + */ + default T runInChildContext( + String name, Class resultType, Function func, RunInChildContextConfig config) { + return runInChildContextAsync(name, TypeToken.get(resultType), func, config) + .get(); + } - DurableFuture> mapAsync( - String name, Collection items, Class resultType, MapFunction function); + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @param config the configuration for the child context + * @return the child context result + */ + default T runInChildContext( + String name, TypeToken resultType, Function func, RunInChildContextConfig config) { + return runInChildContextAsync(name, resultType, func, config).get(); + } - DurableFuture> mapAsync( - String name, Collection items, Class resultType, MapFunction function, MapConfig config); + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @param config the configuration for the child context + * @return the DurableFuture wrapping the child context result + */ + default DurableFuture runInChildContextAsync( + String name, Class resultType, Function func, RunInChildContextConfig config) { + return runInChildContextAsync(name, TypeToken.get(resultType), func, config); + } - DurableFuture> mapAsync( - String name, Collection items, TypeToken resultType, MapFunction function); + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @param config the configuration for the child context + * @return the DurableFuture wrapping the child context result + */ + DurableFuture runInChildContextAsync( + String name, TypeToken resultType, Function func, RunInChildContextConfig config); + + default MapResult map(String name, Collection items, Class resultType, MapFunction function) { + return mapAsync( + name, + items, + TypeToken.get(resultType), + function, + MapConfig.builder().build()) + .get(); + } + + default MapResult map( + String name, Collection items, Class resultType, MapFunction function, MapConfig config) { + return mapAsync(name, items, TypeToken.get(resultType), function, config) + .get(); + } + + default MapResult map( + String name, Collection items, TypeToken resultType, MapFunction function) { + return mapAsync(name, items, resultType, function, MapConfig.builder().build()) + .get(); + } + + default MapResult map( + String name, Collection items, TypeToken resultType, MapFunction function, MapConfig config) { + return mapAsync(name, items, resultType, function, config).get(); + } + + default DurableFuture> mapAsync( + String name, Collection items, Class resultType, MapFunction function) { + return mapAsync( + name, + items, + TypeToken.get(resultType), + function, + MapConfig.builder().build()); + } + + default DurableFuture> mapAsync( + String name, Collection items, Class resultType, MapFunction function, MapConfig config) { + return mapAsync(name, items, TypeToken.get(resultType), function, config); + } + + default DurableFuture> mapAsync( + String name, Collection items, TypeToken resultType, MapFunction function) { + return mapAsync(name, items, resultType, function, MapConfig.builder().build()); + } DurableFuture> mapAsync( String name, Collection items, TypeToken resultType, MapFunction function, MapConfig config); /** - * Creates a {@link ParallelContext} for executing multiple branches concurrently. + * Creates a {@link ParallelDurableFuture} for executing multiple branches concurrently. * * @param config the parallel execution configuration - * @return a new ParallelContext for registering and executing branches + * @return a new ParallelDurableFuture for registering and executing branches */ - ParallelContext parallel(String name, ParallelConfig config); + ParallelDurableFuture parallel(String name, ParallelConfig config); /** * Executes a submitter function and waits for an external callback, blocking until the callback completes. @@ -301,38 +538,67 @@ DurableFuture> mapAsync( * @param func the submitter function, receiving the callback ID and a {@link StepContext} * @return the callback result */ - T waitForCallback(String name, Class resultType, BiConsumer func); + default T waitForCallback(String name, Class resultType, BiConsumer func) { + return waitForCallbackAsync( + name, + TypeToken.get(resultType), + func, + WaitForCallbackConfig.builder().build()) + .get(); + } /** Executes a submitter and waits for an external callback using a {@link TypeToken}, blocking until complete. */ - T waitForCallback(String name, TypeToken resultType, BiConsumer func); + default T waitForCallback(String name, TypeToken resultType, BiConsumer func) { + return waitForCallbackAsync( + name, resultType, func, WaitForCallbackConfig.builder().build()) + .get(); + } /** Executes a submitter and waits for an external callback with custom configuration, blocking until complete. */ - T waitForCallback( + default T waitForCallback( String name, Class resultType, BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig); + WaitForCallbackConfig waitForCallbackConfig) { + return waitForCallbackAsync(name, TypeToken.get(resultType), func, waitForCallbackConfig) + .get(); + } /** Executes a submitter and waits for an external callback using a {@link TypeToken} and custom configuration. */ - T waitForCallback( + default T waitForCallback( String name, TypeToken resultType, BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig); + WaitForCallbackConfig waitForCallbackConfig) { + return waitForCallbackAsync(name, resultType, func, waitForCallbackConfig) + .get(); + } /** Asynchronously executes a submitter and waits for an external callback. */ - DurableFuture waitForCallbackAsync(String name, Class resultType, BiConsumer func); + default DurableFuture waitForCallbackAsync( + String name, Class resultType, BiConsumer func) { + return waitForCallbackAsync( + name, + TypeToken.get(resultType), + func, + WaitForCallbackConfig.builder().build()); + } /** Asynchronously executes a submitter and waits for an external callback using a {@link TypeToken}. */ - DurableFuture waitForCallbackAsync( - String name, TypeToken resultType, BiConsumer func); + default DurableFuture waitForCallbackAsync( + String name, TypeToken resultType, BiConsumer func) { + return waitForCallbackAsync( + name, resultType, func, WaitForCallbackConfig.builder().build()); + } /** Asynchronously executes a submitter and waits for an external callback with custom configuration. */ - DurableFuture waitForCallbackAsync( + default DurableFuture waitForCallbackAsync( String name, Class resultType, BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig); + WaitForCallbackConfig waitForCallbackConfig) { + return waitForCallbackAsync(name, TypeToken.get(resultType), func, waitForCallbackConfig); + } /** * Asynchronously executes a submitter and waits for an external callback using a {@link TypeToken} and custom @@ -365,59 +631,97 @@ DurableFuture waitForCallbackAsync( * @param initialState the initial state passed to the first check invocation * @return the final state value when the condition is met */ - T waitForCondition( + default T waitForCondition( String name, Class resultType, BiFunction> checkFunc, - T initialState); + T initialState) { + return waitForConditionAsync( + name, + TypeToken.get(resultType), + checkFunc, + initialState, + WaitForConditionConfig.builder().build()) + .get(); + } /** Polls a condition function until it signals done, using a custom configuration, blocking until complete. */ - T waitForCondition( + default T waitForCondition( String name, Class resultType, BiFunction> checkFunc, T initialState, - WaitForConditionConfig config); + WaitForConditionConfig config) { + return waitForConditionAsync(name, resultType, checkFunc, initialState, config) + .get(); + } /** Polls a condition function until it signals done, using a {@link TypeToken}, blocking until complete. */ - T waitForCondition( + default T waitForCondition( String name, TypeToken resultType, BiFunction> checkFunc, - T initialState); + T initialState) { + return waitForConditionAsync( + name, + resultType, + checkFunc, + initialState, + WaitForConditionConfig.builder().build()) + .get(); + } /** * Polls a condition function until it signals done, using a {@link TypeToken} and custom configuration, blocking * until complete. */ - T waitForCondition( + default T waitForCondition( String name, TypeToken resultType, BiFunction> checkFunc, T initialState, - WaitForConditionConfig config); + WaitForConditionConfig config) { + return waitForConditionAsync(name, resultType, checkFunc, initialState, config) + .get(); + } /** Asynchronously polls a condition function until it signals done. */ - DurableFuture waitForConditionAsync( + default DurableFuture waitForConditionAsync( String name, Class resultType, BiFunction> checkFunc, - T initialState); + T initialState) { + return waitForConditionAsync( + name, + TypeToken.get(resultType), + checkFunc, + initialState, + WaitForConditionConfig.builder().build()); + } /** Asynchronously polls a condition function until it signals done, using custom configuration. */ - DurableFuture waitForConditionAsync( + default DurableFuture waitForConditionAsync( String name, Class resultType, BiFunction> checkFunc, T initialState, - WaitForConditionConfig config); + WaitForConditionConfig config) { + return waitForConditionAsync(name, TypeToken.get(resultType), checkFunc, initialState, config); + } /** Asynchronously polls a condition function until it signals done, using a {@link TypeToken}. */ - DurableFuture waitForConditionAsync( + default DurableFuture waitForConditionAsync( String name, TypeToken resultType, BiFunction> checkFunc, - T initialState); + T initialState) { + return waitForConditionAsync( + name, + resultType, + checkFunc, + initialState, + WaitForConditionConfig.builder().build()); + } /** * Asynchronously polls a condition function until it signals done, using a {@link TypeToken} and custom @@ -440,4 +744,28 @@ DurableFuture waitForConditionAsync( BiFunction> checkFunc, T initialState, WaitForConditionConfig config); + + /** + * Function applied to each item in a map operation. + * + *

    Each invocation receives its own {@link DurableContext}, allowing the use of durable operations like + * {@code step()} and {@code wait()} within the function body. The index parameter indicates the item's position in + * the input collection. + * + * @param the input item type + * @param the output result type + */ + @FunctionalInterface + interface MapFunction { + + /** + * Applies this function to the given item. + * + * @param item the input item to process + * @param index the zero-based index of the item in the input collection + * @param context the durable context for this item's execution + * @return the result of processing the item + */ + O apply(I item, int index, DurableContext context); + } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java index 392c123b2..51e7163ef 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableFuture.java @@ -4,6 +4,8 @@ import java.util.Arrays; import java.util.List; +import java.util.concurrent.CompletableFuture; +import software.amazon.lambda.durable.operation.BaseDurableOperation; /** * A future representing the result of an asynchronous durable operation. @@ -52,4 +54,19 @@ static List allOf(DurableFuture... futures) { static List allOf(List> futures) { return futures.stream().map(DurableFuture::get).toList(); } + + /** + * Waits for any of the provided futures to complete and returns its result. + * + * @param futures the futures to wait for + * @return the result of the first future to complete + */ + static Object anyOf(DurableFuture... futures) { + return CompletableFuture.anyOf(Arrays.stream(futures) + .map(f -> ((BaseDurableOperation) f).getCompletionFuture()) + .toArray(CompletableFuture[]::new)) + .thenApply(o -> (DurableFuture) o) + .join() + .get(); + } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/MapFunction.java b/sdk/src/main/java/software/amazon/lambda/durable/MapFunction.java deleted file mode 100644 index 041dccfc9..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/MapFunction.java +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; - -/** - * Function applied to each item in a map operation. - * - *

    Each invocation receives its own {@link DurableContext}, allowing the use of durable operations like - * {@code step()} and {@code wait()} within the function body. The index parameter indicates the item's position in the - * input collection. - * - * @param the input item type - * @param the output result type - */ -@FunctionalInterface -public interface MapFunction { - - /** - * Applies this function to the given item. - * - * @param item the input item to process - * @param index the zero-based index of the item in the input collection - * @param context the durable context for this item's execution - * @return the result of processing the item - */ - O apply(I item, int index, DurableContext context); -} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ParallelConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/ParallelConfig.java deleted file mode 100644 index 67992801f..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/ParallelConfig.java +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; - -/** - * Configuration options for parallel operations in durable executions. - * - *

    This class provides a builder pattern for configuring concurrency limits and completion semantics for parallel - * branch execution. - */ -public class ParallelConfig { - private final int maxConcurrency; - private final int minSuccessful; - private final int toleratedFailureCount; - - private ParallelConfig(Builder builder) { - this.maxConcurrency = builder.maxConcurrency; - this.minSuccessful = builder.minSuccessful; - this.toleratedFailureCount = builder.toleratedFailureCount; - } - - /** @return the maximum number of branches running simultaneously, or -1 for unlimited */ - public int maxConcurrency() { - return maxConcurrency; - } - - /** @return the minimum number of successful branches required, or -1 meaning all must succeed */ - public int minSuccessful() { - return minSuccessful; - } - - /** @return the maximum number of branch failures tolerated before stopping */ - public int toleratedFailureCount() { - return toleratedFailureCount; - } - - /** - * Creates a new builder for ParallelConfig. - * - * @return a new Builder instance - */ - public static Builder builder() { - return new Builder(); - } - - /** Builder for creating ParallelConfig instances. */ - public static class Builder { - private int maxConcurrency = -1; - private int minSuccessful = -1; - private int toleratedFailureCount = 0; - - private Builder() {} - - /** - * Sets the maximum number of branches that can run simultaneously. - * - * @param maxConcurrency the concurrency limit, or -1 for unlimited - * @return this builder for method chaining - */ - public Builder maxConcurrency(int maxConcurrency) { - this.maxConcurrency = maxConcurrency; - return this; - } - - /** - * Sets the minimum number of branches that must succeed for the parallel operation to complete successfully. - * - * @param minSuccessful the minimum successful count, or -1 meaning all branches must succeed - * @return this builder for method chaining - */ - public Builder minSuccessful(int minSuccessful) { - this.minSuccessful = minSuccessful; - return this; - } - - /** - * Sets the maximum number of branch failures tolerated before stopping execution. - * - * @param toleratedFailureCount the maximum tolerated failures - * @return this builder for method chaining - */ - public Builder toleratedFailureCount(int toleratedFailureCount) { - this.toleratedFailureCount = toleratedFailureCount; - return this; - } - - /** - * Builds the ParallelConfig instance. - * - * @return a new ParallelConfig with the configured options - * @throws IllegalArgumentException if any configuration values are invalid - */ - public ParallelConfig build() { - if (maxConcurrency != -1 && maxConcurrency <= 0) { - throw new IllegalArgumentException( - "maxConcurrency must be -1 (unlimited) or greater than 0, got: " + maxConcurrency); - } - if (minSuccessful < -1) { - throw new IllegalArgumentException("minSuccessful must be >= -1, got: " + minSuccessful); - } - if (toleratedFailureCount < 0) { - throw new IllegalArgumentException("toleratedFailureCount must be >= 0, got: " + toleratedFailureCount); - } - return new ParallelConfig(this); - } - } -} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java b/sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java deleted file mode 100644 index 6debb9754..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; - -import java.util.Objects; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; -import software.amazon.lambda.durable.model.ParallelResult; -import software.amazon.lambda.durable.operation.ParallelOperation; - -/** User-facing context for managing parallel branch execution within a durable function. */ -public class ParallelContext implements AutoCloseable, DurableFuture { - - private final ParallelOperation parallelOperation; - private final DurableContext durableContext; - private final AtomicBoolean joined = new AtomicBoolean(false); - - /** - * Creates a new ParallelContext. - * - * @param parallelOperation the underlying parallel operation managing concurrency - * @param durableContext the durable context for creating child operations - */ - public ParallelContext(ParallelOperation parallelOperation, DurableContext durableContext) { - this.parallelOperation = Objects.requireNonNull(parallelOperation, "parallelOperation cannot be null"); - this.durableContext = Objects.requireNonNull(durableContext, "durableContext cannot be null"); - } - - /** - * Registers and immediately starts a branch (respects maxConcurrency). - * - * @param name the branch name - * @param resultType the result type class - * @param func the function to execute in the branch's child context - * @param the result type - * @return a {@link DurableFuture} that will contain the branch result - * @throws IllegalStateException if called after {@link #join()} - */ - public DurableFuture branch(String name, Class resultType, Function func) { - return branch(name, TypeToken.get(resultType), func); - } - - /** - * Registers and immediately starts a branch (respects maxConcurrency). - * - * @param name the branch name - * @param resultType the result type token for generic types - * @param func the function to execute in the branch's child context - * @param the result type - * @return a {@link DurableFuture} that will contain the branch result - * @throws IllegalStateException if called after {@link #join()} - */ - public DurableFuture branch(String name, TypeToken resultType, Function func) { - if (joined.get()) { - throw new IllegalStateException("Cannot add branches after join() has been called"); - } - return parallelOperation.addItem( - name, func, resultType, durableContext.getDurableConfig().getSerDes()); - } - - /** - * Waits for completion based on config rules (minSuccessful, toleratedFailureCount). - * - *

    First validates that the number of registered branches is sufficient to satisfy the completion criteria. Then - * blocks until completion criteria are met or failure threshold exceeded. - * - * @throws IllegalArgumentException if branch count cannot satisfy completion criteria - * @throws software.amazon.lambda.durable.exception.ConcurrencyExecutionException if failure threshold exceeded - */ - public void join() { - if (!joined.compareAndSet(false, true)) { - return; - } - parallelOperation.join(); - } - - /** - * Blocks until the parallel operation completes and returns the {@link ParallelResult}. - * - *

    Calling {@code get()} implicitly calls {@code join()} if it has not been called yet. - * - * @return the {@link ParallelResult} summarising branch outcomes - */ - @Override - public ParallelResult get() { - joined.set(true); - return parallelOperation.get(); - } - - /** - * Calls {@link #join()} if not already called. Guarantees that all branches complete before the context is closed. - */ - @Override - public void close() { - join(); - } -} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ParallelDurableFuture.java b/sdk/src/main/java/software/amazon/lambda/durable/ParallelDurableFuture.java new file mode 100644 index 000000000..b71198d7d --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/ParallelDurableFuture.java @@ -0,0 +1,74 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable; + +import java.util.function.Function; +import software.amazon.lambda.durable.config.ParallelBranchConfig; +import software.amazon.lambda.durable.model.ParallelResult; + +/** User-facing context for managing parallel branch execution within a durable function. */ +public interface ParallelDurableFuture extends AutoCloseable, DurableFuture { + + /** + * Registers and immediately starts a branch (respects maxConcurrency). + * + * @param name the branch name + * @param resultType the result type token for generic types + * @param func the function to execute in the branch's child context + * @param the result type + * @return a {@link DurableFuture} that will contain the branch result + * @throws IllegalStateException if called after {@link #close()} + */ + default DurableFuture branch(String name, Class resultType, Function func) { + return branch( + name, + TypeToken.get(resultType), + func, + ParallelBranchConfig.builder().build()); + } + + /** + * Registers and immediately starts a branch (respects maxConcurrency). + * + * @param name the branch name + * @param resultType the result type token for generic types + * @param func the function to execute in the branch's child context + * @param the result type + * @return a {@link DurableFuture} that will contain the branch result + * @throws IllegalStateException if called after {@link #close()} + */ + default DurableFuture branch(String name, TypeToken resultType, Function func) { + return branch(name, resultType, func, ParallelBranchConfig.builder().build()); + } + + /** + * Registers and immediately starts a branch (respects maxConcurrency). + * + * @param name the branch name + * @param resultType the result type token for generic types + * @param func the function to execute in the branch's child context + * @param the result type + * @return a {@link DurableFuture} that will contain the branch result + * @throws IllegalStateException if called after {@link #close()} + */ + default DurableFuture branch( + String name, Class resultType, Function func, ParallelBranchConfig config) { + return branch(name, TypeToken.get(resultType), func, config); + } + + /** + * Registers and immediately starts a branch (respects maxConcurrency). + * + * @param name the branch name + * @param resultType the result type token for generic types + * @param func the function to execute in the branch's child context + * @param the result type + * @return a {@link DurableFuture} that will contain the branch result + * @throws IllegalStateException if called after {@link #close()} + */ + DurableFuture branch( + String name, TypeToken resultType, Function func, ParallelBranchConfig config); + + /** Calls {@link #get()} if not already called. Guarantees that the context is closed. */ + void close(); +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java index d3e6b10a9..fff17bf90 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/StepConfig.java @@ -2,111 +2,23 @@ // SPDX-License-Identifier: Apache-2.0 package software.amazon.lambda.durable; -import software.amazon.lambda.durable.retry.RetryStrategies; -import software.amazon.lambda.durable.retry.RetryStrategy; -import software.amazon.lambda.durable.serde.SerDes; - /** * Configuration options for step operations in durable executions. * *

    This class provides a builder pattern for configuring various aspects of step execution, including retry behavior * and delivery semantics. + * + * @deprecated use {@link software.amazon.lambda.durable.config.StepConfig} */ +@Deprecated public class StepConfig { - private final RetryStrategy retryStrategy; - private final StepSemantics semantics; - private final SerDes serDes; - - private StepConfig(Builder builder) { - this.retryStrategy = builder.retryStrategy; - this.semantics = builder.semantics; - this.serDes = builder.serDes; - } - - /** Returns the retry strategy for this step, or the default strategy if not specified. */ - public RetryStrategy retryStrategy() { - return retryStrategy != null ? retryStrategy : RetryStrategies.Presets.DEFAULT; - } - - /** Returns the delivery semantics for this step, defaults to AT_LEAST_ONCE_PER_RETRY if not specified. */ - public StepSemantics semantics() { - return semantics != null ? semantics : StepSemantics.AT_LEAST_ONCE_PER_RETRY; - } - - /** Returns the custom serializer for this step, or null if not specified (uses default SerDes). */ - public SerDes serDes() { - return serDes; - } - - public Builder toBuilder() { - return new Builder(retryStrategy, semantics, serDes); - } - /** * Creates a new builder for StepConfig. * - * @return a new Builder instance + * @deprecated use {@link software.amazon.lambda.durable.config.StepConfig#builder} */ - public static Builder builder() { - return new Builder(null, null, null); - } - - /** Builder for creating StepConfig instances. */ - public static class Builder { - private RetryStrategy retryStrategy; - private StepSemantics semantics; - private SerDes serDes; - - private Builder(RetryStrategy retryStrategy, StepSemantics semantics, SerDes serDes) { - this.retryStrategy = retryStrategy; - this.semantics = semantics; - this.serDes = serDes; - } - - /** - * Sets the retry strategy for the step. - * - * @param retryStrategy the retry strategy to use, or null for default behavior - * @return this builder for method chaining - */ - public Builder retryStrategy(RetryStrategy retryStrategy) { - this.retryStrategy = retryStrategy; - return this; - } - - /** - * Sets the delivery semantics for the step. - * - * @param semantics the delivery semantics to use, defaults to AT_LEAST_ONCE_PER_RETRY if not specified - * @return this builder for method chaining - */ - public Builder semantics(StepSemantics semantics) { - this.semantics = semantics; - return this; - } - - /** - * Sets a custom serializer for the step. - * - *

    If not specified, the step will use the default SerDes configured for the handler. This allows per-step - * customization of serialization behavior, useful for steps that need special handling (e.g., custom date - * formats, encryption, compression). - * - * @param serDes the custom serializer to use, or null to use the default - * @return this builder for method chaining - */ - public Builder serDes(SerDes serDes) { - this.serDes = serDes; - return this; - } - - /** - * Builds the StepConfig instance. - * - * @return a new StepConfig with the configured options - */ - public StepConfig build() { - return new StepConfig(this); - } + @Deprecated + public static software.amazon.lambda.durable.config.StepConfig.Builder builder() { + return new software.amazon.lambda.durable.config.StepConfig.Builder(null, null, null); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java index 8c6b7e28a..16c9a7da3 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/WaitForCallbackConfig.java @@ -2,74 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 package software.amazon.lambda.durable; -/** - * Configuration for the {@code waitForCallback} composite operation. - * - *

    Combines a {@link StepConfig} (for the step that produces the callback) and a {@link CallbackConfig} (for the - * callback wait itself). - */ +/** @deprecated use {@link software.amazon.lambda.durable.config.WaitForCallbackConfig} instead. */ +@Deprecated public class WaitForCallbackConfig { - private final StepConfig stepConfig; - private final CallbackConfig callbackConfig; - - private WaitForCallbackConfig(Builder builder) { - this.stepConfig = builder.stepConfig == null ? StepConfig.builder().build() : builder.stepConfig; - this.callbackConfig = - builder.callbackConfig == null ? CallbackConfig.builder().build() : builder.callbackConfig; - } - - /** Returns the step configuration for the composite operation. */ - public StepConfig stepConfig() { - return stepConfig; - } - - /** Returns the callback configuration for the composite operation. */ - public CallbackConfig callbackConfig() { - return callbackConfig; - } - - /** Creates a new builder. */ - public static Builder builder() { - return new Builder(); - } - - /** Creates a builder pre-populated with this instance's values. */ - public Builder toBuilder() { - return new Builder().stepConfig(this.stepConfig).callbackConfig(this.callbackConfig); - } - - /** Builder for {@link WaitForCallbackConfig}. */ - public static class Builder { - private StepConfig stepConfig; - private CallbackConfig callbackConfig; - - private Builder() {} - - /** - * Sets the step configuration for the composite operation. - * - * @param stepConfig the step configuration - * @return this builder for method chaining - */ - public Builder stepConfig(StepConfig stepConfig) { - this.stepConfig = stepConfig; - return this; - } - - /** - * Sets the callback configuration for the composite operation. - * - * @param callbackConfig the callback configuration - * @return this builder for method chaining - */ - public Builder callbackConfig(CallbackConfig callbackConfig) { - this.callbackConfig = callbackConfig; - return this; - } - - /** Builds the WaitForCallbackConfig instance. */ - public WaitForCallbackConfig build() { - return new WaitForCallbackConfig(this); - } + /** @deprecated use {@link software.amazon.lambda.durable.config.WaitForCallbackConfig#builder()} instead. */ + @Deprecated + public static software.amazon.lambda.durable.config.WaitForCallbackConfig.Builder builder() { + return new software.amazon.lambda.durable.config.WaitForCallbackConfig.Builder(); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/config/CallbackConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/CallbackConfig.java new file mode 100644 index 000000000..a71962aff --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/CallbackConfig.java @@ -0,0 +1,110 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +import java.time.Duration; +import software.amazon.lambda.durable.serde.SerDes; +import software.amazon.lambda.durable.util.ParameterValidator; + +/** Configuration for callback operations. */ +public class CallbackConfig { + private final Duration timeout; + private final Duration heartbeatTimeout; + private final SerDes serDes; + + private CallbackConfig(Builder builder) { + this.timeout = builder.timeout; + this.heartbeatTimeout = builder.heartbeatTimeout; + this.serDes = builder.serDes; + } + + /** + * Returns the maximum duration to wait for the callback to complete. + * + * @return the timeout duration, or null if not specified + */ + public Duration timeout() { + return timeout; + } + + /** + * Returns the maximum duration between heartbeats before the callback is considered failed. + * + * @return the heartbeat timeout duration, or null if not specified + */ + public Duration heartbeatTimeout() { + return heartbeatTimeout; + } + + /** Returns the custom serializer for this callback, or null if not specified (uses default SerDes). */ + public SerDes serDes() { + return serDes; + } + + /** Creates a new builder with default values. */ + public static Builder builder() { + return new Builder(null, null, null); + } + + /** Creates a new builder pre-populated with this config's values. */ + public Builder toBuilder() { + return new Builder(timeout, heartbeatTimeout, serDes); + } + + /** Builder for {@link CallbackConfig}. */ + public static class Builder { + private Duration timeout; + private Duration heartbeatTimeout; + private SerDes serDes; + + public Builder(Duration timeout, Duration heartbeatTimeout, SerDes serDes) { + this.timeout = timeout; + this.heartbeatTimeout = heartbeatTimeout; + this.serDes = serDes; + } + + /** + * Sets the maximum duration to wait for the callback to complete before timing out. + * + * @param timeout the timeout duration + * @return this builder for method chaining + */ + public Builder timeout(Duration timeout) { + ParameterValidator.validateOptionalDuration(timeout, "Callback timeout"); + this.timeout = timeout; + return this; + } + + /** + * Sets the maximum duration between heartbeats before the callback is considered failed. + * + * @param heartbeatTimeout the heartbeat timeout duration + * @return this builder for method chaining + */ + public Builder heartbeatTimeout(Duration heartbeatTimeout) { + ParameterValidator.validateOptionalDuration(heartbeatTimeout, "Heartbeat timeout"); + this.heartbeatTimeout = heartbeatTimeout; + return this; + } + + /** + * Sets a custom serializer for the callback. + * + *

    If not specified, the callback will use the default SerDes configured for the handler. This allows + * per-callback customization of serialization behavior, useful for callbacks that need special handling (e.g., + * custom date formats, encryption, compression). + * + * @param serDes the custom serializer to use, or null to use the default + * @return this builder for method chaining + */ + public Builder serDes(SerDes serDes) { + this.serDes = serDes; + return this; + } + + /** Builds the {@link CallbackConfig} instance. */ + public CallbackConfig build() { + return new CallbackConfig(this); + } + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/CompletionConfig.java similarity index 66% rename from sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java rename to sdk/src/main/java/software/amazon/lambda/durable/config/CompletionConfig.java index fc52cd1a3..bca1c4ecf 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/CompletionConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/CompletionConfig.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; /** * Controls when a concurrent operation (map or parallel) completes. @@ -8,16 +8,8 @@ *

    Provides factory methods for common completion strategies and fine-grained control via {@code minSuccessful}, * {@code toleratedFailureCount}, and {@code toleratedFailurePercentage}. */ -public class CompletionConfig { - private final Integer minSuccessful; - private final Integer toleratedFailureCount; - private final Double toleratedFailurePercentage; - - private CompletionConfig(Integer minSuccessful, Integer toleratedFailureCount, Double toleratedFailurePercentage) { - this.minSuccessful = minSuccessful; - this.toleratedFailureCount = toleratedFailureCount; - this.toleratedFailurePercentage = toleratedFailurePercentage; - } +public record CompletionConfig( + Integer minSuccessful, Integer toleratedFailureCount, Double toleratedFailurePercentage) { /** All items must succeed. Zero failures tolerated. */ public static CompletionConfig allSuccessful() { @@ -58,19 +50,4 @@ public static CompletionConfig toleratedFailurePercentage(double percentage) { } return new CompletionConfig(null, null, percentage); } - - /** @return minimum number of successful items required, or null if not set */ - public Integer minSuccessful() { - return minSuccessful; - } - - /** @return maximum number of failures tolerated, or null if unlimited */ - public Integer toleratedFailureCount() { - return toleratedFailureCount; - } - - /** @return maximum percentage of failures tolerated (0.0 to 1.0), or null if not set */ - public Double toleratedFailurePercentage() { - return toleratedFailurePercentage; - } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/InvokeConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/InvokeConfig.java similarity index 98% rename from sdk/src/main/java/software/amazon/lambda/durable/InvokeConfig.java rename to sdk/src/main/java/software/amazon/lambda/durable/config/InvokeConfig.java index 0f091e9c7..e9dc7af24 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/InvokeConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/InvokeConfig.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import software.amazon.lambda.durable.serde.SerDes; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/MapConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/MapConfig.java similarity index 74% rename from sdk/src/main/java/software/amazon/lambda/durable/MapConfig.java rename to sdk/src/main/java/software/amazon/lambda/durable/config/MapConfig.java index d4f6b583c..7ca0e9b27 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/MapConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/MapConfig.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import software.amazon.lambda.durable.serde.SerDes; @@ -15,8 +15,9 @@ public class MapConfig { private final SerDes serDes; private MapConfig(Builder builder) { - this.maxConcurrency = builder.maxConcurrency; - this.completionConfig = builder.completionConfig; + this.maxConcurrency = builder.maxConcurrency == null ? Integer.MAX_VALUE : builder.maxConcurrency; + this.completionConfig = + builder.completionConfig == null ? CompletionConfig.allCompleted() : builder.completionConfig; this.serDes = builder.serDes; } @@ -27,7 +28,7 @@ public Integer maxConcurrency() { /** @return completion criteria, defaults to {@link CompletionConfig#allCompleted()} */ public CompletionConfig completionConfig() { - return completionConfig != null ? completionConfig : CompletionConfig.allCompleted(); + return completionConfig; } /** @return the custom serializer, or null to use the default */ @@ -56,24 +57,36 @@ private Builder(Integer maxConcurrency, CompletionConfig completionConfig, SerDe } public Builder maxConcurrency(Integer maxConcurrency) { + if (maxConcurrency != null && maxConcurrency < 1) { + throw new IllegalArgumentException("maxConcurrency must be at least 1, got: " + maxConcurrency); + } this.maxConcurrency = maxConcurrency; return this; } + /** + * Sets the completion criteria for the map operation. + * + * @param completionConfig the completion configuration (default: {@link CompletionConfig#allCompleted()}) + * @return this builder for method chaining + */ public Builder completionConfig(CompletionConfig completionConfig) { this.completionConfig = completionConfig; return this; } + /** + * Sets the custom serializer to use for serializing map items and results. + * + * @param serDes the serializer to use + * @return this builder for method chaining + */ public Builder serDes(SerDes serDes) { this.serDes = serDes; return this; } public MapConfig build() { - if (maxConcurrency != null && maxConcurrency < 1) { - throw new IllegalArgumentException("maxConcurrency must be at least 1, got: " + maxConcurrency); - } return new MapConfig(this); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/config/ParallelBranchConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/ParallelBranchConfig.java new file mode 100644 index 000000000..689f9aa54 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/ParallelBranchConfig.java @@ -0,0 +1,69 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +import software.amazon.lambda.durable.serde.SerDes; + +/** + * Configuration options for parallel branch in durable executions. + * + *

    This class provides a builder pattern for configuring various aspects of parallel branch execution + */ +public class ParallelBranchConfig { + private final SerDes serDes; + + private ParallelBranchConfig(Builder builder) { + this.serDes = builder.serDes; + } + + /** Returns the custom serializer for this step, or null if not specified (uses default SerDes). */ + public SerDes serDes() { + return serDes; + } + + public Builder toBuilder() { + return new Builder(serDes); + } + + /** + * Creates a new builder for ParallelBranchConfig. + * + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(null); + } + + /** Builder for creating StepConfig instances. */ + public static class Builder { + private SerDes serDes; + + public Builder(SerDes serDes) { + this.serDes = serDes; + } + + /** + * Sets a custom serializer for the step. + * + *

    If not specified, the parallel branch will use the default SerDes configured for the handler. This allows + * per-branch customization of serialization behavior, useful for branches that need special handling (e.g., + * custom date formats, encryption, compression). + * + * @param serDes the custom serializer to use, or null to use the default + * @return this builder for method chaining + */ + public Builder serDes(SerDes serDes) { + this.serDes = serDes; + return this; + } + + /** + * Builds the ParallelBranchConfig instance. + * + * @return a new StepConfig with the configured options + */ + public ParallelBranchConfig build() { + return new ParallelBranchConfig(this); + } + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/config/ParallelConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/ParallelConfig.java new file mode 100644 index 000000000..3371be21b --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/ParallelConfig.java @@ -0,0 +1,84 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +/** + * Configuration options for parallel operations in durable executions. + * + *

    This class provides a builder pattern for configuring concurrency limits and completion semantics for parallel + * branch execution. + */ +public class ParallelConfig { + private final int maxConcurrency; + private final CompletionConfig completionConfig; + + private ParallelConfig(Builder builder) { + this.maxConcurrency = builder.maxConcurrency == null ? Integer.MAX_VALUE : builder.maxConcurrency; + this.completionConfig = + builder.completionConfig == null ? CompletionConfig.allCompleted() : builder.completionConfig; + } + + /** @return the maximum number of branches running simultaneously, or -1 for unlimited */ + public int maxConcurrency() { + return maxConcurrency; + } + + public CompletionConfig completionConfig() { + return completionConfig; + } + + /** + * Creates a new builder for ParallelConfig. + * + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** Builder for creating ParallelConfig instances. */ + public static class Builder { + private Integer maxConcurrency; + private CompletionConfig completionConfig; + + private Builder() {} + + /** + * Sets the maximum number of branches that can run simultaneously. + * + * @param maxConcurrency the concurrency limit (default: unlimited) + * @return this builder for method chaining + */ + public Builder maxConcurrency(Integer maxConcurrency) { + if (maxConcurrency != null && maxConcurrency < 1) { + throw new IllegalArgumentException("maxConcurrency must be at least 1, got: " + maxConcurrency); + } + this.maxConcurrency = maxConcurrency; + return this; + } + + /** + * Sets the maximum number of branches that can run simultaneously. + * + * @param completionConfig the completion configuration for the parallel operation + * @return this builder for method chaining + */ + public Builder completionConfig(CompletionConfig completionConfig) { + if (completionConfig != null && completionConfig.toleratedFailurePercentage() != null) { + throw new IllegalArgumentException("ParallelConfig does not support toleratedFailurePercentage"); + } + this.completionConfig = completionConfig; + return this; + } + + /** + * Builds the ParallelConfig instance. + * + * @return a new ParallelConfig with the configured options + * @throws IllegalArgumentException if any configuration values are invalid + */ + public ParallelConfig build() { + return new ParallelConfig(this); + } + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/config/RunInChildContextConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/RunInChildContextConfig.java new file mode 100644 index 000000000..7eb42cbaf --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/RunInChildContextConfig.java @@ -0,0 +1,72 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +import software.amazon.lambda.durable.serde.SerDes; + +/** + * Configuration options for RunInChildContext operations in durable executions. + * + *

    This class provides a builder pattern for configuring various aspects of RunInChildContext execution. + */ +public class RunInChildContextConfig { + private final SerDes serDes; + + private RunInChildContextConfig(Builder builder) { + this.serDes = builder.serDes; + } + + /** + * Returns the custom serializer for this RunInChildContext operation, or null if not specified (uses default + * SerDes). + */ + public SerDes serDes() { + return serDes; + } + + public Builder toBuilder() { + return new Builder(serDes); + } + + /** + * Creates a new builder for RunInChildContextConfig. + * + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(null); + } + + /** Builder for creating StepConfig instances. */ + public static class Builder { + private SerDes serDes; + + public Builder(SerDes serDes) { + this.serDes = serDes; + } + + /** + * Sets a custom serializer for the step. + * + *

    If not specified, the RunInChildContext operation will use the default SerDes configured for the handler. + * This allows per-operation customization of serialization behavior, useful for operations that need special + * handling (e.g., custom date formats, encryption, compression). + * + * @param serDes the custom serializer to use, or null to use the default + * @return this builder for method chaining + */ + public Builder serDes(SerDes serDes) { + this.serDes = serDes; + return this; + } + + /** + * Builds the RunInChildContextConfig instance. + * + * @return a new StepConfig with the configured options + */ + public RunInChildContextConfig build() { + return new RunInChildContextConfig(this); + } + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/config/StepConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/StepConfig.java new file mode 100644 index 000000000..8eada6faf --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/StepConfig.java @@ -0,0 +1,112 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +import software.amazon.lambda.durable.retry.RetryStrategies; +import software.amazon.lambda.durable.retry.RetryStrategy; +import software.amazon.lambda.durable.serde.SerDes; + +/** + * Configuration options for step operations in durable executions. + * + *

    This class provides a builder pattern for configuring various aspects of step execution, including retry behavior + * and delivery semantics. + */ +public class StepConfig { + private final RetryStrategy retryStrategy; + private final StepSemantics semantics; + private final SerDes serDes; + + private StepConfig(Builder builder) { + this.retryStrategy = builder.retryStrategy; + this.semantics = builder.semantics; + this.serDes = builder.serDes; + } + + /** Returns the retry strategy for this step, or the default strategy if not specified. */ + public RetryStrategy retryStrategy() { + return retryStrategy != null ? retryStrategy : RetryStrategies.Presets.DEFAULT; + } + + /** Returns the delivery semantics for this step, defaults to AT_LEAST_ONCE_PER_RETRY if not specified. */ + public StepSemantics semantics() { + return semantics != null ? semantics : StepSemantics.AT_LEAST_ONCE_PER_RETRY; + } + + /** Returns the custom serializer for this step, or null if not specified (uses default SerDes). */ + public SerDes serDes() { + return serDes; + } + + public Builder toBuilder() { + return new Builder(retryStrategy, semantics, serDes); + } + + /** + * Creates a new builder for StepConfig. + * + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(null, null, null); + } + + /** Builder for creating StepConfig instances. */ + public static class Builder { + private RetryStrategy retryStrategy; + private StepSemantics semantics; + private SerDes serDes; + + public Builder(RetryStrategy retryStrategy, StepSemantics semantics, SerDes serDes) { + this.retryStrategy = retryStrategy; + this.semantics = semantics; + this.serDes = serDes; + } + + /** + * Sets the retry strategy for the step. + * + * @param retryStrategy the retry strategy to use, or null for default behavior + * @return this builder for method chaining + */ + public Builder retryStrategy(RetryStrategy retryStrategy) { + this.retryStrategy = retryStrategy; + return this; + } + + /** + * Sets the delivery semantics for the step. + * + * @param semantics the delivery semantics to use, defaults to AT_LEAST_ONCE_PER_RETRY if not specified + * @return this builder for method chaining + */ + public Builder semantics(StepSemantics semantics) { + this.semantics = semantics; + return this; + } + + /** + * Sets a custom serializer for the step. + * + *

    If not specified, the step will use the default SerDes configured for the handler. This allows per-step + * customization of serialization behavior, useful for steps that need special handling (e.g., custom date + * formats, encryption, compression). + * + * @param serDes the custom serializer to use, or null to use the default + * @return this builder for method chaining + */ + public Builder serDes(SerDes serDes) { + this.serDes = serDes; + return this; + } + + /** + * Builds the StepConfig instance. + * + * @return a new StepConfig with the configured options + */ + public StepConfig build() { + return new StepConfig(this); + } + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/StepSemantics.java b/sdk/src/main/java/software/amazon/lambda/durable/config/StepSemantics.java similarity index 94% rename from sdk/src/main/java/software/amazon/lambda/durable/StepSemantics.java rename to sdk/src/main/java/software/amazon/lambda/durable/config/StepSemantics.java index 2028d30be..fc23ab3c7 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/StepSemantics.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/StepSemantics.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; /** * Delivery semantics for step operations. diff --git a/sdk/src/main/java/software/amazon/lambda/durable/config/WaitForCallbackConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/WaitForCallbackConfig.java new file mode 100644 index 000000000..e3bd57f44 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/WaitForCallbackConfig.java @@ -0,0 +1,75 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +/** + * Configuration for the {@code waitForCallback} composite operation. + * + *

    Combines a {@link StepConfig} (for the step that produces the callback) and a {@link CallbackConfig} (for the + * callback wait itself). + */ +public class WaitForCallbackConfig { + private final StepConfig stepConfig; + private final CallbackConfig callbackConfig; + + private WaitForCallbackConfig(Builder builder) { + this.stepConfig = builder.stepConfig == null ? StepConfig.builder().build() : builder.stepConfig; + this.callbackConfig = + builder.callbackConfig == null ? CallbackConfig.builder().build() : builder.callbackConfig; + } + + /** Returns the step configuration for the composite operation. */ + public StepConfig stepConfig() { + return stepConfig; + } + + /** Returns the callback configuration for the composite operation. */ + public CallbackConfig callbackConfig() { + return callbackConfig; + } + + /** Creates a new builder. */ + public static Builder builder() { + return new Builder(); + } + + /** Creates a builder pre-populated with this instance's values. */ + public Builder toBuilder() { + return new Builder().stepConfig(this.stepConfig).callbackConfig(this.callbackConfig); + } + + /** Builder for {@link WaitForCallbackConfig}. */ + public static class Builder { + private StepConfig stepConfig; + private CallbackConfig callbackConfig; + + public Builder() {} + + /** + * Sets the step configuration for the composite operation. + * + * @param stepConfig the step configuration + * @return this builder for method chaining + */ + public Builder stepConfig(StepConfig stepConfig) { + this.stepConfig = stepConfig; + return this; + } + + /** + * Sets the callback configuration for the composite operation. + * + * @param callbackConfig the callback configuration + * @return this builder for method chaining + */ + public Builder callbackConfig(CallbackConfig callbackConfig) { + this.callbackConfig = callbackConfig; + return this; + } + + /** Builds the WaitForCallbackConfig instance. */ + public WaitForCallbackConfig build() { + return new WaitForCallbackConfig(this); + } + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/WaitForConditionConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/config/WaitForConditionConfig.java similarity index 97% rename from sdk/src/main/java/software/amazon/lambda/durable/WaitForConditionConfig.java rename to sdk/src/main/java/software/amazon/lambda/durable/config/WaitForConditionConfig.java index 9ec3eb042..6cc54651a 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/WaitForConditionConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/config/WaitForConditionConfig.java @@ -1,7 +1,8 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; +import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.retry.WaitForConditionWaitStrategy; import software.amazon.lambda.durable.retry.WaitStrategies; import software.amazon.lambda.durable.serde.SerDes; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java index 2cee25d09..9920366f4 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/BaseContextImpl.java @@ -5,8 +5,6 @@ import com.amazonaws.services.lambda.runtime.Context; import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.execution.ExecutionManager; -import software.amazon.lambda.durable.execution.SuspendExecutionException; -import software.amazon.lambda.durable.execution.ThreadContext; import software.amazon.lambda.durable.execution.ThreadType; public abstract class BaseContextImpl implements AutoCloseable, BaseContext { @@ -36,28 +34,6 @@ protected BaseContextImpl( String contextId, String contextName, ThreadType threadType) { - this(executionManager, durableConfig, lambdaContext, contextId, contextName, threadType, true); - } - - /** - * Creates a new BaseContext instance. - * - * @param executionManager the execution manager for thread coordination and state management - * @param durableConfig the durable execution configuration - * @param lambdaContext the AWS Lambda runtime context - * @param contextId the context ID, null for root context, set for child contexts - * @param contextName the human-readable name for this context - * @param threadType the type of thread this context runs on - * @param setCurrentThreadContext whether to call setCurrentThreadContext on the execution manager - */ - protected BaseContextImpl( - ExecutionManager executionManager, - DurableConfig durableConfig, - Context lambdaContext, - String contextId, - String contextName, - ThreadType threadType, - boolean setCurrentThreadContext) { this.executionManager = executionManager; this.durableConfig = durableConfig; this.lambdaContext = lambdaContext; @@ -65,11 +41,6 @@ protected BaseContextImpl( this.contextName = contextName; this.isReplaying = executionManager.hasOperationsForContext(contextId); this.threadType = threadType; - - if (setCurrentThreadContext) { - // write the thread id and type to thread local - executionManager.setCurrentThreadContext(new ThreadContext(contextId, threadType)); - } } // =============== accessors ================ @@ -138,23 +109,4 @@ public boolean isReplaying() { public void setExecutionMode() { this.isReplaying = false; } - - @Override - public void close() { - // this is called in the user thread, after the context's user code has completed - if (getContextId() != null) { - // if this is a child context or a step context, we need to - // deregister the context's thread from the execution manager - try { - executionManager.deregisterActiveThread(getContextId()); - } catch (SuspendExecutionException e) { - // Expected when this is the last active thread. Must catch here because: - // 1/ This runs in a worker thread detached from handlerFuture - // 2/ Uncaught exception would prevent stepAsync().get() from resume - // Suspension/Termination is already signaled via - // suspendExecutionFuture/terminateExecutionFuture - // before the throw. - } - } - } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java index e171276de..ac2eb6af6 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java @@ -10,24 +10,23 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; -import java.util.function.Supplier; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.lambda.model.OperationType; -import software.amazon.lambda.durable.CallbackConfig; import software.amazon.lambda.durable.DurableCallbackFuture; import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.DurableFuture; -import software.amazon.lambda.durable.InvokeConfig; -import software.amazon.lambda.durable.MapConfig; -import software.amazon.lambda.durable.MapFunction; -import software.amazon.lambda.durable.ParallelConfig; -import software.amazon.lambda.durable.ParallelContext; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.ParallelDurableFuture; import software.amazon.lambda.durable.StepContext; import software.amazon.lambda.durable.TypeToken; -import software.amazon.lambda.durable.WaitForCallbackConfig; -import software.amazon.lambda.durable.WaitForConditionConfig; +import software.amazon.lambda.durable.config.CallbackConfig; +import software.amazon.lambda.durable.config.InvokeConfig; +import software.amazon.lambda.durable.config.MapConfig; +import software.amazon.lambda.durable.config.ParallelConfig; +import software.amazon.lambda.durable.config.RunInChildContextConfig; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.WaitForCallbackConfig; +import software.amazon.lambda.durable.config.WaitForConditionConfig; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.execution.OperationIdGenerator; import software.amazon.lambda.durable.execution.ThreadType; @@ -68,24 +67,7 @@ private DurableContextImpl( Context lambdaContext, String contextId, String contextName) { - this(executionManager, durableConfig, lambdaContext, contextId, contextName, true); - } - - private DurableContextImpl( - ExecutionManager executionManager, - DurableConfig durableConfig, - Context lambdaContext, - String contextId, - String contextName, - boolean setCurrentThreadContext) { - super( - executionManager, - durableConfig, - lambdaContext, - contextId, - contextName, - ThreadType.CONTEXT, - setCurrentThreadContext); + super(executionManager, durableConfig, lambdaContext, contextId, contextName, ThreadType.CONTEXT); operationIdGenerator = new OperationIdGenerator(contextId); } @@ -116,22 +98,6 @@ public DurableContextImpl createChildContext(String childContextId, String child getExecutionManager(), getDurableConfig(), getLambdaContext(), childContextId, childContextName); } - /** - * Creates a child context without setting the current thread context. - * - *

    Use this when the child context is being created on a thread that should not have its thread-local context - * overwritten (e.g. when constructing the context ahead of running it on a separate thread). - * - * @param childContextId the child context's ID (the CONTEXT operation's operation ID) - * @param childContextName the name of the child context - * @return a new DurableContext for the child context - */ - public DurableContextImpl createChildContextWithoutSettingThreadContext( - String childContextId, String childContextName) { - return new DurableContextImpl( - getExecutionManager(), getDurableConfig(), getLambdaContext(), childContextId, childContextName, false); - } - /** * Creates a step context for executing step operations. * @@ -150,52 +116,11 @@ public StepContextImpl createStepContext(String stepOperationId, String stepOper attempt); } - // ========== step methods ========== - - @Override - public T step(String name, Class resultType, Function func) { - return step(name, TypeToken.get(resultType), func, StepConfig.builder().build()); - } - - @Override - public T step(String name, Class resultType, Function func, StepConfig config) { - // Simply delegate to stepAsync and block on the result - return stepAsync(name, resultType, func, config).get(); - } - - @Override - public T step(String name, TypeToken typeToken, Function func) { - return step(name, typeToken, func, StepConfig.builder().build()); - } - - @Override - public T step(String name, TypeToken typeToken, Function func, StepConfig config) { - // Simply delegate to stepAsync and block on the result - return stepAsync(name, typeToken, func, config).get(); - } - - @Override - public DurableFuture stepAsync(String name, Class resultType, Function func) { - return stepAsync( - name, TypeToken.get(resultType), func, StepConfig.builder().build()); - } - - @Override - public DurableFuture stepAsync( - String name, Class resultType, Function func, StepConfig config) { - return stepAsync(name, TypeToken.get(resultType), func, config); - } - - @Override - public DurableFuture stepAsync(String name, TypeToken typeToken, Function func) { - return stepAsync(name, typeToken, func, StepConfig.builder().build()); - } - @Override public DurableFuture stepAsync( - String name, TypeToken typeToken, Function func, StepConfig config) { + String name, TypeToken resultType, Function func, StepConfig config) { Objects.requireNonNull(config, "config cannot be null"); - Objects.requireNonNull(typeToken, "typeToken cannot be null"); + Objects.requireNonNull(resultType, "resultType cannot be null"); ParameterValidator.validateOperationName(name); if (config.serDes() == null) { @@ -205,85 +130,13 @@ public DurableFuture stepAsync( // Create and start step operation with TypeToken var operation = new StepOperation<>( - OperationIdentifier.of(operationId, name, OperationType.STEP), func, typeToken, config, this); + OperationIdentifier.of(operationId, name, OperationType.STEP), func, resultType, config, this); operation.execute(); // Start the step (returns immediately) return operation; } - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public T step(String name, Class resultType, Supplier func) { - return stepAsync( - name, - TypeToken.get(resultType), - func, - StepConfig.builder().build()) - .get(); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public T step(String name, Class resultType, Supplier func, StepConfig config) { - // Simply delegate to stepAsync and block on the result - return stepAsync(name, TypeToken.get(resultType), func, config).get(); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public T step(String name, TypeToken typeToken, Supplier func) { - return stepAsync(name, typeToken, func, StepConfig.builder().build()).get(); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public T step(String name, TypeToken typeToken, Supplier func, StepConfig config) { - // Simply delegate to stepAsync and block on the result - return stepAsync(name, typeToken, func, config).get(); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public DurableFuture stepAsync(String name, Class resultType, Supplier func) { - return stepAsync( - name, TypeToken.get(resultType), func, StepConfig.builder().build()); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public DurableFuture stepAsync(String name, Class resultType, Supplier func, StepConfig config) { - return stepAsync(name, TypeToken.get(resultType), func, config); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public DurableFuture stepAsync(String name, TypeToken typeToken, Supplier func) { - return stepAsync(name, typeToken, func, StepConfig.builder().build()); - } - - /** @deprecated use the variants accepting StepContext instead */ - @Deprecated - @Override - public DurableFuture stepAsync(String name, TypeToken typeToken, Supplier func, StepConfig config) { - return stepAsync(name, typeToken, stepContext -> func.get(), config); - } - - // ========== wait methods ========== - - @Override - public Void wait(String name, Duration duration) { - // Block (will throw SuspendExecutionException if there is no active thread) - return waitAsync(name, duration).get(); - } - @Override public DurableFuture waitAsync(String name, Duration duration) { ParameterValidator.validateDuration(duration, "Wait duration"); @@ -299,69 +152,11 @@ public DurableFuture waitAsync(String name, Duration duration) { return operation; } - // ========== chained invoke methods ========== - - @Override - public T invoke(String name, String functionName, U payload, Class resultType) { - return invokeAsync( - name, - functionName, - payload, - TypeToken.get(resultType), - InvokeConfig.builder().build()) - .get(); - } - - @Override - public T invoke(String name, String functionName, U payload, Class resultType, InvokeConfig config) { - return invokeAsync(name, functionName, payload, TypeToken.get(resultType), config) - .get(); - } - - @Override - public T invoke(String name, String functionName, U payload, TypeToken typeToken) { - return invokeAsync( - name, - functionName, - payload, - typeToken, - InvokeConfig.builder().build()) - .get(); - } - - @Override - public T invoke(String name, String functionName, U payload, TypeToken typeToken, InvokeConfig config) { - return invokeAsync(name, functionName, payload, typeToken, config).get(); - } - - /** Asynchronously invokes another Lambda function with custom configuration. */ @Override public DurableFuture invokeAsync( - String name, String functionName, U payload, Class resultType, InvokeConfig config) { - return invokeAsync(name, functionName, payload, TypeToken.get(resultType), config); - } - - @Override - public DurableFuture invokeAsync(String name, String functionName, U payload, Class resultType) { - return invokeAsync( - name, - functionName, - payload, - TypeToken.get(resultType), - InvokeConfig.builder().build()); - } - - @Override - public DurableFuture invokeAsync(String name, String functionName, U payload, TypeToken resultType) { - return invokeAsync( - name, functionName, payload, resultType, InvokeConfig.builder().build()); - } - - @Override - public DurableFuture invokeAsync( - String name, String functionName, U payload, TypeToken typeToken, InvokeConfig config) { + String name, String functionName, U payload, TypeToken resultType, InvokeConfig config) { Objects.requireNonNull(config, "config cannot be null"); - Objects.requireNonNull(typeToken, "typeToken cannot be null"); + Objects.requireNonNull(resultType, "resultType cannot be null"); ParameterValidator.validateOperationName(name); if (config.serDes() == null) { @@ -379,7 +174,7 @@ public DurableFuture invokeAsync( OperationIdentifier.of(operationId, name, OperationType.CHAINED_INVOKE), functionName, payload, - typeToken, + resultType, config, this); @@ -387,26 +182,8 @@ public DurableFuture invokeAsync( return operation; // Block (will throw SuspendExecutionException if needed) } - // ========== createCallback methods ========== - - @Override - public DurableCallbackFuture createCallback(String name, Class resultType, CallbackConfig config) { - return createCallback(name, TypeToken.get(resultType), config); - } - @Override - public DurableCallbackFuture createCallback(String name, TypeToken typeToken) { - return createCallback(name, typeToken, CallbackConfig.builder().build()); - } - - @Override - public DurableCallbackFuture createCallback(String name, Class resultType) { - return createCallback( - name, TypeToken.get(resultType), CallbackConfig.builder().build()); - } - - @Override - public DurableCallbackFuture createCallback(String name, TypeToken typeToken, CallbackConfig config) { + public DurableCallbackFuture createCallback(String name, TypeToken resultType, CallbackConfig config) { ParameterValidator.validateOperationName(name); if (config.serDes() == null) { config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build(); @@ -414,109 +191,57 @@ public DurableCallbackFuture createCallback(String name, TypeToken typ var operationId = nextOperationId(); var operation = new CallbackOperation<>( - OperationIdentifier.of(operationId, name, OperationType.CALLBACK), typeToken, config, this); + OperationIdentifier.of(operationId, name, OperationType.CALLBACK), resultType, config, this); operation.execute(); return operation; } - // ========== runInChildContext methods ========== - - @Override - public T runInChildContext(String name, Class resultType, Function func) { - return runInChildContextAsync(name, TypeToken.get(resultType), func).get(); - } - - @Override - public T runInChildContext(String name, TypeToken typeToken, Function func) { - return runInChildContextAsync(name, typeToken, func).get(); - } - - @Override - public DurableFuture runInChildContextAsync( - String name, Class resultType, Function func) { - return runInChildContextAsync(name, TypeToken.get(resultType), func); - } - + /** + * Runs a function in a child context, blocking until it completes. + * + *

    Child contexts provide isolated operation ID namespaces, allowing nested workflows to be composed without ID + * collisions. On replay, the child context's operations are replayed independently. + * + * @param name the operation name within this context + * @param resultType the result class for deserialization + * @param func the function to execute, receiving a child {@link DurableContext} + * @param config the configuration for the child context + * @return the DurableFuture wrapping the child context result + */ @Override public DurableFuture runInChildContextAsync( - String name, TypeToken typeToken, Function func) { - return runInChildContextAsync(name, typeToken, func, OperationSubType.RUN_IN_CHILD_CONTEXT); + String name, TypeToken resultType, Function func, RunInChildContextConfig config) { + return runInChildContextAsync(name, resultType, func, config, OperationSubType.RUN_IN_CHILD_CONTEXT); } private DurableFuture runInChildContextAsync( - String name, TypeToken typeToken, Function func, OperationSubType subType) { - Objects.requireNonNull(typeToken, "typeToken cannot be null"); + String name, + TypeToken resultType, + Function func, + RunInChildContextConfig config, + OperationSubType subType) { + Objects.requireNonNull(resultType, "resultType cannot be null"); + Objects.requireNonNull(config, "RunInChildContextConfig cannot be null"); ParameterValidator.validateOperationName(name); + + if (config.serDes() == null) { + config = config.toBuilder().serDes(getDurableConfig().getSerDes()).build(); + } + var operationId = nextOperationId(); var operation = new ChildContextOperation<>( OperationIdentifier.of(operationId, name, OperationType.CONTEXT, subType), func, - typeToken, - getDurableConfig().getSerDes(), + resultType, + config, this); operation.execute(); return operation; } - // ========== map methods ========== - - @Override - public MapResult map(String name, Collection items, Class resultType, MapFunction function) { - return mapAsync( - name, - items, - TypeToken.get(resultType), - function, - MapConfig.builder().build()) - .get(); - } - - @Override - public MapResult map( - String name, Collection items, Class resultType, MapFunction function, MapConfig config) { - return mapAsync(name, items, TypeToken.get(resultType), function, config) - .get(); - } - - @Override - public MapResult map( - String name, Collection items, TypeToken resultType, MapFunction function) { - return mapAsync(name, items, resultType, function, MapConfig.builder().build()) - .get(); - } - - @Override - public MapResult map( - String name, Collection items, TypeToken resultType, MapFunction function, MapConfig config) { - return mapAsync(name, items, resultType, function, config).get(); - } - - @Override - public DurableFuture> mapAsync( - String name, Collection items, Class resultType, MapFunction function) { - return mapAsync( - name, - items, - TypeToken.get(resultType), - function, - MapConfig.builder().build()); - } - - @Override - public DurableFuture> mapAsync( - String name, Collection items, Class resultType, MapFunction function, MapConfig config) { - return mapAsync(name, items, TypeToken.get(resultType), function, config); - } - - @Override - public DurableFuture> mapAsync( - String name, Collection items, TypeToken resultType, MapFunction function) { - return mapAsync(name, items, resultType, function, MapConfig.builder().build()); - } - @Override public DurableFuture> mapAsync( String name, Collection items, TypeToken resultType, MapFunction function, MapConfig config) { @@ -551,10 +276,8 @@ public DurableFuture> mapAsync( return operation; } - // ========== parallel methods ========== - @Override - public ParallelContext parallel(String name, ParallelConfig config) { + public ParallelDurableFuture parallel(String name, ParallelConfig config) { Objects.requireNonNull(config, "config cannot be null"); var operationId = nextOperationId(); @@ -562,87 +285,20 @@ public ParallelContext parallel(String name, ParallelConfig config) { OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL), getDurableConfig().getSerDes(), this, - config.maxConcurrency(), - config.minSuccessful(), - config.toleratedFailureCount()); + config); parallelOp.execute(); - return new ParallelContext(parallelOp, this); - } - - // ========= waitForCallback methods ============= - - @Override - public T waitForCallback(String name, Class resultType, BiConsumer func) { - return waitForCallbackAsync( - name, - TypeToken.get(resultType), - func, - WaitForCallbackConfig.builder().build()) - .get(); - } - - @Override - public T waitForCallback(String name, TypeToken typeToken, BiConsumer func) { - return waitForCallbackAsync( - name, typeToken, func, WaitForCallbackConfig.builder().build()) - .get(); - } - - @Override - public T waitForCallback( - String name, - Class resultType, - BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig) { - return waitForCallbackAsync(name, TypeToken.get(resultType), func, waitForCallbackConfig) - .get(); - } - - @Override - public T waitForCallback( - String name, - TypeToken typeToken, - BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig) { - return waitForCallbackAsync(name, typeToken, func, waitForCallbackConfig) - .get(); - } - - @Override - public DurableFuture waitForCallbackAsync( - String name, Class resultType, BiConsumer func) { - return waitForCallbackAsync( - name, - TypeToken.get(resultType), - func, - WaitForCallbackConfig.builder().build()); - } - - @Override - public DurableFuture waitForCallbackAsync( - String name, TypeToken typeToken, BiConsumer func) { - return waitForCallbackAsync( - name, typeToken, func, WaitForCallbackConfig.builder().build()); - } - - @Override - public DurableFuture waitForCallbackAsync( - String name, - Class resultType, - BiConsumer func, - WaitForCallbackConfig waitForCallbackConfig) { - return waitForCallbackAsync(name, TypeToken.get(resultType), func, waitForCallbackConfig); + return parallelOp; } @Override public DurableFuture waitForCallbackAsync( String name, - TypeToken typeToken, + TypeToken resultType, BiConsumer func, WaitForCallbackConfig waitForCallbackConfig) { - Objects.requireNonNull(typeToken, "typeToken cannot be null"); + Objects.requireNonNull(resultType, "resultType cannot be null"); Objects.requireNonNull(waitForCallbackConfig, "waitForCallbackConfig cannot be null"); // waitForCallback adds a suffix for the callback operation name and the submitter operation name so // the length restriction of waitForCallback name is different from the other operations. @@ -658,11 +314,11 @@ public DurableFuture waitForCallbackAsync( return runInChildContextAsync( name, - typeToken, + resultType, childCtx -> { var callback = childCtx.createCallback( name + WAIT_FOR_CALLBACK_CALLBACK_SUFFIX, - typeToken, + resultType, finalWaitForCallbackConfig.callbackConfig()); childCtx.step( name + WAIT_FOR_CALLBACK_SUBMITTER_SUFFIX, @@ -674,109 +330,21 @@ public DurableFuture waitForCallbackAsync( finalWaitForCallbackConfig.stepConfig()); return callback.get(); }, + RunInChildContextConfig.builder() + .serDes(finalWaitForCallbackConfig.stepConfig().serDes()) + .build(), OperationSubType.WAIT_FOR_CALLBACK); } - // ========== waitForCondition methods ========== - @Override - public T waitForCondition( - String name, - Class resultType, - BiFunction> checkFunc, - T initialState) { - return waitForConditionAsync( - name, - TypeToken.get(resultType), - checkFunc, - initialState, - WaitForConditionConfig.builder().build()) - .get(); - } - - @Override - public T waitForCondition( - String name, - Class resultType, - BiFunction> checkFunc, - T initialState, - WaitForConditionConfig config) { - return waitForConditionAsync(name, resultType, checkFunc, initialState, config) - .get(); - } - - @Override - public T waitForCondition( - String name, - TypeToken typeToken, - BiFunction> checkFunc, - T initialState) { - return waitForConditionAsync( - name, - typeToken, - checkFunc, - initialState, - WaitForConditionConfig.builder().build()) - .get(); - } - - @Override - public T waitForCondition( - String name, - TypeToken typeToken, - BiFunction> checkFunc, - T initialState, - WaitForConditionConfig config) { - return waitForConditionAsync(name, typeToken, checkFunc, initialState, config) - .get(); - } - - @Override - public DurableFuture waitForConditionAsync( - String name, - Class resultType, - BiFunction> checkFunc, - T initialState) { - return waitForConditionAsync( - name, - TypeToken.get(resultType), - checkFunc, - initialState, - WaitForConditionConfig.builder().build()); - } - - @Override - public DurableFuture waitForConditionAsync( - String name, - Class resultType, - BiFunction> checkFunc, - T initialState, - WaitForConditionConfig config) { - return waitForConditionAsync(name, TypeToken.get(resultType), checkFunc, initialState, config); - } - @Override public DurableFuture waitForConditionAsync( String name, - TypeToken typeToken, - BiFunction> checkFunc, - T initialState) { - return waitForConditionAsync( - name, - typeToken, - checkFunc, - initialState, - WaitForConditionConfig.builder().build()); - } - - @Override - public DurableFuture waitForConditionAsync( - String name, - TypeToken typeToken, + TypeToken resultType, BiFunction> checkFunc, T initialState, WaitForConditionConfig config) { Objects.requireNonNull(config, "config cannot be null"); - Objects.requireNonNull(typeToken, "typeToken cannot be null"); + Objects.requireNonNull(resultType, "resultType cannot be null"); Objects.requireNonNull(checkFunc, "checkFunc cannot be null"); Objects.requireNonNull(initialState, "initialState cannot be null"); ParameterValidator.validateOperationName(name); @@ -787,7 +355,7 @@ public DurableFuture waitForConditionAsync( var operationId = nextOperationId(); var operation = - new WaitForConditionOperation<>(operationId, name, checkFunc, typeToken, initialState, config, this); + new WaitForConditionOperation<>(operationId, name, checkFunc, resultType, initialState, config, this); operation.execute(); @@ -817,7 +385,6 @@ public void close() { if (logger != null) { logger.close(); } - super.close(); } /** diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/StepContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/StepContextImpl.java index af5c9222a..dcf5af66b 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/StepContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/StepContextImpl.java @@ -66,6 +66,5 @@ public void close() { if (logger != null) { logger.close(); } - super.close(); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/execution/DurableExecutor.java b/sdk/src/main/java/software/amazon/lambda/durable/execution/DurableExecutor.java index c7d0c490c..28608e8dd 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/execution/DurableExecutor.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/execution/DurableExecutor.java @@ -51,6 +51,7 @@ public static DurableExecutionOutput execute( executionManager.registerActiveThread(null); var handlerFuture = CompletableFuture.supplyAsync( () -> { + executionManager.setCurrentThreadContext(new ThreadContext(null, ThreadType.CONTEXT)); var userInput = extractUserInput( executionManager.getExecutionOperation(), config.getSerDes(), inputType); // use try-with-resources to clear logger properties diff --git a/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java b/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java index 774cfa022..97d516ee1 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java @@ -56,8 +56,7 @@ public class ExecutionManager implements AutoCloseable { private final AtomicReference executionMode; // ===== Thread Coordination ===== - private final Map> registeredOperations = - Collections.synchronizedMap(new HashMap<>()); + private final Map registeredOperations = Collections.synchronizedMap(new HashMap<>()); private final Set activeThreads = Collections.synchronizedSet(new HashSet<>()); private static final ThreadLocal currentThreadContext = new ThreadLocal<>(); private final CompletableFuture executionExceptionFuture = new CompletableFuture<>(); @@ -107,7 +106,7 @@ public boolean isReplaying() { } /** Registers an operation so it can receive checkpoint completion notifications. */ - public void registerOperation(BaseDurableOperation operation) { + public void registerOperation(BaseDurableOperation operation) { registeredOperations.put(operation.getOperationId(), operation); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/MapError.java b/sdk/src/main/java/software/amazon/lambda/durable/model/MapError.java deleted file mode 100644 index 478a48edd..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/MapError.java +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable.model; - -import java.util.List; - -/** - * Error details for a failed map item. - * - *

    Stores error information as plain strings so that {@link MapResult} can be serialized through the user's SerDes - * without requiring AWS SDK-specific Jackson modules. - * - * @param errorType the fully qualified exception class name - * @param errorMessage the error message - * @param stackTrace the stack trace frames, or null - */ -public record MapError(String errorType, String errorMessage, List stackTrace) {} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java index a307c4fbf..9fbbbcb09 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java @@ -4,6 +4,7 @@ import java.util.Collections; import java.util.List; +import software.amazon.lambda.durable.util.ExceptionHelper; /** * Result container for map operations. @@ -59,8 +60,7 @@ public int size() { /** Returns all results as an unmodifiable list (nulls for failed/not-started items). */ public List results() { - return Collections.unmodifiableList( - items.stream().map(MapResultItem::result).toList()); + return items.stream().map(MapResultItem::result).toList(); } /** Returns results from items that succeeded (includes null results from successful items). */ @@ -78,4 +78,60 @@ public List failed() { .map(MapResultItem::error) .toList(); } + + /** + * Represents the outcome of a single item in a map operation. + * + *

    Each item either succeeds with a result, fails with an error, or was never started. The status field indicates + * which case applies. + * + *

    Errors are stored as {@link MapError} (plain strings) rather than raw Throwable, so they survive serialization + * across checkpoint-and-replay cycles without requiring AWS SDK-specific Jackson modules. + * + * @param status the status of this item + * @param result the result value, or null if failed/not started + * @param error the error details, or null if succeeded/not started + * @param the result type + */ + public record MapResultItem(Status status, T result, MapError error) { + + /** Status of an individual map item. */ + public enum Status { + SUCCEEDED, + FAILED, + SKIPPED + } + + /** Creates a successful result item. */ + public static MapResultItem succeeded(T result) { + return new MapResultItem<>(Status.SUCCEEDED, result, null); + } + + /** Creates a failed result item. */ + public static MapResultItem failed(MapError error) { + return new MapResultItem<>(Status.FAILED, null, error); + } + + /** Creates a skipped result item. */ + public static MapResultItem skipped() { + return new MapResultItem<>(Status.SKIPPED, null, null); + } + } + + /** + * Error details for a failed map item. + * + *

    Stores error information as plain strings so that {@link MapResult} can be serialized through the user's + * SerDes without requiring AWS SDK-specific Jackson modules. + * + * @param errorType the fully qualified exception class name + * @param errorMessage the error message + * @param stackTrace the stack trace frames, or null + */ + public record MapError(String errorType, String errorMessage, List stackTrace) { + public static MapError of(Throwable e) { + return new MapError( + e.getClass().getName(), e.getMessage(), ExceptionHelper.serializeStackTrace(e.getStackTrace())); + } + } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResultItem.java b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResultItem.java deleted file mode 100644 index 86cb8ce79..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResultItem.java +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable.model; - -/** - * Represents the outcome of a single item in a map operation. - * - *

    Each item either succeeds with a result, fails with an error, or was never started. The status field indicates - * which case applies. - * - *

    Errors are stored as {@link MapError} (plain strings) rather than raw Throwable, so they survive serialization - * across checkpoint-and-replay cycles without requiring AWS SDK-specific Jackson modules. - * - * @param status the status of this item - * @param result the result value, or null if failed/not started - * @param error the error details, or null if succeeded/not started - * @param the result type - */ -public record MapResultItem(Status status, T result, MapError error) { - - /** Status of an individual map item. */ - public enum Status { - SUCCEEDED, - FAILED, - NOT_STARTED - } - - /** Creates a successful result item. */ - public static MapResultItem success(T result) { - return new MapResultItem<>(Status.SUCCEEDED, result, null); - } - - /** Creates a failed result item. */ - public static MapResultItem failure(MapError error) { - return new MapResultItem<>(Status.FAILED, null, error); - } - - /** Creates a not-started result item. */ - public static MapResultItem notStarted() { - return new MapResultItem<>(Status.NOT_STARTED, null, null); - } -} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/ParallelResult.java b/sdk/src/main/java/software/amazon/lambda/durable/model/ParallelResult.java index f0e0fcf18..11bd9049e 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/ParallelResult.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/model/ParallelResult.java @@ -8,41 +8,4 @@ *

    Captures the aggregate outcome of a parallel execution: how many branches were registered, how many succeeded, how * many failed, and why the operation completed. */ -public class ParallelResult { - - private final int totalBranches; - private final int succeededBranches; - private final int failedBranches; - private final ConcurrencyCompletionStatus completionStatus; - - public ParallelResult( - int totalBranches, - int succeededBranches, - int failedBranches, - ConcurrencyCompletionStatus completionStatus) { - this.totalBranches = totalBranches; - this.succeededBranches = succeededBranches; - this.failedBranches = failedBranches; - this.completionStatus = completionStatus; - } - - /** Returns the total number of branches registered before {@code join()} was called. */ - public int getTotalBranches() { - return totalBranches; - } - - /** Returns the number of branches that completed without throwing. */ - public int getSucceededBranches() { - return succeededBranches; - } - - /** Returns the number of branches that threw an exception. */ - public int getFailedBranches() { - return failedBranches; - } - - /** Returns the status indicating why the parallel operation completed. */ - public ConcurrencyCompletionStatus getCompletionStatus() { - return completionStatus; - } -} +public record ParallelResult(int size, int succeeded, int failed, ConcurrencyCompletionStatus completionStatus) {} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java index bd951b407..30cb67611 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java @@ -6,26 +6,22 @@ import java.util.List; import java.util.Objects; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.services.lambda.model.ErrorObject; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.OperationUpdate; -import software.amazon.lambda.durable.DurableFuture; -import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.IllegalDurableOperationException; import software.amazon.lambda.durable.exception.NonDeterministicExecutionException; -import software.amazon.lambda.durable.exception.SerDesException; import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; import software.amazon.lambda.durable.execution.ExecutionManager; +import software.amazon.lambda.durable.execution.SuspendExecutionException; import software.amazon.lambda.durable.execution.ThreadContext; import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; -import software.amazon.lambda.durable.serde.SerDes; -import software.amazon.lambda.durable.util.ExceptionHelper; /** * Base class for all durable operations (STEP, WAIT, etc.). @@ -45,34 +41,25 @@ *

  • Proper thread coordination via future * */ -public abstract class BaseDurableOperation implements DurableFuture { +public abstract class BaseDurableOperation { private static final Logger logger = LoggerFactory.getLogger(BaseDurableOperation.class); private final OperationIdentifier operationIdentifier; - private final ExecutionManager executionManager; - private final TypeToken resultTypeToken; - private final SerDes resultSerDes; - protected final CompletableFuture completionFuture; + protected final ExecutionManager executionManager; + protected final CompletableFuture completionFuture; private final DurableContextImpl durableContext; + private final AtomicReference> runningUserHandler = new AtomicReference<>(null); /** * Constructs a new durable operation. * * @param operationIdentifier the unique identifier for this operation - * @param resultTypeToken the type token for deserializing the result - * @param resultSerDes the serializer/deserializer for the result * @param durableContext the parent context this operation belongs to */ - protected BaseDurableOperation( - OperationIdentifier operationIdentifier, - TypeToken resultTypeToken, - SerDes resultSerDes, - DurableContextImpl durableContext) { + protected BaseDurableOperation(OperationIdentifier operationIdentifier, DurableContextImpl durableContext) { this.operationIdentifier = operationIdentifier; this.durableContext = durableContext; this.executionManager = durableContext.getExecutionManager(); - this.resultTypeToken = resultTypeToken; - this.resultSerDes = resultSerDes; this.completionFuture = new CompletableFuture<>(); @@ -80,6 +67,10 @@ protected BaseDurableOperation( executionManager.registerOperation(this); } + public CompletableFuture getCompletionFuture() { + return completionFuture; + } + /** Gets the operation sub-type (e.g. RUN_IN_CHILD_CONTEXT, WAIT_FOR_CALLBACK). */ public OperationSubType getSubType() { return operationIdentifier.subType(); @@ -144,13 +135,12 @@ protected Operation getOperation() { } /** - * Gets the direct child Operations of a give context operation. + * Gets the direct child Operations of this context operation * - * @param operationId the operation id of the context * @return list of the child Operations */ - protected List getChildOperations(String operationId) { - return executionManager.getChildOperations(operationId); + protected List getChildOperations() { + return executionManager.getChildOperations(getOperationId()); } /** @@ -165,7 +155,7 @@ private void validateCurrentThreadType() { "Nested %s operation is not supported on %s from within a %s execution.", getType(), getName(), current); // terminate execution and throw the exception - terminateExecutionWithIllegalDurableOperationException(message); + throw terminateExecutionWithIllegalDurableOperationException(message); } } @@ -215,12 +205,56 @@ protected Operation waitForOperationCompletion() { // Get result based on status var op = getOperation(); if (op == null) { - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( String.format("%s operation not found: %s", getType(), getOperationId())); } return op; } + protected void runUserHandler(Runnable runnable, String contextId, ThreadType threadType) { + Runnable wrapped = () -> { + executionManager.setCurrentThreadContext(new ThreadContext(contextId, threadType)); + try { + runnable.run(); + } finally { + if (contextId != null) { + try { + // if this is a child context or a step context, we need to + // deregister the context's thread from the execution manager + executionManager.deregisterActiveThread(contextId); + } catch (SuspendExecutionException e) { + // Expected when this is the last active thread. Must catch here because: + // 1/ This runs in a worker thread detached from handlerFuture + // 2/ Uncaught exception would prevent stepAsync().get() from resume + // Suspension/Termination is already signaled via + // suspendExecutionFuture/terminateExecutionFuture + // before the throw. + } + } + } + }; + + // runUserHandler is used to ensure that only one user handler is running at a time + if (runningUserHandler.get() != null) { + throw new IllegalStateException("User handler already running"); + } + + // Thread registration is intentionally split across two threads: + // 1. registerActiveThread on the PARENT thread — ensures the child is tracked before the + // parent can deregister and trigger suspension (race prevention). + // 2. setCurrentContext on the CHILD thread — sets the ThreadLocal so operations inside + // the child context know which context they belong to. + // registerActiveThread is idempotent (no-op if already registered). + registerActiveThread(contextId); + + if (!runningUserHandler.compareAndSet( + null, + CompletableFuture.runAsync( + wrapped, getContext().getDurableConfig().getExecutorService()))) { + throw new IllegalStateException("User handler already running"); + } + } + /** * Receives operation updates from ExecutionManager. Completes the internal future when the operation reaches a * terminal status, unblocking any threads waiting on this operation. @@ -232,15 +266,8 @@ public void onCheckpointComplete(Operation operation) { // This method handles only terminal status updates. Override this method if a DurableOperation needs to // handle other updates. logger.trace("In onCheckpointComplete, completing operation {} ({})", getOperationId(), completionFuture); - // It's important that we synchronize access to the future, otherwise the processing could happen - // on someone else's thread and cause a race condition. - synchronized (completionFuture) { - // Completing the future here will also run any other completion stages that have been attached - // to the future. In our case, other contexts may have attached a function to reactivate themselves, - // so they will definitely have a chance to reactivate before we finish completing and deactivating - // whatever operations were just checkpointed. - completionFuture.complete(null); - } + + markCompletionFutureCompleted(); } } @@ -249,11 +276,18 @@ protected void markAlreadyCompleted() { // When the operation is already completed in a replay, we complete completionFuture immediately // so that the `get` method will be unblocked and the context thread will be registered logger.trace("In markAlreadyCompleted, completing operation: {} ({}).", getOperationId(), completionFuture); + markCompletionFutureCompleted(); + } + private void markCompletionFutureCompleted() { // It's important that we synchronize access to the future, otherwise the processing could happen // on someone else's thread and cause a race condition. synchronized (completionFuture) { - completionFuture.complete(null); + // Completing the future here will also run any other completion stages that have been attached + // to the future. In our case, other contexts may have attached a function to reactivate themselves, + // so they will definitely have a chance to reactivate before we finish completing and deactivating + // whatever operations were just checkpointed. + completionFuture.complete(this); } } @@ -263,7 +297,7 @@ protected void markAlreadyCompleted() { * @param exception the unrecoverable exception * @return never returns normally; always throws */ - protected T terminateExecution(UnrecoverableDurableExecutionException exception) { + protected RuntimeException terminateExecution(UnrecoverableDurableExecutionException exception) { executionManager.terminateExecution(exception); // Exception is already thrown from above. Keep the throw statement below to make tests happy throw exception; @@ -275,7 +309,7 @@ protected T terminateExecution(UnrecoverableDurableExecutionException exception) * @param message the error message * @return never returns normally; always throws */ - protected T terminateExecutionWithIllegalDurableOperationException(String message) { + protected RuntimeException terminateExecutionWithIllegalDurableOperationException(String message) { return terminateExecution(new IllegalDurableOperationException(message)); } @@ -323,82 +357,6 @@ protected CompletableFuture sendOperationUpdateAsync(OperationUpdate.Build return executionManager.sendOperationUpdate(updateBuilder.build()); } - /** - * Deserializes a result string into the operation's result type. - * - * @param result the serialized result string - * @return the deserialized result - * @throws SerDesException if deserialization fails - */ - protected T deserializeResult(String result) { - try { - return resultSerDes.deserialize(result, resultTypeToken); - } catch (SerDesException e) { - logger.warn( - "Failed to deserialize {} result for operation name '{}'. Ensure the result is properly encoded.", - getType(), - getName()); - throw e; - } - } - - /** - * Serializes the result to a string. - * - * @param result the result to serialize - * @return the serialized string - */ - protected String serializeResult(T result) { - return resultSerDes.serialize(result); - } - - /** - * Serializes a throwable into an {@link ErrorObject} for checkpointing. - * - * @param throwable the exception to serialize - * @return the serialized error object - */ - protected ErrorObject serializeException(Throwable throwable) { - return ExceptionHelper.buildErrorObject(throwable, resultSerDes); - } - - /** - * Deserializes an {@link ErrorObject} back into a throwable, reconstructing the original exception type and stack - * trace when possible. Falls back to null if the exception class is not found or deserialization fails. - * - * @param errorObject the serialized error object - * @return the reconstructed throwable, or null if reconstruction is not possible - */ - protected Throwable deserializeException(ErrorObject errorObject) { - Throwable original = null; - if (errorObject == null) { - return original; - } - var errorType = errorObject.errorType(); - var errorData = errorObject.errorData(); - - if (errorType == null) { - return original; - } - try { - - Class exceptionClass = Class.forName(errorType); - if (Throwable.class.isAssignableFrom(exceptionClass)) { - original = - resultSerDes.deserialize(errorData, TypeToken.get(exceptionClass.asSubclass(Throwable.class))); - - if (original != null) { - original.setStackTrace(ExceptionHelper.deserializeStackTrace(errorObject.stackTrace())); - } - } - } catch (ClassNotFoundException e) { - logger.warn("Cannot re-construct original exception type. Falling back to generic StepFailedException."); - } catch (SerDesException e) { - logger.warn("Cannot deserialize original exception data. Falling back to generic StepFailedException.", e); - } - return original; - } - /** Validates that current operation matches checkpointed operation during replay. */ protected void validateReplay(Operation checkpointed) { if (checkpointed == null || checkpointed.type() == null) { @@ -406,13 +364,13 @@ protected void validateReplay(Operation checkpointed) { } if (!checkpointed.type().equals(getType())) { - terminateExecution(new NonDeterministicExecutionException(String.format( + throw terminateExecution(new NonDeterministicExecutionException(String.format( "Operation type mismatch for \"%s\". Expected %s, got %s", getOperationId(), checkpointed.type(), getType()))); } if (!Objects.equals(checkpointed.name(), getName())) { - terminateExecution(new NonDeterministicExecutionException(String.format( + throw terminateExecution(new NonDeterministicExecutionException(String.format( "Operation name mismatch for \"%s\". Expected \"%s\", got \"%s\"", getOperationId(), checkpointed.name(), getName()))); } @@ -420,7 +378,7 @@ protected void validateReplay(Operation checkpointed) { if ((getSubType() == null && checkpointed.subType() != null) || getSubType() != null && !Objects.equals(checkpointed.subType(), getSubType().getValue())) { - terminateExecution(new NonDeterministicExecutionException(String.format( + throw terminateExecution(new NonDeterministicExecutionException(String.format( "Operation subType mismatch for \"%s\". Expected \"%s\", got \"%s\"", getOperationId(), checkpointed.subType(), getSubType()))); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java index 9e1f63fc2..9d9481fb9 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/CallbackOperation.java @@ -6,16 +6,16 @@ import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; import software.amazon.awssdk.services.lambda.model.OperationUpdate; -import software.amazon.lambda.durable.CallbackConfig; import software.amazon.lambda.durable.DurableCallbackFuture; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.CallbackConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.CallbackFailedException; import software.amazon.lambda.durable.exception.CallbackTimeoutException; import software.amazon.lambda.durable.model.OperationIdentifier; /** Durable operation for creating and waiting on external callbacks. */ -public class CallbackOperation extends BaseDurableOperation implements DurableCallbackFuture { +public class CallbackOperation extends SerializableDurableOperation implements DurableCallbackFuture { private final CallbackConfig config; @@ -65,7 +65,7 @@ protected void replay(Operation existing) { // Still waiting - continue to polling } default -> - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( "Unexpected callback status: " + existing.status()); } pollForOperationUpdates(); @@ -80,7 +80,8 @@ public T get() { case FAILED -> throw new CallbackFailedException(op); case TIMED_OUT -> throw new CallbackTimeoutException(op); default -> - terminateExecutionWithIllegalDurableOperationException("Unexpected callback status: " + op.status()); + throw terminateExecutionWithIllegalDurableOperationException( + "Unexpected callback status: " + op.status()); }; } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java index 4fb9e1beb..09a287726 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java @@ -5,8 +5,7 @@ import static software.amazon.lambda.durable.execution.ExecutionManager.isTerminalStatus; import java.nio.charset.StandardCharsets; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import software.amazon.awssdk.services.lambda.model.ContextOptions; import software.amazon.awssdk.services.lambda.model.ErrorObject; @@ -17,6 +16,7 @@ import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.CallbackFailedException; import software.amazon.lambda.durable.exception.CallbackSubmitterException; @@ -27,8 +27,8 @@ import software.amazon.lambda.durable.exception.StepInterruptedException; import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; import software.amazon.lambda.durable.execution.SuspendExecutionException; +import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; -import software.amazon.lambda.durable.serde.SerDes; import software.amazon.lambda.durable.util.ExceptionHelper; /** @@ -41,35 +41,33 @@ * on completion via {@code onItemComplete()} BEFORE closing its own child context. It also skips checkpointing if the * parent operation has already succeeded. */ -public class ChildContextOperation extends BaseDurableOperation { +public class ChildContextOperation extends SerializableDurableOperation { private static final int LARGE_RESULT_THRESHOLD = 256 * 1024; private final Function function; - private final ExecutorService userExecutor; private final ConcurrencyOperation parentOperation; - private boolean replayChildContext; + private final AtomicBoolean replayChildren = new AtomicBoolean(false); private T reconstructedResult; public ChildContextOperation( OperationIdentifier operationIdentifier, Function function, TypeToken resultTypeToken, - SerDes resultSerDes, + RunInChildContextConfig config, DurableContextImpl durableContext) { - this(operationIdentifier, function, resultTypeToken, resultSerDes, durableContext, null); + this(operationIdentifier, function, resultTypeToken, config, durableContext, null); } public ChildContextOperation( OperationIdentifier operationIdentifier, Function function, TypeToken resultTypeToken, - SerDes resultSerDes, + RunInChildContextConfig config, DurableContextImpl durableContext, ConcurrencyOperation parentOperation) { - super(operationIdentifier, resultTypeToken, resultSerDes, durableContext); + super(operationIdentifier, resultTypeToken, config.serDes(), durableContext); this.function = function; - this.userExecutor = getContext().getDurableConfig().getExecutorService(); this.parentOperation = parentOperation; } @@ -89,7 +87,7 @@ protected void replay(Operation existing) { if (existing.contextDetails() != null && Boolean.TRUE.equals(existing.contextDetails().replayChildren())) { // Large result: re-execute child context to reconstruct result - replayChildContext = true; + replayChildren.set(true); executeChildContext(); } else { markAlreadyCompleted(); @@ -98,19 +96,11 @@ protected void replay(Operation existing) { case FAILED -> markAlreadyCompleted(); case STARTED -> executeChildContext(); default -> - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( "Unexpected child context status: " + existing.status()); } } - @Override - protected void markAlreadyCompleted() { - super.markAlreadyCompleted(); - if (parentOperation != null) { - parentOperation.onItemComplete(this); - } - } - private void executeChildContext() { // The operationId is already globally unique (prefixed by parent context path via // DurableContext.nextOperationId), so we use it directly as the contextId. @@ -119,14 +109,6 @@ private void executeChildContext() { // third level child context "hash(hash(hash(1)-2)-1)". var contextId = getOperationId(); - // Thread registration is intentionally split across two threads: - // 1. registerActiveThread on the PARENT thread — ensures the child is tracked before the - // parent can deregister and trigger suspension (race prevention). - // 2. setCurrentContext on the CHILD thread — sets the ThreadLocal so operations inside - // the child context know which context they belong to. - // registerActiveThread is idempotent (no-op if already registered). - registerActiveThread(contextId); - Runnable userHandler = () -> { // use a try-with-resources to // - add thread id/type to thread local when the step starts @@ -142,20 +124,16 @@ private void executeChildContext() { handleChildContextSuccess(result); } catch (Throwable e) { handleChildContextFailure(e); - } finally { - if (parentOperation != null) { - parentOperation.onItemComplete(this); - } } } }; // Execute user provided child context code in user-configured executor - CompletableFuture.runAsync(userHandler, userExecutor); + runUserHandler(userHandler, contextId, ThreadType.CONTEXT); } private void handleChildContextSuccess(T result) { - if (replayChildContext) { + if (replayChildren.get()) { // Replaying a SUCCEEDED child with replayChildren=true — skip checkpointing. // Mark the completableFuture completed so get() doesn't block waiting for a checkpoint response. this.reconstructedResult = result; @@ -169,8 +147,6 @@ private void checkpointSuccess(T result) { // Skip checkpointing if parent ConcurrencyOperation has already completed — // prevents race conditions where a child finishes after the parent has already completed. if (parentOperation != null && parentOperation.isOperationCompleted()) { - this.reconstructedResult = result; - markAlreadyCompleted(); return; } @@ -199,13 +175,12 @@ private void handleChildContextFailure(Throwable exception) { } if (exception instanceof UnrecoverableDurableExecutionException unrecoverableDurableExecutionException) { // terminate the execution and throw the exception if it's not recoverable - terminateExecution(unrecoverableDurableExecutionException); + throw terminateExecution(unrecoverableDurableExecutionException); } // Skip checkpointing if parent ConcurrencyOperation has already completed — // prevents race conditions where a child finishes after the parent has already succeeded. if (parentOperation != null && parentOperation.isOperationCompleted()) { - markAlreadyCompleted(); return; } @@ -243,7 +218,7 @@ public T get() { // throw a general failed exception if a user exception is not reconstructed return switch (getSubType()) { - case WAIT_FOR_CALLBACK -> handleWaitForCallbackFailure(op); + case WAIT_FOR_CALLBACK -> handleWaitForCallbackFailure(); case MAP -> throw new ChildContextFailedException(op); case MAP_ITERATION -> throw new ChildContextFailedException(op); case PARALLEL -> throw new ChildContextFailedException(op); @@ -254,8 +229,8 @@ public T get() { } } - private T handleWaitForCallbackFailure(Operation op) { - var childrenOps = getChildOperations(op.id()); + private T handleWaitForCallbackFailure() { + var childrenOps = getChildOperations(); var callbackOp = childrenOps.stream() .filter(o -> o.type() == OperationType.CALLBACK) .findFirst() diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index f46b890b4..db6592515 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -8,18 +8,24 @@ import java.util.List; import java.util.Queue; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.OperationIdGenerator; +import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.OperationIdentifier; +import software.amazon.lambda.durable.model.OperationSubType; import software.amazon.lambda.durable.serde.SerDes; /** @@ -40,32 +46,43 @@ * * @param the result type of this operation */ -public abstract class ConcurrencyOperation extends BaseDurableOperation { +public abstract class ConcurrencyOperation extends SerializableDurableOperation { private static final Logger logger = LoggerFactory.getLogger(ConcurrencyOperation.class); private final int maxConcurrency; - private final AtomicInteger succeededCount = new AtomicInteger(0); - private final AtomicInteger failedCount = new AtomicInteger(0); - private final AtomicInteger runningCount = new AtomicInteger(0); - private final AtomicBoolean isJoined = new AtomicBoolean(false); - private final Queue> pendingQueue = new ConcurrentLinkedDeque<>(); - private final List> childOperations = Collections.synchronizedList(new ArrayList<>()); - private final Set completedOperations = Collections.synchronizedSet(new HashSet()); - private OperationIdGenerator operationIdGenerator; + private final Integer minSuccessful; + private final Integer toleratedFailureCount; + private final OperationIdGenerator operationIdGenerator; private final DurableContextImpl rootContext; - private ConcurrencyCompletionStatus completionStatus; + + // access by context thread only + private final List> branches = Collections.synchronizedList(new ArrayList<>()); + + // put only by context thread and consume only by consumer thread + private final Queue> pendingQueue = new ConcurrentLinkedDeque<>(); + + // set by context thread and used by consumer thread + protected final AtomicBoolean isJoined = new AtomicBoolean(false); + + // used to wake up consumer thread for either new items or checking completion condition (isJoined changed) + private final AtomicReference> consumerThreadListener; protected ConcurrencyOperation( OperationIdentifier operationIdentifier, TypeToken resultTypeToken, SerDes resultSerDes, DurableContextImpl durableContext, - int maxConcurrency) { + int maxConcurrency, + Integer minSuccessful, + Integer toleratedFailureCount) { super(operationIdentifier, resultTypeToken, resultSerDes, durableContext); this.maxConcurrency = maxConcurrency; + this.minSuccessful = minSuccessful; + this.toleratedFailureCount = toleratedFailureCount; this.operationIdGenerator = new OperationIdGenerator(getOperationId()); - this.rootContext = durableContext.createChildContextWithoutSettingThreadContext(getOperationId(), getName()); + this.rootContext = durableContext.createChildContext(getOperationId(), getName()); + this.consumerThreadListener = new AtomicReference<>(null); } // ========== Template methods for subclasses ========== @@ -77,98 +94,141 @@ protected ConcurrencyOperation( * @param name the name of this item * @param function the user function to execute * @param resultType the result type token - * @param serDes the serializer/deserializer + * @param branchSubType the sub-type of the branch operation * @param parentContext the parent durable context * @param the result type of the child operation * @return a new ChildContextOperation */ - protected abstract ChildContextOperation createItem( + protected ChildContextOperation createItem( String operationId, String name, Function function, TypeToken resultType, SerDes serDes, - DurableContextImpl parentContext); + OperationSubType branchSubType, + DurableContextImpl parentContext) { + return new ChildContextOperation<>( + OperationIdentifier.of(operationId, name, OperationType.CONTEXT, branchSubType), + function, + resultType, + RunInChildContextConfig.builder().serDes(serDes).build(), + parentContext, + this); + } /** Called when the concurrency operation succeeds. Subclasses define checkpointing behavior. */ protected abstract void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus); - /** Called when the concurrency operation fails. Subclasses define checkpointing and exception behavior. */ - protected abstract void handleFailure(ConcurrencyCompletionStatus concurrencyCompletionStatus); - // ========== Concurrency control ========== /** - * Adds a new item to this concurrency operation. Creates the child operation and either starts it immediately or - * enqueues it if maxConcurrency is reached. - * - * @param name the name of the item - * @param function the user function to execute - * @param resultType the result type token - * @param serDes the serializer/deserializer - * @param the result type of the child operation - * @return the created ChildContextOperation - */ - public ChildContextOperation addItem( - String name, Function function, TypeToken resultType, SerDes serDes) { - if (isOperationCompleted()) throw new IllegalStateException("Cannot add items to a completed operation"); - var operationId = this.operationIdGenerator.nextOperationId(); - var childOp = createItem(operationId, name, function, resultType, serDes, this.rootContext); - childOperations.add(childOp); - pendingQueue.add(childOp); - logger.debug("Item added {}", name); - executeNextItemIfAllowed(); - return childOp; - } - - /** - * Creates and enqueues an item without starting execution. Use {@link #startPendingItems()} to begin execution - * after all items have been enqueued. This prevents early termination from blocking item creation when all items - * are known upfront (e.g., map operations). + * Creates and enqueues an item without starting execution. Use {@link #executeItems()} to begin execution after all + * items have been enqueued. This prevents early termination from blocking item creation when all items are known + * upfront (e.g., map operations). */ protected ChildContextOperation enqueueItem( - String name, Function function, TypeToken resultType, SerDes serDes) { + String name, + Function function, + TypeToken resultType, + SerDes serDes, + OperationSubType branchSubType) { var operationId = this.operationIdGenerator.nextOperationId(); - var childOp = createItem(operationId, name, function, resultType, serDes, this.rootContext); - childOperations.add(childOp); + var childOp = createItem(operationId, name, function, resultType, serDes, branchSubType, this.rootContext); + branches.add(childOp); pendingQueue.add(childOp); logger.debug("Item enqueued {}", name); + // notify the consumer thread a new item is available + completeVacancyListenerIfSet(); return childOp; } - /** - * Starts executing enqueued items up to maxConcurrency. Called after all items have been enqueued via - * {@link #enqueueItem}. - */ - protected void startPendingItems() { - // Start as many items as concurrency allows - while (true) { - synchronized (this) { - if (isOperationCompleted()) return; - if (maxConcurrency != -1 && runningCount.get() >= maxConcurrency) return; - var next = pendingQueue.poll(); - if (next == null) return; - runningCount.incrementAndGet(); - logger.debug("Executing operation {}", next.getName()); - next.execute(); + private void completeVacancyListenerIfSet() { + synchronized (this) { + if (consumerThreadListener.get() != null) { + consumerThreadListener.get().complete(null); } } } - /** - * Starts the next queued item if the running count is below maxConcurrency and the operation hasn't completed yet. - * Must be called within {@code synchronized (pendingQueue)}. - */ - private void executeNextItemIfAllowed() { + /** Starts execution of all enqueued items. */ + protected void executeItems() { + // variables accessed only by the consumer thread. Put them here to avoid accidentally used by other threads + Set runningChildren = new HashSet<>(); + AtomicInteger succeededCount = new AtomicInteger(0); + AtomicInteger failedCount = new AtomicInteger(0); + + Runnable consumer = () -> { + while (true) { + if (isOperationCompleted()) { + return; + } + var completionStatus = canComplete(succeededCount, failedCount, runningChildren); + if (completionStatus != null) { + handleComplete(completionStatus); + return; + } + while (runningChildren.size() < maxConcurrency && !pendingQueue.isEmpty()) { + var next = pendingQueue.poll(); + runningChildren.add(next); + logger.debug("Executing operation {}", next.getName()); + next.execute(); + } + var child = waitForChildCompletion(succeededCount, failedCount, runningChildren); + // child may be null if the consumer thread is woken up due to a new item being added + if (child != null) { + if (runningChildren.contains(child)) { + runningChildren.remove(child); + onItemComplete(succeededCount, failedCount, (ChildContextOperation) child); + } else { + throw new IllegalStateException("Unexpected completion: " + child); + } + } + synchronized (this) { + if (consumerThreadListener.get() != null + && consumerThreadListener.get().isDone()) { + consumerThreadListener.set(null); + } + } + } + }; + // run consumer in the user thread pool, although it's not a real user thread + runUserHandler(consumer, getOperationId(), ThreadType.CONTEXT); + } + + private BaseDurableOperation waitForChildCompletion( + AtomicInteger succeededCount, AtomicInteger failedCount, Set runningChildren) { + var threadContext = getCurrentThreadContext(); + CompletableFuture future; + synchronized (this) { - if (isOperationCompleted()) return; - if (maxConcurrency != -1 && runningCount.get() >= maxConcurrency) return; - var next = pendingQueue.poll(); - if (next == null) return; - runningCount.incrementAndGet(); - logger.debug("Executing operation {}", next.getName()); - next.execute(); + // check again in synchronized block to prevent race conditions + if (isOperationCompleted()) { + return null; + } + var completionStatus = canComplete(succeededCount, failedCount, runningChildren); + if (completionStatus != null) { + return null; + } + ArrayList> futures; + futures = new ArrayList<>(runningChildren.stream() + .map(BaseDurableOperation::getCompletionFuture) + .toList()); + if (futures.size() < maxConcurrency) { + // add a future to listen to the new items if there is a vacancy + consumerThreadListener.compareAndSet(null, new CompletableFuture<>()); + futures.add(consumerThreadListener.get()); + } + + // future will be completed immediately if any future of the list is already completed + future = CompletableFuture.anyOf(futures.toArray(CompletableFuture[]::new)); + // skip deregistering the current thread if there is more completed future to process + if (!future.isDone()) { + future.thenRun(() -> registerActiveThread(threadContext.threadId())); + // Deregister the current thread to allow suspension + executionManager.deregisterActiveThread(threadContext.threadId()); + } } + return future.thenApply(o -> (BaseDurableOperation) o).join(); } /** @@ -177,68 +237,55 @@ private void executeNextItemIfAllowed() { * * @param child the child operation that completed */ - public void onItemComplete(ChildContextOperation child) { - if (!completedOperations.add(child.getOperationId())) { - return; - } - + private void onItemComplete( + AtomicInteger succeededCount, AtomicInteger failedCount, ChildContextOperation child) { // Evaluate child result outside the lock — child.get() may block waiting for a checkpoint response. logger.debug("OnItemComplete called by {}, Id: {}", child.getName(), child.getOperationId()); - boolean succeeded; try { child.get(); logger.debug("Result succeeded - {}", child.getName()); - succeeded = true; + succeededCount.incrementAndGet(); } catch (Throwable e) { logger.debug("Child operation {} failed: {}", child.getOperationId(), e.getMessage()); - succeeded = false; - } - - // Counter updates, completion check, and next-item dispatch must be atomic to prevent - // the main thread's join() from seeing runningCount==0 with incomplete counters. - synchronized (this) { - if (succeeded) { - succeededCount.incrementAndGet(); - } else { - failedCount.incrementAndGet(); - } - runningCount.decrementAndGet(); - - this.completionStatus = canComplete(); - if (this.completionStatus != null) { - handleComplete(this.completionStatus); - } else { - executeNextItemIfAllowed(); - } + failedCount.incrementAndGet(); } } // ========== Completion logic ========== - /** - * Validates that the number of registered items is sufficient to satisfy the completion criteria. Called at join() - * time because branches are registered incrementally and the total count is only known once the user calls join(). - * - * @throws IllegalArgumentException if the item count cannot satisfy the criteria - */ - protected abstract void validateItemCount(); - /** * Checks whether the concurrency operation can be considered complete. * * @return the completion status if the operation is complete, or null if it should continue */ - protected abstract ConcurrencyCompletionStatus canComplete(); + private ConcurrencyCompletionStatus canComplete( + AtomicInteger succeededCount, AtomicInteger failedCount, Set runningChildren) { + int succeeded = succeededCount.get(); + int failed = failedCount.get(); + + // If we've met the minimum successful count, we're done + if (minSuccessful != null && succeeded >= minSuccessful) { + return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; + } + + // If we've exceeded the failure tolerance, we're done + if (toleratedFailureCount != null && failed > toleratedFailureCount) { + return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; + } + + // All items finished — complete + if (isJoined.get() && pendingQueue.isEmpty() && runningChildren.isEmpty()) { + return ConcurrencyCompletionStatus.ALL_COMPLETED; + } + + return null; + } private void handleComplete(ConcurrencyCompletionStatus status) { synchronized (this) { if (isOperationCompleted()) { return; } - if (status.isSucceeded()) { - handleSuccess(status); - } else { - handleFailure(status); - } + handleSuccess(status); } } @@ -246,41 +293,18 @@ private void handleComplete(ConcurrencyCompletionStatus status) { * Blocks the calling thread until the concurrency operation reaches a terminal state. Validates item count, handles * zero-branch case, then delegates to {@code waitForOperationCompletion()} from BaseDurableOperation. */ - public void join() { - validateItemCount(); - isJoined.set(true); - synchronized (this) { - this.completionStatus = canComplete(); - if (this.completionStatus != null) { - handleComplete(this.completionStatus); - } + protected void join() { + if (minSuccessful != null && minSuccessful > branches.size()) { + throw new IllegalStateException("minSuccessful (" + minSuccessful + + ") exceeds the number of registered items (" + branches.size() + ")"); } - + isJoined.set(true); + // notify the execution thread this concurrency operation is joined + completeVacancyListenerIfSet(); waitForOperationCompletion(); } - protected int getSucceededCount() { - return succeededCount.get(); - } - - protected int getFailedCount() { - return failedCount.get(); - } - - protected int getTotalItems() { - return childOperations.size(); - } - - protected ConcurrencyCompletionStatus getCompletionStatus() { - return completionStatus; - } - - protected List> getChildOperations() { - return Collections.unmodifiableList(childOperations); - } - - /** Returns true if all items have finished (no pending, no running). Used by subclasses to override canComplete. */ - protected boolean isAllItemsFinished() { - return isJoined.get() && pendingQueue.isEmpty() && runningCount.get() == 0; + protected List> getBranches() { + return branches; } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java index be94fd0b0..9e2c54ace 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/InvokeOperation.java @@ -6,8 +6,8 @@ import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; import software.amazon.awssdk.services.lambda.model.OperationUpdate; -import software.amazon.lambda.durable.InvokeConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.InvokeConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.InvokeException; import software.amazon.lambda.durable.exception.InvokeFailedException; @@ -22,7 +22,7 @@ * @param the result type from the invoked function * @param the payload type sent to the invoked function */ -public class InvokeOperation extends BaseDurableOperation { +public class InvokeOperation extends SerializableDurableOperation { private final String functionName; private final I payload; private final InvokeConfig invokeConfig; @@ -58,7 +58,7 @@ protected void replay(Operation existing) { case STARTED -> pollForOperationUpdates(); case SUCCEEDED, FAILED, TIMED_OUT, STOPPED -> markAlreadyCompleted(); default -> - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( "Unexpected invoke status: " + existing.statusAsString()); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java index d9931a475..474570cae 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -5,26 +5,20 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.function.Function; import software.amazon.awssdk.services.lambda.model.ContextOptions; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; -import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.OperationUpdate; -import software.amazon.lambda.durable.CompletionConfig; import software.amazon.lambda.durable.DurableContext; -import software.amazon.lambda.durable.MapConfig; -import software.amazon.lambda.durable.MapFunction; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.CompletionConfig; +import software.amazon.lambda.durable.config.MapConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; -import software.amazon.lambda.durable.model.MapError; import software.amazon.lambda.durable.model.MapResult; -import software.amazon.lambda.durable.model.MapResultItem; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; import software.amazon.lambda.durable.serde.SerDes; -import software.amazon.lambda.durable.util.ExceptionHelper; /** * Executes a map operation: applies a function to each item in a collection concurrently, with each item running in its @@ -42,18 +36,16 @@ public class MapOperation extends ConcurrencyOperation> { private static final int LARGE_RESULT_THRESHOLD = 256 * 1024; private final List items; - private final MapFunction function; + private final DurableContext.MapFunction function; private final TypeToken itemResultType; private final SerDes serDes; - private final CompletionConfig completionConfig; private boolean replayFromPayload; private volatile MapResult cachedResult; - private ConcurrencyCompletionStatus completionStatus; public MapOperation( OperationIdentifier operationIdentifier, List items, - MapFunction function, + DurableContext.MapFunction function, TypeToken itemResultType, MapConfig config, DurableContextImpl durableContext) { @@ -62,29 +54,56 @@ public MapOperation( new TypeToken<>() {}, config.serDes(), durableContext, - config.maxConcurrency() != null ? config.maxConcurrency() : -1); + config.maxConcurrency(), + config.completionConfig().minSuccessful(), + getToleratedFailureCount(config.completionConfig(), items.size())); + if (config.completionConfig().minSuccessful() != null + && config.completionConfig().minSuccessful() > items.size()) { + throw new IllegalArgumentException("minSuccessful cannot be greater than total items: " + + config.completionConfig().minSuccessful() + " > " + items.size()); + } this.items = List.copyOf(items); this.function = function; this.itemResultType = itemResultType; this.serDes = config.serDes(); - this.completionConfig = config.completionConfig(); + + addAllItems(); } - @Override - protected ChildContextOperation createItem( - String operationId, - String name, - Function function, - TypeToken resultType, - SerDes serDes, - DurableContextImpl parentContext) { - return new ChildContextOperation<>( - OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.MAP_ITERATION), - function, - resultType, - serDes, - parentContext, - this); + private void addAllItems() { + // Enqueue all items first, then start execution. This prevents early termination + // criteria (e.g., minSuccessful) from completing the operation mid-loop on replay, + // which would cause subsequent enqueue calls to fail with "completed operation". + var branchPrefix = getName() == null ? "map-iteration-" : getName() + "-iteration-"; + for (int i = 0; i < items.size(); i++) { + var index = i; + var item = items.get(i); + enqueueItem( + branchPrefix + i, + childCtx -> function.apply(item, index, childCtx), + itemResultType, + serDes, + OperationSubType.MAP_ITERATION); + } + } + + private static Integer getToleratedFailureCount(CompletionConfig completionConfig, int totalItems) { + if (completionConfig == null + || (completionConfig.toleratedFailureCount() == null + && completionConfig.toleratedFailurePercentage() == null)) { + // neither toleratedFailureCount nor toleratedFailurePercentage is specified. + return null; + } + int toleratedFailureCount = completionConfig.toleratedFailureCount() != null + ? completionConfig.toleratedFailureCount() + : Integer.MAX_VALUE; + + // convert percentage to count + int toleratedFailureCountFromPercentage = completionConfig.toleratedFailurePercentage() != null + ? (int) Math.floor(totalItems * completionConfig.toleratedFailurePercentage()) + : Integer.MAX_VALUE; + // minimum of two if both count and percentage is specified + return Math.min(toleratedFailureCount, toleratedFailureCountFromPercentage); } @Override @@ -92,7 +111,8 @@ protected void start() { sendOperationUpdateAsync(OperationUpdate.builder() .action(OperationAction.START) .subType(getSubType().getValue())); - addAllItems(); + + executeItems(); } @Override @@ -102,7 +122,7 @@ protected void replay(Operation existing) { if (existing.contextDetails() != null && Boolean.TRUE.equals(existing.contextDetails().replayChildren())) { // Large result: re-execute children to reconstruct MapResult - addAllItems(); + executeItems(); } else { // Small result: MapResult is in the payload, skip child replay replayFromPayload = true; @@ -112,87 +132,35 @@ protected void replay(Operation existing) { case STARTED -> { // Map was in progress when interrupted — re-create children without sending // another START (the backend rejects duplicate START for existing operations) - addAllItems(); + executeItems(); } default -> - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( "Unexpected map operation status: " + existing.status()); } } - private void addAllItems() { - // Enqueue all items first, then start execution. This prevents early termination - // criteria (e.g., minSuccessful) from completing the operation mid-loop on replay, - // which would cause subsequent enqueue calls to fail with "completed operation". - for (int i = 0; i < items.size(); i++) { - var index = i; - var item = items.get(i); - enqueueItem( - "map-iteration-" + i, childCtx -> function.apply(item, index, childCtx), itemResultType, serDes); - } - startPendingItems(); - } - + @SuppressWarnings("unchecked") @Override protected void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus) { - this.completionStatus = concurrencyCompletionStatus; - checkpointMapResult(); - } - - @Override - protected void handleFailure(ConcurrencyCompletionStatus concurrencyCompletionStatus) { - this.completionStatus = concurrencyCompletionStatus; - checkpointMapResult(); - } + var children = getBranches(); + var resultItems = new ArrayList>(Collections.nCopies(items.size(), null)); - @Override - protected void validateItemCount() { - if (completionConfig.minSuccessful() != null && completionConfig.minSuccessful() > getTotalItems()) { - throw new IllegalArgumentException("minSuccessful (" + completionConfig.minSuccessful() - + ") exceeds the number of items (" + getTotalItems() + ")"); - } - } - - /** - * Overrides the default completion logic from {@link ConcurrencyOperation} to support Map's - * {@link CompletionConfig} semantics. Unlike Parallel (where {@code minSuccessful == -1} means "all must succeed"), - * Map's default {@code allCompleted()} allows failures without early termination. - */ - @Override - protected ConcurrencyCompletionStatus canComplete() { - int succeeded = getSucceededCount(); - int failed = getFailedCount(); - int totalCompleted = succeeded + failed; - - // Check minSuccessful - if (completionConfig.minSuccessful() != null && succeeded >= completionConfig.minSuccessful()) { - return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; - } - - // Check toleratedFailureCount - if (completionConfig.toleratedFailureCount() != null && failed > completionConfig.toleratedFailureCount()) { - return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; - } - - // Check toleratedFailurePercentage - if (completionConfig.toleratedFailurePercentage() != null - && totalCompleted > 0 - && ((double) failed / totalCompleted) > completionConfig.toleratedFailurePercentage()) { - return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; - } - - // All items finished (no pending, no running) — complete with ALL_COMPLETED - if (isAllItemsFinished()) { - return ConcurrencyCompletionStatus.ALL_COMPLETED; + for (int i = 0; i < children.size(); i++) { + var branch = (ChildContextOperation) children.get(i); + if (!branch.isOperationCompleted()) { + resultItems.set(i, MapResult.MapResultItem.skipped()); + } else { + try { + resultItems.set(i, MapResult.MapResultItem.succeeded(branch.get())); + } catch (Exception e) { + resultItems.set(i, MapResult.MapResultItem.failed(MapResult.MapError.of(e))); + } + } } - return null; - } - - private void checkpointMapResult() { - var result = aggregateResults(); - this.cachedResult = result; - var serialized = serializeResult(result); + this.cachedResult = new MapResult<>(resultItems, concurrencyCompletionStatus); + var serialized = serializeResult(cachedResult); var serializedBytes = serialized.getBytes(java.nio.charset.StandardCharsets.UTF_8); if (serializedBytes.length < LARGE_RESULT_THRESHOLD) { @@ -221,43 +189,6 @@ public MapResult get() { } // First execution or large result replay: wait for children, then aggregate join(); - return cachedResult != null ? cachedResult : aggregateResults(); - } - - /** - * Aggregates results from completed branches into a {@code MapResult}. - * - *

    Called after all branches have completed. At this point every branch's {@code completionFuture} is already - * done, so {@code branch.get()} returns immediately without blocking. - */ - @SuppressWarnings("unchecked") - private MapResult aggregateResults() { - var children = getChildOperations(); - var resultItems = new ArrayList>(Collections.nCopies(items.size(), null)); - - for (int i = 0; i < children.size(); i++) { - var branch = (ChildContextOperation) children.get(i); - if (!branch.isOperationCompleted()) { - resultItems.set(i, MapResultItem.notStarted()); - continue; - } - try { - resultItems.set(i, MapResultItem.success(branch.get())); - } catch (Exception e) { - resultItems.set(i, MapResultItem.failure(buildMapError(e))); - } - } - - // Fill any remaining null slots (items beyond children size) with notStarted - for (int i = children.size(); i < items.size(); i++) { - resultItems.set(i, MapResultItem.notStarted()); - } - - return new MapResult<>(resultItems, completionStatus); - } - - private static MapError buildMapError(Exception e) { - return new MapError( - e.getClass().getName(), e.getMessage(), ExceptionHelper.serializeStackTrace(e.getStackTrace())); + return cachedResult; } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index 13aaa5d55..c47d5d695 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -6,10 +6,14 @@ import software.amazon.awssdk.services.lambda.model.ContextOptions; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; -import software.amazon.awssdk.services.lambda.model.OperationType; +import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.DurableFuture; +import software.amazon.lambda.durable.ParallelDurableFuture; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.ParallelBranchConfig; +import software.amazon.lambda.durable.config.ParallelConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; @@ -39,43 +43,37 @@ * └── Branch N context (ChildContextOperation with PARALLEL_BRANCH) * */ -public class ParallelOperation extends ConcurrencyOperation { +public class ParallelOperation extends ConcurrencyOperation implements ParallelDurableFuture { - private final int minSuccessful; - private final int toleratedFailureCount; - private boolean skipCheckpoint = false; + // this field could be written and read in different threads + private volatile boolean skipCheckpoint = false; + private volatile ParallelResult cachedResult; public ParallelOperation( OperationIdentifier operationIdentifier, SerDes resultSerDes, DurableContextImpl durableContext, - int maxConcurrency, - int minSuccessful, - int toleratedFailureCount) { - super(operationIdentifier, new TypeToken() {}, resultSerDes, durableContext, maxConcurrency); - this.minSuccessful = minSuccessful; - this.toleratedFailureCount = toleratedFailureCount; - } - - @Override - protected ChildContextOperation createItem( - String operationId, - String name, - Function function, - TypeToken resultType, - SerDes serDes, - DurableContextImpl parentContext) { - return new ChildContextOperation<>( - OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL_BRANCH), - function, - resultType, - serDes, - parentContext, - this); + ParallelConfig config) { + super( + operationIdentifier, + TypeToken.get(ParallelResult.class), + resultSerDes, + durableContext, + config.maxConcurrency(), + config.completionConfig().minSuccessful(), + config.completionConfig().toleratedFailureCount()); } @Override protected void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus) { + var items = getBranches(); + int succeededCount = Math.toIntExact(items.stream() + .filter(item -> item.getOperation().status() == OperationStatus.SUCCEEDED) + .count()); + int failedCount = Math.toIntExact(items.stream() + .filter(item -> item.getOperation().status() != OperationStatus.SUCCEEDED) + .count()); + this.cachedResult = new ParallelResult(items.size(), succeededCount, failedCount, concurrencyCompletionStatus); if (skipCheckpoint) { // Do not send checkpoint during replay markAlreadyCompleted(); @@ -87,16 +85,13 @@ protected void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionSt .contextOptions(ContextOptions.builder().replayChildren(true).build())); } - @Override - protected void handleFailure(ConcurrencyCompletionStatus concurrencyCompletionStatus) { - handleSuccess(concurrencyCompletionStatus); - } - @Override protected void start() { sendOperationUpdateAsync(OperationUpdate.builder() .action(OperationAction.START) .subType(getSubType().getValue())); + + executeItems(); } @Override @@ -104,42 +99,27 @@ protected void replay(Operation existing) { // No-op: child branches handle their own replay via ChildContextOperation.replay(). // Set replaying=true so handleSuccess() skips re-checkpointing the already-completed parallel context. skipCheckpoint = ExecutionManager.isTerminalStatus(existing.status()); + executeItems(); } @Override public ParallelResult get() { join(); - return new ParallelResult(getTotalItems(), getSucceededCount(), getFailedCount(), getCompletionStatus()); + return cachedResult; } + /** Calls {@link #get()} if not already called. Guarantees that the context is closed. */ @Override - protected void validateItemCount() { - if (minSuccessful > getTotalItems()) { - throw new IllegalArgumentException("minSuccessful (" + minSuccessful - + ") exceeds the number of registered items (" + getTotalItems() + ")"); - } + public void close() { + join(); } - @Override - protected ConcurrencyCompletionStatus canComplete() { - int succeeded = getSucceededCount(); - int failed = getFailedCount(); - - // If we've met the minimum successful count, we're done - if (minSuccessful != -1 && succeeded >= minSuccessful) { - return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; - } - - // If we've exceeded the failure tolerance, we're done - if ((minSuccessful == -1 && failed > 0) || failed > toleratedFailureCount) { - return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; - } - - // All items finished — complete - if (isAllItemsFinished()) { - return ConcurrencyCompletionStatus.ALL_COMPLETED; + public DurableFuture branch( + String name, TypeToken resultType, Function func, ParallelBranchConfig config) { + if (isJoined.get()) { + throw new IllegalStateException("Cannot add branches after join() has been called"); } - - return null; + var serDes = config.serDes() == null ? getContext().getDurableConfig().getSerDes() : config.serDes(); + return enqueueItem(name, func, resultType, serDes, OperationSubType.PARALLEL_BRANCH); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java new file mode 100644 index 000000000..c86d6c263 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/SerializableDurableOperation.java @@ -0,0 +1,153 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.operation; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.lambda.model.ErrorObject; +import software.amazon.lambda.durable.DurableFuture; +import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.context.DurableContextImpl; +import software.amazon.lambda.durable.exception.IllegalDurableOperationException; +import software.amazon.lambda.durable.exception.SerDesException; +import software.amazon.lambda.durable.execution.ThreadType; +import software.amazon.lambda.durable.model.OperationIdentifier; +import software.amazon.lambda.durable.serde.SerDes; +import software.amazon.lambda.durable.util.ExceptionHelper; + +/** + * Base class for all durable operations (STEP, WAIT, etc.). + * + *

    Key methods: + * + *

      + *
    • {@code execute()} starts the operation (returns immediately) + *
    • {@code get()} blocks until complete and returns the result + *
    + * + *

    The separation allows: + * + *

      + *
    • Starting multiple async operations quickly + *
    • Blocking on results later when needed + *
    • Proper thread coordination via future + *
    + */ +public abstract class SerializableDurableOperation extends BaseDurableOperation implements DurableFuture { + private static final Logger logger = LoggerFactory.getLogger(SerializableDurableOperation.class); + + private final TypeToken resultTypeToken; + private final SerDes resultSerDes; + + /** + * Constructs a new durable operation. + * + * @param operationIdentifier the unique identifier for this operation + * @param resultTypeToken the type token for deserializing the result + * @param resultSerDes the serializer/deserializer for the result + * @param durableContext the parent context this operation belongs to + */ + protected SerializableDurableOperation( + OperationIdentifier operationIdentifier, + TypeToken resultTypeToken, + SerDes resultSerDes, + DurableContextImpl durableContext) { + super(operationIdentifier, durableContext); + this.resultTypeToken = resultTypeToken; + this.resultSerDes = resultSerDes; + } + + /** + * Checks if it's called from a Step. + * + * @throws IllegalDurableOperationException if it's in a step + */ + private void validateCurrentThreadType() { + ThreadType current = getCurrentThreadContext().threadType(); + if (current == ThreadType.STEP) { + var message = String.format( + "Nested %s operation is not supported on %s from within a %s execution.", + getType(), getName(), current); + // terminate execution and throw the exception + throw terminateExecutionWithIllegalDurableOperationException(message); + } + } + + /** + * Deserializes a result string into the operation's result type. + * + * @param result the serialized result string + * @return the deserialized result + * @throws SerDesException if deserialization fails + */ + protected T deserializeResult(String result) { + try { + return resultSerDes.deserialize(result, resultTypeToken); + } catch (SerDesException e) { + logger.warn( + "Failed to deserialize {} result for operation name '{}'. Ensure the result is properly encoded.", + getType(), + getName()); + throw e; + } + } + + /** + * Serializes the result to a string. + * + * @param result the result to serialize + * @return the serialized string + */ + protected String serializeResult(T result) { + return resultSerDes.serialize(result); + } + + /** + * Serializes a throwable into an {@link ErrorObject} for checkpointing. + * + * @param throwable the exception to serialize + * @return the serialized error object + */ + protected ErrorObject serializeException(Throwable throwable) { + return ExceptionHelper.buildErrorObject(throwable, resultSerDes); + } + + /** + * Deserializes an {@link ErrorObject} back into a throwable, reconstructing the original exception type and stack + * trace when possible. Falls back to null if the exception class is not found or deserialization fails. + * + * @param errorObject the serialized error object + * @return the reconstructed throwable, or null if reconstruction is not possible + */ + protected Throwable deserializeException(ErrorObject errorObject) { + Throwable original = null; + if (errorObject == null) { + return original; + } + var errorType = errorObject.errorType(); + var errorData = errorObject.errorData(); + + if (errorType == null) { + return original; + } + try { + + Class exceptionClass = Class.forName(errorType); + if (Throwable.class.isAssignableFrom(exceptionClass)) { + original = + resultSerDes.deserialize(errorData, TypeToken.get(exceptionClass.asSubclass(Throwable.class))); + + if (original != null) { + original.setStackTrace(ExceptionHelper.deserializeStackTrace(errorObject.stackTrace())); + } + } + } catch (ClassNotFoundException e) { + logger.warn("Cannot re-construct original exception type. Falling back to generic StepFailedException."); + } catch (SerDesException e) { + logger.warn("Cannot deserialize original exception data. Falling back to generic StepFailedException.", e); + } + return original; + } + + public abstract T get(); +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java index 4570f8637..7ff4b9ae0 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/StepOperation.java @@ -5,7 +5,6 @@ import java.time.Duration; import java.time.Instant; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; import java.util.function.Function; import software.amazon.awssdk.services.lambda.model.ErrorObject; import software.amazon.awssdk.services.lambda.model.Operation; @@ -13,16 +12,17 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.awssdk.services.lambda.model.StepOptions; -import software.amazon.lambda.durable.StepConfig; import software.amazon.lambda.durable.StepContext; -import software.amazon.lambda.durable.StepSemantics; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.StepConfig; +import software.amazon.lambda.durable.config.StepSemantics; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.DurableOperationException; import software.amazon.lambda.durable.exception.StepFailedException; import software.amazon.lambda.durable.exception.StepInterruptedException; import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; import software.amazon.lambda.durable.execution.SuspendExecutionException; +import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.util.ExceptionHelper; @@ -34,12 +34,11 @@ * * @param the result type of the step function */ -public class StepOperation extends BaseDurableOperation { +public class StepOperation extends SerializableDurableOperation { private static final Integer FIRST_ATTEMPT = 0; private final Function function; private final StepConfig config; - private final ExecutorService userExecutor; public StepOperation( OperationIdentifier operationIdentifier, @@ -51,7 +50,6 @@ public StepOperation( this.function = function; this.config = config; - this.userExecutor = durableContext.getDurableConfig().getExecutorService(); } /** Starts the operation. */ @@ -82,7 +80,8 @@ protected void replay(Operation existing) { // Execute with current attempt case READY -> executeStepLogic(attempt); default -> - terminateExecutionWithIllegalDurableOperationException("Unexpected step status: " + existing.status()); + throw terminateExecutionWithIllegalDurableOperationException( + "Unexpected step status: " + existing.status()); } } @@ -96,10 +95,6 @@ private CompletableFuture pollReadyAndExecuteStepLogic(Operation existing, } private void executeStepLogic(int attempt) { - // Register step thread as active BEFORE executor runs (prevents suspension when handler deregisters). - // The thread local ThreadContext is set inside the executor since that's where the step actually runs - registerActiveThread(getOperationId()); - Runnable userHandler = () -> { // use a try-with-resources to // - add thread id/type to thread local when the step starts @@ -119,7 +114,7 @@ private void executeStepLogic(int attempt) { }; // Execute user provided step code in user-configured executor - CompletableFuture.runAsync(userHandler, userExecutor); + runUserHandler(userHandler, getOperationId(), ThreadType.STEP); } private void checkpointStarted() { @@ -155,7 +150,7 @@ private void handleStepFailure(Throwable exception, int attempt) { } if (exception instanceof UnrecoverableDurableExecutionException unrecoverableDurableExecutionException) { // terminate the execution and throw the exception if it's not recoverable - terminateExecution(unrecoverableDurableExecutionException); + throw terminateExecution(unrecoverableDurableExecutionException); } final ErrorObject errorObject; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java index 700fe5f19..ac3258077 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java @@ -14,12 +14,13 @@ import software.amazon.awssdk.services.lambda.model.StepOptions; import software.amazon.lambda.durable.StepContext; import software.amazon.lambda.durable.TypeToken; -import software.amazon.lambda.durable.WaitForConditionConfig; +import software.amazon.lambda.durable.config.WaitForConditionConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.DurableOperationException; import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; import software.amazon.lambda.durable.exception.WaitForConditionException; import software.amazon.lambda.durable.execution.SuspendExecutionException; +import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; import software.amazon.lambda.durable.model.WaitForConditionResult; @@ -34,7 +35,7 @@ * * @param the type of state being polled */ -public class WaitForConditionOperation extends BaseDurableOperation { +public class WaitForConditionOperation extends SerializableDurableOperation { private final BiFunction> checkFunc; private final WaitForConditionConfig config; @@ -73,7 +74,7 @@ protected void replay(Operation existing) { case PENDING -> pollReadyAndResumeCheckLoop(existing); // Check if pending retry case STARTED, READY -> resumeCheckLoop(existing); default -> - terminateExecutionWithIllegalDurableOperationException( + throw terminateExecutionWithIllegalDurableOperationException( "Unexpected waitForCondition status: " + existing.status()); } } @@ -121,59 +122,56 @@ private CompletableFuture pollReadyAndResumeCheckLoop(Operation existing) } private void executeCheckLogic(T currentState, int attempt) { - // Register thread as active BEFORE executor runs - registerActiveThread(getOperationId()); - - CompletableFuture.runAsync( - () -> { - try (var stepContext = getContext().createStepContext(getOperationId(), getName(), attempt)) { - try { - // Checkpoint START if not already started - var existing = getOperation(); - if (existing == null || existing.status() != OperationStatus.STARTED) { - var startUpdate = OperationUpdate.builder().action(OperationAction.START); - sendOperationUpdateAsync(startUpdate); - } - - // Execute check function in user executor - WaitForConditionResult result = checkFunc.apply(currentState, stepContext); - - // Serialize/deserialize round-trip on the value to ensure state is checkpoint-safe - var serializedState = serializeResult(result.value()); - T deserializedValue = deserializeResult(serializedState); - - if (result.isDone()) { - // Condition met — checkpoint SUCCEED - var successUpdate = OperationUpdate.builder() - .action(OperationAction.SUCCEED) - .payload(serializedState); - sendOperationUpdate(successUpdate); - } else { - // Compute delay from strategy - Duration delay = config.waitStrategy().evaluate(deserializedValue, attempt); - - // Checkpoint RETRY with delay - var retryUpdate = OperationUpdate.builder() - .action(OperationAction.RETRY) - .payload(serializedState) - .stepOptions(StepOptions.builder() - .nextAttemptDelaySeconds(Math.toIntExact(delay.toSeconds())) - .build()); - sendOperationUpdate(retryUpdate); - - // Poll for READY, then continue the loop - pollForOperationUpdates() - .thenCompose(op -> op.status() == OperationStatus.READY - ? CompletableFuture.completedFuture(op) - : pollForOperationUpdates()) - .thenRun(() -> executeCheckLogic(deserializedValue, attempt + 1)); - } - } catch (Throwable e) { - handleCheckFailure(e); - } + Runnable userHandler = () -> { + try (var stepContext = getContext().createStepContext(getOperationId(), getName(), attempt)) { + try { + // Checkpoint START if not already started + var existing = getOperation(); + if (existing == null || existing.status() != OperationStatus.STARTED) { + var startUpdate = OperationUpdate.builder().action(OperationAction.START); + sendOperationUpdateAsync(startUpdate); } - }, - userExecutor); + + // Execute check function in user executor + WaitForConditionResult result = checkFunc.apply(currentState, stepContext); + + // Serialize/deserialize round-trip on the value to ensure state is checkpoint-safe + var serializedState = serializeResult(result.value()); + T deserializedValue = deserializeResult(serializedState); + + if (result.isDone()) { + // Condition met — checkpoint SUCCEED + var successUpdate = OperationUpdate.builder() + .action(OperationAction.SUCCEED) + .payload(serializedState); + sendOperationUpdate(successUpdate); + } else { + // Compute delay from strategy + Duration delay = config.waitStrategy().evaluate(deserializedValue, attempt); + + // Checkpoint RETRY with delay + var retryUpdate = OperationUpdate.builder() + .action(OperationAction.RETRY) + .payload(serializedState) + .stepOptions(StepOptions.builder() + .nextAttemptDelaySeconds(Math.toIntExact(delay.toSeconds())) + .build()); + sendOperationUpdate(retryUpdate); + + // Poll for READY, then continue the loop + pollForOperationUpdates() + .thenCompose(op -> op.status() == OperationStatus.READY + ? CompletableFuture.completedFuture(op) + : pollForOperationUpdates()) + .thenRun(() -> executeCheckLogic(deserializedValue, attempt + 1)); + } + } catch (Throwable e) { + handleCheckFailure(e); + } + } + }; + + runUserHandler(userHandler, getOperationId(), ThreadType.STEP); } private void handleCheckFailure(Throwable exception) { @@ -182,7 +180,7 @@ private void handleCheckFailure(Throwable exception) { throw suspendExecutionException; } if (exception instanceof UnrecoverableDurableExecutionException unrecoverable) { - terminateExecution(unrecoverable); + throw terminateExecution(unrecoverable); } final var errorObject = (exception instanceof DurableOperationException durableOpEx) diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java index e9493ba31..f4e3f4b57 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitOperation.java @@ -11,11 +11,9 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.awssdk.services.lambda.model.WaitOptions; -import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.model.OperationIdentifier; -import software.amazon.lambda.durable.serde.NoopSerDes; -import software.amazon.lambda.durable.serde.SerDes; /** * Durable operation that suspends execution for a specified duration without consuming compute. @@ -23,16 +21,15 @@ *

    The wait is checkpointed and the Lambda is suspended. On re-invocation after the wait period, execution resumes * from where it left off. */ -public class WaitOperation extends BaseDurableOperation { +public class WaitOperation extends BaseDurableOperation implements DurableFuture { private static final Logger logger = LoggerFactory.getLogger(WaitOperation.class); - private static final SerDes NOOP_SER_DES = new NoopSerDes(); private final Duration duration; public WaitOperation( OperationIdentifier operationIdentifier, Duration duration, DurableContextImpl durableContext) { - super(operationIdentifier, TypeToken.get(Void.class), NOOP_SER_DES, durableContext); + super(operationIdentifier, durableContext); this.duration = duration; } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/serde/NoopSerDes.java b/sdk/src/main/java/software/amazon/lambda/durable/serde/NoopSerDes.java deleted file mode 100644 index 245fed925..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/serde/NoopSerDes.java +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable.serde; - -import software.amazon.lambda.durable.TypeToken; - -/** A {@link SerDes} implementation that does nothing. Used as a placeholder when no serialization is required. */ -public class NoopSerDes implements SerDes { - @Override - public String serialize(Object value) { - return ""; - } - - @Override - public T deserialize(String data, TypeToken typeToken) { - return null; - } -} diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurableContextTest.java b/sdk/src/test/java/software/amazon/lambda/durable/DurableContextTest.java index 269eacf98..1d50cc1c8 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurableContextTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/DurableContextTest.java @@ -9,6 +9,7 @@ import java.util.List; import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.lambda.model.*; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.execution.SuspendExecutionException; @@ -47,7 +48,9 @@ private DurableContext createTestContext(List initialOperations) { new DurableExecutionInput(EXECUTION_ARN, "test-token", initialExecutionState), DurableConfig.builder().withDurableExecutionClient(client).build()); var root = DurableContextImpl.createRootContext( - executionManager, DurableConfig.builder().build(), null); + executionManager, + software.amazon.lambda.durable.DurableConfig.builder().build(), + null); executionManager.registerActiveThread(null); executionManager.setCurrentThreadContext(new ThreadContext(null, ThreadType.CONTEXT)); return root; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurableFutureTest.java b/sdk/src/test/java/software/amazon/lambda/durable/DurableFutureTest.java index 7292269e4..91338f9b1 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurableFutureTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/DurableFutureTest.java @@ -7,7 +7,7 @@ import java.util.List; import org.junit.jupiter.api.Test; -import software.amazon.lambda.durable.operation.BaseDurableOperation; +import software.amazon.lambda.durable.operation.SerializableDurableOperation; class DurableFutureTest { @@ -63,15 +63,15 @@ void allOfSingleFutureReturnsSingleResult() { void allOfPropagatesException() { var op1 = mockOperation("first"); @SuppressWarnings("unchecked") - BaseDurableOperation op2 = mock(BaseDurableOperation.class); + SerializableDurableOperation op2 = mock(SerializableDurableOperation.class); when(op2.get()).thenThrow(new RuntimeException("Step failed")); assertThrows(RuntimeException.class, () -> DurableFuture.allOf(op1, op2)); } @SuppressWarnings("unchecked") - private BaseDurableOperation mockOperation(T result) { - BaseDurableOperation op = mock(BaseDurableOperation.class); + private SerializableDurableOperation mockOperation(T result) { + SerializableDurableOperation op = mock(SerializableDurableOperation.class); when(op.get()).thenReturn(result); return op; } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurationValidationIntegrationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/DurationValidationIntegrationTest.java index 55a72a607..2deda7107 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurationValidationIntegrationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/DurationValidationIntegrationTest.java @@ -6,6 +6,7 @@ import java.time.Duration; import org.junit.jupiter.api.Test; +import software.amazon.lambda.durable.config.CallbackConfig; class DurationValidationIntegrationTest { diff --git a/sdk/src/test/java/software/amazon/lambda/durable/MapFunctionTest.java b/sdk/src/test/java/software/amazon/lambda/durable/MapFunctionTest.java index 9bc98c1a1..4cf6c01bc 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/MapFunctionTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/MapFunctionTest.java @@ -10,12 +10,12 @@ class MapFunctionTest { @Test void isFunctionalInterface() { - assertTrue(MapFunction.class.isAnnotationPresent(FunctionalInterface.class)); + assertTrue(DurableContext.MapFunction.class.isAnnotationPresent(FunctionalInterface.class)); } @Test void canBeUsedAsLambda() { - MapFunction fn = (item, index, ctx) -> item.toUpperCase(); + DurableContext.MapFunction fn = (item, index, ctx) -> item.toUpperCase(); var result = fn.apply("hello", 0, null); @@ -24,7 +24,7 @@ void canBeUsedAsLambda() { @Test void receivesCorrectIndex() { - MapFunction fn = (item, index, ctx) -> index; + DurableContext.MapFunction fn = (item, index, ctx) -> index; assertEquals(0, fn.apply("a", 0, null)); assertEquals(5, fn.apply("b", 5, null)); @@ -32,7 +32,7 @@ void receivesCorrectIndex() { @Test void canThrowRuntimeException() { - MapFunction fn = (item, index, ctx) -> { + DurableContext.MapFunction fn = (item, index, ctx) -> { throw new IllegalArgumentException("bad input"); }; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/ParallelConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/ParallelConfigTest.java deleted file mode 100644 index 6dc1aa0e8..000000000 --- a/sdk/src/test/java/software/amazon/lambda/durable/ParallelConfigTest.java +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - -import org.junit.jupiter.api.Test; - -class ParallelConfigTest { - - @Test - void defaultValues() { - var config = ParallelConfig.builder().build(); - - assertEquals(-1, config.maxConcurrency()); - assertEquals(-1, config.minSuccessful()); - assertEquals(0, config.toleratedFailureCount()); - } - - @Test - void builderRoundTrip() { - var config = ParallelConfig.builder() - .maxConcurrency(4) - .minSuccessful(2) - .toleratedFailureCount(3) - .build(); - - assertEquals(4, config.maxConcurrency()); - assertEquals(2, config.minSuccessful()); - assertEquals(3, config.toleratedFailureCount()); - } - - @Test - void maxConcurrencyOfOne() { - var config = ParallelConfig.builder().maxConcurrency(1).build(); - - assertEquals(1, config.maxConcurrency()); - } - - @Test - void minSuccessfulOfZero() { - var config = ParallelConfig.builder().minSuccessful(0).build(); - - assertEquals(0, config.minSuccessful()); - } - - @Test - void unlimitedConcurrency() { - var config = ParallelConfig.builder().maxConcurrency(-1).build(); - - assertEquals(-1, config.maxConcurrency()); - } - - @Test - void maxConcurrencyZeroThrows() { - var builder = ParallelConfig.builder().maxConcurrency(0); - assertThrows(IllegalArgumentException.class, builder::build); - } - - @Test - void maxConcurrencyNegativeTwoThrows() { - var builder = ParallelConfig.builder().maxConcurrency(-2); - assertThrows(IllegalArgumentException.class, builder::build); - } - - @Test - void minSuccessfulNegativeTwoThrows() { - var builder = ParallelConfig.builder().minSuccessful(-2); - assertThrows(IllegalArgumentException.class, builder::build); - } - - @Test - void toleratedFailureCountNegativeThrows() { - var builder = ParallelConfig.builder().toleratedFailureCount(-1); - assertThrows(IllegalArgumentException.class, builder::build); - } -} diff --git a/sdk/src/test/java/software/amazon/lambda/durable/TypeTokenTest.java b/sdk/src/test/java/software/amazon/lambda/durable/TypeTokenTest.java index fe5ebe98f..c3029ee6a 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/TypeTokenTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/TypeTokenTest.java @@ -17,7 +17,7 @@ void testSimpleGenericType() { var token = new TypeToken>() {}; Type type = token.getType(); - assertTrue(type instanceof ParameterizedType); + assertInstanceOf(ParameterizedType.class, type); ParameterizedType paramType = (ParameterizedType) type; assertEquals(List.class, paramType.getRawType()); assertEquals(String.class, paramType.getActualTypeArguments()[0]); @@ -28,13 +28,13 @@ void testNestedGenericType() { var token = new TypeToken>>() {}; Type type = token.getType(); - assertTrue(type instanceof ParameterizedType); + assertInstanceOf(ParameterizedType.class, type); ParameterizedType paramType = (ParameterizedType) type; assertEquals(Map.class, paramType.getRawType()); assertEquals(String.class, paramType.getActualTypeArguments()[0]); Type valueType = paramType.getActualTypeArguments()[1]; - assertTrue(valueType instanceof ParameterizedType); + assertInstanceOf(ParameterizedType.class, valueType); ParameterizedType valueParamType = (ParameterizedType) valueType; assertEquals(List.class, valueParamType.getRawType()); assertEquals(Integer.class, valueParamType.getActualTypeArguments()[0]); diff --git a/sdk/src/test/java/software/amazon/lambda/durable/CallbackConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/CallbackConfigTest.java similarity index 97% rename from sdk/src/test/java/software/amazon/lambda/durable/CallbackConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/CallbackConfigTest.java index 5f5853f1a..45798109d 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/CallbackConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/CallbackConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/CompletionConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/CompletionConfigTest.java similarity index 98% rename from sdk/src/test/java/software/amazon/lambda/durable/CompletionConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/CompletionConfigTest.java index 710bf9b1f..b9dbe7ecc 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/CompletionConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/CompletionConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.*; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/MapConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/MapConfigTest.java similarity index 95% rename from sdk/src/test/java/software/amazon/lambda/durable/MapConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/MapConfigTest.java index 11c567d8f..fc8962de3 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/MapConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/MapConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.*; @@ -13,7 +13,7 @@ class MapConfigTest { void defaultBuilder_hasNullMaxConcurrency() { var config = MapConfig.builder().build(); - assertNull(config.maxConcurrency()); + assertEquals(Integer.MAX_VALUE, config.maxConcurrency()); } @Test @@ -121,6 +121,6 @@ void builderWithNegativeMaxConcurrency_shouldThrow() { @Test void builderWithNullMaxConcurrency_shouldPass() { var config = MapConfig.builder().maxConcurrency(null).build(); - assertNull(config.maxConcurrency()); + assertEquals(Integer.MAX_VALUE, config.maxConcurrency()); } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/config/ParallelConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/ParallelConfigTest.java new file mode 100644 index 000000000..923a9d221 --- /dev/null +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/ParallelConfigTest.java @@ -0,0 +1,48 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.config; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; + +class ParallelConfigTest { + + @Test + void defaultValues() { + var config = ParallelConfig.builder().build(); + + assertEquals(Integer.MAX_VALUE, config.maxConcurrency()); + assertEquals(CompletionConfig.allCompleted(), config.completionConfig()); + } + + @Test + void builderRoundTrip() { + CompletionConfig completionConfig = CompletionConfig.allSuccessful(); + var config = ParallelConfig.builder() + .maxConcurrency(4) + .completionConfig(completionConfig) + .build(); + + assertEquals(4, config.maxConcurrency()); + assertEquals(completionConfig, config.completionConfig()); + } + + @Test + void maxConcurrencyOfOne() { + var config = ParallelConfig.builder().maxConcurrency(1).build(); + + assertEquals(1, config.maxConcurrency()); + } + + @Test + void invalidConcurrency() { + assertThrows( + IllegalArgumentException.class, + () -> ParallelConfig.builder().maxConcurrency(-1).build()); + assertThrows( + IllegalArgumentException.class, + () -> ParallelConfig.builder().maxConcurrency(0).build()); + } +} diff --git a/sdk/src/test/java/software/amazon/lambda/durable/StepConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/StepConfigTest.java similarity index 98% rename from sdk/src/test/java/software/amazon/lambda/durable/StepConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/StepConfigTest.java index ec9437c03..033e0b4d5 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/StepConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/StepConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/WaitForCallbackConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/WaitForCallbackConfigTest.java similarity index 99% rename from sdk/src/test/java/software/amazon/lambda/durable/WaitForCallbackConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/WaitForCallbackConfigTest.java index 2ca85e143..6f12f5cd8 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/WaitForCallbackConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/WaitForCallbackConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/WaitForConditionConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/config/WaitForConditionConfigTest.java similarity index 98% rename from sdk/src/test/java/software/amazon/lambda/durable/WaitForConditionConfigTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/config/WaitForConditionConfigTest.java index 82a431874..eacb1f704 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/WaitForConditionConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/config/WaitForConditionConfigTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.config; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/context/BaseContextImplTest.java b/sdk/src/test/java/software/amazon/lambda/durable/context/BaseContextImplTest.java index c0d321ed5..76b59e663 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/context/BaseContextImplTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/context/BaseContextImplTest.java @@ -15,8 +15,6 @@ import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.execution.ExecutionManager; -import software.amazon.lambda.durable.execution.ThreadContext; -import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.DurableExecutionInput; class BaseContextImplTest { @@ -59,11 +57,6 @@ void defaultConstructor_setsCurrentThreadContext() { // Creating a root context with the default constructor should set the thread context DurableContextImpl.createRootContext( executionManager, DurableConfig.builder().build(), null); - - var threadContext = executionManager.getCurrentThreadContext(); - assertNotNull(threadContext); - assertEquals(ThreadType.CONTEXT, threadContext.threadType()); - assertNull(threadContext.threadId()); } @Test @@ -73,51 +66,8 @@ void constructorWithSetCurrentThreadContextTrue_setsCurrentThreadContext() { // createRootContext sets thread context to root (threadId=null) var rootContext = DurableContextImpl.createRootContext( executionManager, DurableConfig.builder().build(), null); - assertEquals( - ThreadType.CONTEXT, executionManager.getCurrentThreadContext().threadType()); - assertNull(executionManager.getCurrentThreadContext().threadId()); // createChildContext (setCurrentThreadContext=true) should overwrite with child's context rootContext.createChildContext("child-id", "child-name"); - - var threadContext = executionManager.getCurrentThreadContext(); - assertNotNull(threadContext); - assertEquals(ThreadType.CONTEXT, threadContext.threadType()); - assertEquals("child-id", threadContext.threadId()); - } - - @Test - void constructorWithSetCurrentThreadContextFalse_doesNotOverwriteThreadContext() { - var executionManager = createExecutionManager(); - - // Create root context first (it will set thread context to null/root) - var rootContext = DurableContextImpl.createRootContext( - executionManager, DurableConfig.builder().build(), null); - - // Now set a sentinel — simulating a caller thread that already has context established - var sentinel = new ThreadContext("original-context", ThreadType.CONTEXT); - executionManager.setCurrentThreadContext(sentinel); - - // createChildContextWithoutSettingThreadContext should NOT overwrite the sentinel - rootContext.createChildContextWithoutSettingThreadContext("child-id", "child-name"); - - // Thread context should still be the sentinel, not the child's context - var threadContext = executionManager.getCurrentThreadContext(); - assertNotNull(threadContext); - assertEquals("original-context", threadContext.threadId()); - } - - @Test - void createChildContextWithoutSettingThreadContext_returnsValidChildContext() { - var executionManager = createExecutionManager(); - executionManager.setCurrentThreadContext(new ThreadContext(null, ThreadType.CONTEXT)); - var rootContext = DurableContextImpl.createRootContext( - executionManager, DurableConfig.builder().build(), null); - - var childContext = rootContext.createChildContextWithoutSettingThreadContext("child-id", "child-name"); - - assertNotNull(childContext); - assertEquals("child-id", childContext.getContextId()); - assertEquals("child-name", childContext.getContextName()); } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/context/DurableContextImplTest.java b/sdk/src/test/java/software/amazon/lambda/durable/context/DurableContextImplTest.java index e3a5f732a..2d6f6c4e3 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/context/DurableContextImplTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/context/DurableContextImplTest.java @@ -7,7 +7,6 @@ import java.util.ArrayList; import java.util.List; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.lambda.model.CheckpointUpdatedExecutionState; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationStatus; @@ -50,44 +49,4 @@ void setUp() { rootContext = DurableContextImpl.createRootContext( executionManager, DurableConfig.builder().build(), null); } - - @Test - void createChildContext_setsThreadContextToChild() { - rootContext.createChildContext("child-1", "my-child"); - - var threadContext = executionManager.getCurrentThreadContext(); - assertNotNull(threadContext); - assertEquals("child-1", threadContext.threadId()); - assertEquals(ThreadType.CONTEXT, threadContext.threadType()); - } - - @Test - void createChildContextWithoutSettingThreadContext_preservesCallerThreadContext() { - var callerContext = new ThreadContext("caller-thread", ThreadType.CONTEXT); - executionManager.setCurrentThreadContext(callerContext); - - rootContext.createChildContextWithoutSettingThreadContext("child-1", "my-child"); - - // Thread context must remain unchanged - var threadContext = executionManager.getCurrentThreadContext(); - assertEquals("caller-thread", threadContext.threadId()); - } - - @Test - void createChildContextWithoutSettingThreadContext_returnsCorrectChildMetadata() { - var child = rootContext.createChildContextWithoutSettingThreadContext("child-42", "child-name"); - - assertEquals("child-42", child.getContextId()); - assertEquals("child-name", child.getContextName()); - } - - @Test - void createChildContextWithoutSettingThreadContext_whenNoThreadContextSet_leavesItNull() { - // Clear any existing thread context - executionManager.setCurrentThreadContext(null); - - rootContext.createChildContextWithoutSettingThreadContext("child-1", "my-child"); - - assertNull(executionManager.getCurrentThreadContext()); - } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionTest.java b/sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionTest.java similarity index 98% rename from sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionTest.java index 3b966cdbd..a4ad86d3b 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.execution; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -19,7 +19,8 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.StepDetails; -import software.amazon.lambda.durable.execution.DurableExecutor; +import software.amazon.lambda.durable.DurableConfig; +import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.model.DurableExecutionInput; import software.amazon.lambda.durable.model.ExecutionStatus; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionWrapperTest.java b/sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionWrapperTest.java similarity index 96% rename from sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionWrapperTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionWrapperTest.java index ac9bfdd7a..a64a19232 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurableExecutionWrapperTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/execution/DurableExecutionWrapperTest.java @@ -1,6 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable; +package software.amazon.lambda.durable.execution; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -14,8 +14,10 @@ import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; +import software.amazon.lambda.durable.DurableConfig; +import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.client.DurableExecutionClient; -import software.amazon.lambda.durable.execution.DurableExecutor; import software.amazon.lambda.durable.model.DurableExecutionInput; import software.amazon.lambda.durable.model.DurableExecutionOutput; import software.amazon.lambda.durable.model.ExecutionStatus; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java b/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java index 09d97e6d2..ea8e49665 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java @@ -9,8 +9,8 @@ class MapResultTest { - private static MapError testError(String message) { - return new MapError("java.lang.RuntimeException", message, null); + private static MapResult.MapError testError(String message) { + return new MapResult.MapError("java.lang.RuntimeException", message, null); } @Test @@ -28,7 +28,7 @@ void empty_returnsZeroSizeResult() { @Test void allSucceeded_trueWhenNoErrors() { var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.success("b")), + List.of(MapResult.MapResultItem.succeeded("a"), MapResult.MapResultItem.succeeded("b")), ConcurrencyCompletionStatus.ALL_COMPLETED); assertTrue(result.allSucceeded()); @@ -43,7 +43,7 @@ void allSucceeded_trueWhenNoErrors() { void allSucceeded_falseWhenAnyError() { var error = testError("fail"); var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.failure(error)), + List.of(MapResult.MapResultItem.succeeded("a"), MapResult.MapResultItem.failed(error)), ConcurrencyCompletionStatus.ALL_COMPLETED); assertFalse(result.allSucceeded()); @@ -53,7 +53,7 @@ void allSucceeded_falseWhenAnyError() { void getResult_returnsNullForFailedItem() { var error = testError("fail"); var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.failure(error)), + List.of(MapResult.MapResultItem.succeeded("a"), MapResult.MapResultItem.failed(error)), ConcurrencyCompletionStatus.ALL_COMPLETED); assertEquals("a", result.getResult(0)); @@ -64,7 +64,7 @@ void getResult_returnsNullForFailedItem() { void getError_returnsNullForSucceededItem() { var error = testError("fail"); var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.failure(error)), + List.of(MapResult.MapResultItem.succeeded("a"), MapResult.MapResultItem.failed(error)), ConcurrencyCompletionStatus.ALL_COMPLETED); assertNull(result.getError(0)); @@ -75,9 +75,9 @@ void getError_returnsNullForSucceededItem() { void succeeded_filtersNullResults() { var result = new MapResult<>( List.of( - MapResultItem.success("a"), - MapResultItem.failure(testError("fail")), - MapResultItem.success("c")), + MapResult.MapResultItem.succeeded("a"), + MapResult.MapResultItem.failed(testError("fail")), + MapResult.MapResultItem.succeeded("c")), ConcurrencyCompletionStatus.ALL_COMPLETED); assertEquals(List.of("a", "c"), result.succeeded()); @@ -87,7 +87,10 @@ void succeeded_filtersNullResults() { void failed_filtersNullErrors() { var error = testError("fail"); var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.failure(error), MapResultItem.success("c")), + List.of( + MapResult.MapResultItem.succeeded("a"), + MapResult.MapResultItem.failed(error), + MapResult.MapResultItem.succeeded("c")), ConcurrencyCompletionStatus.ALL_COMPLETED); var failures = result.failed(); @@ -98,29 +101,33 @@ void failed_filtersNullErrors() { @Test void completionReason_preserved() { var result = new MapResult<>( - List.of(MapResultItem.success("a")), ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED); + List.of(MapResult.MapResultItem.succeeded("a")), ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED); assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.completionReason()); } @Test void items_returnsUnmodifiableList() { - var result = new MapResult<>(List.of(MapResultItem.success("a")), ConcurrencyCompletionStatus.ALL_COMPLETED); + var result = new MapResult<>( + List.of(MapResult.MapResultItem.succeeded("a")), ConcurrencyCompletionStatus.ALL_COMPLETED); - assertThrows(UnsupportedOperationException.class, () -> result.items().add(MapResultItem.success("b"))); + assertThrows( + UnsupportedOperationException.class, () -> result.items().add(MapResult.MapResultItem.succeeded("b"))); } @Test void getItem_returnsMapResultItem() { var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.failure(testError("fail"))), + List.of( + MapResult.MapResultItem.succeeded("a"), + MapResult.MapResultItem.failed(testError("fail"))), ConcurrencyCompletionStatus.ALL_COMPLETED); - assertEquals(MapResultItem.Status.SUCCEEDED, result.getItem(0).status()); + assertEquals(MapResult.MapResultItem.Status.SUCCEEDED, result.getItem(0).status()); assertEquals("a", result.getItem(0).result()); assertNull(result.getItem(0).error()); - assertEquals(MapResultItem.Status.FAILED, result.getItem(1).status()); + assertEquals(MapResult.MapResultItem.Status.FAILED, result.getItem(1).status()); assertNull(result.getItem(1).result()); assertNotNull(result.getItem(1).error()); } @@ -128,10 +135,10 @@ void getItem_returnsMapResultItem() { @Test void notStartedItems_haveNotStartedStatusAndNullResultAndError() { var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.notStarted()), + List.of(MapResult.MapResultItem.succeeded("a"), MapResult.MapResultItem.skipped()), ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED); - assertEquals(MapResultItem.Status.NOT_STARTED, result.getItem(1).status()); + assertEquals(MapResult.MapResultItem.Status.SKIPPED, result.getItem(1).status()); assertNull(result.getResult(1)); assertNull(result.getError(1)); } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/model/ParallelResultTest.java b/sdk/src/test/java/software/amazon/lambda/durable/model/ParallelResultTest.java index c5cfeebc8..328273111 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/model/ParallelResultTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/model/ParallelResultTest.java @@ -12,9 +12,9 @@ class ParallelResultTest { void allBranchesSucceed_countsAreCorrect() { var result = new ParallelResult(3, 3, 0, ConcurrencyCompletionStatus.ALL_COMPLETED); - assertEquals(3, result.getTotalBranches()); - assertEquals(3, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); + assertEquals(3, result.size()); + assertEquals(3, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/CallbackOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/CallbackOperationTest.java index 8f9f8775a..ebe2d7e99 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/CallbackOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/CallbackOperationTest.java @@ -13,10 +13,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.lambda.model.*; -import software.amazon.lambda.durable.CallbackConfig; import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.CallbackConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.CallbackFailedException; import software.amazon.lambda.durable.exception.CallbackTimeoutException; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java index 0e6158ed1..92be2e3ef 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java @@ -20,6 +20,7 @@ import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.ChildContextFailedException; import software.amazon.lambda.durable.exception.NonDeterministicExecutionException; @@ -58,13 +59,22 @@ private DurableConfig createConfig() { private ChildContextOperation createOperation(Function func) { return new ChildContextOperation<>( - OPERATION_IDENTIFIER, func, TypeToken.get(String.class), SERDES, durableContext); + OPERATION_IDENTIFIER, + func, + TypeToken.get(String.class), + RunInChildContextConfig.builder().serDes(SERDES).build(), + durableContext); } private ChildContextOperation createOperationWithParent( Function func, ConcurrencyOperation parent) { return new ChildContextOperation<>( - OPERATION_IDENTIFIER, func, TypeToken.get(String.class), SERDES, durableContext, parent); + OPERATION_IDENTIFIER, + func, + TypeToken.get(String.class), + RunInChildContextConfig.builder().serDes(SERDES).build(), + durableContext, + parent); } // ===== SUCCEEDED replay ===== @@ -259,40 +269,6 @@ void replayWithNameMismatchTerminatesExecution() { // ===== Parent ConcurrencyOperation support ===== - /** Parent's onItemComplete() is called when child succeeds. */ - @Test - void parentOnItemCompleteCalledOnChildSuccess() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")).thenReturn(null); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent(ctx -> "success", parent); - operation.execute(); - Thread.sleep(200); - - verify(parent).onItemComplete(operation); - } - - /** Parent's onItemComplete() is called when child fails. */ - @Test - void parentOnItemCompleteCalledOnChildFailure() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")).thenReturn(null); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent( - ctx -> { - throw new RuntimeException("branch failed"); - }, - parent); - operation.execute(); - Thread.sleep(200); - - verify(parent).onItemComplete(operation); - } - /** Child skips success checkpoint when parent operation has already completed. */ @Test void childSkipsSuccessCheckpointWhenParentAlreadyCompleted() throws Exception { @@ -330,104 +306,4 @@ void childSkipsFailureCheckpointWhenParentAlreadyCompleted() throws Exception { verify(executionManager, never()) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.FAIL)); } - - // ===== onItemComplete called during replay ===== - - /** SUCCEEDED replay (terminal) — onItemComplete is called via markAlreadyCompleted(). */ - @Test - void replaySucceeded_callsParentOnItemComplete() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")) - .thenReturn(Operation.builder() - .id("1") - .name("test-context") - .type(OperationType.CONTEXT) - .subType(OperationSubType.RUN_IN_CHILD_CONTEXT.getValue()) - .status(OperationStatus.SUCCEEDED) - .contextDetails( - ContextDetails.builder().result("\"cached\"").build()) - .build()); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent(ctx -> "unused", parent); - operation.execute(); - - verify(parent).onItemComplete(operation); - } - - /** FAILED replay (terminal) — onItemComplete is called via markAlreadyCompleted(). */ - @Test - void replayFailed_callsParentOnItemComplete() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")) - .thenReturn(Operation.builder() - .id("1") - .name("test-context") - .type(OperationType.CONTEXT) - .subType(OperationSubType.RUN_IN_CHILD_CONTEXT.getValue()) - .status(OperationStatus.FAILED) - .contextDetails(ContextDetails.builder() - .error(ErrorObject.builder() - .errorType("java.lang.RuntimeException") - .errorMessage("original failure") - .build()) - .build()) - .build()); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent(ctx -> "unused", parent); - operation.execute(); - - verify(parent).onItemComplete(operation); - } - - /** STARTED replay — child re-executes and onItemComplete is called from the finally block. */ - @Test - void replayStarted_callsParentOnItemComplete() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")) - .thenReturn(Operation.builder() - .id("1") - .name("test-context") - .type(OperationType.CONTEXT) - .subType(OperationSubType.RUN_IN_CHILD_CONTEXT.getValue()) - .status(OperationStatus.STARTED) - .build()); - when(executionManager.hasOperationsForContext("1")).thenReturn(false); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent(ctx -> "re-executed", parent); - operation.execute(); - Thread.sleep(200); - - verify(parent).onItemComplete(operation); - } - - /** replayChildren=true — child re-executes and onItemComplete is called from the finally block. */ - @Test - void replayChildren_callsParentOnItemComplete() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("1")) - .thenReturn(Operation.builder() - .id("1") - .name("test-context") - .type(OperationType.CONTEXT) - .subType(OperationSubType.RUN_IN_CHILD_CONTEXT.getValue()) - .status(OperationStatus.SUCCEEDED) - .contextDetails( - ContextDetails.builder().replayChildren(true).build()) - .build()); - when(executionManager.hasOperationsForContext("1")).thenReturn(false); - - var parent = mock(ConcurrencyOperation.class); - when(parent.isOperationCompleted()).thenReturn(false); - - var operation = createOperationWithParent(ctx -> "reconstructed", parent); - operation.execute(); - Thread.sleep(200); - - verify(parent, atLeast(1)).onItemComplete(operation); - } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java index 67f7c2ebe..a453ceae9 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java @@ -12,7 +12,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.lambda.model.ContextDetails; @@ -20,8 +19,9 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.lambda.durable.DurableConfig; -import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.CompletionConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.execution.OperationIdGenerator; @@ -37,19 +37,18 @@ class ConcurrencyOperationTest { private static final SerDes SER_DES = new JacksonSerDes(); private static final String OPERATION_ID = "op-1"; + private static final String CHILD_OP_1 = TestUtils.hashOperationId(OPERATION_ID + "-1"); + private static final String CHILD_OP_2 = TestUtils.hashOperationId(OPERATION_ID + "-2"); private static final TypeToken RESULT_TYPE = TypeToken.get(Void.class); private DurableContextImpl durableContext; private DurableContextImpl childContext; private ExecutionManager executionManager; - private AtomicInteger operationIdCounter; - private OperationIdGenerator mockIdGenerator; @BeforeEach void setUp() { durableContext = mock(DurableContextImpl.class); executionManager = mock(ExecutionManager.class); - operationIdCounter = new AtomicInteger(0); var childContext = mock(DurableContextImpl.class); this.childContext = childContext; @@ -65,11 +64,7 @@ void setUp() { .withExecutorService(Executors.newCachedThreadPool()) .build()); when(durableContext.createChildContext(anyString(), anyString())).thenReturn(childContext); - when(durableContext.createChildContextWithoutSettingThreadContext(anyString(), anyString())) - .thenReturn(childContext); when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("Root", ThreadType.CONTEXT)); - mockIdGenerator = mock(OperationIdGenerator.class); - when(mockIdGenerator.nextOperationId()).thenAnswer(inv -> "child-" + operationIdCounter.incrementAndGet()); // All child operations are NOT in replay when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null); // Simulate the real backend: the parent concurrency operation is available in storage after completion @@ -86,19 +81,15 @@ void setUp() { when(executionManager.sendOperationUpdate(any())).thenReturn(CompletableFuture.completedFuture(null)); } - private TestConcurrencyOperation createOperation(int maxConcurrency, int minSuccessful, int toleratedFailureCount) - throws Exception { - TestConcurrencyOperation testConcurrencyOperation = new TestConcurrencyOperation( + private TestConcurrencyOperation createOperation(CompletionConfig completionConfig) throws Exception { + return new TestConcurrencyOperation( OperationIdentifier.of( OPERATION_ID, "test-concurrency", OperationType.CONTEXT, OperationSubType.PARALLEL), RESULT_TYPE, SER_DES, durableContext, - maxConcurrency, - minSuccessful, - toleratedFailureCount); - setOperationIdGenerator(testConcurrencyOperation, mockIdGenerator); - return testConcurrencyOperation; + Integer.MAX_VALUE, + completionConfig); } private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGenerator mockGenerator) @@ -112,9 +103,9 @@ private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGene @Test void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -122,9 +113,9 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { .contextDetails( ContextDetails.builder().result("\"result-1\"").build()) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-2")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2)) .thenReturn(Operation.builder() - .id("child-2") + .id(CHILD_OP_2) .name("branch-2") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -134,38 +125,42 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception { .build()); var functionCalled = new AtomicBoolean(false); - var op = createOperation(-1, -1, 0); - op.addItem( + var op = createOperation(CompletionConfig.allSuccessful()); + op.execute(); + op.enqueueItem( "branch-1", - ctx -> { + ctx1 -> { functionCalled.set(true); return "result-1"; }, TypeToken.get(String.class), - SER_DES); - op.addItem( + SER_DES, + OperationSubType.PARALLEL_BRANCH); + op.enqueueItem( "branch-2", ctx -> { functionCalled.set(true); return "result-2"; }, TypeToken.get(String.class), - SER_DES); + SER_DES, + OperationSubType.PARALLEL_BRANCH); op.exposedJoin(); assertTrue(op.isSuccessHandled()); assertFalse(op.isFailureHandled()); - assertEquals(2, op.getSucceededCount()); - assertEquals(0, op.getFailedCount()); + var items = op.getBranches(); + assertEquals(2, items.size()); + assertTrue(items.stream().allMatch(b -> b.getOperation().status().equals(OperationStatus.SUCCEEDED))); assertFalse(functionCalled.get(), "Functions should not be called during SUCCEEDED replay"); } @Test void singleChildAlreadySucceeds_fullCycle() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("only-branch") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -175,36 +170,27 @@ void singleChildAlreadySucceeds_fullCycle() throws Exception { .build()); var functionCalled = new AtomicBoolean(false); - var op = createOperation(-1, 1, 0); - op.addItem( + var op = createOperation(CompletionConfig.minSuccessful(1)); + op.enqueueItem( "only-branch", ctx -> { functionCalled.set(true); return "done"; }, TypeToken.get(String.class), - SER_DES); + SER_DES, + OperationSubType.PARALLEL_BRANCH); + op.execute(); op.exposedJoin(); assertTrue(op.isSuccessHandled()); - assertEquals(1, op.getSucceededCount()); - assertEquals(0, op.getFailedCount()); + var items = op.getBranches(); + assertEquals(1, items.size()); + assertEquals(OperationStatus.SUCCEEDED, items.get(0).getOperation().status()); assertFalse(functionCalled.get(), "Function should not be called during SUCCEEDED replay"); } - @Test - void addItem_usesRootChildContextAsParent() throws Exception { - var op = createOperation(-1, -1, 0); - - op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); - - // rootContext is created via durableContext.createChildContext(...) in the constructor, - // so the parentContext passed to createItem must be that child context, not durableContext itself - assertNotSame(durableContext, op.getLastParentContext()); - assertSame(childContext, op.getLastParentContext()); - } - // ===== Test subclass ===== static class TestConcurrencyOperation extends ConcurrencyOperation { @@ -213,8 +199,6 @@ static class TestConcurrencyOperation extends ConcurrencyOperation { private boolean failureHandled = false; private final AtomicInteger executingCount = new AtomicInteger(0); private DurableContextImpl lastParentContext; - private final int minSuccessful; - private final int toleratedFailureCount; TestConcurrencyOperation( OperationIdentifier operationIdentifier, @@ -222,35 +206,15 @@ static class TestConcurrencyOperation extends ConcurrencyOperation { SerDes resultSerDes, DurableContextImpl durableContext, int maxConcurrency, - int minSuccessful, - int toleratedFailureCount) { - super(operationIdentifier, resultTypeToken, resultSerDes, durableContext, maxConcurrency); - this.minSuccessful = minSuccessful; - this.toleratedFailureCount = toleratedFailureCount; - } - - @Override - protected ChildContextOperation createItem( - String operationId, - String name, - Function function, - TypeToken resultType, - SerDes serDes, - DurableContextImpl parentContext) { - lastParentContext = parentContext; - return new ChildContextOperation( - OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL_BRANCH), - function, - resultType, - serDes, - parentContext, - this) { - @Override - public void execute() { - executingCount.incrementAndGet(); - super.execute(); - } - }; + CompletionConfig completionConfig) { + super( + operationIdentifier, + resultTypeToken, + resultSerDes, + durableContext, + maxConcurrency, + completionConfig.minSuccessful(), + completionConfig.toleratedFailureCount()); } @Override @@ -265,46 +229,13 @@ protected void handleSuccess(ConcurrencyCompletionStatus completionStatus) { } @Override - protected void handleFailure(ConcurrencyCompletionStatus completionStatus) { - failureHandled = true; - onCheckpointComplete(Operation.builder() - .id(getOperationId()) - .status(OperationStatus.SUCCEEDED) // always success for parallel - .build()); - } - - @Override - protected void start() {} - - @Override - protected void replay(Operation existing) {} - - @Override - protected void validateItemCount() { - if (minSuccessful > getTotalItems() - getFailedCount()) { - throw new IllegalArgumentException("minSuccessful (" + minSuccessful - + ") exceeds the number of registered items (" + getTotalItems() + ")"); - } + protected void start() { + executeItems(); } @Override - protected ConcurrencyCompletionStatus canComplete() { - int succeeded = getSucceededCount(); - int failed = getFailedCount(); - - if (minSuccessful != -1 && succeeded >= minSuccessful) { - return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; - } - - if ((minSuccessful == -1 && failed > 0) || failed > toleratedFailureCount) { - return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; - } - - if (isAllItemsFinished()) { - return ConcurrencyCompletionStatus.ALL_COMPLETED; - } - - return null; + protected void replay(Operation existing) { + executeItems(); } @Override diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/InvokeOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/InvokeOperationTest.java index f370805c9..daa75056c 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/InvokeOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/InvokeOperationTest.java @@ -14,8 +14,8 @@ import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; -import software.amazon.lambda.durable.InvokeConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.InvokeConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.InvokeException; import software.amazon.lambda.durable.exception.InvokeFailedException; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index a44e53459..a7693591b 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -12,7 +12,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.lambda.model.ContextDetails; @@ -21,7 +20,10 @@ import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.lambda.durable.DurableConfig; +import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.CompletionConfig; +import software.amazon.lambda.durable.config.ParallelConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.execution.OperationIdGenerator; @@ -37,17 +39,19 @@ class ParallelOperationTest { private static final SerDes SER_DES = new JacksonSerDes(); private static final String OPERATION_ID = "parallel-op-1"; + private static final String CHILD_OP_1 = TestUtils.hashOperationId(OPERATION_ID + "-1"); + private static final String CHILD_OP_2 = TestUtils.hashOperationId(OPERATION_ID + "-2"); private DurableContextImpl durableContext; private ExecutionManager executionManager; - private AtomicInteger operationIdCounter; - private OperationIdGenerator mockIdGenerator; @BeforeEach void setUp() { durableContext = mock(DurableContextImpl.class); executionManager = mock(ExecutionManager.class); - operationIdCounter = new AtomicInteger(0); + + when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext(null, ThreadType.CONTEXT)); + when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null); var childContext = mock(DurableContextImpl.class); when(childContext.getExecutionManager()).thenReturn(executionManager); @@ -62,18 +66,11 @@ void setUp() { .withExecutorService(Executors.newCachedThreadPool()) .build()); when(durableContext.createChildContext(anyString(), anyString())).thenReturn(childContext); - when(durableContext.createChildContextWithoutSettingThreadContext(anyString(), anyString())) - .thenReturn(childContext); - when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("Root", ThreadType.CONTEXT)); - // Default: no existing operations (fresh execution) - mockIdGenerator = mock(OperationIdGenerator.class); - when(mockIdGenerator.nextOperationId()).thenAnswer(inv -> "child-" + operationIdCounter.incrementAndGet()); - when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null); // Capture registered operations so we can drive onCheckpointComplete callbacks. - var registeredOps = new ConcurrentHashMap>(); + var registeredOps = new ConcurrentHashMap(); doAnswer(inv -> { - BaseDurableOperation op = inv.getArgument(0); + BaseDurableOperation op = inv.getArgument(0); registeredOps.put(op.getOperationId(), op); return null; }) @@ -107,14 +104,15 @@ void setUp() { .sendOperationUpdate(any()); } - private ParallelOperation createOperation(int maxConcurrency, int minSuccessful, int toleratedFailureCount) { - return new ParallelOperation( + private ParallelOperation createOperation(CompletionConfig completionConfig) { + var op = new ParallelOperation( OperationIdentifier.of(OPERATION_ID, "test-parallel", OperationType.CONTEXT, OperationSubType.PARALLEL), SER_DES, durableContext, - maxConcurrency, - minSuccessful, - toleratedFailureCount); + ParallelConfig.builder().completionConfig(completionConfig).build()); + + op.execute(); + return op; } private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGenerator mockGenerator) @@ -127,32 +125,36 @@ private void setOperationIdGenerator(ConcurrencyOperation op, OperationIdGene // ===== Branch creation delegates to ConcurrencyOperation ===== @Test - void branchCreation_createsBranchWithParallelBranchSubType() throws Exception { - var op = createOperation(-1, -1, 0); + void branchCreation_createsBranchWithParallelBranchSubType() { + var op = createOperation(CompletionConfig.allSuccessful()); - var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); + var childOp = op.enqueueItem( + "branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); assertNotNull(childOp); assertEquals(OperationSubType.PARALLEL_BRANCH, childOp.getSubType()); } @Test - void branchCreation_multipleBranchesAllCreated() throws Exception { - var op = createOperation(-1, -1, 0); + void branchCreation_multipleBranchesAllCreated() { + var op = createOperation(CompletionConfig.allSuccessful()); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); - op.addItem("branch-3", ctx -> "r3", TypeToken.get(String.class), SER_DES); + op.enqueueItem( + "branch-1", ctx2 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem( + "branch-2", ctx1 -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem("branch-3", ctx -> "r3", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); - assertEquals(3, op.getTotalItems()); + assertEquals(3, op.getBranches().size()); } @Test void branchCreation_childOperationHasParentReference() throws Exception { - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); // The child operation should be a ChildContextOperation with this op as parent - var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); + var childOp = op.enqueueItem( + "branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); assertNotNull(childOp); // Verify it's a ChildContextOperation (the concrete type returned by createItem) @@ -163,9 +165,9 @@ void branchCreation_childOperationHasParentReference() throws Exception { @Test void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -173,9 +175,9 @@ void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws E .contextDetails( ContextDetails.builder().result("\"r1\"").build()) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-2")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2)) .thenReturn(Operation.builder() - .id("child-2") + .id(CHILD_OP_2) .name("branch-2") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -184,28 +186,28 @@ void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws E ContextDetails.builder().result("\"r2\"").build()) .build()); - var op = createOperation(-1, -1, 0); - setOperationIdGenerator(op, mockIdGenerator); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); + var op = createOperation(CompletionConfig.allSuccessful()); + op.enqueueItem( + "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); - assertEquals(2, result.getTotalBranches()); - assertEquals(2, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); - assertTrue(result.getCompletionStatus().isSucceeded()); + assertEquals(2, result.size()); + assertEquals(2, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); + assertTrue(result.completionStatus().isSucceeded()); } // ===== MinSuccessful satisfaction ===== @Test void minSuccessful_completesWhenThresholdMetAndReturnsResult() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -214,18 +216,17 @@ void minSuccessful_completesWhenThresholdMetAndReturnsResult() throws Exception ContextDetails.builder().result("\"r1\"").build()) .build()); - var op = createOperation(-1, 1, 0); - setOperationIdGenerator(op, mockIdGenerator); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); + var op = createOperation(CompletionConfig.minSuccessful(1)); + op.enqueueItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); - assertEquals(1, result.getTotalBranches()); - assertEquals(1, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.getCompletionStatus()); - assertTrue(result.getCompletionStatus().isSucceeded()); + assertEquals(1, result.size()); + assertEquals(1, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.completionStatus()); + assertTrue(result.completionStatus().isSucceeded()); } // ===== Context hierarchy ===== @@ -234,9 +235,10 @@ void minSuccessful_completesWhenThresholdMetAndReturnsResult() throws Exception void contextHierarchy_branchesUseParallelContextAsParent() throws Exception { // Verify that branches are created with the parallel operation's context (durableContext) // as their parent — not some other context - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); - var childOp = op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); + var childOp = op.enqueueItem( + "branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); // The child operation should be registered in the execution manager // (BaseDurableOperation constructor calls executionManager.registerOperation) @@ -256,9 +258,9 @@ void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exc .subType(OperationSubType.PARALLEL.getValue()) .status(OperationStatus.STARTED) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -266,9 +268,9 @@ void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exc .contextDetails( ContextDetails.builder().result("\"r1\"").build()) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-2")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2)) .thenReturn(Operation.builder() - .id("child-2") + .id(CHILD_OP_2) .name("branch-2") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -277,11 +279,10 @@ void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exc ContextDetails.builder().result("\"r2\"").build()) .build()); - var op = createOperation(-1, -1, 0); - setOperationIdGenerator(op, mockIdGenerator); - op.execute(); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); + var op = createOperation(CompletionConfig.allSuccessful()); + op.enqueueItem( + "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); @@ -289,10 +290,10 @@ void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exc .sendOperationUpdate(argThat(update -> update.action() == OperationAction.START)); verify(executionManager, times(1)) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); - assertEquals(2, result.getTotalBranches()); - assertEquals(2, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); + assertEquals(2, result.size()); + assertEquals(2, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); } @Test @@ -305,9 +306,9 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio .subType(OperationSubType.PARALLEL.getValue()) .status(OperationStatus.SUCCEEDED) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -315,9 +316,9 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio .contextDetails( ContextDetails.builder().result("\"r1\"").build()) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-2")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2)) .thenReturn(Operation.builder() - .id("child-2") + .id(CHILD_OP_2) .name("branch-2") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -326,11 +327,10 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio ContextDetails.builder().result("\"r2\"").build()) .build()); - var op = createOperation(-1, -1, 0); - setOperationIdGenerator(op, mockIdGenerator); - op.execute(); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); - op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); + var op = createOperation(CompletionConfig.allSuccessful()); + op.enqueueItem( + "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); var result = op.get(); @@ -338,51 +338,51 @@ void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exceptio .sendOperationUpdate(argThat(update -> update.action() == OperationAction.START)); verify(executionManager, never()) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); - assertEquals(2, result.getTotalBranches()); - assertEquals(2, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); + assertEquals(2, result.size()); + assertEquals(2, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); } // ===== Branch failure sends SUCCEED checkpoint and returns result ===== @Test - void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() { + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) .status(OperationStatus.FAILED) .build()); - var op = createOperation(-1, -1, 0); - setOperationIdGenerator(op, mockIdGenerator); - op.addItem( + var op = createOperation(CompletionConfig.allSuccessful()); + op.enqueueItem( "branch-1", ctx -> { throw new RuntimeException("branch failed"); }, TypeToken.get(String.class), - SER_DES); + SER_DES, + OperationSubType.PARALLEL_BRANCH); - var result = assertDoesNotThrow(() -> op.get()); + var result = assertDoesNotThrow(op::get); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); verify(executionManager, never()) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.FAIL)); - assertEquals(1, result.getTotalBranches()); - assertEquals(0, result.getSucceededBranches()); - assertEquals(1, result.getFailedBranches()); - assertFalse(result.getCompletionStatus().isSucceeded()); + assertEquals(1, result.size()); + assertEquals(0, result.succeeded()); + assertEquals(1, result.failed()); + assertFalse(result.completionStatus().isSucceeded()); } @Test void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exception { - when(executionManager.getOperationAndUpdateReplayState("child-1")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_1)) .thenReturn(Operation.builder() - .id("child-1") + .id(CHILD_OP_1) .name("branch-1") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -390,9 +390,9 @@ void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exceptio .contextDetails( ContextDetails.builder().result("\"r1\"").build()) .build()); - when(executionManager.getOperationAndUpdateReplayState("child-2")) + when(executionManager.getOperationAndUpdateReplayState(CHILD_OP_2)) .thenReturn(Operation.builder() - .id("child-2") + .id(CHILD_OP_2) .name("branch-2") .type(OperationType.CONTEXT) .subType(OperationSubType.PARALLEL_BRANCH.getValue()) @@ -400,36 +400,37 @@ void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exceptio .build()); // toleratedFailureCount=1 so the operation completes after both branches finish - var op = createOperation(-1, -1, 1); - setOperationIdGenerator(op, mockIdGenerator); - op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); - op.addItem( + var op = createOperation(CompletionConfig.toleratedFailureCount(1)); + op.enqueueItem( + "branch-1", ctx1 -> "r1", TypeToken.get(String.class), SER_DES, OperationSubType.PARALLEL_BRANCH); + op.enqueueItem( "branch-2", ctx -> { throw new RuntimeException("branch failed"); }, TypeToken.get(String.class), - SER_DES); + SER_DES, + OperationSubType.PARALLEL_BRANCH); var result = op.get(); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); - assertEquals(2, result.getTotalBranches()); - assertEquals(1, result.getSucceededBranches()); - assertEquals(1, result.getFailedBranches()); - assertFalse(result.getCompletionStatus().isSucceeded()); + assertEquals(2, result.size()); + assertEquals(1, result.succeeded()); + assertEquals(1, result.failed()); + assertTrue(result.completionStatus().isSucceeded()); } @Test void get_zeroBranches_returnsAllZerosAndAllCompletedStatus() throws Exception { - var op = createOperation(-1, -1, 0); + var op = createOperation(CompletionConfig.allSuccessful()); var result = op.get(); - assertEquals(0, result.getTotalBranches()); - assertEquals(0, result.getSucceededBranches()); - assertEquals(0, result.getFailedBranches()); - assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); + assertEquals(0, result.size()); + assertEquals(0, result.succeeded()); + assertEquals(0, result.failed()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionStatus()); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/BaseDurableOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/SerializableDurableOperationTest.java similarity index 85% rename from sdk/src/test/java/software/amazon/lambda/durable/operation/BaseDurableOperationTest.java rename to sdk/src/test/java/software/amazon/lambda/durable/operation/SerializableDurableOperationTest.java index 8df56e796..6ae8cc890 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/BaseDurableOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/SerializableDurableOperationTest.java @@ -39,7 +39,7 @@ import software.amazon.lambda.durable.serde.JacksonSerDes; import software.amazon.lambda.durable.serde.SerDes; -class BaseDurableOperationTest { +class SerializableDurableOperationTest { private static final String OPERATION_ID = "1"; private static final String CONTEXT_ID = "1-step"; @@ -67,8 +67,8 @@ void setUp() { @Test void getOperation() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} @@ -91,8 +91,8 @@ public String get() { @Test void waitForOperationCompletionThrowsIfOperationMissing() { when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(null); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { markAlreadyCompleted(); @@ -115,8 +115,8 @@ public String get() { @Test void waitForOperationCompletionWhenRunningAndReadyToComplete() throws InterruptedException, ExecutionException, TimeoutException { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} @@ -147,8 +147,8 @@ public String get() { @Test void waitForOperationCompletionWhenAlreadyCompleted() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { markAlreadyCompleted(); @@ -171,8 +171,8 @@ public String get() { @Test void markAlreadyCompleted() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { markAlreadyCompleted(); @@ -198,8 +198,8 @@ void validateReplayThrowsWhenTypeMismatch() { .thenReturn( Operation.builder().type(OperationType.CHAINED_INVOKE).build()); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { validateReplay(getOperation()); @@ -225,8 +225,8 @@ void validateReplayThrowsWhenNameMismatch() { .type(OPERATION_TYPE) .build()); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { validateReplay(getOperation()); @@ -248,8 +248,8 @@ public String get() { void validateReplayDoesNotThrowWhenNoOperation() { when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(null); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { validateReplay(getOperation()); @@ -274,8 +274,8 @@ void validateReplayDoesNotThrowWhenNameAndTypeMatch() { .type(OPERATION_TYPE) .build()); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() { validateReplay(getOperation()); @@ -294,8 +294,8 @@ public String get() { @Test void deserializeResult() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} @@ -315,8 +315,8 @@ public String get() { @Test void deserializeException() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} @@ -341,8 +341,8 @@ public String get() { @Test void polling() { - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} @@ -365,8 +365,8 @@ public String get() { void sendOperationUpdate() { var update = OperationUpdate.builder(); - BaseDurableOperation op = - new BaseDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { + SerializableDurableOperation op = + new SerializableDurableOperation<>(OPERATION_IDENTIFIER, RESULT_TYPE, SER_DES, durableContext) { @Override protected void start() {} diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/StepOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/StepOperationTest.java index 65cb64a7d..2eae32f0c 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/StepOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/StepOperationTest.java @@ -15,8 +15,8 @@ import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.StepDetails; import software.amazon.lambda.durable.DurableConfig; -import software.amazon.lambda.durable.StepConfig; import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.StepFailedException; import software.amazon.lambda.durable.exception.StepInterruptedException; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/WaitForConditionOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/WaitForConditionOperationTest.java index 32f175ec1..2747e6440 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/WaitForConditionOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/WaitForConditionOperationTest.java @@ -18,7 +18,7 @@ import software.amazon.awssdk.services.lambda.model.StepDetails; import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.TypeToken; -import software.amazon.lambda.durable.WaitForConditionConfig; +import software.amazon.lambda.durable.config.WaitForConditionConfig; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.IllegalDurableOperationException; import software.amazon.lambda.durable.exception.NonDeterministicExecutionException; diff --git a/sdk/src/test/java/software/amazon/lambda/durable/retry/RetryStrategiesTest.java b/sdk/src/test/java/software/amazon/lambda/durable/retry/RetryStrategiesTest.java index 74f57966b..0578a8bf8 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/retry/RetryStrategiesTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/retry/RetryStrategiesTest.java @@ -6,7 +6,7 @@ import java.time.Duration; import org.junit.jupiter.api.Test; -import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.config.StepConfig; class RetryStrategiesTest {