Source code for qrisp.jasp.jasp_expression.centerclass
"""\********************************************************************************* Copyright (c) 2025 the Qrisp authors** This program and the accompanying materials are made available under the* terms of the Eclipse Public License 2.0 which is available at* http://www.eclipse.org/legal/epl-2.0.** This Source Code may also be made available under the following Secondary* Licenses when the conditions for such availability set forth in the Eclipse* Public License, v. 2.0 are satisfied: GNU General Public License, version 2* with the GNU Classpath Exception which is* available at https://www.gnu.org/software/classpath/license.html.** SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0********************************************************************************/"""fromfunctoolsimportlru_cacheimportjaxfromjaximportmake_jaxprfromjax.coreimportJaxpr,Literalfromjax.tree_utilimporttree_flatten,tree_unflattenfromjax.errorsimportUnexpectedTracerErrorfromqrisp.jasp.jasp_expressionimportinvert_jaspr,collect_environmentsfromqrisp.jaspimporteval_jaxpr,pjit_to_gate,flatten_environments,cond_to_cl_controlfromqrisp.jasp.primitivesimportAbstractQuantumCircuit
[docs]classJaspr(Jaxpr):""" The ``Jaspr`` class enables an efficient representations of a wide variety of (hybrid) algorithms. For many applications, the representation is agnostic to the scale of the problem, implying function calls with 10 or 10000 qubits can be represented by the same object. The actual unfolding to a circuit-level description is outsourced to `established, classical compilation infrastructure <https://mlir.llvm.org/>`_, implying state-of-the-art compilation speed can be reached. As a subtype of ``jax.core.Jaxpr``, Jasprs are embedded into the well matured `Jax ecosystem <https://github.com/n2cholas/awesome-jax>`_, which facilitates the compilation of classical `real-time computation <https://arxiv.org/abs/2206.12950>`_ using some of the most advanced libraries in the world such as `CUDA <https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html>`_. Especially `machine learning <https://ai.google.dev/gemma/docs/jax_inference>`_ and other scientific computations tasks are particularly well supported. To get a better understanding of the syntax and semantics of Jaxpr (and with that also Jaspr) please check `this link <https://jax.readthedocs.io/en/latest/jaxpr.html>`__. Similar to Jaxpr, Jaspr objects represent (hybrid) quantum algorithms in the form of a `functional programming language <https://en.wikipedia.org/wiki/Functional_programming>`_ in `SSA-form <https://en.wikipedia.org/wiki/Static_single-assignment_form>`_. It is possible to compile Jaspr objects into QIR, which is facilitated by the `Catalyst framework <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`__ (check :meth:`qrisp.jasp.jaspr.to_qir` for more details). Qrisp scripts can be turned into Jaspr objects by calling the ``make_jaspr`` function, which has similar semantics as `jax.make_jaxpr <https://jax.readthedocs.io/en/latest/_autosummary/jax.make_jaxpr.html>`_. :: from qrisp import * from qrisp.jasp import make_jaspr def test_fun(i): qv = QuantumFloat(i, -1) x(qv[0]) cx(qv[0], qv[i-1]) meas_res = measure(qv) meas_res += 1 return meas_res jaspr = make_jaspr(test_fun)(4) print(jaspr) This will give you the following output: .. code-block:: { lambda ; a:QuantumCircuit b:i32[]. let c:QuantumCircuit d:QubitArray = create_qubits a b e:Qubit = get_qubit d 0 f:QuantumCircuit = x c e g:i32[] = sub b 1 h:Qubit = get_qubit d g i:QuantumCircuit = cx f e h j:QuantumCircuit k:i32[] = measure i d l:f32[] = convert_element_type[new_dtype=float64 weak_type=True] k m:f32[] = mul l 0.5 n:f32[] = add m 1.0 in (j, n) } A defining feature of the Jaspr class is that the first input and the first output are always of QuantumCircuit type. Therefore, Jaspr objects always represent some (hybrid) quantum operation. Qrisp comes with a built-in Jaspr interpreter. For that you simply have to call the object like a function: >>> print(jaspr(2)) 2.5 >>> print(jaspr(4)) 5.5 """__slots__="permeability","isqfree","hashvalue","ctrl_jaspr","envs_flattened","consts"def__init__(self,*args,permeability=None,isqfree=None,ctrl_jaspr=None,**kwargs):iflen(args)==1:kwargs["jaxpr"]=args[0]if"jaxpr"inkwargs:jaxpr=kwargs["jaxpr"]self.hashvalue=hash(jaxpr)Jaxpr.__init__(self,constvars=jaxpr.constvars,invars=jaxpr.invars,outvars=jaxpr.outvars,eqns=jaxpr.eqns,effects=jaxpr.effects,debug_info=jaxpr.debug_info)else:self.hashvalue=id(self)Jaxpr.__init__(self,**kwargs)self.permeability={}ifpermeabilityisNone:permeability={}forvarinself.constvars+self.invars+self.outvars:ifisinstance(var,Literal):continueself.permeability[var]=permeability.get(var,None)self.isqfree=isqfreeself.ctrl_jaspr=ctrl_jasprself.envs_flattened=Falseself.consts=[]ifnotisinstance(self.invars[0].aval,AbstractQuantumCircuit):raiseException(f"Tried to create a Jaspr from data that doesn't have a QuantumCircuit as first argument (got {type(self.invars[0].aval)} instead)")ifnotisinstance(self.outvars[0].aval,AbstractQuantumCircuit):raiseException(f"Tried to create a Jaspr from data that doesn't have a QuantumCircuit as first entry of return type (got {type(self.outvars[0].aval)} instead)")def__hash__(self):returnself.hashvaluedef__eq__(self,other):ifnotisinstance(other,Jaxpr):returnFalsereturnid(self)==id(other)defcopy(self):ifself.ctrl_jasprisNone:ctrl_jaspr=Noneelse:ctrl_jaspr=self.ctrl_jaspr.copy()res=Jaspr(permeability=self.permeability,isqfree=self.isqfree,ctrl_jaspr=ctrl_jaspr,constvars=list(self.constvars),invars=list(self.invars),outvars=list(self.outvars),eqns=list(self.eqns),effects=self.effects,debug_info=self.debug_info)res.envs_flattened=self.envs_flattenedreturnres
[docs]definverse(self):""" Returns the inverse Jaspr (if applicable). For Jaspr that contain realtime computations or measurements, the inverse does not exist. Returns ------- Jaspr The daggered Jaspr. Examples -------- We create a simple script and inspect the daggered version: :: from qrisp import * from qrisp.jasp import make_jaspr def example_function(i): qv = QuantumVariable(i) cx(qv[0], qv[1]) t(qv[1]) return qv jaspr = make_jaspr(example_function)(2) print(jaspr.inverse()) # Yields # { lambda ; a:QuantumCircuit b:i32[]. let # c:QuantumCircuit d:QubitArray = create_qubits a b # e:Qubit = get_qubit d 0 # f:Qubit = get_qubit d 1 # g:QuantumCircuit = t_dg c f # h:QuantumCircuit = cx g e f # in (h, d) } """returninvert_jaspr(self)
[docs]defcontrol(self,num_ctrl,ctrl_state=-1):""" Returns the controlled version of the Jaspr. The control qubits are added to the signature of the Jaspr as the arguments after the QuantumCircuit. Parameters ---------- num_ctrl : int The amount of controls to be added. ctrl_state : int of str, optional The control state on which to activate. The default is -1. Returns ------- Jaspr The controlled Jaspr. Examples -------- We create a simple script and inspect the controlled version: :: from qrisp import * from qrisp.jasp import make_jaspr def example_function(i): qv = QuantumVariable(i) cx(qv[0], qv[1]) t(qv[1]) return qv jaspr = make_jaspr(example_function)(2) print(jaspr.control(2)) # Yields # { lambda ; a:QuantumCircuit b:Qubit c:Qubit d:i32[]. let # e:QuantumCircuit f:QubitArray = create_qubits a 1 # g:Qubit = get_qubit f 0 # h:QuantumCircuit = ccx e b c g # i:QuantumCircuit j:QubitArray = create_qubits h d # k:Qubit = get_qubit j 0 # l:Qubit = get_qubit j 1 # m:QuantumCircuit = ccx i g k l # n:QuantumCircuit = ct m g l # o:QuantumCircuit = ccx n b c g # in (o, j) } We see that the control qubits are part of the function signature (``a`` and ``b``) """ifself.ctrl_jasprisnotNoneandnum_ctrl==1andctrl_state==-1:returnself.ctrl_jasprfromqrisp.jaspimportControlledJasprifisinstance(ctrl_state,int):ifctrl_state<0:ctrl_state+=2**num_ctrlctrl_state=bin(ctrl_state)[2:].zfill(num_ctrl)else:ctrl_state=str(ctrl_state)returnControlledJaspr.from_cache(self,ctrl_state)
[docs]defto_qc(self,*args):""" Converts the Jaspr into a :ref:`QuantumCircuit` if applicable. Circuit conversion of algorithms involving realtime computations is not possible. Parameters ---------- *args : tuple The arguments to call the Jaspr with. Returns ------- :ref:`QuantumCircuit` The resulting QuantumCircuit. return_values : tuple The return values of the Jaspr. QuantumVariable return types are returned as lists of Qubits. Examples -------- We create a simple script and inspect the QuantumCircuit: :: from qrisp import * from qrisp.jasp import make_jaspr def example_function(i): qv = QuantumVariable(i) cx(qv[0], qv[1]) t(qv[1]) return qv jaspr = make_jaspr(example_function)(2) qc, qb_list = jaspr.to_qc(2) print(qc) # Yields # qb_0: ──■─────── # ┌─┴─┐┌───┐ # qb_1: ┤ X ├┤ T ├ # └───┘└───┘ """fromqrispimportQuantumCircuit,Clbitjaspr=selfdefeqn_evaluator(eqn,context_dic):ifeqn.primitive.name=="pjit"andisinstance(eqn.params["jaxpr"].jaxpr,Jaspr):returnpjit_to_gate(eqn,context_dic,eqn_evaluator)elifeqn.primitive.name=="cond":returncond_to_cl_control(eqn,context_dic,eqn_evaluator)elifeqn.primitive.name=="convert_element_type":ifisinstance(context_dic[eqn.invars[0]],Clbit):context_dic[eqn.outvars[0]]=context_dic[eqn.invars[0]]returnreturnTrueres=eval_jaxpr(jaspr,eqn_evaluator=eqn_evaluator)(*([QuantumCircuit()]+list(args)))returnres
defeval(self,*args,eqn_evaluator=lambdax,y:True):returneval_jaxpr(self,eqn_evaluator=eqn_evaluator)(*args)defflatten_environments(self):""" Flattens all environments by applying the corresponding compilation routines such that no more ``q_env`` primitives are left. Returns ------- Jaspr The Jaspr with flattened environments. Examples -------- We create a Jaspr containing an :ref:`InversionEnvironment` and flatten: :: def test_function(i): qv = QuantumVariable(i) with invert(): t(qv[0]) cx(qv[0], qv[1]) return qv jaspr = make_jaspr(test_function)(2) print(jaspr) :: { lambda ; a:QuantumCircuit b:i32[]. let c:QuantumCircuit d:QubitArray = create_qubits a b e:QuantumCircuit = q_env[ jaspr={ lambda ; f:QuantumCircuit d:QubitArray. let g:Qubit = get_qubit d 0 h:QuantumCircuit = t f g i:Qubit = get_qubit d 1 j:QuantumCircuit = cx h g i in (j,) } type=InversionEnvironment ] c d in (e, d) } You can see how the body of the :ref:`InversionEnvironment` is __collected__ into another Jaspr. This reflects the fact that at their core, :ref:`QuantumEnvironment <QuantumEnvironment>` describe `higher-order quantum functions <https://en.wikipedia.org/wiki/Higher-order_function>`_ (ie. functions that operate on functions). In order to apply the transformations induced by the QuantumEnvironment, we can call ``jaspr.flatten_environments``: >>> print(jaspr.flatten_environments) { lambda ; a:QuantumCircuit b:i32[]. let c:QuantumCircuit d:QubitArray = create_qubits a b e:Qubit = get_qubit d 0 f:Qubit = get_qubit d 1 g:QuantumCircuit = cx c e f h:QuantumCircuit = t_dg g e in (h, d) } We see that as expected, the order of the ``cx`` and the ``t`` gate has been switched and the ``t`` gate has been turned into a ``t_dg``. """res=flatten_environments(self)ifself.ctrl_jasprisnotNone:res.ctrl_jaspr=self.ctrl_jaspr.flatten_environments()returnresdef__call__(self,*args):fromqrisp.jasp.evaluation_tools.jaspificationimportsimulate_jasprreturnsimulate_jaspr(self,*args)iflen(self.outvars)==1:returnNonefromjax.tree_utilimporttree_flattenfromqrisp.simulatorimportBufferedQuantumStateargs=[BufferedQuantumState()]+list(tree_flatten(args)[0])fromqrisp.jaspimportextract_invalues,insert_outvalues,eval_jaxprflattened_jaspr=selfdefeqn_evaluator(eqn,context_dic):ifeqn.primitive.name=="pjit":ifeqn.params["name"]=="expectation_value_eval_function":fromqrisp.jasp.program_controlimportsampling_evaluatorsampling_evaluator("ev")(eqn,context_dic,eqn_evaluator=eqn_evaluator)returnifeqn.params["name"]=="sampling_eval_function":fromqrisp.jasp.program_controlimportsampling_evaluatorsampling_evaluator("array")(eqn,context_dic,eqn_evaluator=eqn_evaluator)returninvalues=extract_invalues(eqn,context_dic)outvalues=eval_jaxpr(eqn.params["jaxpr"],eqn_evaluator=eqn_evaluator)(*invalues)ifnotisinstance(outvalues,(list,tuple)):outvalues=[outvalues]insert_outvalues(eqn,context_dic,outvalues)elifeqn.primitive.name=="jasp.quantum_kernel":insert_outvalues(eqn,context_dic,BufferedQuantumState())else:returnTrueres=eval_jaxpr(flattened_jaspr,eqn_evaluator=eqn_evaluator)(*(args+self.consts))iflen(self.outvars)==2:returnres[1]else:returnres[1:]definline(self,*args):fromqrisp.jaspimportTracingQuantumSessionqs=TracingQuantumSession.get_instance()abs_qc=qs.abs_qcres=eval_jaxpr(self)(*([abs_qc]+list(args)))ifisinstance(res,tuple):new_abs_qc=res[0]res=res[1:]else:new_abs_qc=resres=Noneqs.abs_qc=new_abs_qcreturnresdefembedd(self,*args,name=None,inline=False):fromqrisp.jaspimportTracingQuantumSessionqs=TracingQuantumSession.get_instance()abs_qc=qs.abs_qcifnotinline:res=jax.jit(eval_jaxpr(self))(*([abs_qc]+list(args)))eqn=jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1]eqn.params["jaxpr"]=jax.core.ClosedJaxpr(self,eqn.params["jaxpr"].consts)ifnameisnotNone:eqn.params["name"]=nameelse:res=eval_jaxpr(self)(*([abs_qc]+list(args)))ifisinstance(res,tuple):new_abs_qc=res[0]res=res[1:]else:new_abs_qc=resres=Noneqs.abs_qc=new_abs_qcreturnres
[docs]defqjit(self,*args,function_name="jaspr_function"):""" Leverages the Catalyst pipeline to compile a QIR representation of this function and executes that function using the Catalyst QIR runtime. Parameters ---------- *args : iterable The arguments to call the function with. Returns ------- The values returned by the compiled, executed function. """flattened_jaspr=selffromqrisp.jasp.evaluation_tools.catalyst_interfaceimportjaspr_to_catalyst_qjitqjit_obj=jaspr_to_catalyst_qjit(flattened_jaspr,function_name=function_name)res=qjit_obj.compiled_function(*args)ifnotisinstance(res,(tuple,list)):returnreseliflen(res)==1:returnres[0]else:returnres
[docs]defto_qasm(self,*args):""" Compiles the Jaspr into an OpenQASM 2 string. Real-time control is possible as long as no computations on the measurement results are performed. Parameters ---------- *args : list The arguments to call the :ref:`QuantumCircuit` evaluation with. Returns ------- str The OpenQASM 2 string. Examples -------- We create a simple script and inspect the QASM 2 string: :: from qrisp import * from qrisp.jasp import make_jaspr def main(i): qv = QuantumVariable(i) cx(qv[0], qv[1]) t(qv[1]) return qv jaspr = make_jaspr(main)(2) qasm_str = jaspr.to_qasm(2) print(qasm_str) # Yields # OPENQASM 2.0; # include "qelib1.inc"; # qreg qb_59[1]; # qreg qb_60[1]; # cx qb_59[0],qb_60[0]; # t qb_60[0]; It is also possible to compile simple real-time control features: :: def main(phi): qf = QuantumFloat(5) h(qf) bl = measure(qf[0]) with control(bl): rz(phi, qf[1]) x(qf[1]) return jaspr = make_jaspr(main)(0.5) print(jaspr.to_qasm(0.5)) This gives: :: OPENQASM 2.0; include "qelib1.inc"; qreg qb_59[1]; qreg qb_60[1]; qreg qb_61[1]; qreg qb_62[1]; qreg qb_63[1]; creg cb_0[1]; h qb_59[0]; h qb_60[0]; h qb_61[0]; reset qb_61[0]; h qb_62[0]; reset qb_62[0]; h qb_63[0]; reset qb_63[0]; measure qb_59[0] -> cb_0[0]; reset qb_59[0]; if(cb_0==1) rz(0.5) qb_60[0]; if(cb_0==1) x qb_60[0]; reset qb_60[0]; """res=self.to_qc(*args)iflen(self.outvars)==1:res=[res]qrisp_qc=res[0]returnqrisp_qc.qasm()
[docs]defto_catalyst_jaxpr(self):""" Compiles the jaspr to the corresponding `Catalyst jaxpr <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`__. Parameters ---------- *args : tuple The arguments to call the jaspr with. Returns ------- jax.core.Jaxpr The Jaxpr using Catalyst primitives. Examples -------- We create a simple script and inspect the Catalyst Jaxpr: :: from qrisp import * from qrisp.jasp import make_jaspr def example_function(i): qv = QuantumFloat(i) cx(qv[0], qv[1]) t(qv[1]) meas_res = measure(qv) meas_res += 1 return meas_res jaspr = make_jaspr(example_function)(2) print(jaspr.to_catalyst_jaxpr()) # Yields # { lambda ; a:AbstractQreg() b:i64[] c:i32[]. let # d:i64[] = convert_element_type[new_dtype=int64 weak_type=True] c # e:i64[] = add b d # f:i64[] = add b 0 # g:i64[] = add b 1 # h:AbstractQbit() = qextract a f # i:AbstractQbit() = qextract a g # j:AbstractQbit() k:AbstractQbit() = qinst[op=CNOT qubits_len=2] h i # l:AbstractQreg() = qinsert a f j # m:AbstractQreg() = qinsert l g k # n:AbstractQbit() = qextract m g # o:AbstractQbit() = qinst[op=T qubits_len=1] n # p:AbstractQreg() = qinsert m g o # q:i64[] = convert_element_type[new_dtype=int64 weak_type=True] c # r:i64[] = add b q # _:i64[] s:i64[] t:AbstractQreg() _:i64[] _:i64[] = while_loop[ # body_jaxpr={ lambda ; u:i64[] v:i64[] w:AbstractQreg() x:i64[] y:i64[]. let # z:AbstractQbit() = qextract w u # ba:bool[] bb:AbstractQbit() = qmeasure z # bc:AbstractQreg() = qinsert w u bb # bd:i64[] = sub u x # be:i64[] = shift_left 2 bd # bf:i64[] = convert_element_type[new_dtype=int64 weak_type=True] ba # bg:i64[] = mul be bf # bh:i64[] = add v bg # bi:i64[] = add u 1 # in (bi, bh, bc, x, y) } # body_nconsts=0 # cond_jaxpr={ lambda ; bj:i64[] bk:i64[] bl:AbstractQreg() bm:i64[] bn:i64[]. let # bo:bool[] = ge bj bn # in (bo,) } # cond_nconsts=0 # nimplicit=0 # preserve_dimensions=True # ] b 0 p b r # bp:i32[] = convert_element_type[new_dtype=int64 weak_type=False] s # bq:i32[] = mul bp 1 # br:i32[] = add bq 1 # in (t, e, br) } """fromqrisp.jasp.evaluation_tools.catalyst_interfaceimportjaspr_to_catalyst_jaxprreturnjaspr_to_catalyst_jaxpr(self.flatten_environments())
defmake_jaspr(fun,garbage_collection="auto",flatten_envs=True,**jax_kwargs):fromqrisp.jaspimportAbstractQuantumCircuit,TracingQuantumSession,check_for_tracing_modefromqrisp.core.quantum_variableimportQuantumVariable,flatten_qv,unflatten_qvfromqrisp.coreimportrecursive_qv_searchdefjaspr_creator(*args,**kwargs):qs=TracingQuantumSession.get_instance()# Close any tracing quantum sessions that might have not been# properly closed due to whatever reason.ifnotcheck_for_tracing_mode():whileqs.abs_qcisnotNone:qs.conclude_tracing()# This function will be traced by Jax.# Note that we add the abs_qc keyword as the tracing quantum circuit defammended_function(abs_qc,*args,**kwargs):qs.start_tracing(abs_qc,garbage_collection)# If the signature contains QuantumVariables, these QuantumVariables went# through a flattening/unflattening procedure. The unflattening creates# a copy of the QuantumVariable object, which is however not yet registered in any# QuantumSession. We register these QuantumVariables in the current QuantumSession.arg_qvs=recursive_qv_search(args)forqvinarg_qvs:qs.register_qv(qv,None)try:res=fun(*args,**kwargs)exceptExceptionase:qs.conclude_tracing()raiseeres_qvs=recursive_qv_search(res)qs.garbage_collection(spare_qv_list=arg_qvs+res_qvs)res_qc=qs.conclude_tracing()returnres_qc,restry:closed_jaxpr=make_jaxpr(ammended_function,**jax_kwargs)(AbstractQuantumCircuit(),*args,**kwargs)exceptUnexpectedTracerErrorase:if"intermediate value with type QuantumCircuit"instr(e):raiseException("""Lost track of QuantumCircuit during tracing. This might have been caused by a missing quantum_kernel decorator. Please visit https://www.qrisp.eu/reference/Jasp/Quantum%20Kernel.html for more details""")raiseejaxpr=closed_jaxpr.jaxpr# Collect the environments# This means that the quantum environments no longer appear as# enter/exit primitives but as primitive that "call" a certain Jaspr.res=Jaspr.from_cache(collect_environments(jaxpr))ifflatten_envs:res=res.flatten_environments()res.consts=closed_jaxpr.constsreturnres# Since we are calling the "ammended function", where the first parameter# is the AbstractQuantumCircuit, we need to move the static_argnums indicator.if"static_argnums"injax_kwargs:jax_kwargs=dict(jax_kwargs)ifisinstance(jax_kwargs["static_argnums"],list):jax_kwargs["static_argnums"]=list(jax_kwargs["static_argnums"])foriinrange(len(jax_kwargs["static_argnums"])):jax_kwargs["static_argnums"][i]+=1else:jax_kwargs["static_argnums"]+=1returnjaspr_creatordefcheck_aval_equivalence(invars_1,invars_2):avals_1=[invar.avalforinvarininvars_1]avals_2=[invar.avalforinvarininvars_2]returnall([type(avals_1[i])==type(avals_2[i])foriinrange(len(avals_1))])
Get in touch!
If you are interested in Qrisp or high-level quantum algorithm research in general connect with us on our
Slack workspace.