from typing import Union, Optional
from torch import Tensor

import math
import random
import pandas as pd
import time

import torch
import torch.nn.functional as F
import torch.optim as optim
import os

import model_adapter as ma

from model_adapter import DTYPE, DEVICE

LOGGER = f"trace_{int(time.time())}.log"
def print_tee(text: str, file=LOGGER):
    "Prints the text to both the console and a file."
    print(text)
    mode = "a" if os.path.exists(file) else "w"
    with open(file, mode) as f:
        f.write(text + "\n")
        f.flush()


def choose_outliers_by_variance(mlp: ma.LockableMLP, outlier_count: int) -> tuple[list[int], list[float]]:
    assert mlp.last_tensor_intermediate is not None, "No intermediate tensor found. Please run the model with `keeps_last_tensors=True`."
    matrix = mlp.last_tensor_intermediate[0]
    _, n = matrix.shape
    assert n >= outlier_count, "The number of outliers cannot exceed the number of dimensions."

    variance = matrix.var(dim=0)
    outliers = torch.topk(variance, outlier_count)
    return outliers.indices.tolist(), outliers.values.tolist()


def choose_outliers_by_scoring(mlp: ma.LockableMLP, outlier_count: int, tau: float) -> tuple[list[int], list[float], int, float]:
    assert mlp.last_tensor_intermediate is not None, "No intermediate tensor found. Please run the model with `keeps_last_tensors=True`."
    assert mlp.last_tensor_output is not None, "No output tensor found. Please run the model with `keeps_last_tensors=True`."
    matrix_i = mlp.last_tensor_intermediate[0]
    matrix_o = mlp.last_tensor_output[0]
    _, n = matrix_i.shape
    assert n >= outlier_count, "The number of outliers cannot exceed the number of dimensions."

    avg_magnitude_o = matrix_o.abs().mean(dim=0)
    feature_outlier_indices = torch.where(avg_magnitude_o > tau * avg_magnitude_o.mean())[0]
    feature_outlier_count = feature_outlier_indices.shape[0]
    max_ratio = (avg_magnitude_o / avg_magnitude_o.mean()).max().item()

    if feature_outlier_count == 0:
        return None, None, 0, max_ratio  # no feature outliers for given tau
    
    weight_abs_sums = mlp.down_proj.weight[feature_outlier_indices, :].abs().sum(dim=0)
    avg_magnitude_i = matrix_i.abs().mean(dim=0)
    scores = weight_abs_sums * avg_magnitude_i
    outliers = torch.topk(scores, outlier_count)
    return outliers.indices.tolist(), outliers.values.tolist(), feature_outlier_count, max_ratio


def run_on_text(adapter: ma.Adapter, text: str):
    "Runs the model on the given text to populate the intermediate and output tensors."
    mlps = adapter.get_mlp_list()
    mlp_count = len(mlps)
    for i in range(mlp_count):
        mlps[i].keeps_last_tensors = True

    with torch.no_grad():
        inputs = adapter.tokenizer(text, return_tensors="pt").to(adapter.model.device)
        _ = adapter.model(**inputs)

    for i in range(mlp_count):
        mlps[i].keeps_last_tensors = False


def hadamard_matrix(n: int):
    "Generates a 2^n x 2^n Hadamard matrix as a PyTorch tensor."
    if n == 0:
        return torch.tensor([[1]], dtype=DTYPE, device=DEVICE)
    else:
        H_prev = hadamard_matrix(n - 1)
        H = torch.kron(torch.tensor([[1, 1], [1, -1]], dtype=DTYPE, device=DEVICE), H_prev)
        return H
    
def get_random_rotation_matrix(n: int):
    "Generates a random rotation matrix of size 2^n x 2^n."
    x = hadamard_matrix(n)
    x *= torch.randint(0, 2, (2**n,), dtype=DTYPE, device=DEVICE) * 2 - 1
    x /= math.sqrt(2**n)
    return x

def get_id_rotation_matrix(n: int):
    "Generates an identity rotation matrix of size 2^n x 2^n."
    return torch.eye(2**n, dtype=DTYPE, device=DEVICE)

def get_random_key(size: int):
    return [random.choice([True, False]) for _ in range(size)]

def get_random_permutational_key(size: int, way=None) -> list[int]:
    if way is None:
        way = size
    assert size % way == 0, "Size must be divisible by way."
    key = []
    for i in range(0, size, way):
        sub_key = random.sample(range(way), way)
        key.extend(j + i for j in sub_key)
    return key

