Skip to content

Conversation

@sheltoncyril
Copy link
Contributor

@sheltoncyril sheltoncyril commented Nov 10, 2025

Summary by Sourcery

Unify partial payload handling for ModelMesh and KServe by consolidating storage methods and refactoring consumer endpoints, convert storage backends to fully async implementations, and introduce a new data upload API with comprehensive tests.

New Features:

  • Add /data/upload endpoint for batch uploading model inference data with optional tagging and upload summary.
  • Introduce write_reconciled_data helper to centralize writing and cleanup of reconciled input, output, and metadata.
  • Implement reconcile_kserve to handle KServe request/response reconciliation alongside ModelMesh payloads.

Enhancements:

  • Unify partial payload persistence via generic persist_partial_payload/get_partial_payload/delete_partial_payload with is_modelmesh flag.
  • Refactor storage interface and PVC/MariaDB implementations to async methods, remove ModelMesh-specific functions, and support both ModelMesh and KServe payloads.
  • Replace module-level storage_interface with global singleton accessed via get_global_storage_interface.
  • Update ModelData to use global storage interface and add checks for dataset existence and metadata retrieval as DataFrame.

Build:

  • Add fastapi-utils dependency to pyproject.toml.

Tests:

  • Update existing unit tests for PVC and MariaDB storage to use async storage methods and the unified payload API.
  • Add extensive tests for the new /data/upload endpoint covering various tensor shapes, datatypes, tagging scenarios, and error cases.

@sourcery-ai
Copy link
Contributor

sourcery-ai bot commented Nov 10, 2025

Reviewer's Guide

This PR refactors how partial payloads are stored and reconciled, unifies ModelMesh and KServe flows under a single persistence interface, and adds a new data upload endpoint that orchestrates KServe request/response ingestion into TrustyAI storage.

Sequence diagram for the new data upload endpoint flow

sequenceDiagram
    actor User
    participant API as "Upload Endpoint"
    participant Consumer as "consume_cloud_event()"
    participant Storage as "StorageInterface"
    participant ModelData
    User->>API: POST /data/upload (UploadPayload)
    API->>ModelData: Check if datasets exist
    ModelData->>Storage: dataset_exists()
    API->>Consumer: consume_cloud_event(response, req_id)
    Consumer->>Storage: persist_partial_payload(response, req_id, is_input=False)
    API->>Consumer: consume_cloud_event(request, req_id, tag)
    Consumer->>Storage: persist_partial_payload(request, req_id, is_input=True)
    API->>ModelData: Get new row counts
    ModelData->>Storage: dataset_rows()
    API-->>User: Return success with number of new datapoints
Loading

Class diagram for unified payload types and storage interface

classDiagram
    class PartialPayloadId {
        +Optional[str] prediction_id
        +Optional[PartialKind] kind
        +get_prediction_id()
        +set_prediction_id(id: str)
        +get_kind()
        +set_kind(kind: PartialKind)
    }
    class InferencePartialPayload {
        +Optional[PartialPayloadId] partialPayloadId
        +Optional[Dict[str, str]] metadata
        +Optional[str] data
        +Optional[str] modelid
        +get_id()
        +set_id(id: str)
        +get_kind()
        +set_kind(kind: PartialKind)
        +get_model_id()
        +set_model_id(model_id: str)
    }
    class KServeData {
        +str name
        +List[int] shape
        +str datatype
        +Optional[Dict[str, str]] parameters
        +List data
    }
    class KServeInferenceRequest {
        +Optional[str] id
        +Optional[Dict[str, str]] parameters
        +List[KServeData] inputs
        +Optional[List[KServeData]] outputs
    }
    class KServeInferenceResponse {
        +str model_name
        +Optional[str] model_version
        +Optional[str] id
        +Optional[Dict[str, str]] parameters
        +List[KServeData] outputs
    }
    class StorageInterface {
        +async dataset_exists(dataset_name: str)
        +async dataset_rows(dataset_name: str)
        +async dataset_shape(dataset_name: str)
        +async write_data(dataset_name: str, new_rows, column_names: List[str])
        +async read_data(dataset_name: str, start_row: int, n_rows: int)
        +async get_original_column_names(dataset_name: str)
        +async get_aliased_column_names(dataset_name: str)
        +async apply_name_mapping(dataset_name: str, name_mapping: Dict[str, str])
        +async delete_dataset(dataset_name: str)
        +async persist_partial_payload(payload, payload_id, is_input: bool)
        +async get_partial_payload(payload_id: str, is_input: bool, is_modelmesh: bool)
        +async delete_partial_payload(payload_id: str, is_input: bool)
    }
    InferencePartialPayload --> PartialPayloadId
    KServeInferenceRequest --> KServeData
    KServeInferenceResponse --> KServeData
