Jaspr#

class Jaspr(*args, permeability=None, isqfree=None, ctrl_jaspr=None, **kwargs)[source]#

The Jaspr class enables an efficient representations of a wide variety of (hybrid) algorithms. For many applications, the representation is agnostic to the scale of the problem, implying function calls with 10 or 10000 qubits can be represented by the same object. The actual unfolding to a circuit-level description is outsourced to established, classical compilation infrastructure, implying state-of-the-art compilation speed can be reached.

As a subtype of jax.core.Jaxpr, Jasprs are embedded into the well matured Jax ecosystem, which facilitates the compilation of classical real-time computation using some of the most advanced libraries in the world such as CUDA. Especially machine learning and other scientific computations tasks are particularly well supported.

To get a better understanding of the syntax and semantics of Jaxpr (and with that also Jaspr) please check this link.

Similar to Jaxpr, Jaspr objects represent (hybrid) quantum algorithms in the form of a functional programming language in SSA-form.

It is possible to compile Jaspr objects into QIR, which is facilitated by the Catalyst framework (check qrisp.jasp.jaspr.to_qir() for more details).

Qrisp scripts can be turned into Jaspr objects by calling the make_jaspr function, which has similar semantics as jax.make_jaxpr.

from qrisp import *
from qrisp.jasp import make_jaspr

def test_fun(i):

    qv = QuantumFloat(i, -1)
    x(qv[0])
    cx(qv[0], qv[i-1])
    meas_res = measure(qv)
    meas_res += 1
    return meas_res

jaspr = make_jaspr(test_fun)(4)
print(jaspr)

This will give you the following output:

{ lambda ; a:QuantumCircuit b:i32[]. let
    c:QuantumCircuit d:QubitArray = create_qubits a b
    e:Qubit = get_qubit d 0
    f:QuantumCircuit = x c e
    g:i32[] = sub b 1
    h:Qubit = get_qubit d g
    i:QuantumCircuit = cx f e h
    j:QuantumCircuit k:i32[] = measure i d
    l:f32[] = convert_element_type[new_dtype=float64 weak_type=True] k
    m:f32[] = mul l 0.5
    n:f32[] = add m 1.0
  in (j, n) }

A defining feature of the Jaspr class is that the first input and the first output are always of QuantumCircuit type. Therefore, Jaspr objects always represent some (hybrid) quantum operation.

Qrisp comes with a built-in Jaspr interpreter. For that you simply have to call the object like a function:

>>> print(jaspr(2))
2.5
>>> print(jaspr(4))
5.5

Methods#

Manipulation#

Jaspr.inverse()

Returns the inverse Jaspr (if applicable).

Jaspr.control(num_ctrl[, ctrl_state])

Returns the controlled version of the Jaspr.

Evaluation#

Jaspr.qjit(*args[, function_name])

Leverages the Catalyst pipeline to compile a QIR representation of this function and executes that function using the Catalyst QIR runtime.

Jaspr.to_qc(*args)

Converts the Jaspr into a QuantumCircuit if applicable.

Jaspr.to_qir()

Compiles the Jaspr to QIR using the Catalyst framework.

Jaspr.to_mlir()

Compiles the Jaspr to MLIR using the Catalyst dialect.

Jaspr.to_qasm(*args)

Compiles the Jaspr into an OpenQASM 2 string.

Jaspr.to_catalyst_jaxpr()

Compiles the jaspr to the corresponding Catalyst jaxpr.

Advanced details#

This section elaborates how Jaspr objects are embedded into the Jax infrastructure. If you just want to accelerate your code you can (probably) skip this. It is recommended to first get a solid understanding of Jax primitives and how to create a Jaxpr out of them.

jasp is designed to model dynamic quantum computations with a minimal set of primitives.

For that, there are 3 new Jax data types defined:

  • QuantumCircuit, which represents an object that tracks what kind of manipulations are applied to the quantum state.

  • QubitArray, which represents an array of qubits, that can have a dynamic amount of qubits

  • Qubit, which represents individual qubits.