def get_random_key_with_flipping_count(key: list[bool], flip_count: int) -> list[bool]:
    assert flip_count <= len(key), "Flip count must be less than or equal to the key size."
    flipped_key = key.copy()
    indices_to_flip = random.sample(range(len(key)), flip_count)
    for i in indices_to_flip:
        flipped_key[i] = not flipped_key[i]
    return flipped_key


def apply_random_lock(
    adapter: ma.Adapter, layer_index: int, log2_of_key_size: int, rotates: bool, permutes: bool,
    way: Optional[int] = None
) -> tuple[list[bool], Optional[list[int]]]:
    adapter.replace_nth_mlp_into_lockable_mlp(layer_index)
    mlp = adapter.get_mlp_list()[layer_index]

    key_size = 2 ** log2_of_key_size
    locked_indices = random.sample(range(mlp.intermediate_size), key_size)
    key = get_random_key(key_size)
    rotation_matrix = (get_random_rotation_matrix if rotates else get_id_rotation_matrix)(log2_of_key_size)
    permutation_key = get_random_permutational_key(key_size, way) if permutes else None
    mlp.lock(locked_indices, key, rotation_matrix, permutation_key)
    return (key, permutation_key)

def apply_outlier_lock(
    adapter: ma.Adapter, layer_index: int, log2_of_key_size: int, rotates: bool, permutes: bool, tau=5.0,
    reference_file="dataset/training_text.txt", offset=0, length=500,
    way: Optional[int] = None
) -> tuple[list[bool], Optional[list[int]]]:
    adapter.replace_nth_mlp_into_lockable_mlp(layer_index)
    mlp = adapter.get_mlp_list()[layer_index]

    with open(reference_file, "r") as f:
        lines = f.readlines()
        text = "".join(lines[offset : offset + length])
        run_on_text(adapter, text)

    key_size = 2 ** log2_of_key_size
    locked_indices, _, _, _, = choose_outliers_by_scoring(mlp, key_size, tau)
    key = get_random_key(key_size)
    rotation_matrix = (get_random_rotation_matrix if rotates else get_id_rotation_matrix)(log2_of_key_size)
    permutation_key = get_random_permutational_key(key_size, way) if permutes else None
    mlp.lock(locked_indices, key, rotation_matrix, permutation_key, way)
    return (key, permutation_key)


def get_mmlu_prompt(question: dict[str, str]) -> str:
    return "\n".join([
        f"{question['prompt']}",
        f"A. {question['A']}",
        f"B. {question['B']}",
        f"C. {question['C']}",
        f"D. {question['D']}",
        "Answer:"
    ])

class MMLU_Evaluator:
    questions: list[dict[str, str]]
    question_indices: list[int]  # indices of questions to be evaluated

    def __init__(self, mmlu_file="dataset/mmlu_test.csv"):
        df = pd.read_csv(mmlu_file)
        self.questions = df.to_dict(orient="records")
        self.question_indices = list(range(len(self.questions)))  # all indices by default
    
    def set_random_index_subset(self, subset_size: int):
        self.question_indices = random.sample(range(len(self.questions)), subset_size)

    def evaluate(self, adapter: ma.Adapter) -> float:
        correct = 0
        for i, index in enumerate(self.question_indices):
            question = self.questions[index]
            print(f"Testing question {i + 1}/{len(self.question_indices)}...", end="\r")
            prompt = get_mmlu_prompt(question)
            inputs = adapter.tokenizer(prompt, return_tensors="pt").to(DEVICE)
            with torch.no_grad():
                outputs = adapter.model(**inputs)
                predicted_token_index = torch.argmax(outputs.logits[:, -1, :], dim=-1).item()
                predicted_token = adapter.tokenizer.decode(predicted_token_index)
                if predicted_token.strip() == question["answer"]:
                    correct += 1
        accuracy = correct / len(self.question_indices)
        return accuracy


