import numpy as np
from scipy.stats import norm, poisson, binom

class BayesianInferenceClassifier:
    """
    Oracle Bayes classifier

    Observed features (in this order):
      0: Check_Time
      1: Challenge_Time
      2: Challenge_Errors
      3: Mouse_Path_Entropy
      4: Click_Speed
      5: Scroll_Count
      6: History_Captcha_Success
      7: History_Captcha_Count
      8: IP_Suspicious
      9: Device_Trust_Score

    It does NOT fit parameters from data; it uses the known generative process.
    """

    def __init__(self, n_trait_samples=100, noise_level=0.5,
                 huge_neg=-1e30, random_state=None):
        """
        n_trait_samples : number of Monte Carlo samples per user_type per class
        """
        self.n_trait_samples = n_trait_samples
        self.noise_level = noise_level
        self.huge_neg = huge_neg
        self.rng = np.random.default_rng(random_state)

        # Class priors
        self.prior_human = 0.6
        self.prior_bot   = 0.4

        # user_type sets and priors
        # humans: 0,1,2 ; bots: 3,4,5
        self.human_types = np.array([0, 1, 2])
        self.bot_types   = np.array([3, 4, 5])
        self.human_type_probs = {0: 0.4, 1: 0.4, 2: 0.2}
        self.bot_type_probs   = {3: 0.5, 4: 0.3, 5: 0.2}

        # trait means per user_type (from generator)
        self.motor_means = {
            0: 1.2,  # careful human
            1: 1.0,  # rushed human
            2: 0.8,  # distracted human
            3: 0.3,  # script bot
            4: 1.1,  # stealth bot
            5: 0.9,  # solver bot
        }
        self.visual_means = {
            0: 1.2,
            1: 1.1,
            2: 0.8,
            3: 0.2,
            4: 0.7,
            5: 1.4,
        }
        self.patience_means = {
            0: 1.3,
            1: 0.2,
            2: 0.5,
            3: -0.2,
            4: 0.8,
            5: 0.1,
        }

        # trait std dev = sqrt(0.5^2 + (0.2*noise)^2)
        self.trait_sd = np.sqrt(0.5**2 + (0.2 * noise_level)**2)

        # Feature noise scales (from generator; approximate)
        self.sigma_check   = 0.3 + 0.3 * noise_level         # Check_Time
        self.sigma_chal    = 0.8 + 0.5 * noise_level         # Challenge_Time
        self.sigma_entropy = 0.3 + 0.3 * noise_level         # Mouse_Path_Entropy
        self.sigma_click   = 0.4 + 0.3 * noise_level         # Click_Speed
        self.sigma_hist    = 3.0                             # History_Captcha_Count
        self.sigma_dev     = 0.8 + 0.4 * noise_level         # Device_Trust_Score

    # ------------------------------------------------------------------
    # sklearn API
    # ------------------------------------------------------------------
    def fit(self, X, y=None):
        # no training; parameters are chosen
        return self

    def predict_proba(self, X):
        """
        Returns array of shape (n_samples, 2):
        [:,0] = P(Is_Human=0), [:,1] = P(Is_Human=1)
        """
        X = np.asarray(X)
        n = X.shape[0]
        out = np.zeros((n, 2), dtype=float)

        for i, row in enumerate(X):
            log_p_x_h = self._log_p_x_given_class(row, is_human=1)
            log_p_x_b = self._log_p_x_given_class(row, is_human=0)

            # add priors
            log_h = log_p_x_h + np.log(self.prior_human)
            log_b = log_p_x_b + np.log(self.prior_bot)

            m = max(log_h, log_b)
            ph = np.exp(log_h - m)
            pb = np.exp(log_b - m)
            s = ph + pb
            out[i, 0] = pb / s
            out[i, 1] = ph / s

        return out

    def predict(self, X):
        proba = self.predict_proba(X)
        return (proba[:, 1] >= 0.5).astype(int)

    # ------------------------------------------------------------------
    # Internals
    # ------------------------------------------------------------------
    def _log_mc_integrate(self, ll_vals):
        """log(mean(exp(ll_vals))) with stability."""
        ll_vals = np.asarray(ll_vals)
        if np.all(np.isneginf(ll_vals)):
            return self.huge_neg
        m = np.max(ll_vals)
        shifted = ll_vals - m
        finite = np.isfinite(shifted)
        if not np.any(finite):
            return self.huge_neg
        shifted = shifted[finite]
        return m + np.log(np.mean(np.exp(shifted)))

    def _sample_traits_given_type(self, user_type, n_samples):
        """Sample (motor, visual, patience) from the prior for a given user_type."""
        mu_m = self.motor_means[user_type]
        mu_v = self.visual_means[user_type]
        mu_p = self.patience_means[user_type]
        sd = self.trait_sd

        motor = self.rng.normal(mu_m, sd, n_samples)
        visual = self.rng.normal(mu_v, sd, n_samples)
        patience = self.rng.normal(mu_p, sd, n_samples)
        return motor, visual, patience

    def _log_p_x_given_class(self, x, is_human):
        """
        log P(x | Is_Human=is_human) = log sum_t P(t|class) * ∫ p(x|traits,t,class)p(traits|t) dtraits
        We approximate ∫ via Monte Carlo, and sum over user_type exactly.
        """
        x = np.asarray(x)
        if x.shape[0] != 10:
            raise ValueError("Expected 10 features per sample.")

        (check_time,
         challenge_time,
         challenge_errors,
         mouse_entropy,
         click_speed,
         scroll_count,
         hist_success_rate,
         hist_count,
         ip_susp,
         device_trust) = x

        # Sanity checks
        if check_time <= 0 or challenge_time <= 0:
            return self.huge_neg
        if hist_count < 0 or scroll_count < 0 or challenge_errors < 0:
            return self.huge_neg

        # Choose user_type set and priors
        if is_human == 1:
            types = self.human_types
            type_prior = self.human_type_probs
        else:
            types = self.bot_types
            type_prior = self.bot_type_probs

        # We'll compute log-sum over types:
        log_terms = []

        for t in types:
            # Monte Carlo over traits for this type
            motor, visual, patience = self._sample_traits_given_type(t, self.n_trait_samples)
            ll_total = self._log_likelihood_features(
                x, is_human, t, motor, visual, patience
            )
            log_p_x_given_t = self._log_mc_integrate(ll_total)
            log_terms.append(np.log(type_prior[t]) + log_p_x_given_t)

        # log sum_t P(t|class) * P(x|t,class)
        return self._log_mc_integrate(np.array(log_terms))

    def _log_likelihood_features(self, x, is_human, user_type, motor, visual, patience):
        """
        Compute log-likelihood of all features given:
          - is_human (0/1)
          - user_type
          - arrays motor, visual, patience (same shape)
        Returns vector of log-likelihoods (one per Monte Carlo sample).
        """
        (check_time,
         challenge_time,
         challenge_errors,
         mouse_entropy,
         click_speed,
         scroll_count,
         hist_success_rate,
         hist_count,
         ip_susp,
         device_trust) = x

        motor   = np.asarray(motor)
        visual  = np.asarray(visual)
        patience = np.asarray(patience)

        ll_total = np.zeros_like(motor, dtype=float)

        # ------------------------------------------------------------------
        # 1) Check_Time ~ approx Normal(mu_check, sigma_check),
        # with special behavior for script bots (user_type==3) via clipped mean.
        is_u0 = (user_type == 0)
        is_u1 = (user_type == 1)
        is_u2 = (user_type == 2)
        is_u3 = (user_type == 3)
        is_u4 = (user_type == 4)
        is_u5 = (user_type == 5)

        mu_check = (
            1.5
            - 0.5 * motor
            - 0.4 * patience
            + 0.3 * is_u2
            + 0.3 * is_u3
        )
        # For script bots, generator tends to compress check_time near ~0.15;
        # we approximate this by slightly pulling the mean toward 0.15.
        if is_u3:
            mu_check = 0.5 * mu_check + 0.5 * 0.15

        ll_check = norm.logpdf(check_time, loc=mu_check, scale=self.sigma_check)
        ll_total += ll_check

        # ------------------------------------------------------------------
        # 2) Challenge_Time ~ Normal(mu_chal, sigma_chal)
        mu_chal = (
            4.0
            - 0.8 * visual
            - 0.3 * patience
            + 0.3 * is_u2
            + 0.2 * is_u4
        )
        ll_chal = norm.logpdf(challenge_time, loc=mu_chal, scale=self.sigma_chal)
        ll_total += ll_chal

        # ------------------------------------------------------------------
        # 3) Challenge_Errors ~ Poisson(lambda_err)
        if not float(challenge_errors).is_integer():
            return np.full_like(motor, self.huge_neg)
        k_err = int(challenge_errors)
        log_lambda_err = (
            -0.7 * visual
            + 0.4 * is_u3
            - 0.6 * is_u5
        )
        lambda_err = np.clip(np.exp(log_lambda_err), 0.05, 4.0)
        ll_err = poisson.logpmf(k_err, mu=lambda_err)
        ll_total += ll_err

        # ------------------------------------------------------------------
        # 4) Mouse_Path_Entropy ~ Normal(mu_entropy, sigma_entropy)
        mu_entropy = (
            1.0
            + 0.6 * motor
            + 0.4 * patience
            - 0.8 * is_u3
            + 0.3 * is_u4
        )
        ll_entropy = norm.logpdf(mouse_entropy, loc=mu_entropy, scale=self.sigma_entropy)
        ll_total += ll_entropy

        # ------------------------------------------------------------------
        # 5) Click_Speed ~ Normal(mu_click, sigma_click)
        mu_click = (
            1.5
            + 0.4 * motor
            - 0.3 * patience
            + 0.5 * is_u3
            + 0.2 * is_u5
        )
        ll_click = norm.logpdf(click_speed, loc=mu_click, scale=self.sigma_click)
        ll_total += ll_click

        # ------------------------------------------------------------------
        # 6) Scroll_Count ~ Poisson(lambda_scroll)
        if not float(scroll_count).is_integer():
            return np.full_like(motor, self.huge_neg)
        k_scroll = int(scroll_count)
        log_lambda_scroll = (
            1.0
            + 1.0 * patience
            + 0.2 * is_u2
            - 0.7 * is_u3
        )
        lambda_scroll = np.clip(np.exp(log_lambda_scroll), 0.2, 20.0)
        ll_scroll = poisson.logpmf(k_scroll, mu=lambda_scroll)
        ll_total += ll_scroll

        # ------------------------------------------------------------------
        # 7) History_Captcha_Count ~ approx Normal(mu_hist, sigma_hist)
        if hist_count < 0:
            return np.full_like(motor, self.huge_neg)
        mu_hist = (
            5.0
            + 4.0 * is_u0
            + 3.0 * is_u1
            + 1.0 * is_u2
            + 5.0 * is_u4
            + 7.0 * is_u5
        )
        ll_hist = norm.logpdf(hist_count, loc=mu_hist, scale=self.sigma_hist)
        ll_total += ll_hist

        # ------------------------------------------------------------------
        # 8) History_Captcha_Success: Binomial(hist_trials, p_success)
        hist_trials = max(int(hist_count), 1)
        successes_obs = int(round(hist_success_rate * hist_trials))
        successes_obs = max(0, min(successes_obs, hist_trials))

        success_logit = (
            0.4 * visual
            + 0.3 * motor
            + 0.1 * patience
            - 0.4 * is_u3
            + 0.3 * is_u0
        )
        p_success = 1.0 / (1.0 + np.exp(-success_logit))
        p_success = np.clip(p_success, 1e-4, 1.0 - 1e-4)
        ll_success = binom.logpmf(successes_obs, n=hist_trials, p=p_success)
        ll_total += ll_success

        # ------------------------------------------------------------------
        # 9) IP_Suspicious: Bernoulli
        ip_logit = (
            -0.7 * is_human
            + 0.3 * is_u3
            + 0.2 * is_u5
            + 0.1 * is_u4
        )
        p_ip = 1.0 / (1.0 + np.exp(-ip_logit))
        p_ip = np.clip(p_ip, 1e-4, 1.0 - 1e-4)
        if ip_susp == 1:
            ll_ip = np.log(p_ip)
        else:
            ll_ip = np.log(1.0 - p_ip)
        ll_total += ll_ip

        # ------------------------------------------------------------------
        # 10) Device_Trust_Score ~ Normal(mu_dev, sigma_dev)
        mu_dev = (
            0.8 * is_human
            - 0.5 * is_u3
            - 0.2 * is_u5
            + 0.3 * is_u0
        )
        ll_dev = norm.logpdf(device_trust, loc=mu_dev, scale=self.sigma_dev)
        ll_total += ll_dev

        return ll_total