Loading

File-Level Changes

Change Details Files
Unify and refactor consumer endpoint payload handling
  • Removed ModelMesh-only Pydantic classes and storage calls
  • Swapped get_storage_interface() for get_global_storage_interface()
  • Replaced persist_modelmesh_payload/get_modelmesh_payload/delete_modelmesh_payload with generic persist_partial_payload/get_partial_payload/delete_partial_payload (with is_modelmesh flag)
  • Added write_reconciled_data() to centralize writing of reconciled data
  • Updated reconcile_modelmesh_payloads and reconcile_kserve to use write_reconciled_data()
  • Adjusted POST /consumer/kserve/v2 and consume_cloud_event flows to use new signatures
src/endpoints/consumer/consumer_endpoint.py
src/endpoints/consumer/__init__.py
Implement new /data/upload endpoint
  • Introduced UploadPayload and validate_data_tag()
  • Orchestrate KServe request/response via consume_cloud_event
  • Generate unique request IDs and compute data point deltas
  • Return count of new datapoints or appropriate error
src/endpoints/data/data_upload.py
tests/endpoints/test_upload_endpoint_pvc.py
tests/endpoints/test_upload_endpoint_maria.py
Consolidate partial payload storage in PVC and MariaDB
  • Removed separate ModelMesh payload tables and methods
  • Added unified persist_partial_payload/get_partial_payload/delete_partial_payload with payload_id, is_input, is_modelmesh flags
  • Converted all I/O methods to async and updated read_data/write_data signatures
  • Enhanced serialization for void dtypes and enforced max void length
  • Cleaned up deprecated methods
src/service/data/storage/pvc.py
src/service/data/storage/maria/maria.py
src/service/data/storage/maria/legacy_maria_reader.py
src/service/data/storage/__init__.py
src/service/data/storage/storage_interface.py
src/service/utils/list_utils.py
Update StorageInterface and storage init
  • Made dataset_exists, dataset_rows, dataset_shape, write_data, read_data, name mapping, and delete_dataset async
  • Expanded partial payload abstract methods to include payload_id and is_modelmesh flag
  • Introduced get_global_storage_interface() with lazy initialization
src/service/data/storage/storage_interface.py
src/service/data/storage/__init__.py
Enhance ModelData for async storage and metadata access
  • Replaced synchronous interface calls with get_global_storage_interface()
  • Added datasets_exist() to check presence of datasets
  • Added get_metadata_as_df() to return metadata as Pandas DataFrame
src/service/data/model_data.py
Revise and extend tests for new async and unified storage
  • Updated all tests to call persist_partial_payload/get_partial_payload/delete_partial_payload
  • Changed sync assertions to await async storage methods
  • Extended tests for new upload endpoint on PVC and MariaDB
  • Refactored test fixtures to handle get_global_storage_interface
tests/service/data/test_payload_reconciliation_pvc.py
tests/service/data/test_payload_reconciliation_maria.py
tests/service/data/test_mariadb_storage.py
tests/service/data/test_mariadb_migration.py
tests/service/data/test_modelmesh_parser.py
tests/service/data/test_utils.py
tests/endpoints/test_upload_endpoint_pvc.py
tests/endpoints/test_upload_endpoint_maria.py
tests/service/test_consumer_endpoint_reconciliation.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link
Contributor

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there - I've reviewed your changes - here's some feedback:

Blocking issues:

  • Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
  • Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
  • Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
  • Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
  • Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
  • Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
  • Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
  • Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
  • Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)

General comments:

  • In write_reconciled_data, wrap the storage writes and subsequent delete_partial_payload calls in a try/finally so that partial payload cleanup always runs even if a write fails.
  • The get_partial_payload method takes an is_modelmesh boolean that leads to type-based branching; consider replacing this flag with separate methods or an enum to make the API clearer and reduce branching logic.
  • The validate_data_tag function returns None or an error message string, which can be confusing; consider having it return a boolean or directly raise an exception to simplify its usage in the upload endpoint.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- In write_reconciled_data, wrap the storage writes and subsequent delete_partial_payload calls in a try/finally so that partial payload cleanup always runs even if a write fails.
- The get_partial_payload method takes an is_modelmesh boolean that leads to type-based branching; consider replacing this flag with separate methods or an enum to make the API clearer and reduce branching logic.
- The validate_data_tag function returns None or an error message string, which can be confusing; consider having it return a boolean or directly raise an exception to simplify its usage in the upload endpoint.

## Individual Comments

### Comment 1
<location> `src/service/data/storage/maria/maria.py:105-107` </location>
<code_context>


     #=== DATASET QUERYING ==========================================================================
