import pyboolector as bt
import clause_utils as cl
from btor2circuit import Btor2Circuit
from product_machine import ProductMachine
from queue import SimpleQueue

class TermsMap:
    "A map from bit width to a set of involved terms of that width (for ET abstraction)."
    m: dict[int, set[cl.Node]]
    "The internal data structure."
    unroll_height: tuple[int, int]
    "The current unroll height for each circuit (A, B), such that it suffices for the terms map."

    def __init__(self, btor2_lines_a: list[str], btor2_lines_b: list[str]):
        "Adds only the state variables and the ObEq signal to the terms map."
        self.m = {}
        self.unroll_height = (1, 1)
        self.add_term(cl.OB_EQ_NODE, 1)  # ObEq
        dummy_solver = bt.Boolector()
        dummy_a = Btor2Circuit(dummy_solver, "@A", btor2_lines_a)
        for local_id, width in dummy_a.state_id_to_width_dict().items():
            self.add_term(cl.make_node_a(0, local_id), width)
        dummy_b = Btor2Circuit(dummy_solver, "@B", btor2_lines_b)
        for local_id, width in dummy_b.state_id_to_width_dict().items():
            self.add_term(cl.make_node_b(0, local_id), width)

    def add_term(self, term: cl.Node, width: int):
        "Adds a term to the terms map, and updates the unroll height."
        if width not in self.m:
            self.m[width] = set()
        self.m[width].add(term)
        copy_id, _ = term
        height_a, height_b = self.unroll_height
        if copy_id > 0:
            self.unroll_height = (max(height_a, copy_id), height_b)
        elif copy_id < 0:
            self.unroll_height = (height_a, max(height_b, -copy_id))

    def print(self, printer: cl.Printer) -> str:
        "Prints the terms map."
        s = []
        L, R = "{", "}"
        for width, nodes in self.m.items():
            s.append(f"{width}: {L}{', '.join(printer.node(node) for node in nodes)}{R}")
        return f"{L}{', '.join(s)}{R}"


class PartitionAssignment:
    """
    Represents a partition assignment (as in IC3+SA) as a hashable and immutable object.

    Instantiated with a concrete assignment and a map from bit-width to terms.
    The copy-ids of nodes must be de-primed.
    """
    pred_assignments: tuple[tuple[cl.Node, bool]]
    "Sorted predicate-value pairs."
    term_partitions: tuple[tuple[tuple[cl.Node]]]
    "Sorted term-equality relationships."
    is_bad: bool
    "Is it a bad state?"

    def __eq__(self, other: "PartitionAssignment"):
        if self is other:
            return True
        return self.pred_assignments == other.pred_assignments and self.term_partitions == other.term_partitions

    def __hash__(self):
        return hash((self.pred_assignments, self.term_partitions))

    def __init__(self, concrete: dict[cl.Node, str], terms_map: TermsMap):
        self.pred_assignments = []
        self.term_partitions = []
        self.is_bad = concrete[cl.OB_EQ_NODE] == "0"
        for width, terms in terms_map.m.items():
            if width == 1:
                for term in terms:
                    self.pred_assignments.append((term, concrete[term] == "1"))
            else:
                pair = sorted((concrete[term], term) for term in terms)
                partition: list[list[cl.Node]] = []
                last_value = ""
                for value, term in pair:
                    if value == last_value:
                        partition[-1].append(term)
                    else:
                        partition.append([term])
                        last_value = value
                self.term_partitions.append(tuple(sorted(tuple(sorted(group)) for group in partition)))
        self.pred_assignments = tuple(sorted(self.pred_assignments))
        self.term_partitions = tuple(sorted(self.term_partitions))

    def to_literals(self, positive: bool) -> list[cl.Literal]:
        "Converts the partition assignment to a list of literals (positive for cube, negative for clause)."
        literals = []
        for node, value in self.pred_assignments:
            literals.append(cl.make_p_lit(node, value == positive))
        for partition in self.term_partitions:
            # inner-group equality:
            for group in partition:
                for i in range(1, len(group)):
                    literals.append(cl.make_e_lit(group[0], group[i], positive))
            # inter-group inequality:
            for i in range(len(partition)):
                group_1 = partition[i]
                for j in range(i + 1, len(partition)):
                    group_2 = partition[j]
                    literals.append(cl.make_e_lit(group_1[0], group_2[0], not positive))
        return literals

    def to_str(self, printer: cl.Printer) -> str:
        "Converts the partition assignment to a string."
        L, R = "{", "}"
        items: list[str] = []
        for node, value in self.pred_assignments:
            items.append(f"{'' if value else '~'}{printer.node(node)}")
        for partition in self.term_partitions:
            groups_as_str = []
            for group in partition:
                groups_as_str.append(', '.join(printer.node(node) for node in group))
            items.append(f"{L}{' | '.join(groups_as_str)}{R}")
        return ' /\ '.join(items)


class AcexTreeNode:
    """
    Represents a node in the abstract counterexample tree. Children can be shared.
    """
    value: PartitionAssignment
    "The value of the node."
    a_stutter: "AcexTreeNode"
    "The child node for A stuttering."
    b_stutter: "AcexTreeNode"
    "The child node for B stuttering."
    sync: "AcexTreeNode"
    "The child node for A and B synchronizing."
    frame_index: int
    "Index of innermost frame containing the node, ranged from 0 to k. (used in blocking phase)"
    parents: set["AcexTreeNode"]
    "The parent nodes. (used in blocking phase)"

    def __init__(self, value: PartitionAssignment, frame_index: int):
        self.value = value
        self.a_stutter = None
        self.b_stutter = None
        self.sync = None
        self.frame_index = frame_index
        self.parents = set()

    def __lt__(self, other: "AcexTreeNode"):
        "Just for PriorityQueue support."
        return id(self) < id(other)

    def is_orphan(self) -> bool:
        "Returns whether the tree node has no parent."
        return len(self.parents) == 0

    def child_by_delta(self, delta: tuple[int, int]) -> "AcexTreeNode":
        "Returns the child node by the given delta (i.e. step goes by circuit A and B)."
        if delta == (1, 1):
            return self.sync
        if delta == (0, 1):
            return self.a_stutter
        if delta == (1, 0):
            return self.b_stutter
        raise ValueError(f"Invalid delta value: {delta}")


