Source code for qrisp.environments.classical_control_environment

"""
\********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""

from jax.lax import cond
import jax

from qrisp.environments import QuantumEnvironment
from qrisp.jasp import extract_invalues, insert_outvalues, check_for_tracing_mode

[docs] class ClControlEnvironment(QuantumEnvironment): r""" The ``ClControlEnvironment`` enables execution of quantum code conditioned on classical values. The environment works with similar semantics as the :ref:`ControlEnvironment`, implying this environment can also be entered using the ``control`` keyword. .. warning:: Contrary to the :ref:`ControlEnvironment` the ``ClControlEnvironment`` must not have "carry values". This means that no value that is created inside this environment may be used outside of the environment. Examples ======== We condition a quantum computation on the outcome of a previous measurement. :: from qrisp import * from qrisp.jasp import make_jaspr def test_f(i): a = QuantumFloat(3) a[:] = i b = measure(a) with control(b == 4): x(a[0]) return measure(a) jaspr = make_jaspr(test_f)(1) This jaspr receives an integer and encodes that integer into the :ref:`QuantumFloat` `a`. Subsequently `a` is measured and an X gate is applied onto the 0-th qubit of `a` if the measurement value is 4. We can now evaluate the jaspr on several inputs >>> jaspr(1) 1 >>> jaspr(2) 2 >>> jaspr(3) 3 >>> jaspr(4) 5 We see that in the case where 4 was encoded, the X gate was indeed executed. To elaborate the restriction of carry values, we give an example that would be illegal: :: def test_f(i): a = QuantumFloat(3) a[:] = i b = measure(a) with control(b == 4): c = QuantumFloat(2) return measure(c) jaspr = make_jaspr(test_f)(1) This script creates a ``QuantumFloat`` `c` within the classical control environment and subsequently uses `c` outside of the environment (in the return statement). It is however possible to create (quantum-)values within the environment and use them still within the environment: :: from qrisp import * from qrisp.jasp import make_jaspr def test_f(i): a = QuantumFloat(3) a[:] = i b = measure(a) with control(b == 4): c = QuantumFloat(2) h(c[0]) d = measure(c) # If c is measured to 1 # flip a and uncompute c with control(d == 1): x(a[0]) x(c[0]) c.delete() return measure(a) jaspr = make_jaspr(test_f)(1) This script allocates another :ref:`QuantumFloat` `c` within the ClControlEnvironment and applies an Hadamard gate to the 0-th qubit. Subsequently the whole ``QuantumFloat`` is measured. If the measurement turns out to be one, the zeroth qubit of `a` is flipped (similar to the above examples) and furthermore `c` is brought back to the $\ket{0}$ state. >>> jaspr(4) 5 >>> jaspr(4) 4 """ def __init__(self, ctrl_bls, ctrl_state=-1, invert = False): if not isinstance(ctrl_bls, list): ctrl_bls = [ctrl_bls] self.ctrl_bls = ctrl_bls QuantumEnvironment.__init__(self, env_args = ctrl_bls) # Process the ctrl_state self.ctrl_state = ctrl_state # If the ctrl state is a string, convert into an integer if isinstance(self.ctrl_state, str): if ctrl_state == len(ctrl_bls)*"1": self.ctrl_state = -1 else: self.ctrl_state = int(self.ctrl_state, 2) self.ctrl_state = self.ctrl_state%(2**len(self.ctrl_bls)) self.invert = invert def compile(self): for i in range(len(self.ctrl_bls)): if self.ctrl_bls[i] != bool((self.ctrl_state >> i) & 1): break else: QuantumEnvironment.compile(self) def __exit__(self, exception_type, exception_value, traceback): static_error_appeared = False if not check_for_tracing_mode(): if exception_type is not None: for i in range(len(self.ctrl_bls)): ctrl_bl = self.ctrl_bls[i] if (ctrl_bl ^ (self.ctrl_state >> i)) & 1: static_error_appeared = True break QuantumEnvironment.__exit__(self, exception_type, exception_value, traceback) if static_error_appeared: return True def jcompile(self, eqn, context_dic): args = extract_invalues(eqn, context_dic) # This list stores the the variables representing the control variables ctrl_vars = args[1:len(self.ctrl_bls)+1] # This list stores the variables used in the environment body env_vars = [args[0]] + args[len(self.ctrl_bls)+1:] # Flatten the environments in the body body_jaspr = eqn.params["jaspr"].flatten_environments() if len(body_jaspr.outvars) > 1: raise Exception("Found ClControlEnvironment with carry value") # Compute the control bool tmp = ctrl_vars[0] cond_bl = tmp # Process the control state requirement if self.ctrl_state != -1 and ((self.ctrl_state & 1) == 0): cond_bl = ~ tmp # If there is more than one control variable, loop through if len(ctrl_vars) > 1: for i in range(1, len(ctrl_vars)): tmp = ctrl_vars[i] if self.ctrl_state != -1 and ((self.ctrl_state & 1<<i) == 0): tmp = ~ tmp cond_bl = cond_bl & tmp if self.invert: cond_bl = ~cond_bl def identity_fun(*args): return args[0] true_fun = identity_fun false_fun = identity_fun res_abs_qc = cond(cond_bl, true_fun, false_fun, *env_vars) insert_outvalues(eqn, context_dic, [res_abs_qc]) traced_eqn = jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1] branch_0 = traced_eqn.params["branches"][0] branch_1 = traced_eqn.params["branches"][1] from qrisp.jasp import Jaspr traced_eqn.params["branches"] = (jax.core.ClosedJaxpr(Jaspr.from_cache(branch_0.jaxpr), branch_0.consts), jax.core.ClosedJaxpr(body_jaspr, branch_1.consts))