import pyboolector as bt
import product_machine as PM
import greedy_product_machine as GPM
from typing import Union, Optional, Iterable
from queue import PriorityQueue, SimpleQueue

(BOTH_MOVE, ONLY_B_MOVES, ONLY_A_MOVES) = range(3)
"Enums indicating which of the two circuits make a move."
MoveType = int
"One from (BOTH_MOVE, ONLY_B_MOVES, ONLY_A_MOVES)"

class Global:
    # algorithm configurations:
    btor2_lines_a: list[str]
    "The lines of the first circuit's BTOR2 file."
    btor2_lines_b: list[str]
    "The lines of the second circuit's BTOR2 file."
    fast_slow_mode: bool
    "If `True`, circuit A is assumed to be faster than circuit B."

    # product-machine related:
    I: list[PM.Cube]
    "The initial relation."
    P: list[PM.Cube]
    "The safety property (ObEq)."
    state_var_wires: list[PM.Wire]
    "all state variables in the product machine, as wires"
    
    # just pointers to the actual objects:
    abstract_space: "AbstractSpace"
    "The abstract space used for the algorithm, gradually refined."
    blocking_manager: "BlockingManager"
    "The blocking manager used for the algorithm, keeping track of hard/soft blockers and corners."
    simulator: "Simulator"
    "The simulator for obtaining values of any node, given an initial assignment."
    greedy_runner: "GreedySchedulingRunner"
    "The SAT-based bounded model checker with greedy scheduling."
    obligation_manager: "ObligationManager"
    "The obligation manager, maintaining the obligation forest and the obligation queue."
    frame_manager: "FrameManager"
    "The frame manager, maintaining the frames for the IC3 algorithm."

    def all_move_types(self) -> list[MoveType]:
        "Returns all possible move types (strictly in order), according to `fast_slow_mode`."
        if self.fast_slow_mode:
            return [BOTH_MOVE, ONLY_B_MOVES]
        return [BOTH_MOVE, ONLY_B_MOVES, ONLY_A_MOVES]
    
global_ = Global()
"""
Some global vars (but really meant to be "constants")!

It serves as a global "look-up table" to speed up the entire equivalence checking algorithm.
Once a value is set, it should NOT be re-set again.
"""


Value = Union[str, tuple[tuple[str, str], ...]]
"Represents a value, str for bitvector, list[str] for a single-dimension array."
State = list[tuple[PM.Wire, Value]]
"Represents a concrete state, mapping state vars (wires) to their values"

def extract_concrete_state(pm: PM.ProductMachine, ex_prime_a: int, ex_prime_b: int) -> State:
    "Extracts the concrete state from from `pm`'s solver (latest SAT assignment)."
    state = []
    for wire in global_.state_var_wires:
        node = pm.wire_to_node(wire, ex_prime_a, ex_prime_b)
        if isinstance(node.assignment, str):
            # Per Boolector's documentation:
            # "If the queried node is a bit vector, its assignment is represented as string."
            value = node.assignment
        elif isinstance(node.assignment, list):
            # Per Boolector's documentation:
            # "If it is an array, its assignment is represented as a list of tuples (index, value)."
            value = tuple(node.assignment)
        else:  # other types not yet supported:
            raise RuntimeError(f"Unexpected value type: {node.assignment}")
        state.append((wire, value))
    return state

def assume_concrete_starting_state(state: State, pm: PM.ProductMachine):
    "Adding assumptions about the concrete starting state to the solver."
    for wire, value in state:
        node = pm.wire_to_node(wire, 0, 0)
        if isinstance(value, str):
            pm.solver.Assume(pm.solver.Const(value) == node)
        elif isinstance(value, list):
            for i, elem in value:
                pm.solver.Assume(pm.solver.Const(elem) == pm.solver.Read(node, pm.solver.Const(i)))
        else:
            assert False, "Value type not yet supported"


AState = str
"Represents an abstract state, which is a string of 0, 1, and - (don't care)."

def cube_syntactically_contains(container: AState, containee: AState) -> bool:
    "Returns `True` iff `container` (as an abstract cube) syntactically contains `containee`."
    ## TODO: this is a naive solution, may be optimized e.g. with Python's big integer
    ## TODO: this just checks syntactic containment; may need an SMT solver to check semantic containment
    for i, ch in enumerate(container):
        if ch != '-' and ch != containee[i]:
            return False
    return True

def dnf_syntactically_contains(container: Iterable[AState], containee: AState) -> bool:
    "Returns `True` iff `container` (as a DNF) syntactically contains `containee`."
    return any(cube_syntactically_contains(cube, containee) for cube in container)

def unsat_core_intersection(cubes: Iterable[AState]) -> AState:
    "Returns the intersection of a set of UNSAT cores. Assumes non-empty!!"
    intersection: list[str] = []
    for cube in cubes:
        if intersection == []:
            intersection = [ch for ch in cube]
        else:
            for i, ch in enumerate(cube):
                if ch != "-":
                    intersection[i] = ch
    return "".join(intersection)


