diff --git a/sagemaker-train/src/sagemaker/train/defaults.py b/sagemaker-train/src/sagemaker/train/defaults.py index 5ba6ddcbb8..3c36b817ec 100644 --- a/sagemaker-train/src/sagemaker/train/defaults.py +++ b/sagemaker-train/src/sagemaker/train/defaults.py @@ -100,12 +100,15 @@ def get_compute(compute: Optional[Compute] = None) -> Compute: volume_size_in_gb=DEFAULT_VOLUME_SIZE, ) logger.info(f"Compute not provided. Using default:\n{compute}") - if compute.instance_type is None: - compute.instance_type = DEFAULT_INSTANCE_TYPE - logger.info(f"Instance type not provided. Using default:\n{DEFAULT_INSTANCE_TYPE}") - if compute.instance_count is None: - compute.instance_count = DEFAULT_INSTANCE_COUNT - logger.info(f"Instance count not provided. Using default:\n{compute.instance_count}") + if not compute.instance_groups: + if compute.instance_type is None: + compute.instance_type = DEFAULT_INSTANCE_TYPE + logger.info(f"Instance type not provided. Using default:\n{DEFAULT_INSTANCE_TYPE}") + if compute.instance_count is None: + compute.instance_count = DEFAULT_INSTANCE_COUNT + logger.info( + f"Instance count not provided. Using default:\n{compute.instance_count}" + ) if compute.volume_size_in_gb is None: compute.volume_size_in_gb = DEFAULT_VOLUME_SIZE logger.info(f"Volume size not provided. Using default:\n{compute.volume_size_in_gb}") @@ -225,11 +228,21 @@ def get_compute( ), ) logger.info(f"Compute not provided. Using default compute:\n{compute}") - if compute.instance_type is None and training_components_model.DefaultTrainingInstanceType: - compute.instance_type = training_components_model.DefaultTrainingInstanceType - logger.info( - f"Instance type not provided. Using default instance type:\n{compute.instance_type}" - ) + if not compute.instance_groups: + if ( + compute.instance_type is None + and training_components_model.DefaultTrainingInstanceType + ): + compute.instance_type = training_components_model.DefaultTrainingInstanceType + logger.info( + f"Instance type not provided. Using default instance type:" + f"\n{compute.instance_type}" + ) + if compute.instance_count is None: + compute.instance_count = DEFAULT_INSTANCE_COUNT + logger.info( + f"Instance count not provided. Using default instance count:\n{compute}" + ) if compute.volume_size_in_gb is None: compute.volume_size_in_gb = ( training_components_model.TrainingVolumeSize or DEFAULT_VOLUME_SIZE @@ -237,9 +250,6 @@ def get_compute( logger.info( f"Volume size not provided. Using default volume size:\n{compute.volume_size_in_gb}" ) - if compute.instance_count is None: - compute.instance_count = DEFAULT_INSTANCE_COUNT - logger.info(f"Instance count not provided. Using default instance count:\n{compute}") return compute def get_networking(