|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""This module provides utility functions for the container drivers.""" |
| 14 | +from __future__ import absolute_import |
| 15 | + |
| 16 | +import os |
| 17 | +import logging |
| 18 | +import sys |
| 19 | +import subprocess |
| 20 | +import traceback |
| 21 | +import json |
| 22 | + |
| 23 | +from typing import List, Dict, Any, Tuple, IO, Optional |
| 24 | + |
| 25 | +# Initialize logger |
| 26 | +SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) |
| 27 | +logger = logging.getLogger(__name__) |
| 28 | +console_handler = logging.StreamHandler(sys.stdout) |
| 29 | +logger.addHandler(console_handler) |
| 30 | +logger.setLevel(int(SM_LOG_LEVEL)) |
| 31 | + |
| 32 | +FAILURE_FILE = "/opt/ml/output/failure" |
| 33 | +DEFAULT_FAILURE_MESSAGE = """ |
| 34 | +Training Execution failed. |
| 35 | +For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'. |
| 36 | +TrainingJob - {training_job_name} |
| 37 | +""" |
| 38 | + |
| 39 | +USER_CODE_PATH = "/opt/ml/input/data/code" |
| 40 | +SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json" |
| 41 | +DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json" |
| 42 | + |
| 43 | +HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json" |
| 44 | + |
| 45 | +SM_EFA_NCCL_INSTANCES = [ |
| 46 | + "ml.g4dn.8xlarge", |
| 47 | + "ml.g4dn.12xlarge", |
| 48 | + "ml.g5.48xlarge", |
| 49 | + "ml.p3dn.24xlarge", |
| 50 | + "ml.p4d.24xlarge", |
| 51 | + "ml.p4de.24xlarge", |
| 52 | + "ml.p5.48xlarge", |
| 53 | + "ml.trn1.32xlarge", |
| 54 | +] |
| 55 | + |
| 56 | +SM_EFA_RDMA_INSTANCES = [ |
| 57 | + "ml.p4d.24xlarge", |
| 58 | + "ml.p4de.24xlarge", |
| 59 | + "ml.trn1.32xlarge", |
| 60 | +] |
| 61 | + |
| 62 | + |
| 63 | +def write_failure_file(message: Optional[str] = None): |
| 64 | + """Write a failure file with the message.""" |
| 65 | + if message is None: |
| 66 | + message = DEFAULT_FAILURE_MESSAGE.format(training_job_name=os.environ["TRAINING_JOB_NAME"]) |
| 67 | + if not os.path.exists(FAILURE_FILE): |
| 68 | + with open(FAILURE_FILE, "w") as f: |
| 69 | + f.write(message) |
| 70 | + |
| 71 | + |
| 72 | +def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON): |
| 73 | + """Read the source code config json file.""" |
| 74 | + try: |
| 75 | + with open(source_code_json, "r") as f: |
| 76 | + source_code_dict = json.load(f) or {} |
| 77 | + except FileNotFoundError: |
| 78 | + source_code_dict = {} |
| 79 | + return source_code_dict |
| 80 | + |
| 81 | + |
| 82 | +def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON): |
| 83 | + """Read the distribution config json file.""" |
| 84 | + try: |
| 85 | + with open(distributed_json, "r") as f: |
| 86 | + distributed_dict = json.load(f) or {} |
| 87 | + except FileNotFoundError: |
| 88 | + distributed_dict = {} |
| 89 | + return distributed_dict |
| 90 | + |
| 91 | + |
| 92 | +def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON): |
| 93 | + """Read the hyperparameters config json file.""" |
| 94 | + try: |
| 95 | + with open(hyperparameters_json, "r") as f: |
| 96 | + hyperparameters_dict = json.load(f) or {} |
| 97 | + except FileNotFoundError: |
| 98 | + hyperparameters_dict = {} |
| 99 | + return hyperparameters_dict |
| 100 | + |
| 101 | + |
| 102 | +def get_process_count(process_count: Optional[int] = None) -> int: |
| 103 | + """Get the number of processes to run on each node in the training job.""" |
| 104 | + return ( |
| 105 | + process_count |
| 106 | + or int(os.environ.get("SM_NUM_GPUS", 0)) |
| 107 | + or int(os.environ.get("SM_NUM_NEURONS", 0)) |
| 108 | + or 1 |
| 109 | + ) |
| 110 | + |
| 111 | + |
| 112 | +def hyperparameters_to_cli_args(hyperparameters: Dict[str, Any]) -> List[str]: |
| 113 | + """Convert the hyperparameters to CLI arguments.""" |
| 114 | + cli_args = [] |
| 115 | + for key, value in hyperparameters.items(): |
| 116 | + value = safe_deserialize(value) |
| 117 | + cli_args.extend([f"--{key}", safe_serialize(value)]) |
| 118 | + |
| 119 | + return cli_args |
| 120 | + |
| 121 | + |
| 122 | +def safe_deserialize(data: Any) -> Any: |
| 123 | + """Safely deserialize data from a JSON string. |
| 124 | +
|
| 125 | + This function handles the following cases: |
| 126 | + 1. If `data` is not a string, it returns the input as-is. |
| 127 | + 2. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`. |
| 128 | + 3. If `data` is a string but cannot be decoded as JSON, it returns the original string. |
| 129 | +
|
| 130 | + Returns: |
| 131 | + Any: The deserialized data, or the original input if it cannot be JSON-decoded. |
| 132 | + """ |
| 133 | + if not isinstance(data, str): |
| 134 | + return data |
| 135 | + |
| 136 | + try: |
| 137 | + return json.loads(data) |
| 138 | + except json.JSONDecodeError: |
| 139 | + return data |
| 140 | + |
| 141 | + |
| 142 | +def safe_serialize(data): |
| 143 | + """Serialize the data without wrapping strings in quotes. |
| 144 | +
|
| 145 | + This function handles the following cases: |
| 146 | + 1. If `data` is a string, it returns the string as-is without wrapping in quotes. |
| 147 | + 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns |
| 148 | + the JSON-encoded string using `json.dumps()`. |
| 149 | + 3. If `data` cannot be serialized (e.g., a custom object), it returns the string |
| 150 | + representation of the data using `str(data)`. |
| 151 | +
|
| 152 | + Args: |
| 153 | + data (Any): The data to serialize. |
| 154 | +
|
| 155 | + Returns: |
| 156 | + str: The serialized JSON-compatible string or the string representation of the input. |
| 157 | + """ |
| 158 | + if isinstance(data, str): |
| 159 | + return data |
| 160 | + try: |
| 161 | + return json.dumps(data) |
| 162 | + except TypeError: |
| 163 | + return str(data) |
| 164 | + |
| 165 | + |
| 166 | +def get_python_executable() -> str: |
| 167 | + """Get the python executable path.""" |
| 168 | + return sys.executable |
| 169 | + |
| 170 | + |
| 171 | +def log_subprocess_output(pipe: IO[bytes]): |
| 172 | + """Log the output from the subprocess.""" |
| 173 | + for line in iter(pipe.readline, b""): |
| 174 | + logger.info(line.decode("utf-8").strip()) |
| 175 | + |
| 176 | + |
| 177 | +def execute_commands(commands: List[str]) -> Tuple[int, str]: |
| 178 | + """Execute the provided commands and return exit code with failure traceback if any.""" |
| 179 | + try: |
| 180 | + process = subprocess.Popen( |
| 181 | + commands, |
| 182 | + stdout=subprocess.PIPE, |
| 183 | + stderr=subprocess.STDOUT, |
| 184 | + ) |
| 185 | + with process.stdout: |
| 186 | + log_subprocess_output(process.stdout) |
| 187 | + exitcode = process.wait() |
| 188 | + if exitcode != 0: |
| 189 | + raise subprocess.CalledProcessError(exitcode, commands) |
| 190 | + return exitcode, "" |
| 191 | + except subprocess.CalledProcessError as e: |
| 192 | + # Capture the traceback in case of failure |
| 193 | + error_traceback = traceback.format_exc() |
| 194 | + print(f"Command failed with exit code {e.returncode}. Traceback: {error_traceback}") |
| 195 | + return e.returncode, error_traceback |
| 196 | + |
| 197 | + |
| 198 | +def is_worker_node() -> bool: |
| 199 | + """Check if the current node is a worker node.""" |
| 200 | + return os.environ.get("SM_CURRENT_HOST") != os.environ.get("SM_MASTER_ADDR") |
| 201 | + |
| 202 | + |
| 203 | +def is_master_node() -> bool: |
| 204 | + """Check if the current node is the master node.""" |
| 205 | + return os.environ.get("SM_CURRENT_HOST") == os.environ.get("SM_MASTER_ADDR") |
0 commit comments