class AbstractSpace:
    """
    Maintains an abstract space (predicate abstracted).
    Converts 'cube strings' to the set of terms/wires they are referring to.
    """
    atoms: list[PM.Atom]
    "The atoms in the abstract space."
    atom_to_index: dict[PM.Atom, int]
    "The index of each atom in the abstract space."
    relational: list[bool]
    "The i-th element is `True` iff the i-th atom is relational."
    max_height: int
    "The maximum heights of the two circuits, used to grow the product machine."

    def __init__(self):
        self.atoms = []
        self.atom_to_index = {}
        self.relational = []
        self.max_height = 0

    def add_atom(self, atom: PM.Atom):
        "Adds an atom to the abstract space."
        self.atom_to_index[atom] = len(self.atoms)
        self.atoms.append(atom)
        self.relational.append(PM.is_relational(atom))
        self.max_height = max(self.max_height, 1 + PM.prime_of_atom(atom))

    def get_lit(self, i: int, positive: bool) -> PM.Literal:
        "Returns the literal for at position i."
        return (positive, self.atoms[i])
    
    def get_cube(self, cube_in: AState) -> PM.Cube:
        "Returns the cube for a string representation (consisting of 0, 1, or -)."
        cube = []
        for i, ch in enumerate(cube_in):
            if ch == '0':
                cube.append(self.get_lit(i, False))
            elif ch == '1':
                cube.append(self.get_lit(i, True))
        return cube

    def get_negation_of_cube(self, cube_in: AState) -> PM.Clause:
        "Returns the negation of a cube for a string representation (consisting of 0, 1, or -)."
        clause = []
        for i, ch in enumerate(cube_in):
            if ch == '0':
                clause.append(self.get_lit(i, True))
            elif ch == '1':
                clause.append(self.get_lit(i, False))
        return clause
    
    def cube_as_abstract_state(self, cube: PM.Cube) -> AState:
        "Converts a cube back to an abstract state."
        state = ['-'] * len(self.atoms)
        for pos, atom in cube:
            state[self.atom_to_index[atom]] = '1' if pos else '0'
        return ''.join(state)
    
    def I_as_abstract_cubes(self) -> list[AState]:
        "Converts the initial relation to abstract DNF."
        return [self.cube_as_abstract_state(cube) for cube in global_.I]
    
    def P_as_abstract_cubes(self) -> list[AState]:
        "Converts the safety property to abstract DNF."
        return [self.cube_as_abstract_state(cube) for cube in global_.P]
    
    def extract_min_cube(self, pm: PM.ProductMachine, ex_prime_a: int, ex_prime_b: int) -> AState:
        "Extracts the abstract state (min cube) from `pm`'s solver (latest SAT assignment)."
        state = []
        for atom in self.atoms:
            node = pm.atom_to_node(atom, ex_prime_a, ex_prime_b)
            state.append('1' if node.assignment == '1' else '0')
        return ''.join(state)
    
    def min_cube_to_r_cube(self, min_cube: AState) -> AState:
        "Converts a min cube to a minimal R-cube."
        state = []
        for i, ch in enumerate(min_cube):
            if self.relational[i]:
                state.append('1' if ch == '1' else '-')
            else:
                state.append(ch)
        return ''.join(state)
    
    def r_cube_to_min_cube(self, r_cube: AState) -> AState:
        "Converts a minimal R-cube to a min cube."
        return r_cube.replace('-', '0')
    

class Simulator: 
    """
    A simulator for the product machine, utilizing Boolector's model generation.

    Used to obtain the values of any node, given an initial assignment.
    """
    solver: bt.Boolector
    "The SMT solver used by the simulator."
    pm: PM.ProductMachine
    "The product machine used with the solver."

    def __init__(self):
        self.solver = bt.Boolector()  ## TODO: If it's really slow, may try something like verilog simulator instead
        self.solver.Set_opt(bt.BTOR_OPT_MODEL_GEN, True)  ## TODO: Is 1 enough? May be faster than 2.
        self.solver.Set_opt(bt.BTOR_OPT_INCREMENTAL, True)  ## TODO: I doubt assuming is faster than from scratch?!
        self.pm = PM.ProductMachine(self.solver, "@A", "@B", global_.btor2_lines_a, global_.btor2_lines_b)

    def simulate(self, state: State, target_atoms: list[PM.Atom]) -> list[Value]:
        "Simulates the product machine, obtaining the values of the target atoms."
        for atom in target_atoms:
            atom_height = PM.prime_of_atom(atom) + 1
            self.pm.grow(atom_height, atom_height)
        target_nodes = [self.pm.atom_to_node(atom, 0, 0) for atom in target_atoms]
        assume_concrete_starting_state(state, self.pm)
        res = self.solver.Sat()
        if res != self.solver.SAT:
            raise RuntimeError("Unexpected SAT result" + str(res))
        return [node.assignment for node in target_nodes]

    def to_min_cube(self, state: State) -> AState:
        "By simulation, converts a concrete state to an abstract state (as min_cube)."
        atoms = global_.abstract_space.atoms
        values = self.simulate(state, atoms)
        return ''.join(values)
    

class GreedySchedulingRunner:
    """
    A SAT-based bounded model checker with greedy scheduling.
    """
    solver: bt.Boolector
    "The SAT solver used for this runner."
    pm: PM.ProductMachine
    "The product machine used with the solver (to constrain initial states)."
    gpm: GPM.GreedyProductMachine
    "The greedy product machine used with the solver (to actually run greedy scheduling)"

    def __init__(self):
        self.solver = bt.Boolector()
        self.solver.Set_opt(bt.BTOR_OPT_MODEL_GEN, True)
        self.solver.Set_opt(bt.BTOR_OPT_INCREMENTAL, True)
        self.pm = PM.ProductMachine(self.solver, "@A", "@B", global_.btor2_lines_a, global_.btor2_lines_b)
        self.gpm = GPM.GreedyProductMachine(self.solver, "#A", "#B", global_.btor2_lines_a, global_.btor2_lines_b, global_.fast_slow_mode)
        self._connect_initial_states()

    def _connect_initial_states(self):
        "Connects the initial states of `pm` and `gpm`."
        for name in self.pm.circuit_a[0].state_names():
            self.solver.Assert(self.pm.circuit_a[0].curr_state_by_name(name) == self.gpm.circuit_a[0][0].curr_state_by_name(name))
        for name in self.pm.circuit_b[0].state_names():
            self.solver.Assert(self.pm.circuit_b[0].curr_state_by_name(name) == self.gpm.circuit_b[0][0].curr_state_by_name(name))
    
    def run_from_concrete_state(self, n: int, state: State) -> bool:
        "Runs the bounded model checker from a concrete state, returning `True` iff it's safe."
        self.gpm.grow(n + 1)
        assume_concrete_starting_state(state, self.pm)
        self.solver.Assume(self.gpm.get_unsafe_signal_within_n(n))
        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            return True
        if res != self.solver.SAT:
            raise RuntimeError("Unexpected SAT result: " + str(res))
        return False
    
    def run_from_abstract_state(self, n: int, astate: AState) -> tuple[Optional[State], Optional[AState]]:
        "Runs the bounded model checker from an abstract state, returning the starting state leading it to unsafe."
        self.pm.grow(global_.abstract_space.max_height, global_.abstract_space.max_height)
        self.solver.Assume(self.pm.cube_to_node(global_.abstract_space.get_cube(astate), 0, 0))
        self.gpm.grow(n + 1)
        self.solver.Assume(self.gpm.get_unsafe_signal_within_n(n))
        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            return None, None
        if res != self.solver.SAT:
            raise RuntimeError("Unexpected SAT result: " + str(res))
        return extract_concrete_state(self.pm, 0, 0), global_.abstract_space.extract_min_cube(self.pm, 0, 0)


