Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update JumpStart object detection example notebook to use JumpStart Model #4722

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install sagemaker ipywidgets --upgrade --quiet"
"!pip install sagemaker jupyterlab --upgrade --quiet\n",
"!pip install ipywidgets==7.6.5"
]
},
{
Expand Down Expand Up @@ -234,9 +235,7 @@
"metadata": {},
"outputs": [],
"source": [
"from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n",
"from sagemaker.model import Model\n",
"from sagemaker.predictor import Predictor\n",
"from sagemaker.jumpstart.model import JumpStartModel\n",
"from sagemaker.utils import name_from_base\n",
"\n",
"# model_version=\"*\" fetches the latest version of the model\n",
Expand All @@ -247,45 +246,17 @@
"\n",
"inference_instance_type = \"ml.p2.xlarge\"\n",
"\n",
"# Retrieve the inference docker container uri\n",
"deploy_image_uri = image_uris.retrieve(\n",
" region=None,\n",
" framework=None, # automatically inferred from model_id\n",
" image_scope=\"inference\",\n",
"# Create the SageMaker JumpStart model instance\n",
"model = JumpStartModel(\n",
" model_id=infer_model_id,\n",
" model_version=infer_model_version,\n",
" instance_type=inference_instance_type,\n",
")\n",
"\n",
"# Retrieve the inference script uri. This includes scripts for model loading, inference handling etc.\n",
"deploy_source_uri = script_uris.retrieve(\n",
" model_id=infer_model_id, model_version=infer_model_version, script_scope=\"inference\"\n",
")\n",
"\n",
"\n",
"# Retrieve the base model uri\n",
"base_model_uri = model_uris.retrieve(\n",
" model_id=infer_model_id, model_version=infer_model_version, model_scope=\"inference\"\n",
")\n",
"\n",
"\n",
"# Create the SageMaker model instance\n",
"model = Model(\n",
" image_uri=deploy_image_uri,\n",
" source_dir=deploy_source_uri,\n",
" model_data=base_model_uri,\n",
" entry_point=\"inference.py\", # entry point file in source_dir and present in deploy_source_uri\n",
" role=aws_role,\n",
" predictor_cls=Predictor,\n",
" name=endpoint_name,\n",
")\n",
"\n",
"# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,\n",
"# for being able to run inference through the sagemaker API.\n",
"base_model_predictor = model.deploy(\n",
" initial_instance_count=1,\n",
" instance_type=inference_instance_type,\n",
" predictor_cls=Predictor,\n",
" endpoint_name=endpoint_name,\n",
")"
]
Expand Down Expand Up @@ -355,8 +326,7 @@
" return query_response\n",
"\n",
"\n",
"def parse_response(query_response):\n",
" model_predictions = json.loads(query_response)\n",
"def parse_response(model_predictions):\n",
" normalized_boxes, classes, scores, labels = (\n",
" model_predictions[\"normalized_boxes\"],\n",
" model_predictions[\"classes\"],\n",
Expand Down Expand Up @@ -837,8 +807,10 @@
"outputs": [],
"source": [
"query_response = query(finetuned_predictor, pedestrian_image_file_name)\n",
"model_predictions = json.loads(query_response)\n",
"\n",
"\n",
"normalized_boxes, classes_names, confidences = parse_response(query_response)\n",
"normalized_boxes, classes_names, confidences = parse_response(model_predictions)\n",
"display_predictions(pedestrian_image_file_name, normalized_boxes, classes_names, confidences)"
]
},
Expand Down