Source code for qrisp.environments.custom_control_environment
"""\********************************************************************************* Copyright (c) 2025 the Qrisp authors** This program and the accompanying materials are made available under the* terms of the Eclipse Public License 2.0 which is available at* http://www.eclipse.org/legal/epl-2.0.** This Source Code may also be made available under the following Secondary* Licenses when the conditions for such availability set forth in the Eclipse* Public License, v. 2.0 are satisfied: GNU General Public License, version 2* with the GNU Classpath Exception which is* available at https://www.gnu.org/software/classpath/license.html.** SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0********************************************************************************/"""importinspectimportjaximportjax.numpyasjnpfromqrisp.environments.quantum_environmentsimportQuantumEnvironmentfromqrisp.environments.gate_wrap_environmentimportGateWrapEnvironmentfromqrisp.circuitimportOperation,QuantumCircuit,Instructionfromqrisp.environments.iteration_environmentimportIterationEnvironmentfromqrisp.coreimportmergefromqrisp.jaspimportcheck_for_tracing_mode,qache,AbstractQubit,make_jaspr
[docs]defcustom_control(*func,**cusc_kwargs):""" The ``custom_control`` decorator allows to specify the controlled version of the decorated function. If this function is called within a :ref:`ControlEnvironment` or a :ref:`ConditionEnvironment` the controlled version is executed instead. Specific controlled versions of quantum functions are very common in many scientific publications. This is because the general control procedure can signifcantly increase resource demands. In order to use the ``custom_control`` decorator, you need to add the ``ctrl`` keyword to your function signature. If called within a controlled context, this keyword will receive the corresponding control qubit. For more details consult the examples section. Parameters ---------- func : function A function of QuantumVariables, which has the ``ctrl`` keyword. Returns ------- adaptive_control_function : function A function which will execute it's controlled version, if called within a :ref:`ControlEnvironment` or a :ref:`ConditionEnvironment`. Examples -------- We create a swap function with custom control. :: from qrisp import mcx, cx, custom_control @custom_control def swap(a, b, ctrl = None): if ctrl is None: cx(a, b) cx(b, a) cx(a, b) else: cx(a, b) mcx([ctrl, b], a) cx(a, b) Test the non-controlled version: :: from qrisp import QuantumBool a = QuantumBool() b = QuantumBool() swap(a, b) print(a.qs) :: QuantumCircuit: -------------- ┌───┐ a.0: ──■──┤ X ├──■── ┌─┴─┐└─┬─┘┌─┴─┐ b.0: ┤ X ├──■──┤ X ├ └───┘ └───┘ Live QuantumVariables: --------------------- QuantumBool a QuantumBool b Test the controlled version: :: from qrisp import control a = QuantumBool() b = QuantumBool() ctrl_qbl = QuantumBool() with control(ctrl_qbl): swap(a,b) print(a.qs.transpile(1)) :: ┌───┐ a.0: ──■──┤ X ├──■── ┌─┴─┐└─┬─┘┌─┴─┐ b.0: ┤ X ├──■──┤ X ├ └───┘ │ └───┘ ctrl_qbl.0: ───────■─────── """iflen(func)==0:returnlambdax:custom_control(x,**cusc_kwargs)else:func=func[0]# The idea to realize the custom control feature in traced mode is to# first trace the non-controlled version into a pjit primitive using# the qache feature and the trace the controlled version.# The controlled version is then stored in the params attribute# Qache the function (in non-traced mode, this has no effect)func=qache(func,**cusc_kwargs)defadaptive_control_function(*args,**kwargs):ifnotcheck_for_tracing_mode():fromqrisp.coreimportrecursive_qs_searchfromqrispimportmerge,ControlEnvironment,ConditionEnvironment,QuantumEnvironment,InversionEnvironment,ConjugationEnvironmentqs_list=recursive_qs_search(args)if"ctrl"inkwargs:ifkwargs["ctrl"]isnotNone:qs_list.append(kwargs["ctrl"].qs())merge(qs_list)iflen(qs_list)==0:returnfunc(*args,**kwargs)qs=qs_list[0]# Search for a Control/Condition Environment and get the control qubitcontrol_qb=Noneforenvinqs.env_stack[::-1]:iftype(env)==QuantumEnvironment:continueifisinstance(env,(ControlEnvironment,ConditionEnvironment)):control_qb=env.condition_truth_valuebreakifnotisinstance(env,(InversionEnvironment,ConjugationEnvironment,GateWrapEnvironment)):ifisinstance(env,IterationEnvironment):ifenv.precompile:breakelse:continuebreak# If no control qubit was found, simply execute the functionifcontrol_qbisNone:returnfunc(*args,**kwargs)# Check whether the function supports the ctrl_method kwarg and adjust# the kwargs accordinglyif"ctrl_method"inlist(inspect.getfullargspec(func))[0]andisinstance(env,ControlEnvironment):kwargs.update({"ctrl_method":env.ctrl_method})# In the case that a qubit was found, we use the CustomControlEnvironent (definded below)# This environments gatewraps the function and compiles it to a specific Operation subtype# called CustomControlledOperation.# The Condition/Control Environment compiler recognizes this Operation type# and processes it accordinglywithCustomControlEnvironment(control_qb,func.__name__):if"ctrl"inkwargs:kwargs["ctrl"]=control_qbres=func(*args,**kwargs)else:res=func(*args,ctrl=control_qb,**kwargs)else:args=list(args)foriinrange(len(args)):ifisinstance(args[i],bool):args[i]=jnp.array(args[i],dtype=jnp.bool)elifisinstance(args[i],int):args[i]=jnp.array(args[i],dtype=jnp.int64)elifisinstance(args[i],float):args[i]=jnp.array(args[i],dtype=jnp.float64)elifisinstance(args[i],complex):args[i]=jnp.array(args[i],dtype=jnp.complex64)# Call the (qached) functionres=func(*args,**kwargs)# Retrieve the pjit equationjit_eqn=jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1]ifnotjit_eqn.params["jaxpr"].jaxpr.ctrl_jaspr:# Trace the controlled versionnew_kwargs=dict(kwargs)ctrl_aval=AbstractQubit()new_kwargs["ctrl"]=ctrl_avalfromqrisp.jaspimportTracingQuantumSessionabs_qs=TracingQuantumSession.get_instance()controlled_jaspr=make_jaspr(func,**cusc_kwargs)(*args,**new_kwargs)# Find the variable that contains the control qubitfori,invarinenumerate(controlled_jaspr.invars):ifinvar.avalisctrl_aval:break# Move it to the place after the QuantumCircuit argumentcontrolled_jaspr.invars.insert(1,controlled_jaspr.invars.pop(i))# Store controlled versionjit_eqn.params["jaxpr"].jaxpr.ctrl_jaspr=controlled_jasprreturnresreturnadaptive_control_function