diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py index c4fadec6a16a7..c99228a3e425b 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py @@ -335,7 +335,7 @@ def execute(self, context: Context) -> dict: if self.deferrable and self.wait_for_completion: response = self.hook.describe_processing_job(self.config["ProcessingJobName"]) status = response["ProcessingJobStatus"] - if status in self.hook.failed_states: + if status in self.hook.processing_job_failed_states: raise AirflowException(f"SageMaker job failed because {response['FailureReason']}") if status == "Completed": self.log.info("%s completed successfully.", self.task_id) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_processing.py b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_processing.py index a18058c745e42..3bddcf88f3c0d 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_processing.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_processing.py @@ -319,6 +319,34 @@ def test_operator_failed_before_defer( assert not mock_defer.called + @mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator.defer") + @mock.patch.object( + SageMakerHook, + "describe_processing_job", + return_value={"ProcessingJobStatus": "Stopped", "FailureReason": "It stopped"}, + ) + @mock.patch.object( + SageMakerHook, + "create_processing_job", + return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, + ) + @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False) + def test_operator_stopped_before_defer( + self, + mock_job_exists, + mock_processing, + mock_describe, + mock_defer, + ): + sagemaker_operator = SageMakerProcessingOperator( + **self.defer_processing_config_kwargs, + config=CREATE_PROCESSING_PARAMS, + ) + with pytest.raises(AirflowException): + sagemaker_operator.execute(context=None) + + assert not mock_defer.called + @mock.patch.object( SageMakerHook, "describe_processing_job",