-    def dataset_exists(self, dataset_name: str) -> bool:
+    async def dataset_exists(self, dataset_name: str) -> bool:
         """
         Check if a dataset exists within the TrustyAI model data.
</code_context>

<issue_to_address>
**suggestion (performance):** Changed dataset_exists to async, but underlying DB call is still sync.

Refactor the DB access in dataset_exists to be asynchronous to prevent blocking the event loop.

Suggested implementation:

```python
    async def dataset_exists(self, dataset_name: str) -> bool:
        """
        Check if a dataset exists within the TrustyAI model data.
        """
        # Assuming self._pool is an aiomysql pool
        async with self._pool.acquire() as conn:
            async with conn.cursor() as cur:
                await cur.execute(
                    "SELECT COUNT(*) FROM dataset_reference_table WHERE name=%s",
                    (dataset_name,)
                )
                result = await cur.fetchone()
                return result[0] > 0

```

- If your MariaDB connection pool is not already async (e.g., using aiomysql), you will need to refactor the rest of your DB access code and initialize `self._pool` as an async pool.
- If you use a different async DB library, adjust the connection/cursor/execute logic accordingly.
- Ensure that any calling code is prepared to await this function.
</issue_to_address>

### Comment 2
<location> `src/service/data/storage/maria/maria.py:144-147` </location>
<code_context>
         """
-        table_name = self._get_clean_table_name(dataset_name)
+        table_name = await self._get_clean_table_name(dataset_name)
         with self.connection_manager as (conn, cursor):
             cursor.execute(f"SHOW COLUMNS FROM {table_name}")
             return len(cursor.fetchall()) - 1
</code_context>

<issue_to_address>
**suggestion (performance):** Database operations remain synchronous within async methods.

Using synchronous DB calls in async methods can block the event loop and impact performance. Please switch to an async DB driver or offload these operations to a thread pool.

```suggestion
        table_name = await self._get_clean_table_name(dataset_name)

        def get_column_count():
            with self.connection_manager as (conn, cursor):
                cursor.execute(f"SHOW COLUMNS FROM {table_name}")
                return len(cursor.fetchall()) - 1

        import asyncio
        return await asyncio.to_thread(get_column_count)
```
</issue_to_address>

### Comment 3
<location> `src/service/data/storage/pvc.py:25` </location>
<code_context>
-    f"{PROTECTED_DATASET_SUFFIX}modelmesh_partial_payloads_outputs"
-)
-
+MAX_VOID_TYPE_LENGTH=1024

 class H5PYContext:
</code_context>

<issue_to_address>
**suggestion:** MAX_VOID_TYPE_LENGTH is set to 1024, but hardcoded values (e.g., V400) are still used elsewhere.

Please update all instances of hardcoded values like "V400" to use MAX_VOID_TYPE_LENGTH for consistency.

Suggested implementation:

```python
            if new_rows.dtype.itemsize > MAX_VOID_TYPE_LENGTH:
                raise ValueError(
                    f"The datatype of the array to be serialized is {new_rows.dtype}- the largest serializable void type is V{MAX_VOID_TYPE_LENGTH}"
                )
            new_rows = new_rows.astype(f"V{MAX_VOID_TYPE_LENGTH}")  # use constant for consistency

```

If there are other places in the file (or codebase) where hardcoded void type lengths like "V400", "V512", etc. are used, you should update those to use f"V{MAX_VOID_TYPE_LENGTH}" as well. This ensures consistency and makes future changes easier.
</issue_to_address>

### Comment 4
<location> `tests/service/test_consumer_endpoint_reconciliation.py:95-97` </location>
<code_context>

     async def _test_consume_input_payload(self):
         """Test consuming an input payload."""
-        self.mock_storage.persist_modelmesh_payload = mock.AsyncMock()
-        self.mock_storage.get_modelmesh_payload = mock.AsyncMock(return_value=None)
+        self.mock_storage.persist_partial_payload = mock.AsyncMock()
+        self.mock_storage.get_partial_payload = mock.AsyncMock(return_value=None)
         self.mock_parse_input.return_value = True
</code_context>

<issue_to_address>
**suggestion (testing):** Mocking updated to new storage interface in consumer endpoint tests.

Also, add tests to cover error scenarios, such as exceptions or malformed data returned by the storage interface.

Suggested implementation:

