Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ EXPOSE 4443


CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8080"]
#CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "4443", "--ssl-keyfile", "/etc/tls/internal/tls.key", "--ssl-certfile", "/etc/tls/internal/tls.crt"]
#CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "4443", "--ssl-keyfile", "/etc/tls/internal/tls.key", "--ssl-certfile", "/etc/tls/internal/tls.crt"]
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TrustyAI Service

👋 The TrustyAI Service is intended to a hub for all kinds of Responsible AI workflows, such as
👋 The TrustyAI Service is intended to a hub for all kinds of Responsible AI workflows, such as
explainability, drift, and Large Language Model (LLM) evaluation. Designed as a REST server wrapping
a core Python library, the TrustyAI service is intended to be a tool that can operate in a local
environment, a Jupyter Notebook, or in Kubernetes.
Expand All @@ -15,7 +15,7 @@ environment, a Jupyter Notebook, or in Kubernetes.
- Meanshift

### ⚖️ Fairness ⚖️
- Statistical Parity Difference
- Statistical Parity Difference
- Disparate Impact Ratio
- Average Odds Ratio (WIP)
- Average Predictive Value Difference (WIP)
Expand Down Expand Up @@ -67,4 +67,3 @@ podman run -t $IMAGE_NAME -p 8080:8080 .
---
## ☎️ API ☎️
When the service is running, visit `localhost:8080/docs` to see the OpenAPI documentation!

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dev = [
"httpx>=0.25.0,<0.26",
]
eval = [
"lm-eval[api]==0.4.4",
"lm-eval[api]==0.4.8",
"fastapi-utils>=0.8.0",
"typing-inspect==0.9.0",
]
Expand Down
53 changes: 30 additions & 23 deletions src/endpoints/consumer/consumer_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,27 @@ async def consume_inference_payload(payload: InferencePartialPayload):
return {"status": "success", "message": "Payload processed successfully"}
except Exception as e:
logger.error(f"Error processing inference payload: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error processing payload: {str(e)}"
)
raise HTTPException(status_code=500, detail=f"Error processing payload: {str(e)}")


def reconcile_mismatching_shape_error(shape_tuples, payload_type, payload_id):
msg = (f"Could not reconcile KServe Inference {payload_id}, because {payload_type} shapes were mismatched. "
f"When using multiple {payload_type}s to describe data columns, all shapes must match."
f"However, the following tensor shapes were found:")
msg = (
f"Could not reconcile KServe Inference {payload_id}, because {payload_type} shapes were mismatched. "
f"When using multiple {payload_type}s to describe data columns, all shapes must match."
f"However, the following tensor shapes were found:"
)
for i, (name, shape) in enumerate(shape_tuples):
msg += f"\n{i}:\t{name}:\t{shape}"
logger.error(msg)
raise HTTPException(status_code=400, detail=msg)


def reconcile_mismatching_row_count_error(payload_id, input_shape, output_shape):
msg = (f"Could not reconcile KServe Inference {payload_id}, because the number of "
f"output rows ({output_shape}) did not match the number of input rows "
f"({input_shape}).")
msg = (
f"Could not reconcile KServe Inference {payload_id}, because the number of "
f"output rows ({output_shape}) did not match the number of input rows "
f"({input_shape})."
)
logger.error(msg)
raise HTTPException(status_code=400, detail=msg)

