Skip to content

Commit ac4fb49

Browse files
authored
Merge pull request #62 from ruivieira/fix-metrics
fix: Metrics consumers, HTTPS ports, global reconciler
2 parents 0c0dc69 + 4b1a778 commit ac4fb49

File tree

15 files changed

+1064
-113
lines changed

15 files changed

+1064
-113
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies = [
1212
"pandas>=2.2.3,<3",
1313
"prometheus-client>=0.21.1,<0.24",
1414
"pydantic>=2.4.2,<3",
15-
"uvicorn>=0.34.0,<0.39",
15+
"hypercorn>=0.17.0,<0.19",
1616
"protobuf>=4.24.4,<7",
1717
"requests>=2.31.0,<3",
1818
"cryptography>=44.0.2,<47",

src/endpoints/consumer/consumer_endpoint.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from src.service.data.storage import get_storage_interface
1515
from src.service.utils import list_utils
1616
from src.service.data.modelmesh_parser import ModelMeshPayloadParser, PartialPayload
17+
from src.service.data.datasources.data_source import DataSource
18+
from src.service.data.shared_data_source import get_shared_data_source
1719

1820
# Define constants locally to avoid import issues
1921
INPUT_SUFFIX = "_inputs"
@@ -31,6 +33,10 @@
3133
unreconciled_inputs = {}
3234
unreconciled_outputs = {}
3335

36+
def get_data_source():
37+
"""Get the shared data source instance."""
38+
return get_shared_data_source()
39+
3440

3541
class PartialPayloadId(BaseModel):
3642
prediction_id: Optional[str] = None
@@ -236,6 +242,21 @@ async def reconcile_modelmesh_payloads(
236242
f"Current storage shapes for {model_id}: Inputs={shapes[0]}, Outputs={shapes[1]}, Metadata={shapes[2]}"
237243
)
238244

245+
# Add model to known models set so it can be discovered by the scheduler
246+
data_source = get_data_source()
247+
await data_source.add_model_to_known(model_id)
248+
known_models = await data_source.get_known_models()
249+
logger.info(f"Added model {model_id} to known models set. Current known models: {list(known_models)}")
250+
logger.debug(f"DataSource instance id: {id(data_source)}")
251+
252+
# Mark that inference data has been recorded for this model
253+
try:
254+
metadata = await data_source.get_metadata(model_id)
255+
metadata.set_recorded_inferences(True)
256+
logger.info(f"Marked model {model_id} as having recorded inferences")
257+
except Exception as e:
258+
logger.warning(f"Could not update recorded_inferences flag for model {model_id}: {e}")
259+
239260
# Clean up
240261
await storage_interface.delete_modelmesh_payload(request_id, True)
241262
await storage_interface.delete_modelmesh_payload(request_id, False)
@@ -338,6 +359,21 @@ async def reconcile(input_payload: KServeInferenceRequest, output_payload: KServ
338359
f"Metadata={shapes[2]}"
339360
)
340361

362+
# Add model to known models set so it can be discovered by the scheduler
363+
data_source = get_data_source()
364+
await data_source.add_model_to_known(output_payload.model_name)
365+
known_models = await data_source.get_known_models()
366+
logger.info(f"Added model {output_payload.model_name} to known models set. Current known models: {list(known_models)}")
367+
logger.debug(f"DataSource instance id: {id(data_source)}")
368+
369+
# Mark that inference data has been recorded for this model
370+
try:
371+
metadata = await data_source.get_metadata(output_payload.model_name)
372+
metadata.set_recorded_inferences(True)
373+
logger.info(f"Marked model {output_payload.model_name} as having recorded inferences")
374+
except Exception as e:
375+
logger.warning(f"Could not update recorded_inferences flag for model {output_payload.model_name}: {e}")
376+
341377

342378
@router.post("/")
343379
async def consume_cloud_event(

src/endpoints/metadata.py

Lines changed: 233 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,24 @@
33
from typing import Dict, List
44
import logging
55

6+
from src.service.data.storage import get_storage_interface
7+
from src.service.data.shared_data_source import get_shared_data_source
8+
from src.service.prometheus.shared_prometheus_scheduler import get_shared_prometheus_scheduler
9+
from src.service.constants import INPUT_SUFFIX, OUTPUT_SUFFIX
10+
611
router = APIRouter()
712
logger = logging.getLogger(__name__)
813

14+
storage_interface = get_storage_interface()
15+
16+
def get_data_source():
17+
"""Get the shared data source instance."""
18+
return get_shared_data_source()
19+
20+
def get_prometheus_scheduler():
21+
"""Get the shared prometheus scheduler instance."""
22+
return get_shared_prometheus_scheduler()
23+
924

1025
class NameMapping(BaseModel):
1126
modelId: str
@@ -18,15 +33,88 @@ class DataTagging(BaseModel):
1833
dataTagging: Dict[str, List[List[int]]] = {}
1934

2035

36+
class ModelIdRequest(BaseModel):
37+
modelId: str
38+
39+
2140
@router.get("/info")
2241
async def get_service_info():
23-
"""Get a list of all inference ids within a particular model inference."""
42+
"""Get a comprehensive overview of the model inference datasets collected by TrustyAI and the metric computations that are scheduled over those datasets."""
2443
try:
25-
# TODO: Implement
26-
return {"models": [], "metrics": [], "version": "1.0.0rc0"}
44+
logger.info("Retrieving service info")
45+
46+
# Get all known models from shared data source
47+
data_source = get_data_source()
48+
known_models = await data_source.get_known_models()
49+
logger.info(f"DataSource instance id: {id(data_source)}")
50+
logger.info(f"Found {len(known_models)} known models: {list(known_models)}")
51+
52+
service_metadata = {}
53+
54+
for model_id in known_models:
55+
try:
56+
# Get metadata for each model
57+
model_metadata = await data_source.get_metadata(model_id)
58+
num_observations = await data_source.get_num_observations(model_id)
59+
has_inferences = await data_source.has_recorded_inferences(model_id)
60+
61+
# Get scheduled metrics for this model
62+
scheduled_metadata = {}
63+
try:
64+
scheduler = get_prometheus_scheduler()
65+
if scheduler:
66+
# Get all metric types and count scheduled requests per model
67+
all_requests = scheduler.get_all_requests() # Should return dict of metric_name -> {request_id -> request}
68+
for metric_name, requests_dict in all_requests.items():
69+
count = 0
70+
for request_id, request in requests_dict.items():
71+
# Check if request is for this model (defensive access)
72+
request_model_id = getattr(request, 'model_id', getattr(request, 'modelId', None))
73+
if request_model_id == model_id:
74+
count += 1
75+
if count > 0:
76+
scheduled_metadata[metric_name] = count
77+
logger.debug(f"Found {len(scheduled_metadata)} scheduled metric types for model {model_id}")
78+
except Exception as e:
79+
logger.warning(f"Error retrieving scheduled metrics for model {model_id}: {e}")
80+
81+
# Transform to match expected format
82+
service_metadata[model_id] = {
83+
"data": {
84+
"observations": num_observations,
85+
"hasRecordedInferences": has_inferences,
86+
"inputTensorName": model_metadata.input_tensor_name if model_metadata else "input",
87+
"outputTensorName": model_metadata.output_tensor_name if model_metadata else "output"
88+
},
89+
"metrics": {
90+
"scheduledMetadata": scheduled_metadata
91+
}
92+
}
93+
94+
logger.debug(f"Retrieved metadata for model {model_id}: observations={num_observations}, hasInferences={has_inferences}")
95+
96+
except Exception as e:
97+
logger.warning(f"Error retrieving metadata for model {model_id}: {e}")
98+
# Still include the model in the response but with basic info
99+
service_metadata[model_id] = {
100+
"data": {
101+
"observations": 0,
102+
"hasRecordedInferences": False,
103+
"inputTensorName": "input",
104+
"outputTensorName": "output"
105+
},
106+
"metrics": {"scheduledMetadata": {}},
107+
"error": str(e)
108+
}
109+
110+
logger.info(f"Successfully retrieved service info for {len(service_metadata)} models")
111+
return service_metadata
112+
27113
except Exception as e:
28114
logger.error(f"Error retrieving service info: {str(e)}")
29-
raise HTTPException(status_code=500, detail=f"Error retrieving service info: {str(e)}")
115+
raise HTTPException(
116+
status_code=500, detail=f"Error retrieving service info: {str(e)}"
117+
) from e
30118

31119

32120
@router.get("/info/inference/ids/{model}")
@@ -41,28 +129,162 @@ async def get_inference_ids(model: str, type: str = "all"):
41129
raise HTTPException(status_code=500, detail=f"Error retrieving inference IDs: {str(e)}")
42130

43131

132+
@router.get("/info/names")
133+
async def get_column_names():
134+
"""Get the current name mappings for all models."""
135+
try:
136+
logger.info("Retrieving name mappings for all models")
137+
138+
# Get all known models from shared data source
139+
data_source = get_data_source()
140+
known_models = await data_source.get_known_models()
141+
logger.info(f"Found {len(known_models)} known models: {list(known_models)}")
142+
143+
name_mappings = {}
144+
145+
for model_id in known_models:
146+
try:
147+
input_dataset_name = model_id + INPUT_SUFFIX
148+
output_dataset_name = model_id + OUTPUT_SUFFIX
149+
150+
input_exists = await storage_interface.dataset_exists(input_dataset_name)
151+
output_exists = await storage_interface.dataset_exists(output_dataset_name)
152+
153+
model_mappings = {
154+
"modelId": model_id,
155+
"inputMapping": {},
156+
"outputMapping": {}
157+
}
158+
159+
# Get input name mappings
160+
if input_exists:
161+
try:
162+
original_input_names = await storage_interface.get_original_column_names(input_dataset_name)
163+
aliased_input_names = await storage_interface.get_aliased_column_names(input_dataset_name)
164+
165+
if original_input_names is not None and aliased_input_names is not None:
166+
# Create mapping from original to aliased names
167+
input_mapping = {}
168+
for orig, alias in zip(list(original_input_names), list(aliased_input_names)):
169+
if orig != alias: # Only include if there's an actual mapping
170+
input_mapping[orig] = alias
171+
model_mappings["inputMapping"] = input_mapping
172+
173+
except Exception as e:
174+
logger.warning(f"Error getting input name mappings for {model_id}: {e}")
175+
176+
# Get output name mappings
177+
if output_exists:
178+
try:
179+
original_output_names = await storage_interface.get_original_column_names(output_dataset_name)
180+
aliased_output_names = await storage_interface.get_aliased_column_names(output_dataset_name)
181+
182+
if original_output_names is not None and aliased_output_names is not None:
183+
# Create mapping from original to aliased names
184+
output_mapping = {}
185+
for orig, alias in zip(list(original_output_names), list(aliased_output_names)):
186+
if orig != alias: # Only include if there's an actual mapping
187+
output_mapping[orig] = alias
188+
model_mappings["outputMapping"] = output_mapping
189+
190+
except Exception as e:
191+
logger.warning(f"Error getting output name mappings for {model_id}: {e}")
192+
193+
name_mappings[model_id] = model_mappings
194+
195+
except Exception as e:
196+
logger.warning(f"Error getting name mappings for model {model_id}: {e}")
197+
198+
logger.info(f"Successfully retrieved name mappings for {len(name_mappings)} models")
199+
return name_mappings
200+
201+
except Exception as e:
202+
logger.error(f"Error retrieving name mappings: {str(e)}")
203+
raise HTTPException(status_code=500, detail=f"Error retrieving name mappings: {str(e)}")
204+
205+
44206
@router.post("/info/names")
45207
async def apply_column_names(name_mapping: NameMapping):
46208
"""Apply a set of human-readable column names to a particular inference."""
47209
try:
48210
logger.info(f"Applying column names for model: {name_mapping.modelId}")
49-
# TODO: Implement
50-
return {"status": "success", "message": "Column names applied successfully"}
211+
212+
model_id = name_mapping.modelId
213+
input_dataset_name = model_id + INPUT_SUFFIX
214+
output_dataset_name = model_id + OUTPUT_SUFFIX
215+
216+
# Check if the model datasets exist
217+
input_exists = await storage_interface.dataset_exists(input_dataset_name)
218+
output_exists = await storage_interface.dataset_exists(output_dataset_name)
219+
220+
if not input_exists and not output_exists:
221+
error_msg = f"No metadata found for model={model_id}. This can happen if TrustyAI has not yet logged any inferences from this model."
222+
logger.error(error_msg)
223+
raise HTTPException(status_code=400, detail=error_msg)
224+
225+
# Apply input mappings if provided and dataset exists
226+
if name_mapping.inputMapping and input_exists:
227+
logger.info(f"Applying input mappings for model {model_id}: {name_mapping.inputMapping}")
228+
await storage_interface.apply_name_mapping(input_dataset_name, name_mapping.inputMapping)
229+
230+
# Apply output mappings if provided and dataset exists
231+
if name_mapping.outputMapping and output_exists:
232+
logger.info(f"Applying output mappings for model {model_id}: {name_mapping.outputMapping}")
233+
await storage_interface.apply_name_mapping(output_dataset_name, name_mapping.outputMapping)
234+
235+
logger.info(f"Name mappings successfully applied to model={model_id}")
236+
return {"message": "Feature and output name mapping successfully applied."}
237+
238+
except HTTPException:
239+
# Re-raise HTTP exceptions without wrapping
240+
raise
51241
except Exception as e:
52242
logger.error(f"Error applying column names: {str(e)}")
53-
raise HTTPException(status_code=500, detail=f"Error applying column names: {str(e)}")
243+
raise HTTPException(
244+
status_code=500, detail=f"Error applying column names: {str(e)}"
245+
) from e
54246

55247

56248
@router.delete("/info/names")
57-
async def remove_column_names(model_id: str):
249+
async def remove_column_names(request: ModelIdRequest):
58250
"""Remove any column names that have been applied to a particular inference."""
59251
try:
252+
model_id = request.modelId
60253
logger.info(f"Removing column names for model: {model_id}")
61-
# TODO: Implement
62-
return {"status": "success", "message": "Column names removed successfully"}
254+
255+
input_dataset_name = model_id + INPUT_SUFFIX
256+
output_dataset_name = model_id + OUTPUT_SUFFIX
257+
258+
# Check if the model datasets exist
259+
input_exists = await storage_interface.dataset_exists(input_dataset_name)
260+
output_exists = await storage_interface.dataset_exists(output_dataset_name)
261+
262+
if not input_exists and not output_exists:
263+
error_msg = f"No metadata found for model={model_id}. This can happen if TrustyAI has not yet logged any inferences from this model."
264+
logger.error(error_msg)
265+
raise HTTPException(status_code=400, detail=error_msg)
266+
267+
# Clear name mappings from input dataset if it exists
268+
if input_exists:
269+
logger.info(f"Clearing input name mappings for model {model_id}")
270+
await storage_interface.clear_name_mapping(input_dataset_name)
271+
272+
# Clear name mappings from output dataset if it exists
273+
if output_exists:
274+
logger.info(f"Clearing output name mappings for model {model_id}")
275+
await storage_interface.clear_name_mapping(output_dataset_name)
276+
277+
logger.info(f"Name mappings successfully cleared from model={model_id}")
278+
return {"message": "Feature and output name mapping successfully cleared."}
279+
280+
except HTTPException:
281+
# Re-raise HTTP exceptions without wrapping
282+
raise
63283
except Exception as e:
64284
logger.error(f"Error removing column names: {str(e)}")
65-
raise HTTPException(status_code=500, detail=f"Error removing column names: {str(e)}")
285+
raise HTTPException(
286+
status_code=500, detail=f"Error removing column names: {str(e)}"
287+
) from e
66288

67289

68290
@router.get("/info/tags")

0 commit comments

Comments
 (0)