```python
    async def _test_consume_input_payload(self):
        """Test consuming an input payload."""
        self.mock_storage.persist_partial_payload = mock.AsyncMock()
        self.mock_storage.get_partial_payload = mock.AsyncMock(return_value=None)
        self.mock_parse_input.return_value = True
        self.mock_parse_output.side_effect = ValueError("Not an output payload")

        response = self.client.post("/consumer/kserve/v2", json=inference_payload)
        print(response.text)

        self.assertEqual(response.status_code, 200)
        self.assertEqual(
            },

    async def test_consume_input_payload_storage_exception(self):
        """Test error handling when storage persist_partial_payload raises an exception."""
        self.mock_storage.persist_partial_payload = mock.AsyncMock(side_effect=Exception("Storage error"))
        self.mock_storage.get_partial_payload = mock.AsyncMock(return_value=None)
        self.mock_parse_input.return_value = True

        response = self.client.post("/consumer/kserve/v2", json=inference_payload)
        print(response.text)
        self.assertEqual(response.status_code, 500)
        self.assertIn("Storage error", response.text)

    async def test_consume_input_payload_malformed_data(self):
        """Test error handling when storage get_partial_payload returns malformed data."""
        self.mock_storage.persist_partial_payload = mock.AsyncMock()
        self.mock_storage.get_partial_payload = mock.AsyncMock(return_value="not-a-dict")
        self.mock_parse_input.return_value = True

        response = self.client.post("/consumer/kserve/v2", json=inference_payload)
        print(response.text)
        self.assertEqual(response.status_code, 400)
        self.assertIn("Malformed data", response.text)

```

- You may need to adjust the error assertions (`status_code`, error message) to match your actual error handling in the endpoint.
- Ensure that your endpoint raises the correct exceptions and returns the expected status codes/messages for these error scenarios.
- If `inference_payload` is not defined in the test class, you should define it or import it as needed.
</issue_to_address>

### Comment 5
<location> `tests/service/data/test_mariadb_storage.py:5-9` </location>
<code_context>

             start_idx = dataset_idx
             n_rows =  dataset_idx * 2
-            retrieved_full_dataset = self.storage.read_data(dataset_name)
-            retrieved_partial_dataset = self.storage.read_data(dataset_name, start_idx, n_rows)
+            retrieved_full_dataset = await self.storage.read_data(dataset_name)
+            retrieved_partial_dataset = await self.storage.read_data(dataset_name, start_idx, n_rows)

</code_context>

<issue_to_address>
**suggestion (testing):** Tests updated for async storage interface methods.

Also, add tests that cover async error scenarios, such as timeouts and task cancellations.

```suggestion
import asyncio
import unittest
import numpy as np

from src.service.data.storage.maria.maria import MariaDBStorage

class TestMariaDBStorageAsyncErrors(unittest.IsolatedAsyncioTestCase):
    async def asyncSetUp(self):
        self.storage = MariaDBStorage(
            host="localhost",
            port=3306,
            database="trustyai-database",
            attempt_migration=False
        )
        self.dataset_name = "test_dataset"
        await self.storage.reset_database()
        await self.storage.write_data(self.dataset_name, np.random.rand(10, 5))

    async def test_read_data_timeout(self):
        # Simulate a timeout by setting a very short timeout
        with self.assertRaises(asyncio.TimeoutError):
            await asyncio.wait_for(self.storage.read_data(self.dataset_name), timeout=0.0001)

    async def test_read_data_cancelled(self):
        # Simulate cancellation of the task
        task = asyncio.create_task(self.storage.read_data(self.dataset_name))
        await asyncio.sleep(0.0001)
        task.cancel()
        with self.assertRaises(asyncio.CancelledError):
            await task

    async def asyncTearDown(self):
        await self.storage.reset_database()
```
</issue_to_address>

### Comment 6
<location> `src/service/data/storage/maria/maria.py:146` </location>
<code_context>
            cursor.execute(f"SHOW COLUMNS FROM {table_name}")
</code_context>

<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.

*Source: opengrep*
</issue_to_address>

### Comment 7
<location> `src/service/data/storage/maria/maria.py:277-280` </location>
<code_context>
            cursor.execute(
                f"SELECT * FROM `{table_name}` ORDER BY row_idx ASC LIMIT ? OFFSET ?",
                (n_rows, start_row)
            )
</code_context>

<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.

*Source: opengrep*
</issue_to_address>

### Comment 8
<location> `src/service/data/storage/maria/maria.py:344-346` </location>
<code_context>
            cursor.execute(
                f"INSERT INTO `{self.partial_payload_table}` (payload_id, is_input, payload_data) VALUES (?, ?, ?)",
                (payload_id, is_input, pkl.dumps(payload.model_dump())))
</code_context>

<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.

*Source: opengrep*
</issue_to_address>

### Comment 9
<location> `src/service/data/storage/maria/maria.py:352` </location>
<code_context>
            cursor.execute(f"SELECT payload_data FROM `{self.partial_payload_table}` WHERE payload_id=? AND is_input=?", (payload_id, is_input))
</code_context>

<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.

*Source: opengrep*
</issue_to_address>

### Comment 10
<location> `src/service/data/storage/maria/maria.py:367` </location>
<code_context>
            cursor.execute(f"DELETE FROM {self.partial_payload_table} WHERE payload_id=? AND is_input=?", (payload_id, is_input))