Before we describe how quantum computations are realized we list some “administrative” primitives and their semantics.

Primitive

Semantics

create_qubits

Can be used to create new qubits. Takes a QuantumCircuit and a (dynamic) integer and returns a new QuantumCircuit and a QubitArray.

get_qubit

Extracts a Qubit from a QubitArray. Takes a QubitArray and a dynamic integer (indicating the position) and returns a Qubit.

get_size

Retrieves the size of a QubitArray. Takes a QubitArray and returns an integer (the size).

To instruct a quantum computation, the Operation class is elevated to a Jax primitive:

from qrisp import *
from qrisp.jasp import *

def test_function(i):
    qv = QuantumVariable(i)
    cx(qv[0], qv[1])
    bl = measure(qv[1])
    return qv, bl

print(make_jaspr(test_function)(2))
    { lambda ; a:QuantumCircuit b:i32[]. let
    c:QuantumCircuit d:QubitArray = create_qubits a b
    e:Qubit = get_qubit d 0
    f:Qubit = get_qubit d 1
    g:QuantumCircuit = cx c e f
    h:QuantumCircuit i:bool[] = measure g f
in (h, d, i) }

The line starting with g: describes how an Operation can be plugged into a Jaspr: The first argument is always a QuantumCircuit, and the following arguments are Qubit objects. With this kind of structure, jasp is very close to how quantum computations are modelled mathematically: As a unitary that is applied to a tensor on certain indices. Indeed you can view the defined object as precisely that (if it helps you programming/understanding): QuantumCircuit objects represent tensors, Qubit object represent integer indices and QubitArray object represent arrays of indices.

The measure primitive takes a special role here: Compared to the other quantum operations, it not only returns a new QuantumCircuit but also a boolean value (the measurement outcome). It is also possible to call the measure on a QubitArray:

def test_function(i):
    qv = QuantumVariable(i)
    cx(qv[0], qv[1])
    a = measure(qv)
    return a

print(make_jaspr(test_function)(2))
{ lambda ; a:QuantumCircuit b:i32[]. let
    c:QuantumCircuit d:QubitArray = create_qubits a b
    e:Qubit = get_qubit d 0
    f:Qubit = get_qubit d 1
    g:QuantumCircuit = cx c e f
    h:QuantumCircuit i:i32[] = measure g d
in (h, i) }

In this case, an integer is returned instead of a boolean value. Both variants return values (bool/int32) that other Jax modules understand, highlighting the seamless embedding of quantum computations into the Jax ecosystem.

QuantumEnvironments#

Quantum Environments in jasp are also represented by a dedicated primitive:

def test_function(i):
    qv = QuantumVariable(i)

    with invert():
        t(qv[0])
        cx(qv[0], qv[1])

    return qv

jaspr = make_jaspr(test_function)(2)
print(jaspr)
{ lambda ; a:QuantumCircuit b:i32[]. let
    c:QuantumCircuit d:QubitArray = create_qubits a b
    e:QuantumCircuit = q_env[
    jaspr={ lambda ; f:QuantumCircuit d:QubitArray. let
        g:Qubit = get_qubit d 0
        h:QuantumCircuit = t f g
        i:Qubit = get_qubit d 1
        j:QuantumCircuit = cx h g i
        in (j,) }
    type=InversionEnvironment
    ] c d
  in (e, d) }

You can see how the body of the InversionEnvironment is _collected_ into another Jaspr. This reflects the fact that at their core, QuantumEnvironments describe higher-order quantum functions (ie. functions that operate on functions). In order to apply the transformations induced by the QuantumEnvironment, we can call Jaspr.flatten_environments:

>>> print(jaspr.flatten_environments)
{ lambda ; a:QuantumCircuit b:i32[]. let
    c:QuantumCircuit d:QubitArray = create_qubits a b
    e:Qubit = get_qubit d 0
    f:Qubit = get_qubit d 1
    g:QuantumCircuit = cx c e f
    h:QuantumCircuit = t_dg g e
in (h, d) }

We see that as expected, the order of the cx and the t gate has been switched and the t gate has been turned into a t_dg.