Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions sdk/src/main/java/software/amazon/lambda/durable/BaseContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,27 @@
import software.amazon.lambda.durable.execution.ExecutionManager;
import software.amazon.lambda.durable.logging.DurableLogger;

public abstract class BaseContext {
public abstract class BaseContext implements AutoCloseable {
protected final ExecutionManager executionManager;
private final DurableConfig durableConfig;
private final Context lambdaContext;
private final ExecutionContext executionContext;
private final String contextId;
private final String contextName;
private boolean isReplaying;

/** Creates a new BaseContext instance. */
protected BaseContext(
ExecutionManager executionManager, DurableConfig durableConfig, Context lambdaContext, String contextId) {
ExecutionManager executionManager,
DurableConfig durableConfig,
Context lambdaContext,
String contextId,
String contextName) {
this.executionManager = executionManager;
this.durableConfig = durableConfig;
this.lambdaContext = lambdaContext;
this.contextId = contextId;
this.contextName = contextName;
this.executionContext = new ExecutionContext(executionManager.getDurableExecutionArn());
this.isReplaying = executionManager.hasOperationsForContext(contextId);
}
Expand Down Expand Up @@ -71,6 +77,10 @@ public String getContextId() {
return contextId;
}

public String getContextName() {
return contextName;
}

public ExecutionManager getExecutionManager() {
return executionManager;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,17 @@

public class DurableContext extends BaseContext {
private final AtomicInteger operationCounter;
private final DurableLogger logger;
private volatile DurableLogger logger;

/** Shared initialization — sets all fields. */
private DurableContext(
ExecutionManager executionManager, DurableConfig durableConfig, Context lambdaContext, String contextId) {
super(executionManager, durableConfig, lambdaContext, contextId);
ExecutionManager executionManager,
DurableConfig durableConfig,
Context lambdaContext,
String contextId,
String contextName) {
super(executionManager, durableConfig, lambdaContext, contextId, contextName);
this.operationCounter = new AtomicInteger(0);

var requestId = lambdaContext != null ? lambdaContext.getAwsRequestId() : null;
this.logger = new DurableLogger(
LoggerFactory.getLogger(DurableContext.class),
executionManager,
requestId,
durableConfig.getLoggerConfig().suppressReplayLogs());
}

/**
Expand All @@ -48,7 +45,7 @@ private DurableContext(
*/
public static DurableContext createRootContext(
ExecutionManager executionManager, DurableConfig durableConfig, Context lambdaContext) {
return new DurableContext(executionManager, durableConfig, lambdaContext, null);
return new DurableContext(executionManager, durableConfig, lambdaContext, null, null);
}

/**
Expand All @@ -57,8 +54,9 @@ public static DurableContext createRootContext(
* @param childContextId the child context's ID (the CONTEXT operation's operation ID)
* @return a new DurableContext for the child context
*/
public DurableContext createChildContext(String childContextId) {
return new DurableContext(executionManager, getDurableConfig(), getLambdaContext(), childContextId);
public DurableContext createChildContext(String childContextId, String childContextName) {
return new DurableContext(
executionManager, getDurableConfig(), getLambdaContext(), childContextId, childContextName);
}

/**
Expand All @@ -67,8 +65,9 @@ public DurableContext createChildContext(String childContextId) {
* @param stepOperationId the ID of the step operation (used for thread registration)
* @return a new StepContext instance
*/
public StepContext createStepContext(String stepOperationId) {
return new StepContext(executionManager, getDurableConfig(), getLambdaContext(), stepOperationId);
public StepContext createStepContext(String stepOperationId, String stepOperationName, int attempt) {
return new StepContext(
executionManager, getDurableConfig(), getLambdaContext(), stepOperationId, stepOperationName, attempt);
}

// ========== step methods ==========
Expand Down Expand Up @@ -305,9 +304,27 @@ public <T> DurableFuture<T> runInChildContextAsync(
* @return the durable logger
*/
public DurableLogger getLogger() {
// lazy initialize logger
if (logger == null) {
synchronized (this) {
if (logger == null) {
logger = new DurableLogger(LoggerFactory.getLogger(DurableContext.class), this);
}
}
}
return logger;
}

/**
* Clears the logger's thread properties. Called during context destruction to prevent memory leaks and ensure clean
* state for subsequent executions.
*/
public void close() {
if (logger != null) {
logger.close();
}
}

/**
* Get the next operationId. For root contexts, returns sequential IDs like "1", "2", "3". For child contexts,
* prefixes with the contextId to ensure global uniqueness, e.g. "1-1", "1-2" for operations inside child context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,52 +38,45 @@ public static <I, O> DurableExecutionOutput execute(
Class<I> inputType,
BiFunction<I, DurableContext, O> handler,
DurableConfig config) {
var executionManager = new ExecutionManager(
input.durableExecutionArn(), input.checkpointToken(), input.initialExecutionState(), config);

var handlerFuture = CompletableFuture.supplyAsync(
() -> {
var userInput =
extractUserInput(executionManager.getExecutionOperation(), config.getSerDes(), inputType);
// Create context in the executor thread so it detects the correct thread name
var context = DurableContext.createRootContext(executionManager, config, lambdaContext);
executionManager.registerActiveThread(null);
executionManager.setCurrentThreadContext(new ThreadContext(null, ThreadType.CONTEXT));
return handler.apply(userInput, context);
},
config.getExecutorService()); // Get executor from config for running user code

// Execute the handlerFuture in ExecutionManager. If it completes successfully, the output of user function
// will be returned. Otherwise, it will complete exceptionally with a SuspendExecutionException or a failure.
return executionManager
.runUntilCompleteOrSuspend(handlerFuture)
.handle((result, ex) -> {
if (ex != null) {
// an exception thrown from handlerFuture or suspension/termination occurred
Throwable cause = ExceptionHelper.unwrapCompletableFuture(ex);
if (cause instanceof SuspendExecutionException) {
return DurableExecutionOutput.pending();
try (var executionManager = new ExecutionManager(input, config)) {
executionManager.registerActiveThread(null);
var handlerFuture = CompletableFuture.supplyAsync(
() -> {
var userInput = extractUserInput(
executionManager.getExecutionOperation(), config.getSerDes(), inputType);
// use try-with-resources to clear logger properties
try (var context = DurableContext.createRootContext(executionManager, config, lambdaContext)) {
// Create context in the executor thread so it detects the correct thread name
executionManager.setCurrentThreadContext(new ThreadContext(null, ThreadType.CONTEXT));
return handler.apply(userInput, context);
}
},
config.getExecutorService()); // Get executor from config for running user code

// Execute the handlerFuture in ExecutionManager. If it completes successfully, the output of user function
// will be returned. Otherwise, it will complete exceptionally with a SuspendExecutionException or a
// failure.
return executionManager
.runUntilCompleteOrSuspend(handlerFuture)
.handle((result, ex) -> {
if (ex != null) {
// an exception thrown from handlerFuture or suspension/termination occurred
Throwable cause = ExceptionHelper.unwrapCompletableFuture(ex);
if (cause instanceof SuspendExecutionException) {
return DurableExecutionOutput.pending();
}

logger.debug("Execution failed: {}", cause.getMessage());
return DurableExecutionOutput.failure(buildErrorObject(cause, config.getSerDes()));
}
// user handler complete successfully
var outputPayload = config.getSerDes().serialize(result);

logger.debug("Execution failed: {}", cause.getMessage());
return DurableExecutionOutput.failure(buildErrorObject(cause, config.getSerDes()));
}
// user handler complete successfully
var outputPayload = config.getSerDes().serialize(result);

logger.debug("Execution completed");
return DurableExecutionOutput.success(handleLargePayload(executionManager, outputPayload));
})
.whenComplete((v, ex) -> {
// We shutdown the execution to make sure remaining checkpoint calls in the queue are drained
// We DO NOT shutdown the executor since it should stay warm for re-invokes against a warm Lambda
// runtime.
// For example, a re-invoke after a wait should re-use the same executor instance from
// DurableConfig.
// userExecutor.shutdown();
executionManager.shutdown();
})
.join();
logger.debug("Execution completed");
return DurableExecutionOutput.success(handleLargePayload(executionManager, outputPayload));
})
.join();
}
}

private static String handleLargePayload(ExecutionManager executionManager, String outputPayload) {
Expand Down
36 changes: 27 additions & 9 deletions sdk/src/main/java/software/amazon/lambda/durable/StepContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import software.amazon.lambda.durable.logging.DurableLogger;

public class StepContext extends BaseContext {
private final DurableLogger logger;
private volatile DurableLogger logger;
private final int attempt;

/**
* Creates a new StepContext instance for use in step operations.
Expand All @@ -22,19 +23,36 @@ protected StepContext(
ExecutionManager executionManager,
DurableConfig durableConfig,
Context lambdaContext,
String stepOperationId) {
super(executionManager, durableConfig, lambdaContext, stepOperationId);
String stepOperationId,
String stepOperationName,
int attempt) {
super(executionManager, durableConfig, lambdaContext, stepOperationId, stepOperationName);
this.attempt = attempt;
}

var requestId = lambdaContext != null ? lambdaContext.getAwsRequestId() : null;
this.logger = new DurableLogger(
LoggerFactory.getLogger(StepContext.class),
executionManager,
requestId,
durableConfig.getLoggerConfig().suppressReplayLogs());
/** @return the current attempt */
public int getAttempt() {
return attempt;
}

@Override
public DurableLogger getLogger() {
// lazy initialize logger
if (logger == null) {
synchronized (this) {
if (logger == null) {
logger = new DurableLogger(LoggerFactory.getLogger(StepContext.class), this);
}
}
}
return logger;
}

/** Closes the logger for this context. */
@Override
public void close() {
if (logger != null) {
logger.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import software.amazon.awssdk.services.lambda.model.OperationUpdate;
import software.amazon.lambda.durable.DurableConfig;
import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException;
import software.amazon.lambda.durable.model.DurableExecutionInput;
import software.amazon.lambda.durable.operation.BaseDurableOperation;

/**
Expand All @@ -45,7 +46,7 @@
*
* @see InternalExecutor
*/
public class ExecutionManager {
public class ExecutionManager implements AutoCloseable {

private static final Logger logger = LoggerFactory.getLogger(ExecutionManager.class);

Expand All @@ -65,25 +66,21 @@ public class ExecutionManager {
// ===== Checkpoint Batching =====
private final CheckpointBatcher checkpointBatcher;

public ExecutionManager(
String durableExecutionArn,
String checkpointToken,
CheckpointUpdatedExecutionState initialExecutionState,
DurableConfig config) {
this.durableExecutionArn = durableExecutionArn;
public ExecutionManager(DurableExecutionInput input, DurableConfig config) {
this.durableExecutionArn = input.durableExecutionArn();

// Create checkpoint batcher for internal coordination
this.checkpointBatcher =
new CheckpointBatcher(config, durableExecutionArn, checkpointToken, this::onCheckpointComplete);
new CheckpointBatcher(config, durableExecutionArn, input.checkpointToken(), this::onCheckpointComplete);

this.operationStorage = checkpointBatcher.fetchAllPages(initialExecutionState).stream()
this.operationStorage = checkpointBatcher.fetchAllPages(input.initialExecutionState()).stream()
.collect(Collectors.toConcurrentMap(Operation::id, op -> op));

// Start in REPLAY mode if we have more than just the initial EXECUTION operation
this.executionMode =
new AtomicReference<>(operationStorage.size() > 1 ? ExecutionMode.REPLAY : ExecutionMode.EXECUTION);

executionOp = findExecutionOp(initialExecutionState);
executionOp = findExecutionOp(input.initialExecutionState());

// Validate initial operation is an EXECUTION operation
if (executionOp == null) {
Expand Down Expand Up @@ -248,7 +245,9 @@ public CompletableFuture<Operation> pollForOperationUpdates(String operationId,
}

// ===== Utilities =====
public void shutdown() {
/** Shutdown the checkpoint batcher. */
@Override
public void close() {
checkpointBatcher.shutdown();
}

Expand Down
Loading