</code_context>

<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.

*Source: opengrep*
</issue_to_address>

### Comment 11
<location> `src/service/data/storage/maria/maria.py:377` </location>
<code_context>
            cursor.execute(f"DELETE FROM `{self.dataset_reference_table}` WHERE dataset_name=?", (dataset_name,))
</code_context>

<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.

*Source: opengrep*
</issue_to_address>

### Comment 12
<location> `src/service/data/storage/maria/maria.py:378` </location>
<code_context>
            cursor.execute(f"DROP TABLE IF EXISTS `{table_name}`")
</code_context>

<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.

*Source: opengrep*
</issue_to_address>

### Comment 13
<location> `src/service/data/storage/maria/maria.py:390` </location>
<code_context>
            cursor.execute(f"DROP TABLE IF EXISTS `{self.dataset_reference_table}`")
</code_context>

<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.

*Source: opengrep*
</issue_to_address>

### Comment 14
<location> `src/service/data/storage/maria/maria.py:391` </location>
<code_context>
            cursor.execute(f"DROP TABLE IF EXISTS `{self.partial_payload_table}`")
</code_context>

<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.

*Source: opengrep*
</issue_to_address>

### Comment 15
<location> `tests/endpoints/test_upload_endpoint_maria.py:7` </location>
<code_context>

</code_context>

<issue_to_address>
**issue (code-quality):** Don't import test modules. ([`dont-import-test-modules`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/dont-import-test-modules))

<details><summary>Explanation</summary>Don't import test modules.

Tests should be self-contained and don't depend on each other.

If a helper function is used by multiple tests,
define it in a helper module,
instead of importing one test from the other.
</details>
</issue_to_address>

### Comment 16
<location> `tests/endpoints/test_upload_endpoint_pvc.py:205-231` </location>
<code_context>

</code_context>

<issue_to_address>
**issue (code-quality):** Avoid loops in tests. ([`no-loop-in-tests`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/no-loop-in-tests))

<details><summary>Explanation</summary>Avoid complex code, like loops, in test functions.

Google's software engineering guidelines says:
"Clear tests are trivially correct upon inspection"
To reach that avoid complex code in tests:
* loops
* conditionals

Some ways to fix this:

* Use parametrized tests to get rid of the loop.
* Move the complex logic into helpers.
* Move the complex part into pytest fixtures.

> Complexity is most often introduced in the form of logic. Logic is defined via the imperative parts of programming languages such as operators, loops, and conditionals. When a piece of code contains logic, you need to do a bit of mental computation to determine its result instead of just reading it off of the screen. It doesn't take much logic to make a test more difficult to reason about.

