From eca938fb5caf843dea46c2773cddf70877d80268 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala Date: Thu, 19 Feb 2026 17:21:16 -0800 Subject: [PATCH] fix(train): Skip default instance_type/instance_count when instance_groups is set Guard the default injection of instance_type and instance_count in TrainDefaults.get_compute() and JumpStartTrainDefaults.get_compute() so that these values are not populated when instance_groups is configured. The SageMaker API treats instance_type/instance_count and instance_groups as mutually exclusive in ResourceConfig, and unconditionally setting defaults causes a ValidationException. Fixes #5555 --- .../src/sagemaker/train/defaults.py | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/defaults.py b/sagemaker-train/src/sagemaker/train/defaults.py index 5ba6ddcbb8..3c36b817ec 100644 --- a/sagemaker-train/src/sagemaker/train/defaults.py +++ b/sagemaker-train/src/sagemaker/train/defaults.py @@ -100,12 +100,15 @@ def get_compute(compute: Optional[Compute] = None) -> Compute: volume_size_in_gb=DEFAULT_VOLUME_SIZE, ) logger.info(f"Compute not provided. Using default:\n{compute}") - if compute.instance_type is None: - compute.instance_type = DEFAULT_INSTANCE_TYPE - logger.info(f"Instance type not provided. Using default:\n{DEFAULT_INSTANCE_TYPE}") - if compute.instance_count is None: - compute.instance_count = DEFAULT_INSTANCE_COUNT - logger.info(f"Instance count not provided. Using default:\n{compute.instance_count}") + if not compute.instance_groups: + if compute.instance_type is None: + compute.instance_type = DEFAULT_INSTANCE_TYPE + logger.info(f"Instance type not provided. Using default:\n{DEFAULT_INSTANCE_TYPE}") + if compute.instance_count is None: + compute.instance_count = DEFAULT_INSTANCE_COUNT + logger.info( + f"Instance count not provided. Using default:\n{compute.instance_count}" + ) if compute.volume_size_in_gb is None: compute.volume_size_in_gb = DEFAULT_VOLUME_SIZE logger.info(f"Volume size not provided. Using default:\n{compute.volume_size_in_gb}") @@ -225,11 +228,21 @@ def get_compute( ), ) logger.info(f"Compute not provided. Using default compute:\n{compute}") - if compute.instance_type is None and training_components_model.DefaultTrainingInstanceType: - compute.instance_type = training_components_model.DefaultTrainingInstanceType - logger.info( - f"Instance type not provided. Using default instance type:\n{compute.instance_type}" - ) + if not compute.instance_groups: + if ( + compute.instance_type is None + and training_components_model.DefaultTrainingInstanceType + ): + compute.instance_type = training_components_model.DefaultTrainingInstanceType + logger.info( + f"Instance type not provided. Using default instance type:" + f"\n{compute.instance_type}" + ) + if compute.instance_count is None: + compute.instance_count = DEFAULT_INSTANCE_COUNT + logger.info( + f"Instance count not provided. Using default instance count:\n{compute}" + ) if compute.volume_size_in_gb is None: compute.volume_size_in_gb = ( training_components_model.TrainingVolumeSize or DEFAULT_VOLUME_SIZE @@ -237,9 +250,6 @@ def get_compute( logger.info( f"Volume size not provided. Using default volume size:\n{compute.volume_size_in_gb}" ) - if compute.instance_count is None: - compute.instance_count = DEFAULT_INSTANCE_COUNT - logger.info(f"Instance count not provided. Using default instance count:\n{compute}") return compute def get_networking(