GitXplorerGitXplorer
s

onnx-extended

public
31 stars
6 forks
4 issues

Commits

List of commits on branch main.
Verified
ff2c70edd9d0dff13966532bcbd2a4edd339ecb1

Add MatX, update ORT version, pybind11 version (#189)

xxadupre committed 4 months ago
Verified
5004ce6a5acf083e06bb086fdbdbe0b2ae0b3cd3

Improves plots (#188)

xxadupre committed 5 months ago
Verified
87d76642a3e14e3ef14d1d76be27d211dae8f42a

Fix compilation with GCC>=13 (#187)

xxadupre committed 6 months ago
Verified
d836302e6c136356baa1b5f3495f38f0ebdd49ab

Add custom op MulMulSigmoid (#185)

xxadupre committed 7 months ago
Verified
2e569785998fafb02b20a8bc95e5ba2f1bd96ca8

Upgrade to onnxruntime==1.18.0 (#184)

xxadupre committed 8 months ago
Verified
08d8766ba5a0410c436b624732b5188180c585cb

Fixes Apple build (#174)

xxadupre committed 8 months ago

README

The README file for this repository.

.. image:: https://github.com/sdpython/onnx-extended/raw/main/_doc/_static/logo.png :width: 120

onnx-extended: extensions for onnx and onnxruntime

.. image:: https://dev.azure.com/xavierdupre3/onnx-extended/_apis/build/status/sdpython.onnx-extended :target: https://dev.azure.com/xavierdupre3/onnx-extended/ .. image:: https://badge.fury.io/py/onnx-extended.svg :target: http://badge.fury.io/py/onnx-extended .. image:: http://img.shields.io/github/issues/sdpython/onnx-extended.png :alt: GitHub Issues :target: https://github.com/sdpython/onnx-extended/issues .. image:: https://img.shields.io/badge/license-MIT-blue.svg :alt: MIT License :target: https://opensource.org/license/MIT/ .. image:: https://img.shields.io/github/repo-size/sdpython/onnx-extended :target: https://github.com/sdpython/onnx-extended/ :alt: size .. image:: https://img.shields.io/badge/code%20style-black-000000.svg :target: https://github.com/psf/black

onnx-extended extends the list of supported operators in onnx reference implementation and onnxruntime <https://github.com/microsoft/onnxruntime>, or implements faster versions in C++. Documentation onnx-extended <https://sdpython.github.io/doc/onnx-extended/dev/>. Source are available on github/onnx-extended <https://github.com/sdpython/onnx-extended/>_.

Use a C++ implementation of existing operators ++++++++++++++++++++++++++++++++++++++++++++++

.. code-block:: python

import timeit
import numpy as np
from onnx import TensorProto
from onnx.helper import (
    make_graph,
    make_model,
    make_node,
    make_opsetid,
    make_tensor_value_info,
)
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from onnx_extended.ext_test_case import measure_time
from onnx_extended.reference import CReferenceEvaluator


X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None, None, None])
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None, None, None])
B = make_tensor_value_info("B", TensorProto.FLOAT, [None, None, None, None])
W = make_tensor_value_info("W", TensorProto.FLOAT, [None, None, None, None])
node = make_node(
    "Conv",
    ["X", "W", "B"],
    ["Y"],
    pads=[1, 1, 1, 1],
    dilations=[1, 1],
    strides=[2, 2],
)
graph = make_graph([node], "g", [X, W, B], [Y])
onnx_model = make_model(graph, opset_imports=[make_opsetid("", 16)])

sH, sW = 64, 64
X = np.arange(sW * sH).reshape((1, 1, sH, sW)).astype(np.float32)
W = np.ones((1, 1, 3, 3), dtype=np.float32)
B = np.array([[[[0]]]], dtype=np.float32)

sess1 = ReferenceEvaluator(onnx_model)
sess2 = CReferenceEvaluator(onnx_model)  # 100 times faster

expected = sess1.run(None, {"X": X, "W": W, "B": B})[0]
got = sess2.run(None, {"X": X, "W": W, "B": B})[0]
diff = np.abs(expected - got).max()
print(f"difference: {diff}")

f1 = lambda: sess1.run(None, {"X": X, "W": W, "B": B})[0]
f2 = lambda: sess2.run(None, {"X": X, "W": W, "B": B})[0]
print("onnx:", timeit.timeit(f1, globals=globals(), number=5))
print("onnx-extended:", timeit.timeit(f2, globals=globals(), number=5))

::

difference: 0.0
onnx: 0.024006774998269975
onnx-extended: 0.0002316169993719086

Build with CUDA, openmp, eigen, onnxruntime +++++++++++++++++++++++++++++++++++++++++++

The package also contains some dummy examples on how to build with C++ functions (pybind11 <https://github.com/pybind/pybind11>, cython <https://cython.org/>), with openmp <https://www.openmp.org/>, eigen <https://eigen.tuxfamily.org/index.php> with or without CUDA. It also shows how to create a custom operator for onnxruntime in C++.

The version released on pypi/onnx-extended <https://pypi.org/project/onnx-extended/>_ only works on CPU. It needs to be manually built to enable the code using CUDA. The build will automatically link with CUDA if it is found. If not, some extensions might not be available.

::

python setup.py build_ext --inplace
# pip install -e .

It is possible to use a specific version of CUDA:

::

python setup.py build_ext --inplace --cuda-version=11.8
# or (not working yet)
# pip install -e . --config-settings="--cuda-version=11.8"
# pip install -e . --global-option="--cuda-version=11.8"
export USE_CUDA=11.8
pip install -e .

NVTX <https://github.com/NVIDIA/NVTX>_ can be enabled with the following command:

::

python setup.py build_ext --inplace --use_nvtx 1
# or (not working yet)
# pip install -e . --config-settings="--use_nvtx=1"
pip install -e . --global-option "--use_nvtx=1"

Experimental cython binding for onnxruntime +++++++++++++++++++++++++++++++++++++++++++

The python onnxruntime package relies on pybind11 to expose its functionalities. onnx-extended tries to build a cython wrapper around the C/C++ API of onnxruntime. cython relies on python C API and is faster than pybind11. This different may be significant when onnxruntime is used on small graphs and tensors.

Custom kernels for onnxruntime ++++++++++++++++++++++++++++++

onnxruntime provides an API to add custom implementation for existing or new onnx operators. An example for CPU.

::

from onnxruntime import InferenceSession, SessionOptions
from onnx_extended.ortops.optim.cpu import get_ort_ext_libs

r = get_ort_ext_libs()
opts = SessionOptions()
if r is not None:
    opts.register_custom_ops_library(r[0])

sess_cus = InferenceSession(
    onx_modified.SerializeToString(), opts, providers=["CPUExecutionProvider"]
)