Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3445d24
Add tests for blocking operation validation
kristjanvalur Mar 8, 2023
2be1bf4
Move validation into the task
kristjanvalur Apr 8, 2023
02772ee
Add some error/validation test cases
kristjanvalur Apr 9, 2023
7687c4e
Use duck typing to detect an ExecutionResult/GraphQLExeucitonResult
kristjanvalur Apr 9, 2023
6198007
add an async context getter for tests which is easily patchable.
kristjanvalur Apr 9, 2023
9ef3bf0
Add tests to ensure context_getter does not block connection
kristjanvalur Apr 9, 2023
7cf537e
Move context getter and root getter into worker task
kristjanvalur Apr 9, 2023
e41ffe9
Catch top level errors
kristjanvalur Jun 8, 2023
c3d6447
Add a test for the task error handler
kristjanvalur Jun 9, 2023
a9589e4
add release.md
kristjanvalur May 9, 2023
ca229a7
Remove dead code, fix coverage
kristjanvalur Jun 30, 2023
1a62b1a
remove special case for AsyncMock
kristjanvalur Aug 3, 2023
76716d0
Add "no cover" to schema code which is designed to not be hit.
kristjanvalur Nov 8, 2023
1ad929b
Update tests for litestar
kristjanvalur Mar 31, 2024
8ced4e4
Litestar integration must be excluded from long test, like Starlite.
kristjanvalur Mar 31, 2024
35a8e68
coverage
kristjanvalur Mar 31, 2024
4084639
Mark some test schema methods as no cover since they are not always used
kristjanvalur Apr 2, 2024
d17a4d4
Mypy support for SubscriptionExecutionResult
kristjanvalur Sep 7, 2024
a1d0695
ruff
kristjanvalur Sep 7, 2024
e43aca8
Remove unused method for coverage
kristjanvalur Sep 8, 2024
3d97deb
Merge branch 'main' into kristjan/validate-in-task
kristjanvalur Oct 13, 2024
1be5a06
revert the handler to original state
kristjanvalur Oct 13, 2024
c074e09
Remove tests for long contexts
kristjanvalur Oct 13, 2024
373400f
Revert "add an async context getter for tests which is easily patchab…
kristjanvalur Oct 13, 2024
c4d0b05
cleanup
kristjanvalur Oct 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

Fix error handling for query operations over graphql-transport-ws
25 changes: 25 additions & 0 deletions tests/views/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ def has_permission(self, source: Any, info: strawberry.Info, **kwargs: Any) -> b
return False


class ConditionalFailPermission(BasePermission):
@property
def message(self):
return f"failed after sleep {self.sleep}"

async def has_permission(self, source, info, **kwargs: Any) -> bool:
self.sleep = kwargs.get("sleep", None)
self.fail = kwargs.get("fail", True)
if self.sleep is not None:
await asyncio.sleep(kwargs["sleep"])
return not self.fail


class MyExtension(SchemaExtension):
def get_results(self) -> Dict[str, str]:
return {"example": "example"}
Expand Down Expand Up @@ -80,6 +93,12 @@ async def async_hello(self, name: Optional[str] = None, delay: float = 0) -> str
def always_fail(self) -> Optional[str]:
return "Hey"

@strawberry.field(permission_classes=[ConditionalFailPermission])
def conditional_fail(
self, sleep: Optional[float] = None, fail: bool = False
) -> str:
return "Hey"

@strawberry.field
async def error(self, message: str) -> AsyncGenerator[str, None]:
yield GraphQLError(message) # type: ignore
Expand Down Expand Up @@ -262,6 +281,12 @@ async def long_finalizer(
finally:
await asyncio.sleep(delay)

@strawberry.subscription(permission_classes=[ConditionalFailPermission])
async def conditional_fail(
self, sleep: Optional[float] = None, fail: bool = False
) -> AsyncGenerator[str, None]:
yield "Hey" # pragma: no cover


class Schema(strawberry.Schema):
def process_errors(
Expand Down
159 changes: 159 additions & 0 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from pytest_mock import MockerFixture

from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import (
BaseGraphQLTransportWSHandler,
)
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
CompleteMessage,
ConnectionAckMessage,
Expand Down Expand Up @@ -437,6 +440,28 @@ async def test_subscription_field_errors(ws: WebSocketClient):
process_errors.assert_called_once()


async def test_query_field_errors(ws: WebSocketClient):
await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="query { notASubscriptionField }",
),
).as_dict()
)

