import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier

from bayes_net_simple import BayesianInferenceClassifier
from bayes_net_complex import PerfectOracleBayesCaptcha


def accuracy_score(y_true, y_pred):
    return (y_true == y_pred).mean()


import numpy as np

def compute_calibration_curve(y_true, y_prob, n_bins=10):

    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)

    # Make bin edges from 0 to 1 (e.g., 0.0, 0.1, 0.2, ... for 10 bins)
    edges = np.linspace(0, 1, n_bins + 1)

    # Determine which bin each probability falls into.
    # We'll store one bin index per sample.
    bin_ids = []
    for p in y_prob:
        # Find the first edge that is greater than p.
        # digitize does this, but we implement it manually for teaching clarity.
        b = 0
        while b < n_bins and p >= edges[b+1]:
            b += 1
        # Clamp to last bin if p == 1.0
        if b >= n_bins:
            b = n_bins - 1
        bin_ids.append(b)

    mean_pred = []
    frac_pos = []

    # Now walk through each bin and gather the samples that landed in it.
    for b in range(n_bins):
        probs_in_bin = []
        trues_in_bin = []

        # Collect samples manually (no vectorized indexing)
        for i in range(len(y_prob)):
            if bin_ids[i] == b:
                probs_in_bin.append(y_prob[i])
                trues_in_bin.append(y_true[i])

        # Skip empty bins entirely
        if len(probs_in_bin) == 0:
            continue

        # Compute the average predicted probability and observed positive rate
        mean_prob = sum(probs_in_bin) / len(probs_in_bin)
        pos_rate = sum(trues_in_bin) / len(trues_in_bin)

        mean_pred.append(mean_prob)
        frac_pos.append(pos_rate)

    return np.array(mean_pred), np.array(frac_pos)


def plot_calibration_curve(mean_predicted, frac_positives, title):
    """
    Plot a calibration (reliability) diagram given bin stats.
    """
    plt.figure(figsize=(6, 6))

    # Perfectly calibrated reference line
    plt.plot([0, 1], [0, 1], linestyle="--", label="Perfect calibration")

    # Model calibration curve
    plt.plot(mean_predicted, frac_positives, marker="o", linewidth=1.5, label=title)

    plt.xlabel("Mean predicted probability")
    plt.ylabel("Fraction of positives")
    plt.title(title)
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.grid(True, linestyle=":")
    plt.legend()
    plt.tight_layout()
    plt.show()


def main():
    # Load dataset
    dataset = pd.read_csv('dataset.csv')

    features = [
        'Check_Time',
        'Challenge_Time',
        'Challenge_Errors',
        'Mouse_Path_Entropy',
        'Click_Speed',
        'Scroll_Count',
        'History_Captcha_Success',
        'History_Captcha_Count',
        'IP_Suspicious',
        'Device_Trust_Score',
    ]

    # Separate features and target
    X = dataset[features]
    y = dataset['Is_Human']

    # Train/test split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=42
    )

    # Initialize models
    log_reg = LogisticRegression(max_iter=1000)

    # Train models
    models = {
        'Logistic Regression': log_reg,
        'Random Forest':RandomForestClassifier(n_estimators=100, max_depth=6, random_state=42)

    }

    for name, model in models.items():
        print("Training", name, "...")
        model.fit(X_train, y_train)

    # ---- Calibration plot for Logistic Regression (from scratch) ----
    # Get predicted probabilities on the test set
    for model_id in models:
        log_reg_test_proba = models[model_id].predict_proba(X_test)[:, 1]

        # Compute calibration curve
        mean_pred, frac_pos = compute_calibration_curve(
            y_true=y_test,
            y_prob=log_reg_test_proba,
            n_bins=10
        )

        # Plot calibration
        plot_calibration_curve(
            mean_predicted=mean_pred,
            frac_positives=frac_pos,
            title=model_id
        )


if __name__ == '__main__':
    main()