diff --git a/cookbook/sepsis_prediction_inference.py b/cookbook/sepsis_prediction_inference.py new file mode 100644 index 00000000..33edb858 --- /dev/null +++ b/cookbook/sepsis_prediction_inference.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +""" +Sepsis Prediction Inference Script + +Demonstrates how to load and use the trained sepsis prediction model. + +Requirements: +- pip install scikit-learn xgboost joblib pandas numpy + +Usage: +- python sepsis_prediction_inference.py +""" + +import pandas as pd +import numpy as np +from pathlib import Path +from typing import Dict, Union, Tuple +import joblib + + +def load_model(model_path: Union[str, Path]) -> Dict: + """ + Load trained sepsis prediction model. + + Args: + model_path: Path to saved model file + + Returns: + Dictionary containing model, scaler, and metadata + """ + print(f"Loading model from {model_path}...") + model_data = joblib.load(model_path) + + metadata = model_data["metadata"] + print(f" Model: {metadata['model_name']}") + print(f" Training date: {metadata['training_date']}") + print(f" Features: {', '.join(metadata['feature_names'])}") + print(f" Test F1-score: {metadata['metrics']['f1']:.4f}") + print(f" Test AUC-ROC: {metadata['metrics']['auc']:.4f}") + + if "optimal_threshold" in metadata["metrics"]: + print(f" Optimal threshold: {metadata['metrics']['optimal_threshold']:.4f}") + print(f" Optimal F1-score: {metadata['metrics']['optimal_f1']:.4f}") + + return model_data + + +def predict_sepsis( + model_data: Dict, patient_features: pd.DataFrame, use_optimal_threshold: bool = True +) -> Tuple[np.ndarray, np.ndarray]: + """ + Predict sepsis risk for patient(s). + + Args: + model_data: Dictionary containing model, scaler, and metadata + patient_features: DataFrame with patient features + use_optimal_threshold: Whether to use optimal threshold (default: True) + + Returns: + Tuple of (predictions, probabilities) + """ + model = model_data["model"] + scaler = model_data["scaler"] + metadata = model_data["metadata"] + feature_names = metadata["feature_names"] + + # Ensure features are in correct order + patient_features = patient_features[feature_names] + + # Apply scaling if Logistic Regression + if scaler is not None: + patient_features_scaled = scaler.transform(patient_features) + probabilities = model.predict_proba(patient_features_scaled)[:, 1] + else: + probabilities = model.predict_proba(patient_features)[:, 1] + + # Use optimal threshold if available and requested + if use_optimal_threshold and "optimal_threshold" in metadata["metrics"]: + threshold = metadata["metrics"]["optimal_threshold"] + else: + threshold = 0.5 + + predictions = (probabilities >= threshold).astype(int) + + return predictions, probabilities + + +def create_example_patients() -> pd.DataFrame: + """ + Create example patient data for demonstration. + + Returns: + DataFrame with example patient features + """ + # Example patient data + # Patient 1: Healthy patient (low risk) + # Patient 2: Moderate risk (some abnormal values) + # Patient 3: Low risk (normal values) + # Patient 4: High risk for sepsis (multiple severe abnormalities) + # Patient 5: Critical sepsis risk (severe multi-organ dysfunction) + patients = pd.DataFrame( + { + "heart_rate": [85, 110, 75, 130, 145], # beats/min (normal: 60-100) + "temperature": [ + 37.2, + 38.5, + 36.8, + 39.2, + 35.5, + ], # Celsius (normal: 36.5-37.5, hypothermia <36) + "respiratory_rate": [16, 24, 14, 30, 35], # breaths/min (normal: 12-20) + "wbc": [8.5, 15.2, 7.0, 18.5, 22.0], # x10^9/L (normal: 4-11) + "lactate": [ + 1.2, + 3.5, + 0.9, + 4.8, + 6.5, + ], # mmol/L (normal: <2, severe sepsis: >4) + "creatinine": [0.9, 1.8, 0.8, 2.5, 3.2], # mg/dL (normal: 0.6-1.2) + "age": [45, 68, 35, 72, 78], # years + "gender_encoded": [1, 0, 1, 1, 0], # 1=Male, 0=Female + } + ) + + return patients + + +def interpret_results( + predictions: np.ndarray, probabilities: np.ndarray, patient_features: pd.DataFrame +) -> None: + """ + Interpret and display prediction results. + + Args: + predictions: Binary predictions (0=no sepsis, 1=sepsis) + probabilities: Probability scores + patient_features: Original patient features + """ + print("\n" + "=" * 80) + print("SEPSIS PREDICTION RESULTS") + print("=" * 80) + + for i in range(len(predictions)): + print(f"\nPatient {i+1}:") + print(f" Risk Score: {probabilities[i]:.2%}") + print(f" Prediction: {'SEPSIS RISK' if predictions[i] == 1 else 'Low Risk'}") + + # Show key vital signs + print(" Key Features:") + print(f" Heart Rate: {patient_features.iloc[i]['heart_rate']:.1f} bpm") + print(f" Temperature: {patient_features.iloc[i]['temperature']:.1f}°C") + print( + f" Respiratory Rate: {patient_features.iloc[i]['respiratory_rate']:.1f} /min" + ) + print(f" WBC: {patient_features.iloc[i]['wbc']:.1f} x10^9/L") + print(f" Lactate: {patient_features.iloc[i]['lactate']:.1f} mmol/L") + print(f" Creatinine: {patient_features.iloc[i]['creatinine']:.2f} mg/dL") + + # Risk interpretation + if probabilities[i] >= 0.7: + risk_level = "HIGH" + elif probabilities[i] >= 0.4: + risk_level = "MODERATE" + else: + risk_level = "LOW" + + print(f" Clinical Interpretation: {risk_level} RISK") + + print("\n" + "=" * 80) + + +def main(): + """Main inference pipeline.""" + # Model path (relative to script location) + script_dir = Path(__file__).parent + model_path = script_dir / "models" / "sepsis_model.pkl" + + print("=" * 80) + print("Sepsis Prediction Inference") + print("=" * 80 + "\n") + + # Load model + model_data = load_model(model_path) + + # Create example patients + print("\nCreating example patient data...") + patient_features = create_example_patients() + print(f"Number of patients: {len(patient_features)}") + + # Make predictions + print("\nMaking predictions...") + predictions, probabilities = predict_sepsis( + model_data, patient_features, use_optimal_threshold=True + ) + + # Interpret results + interpret_results(predictions, probabilities, patient_features) + + print("\n" + "=" * 80) + print("Inference complete!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/cookbook/sepsis_prediction_training.py b/cookbook/sepsis_prediction_training.py new file mode 100644 index 00000000..a0ea85ce --- /dev/null +++ b/cookbook/sepsis_prediction_training.py @@ -0,0 +1,1039 @@ +#!/usr/bin/env python3 +""" +Sepsis Prediction Training Script + +Trains Random Forest, XGBoost, and Logistic Regression models for sepsis prediction +using MIMIC-IV clinical database data. + +Requirements: +- pip install scikit-learn xgboost joblib pandas numpy + +Run: +- python sepsis_prediction_training.py +""" + +import pandas as pd +import numpy as np +from pathlib import Path +from datetime import datetime +from typing import Dict, Tuple, List, Any, Union + +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import ( + accuracy_score, + precision_score, + recall_score, + f1_score, + roc_auc_score, + precision_recall_curve, +) +import xgboost as xgb +import joblib + + +# MIMIC-IV ItemID mappings for features +CHARTEVENTS_ITEMIDS = { + "heart_rate": 220050, + "temperature_f": 223761, + "temperature_c": 223762, + "respiratory_rate": 220210, +} + +LABEVENTS_ITEMIDS = { + "wbc": [51300, 51301], # White Blood Cell Count + "lactate": 50813, + "creatinine": 50912, +} + +# Sepsis ICD-10 codes +SEPSIS_ICD10_CODES = [ + "A41.9", # Sepsis, unspecified organism + "A40", # Streptococcal sepsis (starts with) + "A41", # Other sepsis (starts with) + "R65.20", # Severe sepsis without shock + "R65.21", # Severe sepsis with shock + "R65.1", # SIRS (Systemic Inflammatory Response Syndrome) + "A41.0", # Sepsis due to Streptococcus, group A + "A41.1", # Sepsis due to Streptococcus, group B + "A41.2", # Sepsis due to other specified streptococci + "A41.3", # Sepsis due to Haemophilus influenzae + "A41.4", # Sepsis due to anaerobes + "A41.5", # Sepsis due to other Gram-negative organisms + "A41.50", # Sepsis due to unspecified Gram-negative organism + "A41.51", # Sepsis due to Escherichia coli + "A41.52", # Sepsis due to Pseudomonas + "A41.53", # Sepsis due to Serratia + "A41.59", # Sepsis due to other Gram-negative organisms + "A41.8", # Other specified sepsis + "A41.81", # Sepsis due to Enterococcus + "A41.89", # Other specified sepsis +] + +# Sepsis ICD-9 codes (for older data) +SEPSIS_ICD9_CODES = [ + "038", # Septicemia (starts with) + "99591", # Sepsis + "99592", # Severe sepsis + "78552", # Septic shock +] + + +def load_mimic_data(data_dir: str) -> Dict[str, pd.DataFrame]: + """ + Load all required MIMIC-IV CSV tables. + + Args: + data_dir: Path to MIMIC-IV dataset directory + + Returns: + Dictionary mapping table names to DataFrames + """ + data_dir = Path(data_dir) + + print("Loading MIMIC-IV data...") + + tables = { + "patients": pd.read_csv( + data_dir / "hosp" / "patients.csv.gz", compression="gzip", low_memory=False + ), + "admissions": pd.read_csv( + data_dir / "hosp" / "admissions.csv.gz", + compression="gzip", + low_memory=False, + ), + "icustays": pd.read_csv( + data_dir / "icu" / "icustays.csv.gz", compression="gzip", low_memory=False + ), + "chartevents": pd.read_csv( + data_dir / "icu" / "chartevents.csv.gz", + compression="gzip", + low_memory=False, + ), + "labevents": pd.read_csv( + data_dir / "hosp" / "labevents.csv.gz", compression="gzip", low_memory=False + ), + "diagnoses_icd": pd.read_csv( + data_dir / "hosp" / "diagnoses_icd.csv.gz", + compression="gzip", + low_memory=False, + ), + } + + print(f"Loaded {len(tables)} tables") + for name, df in tables.items(): + print(f" {name}: {len(df)} rows") + + return tables + + +def extract_chartevents_features( + chartevents: pd.DataFrame, icustays: pd.DataFrame +) -> pd.DataFrame: + """ + Extract 2-3 vital signs from chartevents table. + + Args: + chartevents: Chart events DataFrame + icustays: ICU stays DataFrame + + Returns: + DataFrame with features per stay_id + """ + print("Extracting chartevents features...") + + # Filter to relevant itemids + relevant_itemids = list(CHARTEVENTS_ITEMIDS.values()) + chartevents_filtered = chartevents[ + chartevents["itemid"].isin(relevant_itemids) + ].copy() + + # Merge with icustays to get stay times + chartevents_merged = chartevents_filtered.merge( + icustays[["stay_id", "intime", "outtime"]], on="stay_id", how="inner" + ) + + # Convert charttime to datetime + chartevents_merged["charttime"] = pd.to_datetime(chartevents_merged["charttime"]) + chartevents_merged["intime"] = pd.to_datetime(chartevents_merged["intime"]) + + # Filter to first 24 hours of ICU stay + chartevents_merged = chartevents_merged[ + (chartevents_merged["charttime"] >= chartevents_merged["intime"]) + & ( + chartevents_merged["charttime"] + <= chartevents_merged["intime"] + pd.Timedelta(hours=24) + ) + ] + + # Extract numeric values + chartevents_merged["valuenum"] = pd.to_numeric( + chartevents_merged["valuenum"], errors="coerce" + ) + + # Aggregate by stay_id and itemid (take mean) + features = [] + + for stay_id in icustays["stay_id"].unique(): + stay_data = chartevents_merged[chartevents_merged["stay_id"] == stay_id] + + feature_row = {"stay_id": stay_id} + + # Heart Rate + hr_data = stay_data[stay_data["itemid"] == CHARTEVENTS_ITEMIDS["heart_rate"]][ + "valuenum" + ] + feature_row["heart_rate"] = hr_data.mean() if not hr_data.empty else np.nan + + # Temperature (prefer Celsius, convert Fahrenheit if needed) + temp_c = stay_data[stay_data["itemid"] == CHARTEVENTS_ITEMIDS["temperature_c"]][ + "valuenum" + ] + temp_f = stay_data[stay_data["itemid"] == CHARTEVENTS_ITEMIDS["temperature_f"]][ + "valuenum" + ] + + if not temp_c.empty: + feature_row["temperature"] = temp_c.mean() + elif not temp_f.empty: + # Convert Fahrenheit to Celsius + feature_row["temperature"] = (temp_f.mean() - 32) * 5 / 9 + else: + feature_row["temperature"] = np.nan + + # Respiratory Rate + rr_data = stay_data[ + stay_data["itemid"] == CHARTEVENTS_ITEMIDS["respiratory_rate"] + ]["valuenum"] + feature_row["respiratory_rate"] = ( + rr_data.mean() if not rr_data.empty else np.nan + ) + + features.append(feature_row) + + return pd.DataFrame(features) + + +def extract_labevents_features( + labevents: pd.DataFrame, icustays: pd.DataFrame +) -> pd.DataFrame: + """ + Extract 2-3 lab values from labevents table. + + Args: + labevents: Lab events DataFrame + icustays: ICU stays DataFrame + + Returns: + DataFrame with features per stay_id + """ + print("Extracting labevents features...") + + # Get relevant itemids + relevant_itemids = [ + LABEVENTS_ITEMIDS["lactate"], + LABEVENTS_ITEMIDS["creatinine"], + ] + LABEVENTS_ITEMIDS["wbc"] + + labevents_filtered = labevents[labevents["itemid"].isin(relevant_itemids)].copy() + + # Merge with icustays via admissions + # First need to get hadm_id from icustays + icustays_with_hadm = icustays[["stay_id", "hadm_id", "intime"]].copy() + + # Labevents links via hadm_id, then we need to link to stay_id + labevents_merged = labevents_filtered.merge( + icustays_with_hadm, on="hadm_id", how="inner" + ) + + # Convert charttime to datetime + labevents_merged["charttime"] = pd.to_datetime(labevents_merged["charttime"]) + labevents_merged["intime"] = pd.to_datetime(labevents_merged["intime"]) + + # Filter to first 24 hours of ICU stay + labevents_merged = labevents_merged[ + (labevents_merged["charttime"] >= labevents_merged["intime"]) + & ( + labevents_merged["charttime"] + <= labevents_merged["intime"] + pd.Timedelta(hours=24) + ) + ] + + # Extract numeric values + labevents_merged["valuenum"] = pd.to_numeric( + labevents_merged["valuenum"], errors="coerce" + ) + + # Aggregate by stay_id and itemid + features = [] + + for stay_id in icustays["stay_id"].unique(): + stay_data = labevents_merged[labevents_merged["stay_id"] == stay_id] + + feature_row = {"stay_id": stay_id} + + # WBC (check both itemids) + wbc_data = stay_data[stay_data["itemid"].isin(LABEVENTS_ITEMIDS["wbc"])][ + "valuenum" + ] + feature_row["wbc"] = wbc_data.mean() if not wbc_data.empty else np.nan + + # Lactate + lactate_data = stay_data[stay_data["itemid"] == LABEVENTS_ITEMIDS["lactate"]][ + "valuenum" + ] + feature_row["lactate"] = ( + lactate_data.mean() if not lactate_data.empty else np.nan + ) + + # Creatinine + creatinine_data = stay_data[ + stay_data["itemid"] == LABEVENTS_ITEMIDS["creatinine"] + ]["valuenum"] + feature_row["creatinine"] = ( + creatinine_data.mean() if not creatinine_data.empty else np.nan + ) + + features.append(feature_row) + + return pd.DataFrame(features) + + +def extract_demographics( + patients: pd.DataFrame, admissions: pd.DataFrame, icustays: pd.DataFrame +) -> pd.DataFrame: + """ + Extract age and gender from patients table. + + Args: + patients: Patients DataFrame + admissions: Admissions DataFrame (not used, kept for compatibility) + icustays: ICU stays DataFrame + + Returns: + DataFrame with demographics per stay_id + """ + print("Extracting demographics...") + + # icustays already has subject_id, so merge directly with patients + icustays_with_patient = icustays[["stay_id", "subject_id"]].merge( + patients[["subject_id", "gender", "anchor_age"]], on="subject_id", how="left" + ) + + # Use anchor_age if available, otherwise calculate from anchor_year and anchor_age + # For demo data, anchor_age should be available + demographics = icustays_with_patient[["stay_id", "anchor_age", "gender"]].copy() + demographics.rename(columns={"anchor_age": "age"}, inplace=True) + + # Encode gender (M=1, F=0) + demographics["gender_encoded"] = (demographics["gender"] == "M").astype(int) + + return demographics[["stay_id", "age", "gender_encoded"]] + + +def extract_sepsis_labels( + diagnoses_icd: pd.DataFrame, icustays: pd.DataFrame +) -> pd.DataFrame: + """ + Extract sepsis labels from diagnoses_icd table. + Checks both ICD-9 and ICD-10 codes to maximize positive samples. + + Args: + diagnoses_icd: Diagnoses ICD DataFrame + icustays: ICU stays DataFrame + + Returns: + DataFrame with sepsis labels per stay_id + """ + print("Extracting sepsis labels...") + + # Check what ICD versions are available + icd_versions = diagnoses_icd["icd_version"].unique() + print(f" Available ICD versions: {sorted(icd_versions)}") + + all_sepsis_diagnoses = [] + + # Check ICD-10 codes + if 10 in icd_versions: + diagnoses_icd10 = diagnoses_icd[diagnoses_icd["icd_version"] == 10].copy() + print(f" ICD-10 diagnoses: {len(diagnoses_icd10)} rows") + + sepsis_mask = pd.Series( + [False] * len(diagnoses_icd10), index=diagnoses_icd10.index + ) + + for code in SEPSIS_ICD10_CODES: + if "." not in code or code.endswith("."): + # Pattern match (e.g., "A40" matches "A40.x") + code_prefix = code.rstrip(".") + mask = diagnoses_icd10["icd_code"].str.startswith(code_prefix, na=False) + sepsis_mask |= mask + if mask.sum() > 0: + print( + f" Found {mask.sum()} ICD-10 diagnoses matching pattern '{code}'" + ) + else: + # Exact match + mask = diagnoses_icd10["icd_code"] == code + sepsis_mask |= mask + if mask.sum() > 0: + print( + f" Found {mask.sum()} ICD-10 diagnoses with exact code '{code}'" + ) + + sepsis_icd10 = diagnoses_icd10[sepsis_mask].copy() + if len(sepsis_icd10) > 0: + all_sepsis_diagnoses.append(sepsis_icd10) + print(f" Total ICD-10 sepsis diagnoses: {len(sepsis_icd10)}") + + # Check ICD-9 codes + if 9 in icd_versions: + diagnoses_icd9 = diagnoses_icd[diagnoses_icd["icd_version"] == 9].copy() + print(f" ICD-9 diagnoses: {len(diagnoses_icd9)} rows") + + sepsis_mask = pd.Series( + [False] * len(diagnoses_icd9), index=diagnoses_icd9.index + ) + + for code in SEPSIS_ICD9_CODES: + if len(code) <= 3 or code.endswith("."): + # Pattern match (e.g., "038" matches "038.x") + code_prefix = code.rstrip(".") + mask = diagnoses_icd9["icd_code"].str.startswith(code_prefix, na=False) + sepsis_mask |= mask + if mask.sum() > 0: + print( + f" Found {mask.sum()} ICD-9 diagnoses matching pattern '{code}'" + ) + else: + # Exact match + mask = diagnoses_icd9["icd_code"] == code + sepsis_mask |= mask + if mask.sum() > 0: + print( + f" Found {mask.sum()} ICD-9 diagnoses with exact code '{code}'" + ) + + sepsis_icd9 = diagnoses_icd9[sepsis_mask].copy() + if len(sepsis_icd9) > 0: + all_sepsis_diagnoses.append(sepsis_icd9) + print(f" Total ICD-9 sepsis diagnoses: {len(sepsis_icd9)}") + + # Combine all sepsis diagnoses + if all_sepsis_diagnoses: + sepsis_diagnoses = pd.concat(all_sepsis_diagnoses, ignore_index=True) + print(f" Total sepsis diagnoses (ICD-9 + ICD-10): {len(sepsis_diagnoses)}") + + if len(sepsis_diagnoses) > 0: + print( + f" Sample sepsis ICD codes: {sepsis_diagnoses['icd_code'].unique()[:15].tolist()}" + ) + print( + f" Unique hadm_id with sepsis: {sepsis_diagnoses['hadm_id'].nunique()}" + ) + else: + sepsis_diagnoses = pd.DataFrame(columns=diagnoses_icd.columns) + print(" No sepsis diagnoses found") + + # Merge with icustays to get stay_id + icustays_with_hadm = icustays[["stay_id", "hadm_id"]].copy() + + if len(sepsis_diagnoses) > 0: + sepsis_labels = icustays_with_hadm.merge( + sepsis_diagnoses[["hadm_id"]].drop_duplicates(), + on="hadm_id", + how="left", + indicator=True, + ) + else: + sepsis_labels = icustays_with_hadm.copy() + sepsis_labels["_merge"] = "left_only" + + # Create binary label (1 if sepsis, 0 otherwise) + sepsis_labels["sepsis"] = (sepsis_labels["_merge"] == "both").astype(int) + + sepsis_count = sepsis_labels["sepsis"].sum() + print( + f" ICU stays with sepsis: {sepsis_count}/{len(sepsis_labels)} ({sepsis_count/len(sepsis_labels)*100:.2f}%)" + ) + + return sepsis_labels[["stay_id", "sepsis"]] + + +def print_feature_summary(X: pd.DataFrame): + """Print feature statistics with FHIR mapping information. + + Args: + X: Feature matrix with actual data + """ + print("\n" + "=" * 120) + print("FEATURE SUMMARY: MIMIC-IV → Model → FHIR Mapping") + print("=" * 120) + + # Define FHIR mappings for each feature + fhir_mappings = { + "heart_rate": { + "mimic_table": "chartevents", + "mimic_itemid": "220050", + "fhir_resource": "Observation", + "fhir_code": "8867-4", + "fhir_system": "LOINC", + "fhir_display": "Heart rate", + }, + "temperature": { + "mimic_table": "chartevents", + "mimic_itemid": "223762/223761", + "fhir_resource": "Observation", + "fhir_code": "8310-5", + "fhir_system": "LOINC", + "fhir_display": "Body temperature", + }, + "respiratory_rate": { + "mimic_table": "chartevents", + "mimic_itemid": "220210", + "fhir_resource": "Observation", + "fhir_code": "9279-1", + "fhir_system": "LOINC", + "fhir_display": "Respiratory rate", + }, + "wbc": { + "mimic_table": "labevents", + "mimic_itemid": "51300/51301", + "fhir_resource": "Observation", + "fhir_code": "6690-2", + "fhir_system": "LOINC", + "fhir_display": "Leukocytes [#/volume] in Blood", + }, + "lactate": { + "mimic_table": "labevents", + "mimic_itemid": "50813", + "fhir_resource": "Observation", + "fhir_code": "2524-7", + "fhir_system": "LOINC", + "fhir_display": "Lactate [Moles/volume] in Blood", + }, + "creatinine": { + "mimic_table": "labevents", + "mimic_itemid": "50912", + "fhir_resource": "Observation", + "fhir_code": "2160-0", + "fhir_system": "LOINC", + "fhir_display": "Creatinine [Mass/volume] in Serum or Plasma", + }, + "age": { + "mimic_table": "patients", + "mimic_itemid": "anchor_age", + "fhir_resource": "Patient", + "fhir_code": "birthDate", + "fhir_system": "FHIR Core", + "fhir_display": "Patient birth date (calculate age)", + }, + "gender_encoded": { + "mimic_table": "patients", + "mimic_itemid": "gender", + "fhir_resource": "Patient", + "fhir_code": "gender", + "fhir_system": "FHIR Core", + "fhir_display": "Administrative Gender (M/F)", + }, + } + + print( + f"\n{'Feature':<20} {'Mean±SD':<20} {'MIMIC Source':<20} {'FHIR Resource':<20} {'FHIR Code (System)':<30}" + ) + print("-" * 120) + + for feature in X.columns: + mapping = fhir_mappings.get(feature, {}) + + # Calculate statistics + mean_val = X[feature].mean() + std_val = X[feature].std() + + # Format based on feature type + if feature == "gender_encoded": + stats = f"{mean_val:.2f} (M={X[feature].sum():.0f})" + else: + stats = f"{mean_val:.2f}±{std_val:.2f}" + + mimic_source = f"{mapping.get('mimic_table', 'N/A')} ({mapping.get('mimic_itemid', 'N/A')})" + fhir_resource = mapping.get("fhir_resource", "N/A") + fhir_code = ( + f"{mapping.get('fhir_code', 'N/A')} ({mapping.get('fhir_system', 'N/A')})" + ) + + print( + f"{feature:<20} {stats:<20} {mimic_source:<20} {fhir_resource:<20} {fhir_code:<30}" + ) + + print("\n" + "=" * 120) + print( + "Note: Statistics calculated from first 24 hours of ICU stay. Missing values imputed with median." + ) + print("=" * 120 + "\n") + + +def create_feature_matrix( + chartevents_features: pd.DataFrame, + labevents_features: pd.DataFrame, + demographics: pd.DataFrame, + sepsis_labels: pd.DataFrame, +) -> Tuple[pd.DataFrame, pd.Series]: + """ + Create feature matrix and labels from extracted features. + + Args: + chartevents_features: Chart events features + labevents_features: Lab events features + demographics: Demographics features + sepsis_labels: Sepsis labels + + Returns: + Tuple of (feature matrix, labels) + """ + print("Creating feature matrix...") + + # Merge all features on stay_id + features = ( + chartevents_features.merge(labevents_features, on="stay_id", how="outer") + .merge(demographics, on="stay_id", how="outer") + .merge(sepsis_labels, on="stay_id", how="inner") + ) + + # Select feature columns (exclude stay_id and sepsis) + feature_cols = [ + "heart_rate", + "temperature", + "respiratory_rate", + "wbc", + "lactate", + "creatinine", + "age", + "gender_encoded", + ] + + X = features[feature_cols].copy() + y = features["sepsis"].copy() + + print(f"Feature matrix shape: {X.shape}") + print(f"Sepsis cases: {y.sum()} ({y.sum() / len(y) * 100:.2f}%)") + + return X, y + + +def train_models(X_train: pd.DataFrame, y_train: pd.Series) -> Dict[str, Any]: + """ + Train all three models (Random Forest, XGBoost, Logistic Regression). + + Args: + X_train: Training features + y_train: Training labels + + Returns: + Dictionary of trained models + """ + print("\nTraining models...") + + models = {} + + # Check if we have any positive samples + positive_samples = y_train.sum() + total_samples = len(y_train) + positive_rate = positive_samples / total_samples if total_samples > 0 else 0.0 + + print( + f" Positive samples: {positive_samples}/{total_samples} ({positive_rate*100:.2f}%)" + ) + + # Random Forest - use class_weight to handle imbalance + print(" Training Random Forest...") + rf = RandomForestClassifier( + n_estimators=100, + random_state=42, + n_jobs=-1, + class_weight="balanced", # Automatically adjust for class imbalance + ) + rf.fit(X_train, y_train) + models["RandomForest"] = rf + + # XGBoost - handle case with no positive samples + print(" Training XGBoost...") + if positive_samples == 0: + # When there are no positive samples, set base_score to a small value + # and use scale_pos_weight to avoid errors + xgb_model = xgb.XGBClassifier( + random_state=42, + n_jobs=-1, + eval_metric="logloss", + base_score=0.01, # Small positive value instead of 0 + scale_pos_weight=1.0, + ) + else: + # Calculate scale_pos_weight for imbalanced data + scale_pos_weight = (total_samples - positive_samples) / positive_samples + xgb_model = xgb.XGBClassifier( + random_state=42, + n_jobs=-1, + eval_metric="logloss", + scale_pos_weight=scale_pos_weight, + ) + xgb_model.fit(X_train, y_train) + models["XGBoost"] = xgb_model + + # Logistic Regression (with scaling) - use class_weight to handle imbalance + print(" Training Logistic Regression...") + scaler = StandardScaler() + X_train_scaled = scaler.fit_transform(X_train) + lr = LogisticRegression( + random_state=42, + max_iter=1000, + class_weight="balanced", # Automatically adjust for class imbalance + ) + lr.fit(X_train_scaled, y_train) + models["LogisticRegression"] = lr + models["scaler"] = scaler # Store scaler for later use + + return models + + +def evaluate_models( + models: Dict[str, Any], + X_test: pd.DataFrame, + y_test: pd.Series, + feature_names: List[str], +) -> Dict[str, Dict[str, float]]: + """ + Evaluate and compare all models. + + Args: + models: Dictionary of trained models + X_test: Test features + y_test: Test labels + feature_names: List of feature names + + Returns: + Dictionary of evaluation metrics for each model + """ + print("\nEvaluating models...") + print( + f"Test set: {len(y_test)} samples, {y_test.sum()} positive ({y_test.sum()/len(y_test)*100:.2f}%)" + ) + + results = {} + + for name, model in models.items(): + if name == "scaler": + continue + + # Get probability predictions + if name == "LogisticRegression": + X_test_scaled = models["scaler"].transform(X_test) + y_pred_proba = model.predict_proba(X_test_scaled)[:, 1] + else: + y_pred_proba = model.predict_proba(X_test)[:, 1] + + # Use default threshold (0.5) for predictions + y_pred = (y_pred_proba >= 0.5).astype(int) + + # Calculate metrics with default threshold + metrics = { + "accuracy": accuracy_score(y_test, y_pred), + "precision": precision_score(y_test, y_pred, zero_division=0), + "recall": recall_score(y_test, y_pred, zero_division=0), + "f1": f1_score(y_test, y_pred, zero_division=0), + "auc": roc_auc_score(y_test, y_pred_proba) + if len(np.unique(y_test)) > 1 + else 0.0, + } + + # Try to find optimal threshold for F1 score + if len(np.unique(y_test)) > 1 and y_test.sum() > 0: + precision, recall, thresholds = precision_recall_curve(y_test, y_pred_proba) + f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10) + optimal_idx = np.argmax(f1_scores) + optimal_threshold = ( + thresholds[optimal_idx] if optimal_idx < len(thresholds) else 0.5 + ) + optimal_f1 = f1_scores[optimal_idx] + + # Predictions with optimal threshold + y_pred_optimal = (y_pred_proba >= optimal_threshold).astype(int) + metrics["optimal_threshold"] = optimal_threshold + metrics["optimal_f1"] = optimal_f1 + metrics["optimal_precision"] = precision_score( + y_test, y_pred_optimal, zero_division=0 + ) + metrics["optimal_recall"] = recall_score( + y_test, y_pred_optimal, zero_division=0 + ) + else: + metrics["optimal_threshold"] = 0.5 + metrics["optimal_f1"] = 0.0 + metrics["optimal_precision"] = 0.0 + metrics["optimal_recall"] = 0.0 + + results[name] = metrics + + print(f"\n{name}:") + print( + f" Predictions: {y_pred.sum()} positive predicted (actual: {y_test.sum()})" + ) + print(f" Accuracy: {metrics['accuracy']:.4f}") + print(f" Precision: {metrics['precision']:.4f}") + print(f" Recall: {metrics['recall']:.4f}") + print(f" F1-score: {metrics['f1']:.4f}") + print(f" AUC-ROC: {metrics['auc']:.4f}") + if metrics["optimal_f1"] > 0: + print(f" Optimal threshold: {metrics['optimal_threshold']:.4f}") + print(f" Optimal F1-score: {metrics['optimal_f1']:.4f}") + print(f" Optimal Precision: {metrics['optimal_precision']:.4f}") + print(f" Optimal Recall: {metrics['optimal_recall']:.4f}") + + # Show feature importance for tree-based models + if hasattr(model, "feature_importances_"): + print("\n Top 5 Feature Importances:") + importances = model.feature_importances_ + indices = np.argsort(importances)[::-1][:5] + for idx in indices: + print(f" {feature_names[idx]}: {importances[idx]:.4f}") + + return results + + +def select_best_model( + models: Dict[str, Any], + results: Dict[str, Dict[str, float]], + metric: str = "f1", +) -> Tuple[str, Any, Dict[str, float]]: + """ + Select best model based on specified metric. + + Args: + models: Dictionary of trained models + results: Evaluation results + metric: Metric to optimize ("f1", "recall", "precision", "auc") + + Returns: + Tuple of (best model name, best model, best metrics) + """ + print(f"\nSelecting best model based on {metric}...") + + # Get the appropriate metric value (prefer optimal if available) + def get_metric_value(metrics, metric_name): + if metric_name == "f1": + return metrics.get("optimal_f1", metrics["f1"]) + elif metric_name == "recall": + return metrics.get("optimal_recall", metrics["recall"]) + elif metric_name == "precision": + return metrics.get("optimal_precision", metrics["precision"]) + elif metric_name == "auc": + return metrics.get("auc", 0.0) + else: + return metrics.get("optimal_f1", metrics["f1"]) + + best_name = max(results.keys(), key=lambda k: get_metric_value(results[k], metric)) + best_model = models[best_name] + best_metrics = results[best_name] + + best_value = get_metric_value(best_metrics, metric) + print(f"Best model: {best_name} ({metric}: {best_value:.4f})") + + return best_name, best_model, best_metrics + + +def save_model( + model: Any, + model_name: str, + feature_names: List[str], + metrics: Dict[str, float], + scaler: Any, + output_path: Union[str, Path], +) -> None: + """ + Save the best model with metadata. + + Args: + model: Trained model + model_name: Name of the model + feature_names: List of feature names + metrics: Evaluation metrics + scaler: StandardScaler (if Logistic Regression, None otherwise) + output_path: Path to save model + """ + print(f"\nSaving model to {output_path}...") + + # Create output directory if it doesn't exist + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Prepare metadata + metadata = { + "model_name": model_name, + "training_date": datetime.now().isoformat(), + "feature_names": feature_names, + "metrics": metrics, + "itemid_mappings": { + "chartevents": CHARTEVENTS_ITEMIDS, + "labevents": LABEVENTS_ITEMIDS, + }, + "sepsis_icd_codes": { + "icd10": SEPSIS_ICD10_CODES, + "icd9": SEPSIS_ICD9_CODES, + }, + } + + # Save model and metadata + model_data = { + "model": model, + "scaler": scaler, + "metadata": metadata, + } + + joblib.dump(model_data, output_path) + + print("Model saved successfully!") + + +def main(): + """Main training pipeline.""" + # Data directory + data_dir = "../datasets/mimic-iv-clinical-database-demo-2.2" + + # Output path (relative to script location) + script_dir = Path(__file__).parent + output_path = script_dir / "models" / "sepsis_model.pkl" + + print("=" * 60) + print("Sepsis Prediction Model Training") + print("=" * 60) + + # Load data + tables = load_mimic_data(data_dir) + + # Extract features + chartevents_features = extract_chartevents_features( + tables["chartevents"], tables["icustays"] + ) + labevents_features = extract_labevents_features( + tables["labevents"], tables["icustays"] + ) + demographics = extract_demographics( + tables["patients"], tables["admissions"], tables["icustays"] + ) + + # Extract labels + sepsis_labels = extract_sepsis_labels(tables["diagnoses_icd"], tables["icustays"]) + + # Create feature matrix + X, y = create_feature_matrix( + chartevents_features, + labevents_features, + demographics, + sepsis_labels, + ) + + # Handle missing values (impute with median) + print("\nHandling missing values...") + missing_before = X.isnull().sum().sum() + print(f" Missing values before imputation: {missing_before}") + X = X.fillna(X.median()) + + # Print feature summary with actual data statistics + print_feature_summary(X) + + # Split data with careful stratification to ensure positive samples in both sets + print("\nSplitting data...") + if len(np.unique(y)) > 1 and y.sum() > 0: + # Use stratification to ensure positive samples in both train and test + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, stratify=y + ) + print( + f" Training set: {len(X_train)} samples ({y_train.sum()} positive, {y_train.sum()/len(y_train)*100:.2f}%)" + ) + print( + f" Test set: {len(X_test)} samples ({y_test.sum()} positive, {y_test.sum()/len(y_test)*100:.2f}%)" + ) + + # Warn if test set has no positive samples (shouldn't happen with stratify, but check anyway) + if y_test.sum() == 0: + print( + " WARNING: Test set has no positive samples! Consider using a different random seed." + ) + else: + print( + " Warning: No positive samples or only one class. Skipping stratification." + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + print(f" Training set: {len(X_train)} samples") + print(f" Test set: {len(X_test)} samples") + + # Apply oversampling to training data to balance classes + print("\nApplying oversampling to training data...") + try: + from imblearn.over_sampling import SMOTE + + # Only apply SMOTE if we have positive samples + if y_train.sum() > 0 and len(np.unique(y_train)) > 1: + print( + f" Before oversampling: {len(X_train)} samples ({y_train.sum()} positive, {y_train.sum()/len(y_train)*100:.2f}%)" + ) + # Ensure k_neighbors doesn't exceed available positive samples + k_neighbors = min(5, max(1, y_train.sum() - 1)) + smote = SMOTE(random_state=42, k_neighbors=k_neighbors) + X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train) + print( + f" After oversampling: {len(X_train_resampled)} samples ({y_train_resampled.sum()} positive, {y_train_resampled.sum()/len(X_train_resampled)*100:.2f}%)" + ) + X_train = pd.DataFrame( + X_train_resampled, + columns=X_train.columns, + index=X_train.index[: len(X_train_resampled)], + ) + y_train = pd.Series( + y_train_resampled, index=y_train.index[: len(y_train_resampled)] + ) + else: + print(" Skipping oversampling: insufficient positive samples") + except (ImportError, ModuleNotFoundError) as e: + print( + " imbalanced-learn not installed. Install with: pip install imbalanced-learn" + ) + print(f" Error: {e}") + print(" Proceeding without oversampling...") + + # Train models + models = train_models(X_train, y_train) + + # Evaluate models + feature_names = X.columns.tolist() + results = evaluate_models(models, X_test, y_test, feature_names) + + # Select best model (can change metric: "f1", "recall", "precision", "auc") + # For sepsis prediction, recall (sensitivity) is often most important + best_name, best_model, best_metrics = select_best_model( + models, results, metric="f1" + ) + + # Save best model + scaler = models.get("scaler") + save_model( + best_model, + best_name, + feature_names, + best_metrics, + scaler, + output_path, + ) + + print("\n" + "=" * 60) + print("Training complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/docs/cookbook/clinical_coding.md b/docs/cookbook/clinical_coding.md index 46b34a05..129afad8 100644 --- a/docs/cookbook/clinical_coding.md +++ b/docs/cookbook/clinical_coding.md @@ -50,7 +50,7 @@ MEDPLUM_SCOPE=openid ## Add the CDA Adapter -First we'll need to convert the incoming CDA XML to FHIR. The [CdaAdapter](../reference/pipeline/adapters/cdaadapter.md) enables round-trip conversion between CDA and FHIR using the [InteropEngine](../reference/interop/engine.md) for seamless legacy-to-modern data integration. +First we'll need to convert the incoming CDA XML to FHIR. The [CdaAdapter](../reference/io/adapters/cdaadapter.md) enables round-trip conversion between CDA and FHIR using the [InteropEngine](../reference/interop/engine.md) for seamless legacy-to-modern data integration. ```python diff --git a/docs/cookbook/discharge_summarizer.md b/docs/cookbook/discharge_summarizer.md index 897fd010..8fd510d2 100644 --- a/docs/cookbook/discharge_summarizer.md +++ b/docs/cookbook/discharge_summarizer.md @@ -92,7 +92,7 @@ The `SummarizationPipeline` automatically: ## Add the CDS FHIR Adapter -The [CdsFhirAdapter](../reference/pipeline/adapters/cdsfhiradapter.md) converts between CDS Hooks requests and HealthChain's [Document](../reference/pipeline/data_container.md) format. This makes it easy to work with FHIR data in CDS workflows. +The [CdsFhirAdapter](../reference/io/adapters/cdsfhiradapter.md) converts between CDS Hooks requests and HealthChain's [Document](../reference/io/containers/document.md) format. This makes it easy to work with FHIR data in CDS workflows. ```python from healthchain.io import CdsFhirAdapter diff --git a/docs/cookbook/multi_ehr_aggregation.md b/docs/cookbook/multi_ehr_aggregation.md index d0f3d6ce..2e03800d 100644 --- a/docs/cookbook/multi_ehr_aggregation.md +++ b/docs/cookbook/multi_ehr_aggregation.md @@ -142,7 +142,7 @@ uvicorn.run(app) For additional processing like terminology mapping or quality checks, create a Document [Pipeline](../reference/pipeline/pipeline.md). -Document pipelines are optimized for text and structured data processing, such as FHIR resources. When you initialize a [Document](../reference/pipeline/data_container.md) with FHIR [Bundle](https://www.hl7.org/fhir/condition.html) data, it automatically extracts and separates metadata resources from the clinical resources for easier inspection and error handling: +Document pipelines are optimized for text and structured data processing, such as FHIR resources. When you initialize a [Document](../reference/io/containers/document.md) with FHIR [Bundle](https://www.hl7.org/fhir/condition.html) data, it automatically extracts and separates metadata resources from the clinical resources for easier inspection and error handling: ```python # Initialize Document with a Bundle diff --git a/docs/quickstart.md b/docs/quickstart.md index fe722882..08f03972 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -51,11 +51,11 @@ You can build pipelines with three different approaches: #### 1. Quick Inline Functions -For quick experiments, start by picking the right [**Container**](./reference/pipeline/data_container.md) when you initialize your pipeline (e.g. `Pipeline[Document]()` for clinical text). +For quick experiments, start by picking the right [**Container**](./reference/io/containers/containers.md) when you initialize your pipeline (e.g. `Pipeline[Document]()` for clinical text). Containers make your pipeline FHIR-native by loading and transforming your data (free text, EHR resources, etc.) into structured FHIR-ready formats. Just add your processing functions with `@add_node`, compile with `.build()`, and your pipeline is ready to process FHIR data end-to-end. -[(Full Documentation on Container)](./reference/pipeline/data_container.md) +[(Full Documentation on Containers)](./reference/io/containers/containers.md) ```python from healthchain.pipeline import Pipeline @@ -109,9 +109,9 @@ doc = Document("Patient presents with hypertension.") output = pipe(doc) ``` -You can process legacy healthcare data formats too. [**Adapters**](./reference/pipeline/adapters/adapters.md) convert between healthcare formats like [CDA](https://www.hl7.org/implement/standards/product_brief.cfm?product_id=7) and your pipeline — just parse, process, and format without worrying about low-level data conversion. +You can process legacy healthcare data formats too. [**Adapters**](./reference/io/adapters/adapters.md) convert between healthcare formats like [CDA](https://www.hl7.org/implement/standards/product_brief.cfm?product_id=7) and your pipeline — just parse, process, and format without worrying about low-level data conversion. -[(Full Documentation on Adapters)](./reference/pipeline/adapters/adapters.md) +[(Full Documentation on Adapters)](./reference/io/adapters/adapters.md) ```python from healthchain.io import CdaAdapter diff --git a/docs/reference/pipeline/adapters/adapters.md b/docs/reference/io/adapters/adapters.md similarity index 94% rename from docs/reference/pipeline/adapters/adapters.md rename to docs/reference/io/adapters/adapters.md index f68f96f0..cc94e725 100644 --- a/docs/reference/pipeline/adapters/adapters.md +++ b/docs/reference/io/adapters/adapters.md @@ -8,7 +8,7 @@ Unlike the legacy connector pattern, adapters are used explicitly and provide cl Adapters parse data from specific healthcare formats into FHIR resources and store them in a `Document` container for processing. -([Document API Reference](../../../api/containers.md#healthchain.io.containers.document.Document)) +([Document API Reference](../../api/containers.md#healthchain.io.containers.document.Document)) | Adapter | Input Format | Output Format | FHIR Resources | Document Access | |---------|--------------|---------------|----------------|-----------------| @@ -67,6 +67,8 @@ print(f"Allergies: {doc.fhir.allergy_list}") response = adapter.format(doc) # Document → CdaResponse ``` +For more details on the Document container, see [Document](../containers/document.md). + ## Adapter Configuration ### Custom Interop Engine diff --git a/docs/reference/pipeline/adapters/cdaadapter.md b/docs/reference/io/adapters/cdaadapter.md similarity index 100% rename from docs/reference/pipeline/adapters/cdaadapter.md rename to docs/reference/io/adapters/cdaadapter.md diff --git a/docs/reference/pipeline/adapters/cdsfhiradapter.md b/docs/reference/io/adapters/cdsfhiradapter.md similarity index 100% rename from docs/reference/pipeline/adapters/cdsfhiradapter.md rename to docs/reference/io/adapters/cdsfhiradapter.md diff --git a/docs/reference/io/containers/containers.md b/docs/reference/io/containers/containers.md new file mode 100644 index 00000000..3961ab35 --- /dev/null +++ b/docs/reference/io/containers/containers.md @@ -0,0 +1,29 @@ +# Containers + +The `healthchain.io.containers` module provides FHIR-native containers for healthcare data processing. These containers handle the complexities of clinical data formats while providing a clean Python interface for NLP/ML pipelines. + +## Available Containers + +| Container | Purpose | Use Cases | +|-----------|---------|-----------| +| [**Document**](document.md) | Clinical text + FHIR resources | Clinical notes, discharge summaries, CDS workflows | +| [**Dataset**](dataset.md) | ML-ready features from FHIR | Model training/inference, feature engineering | + +## DataContainer 📦 + +`DataContainer` is a generic base class for storing data of any type. It provides serialization methods that other containers inherit. + +```python +from healthchain.io.containers import DataContainer + +# Create a DataContainer with string data +container = DataContainer("Some data") + +# Convert to dictionary and JSON +data_dict = container.to_dict() +data_json = container.to_json() + +# Create from dictionary or JSON +container_from_dict = DataContainer.from_dict(data_dict) +container_from_json = DataContainer.from_json(data_json) +``` diff --git a/docs/reference/io/containers/dataset.md b/docs/reference/io/containers/dataset.md new file mode 100644 index 00000000..731e5310 --- /dev/null +++ b/docs/reference/io/containers/dataset.md @@ -0,0 +1,113 @@ +# Dataset 📊 + +The `Dataset` is a pandas DataFrame wrapper designed for healthcare ML workflows: it extracts ML-ready features from FHIR Bundles using schemas, validates data types, and converts model predictions back into clinical decision support resources ([RiskAssessment](https://hl7.org/fhir/riskassessment.html)). + +## Usage + +The two most helpful methods in the `Dataset` class are: + +- `from_fhir_bundle()`: Extract ML-ready features from a FHIR Bundle using a feature schema. +- `to_risk_assessment()`: Convert model predictions into FHIR RiskAssessment resources for clinical consumption. + +!!! tip "Feature Schemas" + Define features once in YAML and reuse across training, validation, and inference. See [FHIR Feature Mapper](../mappers/fhir_feature.md) for schema details. + + +```python +from healthchain.io.containers import Dataset + +# 1. Extract ML features from a FHIR Bundle using a feature schema +dataset = Dataset.from_fhir_bundle(bundle, schema="path/to/schema.yaml") + +# 2. Inspect the features as a pandas DataFrame +print(dataset.data.head()) +print("Columns:", dataset.columns) + +# 3. Validate the dataset against the schema (checks for missing/invalid fields) +validation_result = dataset.validate(schema="path/to/schema.yaml") +print("Validation Result:", validation_result) + +# 4. Run inference using your ML model +predictions = model.predict(dataset.data) +probabilities = model.predict_proba(dataset.data)[:, 1] + +# 5. Convert predictions to FHIR RiskAssessment resources for downstream use +risk_assessments = dataset.to_risk_assessment( + predictions=predictions, + probabilities=probabilities, + outcome_code="A41.9", + outcome_display="Sepsis, unspecified", + model_name="SepsisRiskModel", + model_version="1.0" +) +``` + +This workflow lets you convert FHIR healthcare data into DataFrames for ML, and then easily package predictions as standardized FHIR artifacts. + + +??? example "Example RiskAssessment Output" + ```json + { + "resourceType": "RiskAssessment", + "id": "hc-a1b2c3d4", + "status": "final", + "subject": { + "reference": "Patient/123" + }, + "method": { + "coding": [{ + "system": "https://healthchain.github.io/ml-models", + "code": "RandomForestClassifier", + "display": "RandomForestClassifier v2.1" + }] + }, + "prediction": [{ + "outcome": { + "coding": [{ + "system": "http://hl7.org/fhir/sid/icd-10", + "code": "A41.9", + "display": "Sepsis, unspecified" + }] + }, + "probabilityDecimal": 0.85, + "qualitativeRisk": { + "coding": [{ + "system": "http://terminology.hl7.org/CodeSystem/risk-probability", + "code": "high", + "display": "High Risk" + }] + } + }], + "note": [{ + "text": "ML prediction: Positive (probability: 85.00%, risk: high)" + }] + } + ``` + +### Properties and Methods + +Common Dataset operations: + +```python +# Metadata +print(dataset.columns) # List of feature names +print(dataset.row_count()) # Number of samples +print(dataset.column_count()) # Number of features +print(dataset.describe()) # Summary statistics + +# Data access +df = dataset.data # Underlying pandas DataFrame +dtypes = dataset.dtypes # Feature data types + +# Data manipulation +dataset.remove_column('temp_feature') # Drop a feature +``` + +## Resource Documentation + +- [FHIR RiskAssessment](https://www.hl7.org/fhir/riskassessment.html) +- [FHIR Observation](https://www.hl7.org/fhir/observation.html) + +## API Reference + +See the [Dataset API Reference](../../api/containers.md#healthchain.io.containers.dataset) for detailed class documentation. diff --git a/docs/reference/io/containers/document.md b/docs/reference/io/containers/document.md new file mode 100644 index 00000000..fecfda3b --- /dev/null +++ b/docs/reference/io/containers/document.md @@ -0,0 +1,235 @@ +# Document 📄 + +The `Document` class is a container for working with both clinical text and structured healthcare data. It natively manages FHIR resources, runs NLP over raw notes, tracks clinical document relationships, stores decision support outputs, and holds LLM model predictions. + +Use Document containers for clinical notes, discharge summaries, patient records, and any healthcare data that combines text with structured FHIR resources. + +## Usage + +The main things you'll do with `Document`: + +- Store and update clinical notes and FHIR Bundles +- Extract and manipulate diagnoses, meds, allergies, and documents +- Run NLP to extract entities or embeddings from text +- Generate & store CDS Hooks cards (recommendations, alerts) +- Attach model predictions for downstream use + + +## API Overview + +**Document** has four key components (all accessible as attributes): + +| Attribute | For | +|---|---| +| `doc.fhir` | FHIR management—Clinical lists, Bundles, DocReference, patient info | +| `doc.nlp` | NLP features—entities, tokens, embeddings | +| `doc.cds` | Decision support—recommendation cards, actions | +| `doc.models` | ML/LLM outputs—store/retrieve predictions, generations | + + +### FHIR Data (`doc.fhir`) + +- Automatic FHIR Bundle creation and management +- Resource type validation +- Easy access to clinical data lists (e.g., problems, medications, allergies) +- OperationOutcome and Provenance resources automatically extracted and accessible as `doc.fhir.operation_outcomes` and `doc.fhir.provenances` (removed from main bundle) + +**Convenience Accessors** + +| Attribute | Description | +|-------------------|---------------------------------------------------------| +| `patient` | First Patient resource in the bundle (or `None`) | +| `patients` | List of Patient resources | +| `problem_list` | List of Condition resources (diagnoses, problems) | +| `medication_list` | List of MedicationStatement resources | +| `allergy_list` | List of AllergyIntolerance resources | + +**Document Reference Management** + +- Document relationship tracking (parent/child/sibling) +- Attachment handling with base64 encoding +- Document family retrieval + +**CDS Support** + +- Support for CDS Hooks prefetch resources +- Resource indexing by type + + +```python +from healthchain.io import Document +from healthchain.fhir import ( + create_condition, + create_document_reference, +) + +# Initialize with clinical text from EHR +doc = Document("Patient presents with uncontrolled hypertension and Type 2 diabetes") + +# Build problem list with SNOMED CT codes +doc.fhir.problem_list = [ + create_condition( + subject="Patient/123", + code="38341003", + display="Hypertension" + ), + create_condition( + subject="Patient/123", + code="44054006", + display="Type 2 diabetes mellitus" + ) +] + +# Track document versions and amendments +initial_note = create_document_reference( + data="Initial assessment: Patient presents with chest pain", + content_type="text/plain", + description="Initial ED note" +) +initial_id = doc.fhir.add_document_reference(initial_note) + +# Add amended note +amended_note = create_document_reference( + data="Amended: Patient presents with chest pain, ruling out cardiac etiology", + content_type="text/plain", + description="Amended ED note" +) +amended_id = doc.fhir.add_document_reference( + amended_note, + parent_id=initial_id, + relationship_type="replaces" +) + +# Retrieve document history for audit trail +family = doc.fhir.get_document_reference_family(amended_id) +print(f"Original note: {family['parents'][0].description}") + + +# Handle errors and track data provenance +if doc.fhir.operation_outcomes: + for outcome in doc.fhir.operation_outcomes: + print(f"Warning: {outcome.issue[0].diagnostics}") + +# Access patient demographics +if doc.fhir.patient: + print(f"Patient: {doc.fhir.patient.name[0].given[0]} {doc.fhir.patient.name[0].family}") + +# Prepare data for CDS Hooks integration +prefetch = { + "Condition": doc.fhir.problem_list, + "MedicationStatement": doc.fhir.medication_list, +} +doc.fhir.prefetch_resources = prefetch + +# CDS service can query prefetch data +conditions = doc.fhir.get_prefetch_resources("Condition") +print(f"Active conditions: {len(conditions)}") +``` + +### NLP (`doc.nlp`) + +- Medical text features: tokens, entities (`get_entities()`), embeddings (`get_embeddings()`) +- Direct spaCy doc access, fast word counting + +```python +# Extract medical concepts from clinical note +doc = Document("Patient diagnosed with pneumonia, started on azithromycin") + +# Get medical entities +entities = doc.nlp.get_entities() +for entity in entities: + print(f"{entity.text}: {entity.label_}") # "pneumonia: CONDITION" + +# Access full spaCy document for custom processing +spacy_doc = doc.nlp.get_spacy_doc() +for ent in spacy_doc.ents: + if hasattr(ent._, "cui"): + print(f"{ent.text} -> SNOMED: {ent._.cui}") +``` + +### Clinical Decision Support (`doc.cds`) + +- `cards`: Clinical recommendation cards displayed in EHR workflows +- `actions`: Suggested interventions (orders, referrals, documentation) + +```python +from healthchain.models import Card, Action + +# Generate clinical alert +doc.cds.cards = [ + Card( + summary="Drug interaction detected", + indicator="critical", + detail="Warfarin + NSAIDs: Increased bleeding risk", + source={"label": "Clinical Decision Support"}, + ) +] + +# Suggest action +doc.cds.actions = [ + Action( + type="create", + description="Order CBC to monitor platelets", + resource={ + "resourceType": "ServiceRequest", + "code": {"text": "Complete Blood Count"} + } + ) +] +``` + + +### LLM Model Outputs (`doc.models`) + +- `get_output(model_name, task)`: Retrieve model predictions by name and task +- `get_generated_text(model_name, task)`: Extract generated text from LLMs +- Supports Hugging Face, LangChain, spaCy, and custom models + +```python +# Store classification results +doc.models.add_output( + model_name="clinical_classifier", + task="diagnosis_prediction", + output={"prediction": "diabetes", "confidence": 0.95} +) + +# Store LLM summary +doc.models.add_output( + model_name="gpt4", + task="summarization", + output="Patient presents with classic diabetic symptoms..." +) + +# Retrieve outputs +diagnosis = doc.models.get_output("clinical_classifier", "diagnosis_prediction") +summary = doc.models.get_generated_text("gpt4", "summarization") +``` + +### Properties and Methods + +```python +# FHIR access +print(doc.fhir.problem_list) +print(doc.fhir.patient) + +# NLP +tokens = doc.nlp.get_tokens() +ents = doc.nlp.get_entities() + +# Clinical decision support +cards = doc.cds.cards + +# Model outputs +doc.models.add_output("my_model", "task", output={"foo": "bar"}) +print(doc.models.get_output("my_model", "task")) +``` + +## Resource Docs + +- [FHIR Bundle](https://www.hl7.org/fhir/bundle.html) +- [FHIR Condition](https://www.hl7.org/fhir/condition.html) +- [FHIR DocumentReference](https://www.hl7.org/fhir/documentreference.html) + +## API Reference + +See [Document API Reference](../../api/containers.md#healthchain.io.containers.document) for full details. diff --git a/docs/reference/io/mappers/fhir_feature.md b/docs/reference/io/mappers/fhir_feature.md new file mode 100644 index 00000000..2efae592 --- /dev/null +++ b/docs/reference/io/mappers/fhir_feature.md @@ -0,0 +1,132 @@ +# FHIR Feature Mapper + +The `FHIRFeatureMapper` allows you to easily extract relevant features from FHIR Bundles based on a declarative schema. This makes it simple to generate ML-ready DataFrames for downstream analysis and modeling. + +## Overview + +The mapper uses feature schemas—YAML configs that define which clinical data to extract and how to transform it. This enables: + +- **Declarative mapping**: Define features in YAML, not code +- **Reproducible pipelines**: Same schema = same features across train/test/prod +- **Built-in validation**: Type checking catches mismatches before inference +- **FHIR-native**: Works with any EHR's FHIR Bundle + +## Usage + +Write a YAML file specifying which FHIR resources and codes to extract, desired data types, and any transformations: + +```yaml +name: sepsis_prediction_features +version: "1.0" +description: Feature schema for sepsis risk model + +# Optional: Control how patient age is calculated +metadata: + age_calculation: event_date # Calculate age at event time + event_date_source: Observation # Use earliest observation date + event_date_strategy: earliest + +features: + # Vital signs from Observations + heart_rate: + fhir_resource: Observation + code: "220045" # MIMIC-IV itemID + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items + display: Heart Rate + unit: bpm + dtype: float64 + required: true + + # Demographics from Patient resource + age: + fhir_resource: Patient + field: birthDate # Extract this field + transform: calculate_age # Apply this transformation + dtype: int64 + required: true +``` + +### Standalone Use + +In most cases you should use the [**Dataset**](../containers/dataset.md) API to automatically load your schema and extract features. It's the easiest and most robust workflow. + +For advanced usage, you can load a schema and use the `FHIRFeatureMapper` directly for more control: + +```python +from healthchain.io.mappers import FHIRFeatureMapper +from healthchain.io.containers import FeatureSchema + +# Manually load your YAML feature schema +schema = FeatureSchema.from_yaml("configs/features/my_model.yaml") + +# Create the feature mapper with your schema +mapper = FHIRFeatureMapper(schema) + +# Extract features from a FHIR Bundle +features_df = mapper.map(bundle, aggregation="mean") + +print(features_df.head()) +# (Optional) Access patient references +patient_refs = features_df["patient_ref"].tolist() +``` + +### Aggregation Strategies + +When a patient has multiple observations for the same code (e.g., multiple temperature readings), specify how to aggregate them: + +```python +# Take the mean of all values +dataset = Dataset.from_fhir_bundle(bundle, schema, aggregation="mean") + +# Use the most recent value +dataset = Dataset.from_fhir_bundle(bundle, schema, aggregation="last") + +# Other options: "median", "max", "min" +``` + +### Multiple Code Systems + +Different EHRs use different code systems. You can map the same clinical concept across systems: + +```yaml +# LOINC code for heart rate (standard) +heart_rate_loinc: + fhir_resource: Observation + code: "8867-4" + code_system: http://loinc.org + display: Heart Rate + dtype: float64 + +# MIMIC-IV internal code +heart_rate_mimic: + fhir_resource: Observation + code: "220045" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items + display: Heart Rate + dtype: float64 +``` + +Then in your code, merge as needed: + +```python +df = dataset.data +# Combine both columns, preferring LOINC +df['heart_rate'] = df['heart_rate_loinc'].fillna(df['heart_rate_mimic']) +``` + +### Validation and Error Handling + +Check that incoming data matches your training schema: + +```python +from healthchain.io.containers import FeatureSchema + +schema = FeatureSchema.from_yaml("configs/features/my_model.yaml") +result = dataset.validate(schema, raise_on_error=False) +``` + +## Related Documentation + +- [Dataset Container](../containers/dataset.md) - Complete Dataset API reference +- [Mappers Overview](mappers.md) - Other mapper types +- [FHIR Helpers](../../utilities/fhir_helpers.md) - Creating FHIR resources diff --git a/docs/reference/io/mappers/mappers.md b/docs/reference/io/mappers/mappers.md new file mode 100644 index 00000000..4c8c41b4 --- /dev/null +++ b/docs/reference/io/mappers/mappers.md @@ -0,0 +1,20 @@ +# Mappers + +Mappers transform data between different healthcare formats and structures. They enable standardized data conversion workflows while maintaining clinical semantics and validation. + +## Available Mappers + +| Mapper | Source Format | Target Format | Primary Use Case | +|--------|---------------|---------------|------------------| +| [**FHIRFeatureMapper**](fhir_feature.md) | FHIR Bundle | pandas DataFrame | Extract ML-ready features from FHIR resources | + +### Future Mappers (Planned) + +- **FHIR-to-FHIR Mapper**: Transform between FHIR resource types +- **FHIR-to-OMOP Mapper**: Convert between FHIR and OMOP Common Data Model + +## Related Documentation + +- [Containers](../containers/containers.md) - Data containers that use mappers +- [Dataset](../containers/dataset.md) - Uses FHIRFeatureMapper for feature extraction +- [Adapters](../adapters/adapters.md) - Convert between healthcare protocols diff --git a/docs/reference/pipeline/data_container.md b/docs/reference/pipeline/data_container.md deleted file mode 100644 index 4e3e952d..00000000 --- a/docs/reference/pipeline/data_container.md +++ /dev/null @@ -1,356 +0,0 @@ -# Data Container - -The `healthchain.io.containers` module provides FHIR-native containers for healthcare data processing. These containers handle the complexities of clinical data formats while providing a clean Python interface for NLP/ML pipelines. - -## DataContainer 📦 - -`DataContainer` is a generic base class for storing data of any type. - -```python -from healthchain.io.containers import DataContainer - -# Create a DataContainer with string data -container = DataContainer("Some data") - -# Convert to dictionary and JSON -data_dict = container.to_dict() -data_json = container.to_json() - -# Create from dictionary or JSON -container_from_dict = DataContainer.from_dict(data_dict) -container_from_json = DataContainer.from_json(data_json) -``` - -## Document 📄 - -The `Document` class is HealthChain's core container for clinical text and structured healthcare data. It handles FHIR resources natively, automatically manages validation and conversion, and integrates seamlessly with NLP models and CDS workflows. - -Use Document containers for clinical notes, discharge summaries, patient records, and any healthcare data that combines text with structured FHIR resources. - -| Attribute | Access | Primary Purpose | Key Features | Common Use Cases | -|-----------|--------|----------------|--------------|------------------| -| [**FHIR Data**](#fhir-data-docfhir) | `doc.fhir` | Manage clinical data in FHIR format | • Resource bundles
• Clinical lists (problems, meds, allergies)
• Document references
• CDS prefetch | • Store patient records
• Track medical history
• Manage clinical documents | -| [**NLP**](#nlp-component-docnlp) | `doc.nlp` | Process and analyze text | • Tokenization
• Entity recognition
• Embeddings
• spaCy integration | • Extract medical terms
• Analyze clinical text
• Generate features | -| [**CDS**](#clinical-decision-support-doccds) | `doc.cds` | Clinical decision support | • Recommendation cards
• Suggested actions
• Clinical alerts | • Generate alerts
• Suggest interventions
• Guide clinical decisions | -| [**Model Outputs**](#model-outputs-docmodels) | `doc.models` | Store ML model results | • Multi-framework support
• Task-specific outputs
• Text generation | • Store classifications
• Keep predictions
• Track generations | - -### FHIR Data (`doc.fhir`) - -The FHIR component provides production-ready management of FHIR resources with automatic validation, error handling, and convenient accessors for common clinical workflows: - -**Storage and Management:** - - - Automatic `Bundle` creation and management - - Resource type validation - - Convenient access to common clinical data lists - - Automatic extraction of `OperationOutcome` and `Provenance` resources into `doc.fhir.operation_outcomes` and `doc.fhir.provenances` (removed from bundle) - -**Convenience Accessors:** - -- `patient`: First Patient resource in the bundle, or `None` -- `patients`: List of Patient resources -- `problem_list`: List of `Condition` resources (diagnoses, problems) -- `medication_list`: List of `MedicationStatement` resources -- `allergy_list`: List of `AllergyIntolerance` resources - -**Document Reference Management:** - - - Document relationship tracking (parent/child/sibling) - - Attachment handling with `base64` encoding - - Document family retrieval - -**CDS Support:** - - - Support for CDS Hooks prefetch resources - - Resource indexing by type - -**Example: Clinical Documentation Workflow** - -```python -from healthchain.io import Document -from healthchain.fhir import ( - create_condition, - create_medication_statement, - create_document_reference, -) - -# Initialize with clinical text from EHR -doc = Document("Patient presents with uncontrolled hypertension and Type 2 diabetes") - -# Build problem list with SNOMED CT codes -doc.fhir.problem_list = [ - create_condition( - subject="Patient/123", - code="38341003", - display="Hypertension" - ), - create_condition( - subject="Patient/123", - code="44054006", - display="Type 2 diabetes mellitus" - ) -] - -# Document current medications -doc.fhir.medication_list = [ - create_medication_statement( - subject="Patient/123", - code="197361", - display="Lisinopril 10 MG" - ), - create_medication_statement( - subject="Patient/123", - code="860975", - display="Metformin 500 MG" - ) -] - -# Track document versions and amendments -initial_note = create_document_reference( - data="Initial assessment: Patient presents with chest pain", - content_type="text/plain", - description="Initial ED note" -) -initial_id = doc.fhir.add_document_reference(initial_note) - -# Add amended note -amended_note = create_document_reference( - data="Amended: Patient presents with chest pain, ruling out cardiac etiology", - content_type="text/plain", - description="Amended ED note" -) -amended_id = doc.fhir.add_document_reference( - amended_note, - parent_id=initial_id, - relationship_type="replaces" -) - -# Retrieve document history for audit trail -family = doc.fhir.get_document_reference_family(amended_id) -print(f"Original note: {family['parents'][0].description}") - -# Prepare data for CDS Hooks integration -prefetch = { - "Condition": doc.fhir.problem_list, - "MedicationStatement": doc.fhir.medication_list, -} -doc.fhir.prefetch_resources = prefetch - -# CDS service can query prefetch data -conditions = doc.fhir.get_prefetch_resources("Condition") -print(f"Active conditions: {len(conditions)}") - -# Handle errors and track data provenance -if doc.fhir.operation_outcomes: - for outcome in doc.fhir.operation_outcomes: - print(f"Warning: {outcome.issue[0].diagnostics}") - -# Access patient demographics -if doc.fhir.patient: - print(f"Patient: {doc.fhir.patient.name[0].given[0]} {doc.fhir.patient.name[0].family}") -``` - -**Technical Notes:** - -- All FHIR resources are validated using [fhir.resources](https://github.com/nazrulworld/fhir.resources) -- Document relationships follow the FHIR [DocumentReference.relatesTo](https://www.hl7.org/fhir/documentreference-definitions.html#DocumentReference.relatesTo) standard - -**Resource Documentation:** - -- [FHIR Bundle](https://www.hl7.org/fhir/bundle.html) -- [FHIR DocumentReference](https://www.hl7.org/fhir/documentreference.html) -- [FHIR Condition](https://www.hl7.org/fhir/condition.html) -- [FHIR MedicationStatement](https://www.hl7.org/fhir/medicationstatement.html) -- [FHIR AllergyIntolerance](https://www.hl7.org/fhir/allergyintolerance.html) - -### NLP Component (`doc.nlp`) - -Process clinical text with medical NLP models and access extracted features: - -- `get_tokens()`: Tokenized clinical text for downstream processing -- `get_entities()`: Medical entities with optional CUI codes (SNOMED CT, RxNorm) -- `get_embeddings()`: Vector representations for similarity search and clustering -- `get_spacy_doc()`: Direct access to spaCy document for custom processing -- `word_count()`: Token-based word count - -**Example: Medical Entity Extraction** -```python -# Extract medical concepts from clinical note -doc = Document("Patient diagnosed with pneumonia, started on azithromycin") - -# Get medical entities -entities = doc.nlp.get_entities() -for entity in entities: - print(f"{entity.text}: {entity.label_}") # "pneumonia: CONDITION" - -# Access full spaCy document for custom processing -spacy_doc = doc.nlp.get_spacy_doc() -for ent in spacy_doc.ents: - if hasattr(ent._, "cui"): - print(f"{ent.text} -> SNOMED: {ent._.cui}") -``` - - -### Clinical Decision Support (`doc.cds`) - -Generate CDS Hooks cards and actions for real-time EHR integration: - -- `cards`: Clinical recommendation cards displayed in EHR workflows -- `actions`: Suggested interventions (orders, referrals, documentation) - -**Example: CDS Hooks Response** -```python -from healthchain.models import Card, Action - -# Generate clinical alert -doc.cds.cards = [ - Card( - summary="Drug interaction detected", - indicator="critical", - detail="Warfarin + NSAIDs: Increased bleeding risk", - source={"label": "Clinical Decision Support"}, - ) -] - -# Suggest action -doc.cds.actions = [ - Action( - type="create", - description="Order CBC to monitor platelets", - resource={ - "resourceType": "ServiceRequest", - "code": {"text": "Complete Blood Count"} - } - ) -] -``` - - -### Model Outputs (`doc.models`) - -Store and retrieve ML model predictions across multiple frameworks: - -- `get_output(model_name, task)`: Retrieve model predictions by name and task -- `get_generated_text(model_name, task)`: Extract generated text from LLMs -- Supports Hugging Face, LangChain, spaCy, and custom models - -**Example: Multi-Model Pipeline** -```python -# Store classification results -doc.models.add_output( - model_name="clinical_classifier", - task="diagnosis_prediction", - output={"prediction": "diabetes", "confidence": 0.95} -) - -# Store LLM summary -doc.models.add_output( - model_name="gpt4", - task="summarization", - output="Patient presents with classic diabetic symptoms..." -) - -# Retrieve outputs -diagnosis = doc.models.get_output("clinical_classifier", "diagnosis_prediction") -summary = doc.models.get_generated_text("gpt4", "summarization") -``` - -**Example: Complete Clinical Workflow** -```python -from healthchain.io import Document -from healthchain.fhir import create_condition -from healthchain.models import Card, Action - -# Initialize with clinical note from EHR -doc = Document("67yo M presents with acute chest pain radiating to left arm, diaphoresis") - -# Process with NLP model -print(f"Clinical note length: {doc.nlp.word_count()} words") -entities = doc.nlp.get_entities() - -# Extract FHIR conditions from text -spacy_doc = doc.nlp.get_spacy_doc() -for ent in spacy_doc.ents: - if ent.label_ == "CONDITION" and hasattr(ent._, "cui"): - doc.fhir.problem_list.append( - create_condition( - subject="Patient/123", - code=ent._.cui, - display=ent.text - ) - ) - -# Or use helper method for automatic extraction -doc.update_problem_list_from_nlp() - -# Generate CDS alert based on findings -doc.cds.cards = [ - Card( - summary="STEMI Alert - Activate Cath Lab", - indicator="critical", - detail="Patient meets criteria for ST-elevation myocardial infarction", - source={"label": "Cardiology Protocol"}, - ) -] - -# Track model predictions -doc.models.add_output( - model_name="cardiac_risk_model", - task="classification", - output={"risk_level": "high", "score": 0.89} -) - -# Access all components -print(f"Problem list: {len(doc.fhir.problem_list)} conditions") -print(f"CDS cards: {len(doc.cds.cards)} alerts") -print(f"Risk assessment: {doc.models.get_output('cardiac_risk_model', 'classification')}") -``` - -[Document API Reference](../../api/containers.md#healthchain.io.containers.document) - -## Tabular 📊 - -The `Tabular` class handles structured healthcare data like lab results, patient cohorts, and claims data. It wraps pandas DataFrame with healthcare-specific operations. - -**Example: Patient Cohort Analysis** -```python -import pandas as pd -from healthchain.io.containers import Tabular - -# Load patient cohort data -df = pd.DataFrame({ - 'patient_id': ['P001', 'P002', 'P003'], - 'age': [45, 62, 58], - 'diagnosis': ['diabetes', 'hypertension', 'diabetes'], - 'hba1c': [7.2, None, 8.1] -}) -cohort = Tabular(df) - -# Analyze cohort characteristics -print(f"Cohort size: {cohort.row_count()} patients") -print(f"Average age: {cohort.data['age'].mean():.1f} years") -print(f"\nClinical measures:\n{cohort.describe()}") - -# Filter for diabetic patients -diabetic_cohort = cohort.data[cohort.data['diagnosis'] == 'diabetes'] -print(f"\nDiabetic patients: {len(diabetic_cohort)}") -print(f"Mean HbA1c: {diabetic_cohort['hba1c'].mean():.1f}%") - -# Export for reporting -cohort.to_csv('patient_cohort_analysis.csv') -``` - -**Example: Lab Results Processing** -```python -# Load lab results from EHR export -labs = Tabular.from_csv('lab_results.csv') - -print(f"Total lab orders: {labs.row_count()}") -print(f"Test types: {labs.data['test_name'].nunique()}") - -# Identify abnormal results -abnormal = labs.data[labs.data['flag'] == 'ABNORMAL'] -print(f"Abnormal results: {len(abnormal)} ({len(abnormal)/labs.row_count()*100:.1f}%)") -``` - -These containers provide a consistent, FHIR-aware interface for healthcare data processing throughout HealthChain pipelines, handling validation, conversion, and integration with clinical workflows automatically. diff --git a/docs/reference/pipeline/pipeline.md b/docs/reference/pipeline/pipeline.md index 5fc99da9..af65d73c 100644 --- a/docs/reference/pipeline/pipeline.md +++ b/docs/reference/pipeline/pipeline.md @@ -8,7 +8,7 @@ Choose from prebuilt pipelines tailored to standard clinical workflows, or build HealthChain comes with a set of end-to-end pipeline implementations of common healthcare data processing tasks. -These prebuilt pipelines handle FHIR conversion, validation, and EHR integration for you. They work out-of-the-box with [**Adapters**](./adapters/adapters.md) and [**Gateways**](../gateway/gateway.md), supporting CDS Hooks, NoteReader CDI, and FHIR APIs. They're great for a quick setup to build more complex integrations on top of. +These prebuilt pipelines handle FHIR conversion, validation, and EHR integration for you. They work out-of-the-box with [**Adapters**](../io/adapters/adapters.md) and [**Gateways**](../gateway/gateway.md), supporting CDS Hooks, NoteReader CDI, and FHIR APIs. They're great for a quick setup to build more complex integrations on top of. | Pipeline | Container | Use Case | Description | Example Application | @@ -61,9 +61,9 @@ pipeline = MedicalCodingPipeline.from_local_model( ## Freestyle 🕺 -[**Containers**](./data_container.md) are at the core of HealthChain pipelines: they define your data type and flow through each pipeline step, just like spaCy’s `Doc`. +[**Containers**](../io/containers/containers.md) are at the core of HealthChain pipelines: they define your data type and flow through each pipeline step, just like spaCy's `Doc`. -Specify the container (e.g. `Document` or `Tabular`) when creating your pipeline (`Pipeline[Document]()`). Each node processes and returns the container, enabling smooth, type-safe, modular workflows and direct FHIR conversion. +Specify the container (e.g. [Document](../io/containers/document.md) or [Dataset](../io/containers/dataset.md)) when creating your pipeline (`Pipeline[Document]()`). Each node processes and returns the container, enabling smooth, type-safe, modular workflows and direct FHIR conversion. ```python from healthchain.pipeline import Pipeline @@ -282,7 +282,7 @@ print(pipeline.stages) Adapters let you easily convert between healthcare formats (CDA, FHIR, CDS Hooks) and HealthChain Documents. Keep your ML pipeline format-agnostic while always getting FHIR-ready outputs. -[(Full Documentation on Adapters)](./adapters/adapters.md) +[(Full Documentation on Adapters)](../io/adapters/adapters.md) ```python from healthchain.io import CdaAdapter, Document diff --git a/docs/reference/utilities/sandbox.md b/docs/reference/utilities/sandbox.md index 3c666bb8..5ab4d952 100644 --- a/docs/reference/utilities/sandbox.md +++ b/docs/reference/utilities/sandbox.md @@ -140,6 +140,38 @@ data_dir/ ) ``` +=== "Direct Loader for ML Workflows" + ```python + # Use loader directly for ML pipelines (faster, no validation) + from healthchain.sandbox.loaders import MimicOnFHIRLoader + from healthchain.io import Dataset + + loader = MimicOnFHIRLoader() + + # as_dict=True: Returns single bundle dict (fast, no FHIR validation) + # Suitable for ML feature extraction workflows + bundle = loader.load( + data_dir="./data/mimic-iv-fhir", + resource_types=["MimicObservationChartevents", "MimicPatient"], + as_dict=True + ) + + # Convert to DataFrame for ML + dataset = Dataset.from_fhir_bundle( + bundle, + schema="healthchain/configs/features/sepsis_vitals.yaml" + ) + df = dataset.data + + # as_dict=False (default): Returns Dict[str, Bundle] + # Validated Bundle objects grouped by resource type (for CDS Hooks) + bundles = loader.load( + data_dir="./data/mimic-iv-fhir", + resource_types=["MimicMedication", "MimicCondition"] + ) + # Use bundles["medicationstatement"] and bundles["condition"] + ``` + ### Synthea Loader Synthetic patient data generated by [Synthea](https://synthea.mitre.org), containing realistic FHIR Bundles (typically 100-500 resources per patient). Ideal for single-patient workflows that require diverse data scenarios. diff --git a/healthchain/configs/features/sepsis_vitals.yaml b/healthchain/configs/features/sepsis_vitals.yaml new file mode 100644 index 00000000..52133afc --- /dev/null +++ b/healthchain/configs/features/sepsis_vitals.yaml @@ -0,0 +1,85 @@ +name: sepsis_prediction_features +version: "1.0" +description: Feature schema for sepsis prediction model trained on MIMIC-IV data + +model_info: + model_type: Random Forest / XGBoost / Logistic Regression + training_data: MIMIC-IV Clinical Database Demo + target: Sepsis (ICD-9/ICD-10 codes) + prediction_window: First 24 hours of ICU stay + +metadata: + age_calculation: event_date + event_date_source: Observation + event_date_strategy: earliest + +features: + heart_rate: + fhir_resource: Observation + code: "220045" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items + display: Heart Rate + unit: bpm + dtype: float64 + required: true + + temperature: + fhir_resource: Observation + code: "223761" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items + display: Temperature Fahrenheit + unit: °F + dtype: float64 + required: true + + respiratory_rate: + fhir_resource: Observation + code: "220210" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items + display: Respiratory Rate + unit: insp/min + dtype: float64 + required: true + + wbc: + fhir_resource: Observation + code: "51301" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-labitems + display: White Blood Cells + unit: K/uL + dtype: float64 + required: true + + lactate: + fhir_resource: Observation + code: "50843" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-labitems + display: Lactate Dehydrogenase, Ascites + unit: IU/L + dtype: float64 + required: true + + creatinine: + fhir_resource: Observation + code: "50912" + code_system: http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-labitems + display: Creatinine + unit: mg/dL + dtype: float64 + required: true + + age: + fhir_resource: Patient + field: birthDate + transform: calculate_age + dtype: int64 + required: true + display: Patient age calculated from birth date + + gender_encoded: + fhir_resource: Patient + field: gender + transform: encode_gender + dtype: int64 + required: true + display: Administrative gender (M=1, F=0) diff --git a/healthchain/fhir/__init__.py b/healthchain/fhir/__init__.py index 9da081d3..9193ccd4 100644 --- a/healthchain/fhir/__init__.py +++ b/healthchain/fhir/__init__.py @@ -1,23 +1,32 @@ """FHIR utilities for HealthChain.""" -from healthchain.fhir.helpers import ( +from healthchain.fhir.resourcehelpers import ( create_condition, create_medication_statement, create_allergy_intolerance, - create_single_codeable_concept, - create_single_reaction, - set_condition_category, - read_content_attachment, + create_value_quantity_observation, + create_patient, + create_risk_assessment_from_prediction, create_document_reference, create_document_reference_content, + set_condition_category, + add_provenance_metadata, + add_coding_to_codeable_concept, +) + +from healthchain.fhir.elementhelpers import ( + create_single_codeable_concept, + create_single_reaction, create_single_attachment, +) + +from healthchain.fhir.readers import ( create_resource_from_dict, convert_prefetch_to_fhir_objects, - add_provenance_metadata, - add_coding_to_codeable_concept, + read_content_attachment, ) -from healthchain.fhir.bundle_helpers import ( +from healthchain.fhir.bundlehelpers import ( create_bundle, add_resource, get_resources, @@ -27,11 +36,30 @@ count_resources, ) +from healthchain.fhir.dataframe import ( + BundleConverterConfig, + bundle_to_dataframe, + get_supported_resources, + get_resource_info, + print_supported_resources, +) + +from healthchain.fhir.utilities import ( + calculate_age_from_birthdate, + calculate_age_from_event_date, + encode_gender, +) + __all__ = [ # Resource creation "create_condition", "create_medication_statement", "create_allergy_intolerance", + "create_value_quantity_observation", + "create_patient", + "create_risk_assessment_from_prediction", + "create_document_reference", + # Element creation "create_single_codeable_concept", "create_single_reaction", "set_condition_category", @@ -39,11 +67,14 @@ "create_document_reference", "create_document_reference_content", "create_single_attachment", - "create_resource_from_dict", - "convert_prefetch_to_fhir_objects", # Resource modification + "set_condition_category", "add_provenance_metadata", "add_coding_to_codeable_concept", + # Conversions and readers + "create_resource_from_dict", + "convert_prefetch_to_fhir_objects", + "read_content_attachment", # Bundle operations "create_bundle", "add_resource", @@ -52,4 +83,14 @@ "merge_bundles", "extract_resources", "count_resources", + # Bundle to DataFrame conversion + "BundleConverterConfig", + "bundle_to_dataframe", + "get_supported_resources", + "get_resource_info", + "print_supported_resources", + # Utility functions + "calculate_age_from_birthdate", + "calculate_age_from_event_date", + "encode_gender", ] diff --git a/healthchain/fhir/bundle_helpers.py b/healthchain/fhir/bundlehelpers.py similarity index 100% rename from healthchain/fhir/bundle_helpers.py rename to healthchain/fhir/bundlehelpers.py diff --git a/healthchain/fhir/dataframe.py b/healthchain/fhir/dataframe.py new file mode 100644 index 00000000..0c9bfdb2 --- /dev/null +++ b/healthchain/fhir/dataframe.py @@ -0,0 +1,600 @@ +"""FHIR to DataFrame converters. + +This module provides generic functions to convert FHIR Bundles to pandas DataFrames +for analysis and ML model deployment. + +In instances where there are multiple codes present for a single resource, the first code is used as the primary code. +""" + +import pandas as pd +import logging + +from typing import Any, Dict, List, Union, Optional, Literal +from collections import defaultdict +from fhir.resources.bundle import Bundle +from pydantic import BaseModel, field_validator, ConfigDict + +from healthchain.fhir.utilities import ( + calculate_age_from_birthdate, + calculate_age_from_event_date, + encode_gender, +) + +logger = logging.getLogger(__name__) + + +# Resource handler registry +SUPPORTED_RESOURCES = { + "Patient": { + "handler": "_flatten_patient", + "description": "Patient demographics (age, gender)", + "output_columns": ["age", "gender"], + }, + "Observation": { + "handler": "_flatten_observations", + "description": "Clinical observations (vitals, labs)", + "output_columns": "Dynamic based on observation codes", + "options": ["aggregation"], + }, + "Condition": { + "handler": "_flatten_conditions", + "description": "Conditions/diagnoses as binary indicators", + "output_columns": "Dynamic: condition_{code}_{display}", + }, + "MedicationStatement": { + "handler": "_flatten_medications", + "description": "Medications as binary indicators", + "output_columns": "Dynamic: medication_{code}_{display}", + }, +} + + +class BundleConverterConfig(BaseModel): + """Configuration for FHIR Bundle to DataFrame conversion. + + This configuration object controls which FHIR resources are processed and how + they are converted to DataFrame columns for ML model deployment. + + Attributes: + resources: List of FHIR resource types to include in the conversion + observation_aggregation: How to aggregate multiple observation values + age_calculation: Method for calculating patient age + event_date_source: Which resource to extract event date from + event_date_strategy: Which date to use when multiple dates exist + resource_options: Resource-specific configuration options (extensible) + + Example: + >>> config = BundleConverterConfig( + ... resources=["Patient", "Observation", "Condition"], + ... observation_aggregation="median" + ... ) + >>> df = bundle_to_dataframe(bundle, config=config) + """ + + # Core resources to include + resources: List[str] = ["Patient", "Observation"] + + # Observation-specific options + observation_aggregation: Literal["mean", "median", "max", "min", "last"] = "mean" + + # Patient age calculation + age_calculation: Literal["current_date", "event_date"] = "current_date" + event_date_source: Literal["Observation", "Encounter"] = "Observation" + event_date_strategy: Literal["earliest", "latest", "first"] = "earliest" + + # Resource-specific options (extensible for future use) + resource_options: Dict[str, Dict[str, Any]] = {} + + model_config = ConfigDict(extra="allow") + + @field_validator("resources") + @classmethod + def validate_resources(cls, v): + """Validate that requested resources are supported and warn about unsupported ones.""" + supported = get_supported_resources() + unsupported = [r for r in v if r not in supported] + if unsupported: + logger.warning( + f"Unsupported resources will be skipped: {unsupported}. " + f"Supported resources: {supported}" + ) + return v + + +def get_supported_resources() -> List[str]: + """Get list of supported FHIR resource types. + + Returns: + List of resource type names that can be converted to DataFrame columns + + Example: + >>> resources = get_supported_resources() + >>> print(resources) + ['Patient', 'Observation', 'Condition', 'MedicationStatement'] + """ + return list(SUPPORTED_RESOURCES.keys()) + + +def get_resource_info(resource_type: str) -> Dict[str, Any]: + """Get detailed information about a supported resource type. + + Args: + resource_type: FHIR resource type name + + Returns: + Dictionary with resource handler information, or empty dict if unsupported + + Example: + >>> info = get_resource_info("Observation") + >>> print(info["description"]) + 'Clinical observations (vitals, labs)' + """ + return SUPPORTED_RESOURCES.get(resource_type, {}) + + +def print_supported_resources() -> None: + """Print user-friendly list of supported FHIR resources for conversion. + + Example: + >>> from healthchain.fhir.converters import print_supported_resources + >>> print_supported_resources() + Supported FHIR Resources for ML Dataset Conversion: + + ✓ Patient + Patient demographics (age, gender) + Columns: age, gender + ... + """ + print("Supported FHIR Resources for ML Dataset Conversion:\n") + for resource, info in SUPPORTED_RESOURCES.items(): + print(f" ✓ {resource}") + print(f" {info['description']}") + if isinstance(info["output_columns"], list): + print(f" Columns: {', '.join(info['output_columns'])}") + else: + print(f" Columns: {info['output_columns']}") + if info.get("options"): + print(f" Options: {', '.join(info['options'])}") + print() + + +def _get_field(resource: Dict, field_name: str, default=None): + """Get field value from a dictionary.""" + return resource.get(field_name, default) + + +def _get_reference(field: Union[str, Dict[str, Any]]) -> Optional[str]: + """Extract reference string from a FHIR Reference field.""" + + if not field: + return None + + # Case 1: Already a string + if isinstance(field, str): + return field + + # Case 2: Dict with 'reference' field + return _get_field(field, "reference") + + +def extract_observation_value(observation: Dict) -> Optional[float]: + """Extract numeric value from an Observation dict. + + Handles different value types (valueQuantity, valueInteger, valueString) and + attempts to convert to float. + """ + + try: + value_quantity = _get_field(observation, "valueQuantity") + if value_quantity: + value = _get_field(value_quantity, "value") + if value is not None: + return float(value) + + value_int = _get_field(observation, "valueInteger") + if value_int is not None: + return float(value_int) + + value_str = _get_field(observation, "valueString") + if value_str: + return float(value_str) + + except (ValueError, TypeError): + pass + + return None + + +def extract_event_date( + resources: Dict[str, List[Any]], + source: str = "Observation", + strategy: str = "earliest", +) -> Optional[str]: + """Extract event date from patient resources for age calculation. + + Used primarily for MIMIC-IV on FHIR datasets where age is calculated + based on event dates rather than current date. + + Args: + resources: Dictionary of patient resources (from group_bundle_by_patient) + source: Which resource type to extract date from ("Observation" or "Encounter") + strategy: Which date to use ("earliest", "latest", "first") + + Returns: + Event date in ISO format, or None if no suitable date found + + Example: + >>> resources = {"observations": [obs1, obs2], "encounters": [enc1]} + >>> event_date = extract_event_date(resources, source="Observation", strategy="earliest") + """ + if source == "Observation": + items = resources.get("observations", []) + date_field = "effectiveDateTime" + elif source == "Encounter": + items = resources.get("encounters", []) + date_field = "period" + else: + return None + + if not items: + return None + + dates = [] + for item in items: + if source == "Encounter": + # Extract start date from period + period = _get_field(item, date_field) + if period: + start = _get_field(period, "start") + if start: + dates.append(start) + else: + # Direct date field + date_value = _get_field(item, date_field) + if date_value: + dates.append(date_value) + + if not dates: + return None + + # Apply strategy + if strategy == "earliest": + return min(dates) + elif strategy == "latest": + return max(dates) + elif strategy == "first": + return dates[0] + else: + return min(dates) # Default to earliest + + +def group_bundle_by_patient( + bundle: Union[Bundle, Dict[str, Any]], +) -> Dict[str, Dict[str, List[Any]]]: + """Group Bundle resources by patient reference. + + Organizes FHIR resources in a Bundle by their associated patient, making it easier + to process patient-centric data. Accepts both Pydantic Bundle objects and dicts, + converts to dict internally for performance. + + Args: + bundle: FHIR Bundle resource (Pydantic object or dict) + + Returns: + Dictionary mapping patient references to their resources: + { + "Patient/123": { + "patient": Patient resource dict, + "observations": [Observation dict, ...], + "conditions": [Condition dict, ...], + ... + } + } + """ + if not isinstance(bundle, dict): + bundle = bundle.model_dump() + + patient_data = defaultdict( + lambda: { + "patient": None, + "observations": [], + "conditions": [], + "medications": [], + "allergies": [], + "procedures": [], + "encounters": [], + "other": [], + } + ) + + # Get bundle entries + entries = _get_field(bundle, "entry") + if not entries: + return dict(patient_data) + + for entry in entries: + # Get resource from entry + resource = _get_field(entry, "resource") + if not resource: + continue + + resource_type = _get_field(resource, "resourceType") + resource_id = _get_field(resource, "id") + + if resource_type == "Patient": + patient_ref = f"Patient/{resource_id}" + patient_data[patient_ref]["patient"] = resource + + else: + # Get patient reference from resource + subject = _get_field(resource, "subject") + patient_field = _get_field(resource, "patient") + + patient_ref = _get_reference(subject) or _get_reference(patient_field) + + if patient_ref: + # Add to appropriate list based on resource type + if resource_type == "Observation": + patient_data[patient_ref]["observations"].append(resource) + elif resource_type == "Condition": + patient_data[patient_ref]["conditions"].append(resource) + elif resource_type == "MedicationStatement": + patient_data[patient_ref]["medications"].append(resource) + elif resource_type == "AllergyIntolerance": + patient_data[patient_ref]["allergies"].append(resource) + elif resource_type == "Procedure": + patient_data[patient_ref]["procedures"].append(resource) + elif resource_type == "Encounter": + patient_data[patient_ref]["encounters"].append(resource) + else: + patient_data[patient_ref]["other"].append(resource) + + return dict(patient_data) + + +def bundle_to_dataframe( + bundle: Union[Bundle, Dict[str, Any]], + config: Optional[BundleConverterConfig] = None, +) -> pd.DataFrame: + """Convert a FHIR Bundle to a pandas DataFrame. + + Converts FHIR resources to a tabular format with one row per patient. + Uses a configuration object to control which resources are processed and how. + + Args: + bundle: FHIR Bundle resource (object or dict) + config: BundleConverterConfig object specifying conversion behavior. + If None, uses default config (Patient + Observation with mean aggregation) + + Returns: + DataFrame with one row per patient and columns for each feature + + Example: + >>> from healthchain.fhir.converters import BundleConverterConfig + >>> + >>> # Default behavior + >>> df = bundle_to_dataframe(bundle) + >>> + >>> # Custom config + >>> config = BundleConverterConfig( + ... resources=["Patient", "Observation", "Condition"], + ... observation_aggregation="median", + ... age_calculation="event_date" + ... ) + >>> df = bundle_to_dataframe(bundle, config=config) + """ + # Use default config if not provided + if config is None: + config = BundleConverterConfig() + + # Group resources by patient + patient_data = group_bundle_by_patient(bundle) + + if not patient_data: + return pd.DataFrame() + + # Build rows for each patient + rows = [] + for patient_ref, resources in patient_data.items(): + row = {"patient_ref": patient_ref} + + # Process each requested resource type using registry + for resource_type in config.resources: + handler_info = SUPPORTED_RESOURCES.get(resource_type) + + if not handler_info: + # Skip unsupported resources gracefully (already warned by validator) + continue + + # Get handler function by name + handler_name = handler_info["handler"] + handler = globals()[handler_name] + + # Call handler with standardized signature + features = handler(resources, config) + if features: + row.update(features) + + rows.append(row) + + return pd.DataFrame(rows) + + +def _flatten_patient( + resources: Dict[str, Any], config: BundleConverterConfig +) -> Dict[str, Any]: + """Flatten patient demographics into feature columns. + + Args: + resources: Dictionary of patient resources + config: Converter configuration + + Returns: + Dictionary with age and gender features + """ + if not resources["patient"]: + return {} + + features = {} + patient = resources["patient"] + + birth_date = _get_field(patient, "birthDate") + gender = _get_field(patient, "gender") + + # Calculate age based on configuration + if config.age_calculation == "event_date": + event_date = extract_event_date( + resources, config.event_date_source, config.event_date_strategy + ) + features["age"] = calculate_age_from_event_date(birth_date, event_date) + else: + features["age"] = calculate_age_from_birthdate(birth_date) + + features["gender"] = encode_gender(gender) + + return features + + +def _flatten_observations( + resources: Dict[str, Any], config: BundleConverterConfig +) -> Dict[str, float]: + """Flatten observations into feature columns. + + Args: + resources: Dictionary of patient resources + config: Converter configuration + + Returns: + Dictionary with observation features + """ + observations = resources.get("observations", []) + aggregation = config.observation_aggregation + import numpy as np + + # Group observations by code + obs_by_code = defaultdict(list) + + for obs in observations: + code_field = _get_field(obs, "code") + if not code_field: + continue + + coding_array = _get_field(code_field, "coding") + if not coding_array or len(coding_array) == 0: + continue + + coding = coding_array[0] + code = _get_field(coding, "code") + display = _get_field(coding, "display") or code + system = _get_field(coding, "system") + + value = extract_observation_value(obs) + if value is not None: + obs_by_code[code].append( + { + "value": value, + "display": display, + "system": system, + } + ) + + # Aggregate and create feature columns + features = {} + for code, obs_list in obs_by_code.items(): + values = [item["value"] for item in obs_list] + display = obs_list[0]["display"] + + # Create column name: code_display + col_name = f"{code}_{display.replace(' ', '_')}" + + # Aggregate values + if aggregation == "mean": + features[col_name] = np.mean(values) + elif aggregation == "median": + features[col_name] = np.median(values) + elif aggregation == "max": + features[col_name] = np.max(values) + elif aggregation == "min": + features[col_name] = np.min(values) + elif aggregation == "last": + features[col_name] = values[-1] + else: + features[col_name] = np.mean(values) + + return features + + +def _flatten_conditions( + resources: Dict[str, Any], config: BundleConverterConfig +) -> Dict[str, int]: + """Flatten conditions into binary indicator columns. + + Args: + resources: Dictionary of patient resources + config: Converter configuration + + Returns: + Dictionary with condition indicator features + """ + conditions = resources.get("conditions", []) + features = {} + + for condition in conditions: + code_field = _get_field(condition, "code") + if not code_field: + continue + + coding_array = _get_field(code_field, "coding") + if not coding_array or len(coding_array) == 0: + continue + + # Get primary coding + coding = coding_array[0] + code = _get_field(coding, "code") + display = _get_field(coding, "display") or code + + # Create column name: condition_code_display + col_name = f"condition_{code}_{display.replace(' ', '_')}" + features[col_name] = 1 + + return features + + +def _flatten_medications( + resources: Dict[str, Any], config: BundleConverterConfig +) -> Dict[str, int]: + """Flatten medications into binary indicator columns. + + Args: + resources: Dictionary of patient resources + config: Converter configuration + + Returns: + Dictionary with medication indicator features + """ + medications = resources.get("medications", []) + features = {} + + for med in medications: + medication = _get_field(med, "medication") + if not medication: + continue + + med_concept = _get_field(medication, "concept") + if not med_concept: + continue + + coding_array = _get_field(med_concept, "coding") + if not coding_array or len(coding_array) == 0: + continue + + # Get primary coding + coding = coding_array[0] + code = _get_field(coding, "code") + display = _get_field(coding, "display") or code + + # Create column name: medication_code_display + col_name = f"medication_{code}_{display.replace(' ', '_')}" + features[col_name] = 1 + + return features diff --git a/healthchain/fhir/elementhelpers.py b/healthchain/fhir/elementhelpers.py new file mode 100644 index 00000000..c4b4532f --- /dev/null +++ b/healthchain/fhir/elementhelpers.py @@ -0,0 +1,109 @@ +"""FHIR element creation functions. + +This module provides convenience functions for creating FHIR elements that are used +as building blocks within FHIR resources (e.g., CodeableConcept, Attachment, Coding). +""" + +import logging +import base64 +import datetime + +from typing import Optional, List, Dict, Any +from fhir.resources.codeableconcept import CodeableConcept +from fhir.resources.codeablereference import CodeableReference +from fhir.resources.coding import Coding +from fhir.resources.attachment import Attachment + +logger = logging.getLogger(__name__) + + +def create_single_codeable_concept( + code: str, + display: Optional[str] = None, + system: Optional[str] = "http://snomed.info/sct", +) -> CodeableConcept: + """ + Create a minimal FHIR CodeableConcept with a single coding. + + Args: + code: REQUIRED. The code value from the code system + display: The display name for the code + system: The code system (default: SNOMED CT) + + Returns: + CodeableConcept: A FHIR CodeableConcept resource with a single coding + """ + return CodeableConcept(coding=[Coding(system=system, code=code, display=display)]) + + +def create_single_reaction( + code: str, + display: Optional[str] = None, + system: Optional[str] = "http://snomed.info/sct", + severity: Optional[str] = None, +) -> List[Dict[str, Any]]: + """Create a minimal FHIR Reaction with a single coding. + + Creates a FHIR Reaction object with a single manifestation coding. The manifestation + describes the clinical reaction that was observed. The severity indicates how severe + the reaction was. + + Args: + code: REQUIRED. The code value from the code system representing the reaction manifestation + display: The display name for the manifestation code + system: The code system for the manifestation code (default: SNOMED CT) + severity: The severity of the reaction (mild, moderate, severe) + + Returns: + A list containing a single FHIR Reaction dictionary with manifestation and severity fields + """ + return [ + { + "manifestation": [ + CodeableReference( + concept=CodeableConcept( + coding=[Coding(system=system, code=code, display=display)] + ) + ) + ], + "severity": severity, + } + ] + + +def create_single_attachment( + content_type: Optional[str] = None, + data: Optional[str] = None, + url: Optional[str] = None, + title: Optional[str] = "Attachment created by HealthChain", +) -> Attachment: + """Create a minimal FHIR Attachment. + + Creates a FHIR Attachment resource with basic fields. Either data or url should be provided. + If data is provided, it will be base64 encoded. + + Args: + content_type: The MIME type of the content + data: The actual data content to be base64 encoded + url: The URL where the data can be found + title: A title for the attachment (default: "Attachment created by HealthChain") + + Returns: + Attachment: A FHIR Attachment resource with basic metadata and content + """ + + if not data and not url: + logger.warning("No data or url provided for attachment") + + if data: + data = base64.b64encode(data.encode("utf-8")).decode("utf-8") + + return Attachment( + contentType=content_type, + data=data, + url=url, + title=title, + creation=datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%S%z" + ), + ) diff --git a/healthchain/fhir/readers.py b/healthchain/fhir/readers.py new file mode 100644 index 00000000..7d7bbd06 --- /dev/null +++ b/healthchain/fhir/readers.py @@ -0,0 +1,137 @@ +"""FHIR conversion and reading functions. + +This module provides functions for converting between different FHIR representations +and reading data from FHIR resources. +""" + +import logging +import importlib + +from typing import Optional, Dict, Any, List +from fhir.resources.resource import Resource +from fhir.resources.documentreference import DocumentReference + +logger = logging.getLogger(__name__) + + +def create_resource_from_dict( + resource_dict: Dict, resource_type: str +) -> Optional[Resource]: + """Create a FHIR resource instance from a dictionary + + Args: + resource_dict: Dictionary representation of the resource + resource_type: Type of FHIR resource to create + + Returns: + Optional[Resource]: FHIR resource instance or None if creation failed + """ + try: + resource_module = importlib.import_module( + f"fhir.resources.{resource_type.lower()}" + ) + resource_class = getattr(resource_module, resource_type) + return resource_class(**resource_dict) + except Exception as e: + logger.error(f"Failed to create FHIR resource: {str(e)}") + return None + + +def convert_prefetch_to_fhir_objects( + prefetch_dict: Dict[str, Any], +) -> Dict[str, Resource]: + """Convert a dictionary of FHIR resource dicts to FHIR Resource objects. + + Takes a prefetch dictionary where values may be either dict representations of FHIR + resources or already instantiated FHIR Resource objects, and ensures all values are + FHIR Resource objects. + + Args: + prefetch_dict: Dictionary mapping keys to FHIR resource dicts or objects + + Returns: + Dict[str, Resource]: Dictionary with same keys but all values as FHIR Resource objects + + Example: + >>> prefetch = { + ... "patient": {"resourceType": "Patient", "id": "123"}, + ... "condition": Condition(id="456", ...) + ... } + >>> fhir_objects = convert_prefetch_to_fhir_objects(prefetch) + >>> isinstance(fhir_objects["patient"], Patient) # True + >>> isinstance(fhir_objects["condition"], Condition) # True + """ + from fhir.resources import get_fhir_model_class + + result: Dict[str, Resource] = {} + + for key, resource_data in prefetch_dict.items(): + if isinstance(resource_data, dict): + # Convert dict to FHIR Resource object + resource_type = resource_data.get("resourceType") + if resource_type: + try: + resource_class = get_fhir_model_class(resource_type) + result[key] = resource_class(**resource_data) + except Exception as e: + logger.warning( + f"Failed to convert {resource_type} to FHIR object: {e}" + ) + result[key] = resource_data + else: + logger.warning( + f"No resourceType found for key '{key}', keeping as dict" + ) + result[key] = resource_data + elif isinstance(resource_data, Resource): + # Already a FHIR object + result[key] = resource_data + else: + logger.warning(f"Unexpected type for key '{key}': {type(resource_data)}") + result[key] = resource_data + + return result + + +def read_content_attachment( + document_reference: DocumentReference, + include_data: bool = True, +) -> Optional[List[Dict[str, Any]]]: + """Read the attachments in a human readable format from a FHIR DocumentReference content field. + + Args: + document_reference: The FHIR DocumentReference resource + include_data: Whether to include the data of the attachments. If true, the data will be also be decoded (default: True) + + Returns: + Optional[List[Dict[str, Any]]]: List of dictionaries containing attachment data and metadata, + or None if no attachments are found: + [ + { + "data": str, + "metadata": Dict[str, Any] + } + ] + """ + if not document_reference.content: + return None + + attachments = [] + for content in document_reference.content: + attachment = content.attachment + result = {} + + if include_data: + result["data"] = ( + attachment.url if attachment.url else attachment.data.decode("utf-8") + ) + + result["metadata"] = { + "content_type": attachment.contentType, + "title": attachment.title, + "creation": attachment.creation, + } + + attachments.append(result) + + return attachments diff --git a/healthchain/fhir/helpers.py b/healthchain/fhir/resourcehelpers.py similarity index 62% rename from healthchain/fhir/helpers.py rename to healthchain/fhir/resourcehelpers.py index 95114272..bb278cef 100644 --- a/healthchain/fhir/helpers.py +++ b/healthchain/fhir/resourcehelpers.py @@ -1,214 +1,43 @@ -"""Convenience functions for creating minimal FHIR resources. +"""FHIR resource creation and modification functions. + +This module provides convenience functions for creating and modifying FHIR resources. + Patterns: -- create_*(): create a new FHIR resource with sensible defaults - useful for dev, use with caution -- add_*(): add data to resources with list fields safely (e.g. coding) -- set_*(): set the field of specific resources with soft validation (e.g. category) -- read_*(): return a human readable format of the data in a resource (e.g. attachments) +- create_*(): create a new FHIR resource with sensible defaults +- set_*(): set specific fields of resources with soft validation +- add_*(): add data to resources safely + +Parameters marked REQUIRED are required by FHIR specification. """ import logging -import base64 import datetime -import uuid -import importlib -from typing import Optional, List, Dict, Any +from typing import List, Optional, Dict, Any +from fhir.resources.coding import Coding from fhir.resources.condition import Condition +from fhir.resources.identifier import Identifier from fhir.resources.medicationstatement import MedicationStatement from fhir.resources.allergyintolerance import AllergyIntolerance from fhir.resources.documentreference import DocumentReference -from fhir.resources.codeableconcept import CodeableConcept -from fhir.resources.codeablereference import CodeableReference -from fhir.resources.coding import Coding -from fhir.resources.attachment import Attachment +from fhir.resources.observation import Observation from fhir.resources.resource import Resource +from fhir.resources.riskassessment import RiskAssessment +from fhir.resources.patient import Patient +from fhir.resources.quantity import Quantity +from fhir.resources.codeableconcept import CodeableConcept from fhir.resources.reference import Reference from fhir.resources.meta import Meta +from healthchain.fhir.elementhelpers import ( + create_single_codeable_concept, + create_single_attachment, +) +from healthchain.fhir.utilities import _generate_id logger = logging.getLogger(__name__) -def _generate_id() -> str: - """Generate a unique ID prefixed with 'hc-'. - - Returns: - str: A unique ID string prefixed with 'hc-' - """ - return f"hc-{str(uuid.uuid4())}" - - -def create_resource_from_dict( - resource_dict: Dict, resource_type: str -) -> Optional[Resource]: - """Create a FHIR resource instance from a dictionary - - Args: - resource_dict: Dictionary representation of the resource - resource_type: Type of FHIR resource to create - - Returns: - Optional[Resource]: FHIR resource instance or None if creation failed - """ - try: - resource_module = importlib.import_module( - f"fhir.resources.{resource_type.lower()}" - ) - resource_class = getattr(resource_module, resource_type) - return resource_class(**resource_dict) - except Exception as e: - logger.error(f"Failed to create FHIR resource: {str(e)}") - return None - - -def convert_prefetch_to_fhir_objects( - prefetch_dict: Dict[str, Any], -) -> Dict[str, Resource]: - """Convert a dictionary of FHIR resource dicts to FHIR Resource objects. - - Takes a prefetch dictionary where values may be either dict representations of FHIR - resources or already instantiated FHIR Resource objects, and ensures all values are - FHIR Resource objects. - - Args: - prefetch_dict: Dictionary mapping keys to FHIR resource dicts or objects - - Returns: - Dict[str, Resource]: Dictionary with same keys but all values as FHIR Resource objects - - Example: - >>> prefetch = { - ... "patient": {"resourceType": "Patient", "id": "123"}, - ... "condition": Condition(id="456", ...) - ... } - >>> fhir_objects = convert_prefetch_to_fhir_objects(prefetch) - >>> isinstance(fhir_objects["patient"], Patient) # True - >>> isinstance(fhir_objects["condition"], Condition) # True - """ - from fhir.resources import get_fhir_model_class - - result: Dict[str, Resource] = {} - - for key, resource_data in prefetch_dict.items(): - if isinstance(resource_data, dict): - # Convert dict to FHIR Resource object - resource_type = resource_data.get("resourceType") - if resource_type: - try: - resource_class = get_fhir_model_class(resource_type) - result[key] = resource_class(**resource_data) - except Exception as e: - logger.warning( - f"Failed to convert {resource_type} to FHIR object: {e}" - ) - result[key] = resource_data - else: - logger.warning( - f"No resourceType found for key '{key}', keeping as dict" - ) - result[key] = resource_data - elif isinstance(resource_data, Resource): - # Already a FHIR object - result[key] = resource_data - else: - logger.warning(f"Unexpected type for key '{key}': {type(resource_data)}") - result[key] = resource_data - - return result - - -def create_single_codeable_concept( - code: str, - display: Optional[str] = None, - system: Optional[str] = "http://snomed.info/sct", -) -> CodeableConcept: - """ - Create a minimal FHIR CodeableConcept with a single coding. - - Args: - code: REQUIRED. The code value from the code system - display: The display name for the code - system: The code system (default: SNOMED CT) - - Returns: - CodeableConcept: A FHIR CodeableConcept resource with a single coding - """ - return CodeableConcept(coding=[Coding(system=system, code=code, display=display)]) - - -def create_single_reaction( - code: str, - display: Optional[str] = None, - system: Optional[str] = "http://snomed.info/sct", - severity: Optional[str] = None, -) -> List[Dict[str, Any]]: - """Create a minimal FHIR Reaction with a single coding. - - Creates a FHIR Reaction object with a single manifestation coding. The manifestation - describes the clinical reaction that was observed. The severity indicates how severe - the reaction was. - - Args: - code: REQUIRED. The code value from the code system representing the reaction manifestation - display: The display name for the manifestation code - system: The code system for the manifestation code (default: SNOMED CT) - severity: The severity of the reaction (mild, moderate, severe) - - Returns: - A list containing a single FHIR Reaction dictionary with manifestation and severity fields - """ - return [ - { - "manifestation": [ - CodeableReference( - concept=CodeableConcept( - coding=[Coding(system=system, code=code, display=display)] - ) - ) - ], - "severity": severity, - } - ] - - -def create_single_attachment( - content_type: Optional[str] = None, - data: Optional[str] = None, - url: Optional[str] = None, - title: Optional[str] = "Attachment created by HealthChain", -) -> Attachment: - """Create a minimal FHIR Attachment. - - Creates a FHIR Attachment resource with basic fields. Either data or url should be provided. - If data is provided, it will be base64 encoded. - - Args: - content_type: The MIME type of the content - data: The actual data content to be base64 encoded - url: The URL where the data can be found - title: A title for the attachment (default: "Attachment created by HealthChain") - - Returns: - Attachment: A FHIR Attachment resource with basic metadata and content - """ - - if not data and not url: - logger.warning("No data or url provided for attachment") - - if data: - data = base64.b64encode(data.encode("utf-8")).decode("utf-8") - - return Attachment( - contentType=content_type, - data=data, - url=url, - title=title, - creation=datetime.datetime.now(datetime.timezone.utc).strftime( - "%Y-%m-%dT%H:%M:%S%z" - ), - ) - - def create_condition( subject: str, clinical_status: str = "active", @@ -321,6 +150,187 @@ def create_allergy_intolerance( return allergy +def create_value_quantity_observation( + code: str, + value: float, + unit: str, + status: str = "final", + subject: Optional[str] = None, + system: str = "http://loinc.org", + display: Optional[str] = None, + effective_datetime: Optional[str] = None, +) -> Observation: + """ + Create a minimal FHIR Observation for vital signs or laboratory values. + If you need to create a more complex observation, use the FHIR Observation resource directly. + https://hl7.org/fhir/observation.html + + Args: + status: REQUIRED. The status of the observation (default: "final") + code: REQUIRED. The observation code (e.g., LOINC code for the measurement) + value: The numeric value of the observation + unit: The unit of measure (e.g., "beats/min", "mg/dL") + system: The code system for the observation code (default: LOINC) + display: The display name for the observation code + effective_datetime: When the observation was made (ISO format). Uses current time if not provided. + subject: Reference to the patient (e.g. "Patient/123") + + Returns: + Observation: A FHIR Observation resource with an auto-generated ID prefixed with 'hc-' + """ + if not effective_datetime: + effective_datetime = datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%S%z" + ) + subject_ref = None + if subject is not None: + subject_ref = Reference(reference=subject) + + observation = Observation( + id=_generate_id(), + status=status, + code=create_single_codeable_concept(code, display, system), + subject=subject_ref, + effectiveDateTime=effective_datetime, + valueQuantity=Quantity( + value=value, unit=unit, system="http://unitsofmeasure.org", code=unit + ), + ) + + return observation + + +def create_patient( + gender: Optional[str] = None, + birth_date: Optional[str] = None, + identifier: Optional[str] = None, + identifier_system: Optional[str] = "http://hospital.example.org", +) -> Patient: + """ + Create a minimal FHIR Patient resource with basic gender and birthdate + If you need to create a more complex patient, use the FHIR Patient resource directly + https://hl7.org/fhir/patient.html (No required fields). + + Args: + gender: Administrative gender (male, female, other, unknown) + birth_date: Birth date in YYYY-MM-DD format + identifier: Optional identifier value for the patient (e.g., MRN) + identifier_system: The system for the identifier (default: "http://hospital.example.org") + + Returns: + Patient: A FHIR Patient resource with an auto-generated ID prefixed with 'hc-' + """ + patient_id = _generate_id() + + patient_data = {"id": patient_id} + + if birth_date: + patient_data["birthDate"] = birth_date + + if gender: + patient_data["gender"] = gender.lower() + + if identifier: + patient_data["identifier"] = [ + Identifier( + system=identifier_system, + value=identifier, + ) + ] + + patient = Patient(**patient_data) + return patient + + +def create_risk_assessment_from_prediction( + subject: str, + prediction: Dict[str, Any], + status: str = "final", + method: Optional[CodeableConcept] = None, + basis: Optional[List[Reference]] = None, + comment: Optional[str] = None, + occurrence_datetime: Optional[str] = None, +) -> RiskAssessment: + """ + Create a FHIR RiskAssessment from ML model prediction output. + If you need to create a more complex risk assessment, use the FHIR RiskAssessment resource directly. + https://hl7.org/fhir/riskassessment.html + + Args: + subject: REQUIRED. Reference to the patient (e.g. "Patient/123") + prediction: Dictionary containing prediction details with keys: + - outcome: CodeableConcept or dict with code, display, system for the predicted outcome + - probability: float between 0 and 1 representing the risk probability + - qualitative_risk: Optional str indicating risk level (e.g., "high", "moderate", "low") + status: REQUIRED. The status of the assessment (default: "final") + method: Optional CodeableConcept describing the assessment method/model used + basis: Optional list of References to observations or other resources used as input + comment: Optional text comment about the assessment + + occurrence_datetime: When the assessment was made (ISO format). Uses current time if not provided. + + Returns: + RiskAssessment: A FHIR RiskAssessment resource with an auto-generated ID prefixed with 'hc-' + + Example: + >>> prediction = { + ... "outcome": {"code": "A41.9", "display": "Sepsis", "system": "http://hl7.org/fhir/sid/icd-10"}, + ... "probability": 0.85, + ... "qualitative_risk": "high" + ... } + >>> risk = create_risk_assessment("Patient/123", prediction) + """ + if not occurrence_datetime: + occurrence_datetime = datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%S%z" + ) + + outcome = prediction.get("outcome") + if isinstance(outcome, dict): + outcome_concept = create_single_codeable_concept( + code=outcome["code"], + display=outcome.get("display"), + system=outcome.get("system", "http://snomed.info/sct"), + ) + else: + outcome_concept = outcome + + prediction_data = { + "outcome": outcome_concept, + } + + if "probability" in prediction: + prediction_data["probabilityDecimal"] = prediction["probability"] + + if "qualitative_risk" in prediction: + prediction_data["qualitativeRisk"] = create_single_codeable_concept( + code=prediction["qualitative_risk"], + display=prediction["qualitative_risk"].capitalize(), + system="http://terminology.hl7.org/CodeSystem/risk-probability", + ) + + risk_assessment_data = { + "id": _generate_id(), + "status": status, + "subject": Reference(reference=subject), + "occurrenceDateTime": occurrence_datetime, + "prediction": [prediction_data], + } + + if method: + risk_assessment_data["method"] = method + + if basis: + risk_assessment_data["basis"] = basis + + if comment: + risk_assessment_data["note"] = [{"text": comment}] + + risk_assessment = RiskAssessment(**risk_assessment_data) + + return risk_assessment + + def create_document_reference( data: Optional[Any] = None, url: Optional[str] = None, @@ -566,47 +576,3 @@ def add_coding_to_codeable_concept( codeable_concept.coding.append(Coding(system=system, code=code, display=display)) return codeable_concept - - -def read_content_attachment( - document_reference: DocumentReference, - include_data: bool = True, -) -> Optional[List[Dict[str, Any]]]: - """Read the attachments in a human readable format from a FHIR DocumentReference content field. - - Args: - document_reference: The FHIR DocumentReference resource - include_data: Whether to include the data of the attachments. If true, the data will be also be decoded (default: True) - - Returns: - Optional[List[Dict[str, Any]]]: List of dictionaries containing attachment data and metadata, - or None if no attachments are found: - [ - { - "data": str, - "metadata": Dict[str, Any] - } - ] - """ - if not document_reference.content: - return None - - attachments = [] - for content in document_reference.content: - attachment = content.attachment - result = {} - - if include_data: - result["data"] = ( - attachment.url if attachment.url else attachment.data.decode("utf-8") - ) - - result["metadata"] = { - "content_type": attachment.contentType, - "title": attachment.title, - "creation": attachment.creation, - } - - attachments.append(result) - - return attachments diff --git a/healthchain/fhir/utilities.py b/healthchain/fhir/utilities.py new file mode 100644 index 00000000..31788d42 --- /dev/null +++ b/healthchain/fhir/utilities.py @@ -0,0 +1,117 @@ +"""FHIR utility functions. + +This module provides utility functions for common operations like ID generation, +age calculation, and gender encoding. +""" + +import datetime +import uuid +from typing import Optional + + +def _generate_id() -> str: + """Generate a unique ID prefixed with 'hc-'. + + Returns: + str: A unique ID string prefixed with 'hc-' + """ + return f"hc-{str(uuid.uuid4())}" + + +def calculate_age_from_birthdate(birth_date: str) -> Optional[int]: + """Calculate age in years from a birth date string. + + Args: + birth_date: Birth date in ISO format (YYYY-MM-DD or full ISO datetime) + + Returns: + Age in years, or None if birth date is invalid + """ + if not birth_date: + return None + + try: + if isinstance(birth_date, str): + # Remove timezone info for simpler parsing + birth_date_clean = birth_date.replace("Z", "").split("T")[0] + birth_dt = datetime.datetime.strptime(birth_date_clean, "%Y-%m-%d") + else: + birth_dt = birth_date + + # Calculate age + today = datetime.datetime.now() + age = today.year - birth_dt.year + + # Adjust if birthday hasn't occurred this year + if (today.month, today.day) < (birth_dt.month, birth_dt.day): + age -= 1 + + return age + except (ValueError, AttributeError, TypeError): + return None + + +def calculate_age_from_event_date(birth_date: str, event_date: str) -> Optional[int]: + """Calculate age in years from birth date and event date (MIMIC-IV style). + + Uses the formula: age = year(eventDate) - year(birthDate) + This matches MIMIC-IV on FHIR de-identified age calculation. + + Args: + birth_date: Birth date in ISO format (YYYY-MM-DD or full ISO datetime) + event_date: Event date in ISO format (YYYY-MM-DD or full ISO datetime) + + Returns: + Age in years based on year difference, or None if dates are invalid + + Example: + >>> calculate_age_from_event_date("1990-06-15", "2020-03-10") + 30 + """ + if not birth_date or not event_date: + return None + + try: + # Parse birth date + if isinstance(birth_date, str): + birth_date_clean = birth_date.replace("Z", "").split("T")[0] + birth_year = int(birth_date_clean.split("-")[0]) + else: + birth_year = birth_date.year + + # Parse event date + if isinstance(event_date, str): + event_date_clean = event_date.replace("Z", "").split("T")[0] + event_year = int(event_date_clean.split("-")[0]) + else: + event_year = event_date.year + + # MIMIC-IV style: simple year difference + age = event_year - birth_year + + return age + except (ValueError, AttributeError, TypeError, IndexError): + return None + + +def encode_gender(gender: str) -> Optional[int]: + """Encode gender as integer for ML models. + + Standard encoding: Male=1, Female=0, Other/Unknown=None + + Args: + gender: Gender string (case-insensitive) + + Returns: + Encoded gender (1 for male, 0 for female, None for other/unknown) + """ + if not gender: + return None + + gender_lower = gender.lower() + if gender_lower in ["male", "m"]: + return 1 + elif gender_lower in ["female", "f"]: + return 0 + else: + return None diff --git a/healthchain/io/__init__.py b/healthchain/io/__init__.py index 52bc38ef..2e33328c 100644 --- a/healthchain/io/__init__.py +++ b/healthchain/io/__init__.py @@ -1,15 +1,32 @@ -from .containers import DataContainer, Document, Tabular -from .base import BaseAdapter +"""IO module for data containers, adapters, and mappers. + +This module provides: +- Containers: Data structures for documents and datasets +- Adapters: Convert external formats (CDA, CDS Hooks) to/from HealthChain +- Mappers: Transform clinical data between formats (FHIR to pandas, FHIR versions) +""" + +from .containers import DataContainer, Document, Dataset, FeatureSchema +from .adapters.base import BaseAdapter from .adapters.cdaadapter import CdaAdapter from .adapters.cdsfhiradapter import CdsFhirAdapter +from .mappers import BaseMapper, FHIRFeatureMapper +from .types import TimeWindow, ValidationResult __all__ = [ # Containers "DataContainer", "Document", - "Tabular", + "Dataset", + "FeatureSchema", # Adapters "BaseAdapter", "CdaAdapter", "CdsFhirAdapter", + # Mappers + "BaseMapper", + "FHIRFeatureMapper", + # Types + "TimeWindow", + "ValidationResult", ] diff --git a/healthchain/io/adapters/__init__.py b/healthchain/io/adapters/__init__.py index 6fb1012a..cc698d79 100644 --- a/healthchain/io/adapters/__init__.py +++ b/healthchain/io/adapters/__init__.py @@ -7,5 +7,6 @@ from .cdaadapter import CdaAdapter from .cdsfhiradapter import CdsFhirAdapter +from .base import BaseAdapter -__all__ = ["CdaAdapter", "CdsFhirAdapter"] +__all__ = ["CdaAdapter", "CdsFhirAdapter", "BaseAdapter"] diff --git a/healthchain/io/base.py b/healthchain/io/adapters/base.py similarity index 100% rename from healthchain/io/base.py rename to healthchain/io/adapters/base.py diff --git a/healthchain/io/adapters/cdaadapter.py b/healthchain/io/adapters/cdaadapter.py index 8271f52e..91d5d456 100644 --- a/healthchain/io/adapters/cdaadapter.py +++ b/healthchain/io/adapters/cdaadapter.py @@ -2,7 +2,7 @@ from typing import Optional from healthchain.io.containers import Document -from healthchain.io.base import BaseAdapter +from healthchain.io.adapters.base import BaseAdapter from healthchain.interop import create_interop, FormatType, InteropEngine from healthchain.models.requests.cdarequest import CdaRequest from healthchain.models.responses.cdaresponse import CdaResponse diff --git a/healthchain/io/adapters/cdsfhiradapter.py b/healthchain/io/adapters/cdsfhiradapter.py index 7d3be0e7..42d36572 100644 --- a/healthchain/io/adapters/cdsfhiradapter.py +++ b/healthchain/io/adapters/cdsfhiradapter.py @@ -4,7 +4,7 @@ from fhir.resources.documentreference import DocumentReference from healthchain.io.containers import Document -from healthchain.io.base import BaseAdapter +from healthchain.io.adapters.base import BaseAdapter from healthchain.models.requests.cdsrequest import CDSRequest from healthchain.models.responses.cdsresponse import CDSResponse from healthchain.fhir import read_content_attachment, convert_prefetch_to_fhir_objects diff --git a/healthchain/io/containers/__init__.py b/healthchain/io/containers/__init__.py index 05db1a95..d11372a8 100644 --- a/healthchain/io/containers/__init__.py +++ b/healthchain/io/containers/__init__.py @@ -1,5 +1,6 @@ from .base import DataContainer, BaseDocument from .document import Document -from .tabular import Tabular +from .dataset import Dataset +from .featureschema import FeatureSchema -__all__ = ["DataContainer", "BaseDocument", "Document", "Tabular"] +__all__ = ["DataContainer", "BaseDocument", "Document", "Dataset", "FeatureSchema"] diff --git a/healthchain/io/containers/dataset.py b/healthchain/io/containers/dataset.py new file mode 100644 index 00000000..39740be5 --- /dev/null +++ b/healthchain/io/containers/dataset.py @@ -0,0 +1,307 @@ +import pandas as pd +import numpy as np + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterator, List, Union, Optional + +from fhir.resources.bundle import Bundle +from fhir.resources.riskassessment import RiskAssessment + +from healthchain.io.containers.base import DataContainer +from healthchain.io.containers.featureschema import FeatureSchema +from healthchain.io.mappers.fhirfeaturemapper import FHIRFeatureMapper +from healthchain.io.types import ValidationResult +from healthchain.fhir.resourcehelpers import ( + create_risk_assessment_from_prediction, + create_single_codeable_concept, +) + + +@dataclass +class Dataset(DataContainer[pd.DataFrame]): + """ + A container for tabular data optimized for ML inference, lightweight wrapper around a pandas DataFrame. + + Methods: + from_csv: Load Dataset from CSV. + from_dict: Load Dataset from dict. + from_fhir_bundle: Create Dataset from FHIR Bundle and schema. + to_csv: Save Dataset to CSV. + to_risk_assessment: Convert predictions to FHIR RiskAssessment. + """ + + def __post_init__(self): + if not isinstance(self.data, pd.DataFrame): + raise TypeError("data must be a pandas DataFrame") + + @property + def columns(self) -> List[str]: + return list(self.data.columns) + + @property + def index(self) -> pd.Index: + return self.data.index + + @property + def dtypes(self) -> Dict[str, str]: + return {col: str(dtype) for col, dtype in self.data.dtypes.items()} + + def column_count(self) -> int: + return len(self.columns) + + def row_count(self) -> int: + return len(self.data) + + def get_dtype(self, column: str) -> str: + return str(self.data[column].dtype) + + def __iter__(self) -> Iterator[str]: + return iter(self.columns) + + def __len__(self) -> int: + return self.row_count() + + def describe(self) -> str: + return f"Dataset with {self.column_count()} columns and {self.row_count()} rows" + + def remove_column(self, name: str) -> None: + self.data.drop(columns=[name], inplace=True) + + @classmethod + def from_csv(cls, path: str, **kwargs) -> "Dataset": + return cls(pd.read_csv(path, **kwargs)) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Dataset": + df = pd.DataFrame(data["data"]) + return cls(df) + + def to_csv(self, path: str, **kwargs) -> None: + self.data.to_csv(path, **kwargs) + + @classmethod + def from_fhir_bundle( + cls, + bundle: Union[Bundle, Dict[str, Any]], + schema: Union[str, Path, FeatureSchema], + aggregation: str = "mean", + ) -> "Dataset": + """Create Dataset from a FHIR Bundle using a feature schema. + + Extracts features from FHIR resources according to the schema specification, + converting FHIR data to a pandas DataFrame suitable for ML inference. + + Args: + bundle: FHIR Bundle resource (object or dict) + schema: FeatureSchema object, or path to YAML schema file + aggregation: How to aggregate multiple observation values (default: "mean") + Options: "mean", "median", "max", "min", "last" (default: "mean") + + Returns: + Dataset container with extracted features + + Example: + >>> from fhir.resources.bundle import Bundle + >>> bundle = Bundle(**patient_data) + >>> dataset = Dataset.from_fhir_bundle( + ... bundle, + ... schema="healthchain/configs/features/sepsis_vitals.yaml" + ... ) + >>> df = dataset.data + """ + # Load schema if path provided + if isinstance(schema, (str, Path)): + schema = FeatureSchema.from_yaml(schema) + + # Extract features using mapper + mapper = FHIRFeatureMapper(schema) + df = mapper.extract_features(bundle, aggregation=aggregation) + + return cls(df) + + def validate( + self, schema: FeatureSchema, raise_on_error: bool = False + ) -> ValidationResult: + """Validate DataFrame against a feature schema. + + Checks that required features are present and have correct data types. + + Args: + schema: FeatureSchema to validate against + raise_on_error: Whether to raise exception on validation failure + + Returns: + ValidationResult with validation status and details + + Raises: + ValueError: If raise_on_error is True and validation fails + + Example: + >>> schema = FeatureSchema.from_yaml("configs/features/sepsis_vitals.yaml") + >>> result = dataset.validate(schema) + >>> if not result.valid: + ... print(result.errors) + """ + result = ValidationResult(valid=True) + + # Check for missing required features + required = schema.get_required_features() + missing = [f for f in required if f not in self.data.columns] + + for feature in missing: + result.add_missing_feature(feature) + + # Check data types for present features + for feature_name, mapping in schema.features.items(): + if feature_name in self.data.columns: + actual_dtype = str(self.data[feature_name].dtype) + expected_dtype = mapping.dtype + + # Check for type mismatches (allow some flexibility) + if not self._dtypes_compatible(actual_dtype, expected_dtype): + result.add_type_mismatch(feature_name, expected_dtype, actual_dtype) + + # Warn about optional missing features + optional = set(schema.get_feature_names()) - set(required) + missing_optional = [f for f in optional if f not in self.data.columns] + + for feature in missing_optional: + result.add_warning(f"Optional feature '{feature}' is missing") + + if raise_on_error and not result.valid: + raise ValueError(str(result)) + + return result + + def _dtypes_compatible(self, actual: str, expected: str) -> bool: + """Check if actual dtype is compatible with expected dtype. + + Args: + actual: Actual dtype string + expected: Expected dtype string + + Returns: + True if dtypes are compatible + """ + # Handle numeric types flexibly + numeric_types = {"int64", "int32", "float64", "float32"} + if expected in numeric_types and actual in numeric_types: + return True + + # Exact match for non-numeric types + return actual == expected + + def to_risk_assessment( + self, + predictions: np.ndarray, + probabilities: np.ndarray, + outcome_code: str, + outcome_display: str, + outcome_system: str = "http://hl7.org/fhir/sid/icd-10", + model_name: Optional[str] = None, + model_version: Optional[str] = None, + high_threshold: float = 0.7, + moderate_threshold: float = 0.4, + ) -> List[RiskAssessment]: + """Convert model predictions to FHIR RiskAssessment resources. + + Creates RiskAssessment resources from ML model output, suitable for + including in FHIR Bundles or sending to FHIR servers. + + Args: + predictions: Binary predictions array (0/1) + probabilities: Probability scores array (0-1) + outcome_code: Code for the predicted outcome (e.g., "A41.9" for sepsis) + outcome_display: Display text for the outcome (e.g., "Sepsis") + outcome_system: Code system for the outcome (default: ICD-10) + model_name: Name of the ML model (optional) + model_version: Version of the ML model (optional) + high_threshold: Threshold for high risk (default: 0.7) + moderate_threshold: Threshold for moderate risk (default: 0.4) + + Returns: + List of RiskAssessment resources, one per patient + + Example: + >>> predictions = np.array([0, 1, 0]) + >>> probabilities = np.array([0.15, 0.85, 0.32]) + >>> risk_assessments = dataset.to_risk_assessment( + ... predictions, + ... probabilities, + ... outcome_code="A41.9", + ... outcome_display="Sepsis, unspecified", + ... model_name="RandomForest", + ... model_version="1.0" + ... ) + """ + if len(predictions) != len(self.data): + raise ValueError( + f"Predictions length ({len(predictions)}) must match " + f"DataFrame length ({len(self.data)})" + ) + + if len(probabilities) != len(self.data): + raise ValueError( + f"Probabilities length ({len(probabilities)}) must match " + f"DataFrame length ({len(self.data)})" + ) + + risk_assessments = [] + + # Get patient references + if "patient_ref" not in self.data.columns: + raise ValueError("DataFrame must have 'patient_ref' column") + + for idx, row in self.data.iterrows(): + patient_ref = row["patient_ref"] + prediction = int(predictions[idx]) + probability = float(probabilities[idx]) + + # Determine qualitative risk + if probability >= high_threshold: + qualitative_risk = "high" + elif probability >= moderate_threshold: + qualitative_risk = "moderate" + else: + qualitative_risk = "low" + + # Build prediction dict + prediction_dict = { + "outcome": { + "code": outcome_code, + "display": outcome_display, + "system": outcome_system, + }, + "probability": probability, + "qualitative_risk": qualitative_risk, + } + + # Create method CodeableConcept if model info provided + method = None + if model_name: + method = create_single_codeable_concept( + code=model_name, + display=f"{model_name} v{model_version}" + if model_version + else model_name, + system="https://healthchain.github.io/ml-models", + ) + + # Create comment with prediction details + comment = ( + f"ML prediction: {'Positive' if prediction == 1 else 'Negative'} " + f"(probability: {probability:.2%}, risk: {qualitative_risk})" + ) + + # Create RiskAssessment + risk_assessment = create_risk_assessment_from_prediction( + subject=patient_ref, + prediction=prediction_dict, + method=method, + comment=comment, + ) + + risk_assessments.append(risk_assessment) + + return risk_assessments diff --git a/healthchain/io/containers/featureschema.py b/healthchain/io/containers/featureschema.py new file mode 100644 index 00000000..6504d4ef --- /dev/null +++ b/healthchain/io/containers/featureschema.py @@ -0,0 +1,220 @@ +"""Feature schema definitions for FHIR to Dataset data conversion. + +This module provides classes to define and manage feature schemas that map +FHIR resources to pandas DataFrame columns for ML model deployment. +""" + +import yaml +from pathlib import Path +from typing import Dict, List, Optional, Union, Any +from pydantic import BaseModel, field_validator, ConfigDict, model_validator + + +class FeatureMapping(BaseModel): + """Maps a single feature to its FHIR source.""" + + name: str + fhir_resource: str + code: Optional[str] = None + code_system: Optional[str] = None + field: Optional[str] = None + transform: Optional[str] = None + dtype: str = "float64" + required: bool = True + unit: Optional[str] = None + display: Optional[str] = None + + model_config = ConfigDict(extra="allow") + + @model_validator(mode="after") + def validate_resource_requirements(self) -> "FeatureMapping": + """Validate the feature mapping configuration based on resource type.""" + if self.fhir_resource == "Observation": + if not self.code: + raise ValueError( + f"Feature '{self.name}': Observation resources require a 'code'" + ) + if not self.code_system: + raise ValueError( + f"Feature '{self.name}': Observation resources require a 'code_system'" + ) + elif self.fhir_resource == "Patient": + if not self.field: + raise ValueError( + f"Feature '{self.name}': Patient resources require a 'field'" + ) + return self + + @classmethod + def from_dict(cls, name: str, data: Dict[str, Any]) -> "FeatureMapping": + """Create a FeatureMapping from a dictionary. + + Args: + name: The feature name + data: Dictionary containing feature configuration + + Returns: + FeatureMapping instance + """ + return cls(name=name, **data) + + +class FeatureSchema(BaseModel): + """Schema defining how to extract features from FHIR resources.""" + + name: str + version: str + features: Dict[str, FeatureMapping] = {} + description: Optional[str] = None + model_info: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None + + model_config = ConfigDict(extra="allow") + + @field_validator("features", mode="before") + @classmethod + def convert_feature_dicts(cls, v): + """Convert feature dicts to FeatureMapping objects if needed.""" + if v and isinstance(v, dict): + # Check if values are dicts (need conversion) or already FeatureMapping + if v and isinstance(list(v.values())[0], dict): + return { + name: FeatureMapping.from_dict(name, mapping) + for name, mapping in v.items() + } + return v + + @classmethod + def from_yaml(cls, path: Union[str, Path]) -> "FeatureSchema": + """Load schema from a YAML file. + + Args: + path: Path to the YAML file + + Returns: + FeatureSchema instance + + Example: + >>> schema = FeatureSchema.from_yaml("configs/features/sepsis_vitals.yaml") + """ + path = Path(path) + with open(path, "r") as f: + data = yaml.safe_load(f) + + return cls.model_validate(data) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FeatureSchema": + """Create a FeatureSchema from a dictionary. + + Args: + data: Dictionary containing schema configuration + + Returns: + FeatureSchema instance + """ + return cls.model_validate(data) + + def to_dict(self) -> Dict[str, Any]: + """Convert schema to dictionary format. + + Returns: + Dictionary representation of the schema + """ + result = { + "name": self.name, + "version": self.version, + "description": self.description, + "model_info": self.model_info, + "features": { + name: { + k: v + for k, v in mapping.model_dump().items() + if k != "name" and v is not None + } + for name, mapping in self.features.items() + }, + } + if self.metadata: + result["metadata"] = self.metadata + return result + + def to_yaml(self, path: Union[str, Path]) -> None: + """Save schema to a YAML file. + + Args: + path: Path where the YAML file will be saved + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + with open(path, "w") as f: + yaml.dump(self.to_dict(), f, default_flow_style=False, sort_keys=False) + + def get_feature_names(self) -> List[str]: + """Get list of feature names in order. + + Returns: + List of feature names + """ + return list(self.features.keys()) + + def get_required_features(self) -> List[str]: + """Get list of required feature names. + + Returns: + List of required feature names + """ + return [name for name, mapping in self.features.items() if mapping.required] + + def get_features_by_resource(self, resource_type: str) -> Dict[str, FeatureMapping]: + """Get all features mapped to a specific FHIR resource type. + + Args: + resource_type: FHIR resource type (e.g., "Observation", "Patient") + + Returns: + Dictionary of features for the specified resource type + """ + return { + name: mapping + for name, mapping in self.features.items() + if mapping.fhir_resource == resource_type + } + + def get_observation_codes(self) -> Dict[str, FeatureMapping]: + """Get all Observation features with their codes. + + Returns: + Dictionary mapping codes to feature mappings + """ + observations = self.get_features_by_resource("Observation") + return { + mapping.code: mapping for mapping in observations.values() if mapping.code + } + + def validate_dataframe_columns(self, columns: List[str]) -> Dict[str, Any]: + """Validate that a DataFrame has the expected columns. + + Args: + columns: List of column names from a DataFrame + + Returns: + Dictionary with validation results: + - valid: bool + - missing_required: List of missing required features + - unexpected: List of unexpected columns + """ + expected = set(self.get_feature_names()) + actual = set(columns) + required = set(self.get_required_features()) + + missing_required = list(required - actual) + unexpected = list(actual - expected) + + return { + "valid": len(missing_required) == 0, + "missing_required": missing_required, + "unexpected": unexpected, + "missing_optional": list((expected - required) - actual), + } diff --git a/healthchain/io/containers/tabular.py b/healthchain/io/containers/tabular.py deleted file mode 100644 index 809bb16e..00000000 --- a/healthchain/io/containers/tabular.py +++ /dev/null @@ -1,81 +0,0 @@ -import pandas as pd - -from dataclasses import dataclass -from typing import Any, Dict, Iterator, List - -from healthchain.io.containers.base import DataContainer - - -@dataclass -class Tabular(DataContainer[pd.DataFrame]): - """ - A container for tabular data, wrapping a pandas DataFrame. - - Attributes: - data (pd.DataFrame): The pandas DataFrame containing the tabular data. - - Methods: - __post_init__(): Validates that the data is a pandas DataFrame. - columns: Property that returns a list of column names. - index: Property that returns the DataFrame's index. - dtypes: Property that returns a dictionary of column names and their data types. - column_count(): Returns the number of columns in the DataFrame. - row_count(): Returns the number of rows in the DataFrame. - get_dtype(column: str): Returns the data type of a specific column. - __iter__(): Returns an iterator over the column names. - __len__(): Returns the number of rows in the DataFrame. - describe(): Returns a string description of the tabular data. - remove_column(name: str): Removes a column from the DataFrame. - from_csv(path: str, **kwargs): Class method to create a Tabular object from a CSV file. - from_dict(data: Dict[str, Any]): Class method to create a Tabular object from a dictionary. - to_csv(path: str, **kwargs): Saves the DataFrame to a CSV file. - """ - - def __post_init__(self): - if not isinstance(self.data, pd.DataFrame): - raise TypeError("data must be a pandas DataFrame") - - @property - def columns(self) -> List[str]: - return list(self.data.columns) - - @property - def index(self) -> pd.Index: - return self.data.index - - @property - def dtypes(self) -> Dict[str, str]: - return {col: str(dtype) for col, dtype in self.data.dtypes.items()} - - def column_count(self) -> int: - return len(self.columns) - - def row_count(self) -> int: - return len(self.data) - - def get_dtype(self, column: str) -> str: - return str(self.data[column].dtype) - - def __iter__(self) -> Iterator[str]: - return iter(self.columns) - - def __len__(self) -> int: - return self.row_count() - - def describe(self) -> str: - return f"Tabular data with {self.column_count()} columns and {self.row_count()} rows" - - def remove_column(self, name: str) -> None: - self.data.drop(columns=[name], inplace=True) - - @classmethod - def from_csv(cls, path: str, **kwargs) -> "Tabular": - return cls(pd.read_csv(path, **kwargs)) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Tabular": - df = pd.DataFrame(**data["data"]) - return cls(df) - - def to_csv(self, path: str, **kwargs) -> None: - self.data.to_csv(path, **kwargs) diff --git a/healthchain/io/mappers/__init__.py b/healthchain/io/mappers/__init__.py new file mode 100644 index 00000000..2bee9cff --- /dev/null +++ b/healthchain/io/mappers/__init__.py @@ -0,0 +1,12 @@ +"""Clinical data mappers for transformations between formats. + +Mappers handle transformations between different clinical data formats: +- FHIR to pandas (ML feature extraction) +- FHIR version migrations +- Clinical standard conversions (FHIR to OMOP) +""" + +from .base import BaseMapper +from .fhirfeaturemapper import FHIRFeatureMapper + +__all__ = ["BaseMapper", "FHIRFeatureMapper"] diff --git a/healthchain/io/mappers/base.py b/healthchain/io/mappers/base.py new file mode 100644 index 00000000..ddc4031c --- /dev/null +++ b/healthchain/io/mappers/base.py @@ -0,0 +1,48 @@ +"""Base mapper for clinical data transformations. + +Mappers handle transformations between different clinical data formats and +representations, including: +- Clinical standard conversions (FHIR versions, FHIR to OMOP) +- Feature extraction for ML (FHIR to pandas) +- Data model transformations +""" + +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +SourceType = TypeVar("SourceType") +TargetType = TypeVar("TargetType") + + +class BaseMapper(Generic[SourceType, TargetType], ABC): + """ + Abstract base class for clinical data mappers. + + Mappers transform clinical data between different formats and representations, + distinct from Adapters which handle external message format conversion. + + Use mappers for: + - FHIR to pandas feature extraction (ML workflows) + - FHIR version migrations (R4 to R5) + - Clinical standard conversions (FHIR to OMOP) + - Semantic and structural data transformations + + Example: + >>> class FHIRFeatureMapper(BaseMapper[Bundle, pd.DataFrame]): + ... def map(self, source: Bundle) -> pd.DataFrame: + ... # Extract features from FHIR Bundle + ... return dataframe + """ + + @abstractmethod + def map(self, source: SourceType) -> TargetType: + """ + Transform source data to target format. + + Args: + source: Source data in input format + + Returns: + Transformed data in target format + """ + pass diff --git a/healthchain/io/mappers/fhirfeaturemapper.py b/healthchain/io/mappers/fhirfeaturemapper.py new file mode 100644 index 00000000..24eba60f --- /dev/null +++ b/healthchain/io/mappers/fhirfeaturemapper.py @@ -0,0 +1,149 @@ +"""Schema-driven FHIR to feature mapper for ML model deployment. + +This module provides schema-driven feature extraction from FHIR Bundles, +using FeatureSchema to specify which features to extract and how to transform them. +""" + +from typing import Any, Dict, Union +import pandas as pd +import numpy as np + +from fhir.resources.bundle import Bundle + +from healthchain.io.containers.featureschema import FeatureSchema +from healthchain.io.mappers.base import BaseMapper +from healthchain.fhir.dataframe import bundle_to_dataframe, BundleConverterConfig + + +class FHIRFeatureMapper(BaseMapper[Bundle, pd.DataFrame]): + """Schema-driven mapper from FHIR resources to DataFrame features. + + Uses a FeatureSchema to extract and transform specific features from FHIR Bundles. + Leverages the generic bundle_to_dataframe converter and filters/renames columns + based on the schema. + """ + + def __init__(self, schema: FeatureSchema): + self.schema = schema + + def map(self, source: Bundle) -> pd.DataFrame: + """Transform FHIR Bundle to DataFrame using default aggregation. + Args: + source: FHIR Bundle resource + + Returns: + DataFrame with extracted features + """ + return self.extract_features(source) + + def extract_features( + self, + bundle: Union[Bundle, Dict[str, Any]], + aggregation: str = "mean", + ) -> pd.DataFrame: + """Extract features from a FHIR Bundle according to the schema. + + Args: + bundle: FHIR Bundle resource (object or dict) + aggregation: How to aggregate multiple observation values (default: "mean") + Options: "mean", "median", "max", "min", "last" (default: "mean") + + Returns: + DataFrame with one row per patient and columns matching schema features + + Example: + >>> from healthchain.io.containers.featureschema import FeatureSchema + >>> schema = FeatureSchema.from_yaml("configs/features/sepsis_vitals.yaml") + >>> mapper = FHIRFeatureMapper(schema) + >>> df = mapper.extract_features(bundle) + """ + # Build config from schema + config = self._build_config_from_schema(aggregation) + + # Extract features using config + df = bundle_to_dataframe(bundle, config=config) + + if df.empty: + return pd.DataFrame( + columns=["patient_ref"] + self.schema.get_feature_names() + ) + + # Map generic column names to schema feature names + df_mapped = self._map_columns_to_schema(df) + + # Ensure all schema features are present (fill missing with NaN) + feature_names = self.schema.get_feature_names() + for feature in feature_names: + if feature not in df_mapped.columns: + df_mapped[feature] = np.nan + + # Reorder columns to match schema + df_mapped = df_mapped[["patient_ref"] + feature_names] + + return df_mapped + + def _build_config_from_schema(self, aggregation: str) -> BundleConverterConfig: + """Build converter config from feature schema. + + Args: + aggregation: Aggregation method for observations + + Returns: + BundleConverterConfig configured based on schema + """ + # Determine which resources are needed from schema + resources = set() + for feature in self.schema.features.values(): + resources.add(feature.fhir_resource) + + # Extract age calculation metadata if present + metadata = self.schema.metadata or {} + age_calculation = metadata.get("age_calculation", "current_date") + event_date_source = metadata.get("event_date_source", "Observation") + event_date_strategy = metadata.get("event_date_strategy", "earliest") + + return BundleConverterConfig( + resources=list(resources), + observation_aggregation=aggregation, + age_calculation=age_calculation, + event_date_source=event_date_source, + event_date_strategy=event_date_strategy, + ) + + def _map_columns_to_schema(self, df: pd.DataFrame) -> pd.DataFrame: + """Map generic DataFrame columns to schema feature names. + + Args: + df: DataFrame from bundle_to_dataframe + + Returns: + DataFrame with columns renamed according to schema + """ + rename_map = {} + + # Map observation columns + obs_features = self.schema.get_features_by_resource("Observation") + for feature_name, mapping in obs_features.items(): + # Generic converter creates columns like: "8867-4_Heart_rate" + # Find matching column in df + for col in df.columns: + if col.startswith(mapping.code): + rename_map[col] = feature_name + break + + # Map patient columns (already have correct names from helpers) + patient_features = self.schema.get_features_by_resource("Patient") + for feature_name, mapping in patient_features.items(): + if mapping.field == "birthDate": + # Generic converter uses "age" + if "age" in df.columns: + rename_map["age"] = feature_name + elif mapping.field == "gender": + # Generic converter uses "gender" + if "gender" in df.columns: + rename_map["gender"] = feature_name + + # Rename columns + df_renamed = df.rename(columns=rename_map) + + return df_renamed diff --git a/healthchain/io/types.py b/healthchain/io/types.py new file mode 100644 index 00000000..260cab61 --- /dev/null +++ b/healthchain/io/types.py @@ -0,0 +1,139 @@ +"""Type definitions for IO operations. + +This module provides common types used across IO operations, particularly +for FHIR to Dataset data conversion. +""" + +from dataclasses import dataclass +from typing import List, Dict, Tuple +from pydantic import BaseModel, Field, field_validator + + +class TimeWindow(BaseModel): + """Defines a time window for filtering temporal data. + + Used to extract data from a specific time period relative to a reference point, + such as the first 24 hours after ICU admission. + + Attributes: + reference_field: Field name in the FHIR resource marking the reference time + (e.g., "intime" for ICU admission, "admittime" for hospital admission) + hours: Duration of the time window in hours from the reference point + offset_hours: Number of hours to offset from the reference point (default: 0) + For example, offset_hours=6 and hours=24 would capture hours 6-30 + + Example: + >>> # Capture first 24 hours after ICU admission + >>> window = TimeWindow(reference_field="intime", hours=24) + >>> + >>> # Capture hours 6-30 after admission + >>> window = TimeWindow(reference_field="admittime", hours=24, offset_hours=6) + """ + + reference_field: str + hours: int + offset_hours: int = Field(default=0) + + @field_validator("hours") + @classmethod + def hours_must_be_positive(cls, v): + if v <= 0: + raise ValueError("hours must be positive") + return v + + @field_validator("offset_hours") + @classmethod + def offset_hours_non_negative(cls, v): + if v < 0: + raise ValueError("offset_hours must be non-negative") + return v + + +@dataclass +class ValidationResult: + """Result of data validation operations. + + Attributes: + valid: Overall validation status + missing_features: List of required features that are missing + type_mismatches: Dictionary mapping feature names to (expected, actual) type tuples + warnings: List of non-critical validation warnings + errors: List of validation errors + + Example: + >>> result = ValidationResult( + ... valid=False, + ... missing_features=["heart_rate"], + ... type_mismatches={"age": ("int64", "object")}, + ... warnings=["Optional feature 'temperature' is missing"], + ... errors=["Required feature 'heart_rate' is missing"] + ... ) + """ + + valid: bool + missing_features: List[str] = None + type_mismatches: Dict[str, Tuple[str, str]] = None + warnings: List[str] = None + errors: List[str] = None + + def __post_init__(self): + """Initialize empty lists and dicts for None values.""" + if self.missing_features is None: + self.missing_features = [] + if self.type_mismatches is None: + self.type_mismatches = {} + if self.warnings is None: + self.warnings = [] + if self.errors is None: + self.errors = [] + + def __str__(self) -> str: + """Human-readable validation result.""" + if self.valid: + return "Validation passed" + + lines = ["Validation failed:"] + + if self.errors: + lines.append("\nErrors:") + for error in self.errors: + lines.append(f" - {error}") + + if self.missing_features: + lines.append("\nMissing features:") + for feature in self.missing_features: + lines.append(f" - {feature}") + + if self.type_mismatches: + lines.append("\nType mismatches:") + for feature, (expected, actual) in self.type_mismatches.items(): + lines.append(f" - {feature}: expected {expected}, got {actual}") + + if self.warnings: + lines.append("\nWarnings:") + for warning in self.warnings: + lines.append(f" - {warning}") + + return "\n".join(lines) + + def add_error(self, error: str) -> None: + """Add an error to the validation result.""" + self.errors.append(error) + self.valid = False + + def add_warning(self, warning: str) -> None: + """Add a warning to the validation result.""" + self.warnings.append(warning) + + def add_missing_feature(self, feature: str) -> None: + """Add a missing feature.""" + self.missing_features.append(feature) + self.errors.append(f"Required feature '{feature}' is missing") + self.valid = False + + def add_type_mismatch(self, feature: str, expected: str, actual: str) -> None: + """Add a type mismatch.""" + self.type_mismatches[feature] = (expected, actual) + self.errors.append( + f"Type mismatch for '{feature}': expected {expected}, got {actual}" + ) diff --git a/healthchain/sandbox/generators/conditiongenerators.py b/healthchain/sandbox/generators/conditiongenerators.py index 366b9984..09dd6354 100644 --- a/healthchain/sandbox/generators/conditiongenerators.py +++ b/healthchain/sandbox/generators/conditiongenerators.py @@ -4,7 +4,7 @@ from fhir.resources.reference import Reference from fhir.resources.condition import ConditionStage, ConditionParticipant -from healthchain.fhir.helpers import create_single_codeable_concept, create_condition +from healthchain.fhir import create_single_codeable_concept, create_condition from healthchain.sandbox.generators.basegenerators import ( BaseGenerator, generator_registry, diff --git a/healthchain/sandbox/loaders/mimic.py b/healthchain/sandbox/loaders/mimic.py index be79adc3..57e03219 100644 --- a/healthchain/sandbox/loaders/mimic.py +++ b/healthchain/sandbox/loaders/mimic.py @@ -7,7 +7,7 @@ import logging import random from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from fhir.resources.R4B.bundle import Bundle @@ -54,28 +54,49 @@ def load( resource_types: Optional[List[str]] = None, sample_size: Optional[int] = None, random_seed: Optional[int] = None, + as_dict: bool = False, **kwargs, - ) -> Dict: + ) -> Union[Dict[str, Bundle], Dict[str, Any]]: """ - Load MIMIC-on-FHIR data as a dict of FHIR Bundles. + Load MIMIC-on-FHIR data as FHIR Bundle(s). Args: data_dir: Path to root MIMIC-on-FHIR directory (expects a /fhir subdir with .ndjson.gz files) resource_types: Resource type names to load (e.g., ["MimicMedication"]). Required. sample_size: Number of resources to randomly sample per type (loads all if None) random_seed: Seed for sampling + as_dict: If True, return single bundle dict (fast, no validation - for ML workflows). + If False, return dict of validated Bundle objects grouped by resource type (for CDS Hooks). + Default: False **kwargs: Reserved for future use Returns: - Dict mapping resource type (e.g., "MedicationStatement") to FHIR R4B Bundle + If as_dict=False: Dict[str, Bundle] - validated Pydantic Bundle objects grouped by resource type + Example: {"observation": Bundle(...), "patient": Bundle(...)} + If as_dict=True: Dict[str, Any] - single combined bundle dict (no validation) + Example: {"type": "collection", "entry": [...]} Raises: FileNotFoundError: If directory or resource files not found ValueError: If resource_types is None/empty or resources fail validation - Example: + Examples: + CDS Hooks prefetch format (validated, grouped by resource type): >>> loader = MimicOnFHIRLoader() - >>> loader.load(data_dir="./data/mimic-iv-fhir", resource_types=["MimicMedication"], sample_size=100) + >>> prefetch = loader.load( + ... data_dir="./data/mimic-iv-fhir", + ... resource_types=["MimicMedication", "MimicCondition"] + ... ) + >>> prefetch["medicationstatement"] # Pydantic Bundle object + + ML workflow (single bundle dict, fast, no validation): + >>> bundle = loader.load( + ... data_dir="./data/mimic-iv-fhir", + ... resource_types=["MimicObservationChartevents", "MimicPatient"], + ... as_dict=True + ... ) + >>> from healthchain.io import Dataset + >>> dataset = Dataset.from_fhir_bundle(bundle, schema="sepsis_vitals.yaml") """ data_dir = Path(data_dir) @@ -141,6 +162,15 @@ def load( f"No valid resources loaded from specified resource types: {resource_types}" ) + # ML workflow + if as_dict: + all_entries = [] + for resources in resources_by_type.values(): + all_entries.extend([{"resource": r} for r in resources]) + + return {"type": "collection", "entry": all_entries} + + # CDS Hooks prefetch bundles = {} for fhir_type, resources in resources_by_type.items(): bundles[fhir_type.lower()] = Bundle( diff --git a/mkdocs.yml b/mkdocs.yml index b3e54e5e..8692f0d3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,17 +28,24 @@ nav: - Protocols: - CDS Hooks: reference/gateway/cdshooks.md - SOAP/CDA: reference/gateway/soap_cda.md + - I/O: + - Containers: + - Overview: reference/io/containers/containers.md + - Document: reference/io/containers/document.md + - Dataset: reference/io/containers/dataset.md + - Adapters: + - Overview: reference/io/adapters/adapters.md + - CDA Adapter: reference/io/adapters/cdaadapter.md + - CDS FHIR Adapter: reference/io/adapters/cdsfhiradapter.md + - Mappers: + - Overview: reference/io/mappers/mappers.md + - FHIR Feature Mapper: reference/io/mappers/fhir_feature.md - Pipeline: - Overview: reference/pipeline/pipeline.md - - Data Container: reference/pipeline/data_container.md - Components: - Overview: reference/pipeline/components/components.md - CdsCardCreator: reference/pipeline/components/cdscardcreator.md - FHIRProblemListExtractor: reference/pipeline/components/fhirproblemextractor.md - - Adapters: - - Overview: reference/pipeline/adapters/adapters.md - - CDA Adapter: reference/pipeline/adapters/cdaadapter.md - - CDS FHIR Adapter: reference/pipeline/adapters/cdsfhiradapter.md - Prebuilt Pipelines: - Medical Coding: reference/pipeline/prebuilt_pipelines/medicalcoding.md - Summarization: reference/pipeline/prebuilt_pipelines/summarization.md diff --git a/tests/containers/conftest.py b/tests/containers/conftest.py deleted file mode 100644 index bc77ee38..00000000 --- a/tests/containers/conftest.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -from healthchain.io.containers.document import FhirData, Document -from healthchain.fhir import create_bundle, create_document_reference - - -@pytest.fixture -def fhir_data(): - return FhirData() - - -@pytest.fixture -def sample_bundle(): - return create_bundle() - - -@pytest.fixture -def sample_document(): - return Document("This is a sample text for testing.") - - -@pytest.fixture -def sample_document_reference(): - return create_document_reference( - data="test content", - content_type="text/plain", - description="Test Document", - ) - - -@pytest.fixture -def document_family(): - """Create a family of related documents.""" - original = create_document_reference( - data="original content", - content_type="text/plain", - description="Original Report", - ) - - summary = create_document_reference( - data="summary content", content_type="text/plain", description="Summary" - ) - - translation = create_document_reference( - data="translated content", content_type="text/plain", description="Translation" - ) - - return original, summary, translation diff --git a/tests/fhir/test_bundle_helpers.py b/tests/fhir/test_bundle_helpers.py index d82ba692..cfc89c4e 100644 --- a/tests/fhir/test_bundle_helpers.py +++ b/tests/fhir/test_bundle_helpers.py @@ -7,7 +7,7 @@ from fhir.resources.allergyintolerance import AllergyIntolerance from fhir.resources.documentreference import DocumentReference -from healthchain.fhir.bundle_helpers import ( +from healthchain.fhir.bundlehelpers import ( create_bundle, add_resource, get_resources, diff --git a/tests/fhir/test_converters.py b/tests/fhir/test_converters.py new file mode 100644 index 00000000..40e06725 --- /dev/null +++ b/tests/fhir/test_converters.py @@ -0,0 +1,525 @@ +"""Tests for FHIR converters module. + +Tests the converter functions that transform FHIR Bundles to DataFrames, +with focus on the dict-based conversion architecture. +""" + +import pytest +import pandas as pd + +from healthchain.fhir.dataframe import ( + extract_observation_value, + group_bundle_by_patient, + bundle_to_dataframe, + extract_event_date, + get_supported_resources, + get_resource_info, + BundleConverterConfig, +) +from healthchain.fhir import ( + create_bundle, + add_resource, + create_patient, + create_value_quantity_observation, + create_condition, + create_medication_statement, +) + + +@pytest.mark.parametrize( + "obs_dict,expected", + [ + ({"valueQuantity": {"value": 85.0}}, 85.0), + ({"valueInteger": 100}, 100.0), + ({"valueString": "98.6"}, 98.6), + ({}, None), + ({"valueString": "not a number"}, None), + ({"valueBoolean": True}, None), + ], +) +def test_extract_observation_value_handles_value_types(obs_dict, expected): + """extract_observation_value handles different value types and invalid values.""" + assert extract_observation_value(obs_dict) == expected + + +def test_group_bundle_by_patient_handles_both_input_types(): + """group_bundle_by_patient handles Pydantic Bundle and dict input.""" + # Test Pydantic input + pydantic_bundle = create_bundle() + patient1 = create_patient("male", "1980-01-01") + patient1.id = "123" + add_resource(pydantic_bundle, patient1) + add_resource( + pydantic_bundle, + create_value_quantity_observation( + subject="Patient/123", code="8867-4", value=85.0, unit="bpm" + ), + ) + + result = group_bundle_by_patient(pydantic_bundle) + assert "Patient/123" in result + assert isinstance(result["Patient/123"]["patient"], dict) + assert result["Patient/123"]["patient"]["resourceType"] == "Patient" + + # Test dict input + dict_bundle = { + "resourceType": "Bundle", + "type": "collection", + "entry": [ + {"resource": {"resourceType": "Patient", "id": "456", "gender": "female"}}, + { + "resource": { + "resourceType": "Observation", + "subject": {"reference": "Patient/456"}, + "code": {"coding": [{"code": "8310-5"}]}, + "valueQuantity": {"value": 37.0}, + } + }, + ], + } + + result = group_bundle_by_patient(dict_bundle) + assert "Patient/456" in result + assert len(result["Patient/456"]["observations"]) == 1 + + +def test_group_bundle_by_patient_handles_reference_formats(): + """group_bundle_by_patient handles string and dict references, plus patient field.""" + bundle_dict = { + "resourceType": "Bundle", + "type": "collection", + "entry": [ + {"resource": {"resourceType": "Patient", "id": "789"}}, + { + "resource": { + "resourceType": "Observation", + "subject": "Patient/789", # String reference + "code": {"coding": [{"code": "8867-4"}]}, + "valueQuantity": {"value": 90.0}, + } + }, + { + "resource": { + "resourceType": "AllergyIntolerance", + "patient": {"reference": "Patient/789"}, # Uses patient field + "code": {"coding": [{"code": "123"}]}, + } + }, + ], + } + + result = group_bundle_by_patient(bundle_dict) + assert len(result["Patient/789"]["observations"]) == 1 + assert len(result["Patient/789"]["allergies"]) == 1 + + +def test_group_bundle_by_patient_groups_multiple_resource_types(): + """group_bundle_by_patient correctly categorizes different resource types.""" + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "999" + + add_resource(bundle, patient) + + # Add one of each type + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/999", code="8867-4", value=85.0, unit="bpm" + ), + ) + + from healthchain.fhir import ( + create_condition, + create_medication_statement, + create_allergy_intolerance, + ) + + add_resource(bundle, create_condition("Patient/999", code="E11.9")) + add_resource(bundle, create_medication_statement("Patient/999", code="123")) + add_resource(bundle, create_allergy_intolerance("Patient/999", code="456")) + + result = group_bundle_by_patient(bundle) + + assert len(result["Patient/999"]["observations"]) == 1 + assert len(result["Patient/999"]["conditions"]) == 1 + assert len(result["Patient/999"]["medications"]) == 1 + assert len(result["Patient/999"]["allergies"]) == 1 + + +def test_bundle_to_dataframe_basic_conversion(): + """bundle_to_dataframe converts both Pydantic and dict Bundles to DataFrames.""" + # Test with Pydantic Bundle + pydantic_bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(pydantic_bundle, patient) + add_resource( + pydantic_bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + display="Heart rate", + ), + ) + + df = bundle_to_dataframe(pydantic_bundle) + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 + assert "age" in df.columns and "gender" in df.columns + assert "8867-4_Heart_rate" in df.columns + assert df["8867-4_Heart_rate"].iloc[0] == 85.0 + + # Test with dict Bundle + dict_bundle = { + "resourceType": "Bundle", + "type": "collection", + "entry": [ + { + "resource": { + "resourceType": "Patient", + "id": "456", + "gender": "female", + "birthDate": "1990-05-15", + } + }, + { + "resource": { + "resourceType": "Observation", + "subject": {"reference": "Patient/456"}, + "code": { + "coding": [{"code": "8310-5", "display": "Body temperature"}] + }, + "valueQuantity": {"value": 37.0}, + } + }, + ], + } + + df = bundle_to_dataframe(dict_bundle) + assert len(df) == 1 + assert "8310-5_Body_temperature" in df.columns + + +@pytest.mark.parametrize( + "resources,source,strategy,expected", + [ + ( + { + "observations": [ + {"effectiveDateTime": "2024-01-15"}, + {"effectiveDateTime": "2024-01-10"}, + {"effectiveDateTime": "2024-01-20"}, + ] + }, + "Observation", + "earliest", + "2024-01-10", + ), + ( + { + "observations": [ + {"effectiveDateTime": "2024-01-15"}, + {"effectiveDateTime": "2024-01-10"}, + {"effectiveDateTime": "2024-01-20"}, + ] + }, + "Observation", + "latest", + "2024-01-20", + ), + ( + { + "observations": [ + {"effectiveDateTime": "2024-01-15"}, + {"effectiveDateTime": "2024-01-10"}, + {"effectiveDateTime": "2024-01-20"}, + ] + }, + "Observation", + "first", + "2024-01-15", + ), + ( + { + "encounters": [ + {"period": {"start": "2024-01-15T10:00:00Z"}}, + {"period": {"start": "2024-01-10T08:00:00Z"}}, + ] + }, + "Encounter", + "earliest", + "2024-01-10T08:00:00Z", + ), + ({}, "Observation", "earliest", None), + ({"observations": []}, "Observation", "earliest", None), + ], +) +def test_extract_event_date_strategies_and_sources( + resources, source, strategy, expected +): + """extract_event_date handles different strategies and resource sources.""" + assert extract_event_date(resources, source=source, strategy=strategy) == expected + + +@pytest.mark.parametrize( + "aggregation,values,expected", + [ + ("mean", [85.0, 92.0], 88.5), + ("median", [85.0, 92.0, 100.0], 92.0), + ("max", [85.0, 92.0, 100.0], 100.0), + ("min", [85.0, 92.0, 100.0], 85.0), + ("last", [85.0, 92.0, 100.0], 100.0), + ], +) +def test_bundle_to_dataframe_observation_aggregation_strategies( + aggregation, values, expected +): + """bundle_to_dataframe applies different aggregation strategies correctly.""" + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(bundle, patient) + + # Add multiple observations with same code + for value in values: + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=value, + unit="bpm", + display="Heart rate", + ), + ) + + config = BundleConverterConfig( + resources=["Patient", "Observation"], observation_aggregation=aggregation + ) + df = bundle_to_dataframe(bundle, config=config) + + assert df["8867-4_Heart_rate"].iloc[0] == expected + + +def test_bundle_to_dataframe_age_calculation_modes(): + """bundle_to_dataframe calculates age from current date or event date.""" + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + effective_datetime="2020-01-01T00:00:00Z", + ), + ) + + # Test event_date calculation + config = BundleConverterConfig( + age_calculation="event_date", + event_date_source="Observation", + event_date_strategy="earliest", + ) + df = bundle_to_dataframe(bundle, config=config) + assert df["age"].iloc[0] == 40 # 2020 - 1980 + + # Test current_date calculation (default) + config_default = BundleConverterConfig() + df_default = bundle_to_dataframe(bundle, config=config_default) + assert df_default["age"].iloc[0] is not None + + +def test_bundle_to_dataframe_creates_binary_indicators_for_conditions_and_medications(): + """bundle_to_dataframe creates binary indicator columns for conditions and medications.""" + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(bundle, patient) + add_resource( + bundle, create_condition("Patient/123", code="E11.9", display="Type_2_diabetes") + ) + add_resource( + bundle, + create_medication_statement("Patient/123", code="1049221", display="Insulin"), + ) + + config = BundleConverterConfig( + resources=["Patient", "Condition", "MedicationStatement"] + ) + df = bundle_to_dataframe(bundle, config=config) + + assert "condition_E11.9_Type_2_diabetes" in df.columns + assert df["condition_E11.9_Type_2_diabetes"].iloc[0] == 1 + assert "medication_1049221_Insulin" in df.columns + assert df["medication_1049221_Insulin"].iloc[0] == 1 + + +def test_bundle_to_dataframe_handles_edge_cases(): + """bundle_to_dataframe handles empty bundles and malformed data gracefully.""" + # Empty bundle + empty_bundle = create_bundle() + df = bundle_to_dataframe(empty_bundle) + assert isinstance(df, pd.DataFrame) and len(df) == 0 + + # Missing coding arrays - should skip bad observation + bundle_dict = { + "resourceType": "Bundle", + "type": "collection", + "entry": [ + { + "resource": { + "resourceType": "Patient", + "id": "123", + "gender": "male", + "birthDate": "1980-01-01", + } + }, + { + "resource": { + "resourceType": "Observation", + "subject": {"reference": "Patient/123"}, + "code": {}, # Missing coding array + "valueQuantity": {"value": 85.0}, + } + }, + ], + } + + df = bundle_to_dataframe(bundle_dict) + assert len(df) == 1 + assert df["patient_ref"].iloc[0] == "Patient/123" + + # Missing display - should use code as fallback + bundle_with_condition = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "456" + add_resource(bundle_with_condition, patient) + add_resource(bundle_with_condition, create_condition("Patient/456", code="E11.9")) + + config = BundleConverterConfig(resources=["Patient", "Condition"]) + df = bundle_to_dataframe(bundle_with_condition, config=config) + assert "condition_E11.9_E11.9" in df.columns # Code used as display + + +def test_bundle_to_dataframe_handles_multiple_patients(): + """bundle_to_dataframe creates one row per patient in multi-patient bundles.""" + bundle = create_bundle() + + # Add first patient with observations + patient1 = create_patient("male", "1980-01-01") + patient1.id = "123" + add_resource(bundle, patient1) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + display="Heart rate", + ), + ) + + # Add second patient with observations + patient2 = create_patient("female", "1990-05-15") + patient2.id = "456" + add_resource(bundle, patient2) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/456", + code="8867-4", + value=72.0, + unit="bpm", + display="Heart rate", + ), + ) + + df = bundle_to_dataframe(bundle) + + assert len(df) == 2 + assert set(df["patient_ref"]) == {"Patient/123", "Patient/456"} + assert df[df["patient_ref"] == "Patient/123"]["8867-4_Heart_rate"].iloc[0] == 85.0 + assert df[df["patient_ref"] == "Patient/456"]["8867-4_Heart_rate"].iloc[0] == 72.0 + + +def test_bundle_converter_config_defaults(): + """BundleConverterConfig uses sensible defaults.""" + config = BundleConverterConfig() + + assert config.resources == ["Patient", "Observation"] + assert config.observation_aggregation == "mean" + assert config.age_calculation == "current_date" + assert config.event_date_source == "Observation" + assert config.event_date_strategy == "earliest" + + +def test_bundle_converter_config_validates_unsupported_resources(caplog): + """BundleConverterConfig warns about unsupported resources but doesn't fail.""" + import logging + + caplog.set_level(logging.WARNING) + + config = BundleConverterConfig( + resources=["Patient", "Observation", "UnsupportedResource", "AnotherFakeOne"] + ) + + # Should still create config successfully + assert "Patient" in config.resources + assert "Observation" in config.resources + + # Should have logged warnings + assert any("UnsupportedResource" in record.message for record in caplog.records) + + +def test_get_supported_resources_returns_expected_types(): + """get_supported_resources returns list of supported resource types.""" + resources = get_supported_resources() + + assert isinstance(resources, list) + assert "Patient" in resources + assert "Observation" in resources + assert "Condition" in resources + assert "MedicationStatement" in resources + + +def test_get_resource_info_returns_handler_details(): + """get_resource_info returns metadata for supported resources.""" + obs_info = get_resource_info("Observation") + + assert obs_info["handler"] == "_flatten_observations" + assert "description" in obs_info + assert "observation" in obs_info["description"].lower() + + # Unsupported resource returns empty dict + assert get_resource_info("UnsupportedResource") == {} + + +def test_bundle_to_dataframe_skips_unsupported_resources_gracefully(): + """bundle_to_dataframe skips unsupported resources without error.""" + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", code="8867-4", value=85.0, unit="bpm" + ), + ) + + # Include unsupported resource types in config + config = BundleConverterConfig( + resources=["Patient", "Observation", "UnsupportedType"] + ) + + # Should not raise error, just skip unsupported types + df = bundle_to_dataframe(bundle, config=config) + assert len(df) == 1 diff --git a/tests/fhir/test_helpers.py b/tests/fhir/test_helpers.py index 75763b16..35f7ce24 100644 --- a/tests/fhir/test_helpers.py +++ b/tests/fhir/test_helpers.py @@ -9,7 +9,7 @@ from datetime import datetime -from healthchain.fhir.helpers import ( +from healthchain.fhir import ( create_resource_from_dict, create_single_codeable_concept, create_single_reaction, @@ -23,6 +23,8 @@ read_content_attachment, add_provenance_metadata, add_coding_to_codeable_concept, + calculate_age_from_birthdate, + calculate_age_from_event_date, ) import pytest @@ -336,3 +338,72 @@ def test_set_condition_category_invalid_raises(): def test_create_condition_without_code_is_none(): cond = create_condition(subject="Patient/1") assert cond.code is None + + +def test_calculate_age_from_birthdate(): + """Test standard age calculation from birth date.""" + # Test with date 30 years ago + from datetime import datetime + + birth_year = datetime.now().year - 30 + birth_date = f"{birth_year}-06-15" + + age = calculate_age_from_birthdate(birth_date) + assert age is not None + # Age should be 29 or 30 depending on current date + assert age in [29, 30] + + +def test_calculate_age_from_birthdate_with_full_datetime(): + """Test age calculation with full ISO datetime.""" + from datetime import datetime + + birth_year = datetime.now().year - 25 + birth_date = f"{birth_year}-03-10T10:30:00Z" + + age = calculate_age_from_birthdate(birth_date) + assert age is not None + assert age in [24, 25] + + +def test_calculate_age_from_birthdate_invalid(): + """Test age calculation with invalid date.""" + assert calculate_age_from_birthdate(None) is None + assert calculate_age_from_birthdate("") is None + assert calculate_age_from_birthdate("invalid") is None + + +def test_calculate_age_from_event_date(): + """Test MIMIC-IV style age calculation using event date.""" + birth_date = "1990-06-15" + event_date = "2020-03-10" + + age = calculate_age_from_event_date(birth_date, event_date) + assert age == 30 # 2020 - 1990 = 30 + + +def test_calculate_age_from_event_date_with_full_datetime(): + """Test MIMIC-IV style calculation with full ISO datetime.""" + birth_date = "1985-12-25T08:00:00Z" + event_date = "2023-01-15T14:30:00Z" + + age = calculate_age_from_event_date(birth_date, event_date) + assert age == 38 # 2023 - 1985 = 38 + + +def test_calculate_age_from_event_date_same_year(): + """Test MIMIC-IV style calculation when birth and event in same year.""" + birth_date = "2020-01-01" + event_date = "2020-12-31" + + age = calculate_age_from_event_date(birth_date, event_date) + assert age == 0 # Same year = 0 + + +def test_calculate_age_from_event_date_invalid(): + """Test MIMIC-IV style calculation with invalid dates.""" + assert calculate_age_from_event_date(None, "2020-01-01") is None + assert calculate_age_from_event_date("1990-01-01", None) is None + assert calculate_age_from_event_date("", "2020-01-01") is None + assert calculate_age_from_event_date("invalid", "2020-01-01") is None + assert calculate_age_from_event_date("1990-01-01", "invalid") is None diff --git a/tests/io/conftest.py b/tests/io/conftest.py new file mode 100644 index 00000000..efbe95dd --- /dev/null +++ b/tests/io/conftest.py @@ -0,0 +1,212 @@ +import pytest +import pandas as pd +from pathlib import Path + +from healthchain.io.containers.featureschema import FeatureSchema +from healthchain.io.containers.dataset import Dataset +from healthchain.fhir import create_bundle +from healthchain.fhir import create_patient, create_value_quantity_observation + + +@pytest.fixture +def sepsis_schema(): + """Load the actual sepsis_vitals.yaml schema. + + Uses the real schema file for integration-style testing. + """ + schema_path = Path("healthchain/configs/features/sepsis_vitals.yaml") + return FeatureSchema.from_yaml(schema_path) + + +@pytest.fixture +def minimal_schema(): + """Minimal schema with required and optional features. + + Useful for testing basic functionality without all the complexity + of the full sepsis schema. + """ + return FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "features": { + "heart_rate": { + "fhir_resource": "Observation", + "code": "8867-4", + "code_system": "http://loinc.org", + "display": "Heart rate", + "dtype": "float64", + "required": True, + }, + "temperature": { + "fhir_resource": "Observation", + "code": "8310-5", + "code_system": "http://loinc.org", + "display": "Body temperature", + "dtype": "float64", + "required": False, + }, + "age": { + "fhir_resource": "Patient", + "field": "birthDate", + "transform": "calculate_age", + "dtype": "int64", + "required": True, + }, + "gender_encoded": { + "fhir_resource": "Patient", + "field": "gender", + "transform": "encode_gender", + "dtype": "int64", + "required": True, + }, + }, + } + ) + + +@pytest.fixture +def observation_bundle(): + """Bundle with patient and observations matching minimal schema. + + Contains a single patient with heart rate and temperature observations. + """ + from healthchain.fhir import add_resource + + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + + # Use add_resource to properly add to bundle + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + system="http://loinc.org", + display="Heart rate", + ), + ) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8310-5", + value=37.0, + unit="F", + system="http://loinc.org", + display="Body temperature", + ), + ) + + return bundle + + +@pytest.fixture +def observation_bundle_with_duplicates(): + """Bundle with multiple observations of the same type for testing aggregation.""" + from healthchain.fhir import add_resource + + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + + # Use add_resource consistently like observation_bundle + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + system="http://loinc.org", + display="Heart rate", + ), + ) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=90.0, + unit="bpm", + system="http://loinc.org", + display="Heart rate", + ), + ) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=88.0, + unit="bpm", + system="http://loinc.org", + display="Heart rate", + ), + ) + + return bundle + + +@pytest.fixture +def empty_observation_bundle(): + """Bundle with patient but no observations.""" + from healthchain.fhir import add_resource + + bundle = create_bundle() + patient = create_patient("female", "1990-05-15") + patient.id = "456" + + add_resource(bundle, patient) + return bundle + + +@pytest.fixture +def sample_dataset(): + """Sample dataset with minimal schema features. + + Contains two patients with complete feature data. + """ + data = { + "patient_ref": ["Patient/1", "Patient/2"], + "heart_rate": [85.0, 92.0], + "temperature": [37.0, 37.5], + "age": [45, 62], + "gender_encoded": [1, 0], + } + return Dataset(pd.DataFrame(data)) + + +@pytest.fixture +def sample_dataset_incomplete(): + """Sample dataset missing required features. + + Useful for testing validation logic. + """ + data = { + "patient_ref": ["Patient/1", "Patient/2"], + "heart_rate": [85.0, 92.0], + # Missing temperature (optional), age, and gender_encoded (required) + } + return Dataset(pd.DataFrame(data)) + + +@pytest.fixture +def sample_dataset_wrong_types(): + """Sample dataset with incorrect data types. + + Useful for testing type validation logic. + """ + data = { + "patient_ref": ["Patient/1", "Patient/2"], + "heart_rate": ["85.0", "92.0"], # String instead of float + "temperature": [37.0, 37.5], + "age": [45.5, 62.5], # Float instead of int + "gender_encoded": [1, 0], + } + return Dataset(pd.DataFrame(data)) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py new file mode 100644 index 00000000..be2e25f1 --- /dev/null +++ b/tests/io/test_dataset.py @@ -0,0 +1,294 @@ +import pytest +import pandas as pd +import numpy as np + + +from healthchain.io.containers.dataset import Dataset + + +def test_dataset_from_fhir_bundle(observation_bundle, minimal_schema): + """Dataset.from_fhir_bundle extracts features using schema.""" + dataset = Dataset.from_fhir_bundle(observation_bundle, minimal_schema) + + assert len(dataset.data) == 1 + assert "patient_ref" in dataset.columns + assert "heart_rate" in dataset.columns + assert "temperature" in dataset.columns + assert "age" in dataset.columns + + +def test_dataset_from_fhir_bundle_with_yaml_path(observation_bundle): + """Dataset.from_fhir_bundle accepts YAML schema path.""" + schema_path = "healthchain/configs/features/sepsis_vitals.yaml" + dataset = Dataset.from_fhir_bundle(observation_bundle, schema_path) + + assert len(dataset.data) == 1 + assert "patient_ref" in dataset.columns + + +def test_dataset_from_fhir_bundle_with_aggregation( + observation_bundle_with_duplicates, minimal_schema +): + """Dataset.from_fhir_bundle respects aggregation parameter.""" + dataset_mean = Dataset.from_fhir_bundle( + observation_bundle_with_duplicates, minimal_schema, aggregation="mean" + ) + dataset_max = Dataset.from_fhir_bundle( + observation_bundle_with_duplicates, minimal_schema, aggregation="max" + ) + + assert dataset_mean.data["heart_rate"].iloc[0] == pytest.approx(87.666667, rel=1e-5) + assert dataset_max.data["heart_rate"].iloc[0] == 90.0 + + +def test_dataset_validate_with_complete_data(sample_dataset, minimal_schema): + """Dataset.validate passes with complete valid data.""" + result = sample_dataset.validate(minimal_schema) + + assert result.valid is True + assert len(result.missing_features) == 0 + assert len(result.errors) == 0 + + +def test_dataset_validate_detects_missing_required_features( + sample_dataset_incomplete, minimal_schema +): + """Dataset.validate detects missing required features.""" + result = sample_dataset_incomplete.validate(minimal_schema) + + assert result.valid is False + assert len(result.missing_features) > 0 + assert "age" in result.missing_features + assert "gender_encoded" in result.missing_features + + +def test_dataset_validate_raises_on_error_when_requested( + sample_dataset_incomplete, minimal_schema +): + """Dataset.validate raises exception when raise_on_error is True.""" + with pytest.raises(ValueError, match="Validation failed"): + sample_dataset_incomplete.validate(minimal_schema, raise_on_error=True) + + +def test_dataset_validate_detects_type_mismatches( + sample_dataset_wrong_types, minimal_schema +): + """Dataset.validate detects incorrect data types.""" + result = sample_dataset_wrong_types.validate(minimal_schema) + + # Type mismatches are recorded even if they don't fail validation due to dtype_compatible + assert len(result.type_mismatches) > 0 + # heart_rate should be object (string) instead of float64 + assert "heart_rate" in result.type_mismatches + # Check that errors were added for the type mismatches + assert len(result.errors) > 0 + assert any("heart_rate" in error for error in result.errors) + + +def test_dataset_validate_warns_about_missing_optional(minimal_schema): + """Dataset.validate generates warnings for missing optional features.""" + data = pd.DataFrame( + { + "patient_ref": ["Patient/1"], + "heart_rate": [85.0], + "age": [45], + "gender_encoded": [1], + # Missing optional "temperature" + } + ) + dataset = Dataset(data) + + result = dataset.validate(minimal_schema) + + assert result.valid is True + assert len(result.warnings) > 0 + assert any("temperature" in w for w in result.warnings) + + +def test_dataset_dtype_compatibility_allows_numeric_flexibility(): + """Dataset._dtypes_compatible allows flexibility between numeric types.""" + data = pd.DataFrame( + { + "patient_ref": ["Patient/1"], + "value_int": [45], # int64 + "value_float": [45.0], # float64 + } + ) + dataset = Dataset(data) + + # int64 and float64 should be compatible + assert dataset._dtypes_compatible("int64", "float64") + assert dataset._dtypes_compatible("float64", "int64") + assert dataset._dtypes_compatible("int32", "float64") + + +def test_dataset_to_risk_assessment_creates_resources_with_metadata(sample_dataset): + """Dataset.to_risk_assessment creates RiskAssessment resources with probabilities, model metadata, and comments.""" + predictions = np.array([0, 1]) + probabilities = np.array([0.15, 0.85]) + + # Test with model metadata + risks = sample_dataset.to_risk_assessment( + predictions, + probabilities, + outcome_code="A41.9", + outcome_display="Sepsis", + model_name="RandomForest", + model_version="1.0", + ) + + # Basic structure + assert len(risks) == 2 + assert risks[0].subject.reference == "Patient/1" + assert risks[1].subject.reference == "Patient/2" + assert risks[0].status == "final" + + # Probabilities + assert risks[0].prediction[0].probabilityDecimal == 0.15 + assert risks[1].prediction[0].probabilityDecimal == 0.85 + + # Model metadata + assert risks[0].method is not None + assert risks[0].method.coding[0].code == "RandomForest" + assert "v1.0" in risks[0].method.coding[0].display + + # Comments + assert risks[0].note is not None + assert "Negative" in risks[0].note[0].text + assert "15.00%" in risks[0].note[0].text + assert "low" in risks[0].note[0].text + assert "Positive" in risks[1].note[0].text + assert "85.00%" in risks[1].note[0].text + assert "high" in risks[1].note[0].text + + +@pytest.mark.parametrize( + "predictions,probabilities,expected_risks", + [ + ([0, 1, 0], [0.15, 0.85, 0.55], ["low", "high", "moderate"]), + ([0, 1, 0], [0.0, 1.0, 0.5], ["low", "high", "moderate"]), # Edge cases + ], +) +def test_dataset_to_risk_assessment_categorizes_risk_levels( + predictions, probabilities, expected_risks +): + """Dataset.to_risk_assessment correctly categorizes risk levels including edge probabilities.""" + data = pd.DataFrame( + { + "patient_ref": ["Patient/1", "Patient/2", "Patient/3"], + "heart_rate": [85.0, 92.0, 88.0], + "temperature": [37.0, 37.5, 37.2], + "age": [45, 62, 50], + "gender_encoded": [1, 0, 1], + } + ) + dataset = Dataset(data) + + risks = dataset.to_risk_assessment( + np.array(predictions), + np.array(probabilities), + outcome_code="A41.9", + outcome_display="Sepsis", + ) + + for i, expected_risk in enumerate(expected_risks): + assert risks[i].prediction[0].qualitativeRisk.coding[0].code == expected_risk + + +@pytest.mark.parametrize( + "data_dict,predictions,probabilities,expected_error", + [ + ( + {"heart_rate": [85.0, 92.0], "age": [45, 62]}, # Missing patient_ref + [0, 1], + [0.15, 0.85], + "DataFrame must have 'patient_ref' column", + ), + ( + {"patient_ref": ["Patient/1", "Patient/2"], "value": [1, 2]}, + [0], # Wrong prediction length + [0.15, 0.85], + "Predictions length .* must match", + ), + ( + {"patient_ref": ["Patient/1", "Patient/2"], "value": [1, 2]}, + [0, 1], + [0.15], # Wrong probability length + "Probabilities length .* must match", + ), + ], +) +def test_dataset_to_risk_assessment_validation_errors( + data_dict, predictions, probabilities, expected_error +): + """Dataset.to_risk_assessment validates required columns and array lengths.""" + data = pd.DataFrame(data_dict) + dataset = Dataset(data) + + with pytest.raises(ValueError, match=expected_error): + dataset.to_risk_assessment( + np.array(predictions), + np.array(probabilities), + outcome_code="A41.9", + outcome_display="Sepsis", + ) + + +def test_dataset_from_csv_loads_correctly(tmp_path): + """Dataset.from_csv loads CSV files into DataFrame.""" + csv_file = tmp_path / "test.csv" + csv_file.write_text( + "patient_ref,heart_rate,age\nPatient/1,85.0,45\nPatient/2,92.0,62" + ) + + dataset = Dataset.from_csv(str(csv_file)) + + assert len(dataset.data) == 2 + assert "patient_ref" in dataset.columns + assert dataset.data["heart_rate"].iloc[0] == 85.0 + + +def test_dataset_from_dict_creates_dataframe(): + """Dataset.from_dict creates DataFrame from dict.""" + data_dict = { + "data": {"patient_ref": ["Patient/1", "Patient/2"], "heart_rate": [85.0, 92.0]} + } + + dataset = Dataset.from_dict(data_dict) + + assert len(dataset.data) == 2 + assert "patient_ref" in dataset.columns + assert "heart_rate" in dataset.columns + assert dataset.data["heart_rate"].iloc[0] == 85.0 + + +def test_dataset_to_csv_saves_correctly(tmp_path, sample_dataset): + """Dataset.to_csv exports DataFrame to CSV.""" + csv_file = tmp_path / "output.csv" + + sample_dataset.to_csv(str(csv_file), index=False) + + assert csv_file.exists() + df = pd.read_csv(csv_file) + assert len(df) == 2 + assert "patient_ref" in df.columns + + +def test_dataset_rejects_non_dataframe_input(): + """Dataset validates input is a DataFrame in __post_init__.""" + with pytest.raises(TypeError, match="data must be a pandas DataFrame"): + Dataset([{"patient_ref": "Patient/1"}]) + + +def test_dataset_to_risk_assessment_validates_probability_length(): + """Dataset.to_risk_assessment validates probabilities array length.""" + data = pd.DataFrame({"patient_ref": ["Patient/1", "Patient/2"], "value": [1, 2]}) + dataset = Dataset(data) + + predictions = np.array([0, 1]) + probabilities = np.array([0.15]) # Wrong length + + with pytest.raises(ValueError, match="Probabilities length .* must match"): + dataset.to_risk_assessment( + predictions, probabilities, outcome_code="A41.9", outcome_display="Sepsis" + ) diff --git a/tests/containers/test_document.py b/tests/io/test_document.py similarity index 86% rename from tests/containers/test_document.py rename to tests/io/test_document.py index 9faa038f..ab86a839 100644 --- a/tests/containers/test_document.py +++ b/tests/io/test_document.py @@ -6,6 +6,11 @@ from healthchain.fhir import create_bundle, add_resource, create_condition +@pytest.fixture +def sample_document(): + return Document("This is a sample text for testing.") + + def test_document_initialization(sample_document): """Test basic Document initialization and properties.""" assert sample_document.data == "This is a sample text for testing." @@ -23,55 +28,6 @@ def test_document_initialization(sample_document): assert sample_document.nlp.get_embeddings() is None -def test_document_properties(sample_document): - """Test Document property access.""" - # Test property access - assert hasattr(sample_document, "nlp") - assert hasattr(sample_document, "fhir") - assert hasattr(sample_document, "cds") - assert hasattr(sample_document, "models") - - -def test_document_word_count(sample_document): - """Test word count functionality.""" - assert sample_document.word_count() == 7 - - -def test_document_iteration(sample_document): - """Test document iteration over tokens.""" - tokens = list(sample_document) - assert tokens == [ - "This", - "is", - "a", - "sample", - "text", - "for", - "testing.", - ] - - -def test_document_length(sample_document): - """Test document length.""" - assert len(sample_document) == 34 # Length of the text string - - -def test_document_post_init(sample_document): - """Test post-initialization behavior.""" - # Test that text is set from data - assert sample_document.text == sample_document.data - # Test that basic tokenization is performed - assert len(sample_document.nlp._tokens) > 0 - - -def test_empty_document(): - """Test Document initialization with empty text.""" - doc = Document("") - assert doc.text == "" - assert doc.nlp._tokens == [] - assert doc.word_count() == 0 - - @pytest.mark.parametrize( "data_builder, expect_bundle, expected_entries, expected_text", [ diff --git a/tests/io/test_feature_schema.py b/tests/io/test_feature_schema.py new file mode 100644 index 00000000..4569434f --- /dev/null +++ b/tests/io/test_feature_schema.py @@ -0,0 +1,244 @@ +import pytest +import tempfile +from pathlib import Path + +from healthchain.io.containers.featureschema import FeatureSchema, FeatureMapping + + +@pytest.mark.parametrize( + "mapping_data,expected_error", + [ + ( + {"fhir_resource": "Observation"}, + "Observation resources require a 'code'", + ), + ( + {"fhir_resource": "Observation", "code": "123"}, + "Observation resources require a 'code_system'", + ), + ( + {"fhir_resource": "Observation", "code_system": "http://loinc.org"}, + "Observation resources require a 'code'", + ), + ( + {"fhir_resource": "Patient"}, + "Patient resources require a 'field'", + ), + ], +) +def test_feature_mapping_required_fields_and_validations(mapping_data, expected_error): + """FeatureMapping enforces required fields and validates resource-specific requirements.""" + with pytest.raises(ValueError, match=expected_error): + FeatureMapping(name="test_feature", dtype="float64", **mapping_data) + + +def test_feature_schema_loads_from_yaml(sepsis_schema): + """FeatureSchema.from_yaml loads the sepsis_vitals schema correctly.""" + assert sepsis_schema.name == "sepsis_prediction_features" + assert sepsis_schema.version == "1.0" + assert len(sepsis_schema.features) == 8 + assert "heart_rate" in sepsis_schema.features + assert "age" in sepsis_schema.features + + +def test_feature_schema_from_dict(minimal_schema): + """FeatureSchema.from_dict creates schema with proper FeatureMapping objects.""" + assert minimal_schema.name == "test_schema" + assert isinstance(minimal_schema.features["heart_rate"], FeatureMapping) + assert minimal_schema.features["heart_rate"].required is True + assert minimal_schema.features["temperature"].required is False + + +def test_feature_schema_to_dict_and_back_handles_unknown_and_nested_fields( + minimal_schema, +): + """FeatureSchema.to_dict/from_dict: unknown fields are allowed (Pydantic extra='allow').""" + # Add an unknown field at the top-level + schema_dict = minimal_schema.to_dict() + schema_dict["extra_top_level"] = "foo" + # Add extra/unknown fields at the feature level + schema_dict["features"]["heart_rate"]["unknown_field"] = 12345 + schema_dict["features"]["temperature"]["nested_field"] = {"inner": ["a", {"b": 7}]} + + # With Pydantic extra='allow', unknown fields are accepted and preserved + loaded = FeatureSchema.from_dict(schema_dict) + + # Core fields should still be correct + assert loaded.name == minimal_schema.name + assert loaded.version == minimal_schema.version + assert len(loaded.features) == len(minimal_schema.features) + + # Unknown fields are preserved in the model + assert "heart_rate" in loaded.features + assert loaded.features["heart_rate"].code == "8867-4" + + +def test_feature_schema_to_yaml_and_back(minimal_schema): + """FeatureSchema can be saved to YAML and reloaded.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + temp_path = f.name + + try: + minimal_schema.to_yaml(temp_path) + loaded = FeatureSchema.from_yaml(temp_path) + + assert loaded.name == minimal_schema.name + assert len(loaded.features) == len(minimal_schema.features) + assert loaded.features["heart_rate"].code == "8867-4" + finally: + Path(temp_path).unlink() + + +def test_feature_schema_required_vs_optional_distinction(minimal_schema): + """FeatureSchema correctly distinguishes required from optional features.""" + required = minimal_schema.get_required_features() + all_features = minimal_schema.get_feature_names() + + # Required features should be a subset of all features + assert set(required).issubset(set(all_features)) + + # Temperature is optional, others are required + assert "temperature" not in required + assert len(required) == 3 + assert all(f in required for f in ["heart_rate", "age", "gender_encoded"]) + + +@pytest.mark.parametrize( + "columns, expected_valid, missing_required, missing_optional, unexpected", + [ + ( + ["heart_rate", "temperature"], # missing required + False, + {"age", "gender_encoded"}, + set(), + set(), + ), + ( + ["heart_rate", "age", "gender_encoded"], # missing optional + True, + set(), + {"temperature"}, + set(), + ), + ( + ["heart_rate", "age", "gender_encoded", "unexpected_col"], # unexpected col + True, + set(), + set(), + {"unexpected_col"}, + ), + ], +) +def test_feature_schema_validate_dataframe_columns_various_cases( + minimal_schema, + columns, + expected_valid, + missing_required, + missing_optional, + unexpected, +): + """FeatureSchema.validate_dataframe_columns: missing required, optional, and unexpected columns.""" + result = minimal_schema.validate_dataframe_columns(columns) + assert result["valid"] is expected_valid + assert set(result["missing_required"]) == missing_required + if missing_optional: + assert set(result["missing_optional"]) == missing_optional + if unexpected: + assert set(result["unexpected"]) == unexpected + + +def test_feature_schema_get_features_by_resource(minimal_schema): + """FeatureSchema.get_features_by_resource filters features by FHIR resource type.""" + observations = minimal_schema.get_features_by_resource("Observation") + patients = minimal_schema.get_features_by_resource("Patient") + + assert len(observations) == 2 # heart_rate, temperature + assert "heart_rate" in observations + assert "temperature" in observations + + assert len(patients) == 2 # age, gender_encoded + assert "age" in patients + assert "gender_encoded" in patients + + # Non-existent resource type returns empty dict + assert minimal_schema.get_features_by_resource("Condition") == {} + + +def test_feature_schema_get_observation_codes(minimal_schema): + """FeatureSchema.get_observation_codes returns mapping of codes to features.""" + obs_codes = minimal_schema.get_observation_codes() + + assert "8867-4" in obs_codes # heart_rate code + assert "8310-5" in obs_codes # temperature code + assert obs_codes["8867-4"].name == "heart_rate" + assert obs_codes["8310-5"].name == "temperature" + + +def test_feature_schema_get_feature_names_preserves_order(minimal_schema): + """FeatureSchema.get_feature_names returns features in definition order.""" + names = minimal_schema.get_feature_names() + + assert isinstance(names, list) + assert len(names) == 4 + # Order should match the features dict order + assert names == ["heart_rate", "temperature", "age", "gender_encoded"] + + +def test_feature_schema_from_yaml_handles_malformed_file(): + """FeatureSchema.from_yaml raises error for malformed YAML.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write("invalid: yaml: content: [\n") # Malformed YAML + temp_path = f.name + + try: + with pytest.raises( + Exception + ): # Could be yaml.YAMLError or other parsing errors + FeatureSchema.from_yaml(temp_path) + finally: + Path(temp_path).unlink() + + +def test_feature_mapping_from_dict_creates_instance(): + """FeatureMapping.from_dict creates instance with name parameter.""" + mapping_data = { + "fhir_resource": "Observation", + "code": "8867-4", + "code_system": "http://loinc.org", + "dtype": "float64", + "required": True, + } + + mapping = FeatureMapping.from_dict("test_feature", mapping_data) + + assert mapping.name == "test_feature" + assert mapping.code == "8867-4" + assert mapping.fhir_resource == "Observation" + assert mapping.required is True + + +def test_feature_schema_handles_optional_fields(minimal_schema): + """FeatureSchema preserves optional metadata fields.""" + # Check that optional fields can be None + assert minimal_schema.description is None or isinstance( + minimal_schema.description, str + ) + assert minimal_schema.model_info is None or isinstance( + minimal_schema.model_info, dict + ) + + # Create schema with metadata + schema_with_metadata = FeatureSchema.from_dict( + { + "name": "test", + "version": "1.0", + "description": "Test description", + "model_info": {"type": "RandomForest"}, + "metadata": {"custom_field": "value"}, + "features": {}, + } + ) + + assert schema_with_metadata.description == "Test description" + assert schema_with_metadata.model_info["type"] == "RandomForest" + assert schema_with_metadata.metadata["custom_field"] == "value" diff --git a/tests/containers/test_fhir_data.py b/tests/io/test_fhir_data.py similarity index 86% rename from tests/containers/test_fhir_data.py rename to tests/io/test_fhir_data.py index 90830e3e..0c28dc13 100644 --- a/tests/containers/test_fhir_data.py +++ b/tests/io/test_fhir_data.py @@ -1,12 +1,41 @@ +import pytest +from healthchain.io.containers.document import FhirData + from healthchain.fhir import create_condition, create_document_reference -def test_bundle_operations(fhir_data, sample_bundle): - """Test basic bundle operations.""" - assert fhir_data.bundle is None +@pytest.fixture +def fhir_data(): + return FhirData() + + +@pytest.fixture +def sample_document_reference(): + return create_document_reference( + data="test content", + content_type="text/plain", + description="Test Document", + ) + + +@pytest.fixture +def document_family(): + """Create a family of related documents.""" + original = create_document_reference( + data="original content", + content_type="text/plain", + description="Original Report", + ) + + summary = create_document_reference( + data="summary content", content_type="text/plain", description="Summary" + ) + + translation = create_document_reference( + data="translated content", content_type="text/plain", description="Translation" + ) - fhir_data.bundle = sample_bundle - assert fhir_data.bundle == sample_bundle + return original, summary, translation def test_resource_operations(fhir_data): diff --git a/tests/io/test_fhir_feature_mapper.py b/tests/io/test_fhir_feature_mapper.py new file mode 100644 index 00000000..8c6bb8eb --- /dev/null +++ b/tests/io/test_fhir_feature_mapper.py @@ -0,0 +1,363 @@ +import pytest +import numpy as np + +from healthchain.io.mappers.fhirfeaturemapper import FHIRFeatureMapper + + +def test_mapper_extracts_features_from_bundle(observation_bundle, minimal_schema): + """FHIRFeatureMapper extracts features matching schema from FHIR Bundle.""" + mapper = FHIRFeatureMapper(minimal_schema) + df = mapper.extract_features(observation_bundle) + + assert len(df) == 1 + assert "patient_ref" in df.columns + assert df["patient_ref"].iloc[0] == "Patient/123" + assert "heart_rate" in df.columns + assert "temperature" in df.columns + assert "age" in df.columns + assert "gender_encoded" in df.columns + + +@pytest.mark.parametrize( + "aggregation,expected_value", + [ + ("mean", 87.666667), + ("median", 88.0), + ("max", 90.0), + ("min", 85.0), + ("last", 88.0), + ], +) +def test_mapper_aggregation_methods( + observation_bundle_with_duplicates, minimal_schema, aggregation, expected_value +): + """FHIRFeatureMapper correctly aggregates multiple observation values.""" + mapper = FHIRFeatureMapper(minimal_schema) + df = mapper.extract_features( + observation_bundle_with_duplicates, aggregation=aggregation + ) + + assert len(df) == 1 + assert df["heart_rate"].iloc[0] == pytest.approx(expected_value, rel=1e-5) + + +def test_mapper_fills_missing_observations_with_nan( + empty_observation_bundle, minimal_schema +): + """FHIRFeatureMapper fills missing observations with NaN.""" + mapper = FHIRFeatureMapper(minimal_schema) + df = mapper.extract_features(empty_observation_bundle) + + assert len(df) == 1 + assert df["patient_ref"].iloc[0] == "Patient/456" + # Patient features should be present + assert df["age"].notna().iloc[0] + assert df["gender_encoded"].notna().iloc[0] + # Observation features should be NaN + assert np.isnan(df["heart_rate"].iloc[0]) + assert np.isnan(df["temperature"].iloc[0]) + + +def test_mapper_column_mapping_from_generic_to_schema(): + """FHIRFeatureMapper correctly maps generic column names to schema feature names.""" + from healthchain.fhir import ( + create_bundle, + add_resource, + create_patient, + create_value_quantity_observation, + ) + from healthchain.io.containers.featureschema import FeatureSchema + + # Create schema with specific LOINC codes + schema = FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "features": { + "hr": { # Schema uses "hr" as feature name + "fhir_resource": "Observation", + "code": "8867-4", # LOINC for heart rate + "code_system": "http://loinc.org", + "dtype": "float64", + "required": True, + } + }, + } + ) + + # Create bundle with observation that has code 8867-4 + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + system="http://loinc.org", + display="Heart rate", + ), + ) + + mapper = FHIRFeatureMapper(schema) + df = mapper.extract_features(bundle) + + # Should be renamed to "hr" not "8867-4_Heart_rate" + assert "hr" in df.columns + assert df["hr"].iloc[0] == 85.0 + + +def test_mapper_handles_bundle_with_no_matching_observations(observation_bundle): + """FHIRFeatureMapper handles bundle with observations that don't match schema.""" + from healthchain.io.containers.featureschema import FeatureSchema + + # Schema with different codes than what's in the bundle + schema = FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "features": { + "blood_pressure": { + "fhir_resource": "Observation", + "code": "85354-9", # Different code + "code_system": "http://loinc.org", + "dtype": "float64", + "required": False, + } + }, + } + ) + + mapper = FHIRFeatureMapper(schema) + df = mapper.extract_features(observation_bundle) + + assert len(df) == 1 + assert "blood_pressure" in df.columns + assert np.isnan(df["blood_pressure"].iloc[0]) + + +def test_mapper_extracts_patient_demographics(observation_bundle, minimal_schema): + """FHIRFeatureMapper correctly extracts and transforms patient demographics.""" + mapper = FHIRFeatureMapper(minimal_schema) + df = mapper.extract_features(observation_bundle) + + # Age should be calculated from birthDate (1980-01-01) + assert df["age"].iloc[0] > 40 # Age should be around 44-45 + assert df["age"].dtype == np.int64 + + # Gender should be encoded (male = 1) + assert df["gender_encoded"].iloc[0] == 1 + assert df["gender_encoded"].dtype == np.int64 + + +def test_mapper_preserves_column_order_from_schema(observation_bundle, minimal_schema): + """FHIRFeatureMapper returns DataFrame with columns ordered as in schema.""" + mapper = FHIRFeatureMapper(minimal_schema) + df = mapper.extract_features(observation_bundle) + + expected_order = ["patient_ref"] + minimal_schema.get_feature_names() + assert list(df.columns) == expected_order + + +def test_mapper_handles_multiple_patients(): + """FHIRFeatureMapper processes multiple patients in a bundle.""" + from healthchain.fhir import ( + create_bundle, + add_resource, + create_patient, + create_value_quantity_observation, + ) + from healthchain.io.containers.featureschema import FeatureSchema + + schema = FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "features": { + "heart_rate": { + "fhir_resource": "Observation", + "code": "8867-4", + "code_system": "http://loinc.org", + "dtype": "float64", + "required": True, + } + }, + } + ) + + bundle = create_bundle() + patient1 = create_patient("male", "1980-01-01") + patient1.id = "123" + patient2 = create_patient("female", "1990-05-15") + patient2.id = "456" + + add_resource(bundle, patient1) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + system="http://loinc.org", + ), + ) + add_resource(bundle, patient2) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/456", + code="8867-4", + value=92.0, + unit="bpm", + system="http://loinc.org", + ), + ) + + mapper = FHIRFeatureMapper(schema) + df = mapper.extract_features(bundle) + + assert len(df) == 2 + assert set(df["patient_ref"]) == {"Patient/123", "Patient/456"} + assert 85.0 in df["heart_rate"].values + assert 92.0 in df["heart_rate"].values + + +def test_mapper_aggregation_with_mixed_values(): + """FHIRFeatureMapper aggregates correctly with extreme value differences.""" + from healthchain.fhir import ( + create_bundle, + add_resource, + create_patient, + create_value_quantity_observation, + ) + from healthchain.io.containers.featureschema import FeatureSchema + + schema = FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "features": { + "heart_rate": { + "fhir_resource": "Observation", + "code": "8867-4", + "code_system": "http://loinc.org", + "dtype": "float64", + "required": True, + } + }, + } + ) + + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + + # Extreme values + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=50.0, + unit="bpm", + system="http://loinc.org", + ), + ) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=100.0, + unit="bpm", + system="http://loinc.org", + ), + ) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=75.0, + unit="bpm", + system="http://loinc.org", + ), + ) + + mapper = FHIRFeatureMapper(schema) + + # Test different aggregation methods + df_mean = mapper.extract_features(bundle, aggregation="mean") + assert df_mean["heart_rate"].iloc[0] == 75.0 + + df_max = mapper.extract_features(bundle, aggregation="max") + assert df_max["heart_rate"].iloc[0] == 100.0 + + df_min = mapper.extract_features(bundle, aggregation="min") + assert df_min["heart_rate"].iloc[0] == 50.0 + + +def test_mapper_with_schema_metadata_configuration(): + """FHIRFeatureMapper uses schema metadata for age calculation.""" + from healthchain.fhir import ( + create_bundle, + add_resource, + create_patient, + create_value_quantity_observation, + ) + from healthchain.io.containers.featureschema import FeatureSchema + + schema = FeatureSchema.from_dict( + { + "name": "test_schema", + "version": "1.0", + "metadata": { + "age_calculation": "event_date", + "event_date_source": "Observation", + "event_date_strategy": "earliest", + }, + "features": { + "heart_rate": { + "fhir_resource": "Observation", + "code": "8867-4", + "code_system": "http://loinc.org", + "dtype": "float64", + "required": True, + }, + "age": { + "fhir_resource": "Patient", + "field": "birthDate", + "transform": "calculate_age", + "dtype": "int64", + "required": True, + }, + }, + } + ) + + bundle = create_bundle() + patient = create_patient("male", "1980-01-01") + patient.id = "123" + + add_resource(bundle, patient) + add_resource( + bundle, + create_value_quantity_observation( + subject="Patient/123", + code="8867-4", + value=85.0, + unit="bpm", + effective_datetime="2020-01-01T00:00:00Z", + ), + ) + + mapper = FHIRFeatureMapper(schema) + df = mapper.extract_features(bundle) + + # Age should be calculated from birthdate to event date (40 years) + assert df["age"].iloc[0] == 40 diff --git a/tests/sandbox/test_mimic_loader.py b/tests/sandbox/test_mimic_loader.py index 0c2614e2..bf4e10dd 100644 --- a/tests/sandbox/test_mimic_loader.py +++ b/tests/sandbox/test_mimic_loader.py @@ -316,3 +316,115 @@ def test_mimic_loader_skips_resources_without_resource_type(temp_mimic_data_dir) # Should only load the valid resource bundle = result["medicationstatement"] assert len(bundle.entry) == 1 + + +def test_mimic_loader_as_dict_returns_plain_dict( + temp_mimic_data_dir, mock_medication_resources +): + """MimicOnFHIRLoader with as_dict=True returns plain dict (not Pydantic Bundle).""" + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication"], + as_dict=True, + ) + + # Should return a plain dict, not Dict[str, Bundle] + assert isinstance(result, dict) + assert "type" in result + assert result["type"] == "collection" + assert "entry" in result + assert isinstance(result["entry"], list) + assert len(result["entry"]) == 2 + + +def test_mimic_loader_as_dict_combines_multiple_resource_types( + temp_mimic_data_dir, mock_medication_resources, mock_condition_resources +): + """MimicOnFHIRLoader with as_dict=True combines all resources into single bundle.""" + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + create_ndjson_gz_file( + fhir_dir / "MimicCondition.ndjson.gz", mock_condition_resources + ) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication", "MimicCondition"], + as_dict=True, + ) + + # Should be a single bundle dict with all resources combined + assert isinstance(result, dict) + assert result["type"] == "collection" + assert len(result["entry"]) == 3 # 2 medications + 1 condition + + # Verify resource types are mixed + resource_types = {entry["resource"]["resourceType"] for entry in result["entry"]} + assert resource_types == {"MedicationStatement", "Condition"} + + +def test_mimic_loader_default_returns_validated_bundles( + temp_mimic_data_dir, mock_medication_resources, mock_condition_resources +): + """MimicOnFHIRLoader with as_dict=False (default) returns validated Bundle objects.""" + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + create_ndjson_gz_file( + fhir_dir / "MimicCondition.ndjson.gz", mock_condition_resources + ) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication", "MimicCondition"], + as_dict=False, # Explicit default + ) + + # Should return Dict[str, Bundle] with validated Pydantic objects + assert isinstance(result, dict) + assert "medicationstatement" in result + assert "condition" in result + + # Each value should be a Pydantic Bundle + assert type(result["medicationstatement"]).__name__ == "Bundle" + assert type(result["condition"]).__name__ == "Bundle" + assert len(result["medicationstatement"].entry) == 2 + assert len(result["condition"].entry) == 1 + + +def test_mimic_loader_as_dict_structure_matches_fhir_bundle( + temp_mimic_data_dir, mock_medication_resources +): + """MimicOnFHIRLoader with as_dict=True produces valid FHIR Bundle structure.""" + fhir_dir = temp_mimic_data_dir / "fhir" + create_ndjson_gz_file( + fhir_dir / "MimicMedication.ndjson.gz", mock_medication_resources + ) + + loader = MimicOnFHIRLoader() + result = loader.load( + data_dir=str(temp_mimic_data_dir), + resource_types=["MimicMedication"], + as_dict=True, + ) + + # Verify FHIR Bundle structure + assert result["type"] == "collection" + assert isinstance(result["entry"], list) + + # Each entry should have resource field + for entry in result["entry"]: + assert "resource" in entry + assert "resourceType" in entry["resource"] + assert entry["resource"]["resourceType"] == "MedicationStatement"