response = await ws.receive_json()
assert response["type"] == ErrorMessage.type
assert response["id"] == "sub1"
assert len(response["payload"]) == 1
assert response["payload"][0].get("path") is None
assert response["payload"][0]["locations"] == [{"line": 1, "column": 9}]
assert (
response["payload"][0]["message"]
== "Cannot query field 'notASubscriptionField' on type 'Query'."
)


async def test_subscription_cancellation(ws: WebSocketClient):
await ws.send_json(
SubscribeMessage(
Expand Down Expand Up @@ -963,3 +988,137 @@ async def test_no_extensions_results_wont_send_extensions_in_payload(
mock.assert_called_once()
assert_next(response, "sub1", {"echo": "Hi"})
assert "extensions" not in response["payload"]


async def test_validation_query(ws: WebSocketClient):
"""
Test validation for query
"""
await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="query { conditionalFail(fail:true) }"
),
Comment on lines +993 to +1002
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (testing): Inconsistency in error handling between queries and subscriptions

As noted in the PR description, there's a discrepancy in how validation errors are handled for queries vs. subscriptions. This test is failing because it expects an ErrorMessage, but receives a NextMessage instead. Consider updating the implementation to handle validation errors consistently across queries and subscriptions.

).as_dict()
)

# We expect an error message directly
response = await ws.receive_json()
assert response["type"] == ErrorMessage.type
assert response["id"] == "sub1"
assert len(response["payload"]) == 1
assert response["payload"][0].get("path") == ["conditionalFail"]
assert response["payload"][0]["message"] == "failed after sleep None"


async def test_validation_subscription(ws: WebSocketClient):
"""
Test validation for subscription
"""
await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="subscription { conditionalFail(fail:true) }"
),
).as_dict()
)

# We expect an error message directly
response = await ws.receive_json()
assert response["type"] == ErrorMessage.type
assert response["id"] == "sub1"
assert len(response["payload"]) == 1
assert response["payload"][0].get("path") == ["conditionalFail"]
assert response["payload"][0]["message"] == "failed after sleep None"


async def test_long_validation_concurrent_query(ws: WebSocketClient):
"""
Test that the websocket is not blocked while validating a
single-result-operation
"""
await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="query { conditionalFail(sleep:0.1) }"
),
).as_dict()
)
await ws.send_json(
SubscribeMessage(
id="sub2",
payload=SubscribeMessagePayload(
query="query { conditionalFail(fail:false) }"
),
).as_dict()
)

# we expect the second query to arrive first, because the
# first query is stuck in validation
response = await ws.receive_json()
assert_next(response, "sub2", {"conditionalFail": "Hey"})


async def test_long_validation_concurrent_subscription(ws: WebSocketClient):
"""
Test that the websocket is not blocked while validating a
subscription
"""
await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="subscription { conditionalFail(sleep:0.1) }"
),
).as_dict()
)
await ws.send_json(
SubscribeMessage(
id="sub2",
payload=SubscribeMessagePayload(
query="query { conditionalFail(fail:false) }"
),
).as_dict()
)

# we expect the second query to arrive first, because the
# first operation is stuck in validation
response = await ws.receive_json()
assert_next(response, "sub2", {"conditionalFail": "Hey"})


async def test_task_error_handler(ws: WebSocketClient):
"""
Test that error handling works
"""
# can't use a simple Event here, because the handler may run
# on a different thread
wakeup = False

# a replacement method which causes an error in th eTask
async def op(*args: Any, **kwargs: Any):
nonlocal wakeup
wakeup = True
raise ZeroDivisionError("test")

with patch.object(BaseGraphQLTransportWSHandler, "task_logger") as logger:
with patch.object(BaseGraphQLTransportWSHandler, "handle_operation", op):
# send any old subscription request. It will raise an error
await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="subscription { conditionalFail(sleep:0) }"
),
).as_dict()
)

# wait for the error to be logged. Must use timed loop and not event.
while not wakeup: # noqa: ASYNC110
await asyncio.sleep(0.01)
# and another little bit, for the thread to finish
await asyncio.sleep(0.01)
assert logger.exception.called