class AcexTreeEdge:
    """
    An edge in the abstract counterexample tree.
    """
    move_type: MoveType
    "Which circuit makes a move."
    start_state: State
    "The concrete state where the edge starts."
    start_min_cube: AState
    "The minimal cube of the start state."
    end_state: State
    "The concrete state where the edge ends."
    end_min_cube: AState
    "The minimal cube of the end state."

    def __init__(self, move_type: MoveType, start_state: State, start_min_cube: AState, end_state: State, end_min_cube: AState):
        self.move_type = move_type
        self.start_state = start_state
        self.start_min_cube = start_min_cube
        self.end_state = end_state
        self.end_min_cube = end_min_cube
    
    def __repr__(self):
        return f"AcexTreeEdge({self.move_type}, {self.start_state}, {self.start_min_cube}, {self.end_state}, {self.end_min_cube})"


class AcexTreeNode:
    """
    A node in the abstract counterexample tree.
    """
    min_cube: AState
    "The minimal cube of the node."
    children: list[tuple[AcexTreeEdge, 'AcexTreeNode']]
    "The children of the node (if not the leaf), i-th child is of the i-th move type."
    height: int
    "An upper bound on how many steps it takes to (abstractly) reach a bad state."
    
    def __init__(self, min_cube: AState, children: list[tuple[AcexTreeEdge, 'AcexTreeNode']], height: int):
        self.min_cube = min_cube
        self.children = children
        self.height = height
    def __repr__(self):
        return f"AcexTreeNode(\n{self.min_cube},\n{self.children},\n{self.height}\n)\n"


class ObligationTreeNode:
    """
    A node in the obligation tree. Such nodes look "like edges".
    """
    edge: AcexTreeEdge
    "The edge that the node represents. (Its`.end_min_cube`'s r_cube is the obligation.)"
    child_nodes: list['ObligationTreeNode']
    "The children of the node (Can be of any number; don't mess up with `AcexTreeNode`!)"

    def __init__(self, edge: AcexTreeEdge):
        self.edge = edge
        self.child_nodes = []

    def add_child_node(self, child_node: 'ObligationTreeNode'):
        "Adds a child edge to the node."
        self.child_nodes.append(child_node)

    def __lt__(self, other: 'ObligationTreeNode') -> bool:
        return id(self.edge) < id(other.edge)


class ObligationManager:
    """
    Maintains the obligation forest and the obligation queue.
    """
    obligations: PriorityQueue[tuple[int, ObligationTreeNode]]
    "The obligations to be processed, as a priority queue of (frame index, abstract state (minimal R-cubes))."
    done_obligations: list[ObligationTreeNode]
    "The obligations already pushed out of the outermost frame."
    obligation_roots: list[ObligationTreeNode]
    "The roots of the obligation trees (all residing in `I`)."
    obligation_to_level: dict[ObligationTreeNode, int]
    "The level of each obligation (i.e. the frame index). Used for replaying during update."

    def __init__(self):
        self.reset_all()

    def reset_all(self):
        "Clears all obligations."
        self.obligations = PriorityQueue()
        self.done_obligations = []
        self.obligation_roots = []
        self.obligation_to_level = {}

    def add(self, level: int, node: ObligationTreeNode, parent: Optional[ObligationTreeNode]):
        "Pushes an obligation to the priority queue, and adds it to the obligation forest."
        self.re_add(level, node)
        if parent is None:
            self.obligation_roots.append(node)
        else:
            parent.add_child_node(node)

    def is_done(self) -> bool:
        "Returns `True` iff there are no obligations left."
        return self.obligations.empty()

    def pop(self) -> tuple[int, ObligationTreeNode]:
        "Pops the obligation with the lowest level."
        level, node = self.obligations.get()
        self.obligation_to_level.pop(node)
        return level, node
    
    def re_add(self, level: int, node: ObligationTreeNode):
        "Re-adds an obligation to the priority queue."
        if node in self.obligation_to_level:
            return  # already in the queue, redundant, skip
        self.obligations.put((level, node))
        self.obligation_to_level[node] = level
    
    def finish(self, node: ObligationTreeNode):
        "Finishes an obligation, moving it to `done_obligations`."
        self.done_obligations.append(node)

    def undone_all(self, level: int):
        "Undoes all done obligations, and put them back to some level."
        for node in self.done_obligations:
            self.re_add(level, node)
        self.done_obligations = []

    def _copy_assign(self, other: 'ObligationManager'):
        "Copies the obligations from another obligation manager."
        self.obligations = other.obligations
        self.done_obligations = other.done_obligations
        self.obligation_roots = other.obligation_roots
        self.obligation_to_level = other.obligation_to_level

    def replay_all(self):
        "Replays all obligations, according to the removed schedules and blockers."
        temp = ObligationManager()
        stack = [(edge_node, None) for edge_node in self.obligation_roots]
        visited = set()
        while stack:
            edge_node, parent = stack.pop()
            if edge_node in visited:
                continue
            visited.add(edge_node)
            edge = edge_node.edge
            if edge.start_min_cube in global_.blocking_manager.soft_blockers:
                continue
            if (edge.move_type, edge.start_min_cube) in global_.blocking_manager.removed_schedules:
                continue
            if edge_node in self.obligation_to_level:
                level = self.obligation_to_level[edge_node]
                temp.add(level, edge_node, parent)
            else:
                temp.done_obligations.append(edge_node)
            for child_node in edge_node.child_nodes:
                stack.append((child_node, edge_node))
        self._copy_assign(temp)


