Exporting Scikit-learn Models to ONNX

Exporting Scikit-learn models to ONNX (Open Neural Network Exchange) allows for seamless integration with other tools and frameworks outside Python, including deployment in edge devices, web services, and mobile applications.

Key Characteristics

  • Converts Scikit-learn models into an interoperable format
  • Facilitates deployment in non-Python environments
  • ONNX format is supported by multiple frameworks (e.g., ONNX Runtime, TensorFlow, Caffe2)
  • Lightweight and optimized for inference

Basic Rules

  • Ensure ONNX and skl2onnx packages are installed
  • Only trained models can be converted
  • Input data type and shape must be explicitly defined
  • Use ONNX-compatible Scikit-learn models

Syntax Table

SL NO Technique Syntax Example Description
1 Install Packages pip install onnx skl2onnx Installs required conversion libraries
2 Import Modules from skl2onnx import convert_sklearn Imports converter function
3 Define Input Type initial_type = [('float_input', FloatTensorType([None, 4]))] Defines input format for the model
4 Convert Model onnx_model = convert_sklearn(model, initial_types=initial_type) Converts model to ONNX format
5 Save Model to File with open('model.onnx', 'wb') as f: f.write(onnx_model.SerializeToString()) Saves model to disk

Syntax Explanation

1. Install Packages

What is it?
Installs the necessary packages to convert and handle ONNX models.

Syntax:

pip install onnx skl2onnx

Explanation:

  • onnx: Core ONNX specification library for handling ONNX models.
  • skl2onnx: Used to convert trained Scikit-learn models into ONNX format.
  • Must be installed before conversion can begin.

2. Import Modules

What is it?
Imports the ONNX converter from skl2onnx.

Syntax:

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

Explanation:

  • convert_sklearn: The main function to convert Scikit-learn models.
  • FloatTensorType: Used to define the data type and shape of input for the ONNX model.
  • Necessary for model compatibility with ONNX runtime environments.

3. Define Input Type

What is it?
Defines the input signature of the Scikit-learn model.

Syntax:

initial_type = [('float_input', FloatTensorType([None, 4]))]

Explanation:

  • Describes input as a tensor with dynamic batch size (None) and 4 features.
  • Ensures the converter knows the expected input shape and type.
  • Must match the shape of your training data.

4. Convert Model

What is it?
Converts a trained Scikit-learn model to ONNX format.

Syntax:

onnx_model = convert_sklearn(model, initial_types=initial_type)

Explanation:

  • model is the fitted Scikit-learn model.
  • Uses initial_types to guide the conversion process.
  • Output is an ONNX model object that can be saved or deployed.

5. Save Model to File

What is it?
Serializes the ONNX model to a binary file.

Syntax:

with open("model.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

Explanation:

  • Converts ONNX object into a byte string using .SerializeToString().
  • open(..., 'wb') writes the byte stream to disk.
  • Creates a portable .onnx file for deployment.

Real-Life Project: Iris Classifier Export

Project Overview

Train a simple Iris classification model and export it as an ONNX file for deployment.

Code Example

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
import onnx

# Load dataset
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Train model
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)

# Convert to ONNX
initial_type = [('float_input', FloatTensorType([None, 4]))]
onnx_model = convert_sklearn(model, initial_types=initial_type)

# Save to file
with open("iris_model.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

Expected Output

  • An ONNX file named iris_model.onnx containing the serialized Logistic Regression model

Common Mistakes to Avoid

  • ❌ Forgetting to install skl2onnx before converting
  • ❌ Incorrect input shape or type in initial_type
  • ❌ Trying to convert unfitted models

Further Reading Recommendation

📘 Hands-On Python and Scikit-Learn: A Practical Guide to Machine Learning by Sarful Hassan

🔗 Available on Amazon