diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index a767be42b08e..6831d558d896 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -213,7 +213,7 @@ class ChartPostSchema(Schema): query_context = fields.String( metadata={"description": query_context_description}, allow_none=True, - validate=utils.validate_json, + validate=utils.validate_query_context_metadata, ) query_context_generation = fields.Boolean( metadata={"description": query_context_generation_description}, allow_none=True @@ -274,7 +274,9 @@ class ChartPutSchema(Schema): validate=utils.validate_json, ) query_context = fields.String( - metadata={"description": query_context_description}, allow_none=True + metadata={"description": query_context_description}, + allow_none=True, + validate=utils.validate_query_context_metadata, ) query_context_generation = fields.Boolean( metadata={"description": query_context_generation_description}, allow_none=True diff --git a/superset/utils/schema.py b/superset/utils/schema.py index 8994b87c1d9c..eb5631e12f73 100644 --- a/superset/utils/schema.py +++ b/superset/utils/schema.py @@ -50,4 +50,37 @@ def validate_json(value: Union[bytes, bytearray, str]) -> None: try: json.validate_json(value) except json.JSONDecodeError as ex: - raise ValidationError("JSON not valid") from ex + error_msg = "JSON not valid" + raise ValidationError(error_msg) from ex + + +def validate_query_context_metadata(value: Union[bytes, bytearray, str, None]) -> None: + """ + Validator for query_context field to ensure it contains required metadata. + + Validates that the query_context JSON contains the required 'datasource' and + 'queries' fields needed for chart data retrieval. + + :raises ValidationError: if value is not valid JSON or missing required fields + :param value: a JSON string that should contain datasource and queries metadata + """ + if value is None or value == "": + return # Allow None values and empty strings + + # Reuse existing JSON validation logic + validate_json(value) + + # Parse and validate the structure + parsed_data = json.loads(value) + + # Validate required fields exist in the query_context + if not isinstance(parsed_data, dict): + error_msg = "Query context must be a valid JSON object" + raise ValidationError(error_msg) + + # When query_context is provided (not None), validate it has required fields + required_fields = {"datasource", "queries"} + missing_fields: set[str] = required_fields - parsed_data.keys() + if missing_fields: + fields_str = ", ".join(sorted(missing_fields)) + raise ValidationError(f"Query context is missing required fields: {fields_str}") diff --git a/tests/unit_tests/charts/test_schemas.py b/tests/unit_tests/charts/test_schemas.py index 5466a0deadd9..404c35291393 100644 --- a/tests/unit_tests/charts/test_schemas.py +++ b/tests/unit_tests/charts/test_schemas.py @@ -22,8 +22,11 @@ from superset.charts.schemas import ( ChartDataProphetOptionsSchema, ChartDataQueryObjectSchema, + ChartPostSchema, + ChartPutSchema, get_time_grain_choices, ) +from superset.utils import json def test_get_time_grain_choices(app_context: None) -> None: @@ -152,3 +155,120 @@ def test_time_grain_validation_with_config_addons(app_context: None) -> None: } result = schema.load(custom_data) assert result["time_grain"] == "PT10M" + + +def test_chart_post_schema_query_context_validation(app_context: None) -> None: + """Test that ChartPostSchema validates query_context contains required metadata""" + schema = ChartPostSchema() + + # Valid query_context with datasource and queries should pass + valid_query_context = json.dumps( + { + "datasource": {"type": "table", "id": 1}, + "queries": [{"metrics": ["count"], "columns": []}], + } + ) + valid_data = { + "slice_name": "Test Chart", + "datasource_id": 1, + "datasource_type": "table", + "query_context": valid_query_context, + } + result = schema.load(valid_data) + assert result["query_context"] == valid_query_context + + # None query_context should be allowed (allow_none=True) + none_data = { + "slice_name": "Test Chart", + "datasource_id": 1, + "datasource_type": "table", + "query_context": None, + } + result = schema.load(none_data) + assert result["query_context"] is None + + # Query context missing 'datasource' field should fail + missing_datasource = json.dumps( + {"queries": [{"metrics": ["count"], "columns": []}]} + ) + invalid_data_1 = { + "slice_name": "Test Chart", + "datasource_id": 1, + "datasource_type": "table", + "query_context": missing_datasource, + } + with pytest.raises(ValidationError) as exc_info: + schema.load(invalid_data_1) + assert "query_context" in exc_info.value.messages + assert "datasource" in str(exc_info.value.messages["query_context"]) + + # Query context missing 'queries' field should fail + missing_queries = json.dumps({"datasource": {"type": "table", "id": 1}}) + invalid_data_2 = { + "slice_name": "Test Chart", + "datasource_id": 1, + "datasource_type": "table", + "query_context": missing_queries, + } + with pytest.raises(ValidationError) as exc_info: + schema.load(invalid_data_2) + assert "query_context" in exc_info.value.messages + assert "queries" in str(exc_info.value.messages["query_context"]) + + # Query context missing both 'datasource' and 'queries' should fail + empty_query_context = json.dumps({}) + invalid_data_3 = { + "slice_name": "Test Chart", + "datasource_id": 1, + "datasource_type": "table", + "query_context": empty_query_context, + } + with pytest.raises(ValidationError) as exc_info: + schema.load(invalid_data_3) + assert "query_context" in exc_info.value.messages + assert "datasource" in str(exc_info.value.messages["query_context"]) + assert "queries" in str(exc_info.value.messages["query_context"]) + + # Invalid JSON should fail + invalid_json = "not valid json" + invalid_data_4 = { + "slice_name": "Test Chart", + "datasource_id": 1, + "datasource_type": "table", + "query_context": invalid_json, + } + with pytest.raises(ValidationError) as exc_info: + schema.load(invalid_data_4) + assert "query_context" in exc_info.value.messages + + +def test_chart_put_schema_query_context_validation(app_context: None) -> None: + """Test that ChartPutSchema validates query_context contains required metadata""" + schema = ChartPutSchema() + + # Valid query_context with datasource and queries should pass + valid_query_context = json.dumps( + { + "datasource": {"type": "table", "id": 1}, + "queries": [{"metrics": ["count"], "columns": []}], + } + ) + valid_data = { + "slice_name": "Updated Chart", + "query_context": valid_query_context, + } + result = schema.load(valid_data) + assert result["query_context"] == valid_query_context + + # Query context missing required fields should fail + missing_datasource = json.dumps( + {"queries": [{"metrics": ["count"], "columns": []}]} + ) + invalid_data = { + "slice_name": "Updated Chart", + "query_context": missing_datasource, + } + with pytest.raises(ValidationError) as exc_info: + schema.load(invalid_data) + assert "query_context" in exc_info.value.messages + assert "datasource" in str(exc_info.value.messages["query_context"]) diff --git a/tests/unit_tests/utils/test_schema.py b/tests/unit_tests/utils/test_schema.py new file mode 100644 index 000000000000..29f60b22abe9 --- /dev/null +++ b/tests/unit_tests/utils/test_schema.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for schema validation utilities.""" + +import pytest +from marshmallow import ValidationError + +from superset.utils import json +from superset.utils.schema import ( + OneOfCaseInsensitive, + validate_json, + validate_query_context_metadata, +) + + +def test_validate_json_valid() -> None: + """Test validate_json with valid JSON string.""" + valid_json = '{"key": "value", "number": 123}' + # Should not raise any exception + validate_json(valid_json) + + +def test_validate_json_valid_bytes() -> None: + """Test validate_json with valid JSON bytes.""" + valid_json = b'{"key": "value", "number": 123}' + # Should not raise any exception + validate_json(valid_json) + + +def test_validate_json_invalid() -> None: + """Test validate_json with invalid JSON.""" + invalid_json = '{"key": "value", "number": 123' + with pytest.raises(ValidationError) as exc_info: + validate_json(invalid_json) + assert "JSON not valid" in str(exc_info.value) + + +def test_validate_json_empty_string() -> None: + """Test validate_json with empty string - empty strings are allowed.""" + # Empty strings do not raise an error because validate_json has early return + validate_json("") # Should not raise any exception + + +def test_validate_json_whitespace_only() -> None: + """Test validate_json with whitespace-only string.""" + with pytest.raises(ValidationError) as exc_info: + validate_json(" ") + assert "JSON not valid" in str(exc_info.value) + + +def test_validate_json_not_json() -> None: + """Test validate_json with non-JSON string.""" + not_json = "this is not json" + with pytest.raises(ValidationError) as exc_info: + validate_json(not_json) + assert "JSON not valid" in str(exc_info.value) + + +def test_validate_query_context_metadata_none() -> None: + """Test validate_query_context_metadata allows None values.""" + # Should not raise any exception for None + validate_query_context_metadata(None) + + +def test_validate_query_context_metadata_valid() -> None: + """Test validate_query_context_metadata with valid query context.""" + valid_query_context = json.dumps( + { + "datasource": {"type": "table", "id": 1}, + "queries": [{"metrics": ["count"], "columns": []}], + }, + ) + # Should not raise any exception + validate_query_context_metadata(valid_query_context) + + +def test_validate_query_context_metadata_valid_bytes() -> None: + """Test validate_query_context_metadata with valid query context as bytes.""" + valid_query_context = json.dumps( + { + "datasource": {"type": "table", "id": 1}, + "queries": [{"metrics": ["count"], "columns": []}], + }, + ).encode("utf-8") + # Should not raise any exception + validate_query_context_metadata(valid_query_context) + + +def test_validate_query_context_metadata_invalid_json() -> None: + """Test validate_query_context_metadata with invalid JSON.""" + invalid_json = '{"datasource": {"type": "table"' + with pytest.raises(ValidationError) as exc_info: + validate_query_context_metadata(invalid_json) + assert "JSON not valid" in str(exc_info.value) + + +def test_validate_query_context_metadata_not_dict() -> None: + """Test validate_query_context_metadata with non-dict JSON.""" + not_dict = json.dumps(["array", "values"]) + with pytest.raises(ValidationError) as exc_info: + validate_query_context_metadata(not_dict) + assert "Query context must be a valid JSON object" in str(exc_info.value) + + +def test_validate_query_context_metadata_missing_datasource() -> None: + """Test validate_query_context_metadata with missing datasource field.""" + missing_datasource = json.dumps( + {"queries": [{"metrics": ["count"], "columns": []}]}, + ) + with pytest.raises(ValidationError) as exc_info: + validate_query_context_metadata(missing_datasource) + error_message = str(exc_info.value) + assert "Query context is missing required fields" in error_message + assert "datasource" in error_message + + +def test_validate_query_context_metadata_missing_queries() -> None: + """Test validate_query_context_metadata with missing queries field.""" + missing_queries = json.dumps({"datasource": {"type": "table", "id": 1}}) + with pytest.raises(ValidationError) as exc_info: + validate_query_context_metadata(missing_queries) + error_message = str(exc_info.value) + assert "Query context is missing required fields" in error_message + assert "queries" in error_message + + +def test_validate_query_context_metadata_missing_both_fields() -> None: + """Test validate_query_context_metadata with both required fields missing.""" + empty_context = json.dumps({}) + with pytest.raises(ValidationError) as exc_info: + validate_query_context_metadata(empty_context) + error_message = str(exc_info.value) + assert "Query context is missing required fields" in error_message + assert "datasource" in error_message + assert "queries" in error_message + + +def test_validate_query_context_metadata_extra_fields() -> None: + """Test validate_query_context_metadata allows extra fields.""" + context_with_extras = json.dumps( + { + "datasource": {"type": "table", "id": 1}, + "queries": [{"metrics": ["count"], "columns": []}], + "extra_field": "extra_value", + "another_field": 123, + }, + ) + # Should not raise any exception - extra fields are allowed + validate_query_context_metadata(context_with_extras) + + +def test_validate_query_context_metadata_empty_values() -> None: + """Test validate_query_context_metadata with empty but present values.""" + context_with_empty = json.dumps( + { + "datasource": {}, + "queries": [], + }, + ) + # Should not raise any exception - fields exist even if empty + validate_query_context_metadata(context_with_empty) + + +def test_validate_query_context_metadata_null_datasource() -> None: + """Test validate_query_context_metadata with null datasource value.""" + context_with_null = json.dumps( + { + "datasource": None, + "queries": [{"metrics": ["count"]}], + }, + ) + # Should not raise any exception - field exists even if null + validate_query_context_metadata(context_with_null) + + +def test_validate_query_context_metadata_null_queries() -> None: + """Test validate_query_context_metadata with null queries value.""" + context_with_null = json.dumps( + { + "datasource": {"type": "table", "id": 1}, + "queries": None, + }, + ) + # Should not raise any exception - field exists even if null + validate_query_context_metadata(context_with_null) + + +def test_validate_query_context_metadata_empty_string() -> None: + """Test validate_query_context_metadata with empty string.""" + # Empty string should be treated as None and not raise error + validate_query_context_metadata("") + + +def test_validate_query_context_metadata_whitespace() -> None: + """Test validate_query_context_metadata with whitespace-only string.""" + with pytest.raises(ValidationError) as exc_info: + validate_query_context_metadata(" ") + assert "JSON not valid" in str(exc_info.value) + + +def test_validate_query_context_metadata_string_value() -> None: + """Test validate_query_context_metadata with plain string instead of JSON object.""" + plain_string = json.dumps("just a string") + with pytest.raises(ValidationError) as exc_info: + validate_query_context_metadata(plain_string) + assert "Query context must be a valid JSON object" in str(exc_info.value) + + +def test_validate_query_context_metadata_number_value() -> None: + """Test validate_query_context_metadata with number instead of JSON object.""" + number_value = json.dumps(12345) + with pytest.raises(ValidationError) as exc_info: + validate_query_context_metadata(number_value) + assert "Query context must be a valid JSON object" in str(exc_info.value) + + +def test_validate_query_context_metadata_boolean_value() -> None: + """Test validate_query_context_metadata with boolean instead of JSON object.""" + bool_value = json.dumps(obj=True) + with pytest.raises(ValidationError) as exc_info: + validate_query_context_metadata(bool_value) + assert "Query context must be a valid JSON object" in str(exc_info.value) + + +def test_validate_query_context_metadata_complex_nested_structure() -> None: + """Test validate_query_context_metadata with complex nested structure.""" + complex_context = json.dumps( + { + "datasource": { + "type": "table", + "id": 1, + "schema": "public", + "table_name": "my_table", + "columns": [ + {"name": "col1", "type": "VARCHAR"}, + {"name": "col2", "type": "INTEGER"}, + ], + }, + "queries": [ + { + "metrics": ["count", "sum", "avg"], + "columns": ["col1", "col2"], + "filters": [{"col": "col1", "op": "==", "val": "value"}], + "orderby": [["col1", True]], + "extras": {"where": "col1 IS NOT NULL"}, + }, + ], + "result_format": "json", + "result_type": "full", + }, + ) + # Should not raise any exception - has required fields plus extras + validate_query_context_metadata(complex_context) + + +def test_one_of_case_insensitive_valid_lowercase() -> None: + """Test OneOfCaseInsensitive validator with lowercase value.""" + validator = OneOfCaseInsensitive(["Option1", "Option2", "Option3"]) + result = validator("option1") + assert result == "option1" + + +def test_one_of_case_insensitive_valid_uppercase() -> None: + """Test OneOfCaseInsensitive validator with uppercase value.""" + validator = OneOfCaseInsensitive(["Option1", "Option2", "Option3"]) + result = validator("OPTION2") + assert result == "OPTION2" + + +def test_one_of_case_insensitive_valid_mixed_case() -> None: + """Test OneOfCaseInsensitive validator with mixed case value.""" + validator = OneOfCaseInsensitive(["Option1", "Option2", "Option3"]) + result = validator("OpTiOn3") + assert result == "OpTiOn3" + + +def test_one_of_case_insensitive_invalid() -> None: + """Test OneOfCaseInsensitive validator with invalid value.""" + validator = OneOfCaseInsensitive(["Option1", "Option2", "Option3"]) + with pytest.raises(ValidationError): + validator("invalid") + + +def test_one_of_case_insensitive_non_string_valid() -> None: + """Test OneOfCaseInsensitive validator with non-string valid value.""" + validator = OneOfCaseInsensitive([1, 2, 3]) + result = validator(2) + assert result == 2 + + +def test_one_of_case_insensitive_non_string_invalid() -> None: + """Test OneOfCaseInsensitive validator with non-string invalid value.""" + validator = OneOfCaseInsensitive([1, 2, 3]) + with pytest.raises(ValidationError): + validator(4) + + +def test_one_of_case_insensitive_mixed_types() -> None: + """Test OneOfCaseInsensitive validator with mixed types in choices.""" + validator = OneOfCaseInsensitive(["Option1", 2, "Option3"]) + result = validator("option1") + assert result == "option1" + result = validator(2) + assert result == 2 + + +def test_one_of_case_insensitive_type_error() -> None: + """Test OneOfCaseInsensitive validator with incomparable types.""" + validator = OneOfCaseInsensitive(["Option1", "Option2"]) + # Passing a dict or other non-comparable type should raise ValidationError + with pytest.raises(ValidationError): + validator({"key": "value"})