extract_stim#

extract_stim(func=None, *, detector_order='chronological')[source]#

Decorator that extracts a Stim circuit from a Jasp-traceable function.

This decorator enables high-performance Clifford circuit simulation by converting Jasp-traceable Qrisp functions into Stim circuits. It handles the translation of quantum measurements, detectors, observables, and qubit references into Stim’s internal indices, allowing you to map simulation results back to the original function’s return values.

Return values are wrapped in typed numpy array subtypes for easy identification:

  • StimMeasurementHandles - measurement record indices (for slicing sampler results)

  • StimDetectorHandles - detector indices (for slicing detector sampler results)

  • StimObservableHandles - observable indices (for slicing observable results)

  • StimQubitIndices - qubit indices (for identifying which Stim qubits correspond to a QuantumVariable)

These typed arrays behave exactly like regular numpy arrays and can be used directly for slicing sample arrays, while also carrying type information about what they represent.

Note

Stim only supports Clifford operations. Functions containing non-Clifford gates (e.g., T, RZ, RX) will raise an error during conversion.

Warning

Measurement post-processing limitation: Advanced quantum types like QuantumFloat apply post-processing to raw measurement results during decoding. For example, a QuantumFloat might convert a raw integer into a fractional value. This post-processing cannot be performed during the Stim extraction because it requires transforming a list of classical bits in ways that involve classical post-processing steps that can not be represented by the Stim circuit format.

For this reason, it is recommended to use QuantumVariable instead of QuantumFloat (or similar advanced types) when working with extract_stim. QuantumVariable’s decoder returns raw integer values without post-processing, making it fully compatible with Stim’s measurement record. You can then apply any necessary transformations manually after sampling. Values that have been processed in this way are represented through the ProcessedMeasurement class, which acts as a dummy representative.

See the Jaspr.to_qc documentation for more details on this limitation.

Parameters:
funccallable

A Jasp-traceable function that manipulates quantum variables and returns quantum measurement results, parity checks (detectors/observables), QuantumVariables (for qubit indices), and/or classical values.

detector_orderstr, optional

Specifies the ordering of detectors in the returned Stim circuit.

  • "chronological" (default): Detectors appear in the circuit in the order they appear in the code execution.

  • "return_order": Reorders detectors based on the function’s return values. This analyzes all StimDetectorHandles in the return values, flattens and concatenates them to form a permutation, then applies permute_detectors() to reorder the circuit accordingly. The detector handle values are also adjusted to reflect the new ordering.

Returns:
callable

A decorated function that returns:

  • No return value: If func returns nothing, returns just the stim.Circuit object.

  • With return values: If func returns n values, returns a tuple of (n+1) elements:

    • Elements 0 to n-1: The function’s return values, where:

      • Classical values (integers, floats, etc.) are returned as-is.

      • Quantum measurements are returned as StimMeasurementHandles arrays containing measurement record indices. These can be used directly to slice the results from stim_circuit.compile_sampler().sample().

      • Parity detectors (from parity(..., expectation=0/1)) are returned as StimDetectorHandles arrays containing detector indices.

      • Parity observables (from parity(..., observable=True)) are returned as StimObservableHandles arrays containing observable indices.

      • QuantumVariables (unmeasured) are returned as StimQubitIndices arrays containing the Stim qubit indices for that variable.

    • Element n: The stim.Circuit object.

Examples

Example 1: Single return value

When the function has no return value, only the Stim circuit is returned:

from qrisp import QuantumVariable, h, cx, measure
from qrisp.jasp import extract_stim

@extract_stim
def bell_state():
    qv = QuantumVariable(2)
    h(qv[0])
    cx(qv[0], qv[1])
    measure(qv)

stim_circuit = bell_state()
print(stim_circuit)
# Yields:
# H 0
# CX 0 1
# M 0 1

Example 2: Multiple return values with measurement indices

When returning one or more values, quantum measurements are returned as StimMeasurementHandles arrays, while classical values remain unchanged:

from qrisp import QuantumFloat, h, cx, measure
from qrisp.jasp import extract_stim

@extract_stim
def analyze_state(n):
    qf = QuantumVariable(n)
    h(qf)

    # Mid-circuit measurement
    first_qubit_result = measure(qf[0])

    # Classical computation
    classical_value = n * 2

    # Final measurement
    final_result = measure(qf)

    return classical_value, first_qubit_result, final_result

classical_val, first_meas_idx, final_meas_indices, stim_circuit = analyze_state(3)

print(f"Classical value: {classical_val}")  # 6 (unchanged)
print(f"First qubit measurement index: {first_meas_idx}")  # StimMeasurementHandles([0])
print(f"Final measurement indices: {final_meas_indices}")  # StimMeasurementHandles([1, 2, 3])
print(f"Type: {type(final_meas_indices)}")  # <class 'StimMeasurementHandles'>

Example 3: Sampling and slicing results

Use the measurement handles (which are numpy arrays) to slice results from Stim’s samples:

