Source code for qrisp.jasp.program_control.jrange_iterator

"""
\********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""

import jax.numpy as jnp
from jaxlib.xla_extension import ArrayImpl
from jax import jit

from qrisp.jasp.tracing_logic import check_for_tracing_mode

class JRangeIterator:
    
    def __init__(self, *args):
        
        # Differentiate between the 3 possible cases of input signature
        
        
        if len(args) == 1:
            # In the case of one input argument, this argument is the stop value
            self.start = None
            self.stop = jnp.asarray(args[0], dtype = "int64")
            self.step = jnp.asarray(1, dtype = "int64")
        elif len(args) == 2:
            self.start = jnp.asarray(args[0], dtype = "int64")
            self.stop = jnp.asarray(args[1], dtype = "int64")
            self.step = jnp.asarray(1, dtype = "int64")
        elif len(args) == 3:
            # Three arguments denote the case of a non-trivial step
            self.start = jnp.asarray(args[0], dtype = "int64")
            self.stop = jnp.asarray(args[1], dtype = "int64")
            self.step = jnp.asarray(args[2], dtype = "int64")
            
        # The loop index should be inclusive because this makes loop inversion
        # much easier. For more details check inv_transform.py.
        self.stop -= 1
        
    def __iter__(self):
        self.iteration = 0
        
        # We create the loop iteration index tracer
        if self.start is None:
            self.loop_index = self.stop - self.stop
        else:
            self.loop_index = self.start + 0
        return self
    
    def __next__(self):
        # The idea is now to trace two iterations to capture what values get
        # updated after each iteration.
        # We capture the loop semantics using the JIterationEnvironment.
        # The actual jax loop primitive is then compiled in
        # JIterationEnvironment.jcompile
        from qrisp.jasp import TracingQuantumSession
        from qrisp import reset
        
        self.iteration += 1
        if self.iteration == 1:
            from qrisp.environments import JIterationEnvironment
            self.iter_env = JIterationEnvironment()
            # Enter the environment
            self.iter_env.__enter__()
            
            # We perform a trivial addition on the loop cancelation index.
            # This way the loop cancelation index will appear in the collected
            # quantum environment jaxpr and can therefore be identified as such.
            self.stop + 0
            
            self.iter_1_qvs = list(TracingQuantumSession.get_instance().qv_list)
            
            return self.loop_index
            
        elif self.iteration == 2:
            
            qs = TracingQuantumSession.get_instance()
            created_qvs = set(list(qs.qv_list)) - set(self.iter_1_qvs)
            created_qvs = list(created_qvs)
            created_qvs = sorted(created_qvs, key = lambda x : hash(x))
            
            if qs.gc_mode == "auto":
                for qv in created_qvs:
                    reset(qv)
                    qv.delete()
            elif qs.gc_mode == "debug" and len(created_qvs):
                raise Exception(f"QuantumVariables {created_qvs} went out of scope without deletion during jrange")
            
            # Perform the incrementation
            self.loop_index += self.step
            
            # Exit the old environment and enter the new one.
            self.iter_env.__exit__(None, None, None)
            self.iter_env.__enter__()
            # Similar to the incrementation above
            
            
            self.stop + 0
            self.iter_2_qvs = list(TracingQuantumSession.get_instance().qv_list)
            
            return self.loop_index
            
        elif self.iteration == 3:
                        
            qs = TracingQuantumSession.get_instance()
            created_qvs = set(list(qs.get_instance().qv_list)) - set(self.iter_2_qvs)
            created_qvs = list(created_qvs)
            created_qvs = sorted(created_qvs, key = lambda x : hash(x))
            
            if qs.gc_mode == "auto":
                for qv in created_qvs:
                    reset(qv)
                    qv.delete()
            elif qs.gc_mode == "debug" and len(created_qvs):
                raise Exception(f"QuantumVariables {created_qvs} went out of scope without deletion during jrange")
            
            
            self.loop_index += self.step
            
            self.iter_env.__exit__(None, None, None)
            raise StopIteration

[docs] def jrange(*args): """ Performs a loop with a dynamic bound. Similar to the Python native ``range``, this iterator can receive multiple arguments. If it receives just one, this value is interpreted as the stop value and the start value is assumed to be 0. Two arguments represent start and stop value, whereas three represent start, stop and step. .. warning:: Similar to the :ref:`ClControlEnvironment <ClControlEnvironment>`, this feature must not have external carry values, implying values computed within the loop can't be used outside of the loop. It is however possible to carry on values from the previous iteration. .. warning:: Each loop iteration must perform exactly the same instructions - the only thing that changes is the loop index Parameters ---------- start : int The loop index to start at. stop : int The loop index to stop at. step : int The value to increase the loop index by after each iteration. Examples -------- We construct a function that encodes an integer into an arbitrarily sized :ref:`QuantumVariable`: :: from qrisp import QuantumFloat, control, x from qrisp.jasp import jrange, make_jaspr @qache def int_encoder(qv, encoding_int): for i in jrange(qv.size): with control(encoding_int & (1<<i)): x(qv[i]) def test_f(a, b): qv = QuantumFloat(a) int_encoder(qv, b+1) return measure(qv) jaspr = make_jaspr(test_f)(1,1) Test the result: >>> jaspr(5, 8) 9 >>> jaspr(5, 9) 10 We now give examples that violate the above rules (ie. no carries and changing iteration behavior). To create a loop with carry behavior we simply return the final loop index :: @qache def int_encoder(qv, encoding_int): for i in jrange(qv.size): with control(encoding_int & (1<<i)): x(qv[i]) return i def test_f(a, b): qv = QuantumFloat(a) int_encoder(qv, b+1) return measure(qv) jaspr = make_jaspr(test_f)(1,1) >>> jaspr(5, 8) Exception: Found jrange with external carry value To demonstrate the second kind of illegal behavior, we construct a loop that behaves differently on the first iteration: :: @qache def int_encoder(qv, encoding_int): flag = True for i in jrange(qv.size): if flag: with control(encoding_int & (1<<i)): x(qv[i]) else: x(qv[0]) flag = False def test_f(a, b): qv = QuantumFloat(a) int_encoder(qv, b+1) return measure(qv) jaspr = make_jaspr(test_f)(1,1) In this script, ``int_encoder`` defines a boolean flag that changes the semantics of the iteration behavior. After the first iteration the flag is set to ``False`` such that the alternate behavior is activated. >>> jaspr(5, 8) Exception: Jax semantics changed during jrange iteration """ new_args = [] if check_for_tracing_mode(): for i in range(len(args)): if i == 2: new_args.append(args[i]) continue if isinstance(args[i], (int, ArrayImpl)): new_args.append(make_tracer(args[i])) else: new_args.append(args[i]) return JRangeIterator(*new_args) else: for i in range(len(args)): if not isinstance(args[i], int): new_args.append(int(args[i])) else: new_args.append(args[i]) return range(*new_args)
def make_tracer(x): if isinstance(x, bool): dtype = jnp.bool elif isinstance(x, int): dtype = jnp.int64 elif isinstance(x, float): dtype = jnp.float64 elif isinstance(x, complex): dtype = jnp.complex32 else: raise Exception(f"Don't know how to tracerize type {type(x)}") def tracerizer(): return jnp.array(x, dtype) return jit(tracerizer)() def jlen(x): if isinstance(x, list): return len(x) else: return x.size