Sampling#
- sample(state_prep=None, shots=0, post_processor=None)[source]#
The
sample
function allows to take samples from a state that is specified by a preparation procedure. This preparation procedure can be supplied via a Python function that returns one or more QuantumVariables.The samples are returned in the form of a Jax Array which is shaped according to the
shots
parameter. Because of this, shots can only be a static integer (no dynamic values!). If you want to sample with a dynamic shot amount, look into Expectation Value.- Parameters:
- state_prepcallable
A function returning one or more QuantumVariables. The state from this QuantumVariables will be sampled. The state preparation function can only take classical values as arguments. This is because a quantum value would need to be copied for each sampling iteration, which is prohibited by the no-cloning theorem.
- shotsint
The amounts of samples to take.
- post_processorcallable, optional
A function to apply to the samples directly after measuring. By default no post processing is applied.
- Returns:
- callable
A classical, Jax traceable function returning a jax array containing the measurement results of each shot.
- Raises:
- Exception
Tried to sample with dynamic shots value (static integer required)
- Exception
Tried to sample from state preparation function taking a quantum value
Examples
We prepare the state
\[\ket{\psi} = \frac{1}{\sqrt{2}} \left(\ket{0}\ket{0}\ket{\text{True}} + \ket{k}\ket{k}\ket{\text{True}})\right)\]from qrisp import * from qrisp.jasp import * def state_prep(k): a = QuantumFloat(4) b = QuantumFloat(4) qbl = QuantumBool() h(qbl) with control(qbl[0]): a[:] = k cx(a, b) return a, b
And subsequently sample from the QuantumFloats:
@jaspify def main(k): sampling_function = sample(state_prep, shots = 10) return sampling_function(k) print(main(3)) # Yields # [[3. 3.] # [0. 0.] # [0. 0.] # [3. 3.] # [0. 0.] # [0. 0.] # [3. 3.] # [3. 3.] # [0. 0.] # [0. 0.]]
To demonstrate the post processing feature, we write a simple post processing function:
def post_processor(x, y): return 2*x + y//2 @jaspify def main(k): sampling_function = sample(state_prep, shots = 10, post_processor = post_processor) return sampling_function(k) print(main(4)) # Yields # [10. 10. 0. 0. 0. 0. 0. 0. 10. 10.]