-
Notifications
You must be signed in to change notification settings - Fork 10
Rob's fork diff #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Rob's fork diff #64
Conversation
Reviewer's GuideThis PR refactors how partial payloads are stored and reconciled, unifies ModelMesh and KServe flows under a single persistence interface, and adds a new data upload endpoint that orchestrates KServe request/response ingestion into TrustyAI storage. Sequence diagram for the new data upload endpoint flowsequenceDiagram
actor User
participant API as "Upload Endpoint"
participant Consumer as "consume_cloud_event()"
participant Storage as "StorageInterface"
participant ModelData
User->>API: POST /data/upload (UploadPayload)
API->>ModelData: Check if datasets exist
ModelData->>Storage: dataset_exists()
API->>Consumer: consume_cloud_event(response, req_id)
Consumer->>Storage: persist_partial_payload(response, req_id, is_input=False)
API->>Consumer: consume_cloud_event(request, req_id, tag)
Consumer->>Storage: persist_partial_payload(request, req_id, is_input=True)
API->>ModelData: Get new row counts
ModelData->>Storage: dataset_rows()
API-->>User: Return success with number of new datapoints
Class diagram for unified payload types and storage interfaceclassDiagram
class PartialPayloadId {
+Optional[str] prediction_id
+Optional[PartialKind] kind
+get_prediction_id()
+set_prediction_id(id: str)
+get_kind()
+set_kind(kind: PartialKind)
}
class InferencePartialPayload {
+Optional[PartialPayloadId] partialPayloadId
+Optional[Dict[str, str]] metadata
+Optional[str] data
+Optional[str] modelid
+get_id()
+set_id(id: str)
+get_kind()
+set_kind(kind: PartialKind)
+get_model_id()
+set_model_id(model_id: str)
}
class KServeData {
+str name
+List[int] shape
+str datatype
+Optional[Dict[str, str]] parameters
+List data
}
class KServeInferenceRequest {
+Optional[str] id
+Optional[Dict[str, str]] parameters
+List[KServeData] inputs
+Optional[List[KServeData]] outputs
}
class KServeInferenceResponse {
+str model_name
+Optional[str] model_version
+Optional[str] id
+Optional[Dict[str, str]] parameters
+List[KServeData] outputs
}
class StorageInterface {
+async dataset_exists(dataset_name: str)
+async dataset_rows(dataset_name: str)
+async dataset_shape(dataset_name: str)
+async write_data(dataset_name: str, new_rows, column_names: List[str])
+async read_data(dataset_name: str, start_row: int, n_rows: int)
+async get_original_column_names(dataset_name: str)
+async get_aliased_column_names(dataset_name: str)
+async apply_name_mapping(dataset_name: str, name_mapping: Dict[str, str])
+async delete_dataset(dataset_name: str)
+async persist_partial_payload(payload, payload_id, is_input: bool)
+async get_partial_payload(payload_id: str, is_input: bool, is_modelmesh: bool)
+async delete_partial_payload(payload_id: str, is_input: bool)
}
InferencePartialPayload --> PartialPayloadId
KServeInferenceRequest --> KServeData
KServeInferenceResponse --> KServeData
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there - I've reviewed your changes - here's some feedback:
Blocking issues:
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
General comments:
- In write_reconciled_data, wrap the storage writes and subsequent delete_partial_payload calls in a try/finally so that partial payload cleanup always runs even if a write fails.
- The get_partial_payload method takes an is_modelmesh boolean that leads to type-based branching; consider replacing this flag with separate methods or an enum to make the API clearer and reduce branching logic.
- The validate_data_tag function returns None or an error message string, which can be confusing; consider having it return a boolean or directly raise an exception to simplify its usage in the upload endpoint.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In write_reconciled_data, wrap the storage writes and subsequent delete_partial_payload calls in a try/finally so that partial payload cleanup always runs even if a write fails.
- The get_partial_payload method takes an is_modelmesh boolean that leads to type-based branching; consider replacing this flag with separate methods or an enum to make the API clearer and reduce branching logic.
- The validate_data_tag function returns None or an error message string, which can be confusing; consider having it return a boolean or directly raise an exception to simplify its usage in the upload endpoint.
## Individual Comments
### Comment 1
<location> `src/service/data/storage/maria/maria.py:105-107` </location>
<code_context>
#=== DATASET QUERYING ==========================================================================
- def dataset_exists(self, dataset_name: str) -> bool:
+ async def dataset_exists(self, dataset_name: str) -> bool:
"""
Check if a dataset exists within the TrustyAI model data.
</code_context>
<issue_to_address>
**suggestion (performance):** Changed dataset_exists to async, but underlying DB call is still sync.
Refactor the DB access in dataset_exists to be asynchronous to prevent blocking the event loop.
Suggested implementation:
```python
async def dataset_exists(self, dataset_name: str) -> bool:
"""
Check if a dataset exists within the TrustyAI model data.
"""
# Assuming self._pool is an aiomysql pool
async with self._pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"SELECT COUNT(*) FROM dataset_reference_table WHERE name=%s",
(dataset_name,)
)
result = await cur.fetchone()
return result[0] > 0
```
- If your MariaDB connection pool is not already async (e.g., using aiomysql), you will need to refactor the rest of your DB access code and initialize `self._pool` as an async pool.
- If you use a different async DB library, adjust the connection/cursor/execute logic accordingly.
- Ensure that any calling code is prepared to await this function.
</issue_to_address>
### Comment 2
<location> `src/service/data/storage/maria/maria.py:144-147` </location>
<code_context>
"""
- table_name = self._get_clean_table_name(dataset_name)
+ table_name = await self._get_clean_table_name(dataset_name)
with self.connection_manager as (conn, cursor):
cursor.execute(f"SHOW COLUMNS FROM {table_name}")
return len(cursor.fetchall()) - 1
</code_context>
<issue_to_address>
**suggestion (performance):** Database operations remain synchronous within async methods.
Using synchronous DB calls in async methods can block the event loop and impact performance. Please switch to an async DB driver or offload these operations to a thread pool.
```suggestion
table_name = await self._get_clean_table_name(dataset_name)
def get_column_count():
with self.connection_manager as (conn, cursor):
cursor.execute(f"SHOW COLUMNS FROM {table_name}")
return len(cursor.fetchall()) - 1
import asyncio
return await asyncio.to_thread(get_column_count)
```
</issue_to_address>
### Comment 3
<location> `src/service/data/storage/pvc.py:25` </location>
<code_context>
- f"{PROTECTED_DATASET_SUFFIX}modelmesh_partial_payloads_outputs"
-)
-
+MAX_VOID_TYPE_LENGTH=1024
class H5PYContext:
</code_context>
<issue_to_address>
**suggestion:** MAX_VOID_TYPE_LENGTH is set to 1024, but hardcoded values (e.g., V400) are still used elsewhere.
Please update all instances of hardcoded values like "V400" to use MAX_VOID_TYPE_LENGTH for consistency.
Suggested implementation:
```python
if new_rows.dtype.itemsize > MAX_VOID_TYPE_LENGTH:
raise ValueError(
f"The datatype of the array to be serialized is {new_rows.dtype}- the largest serializable void type is V{MAX_VOID_TYPE_LENGTH}"
)
new_rows = new_rows.astype(f"V{MAX_VOID_TYPE_LENGTH}") # use constant for consistency
```
If there are other places in the file (or codebase) where hardcoded void type lengths like "V400", "V512", etc. are used, you should update those to use f"V{MAX_VOID_TYPE_LENGTH}" as well. This ensures consistency and makes future changes easier.
</issue_to_address>
### Comment 4
<location> `tests/service/test_consumer_endpoint_reconciliation.py:95-97` </location>
<code_context>
async def _test_consume_input_payload(self):
"""Test consuming an input payload."""
- self.mock_storage.persist_modelmesh_payload = mock.AsyncMock()
- self.mock_storage.get_modelmesh_payload = mock.AsyncMock(return_value=None)
+ self.mock_storage.persist_partial_payload = mock.AsyncMock()
+ self.mock_storage.get_partial_payload = mock.AsyncMock(return_value=None)
self.mock_parse_input.return_value = True
</code_context>
<issue_to_address>
**suggestion (testing):** Mocking updated to new storage interface in consumer endpoint tests.
Also, add tests to cover error scenarios, such as exceptions or malformed data returned by the storage interface.
Suggested implementation:
```python
async def _test_consume_input_payload(self):
"""Test consuming an input payload."""
self.mock_storage.persist_partial_payload = mock.AsyncMock()
self.mock_storage.get_partial_payload = mock.AsyncMock(return_value=None)
self.mock_parse_input.return_value = True
self.mock_parse_output.side_effect = ValueError("Not an output payload")
response = self.client.post("/consumer/kserve/v2", json=inference_payload)
print(response.text)
self.assertEqual(response.status_code, 200)
self.assertEqual(
},
async def test_consume_input_payload_storage_exception(self):
"""Test error handling when storage persist_partial_payload raises an exception."""
self.mock_storage.persist_partial_payload = mock.AsyncMock(side_effect=Exception("Storage error"))
self.mock_storage.get_partial_payload = mock.AsyncMock(return_value=None)
self.mock_parse_input.return_value = True
response = self.client.post("/consumer/kserve/v2", json=inference_payload)
print(response.text)
self.assertEqual(response.status_code, 500)
self.assertIn("Storage error", response.text)
async def test_consume_input_payload_malformed_data(self):
"""Test error handling when storage get_partial_payload returns malformed data."""
self.mock_storage.persist_partial_payload = mock.AsyncMock()
self.mock_storage.get_partial_payload = mock.AsyncMock(return_value="not-a-dict")
self.mock_parse_input.return_value = True
response = self.client.post("/consumer/kserve/v2", json=inference_payload)
print(response.text)
self.assertEqual(response.status_code, 400)
self.assertIn("Malformed data", response.text)
```
- You may need to adjust the error assertions (`status_code`, error message) to match your actual error handling in the endpoint.
- Ensure that your endpoint raises the correct exceptions and returns the expected status codes/messages for these error scenarios.
- If `inference_payload` is not defined in the test class, you should define it or import it as needed.
</issue_to_address>
### Comment 5
<location> `tests/service/data/test_mariadb_storage.py:5-9` </location>
<code_context>
start_idx = dataset_idx
n_rows = dataset_idx * 2
- retrieved_full_dataset = self.storage.read_data(dataset_name)
- retrieved_partial_dataset = self.storage.read_data(dataset_name, start_idx, n_rows)
+ retrieved_full_dataset = await self.storage.read_data(dataset_name)
+ retrieved_partial_dataset = await self.storage.read_data(dataset_name, start_idx, n_rows)
</code_context>
<issue_to_address>
**suggestion (testing):** Tests updated for async storage interface methods.
Also, add tests that cover async error scenarios, such as timeouts and task cancellations.
```suggestion
import asyncio
import unittest
import numpy as np
from src.service.data.storage.maria.maria import MariaDBStorage
class TestMariaDBStorageAsyncErrors(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.storage = MariaDBStorage(
host="localhost",
port=3306,
database="trustyai-database",
attempt_migration=False
)
self.dataset_name = "test_dataset"
await self.storage.reset_database()
await self.storage.write_data(self.dataset_name, np.random.rand(10, 5))
async def test_read_data_timeout(self):
# Simulate a timeout by setting a very short timeout
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(self.storage.read_data(self.dataset_name), timeout=0.0001)
async def test_read_data_cancelled(self):
# Simulate cancellation of the task
task = asyncio.create_task(self.storage.read_data(self.dataset_name))
await asyncio.sleep(0.0001)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
async def asyncTearDown(self):
await self.storage.reset_database()
```
</issue_to_address>
### Comment 6
<location> `src/service/data/storage/maria/maria.py:146` </location>
<code_context>
cursor.execute(f"SHOW COLUMNS FROM {table_name}")
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>
### Comment 7
<location> `src/service/data/storage/maria/maria.py:277-280` </location>
<code_context>
cursor.execute(
f"SELECT * FROM `{table_name}` ORDER BY row_idx ASC LIMIT ? OFFSET ?",
(n_rows, start_row)
)
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>
### Comment 8
<location> `src/service/data/storage/maria/maria.py:344-346` </location>
<code_context>
cursor.execute(
f"INSERT INTO `{self.partial_payload_table}` (payload_id, is_input, payload_data) VALUES (?, ?, ?)",
(payload_id, is_input, pkl.dumps(payload.model_dump())))
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>
### Comment 9
<location> `src/service/data/storage/maria/maria.py:352` </location>
<code_context>
cursor.execute(f"SELECT payload_data FROM `{self.partial_payload_table}` WHERE payload_id=? AND is_input=?", (payload_id, is_input))
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>
### Comment 10
<location> `src/service/data/storage/maria/maria.py:367` </location>
<code_context>
cursor.execute(f"DELETE FROM {self.partial_payload_table} WHERE payload_id=? AND is_input=?", (payload_id, is_input))
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>
### Comment 11
<location> `src/service/data/storage/maria/maria.py:377` </location>
<code_context>
cursor.execute(f"DELETE FROM `{self.dataset_reference_table}` WHERE dataset_name=?", (dataset_name,))
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>
### Comment 12
<location> `src/service/data/storage/maria/maria.py:378` </location>
<code_context>
cursor.execute(f"DROP TABLE IF EXISTS `{table_name}`")
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>
### Comment 13
<location> `src/service/data/storage/maria/maria.py:390` </location>
<code_context>
cursor.execute(f"DROP TABLE IF EXISTS `{self.dataset_reference_table}`")
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>
### Comment 14
<location> `src/service/data/storage/maria/maria.py:391` </location>
<code_context>
cursor.execute(f"DROP TABLE IF EXISTS `{self.partial_payload_table}`")
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>
### Comment 15
<location> `tests/endpoints/test_upload_endpoint_maria.py:7` </location>
<code_context>
</code_context>
<issue_to_address>
**issue (code-quality):** Don't import test modules. ([`dont-import-test-modules`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/dont-import-test-modules))
<details><summary>Explanation</summary>Don't import test modules.
Tests should be self-contained and don't depend on each other.
If a helper function is used by multiple tests,
define it in a helper module,
instead of importing one test from the other.
</details>
</issue_to_address>
### Comment 16
<location> `tests/endpoints/test_upload_endpoint_pvc.py:205-231` </location>
<code_context>
</code_context>
<issue_to_address>
**issue (code-quality):** Avoid loops in tests. ([`no-loop-in-tests`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/no-loop-in-tests))
<details><summary>Explanation</summary>Avoid complex code, like loops, in test functions.
Google's software engineering guidelines says:
"Clear tests are trivially correct upon inspection"
To reach that avoid complex code in tests:
* loops
* conditionals
Some ways to fix this:
* Use parametrized tests to get rid of the loop.
* Move the complex logic into helpers.
* Move the complex part into pytest fixtures.
> Complexity is most often introduced in the form of logic. Logic is defined via the imperative parts of programming languages such as operators, loops, and conditionals. When a piece of code contains logic, you need to do a bit of mental computation to determine its result instead of just reading it off of the screen. It doesn't take much logic to make a test more difficult to reason about.
Software Engineering at Google / [Don't Put Logic in Tests](https://abseil.io/resources/swe-book/html/ch12.html#donapostrophet_put_logic_in_tests)
</details>
</issue_to_address>
### Comment 17
<location> `tests/endpoints/test_upload_endpoint_pvc.py:241-271` </location>
<code_context>
</code_context>
<issue_to_address>
**issue (code-quality):** Avoid loops in tests. ([`no-loop-in-tests`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/no-loop-in-tests))
<details><summary>Explanation</summary>Avoid complex code, like loops, in test functions.
Google's software engineering guidelines says:
"Clear tests are trivially correct upon inspection"
To reach that avoid complex code in tests:
* loops
* conditionals
Some ways to fix this:
* Use parametrized tests to get rid of the loop.
* Move the complex logic into helpers.
* Move the complex part into pytest fixtures.
> Complexity is most often introduced in the form of logic. Logic is defined via the imperative parts of programming languages such as operators, loops, and conditionals. When a piece of code contains logic, you need to do a bit of mental computation to determine its result instead of just reading it off of the screen. It doesn't take much logic to make a test more difficult to reason about.
Software Engineering at Google / [Don't Put Logic in Tests](https://abseil.io/resources/swe-book/html/ch12.html#donapostrophet_put_logic_in_tests)
</details>
</issue_to_address>
### Comment 18
<location> `tests/service/data/test_payload_reconciliation_maria.py:11` </location>
<code_context>
</code_context>
<issue_to_address>
**issue (code-quality):** Don't import test modules. ([`dont-import-test-modules`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/dont-import-test-modules))
<details><summary>Explanation</summary>Don't import test modules.
Tests should be self-contained and don't depend on each other.
If a helper function is used by multiple tests,
define it in a helper module,
instead of importing one test from the other.
</details>
</issue_to_address>
### Comment 19
<location> `src/endpoints/data/data_upload.py:39` </location>
<code_context>
@router.post("/data/upload")
async def upload(payload: UploadPayload) -> Dict[str, str]:
"""Upload model data"""
# validate tag
tag_validation_msg = validate_data_tag(payload.data_tag)
if tag_validation_msg:
raise HTTPException(status_code=400, detail=tag_validation_msg)
try:
logger.info(f"Received upload request for model: {payload.model_name}")
# overwrite response model name with provided model name
payload.response.model_name = payload.model_name
req_id = str(uuid.uuid4())
model_data = ModelData(payload.model_name)
datasets_exist = await model_data.datasets_exist()
if all(datasets_exist):
previous_data_points = (await model_data.row_counts())[0]
else:
previous_data_points = 0
await consume_cloud_event(payload.response, req_id)
await consume_cloud_event(payload.request, req_id, tag=payload.data_tag)
model_data = ModelData(payload.model_name)
new_data_points = (await model_data.row_counts())[0]
logger.info(f"Upload completed for model: {payload.model_name}")
return {
"status": "success",
"message": f"{new_data_points-previous_data_points} datapoints successfully added to {payload.model_name} data."
}
except HTTPException as e:
if "Could not reconcile_kserve KServe Inference" in str(e):
raise HTTPException(status_code=400, detail=f"Could not upload payload for model {payload.model_name}: {str(e)}") from e
raise e
except Exception as e:
logger.error(f"Unexpected error in upload endpoint for model {payload.model_name}: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
</code_context>
<issue_to_address>
**issue (code-quality):** We've found these issues:
- Use named expression to simplify assignment and conditional ([`use-named-expression`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/use-named-expression/))
- Explicitly raise from a previous error ([`raise-from-previous-error`](https://docs.sourcery.ai/Reference/Default-Rules/suggestions/raise-from-previous-error/))
</issue_to_address>
### Comment 20
<location> `tests/endpoints/test_upload_endpoint_pvc.py:36` </location>
<code_context>
def generate_payload(n_rows, n_input_cols, n_output_cols, datatype, tag, input_offset=0, output_offset=0):
"""Generate a test payload with specific dimensions and data types."""
model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}"
input_data = []
for i in range(n_rows):
if n_input_cols == 1:
input_data.append(i + input_offset)
else:
row = [i + j + input_offset for j in range(n_input_cols)]
input_data.append(row)
output_data = []
for i in range(n_rows):
if n_output_cols == 1:
output_data.append(i * 2 + output_offset)
else:
row = [i * 2 + j + output_offset for j in range(n_output_cols)]
output_data.append(row)
payload = {
"model_name": model_name,
"data_tag": tag,
"is_ground_truth": False,
"request": {
"inputs": [
{
"name": "input",
"shape": [n_rows, n_input_cols] if n_input_cols > 1 else [n_rows],
"datatype": datatype,
"data": input_data,
}
]
},
"response": {
"outputs": [
{
"name": "output",
"shape": [n_rows, n_output_cols] if n_output_cols > 1 else [n_rows],
"datatype": datatype,
"data": output_data,
}
]
},
}
return payload
</code_context>
<issue_to_address>
**issue (code-quality):** Inline variable that is immediately returned ([`inline-immediately-returned-variable`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/inline-immediately-returned-variable/))
</issue_to_address>
### Comment 21
<location> `tests/endpoints/test_upload_endpoint_pvc.py:75` </location>
<code_context>
def generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag):
"""Generate a test payload with multi-dimensional tensors like real data."""
model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}"
input_data = []
for row_idx in range(n_rows):
row = [row_idx + col_idx * 10 for col_idx in range(n_input_cols)]
input_data.append(row)
output_data = []
for row_idx in range(n_rows):
row = [row_idx * 2 + col_idx for col_idx in range(n_output_cols)]
output_data.append(row)
payload = {
"model_name": model_name,
"data_tag": tag,
"is_ground_truth": False,
"request": {
"inputs": [
{
"name": "multi_input",
"shape": [n_rows, n_input_cols],
"datatype": datatype,
"data": input_data,
}
]
},
"response": {
"outputs": [
{
"name": "multi_output",
"shape": [n_rows, n_output_cols],
"datatype": datatype,
"data": output_data,
}
]
},
}
return payload
</code_context>
<issue_to_address>
**issue (code-quality):** Inline variable that is immediately returned ([`inline-immediately-returned-variable`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/inline-immediately-returned-variable/))
</issue_to_address>
### Comment 22
<location> `tests/endpoints/test_upload_endpoint_pvc.py:110` </location>
<code_context>
def generate_mismatched_shape_no_unique_name_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag):
"""Generate a payload with mismatched shapes and non-unique names."""
model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}"
input_data_1 = [[row_idx + col_idx * 10 for col_idx in range(n_input_cols)] for row_idx in range(n_rows)]
mismatched_rows = n_rows - 1 if n_rows > 1 else 1
input_data_2 = [[row_idx + col_idx * 20 for col_idx in range(n_input_cols)] for row_idx in range(mismatched_rows)]
output_data = [[row_idx * 2 + col_idx for col_idx in range(n_output_cols)] for row_idx in range(n_rows)]
payload = {
"model_name": model_name,
"data_tag": tag,
"is_ground_truth": False,
"request": {
"inputs": [
{
"name": "same_name",
"shape": [n_rows, n_input_cols],
"datatype": datatype,
"data": input_data_1,
},
{
"name": "same_name",
"shape": [mismatched_rows, n_input_cols],
"datatype": datatype,
"data": input_data_2,
},
]
},
"response": {
"outputs": [
{
"name": "multi_output",
"shape": [n_rows, n_output_cols],
"datatype": datatype,
"data": output_data,
}
]
},
}
return payload
</code_context>
<issue_to_address>
**issue (code-quality):** Inline variable that is immediately returned ([`inline-immediately-returned-variable`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/inline-immediately-returned-variable/))
</issue_to_address>
### Comment 23
<location> `tests/endpoints/test_upload_endpoint_pvc.py:182-192` </location>
<code_context>
def post_test(self, payload, expected_status_code, check_msgs):
"""Post a payload and check the response."""
response = self.client.post("/data/upload", json=payload)
if response.status_code != expected_status_code:
print(f"\n=== DEBUG INFO ===")
print(f"Expected status: {expected_status_code}")
print(f"Actual status: {response.status_code}")
print(f"Response text: {response.text}")
print(f"Response headers: {dict(response.headers)}")
if hasattr(response, "json"):
try:
print(f"Response JSON: {response.json()}")
except:
pass
print("==================")
self.assertEqual(response.status_code, expected_status_code)
return response
</code_context>
<issue_to_address>
**issue (code-quality):** We've found these issues:
- Extract code out into method ([`extract-method`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/extract-method/))
- Use `except Exception:` rather than bare `except:` ([`do-not-use-bare-except`](https://docs.sourcery.ai/Reference/Default-Rules/suggestions/do-not-use-bare-except/))
</issue_to_address>
### Comment 24
<location> `tests/endpoints/test_upload_endpoint_pvc.py:292-298` </location>
<code_context>
def test_upload_multiple_tagging(self):
"""Test uploading data with multiple tags."""
n_payload1 = 50
n_payload2 = 51
tag1 = "TRAINING"
tag2 = "NOT TRAINING "
model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}"
payload1 = generate_payload(n_payload1, 10, 1, "INT64", tag1)
payload1["model_name"] = model_name
self.post_test(payload1, 200, [f"{n_payload1} datapoints"])
payload2 = generate_payload(n_payload2, 10, 1, "INT64", tag2)
payload2["model_name"] = model_name
self.post_test(payload2, 200, [f"{n_payload2} datapoints"])
tag1_count = count_rows_with_tag(model_name, tag1)
tag2_count = count_rows_with_tag(model_name, tag2)
self.assertEqual(tag1_count, n_payload1, f"Expected {n_payload1} rows with tag {tag1}")
self.assertEqual(tag2_count, n_payload2, f"Expected {n_payload2} rows with tag {tag2}")
input_rows, _, _ = asyncio.run(ModelData(payload1["model_name"]).row_counts())
self.assertEqual(input_rows, n_payload1 + n_payload2, "Incorrect total number of rows")
</code_context>
<issue_to_address>
**issue (code-quality):** Extract duplicate code into method ([`extract-duplicate-method`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/extract-duplicate-method/))
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| #=== DATASET QUERYING ========================================================================== | ||
| def dataset_exists(self, dataset_name: str) -> bool: | ||
| async def dataset_exists(self, dataset_name: str) -> bool: | ||
| """ | ||
| Check if a dataset exists within the TrustyAI model data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (performance): Changed dataset_exists to async, but underlying DB call is still sync.
Refactor the DB access in dataset_exists to be asynchronous to prevent blocking the event loop.
Suggested implementation:
async def dataset_exists(self, dataset_name: str) -> bool:
"""
Check if a dataset exists within the TrustyAI model data.
"""
# Assuming self._pool is an aiomysql pool
async with self._pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"SELECT COUNT(*) FROM dataset_reference_table WHERE name=%s",
(dataset_name,)
)
result = await cur.fetchone()
return result[0] > 0- If your MariaDB connection pool is not already async (e.g., using aiomysql), you will need to refactor the rest of your DB access code and initialize
self._poolas an async pool. - If you use a different async DB library, adjust the connection/cursor/execute logic accordingly.
- Ensure that any calling code is prepared to await this function.
| cursor.execute(f"SHOW COLUMNS FROM {table_name}") | ||
| return len(cursor.fetchall()) - 1 | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (performance): Database operations remain synchronous within async methods.
Using synchronous DB calls in async methods can block the event loop and impact performance. Please switch to an async DB driver or offload these operations to a thread pool.
| cursor.execute(f"SHOW COLUMNS FROM {table_name}") | |
| return len(cursor.fetchall()) - 1 | |
| table_name = await self._get_clean_table_name(dataset_name) | |
| def get_column_count(): | |
| with self.connection_manager as (conn, cursor): | |
| cursor.execute(f"SHOW COLUMNS FROM {table_name}") | |
| return len(cursor.fetchall()) - 1 | |
| import asyncio | |
| return await asyncio.to_thread(get_column_count) |
| f"{PROTECTED_DATASET_SUFFIX}modelmesh_partial_payloads_outputs" | ||
| ) | ||
|
|
||
| MAX_VOID_TYPE_LENGTH=1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: MAX_VOID_TYPE_LENGTH is set to 1024, but hardcoded values (e.g., V400) are still used elsewhere.
Please update all instances of hardcoded values like "V400" to use MAX_VOID_TYPE_LENGTH for consistency.
Suggested implementation:
if new_rows.dtype.itemsize > MAX_VOID_TYPE_LENGTH:
raise ValueError(
f"The datatype of the array to be serialized is {new_rows.dtype}- the largest serializable void type is V{MAX_VOID_TYPE_LENGTH}"
)
new_rows = new_rows.astype(f"V{MAX_VOID_TYPE_LENGTH}") # use constant for consistencyIf there are other places in the file (or codebase) where hardcoded void type lengths like "V400", "V512", etc. are used, you should update those to use f"V{MAX_VOID_TYPE_LENGTH}" as well. This ensures consistency and makes future changes easier.
| import asyncio | ||
| import unittest | ||
| import os | ||
| import numpy as np | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (testing): Tests updated for async storage interface methods.
Also, add tests that cover async error scenarios, such as timeouts and task cancellations.
| import asyncio | |
| import unittest | |
| import os | |
| import numpy as np | |
| import asyncio | |
| import unittest | |
| import numpy as np | |
| from src.service.data.storage.maria.maria import MariaDBStorage | |
| class TestMariaDBStorageAsyncErrors(unittest.IsolatedAsyncioTestCase): | |
| async def asyncSetUp(self): | |
| self.storage = MariaDBStorage( | |
| host="localhost", | |
| port=3306, | |
| database="trustyai-database", | |
| attempt_migration=False | |
| ) | |
| self.dataset_name = "test_dataset" | |
| await self.storage.reset_database() | |
| await self.storage.write_data(self.dataset_name, np.random.rand(10, 5)) | |
| async def test_read_data_timeout(self): | |
| # Simulate a timeout by setting a very short timeout | |
| with self.assertRaises(asyncio.TimeoutError): | |
| await asyncio.wait_for(self.storage.read_data(self.dataset_name), timeout=0.0001) | |
| async def test_read_data_cancelled(self): | |
| # Simulate cancellation of the task | |
| task = asyncio.create_task(self.storage.read_data(self.dataset_name)) | |
| await asyncio.sleep(0.0001) | |
| task.cancel() | |
| with self.assertRaises(asyncio.CancelledError): | |
| await task | |
| async def asyncTearDown(self): | |
| await self.storage.reset_database() |
| logger.info(f"Deleting table={table_name} to delete dataset={dataset_name}.") | ||
| with self.connection_manager as (conn, cursor): | ||
| cursor.execute(f"DELETE FROM `{self.dataset_reference_table}` WHERE dataset_name=?", (dataset_name,)) | ||
| cursor.execute(f"DROP TABLE IF EXISTS `{table_name}`") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
security (python.sqlalchemy.security.sqlalchemy-execute-raw-query): Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
Source: opengrep
ruivieira
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sheltoncyril!
Since this is mainly for syncing the forks, I'm fine with merging it, but the merge conflicts still need to be resolved.
| INPUT_SUFFIX = "_inputs" | ||
| OUTPUT_SUFFIX = "_outputs" | ||
| METADATA_SUFFIX = "_metadata" | ||
| GROUND_TRUTH_SUFFIX = "_ground_truth" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be shadowed by the GROUND_TRUTH_SUFFIX below
| for kserve_data in get_data(payload): | ||
| data.append(kserve_data.data) | ||
| shapes.add(tuple(kserve_data.data.shape)) | ||
| shapes.add(tuple(kserve_data.shape)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a guard here in case name is None
| async def consume_cloud_event( | ||
| payload: Union[KServeInferenceRequest, KServeInferenceResponse], | ||
| ce_id: Annotated[str | None, Header()] = None, | ||
| tag: str = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate with validate_data_tag()?
Summary by Sourcery
Unify partial payload handling for ModelMesh and KServe by consolidating storage methods and refactoring consumer endpoints, convert storage backends to fully async implementations, and introduce a new data upload API with comprehensive tests.
New Features:
Enhancements:
Build:
Tests: