Skip to content

Commit

Permalink
Some sagemaker fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dfangl committed Aug 9, 2024
1 parent 6089f83 commit fe7fabb
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions moto/sagemaker/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4764,15 +4764,15 @@ def create_cluster(
return cluster.arn

def describe_cluster(self, cluster_name: str) -> Dict[str, Any]:
if cluster_name.startswith("arn:aws:sagemaker:"):
if cluster_name.startswith(f"arn:{self.partition}:sagemaker:"):
cluster_name = (cluster_name.split(":")[-1]).split("/")[-1]
cluster = self.clusters.get(cluster_name)
if not cluster:
raise ValidationError(message=f"Could not find cluster '{cluster_name}'.")
return cluster.describe()

def delete_cluster(self, cluster_name: str) -> str:
if cluster_name.startswith("arn:aws:sagemaker:"):
if cluster_name.startswith(f"arn:{self.partition}:sagemaker:"):
cluster_name = (cluster_name.split(":")[-1]).split("/")[-1]
cluster = self.clusters.get(cluster_name)
if not cluster:
Expand All @@ -4783,7 +4783,7 @@ def delete_cluster(self, cluster_name: str) -> str:
return arn

def describe_cluster_node(self, cluster_name: str, node_id: str) -> Dict[str, Any]:
if cluster_name.startswith("arn:aws:sagemaker:"):
if cluster_name.startswith(f"arn:{self.partition}:sagemaker:"):
cluster_name = (cluster_name.split(":")[-1]).split("/")[-1]
cluster = self.clusters.get(cluster_name)
if not cluster:
Expand Down Expand Up @@ -4832,7 +4832,7 @@ def list_cluster_nodes(
sort_by: Optional[str],
sort_order: Optional[str],
) -> List[ClusterNode]:
if cluster_name.startswith("arn:aws:sagemaker:"):
if cluster_name.startswith(f"arn:{self.partition}:sagemaker:"):
cluster_name = (cluster_name.split(":")[-1]).split("/")[-1]
cluster = self.clusters.get(cluster_name)
if not cluster:
Expand Down Expand Up @@ -6052,9 +6052,7 @@ def response_create(self) -> Dict[str, str]:

@staticmethod
def arn_formatter(name: str, account_id: str, region: str) -> str:
return (
f"arn:aws:sagemaker:{region}:{account_id}:model-bias-job-definition/{name}"
)
return f"arn:{get_partition(region)}:sagemaker:{region}:{account_id}:model-bias-job-definition/{name}"

@property
def summary_object(self) -> Dict[str, str]:
Expand Down

0 comments on commit fe7fabb

Please sign in to comment.