Skip to content

ModelTrainer with heterogeneous cluster (instance_groups) fails — defaults.py injects instance_type/instance_count unconditionally #5555

@brunopistone

Description

@brunopistone

PySDK Version

  • PySDK V2 (2.x)
  • PySDK V3 (3.x)

Describe the bug
When creating a ModelTrainer with a Compute config that uses instance_groups (heterogeneous cluster), the Defaults.get_compute() method in sagemaker/train/defaults.py unconditionally injects default instance_type (ml.m5.xlarge) and instance_count (1) when those fields are None. With heterogeneous clusters, instance_type and instance_count are intentionally None because they are mutually exclusive with instance_groups´. This causes the CreateTrainingJob` API call to include both InstanceType/InstanceCount and InstanceGroups in the ResourceConfig, which the SageMaker API rejects with a ValidationException.

The bug is in sagemaker-train/src/sagemaker/train/defaults.py, lines 103-108:

if compute.instance_type is None:
      compute.instance_type = DEFAULT_INSTANCE_TYPE
if compute.instance_count is None:
      compute.instance_count = DEFAULT_INSTANCE_COUNT

These checks should be guarded with if not compute.instance_groups: to skip default population when instance groups are configured.

The same pattern exists in the JumpStart defaults path at lines 228-242.

To reproduce

 from sagemaker.train.configs import (
      Compute,
      InstanceGroup,
      OutputDataConfig,
      SourceCode,
      StoppingCondition,
  )
  from sagemaker.train.model_trainer import ModelTrainer

  instance_groups = [
      InstanceGroup(
          instance_group_name="head-instance-group",
          instance_type="ml.t3.large",
          instance_count=1,
      ),
      InstanceGroup(
          instance_group_name="worker-instance-group-1",
          instance_type="ml.m5.2xlarge",
          instance_count=2,
      ),
  ]

  compute_configs = Compute(
      instance_groups=instance_groups,
      keep_alive_period_in_seconds=0,
  )

  source_code = SourceCode(
      source_dir="./scripts",
      requirements="requirements.txt",
      command="python launcher.py",
  )

  model_trainer = ModelTrainer(
      training_image="<training-image-uri>",
      source_code=source_code,
      base_job_name="test-heterogeneous",
      compute=compute_configs,
      stopping_condition=StoppingCondition(max_runtime_in_seconds=18000),
      output_data_config=OutputDataConfig(
          s3_output_path="s3://my-bucket/output", compression_type="NONE"
      ),
      role="<role-arn>",
  )

  model_trainer.train(wait=False)

Expected behavior
The ResourceConfig sent to the CreateTrainingJob API should only contain InstanceGroups (without top-level InstanceType/InstanceCount):

{
    "ResourceConfig": {
      "VolumeSizeInGB": 30,
      "InstanceGroups": [
        {"InstanceType": "ml.t3.large", "InstanceCount": 1, "InstanceGroupName": "head-instance-group"},
        {"InstanceType": "ml.m5.2xlarge", "InstanceCount": 2, "InstanceGroupName": "worker-instance-group-1"}
      ]
    }
  }

Screenshots or logs

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:2                                                                                    │
│                                                                                                  │
│   1 # starting the train job with our uploaded datasets as input                                 │
│ ❱ 2 model_trainer.train(input_data_config=data, wait=False)                                      │
│   3                                                                                              │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/teleme │
│ try/telemetry_logging.py:168 in wrapper                                                          │
│                                                                                                  │
│   165 │   │   │   │   │   caught_ex = e                                                          │
│   166 │   │   │   │   finally:                                                                   │
│   167 │   │   │   │   │   if caught_ex:                                                          │
│ ❱ 168 │   │   │   │   │   │   raise caught_ex                                                    │
│   169 │   │   │   │   │   return response  # pylint: disable=W0150                               │
│   170 │   │   │   else:                                                                          │
│   171 │   │   │   │   logger.debug(                                                              │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/teleme │
│ try/telemetry_logging.py:139 in wrapper                                                          │
│                                                                                                  │
│   136 │   │   │   │   start_timer = perf_counter()                                               │
│   137 │   │   │   │   try:                                                                       │
│   138 │   │   │   │   │   # Call the original function                                           │
│ ❱ 139 │   │   │   │   │   response = func(*args, **kwargs)                                       │
│   140 │   │   │   │   │   stop_timer = perf_counter()                                            │
│   141 │   │   │   │   │   elapsed = stop_timer - start_timer                                     │
│   142 │   │   │   │   │   extra += f"&x-latency={round(elapsed, 2)}"                             │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/workfl │
│ ow/pipeline_context.py:346 in wrapper                                                            │
│                                                                                                  │
│   343 │   │   │                                                                                  │
│   344 │   │   │   return _StepArguments(retrieve_caller_name(self_instance), run_func, *args,    │
│   345 │   │                                                                                      │
│ ❱ 346 │   │   return run_func(*args, **kwargs)                                                   │
│   347 │                                                                                          │
│   348 │   return wrapper                                                                         │
│   349                                                                                            │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/pydantic/_internal/_v │
│ alidate_call.py:39 in wrapper_function                                                           │
│                                                                                                  │
│    36 │   │                                                                                      │
│    37 │   │   @functools.wraps(wrapped)                                                          │
│    38 │   │   def wrapper_function(*args, **kwargs):                                             │
│ ❱  39 │   │   │   return wrapper(*args, **kwargs)                                                │
│    40 │                                                                                          │
│    41 │   # We need to manually update this because `partial` object has no `__name__` and `__   │
│    42 │   wrapper_function.__name__ = extract_function_name(wrapped)                             │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/pydantic/_internal/_v │
│ alidate_call.py:136 in __call__                                                                  │
│                                                                                                  │
│   133 │   │   if not self.__pydantic_complete__:                                                 │
│   134 │   │   │   self._create_validators()                                                      │
│   135 │   │                                                                                      │
│ ❱ 136 │   │   res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args,   │
│   137 │   │   if self.__return_pydantic_validator__:                                             │
│   138 │   │   │   return self.__return_pydantic_validator__(res)                                 │
│   139 │   │   else:                                                                              │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/train/model │
│ _trainer.py:781 in train                                                                         │
│                                                                                                  │
│    778 │   │   │   │   self.sagemaker_session._intercept_create_request(training_request, None,  │
│    779 │   │   │   │   return                                                                    │
│    780 │   │   │                                                                                 │
│ ❱  781 │   │   │   training_job = TrainingJob.create(                                            │
│    782 │   │   │   │   session=self.sagemaker_session.boto_session,                              │
│    783 │   │   │   │   **training_request                                                        │
│    784 │   │   │   )                                                                             │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/resour │
│ ces.py:35418 in wrapper                                                                          │
│                                                                                                  │
│   35415 │   │   │   │   "tensor_board_output_config": {"s3_output_path": {"type": "string"}},    │
│   35416 │   │   │   │   "profiler_config": {"s3_output_path": {"type": "string"}},               │
│   35417 │   │   │   }                                                                            │
│ ❱ 35418 │   │   │   return create_func(                                                          │
│   35419 │   │   │   │   *args,                                                                   │
│   35420 │   │   │   │   **Base.get_updated_kwargs_with_configured_attributes(                    │
│   35421 │   │   │   │   │   config_schema_for_resource, "TrainingJob", **kwargs                  │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/resour │
│ ces.py:143 in wrapper                                                                            │
│                                                                                                  │
│     140 │   │   @functools.wraps(func)                                                           │
│     141 │   │   def wrapper(*args, **kwargs):                                                    │
│     142 │   │   │   config = dict(arbitrary_types_allowed=True)                                  │
│ ❱   143 │   │   │   return validate_call(config=config)(func)(*args, **kwargs)                   │
│     144 │   │                                                                                    │
│     145 │   │   return wrapper                                                                   │
│     146                                                                                          │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/pydantic/_internal/_v │
│ alidate_call.py:39 in wrapper_function                                                           │
│                                                                                                  │
│    36 │   │                                                                                      │
│    37 │   │   @functools.wraps(wrapped)                                                          │
│    38 │   │   def wrapper_function(*args, **kwargs):                                             │
│ ❱  39 │   │   │   return wrapper(*args, **kwargs)                                                │
│    40 │                                                                                          │
│    41 │   # We need to manually update this because `partial` object has no `__name__` and `__   │
│    42 │   wrapper_function.__name__ = extract_function_name(wrapped)                             │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/pydantic/_internal/_v │
│ alidate_call.py:136 in __call__                                                                  │
│                                                                                                  │
│   133 │   │   if not self.__pydantic_complete__:                                                 │
│   134 │   │   │   self._create_validators()                                                      │
│   135 │   │                                                                                      │
│ ❱ 136 │   │   res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args,   │
│   137 │   │   if self.__return_pydantic_validator__:                                             │
│   138 │   │   │   return self.__return_pydantic_validator__(res)                                 │
│   139 │   │   else:                                                                              │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/sagemaker/core/resour │
│ ces.py:35595 in create                                                                           │
│                                                                                                  │
│   35592 │   │   logger.debug(f"Serialized input request: {operation_input_args}")                │
│   35593 │   │                                                                                    │
│   35594 │   │   # create the resource                                                            │
│ ❱ 35595 │   │   response = client.create_training_job(**operation_input_args)                    │
│   35596 │   │   logger.debug(f"Response: {response}")                                            │
│   35597 │   │                                                                                    │
│   35598 │   │   return cls.get(training_job_name=training_job_name, session=session, region=regi │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/botocore/client.py:60 │
│ 2 in _api_call                                                                                   │
│                                                                                                  │
│    599 │   │   │   │   │   f"{py_operation_name}() only accepts keyword arguments."              │
│    600 │   │   │   │   )                                                                         │
│    601 │   │   │   # The "self" in this scope is referring to the BaseClient.                    │
│ ❱  602 │   │   │   return self._make_api_call(operation_name, kwargs)                            │
│    603 │   │                                                                                     │
│    604 │   │   _api_call.__name__ = str(py_operation_name)                                       │
│    605                                                                                           │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/botocore/context.py:1 │
│ 23 in wrapper                                                                                    │
│                                                                                                  │
│   120 │   │   │   with start_as_current_context():                                               │
│   121 │   │   │   │   if hook:                                                                   │
│   122 │   │   │   │   │   hook()                                                                 │
│ ❱ 123 │   │   │   │   return func(*args, **kwargs)                                               │
│   124 │   │                                                                                      │
│   125 │   │   return wrapper                                                                     │
│   126                                                                                            │
│                                                                                                  │
│ /Users/bpistone/miniconda3/envs/python312-sm3/lib/python3.12/site-packages/botocore/client.py:10 │
│ 78 in _make_api_call                                                                             │
│                                                                                                  │
│   1075 │   │   │   │   'error_code_override'                                                     │
│   1076 │   │   │   ) or error_info.get("Code")                                                   │
│   1077 │   │   │   error_class = self.exceptions.from_code(error_code)                           │
│ ❱ 1078 │   │   │   raise error_class(parsed_response, operation_name)                            │
│   1079 │   │   else:                                                                             │
│   1080 │   │   │   return parsed_response                                                        │
│   1081                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ClientError: An error occurred (ValidationException) when calling the CreateTrainingJob operation: InstanceType or 
InstanceCount cannot be specified with InstanceGroups

System information
A description of your system. Please provide:

  • SageMaker Python SDK version: 3.4.1
  • Framework name: PyTorch
  • Framework version: 2.8.0
  • Python version: 3.12.12
  • CPU or GPU: CPU
  • Custom Docker image: N

Additional context
Suggested fix in sagemaker-train/src/sagemaker/train/defaults.py:

# In Defaults.get_compute() (line 103)
  if not compute.instance_groups:
      if compute.instance_type is None:
          compute.instance_type = DEFAULT_INSTANCE_TYPE
      if compute.instance_count is None:
          compute.instance_count = DEFAULT_INSTANCE_COUNT

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions