Source code for circuit_knitting.cutting.cutqc.wire_cutting_post_processing

# This code is a Qiskit project.

# (C) Copyright IBM 2022.

# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""File containing all cutting post processing functionality."""

from __future__ import annotations

import itertools
import multiprocessing as mp
from typing import Sequence, Any

import numpy as np
from qiskit import QuantumCircuit
from qiskit.utils.deprecation import deprecate_func
from qiskit.circuit import Qubit


@deprecate_func(
    removal_timeline="no sooner than CKT v0.8.0",
    since="0.7.0",
    package_name="circuit-knitting-toolbox",
    additional_msg="Use the wire cutting or automated cut-finding functionality in the ``circuit_knitting.cutting`` package. ",
)
def get_cut_qubit_pairs(
    complete_path_map: dict[Qubit, Sequence[dict[str, int | Qubit]]]
) -> list[tuple[dict[str, Any], dict[str, Any]]]:
    """
    Get O-rho cut qubit pairs.

    Iterates through the path for each of the qubits and keeps track of the
    each pair of neigbors.

    Args:
        complete_path_map: The dictionary containing all path information

    Returns:
        All pairs for each of the qubit paths
    """
    O_rho_pairs = []
    for input_qubit in complete_path_map:
        path = complete_path_map[input_qubit]
        if len(path) > 1:
            for path_ctr, item in enumerate(path[:-1]):
                O_qubit_tuple = item
                rho_qubit_tuple = path[path_ctr + 1]
                O_rho_pairs.append((O_qubit_tuple, rho_qubit_tuple))

    return O_rho_pairs


@deprecate_func(
    removal_timeline="no sooner than CKT v0.8.0",
    since="0.7.0",
    package_name="circuit-knitting-toolbox",
    additional_msg="Use the wire cutting or automated cut-finding functionality in the ``circuit_knitting.cutting`` package. ",
)
def get_label(label_idx: int, num_cuts: int) -> Sequence[str]:
    """
    Get the basis label for each cut point.

    Args:
        label_idx: The label to be applied for the current basis
        num_cuts: The number of cuts

    Returns:
        The list of the labels
    """
    assert label_idx < 4**num_cuts
    basis = ["I", "X", "Y", "Z"]
    label = []
    for position in range(num_cuts):
        digit = label_idx % 4
        label.append(basis[digit])
        label_idx = label_idx // 4
    return label


@deprecate_func(
    removal_timeline="no sooner than CKT v0.8.0",
    since="0.7.0",
    package_name="circuit-knitting-toolbox",
    additional_msg="Use the wire cutting or automated cut-finding functionality in the ``circuit_knitting.cutting`` package. ",
)
def attribute_label(
    label: Sequence[str],
    O_rho_pairs: list[tuple[dict[str, Any], dict[str, Any]]],
    num_subcircuits: int,
) -> dict[int, dict[str, str]]:
    """
    Get the label attributed to each subcircuit.

    subcircuit_label[subcircuit_idx]['init'] = str of init for the cut qubits.
    subcircuit_label[subcircuit_idx]['meas'] = str of meas for the cut qubits.

    Args:
        label: The labels of bases for each cut point, as generated by the
            get_label function
        O_rho_paris: The pairs of cut qubits, as generated by the get_cut_qubit_pairs
            function
        num_subcircuits: The number of subcircuits

    Returns:
        The dictionary containing the labels applied to each of the subcircuits
    """
    subcircuit_label = {
        subcircuit_idx: {"init": "", "meas": ""}
        for subcircuit_idx in range(num_subcircuits)
    }
    for c, pair in zip(label, O_rho_pairs):
        if len(pair) != 2:
            raise ValueError(
                "O_rho_pairs should only contain sequences of length 2: {pair}"
            )
        O_qubit = pair[0]
        rho_qubit: dict[str, int | Qubit] = pair[-1]
        subcircuit_label[O_qubit["subcircuit_idx"]]["meas"] += c
        subcircuit_label[rho_qubit["subcircuit_idx"]]["init"] += c
    return subcircuit_label