Software Engineering at Google / [Don't Put Logic in Tests](https://abseil.io/resources/swe-book/html/ch12.html#donapostrophet_put_logic_in_tests)
</details>
</issue_to_address>

### Comment 17
<location> `tests/endpoints/test_upload_endpoint_pvc.py:241-271` </location>
<code_context>

</code_context>

<issue_to_address>
**issue (code-quality):** Avoid loops in tests. ([`no-loop-in-tests`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/no-loop-in-tests))

<details><summary>Explanation</summary>Avoid complex code, like loops, in test functions.

Google's software engineering guidelines says:
"Clear tests are trivially correct upon inspection"
To reach that avoid complex code in tests:
* loops
* conditionals

Some ways to fix this:

* Use parametrized tests to get rid of the loop.
* Move the complex logic into helpers.
* Move the complex part into pytest fixtures.

> Complexity is most often introduced in the form of logic. Logic is defined via the imperative parts of programming languages such as operators, loops, and conditionals. When a piece of code contains logic, you need to do a bit of mental computation to determine its result instead of just reading it off of the screen. It doesn't take much logic to make a test more difficult to reason about.

Software Engineering at Google / [Don't Put Logic in Tests](https://abseil.io/resources/swe-book/html/ch12.html#donapostrophet_put_logic_in_tests)
</details>
</issue_to_address>

### Comment 18
<location> `tests/service/data/test_payload_reconciliation_maria.py:11` </location>
<code_context>

</code_context>

<issue_to_address>
**issue (code-quality):** Don't import test modules. ([`dont-import-test-modules`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/dont-import-test-modules))

<details><summary>Explanation</summary>Don't import test modules.

Tests should be self-contained and don't depend on each other.

If a helper function is used by multiple tests,
define it in a helper module,
instead of importing one test from the other.
</details>
</issue_to_address>

### Comment 19
<location> `src/endpoints/data/data_upload.py:39` </location>
<code_context>
@router.post("/data/upload")
async def upload(payload: UploadPayload) -> Dict[str, str]:
    """Upload model data"""

    # validate tag
    tag_validation_msg = validate_data_tag(payload.data_tag)
    if tag_validation_msg:
        raise HTTPException(status_code=400, detail=tag_validation_msg)
    try:
        logger.info(f"Received upload request for model: {payload.model_name}")

        # overwrite response model name with provided model name
        payload.response.model_name = payload.model_name

        req_id = str(uuid.uuid4())
        model_data = ModelData(payload.model_name)
        datasets_exist = await model_data.datasets_exist()

        if all(datasets_exist):
            previous_data_points = (await model_data.row_counts())[0]
        else:
            previous_data_points = 0

        await consume_cloud_event(payload.response, req_id)
        await consume_cloud_event(payload.request, req_id, tag=payload.data_tag)

        model_data =  ModelData(payload.model_name)
        new_data_points = (await model_data.row_counts())[0]

        logger.info(f"Upload completed for model: {payload.model_name}")

        return {
            "status": "success",
            "message": f"{new_data_points-previous_data_points} datapoints successfully added to {payload.model_name} data."
        }

    except HTTPException as e:
        if "Could not reconcile_kserve KServe Inference" in str(e):
            raise HTTPException(status_code=400, detail=f"Could not upload payload for model {payload.model_name}: {str(e)}") from e
        raise e
    except Exception as e:
        logger.error(f"Unexpected error in upload endpoint for model {payload.model_name}: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

</code_context>

<issue_to_address>
**issue (code-quality):** We've found these issues:

- Use named expression to simplify assignment and conditional ([`use-named-expression`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/use-named-expression/))
- Explicitly raise from a previous error ([`raise-from-previous-error`](https://docs.sourcery.ai/Reference/Default-Rules/suggestions/raise-from-previous-error/))
</issue_to_address>

### Comment 20
<location> `tests/endpoints/test_upload_endpoint_pvc.py:36` </location>
<code_context>
def generate_payload(n_rows, n_input_cols, n_output_cols, datatype, tag, input_offset=0, output_offset=0):
    """Generate a test payload with specific dimensions and data types."""
    model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}"
    input_data = []
    for i in range(n_rows):
        if n_input_cols == 1:
            input_data.append(i + input_offset)
        else:
            row = [i + j + input_offset for j in range(n_input_cols)]
            input_data.append(row)
    output_data = []
    for i in range(n_rows):
        if n_output_cols == 1:
            output_data.append(i * 2 + output_offset)
        else:
            row = [i * 2 + j + output_offset for j in range(n_output_cols)]
            output_data.append(row)
    payload = {
        "model_name": model_name,
        "data_tag": tag,
        "is_ground_truth": False,
        "request": {
            "inputs": [
                {
                    "name": "input",
                    "shape": [n_rows, n_input_cols] if n_input_cols > 1 else [n_rows],
                    "datatype": datatype,
                    "data": input_data,
                }
            ]
        },
        "response": {
            "outputs": [
                {
                    "name": "output",
                    "shape": [n_rows, n_output_cols] if n_output_cols > 1 else [n_rows],
                    "datatype": datatype,
                    "data": output_data,
                }
            ]
        },
    }
    return payload

</code_context>

<issue_to_address>
**issue (code-quality):** Inline variable that is immediately returned ([`inline-immediately-returned-variable`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/inline-immediately-returned-variable/))
</issue_to_address>

### Comment 21
<location> `tests/endpoints/test_upload_endpoint_pvc.py:75` </location>
<code_context>
def generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag):
    """Generate a test payload with multi-dimensional tensors like real data."""
    model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}"
    input_data = []
    for row_idx in range(n_rows):
        row = [row_idx + col_idx * 10 for col_idx in range(n_input_cols)]
        input_data.append(row)
    output_data = []
    for row_idx in range(n_rows):
        row = [row_idx * 2 + col_idx for col_idx in range(n_output_cols)]
        output_data.append(row)
    payload = {
        "model_name": model_name,
        "data_tag": tag,
        "is_ground_truth": False,
        "request": {
            "inputs": [
                {
                    "name": "multi_input",
                    "shape": [n_rows, n_input_cols],
                    "datatype": datatype,
                    "data": input_data,
                }
            ]
        },
        "response": {
            "outputs": [
                {
                    "name": "multi_output",
                    "shape": [n_rows, n_output_cols],
                    "datatype": datatype,
                    "data": output_data,
                }
            ]
        },
    }
    return payload

</code_context>

<issue_to_address>
**issue (code-quality):** Inline variable that is immediately returned ([`inline-immediately-returned-variable`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/inline-immediately-returned-variable/))
</issue_to_address>

### Comment 22
<location> `tests/endpoints/test_upload_endpoint_pvc.py:110` </location>
<code_context>
def generate_mismatched_shape_no_unique_name_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag):
    """Generate a payload with mismatched shapes and non-unique names."""
    model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}"
    input_data_1 = [[row_idx + col_idx * 10 for col_idx in range(n_input_cols)] for row_idx in range(n_rows)]
    mismatched_rows = n_rows - 1 if n_rows > 1 else 1
    input_data_2 = [[row_idx + col_idx * 20 for col_idx in range(n_input_cols)] for row_idx in range(mismatched_rows)]
    output_data = [[row_idx * 2 + col_idx for col_idx in range(n_output_cols)] for row_idx in range(n_rows)]
    payload = {
        "model_name": model_name,
        "data_tag": tag,
        "is_ground_truth": False,
        "request": {
            "inputs": [
                {
                    "name": "same_name",
                    "shape": [n_rows, n_input_cols],
                    "datatype": datatype,
                    "data": input_data_1,
                },
                {
                    "name": "same_name",
                    "shape": [mismatched_rows, n_input_cols],
                    "datatype": datatype,
                    "data": input_data_2,
                },
            ]
        },
        "response": {
            "outputs": [
                {
                    "name": "multi_output",
                    "shape": [n_rows, n_output_cols],
                    "datatype": datatype,
                    "data": output_data,
                }
            ]
        },
    }
    return payload