class PPL_Evaluator:
    reference_texts: list[str]
    text_indices: list[int]  # indices of texts to be evaluated
    skip_first_n_tokens: int  # as the first n tokens usually have high perplexity, we skip them

    def __init__(self, reference_file="dataset/test_text.txt", lines_per_chunk=100, skip_first_n_tokens=10):
        with open(reference_file, "r") as f:
            lines = f.readlines()
        self.reference_texts = ["".join(lines[i : i + lines_per_chunk]) for i in range(0, len(lines), lines_per_chunk)][:-1]
        self.text_indices = list(range(len(self.reference_texts)))
        self.skip_first_n_tokens = skip_first_n_tokens

    def set_random_index_subset(self, subset_size: int):
        self.text_indices = random.sample(range(len(self.reference_texts)), subset_size)

    def evaluate(self, adapter: ma.Adapter) -> float:
        ## Credit: Thanks Gemini 2.5 Pro for helping design this function!!
        curr_cross_entropy, curr_token_count = 0.0, 0
        for i, index in enumerate(self.text_indices):
            text = self.reference_texts[index]
            print(f"Testing text {i + 1}/{len(self.text_indices)}...", end="\r")
            inputs = adapter.tokenizer(text, return_tensors="pt").to(DEVICE)
            with torch.no_grad():
                outputs = adapter.model(**inputs)
                shift_logits = outputs.logits[..., :-1, :].contiguous()
                shift_labels = inputs['input_ids'][..., 1:].contiguous()

                # Flatten the tokens
                shift_logits = shift_logits.view(-1, shift_logits.size(-1))
                shift_labels = shift_labels.view(-1)

                loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
                loss = loss_fct(shift_logits, shift_labels)
                num_tokens = loss.shape[0] - self.skip_first_n_tokens
                curr_cross_entropy += loss[self.skip_first_n_tokens:].sum().item()
                curr_token_count += num_tokens
        mean_cross_entropy = curr_cross_entropy / curr_token_count
        perplexity = math.exp(mean_cross_entropy)
        return perplexity


def average_jsd(logits_oracle: Tensor, logits_eval: Tensor) -> Tensor:
    ## Credit: Thanks Gemini 2.5 Pro for debugging and fixing this function!!
    logits_oracle = logits_oracle.to(torch.float32)
    logits_eval = logits_eval.to(torch.float32)

    probs_oracle = F.softmax(logits_oracle, dim=-1)
    probs_eval = F.softmax(logits_eval, dim=-1)

    epsilon_probs = 1e-12
    probs_oracle_clamped = torch.clamp(probs_oracle, min=epsilon_probs)
    probs_eval_clamped = torch.clamp(probs_eval, min=epsilon_probs)
    mean_probs = (probs_oracle_clamped + probs_eval_clamped) / 2

    epsilon_log = 1e-12
    log_mean_probs = (mean_probs + epsilon_log).log()

    kl_oracle = F.kl_div(log_mean_probs, probs_oracle_clamped, reduction="batchmean", log_target=False)
    kl_eval = F.kl_div(log_mean_probs, probs_eval_clamped, reduction="batchmean", log_target=False)

    jsd = (kl_oracle + kl_eval) / 2
    return jsd

class JSD_LossEvaluator:
    reference_texts: list[str]
    reference_logits: list[Tensor]

    def __init__(self, reference_file="dataset/training_text.txt", lines_per_chunk=100, chunk_count=24):
        with open(reference_file, "r") as f:
            lines = f.readlines()
        assert len(lines) >= lines_per_chunk * chunk_count, "Not enough lines in the reference file."
        self.reference_texts = ["".join(lines[i : i + lines_per_chunk]) for i in range(0, lines_per_chunk * chunk_count, lines_per_chunk)]
        self.reference_logits = []

    def set_reference_logits(self, adapter: ma.Adapter, layer_index: int):
        assert len(self.reference_logits) == 0, "Reference logits already set. Please clear them before setting again."
        mlp = adapter.get_mlp_list()[layer_index]
        assert mlp.locked, "Please lock the MLP first."
        mlp.set_runtime_key(mlp.correct_key)
        saved_key_module_enabled = mlp.key_module_enabled
        mlp.key_module_enabled = False
        if mlp.correct_permutation_key is not None:
            mlp.set_runtime_permutation(mlp.correct_permutation_key)
        saved_permutation_key_enabled = mlp.permutation_key_module_enabled
        mlp.permutation_key_module_enabled = False

        for i, text in enumerate(self.reference_texts):
            print(f"Getting logits for text {i + 1}/{len(self.reference_texts)}...", end="\r")
            inputs = adapter.tokenizer(text, return_tensors="pt").to(DEVICE)
            with torch.no_grad():
                outputs = adapter.model(**inputs)
                logits = outputs.logits[0]
                self.reference_logits.append(logits)
        mlp.key_module_enabled = saved_key_module_enabled
        mlp.permutation_key_module_enabled = saved_permutation_key_enabled

    def get_loss(self, adapter: ma.Adapter, text_index: int, sgd: bool) -> Tensor:
        assert len(self.reference_logits) == len(self.reference_texts), "Reference logits length mismatch."
        text = self.reference_texts[text_index]
        inputs = adapter.tokenizer(text, return_tensors="pt").to(DEVICE)
        reference_logits = self.reference_logits[text_index]

        def calc_loss() -> Tensor:
            outputs = adapter.model(**inputs)
            logits = outputs.logits[0]
            loss = average_jsd(reference_logits, logits)
            return loss
        if sgd:
            loss = calc_loss()
        else:
            with torch.no_grad():
                loss = calc_loss()
        return loss
    
    def evaluate(self, adapter: ma.Adapter) -> float:
        total_loss = 0.0
        for i in range(len(self.reference_texts)):
            print(f"Testing text {i + 1}/{len(self.reference_texts)}...", end="\r")
            loss = self.get_loss(adapter, i, False)
            total_loss += loss.item()
        average_loss = total_loss / len(self.reference_texts)
        return average_loss
    
    