class BlockingManager:
    """
    Maintains hard/soft blockers and corners
    """
    hard_blockers: list[tuple[State, int]]
    "Concrete states(.0) that is k(.1)-bad, i.e. moves into a bad state within k steps"
    soft_blockers: dict[AState, AcexTreeNode]
    "Abstract states that corresponds to an abstract counterexample tree node"
    soft_corners: dict[tuple[AState, MoveType], tuple[AcexTreeEdge, AcexTreeNode]]
    "Abstract states that has an edge to an abstract counterexample tree node"
    I: list[AState]
    "The initial relation in the abstract space."
    removed_schedules: set[tuple[MoveType, AState]]
    "Keeping track of removed schedules so far."

    def __init__(self):
        self.hard_blockers = []

    def reset_for_new_abstract_space(self) -> Optional[AcexTreeNode]:
        "Resets the soft blockers and corners -- when the abstract space is updated."
        self.soft_blockers = {}
        self.soft_corners = {}
        self.I = global_.abstract_space.I_as_abstract_cubes()
        self.removed_schedules = set()
        self._process_hard_blockers()
        return None  # No "hard corners" to consider for now...

    def _process_hard_blockers(self):
        "converts hard blockers to soft blockers, according to the abstract space"
        for state, height in self.hard_blockers:
            min_cube = global_.simulator.to_min_cube(state)
            if min_cube in self.soft_blockers:
                continue
            tree_node = AcexTreeNode(min_cube, [], height)
            self.soft_blockers[min_cube] = tree_node

    def _merge_to_blocker_if_ready(self, start_min_cube: AState) -> bool:
        "If all edges from the state are already in `soft_corners`, merge it to a blocker; else do nothing."
        if not all((start_min_cube, move_type) in self.soft_corners for move_type in global_.all_move_types()):
            return False
        children = []
        heights = []
        for move_type in global_.all_move_types():
            edge, child_node = self.soft_corners[(start_min_cube, move_type)]
            heights.append(child_node.height)
            children.append((edge, child_node))
        height = max(heights) + 1
        tree_node = AcexTreeNode(start_min_cube, children, height)
        self.soft_blockers[start_min_cube] = tree_node
        return True

    def update_scheduling(self, new_bad_min_cube: AState) -> Optional[AcexTreeNode]:
        "Updates the scheduling according to the new bad state and the obligation forest. Returns the root if `I` is blocked."
        queue: SimpleQueue[AState] = SimpleQueue() # hard iff the concrete state(.1) is not None
        queue.put(new_bad_min_cube)
        while not queue.empty():
            cube = queue.get()
            if dnf_syntactically_contains(self.I, cube):
                return self.soft_blockers[cube]
            stack = [edge_node for edge_node in global_.obligation_manager.obligation_roots]
            visited = set()
            while stack:
                edge_node = stack.pop()
                if edge_node in visited:
                    continue
                visited.add(edge_node)
                edge = edge_node.edge
                obligation = global_.abstract_space.min_cube_to_r_cube(edge.end_min_cube)
                if cube_syntactically_contains(obligation, cube):
                    self.soft_corners[(edge.start_min_cube, edge.move_type)] = (edge, self.soft_blockers[cube])
                    if (edge.start_min_cube not in self.soft_blockers) and self._merge_to_blocker_if_ready(edge.start_min_cube):
                        queue.put(edge.start_min_cube)
                    else:
                        self.removed_schedules.add((edge.move_type, edge.start_min_cube))
                        global_.frame_manager.remove_schedule_for_all_frames(edge.move_type, edge.start_min_cube)
                for child_node in edge_node.child_nodes:
                    stack.append(child_node)
        global_.obligation_manager.replay_all()
        return None
                

