-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
PySDK Version
- PySDK V2 (2.x)
- PySDK V3 (3.x)
Describe the bug
When creating a ModelTrainer with a Compute config that uses instance_groups (heterogeneous cluster), the Defaults.get_compute() method in sagemaker/train/defaults.py unconditionally injects default instance_type (ml.m5.xlarge) and instance_count (1) when those fields are None. With heterogeneous clusters, instance_type and instance_count are intentionally None because they are mutually exclusive with instance_groups´. This causes the CreateTrainingJob` API call to include both InstanceType/InstanceCount and InstanceGroups in the ResourceConfig, which the SageMaker API rejects with a ValidationException.
The bug is in sagemaker-train/src/sagemaker/train/defaults.py, lines 103-108:
if compute.instance_type is None:
compute.instance_type = DEFAULT_INSTANCE_TYPE
if compute.instance_count is None:
compute.instance_count = DEFAULT_INSTANCE_COUNT
These checks should be guarded with if not compute.instance_groups: to skip default population when instance groups are configured.
The same pattern exists in the JumpStart defaults path at lines 228-242.
To reproduce
from sagemaker.train.configs import (
Compute,
InstanceGroup,
OutputDataConfig,
SourceCode,
StoppingCondition,
)
from sagemaker.train.model_trainer import ModelTrainer
instance_groups = [
InstanceGroup(
instance_group_name="head-instance-group",
instance_type="ml.t3.large",
instance_count=1,
),
InstanceGroup(
instance_group_name="worker-instance-group-1",
instance_type="ml.m5.2xlarge",
instance_count=2,
),
]
compute_configs = Compute(
instance_groups=instance_groups,
keep_alive_period_in_seconds=0,
)
source_code = SourceCode(
source_dir="./scripts",
requirements="requirements.txt",
command="python launcher.py",
)
model_trainer = ModelTrainer(
training_image="<training-image-uri>",
source_code=source_code,
base_job_name="test-heterogeneous",
compute=compute_configs,
stopping_condition=StoppingCondition(max_runtime_in_seconds=18000),
output_data_config=OutputDataConfig(
s3_output_path="s3://my-bucket/output", compression_type="NONE"
),
role="<role-arn>",
)
model_trainer.train(wait=False)
Expected behavior
The ResourceConfig sent to the CreateTrainingJob API should only contain InstanceGroups (without top-level InstanceType/InstanceCount):
{
"ResourceConfig": {
"VolumeSizeInGB": 30,
"InstanceGroups": [
{"InstanceType": "ml.t3.large", "InstanceCount": 1, "InstanceGroupName": "head-instance-group"},
{"InstanceType": "ml.m5.2xlarge", "InstanceCount": 2, "InstanceGroupName": "worker-instance-group-1"}
]
}
}
Screenshots or logs
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:2 │
│ │
│ 1 # starting the train job with our uploaded datasets as input │
│ ❱ 2 model_trainer.train(input_data_config=data, wait=False) │
│ 3 │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/teleme │
│ try/telemetry_logging.py:168 in wrapper │
│ │
│ 165 │ │ │ │ │ caught_ex = e │
│ 166 │ │ │ │ finally: │
│ 167 │ │ │ │ │ if caught_ex: │
│ ❱ 168 │ │ │ │ │ │ raise caught_ex │
│ 169 │ │ │ │ │ return response # pylint: disable=W0150 │
│ 170 │ │ │ else: │
│ 171 │ │ │ │ logger.debug( │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/teleme │
│ try/telemetry_logging.py:139 in wrapper │
│ │
│ 136 │ │ │ │ start_timer = perf_counter() │
│ 137 │ │ │ │ try: │
│ 138 │ │ │ │ │ # Call the original function │
│ ❱ 139 │ │ │ │ │ response = func(*args, **kwargs) │
│ 140 │ │ │ │ │ stop_timer = perf_counter() │
│ 141 │ │ │ │ │ elapsed = stop_timer - start_timer │
│ 142 │ │ │ │ │ extra += f"&x-latency={round(elapsed, 2)}" │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/workfl │
│ ow/pipeline_context.py:346 in wrapper │
│ │
│ 343 │ │ │ │
│ 344 │ │ │ return _StepArguments(retrieve_caller_name(self_instance), run_func, *args, │
│ 345 │ │ │
│ ❱ 346 │ │ return run_func(*args, **kwargs) │
│ 347 │ │
│ 348 │ return wrapper │
│ 349 │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/pydantic/_internal/_v │
│ alidate_call.py:39 in wrapper_function │
│ │
│ 36 │ │ │
│ 37 │ │ @functools.wraps(wrapped) │
│ 38 │ │ def wrapper_function(*args, **kwargs): │
│ ❱ 39 │ │ │ return wrapper(*args, **kwargs) │
│ 40 │ │
│ 41 │ # We need to manually update this because `partial` object has no `__name__` and `__ │
│ 42 │ wrapper_function.__name__ = extract_function_name(wrapped) │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/pydantic/_internal/_v │
│ alidate_call.py:136 in __call__ │
│ │
│ 133 │ │ if not self.__pydantic_complete__: │
│ 134 │ │ │ self._create_validators() │
│ 135 │ │ │
│ ❱ 136 │ │ res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, │
│ 137 │ │ if self.__return_pydantic_validator__: │
│ 138 │ │ │ return self.__return_pydantic_validator__(res) │
│ 139 │ │ else: │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/train/model │
│ _trainer.py:781 in train │
│ │
│ 778 │ │ │ │ self.sagemaker_session._intercept_create_request(training_request, None, │
│ 779 │ │ │ │ return │
│ 780 │ │ │ │
│ ❱ 781 │ │ │ training_job = TrainingJob.create( │
│ 782 │ │ │ │ session=self.sagemaker_session.boto_session, │
│ 783 │ │ │ │ **training_request │
│ 784 │ │ │ ) │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/resour │
│ ces.py:35418 in wrapper │
│ │
│ 35415 │ │ │ │ "tensor_board_output_config": {"s3_output_path": {"type": "string"}}, │
│ 35416 │ │ │ │ "profiler_config": {"s3_output_path": {"type": "string"}}, │
│ 35417 │ │ │ } │
│ ❱ 35418 │ │ │ return create_func( │
│ 35419 │ │ │ │ *args, │
│ 35420 │ │ │ │ **Base.get_updated_kwargs_with_configured_attributes( │
│ 35421 │ │ │ │ │ config_schema_for_resource, "TrainingJob", **kwargs │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/resour │
│ ces.py:143 in wrapper │
│ │
│ 140 │ │ @functools.wraps(func) │
│ 141 │ │ def wrapper(*args, **kwargs): │
│ 142 │ │ │ config = dict(arbitrary_types_allowed=True) │
│ ❱ 143 │ │ │ return validate_call(config=config)(func)(*args, **kwargs) │
│ 144 │ │ │
│ 145 │ │ return wrapper │
│ 146 │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/pydantic/_internal/_v │
│ alidate_call.py:39 in wrapper_function │
│ │
│ 36 │ │ │
│ 37 │ │ @functools.wraps(wrapped) │
│ 38 │ │ def wrapper_function(*args, **kwargs): │
│ ❱ 39 │ │ │ return wrapper(*args, **kwargs) │
│ 40 │ │
│ 41 │ # We need to manually update this because `partial` object has no `__name__` and `__ │
│ 42 │ wrapper_function.__name__ = extract_function_name(wrapped) │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/pydantic/_internal/_v │
│ alidate_call.py:136 in __call__ │
│ │
│ 133 │ │ if not self.__pydantic_complete__: │
│ 134 │ │ │ self._create_validators() │
│ 135 │ │ │
│ ❱ 136 │ │ res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, │
│ 137 │ │ if self.__return_pydantic_validator__: │
│ 138 │ │ │ return self.__return_pydantic_validator__(res) │
│ 139 │ │ else: │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/resour │
│ ces.py:35595 in create │
│ │
│ 35592 │ │ logger.debug(f"Serialized input request: {operation_input_args}") │
│ 35593 │ │ │
│ 35594 │ │ # create the resource │
│ ❱ 35595 │ │ response = client.create_training_job(**operation_input_args) │
│ 35596 │ │ logger.debug(f"Response: {response}") │
│ 35597 │ │ │
│ 35598 │ │ return cls.get(training_job_name=training_job_name, session=session, region=regi │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/botocore/client.py:60 │
│ 2 in _api_call │
│ │
│ 599 │ │ │ │ │ f"{py_operation_name}() only accepts keyword arguments." │
│ 600 │ │ │ │ ) │
│ 601 │ │ │ # The "self" in this scope is referring to the BaseClient. │
│ ❱ 602 │ │ │ return self._make_api_call(operation_name, kwargs) │
│ 603 │ │ │
│ 604 │ │ _api_call.__name__ = str(py_operation_name) │
│ 605 │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/botocore/context.py:1 │
│ 23 in wrapper │
│ │
│ 120 │ │ │ with start_as_current_context(): │
│ 121 │ │ │ │ if hook: │
│ 122 │ │ │ │ │ hook() │
│ ❱ 123 │ │ │ │ return func(*args, **kwargs) │
│ 124 │ │ │
│ 125 │ │ return wrapper │
│ 126 │
│ │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/botocore/client.py:10 │
│ 78 in _make_api_call │
│ │
│ 1075 │ │ │ │ 'error_code_override' │
│ 1076 │ │ │ ) or error_info.get("Code") │
│ 1077 │ │ │ error_class = self.exceptions.from_code(error_code) │
│ ❱ 1078 │ │ │ raise error_class(parsed_response, operation_name) │
│ 1079 │ │ else: │
│ 1080 │ │ │ return parsed_response │
│ 1081 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ClientError: An error occurred (ValidationException) when calling the CreateTrainingJob operation: InstanceType or
InstanceCount cannot be specified with InstanceGroups
System information
A description of your system. Please provide:
- SageMaker Python SDK version: 3.4.1
- Framework name: PyTorch
- Framework version: 2.8.0
- Python version: 3.12.12
- CPU or GPU: CPU
- Custom Docker image: N
Additional context
Suggested fix in sagemaker-train/src/sagemaker/train/defaults.py:
# In Defaults.get_compute() (line 103)
if not compute.instance_groups:
if compute.instance_type is None:
compute.instance_type = DEFAULT_INSTANCE_TYPE
if compute.instance_count is None:
compute.instance_count = DEFAULT_INSTANCE_COUNT