import re
from pprint import pprint
import base64
import os
from openai import OpenAI
import random
import json
import math
import ast  # needed for extract_dict


epsilon = 0.00001

api_key = open("key.txt").read().strip()
client = OpenAI(api_key=api_key)

CACHE_PATH = "gpt_cache.json"

#######################################################################
#                              MAIN LOGIC
#######################################################################

def main():
    raw_data = open('raw.txt').read()
    records = parse_records(raw_data)

    target_id = "MR0001"

    # Load three summaries
    summaries = {
        'great': open('summary_great.txt').read(),
        'good': open('summary_good.txt').read(),
        'ok': open('summary_ok.txt').read(),
        'bad': open('summary_bad.txt').read()
    }

    N_DISTRACTORS = 10
    N_RUNS = 5

    # === SET RANDOM SEED FOR REPRODUCIBILITY (important for cache hits) ===
    random.seed(42)

    # All possible distractor IDs (exclude target)
    all_ids = [rid for rid in records.keys() if rid != target_id]

    # Store scores for each summary across runs
    scores = {name: [] for name in summaries.keys()}

    for run in range(N_RUNS):
        # === Pick a new set of distractors for this run ===
        distractor_ids = random.sample(all_ids, N_DISTRACTORS)

        # Candidate IDs: target + distractors
        candidate_ids = [target_id] + distractor_ids

        print(f"\n=== Run {run + 1} ===")
        print("Candidate IDs:", candidate_ids)

        # Use the SAME candidate_ids for all summaries in this run
        for name, summary_text in summaries.items():
            print(f"\nScoring summary: {name}")
            p_target = score_summary(
                records=records,
                target_id=target_id,
                summary=summary_text,
                candidate_ids=candidate_ids,
            )
            scores[name].append(p_target)

    # Compute averages
    avg_scores = {
        name: (sum(vals) / len(vals)) if vals else float('nan')
        for name, vals in scores.items()
    }

    print("\n=== Average probability on true record (MR0001) ===")

    for name, avg in avg_scores.items():
        if math.isnan(avg):
            print(f"{name}: NaN")
            continue

        # Compute MI upper bound: log(p) + log(N)
        MI = math.log(avg) + math.log(N_DISTRACTORS)

        print(f"{name}:\tp = {avg:.2f},  MI < {MI:.2f} nats")

    # Save final cache state to disk
    save_cache(gpt_cache)

    return avg_scores


def score_summary(records, target_id, summary, candidate_ids):
    """
    Ask the LLM to assign probabilities over candidate_ids for a given summary.
    Returns the probability assigned to the true target_id.
    """

    # Build text for candidates
    candidates_text = "\n\n".join(
        f"Medical record ID: {mrn}\n{records[mrn]}"
        for mrn in candidate_ids
    )

    # Build the full prompt
    text_prompt = f"""
You are an expert clinician reviewing anonymized electronic medical records.

You are given:
1. A short summary describing a patient's presentation and clinical course.
2. A set of candidate medical records from different patients.

Your task:
For each candidate medical record ID, assign a probability that this summary
most likely belongs to that record.

Output requirements:
- Return ONLY a JSON object.
- Keys must be the medical record IDs exactly as given (e.g. "MR0001").
- Values must be probabilities (numbers) with EXACTLY 3 decimal places.
- The probabilities must sum to 1.0 (allowing normal rounding).
- Do NOT include any extra text, explanations, or fields.

Summary of the patient:
{summary}

Candidate medical records (each begins with its ID line):
{candidates_text}
"""

    # === get raw result from cache or model ===
    raw_result = get_llm_raw_result(summary, candidate_ids, text_prompt)
    print(f"LLM raw result: {raw_result}")

    # Parse returned dict
    result_dict = extract_dict(raw_result)
    if result_dict is None:
        raise ValueError("Could not parse a dictionary from LLM output.")

    # Convert values to float if needed
    for k, v in list(result_dict.items()):
        if isinstance(v, str):
            try:
                result_dict[k] = float(v)
            except ValueError:
                pass

    # Normalize probability mass
    pmf = normalize(result_dict)

    # Probability on the true target record
    p_goal = max(pmf.get(target_id, 0.0), epsilon)
    print(f"p(target={target_id}) = {p_goal:.3f}")

    # Optional: mutual information debug
    N = len(candidate_ids) - 1  # number of distractors
    mutual_info = math.log(p_goal) + math.log(N)
    print(f"MI < {mutual_info:.2f} nats")

    return p_goal


#######################################################################
#                           LLM CALL WRAPPER
#######################################################################

def get_llm_raw_result(summary, candidate_ids, text_prompt):
    """
    Wrapper around the LLM call with persistent caching.
    Cache key is a string representing (summary, candidate_ids).
    """

    # --- Convert key to a stable JSON-serializable string ---
    key = json.dumps({
        "summary": summary,
        "candidate_ids": list(candidate_ids)
    }, sort_keys=True)

    # --- Check cache ---
    if key in gpt_cache:
        print("Using cached LLM result.")
        return gpt_cache[key]

    # --- Call the model if not in cache ---
    response = client.chat.completions.create(
        model="gpt-4.1-mini",
        messages=[
            {"role": "system", "content": "You are a careful, detail-oriented medical data annotator."},
            {"role": "user", "content": text_prompt},
        ],
        temperature=0.0,
    )

    raw_result = response.choices[0].message.content.strip()

    # --- Store in cache and save to disk ---
    gpt_cache[key] = raw_result
    save_cache(gpt_cache)

    return raw_result


#######################################################################
#                        UTILITY FUNCTIONS
#######################################################################

def parse_records(text: str):
    raw_records = re.split(r'\n-{5,}\n', text)
    records_by_id = {}

    for raw in raw_records:
        raw = raw.strip()
        if not raw:
            continue
        m = re.search(r'Patient ID:\s*(\S+)', raw)
        if not m:
            continue
        patient_id = m.group(1)
        records_by_id[patient_id] = raw

    return records_by_id


def extract_dict(text: str):
    code_blocks = re.findall(r"```(?:json|python)?\s*(.*?)```", text, re.DOTALL)
    candidates = code_blocks + [text]

    for chunk in candidates:
        chunk = chunk.strip()
        if not chunk:
            continue

        brace_match = re.search(r"\{.*\}", chunk, re.DOTALL)
        if not brace_match:
            continue

        obj_str = brace_match.group(0)

        try:
            return json.loads(obj_str)
        except:
            pass

        try:
            parsed = ast.literal_eval(obj_str)
            if isinstance(parsed, dict):
                return parsed
        except:
            continue

    return None


def normalize(pmf):
    total = sum(pmf.values())
    if total == 0:
        return pmf
    return {k: v / total for k, v in pmf.items()}


#######################################################################
#                           CACHE HELPERS
#######################################################################

def load_cache():
    """Safely load GPT cache from disk, or return empty dict."""
    if not os.path.exists(CACHE_PATH):
        print("Cache file not found, starting with empty cache.")
        return {}

    try:
        with open(CACHE_PATH, "r", encoding="utf-8") as f:
            cache = json.load(f)
            print(f"Loaded cache with {len(cache)} entries.")
            return cache
    except Exception as e:
        print(f"Failed to load cache ({e}), starting with empty cache.")
        return {}


def save_cache(cache):
    """Safely save GPT cache to disk."""
    try:
        with open(CACHE_PATH, "w", encoding="utf-8") as f:
            json.dump(cache, f, indent=2)
            print(f"Saved cache with {len(cache)} entries.")
    except Exception as e:
        print(f"ERROR saving cache: {e}")


# Load cache ONCE at runtime
gpt_cache = load_cache()



#######################################################################

if __name__ == '__main__':
    main()