2018-11-19 16:48:22 -08:00
|
|
|
|
|
|
|
|
========
|
|
|
|
|
Tutorial
|
|
|
|
|
========
|
|
|
|
|
|
|
|
|
|
*ONNX Runtime* provides an easy way to run
|
|
|
|
|
machine learned models with high performance on CPU or GPU
|
|
|
|
|
without dependencies on the training framework.
|
|
|
|
|
Machine learning frameworks are usually optimized for
|
|
|
|
|
batch training rather than for prediction, which is a
|
|
|
|
|
more common scenario in applications, sites, and services.
|
|
|
|
|
At a high level, you can:
|
|
|
|
|
|
|
|
|
|
1. Train a model using your favorite framework.
|
|
|
|
|
2. Convert or export the model into ONNX format.
|
|
|
|
|
See `ONNX Tutorials <https://github.com/onnx/tutorials>`_
|
|
|
|
|
for more details.
|
|
|
|
|
3. Load and run the model using *ONNX Runtime*.
|
|
|
|
|
|
|
|
|
|
In this tutorial, we will briefly create a
|
|
|
|
|
pipeline with *scikit-learn*, convert it into
|
|
|
|
|
ONNX format and run the first predictions.
|
|
|
|
|
|
2019-12-21 00:25:45 +01:00
|
|
|
.. _l-logreg-example:
|
|
|
|
|
|
2018-11-19 16:48:22 -08:00
|
|
|
Step 1: Train a model using your favorite framework
|
|
|
|
|
+++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
|
|
|
|
|
|
|
|
We'll use the famous iris datasets.
|
|
|
|
|
|
2018-12-05 19:12:25 +01:00
|
|
|
.. runpython::
|
|
|
|
|
:showcode:
|
|
|
|
|
:store:
|
|
|
|
|
:warningout: ImportWarning FutureWarning
|
2018-11-19 16:48:22 -08:00
|
|
|
|
|
|
|
|
from sklearn.datasets import load_iris
|
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
iris = load_iris()
|
|
|
|
|
X, y = iris.data, iris.target
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y)
|
|
|
|
|
|
|
|
|
|
from sklearn.linear_model import LogisticRegression
|
|
|
|
|
clr = LogisticRegression()
|
|
|
|
|
clr.fit(X_train, y_train)
|
2018-12-05 19:12:25 +01:00
|
|
|
print(clr)
|
2018-11-19 16:48:22 -08:00
|
|
|
|
|
|
|
|
Step 2: Convert or export the model into ONNX format
|
|
|
|
|
++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
|
|
|
|
|
|
|
|
`ONNX <https://github.com/onnx/onnx>`_ is a format to describe
|
|
|
|
|
the machine learned model.
|
|
|
|
|
It defines a set of commonly used operators to compose models.
|
|
|
|
|
There are `tools <https://github.com/onnx/tutorials>`_
|
|
|
|
|
to convert other model formats into ONNX. Here we will use
|
|
|
|
|
`ONNXMLTools <https://github.com/onnx/onnxmltools>`_.
|
|
|
|
|
|
2018-12-05 19:12:25 +01:00
|
|
|
.. runpython::
|
|
|
|
|
:showcode:
|
|
|
|
|
:restore:
|
|
|
|
|
:store:
|
|
|
|
|
:warningout: ImportWarning FutureWarning
|
2018-11-19 16:48:22 -08:00
|
|
|
|
2019-01-11 12:41:42 +01:00
|
|
|
from skl2onnx import convert_sklearn
|
|
|
|
|
from skl2onnx.common.data_types import FloatTensorType
|
2018-11-19 16:48:22 -08:00
|
|
|
|
2019-12-21 00:25:45 +01:00
|
|
|
initial_type = [('float_input', FloatTensorType([None, 4]))]
|
2018-11-19 16:48:22 -08:00
|
|
|
onx = convert_sklearn(clr, initial_types=initial_type)
|
2019-01-11 12:41:42 +01:00
|
|
|
with open("logreg_iris.onnx", "wb") as f:
|
|
|
|
|
f.write(onx.SerializeToString())
|
2018-11-19 16:48:22 -08:00
|
|
|
|
|
|
|
|
Step 3: Load and run the model using ONNX Runtime
|
|
|
|
|
+++++++++++++++++++++++++++++++++++++++++++++++++
|
|
|
|
|
|
|
|
|
|
We will use *ONNX Runtime* to compute the predictions
|
|
|
|
|
for this machine learning model.
|
|
|
|
|
|
2018-12-05 19:12:25 +01:00
|
|
|
.. runpython::
|
|
|
|
|
:showcode:
|
|
|
|
|
:restore:
|
|
|
|
|
:store:
|
2018-11-19 16:48:22 -08:00
|
|
|
|
2018-12-05 19:12:25 +01:00
|
|
|
import numpy
|
2018-11-19 16:48:22 -08:00
|
|
|
import onnxruntime as rt
|
2018-12-05 19:12:25 +01:00
|
|
|
|
2021-11-30 15:26:10 -08:00
|
|
|
sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers())
|
2018-12-05 19:12:25 +01:00
|
|
|
input_name = sess.get_inputs()[0].name
|
|
|
|
|
pred_onx = sess.run(None, {input_name: X_test.astype(numpy.float32)})[0]
|
|
|
|
|
print(pred_onx)
|
|
|
|
|
|
|
|
|
|
The code can be changed to get one specific output
|
|
|
|
|
by specifying its name into a list.
|
|
|
|
|
|
|
|
|
|
.. runpython::
|
|
|
|
|
:showcode:
|
|
|
|
|
:restore:
|
|
|
|
|
|
|
|
|
|
import numpy
|
|
|
|
|
import onnxruntime as rt
|
|
|
|
|
|
2021-11-30 15:26:10 -08:00
|
|
|
sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers())
|
2018-11-19 16:48:22 -08:00
|
|
|
input_name = sess.get_inputs()[0].name
|
2018-12-05 19:12:25 +01:00
|
|
|
label_name = sess.get_outputs()[0].name
|
2018-11-19 16:48:22 -08:00
|
|
|
pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
|
2018-12-05 19:12:25 +01:00
|
|
|
print(pred_onx)
|
|
|
|
|
|
|
|
|
|
|