"""
\********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""
from jax.lax import fori_loop, while_loop, cond
import jax
from qrisp.jasp.tracing_logic import TracingQuantumSession
from qrisp.jasp.primitives import AbstractQuantumCircuit
[docs]
def q_while_loop(cond_fun, body_fun, init_val):
"""
Jasp compatible version of
`jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html#jax.lax.while_loop>`_
The parameters and semantics are the same as for the Jax version.
In particular the following loop is performed
::
def q_while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
Parameters
----------
cond_fun : callable
A function that evaluates the condition of the while loop. Must not
contain any quantum operations.
body_fun : callable
A function describing the body of the loop.
init_val : object
An object to initialize the loop.
Raises
------
Exception
Tried to modify quantum state during while condition evaluation.
Returns
-------
val
The result of ``body_fun`` after the last iteration.
Examples
--------
We write a dynamic loop that collects measurement values of a quantum
qubits into an accumulator. Note that the accumulator variable is a carry
value implying the loop could not be implemented using :ref:`jrange`.
::
from qrisp import *
from qrisp.jasp import *
@jaspify
def main(k):
qf = QuantumFloat(6)
def body_fun(val):
i, acc, qf = val
x(qf[i])
acc += measure(qf[i])
i += 1
return i, acc, qf
def cond_fun(val):
return val[0] < 5
i, acc, qf = q_while_loop(cond_fun, body_fun, (0, 0, qf))
return acc, measure(qf)
print(main(6))
# Yields
# (Array(5, dtype=int64), Array(31., dtype=float64))
"""
def new_cond_fun(val):
temp_qc = qs.abs_qc
res = cond_fun(val[1])
if not qs.abs_qc is temp_qc:
raise Exception("Tried to modify quantum state during while condition evaluation")
return res
def new_body_fun(val):
qs.start_tracing(val[0])
res = body_fun(val[1])
abs_qc = qs.conclude_tracing()
return (abs_qc, res)
qs = TracingQuantumSession.get_instance()
abs_qc = qs.abs_qc
new_init_val = (abs_qc, init_val)
while_res = while_loop(new_cond_fun, new_body_fun, new_init_val)
eqn = jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1]
body_jaxpr = eqn.params["body_jaxpr"].jaxpr
if not isinstance(body_jaxpr.invars[0].aval, AbstractQuantumCircuit):
raise Exception("Found implicit variable import in q_while. Please make sure all used variables are part of the body signature.")
from qrisp import Jaspr
eqn.params["body_jaxpr"] = jax.core.ClosedJaxpr(Jaspr.from_cache(body_jaxpr), eqn.params["body_jaxpr"].consts)
qs.abs_qc = while_res[0]
return while_res[1]
[docs]
def q_fori_loop(lower, upper, body_fun, init_val):
"""
Jasp compatible version of
`jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html#jax.lax.fori_loop>`_
The parameters and semantics are the same as for the Jax version.
In particular the following loop is performed
::
def q_fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
Parameters
----------
lower : int or jax.core.Tracer
An integer representing the loop index lower bound (inclusive).
upper : int or jax.core.Tracer
An integer representing the loop index upper bound (exclusive).
body_fun : callable
The function describing the loop body.
init_val : object
Some object to initialize the loop with.
Returns
-------
val : object
The return value of body_fun after the final iteration.
Examples
--------
We write a dynamic loop that collects measurement values of a quantum
qubits into an accumulator:
::
@jaspify
def main(k):
qf = QuantumFloat(6)
def body_fun(i, val):
acc, qf = val
x(qf[i])
acc += measure(qf[i])
return acc, qf
acc, qf = q_fori_loop(0, k, body_fun, (0, qf))
return acc, measure(qf)
print(main(k))
# Yields:
# (Array(5, dtype=int64), Array(31., dtype=float64))
"""
def new_body_fun(val):
body_val = val[0]
i = val[1]
return (body_fun(i, body_val), i + 1, val[2])
def new_cond_fun(val):
i = val[1]
upper = val[2]
return i < upper
return q_while_loop(new_cond_fun, new_body_fun, (init_val, lower, upper))[0]
[docs]
def q_cond(pred, true_fun, false_fun, *operands):
r"""
Jasp compatible version of
`jax.lax.cond <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html#jax.lax.cond>`_
The parameters and semantics are the same as for the Jax version.
Performs the following semantics:
::
def q_cond(pred, true_fun, false_fun, *operands):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)
Parameters
----------
pred : bool or jax.core.Tracer
A boolean value, deciding which function gets executed.
true_fun : callable
The function that is executed when ``pred`` is True.
false_fun : callable
The function that is executed when ``pred`` is False.
*operands : tuple
The input values for both functions.
Returns
-------
object
The return value of the respective function.
Examples
--------
We write a script that brings a :ref:`QuantumBool` into superpostion and
subsequently measures it. If the measurement result is ``False`` we flip
it such that in the end, the bool will always be in the $\ket{\text{True}}$
state.
::
from qrisp import *
from qrisp.jasp import *
@jaspify
def main():
def false_fun(qbl):
qbl.flip()
return qbl
def true_fun(qbl):
return qbl
qbl = QuantumBool()
h(qbl)
pred = measure(qbl)
qbl = q_cond(pred,
true_fun,
false_fun,
qbl)
return measure(qbl)
print(main())
# Yields:
# True
"""
def new_true_fun(*operands):
qs.start_tracing(operands[0])
res = true_fun(*operands[1])
abs_qc = qs.conclude_tracing()
return (abs_qc, res)
def new_false_fun(*operands):
qs.start_tracing(operands[0])
res = false_fun(*operands[1])
abs_qc = qs.conclude_tracing()
return (abs_qc, res)
qs = TracingQuantumSession.get_instance()
abs_qc = qs.abs_qc
new_operands = (abs_qc, operands)
cond_res = cond(pred, new_true_fun, new_false_fun, *new_operands)
eqn = jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1]
false_jaxpr = eqn.params["branches"][0].jaxpr
true_jaxpr = eqn.params["branches"][1].jaxpr
if not isinstance(false_jaxpr.invars[0].aval, AbstractQuantumCircuit):
raise Exception("Found implicit variable import in q_cond. Please make sure all used variables are part of the body signature.")
from qrisp.jasp import Jaspr
eqn.params["branches"] = (jax.core.ClosedJaxpr(Jaspr.from_cache(false_jaxpr), eqn.params["branches"][0].consts),
jax.core.ClosedJaxpr(Jaspr.from_cache(true_jaxpr), eqn.params["branches"][1].consts))
qs.abs_qc = cond_res[0]
return cond_res[1]