def average_cross_entropy(reference_tokens: Tensor, logits_eval: Tensor) -> Tensor:
    logits_eval = logits_eval.to(torch.float32)
    # reference_tokens: shape [seq_len], logits_eval: shape [seq_len, vocab_size]
    loss = F.cross_entropy(logits_eval, reference_tokens, reduction='mean')
    return loss

class CrossEntropy_LossEvaluator:
    reference_texts: list[str]
    reference_tokens: list[Tensor]
    
    def __init__(self, reference_file="dataset/training_text.txt", lines_per_chunk=100, chunk_count=24):
        with open(reference_file, "r") as f:
            lines = f.readlines()
        assert len(lines) >= lines_per_chunk * chunk_count, "Not enough lines in the reference file."
        self.reference_texts = ["".join(lines[i : i + lines_per_chunk]) for i in range(0, lines_per_chunk * chunk_count, lines_per_chunk)]
        self.reference_tokens = []

    def modify_reference_texts_with_oracle(self, adapter: ma.Adapter, layer_index: int, prompt_length=10, temperature=0.7):
        mlp = adapter.get_mlp_list()[layer_index]
        assert mlp.locked, "Please lock the MLP first."
        mlp.set_runtime_key(mlp.correct_key)
        saved_key_module_enabled = mlp.key_module_enabled
        mlp.key_module_enabled = False
        if mlp.correct_permutation_key is not None:
            mlp.set_runtime_permutation(mlp.correct_permutation_key)
        saved_permutation_key_enabled = mlp.permutation_key_module_enabled
        mlp.permutation_key_module_enabled = False

        for i, text in enumerate(self.reference_texts):
            print(f"Generating text {i + 1}/{len(self.reference_texts)}...", end="\r")
            inputs = adapter.tokenizer(text, return_tensors="pt").to(DEVICE)
            targeted_length = inputs["input_ids"].shape[1]
            prompt = inputs["input_ids"][0][:prompt_length]
            outputs = adapter.model.generate(
                prompt.unsqueeze(0),
                do_sample=True,
                max_length=targeted_length,
                temperature=temperature,
                top_k=0,
                top_p=0.9,
                num_return_sequences=1,
            )
            self.reference_texts[i] = adapter.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # while prompt.shape[0] < targeted_length:
            #     with torch.no_grad():
            #         outputs = adapter.model(**inputs)
            #         last_token_logits = outputs.logits[0, -1, :] / temperature
            #         last_token_logits = last_token_logits - torch.max(last_token_logits)
            #         probs = torch.softmax(last_token_logits, dim=-1)
            #         next_token = torch.multinomial(probs, num_samples=1)
            #         next_token = next_token[0, 0]
            #         prompt = torch.cat((prompt, next_token.unsqueeze(0)), dim=0)
            # self.reference_texts[i] = adapter.tokenizer.decode(prompt)

            # sanity check:
            print(f"Oracle generated text {i + 1}/{len(self.reference_texts)}: {self.reference_texts[i]}")

        mlp.key_module_enabled = saved_key_module_enabled
        mlp.permutation_key_module_enabled = saved_permutation_key_enabled
    
    def set_reference_tokens(self, adapter: ma.Adapter):
        assert len(self.reference_tokens) == 0, "Reference tokens already set. Please clear them before setting again."
        for text in self.reference_texts:
            tokens = adapter.tokenizer(text, return_tensors="pt").to(DEVICE)["input_ids"][0]
            self.reference_tokens.append(tokens)

    def get_loss(self, adapter: ma.Adapter, text_index: int, sgd: bool, skip_first_n_tokens=10) -> Tensor:
        assert len(self.reference_tokens) == len(self.reference_texts), "Reference tokens length mismatch."
        text = self.reference_texts[text_index]
        inputs = adapter.tokenizer(text, return_tensors="pt").to(DEVICE)
        reference_tokens = self.reference_tokens[text_index]

        def calc_loss() -> Tensor:
            outputs = adapter.model(**inputs)
            logits = outputs.logits[0][skip_first_n_tokens:-1]
            tokens = reference_tokens[skip_first_n_tokens+1:]
            loss = average_cross_entropy(tokens, logits)
            return loss
        if sgd:
            loss = calc_loss()
        else:
            with torch.no_grad():
                loss = calc_loss()
        return loss


def evaluate_specific_key(
    adapter: ma.Adapter, mlp: ma.LockableMLP, runtime_key: list[bool], runtime_permutation_key: Optional[list[int]],
    mmlu_evaluator: MMLU_Evaluator, ppl_evaluator: PPL_Evaluator,
    enforce_binary=True,
) -> tuple[float, float, float, float]:
    if enforce_binary:
        saved_key_module_enabled = mlp.key_module_enabled
        mlp.key_module_enabled = False
        saved_permutation_key_enabled = mlp.permutation_key_module_enabled
        mlp.permutation_key_module_enabled = False
        if mlp.correct_permutation_key is not None:
            mlp.set_runtime_permutation(runtime_permutation_key)
        mlp.set_runtime_key(runtime_key)

    key_size = len(runtime_key)
    fidelity = sum(1 for i in range(key_size) if runtime_key[i] == mlp.correct_key[i]) / key_size
    mmlu = mmlu_evaluator.evaluate(adapter)
    ppl = ppl_evaluator.evaluate(adapter)

    permutation_fidelity = None
    if mlp.correct_permutation_key is not None:
        permutation_fidelity = sum(1 for i in range(key_size) if runtime_permutation_key[i] == mlp.correct_permutation_key[i]) / key_size

    if enforce_binary:
        mlp.key_module_enabled = saved_key_module_enabled
        mlp.permutation_key_module_enabled = saved_permutation_key_enabled
    return fidelity, mmlu, ppl, permutation_fidelity


class GA_Parameters:
    POPULATION_SIZE = 50
    MUTATION_RATE = 0.01
    CROSSOVER_RATE = 0.8
    TOURNAMENT_SIZE = 5
    CUTOFF = 0.00001