@extract_stim
def prepare_entangled_state():
    qf1 = QuantumVariable(2)
    qf2 = QuantumVariable(3)

    # Prepare qf1 in superposition
    h(qf1)

    # Entangle qf2 with qf1[0]
    for i in range(3):
        cx(qf1[0], qf2[i])

    result1 = measure(qf1)
    result2 = measure(qf2)

    return result1, result2

# Extract the circuit and measurement indices
qf1_indices, qf2_indices, stim_circuit = prepare_entangled_state()

print(f"qf1 measured at positions: {qf1_indices}")  # StimMeasurementHandles([0, 1])
print(f"qf2 measured at positions: {qf2_indices}")  # StimMeasurementHandles([2, 3, 4])

# Sample 1000 shots from the Stim circuit
sampler = stim_circuit.compile_sampler()
all_samples = sampler.sample(1000)  # Shape: (1000, 5) - 5 total measurements

# Slice the samples using the handle arrays directly (they are numpy arrays)
qf1_samples = all_samples[:, qf1_indices]  # Shape: (1000, 2)
qf2_samples = all_samples[:, qf2_indices]  # Shape: (1000, 3)

# Slice the samples using the handle arrays directly (they are numpy arrays)
qf1_samples = all_samples[:, qf1_indices]  # Shape: (1000, 2)
qf2_samples = all_samples[:, qf2_indices]  # Shape: (1000, 3)

# Convert bit arrays to integers (little-endian)
import numpy as np
qf1_values = qf1_samples.dot(1 << np.arange(qf1_samples.shape[1]))
qf2_values = qf2_samples.dot(1 << np.arange(qf2_samples.shape[1]))

print(f"First 10 qf1 values: {qf1_values[:10]}")
print(f"First 10 qf2 values: {qf2_values[:10]}")

# Verify entanglement: when qf1[0]=0, all qf2 bits should be 0
qf1_first_bit = qf1_samples[:, 0]
assert np.all(qf2_samples[qf1_first_bit == 0] == 0)

Example 4: Using parity checks (Detectors)

You can use the parity() function to define parity checks within your circuit. When extracted to Stim, these are converted into DETECTOR instructions and returned as StimDetectorHandles arrays:

from qrisp import QuantumVariable, h, cx, measure
from qrisp.jasp import extract_stim, parity
from qrisp.misc.stim_tools import stim_noise
import stim

@extract_stim
def selective_noise_demo():
    # Create two QuantumVariables for independent Bell pairs
    bell_pair_1 = QuantumVariable(2)
    bell_pair_2 = QuantumVariable(2)

    h(bell_pair_1[0]); cx(bell_pair_1[0], bell_pair_1[1])
    h(bell_pair_2[0]); cx(bell_pair_2[0], bell_pair_2[1])

    # Apply deterministic X error to one of the qubits in the second pair
    stim_noise("X_ERROR", 1.0, bell_pair_2[0])

    m1_0 = measure(bell_pair_1[0]); m1_1 = measure(bell_pair_1[1])
    m2_0 = measure(bell_pair_2[0]); m2_1 = measure(bell_pair_2[1])

    # Detector 1: expectation=0 implies we expect even parity
    d1 = parity(m1_0, m1_1, expectation=0)

    # Detector 2: Checks parity of second, noisy pair
    d2 = parity(m2_0, m2_1, expectation=0)

    return d1, d2

d1, d2, stim_circuit = selective_noise_demo()

print(f"Detector 1 index: {d1}")  # StimDetectorHandles([0])
print(f"Detector 2 index: {d2}")  # StimDetectorHandles([1])

sampler = stim_circuit.compile_detector_sampler()
detector_samples = sampler.sample(1)

# Slice detector results using the handles
d1_result = detector_samples[:, d1]
d2_result = detector_samples[:, d2]
print(f"D1: {d1_result}, D2: {d2_result}")  # [[False]] [[True]] (error in pair 2)

Example 5: Defining Observables

Similarly, parity() with observable=True defines logical observables in Stim, returned as StimObservableHandles arrays:

@extract_stim
def observable_demo():
    qv = QuantumVariable(2)
    h(qv)
    m0 = measure(qv[0]); m1 = measure(qv[1])

    # Define an observable O = Z_0 Z_1
    logical_obs = parity(m0, m1, observable=True)
    return logical_obs

obs_idx, stim_circuit = observable_demo()
print(f"Observable index: {obs_idx}")  # StimObservableHandles([0])
# stim_circuit contains OBSERVABLE_INCLUDE(0) ...

Example 6: Returning qubit indices

When returning an unmeasured QuantumVariable, you get StimQubitIndices arrays containing the Stim qubit indices for that variable:

@extract_stim
def qubit_index_demo():
    qv1 = QuantumVariable(2)
    qv2 = QuantumVariable(3)
    h(qv1)
    cx(qv1[0], qv2[0])
    return qv1, qv2

qv1_indices, qv2_indices, stim_circuit = qubit_index_demo()
print(f"qv1 uses Stim qubits: {qv1_indices}")  # StimQubitIndices([0, 1])
print(f"qv2 uses Stim qubits: {qv2_indices}")  # StimQubitIndices([2, 3, 4])