Source code for qrisp.jasp.evaluation_tools.jaspification

"""
\********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""

from jax.tree_util import tree_flatten, tree_unflatten

from qrisp.jasp.interpreter_tools import extract_invalues, insert_outvalues, eval_jaxpr
from qrisp.jasp.evaluation_tools.buffered_quantum_state import BufferedQuantumState
from qrisp.core import recursive_qv_search
from qrisp.circuit import fast_append

[docs] def jaspify(func = None, terminal_sampling = False): """ This simulator is the established Qrisp simulator linked to the Jasp infrastructure. Among a variety of simulation tricks, the simulator can leverage state sparsity, allowing simulations with up to hundreds of qubits! To be called as a decorator of a Jasp-traceable function. .. note:: If you are developing a hybrid algorithm like QAOA or VQE that relies heavily on sampling, please activate the ``terminal_sampling`` feature. Parameters ---------- func : callable The function to simulate. terminal_sampling : bool, optional Whether to leverage the terminal sampling strategy. Significantly fast for all sampling tasks but can yield incorrect results in some situations. Check out :ref:`terminal_sampling` form more details. The default is False. Returns ------- callable A function performing the simulation. Examples -------- We simulate a function creating a simple GHZ state: :: from qrisp import * from qrisp.jasp import * @jaspify def main(): qf = QuantumFloat(5) h(qf[0]) for i in range(1, 5): cx(qf[0], qf[i]) return measure(qf) print(main()) # Yields either 0 or 31 To highlight the speed of the terminal sampling feature, we :ref:`sample` from a uniform superposition :: def state_prep(): qf = QuantumFloat(5) h(qf) return qf @jaspify def without_terminal_sampling(): sampling_func = sample(state_prep, shots = 10000) return sampling_func() @jaspify(terminal_sampling = True) def with_terminal_sampling(): sampling_func = sample(state_prep, shots = 10000) return sampling_func() Benchmark the time difference: :: import time t0 = time.time() res = without_terminal_sampling() print(time.time() - t0) # Yields # 43.78982925 t0 = time.time() res = with_terminal_sampling() print(time.time() - t0) # Yields # 0.550775527 """ if isinstance(func, bool): terminal_sampling = func func = None if func is None: return lambda x : jaspify(x, terminal_sampling = terminal_sampling) from qrisp.jasp import make_jaspr treedef_container = [] def tracing_function(*args): res = func(*args) flattened_values, tree_def = tree_flatten(res) treedef_container.append(tree_def) return flattened_values def return_function(*args): # To prevent "accidental deletion" induced non-determinism we set the # garbage collection mode to manual if terminal_sampling: garbage_collection = "manual" else: garbage_collection = "auto" jaspr = make_jaspr(tracing_function, garbage_collection = garbage_collection)(*args) jaspr_res = simulate_jaspr(jaspr, *args, terminal_sampling = terminal_sampling) if isinstance(jaspr_res, tuple): jaspr_res = tree_unflatten(treedef_container[0], jaspr_res) if len(recursive_qv_search(jaspr_res)): raise Exception("Tried to jaspify function returning a QuantumVariable") return jaspr_res return return_function
[docs] def stimulate(func = None): """ This function leverages the `Stim simulator <https://github.com/quantumlib/Stim?tab=readme-ov-file>`_ to evaluate a Jasp-traceable function containing only Clifford gates. Stim is a popular tool to simulate quantum error correction codes. .. note:: To use this simulator, you need stim installed, which can be achieved via ``pip install stim``. Parameters ---------- func : callable The function to simulate. Returns ------- callable A function performing the simulation. Examples -------- We simulate a function creating a simple GHZ state: :: from qrisp import * from qrisp.jasp import * @stimulate def main(): qf = QuantumFloat(5) h(qf[0]) for i in range(1, 5): cx(qf[0], qf[i]) return measure(qf) print(main()) # Yields either 0 or 31 The ``stimulate`` decorator can also simulate real-time features: :: @stimulate def main(): qf = QuantumFloat(5) h(qf[0]) cl_bl = measure(qf[0]) with control(cl_bl): for i in range(1, 5): x(qf[i]) return measure(qf) print(main()) # Yields either 0 or 31 """ from qrisp.jasp import make_jaspr treedef_container = [] def tracing_function(*args): res = func(*args) flattened_values, tree_def = tree_flatten(res) treedef_container.append(tree_def) return flattened_values def return_function(*args): jaspr = make_jaspr(tracing_function)(*args) jaspr_res = simulate_jaspr(jaspr, *args, simulator = "stim") if isinstance(jaspr_res, tuple): jaspr_res = tree_unflatten(treedef_container[0], jaspr_res) if len(recursive_qv_search(jaspr_res)): raise Exception("Tried to simulate function returning a QuantumVariable") return jaspr_res return return_function
def simulate_jaspr(jaspr, *args, terminal_sampling = False, simulator = "qrisp"): from qrisp.alg_primitives.mcx_algs.circuit_library import gidney_qc if len(jaspr.outvars) == 1: return None if simulator == "stim": if terminal_sampling: raise Exception("Terminal sampling with stim is currently not implemented") elif not simulator == "qrisp": raise Exception(f"Don't know simulator {simulator}") args = [BufferedQuantumState(simulator)] + list(tree_flatten(args)[0]) def eqn_evaluator(eqn, context_dic): if eqn.primitive.name == "pjit": function_name = eqn.params["name"] if terminal_sampling: translation_dic = {"expectation_value_eval_function" : "ev", "sampling_eval_function" : "array", "dict_sampling_eval_function" : "dict"} from qrisp.jasp.interpreter_tools import terminal_sampling_evaluator if function_name in translation_dic: terminal_sampling_evaluator(translation_dic[function_name])(eqn, context_dic, eqn_evaluator = eqn_evaluator) return invalues = extract_invalues(eqn, context_dic) # We simulate the inverse Gidney mcx via the non-hybrid version because # the hybrid version prevents the simulator from fusing gates, which # slows down the simulation if eqn.params["name"] == "gidney_mcx_inv": invalues[0].append(gidney_qc.inverse().to_gate(), invalues[1:]) outvalues = [invalues[0]] else: outvalues = eval_jaxpr(eqn.params["jaxpr"], eqn_evaluator = eqn_evaluator)(*invalues) if not isinstance(outvalues, (list, tuple)): outvalues = [outvalues] insert_outvalues(eqn, context_dic, outvalues) elif eqn.primitive.name == "jasp.quantum_kernel": insert_outvalues(eqn, context_dic, BufferedQuantumState(simulator)) else: return True with fast_append(3): res = eval_jaxpr(jaspr, eqn_evaluator = eqn_evaluator)(*(args + jaspr.consts)) if len(jaspr.outvars) == 2: return res[1] else: return res[1:]