Source code for qrisp.jasp.program_control.ev

"""
\********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""

import jax
import jax.numpy as jnp

from qrisp.jasp.tracing_logic import quantum_kernel

# The following function implements the expectation_value feature.
# The basic functionality would be relatively straightforward to implement,
# however there are some complications. The reason for that is that the resulting
# jaxpr should be "readable" by the terminal sampling interpreter.
# Terminal sampling means that instead of performing the simulations "shots"-times
# it is performed once and the shots are then sampled from that distribution.
# Naturally this implies a massive performance increase, which is why a lot
# of effort is spent to realize a smooth implementation.

# The underlying idea to make the feature easily "readable" by the terminal
# sampling interpreter is to structure one iteration of sampling into three
# steps.

# 1. Evaluating the user function, which generates the distribution.
# 2. Sampling from that distribution via the "measure" function.
# 3. Decoding and postprocessing the measurement results.

# For the final two steps we deploy some custom logic to realize the terminal
# sampling behavior. To simplify the automatic processing of these steps,
# we capture each into individual pjit calls.

# The terminal sampling interpreter then identifies each steps via the
# eqn.params["name"] attribute and executes the custom logic.


[docs] def expectation_value(state_prep, shots, return_dict = False, post_processor = None): r""" The ``expectation_value`` function allows to estimate the expectation value from a state that is specified by a preparation procedure. This preparation procedure can be supplied via a Python function that returns one or more :ref:`QuantumVariables <QuantumVariable>`. Parameters ---------- state_prep : callable A function returning one or more :ref:`QuantumVariables <QuantumVariable>`. The expectation value from this state will be computed. The state preparation function can only take classical values as arguments. This is because a quantum value would need to be copied for each sampling iteration, which is prohibited by the no-cloning theorem. shots : int or jax.core.Tracer The amount of samples to take to compute the expectation value. post_processor : callable, optional A classical Jax traceable function to apply to the results directly after measuring. By default no post processing is applied. Raises ------ Exception Tried to sample from state preparation function taking a quantum value Returns ------- callable A function returning a Jax array containing the expectation value. Examples -------- We prepare the state .. math:: \ket{\psi_k} = \frac{1}{\sqrt{2}} \left(\ket{0}\ket{0}\ket{\text{False}} + \ket{k}\ket{k}\ket{\text{True}}\right) :: from qrisp import * from qrisp.jasp import * def state_prep(k): a = QuantumFloat(4) b = QuantumFloat(4) qbl = QuantumBool() h(qbl) with control(qbl[0]): a[:] = k cx(a, b) return a, b And compute the expectation value of the QuantumFloats :: @jaspify def main(k): ev_function = expectation_value(state_prep, shots = 50) return ev_function(k) print(main(3)) # Yields # [1.44 1.44] The true value 1.5 is not reached because of `shot noise <https://en.wikipedia.org/wiki/Shot_noise>`_. To improve the approximation, feel free to increase the shots! To demonstrate the ``post_processor`` keyword we define a simple post processing function :: def post_processor(x, y): return x*y @jaspify def main(k): ev_function = expectation_value(state_prep, shots = 50) return ev_function(k) print(main(3)) # Yields # 4.338 This result is expected because the inputs of ``post_processor`` are either (0,0) or (3,3) with 50% probability, so we get .. math:: 4.5 = \frac{3\cdot 3 + 0\cdot 0}{2} """ from qrisp.jasp import make_tracer, qache from qrisp.core import QuantumVariable, measure if isinstance(shots, int): shots = make_tracer(shots) if post_processor is None: def identity(*args): return args post_processor = identity # Qache the user function @qache def user_func(*args): return state_prep(*args) # This function performs the logic to evaluate the expectation value def expectation_value_eval_function(*args): for arg in args: if isinstance(arg, QuantumVariable): raise Exception("Tried to sample from state preparation function taking a quantum value") # We now construct a loop to evaluate the expectation value via adding # the decoded and postprocessed measurement result into an accumulator. # The following function is the loop body, which is kernelized. @quantum_kernel def sampling_body_func(i, args): # Evaluate the user function acc = args[0] qv_tuple = user_func(*args[1:]) if not isinstance(qv_tuple, tuple): qv_tuple = (qv_tuple,) # Ensure all results are QuantumVariables for qv in qv_tuple: if not isinstance(qv, QuantumVariable): raise Exception("Tried to sample from function not returning a QuantumVariable") # Trace the DynamicQubitArray measurements # Since we execute the measurements on the .reg attribute, no decoding # is applied. The decoding happens in sampling_helper_2 @qache def sampling_helper_1(*args): res_list = [] for reg in args: res_list.append(measure(reg)) return tuple(res_list) measurement_ints = sampling_helper_1(*[qv.reg for qv in qv_tuple]) # Trace the decoding @jax.jit def sampling_helper_2(*meas_ints): res_list = [] for i in range(len(qv_tuple)): res_list.append(qv_tuple[i].jdecoder(meas_ints[i])) # Apply the post processing return post_processor(*res_list) decoded_values = sampling_helper_2(*(list(measurement_ints))) if isinstance(decoded_values, tuple) and len(decoded_values) != 1: # Save the return amount (for more details check the comment of the) # initialization command of return_amount return_amount.append(len(decoded_values)) if acc.shape[0] == 1: raise AuxException() # Turn into jax array and add to the accumulator meas_res = jnp.array(decoded_values) acc += meas_res # Return the updated accumulator for the next loop iteration. return (acc, *args[1:]) # This list captures the amount of return values. The strategy here is # to initially assume only one QuantumVariable is returned, which is then # added to the expectation value accumulator. If more than one is returned, # the amount is saved in this list and an exception is raised, which # subsequently causes another call but this time with the correct accumulator # dimension. return_amount = [] try: loop_res = jax.lax.fori_loop(0, shots, sampling_body_func, (jnp.array([0.]), *args)) return loop_res[0][0]/shots except AuxException: loop_res = jax.lax.fori_loop(0, shots, sampling_body_func, (jnp.array([0.]*return_amount[0]), *args)) return loop_res[0]/shots if return_dict: expectation_value_eval_function.__name__ = "dict_sampling_eval_function" return jax.jit(expectation_value_eval_function)
class AuxException(Exception): pass