"""\********************************************************************************* Copyright (c) 2025 the Qrisp authors** This program and the accompanying materials are made available under the* terms of the Eclipse Public License 2.0 which is available at* http://www.eclipse.org/legal/epl-2.0.** This Source Code may also be made available under the following Secondary* Licenses when the conditions for such availability set forth in the Eclipse* Public License, v. 2.0 are satisfied: GNU General Public License, version 2* with the GNU Classpath Exception which is* available at https://www.gnu.org/software/classpath/license.html.** SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0********************************************************************************/"""importinspectfromjax.laximportwhile_loop,condimportjaximportjax.numpyasjnpfromqrisp.circuitimportXGatefromqrisp.jaspimportTracingQuantumSession,AbstractQubitArray,DynamicQubitArray,qachefromqrisp.jasp.primitivesimportMeasurement_p,OperationPrimitive,get_qubit_p,get_size_p,delete_qubits_p,reset_p
[docs]defRUS(*trial_function,**jit_kwargs):r""" Decorator to deploy repeat-until-success (RUS) components. At the core, RUS repeats a given quantum subroutine followed by a qubit measurement until the measurement returns the value ``1``. This step is prevalent in many important algorithms, among them the `HHL algorithm <https://arxiv.org/abs/0811.3171>`_ or the `LCU procedure <https://arxiv.org/abs/1202.5822>`_. Within Jasp, RUS steps can be realized by providing the quantum subroutine as a "trial function", which returns a boolean value (the repetition condition) and possibly other return values. It is important to note that the trial function can not receive quantum arguments. This is because after each trial, a new copy of these arguments would be required to perform the next iteration, which is prohibited by the no-clone theorem. It is however legal to provide classical arguments. Parameters ---------- trial_function : callable A function returning a boolean value as the first return value. More return values are possible. static_argnums : int or list[int], optional A list of integers specifying which arguments are considered static in the sense of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html>`_. The first argument is indicated by 1, the second by 2, etc. The default is ``[]``. static_argnames : str or list[str], optional A list of strings specifying which arguments are considered static in the sense of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html>`_. The default is ``[]``. Returns ------- callable A function that performs the RUS protocol with the trial function. The return values of this function are the return values of the trial function WITHOUT the boolean value. Examples -------- To demonstrate the RUS behavior, we initialize a GHZ state .. math:: \ket{\psi} = \frac{\ket{00000} + \ket{11111}}{\sqrt{2}} and measure the first qubit into a boolean value. This will be the value to cancel the repetition. This will collapse the GHZ state into either $\ket{00000}$ (which will cause a new repetition) or $\ket{11111}$, which cancels the loop. After the repetition is canceled we are therefore guaranteed to have the latter state. :: from qrisp.jasp import RUS, make_jaspr from qrisp import QuantumFloat, h, cx, measure @RUS def rus_trial_function(): qf = QuantumFloat(5) h(qf[0]) for i in range(1, 5): cx(qf[0], qf[i]) cancelation_bool = measure(qf[0]) return cancelation_bool, qf def call_RUS_example(): qf = rus_trial_function() return measure(qf) Create the ``jaspr`` and simulate: :: jaspr = make_jaspr(call_RUS_example)() print(jaspr()) # Yields, 31 which is the decimal version of 11111 **Static arguments** To demonstrate the specification of static arguments, we will realize implement a simple `linear combination of unitaries <https://arxiv.org/abs/1202.5822>`_. Our implementation initializes a state of the form .. math:: \left( \sum_{i = 0}^N c_i U_i \right) \ket{0}. We achieve this by specifying a set of unitaries $U_i$ in the form of a tuple of functions, each processing a :ref:`QuantumFloat`. The coefficients $c_i$ are specified through a function preparing the state .. math:: \ket{\psi} = \sum_{i = 0}^N c_i \ket{i} For the state preparation function we specify two options to experiment with. A two qubit uniform superposition and a function that brings only the first qubit into superpostion. :: def state_prep_full(qv): h(qv[0]) h(qv[1]) def state_prep_half(qv): h(qv[0]) For the first one we have $c_0 = c_1 = c_2 = c_3 = \sqrt{0.25}$. The second one gives $c_0 = c_1 = \sqrt{0.5}$ and $c_2 = c_3 = 0$. The next step is to define the unitaries $U_i$ in the form of a tuple of functions. :: from qrisp.jasp import * from qrisp import * def case_function_0(x): x += 3 def case_function_1(x): x += 4 def case_function_2(x): x += 5 def case_function_3(x): x += 6 case_functions = (case_function_0, case_function_1, case_function_2, case_function_3) These functions each represent the unitary: .. math:: U_i \ket{0} = \ket{i+3} Executing a linear combination of unitaries therefore gives .. math:: \left( \sum_{i = 0}^N c_i U_i \right) \ket{0} = \sum_{i = 0}^N c_i \ket{i+3} Now we implement the LCU procedure. :: # Specify the corresponding arguments of the block encoding as "static", # i.e. compile time constants. @RUS(static_argnums = [2,3]) def block_encoding(return_size, state_preparation, case_functions): # This QuantumFloat will be returned qf = QuantumFloat(return_size) # Specify the QuantumVariable that indicates, which # case to execute n = int(np.ceil(np.log2(len(case_functions)))) case_indicator = QuantumFloat(n) # Turn into a list of qubits case_indicator_qubits = [case_indicator[i] for i in range(n)] # Perform the LCU protocoll with conjugate(state_preparation)(case_indicator): for i in range(len(case_functions)): with control(case_indicator_qubits, ctrl_state = i): case_functions[i](qf) # Compute the success condition success_bool = (measure(case_indicator) == 0) return success_bool, qf Finally, evaluate via the :ref:`terminal_sampling <terminal_sampling>` feature: :: @terminal_sampling def main(): return block_encoding(4, state_prep_full, case_functions) print(main()) # Yields: {3.0: 0.25, 4.0: 0.25, 5.0: 0.25, 6.0: 0.25} Evaluate the other state preparation function :: @terminal_sampling def main(): return block_encoding(4, state_prep_half, case_functions) print(main()) # Yields: {3.0: 0.5, 4.0: 0.5} As expected, the full state preparation function yields a state proportional to .. math:: \ket{3} + \ket{4} + \ket{5} + \ket{6}. The second state preparation gives us .. math:: \ket{3} + \ket{4}. """iflen(trial_function)==0:returnlambdax:RUS(x,**jit_kwargs)else:trial_function=trial_function[0]# The idea for implementing this feature is to execute the function once# to collect the output QuantumVariable object.# Subsequently a jaspr in extracted, which is looped over until the condition is metdefreturn_function(*trial_args):abs_qs=TracingQuantumSession.get_instance()initial_gc_mode=abs_qs.gc_modeabs_qs.gc_mode="auto"# Execute the functionfirst_iter_res=qache(trial_function,**jit_kwargs)(*trial_args)abs_qs.gc_mode=initial_gc_mode# Extract the jaspreqn=jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1]ammended_trial_func_jaspr=eqn.params["jaxpr"].jaxprfromqrisp.jaspimportcollect_environmentsammended_trial_func_jaspr=collect_environments(ammended_trial_func_jaspr)ammended_trial_func_jaspr=ammended_trial_func_jaspr.flatten_environments()# Filter out the static argumentsif"static_argnums"injit_kwargs:static_argnums=jit_kwargs["static_argnums"]ifisinstance(static_argnums,int):static_argnums=[static_argnums]else:static_argnums=[]if"static_argnames"injit_kwargs:argname_list=inspect.getfullargspec(trial_function)foriinrange(len(argname_list)):ifargname_list[i]injit_kwargs["static_argnames"]:static_argnums.append(i)new_trial_args=[]foriinrange(len(trial_args)):ifinotinstatic_argnums:new_trial_args.append(trial_args[i])trial_args=new_trial_args# Flatten the arguments and the res valuesarg_vals,arg_tree_def=jax.tree.flatten(trial_args)res_vals,res_tree_def=jax.tree.flatten(first_iter_res)# Next we construct the body of the loop# In order to work with the while_loop interface from jax# this function receives a tuple of arguments and also returns# a tuple.# This tuple contains several sections of argument types:# The first argument is an AbstractQuantumCircuit# The next section are the results from the previous iteration# And the final section are trial function argumentscombined_args=tuple([abs_qs.abs_qc]+list(arg_vals)+list(res_vals))n_res_vals=len(res_vals)n_arg_vals=len(arg_vals)defbody_fun(args):# We now need to deallocate the AbstractQubitArrays from the previous# iteration since they are no longer needed.res_qv_vals=args[-n_res_vals:]abs_qc=args[0]forres_valinres_qv_vals:ifisinstance(res_val.aval,AbstractQubitArray):abs_qc=reset_p.bind(abs_qc,res_val)abs_qc=delete_qubits_p.bind(abs_qc,res_val)# Next we evaluate the trial function by evaluating the corresponding jaspr# Prepare the arguments tupletrial_args=[abs_qc]+list(args[1:1+n_arg_vals])# Evaluate the functiontrial_res=ammended_trial_func_jaspr.eval(*trial_args)# Return the resultsreturntuple([trial_res[0]]+list(trial_args)[1:]+list(trial_res)[1:])defcond_fun(val):# The loop cancelation index is located at the second position of the# return value tuplereturn~val[1+n_arg_vals]# We now evaluate the loop# If the first iteration was already successful, we simply return the results# To realize this behavior we use a cond primitivedeftrue_fun(combined_args):returncombined_argsdeffalse_fun(combined_args):# Here is the while_loopreturnwhile_loop(cond_fun,body_fun,init_val=combined_args)# Evaluate everythingcombined_res=cond(first_iter_res[0],true_fun,false_fun,combined_args)# Update the AbstractQuantumCircuitabs_qs.abs_qc=combined_res[0]# Extract the results of the trial functionflat_trial_function_res=combined_res[1+n_arg_vals:1+n_arg_vals+n_res_vals]# The results are however still "flattened" i.e. if the trial function# returned a QuantumVariable, they show up as a AbstractQubitArray.# We call the unflattening function with the auxiliary results values of the# first iteration and the traced values of the loop.trial_function_res=jax.tree.unflatten(res_tree_def,flat_trial_function_res)# Return the resultsiflen(first_iter_res)==2:returntrial_function_res[1]else:returntrial_function_res[1:]returnreturn_function
Get in touch!
If you are interested in Qrisp or high-level quantum algorithm research in general connect with us on our
Slack workspace.