Skip to main content

Batch classification with XGBoost

This example builds on the flower classification with XGBoost example. In that deployment, the source.py was constructed to perform inferences one at a time by calling the model's predict function for each inference request.

In this example we'll use DataFrame mode to run inferences on whole DataFrames.

Before you begin, make sure you've cloned your Modelbit workspace.

Creating the deployment

We'll call this deployment xgb_iris_dataframe. Create the directory deployments/xgb_iris_dataframe in your Modelbit repo. All files created in this tutorial will be under this directory.

We're going to create four files under deployments/xgb_iris_dataframe/:

  • my_model.pkl: Our example model artifact. It'll be an XGBClassifier
  • source.py: The code we'll use to load and execute the model
  • requirements.txt: The list of pip packages needed in this deployment's environment
  • metadata.yaml: A configuration file that tells Modelbit how to run the deployment

Creating my_model.pkl

Run the following script within deployments/xgb_iris_dataframe. It trains and saves a simple classifier from the XGBoost getting started guide:

from xgboost import XGBClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import pickle

data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(data['data'], data['target'], test_size=.2)
bst = XGBClassifier(n_estimators=2, max_depth=2, learning_rate=1, objective='binary:logistic')
bst.fit(X_train, y_train)

with open("my_model.pkl", "wb") as f:
pickle.dump(bst, f)

You can delete this script once my_model.pkl has been created.

info

For future reference, the features in the sample Iris dataset are 4 floats (sepal_length, sepal_width, petal_length, petal_width) with an target class of 0, 1, or 2 corresponding to the Setosa, Versicolor, and Virginica species.

Creating source.py

This file contains the function that'll get called for dataframe sent to the API. By convention we'll call the main function the same name as the deployment, xgb_iris_dataframe.

deployments/xgb_iris_dataframe/source.py
import pickle
import pandas as pd

with open("my_model.pkl", "rb") as f:
my_model = pickle.load(f)

flower_class_names = ['setosa', 'versicolor', 'virginica']

# main function
def xgb_iris_dataframe(df: pd.DataFrame) -> str:
# Run predictions on the whole data frame
result_series = my_model.predict(df)

# In DataFrame mode, always return an iterable with the same length as the input DataFrame
return [flower_class_names[r] for r in result_series]

# for local testing
if __name__ == "__main__":
my_df = pd.DataFrame.from_records([
{"sepal_length": 4.9, "sepal_width": 2.4, "petal_length": 3.3, "petal_width": 1.0 },
{"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.3 }
])
print(xgb_iris_dataframe(my_df))

This deployment's code loads my_model from its pickle file and then executes inferences by passing a DataFrame to the predict function. With DataFrame mode, Modelbit's API will take care of converting the input to a DataFrame before calling your main function.

Running this file locally runs the test code at the bottom, outputting the following:

["versicolor", "setosa"]

Creating requirements.txt

In order for Modelbit to run this deployment it needs to know the packages and versions to install. These requirements get built into a Docker image that'll run source.py. In this case we need xgboost and pandas:

deployments/xgb_iris_dataframe/requirements.txt
pandas==2.2.2
xgboost==2.0.3

Creating metadata.yaml:

The last file is deployment a configuration file telling Modelbit how to create the API. In this case we need to specify that we're using DataFrame mode so the API knows how to create the DataFrame for the deployment.

We do that by specifying the DataFrame fields in dataframeModeColumns:

deployments/xgb_iris_dataframe/metadata.yaml
owner: you@company.com
runtimeInfo:
dataframeModeColumns:
- dtype: float64
example: 4.9
name: sepal_length
- dtype: float64
example: 2.4
name: sepal_width
- dtype: float64
example: 3.3
name: petal_length
- dtype: float64
example: 1.0
name: petal_width
mainFunction: xgb_iris_dataframe
mainFunctionArgs:
- df:Any
pythonVersion: "3.10"
schemaVersion: 2
tip

Make sure to update the owner field to your email address.

To validate the format of this file, run modelbit validate. It'll print out errors if something looks wrong.

At this point we're ready to send this deployment to Modelbit!

Deploy to Modelbit

Run the usual git commands to send these files to Modelbit:

git add .
git commit -m "creating xgb_iris_dataframe deployment"
git push

You'll see output like the following:

Enumerating objects: 21, done.
Counting objects: 100% (20/20), done.
Delta compression using up to 16 threads
Compressing objects: 100% (8/8), done.
Writing objects: 100% (9/9), 1.77 KiB | 1.77 MiB/s, done.
Total 9 (delta 2), reused 0 (delta 0), pack-reused 0
remote:
remote:
remote: 1 deployment had changes. Modelbit is deploying these changes to production:
remote: - xgb_iris_dataframe: https://<YOUR_WORKSPACE>/main/deployments/xgb_iris_dataframe/overview
remote:

Click the link to view your deployment in Modelbit.

Call the deployment

You can call the deployment several different ways. Sample code with your specific endpoint URL is available in the API Endpoints tab of your deployment.

When using DataFrame mode, results will come back as lists with IDs corresponding to the index in the input data.

Calling your deployment will look like the following:

Using modelbit.get_inference in a Python environment to send a DataFrame:

import modelbit
import pandas as pd

# Make your Dataframe
my_df = pd.DataFrame.from_records([
{"sepal_length": 4.9, "sepal_width": 2.4, "petal_length": 3.3, "petal_width": 1.0 },
{"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.3 }
])

# Call your deployment
modelbit.get_inference(
deployment="xgb_iris_dataframe",
workspace="<YOUR_WORKSPACE>",
region="<YOUR_REGION>",
data=my_df)

# return value
{"data": [[0, "versicolor"], [1, "setosa"]]}