qrisp.jasp.Jaspr.extract_post_processing#

Jaspr.extract_post_processing(*args)[source]#

Extracts the post-processing logic from this Jaspr and returns a function that performs the post-processing on measurement results.

This method is useful for separating the quantum circuit from the classical post-processing of measurement results. The quantum circuit can be executed on a NISQ-style backend to obtain measurement results, and then the post-processing function can be applied to those results to obtain the final output.

Note

It is not possible to extract QuantumCircuits from Jaspr objects involving real-time computation, but it is possible to extract a post processing function.

Parameters:
*argstuple

The static argument values that were used for circuit extraction. These will be bound into the post-processing function as Literals.

Returns:
callable

A function that takes measurement results and returns the post-processed results. Accepts either a string of ‘0’ and ‘1’ characters or a JAX array of booleans with shape (n,). String inputs are automatically converted to boolean arrays.

Examples

We create a Jaspr that performs post-processing on measurement results:

from qrisp import *
import jax.numpy as jnp

@make_jaspr
def example_function(i):
    qv = QuantumFloat(5)
    # First measurement
    meas_1 = measure(qv[i])
    h(qv[1])
    # Second measurement
    meas_2 = measure(qv[1])
    # Classical post-processing
    return meas_1 + 2, meas_2

jaspr = example_function(1)

# Extract the quantum circuit
a, b, qc = jaspr.to_qc(1)

# Extract the post-processing function with the SAME arguments
post_proc = jaspr.extract_post_processing(1)

# Execute qc on a backend to get measurement results
results = qc.run()

# Apply post-processing to each result
for bitstring, count in results.items():
    processed = post_proc(bitstring)
    print(f"{bitstring} -> {processed}")

# Yields:
# 00 -> (Array(2, dtype=int64), Array(False, dtype=bool))
# 01 -> (Array(2, dtype=int64), Array(True, dtype=bool))

# Can also use with array input (useful for JAX jitting):
import jax.numpy as jnp
meas_array = jnp.array([False, True])
processed = post_proc(meas_array)

Note that the static arguments (in this case 1) must be the same as those used for circuit extraction, since they affect the structure of both the quantum circuit and the post-processing logic.