class FrameReachSolver:
    """
    Maintains a per-frame solver for frame reachability queries (shaped `SAT(? /\ T /\ F[i]')`.)
    """
    solver: bt.Boolector
    "The SAT solver."
    pm: PM.ProductMachine
    "The product machine used with the solver."
    move_type: MoveType
    "Which of the two circuits make a move."
    ex_prime_a: int
    "either 1 if a moves, or 0 otherwise"
    ex_prime_b: int
    "either 1 if b moves, or 0 otherwise"

    def __init__(self, move_type: MoveType):        
        # create and configure the solver:
        self.solver = bt.Boolector()
        self.solver.Set_opt(bt.BTOR_OPT_MODEL_GEN, True)
        self.solver.Set_opt(bt.BTOR_OPT_INCREMENTAL, True)

        # set up the product machine:
        self.pm = PM.ProductMachine(self.solver, "@A", "@B", global_.btor2_lines_a, global_.btor2_lines_b)
        height_to_grow = global_.abstract_space.max_height + 1  # will be fixed throughout an epoch
        self.pm.grow(height_to_grow, height_to_grow)

        # consider the move type:
        self.move_type = move_type
        self.ex_prime_a = 0 if self.move_type == ONLY_B_MOVES else 1
        self.ex_prime_b = 0 if self.move_type == ONLY_A_MOVES else 1

    def assert_clause(self, clause: PM.Clause):
        "Asserts a clause in `F[i]'` to the solver. "
        node = self.pm.clause_to_node(clause, self.ex_prime_a, self.ex_prime_b)
        self.solver.Assert(node)

    def assert_sched_condition(self, condition: PM.Clause):
        "Asserts the scheduling condition to the solver."
        node = self.pm.clause_to_node(condition, 0, 0)
        self.solver.Assert(node)

    def _extract_acex_tree_edge(self) -> AcexTreeEdge:
        "Assuming SAT, extracts the start/end state pairs, both concrete and abstract."
        start_state = extract_concrete_state(self.pm, 0, 0)
        start_min_cube = global_.abstract_space.extract_min_cube(self.pm, 0, 0)
        end_state = extract_concrete_state(self.pm, self.ex_prime_a, self.ex_prime_b)
        end_min_cube = global_.abstract_space.extract_min_cube(self.pm, self.ex_prime_a, self.ex_prime_b)
        return AcexTreeEdge(self.move_type, start_state, start_min_cube, end_state, end_min_cube)
    
    def _is_P_literal(self, lit: PM.Literal) -> bool:
        "Returns `True` iff the literal is in P."
        return any(lit in cube for cube in global_.P)

    def sat_query_for_initial(self) -> Optional[AcexTreeEdge]:
        "Queries `SAT(I /\ c /\ T /\ F[i]')`, returning the model if SAT."
        self.solver.Assume(self.pm.dnf_to_node(global_.I, 0, 0))
        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            return None
        if res != self.solver.SAT:
            raise RuntimeError("Unexpected SAT result: " + str(res))
        return self._extract_acex_tree_edge()

    def sat_query_for_obligation(self, r_cube: AState) -> Union[AcexTreeEdge, AState]:
        "Queries `SAT(r_cube /\ c /\ T /\ F[i]')`, returning the model if SAT; UNSAT core otherwise."
        self.solver.Push()
        nodes = []
        for i, ch in enumerate(r_cube):
            if ch == '-':
                node = self.solver.Const(True)
            else:
                lit = global_.abstract_space.get_lit(i, ch == '1')
                node = self.pm.lit_to_node(lit, 0, 0)
                if self._is_P_literal(lit):  # P-literals cannot be generalized away
                    self.solver.Assert(node)
                    node = self.solver.Const(True)
            nodes.append(node)
        self.solver.Assume(*nodes)
        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            in_core: list[bool] = self.solver.Failed(*nodes)
            core = "".join(r_cube[i] if in_core[i] else '-' for i in range(len(r_cube)))
            self.solver.Pop()
            return core
        if res != self.solver.SAT:
            raise RuntimeError("Unexpected SAT result: " + str(res))
        edge = self._extract_acex_tree_edge()
        self.solver.Pop()
        return edge
    
    def sat_query_for_propagation(self, cube: AState) -> bool:
        "Queries `SAT(cube /\ c /\ T /\ F[i]')`, returning `True` iff SAT."
        self.solver.Assume(self.pm.cube_to_node(global_.abstract_space.get_cube(cube), 0, 0))
        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            return False
        if res != self.solver.SAT:
            raise RuntimeError("Unexpected SAT result: " + str(res))
        return True


class Frame:
    """
    A frame as in IC3, essentially a set of clauses.
    F[0], the innermost frame, coincides with the negation of P.
    Other frames are initialized as the negation of I, and incrementally refined.
    """
    delta_clauses: set[AState]
    "The clause (actually CUBE!!) difference between this frame `F[i]` and its outer frame `F[i+1]`."
    reach_solvers: list[FrameReachSolver]
    "The reachability solvers for this frame, i-th element for the i-th move type."

    def __init__(self):
        self.delta_clauses = set()
        self.reach_solvers = [FrameReachSolver(move_type) for move_type in global_.all_move_types()]


class FrameManager:
    """
    Manages the frames for the IC3 algorithm.
    """
    frames: list[Frame]
    "A list of frames as in IC3."

    def __init__(self):
        self.frames = []

    def reset(self):
        "Resets the frames for a new epoch."
        self.frames = []
        self.add_frame()
        for cube in global_.abstract_space.P_as_abstract_cubes():  # initialize F[0] as the negated P
            self.add_clause(0, cube)

    def k(self) -> int:
        "Returns the maximal frame index."
        return len(self.frames) - 1
    
    def get_frame(self, index: int) -> Frame:
        "Returns the frame at the given index."
        return self.frames[index]
    
    def get_frame_k(self) -> Frame:
        "Returns the frame at the maximal index."
        return self.get_frame(self.k())

    def add_frame(self):
        "Adds a new frame to the list of frames."
        self.frames.append(Frame())
        for cube in global_.abstract_space.I_as_abstract_cubes():  # initialize the new frame as the negated I:
            self.add_clause(self.k(), cube)
        for (move_type, cube) in global_.blocking_manager.removed_schedules:  # replay removed schedules:
            negation = global_.abstract_space.get_negation_of_cube(cube)
            self.frames[self.k()].reach_solvers[move_type].assert_sched_condition(negation)

    def add_clause(self, index: int, cube: AState):
        "Adds a clause (negation of a cube) to F[i] and all its inner frames, avoiding syntactic containment."
        i = index
        while i >= 0:
            frame = self.frames[i]
            for solver in frame.reach_solvers:
                solver.assert_clause(global_.abstract_space.get_negation_of_cube(cube))
            frame.delta_clauses = {c for c in frame.delta_clauses if not cube_syntactically_contains(cube, c)}
            i -= 1
        self.frames[index].delta_clauses.add(cube)

    def propagate_clause(self, index: int, cube: AState):
        "Propagates a clause (negation of a cube) from `F[i]` to `F[i+1]`."
        self.frames[index].delta_clauses.remove(cube)
        self.frames[index + 1].delta_clauses.add(cube)
        for solver in self.frames[index + 1].reach_solvers:
            solver.assert_clause(global_.abstract_space.get_negation_of_cube(cube))

    def remove_schedule_for_all_frames(self, move_type: MoveType, cube: AState):
        "Removes the schedule (as a cube), so that the `move_type` is no longer considered in that cube."
        for frame in self.frames:
            negation = global_.abstract_space.get_negation_of_cube(cube)
            frame.reach_solvers[move_type].assert_sched_condition(negation)
        global_.blocking_manager.removed_schedules.add((move_type, cube))


