import inspect
from jax.lax import while_loop, cond
import jax
import jax.numpy as jnp
from qrisp.circuit import XGate
from qrisp.jasp import TracingQuantumSession, AbstractQubitArray, DynamicQubitArray, qache
from qrisp.jasp.primitives import Measurement_p, OperationPrimitive, get_qubit_p, get_size_p, delete_qubits_p, reset_p
def RUS(*trial_function, **jit_kwargs):
Decorator to deploy repeat-until-success (RUS) components. At the core,
RUS repeats a given quantum subroutine followed by a qubit measurement until
the measurement returns the value ``1``. This step is prevalent
in many important algorithms, among them the
`HHL algorithm <https://arxiv.org/abs/0811.3171>`_ or the
`LCU procedure <https://arxiv.org/abs/1202.5822>`_.
Within Jasp, RUS steps can be realized by providing the quantum subroutine
as a "trial function", which returns a boolean value (the repetition condition) and
possibly other return values.
It is important to note that the trial function can not receive quantum
arguments. This is because after each trial, a new copy of these arguments
would be required to perform the next iteration, which is prohibited by
the no-clone theorem. It is however legal to provide classical arguments.
trial_function : callable
A function returning a boolean value as the first return value. More
return values are possible.
static_argnums : int or list[int], optional
A list of integers specifying which arguments are considered static in
the sense of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html>`_.
The first argument is indicated by 1, the second by 2, etc. The default
is ``[]``.
static_argnames : str or list[str], optional
A list of strings specifying which arguments are considered static in
the sense of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html>`_.
The default is ``[]``.
A function that performs the RUS protocol with the trial function. The
return values of this function are the return values of the trial function
WITHOUT the boolean value.
To demonstrate the RUS behavior, we initialize a GHZ state
.. math::
\ket{\psi} = \frac{\ket{00000} + \ket{11111}}{\sqrt{2}}
and measure the first qubit into a boolean value. This will be the value
to cancel the repetition. This will collapse the GHZ state into either
$\ket{00000}$ (which will cause a new repetition) or $\ket{11111}$, which
cancels the loop. After the repetition is canceled we are therefore
guaranteed to have the latter state.
from qrisp.jasp import RUS, make_jaspr
from qrisp import QuantumFloat, h, cx, measure
def rus_trial_function():
qf = QuantumFloat(5)
for i in range(1, 5):
cx(qf[0], qf[i])
cancelation_bool = measure(qf[0])
return cancelation_bool, qf
def call_RUS_example():
qf = rus_trial_function()
return measure(qf)
Create the ``jaspr`` and simulate:
jaspr = make_jaspr(call_RUS_example)()
# Yields, 31 which is the decimal version of 11111
**Static arguments**
To demonstrate the specification of static arguments, we will realize implement a
simple `linear combination of unitaries <https://arxiv.org/abs/1202.5822>`_.
Our implementation initializes a state of the form
.. math::
\left( \sum_{i = 0}^N c_i U_i \right) \ket{0}.
We achieve this by specifying a set of unitaries $U_i$ in the form of a
tuple of functions, each processing a :ref:`QuantumFloat`.
The coefficients $c_i$ are specified through a function preparing the state
.. math::
\ket{\psi} = \sum_{i = 0}^N c_i \ket{i}
For the state preparation function we specify two options to experiment with.
A two qubit uniform superposition and a function that brings only the first
qubit into superpostion.
def state_prep_full(qv):
def state_prep_half(qv):
For the first one we have $c_0 = c_1 = c_2 = c_3 = \sqrt{0.25}$. The second one
gives $c_0 = c_1 = \sqrt{0.5}$ and $c_2 = c_3 = 0$.
The next step is to define the unitaries $U_i$ in the form of a tuple
of functions.
from qrisp.jasp import *
from qrisp import *
def case_function_0(x):
x += 3
def case_function_1(x):
x += 4
def case_function_2(x):
x += 5
def case_function_3(x):
x += 6
case_functions = (case_function_0,
These functions each represent the unitary:
.. math::
U_i \ket{0} = \ket{i+3}
Executing a linear combination of unitaries therefore gives
.. math::
\left( \sum_{i = 0}^N c_i U_i \right) \ket{0} = \sum_{i = 0}^N c_i \ket{i+3}
Now we implement the LCU procedure.
# Specify the corresponding arguments of the block encoding as "static",
# i.e. compile time constants.
@RUS(static_argnums = [2,3])
def block_encoding(return_size, state_preparation, case_functions):
# This QuantumFloat will be returned
qf = QuantumFloat(return_size)
# Specify the QuantumVariable that indicates, which
# case to execute
n = int(np.ceil(np.log2(len(case_functions))))
case_indicator = QuantumFloat(n)
# Turn into a list of qubits
case_indicator_qubits = [case_indicator[i] for i in range(n)]
# Perform the LCU protocoll
with conjugate(state_preparation)(case_indicator):
for i in range(len(case_functions)):
with control(case_indicator_qubits, ctrl_state = i):
# Compute the success condition
success_bool = (measure(case_indicator) == 0)
return success_bool, qf
Finally, evaluate via the :ref:`terminal_sampling <terminal_sampling>`
def main():
return block_encoding(4, state_prep_full, case_functions)
# Yields: {3.0: 0.25, 4.0: 0.25, 5.0: 0.25, 6.0: 0.25}
Evaluate the other state preparation function
def main():
return block_encoding(4, state_prep_half, case_functions)
# Yields: {3.0: 0.5, 4.0: 0.5}
As expected, the full state preparation function yields a state proportional
.. math::
\ket{3} + \ket{4} + \ket{5} + \ket{6}.
The second state preparation gives us
.. math::
\ket{3} + \ket{4}.
if len(trial_function) == 0:
return lambda x : RUS(x, **jit_kwargs)
trial_function = trial_function[0]
# The idea for implementing this feature is to execute the function once
# to collect the output QuantumVariable object.
# Subsequently a jaspr in extracted, which is looped over until the condition is met
def return_function(*trial_args):
abs_qs = TracingQuantumSession.get_instance()
initial_gc_mode = abs_qs.gc_mode
abs_qs.gc_mode = "auto"
# Execute the function
first_iter_res = qache(trial_function, **jit_kwargs)(*trial_args)
abs_qs.gc_mode = initial_gc_mode
# Extract the jaspr
eqn = jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1]
ammended_trial_func_jaspr = eqn.params["jaxpr"].jaxpr
from qrisp.jasp import collect_environments
ammended_trial_func_jaspr = collect_environments(ammended_trial_func_jaspr)
ammended_trial_func_jaspr = ammended_trial_func_jaspr.flatten_environments()
# Filter out the static arguments
if "static_argnums" in jit_kwargs:
static_argnums = jit_kwargs["static_argnums"]
if isinstance(static_argnums, int):
static_argnums = [static_argnums]
static_argnums = []
if "static_argnames" in jit_kwargs:
argname_list = inspect.getfullargspec(trial_function)
for i in range(len(argname_list)):
if argname_list[i] in jit_kwargs["static_argnames"]:
new_trial_args = []
for i in range(len(trial_args)):
if i not in static_argnums:
trial_args = new_trial_args
# Flatten the arguments and the res values
arg_vals, arg_tree_def = jax.tree.flatten(trial_args)
res_vals, res_tree_def = jax.tree.flatten(first_iter_res)
# Next we construct the body of the loop
# In order to work with the while_loop interface from jax
# this function receives a tuple of arguments and also returns
# a tuple.
# This tuple contains several sections of argument types:
# The first argument is an AbstractQuantumCircuit
# The next section are the results from the previous iteration
# And the final section are trial function arguments
combined_args = tuple([abs_qs.abs_qc] + list(arg_vals) + list(res_vals))
n_res_vals = len(res_vals)
n_arg_vals = len(arg_vals)
def body_fun(args):
# We now need to deallocate the AbstractQubitArrays from the previous
# iteration since they are no longer needed.
res_qv_vals = args[-n_res_vals:]
abs_qc = args[0]
for res_val in res_qv_vals:
if isinstance(res_val.aval, AbstractQubitArray):
abs_qc = reset_p.bind(abs_qc, res_val)
abs_qc = delete_qubits_p.bind(abs_qc, res_val)
# Next we evaluate the trial function by evaluating the corresponding jaspr
# Prepare the arguments tuple
trial_args = [abs_qc] + list(args[1:1+n_arg_vals])
# Evaluate the function
trial_res = ammended_trial_func_jaspr.eval(*trial_args)
# Return the results
return tuple([trial_res[0]] + list(trial_args)[1:] + list(trial_res)[1:])
def cond_fun(val):
# The loop cancelation index is located at the second position of the
# return value tuple
return ~val[1+n_arg_vals]
# We now evaluate the loop
# If the first iteration was already successful, we simply return the results
# To realize this behavior we use a cond primitive
def true_fun(combined_args):
return combined_args
def false_fun(combined_args):
# Here is the while_loop
return while_loop(cond_fun, body_fun, init_val = combined_args)
# Evaluate everything
combined_res = cond(first_iter_res[0], true_fun, false_fun, combined_args)
# Update the AbstractQuantumCircuit
abs_qs.abs_qc = combined_res[0]
# Extract the results of the trial function
flat_trial_function_res = combined_res[1+n_arg_vals:1+n_arg_vals+n_res_vals]
# The results are however still "flattened" i.e. if the trial function
# returned a QuantumVariable, they show up as a AbstractQubitArray.
# We call the unflattening function with the auxiliary results values of the
# first iteration and the traced values of the loop.
trial_function_res = jax.tree.unflatten(res_tree_def, flat_trial_function_res)
# Return the results
if len(first_iter_res) == 2:
return trial_function_res[1]
return trial_function_res[1:]
return return_function