Skip to content

Commit 290fdc2

Browse files
authored
Add training examples (#5341)
1 parent 16c2591 commit 290fdc2

19 files changed

+3296
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Sagemaker modules container drivers directory."""
14+
from __future__ import absolute_import
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Sagemaker modules container drivers - common directory."""
14+
from __future__ import absolute_import
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module provides utility functions for the container drivers."""
14+
from __future__ import absolute_import
15+
16+
import os
17+
import logging
18+
import sys
19+
import subprocess
20+
import traceback
21+
import json
22+
23+
from typing import List, Dict, Any, Tuple, IO, Optional
24+
25+
# Initialize logger
26+
SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
27+
logger = logging.getLogger(__name__)
28+
console_handler = logging.StreamHandler(sys.stdout)
29+
logger.addHandler(console_handler)
30+
logger.setLevel(int(SM_LOG_LEVEL))
31+
32+
FAILURE_FILE = "/opt/ml/output/failure"
33+
DEFAULT_FAILURE_MESSAGE = """
34+
Training Execution failed.
35+
For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'.
36+
TrainingJob - {training_job_name}
37+
"""
38+
39+
USER_CODE_PATH = "/opt/ml/input/data/code"
40+
SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json"
41+
DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json"
42+
43+
HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json"
44+
45+
SM_EFA_NCCL_INSTANCES = [
46+
"ml.g4dn.8xlarge",
47+
"ml.g4dn.12xlarge",
48+
"ml.g5.48xlarge",
49+
"ml.p3dn.24xlarge",
50+
"ml.p4d.24xlarge",
51+
"ml.p4de.24xlarge",
52+
"ml.p5.48xlarge",
53+
"ml.trn1.32xlarge",
54+
]
55+
56+
SM_EFA_RDMA_INSTANCES = [
57+
"ml.p4d.24xlarge",
58+
"ml.p4de.24xlarge",
59+
"ml.trn1.32xlarge",
60+
]
61+
62+
63+
def write_failure_file(message: Optional[str] = None):
64+
"""Write a failure file with the message."""
65+
if message is None:
66+
message = DEFAULT_FAILURE_MESSAGE.format(training_job_name=os.environ["TRAINING_JOB_NAME"])
67+
if not os.path.exists(FAILURE_FILE):
68+
with open(FAILURE_FILE, "w") as f:
69+
f.write(message)
70+
71+
72+
def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON):
73+
"""Read the source code config json file."""
74+
try:
75+
with open(source_code_json, "r") as f:
76+
source_code_dict = json.load(f) or {}
77+
except FileNotFoundError:
78+
source_code_dict = {}
79+
return source_code_dict
80+
81+
82+
def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON):
83+
"""Read the distribution config json file."""
84+
try:
85+
with open(distributed_json, "r") as f:
86+
distributed_dict = json.load(f) or {}
87+
except FileNotFoundError:
88+
distributed_dict = {}
89+
return distributed_dict
90+
91+
92+
def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON):
93+
"""Read the hyperparameters config json file."""
94+
try:
95+
with open(hyperparameters_json, "r") as f:
96+
hyperparameters_dict = json.load(f) or {}
97+
except FileNotFoundError:
98+
hyperparameters_dict = {}
99+
return hyperparameters_dict
100+
101+
102+
def get_process_count(process_count: Optional[int] = None) -> int:
103+
"""Get the number of processes to run on each node in the training job."""
104+
return (
105+
process_count
106+
or int(os.environ.get("SM_NUM_GPUS", 0))
107+
or int(os.environ.get("SM_NUM_NEURONS", 0))
108+
or 1
109+
)
110+
111+
112+
def hyperparameters_to_cli_args(hyperparameters: Dict[str, Any]) -> List[str]:
113+
"""Convert the hyperparameters to CLI arguments."""
114+
cli_args = []
115+
for key, value in hyperparameters.items():
116+
value = safe_deserialize(value)
117+
cli_args.extend([f"--{key}", safe_serialize(value)])
118+
119+
return cli_args
120+
121+
122+
def safe_deserialize(data: Any) -> Any:
123+
"""Safely deserialize data from a JSON string.
124+
125+
This function handles the following cases:
126+
1. If `data` is not a string, it returns the input as-is.
127+
2. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`.
128+
3. If `data` is a string but cannot be decoded as JSON, it returns the original string.
129+
130+
Returns:
131+
Any: The deserialized data, or the original input if it cannot be JSON-decoded.
132+
"""
133+
if not isinstance(data, str):
134+
return data
135+
136+
try:
137+
return json.loads(data)
138+
except json.JSONDecodeError:
139+
return data
140+
141+
142+
def safe_serialize(data):
143+
"""Serialize the data without wrapping strings in quotes.
144+
145+
This function handles the following cases:
146+
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
147+
2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
148+
the JSON-encoded string using `json.dumps()`.
149+
3. If `data` cannot be serialized (e.g., a custom object), it returns the string
150+
representation of the data using `str(data)`.
151+
152+
Args:
153+
data (Any): The data to serialize.
154+
155+
Returns:
156+
str: The serialized JSON-compatible string or the string representation of the input.
157+
"""
158+
if isinstance(data, str):
159+
return data
160+
try:
161+
return json.dumps(data)
162+
except TypeError:
163+
return str(data)
164+
165+
166+
def get_python_executable() -> str:
167+
"""Get the python executable path."""
168+
return sys.executable
169+
170+
171+
def log_subprocess_output(pipe: IO[bytes]):
172+
"""Log the output from the subprocess."""
173+
for line in iter(pipe.readline, b""):
174+
logger.info(line.decode("utf-8").strip())
175+
176+
177+
def execute_commands(commands: List[str]) -> Tuple[int, str]:
178+
"""Execute the provided commands and return exit code with failure traceback if any."""
179+
try:
180+
process = subprocess.Popen(
181+
commands,
182+
stdout=subprocess.PIPE,
183+
stderr=subprocess.STDOUT,
184+
)
185+
with process.stdout:
186+
log_subprocess_output(process.stdout)
187+
exitcode = process.wait()
188+
if exitcode != 0:
189+
raise subprocess.CalledProcessError(exitcode, commands)
190+
return exitcode, ""
191+
except subprocess.CalledProcessError as e:
192+
# Capture the traceback in case of failure
193+
error_traceback = traceback.format_exc()
194+
print(f"Command failed with exit code {e.returncode}. Traceback: {error_traceback}")
195+
return e.returncode, error_traceback
196+
197+
198+
def is_worker_node() -> bool:
199+
"""Check if the current node is a worker node."""
200+
return os.environ.get("SM_CURRENT_HOST") != os.environ.get("SM_MASTER_ADDR")
201+
202+
203+
def is_master_node() -> bool:
204+
"""Check if the current node is the master node."""
205+
return os.environ.get("SM_CURRENT_HOST") == os.environ.get("SM_MASTER_ADDR")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Sagemaker modules container drivers - drivers directory."""
14+
from __future__ import absolute_import
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module is the entry point for the Basic Script Driver."""
14+
from __future__ import absolute_import
15+
16+
import os
17+
import sys
18+
import json
19+
import shlex
20+
21+
from pathlib import Path
22+
from typing import List
23+
24+
sys.path.insert(0, str(Path(__file__).parent.parent))
25+
26+
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
27+
logger,
28+
get_python_executable,
29+
write_failure_file,
30+
hyperparameters_to_cli_args,
31+
execute_commands,
32+
)
33+
34+
35+
def create_commands() -> List[str]:
36+
"""Create the commands to execute."""
37+
entry_script = os.environ["SM_ENTRY_SCRIPT"]
38+
hyperparameters = json.loads(os.environ["SM_HPS"])
39+
python_executable = get_python_executable()
40+
41+
args = hyperparameters_to_cli_args(hyperparameters)
42+
if entry_script.endswith(".py"):
43+
commands = [python_executable, entry_script]
44+
commands += args
45+
elif entry_script.endswith(".sh"):
46+
args_str = " ".join(shlex.quote(arg) for arg in args)
47+
commands = [
48+
"/bin/sh",
49+
"-c",
50+
f"chmod +x {entry_script} && ./{entry_script} {args_str}",
51+
]
52+
else:
53+
raise ValueError(
54+
f"Unsupported entry script type: {entry_script}. Only .py and .sh are supported."
55+
)
56+
return commands
57+
58+
59+
def main():
60+
"""Main function for the Basic Script Driver.
61+
62+
This function is the entry point for the Basic Script Driver.
63+
64+
Execution Lifecycle:
65+
1. Read the source code and hyperparameters JSON files.
66+
2. Set hyperparameters as command line arguments.
67+
3. Create the commands to execute.
68+
4. Execute the commands.
69+
"""
70+
71+
cmd = create_commands()
72+
73+
logger.info(f"Executing command: {' '.join(cmd)}")
74+
exit_code, traceback = execute_commands(cmd)
75+
if exit_code != 0:
76+
write_failure_file(traceback)
77+
sys.exit(exit_code)
78+
79+
80+
if __name__ == "__main__":
81+
main()

0 commit comments

Comments
 (0)