
sample(state_prep=None, shots=0, post_processor=None)[source]#

The sample function allows to take samples from a state that is specified by a preparation procedure. This preparation procedure can be supplied via a Python function that returns one or more QuantumVariables.

The samples are returned in the form of a Jax Array which is shaped according to the shots parameter. Because of this, shots can only be a static integer (no dynamic values!). If you want to sample with a dynamic shot amount, look into Expectation Value.


A function returning one or more QuantumVariables. The state from this QuantumVariables will be sampled. The state preparation function can only take classical values as arguments. This is because a quantum value would need to be copied for each sampling iteration, which is prohibited by the no-cloning theorem.


The amounts of samples to take.

post_processorcallable, optional

A function to apply to the samples directly after measuring. By default no post processing is applied.


A classical, Jax traceable function returning a jax array containing the measurement results of each shot.


Tried to sample with dynamic shots value (static integer required)


Tried to sample from state preparation function taking a quantum value


We prepare the state

\[\ket{\psi} = \frac{1}{\sqrt{2}} \left(\ket{0}\ket{0}\ket{\text{True}} + \ket{k}\ket{k}\ket{\text{True}})\right)\]
from qrisp import *
from qrisp.jasp import *

def state_prep(k):
    a = QuantumFloat(4)
    b = QuantumFloat(4)

    qbl = QuantumBool()

    with control(qbl[0]):
        a[:] = k

    cx(a, b)

    return a, b

And subsequently sample from the QuantumFloats:

def main(k):

    sampling_function = sample(state_prep, 
                               shots = 10)

    return sampling_function(k)


# Yields
# [[3. 3.]
#  [0. 0.]
#  [0. 0.]
#  [3. 3.]
#  [0. 0.]
#  [0. 0.]
#  [3. 3.]
#  [3. 3.]
#  [0. 0.]
#  [0. 0.]]

To demonstrate the post processing feature, we write a simple post processing function:

def post_processor(x, y):
    return 2*x + y//2

def main(k):

    sampling_function = sample(state_prep, 
                               shots = 10,
                               post_processor = post_processor)

    return sampling_function(k)

# Yields
# [10. 10.  0.  0.  0.  0.  0.  0. 10. 10.]