class AcexCommonParentSolver:
    """
    Used by `AcexRefiner`.
    Checks if there is a common concrete state reaching the target abstract states at the same time.

    (Note: the lifecycle of its object is just one refinement. Beware in case this changes.)
    """
    solver: bt.Boolector
    "The SAT solver used for this solver."
    pm: PM.ProductMachine
    "The product machine used with the solver."

    def __init__(self):
        self.solver = bt.Boolector()
        self.solver.Set_opt(bt.BTOR_OPT_MODEL_GEN, True)
        self.solver.Set_opt(bt.BTOR_OPT_INCREMENTAL, True)
        self.pm = PM.ProductMachine(self.solver, "@A", "@B", global_.btor2_lines_a, global_.btor2_lines_b)
        height = global_.abstract_space.max_height + 1
        self.pm.grow(height, height)  # NOTE: this being inside `__init__` is one consequence of such lifecycle

    def query_for_reachability(self, start: AState, ends: list[AState]) -> Optional[tuple[State, list[State]]]:
        "Queries if there is a common parent for the given abstract states."
        start_node = self.pm.cube_to_node(global_.abstract_space.get_cube(start), 0, 0)
        self.solver.Assume(start_node)
        for move_type in global_.all_move_types():
            ex_prime_a = 0 if move_type == ONLY_B_MOVES else 1
            ex_prime_b = 0 if move_type == ONLY_A_MOVES else 1
            end_node = self.pm.cube_to_node(global_.abstract_space.get_cube(ends[move_type]), ex_prime_a, ex_prime_b)
            self.solver.Assume(end_node)
        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            return None
        if res != self.solver.SAT:
            raise RuntimeError("Unexpected SAT result: " + str(res))
        start_state = extract_concrete_state(self.pm, 0, 0)
        end_states = []
        for move_type in global_.all_move_types():
            ex_prime_a = 0 if move_type == ONLY_B_MOVES else 1
            ex_prime_b = 0 if move_type == ONLY_A_MOVES else 1
            end_states.append(extract_concrete_state(self.pm, ex_prime_a, ex_prime_b))
        return start_state, end_states