class GA_Attacker:
    adapter: ma.Adapter
    mlp: ma.LockableMLP

    mmlu_evaluator: MMLU_Evaluator
    ppl_evaluator: PPL_Evaluator
    jsd_loss_evaluator: JSD_LossEvaluator

    way: int  # way for the permutational key, must be greater than 0

    def __init__(self, adapter: ma.Adapter, layer_index: int, reference_file="dataset/training_text.txt",
                 lines_per_chunk=100, chunk_count=10,
                 way=None):
        assert way > 0, "Way must be greater than 0."
        self.way = way

        self.adapter = adapter
        self.mlp = adapter.get_mlp_list()[layer_index]
        self.mlp.key_module_enabled = False
        self.mlp.permutation_key_module_enabled = False

        self.mmlu_evaluator = MMLU_Evaluator()
        self.mmlu_evaluator.set_random_index_subset(256)
        self.ppl_evaluator = PPL_Evaluator("dataset/test_text.txt")

        self.jsd_loss_evaluator = JSD_LossEvaluator(reference_file, lines_per_chunk, chunk_count)
        self.jsd_loss_evaluator.set_reference_logits(adapter, layer_index)

        if self.mlp.correct_permutation_key is None:
            raise ValueError("Should use a p-key")
            # raise NotImplementedError("Permutation key is not supported in GA attack. Please use SGD attack instead.")

    def calculate_fitness(self, runtime_permutation: list[int], text_index: int) -> float:
        "Calculates fitness by JSD."
        # self.mlp.set_runtime_key(runtime_key)
        self.mlp.set_runtime_permutation(runtime_permutation)
        loss = self.jsd_loss_evaluator.get_loss(self.adapter, text_index, False)
        return loss.item()
        
    def initialize_population(self) -> list[list[int]]:
        "Creates the initial population of random keys."
        # key_size = len(self.mlp.correct_key)
        key_size = len(self.mlp.correct_permutation_key)
        # return [get_random_key(key_size) for _ in range(GA_Parameters.POPULATION_SIZE)]
        return [get_random_permutational_key(key_size, self.way) for _ in range(GA_Parameters.POPULATION_SIZE)]
    
    def evaluate_population(self, population: list[list[int]]) -> list[float]:
        "Calculates the fitness (perplexity) for each individual in the population."
        fitnesses = []
        text_index = random.randint(0, len(self.jsd_loss_evaluator.reference_texts) - 1)  # Randomly select a text index for evaluation
        for i, key in enumerate(population):
            print(f"Evaluating individual {i + 1}/{len(population)}...", end="\r")
            fitness = self.calculate_fitness(key, text_index)
            fitnesses.append(fitness)
        return fitnesses
        
    def tournament_selection(self, population: list[list[int]], fitnesses: list[float]) -> list[int]:
        "Selects a parent using tournament selection (minimization)."
        selection_ix = random.randint(0, len(population) - 1)
        for _ in range(GA_Parameters.TOURNAMENT_SIZE - 1):
            ix = random.randint(0, len(population) - 1)
            # Select the individual with lower perplexity (better fitness)
            if fitnesses[ix] < fitnesses[selection_ix]:
                selection_ix = ix
        return population[selection_ix]
    
    def single_point_crossover(self, parent1: list[int], parent2: list[int]) -> tuple[list[int], list[int]]:
        "Performs single-point crossover between two parents."
        # can only cross at way boundaries

        if random.random() < GA_Parameters.CROSSOVER_RATE:
            point = (random.randint(1, len(parent1) - 1) // self.way) * self.way
            child1 = parent1[:point] + parent2[point:]
            child2 = parent2[:point] + parent1[point:]
            return child1, child2
        else:
            return parent1.copy(), parent2.copy()
        
    def mutate(self, key: list[int]) -> list[int]:
        "Performs shuffling mutation on a key."
        # shuffle within way boundaries

        for i in range(len(key)):
            group_id, neuron_id = divmod(i, self.way)
            another_neuron_id = random.randint(0, self.way - 1)
            j = group_id * self.way + another_neuron_id
            if random.random() < GA_Parameters.MUTATION_RATE:
                key[i], key[j] = key[j], key[i]  # swap two neurons within the same group
        return key
        
    def main_loop(self, generations: int, timeout_seconds: float = None) -> tuple[list[bool], Optional[list[int]], float]:
        "Runs the Genetic Algorithm."
        # 1. Initialization
        population = self.initialize_population()
        best_key = None
        best_fitness = float('inf')
        print_tee(f"Starting GA with Population Size={GA_Parameters.POPULATION_SIZE}, Generations={generations}")
        print_tee("-" * 30)

        start_time = time.time()
        for generation in range(generations):
            # 2. Evaluation
            fitnesses = self.evaluate_population(population)

            # Track best individual found so far
            current_best_fitness_idx = fitnesses.index(min(fitnesses))
            current_best_fitness = fitnesses[current_best_fitness_idx]
            current_best_key = population[current_best_fitness_idx].copy()
            if current_best_fitness < best_fitness:
                best_fitness = current_best_fitness
                best_key = current_best_key

            # Optional: Print progress
            current_fidelity = sum(1 for i in range(len(current_best_key)) if current_best_key[i] == self.mlp.correct_permutation_key[i]) / len(current_best_key)
            curr_acc_time = time.time() - start_time
            print_tee(f"Generation {generation + 1}/{generations}, Time Elapsed = {curr_acc_time:.2f}s, "
                      f"Current Best Fitness (JSD) = {current_best_fitness:.4f}, Fidelity = {current_fidelity:.4f}")

            if (generation + 1) % 10 == 0:  # start an evaluation every 10 generations
                _, eval_mmlu, eval_ppl, _ = evaluate_specific_key(
                    self.adapter, self.mlp, self.mlp.correct_key, current_best_key,
                    self.mmlu_evaluator, self.ppl_evaluator
                )
                print_tee(f"Evaluation at Generation {generation + 1}: MMLU = {eval_mmlu:.4f}, PPL = {eval_ppl:.4f}")

            if best_fitness < GA_Parameters.CUTOFF:
                print_tee(f"[Early stopping at Generation {generation + 1} with fitness {best_fitness:.4f} smaller than cutoff {GA_Parameters.CUTOFF}.]")
                break

            # 3. Selection  
            selected_parents = [self.tournament_selection(population, fitnesses) for _ in range(GA_Parameters.POPULATION_SIZE)]
            # 4. Crossover & Mutation (Create next generation)
            next_population = []
            for i in range(0, GA_Parameters.POPULATION_SIZE, 2):
                # Ensure we don't go out of bounds if POPULATION_SIZE is odd
                if i + 1 >= GA_Parameters.POPULATION_SIZE:
                    # If odd population size, just mutate the last selected parent
                    child = self.mutate(selected_parents[i].copy())
                    next_population.append(child)
                    continue

                parent1, parent2 = selected_parents[i], selected_parents[i + 1]
                # 5. Crossover
                child1, child2 = self.single_point_crossover(parent1, parent2)
                # 6. Mutation
                self.mutate(child1)
                self.mutate(child2)
                next_population.extend([child1, child2])

            # 7. Replace the old population with the new one
            population = next_population

            time_elapsed = time.time() - start_time
            if timeout_seconds and time_elapsed > timeout_seconds:
                print_tee(f"[Timeout after {timeout_seconds} seconds. Stopping GA.]")
                break

        # Final "quick evaluation" of the best key found:
        print_tee("Final evaluation of the best key found...")
        eval_fidelity, eval_mmlu, eval_ppl, _ = evaluate_specific_key(
            self.adapter, self.mlp, self.mlp.correct_key, best_key,
            self.mmlu_evaluator, self.ppl_evaluator
        )
        print_tee(f"Final Evaluation: MMLU (sampled, not accurate) = {eval_mmlu:.4f}, PPL = {eval_ppl:.4f}, Fidelity = {eval_fidelity:.4f}")
        print_tee(f"JSD (sampled, not accurate) with best key: {best_fitness:.4f}")
        return self.mlp.correct_key, best_key, best_fitness


class SGD_Parameters:
    LEARNING_RATE = 0.03
    CUTOFF = 0.00001

class SGD_Attacker:
    adapter: ma.Adapter
    mlp: ma.LockableMLP

    mmlu_evaluator: MMLU_Evaluator
    ppl_evaluator: PPL_Evaluator
    loss_evaluator: Union[JSD_LossEvaluator, CrossEntropy_LossEvaluator]

    uses_key: bool
    uses_permutation_key: bool
    enforces_binary: bool

    def __init__(self,
        adapter: ma.Adapter, layer_index: int,
        uses_key: bool, uses_permutation_key: bool, enforces_binary: bool,
        reference_file="dataset/training_text.txt", lines_per_chunk=100, chunk_count=10,
        uses_cross_entropy_loss_instead=False,
        oracle_guided_cross_entropy_loss=False,
    ):
        assert uses_key or uses_permutation_key, "At least one of the keys must be used."
        self.uses_key = uses_key
        self.uses_permutation_key = uses_permutation_key
        self.enforces_binary = enforces_binary

        self.adapter = adapter
        self.mlp = adapter.get_mlp_list()[layer_index]

        self.mmlu_evaluator = MMLU_Evaluator()
        self.mmlu_evaluator.set_random_index_subset(256)
        self.ppl_evaluator = PPL_Evaluator("dataset/test_text.txt")

        if uses_cross_entropy_loss_instead:
            self.loss_evaluator = CrossEntropy_LossEvaluator(reference_file, lines_per_chunk, chunk_count)
            if oracle_guided_cross_entropy_loss:
                self.loss_evaluator.modify_reference_texts_with_oracle(adapter, layer_index)
            self.loss_evaluator.set_reference_tokens(adapter)
        else:
            self.loss_evaluator = JSD_LossEvaluator(reference_file, lines_per_chunk, chunk_count)
            self.loss_evaluator.set_reference_logits(adapter, layer_index)

    def main_loop(self, epochs: int, timeout_seconds: float = None) -> tuple[list[bool], Optional[list[int]], float]:
        "Runs the SGD attack."
        optimizer = optim.Adam(self.adapter.model.parameters(), lr=SGD_Parameters.LEARNING_RATE)
        correct_key = self.mlp.correct_key
        correct_permutation_key = self.mlp.correct_permutation_key

        print_tee(f"Starting SGD with Learning Rate={SGD_Parameters.LEARNING_RATE}, Epochs={epochs}")
        print_tee("-" * 30)
        start_time = time.time()

        self.mlp.key_module_enabled = self.uses_key
        self.mlp.permutation_key_module_enabled = self.uses_permutation_key

        for epoch in range(epochs):
            losses = []
            progress = epoch / epochs
            if self.uses_key:
                self.mlp.key_module.progress = progress
            if self.uses_permutation_key:
                self.mlp.permutation_key_module.progress = progress
            
            # get chunks in a random order:
            chuck_count = len(self.loss_evaluator.reference_texts)
            indices = random.sample(range(chuck_count), chuck_count)

            for j, i in enumerate(indices):
                print(f"Epoch {epoch + 1}/{epochs}, chunk {j + 1}/{chuck_count}", end="\r")
                optimizer.zero_grad()
                loss = self.loss_evaluator.get_loss(self.adapter, i, True)
                loss.backward()
                optimizer.step()
                losses.append(loss.item())

            fidelity, permutation_fidelity, curr_permutation_key = None, None, None
            curr_key = self.mlp.correct_key
            if self.uses_key:
                curr_key = self.mlp.key_module.get_key_bits()
                fidelity = sum(curr_key[i] == correct_key[i] for i in range(len(correct_key))) / len(correct_key)
            if self.uses_permutation_key:
                curr_permutation_key = self.mlp.permutation_key_module.get_permutation_key()
                permutation_fidelity = sum(curr_permutation_key[i] == correct_permutation_key[i]
                                           for i in range(len(correct_permutation_key))) / len(correct_permutation_key)
            loss = sum(losses) / len(losses)
            curr_acc_time = time.time() - start_time
            loss_type = "JSD" if isinstance(self.loss_evaluator, JSD_LossEvaluator) else "CrossEntropy"
            print_tee(f"Epoch {epoch + 1}/{epochs}, Time Elapsed = {curr_acc_time:.2f}s, Loss ({loss_type}) = {loss:.8f}, "
                      f"Fidelity = {fidelity}, P-Fidelity = {permutation_fidelity}")

            if (epoch + 1) % 20 == 0:
                for now_enforces_binary in [True, False]:
                    _, eval_mmlu, eval_ppl, _ = evaluate_specific_key(
                        self.adapter, self.mlp, curr_key, curr_permutation_key,
                        self.mmlu_evaluator, self.ppl_evaluator, enforce_binary=now_enforces_binary
                    )
                    print_tee(f"Evaluation at Epoch {epoch + 1}: MMLU = {eval_mmlu:.4f}, PPL = {eval_ppl:.4f} (Enforce Binary = {now_enforces_binary})")

            if loss < SGD_Parameters.CUTOFF:
                print_tee(f"[Early stopping at Epoch {epoch + 1} with loss {loss:.4f} smaller than cutoff {SGD_Parameters.CUTOFF}.]")
                break

            time_elapsed = time.time() - start_time
            if timeout_seconds and time_elapsed > timeout_seconds:
                print_tee(f"[Timeout after {timeout_seconds} seconds. Stopping SGD.]")
                break

        # Final "quick evaluation" of the best key found:
        print_tee("Final evaluation of the best key found...")
        final_key = self.mlp.key_module.get_key_bits() if self.uses_key else self.mlp.correct_key
        final_p_key = self.mlp.permutation_key_module.get_permutation_key() if self.uses_permutation_key else None
        final_loss = loss
        eval_fidelity, eval_mmlu, eval_ppl, eval_p_fidelity = evaluate_specific_key(
            self.adapter, self.mlp, final_key, final_p_key,
            self.mmlu_evaluator, self.ppl_evaluator, enforce_binary=self.enforces_binary
        )
        print_tee(f"Final Evaluation: MMLU (sampled, not accurate) = {eval_mmlu:.4f}, PPL = {eval_ppl:.4f}, "
                  f"Fidelity = {eval_fidelity}, P-Fidelity = {eval_p_fidelity}")
        print_tee(f"JSD (sampled, not accurate) with best key: {final_loss:.4f}")

        if self.enforces_binary:
            self.mlp.key_module_enabled = False
            self.mlp.permutation_key_module_enabled = False

        return final_key, final_p_key, final_loss
    