</code_context>

<issue_to_address>
**issue (code-quality):** Inline variable that is immediately returned ([`inline-immediately-returned-variable`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/inline-immediately-returned-variable/))
</issue_to_address>

### Comment 23
<location> `tests/endpoints/test_upload_endpoint_pvc.py:182-192` </location>
<code_context>
    def post_test(self, payload, expected_status_code, check_msgs):
        """Post a payload and check the response."""
        response = self.client.post("/data/upload", json=payload)
        if response.status_code != expected_status_code:
            print(f"\n=== DEBUG INFO ===")
            print(f"Expected status: {expected_status_code}")
            print(f"Actual status: {response.status_code}")
            print(f"Response text: {response.text}")
            print(f"Response headers: {dict(response.headers)}")
            if hasattr(response, "json"):
                try:
                    print(f"Response JSON: {response.json()}")
                except:
                    pass
            print("==================")

        self.assertEqual(response.status_code, expected_status_code)
        return response

</code_context>

<issue_to_address>
**issue (code-quality):** We've found these issues:

- Extract code out into method ([`extract-method`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/extract-method/))
- Use `except Exception:` rather than bare `except:` ([`do-not-use-bare-except`](https://docs.sourcery.ai/Reference/Default-Rules/suggestions/do-not-use-bare-except/))
</issue_to_address>

### Comment 24
<location> `tests/endpoints/test_upload_endpoint_pvc.py:292-298` </location>
<code_context>
    def test_upload_multiple_tagging(self):
        """Test uploading data with multiple tags."""
        n_payload1 = 50
        n_payload2 = 51
        tag1 = "TRAINING"
        tag2 = "NOT TRAINING "
        model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}"

        payload1 = generate_payload(n_payload1, 10, 1, "INT64", tag1)
        payload1["model_name"] = model_name
        self.post_test(payload1, 200, [f"{n_payload1} datapoints"])

        payload2 = generate_payload(n_payload2, 10, 1, "INT64", tag2)
        payload2["model_name"] = model_name
        self.post_test(payload2, 200, [f"{n_payload2} datapoints"])

        tag1_count = count_rows_with_tag(model_name, tag1)
        tag2_count = count_rows_with_tag(model_name, tag2)

        self.assertEqual(tag1_count, n_payload1, f"Expected {n_payload1} rows with tag {tag1}")
        self.assertEqual(tag2_count, n_payload2, f"Expected {n_payload2} rows with tag {tag2}")

        input_rows, _, _ = asyncio.run(ModelData(payload1["model_name"]).row_counts())
        self.assertEqual(input_rows, n_payload1 + n_payload2, "Incorrect total number of rows")

</code_context>

<issue_to_address>
**issue (code-quality):** Extract duplicate code into method ([`extract-duplicate-method`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/extract-duplicate-method/))
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment on lines -105 to 107
#=== DATASET QUERYING ==========================================================================
def dataset_exists(self, dataset_name: str) -> bool:
async def dataset_exists(self, dataset_name: str) -> bool:
"""
Check if a dataset exists within the TrustyAI model data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Changed dataset_exists to async, but underlying DB call is still sync.

Refactor the DB access in dataset_exists to be asynchronous to prevent blocking the event loop.

Suggested implementation:

    async def dataset_exists(self, dataset_name: str) -> bool:
        """
        Check if a dataset exists within the TrustyAI model data.
        """
        # Assuming self._pool is an aiomysql pool
        async with self._pool.acquire() as conn:
            async with conn.cursor() as cur:
                await cur.execute(
                    "SELECT COUNT(*) FROM dataset_reference_table WHERE name=%s",
                    (dataset_name,)
                )
                result = await cur.fetchone()
                return result[0] > 0
  • If your MariaDB connection pool is not already async (e.g., using aiomysql), you will need to refactor the rest of your DB access code and initialize self._pool as an async pool.
  • If you use a different async DB library, adjust the connection/cursor/execute logic accordingly.
  • Ensure that any calling code is prepared to await this function.

Comment on lines 146 to 147
cursor.execute(f"SHOW COLUMNS FROM {table_name}")
return len(cursor.fetchall()) - 1


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Database operations remain synchronous within async methods.

Using synchronous DB calls in async methods can block the event loop and impact performance. Please switch to an async DB driver or offload these operations to a thread pool.

Suggested change
cursor.execute(f"SHOW COLUMNS FROM {table_name}")
return len(cursor.fetchall()) - 1
table_name = await self._get_clean_table_name(dataset_name)
def get_column_count():
with self.connection_manager as (conn, cursor):
cursor.execute(f"SHOW COLUMNS FROM {table_name}")
return len(cursor.fetchall()) - 1
import asyncio
return await asyncio.to_thread(get_column_count)

f"{PROTECTED_DATASET_SUFFIX}modelmesh_partial_payloads_outputs"
)

MAX_VOID_TYPE_LENGTH=1024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: MAX_VOID_TYPE_LENGTH is set to 1024, but hardcoded values (e.g., V400) are still used elsewhere.

Please update all instances of hardcoded values like "V400" to use MAX_VOID_TYPE_LENGTH for consistency.

Suggested implementation:

            if new_rows.dtype.itemsize > MAX_VOID_TYPE_LENGTH:
                raise ValueError(
                    f"The datatype of the array to be serialized is {new_rows.dtype}- the largest serializable void type is V{MAX_VOID_TYPE_LENGTH}"
                )
            new_rows = new_rows.astype(f"V{MAX_VOID_TYPE_LENGTH}")  # use constant for consistency

If there are other places in the file (or codebase) where hardcoded void type lengths like "V400", "V512", etc. are used, you should update those to use f"V{MAX_VOID_TYPE_LENGTH}" as well. This ensures consistency and makes future changes easier.

Comment on lines 5 to 9
import asyncio
import unittest
import os
import numpy as np

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Tests updated for async storage interface methods.

Also, add tests that cover async error scenarios, such as timeouts and task cancellations.

Suggested change
import asyncio
import unittest
import os
import numpy as np
import asyncio
import unittest
import numpy as np
from src.service.data.storage.maria.maria import MariaDBStorage
class TestMariaDBStorageAsyncErrors(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.storage = MariaDBStorage(
host="localhost",
port=3306,
database="trustyai-database",
attempt_migration=False
)
self.dataset_name = "test_dataset"
await self.storage.reset_database()
await self.storage.write_data(self.dataset_name, np.random.rand(10, 5))
async def test_read_data_timeout(self):
# Simulate a timeout by setting a very short timeout
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(self.storage.read_data(self.dataset_name), timeout=0.0001)
async def test_read_data_cancelled(self):
# Simulate cancellation of the task
task = asyncio.create_task(self.storage.read_data(self.dataset_name))
await asyncio.sleep(0.0001)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
async def asyncTearDown(self):
await self.storage.reset_database()

logger.info(f"Deleting table={table_name} to delete dataset={dataset_name}.")
with self.connection_manager as (conn, cursor):
cursor.execute(f"DELETE FROM `{self.dataset_reference_table}` WHERE dataset_name=?", (dataset_name,))
cursor.execute(f"DROP TABLE IF EXISTS `{table_name}`")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security (python.sqlalchemy.security.sqlalchemy-execute-raw-query): Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.

Source: opengrep

Copy link
Member

@ruivieira ruivieira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sheltoncyril!

Since this is mainly for syncing the forks, I'm fine with merging it, but the merge conflicts still need to be resolved.

INPUT_SUFFIX = "_inputs"
OUTPUT_SUFFIX = "_outputs"
METADATA_SUFFIX = "_metadata"
GROUND_TRUTH_SUFFIX = "_ground_truth"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be shadowed by the GROUND_TRUTH_SUFFIX below

for kserve_data in get_data(payload):
data.append(kserve_data.data)
shapes.add(tuple(kserve_data.data.shape))
shapes.add(tuple(kserve_data.shape))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a guard here in case name is None

async def consume_cloud_event(
payload: Union[KServeInferenceRequest, KServeInferenceResponse],
ce_id: Annotated[str | None, Header()] = None,
tag: str = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validate with validate_data_tag()?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants