Make predictions with scikit-learn models in ONNX format


Open Neural Network Exchange (ONNX) provides a uniform format designed to represent any machine learning frameworks. BigQuery ML support for ONNX allows you to:

  • Train a model using your favorite framework.
  • Convert the model into ONNX model format. For more information, see Converting to ONNX format.
  • Import the ONNX model into BigQuery and make predictions using BigQuery ML.

This tutorial shows you how to import ONNX models trained with scikit-learn into a BigQuery dataset and use them to make predictions from a SQL query. You can import ONNX models using these interfaces:

For more information about importing ONNX models into BigQuery, including format and storage requirements, see The CREATE MODEL statement for importing ONNX models.


In this tutorial, you will:

  • Create and train models with scikit-learn.
  • Convert the models to ONNX format using sklearn-onnx.
  • Import the ONNX models into BigQuery and make predictions.

Train a classification model with scikit-learn

Create and train a scikit-learn pipeline on the Iris dataset:

import numpy
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier

data = load_iris()
X =[:, :4]
y =

ind = numpy.arange(X.shape[0])
X = X[ind, :].copy()
y = y[ind].copy()

pipe = Pipeline([('scaler', StandardScaler()),
                ('clr', RandomForestClassifier())]), y)

Convert the model into ONNX format and save

Use sklearn-onnx to convert the scikit-learn pipeline into an ONNX model named pipeline_rf.onnx:

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

# Disable zipmap as it is not supported in BigQuery ML.
options = {id(pipe): {'zipmap': False}}

# Define input features. scikit-learn does not store information about the
# training dataset. It is not always possible to retrieve the number of features
# or their types. That's why the function needs another argument called initial_types.
initial_types = [
   ('sepal_length', FloatTensorType([None, 1])),
   ('sepal_width', FloatTensorType([None, 1])),
   ('petal_length', FloatTensorType([None, 1])),
   ('petal_width', FloatTensorType([None, 1])),

# Convert the model.
model_onnx = convert_sklearn(
   pipe, 'pipeline_rf', initial_types=initial_types, options=options

# And save.
with open('pipeline_rf.onnx', 'wb') as f:

Upload the ONNX model to Cloud Storage

Create a Cloud Storage bucket to store the ONNX model file, and then upload the saved ONNX model file to your Cloud Storage bucket. For more information, see Upload objects from a filesystem.

Import the ONNX model into BigQuery

This step assumes you have uploaded the ONNX model to your Cloud Storage bucket. An example model is stored at gs://cloud-samples-data/bigquery/ml/onnx/pipeline_rf.onnx.


  1. In the Google Cloud console, go to the BigQuery page.

  2. In the query editor, enter a CREATE MODEL statement like the following.

     CREATE OR REPLACE MODEL `mydataset.mymodel`

    For example:

     CREATE OR REPLACE MODEL `example_dataset.imported_onnx_model`

    The preceding query imports the ONNX model located at gs://cloud-samples-data/bigquery/ml/onnx/pipeline_rf.onnx as a BigQuery model named imported_onnx_model.

  3. Your new model should now appear in the Resources panel. As you expand each of the datasets in a project, models are listed along with the other BigQuery resources in the datasets. Models are indicated by the model icon: model icon .

  4. If you select the new model in the Resources panel, information about the model appears below the Query editor.

    onnx model info


To import an ONNX model from Cloud Storage, run a batch query by entering a command like the following:

bq query \
--use_legacy_sql=false \

For example:

bq query --use_legacy_sql=false \

After importing the model, it should appear in the output of bq ls [dataset_name]:

$ bq ls example_dataset

       tableId          Type    Labels   Time Partitioning
 --------------------- ------- -------- -------------------
  imported_onnx_model   MODEL


Insert a new job and populate the jobs#configuration.query property as in the following request body:

  "query": "CREATE MODEL project_id:mydataset.mymodel OPTIONS(MODEL_TYPE='ONNX' MODEL_PATH='gs://bucket/path/to/onnx_model/*')"

Make predictions with the imported ONNX model


  1. In the Google Cloud console, go to the BigQuery page.

  2. In the query editor, enter a query using ML.PREDICT like the following.

     SELECT *
       FROM ML.PREDICT(MODEL example_dataset.imported_onnx_model,
          SELECT * FROM bigquery-public-data.ml_datasets.iris

    The preceding query uses the model named imported_onnx_model in the dataset example_dataset in the current project to make predictions from input data in the public table iris from the dataset ml_datasets in the project bigquery-public-data. In this case, the ONNX model expects four float inputs: sepal_length, sepal_width, petal_length, petal_width which match the initial_types defined in step 2 , so the subquery SELECT the whole bigquery-public-data table which contains these 4 input columns.

    The model outputs the column label and probabilities, as well as the columns from the input table.

    • label represents the predicted class label.
    • probabilities is an array of probabilities representing probabilities for each class.

    The query result is similar to following:

    Query results


To make predictions from input data in the table input_data, enter a command like the following, using the imported ONNX model my_model:

bq query \
--use_legacy_sql=false \
   MODEL `my_project.my_dataset.my_model`,
   (SELECT * FROM input_data))'

For example:

bq query \
--use_legacy_sql=false \
  MODEL `example_dataset.imported_onnx_model`,
  (SELECT * FROM `bigquery-public-data.ml_datasets.iris`))'


Insert a new job and populate the jobs#configuration.query property as in the following request body:

  "query": "SELECT * FROM ML.PREDICT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM input_data))"