class AcexTree:
    """
    Represents an abstract counterexample tree.
    """
    root: AcexTreeNode
    "The root of the tree."
    value_to_node: dict[PartitionAssignment, AcexTreeNode]
    "Maps partition assignments to tree nodes, to avoild duplicates."

    def __init__(self, root: AcexTreeNode):
        self.root = root
        self.value_to_node = {root.value: root}

    def _add_child_helper(self, region: PartitionAssignment, frame_index: int) -> tuple[AcexTreeNode, bool]:
        "Adds a child node. Returns the child and whether the child is new."
        if region in self.value_to_node:  # reuse existing node
            return self.value_to_node[region], False
        child = AcexTreeNode(region, frame_index)
        self.value_to_node[region] = child
        return child, True

    def add_child_sync(self, parent: AcexTreeNode, region: PartitionAssignment, frame_index: int) -> bool:
        "Adds a `sync` child node. Returns whether the child is new."
        child, is_new = self._add_child_helper(region, frame_index)
        child.parents.add(parent)
        parent.sync = child
        return is_new

    def add_child_a_stutter(self, parent: AcexTreeNode, region: PartitionAssignment, frame_index: int) -> bool:
        "Adds a `a_stutter` child node. Returns whether the child is new."
        child, is_new = self._add_child_helper(region, frame_index)
        child.parents.add(parent)
        parent.a_stutter = child
        return is_new

    def add_child_b_stutter(self, parent: AcexTreeNode, region: PartitionAssignment, frame_index: int) -> bool:
        "Adds a `b_stutter` child node. Returns whether the child is new."
        child, is_new = self._add_child_helper(region, frame_index)
        child.parents.add(parent)
        parent.b_stutter = child
        return is_new

    def give_up_children(self, parent: AcexTreeNode) -> list[AcexTreeNode]:
        "Gives up all children from the parent node. Returns the list of children."
        children = []
        if parent.a_stutter:
            children.append(parent.a_stutter)
            parent.a_stutter.parents.discard(parent)
            parent.a_stutter = None
        if parent.b_stutter:
            children.append(parent.b_stutter)
            parent.b_stutter.parents.discard(parent)
            parent.b_stutter = None
        if parent.sync:
            children.append(parent.sync)
            parent.sync.parents.discard(parent)
            parent.sync = None
        return children

    def remove_orphan(self, node: AcexTreeNode) -> None:
        "Removes an orphan node from the `value_to_node` dict."
        if node.value in self.value_to_node:
            del self.value_to_node[node.value]
        
    def print_tree(self, printer: cl.Printer, fast_slow_mode: bool) -> list[str]:
        "Prints the tree as lines."
        lines = []
        queue: SimpleQueue[tuple[int, AcexTreeNode]] = SimpleQueue()
        queue.put((0, self.root))
        visited = {}

        next_index = 1
        while not queue.empty():
            index, node = queue.get()
            if node in visited:
                lines.append(f"{index} = {visited[node]}")
                continue
            visited[node] = index
            lines.append(f"{index}: {node.value.to_str(printer)}")
            if node.value.is_bad:
                lines.append("    BAD STATE REACHED.")
            else:
                next_line = f"    SYNC -> {next_index}"
                queue.put((next_index, node.sync))
                next_index += 1
                next_line += f"    A-STUTTER -> {next_index}"
                queue.put((next_index, node.a_stutter))
                next_index += 1
                if not fast_slow_mode:
                    next_line += f"    B-STUTTER -> {next_index}"
                    queue.put((next_index, node.b_stutter))
                    next_index += 1
                lines.append(next_line)
        return lines


class CexTreeNode:
    """
    Represents a node in the concrete counterexample tree.
    Children can be shared, but not always in the same way as in the abstract tree.
    """
    value: dict[cl.Node, str]
    "The concrete assignment at this node."
    prime_height: tuple[int, int]
    "The prime height of the node (in circuit A and B, respectively)."
    is_bad: bool
    "Whether the node is a bad state."

    def __init__(self, prime_height: tuple[int, int]):
        self.value = None
        self.prime_height = prime_height
        self.is_bad = False

    def set_assignment(self, value: dict[cl.Node, str]) -> None:
        "Sets the assignment of the node."
        self.value = value

    def print_assignment(self, printer: cl.Printer) -> str:
        "Prints the assignment of the node."
        return ", ".join(f"{printer.node(node)} = {int(value, 2)} ({value})" for node, value in self.value.items())


class CexTree:
    """
    Represents a concrete counterexample tree.
    """
    root: CexTreeNode
    "The root of the tree."
    prime_to_node: dict[tuple[int, int], CexTreeNode]
    "Maps prime heights to tree nodes."

    def __init__(self, root: CexTreeNode):
        self.root = root
        self.prime_to_node = {root.prime_height: root}

    def add_child(self, prime: tuple[int, int]) -> tuple[CexTreeNode, bool]:
        "Adds a child node at some prime height. Returns whether the child is new."
        if prime in self.prime_to_node:  # reuse existing node
            return False
        child = CexTreeNode(prime)
        self.prime_to_node[prime] = child
        return True

    def print_tree(self, printer: cl.Printer, fast_slow_mode: bool) -> list[str]:
        "Prints the tree as lines."
        lines = []
        for prime, cex_node in sorted(self.prime_to_node.items()):
            lines.append(f"{prime}: {cex_node.print_assignment(printer)}")
            if cex_node.is_bad:
                next_line = "    BAD STATE REACHED."
            else:
                next_line = ""
                prime_a, prime_b = prime
                if (prime_a + 1, prime_b + 1) in self.prime_to_node:
                    next_line += f"    SYNC -> {(prime_a + 1, prime_b + 1)}"
                if (prime_a, prime_b + 1) in self.prime_to_node:
                    next_line += f"    A-STUTTER -> {(prime_a, prime_b + 1)}"
                if not fast_slow_mode and (prime_a + 1, prime_b) in self.prime_to_node:
                    next_line += f"    B-STUTTER -> {(prime_a + 1, prime_b)}"
            lines.append(next_line)
        return lines
    

