diff --git a/framework/py/flwr/supernode/nodestate/in_memory_nodestate.py b/framework/py/flwr/supernode/nodestate/in_memory_nodestate.py index 17f20bfbe1ca..d002420c5543 100644 --- a/framework/py/flwr/supernode/nodestate/in_memory_nodestate.py +++ b/framework/py/flwr/supernode/nodestate/in_memory_nodestate.py @@ -79,6 +79,7 @@ def get_messages( *, run_ids: Sequence[int] | None = None, is_reply: bool | None = None, + is_retrieved: bool | None = False, limit: int | None = None, ) -> Sequence[Message]: """Retrieve messages based on the specified filters.""" @@ -90,9 +91,10 @@ def get_messages( entry = self.msg_store[object_id] message = entry.message - # Skip messages that have already been retrieved - if entry.is_retrieved: - continue + # Filter by retrieved status if specified + if is_retrieved is not None: + if entry.is_retrieved != is_retrieved: + continue # Skip messages whose run_id doesn't match the filter if run_ids is not None: diff --git a/framework/py/flwr/supernode/nodestate/nodestate.py b/framework/py/flwr/supernode/nodestate/nodestate.py index 08d3e0ec6854..dcbc61b08903 100644 --- a/framework/py/flwr/supernode/nodestate/nodestate.py +++ b/framework/py/flwr/supernode/nodestate/nodestate.py @@ -55,6 +55,7 @@ def get_messages( *, run_ids: Sequence[int] | None = None, is_reply: bool | None = None, + is_retrieved: bool | None = False, limit: int | None = None, ) -> Sequence[Message]: """Retrieve messages based on the specified filters. @@ -70,6 +71,9 @@ def get_messages( is_reply : Optional[bool] (default: None) If True, filter for reply messages; if False, filter for non-reply (instruction) messages. + is_retrieved : Optional[bool] (default: False) + If True, retrieve only messages that have already been retrieved. + If False, retrieve only messages that have not yet been retrieved. limit : Optional[int] (default: None) Maximum number of messages to return. If None, no limit is applied. @@ -77,11 +81,6 @@ def get_messages( ------- Sequence[Message] A sequence of messages matching the specified filters. - - Notes - ----- - **IMPORTANT:** Retrieved messages will **NOT** be returned again by subsequent - calls to this method, even if the filters match them. """ @abstractmethod diff --git a/framework/py/flwr/supernode/nodestate/nodestate_test.py b/framework/py/flwr/supernode/nodestate/nodestate_test.py index 1d1c66dd38e3..bc8850cc364b 100644 --- a/framework/py/flwr/supernode/nodestate/nodestate_test.py +++ b/framework/py/flwr/supernode/nodestate/nodestate_test.py @@ -116,6 +116,9 @@ def test_store_and_get_message_basic(self) -> None: ({"run_ids": [2, 3]}, {"msg3", "msg4"}), ({"is_reply": True}, {"msg1", "msg4"}), ({"is_reply": True, "limit": 1}, {"msg1", "msg4"}), + ({"is_retrieved": False}, {"msg1", "msg2", "msg3", "msg4"}), + ({"is_retrieved": True}, set()), + ({"run_ids": [1], "is_retrieved": False}, {"msg1", "msg2"}), ] ) def test_get_message_with_filters( @@ -159,6 +162,33 @@ def test_delete_message(self) -> None: self.assertNotIn("msg1", msg_ids) self.assertIn("msg2", msg_ids) + def test_get_messages_is_retrieved_filter(self) -> None: + """Test retrieving messages based on is_retrieved status.""" + # Prepare: store messages + self.state.store_message(make_dummy_message(1, False, "msg1")) + self.state.store_message(make_dummy_message(2, False, "msg2")) + self.state.store_message(make_dummy_message(3, False, "msg3")) + + # Execute: retrieve msg1 and msg2 + result1 = self.state.get_messages(run_ids=[1, 2], is_retrieved=False) + self.assertEqual(len(result1), 2) + + # Assert: can retrieve already retrieved messages with is_retrieved=True + result2 = self.state.get_messages(is_retrieved=True) + result2_ids = {msg.metadata.message_id for msg in result2} + self.assertEqual(result2_ids, {"msg1", "msg2"}) + + # Assert: msg3 is still not retrieved + result3 = self.state.get_messages(is_retrieved=False) + result3_ids = {msg.metadata.message_id for msg in result3} + self.assertEqual(result3_ids, {"msg3"}) + + # Assert: can retrieve all messages with is_retrieved=None + result4 = self.state.get_messages(is_retrieved=None) + result4_ids = {msg.metadata.message_id for msg in result4} + # All three messages should be returned + self.assertEqual(result4_ids, {"msg1", "msg2", "msg3"}) + def test_get_run_ids_with_pending_messages(self) -> None: """Test retrieving run IDs with pending messages.""" # Prepare: store messages for runs 1, 2, and 3