diff --git a/dashboard/README.md b/dashboard/README.md index 870bb16..573e3bf 100644 --- a/dashboard/README.md +++ b/dashboard/README.md @@ -54,7 +54,7 @@ conda-lock install --name synapse-gui environment-lock.yml 2. Move to the [dashboard/](./) directory. -3. Set up the database settings (read-only) and the AmSC MLflow API key: +3. Set up the database settings (read-only) and the [AmSC MLflow API](https://profile.american-science-cloud.org) key: ```bash export SF_DB_HOST='127.0.0.1' export SF_DB_READONLY_PASSWORD='your_password_here' # Use SINGLE quotes around the password! diff --git a/dashboard/app.py b/dashboard/app.py index cb5108f..79ac774 100644 --- a/dashboard/app.py +++ b/dashboard/app.py @@ -1,3 +1,4 @@ +import asyncio from bson.objectid import ObjectId import os import re @@ -6,13 +7,21 @@ from trame.ui.vuetify3 import SinglePageWithDrawerLayout from trame.widgets import plotly, router, vuetify3 as vuetify, html -from model_manager import ModelManager, model_type_dict +from model_manager import ( + GENESIS_MODEL_TYPE, + GENESIS_LOGO_URL, + ModelManager, + clear_model_load_errors, + is_missing_mlflow_model, + load_model_from_mlflow_with_progress, + model_type_dict, +) from outputs_manager import OutputManager from optimization_manager import OptimizationManager from parameters_manager import ParametersManager from calibration_manager import SimulationCalibrationManager from sfapi_manager import load_sfapi_card -from state_manager import server, state, ctrl, initialize_state +from state_manager import server, state, initialize_state from error_manager import error_panel, add_error from utils import ( data_depth_panel, @@ -33,6 +42,7 @@ par_manager = None opt_manager = None cal_manager = None +PLOTS_FIGURE_STATE = "plots_figure" # list of available experiments experiments = load_experiments() @@ -42,6 +52,13 @@ # ----------------------------------------------------------------------------- +def update_plot_figure(fig): + """Replace the figure shown in the Plots card.""" + state[PLOTS_FIGURE_STATE] = plotly.Figure.to_data(fig) + state.dirty(PLOTS_FIGURE_STATE) + state.flush() + + def update( reset_model=True, reset_output=True, @@ -52,6 +69,7 @@ def update( reset_gui_route_nersc=True, reset_gui_route_chat=True, reset_gui_layout=True, + preloaded_model_manager=None, **kwargs, ): print("Updating...") @@ -79,10 +97,15 @@ def update( cal_manager = SimulationCalibrationManager(simulation_calibration) # reset model if reset_model: - mod_manager = ModelManager( - config_dict=config_dict, - model_type=model_type_dict[state.model_type_verbose], - ) + state.model_available = False + if preloaded_model_manager is None: + mod_manager = ModelManager( + config_dict=config_dict, + model_type=model_type_dict[state.model_type_verbose], + ) + else: + mod_manager = preloaded_model_manager + state.model_available = mod_manager.avail() opt_manager = OptimizationManager(mod_manager) # reset parameters if reset_parameters: @@ -102,7 +125,6 @@ def update( # reset GUI layout if reset_gui_layout: gui_setup() - # reset plots if reset_plots: fig = plot( exp_data=exp_data, @@ -110,7 +132,69 @@ def update( model_manager=mod_manager, cal_manager=cal_manager, ) - ctrl.figure_update(fig) + update_plot_figure(fig) + + +async def update_with_model_download_indicator(**update_kwargs): + """Run a dashboard update with visible download feedback for NN models.""" + show_model_download = ( + update_kwargs.get("reset_model", True) + and state.model_type_verbose == GENESIS_MODEL_TYPE + ) + load_error = None + if show_model_download: + experiment = state.experiment + model_type_verbose = state.model_type_verbose + config_dict = load_config_dict(experiment) + model_type = model_type_dict[model_type_verbose] + state.model_available = False + state.model_downloading = True + state.model_download_status = "Downloading from American Science Cloud..." + state.model_download_progress = None + clear_model_load_errors() + state.flush() + await asyncio.sleep(0.05) + try: + loaded_model = await asyncio.to_thread( + load_model_from_mlflow_with_progress, + config_dict, + model_type, + asyncio.get_running_loop(), + ) + except Exception as e: + loaded_model = None + model_name = f"synapse-{config_dict['experiment']}_{model_type}" + if is_missing_mlflow_model(e): + print(f"Model {model_name} not found in MLflow; continuing without it.") + else: + load_error = e + if ( + state.experiment != experiment + or state.model_type_verbose != model_type_verbose + ): + state.model_downloading = False + state.model_download_status = None + state.model_download_progress = None + state.flush() + return + update_kwargs["preloaded_model_manager"] = ModelManager( + config_dict=config_dict, + model_type=model_type, + loaded_model=loaded_model, + ) + try: + update(**update_kwargs) + if load_error is not None: + title = f"Unable to load model {model_type}" + msg = f"Error occurred when loading model from MLflow: {load_error}" + add_error(title, msg) + state.flush() + finally: + if show_model_download: + state.model_downloading = False + state.model_download_status = None + state.model_download_progress = None + state.flush() @state.change( @@ -127,70 +211,68 @@ def update( "simulation_calibration", "use_inferred_calibration", ) -def reset(**kwargs): - # skip if triggered on server ready (all state variables marked as modified) - if len(state.modified_keys) == 1: - print(f"Reacting to state change in {state.modified_keys}...") - if any( - key in state.modified_keys - for key in [ - "experiment", - "experiment_date_range", - ] - ): - update( - reset_model=True, - reset_output=True, - reset_parameters=True, - reset_calibration=True, - reset_plots=True, - reset_gui_route_home=True, - reset_gui_route_nersc=False, - reset_gui_route_chat=False, - reset_gui_layout=False, - ) - elif any( - key in state.modified_keys - for key in [ - "model_type_verbose", - "model_training_time", - ] - ): - update( - reset_model=True, - reset_output=False, - reset_parameters=False, - reset_calibration=False, - reset_plots=True, - reset_gui_route_home=True, - reset_gui_route_nersc=False, - reset_gui_route_chat=False, - reset_gui_layout=False, - ) - elif any( - key in state.modified_keys - for key in [ - "displayed_output", - "parameters", - "opacity", - "parameters_min", - "parameters_max", - "parameters_show_all", - "simulation_calibration", - "use_inferred_calibration", - ] - ): - update( - reset_model=False, - reset_output=False, - reset_parameters=False, - reset_calibration=False, - reset_plots=True, - reset_gui_route_home=False, - reset_gui_route_nersc=False, - reset_gui_route_chat=False, - reset_gui_layout=False, - ) +async def reset(**kwargs): + experiment_keys = { + "experiment", + "experiment_date_range", + } + model_keys = { + "model_type_verbose", + "model_training_time", + } + plot_keys = { + "displayed_output", + "parameters", + "opacity", + "parameters_min", + "parameters_max", + "parameters_show_all", + "simulation_calibration", + "use_inferred_calibration", + } + watched_keys = experiment_keys | model_keys | plot_keys + modified_keys = set(state.modified_keys) & watched_keys + + if not modified_keys or modified_keys == watched_keys: + return + + print(f"Reacting to state change in {modified_keys}...") + if modified_keys & experiment_keys: + await update_with_model_download_indicator( + reset_model=True, + reset_output=True, + reset_parameters=True, + reset_calibration=True, + reset_plots=True, + reset_gui_route_home=True, + reset_gui_route_nersc=False, + reset_gui_route_chat=False, + reset_gui_layout=False, + ) + elif modified_keys & model_keys: + await update_with_model_download_indicator( + reset_model=True, + reset_output=False, + reset_parameters=False, + reset_calibration=False, + reset_plots=True, + reset_gui_route_home=True, + reset_gui_route_nersc=False, + reset_gui_route_chat=False, + reset_gui_layout=False, + ) + elif modified_keys & plot_keys: + update( + reset_model=False, + reset_output=False, + reset_parameters=False, + reset_calibration=False, + reset_plots=True, + reset_gui_route_home=False, + reset_gui_route_nersc=False, + reset_gui_route_chat=False, + reset_gui_layout=False, + ) def find_simulation(event, db): @@ -345,7 +427,8 @@ def home_route(): with vuetify.VContainer( style=f"height: {400 * len(state.parameters)}px;" ): - figure = plotly.Figure( + plotly.Figure( + state_variable_name=PLOTS_FIGURE_STATE, display_mode_bar="true", config={"responsive": True}, click=( @@ -353,7 +436,6 @@ def home_route(): "[utils.safe($event)]", ), ) - ctrl.figure_update = figure.update # NERSC route @@ -382,6 +464,8 @@ def chat_route(): # GUI layout def gui_setup(): print("Setting GUI layout...") + if GENESIS_LOGO_URL: + state.trame__favicon = GENESIS_LOGO_URL with SinglePageWithDrawerLayout(server) as layout: layout.title.set_text("Synapse") # add toolbar components diff --git a/dashboard/logos/AmSC_300px.png b/dashboard/logos/AmSC_300px.png new file mode 100644 index 0000000..a002c3a Binary files /dev/null and b/dashboard/logos/AmSC_300px.png differ diff --git a/dashboard/logos/genesis_80px.png b/dashboard/logos/genesis_80px.png new file mode 100644 index 0000000..5a1e92c Binary files /dev/null and b/dashboard/logos/genesis_80px.png differ diff --git a/dashboard/model_manager.py b/dashboard/model_manager.py index 6f8095a..f2fd5af 100644 --- a/dashboard/model_manager.py +++ b/dashboard/model_manager.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import contextmanager from datetime import datetime from pathlib import Path import tempfile @@ -6,20 +7,157 @@ import yaml import re import mlflow +import mlflow.store.artifact.artifact_repo as mlflow_artifact_repo +import mlflow.store.artifact.cloud_artifact_repo as mlflow_cloud_artifact_repo +import mlflow.utils.file_utils as mlflow_file_utils +from mlflow.exceptions import MlflowException +from trame.assets.local import LocalFileManager from sfapi_client import AsyncClient from sfapi_client.compute import Machine -from trame.widgets import vuetify3 as vuetify +from trame.widgets import vuetify3 as vuetify, html from utils import timer, load_config_dict, create_date_filter from calibration_manager import build_inferred_calibration from error_manager import add_error from sfapi_manager import monitor_sfapi_job from state_manager import state +LOGO_DIR = Path(__file__).parent / "logos" +AMSC_MLFLOW_URL = "https://mlflow.american-science-cloud.org" +GENESIS_MODEL_TYPE = "Neural Network (single)" +GENESIS_LOGO_PATH = LOGO_DIR / "genesis_80px.png" +AMSC_LOGO_PATH = LOGO_DIR / "AmSC_300px.png" +GENESIS_LOGO_URL = ( + LocalFileManager(LOGO_DIR).url("genesis_logo", GENESIS_LOGO_PATH) + if GENESIS_LOGO_PATH.is_file() + else None +) +AMSC_LOGO_URL = ( + LocalFileManager(LOGO_DIR).url("amsc_logo", AMSC_LOGO_PATH) + if AMSC_LOGO_PATH.is_file() + else None +) +MODEL_DOWNLOAD_ACTIVE_EXPR = ( + f"model_downloading && model_type_verbose === '{GENESIS_MODEL_TYPE}'" +) + model_type_dict = { "Gaussian Process": "GP", - "Neural Network (single)": "NN", + GENESIS_MODEL_TYPE: "NN", "Neural Network (ensemble)": "ensemble_NN", } +AMSC_MLFLOW_MODEL_URL_EXPR = ( + f"'{AMSC_MLFLOW_URL}/#/models/synapse-' + experiment + '_' + " + "(model_type_verbose === 'Gaussian Process' ? 'GP' : " + f"model_type_verbose === '{GENESIS_MODEL_TYPE}' ? 'NN' : " + "model_type_verbose === 'Neural Network (ensemble)' ? 'ensemble_NN' : " + "model_type_verbose)" +) + +_NO_PRELOADED_MODEL = object() + + +def clear_model_load_errors(): + """Remove stale model-load errors before starting another load attempt.""" + if state.errors is None: + return + errors = [ + error + for error in state.errors + if not str(error.get("title", "")).startswith("Unable to load model") + ] + if len(errors) != len(state.errors): + state.errors = errors + state.dirty("errors") + + +def is_missing_mlflow_model(error): + """Return whether an MLflow error represents a missing registered model.""" + return isinstance(error, MlflowException) and ( + error.error_code == "RESOURCE_DOES_NOT_EXIST" + or "RESOURCE_DOES_NOT_EXIST" in str(error) + ) + + +def load_model_from_mlflow(config_dict, model_type): + """Load the latest registered MLflow model for an experiment configuration.""" + if "mlflow" not in config_dict or not config_dict["mlflow"].get("tracking_uri"): + print( + f"No mlflow.tracking_uri in configuration file for {config_dict['experiment']}; cannot load model from MLflow." + ) + return None + + mlflow.set_tracking_uri(config_dict["mlflow"]["tracking_uri"]) + # When using the AmSC MLflow: inject the X-Api-Key into the requests to authenticate with the MLflow server + # (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) + if ( + config_dict["mlflow"]["tracking_uri"] + == "https://mlflow.american-science-cloud.org" + ): + enable_amsc_x_api_key(config_dict) + + experiment = config_dict["experiment"] + model_name = f"synapse-{experiment}_{model_type}" + return ( + mlflow.pyfunc.load_model(f"models:/{model_name}/latest") + .unwrap_python_model() + .model + ) + + +@contextmanager +def mlflow_artifact_progress_to_state(loop): + """Expose MLflow artifact download progress through dashboard state.""" + progress_bar_modules = [ + mlflow_file_utils, + mlflow_artifact_repo, + mlflow_cloud_artifact_repo, + ] + original_progress_bars = { + module: module.ArtifactProgressBar for module in progress_bar_modules + } + original_progress_bar = mlflow_file_utils.ArtifactProgressBar + + def set_download_progress(progress, total): + """Publish the current download completion percentage to the GUI.""" + + def update_progress_state(): + if total: + state.model_download_progress = min(100, progress / total * 100) + else: + state.model_download_progress = None + state.flush() + + loop.call_soon_threadsafe(update_progress_state) + + class TrameArtifactProgressBar(original_progress_bar): + def __init__(self, desc, total, step, **kwargs): + super().__init__(desc, total, step, **kwargs) + self.trame_progress = 0 + if desc.startswith("Downloading"): + set_download_progress(self.trame_progress, self.total) + + def update(self): + super().update() + self.trame_progress = min( + self.total, + self.trame_progress + self.step, + ) + if self.desc.startswith("Downloading"): + set_download_progress(self.trame_progress, self.total) + + for module in progress_bar_modules: + module.ArtifactProgressBar = TrameArtifactProgressBar + try: + yield + finally: + for module, progress_bar in original_progress_bars.items(): + module.ArtifactProgressBar = progress_bar + + +def load_model_from_mlflow_with_progress(config_dict, model_type, loop): + """Load an MLflow model while reporting artifact download progress.""" + with mlflow_artifact_progress_to_state(loop): + return load_model_from_mlflow(config_dict, model_type) def enable_amsc_x_api_key(config_dict): @@ -66,36 +204,23 @@ def patched(host_creds, endpoint, method, *args, **kwargs): class ModelManager: - def __init__(self, config_dict, model_type): + def __init__(self, config_dict, model_type, loaded_model=_NO_PRELOADED_MODEL): print("Initializing model manager...") + clear_model_load_errors() self.__model = None self.__model_type = model_type - if "mlflow" not in config_dict or not config_dict["mlflow"].get("tracking_uri"): - print( - f"No mlflow.tracking_uri in configuration file for {config_dict['experiment']}; cannot load model from MLflow." - ) - return - - mlflow.set_tracking_uri(config_dict["mlflow"]["tracking_uri"]) - # When using the AmSC MLflow: inject the X-Api-Key into the requests to authenticate with the MLflow server - # (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) - if ( - config_dict["mlflow"]["tracking_uri"] - == "https://mlflow.american-science-cloud.org" - ): - enable_amsc_x_api_key(config_dict) - experiment = config_dict["experiment"] model_name = f"synapse-{experiment}_{model_type}" try: - # Download model from MLflow server self.__model = ( - mlflow.pyfunc.load_model(f"models:/{model_name}/latest") - .unwrap_python_model() - .model + load_model_from_mlflow(config_dict, model_type) + if loaded_model is _NO_PRELOADED_MODEL + else loaded_model ) + if self.__model is None: + return if model_type not in ("NN", "ensemble_NN", "GP"): raise ValueError(f"Unsupported model type: {model_type}") # Populate inferred calibration in physics units for GUI @@ -105,13 +230,16 @@ def __init__(self, config_dict, model_type): config_dict["inputs"], config_dict["outputs"] ) except Exception as e: + if is_missing_mlflow_model(e): + print(f"Model {model_name} not found in MLflow; continuing without it.") + return title = f"Unable to load model {model_type}" msg = f"Error occurred when loading model from MLflow: {e}" add_error(title, msg) print(msg) def avail(self): - print("Checking model availability...") + # print("Checking model availability...") model_avail = True if self.__model is not None else False return model_avail @@ -353,32 +481,120 @@ def panel(self): print("Setting model card...") # list of available model types model_type_list = [ - "Gaussian Process", - "Neural Network (single)", - "Neural Network (ensemble)", + {"title": "Gaussian Process", "value": "Gaussian Process"}, + {"title": GENESIS_MODEL_TYPE, "value": GENESIS_MODEL_TYPE}, + { + "title": "Neural Network (ensemble)", + "value": "Neural Network (ensemble)", + }, ] + if GENESIS_LOGO_URL: + model_type_list[1]["logo"] = GENESIS_LOGO_URL + model_type_cols = 8 if AMSC_LOGO_URL else 12 with vuetify.VExpansionPanels(v_model=("expand_panel_control_model", 0)): with vuetify.VExpansionPanel( title="Control: Models", style="font-size: 20px; font-weight: 500;", ): with vuetify.VExpansionPanelText(): - with vuetify.VRow(): - with vuetify.VCol(): - vuetify.VSelect( + with vuetify.VRow(align="center"): + with vuetify.VCol(cols=model_type_cols): + with vuetify.VSelect( v_model=("model_type_verbose",), label="Model type", items=(model_type_list,), + item_title="title", + item_value="value", dense=True, + ): + with html.Template(v_slot_item="{ props, item }"): + with vuetify.VListItem( + v_bind=("{ ...props, title: undefined }",) + ): + with html.Div( + classes="d-flex align-center w-100" + ): + html.Span(v_text=("item.title",)) + vuetify.VSpacer() + vuetify.VImg( + v_if=("item.raw.logo",), + src=("item.raw.logo",), + width=40, + height=24, + max_width=40, + alt="Genesis", + style="margin-left: 12px;", + ) + if AMSC_LOGO_URL: + with vuetify.VCol( + cols=4, + classes="d-flex align-center justify-end", + ): + with html.A( + v_if=("model_available",), + href=(AMSC_MLFLOW_MODEL_URL_EXPR,), + target="_blank", + rel="noopener noreferrer", + title="Open selected model in AmSC MLflow", + style=( + "display: block; width: 100%; " + "max-width: 300px; margin-left: auto; " + "cursor: pointer;" + ), + ): + vuetify.VImg( + src=AMSC_LOGO_URL, + alt="AmSC", + max_width=300, + max_height=72, + contain=True, + style="width: 100%;", + ) + vuetify.VImg( + v_if=("!model_available",), + src=AMSC_LOGO_URL, + alt="AmSC", + max_width=300, + max_height=72, + contain=True, + title="Selected model is not available in AmSC MLflow", + style=( + "width: 100%; max-width: 300px; " + "margin-left: auto;" + ), + ) + with vuetify.VRow( + v_if=(MODEL_DOWNLOAD_ACTIVE_EXPR,), + no_gutters=True, + align="center", + style="margin-top: -8px; margin-bottom: 8px;", + ): + with vuetify.VCol(cols=model_type_cols): + with html.Div( + classes="d-flex align-center text-caption text-medium-emphasis mb-1" + ): + vuetify.VIcon( + "mdi-cloud-download-outline", + size=16, + classes="mr-1", + ) + html.Span(v_text=("model_download_status",)) + vuetify.VSpacer() + html.Span( + v_if=("model_download_progress !== null",), + v_text=( + "`${Math.round(model_download_progress)}%`", + ), + ) + vuetify.VProgressLinear( + indeterminate=("model_download_progress === null",), + model_value=("model_download_progress",), + color="primary", + height=4, + rounded=True, ) - with vuetify.VCol(): - vuetify.VTextField( - v_model_number=("model_training_status",), - label="Training status", - readonly=True, - ) - with vuetify.VRow(): - with vuetify.VCol(): + with vuetify.VRow(align="center"): + with vuetify.VCol(cols="auto"): vuetify.VBtn( "Train", click=self.training_trigger, @@ -387,3 +603,9 @@ def panel(self): ), style="text-transform: none", ) + with vuetify.VCol(cols=6, style="margin-left: auto;"): + vuetify.VTextField( + v_model_number=("model_training_status",), + label="Training status", + readonly=True, + ) diff --git a/dashboard/state_manager.py b/dashboard/state_manager.py index 8d8d31e..93ddb45 100644 --- a/dashboard/state_manager.py +++ b/dashboard/state_manager.py @@ -32,6 +32,10 @@ def initialize_state(): state.model_training_mode = "local" state.model_training_status = None state.model_training_time = None + state.model_available = False + state.model_downloading = False + state.model_download_status = None + state.model_download_progress = None # Optimization state.optimization_type = "Maximize" state.optimization_status = None