Skip to content

Commit dababa9

Browse files
Add Dataset Container for ML workflows (#158)
* Sample models training scripts * gitignore * Add bundle to df converter and update fhir helpers * Add dataset class and mappers * Update tests * Add example feature schema * Use config instead of params for bundle conversion * Refactor io/ module * Consolidate tests * Refactor fhir helper module for clearer separation of utils * Add loading as dict for ml workflows to MIMIC loader * Fix import * Update docs
1 parent f74850f commit dababa9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+5935
-819
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Sepsis Prediction Inference Script
4+
5+
Demonstrates how to load and use the trained sepsis prediction model.
6+
7+
Requirements:
8+
- pip install scikit-learn xgboost joblib pandas numpy
9+
10+
Usage:
11+
- python sepsis_prediction_inference.py
12+
"""
13+
14+
import pandas as pd
15+
import numpy as np
16+
from pathlib import Path
17+
from typing import Dict, Union, Tuple
18+
import joblib
19+
20+
21+
def load_model(model_path: Union[str, Path]) -> Dict:
22+
"""
23+
Load trained sepsis prediction model.
24+
25+
Args:
26+
model_path: Path to saved model file
27+
28+
Returns:
29+
Dictionary containing model, scaler, and metadata
30+
"""
31+
print(f"Loading model from {model_path}...")
32+
model_data = joblib.load(model_path)
33+
34+
metadata = model_data["metadata"]
35+
print(f" Model: {metadata['model_name']}")
36+
print(f" Training date: {metadata['training_date']}")
37+
print(f" Features: {', '.join(metadata['feature_names'])}")
38+
print(f" Test F1-score: {metadata['metrics']['f1']:.4f}")
39+
print(f" Test AUC-ROC: {metadata['metrics']['auc']:.4f}")
40+
41+
if "optimal_threshold" in metadata["metrics"]:
42+
print(f" Optimal threshold: {metadata['metrics']['optimal_threshold']:.4f}")
43+
print(f" Optimal F1-score: {metadata['metrics']['optimal_f1']:.4f}")
44+
45+
return model_data
46+
47+
48+
def predict_sepsis(
49+
model_data: Dict, patient_features: pd.DataFrame, use_optimal_threshold: bool = True
50+
) -> Tuple[np.ndarray, np.ndarray]:
51+
"""
52+
Predict sepsis risk for patient(s).
53+
54+
Args:
55+
model_data: Dictionary containing model, scaler, and metadata
56+
patient_features: DataFrame with patient features
57+
use_optimal_threshold: Whether to use optimal threshold (default: True)
58+
59+
Returns:
60+
Tuple of (predictions, probabilities)
61+
"""
62+
model = model_data["model"]
63+
scaler = model_data["scaler"]
64+
metadata = model_data["metadata"]
65+
feature_names = metadata["feature_names"]
66+
67+
# Ensure features are in correct order
68+
patient_features = patient_features[feature_names]
69+
70+
# Apply scaling if Logistic Regression
71+
if scaler is not None:
72+
patient_features_scaled = scaler.transform(patient_features)
73+
probabilities = model.predict_proba(patient_features_scaled)[:, 1]
74+
else:
75+
probabilities = model.predict_proba(patient_features)[:, 1]
76+
77+
# Use optimal threshold if available and requested
78+
if use_optimal_threshold and "optimal_threshold" in metadata["metrics"]:
79+
threshold = metadata["metrics"]["optimal_threshold"]
80+
else:
81+
threshold = 0.5
82+
83+
predictions = (probabilities >= threshold).astype(int)
84+
85+
return predictions, probabilities
86+
87+
88+
def create_example_patients() -> pd.DataFrame:
89+
"""
90+
Create example patient data for demonstration.
91+
92+
Returns:
93+
DataFrame with example patient features
94+
"""
95+
# Example patient data
96+
# Patient 1: Healthy patient (low risk)
97+
# Patient 2: Moderate risk (some abnormal values)
98+
# Patient 3: Low risk (normal values)
99+
# Patient 4: High risk for sepsis (multiple severe abnormalities)
100+
# Patient 5: Critical sepsis risk (severe multi-organ dysfunction)
101+
patients = pd.DataFrame(
102+
{
103+
"heart_rate": [85, 110, 75, 130, 145], # beats/min (normal: 60-100)
104+
"temperature": [
105+
37.2,
106+
38.5,
107+
36.8,
108+
39.2,
109+
35.5,
110+
], # Celsius (normal: 36.5-37.5, hypothermia <36)
111+
"respiratory_rate": [16, 24, 14, 30, 35], # breaths/min (normal: 12-20)
112+
"wbc": [8.5, 15.2, 7.0, 18.5, 22.0], # x10^9/L (normal: 4-11)
113+
"lactate": [
114+
1.2,
115+
3.5,
116+
0.9,
117+
4.8,
118+
6.5,
119+
], # mmol/L (normal: <2, severe sepsis: >4)
120+
"creatinine": [0.9, 1.8, 0.8, 2.5, 3.2], # mg/dL (normal: 0.6-1.2)
121+
"age": [45, 68, 35, 72, 78], # years
122+
"gender_encoded": [1, 0, 1, 1, 0], # 1=Male, 0=Female
123+
}
124+
)
125+
126+
return patients
127+
128+
129+
def interpret_results(
130+
predictions: np.ndarray, probabilities: np.ndarray, patient_features: pd.DataFrame
131+
) -> None:
132+
"""
133+
Interpret and display prediction results.
134+
135+
Args:
136+
predictions: Binary predictions (0=no sepsis, 1=sepsis)
137+
probabilities: Probability scores
138+
patient_features: Original patient features
139+
"""
140+
print("\n" + "=" * 80)
141+
print("SEPSIS PREDICTION RESULTS")
142+
print("=" * 80)
143+
144+
for i in range(len(predictions)):
145+
print(f"\nPatient {i+1}:")
146+
print(f" Risk Score: {probabilities[i]:.2%}")
147+
print(f" Prediction: {'SEPSIS RISK' if predictions[i] == 1 else 'Low Risk'}")
148+
149+
# Show key vital signs
150+
print(" Key Features:")
151+
print(f" Heart Rate: {patient_features.iloc[i]['heart_rate']:.1f} bpm")
152+
print(f" Temperature: {patient_features.iloc[i]['temperature']:.1f}°C")
153+
print(
154+
f" Respiratory Rate: {patient_features.iloc[i]['respiratory_rate']:.1f} /min"
155+
)
156+
print(f" WBC: {patient_features.iloc[i]['wbc']:.1f} x10^9/L")
157+
print(f" Lactate: {patient_features.iloc[i]['lactate']:.1f} mmol/L")
158+
print(f" Creatinine: {patient_features.iloc[i]['creatinine']:.2f} mg/dL")
159+
160+
# Risk interpretation
161+
if probabilities[i] >= 0.7:
162+
risk_level = "HIGH"
163+
elif probabilities[i] >= 0.4:
164+
risk_level = "MODERATE"
165+
else:
166+
risk_level = "LOW"
167+
168+
print(f" Clinical Interpretation: {risk_level} RISK")
169+
170+
print("\n" + "=" * 80)
171+
172+
173+
def main():
174+
"""Main inference pipeline."""
175+
# Model path (relative to script location)
176+
script_dir = Path(__file__).parent
177+
model_path = script_dir / "models" / "sepsis_model.pkl"
178+
179+
print("=" * 80)
180+
print("Sepsis Prediction Inference")
181+
print("=" * 80 + "\n")
182+
183+
# Load model
184+
model_data = load_model(model_path)
185+
186+
# Create example patients
187+
print("\nCreating example patient data...")
188+
patient_features = create_example_patients()
189+
print(f"Number of patients: {len(patient_features)}")
190+
191+
# Make predictions
192+
print("\nMaking predictions...")
193+
predictions, probabilities = predict_sepsis(
194+
model_data, patient_features, use_optimal_threshold=True
195+
)
196+
197+
# Interpret results
198+
interpret_results(predictions, probabilities, patient_features)
199+
200+
print("\n" + "=" * 80)
201+
print("Inference complete!")
202+
print("=" * 80)
203+
204+
205+
if __name__ == "__main__":
206+
main()

0 commit comments

Comments
 (0)