Skip to main content

Image segmentation with SAM

Facebook Research's Segment Anything model is highly capable at detecting objects in images. In this tutorial we'll deploy the Segment Anything model to a REST endpoint to make an image segmentation API.

We'll follow the sample code in Segment Anythings getting started notebook to initialize the model and add some helper functions. Then we'll deploy to Modelbit.

Environment setup

Install Segment Anything from their git repository:

pip install git+https://github.com/facebookresearch/segment-anything.git

Install the packages we'll need to run Segment Anything:

pip install --extra-index-url=https://download.pytorch.org/whl/ \
torch==2.0.1+cpu torchvision==0.15.2 opencv-python==4.8.0.74 matplotlib==3.7.2

And download a Segment Anything checkpoint file:

wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

Helper functions

Now we'll define several helper functions to process image data with Segment Anything.

These functions paint segmentation mask onto the original image, so it's easy for us humans to see.

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='yellow', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

And these functions draw a bounding box around the segmentation mask:

def get_bounding_box(mask):
x_min = None
x_max = None
y_min = None
y_max = None
for y, row in enumerate(mask):
for x, val in enumerate(row):
if val:
if x_min is None or x_min > x:
x_min = x
if y_min is None or y_min > y:
y_min = y
if x_max is None or x_max < x:
x_max = x
if y_max is None or y_max < y:
y_max = y
return x_min, y_min, x_max, y_max

def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle((x0, y0),
w,
h,
edgecolor='yellow',
facecolor=(0, 0, 0, 0),
lw=2))

Load the model

Loading the the Segment Anything model involves loading the checkpoint file downloaded earlier, and initializing a SamPredictor:

from segment_anything import SamPredictor, sam_model_registry
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
predictor = SamPredictor(sam)

Make a function to segment images

As the last step before we deploy, we need to use all the previous functions to segment an image, draw the segmentation graphics onto the image, and return the bounding box of the segment.

We'll bundle this logic into a function that takes the image URL we want to segment, and the X and Y coordinates of a point inside the image where our segment should focus:

import requests

def segment_image(image_url, point_x, point_y):
# download the image into cv2
resp = requests.get(image_url)
image = np.asarray(bytearray(resp.content))
image = cv2.imdecode(image, -1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# load the image into Segment Anything and get the segment mask
predictor.set_image(image)
input_point = np.array([[point_x, point_y]])
input_label = np.array([1])
mask, score, logit = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
mask = mask[0]

# Draw the image, the point we selected, the mask, and the bounding box
fig = plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
bbox = get_bounding_box(mask)
show_box(bbox, plt.gca())
plt.show()
mb.log_image(fig) # show the segmented image in modelbit logs

# return the bounding box of our segment
return bbox

We're using mb.log_image in the above code to log images generated in the REST API to Modelbit logs.

Test and deploy the segment API

Log into Modelbit:

import modelbit

mb = modelbit.login()

And test that the segment_image function works. We'll test it with a picture of a successful shopping trip, and we'll segment out a bag of groceries:

image_url = "https://doc.modelbit.com/img/groceries.jpg"
segment_image(image_url, 400, 275)

Which we can see in our notebook:



Finally, deploy the segment_image function to Modelbit. We'll include the git repo and the versions of torch we're using in python_packages:

mb.deploy(segment_image, python_packages=[
"git+https://github.com/facebookresearch/segment-anything.git",
"torch==2.0.1+cpu",
"torchvision==0.15.2"
])

Once deployed, call the segmentation API over REST to get bounding boxes of image segments!

curl -s -XPOST "https://<your-workspace-url>/v1/segment_image/latest" \
-d '{"data": ["https://doc.modelbit.com/img/groceries.jpg", 400, 275]}' | json_pp
{
"data": [367, 166, 445, 341]
}
tip

Your Modelbit workspace may not be configured with enough resources for large image models like Segment Anything. If you hit an Out of Memory issue, contact Modelbit Support to upgrade your environment.