import pyboolector as bt
from btor2circuit import Btor2Circuit

class GreedyProductMachine:
    """
    Greedily-scheduled product machine of two circuits, each with unrolled copies
    (can be lazily built and added to the solver).
    """
    solver: bt.Boolector
    "The underlying SAT solver."
    suffix_a: str
    "The suffix for the first circuit, e.g. `@A`."
    suffix_b: str
    "The suffix for the second circuit, e.g. `@B`."
    btor2_lines_a: list[str]
    "The lines of the first circuit's BTOR2 file."
    btor2_lines_b: list[str]
    "The lines of the second circuit's BTOR2 file."
    circuit_a: list[tuple[Btor2Circuit, Btor2Circuit]]
    "The unrolling copies of the first circuit."
    circuit_b: list[tuple[Btor2Circuit, Btor2Circuit]]
    "The unrolling copies of the second circuit."
    fast_slow_mode: bool
    "If True, the product machine is in fast-slow mode (i.e. the first circuit is the fast one)."
    a_is_active: list[bt.BoolectorNode]
    "A list whose i-th element indicates whether the i-th copy of circuit A is active. (w.r.t. the i+1 -th copy)"
    b_is_active: list[bt.BoolectorNode]
    "A list whose i-th element indicates whether the i-th copy of circuit B is active. (w.r.t. the i+1 -th copy)"
    P: list[bt.BoolectorNode]
    "A list whose i-th element is the safety property P at time i."
    
    def __init__(self, solver: bt.Boolector, suffix_a: str, suffix_b: str, btor2_lines_a: list[str], btor2_lines_b: list[str], fast_slow_mode: bool):
        self.solver = solver
        self.suffix_a = suffix_a
        self.suffix_b = suffix_b
        self.btor2_lines_a = btor2_lines_a
        self.btor2_lines_b = btor2_lines_b
        self.fast_slow_mode = fast_slow_mode
        self.circuit_a = []
        self.circuit_b = []
        self.a_is_active = []
        self.b_is_active = []
        self.P = []
        self.grow(1)

    def _plain_join_states(self, c_curr: Btor2Circuit, c_next: Btor2Circuit) -> bt.BoolectorNode:
        "Joins the next states of `c_curr` with the current states of `c_next`, and provide `active` signal."
        for name in c_curr.state_names():
            self.solver.Assert(c_curr.next_state_by_name(name) == c_next.curr_state_by_name(name))
        becomes_valid = (~c_curr.valid_signal()) & c_next.valid_signal()
        becomes_invalid = c_curr.valid_signal() & (~c_next.valid_signal())
        output_changes = self.solver.Const(False)
        for name in c_curr.output_names():
            output_changes = output_changes | (c_curr.output_by_name(name) != c_next.output_by_name(name))
        is_active = becomes_valid | becomes_invalid | (c_curr.valid_signal() & output_changes)
        return is_active

    def grow(self, height: int):
        "Grows the product machine to the given height. (Idempotent: does nothing if grown enough)"
        while len(self.P) < height:
            index = len(self.P)
            a_copy_0 = Btor2Circuit(self.solver, f"{self.suffix_a}{index}_0", self.btor2_lines_a)
            a_copy_1 = Btor2Circuit(self.solver, f"{self.suffix_a}{index}_1", self.btor2_lines_a)
            self.a_is_active.append(self._plain_join_states(a_copy_0, a_copy_1))
            b_copy_0 = Btor2Circuit(self.solver, f"{self.suffix_b}{index}_0", self.btor2_lines_b)
            b_copy_1 = Btor2Circuit(self.solver, f"{self.suffix_b}{index}_1", self.btor2_lines_b)
            self.b_is_active.append(self._plain_join_states(b_copy_0, b_copy_1))
            self.circuit_a.append((a_copy_0, a_copy_1))
            self.circuit_b.append((b_copy_0, b_copy_1))
            # build the P signal:
            outputs_all_eq = self.solver.Const(True)
            for name in a_copy_0.output_names():
                output_eq = a_copy_0.output_by_name(name) == b_copy_0.output_by_name(name)
                outputs_all_eq = outputs_all_eq & output_eq
            both_invalid = (~a_copy_0.valid_signal()) & (~b_copy_0.valid_signal())
            both_valid = a_copy_0.valid_signal() & b_copy_0.valid_signal()
            ob_eq = (both_valid & outputs_all_eq) | both_invalid
            self.P.append(ob_eq)
            # join the previous copies:
            if index == 0:
                continue
            a_prev_0, b_prev_0 = self.circuit_a[index - 1][0], self.circuit_b[index - 1][0]
            a_is_active, b_is_active = self.a_is_active[index - 1], self.b_is_active[index - 1]
            if self.fast_slow_mode:
                only_a_moves = self.solver.Const(False)
                only_b_moves = a_is_active & (~b_is_active)
            else:
                only_a_moves = (~a_is_active) & b_is_active
                only_b_moves = a_is_active & (~b_is_active)
            for name in a_copy_0.state_names():
                next_state = self.solver.Cond(only_b_moves, a_prev_0.curr_state_by_name(name), a_prev_0.next_state_by_name(name))
                self.solver.Assert(a_copy_0.curr_state_by_name(name) == next_state)
            for name in b_copy_0.state_names():
                next_state = self.solver.Cond(only_a_moves, b_prev_0.curr_state_by_name(name), b_prev_0.next_state_by_name(name))
                self.solver.Assert(b_copy_0.curr_state_by_name(name) == next_state)

    def get_unsafe_signal_within_n(self, n: int) -> bt.BoolectorNode:
        "Returns the signal that is True if the unsafe state is reached within n steps."
        unsafe = self.solver.Const(False)
        for i in range(n + 1):
            unsafe = unsafe | ~self.P[i]
        return unsafe
