Source code for qrisp.jasp.evaluation_tools.catalyst_qjit
"""
\********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""
from jax.tree_util import tree_flatten, tree_unflatten
from qrisp.jasp.jasp_expression import make_jaspr
[docs]
def qjit(function):
"""
Decorator to leverage the jasp + Catalyst infrastructure to compile the given
function to QIR and run it on the Catalyst QIR runtime.
Parameters
----------
function : callable
A function performing Qrisp code.
Returns
-------
callable
A function executing the compiled code.
Examples
--------
We write a simple function using the QuantumFloat quantum type and execute
via ``qjit``:
::
from qrisp import *
from qrisp.jasp import qjit
@qjit
def test_fun(i):
qv = QuantumFloat(i, -2)
with invert():
cx(qv[0], qv[qv.size-1])
h(qv[0])
meas_res = measure(qv)
return meas_res + 3
We execute the function a couple of times to demonstrate the randomness
>>> test_fun(4)
[array(5.25, dtype=float64)]
>>> test_fun(5)
[array(3., dtype=float64)]
>>> test_fun(5)
[array(7.25, dtype=float64)]
"""
def jitted_function(*args):
if not hasattr(function, "jaspr_dict"):
function.jaspr_dict = {}
args = list(args)
signature = tuple([type(arg) for arg in args])
if not signature in function.jaspr_dict:
function.jaspr_dict[signature] = make_jaspr(function)(*args)
return function.jaspr_dict[signature].qjit(*args, function_name = function.__name__)
return jitted_function