Source code for qrisp.jasp.program_control.jrange_iterator
"""\********************************************************************************* Copyright (c) 2025 the Qrisp authors** This program and the accompanying materials are made available under the* terms of the Eclipse Public License 2.0 which is available at* http://www.eclipse.org/legal/epl-2.0.** This Source Code may also be made available under the following Secondary* Licenses when the conditions for such availability set forth in the Eclipse* Public License, v. 2.0 are satisfied: GNU General Public License, version 2* with the GNU Classpath Exception which is* available at https://www.gnu.org/software/classpath/license.html.** SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0********************************************************************************/"""importjax.numpyasjnpfromjaxlib.xla_extensionimportArrayImplfromjaximportjitfromqrisp.jasp.tracing_logicimportcheck_for_tracing_modeclassJRangeIterator:def__init__(self,*args):# Differentiate between the 3 possible cases of input signatureiflen(args)==1:# In the case of one input argument, this argument is the stop valueself.start=Noneself.stop=jnp.asarray(args[0],dtype="int64")self.step=jnp.asarray(1,dtype="int64")eliflen(args)==2:self.start=jnp.asarray(args[0],dtype="int64")self.stop=jnp.asarray(args[1],dtype="int64")self.step=jnp.asarray(1,dtype="int64")eliflen(args)==3:# Three arguments denote the case of a non-trivial stepself.start=jnp.asarray(args[0],dtype="int64")self.stop=jnp.asarray(args[1],dtype="int64")self.step=jnp.asarray(args[2],dtype="int64")# The loop index should be inclusive because this makes loop inversion# much easier. For more details check inv_transform.py.self.stop-=1def__iter__(self):self.iteration=0# We create the loop iteration index tracerifself.startisNone:self.loop_index=self.stop-self.stopelse:self.loop_index=self.start+0returnselfdef__next__(self):# The idea is now to trace two iterations to capture what values get# updated after each iteration.# We capture the loop semantics using the JIterationEnvironment.# The actual jax loop primitive is then compiled in# JIterationEnvironment.jcompilefromqrisp.jaspimportTracingQuantumSessionfromqrispimportresetself.iteration+=1ifself.iteration==1:fromqrisp.environmentsimportJIterationEnvironmentself.iter_env=JIterationEnvironment()# Enter the environmentself.iter_env.__enter__()# We perform a trivial addition on the loop cancelation index.# This way the loop cancelation index will appear in the collected# quantum environment jaxpr and can therefore be identified as such.self.stop+0self.iter_1_qvs=list(TracingQuantumSession.get_instance().qv_list)returnself.loop_indexelifself.iteration==2:qs=TracingQuantumSession.get_instance()created_qvs=set(list(qs.qv_list))-set(self.iter_1_qvs)created_qvs=list(created_qvs)created_qvs=sorted(created_qvs,key=lambdax:hash(x))ifqs.gc_mode=="auto":forqvincreated_qvs:reset(qv)qv.delete()elifqs.gc_mode=="debug"andlen(created_qvs):raiseException(f"QuantumVariables {created_qvs} went out of scope without deletion during jrange")# Perform the incrementationself.loop_index+=self.step# Exit the old environment and enter the new one.self.iter_env.__exit__(None,None,None)self.iter_env.__enter__()# Similar to the incrementation aboveself.stop+0self.iter_2_qvs=list(TracingQuantumSession.get_instance().qv_list)returnself.loop_indexelifself.iteration==3:qs=TracingQuantumSession.get_instance()created_qvs=set(list(qs.get_instance().qv_list))-set(self.iter_2_qvs)created_qvs=list(created_qvs)created_qvs=sorted(created_qvs,key=lambdax:hash(x))ifqs.gc_mode=="auto":forqvincreated_qvs:reset(qv)qv.delete()elifqs.gc_mode=="debug"andlen(created_qvs):raiseException(f"QuantumVariables {created_qvs} went out of scope without deletion during jrange")self.loop_index+=self.stepself.iter_env.__exit__(None,None,None)raiseStopIteration
[docs]defjrange(*args):""" Performs a loop with a dynamic bound. Similar to the Python native ``range``, this iterator can receive multiple arguments. If it receives just one, this value is interpreted as the stop value and the start value is assumed to be 0. Two arguments represent start and stop value, whereas three represent start, stop and step. .. warning:: Similar to the :ref:`ClControlEnvironment <ClControlEnvironment>`, this feature must not have external carry values, implying values computed within the loop can't be used outside of the loop. It is however possible to carry on values from the previous iteration. .. warning:: Each loop iteration must perform exactly the same instructions - the only thing that changes is the loop index Parameters ---------- start : int The loop index to start at. stop : int The loop index to stop at. step : int The value to increase the loop index by after each iteration. Examples -------- We construct a function that encodes an integer into an arbitrarily sized :ref:`QuantumVariable`: :: from qrisp import QuantumFloat, control, x from qrisp.jasp import jrange, make_jaspr @qache def int_encoder(qv, encoding_int): for i in jrange(qv.size): with control(encoding_int & (1<<i)): x(qv[i]) def test_f(a, b): qv = QuantumFloat(a) int_encoder(qv, b+1) return measure(qv) jaspr = make_jaspr(test_f)(1,1) Test the result: >>> jaspr(5, 8) 9 >>> jaspr(5, 9) 10 We now give examples that violate the above rules (ie. no carries and changing iteration behavior). To create a loop with carry behavior we simply return the final loop index :: @qache def int_encoder(qv, encoding_int): for i in jrange(qv.size): with control(encoding_int & (1<<i)): x(qv[i]) return i def test_f(a, b): qv = QuantumFloat(a) int_encoder(qv, b+1) return measure(qv) jaspr = make_jaspr(test_f)(1,1) >>> jaspr(5, 8) Exception: Found jrange with external carry value To demonstrate the second kind of illegal behavior, we construct a loop that behaves differently on the first iteration: :: @qache def int_encoder(qv, encoding_int): flag = True for i in jrange(qv.size): if flag: with control(encoding_int & (1<<i)): x(qv[i]) else: x(qv[0]) flag = False def test_f(a, b): qv = QuantumFloat(a) int_encoder(qv, b+1) return measure(qv) jaspr = make_jaspr(test_f)(1,1) In this script, ``int_encoder`` defines a boolean flag that changes the semantics of the iteration behavior. After the first iteration the flag is set to ``False`` such that the alternate behavior is activated. >>> jaspr(5, 8) Exception: Jax semantics changed during jrange iteration """new_args=[]ifcheck_for_tracing_mode():foriinrange(len(args)):ifi==2:new_args.append(args[i])continueifisinstance(args[i],(int,ArrayImpl)):new_args.append(make_tracer(args[i]))else:new_args.append(args[i])returnJRangeIterator(*new_args)else:foriinrange(len(args)):ifnotisinstance(args[i],int):new_args.append(int(args[i]))else:new_args.append(args[i])returnrange(*new_args)
defmake_tracer(x):ifisinstance(x,bool):dtype=jnp.boolelifisinstance(x,int):dtype=jnp.int64elifisinstance(x,float):dtype=jnp.float64elifisinstance(x,complex):dtype=jnp.complex32else:raiseException(f"Don't know how to tracerize type {type(x)}")deftracerizer():returnjnp.array(x,dtype)returnjit(tracerizer)()defjlen(x):ifisinstance(x,list):returnlen(x)else:returnx.size
Get in touch!
If you are interested in Qrisp or high-level quantum algorithm research in general connect with us on our
Slack workspace.