Source code for qrisp.jasp.evaluation_tools.boolean_simulation

"""
\********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""

import jax.numpy as jnp
from jax import jit

from qrisp.jasp import make_jaspr

from qrisp.jasp.interpreter_tools.interpreters.cl_func_interpreter import jaspr_to_cl_func_jaxpr
from qrisp.jasp.interpreter_tools import Jlist, eval_jaxpr

[docs] def boolean_simulation(*func, bit_array_padding = 2**16): """ Decorator to simulate Jasp functions containing only classical logic (like X, CX, CCX etc.). This decorator transforms the function into a Jax-Expression without any quantum primitives and leverages the Jax compilation pipeline to compile a highly efficient simulation. .. note:: The ``boolean_simulation`` decorator will check if deleted :ref:`QuantumVariables <QuantumVariable>` have been properly uncomputed and submit a warning otherwise. It therefore provides a valuable tool for verifying the correctness of your algorithms at scale. Parameters ---------- func : callable A Python function performing Jasp logic. bit_array_padding : int, optional An integer specifying the size of the classical array containing the (classical) bits which are simulated. Since Jax doesn't allow dynamically sized arrays but Jasp supports dynamically sized QuantumVariables, the array has to be "padded". The padding therefore indicates an upper boundary for how many qubits are required to execute ``func``. A large padding slows down the simulation but prevents overflow errors. The simulation is performed without any memory management, therefore even qubits that are deallocated count into the padding. The default is ``2**20``. The minimum value is 64. Returns ------- simulator_function A function performing the simulation for the given input parameters. Examples -------- We create a simple script that demonstrates the functionality: :: from qrisp import QuantumFloat, measure from qrisp.jasp import boolean_simulation, jrange @boolean_simulation def main(i, j): a = QuantumFloat(10) b = QuantumFloat(10) a[:] = i b[:] = j c = QuantumFloat(30) for i in jrange(150): c += a*b return measure(c) This script evaluates the multiplication of the two inputs 150 times and adds them into the same QuantumFloat. The respected result is therefore ``i*j*150``. >>> main(1,2) Array(300., dtype=float64) >>> main(3,4) Array(1800., dtype=float64) Next we demonstrate the behavior under a faulty uncomputation: :: @boolean_simulation def main(i): a = QuantumFloat(10) a[:] = i a.delete() return >>> main(0) >>> main(1) WARNING: Faulty uncomputation found during simulation. >>> main(3) WARNING: Faulty uncomputation found during simulation. WARNING: Faulty uncomputation found during simulation. For the first case, the deletion is valid, because ``a`` is initialized in the $\ket{0}$ state. For the second case, the first qubit is in the $\ket{1}$ state, so the deletion is not valid. The third case has both the first and the second qubit in the $\ket{1}$ state (because 3 = ``11`` in binary) so there are two warnings. **Padding** We demonstrate the effects of the padding feature. For this we recreate the above script but with different padding selections. :: @boolean_simulation(bit_array_padding = 64) def main(i, j): a = QuantumFloat(10) b = QuantumFloat(10) a[:] = i b[:] = j c = QuantumFloat(30) for i in jrange(150): c += a*b return measure(c) >>> main(1,2) Array(8.92323439e+08, dtype=float64) A faulty result because the script needs more than 64 qubits. Increasing the padding ensures that enough qubits are available at the cost of simulation speed. """ if len(func) == 0: return lambda x : boolean_simulation(x, bit_array_padding = bit_array_padding) else: func = func[0] if bit_array_padding < 64: raise Exception("Tried to initialize boolean_simulation with less than 64 bits") @jit def return_function(*args): jaspr = make_jaspr(func, garbage_collection="manual")(*args) cl_func_jaxpr = jaspr_to_cl_func_jaxpr(jaspr.flatten_environments(), bit_array_padding) aval = cl_func_jaxpr.invars[0].aval bit_array = jnp.zeros(aval.shape, dtype = aval.dtype) free_qubit_list = Jlist(jnp.arange(bit_array_padding), max_size = bit_array_padding).flatten()[0] boolean_quantum_circuit = (bit_array, *free_qubit_list) res = eval_jaxpr(cl_func_jaxpr)(*boolean_quantum_circuit, *args) if len(res) == 4: return res[3] elif len(res) == 3: return None else: return res[3:] return return_function