Skip to main content

Image Segmentation with SAM

Facebook Research's Segment Anything (SAM) model is highly capable at detecting objects in images.

In this tutorial, we'll use this model to demonstrate how to deploy a computer vision model to a containerized REST endpoint with Modelbit using Git.

Setup

Ensure that you already have a Modelbit account, and that you've cloned the repo locally using modelbit clone.

Segment Anything depends on a few Python packages that you'll want to install locally for local testing.

First, there's the Segment Anything library itself, from Facebook Research.

pip install git+https://github.com/facebookresearch/segment-anything.git

You'll want versions of torch, torchvision, opencv-python and matplotlib appropriate to your system. These versions work well on CPU:

pip install --extra-index-url=https://download.pytorch.org/whl/ torch==2.0.1+cpu torchvision==0.15.2 opencv-python==4.8.0.74 matplotlib==3.7.2

Finally, let's create the deployment directory. From the root modelbit directory of your clone, run:

mkdir deployments/find_object
wget -P deployments/find_object https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

That second line will pull down the SAM checkpoint file that we'll use to instantiate the model.

Source Code

Like all computer vision models, using SAM requires some image manipulation code for converting images to the right format, calculating bounding boxes from masks and so forth.

For convenience let's put that code into a source file called image_utils.py in the find_object deployment directory:

deployments/find_object/image_utils.py
import numpy as np
import matplotlib.pyplot as plt

def mask2boundingbox(mask):
x_min = None
x_max = None
y_min = None
y_max = None
for y, row in enumerate(mask):
for x, val in enumerate(row):
if val:
if x_min is None or x_min > x:
x_min = x
if y_min is None or y_min > y:
y_min = y
if x_max is None or x_max < x:
x_max = x
if y_max is None or y_max < y:
y_max = y
return x_min, y_min, x_max, y_max

def show_image(img, points = None, mask = np.ndarray([]), box = ()):
im = plt.figure(figsize=(10,10))
plt.imshow(img)
if points:
show_points(np.array([[points[0], points[1]]]), np.array([1]), plt.gca())
if mask.any():
show_mask(mask, plt.gca())
if box:
show_box(box, plt.gca())
plt.axis('on')
modelbit.log_image(im)
plt.show()

def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)

def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0,0,0,0), lw=2))

Additionally, we'll want a file source.py that contains the entry-point for the inference code. That code will load the model, perform the inference, manipulate the inputs and outputs as desired, and return the result:

deployments/find_object/source.py
# All the code outside of a function is run once, at boot time, and does not impact inference performance
import modelbit, sys, os
from segment_anything.predictor import SamPredictor
import urllib
import cv2
import numpy as np

from segment_anything import SamPredictor, sam_model_registry
from image_utils import mask2boundingbox, show_image

# For optimal performance make sure this isn't inside `find_object`!
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
predictor = SamPredictor(sam)

