Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def get_messages(
*,
run_ids: Sequence[int] | None = None,
is_reply: bool | None = None,
is_retrieved: bool | None = False,
limit: int | None = None,
) -> Sequence[Message]:
"""Retrieve messages based on the specified filters."""
Expand All @@ -90,9 +91,10 @@ def get_messages(
entry = self.msg_store[object_id]
message = entry.message

# Skip messages that have already been retrieved
if entry.is_retrieved:
continue
# Filter by retrieved status if specified
if is_retrieved is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about handling this as a boolean? so having an intermediate var that's False if the input is_retrieved is None or False?

if entry.is_retrieved != is_retrieved:
continue

# Skip messages whose run_id doesn't match the filter
if run_ids is not None:
Expand Down
9 changes: 4 additions & 5 deletions framework/py/flwr/supernode/nodestate/nodestate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def get_messages(
*,
run_ids: Sequence[int] | None = None,
is_reply: bool | None = None,
is_retrieved: bool | None = False,
limit: int | None = None,
) -> Sequence[Message]:
"""Retrieve messages based on the specified filters.
Expand All @@ -70,18 +71,16 @@ def get_messages(
is_reply : Optional[bool] (default: None)
If True, filter for reply messages; if False, filter for non-reply
(instruction) messages.
is_retrieved : Optional[bool] (default: False)
If True, retrieve only messages that have already been retrieved.
If False, retrieve only messages that have not yet been retrieved.
limit : Optional[int] (default: None)
Maximum number of messages to return. If None, no limit is applied.

Returns
-------
Sequence[Message]
A sequence of messages matching the specified filters.

Notes
-----
**IMPORTANT:** Retrieved messages will **NOT** be returned again by subsequent
calls to this method, even if the filters match them.
"""

@abstractmethod
Expand Down
30 changes: 30 additions & 0 deletions framework/py/flwr/supernode/nodestate/nodestate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def test_store_and_get_message_basic(self) -> None:
({"run_ids": [2, 3]}, {"msg3", "msg4"}),
({"is_reply": True}, {"msg1", "msg4"}),
({"is_reply": True, "limit": 1}, {"msg1", "msg4"}),
({"is_retrieved": False}, {"msg1", "msg2", "msg3", "msg4"}),
({"is_retrieved": True}, set()),
({"run_ids": [1], "is_retrieved": False}, {"msg1", "msg2"}),
]
)
def test_get_message_with_filters(
Expand Down Expand Up @@ -159,6 +162,33 @@ def test_delete_message(self) -> None:
self.assertNotIn("msg1", msg_ids)
self.assertIn("msg2", msg_ids)

def test_get_messages_is_retrieved_filter(self) -> None:
"""Test retrieving messages based on is_retrieved status."""
# Prepare: store messages
self.state.store_message(make_dummy_message(1, False, "msg1"))
self.state.store_message(make_dummy_message(2, False, "msg2"))
self.state.store_message(make_dummy_message(3, False, "msg3"))

# Execute: retrieve msg1 and msg2
result1 = self.state.get_messages(run_ids=[1, 2], is_retrieved=False)
self.assertEqual(len(result1), 2)

# Assert: can retrieve already retrieved messages with is_retrieved=True
result2 = self.state.get_messages(is_retrieved=True)
result2_ids = {msg.metadata.message_id for msg in result2}
self.assertEqual(result2_ids, {"msg1", "msg2"})

# Assert: msg3 is still not retrieved
result3 = self.state.get_messages(is_retrieved=False)
result3_ids = {msg.metadata.message_id for msg in result3}
self.assertEqual(result3_ids, {"msg3"})

# Assert: can retrieve all messages with is_retrieved=None
result4 = self.state.get_messages(is_retrieved=None)
result4_ids = {msg.metadata.message_id for msg in result4}
# All three messages should be returned
self.assertEqual(result4_ids, {"msg1", "msg2", "msg3"})

def test_get_run_ids_with_pending_messages(self) -> None:
"""Test retrieving run IDs with pending messages."""
# Prepare: store messages for runs 1, 2, and 3
Expand Down
Loading