from numpy import binary_repr
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit_aer import AerSimulator
import matplotlib
matplotlib.use("TkAgg")   # Force GUI backend
import matplotlib.pyplot as plt

def create_carry_gate():
    qc = QuantumCircuit(4, name='CARRY')  # carry_in, a, b, carry_out
    qc.ccx(1, 2, 3)        # a AND b → carry_out
    qc.cx(1, 2)            # a XOR b (into qubit 2)
    qc.ccx(0, 2, 3)        # carry_in AND (a XOR b) → carry_out
    return qc.to_gate()

def create_sum_gate():
    qc = QuantumCircuit(3, name='SUM')  # carry_in, a, b
    # restore original order: carry XOR b, then a XOR (carry XOR b)
    qc.cx(0, 2)     # carry XOR b
    qc.cx(1, 2)     # a XOR (...)
    return qc.to_gate()

CARRY = create_carry_gate()
SUM = create_sum_gate()

class FullAdder:
    def __init__(self, bits):
        self.n = bits
        self.classical = ClassicalRegister(self.n + 1)
        self.A = QuantumRegister(self.n, 'A')
        self.B = QuantumRegister(self.n, 'B')
        self.Carry = QuantumRegister(self.n + 1, 'Carry')

        self.qc = QuantumCircuit(self.A, self.B, self.Carry, self.classical,
                                 name=f"{bits}-bit Quantum Full Adder")

        # Forward carry generation
        for i in range(self.n):
            self.qc.append(CARRY,
                           [self.Carry[i], self.A[i], self.B[i], self.Carry[i + 1]])

        # THIS CX was in your original code and is required for correct behaviour
        self.qc.cx(self.A[self.n - 1], self.B[self.n - 1])

        # Reverse + sum generation
        for i in reversed(range(self.n)):
            self.qc.append(SUM, [self.Carry[i], self.A[i], self.B[i]])
            if i > 0:
                self.qc.append(CARRY.inverse(),
                               [self.Carry[i - 1], self.A[i - 1], self.B[i - 1], self.Carry[i]])

    def draw(self, case=None):
        print(self.qc.name)
        if case == "decompose":
            return self.qc.decompose().draw('mpl')
        return self.qc.draw(output="mpl", style="clifford")

    def add(self, a, b):
        sim = AerSimulator()

        temp = QuantumCircuit(self.A, self.B, self.Carry, self.classical)

        # initialize A and B using X gates (safer & simpler than Statevector.initialize)
        bits_a = binary_repr(a, width=self.n)[::-1]  # LSB..MSB for mapping to qubit indices
        bits_b = binary_repr(b, width=self.n)[::-1]
        for i in range(self.n):
            if bits_a[i] == '1':
                temp.x(self.A[i])
            if bits_b[i] == '1':
                temp.x(self.B[i])

        temp = temp.compose(self.qc)
        temp.barrier()

        # measure B[i] (sum) into classical[i], and final carry into classical[n]
        for i in range(self.n):
            temp.measure(self.B[i], self.classical[i])
        temp.measure(self.Carry[self.n], self.classical[self.n])

        # Decompose custom gates so Aer accepts them
        temp = temp.decompose(reps=10)

        result = sim.run(temp).result()
        counts = result.get_counts()

        # Get most probable result key (bitstring)
        key = max(counts, key=counts.get)  # e.g. '0101' (Qiskit's string order must be remapped)

        # Remap key -> classical indices:
        # key string positions are big-endian (leftmost = highest classical index).
        # To access classical bit j (0..n), use key[::-1][j]
        key_rev = key[::-1]  # now key_rev[j] == classical bit j (j=0..n)
        # Build human-friendly string: carry (classical[n]) followed by sum bits MSB..LSB
        carry_bit = key_rev[self.n]
        sum_bits_msb_to_lsb = ''.join(key_rev[i] for i in range(self.n - 1, -1, -1))
        human_bits = carry_bit + sum_bits_msb_to_lsb  # MSB..LSB

        return human_bits

    def minus(self, a, b):
        saved = self.qc
        self.qc = self.qc.inverse()
        result = self.add(b, a)
        self.qc = saved
        return result

# -------------------------------
# Example usage
# -------------------------------
if __name__ == "__main__":
    adder = FullAdder(bits=3)
    adder.draw()
    plt.show()   # <-- REQUIRED in scripts but not in Jupyter notebooks

    r = adder.add(1, 1)
    print("1 + 1 = 0b" + r + " = " + str(int(r, 2)))
		
    r = adder.add(2, 3)
    print("2 + 3 = 0b" + r + " = " + str(int(r, 2)))

    r = adder.add(5, 3)
    print("5 + 3 = 0b" + r + " = " + str(int(r, 2)))

    r = adder.add(7, 7)
    print("7 + 7 = 0b" + r + " = " + str(int(r, 2)))