# find_object is called directly at inference time in production
def find_object(url: str, x_coord: int, y_coord: int):
url_response = urllib.request.urlopen(url)
img = cv2.imdecode(
np.array(bytearray(url_response.read()), dtype=np.uint8), -1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
predictor.set_image(img)

masks, scores, logits = predictor.predict(
point_coords=np.array([[x_coord, y_coord]]),
point_labels=np.array([1]),
multimask_output=True,
)

top_score = 0
best_mask = None
for i, score in enumerate(scores):
if score > top_score:
top_score = score
best_mask = masks[i]

bbox = mask2boundingbox(best_mask)
show_image(img, (x_coord, y_coord), best_mask, bbox)

return bbox

# This code only runs during local testing.
if __name__ == '__main__':
print(find_object(sys.argv[1], int(sys.argv[2]), int(sys.argv[3])))

At this point you should be able to test this code locally and see it return an object's bounding box! For example:

$ python3 source.py https://montereyzoo.org/wp-content/uploads/2017/04/big-cats-6.jpg 225 150
(183, 43, 558, 326)

Performance will of course be dependent on local system resources, but don't be discouraged if it's slow! This model is quite speedy in Modelbit production.

Configuration

Before deploying we'll want to add two more files to configure Modelbit's Python environment and REST API.

The first is a standard requirements.txt file with the model's dependencies. For production we'll use torch and torchvision with up-to-date CUDA drivers:

deployments/find_object/requirements.txt
--extra-index-url=https://download.pytorch.org/whl/
git+https://github.com/facebookresearch/segment-anything.git
matplotlib==3.7.1
numpy==1.25.2
opencv-python-headless==4.10.0.82
torch==2.3.0+cu121
torchvision==0.18.0+cu121

The second is metadata.yaml which configures Modelbit's REST API. Here's a basic one that'll do the job:

deployments/find_object/metadata.yaml
owner: <your email address>
runtimeInfo:
capabilities:
- gpu=T4
mainFunction: find_object
mainFunctionArgs:
- url:str
- x_coord:int
- y_coord:int
pythonVersion: "3.10"
systemPackages:
- git
- python3-opencv
schemaVersion: 2

Deployment

Now that you have all five files (sam_vit_b_01ec64.pth image_utils.py, source.py, requirements.txt and metadata.yaml) in your deployments/find_object directory, go ahead and commit them to Git.

To begin, a quick git status should show you that the directory is not added yet:

$ git status
On branch main
Your branch is up to date with 'origin/main'.

Untracked files:
(use "git add <file>..." to include in what will be committed)
deployments/find_object/

nothing added to commit but untracked files present (use "git add" to track)

Go ahead and add the files:

$ git add deployments/find_object/
Encrypting 'deployments/find_object/sam_vit_b_01ec64.pth': 386MB [00:01, 380MB/s]
Uploading 'deployments/find_object/sam_vit_b_01ec64.pth': 100%|████████████████████████| 347M/347M [00:08<00:00, 41.9MB/s]

Notice the uploader for the big checkpoint file! Because you ran modelbit clone, Modelbit is uploading the file to high-performance model storage and replacing it with a stub in the repo itself. This will be transparent to you while providing optimal performance.

Your git status should now look like this:

$ git status
On branch main
Your branch is up to date with 'origin/main'.

Changes to be committed:
(use "git restore --staged <file>..." to unstage)
new file: deployments/find_object/image_utils.py
new file: deployments/find_object/metadata.yaml
new file: deployments/find_object/requirements.txt
new file: deployments/find_object/sam_vit_b_01ec64.pth
new file: deployments/find_object/source.py

Go ahead and commit:

$ git commit -m "adding the object finder model"

Validating deployments...
✅ find_object
✅ predict_weather

[main a3ade0b] adding the object finder model
5 files changed, 124 insertions(+)
create mode 100644 deployments/find_object/image_utils.py
create mode 100644 deployments/find_object/metadata.yaml
create mode 100644 deployments/find_object/requirements.txt
create mode 100644 deployments/find_object/sam_vit_b_01ec64.pth
create mode 100644 deployments/find_object/source.py

As you can see, Modelbit validates client-side that the deployments are correctly configured before you push. In this case, if you already have a getting-started deployment like predict_weather, that one is being re-validated as well.

Finally, to deploy the model, go ahead and push:

$ git push
Connected to workspace <your-workspace>
Enumerating objects: 11, done.
Counting objects: 100% (11/11), done.
Delta compression using up to 12 threads
Compressing objects: 100% (9/9), done.
Writing objects: 100% (9/9), 2.28 KiB | 2.28 MiB/s, done.
Total 9 (delta 2), reused 0 (delta 0), pack-reused 0
remote:
remote: 1 deployment had changes. Modelbit is deploying these changes to production:
remote: - find_object: https://<your-cluster>.modelbit.com/w/<your-workspace>/main/deployments/find-object/overview
remote:
To git.modelbit.com:your-workspace
2fa58a2..a3ade0b main -> main

You'll notice that Modelbit detected a new deployment and deployed it! Click the link it outputted to see the deployment booting and running, and perform your first inferences!