Expand All @@ -112,9 +114,7 @@ def process_payload(payload, get_data: Callable, enforced_first_shape: int = Non
return np.array(data).T, column_names
else:
reconcile_mismatching_shape_error(
shape_tuples,
"input" if enforced_first_shape is None else "output",
payload.id
shape_tuples, "input" if enforced_first_shape is None else "output", payload.id
)
else: # single tensor case: we have one tensor of shape [nrows, d1, d2, ...., dN]
kserve_data: KServeData = get_data(payload)[0]
Expand Down Expand Up @@ -153,18 +153,23 @@ async def reconcile(input_payload: KServeInferenceRequest, output_payload: KServ
tg.create_task(storage_inferface.write_data(output_dataset, output_array, output_names))
tg.create_task(storage_inferface.write_data(metadata_dataset, metadata, metadata_names))

shapes = await (ModelData(output_payload.model_name).shapes())
logger.info(f"Successfully reconciled KServe inference {input_payload.id}, "
f"consisting of {input_array.shape[0]:,} rows from {output_payload.model_name}.")
logger.debug(f"Current storage shapes for {output_payload.model_name}: "
f"Inputs={shapes[0]}, "
f"Outputs={shapes[1]}, "
f"Metadata={shapes[2]}")
shapes = await ModelData(output_payload.model_name).shapes()
logger.info(
f"Successfully reconciled KServe inference {input_payload.id}, "
f"consisting of {input_array.shape[0]:,} rows from {output_payload.model_name}."
)
logger.debug(
f"Current storage shapes for {output_payload.model_name}: "
f"Inputs={shapes[0]}, "
f"Outputs={shapes[1]}, "
f"Metadata={shapes[2]}"
)


@router.post("/")
async def consume_cloud_event(payload: Union[KServeInferenceRequest, KServeInferenceResponse],
ce_id: Annotated[str | None, Header()] = None):
async def consume_cloud_event(
payload: Union[KServeInferenceRequest, KServeInferenceResponse], ce_id: Annotated[str | None, Header()] = None
):
# set payload if from cloud event header
payload.id = ce_id

Expand All @@ -185,8 +190,10 @@ async def consume_cloud_event(payload: Union[KServeInferenceRequest, KServeInfer

elif isinstance(payload, KServeInferenceResponse):
if len(payload.outputs) == 0:
msg = (f"KServe Inference Output {payload.id} received from model={payload.model_name}, "
f"but data field was empty. Payload will not be saved.")
msg = (
f"KServe Inference Output {payload.id} received from model={payload.model_name}, "
f"but data field was empty. Payload will not be saved."
)
logger.error(msg)
raise HTTPException(status_code=400, detail=msg)
else:
Expand Down
49 changes: 22 additions & 27 deletions src/endpoints/evaluation/lm_evaluation_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from lm_eval.__main__ import setup_parser as lm_eval_setup_parser
from fastapi_utils.tasks import repeat_every
except ImportError:
raise ImportError("The TrustyAI service was not built with LM-Evaluation-Harness support, use `pip install .[eval]`")
raise ImportError(
"The TrustyAI service was not built with LM-Evaluation-Harness support, use `pip install .[eval]`"
)

from pydantic import BaseModel, create_model

Expand All @@ -26,7 +28,7 @@

router = APIRouter()
logger = logging.getLogger(__name__)
API_PREFIX= "/eval/lm-evaluation-harness"
API_PREFIX = "/eval/lm-evaluation-harness"


# === STATIC API OBJECTS ===========================================================================
Expand All @@ -37,6 +39,7 @@ class JobStatus(Enum):
QUEUED = "Queued"
STOPPED = "Stopped"


class LMEvalJobSummary(BaseModel):
job_id: int
argument: str
Expand All @@ -56,18 +59,16 @@ class AllLMEvalJobs(BaseModel):


# === Dynamic API Object from LM-Eval CLI ==========================================================
NON_CLI_ARGUMENTS = {
"env_vars": (Dict[str, str], {}),
"lm_eval_path": (str, f"{sys.executable} -m lm_eval")
}
NON_CLI_ARGUMENTS = {"env_vars": (Dict[str, str], {}), "lm_eval_path": (str, f"{sys.executable} -m lm_eval")}


def get_lm_eval_arguments():
"""Grab all fields from an argparse specification into a dictionary"""
parser = lm_eval_setup_parser() # grab lm-eval argparse specification

args = {}
for action in parser._positionals._actions:
arg = {"cli": action.option_strings[0], "argparse_type":action.__class__.__name__}
arg = {"cli": action.option_strings[0], "argparse_type": action.__class__.__name__}
if action.__class__.__name__ == "_StoreTrueAction":
arg["type"] = bool
arg["default"] = False
Expand All @@ -86,10 +87,11 @@ def get_lm_eval_arguments():
def get_model():
"""Build a Pydantic model from the lm-eval argparse arguments, adding in a few config variables of our own as well"""
args = get_lm_eval_arguments()
model_args = {k:(v['type'],v['default']) for k,v in args.items()}
model_args = {k: (v["type"], v["default"]) for k, v in args.items()}
model_args.update(NON_CLI_ARGUMENTS)
return create_model("LMEvalRequest", **model_args)


# Dynamically create the lm-eval-harness job request from the library's argparse
LMEvalRequest = get_model()

Expand Down Expand Up @@ -147,8 +149,8 @@ def convert_to_cli(request: LMEvalRequest):

cli_cmd += " "
arg = args[field]
if arg['argparse_type'] in {"_StoreTrueAction", "_StoreFalseAction"}:
cli_cmd += args[field]['cli']
if arg["argparse_type"] in {"_StoreTrueAction", "_StoreFalseAction"}:
cli_cmd += args[field]["cli"]
else:
field_value = getattr(request, field)
field_value = shlex.quote(field_value) if isinstance(field_value, str) else field_value
Expand Down Expand Up @@ -216,10 +218,7 @@ def _launch_job(job: LMEvalJob):
os.set_blocking(p.stderr.fileno(), False)

# register the subprocess in the global registry
job_registry[job.job_id].mark_launch(
process=p,
start_time=datetime.datetime.now(datetime.timezone.utc).isoformat()
)
job_registry[job.job_id].mark_launch(process=p, start_time=datetime.datetime.now(datetime.timezone.utc).isoformat())


# === ROUTER =======================================================================================
Expand All @@ -244,19 +243,15 @@ def lm_eval_job(request: LMEvalRequest):

# store job
job_id = _generate_job_id()
queued_job = LMEvalJob(
job_id=job_id,
request=request,
argument=cli_cmd
)
queued_job = LMEvalJob(job_id=job_id, request=request, argument=cli_cmd)
job_queue.put(job_id)
job_registry[job_id] = queued_job

return {"status": "success", "message": f"Job {job_id} successfully queued.", "job_id": job_id}


# === METADATA =====================================================================================
@router.get(API_PREFIX+"/jobs", summary="List all running jobs")
@router.get(API_PREFIX + "/jobs", summary="List all running jobs")
def list_running_lm_eval_jobs(include_finished: bool = True) -> AllLMEvalJobs:
"""Provide a list of all lm-evaluation-harness jobs with attached summary information"""

Expand All @@ -271,7 +266,7 @@ def list_running_lm_eval_jobs(include_finished: bool = True) -> AllLMEvalJobs:
return AllLMEvalJobs(jobs=jobs)


@router.get(API_PREFIX+"/job/{job_id}", summary="Get information about a specific job")
@router.get(API_PREFIX + "/job/{job_id}", summary="Get information about a specific job")
def check_lm_eval_job(job_id: int) -> LMEvalJobDetail:
"""Get detailed report of an lm-evaluation-harness job by ID"""

Expand Down Expand Up @@ -301,12 +296,12 @@ def check_lm_eval_job(job_id: int) -> LMEvalJobDetail:
exit_code=status_code,
inference_progress_pct=job.progress,
stdout=job.cumulative_out,
stderr=job.cumulative_err
stderr=job.cumulative_err,
)


# === DELETE DATA ==================================================================================
@router.delete(API_PREFIX+"/job/{id}", summary="Delete an lm-evaluation-harness job's data from the server.")
@router.delete(API_PREFIX + "/job/{id}", summary="Delete an lm-evaluation-harness job's data from the server.")
def delete_lm_eval_job(job_id: int):
"""Delete an lm-evaluation-harness job's data from the server by ID, terminating the job if it's still running"""
if job_id not in job_registry:
Expand All @@ -317,7 +312,7 @@ def delete_lm_eval_job(job_id: int):
return {"status": "success", "message": f"Job {job_id} deleted successfully."}


@router.delete(API_PREFIX+"/jobs", summary="Delete data from all lm-evaluation-harness jobs from the server.")
@router.delete(API_PREFIX + "/jobs", summary="Delete data from all lm-evaluation-harness jobs from the server.")
def delete_all_lm_eval_job():
"""Delete data from all lm-evaluation-harness job's data from the server, terminating any job that its still running"""

Expand All @@ -330,7 +325,7 @@ def delete_all_lm_eval_job():


# === STOP JOBS ====================================================================================
@router.get(API_PREFIX+"/job/{job_id}/dequeue", summary="Stop a running lm-evaluation-harness job.")
@router.get(API_PREFIX + "/job/{job_id}/dequeue", summary="Stop a running lm-evaluation-harness job.")
def stop_lm_eval_job(job_id: int):
"""Stop an lm-evaluation-harness job by ID"""

Expand All @@ -350,7 +345,7 @@ def stop_lm_eval_job(job_id: int):
return {"status": "success", "message": f"Job {job_id} has already completed."}


@router.get(API_PREFIX+"/jobs/dequeue", summary="Stop all running lm-evaluation-harness jobs.")
@router.get(API_PREFIX + "/jobs/dequeue", summary="Stop all running lm-evaluation-harness jobs.")
def stop_all_lm_eval_job():
"""Stop all lm-evaluation-harness jobs"""

Expand All @@ -359,4 +354,4 @@ def stop_all_lm_eval_job():
stop_lm_eval_job(job_id)
stopped.append(job_id)

return {"status": "success", "message": f"Jobs {stopped} stopped successfully."}
return {"status": "success", "message": f"Jobs {stopped} stopped successfully."}
16 changes: 4 additions & 12 deletions src/endpoints/explainers/global_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,20 @@ class GlobalExplanationRequest(BaseModel):
async def global_lime_explanation(request: GlobalExplanationRequest):
"""Compute a global LIME explanation."""
try:
logger.info(
f"Computing global LIME explanation for model: {request.modelConfig.name}"
)
logger.info(f"Computing global LIME explanation for model: {request.modelConfig.name}")
# TODO: Implement
except Exception as e:
logger.error(f"Error computing global LIME explanation: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error computing explanation: {str(e)}"
)
raise HTTPException(status_code=500, detail=f"Error computing explanation: {str(e)}")


@router.post("/explainers/global/pdp")
async def global_pdp_explanation(request: GlobalExplanationRequest):
"""Compute a global PDP explanation."""
try:
logger.info(
f"Computing global PDP explanation for model: {request.modelConfig.name}"
)
logger.info(f"Computing global PDP explanation for model: {request.modelConfig.name}")
# TODO: Implement
return {"status": "success", "explanation": {}}
except Exception as e:
logger.error(f"Error computing global PDP explanation: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error computing explanation: {str(e)}"
)
raise HTTPException(status_code=500, detail=f"Error computing explanation: {str(e)}")
Loading