class AcexRefiner:
    """
    Check and refine the abstract counterexample tree.

    (Note: the lifecycle of its object is just one refinement. Beware in case this changes.)
    """
    root: AcexTreeNode
    "The root of the abstract counterexample tree."
    common_parent_solver: AcexCommonParentSolver
    "The solver used to check if there is a common parent for two abstract states."
    ## intermediate results:
    nodes_in_topo_order: list[AcexTreeNode]
    "The nodes in the tree in topological order."
    bad_nodes_to_witness: dict[AcexTreeNode, State]
    "The nodes that are bad (i.e. minimal r-cube containing what moves to a neg-P state)."
    greedy_n: int
    "The bound for the greedy scheduling. (Can be e.g. set to 0, or passed from `Checker`)"
    type_1_nodes: dict[AcexTreeNode, tuple[State, State]]  # TODO: 1-to-many, should be a list instead??
    "The nodes for type-1 refinement: reachability granularity. Bad(.0) v.s. Good(.1)."
    type_2_nodes: dict[AcexTreeNode, list[State]]
    "The nodes for type-2 refinement: scheduling granularity. Each reaches bad for some move_type."
    ## final results:
    cex: list[State]
    "The (concrete) counterexample trace."  # For now, can just give the initial state... will fix
    new_atoms: list[PM.Atom]
    "The new atoms to be added to the circuit."

    def __init__(self, root: AcexTreeNode, greedy_n: int):
        self.root = root
        self.common_parent_solver = AcexCommonParentSolver()
        self.nodes_in_topo_order = []
        self.bad_nodes_to_witness = {}
        self.greedy_n = greedy_n
        self.type_1_nodes = {}
        self.type_2_nodes = {}


    def _get_nodes_in_topo_order(self, node: AcexTreeNode, in_list: set[AcexTreeNode]) -> set[AcexTreeNode]:
        "get & save all nodes (from leaves) in topological order, note children may be shared!!"
        if node in in_list:
            return in_list
        if len(node.children) != 0:
            for move_type in global_.all_move_types():
                _, child = node.children[move_type]
                in_list = self._get_nodes_in_topo_order(child, in_list)
        in_list.add(node)
        self.nodes_in_topo_order.append(node)
        return in_list
    
    def _filter_bad_nodes(self):
        "Filters out the bad nodes in the tree."
        for node in self.nodes_in_topo_order:
            node_r_cube = global_.abstract_space.min_cube_to_r_cube(node.min_cube)
            state, _ = global_.greedy_runner.run_from_abstract_state(node.height, node_r_cube)
            if state is not None:
                if node.height > 0:
                    global_.blocking_manager.hard_blockers.append((state, node.height))
                self.bad_nodes_to_witness[node] = state
                self.greedy_n = max(self.greedy_n, node.height)

    def _sort_out_type_1_and_type_2_nodes(self):
        "Sorts out the type-1 and type-2 nodes in the tree."
        for node in self.nodes_in_topo_order:
            bad_end_count = 0
            if len(node.children) == 0:
                continue  # needs to have children
            for move_type in global_.all_move_types():
                edge, child = node.children[move_type]
                if not (global_.greedy_runner.run_from_concrete_state(child.height, edge.end_state)):
                    bad_end_count += 1
                    continue  # needs a good concrete state in child
                if not (child in self.bad_nodes_to_witness):
                    continue  # needs a bad concrete state in child
                self.type_1_nodes[node] = (self.bad_nodes_to_witness[child], edge.end_state)
            if node in self.bad_nodes_to_witness:
                continue  # needs the parent itself to be good
            if bad_end_count < len(global_.all_move_types()):
                continue  # needs all ends to be bad
            ends = []
            for move_type in global_.all_move_types():
                edge, _ = node.children[move_type]
                ends.append(edge.end_min_cube)
            assignment = self.common_parent_solver.query_for_reachability(node.min_cube, ends)
            if assignment is not None:
                _, end_states = assignment
                for move_type in global_.all_move_types():
                    edge, child = node.children[move_type]
                    end_state = end_states[move_type]
                    if global_.greedy_runner.run_from_concrete_state(child.height, end_state):
                        self.type_1_nodes[node] = (edge.end_state, end_state)
            else:
                start_states = []
                for move_type in global_.all_move_types():
                    edge, _ = node.children[move_type]
                    start_states.append(edge.start_state)
                self.type_2_nodes[node] = start_states

    def _refine_type_1(self, node: AcexTreeNode):  # TODO: super naive, should be re-designed (in many ways...)
        "Refines on the specified type-1 node."
        bad_state, good_state = self.type_1_nodes[node]
        prime = 0
        while True:
            term_value_to_wire_a: dict[Value, list[int]] = {}
            term_value_to_wire_b: dict[Value, list[int]] = {}
            wires = global_.simulator.pm.get_all_wires([prime], [prime])
            good_values = global_.simulator.simulate(good_state, wires)
            bad_values = global_.simulator.simulate(bad_state, wires)
            for i, wire in enumerate(wires):
                good_value, bad_value = good_values[i], bad_values[i]
                if isinstance(good_value, str) and len(good_value) == 1:  # single-bit
                    if good_value != bad_value:
                        self.new_atoms = [wire]
                        return
                if wire.is_a:
                    this, that = term_value_to_wire_a, term_value_to_wire_b
                else:
                    this, that = term_value_to_wire_b, term_value_to_wire_a
                this.setdefault(good_value, []).append(i)
                for j in that.get(good_value, []):  # equality relation
                    if bad_values[i] != bad_values[j]:
                        self.new_atoms = [PM.EqWirePair(wires[i], wires[j])]
                        return
            prime += 1
    
    def _refine_type_2(self, node: AcexTreeNode):  # TODO: super naive, should be re-designed
        "Refines on the specified type-2 node."
        start_states = self.type_2_nodes[node]
        prime = 0
        while True:
            wires = global_.simulator.pm.get_all_wires([prime], [prime])
            values = []
            for move_type in global_.all_move_types():
                start_state = start_states[move_type]
                values.append(global_.simulator.simulate(start_state, wires))
            for i, wire in enumerate(wires):
                if isinstance(values[0][i], str) and len(values[0][i]) == 1:  # single-bit
                    if values[0][i] != values[1][i] or (len(values) == 3 and (values[0][i] != values[2][i] or values[1][i] != values[2][i])):
                        self.new_atoms = [wire]
                        return
            prime += 1

    def refine(self) -> bool:
        "The main function to refine the abstract counterexample tree. Returns `False` if concrete cex is found."
        self._get_nodes_in_topo_order(self.root, set())
        self._filter_bad_nodes()
        if self.root in self.bad_nodes_to_witness:
            self.cex = [self.bad_nodes_to_witness[self.root]]  # only 1 state, will fix
            return False
        self._sort_out_type_1_and_type_2_nodes()
        ## TODO: below is naive, should be re-designed
        if len(self.type_1_nodes) > 0:
            self._refine_type_1(next(iter(self.type_1_nodes)))
        else:
            self._refine_type_2(next(iter(self.type_2_nodes)))
        return True


