"""********************************************************************************* Copyright (c) 2024 the Qrisp authors** This program and the accompanying materials are made available under the* terms of the Eclipse Public License 2.0 which is available at* http://www.eclipse.org/legal/epl-2.0.** This Source Code may also be made available under the following Secondary* Licenses when the conditions for such availability set forth in the Eclipse* Public License, v. 2.0 are satisfied: GNU General Public License, version 2* with the GNU Classpath Exception which is* available at https://www.gnu.org/software/classpath/license.html.** SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0********************************************************************************"""fromqrispimportQuantumArray,mcp,conjugate,invertfromqrisp.jaspimportq_fori_loop,q_cond,check_for_tracing_modefromjaximportlaximportsympyasspimportnumpyasnpimportjax.numpyasjnp
[docs]defQITE(qarg,U_0,exp_H,s,k,method="GC"):r""" Performs `Double-Braket Quantum Imaginary-Time Evolution (DB-QITE) <https://arxiv.org/abs/2412.04554>`_. Given a Hamiltonian :ref:`Operator <Operators>` $H$, this method implements the unitary $U_k$ that is recursively defined by either of * Group commutator (GQ) approximation: .. math:: U_{k+1} = e^{i\sqrt{s_k}H}e^{i\sqrt{s_k}\omega_k}e^{-i\sqrt{s_k}H}U_k * Higher-order product formula (HOPF) approximation: .. math:: U_{k+1} = e^{i\phi\sqrt{s_k}H}e^{i\phi\sqrt{s_k}\omega_k}e^{-i\sqrt{s_k}H}e^{-i(1+\phi)\sqrt{s_k}\omega_k}e^{i(1-\phi)\sqrt{s_k}H}U_k where $e^{-it\omega_k}=U_ke^{it\ket{0}\bra{0}}U_k^{\dagger}$ is the refection around the state $\ket{\omega_k}=U_k\ket{0}$. Parameters ---------- qarg : :ref:`QuantumVariable` or :ref:`QuantumArray` The quantum argument on which quantum imaginary time evolution is performed. U_0 : function A Python function that takes a QuantumVariable or QuantumArray ``qarg`` as input, and prepares the initial state. exp_H : function A Python function that takes a QuantumVariable or QuantumArray ``qarg`` and time ``t`` as input, and performs forward evolution $e^{-itH}$. s : list[float] or list[Sympy.Symbol] A list of evolution times for each step. k : int The number of steps. method : str, optional The method for approximating the double-bracket flow (DBF). Available are ``GC`` and ``HOPF``. The default is ``GC``. Examples -------- We utilize QITE to approximate the ground state energy of a Heisenberg chain. We start by defining the lattice graph $G$: :: import networkx as nx # Create a graph N = 4 G = nx.Graph() G.add_edges_from([(k,k+1) for k in range(N-1)]) Next, we set up the Heisenberg Hamiltonian and calculate the ground state energy classically: :: from qrisp.operators import X, Y, Z def create_heisenberg_hamiltonian(G): H = sum(X(i)*X(j)+Y(i)*Y(j)+Z(i)*Z(j) for (i,j) in G.edges()) return H H = create_heisenberg_hamiltonian(G) print(H) print(H.ground_state_energy()) As explained :ref:`in this example <VQEHeisenberg>`, a suitable initial approximation for the ground state is given by a tensor product of singlet states $\frac{1}{\sqrt{2}}(\ket{10}-\ket{01})$ corresponding to a maximal matching of the graph $G$. Accordingly, we define the function ``U_0``: :: from qrisp import QuantumVariable from qrisp.vqe.problems.heisenberg import create_heisenberg_init_function M = nx.maximal_matching(G) U_0 = create_heisenberg_init_function(M) def state_prep(): qv = QuantumVariable(N) U_0(qv) return qv E_0 = H.expectation_value(state_prep)() print(E_0) For the function ``exp_H`` that performs forward evolution $e^{-itH}$, we use the :meth:`trotterization <qrisp.operators.qubit.QubitOperator.trotterization>` method with 5 Trotter steps: :: def exp_H(qv, t): H.trotterization(method='commuting')(qv,t,5) With all the necessary ingredients, we use QITE to approximate the ground state: :: import numpy as np import sympy as sp from qrisp.qite import QITE steps = 4 s_values = np.linspace(.01,.3,10) theta = sp.Symbol('theta') optimal_s = [theta] optimal_energies = [E_0] for k in range(1,steps+1): # Perform k steps of QITE def state_prep(): qv = QuantumVariable(N) QITE(qv, U_0, exp_H, optimal_s, k) return qv qv = state_prep() qc = qv.qs.compile() # Find optimal evolution time # Use "precompliled_qc" keyword argument to avoid repeated compilation of the QITE circuit energies = [H.expectation_value(state_prep, diagonalisation_method='commuting', subs_dic={theta:s_}, precompiled_qc=qc)() for s_ in s_values] index = np.argmin(energies) s_min = s_values[index] optimal_s.insert(-1,s_min) optimal_energies.append(energies[index]) print(optimal_energies) Finally, we visualize the results: :: import matplotlib.pyplot as plt evolution_times = [sum(optimal_s[i] for i in range(k)) for k in range(steps+1)] plt.xlabel('Evolution time', fontsize=15, color='#444444') plt.ylabel('Energy', fontsize=15, color='#444444') plt.axhline(y=H.ground_state_energy(), color='#6929C4', linestyle='--', linewidth=2, label='Exact energy') plt.plot(evolution_times, optimal_energies, c='#20306f', marker="o", linestyle='solid', linewidth=3, zorder=3, label='DB-QITE') plt.legend(fontsize=12, labelcolor='linecolor') plt.tick_params(axis='both', labelsize=12) plt.grid() plt.show() .. figure:: /_static/heisenberg_qite.png :scale: 80% :align: center """ifnotcheck_for_tracing_mode():ifk==0:U_0(qarg)else:s_=sp.sqrt(s[k-1])defconjugator(qarg):withinvert():QITE(qarg,U_0,exp_H,s,k-1,method=method)defreflection(qarg,t_):withconjugate(conjugator)(qarg):ifisinstance(qarg,QuantumArray):qubits=sum([qv.regforqvinqarg.flatten()],[])mcp(t_,qubits,ctrl_state=0,method="khattar")else:mcp(t_,qarg,ctrl_state=0,method="khattar")ifmethod=="GC":QITE(qarg,U_0,exp_H,s,k-1,method=method)withconjugate(exp_H)(qarg,s_):reflection(qarg,s_)ifmethod=="HOPF":phi=(sp.sqrt(5)-1)/2QITE(qarg,U_0,exp_H,s,k-1,method=method)# exp_H performs forward evolution $e^{-itH}exp_H(qarg,-(1-phi)*s_)reflection(qarg,-(1+phi)*s_)exp_H(qarg,s_)reflection(qarg,phi*s_)exp_H(qarg,-phi*s_)else:""" To create a jasp-compatible implementation of QITE, we need to remove the recursive structure. We achieve this by fully expanding the recursive formula for $U_k$ down to the $k=0$ level. From there, we find a tree structure with branching factor 3 (GC) or 5 (HOPF) where some branches are inverted due to the presence of conjugate operators $U_i^\dagger$. We traverse the tree depth-first using up-, down-, bounce-, and leaf-operations that we obtain from inspecting the formula for $U_k$. """defint_to_base(n,base=3,max_digits=10):""" Get the array representation of an integer `n` with base `base`. The array has length `max_digits` and the least significant digit is at index `0`. """defcond_fun(state):n,digits,i=statereturnjnp.logical_and(n>0,i<max_digits)defbody_fun(state):n,digits,i=statedigits=digits.at[i].set(n%base)n=n//basereturnn,digits,i+1init_digits=jnp.zeros((max_digits,),dtype=jnp.int32)_,digits,_=lax.while_loop(cond_fun,body_fun,(n,init_digits,0))returndigits# Define basic operationsdefU_0_dag(q_arg):withinvert():U_0(q_arg)defexp_00(q_arg,time):mcp(time,q_arg,ctrl_state=0)ifmethod=="GC":defbody_fun(i,val):qarg=val# Obtain old and new positionold_pos=int_to_base(i,3)new_pos=int_to_base(i+1,3)# Obtain largest changed index + 1num_changes=jnp.count_nonzero(new_pos!=old_pos)# Compute which operations must be invertedinv_mode_leaf=jnp.equal(jnp.count_nonzero(old_pos==1)%2,0)inv_mode_up=inv_mode_leafinv_mode_bounce=jnp.logical_xor(inv_mode_up,old_pos[num_changes-1]==1)inv_mode_down=jnp.equal(jnp.count_nonzero(new_pos==1)%2,0)# Apply U_0q_cond(inv_mode_leaf,U_0,U_0_dag,qarg)# Go up the branchtime=q_fori_loop(0,num_changes-1,lambdaj,time:time+jnp.sqrt(s[j]),0)q_cond(inv_mode_up,exp_H,lambdaa,b:None,qarg,-time)# Bounce to next branchq_cond(jnp.logical_and(old_pos[num_changes-1]==0,inv_mode_bounce),exp_H,lambdaa,b:None,qarg,jnp.sqrt(s[num_changes-1]),)q_cond(jnp.logical_and(old_pos[num_changes-1]==1,inv_mode_bounce),exp_00,lambdaa,b:None,qarg,jnp.sqrt(s[num_changes-1]),)q_cond(jnp.logical_and(old_pos[num_changes-1]==0,jnp.logical_not(inv_mode_bounce)),exp_00,lambdaa,b:None,qarg,-jnp.sqrt(s[num_changes-1]),)q_cond(jnp.logical_and(old_pos[num_changes-1]==1,jnp.logical_not(inv_mode_bounce)),exp_H,lambdaa,b:None,qarg,-jnp.sqrt(s[num_changes-1]),)# Go down to leafq_cond(inv_mode_down,lambdaa,b:None,exp_H,qarg,time)returnqarg# Iterate all leafs except lastq_fori_loop(0,3**k-1,body_fun,qarg)# Do last leafU_0(qarg)time=lax.fori_loop(0,k,lambdaj,time:time+jnp.sqrt(s[j]),0)exp_H(qarg,-time)ifmethod=="HOPF":phi=(jnp.sqrt(5)-1)/2defbody_fun(i,val):qarg=val# Obtain old and new positionold_pos=int_to_base(i,5)new_pos=int_to_base(i+1,5)# Obtain largest changed index + 1num_changes=jnp.count_nonzero(new_pos!=old_pos)# Compute which operations must be invertedinv_mode_leaf=(jnp.count_nonzero(old_pos==1)+jnp.count_nonzero(old_pos==3))%2==0inv_mode_up=inv_mode_leafinv_mode_bounce=jnp.logical_xor(inv_mode_up,jnp.logical_or(old_pos[num_changes-1]==1,old_pos[num_changes-1]==3),)inv_mode_down=(jnp.count_nonzero(new_pos==1)+jnp.count_nonzero(new_pos==3))%2==0# Apply U_0q_cond(inv_mode_leaf,U_0,U_0_dag,qarg)# Go up the branchtime=(q_fori_loop(0,num_changes-1,lambdaj,time:time+jnp.sqrt(s[j]),0)*phi)q_cond(inv_mode_up,exp_H,lambdaa,b:None,qarg,-time)# Bounce to next branchq_cond(jnp.logical_and(old_pos[num_changes-1]==0,inv_mode_bounce),exp_H,lambdaa,b:None,qarg,-(1-phi)*jnp.sqrt(s[num_changes-1]),)q_cond(jnp.logical_and(old_pos[num_changes-1]==1,inv_mode_bounce),exp_00,lambdaa,b:None,qarg,-(1+phi)*jnp.sqrt(s[num_changes-1]),)q_cond(jnp.logical_and(old_pos[num_changes-1]==2,inv_mode_bounce),exp_H,lambdaa,b:None,qarg,jnp.sqrt(s[num_changes-1]),)q_cond(jnp.logical_and(old_pos[num_changes-1]==3,inv_mode_bounce),exp_00,lambdaa,b:None,qarg,phi*jnp.sqrt(s[num_changes-1]),)q_cond(jnp.logical_and(old_pos[num_changes-1]==3,jnp.logical_not(inv_mode_bounce)),exp_H,lambdaa,b:None,qarg,(1-phi)*jnp.sqrt(s[num_changes-1]),)q_cond(jnp.logical_and(old_pos[num_changes-1]==2,jnp.logical_not(inv_mode_bounce)),exp_00,lambdaa,b:None,qarg,(1+phi)*jnp.sqrt(s[num_changes-1]),)q_cond(jnp.logical_and(old_pos[num_changes-1]==1,jnp.logical_not(inv_mode_bounce)),exp_H,lambdaa,b:None,qarg,-jnp.sqrt(s[num_changes-1]),)q_cond(jnp.logical_and(old_pos[num_changes-1]==0,jnp.logical_not(inv_mode_bounce)),exp_00,lambdaa,b:None,qarg,-phi*jnp.sqrt(s[num_changes-1]),)# Go down to leafq_cond(inv_mode_down,lambdaa,b:None,exp_H,qarg,time)returnqarg# Iterate all leafs except lastq_fori_loop(0,5**k-1,body_fun,qarg)# Do last leafU_0(qarg)time=-phi*lax.fori_loop(0,k,lambdaj,time:time+jnp.sqrt(s[j]),0)exp_H(qarg,time)
Get in touch!
If you are interested in Qrisp or high-level quantum algorithm research in general connect with us on our
Slack workspace.