Skip to main content

Image segmentation with SAM 2

Facebook Research's Segment Anything 2 (SAM 2) is a state-of-the-art model for detecting objects in images and videos. In this tutorial we'll deploy the image predictor to a REST API. Given an image URL and a prompting point, the bounding box of the highest-salience object will be returned.

In this tutorial we'll walk through two different ways to deploy this model to Modelbit. First, we'll use the Ultralytics package which provides a high-performance runtime optimized for production. This is the quickest and highest-performance way to deploy.

We'll also walk through how to deploy the SAM2 package from Facebook Research's own GitHub account. This is best for those who are learning and want to follow closely with Facebook's own tutorial.

Deploying SAM2 with Ultralytics

The fastest and highest-performance way to deploy SAM2 to Modelbit is using the Ultralytics package to run SAM2. Here are steps to deploy a SAM2 model with Ultralytics from a Python environment to Modelbit. Note that a Colab notebook with at least a T4 GPU is recommended as a development environment.

Notebook setup

Start by installing Modelbit and logging in, of course:

!pip install --upgrade modelbit
import modelbit

mb = modelbit.login()

Additionally, we'll want to install the ultralytics package:

!pip install ultralytics

And import the relevant SAM code from that package:

from ultralytics import SAM

Loading the model and making predictions

To load the model, simply write:

with modelbit.setup("load_model"):
model = SAM("sam2_b.pt")

The first time you run this code, the SAM function call will download the checkpoint locally. sam2_b.pt corresponds to the base SAM2 checkpoint. Any of the checkpoints will work fine.

Note the use of with modelbit.setup(...) here. This is important so that when we deploy the model, this code is run just once at boot time, and then kept warm.

To make predictions, simply run:

def find_object(url: str, x_coord: int, y_coord: int):
results = model(url, points=[x_coord, y_coord], labels=[1])
return results[0].masks.xy[0][0]

This particular code returns the bounding box of the top-resulting found object. Explore the results object to see what other information you can return!

Deploying the SAM2 model

Finally, let's deploy this model to a REST API in production!

Here's the code:

mb.deploy(find_object, extra_files=["sam2_b.pt"], setup="load_model", require_gpu=True)

Let's walk through it. find_object is the code that runs the inference, so that's our first parameter. We want to make sure to include our checkpoint file, sam2_b.pt, so that's included in extra_files. As we mentioned, we want to run our setup code at boot time, so we include that here. And finally, we'll want a GPU for this deployment, hence the require_gpu=True code.

That's it! After you run this, click the resulting link to explore the SAM2 REST API you just deployed!

Deploying SAM2 from Facebook Research's GitHub

If you want to hew closely to Facebook Research's own GitHub package and guide, follow these instructions.

info

The SAM2 package that Facebook provides on their own GitHub includes a lot of developer tooling that makes the package much larger. This can make it more cumbersome to deploy, and slower to run in production. We recommend using Ultralytics in production. However, you can follow along below to see a deeper look at the model itself and deploy Facebook's own package if you so choose.

Notebook setup

To install into our notebook we'll follow Facebook's own installation instructions. Note that a Colab notebook with at least a T4 GPU is recommended.

First, clone the repro, install the package, and download the checkpoint files:

!git clone https://github.com/facebookresearch/segment-anything-2.git
!pip install -e segment-anything-2
!segment-anything-2/checkpoints/download_ckpts.sh

As always, make sure Modelbit is installed in the notebook, and login:

!pip install --upgrade modelbit
import modelbit

mb = modelbit.login()

And let's make sure we have our imports in order:

import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import numpy as np
import matplotlib.pyplot as plt
import urllib
import cv2
import os

Image helper code

To help us work with the model, we'll want some helper code to calculate bounding boxes, render boxes and masks on images, and so forth. Here's some general purpose image code that we'll use:

# Given the pixels of an image mask, return the mask's bounding box
def mask2boundingbox(mask):
x_min = None
x_max = None
y_min = None
y_max = None
for y, row in enumerate(mask):
for x, val in enumerate(row):
if val:
if x_min is None or x_min > x:
x_min = x
if y_min is None or y_min > y:
y_min = y
if x_max is None or x_max < x:
x_max = x
if y_max is None or y_max < y:
y_max = y
return x_min, y_min, x_max, y_max

# Render a mask in matplotlib
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)

# Render a point as a star in matplotlib
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

# Render a box in matplotlib
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0,0,0,0), lw=2))

# Render an image in matplotlib
def show_image(img, points = None, mask = np.ndarray([]), box = ()):
im = plt.figure(figsize=(10,10))
plt.imshow(img)
if points:
show_points(np.array([[points[0], points[1]]]), np.array([1]), plt.gca())
if mask.any():
show_mask(mask, plt.gca())
if box:
show_box(box, plt.gca())
plt.axis('on')
modelbit.log_image(im)
plt.show()

Setting up our predictor

Next, let's choose our checkpoint file and load it up. Here we've chosen the large version of the model, but any version will do.

Note that we wrap this code in modelbit.setup. This is important so that when we deploy this code, Modelbit runs the setup code only once, at boot time, and then leaves it warm for you.

with modelbit.setup(name="load_model"):
if 'PYTHONPATH' in os.environ:
os.environ['PYTHONPATH'] = f"{os.getcwd()}:{os.environ['PYTHONPATH']}"
else:
os.environ['PYTHONPATH'] = os.getcwd()
checkpoint = "./sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))

You'll notice the first four lines of code that add the current directory to the PYTHONPATH. SAM2 expects the parent directory of the configs to be on the PYTHONPATH, so we need to do that as part of our setup code.

The next 3 lines just load up the checkpoint file into the SAM2 predictor! We want that to run at setup time too, of course, so that inferences are nice and fast afterward.

Finding an object in an image

Now that we have our predictor, let's use it to find objects in an image! Here's the code:

def find_object(url: str, x_coord: int, y_coord: int):
url_response = urllib.request.urlopen(url)
img = cv2.imdecode(np.array(bytearray(url_response.read()), dtype=np.uint8), -1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
predictor.set_image(img)

masks, scores, logits = predictor.predict(
point_coords=np.array([[x_coord, y_coord]]),
point_labels=np.array([1]),
multimask_output=True,
)

top_score = 0
best_mask = None
for i, score in enumerate(scores):
if score > top_score:
top_score = score
best_mask = masks[i]

bbox = mask2boundingbox(best_mask)
show_image(img, (x_coord, y_coord), best_mask, bbox)

return bbox

This function downloads the image from its specified URL, performs the all-important predictor.predict(...), sorts the results by score, and finds the bounding box of the top result!

Deploying to production

Finally, let's deploy! Since find_object is the name of the method that performs the prediction, that's what we'll deploy. Here's the full call to modelbit.deploy:

mb.deploy(find_object,
extra_files={
'sam2_hiera_large.pt': 'sam2_hiera_large.pt',
'segment-anything-2/sam2_configs': 'sam2_configs'
},
python_packages=['git+https://github.com/facebookresearch/segment-anything-2.git'],
setup='load_model',
require_gpu=True)

Let's walk through this step by step. find_object is the actual code that finds the object, so that's our first parameter. We want to make sure to bring along our checkpoint file as well as model configs, so we include those as extra_files. Of course we need the SAM2 module, so that's included in python_packages. We want to run the setup code that loads the model that was in that modelbit.setup(...) block earlier, so we specify that with the setup parameter. And finally, we want a GPU for this deployment, so we set require_gpu=True.

That's it! Modelbit will take our SAM2 predictor, and deploy it behind a REST API.