Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/bedrock_agentcore/_utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Shared configuration dataclasses for SDK clients."""

from dataclasses import dataclass


@dataclass
class WaitConfig:
"""Configuration for *_and_wait polling methods.

Args:
max_wait: Maximum seconds to wait. Default: 300. Must be >= 1.
poll_interval: Seconds between status checks. Default: 10. Must be >= 1.
"""

max_wait: int = 300
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if one function requires different max_waits than another? For example, if create_memories takes longer than 5 minutes usually, then create_memories_and_wait would usually fail with these defaults.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a *_and_wait method requires a larger window than the standard, the default max_wait can be set at the method level. Right now all the _and_wait methods pass None if the caller does not provide a WaitConfig. But if necessary for a specific case we could instead pass in a WaitConfig we define.

poll_interval: int = 10
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same answer as above. As for the exponential backoff, I think it would be good to have. Should we include that as an optional parameter or a default behavior?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets keep that as a p1


def __post_init__(self):
if self.max_wait < 1:
raise ValueError("max_wait must be at least 1")
if self.poll_interval < 1:
raise ValueError("poll_interval must be at least 1")
92 changes: 92 additions & 0 deletions src/bedrock_agentcore/_utils/polling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Shared polling helpers for SDK clients."""

import logging
import time
from typing import Any, Callable, Dict, Optional, Set

from .config import WaitConfig

logger = logging.getLogger(__name__)


def wait_until(
poll_fn: Callable[[], Dict[str, Any]],
target: str,
failed: Set[str],
wait_config: Optional[WaitConfig] = None,
error_field: str = "statusReasons",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this default chosen?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The majority of CP resources use statusReasons as the error detail field, which is why I set it as the default. The few primitives whose CP resources use failureReason are set as such at the method level

) -> Dict[str, Any]:
"""Poll until a resource reaches the target status.

Args:
poll_fn: Zero-arg callable that returns the resource's current state.
target: The status to wait for (e.g. "ACTIVE", "READY").
failed: Statuses that indicate terminal failure.
wait_config: Optional WaitConfig for polling behavior.
error_field: Response field containing error details.

Returns:
Full response when target status is reached.

Raises:
RuntimeError: If the resource reaches a failed status.
TimeoutError: If target status is not reached within max_wait.
"""
wait = wait_config or WaitConfig()
start_time = time.time()
while True:
resp = poll_fn()
status = resp.get("status")
if status is None:
logger.warning("Response missing 'status' field: %s", resp)
if status == target:
return resp
if status in failed:
reason = resp.get(error_field, "Unknown")
raise RuntimeError("Reached %s: %s" % (status, reason))
if time.time() - start_time >= wait.max_wait:
break
time.sleep(wait.poll_interval)
raise TimeoutError("Did not reach %s within %d seconds" % (target, wait.max_wait))


def wait_until_deleted(
poll_fn: Callable[[], Dict[str, Any]],
not_found_code: str = "ResourceNotFoundException",
failed: Optional[Set[str]] = None,
wait_config: Optional[WaitConfig] = None,
error_field: str = "statusReasons",
) -> None:
"""Poll until a resource is deleted (raises not-found exception).

Args:
poll_fn: Zero-arg callable that calls the get API.
not_found_code: The error code indicating the resource is gone.
failed: Optional set of statuses that indicate deletion failed.
wait_config: Optional WaitConfig for polling behavior.
error_field: Response field containing error details.

Raises:
RuntimeError: If the resource reaches a failed status.
TimeoutError: If the resource is not deleted within max_wait.
"""
from botocore.exceptions import ClientError

wait = wait_config or WaitConfig()
start_time = time.time()
while True:
try:
resp = poll_fn()
if failed:
status = resp.get("status")
if status in failed:
reason = resp.get(error_field, "Unknown")
raise RuntimeError("Reached %s: %s" % (status, reason))
except ClientError as e:
if e.response["Error"]["Code"] == not_found_code:
return
raise
if time.time() - start_time >= wait.max_wait:
break
time.sleep(wait.poll_interval)
raise TimeoutError("Resource was not deleted within %d seconds" % wait.max_wait)
5 changes: 5 additions & 0 deletions src/bedrock_agentcore/_utils/snake_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return method(*args, **converted)