@deprecate_func(
    removal_timeline="no sooner than CKT v0.8.0",
    since="0.7.0",
    package_name="circuit-knitting-toolbox",
    additional_msg="Use the wire cutting or automated cut-finding functionality in the ``circuit_knitting.cutting`` package. ",
)
def fill_label(
    subcircuit_idx: int,
    subcircuit: QuantumCircuit,
    subcircuit_label: dict[str, str],
    O_rho_pairs: list[tuple[dict[str, Any], dict[str, Any]]],
) -> tuple[list[str], list[str]]:
    """
    Given a subcircuit, its label and O-rho cut qubit pairs fill the full init, meas strings.

    Args:
        subcircuit_idx: The subcircuit index
        subcircuit: The subcircuit
        subcircuit_label: The dictionary containing the applied labels for the bases
        O_rho_pairs: The list of cut pair qubits

    Returns:
        A tuple containing the initial input bases and the measurements required
    """
    subcircuit_init_label_counter = 0
    subcircuit_meas_label_counter = 0
    init = ["zero" for q in range(subcircuit.num_qubits)]
    meas = ["comp" for q in range(subcircuit.num_qubits)]
    for pair in O_rho_pairs:
        if len(pair) != 2:
            raise ValueError(f"O_rho_pairs should be length 2: {O_rho_pairs}")
        O_qubit = pair[0]
        rho_qubit: dict[str, int | Qubit] = pair[-1]
        if O_qubit["subcircuit_idx"] == subcircuit_idx:
            qubit_idx = subcircuit.qubits.index(O_qubit["subcircuit_qubit"])
            meas[qubit_idx] = subcircuit_label["meas"][subcircuit_meas_label_counter]
            subcircuit_meas_label_counter += 1
        if rho_qubit["subcircuit_idx"] == subcircuit_idx:
            qubit_idx = subcircuit.qubits.index(rho_qubit["subcircuit_qubit"])
            init[qubit_idx] = subcircuit_label["init"][subcircuit_init_label_counter]
            subcircuit_init_label_counter += 1
    return init, meas


@deprecate_func(
    removal_timeline="no sooner than CKT v0.8.0",
    since="0.7.0",
    package_name="circuit-knitting-toolbox",
    additional_msg="Use the wire cutting or automated cut-finding functionality in the ``circuit_knitting.cutting`` package. ",
)
def get_init_meas(
    init_label: Sequence[str], meas_label: Sequence[str]
) -> list[tuple[tuple[str, ...], tuple[str, ...]]]:
    """
    Generate the initial measurements.

    Args:
        init_label: The list of initial bases
        meas: The list of measurement bases

    Returns:
        All subcircuit initial measurements required
    """
    init_combinations = []
    for x in init_label:
        if x == "zero":
            init_combinations.append(["zero"])
        elif x == "I":
            init_combinations.append(["+zero", "+one"])
        elif x == "X":
            init_combinations.append(["2plus", "-zero", "-one"])
        elif x == "Y":
            init_combinations.append(["2plusI", "-zero", "-one"])
        elif x == "Z":
            init_combinations.append(["+zero", "-one"])
        else:
            raise Exception("Illegal initilization symbol :", x)
    init_combinations = list(itertools.product(*init_combinations))  # type: ignore

    meas_combinations = []
    for x in meas_label:
        meas_combinations.append([x])
    meas_combinations = list(itertools.product(*meas_combinations))  # type: ignore

    subcircuit_init_meas = []
    for init in init_combinations:
        for meas in meas_combinations:
            subcircuit_init_meas.append((tuple(init), tuple(meas)))
    return subcircuit_init_meas


@deprecate_func(
    removal_timeline="no sooner than CKT v0.8.0",
    since="0.7.0",
    package_name="circuit-knitting-toolbox",
    additional_msg="Use the wire cutting or automated cut-finding functionality in the ``circuit_knitting.cutting`` package. ",
)
def convert_to_physical_init(init: list[str]) -> tuple[int, tuple[str, ...]]:
    """
    Convert the initial measurements to the physical representations.

    Args:
        init: The initial measurements, e.g. ["zero", "2plus"]

    Returns:
        A tuple containing the coefficient of the physical measurements
        and the physical measurement bases
    """
    coefficient = 1
    for idx, x in enumerate(init):
        if x == "zero":
            continue
        elif x == "+zero":
            init[idx] = "zero"
        elif x == "+one":
            init[idx] = "one"
        elif x == "2plus":
            init[idx] = "plus"
            coefficient *= 2
        elif x == "-zero":
            init[idx] = "zero"
            coefficient *= -1
        elif x == "-one":
            init[idx] = "one"
            coefficient *= -1
        elif x == "2plusI":
            init[idx] = "plusI"
            coefficient *= 2
        else:
            raise Exception("Illegal initilization symbol :", x)
    return coefficient, tuple(init)