class CexSolver:
    """
    Maintains a SAT solver for cex simulation, given an acex.
    """
    solver: bt.Boolector
    "The SAT solver."
    fast_slow_mode: bool
    "If `True`, circuit A is assumed to be faster than circuit B."
    acex_tree: AcexTree
    "The abstract counterexample tree."
    cex_tree: CexTree
    "The concrete counterexample tree."
    terms_map: TermsMap
    "The current terms map used by the checker."
    builder: ProductMachine
    "The product machine for building the cex tree."
    guider: ProductMachine
    "The product machine for guiding the build of the cex tree."

    def __init__(self, btor2_lines_a: list[str], btor2_lines_b: list[str], fast_slow_mode: bool,
    acex_tree: AcexTree, terms_map: TermsMap):
        self.fast_slow_mode = fast_slow_mode
        self.acex_tree = acex_tree
        self.cex_tree = CexTree(CexTreeNode((0, 0)))
        self.terms_map = terms_map

        # create and configure the solver:
        self.solver = bt.Boolector()
        self.solver.Set_opt(bt.BTOR_OPT_MODEL_GEN, True)
        self.solver.Set_opt(bt.BTOR_OPT_INCREMENTAL, True)

        # set up the product machines:
        a_height, b_height = self.terms_map.unroll_height
        self.builder = ProductMachine(self.solver, "@A", "@B", btor2_lines_a, btor2_lines_b)
        self.builder.grow(a_height, b_height)
        self.builder.build_ob_eq_signal(0, 0)
        self.builder.apply_initial_relation()

        self.guider = ProductMachine(self.solver, "#A", "#B", btor2_lines_a, btor2_lines_b)
        self.guider.grow(a_height + 1, b_height + 1)
        self.guider.build_ob_eq_signal(0, 0)
        self.guider.build_ob_eq_signal(1, 1)  # sync
        self.guider.build_ob_eq_signal(0, 1)  # fast stutter
        if not self.fast_slow_mode:
            self.guider.build_ob_eq_signal(1, 0)  # slow stutter

    def assert_guider_constraint(self, acex_node: AcexTreeNode):
        "Asserts the guider constraint for the given acex node."
        for literal in acex_node.value.to_literals(True):
            self.solver.Assert(self.guider.get_literal_by_number(literal, 0, 0))
        for literal in acex_node.sync.value.to_literals(True):
            self.solver.Assert(self.guider.get_literal_by_number(literal, 1, 1))
        for literal in acex_node.a_stutter.value.to_literals(True):
            self.solver.Assert(self.guider.get_literal_by_number(literal, 0, 1))
        if not self.fast_slow_mode:
            for literal in acex_node.b_stutter.value.to_literals(True):
                self.solver.Assert(self.guider.get_literal_by_number(literal, 1, 0))

    def add_all_nodes_for_circuit_a(self, width_to_node_ids: dict[int, list[cl.Node]], copy_id: int):
        "Adds all nodes for circuit A to the miter-term map."
        for local_id, node in self.builder.circuit_a[0].id_to_node.items():
        # for local_id in self.builder.circuit_a[0].name_to_next_state_id.values():  # only next states
        #     node = self.builder.circuit_a[0].id_to_node[local_id]
            if node.width not in width_to_node_ids:
                width_to_node_ids[node.width] = list()
            width_to_node_ids[node.width].append(cl.make_node_a(copy_id, local_id))

    def add_all_nodes_for_circuit_b(self, width_to_node_ids: dict[int, list[cl.Node]], copy_id: int):
        "Adds all nodes for circuit B to the miter-term map."
        for local_id, node in self.builder.circuit_b[0].id_to_node.items():
        # for local_id in self.builder.circuit_b[0].name_to_next_state_id.values():  # only next states
        #     node = self.builder.circuit_b[0].id_to_node[local_id]
            if node.width not in width_to_node_ids:
                width_to_node_ids[node.width] = list()
            width_to_node_ids[node.width].append(cl.make_node_b(copy_id, local_id))

    def get_inner_pairwise_literals(self, width_to_node_ids: dict[int, list[cl.Node]]) -> list[cl.Literal]:
        "Returns a list of inner-pairwise literals for the miter-term map."
        literals = []
        for width, node_ids in width_to_node_ids.items():
            if width == 1:
                for node_id in node_ids:
                    literals.append(cl.make_p_lit(node_id, True))
            else:
                for i in range(len(node_ids)):
                    for j in range(i + 1, len(node_ids)):
                        literals.append(cl.make_e_lit(node_ids[i], node_ids[j], True))
        return literals

    def get_cross_pairwise_literals(self, width_to_node_ids_1: dict[int, list[cl.Node]],
    width_to_node_ids_2: dict[int, list[cl.Node]]) -> list[cl.Literal]:
        "Returns a list of cross-pairwise literals for the miter-term maps."
        literals = []
        for width in width_to_node_ids_1.keys():
            if width == 1:
                continue
            node_ids_1, node_ids_2 = width_to_node_ids_1[width], width_to_node_ids_2[width]
            for node_1 in node_ids_1:
                for node_2 in node_ids_2:
                    literals.append(cl.make_e_lit(node_1, node_2, True))
        return literals

    def get_miter_signal_by_literal(self, literal: cl.Literal, prime_a: int, prime_b: int) -> bt.BoolectorNode:
        "Returns the miter signal for the given literal."
        builder_node = self.builder.get_literal_by_number(literal, prime_a, prime_b)
        guider_node = self.guider.get_literal_by_number(literal, 0, 0)
        return self.solver.Xnor(builder_node, guider_node)

    def check_miters_or_update_terms_map(self, literals: list[cl.Literal], signals: list[bt.BoolectorNode]) -> bool:
        "Use a SAT query to check the miters. Returns `True` iff SAT. Updates terms map if UNSAT."
        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            core = [literals[i] for i, b in enumerate(self.solver.Failed(*signals)) if b]
            for literal in core:
                for node in literal[1:]:
                    width = self.guider.get_node_by_number(node, 0, 0).width
                    self.terms_map.add_term(node, width)
            return False

        if res != self.solver.SAT:
            raise RuntimeError("Unexpected solver result: " + res)
        return True
            
    def try_extend_cex_tree(self, prime_a: int, prime_b: int, acex_node: AcexTreeNode) -> bool:
        "Tries to extend the cex tree. Returns whether it's successful. If not, mutates the terms map."
        self.solver.Push()
        self.assert_guider_constraint(acex_node)

        ## Phase 1:
        width_to_node_ids_1: dict[int, list[cl.Node]] = {}
        a_height, b_height = self.terms_map.unroll_height
        for i in range(a_height):
            self.add_all_nodes_for_circuit_a(width_to_node_ids_1, i)
        for i in range(b_height):
            self.add_all_nodes_for_circuit_b(width_to_node_ids_1, i)
        literals = self.get_inner_pairwise_literals(width_to_node_ids_1)
        signals = [self.get_miter_signal_by_literal(literal, prime_a, prime_b) for literal in literals]
        self.solver.Assume(*signals)
        if not self.check_miters_or_update_terms_map(literals, signals):
            self.solver.Pop()
            return False

        ## Phase 2:
        self.solver.Assert(*signals)  # DO NOT USE `.Fixate_assumptions`!! BUGGY!!
        width_to_node_ids_2: dict[int, list[cl.Node]] = {}
        self.add_all_nodes_for_circuit_a(width_to_node_ids_2, a_height)
        self.add_all_nodes_for_circuit_b(width_to_node_ids_2, b_height)
        literals = self.get_cross_pairwise_literals(width_to_node_ids_1, width_to_node_ids_2)
        literals.extend(self.get_inner_pairwise_literals(width_to_node_ids_2))
        signals = [self.get_miter_signal_by_literal(literal, prime_a, prime_b) for literal in literals]
        self.solver.Assume(*signals)
        if not self.check_miters_or_update_terms_map(literals, signals):
            self.solver.Pop()
            return False

        self.solver.Pop()
        return True


    def grow_builder(self, prime_a: int, prime_b: int):
        "Grows the builder to the given prime height, and adds the ObEq signals"
        a_height, b_height = self.terms_map.unroll_height
        self.builder.grow(a_height + prime_a, b_height + prime_b)
        self.builder.build_ob_eq_signal(prime_a, prime_b)

    def add_builder_constraint(self, region: PartitionAssignment, prime_a: int, prime_b: int):
        "Adds a region constraint to the builder."
        literals = region.to_literals(True)
        lit_nodes = [self.builder.get_literal_by_number(literal, prime_a, prime_b) for literal in literals]
        self.solver.Assert(*lit_nodes)

    def finalize_cex_tree(self):
        "Finalizes the cex tree by filling the concrete assignments into the cex tree."
        assert self.solver.Sat() == self.solver.SAT
        state_ids_a = self.builder.circuit_a[0].name_to_state_id.values()
        state_ids_b = self.builder.circuit_b[0].name_to_state_id.values()
        for (prime_a, prime_b), tree_node in self.cex_tree.prime_to_node.items():
            assignment: dict[cl.Node, str] = {}
            for _id in state_ids_a:
                node = cl.make_node_a(prime_a, _id)
                assignment[node] = self.builder.circuit_a[prime_a].id_to_node[_id].assignment
            for _id in state_ids_b:
                node = cl.make_node_b(prime_b, _id)
                assignment[node] = self.builder.circuit_b[prime_b].id_to_node[_id].assignment
            tree_node.set_assignment(assignment)

    def build_cex_tree(self) -> bool:
        """
        Builds the concrete counterexample tree. Returns `True` iff the build succeeds.
        - If build succeeds, `self.cex_tree` is set to the built tree.
        - If build fails, `self.terms_map` (which is also checker's) is updated.
        """
        self.add_builder_constraint(self.acex_tree.root.value, 0, 0)
        if self.acex_tree.root.value.is_bad:  # 0-step cex
            self.cex_tree.root.is_bad = True
            self.finalize_cex_tree()
            return True
        
        queue: SimpleQueue[tuple[int, int, AcexTreeNode]] = SimpleQueue()
        deltas = ([] if self.fast_slow_mode else [(1, 0)]) + [(0, 1), (1, 1)]
        for delta_a, delta_b in deltas:
            assert self.cex_tree.add_child((delta_a, delta_b))
            acex_child = self.acex_tree.root.child_by_delta((delta_a, delta_b))
            self.grow_builder(delta_a, delta_b)
            self.add_builder_constraint(acex_child.value, delta_a, delta_b)
            queue.put((delta_a, delta_b, acex_child))

        while not queue.empty():
            prime_a, prime_b, acex_node = queue.get()
            if acex_node.value.is_bad:
                self.cex_tree.prime_to_node[(prime_a, prime_b)].is_bad = True
                continue
            for delta_a, delta_b in deltas:
                new_prime_a, new_prime_b = prime_a + delta_a, prime_b + delta_b
                if (new_prime_a, new_prime_b) not in self.cex_tree.prime_to_node:
                    self.grow_builder(new_prime_a, new_prime_b)
            if self.try_extend_cex_tree(prime_a, prime_b, acex_node):  # extend success
                for delta_a, delta_b in deltas:
                    new_prime_a, new_prime_b = prime_a + delta_a, prime_b + delta_b
                    if self.cex_tree.add_child((new_prime_a, new_prime_b)):
                        acex_child = acex_node.child_by_delta((delta_a, delta_b))
                        self.add_builder_constraint(acex_child.value, new_prime_a, new_prime_b)
                        queue.put((new_prime_a, new_prime_b, acex_child))
            else:  # extend failure
                return False  # terms map has been updated

        self.finalize_cex_tree()
        return True