return wrapper


def convert_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Convert snake_case kwargs to camelCase for direct boto3 calls."""
return {snake_to_camel(k): v for k, v in kwargs.items()}
114 changes: 114 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Tests for shared _utils: pagination and polling."""

from unittest.mock import Mock, patch

import pytest
from botocore.exceptions import ClientError

from bedrock_agentcore._utils.polling import wait_until, wait_until_deleted


class TestWaitUntil:
def test_immediate_success(self):
poll_fn = Mock(return_value={"status": "ACTIVE"})
result = wait_until(poll_fn, "ACTIVE", {"FAILED"})
assert result["status"] == "ACTIVE"
poll_fn.assert_called_once()

@patch("bedrock_agentcore._utils.polling.time.sleep")
@patch(
"bedrock_agentcore._utils.polling.time.time",
side_effect=[0, 0, 0, 1, 1],
)
def test_polls_until_target(self, _mock_time, _mock_sleep):
poll_fn = Mock(
side_effect=[{"status": "CREATING"}, {"status": "ACTIVE"}],
)
result = wait_until(poll_fn, "ACTIVE", {"FAILED"})
assert result["status"] == "ACTIVE"
assert poll_fn.call_count == 2

def test_raises_on_failed_status(self):
poll_fn = Mock(
return_value={"status": "FAILED", "statusReasons": ["broke"]},
)
with pytest.raises(RuntimeError, match="FAILED"):
wait_until(poll_fn, "ACTIVE", {"FAILED"})

def test_custom_error_field(self):
poll_fn = Mock(
return_value={
"status": "CREATE_FAILED",
"failureReason": "bad config",
},
)
with pytest.raises(RuntimeError, match="bad config"):
wait_until(
poll_fn,
"ACTIVE",
{"CREATE_FAILED"},
error_field="failureReason",
)

@patch("bedrock_agentcore._utils.polling.time.sleep")
@patch(
"bedrock_agentcore._utils.polling.time.time",
side_effect=[0, 0, 0, 301],
)
def test_timeout(self, _mock_time, _mock_sleep):
poll_fn = Mock(return_value={"status": "CREATING"})
with pytest.raises(TimeoutError):
wait_until(poll_fn, "ACTIVE", {"FAILED"})


class TestWaitUntilDeleted:
def test_immediate_not_found(self):
poll_fn = Mock(
side_effect=ClientError(
{"Error": {"Code": "ResourceNotFoundException", "Message": ""}},
"Get",
),
)
wait_until_deleted(poll_fn)
poll_fn.assert_called_once()

@patch("bedrock_agentcore._utils.polling.time.sleep")
@patch(
"bedrock_agentcore._utils.polling.time.time",
side_effect=[0, 0, 0, 1, 1],
)
def test_polls_then_deleted(self, _mock_time, _mock_sleep):
poll_fn = Mock(
side_effect=[
{"status": "DELETING"},
ClientError(
{"Error": {"Code": "ResourceNotFoundException", "Message": ""}},
"Get",
),
],
)
wait_until_deleted(poll_fn)
assert poll_fn.call_count == 2

def test_raises_on_failed_status(self):
poll_fn = Mock(
return_value={
"status": "DELETE_FAILED",
"statusReasons": ["stuck"],
},
)
with pytest.raises(RuntimeError, match="DELETE_FAILED"):
wait_until_deleted(
poll_fn,
failed={"DELETE_FAILED"},
)

@patch("bedrock_agentcore._utils.polling.time.sleep")
@patch(
"bedrock_agentcore._utils.polling.time.time",
side_effect=[0, 0, 0, 301],
)
def test_timeout(self, _mock_time, _mock_sleep):
poll_fn = Mock(return_value={"status": "DELETING"})
with pytest.raises(TimeoutError):
wait_until_deleted(poll_fn)
Loading