import lock_and_attack as la
import model_adapter as ma

from torch.cuda import empty_cache
from lock_and_attack import print_tee

from typing import Optional
import random


def calculate_param_overhead(adapter: ma.Adapter, key_size: int):
    total_params = sum(p.numel() for p in adapter.model.parameters())
    new_params_from_R = key_size ** 2
    return new_params_from_R / total_params

def calculate_flop_overhead(adapter: ma.Adapter, log2_of_key_size: int, token_count=256):
    from fvcore.nn import FlopCountAnalysis
    prompt = "Whatever" * token_count
    inputs = adapter.tokenizer.encode(prompt, return_tensors='pt').to(adapter.model.device)
    model = adapter.model
    if hasattr(model, 'config') and hasattr(model.config, 'use_cache'):  # disable cache to get fvcore running
        model.config.use_cache = False

    flops = FlopCountAnalysis(adapter.model, inputs)
    total_flops_before_locking = flops.total()

    la.apply_random_lock(adapter, 0, log2_of_key_size, True, False)
    flops = FlopCountAnalysis(adapter.model, inputs)
    total_flops_after_locking = flops.total()
    return total_flops_after_locking / total_flops_before_locking - 1


def evaluate_original_model(adapter: ma.Adapter):
    mmlu_evaluator = la.MMLU_Evaluator()
    mmlu = mmlu_evaluator.evaluate(adapter)
    print_tee(f"Original model MMLU score: {mmlu}")
    
    ppl_evaluator = la.PPL_Evaluator()
    ppl = ppl_evaluator.evaluate(adapter)
    print_tee(f"Original model PPL score: {ppl}")


def extra_permutation_sanity_check(adapter: ma.Adapter, layer_index: int, log2_of_key_size: int):
    ppl_evaluator = la.PPL_Evaluator()

    print(f"Original model PPL score: {ppl_evaluator.evaluate(adapter)}")

    key, p_key = la.apply_outlier_lock(adapter, layer_index, log2_of_key_size, rotates=True, permutes=True)
    mlp = adapter.get_mlp_list()[layer_index]
    mlp.key_module_enabled = False
    mlp.permutation_key_module_enabled = False
    mlp.set_runtime_key(key)
    mlp.set_runtime_permutation(p_key)
    print(f"Locked model PPL score, with correct keys: {ppl_evaluator.evaluate(adapter)}")

    key_size = 2 ** log2_of_key_size
    mlp.set_runtime_key(la.get_random_key(key_size))
    print(f"Locked model PPL score, with random key: {ppl_evaluator.evaluate(adapter)}")
    mlp.set_runtime_key(key)

    for i in range(3):
        mlp.set_runtime_permutation(la.get_random_permutational_key(key_size))
        print(f"Locked model PPL score {i}, with random permutation: {ppl_evaluator.evaluate(adapter)}")


def evaluate_randomly_locked_model(
    adapter: ma.Adapter, layer_index: int, log2_of_key_size: int, count: int,
    picks_outlier: bool, rotates: bool, permutes: bool, flips: bool,
    mmlu_subset_size=None,
    way=16,
):
    correct_key, correct_p_key = (la.apply_outlier_lock if picks_outlier else la.apply_random_lock)(
        adapter, layer_index, log2_of_key_size, rotates, permutes,
        way=way,
    )
    mlp = adapter.get_mlp_list()[layer_index]
    mlp.key_module_enabled = False
    mlp.permutation_key_module_enabled = False
    key_size = len(correct_key)

    mmlu_evaluator = la.MMLU_Evaluator()
    if mmlu_subset_size is not None:
        mmlu_evaluator.set_random_index_subset(mmlu_subset_size)

    ppl_evaluator = la.PPL_Evaluator()

    jsd_evaluator = la.JSD_LossEvaluator(chunk_count=10)
    jsd_evaluator.set_reference_logits(adapter, layer_index)

    mmlu_list = []
    ppl_list = []

    def evaluate_key(key, p_key):
        mlp.set_runtime_key(key)
        if permutes:
            mlp.set_runtime_permutation(p_key)
        fidelity = (sum(1 for i in range(key_size) if key[i] != correct_key[i]) / key_size) if flips else None
        p_fidelity = (sum(1 for i in range(key_size) if p_key[i] != correct_p_key[i]) / key_size) if permutes else None
        mmlu = mmlu_evaluator.evaluate(adapter)
        ppl = ppl_evaluator.evaluate(adapter)
        jsd = jsd_evaluator.evaluate(adapter)
        print_tee(f"fidelity={fidelity} mmlu={mmlu:.3f} ppl={ppl:.3f} jsd={jsd:.3f}, p_fidelity={p_fidelity}")
        mmlu_list.append(mmlu)
        ppl_list.append(ppl)

    print_tee("Evaluating original key...")
    evaluate_key(correct_key, correct_p_key)
    for i in range(count):
        print_tee(f"Evaluating random key {i + 1}/{count}...")
        key = la.get_random_key(key_size) if flips else correct_key
        p_key = la.get_random_permutational_key(key_size, way) if permutes else None
        evaluate_key(key, p_key)

    ref_mmlu = mmlu_list[0]
    rest_mmlu = mmlu_list[1:]
    avg_mmlu = sum(rest_mmlu) / len(rest_mmlu)
    stddev_mmlu = (sum((x - avg_mmlu) ** 2 for x in rest_mmlu) / len(rest_mmlu)) ** 0.5
    print_tee(f"ref_mmlu={ref_mmlu:.3f} avg_mmlu={avg_mmlu:.3f} stddev_mmlu={stddev_mmlu:.3f}")

    ref_ppl = ppl_list[0]
    rest_ppl = ppl_list[1:]
    avg_ppl = sum(rest_ppl) / len(rest_ppl)
    stddev_ppl = (sum((x - avg_ppl) ** 2 for x in rest_ppl) / len(rest_ppl)) ** 0.5
    print_tee(f"ref_ppl={ref_ppl:.3f} avg_ppl={avg_ppl:.3f} stddev_ppl={stddev_ppl:.3f}")


def test_attack(
    adapter: ma.Adapter, layer_index: int, log2_of_key_size: int, epochs: int, ite_ga_sgd: bool, timeout: float,
    uses_jsd: bool, oracle_guided: bool,
    flips: bool, rotates: bool, permutes: bool,
    enforces_binary: bool,
    way: Optional[int],
):
    key, p_key = la.apply_outlier_lock(adapter, layer_index, log2_of_key_size, rotates, permutes, way=way)
    mlp = adapter.get_mlp_list()[layer_index]

    if ite_ga_sgd:
        chunk_count = 5
        attacker = la.GA_Attacker(adapter, layer_index,
                                  chunk_count=chunk_count,
                                  way=way,)
        final_key, final_p_key, _ = attacker.main_loop(epochs, timeout)
        # final_p_key = None
    else:
        mlp.set_runtime_key(key)
        if permutes:
            mlp.set_runtime_permutation(p_key)
        chunk_count = 5 if uses_jsd else 10
        attacker = la.SGD_Attacker(adapter, layer_index, uses_key=flips, uses_permutation_key=permutes,
            enforces_binary=enforces_binary,
            uses_cross_entropy_loss_instead=not uses_jsd,
            oracle_guided_cross_entropy_loss=oracle_guided,
            chunk_count=chunk_count,
        ) 
        final_key, final_p_key, _ = attacker.main_loop(epochs, timeout)

    print_tee("Evaluating final key...")
    if enforces_binary:
        mlp.set_runtime_key(final_key if flips else key)
        mlp.set_runtime_permutation(final_p_key if permutes else p_key)

    mmlu_evaluator = la.MMLU_Evaluator()
    ppl_evaluator = la.PPL_Evaluator()
    mmlu = mmlu_evaluator.evaluate(adapter)
    ppl = ppl_evaluator.evaluate(adapter)

    print_tee(f"Final accurate mmlu: {mmlu}")
    print_tee(f"Final accurate ppl: {ppl}")
       

ADAPTER_CONFIGS = [
    (ma.Minitron4B_Adapter, 1, "Minitron 4B"),                      # 0
    (ma.Minitron8B_Adapter, 1, "Minitron 8B"),                      # 1
    (ma.LlamaMinitron4B_Adapter, 1, "Llama-3.1 Minitron 4B W"),     # 2
    (ma.MistralNeMo8B_Adapter, 0, "Mistral NeMo Minitron 8B"),      # 3
]

def main_result_script():
    """
    Script to generate data for our main results in the paper.
    """
    # import torch
    # torch.autograd.set_detect_anomaly(True)

    ## Which GPU to use?
    gpu_id = 0                          # TODO: Change this line to run on a different GPU

    ## Which models to test?
    adapter_ids = [0, 1, 2, 3]          # TODO: Change this line to run only on a subset of models (see `ADAPTER_CONFIGS`)

    ## Which key sizes to test?
    # key size | log_2
    #  256     | 8
    #  512     | 9
    #  ...
    #  8192    | 13
    log2_of_key_sizes = [8, 9, 10, 11, 12, 13]      # TODO: Change this line to run only on a subset of key sizes

    ## Which experiments to run?
    tests_original_performance = True   # TODO: Change to `False` to skip this test
    tests_hpnn_locking = True           # TODO: Change to `False` to skip this test
    tests_lla_locking = True            # TODO: Change to `False` to skip this test
    tests_genetic_attack = True         # TODO: Change to `False` to skip this test
    tests_gradient_attack = True        # TODO: Change to `False` to skip this test
    tests_overhead = False

    sanity_check = False

    ## Sanity check
    ma.DEVICE = f"cuda:{gpu_id}"
    la.DEVICE = f"cuda:{gpu_id}"

    if sanity_check:
        for adapter_id in adapter_ids:
            adapter_class, layer_to_lock, display_name = ADAPTER_CONFIGS[adapter_id]
            print_tee(f"Testing {display_name} for sanity check...")
            adapter: ma.Adapter = adapter_class()
            extra_permutation_sanity_check(adapter, layer_to_lock, log2_of_key_sizes[0])
            empty_cache()
            print_tee("")

    ## Let's run the tests!
    print_tee(f"Evaluating the main table...")
              
    if tests_original_performance:
        for adapter_id in adapter_ids:
            adapter_class, _, display_name = ADAPTER_CONFIGS[adapter_id]
            print_tee(f"Evaluating original model {display_name}...")
            adapter: ma.Adapter = adapter_class()
            evaluate_original_model(adapter)
            empty_cache()
            print_tee("")

    if tests_hpnn_locking:
        random_key_count = 100      # TODO: Use a different value to balance statistical accuracy v.s. runtime
        mmlu_subset_size = 1024     # TODO: Use a different value to balance statistical accuracy v.s. runtime
        for adapter_id in adapter_ids:
            adapter_class, layer_to_lock, display_name = ADAPTER_CONFIGS[adapter_id]
            print_tee(f"Evaluating HPNN locking {display_name}...")
            for log2_of_key_size in log2_of_key_sizes:
                key_size = 2 ** log2_of_key_size
                print_tee(f"Key size: {key_size}")
                adapter: ma.Adapter = adapter_class()
                random_layer_to_lock = random.randint(0, 9)
                evaluate_randomly_locked_model(adapter, random_layer_to_lock, log2_of_key_size, random_key_count,
                    picks_outlier=False, rotates=False, permutes=False, flips=True, mmlu_subset_size=mmlu_subset_size)
                empty_cache()
                print_tee("")

    if tests_lla_locking:
        random_key_count = 100      # TODO: Use a different value to balance statistical accuracy v.s. runtime
        mmlu_subset_size = 1024     # TODO: Use a different value to balance statistical accuracy v.s. runtime
        for adapter_id in adapter_ids:
            adapter_class, layer_to_lock, display_name = ADAPTER_CONFIGS[adapter_id]
            print_tee(f"Evaluating LLA locking {display_name}...")
            for log2_of_key_size in log2_of_key_sizes:
                key_size = 2 ** log2_of_key_size
                print_tee(f"Key size: {key_size}")
                adapter: ma.Adapter = adapter_class()
                evaluate_randomly_locked_model(adapter, layer_to_lock, log2_of_key_size, random_key_count,
                    picks_outlier=True, rotates=True, permutes=True, flips=False, mmlu_subset_size=mmlu_subset_size)
                empty_cache()
                print_tee("")

    if tests_genetic_attack:
        repeat = 3                  # TODO: Use a different value to balance statistical accuracy v.s. runtime
        max_epochs = 2000           # TODO: Use a different value to reflect attacker's patience
        timeout = 7200              # TODO: Use a different value to reflect attacker's patience
        
        way = 16        # TODO: Group size for permutation

        for adapter_id in adapter_ids:
            adapter_class, layer_to_lock, display_name = ADAPTER_CONFIGS[adapter_id]
            print_tee(f"Evaluating genetic attack {display_name}...")
            for log2_of_key_size in log2_of_key_sizes:
                key_size = 2 ** log2_of_key_size
                print_tee(f"Key size: {key_size}")
                for i in range(repeat):
                    print_tee(f"Repeat {i + 1}/{repeat}...")
                    adapter: ma.Adapter = adapter_class()
                    test_attack(adapter, layer_to_lock, log2_of_key_size, max_epochs, ite_ga_sgd=True, timeout=timeout,
                                uses_jsd=True, oracle_guided=True,
                                flips=False, rotates=True, permutes=True,
                                enforces_binary=True,
                                way=way)
                    empty_cache()
                    print_tee("")

    if tests_gradient_attack:
        repeat = 3                  # TODO: Use a different value to balance statistical accuracy v.s. runtime
        max_epochs = 2000           # TODO: Use a different value to reflect attacker's patience
        timeout = 7200              # TODO: Use a different value to reflect attacker's patience
        uses_jsd = True             # TODO: Change to `True` to use JSD loss, `False` to use CrossEntropy loss
        oracle_guided = True        # TODO: Change to `True` for oracle guided attack, `False` for oracle-less attack

        enforces_binary = True      # TODO: Change to `True` to enforce binary keys, `False` to allow any real-valued keys
        flips = False               # TODO: Enables key-controlled negation?
        rotates = True              # TODO: Enables Hadamard-matrix-based rotation?
        permutes = True             # TODO: Enables key-controlled permutation?

        way = 16                    # TODO: Group size for permutation

        for adapter_id in adapter_ids:
            adapter_class, layer_to_lock, display_name = ADAPTER_CONFIGS[adapter_id]
            print_tee(f"Evaluating gradient attack {display_name}... flips={flips} rotates={rotates} permutes={permutes} way={way}...")
            for log2_of_key_size in log2_of_key_sizes:
                key_size = 2 ** log2_of_key_size
                print_tee(f"Key size: {key_size}")
                for i in range(repeat):
                    print_tee(f"Repeat {i + 1}/{repeat}...")
                    adapter: ma.Adapter = adapter_class()
                    test_attack(adapter, layer_to_lock, log2_of_key_size, max_epochs, ite_ga_sgd=False, timeout=timeout,
                                uses_jsd=uses_jsd, oracle_guided=oracle_guided, flips=flips, rotates=rotates, permutes=permutes,
                                enforces_binary=enforces_binary,
                                way=way)
                    empty_cache()
                    print_tee("")
                
    if tests_overhead:
        for adapter_id in adapter_ids:
            adapter_class, _, display_name = ADAPTER_CONFIGS[adapter_id]
            print_tee(f"Evaluating overhead {display_name}...")
            for log2_of_key_size in log2_of_key_sizes:
                key_size = 2 ** log2_of_key_size
                print_tee(f"Key size: {key_size}")
                adapter: ma.Adapter = adapter_class()
                param_overhead = calculate_param_overhead(adapter, key_size)
                flop_overhead = calculate_flop_overhead(adapter, log2_of_key_size)
                print_tee(f"Param overhead: {param_overhead}")
                print_tee(f"FLOP overhead: {flop_overhead}")
                empty_cache()
                print_tee("")

    print_tee("Finished evaluating all tests.")


def ablation_study_script():
    """
    Script to generate data for the ablation study in the paper.
    """
    gpu_id = 0
    ma.DEVICE = f"cuda:{gpu_id}"
    la.DEVICE = f"cuda:{gpu_id}"

    ## Which model to test?
    adapter_id = 3                      # TODO: Change this line to run ablation study on another model

    ## Which key sizes to test?
    log2_of_key_sizes = [10]            # TODO: Change this line to run on other subset of key sizes

    ## How many points to produce for each configuration?
    random_key_count = 100              # TODO: Use a different value to balance statistical accuracy v.s. runtime

    ## What is the subset size for MMLU?
    mmlu_subset_size = 1024             # TODO: Use a different value to balance statistical accuracy v.s. runtime

    ## Setting up...
    adapter_class, layer_to_lock, display_name = ADAPTER_CONFIGS[adapter_id]
    import random
    alternate_layer_to_lock = random.choice([i for i in range(0, 4) if i != layer_to_lock])

    ablation_configs = [
        ("Locking a Different Block", alternate_layer_to_lock, True, True, False, True),
        ("Locking Random Neurons", layer_to_lock, False, True, False, True),
        ("Locking Without Obfuscation", layer_to_lock, True, False, False, True),
        ("Locking with HPNN Negation", layer_to_lock, True, True, True, False),
        ("LLA As-Is", layer_to_lock, True, True, False, True),
    ]

    ## Let's run the tests!
    print_tee(f"Evaluating ablation study for model: {display_name}...")

    for config_description, layer_to_lock, picks_outlier, rotates, flips, permutes in ablation_configs:
        print_tee(f"Evaluating '{config_description}' for model: {display_name}...")
        for log2_of_key_size in log2_of_key_sizes:
            key_size = 2 ** log2_of_key_size
            print_tee(f"Key size: {key_size}")
            adapter: ma.Adapter = adapter_class()
            evaluate_randomly_locked_model(adapter, layer_to_lock, log2_of_key_size, random_key_count,
                picks_outlier=picks_outlier, rotates=rotates, permutes=permutes, flips=flips, mmlu_subset_size=mmlu_subset_size,
                way=16,)
            empty_cache()
            print_tee("")

    print_tee("Finished evaluating all tests.")


if __name__ == "__main__":
    main_result_script()         ## TODO: comment this line to skip the main experiment script
    ablation_study_script()     ## TODO: comment this line to skip the ablation study script
    print("Done!")