class Checker:
    """
    Checks whether two circuits are stuttering equivalent.
    - When equivalent, also returns the invariant.
    - When non-equivalent, also returns the cex trace/tree.
    """
    invariant: list[PM.Cube]
    "The invariant returned (as DNF) when the two circuits are equivalent."
    cex_trace: list[State]
    "The counterexample trace returned when the two circuits are non-equivalent."
    # some hyper-params?
    greedy_n: int
    "The bound for the greedy scheduling."

    def __init__(self, btor2_lines_a: list[str], btor2_lines_b: list[str], fast_slow_mode: bool, greedy_n = 10):
        global_.btor2_lines_a = btor2_lines_a
        global_.btor2_lines_b = btor2_lines_b
        global_.fast_slow_mode = fast_slow_mode
        global_.abstract_space = AbstractSpace()
        global_.blocking_manager = BlockingManager()
        global_.simulator = Simulator()
        global_.greedy_runner = GreedySchedulingRunner()
        global_.obligation_manager = ObligationManager()
        global_.frame_manager = FrameManager()
        self.invariant = None
        self.cex_trace = None
        self.greedy_n = greedy_n

    def _initial_check(self) -> bool:
        "Essentially checks if I /\ ~P is satisfiable."
        solver = bt.Boolector()
        solver.Set_opt(bt.BTOR_OPT_MODEL_GEN, True)
        pm = PM.ProductMachine(solver, "@A", "@B", global_.btor2_lines_a, global_.btor2_lines_b)
        global_.I, global_.P = pm.I(), pm.P()
        global_.state_var_wires = pm.all_state_wires()
        node_I = pm.dnf_to_node(global_.I, 0, 0)
        node_P = pm.dnf_to_node(global_.P, 0, 0)
        solver.Assert(node_I)
        solver.Assert(solver.Not(node_P))
        res = solver.Sat()
        if res == solver.UNSAT:  # check pass!
            return True
        if res != solver.SAT:
            raise RuntimeError("Unexpected SAT result: " + str(res))
        # counterexample found, print it:
        state = extract_concrete_state(pm, 0, 0)
        self.cex_trace = [state]
        return False
    
    def _init_abstract_space(self):
        "Initializes the abstract space, which minimally includes all atoms in I and P."
        # assumes global_.I and global_.P are already set
        atom_set = set()
        for cube in global_.I + global_.P:
            for _, atom in cube:
                atom_set.add(atom)
        for atom in atom_set:
            global_.abstract_space.add_atom(atom)

    def _add_initial_obligations(self) -> bool:
        "Adds initial obligations (those starting from `I`) to the priority queue. Returns `True` iff added."
        F_k = global_.frame_manager.get_frame_k()
        added = False
        for solver in F_k.reach_solvers:
            edge = solver.sat_query_for_initial()
            if edge is None:
                continue
            global_.obligation_manager.add(global_.frame_manager.k(), ObligationTreeNode(edge), None)
            added = True
        return added

    def _handle_obligation(self, level: int, node: ObligationTreeNode) -> Optional[AcexTreeNode]:
        "Handles an obligation at the given level."
        bm, fm, om = global_.blocking_manager, global_.frame_manager, global_.obligation_manager
        edge = node.edge
        # case 1/3 of reaching a bad state:
        if level == 0:
            bm.soft_blockers[edge.end_min_cube] = AcexTreeNode(edge.end_min_cube, [], 0)
            return bm.update_scheduling(edge.end_min_cube)
        # case 2/3 of reaching a bad state:
        obligation = global_.abstract_space.min_cube_to_r_cube(edge.end_min_cube)
        for blocker in bm.soft_blockers:
            if cube_syntactically_contains(obligation, blocker):
                return bm.update_scheduling(blocker)
        # case 3/3 of reaching a bad state:
        new_state, new_blocker = global_.greedy_runner.run_from_abstract_state(self.greedy_n, obligation)
        if new_state is not None:
            bm.hard_blockers.append((new_state, self.greedy_n))
            bm.soft_blockers[new_blocker] = AcexTreeNode(new_blocker, [], self.greedy_n)
            return bm.update_scheduling(new_blocker)
        # if obligation is already pushed out, refind where it is:
        for i in range(fm.k(), level-1, -1):  # level i := k, k-1, ..., level
            if dnf_syntactically_contains(fm.get_frame(i).delta_clauses, edge.end_min_cube):  # out of i!
                return om.finish(node) if i == fm.k() else om.re_add(i + 1, node)
        # if SAT, add obligations:
        F_i = fm.get_frame(level - 1)
        cores = []
        for solver in F_i.reach_solvers:
            witness_or_core = solver.sat_query_for_obligation(obligation)
            if isinstance(witness_or_core, AcexTreeEdge):
                om.add(level - 1, ObligationTreeNode(witness_or_core), node)
            else:
                cores.append(witness_or_core)
        # if UNSAT, push the obligation out:
        if len(cores) == len(F_i.reach_solvers):
            fm.add_clause(level, unsat_core_intersection(cores))
            return om.finish(node) if level == fm.k() else om.re_add(level + 1, node)
        om.re_add(level, node)

    def _propagation_phase(self) -> bool:
        "Runs the propagation phase of the IC3 algorithm. Returns `True` iff an invariant is found."
        fm = global_.frame_manager
        for i in range(fm.k()):
            F_i = fm.get_frame(i)
            clauses = [clause for clause in F_i.delta_clauses]  # copy before mutation
            for clause in clauses:
                if not any(solver.sat_query_for_propagation(clause) for solver in F_i.reach_solvers):
                    fm.propagate_clause(i, clause)
            if len(F_i.delta_clauses) == 0:  # inductive invariant found!
                self.invariant = []
                for j in range(i + 1, fm.k() + 1):
                    for clause in fm.get_frame(j).delta_clauses:  ## TODO: simplify the invariant!
                        self.invariant.append(global_.abstract_space.get_cube(clause))
                return True
        return False

    def epoch(self) -> Optional[AcexTreeNode]:
        "Runs one epoch of the equivalence checking algorithm."
        bm, fm, om = global_.blocking_manager, global_.frame_manager, global_.obligation_manager
        acex = bm.reset_for_new_abstract_space()
        if acex is not None:
            return acex
        fm.reset()
        om.reset_all()
        while True:  # the IC3 loop:
            om.undone_all(fm.k())
            while True:  # the obligation loop:
                if om.is_done() and not self._add_initial_obligations():
                    break
                level, node = om.pop()  # choose an innermost obligation
                acex = self._handle_obligation(level, node)
                if acex is not None:
                    return acex
            fm.add_frame()
            if self._propagation_phase():
                return None

    def check(self) -> bool:
        "The main function of the equivalence checker."
        if not self._initial_check():
            return False
        self._init_abstract_space()
        while True:  # the "epoch" loop:
            print("new epoch")
            acex = self.epoch()
            if acex is None:  # not even an abstract counterexample; success!
                self.print_invariant()
                return True
            refiner = AcexRefiner(acex, 0)  # perform a refinement
            if not refiner.refine():
                self.cex_trace = refiner.cex
                return False
            self.greedy_n = refiner.greedy_n
            for atom in refiner.new_atoms:
                global_.abstract_space.add_atom(atom)

    def print_invariant(self):
        "Prints the invariant found."
        print("Invariant:")
        for cube in self.invariant:
            print(cube)