[docs] @deprecate_func( removal_timeline="no sooner than CKT v0.8.0", since="0.7.0", package_name="circuit-knitting-toolbox", additional_msg="Use the wire cutting or automated cut-finding functionality in the ``circuit_knitting.cutting`` package. ", ) def generate_summation_terms( subcircuits: Sequence[QuantumCircuit], complete_path_map: dict[Qubit, Sequence[dict[str, int | Qubit]]], num_cuts: int, ) -> tuple[ list[dict[int, int]], dict[int, dict[tuple[str, str], tuple[int, Sequence[tuple[int, int]]]]], dict[int, dict[tuple[tuple[str, ...], tuple[Any, ...]], int]], ]: """ Generate all summation terms for the final reconstructions. Final CutQC reconstruction result = Sum(summation_terms). summation_terms (list): [summation_term_0, summation_term_1, ...] --> 4^#cuts elements. | summation_term[subcircuit_idx] = subcircuit_entry_idx. | E.g. summation_term = {0:0,1:13,2:7} = Kron(subcircuit_0_entry_0, subcircuit_1_entry_13, subcircuit_2_entry_7). | subcircuit_entries[subcircuit_idx][init_label,meas_label] = subcircuit_entry_idx, kronecker_term. | kronecker_term (list): (coefficient, subcircuit_instance_idx). | Add coefficient*subcircuit_instance to subcircuit_entry. subcircuit_instances[subcircuit_idx][init,meas] = subcircuit_instance_idx. Args: subcircuits: The list of subcircuits complete_path_map: The paths of all the qubits through the circuit DAGs num_cuts: The number of cuts Returns: A tuple containing a summation term dict, a subcircuit entry dict, and a subcircuit instances dict """ summation_terms = [] subcircuit_entries: dict[ int, dict[tuple[str, str], tuple[int, Sequence[tuple[int, int]]]] ] = {subcircuit_idx: {} for subcircuit_idx in range(len(subcircuits))} subcircuit_instances: dict[ int, dict[tuple[tuple[str, ...], tuple[Any, ...]], int] ] = {subcircuit_idx: {} for subcircuit_idx in range(len(subcircuits))} O_rho_pairs = get_cut_qubit_pairs(complete_path_map=complete_path_map) for summation_term_idx in range(4**num_cuts): label = get_label(label_idx=summation_term_idx, num_cuts=num_cuts) # print('%d/%d summation term:'%(summation_term_idx+1,4**num_cuts),label) summation_term = {} subcircuit_labels = attribute_label( label=label, O_rho_pairs=O_rho_pairs, num_subcircuits=len(subcircuits) ) for subcircuit_idx in range(len(subcircuits)): # print('subcircuit %d label :'%subcircuit_idx,subcircuit_labels[subcircuit_idx]) subcircuit_entry_key = ( subcircuit_labels[subcircuit_idx]["init"], subcircuit_labels[subcircuit_idx]["meas"], ) if subcircuit_entry_key in subcircuit_entries[subcircuit_idx]: subcircuit_entry_idx, kronecker_term = subcircuit_entries[ subcircuit_idx ][subcircuit_entry_key] # print('Already in, subcircuit_entry {:d} : {}'.format(subcircuit_entry_idx,kronecker_term)) else: subcircuit_full_label = fill_label( subcircuit_idx=subcircuit_idx, subcircuit=subcircuits[subcircuit_idx], subcircuit_label=subcircuit_labels[subcircuit_idx], O_rho_pairs=O_rho_pairs, ) # print('Full label :',subcircuit_full_label) if len(subcircuit_full_label) != 2: raise ValueError( f"subcircuit_full_label variable should be a length-2 tuple: {subcircuit_full_label}" ) subcircuit_init_meas = get_init_meas( init_label=subcircuit_full_label[0], meas_label=subcircuit_full_label[-1], ) kronecker_term = [] for init_meas in subcircuit_init_meas: if len(init_meas) != 2: raise ValueError( f"init_meas variable should be a length-2 tuple: {init_meas}" ) init: tuple[str, ...] = init_meas[0] meas: tuple[str, ...] = init_meas[-1] coefficient, init = convert_to_physical_init( init=list(init_meas[0]) ) if (init, meas) in subcircuit_instances[subcircuit_idx]: subcircuit_instance_idx = subcircuit_instances[subcircuit_idx][ (init, meas) ] else: subcircuit_instance_idx = len( subcircuit_instances[subcircuit_idx] ) subcircuit_instances[subcircuit_idx][ (init, meas) ] = subcircuit_instance_idx kronecker_term.append((coefficient, subcircuit_instance_idx)) subcircuit_entry_idx = len(subcircuit_entries[subcircuit_idx]) subcircuit_entries[subcircuit_idx][subcircuit_entry_key] = ( subcircuit_entry_idx, kronecker_term, ) # print('New subcircuit_entry, {:d} : {}'.format(subcircuit_entry_idx,kronecker_term)) summation_term[subcircuit_idx] = subcircuit_entry_idx summation_terms.append(summation_term) # print('summation_term =',summation_term,'\n') return summation_terms, subcircuit_entries, subcircuit_instances
@deprecate_func( removal_timeline="no sooner than CKT v0.8.0", since="0.7.0", package_name="circuit-knitting-toolbox", additional_msg="Use the wire cutting or automated cut-finding functionality in the ``circuit_knitting.cutting`` package. ", ) def naive_compute( subcircuit_order: Sequence[int], summation_terms: Sequence[dict[int, int]], subcircuit_entry_probs: dict[int, dict[int, np.ndarray]], ) -> tuple[np.ndarray | None, dict[str, int]]: """ Reconstruct the full probability distribution from the subcircuits. This function is called within the build function, meant to be used in a multipooling manner. Args: subcircuit_order: The order of the subcircuit inputs summation_terms: The summation terms, as generated from generate_summation_terms subcircuit_entry_probs: The input probabilities from each of the subcircuit executions Returns: A tuple containing the reconstructed probability distribution and the approximate computational overhead of the function """ reconstructed_prob = None overhead = {"additions": 0, "multiplications": 0} for summation_term in summation_terms: summation_term_prob = None for subcircuit_idx in subcircuit_order: subcircuit_entry_idx = summation_term[subcircuit_idx] subcircuit_entry_prob = subcircuit_entry_probs[subcircuit_idx][ subcircuit_entry_idx ] if summation_term_prob is None: summation_term_prob = subcircuit_entry_prob else: summation_term_prob = np.kron( summation_term_prob, subcircuit_entry_prob ) overhead["multiplications"] += len(summation_term_prob) if reconstructed_prob is None: reconstructed_prob = summation_term_prob else: reconstructed_prob += summation_term_prob overhead["additions"] += len(reconstructed_prob) return reconstructed_prob, overhead
[docs] @deprecate_func( removal_timeline="no sooner than CKT v0.8.0", since="0.7.0", package_name="circuit-knitting-toolbox", additional_msg="Use the wire cutting or automated cut-finding functionality in the ``circuit_knitting.cutting`` package. ", ) def build( summation_terms: Sequence[dict[int, int]], subcircuit_entry_probs: dict[int, dict[int, np.ndarray]], num_cuts: int, num_threads: int, ) -> tuple[np.ndarray, list[int], dict[str, int]]: """ Reconstruct the full probability distribution from the subcircuits. Args: summation_terms: The summation terms used to generate the full vector, as generated in generate_summation_terms subcircuit_entry_probs: The probabilities vectors from the subcircuit executions num_cuts: The number of cuts num_threads: The number of threads to use for multithreading Returns: A tuple containing the reconstructed probability distribution of the full circuit, the ordering of the distribution, and the computational post-processing overhead """ smart_order = sorted( list(subcircuit_entry_probs.keys()), key=lambda subcircuit_idx: len(subcircuit_entry_probs[subcircuit_idx][0]), ) args = [] for i in range(num_threads * 5): segment_summation_terms = find_process_jobs( jobs=summation_terms, rank=i, num_workers=num_threads * 5 ) if len(segment_summation_terms) == 0: break arg = (smart_order, segment_summation_terms, subcircuit_entry_probs) args.append(arg) # Why "spawn"? See https://pythonspeed.com/articles/python-multiprocessing/ with mp.get_context("spawn").Pool(num_threads) as pool: results = pool.starmap(naive_compute, args) overhead = {"additions": 0, "multiplications": 0} reconstructed_prob = None for result in results: thread_reconstructed_prob, thread_overhead = result if reconstructed_prob is None: reconstructed_prob = thread_reconstructed_prob else: reconstructed_prob += thread_reconstructed_prob overhead["additions"] += thread_overhead["additions"] overhead["multiplications"] += thread_overhead["multiplications"] reconstructed_prob /= 2**num_cuts if reconstructed_prob is None: raise ValueError("Something went wrong during the build.") return reconstructed_prob, smart_order, overhead
@deprecate_func( removal_timeline="no sooner than CKT v0.8.0", since="0.7.0", package_name="circuit-knitting-toolbox", additional_msg="Use the wire cutting or automated cut-finding functionality in the ``circuit_knitting.cutting`` package. ", ) def find_process_jobs( jobs: Sequence[Any], rank: int, num_workers: int ) -> Sequence[Any]: """ Split up the total jobs into subjobs to be multithreaded. Args: jobs: All summation terms to be used rank: The input thread num_workers: 5 * the number of threads Returns: The sublist of jobs to be executed on the given thread """ count = int(len(jobs) / num_workers) remainder = len(jobs) % num_workers if rank < remainder: jobs_start = rank * (count + 1) jobs_stop = jobs_start + count + 1 else: jobs_start = rank * count + remainder jobs_stop = jobs_start + (count - 1) + 1 process_jobs = list(jobs[jobs_start:jobs_stop]) return process_jobs