import numpy as np
from scipy.stats import norm, poisson, binom
from scipy.special import logsumexp
from tqdm import tqdm

class PerfectOracleBayesCaptcha:
    """
    Exact Bayes oracle (up to numerical quadrature) for generate_captcha_behavior_dataset_v1.

    - Enumerates user_type exactly per class.
    - Integrates traits (motor, visual, patience) with 3D Gauss–Hermite.
    - Integrates inner simulator noise (lognormal & logistic-normal) with 1D Gauss–Hermite.
    - Preserves point masses from clipping and the 'min' transform for script bots' Check_Time.

    sklearn-style API: fit (no-op), predict_proba, predict.
    """

    def __init__(self,
                 noise_level=0.5,
                 class_prior_human=0.6,
                 gh_nodes_traits=5,     # 7→9 for higher precision (3D grid => 343/729 nodes)
                 gh_nodes_noise=11,     # 11–25 is a good range for inner 1D integrals
                 tol=1e-9,
                 random_state=None):
        self.noise_level = noise_level
        self.prior_h = class_prior_human
        self.prior_b = 1 - class_prior_human
        self.gh_nodes_traits = gh_nodes_traits
        self.gh_nodes_noise = gh_nodes_noise
        self.tol = tol

        # User-type priors
        self.human_types = np.array([0, 1, 2])
        self.bot_types   = np.array([3, 4, 5])
        self.type_prob_h = {0: 0.4, 1: 0.4, 2: 0.2}
        self.type_prob_b = {3: 0.5, 4: 0.3, 5: 0.2}

        # Trait priors (means per type); independent normals with sd below
        self.motor_mean   = {0:1.2, 1:1.0, 2:0.8, 3:0.3, 4:1.1, 5:0.9}
        self.visual_mean  = {0:1.2, 1:1.1, 2:0.8, 3:0.2, 4:0.7, 5:1.4}
        self.patience_mean= {0:1.3, 1:0.2, 2:0.5, 3:-0.2, 4:0.8, 5:0.1}
        # sd = sqrt(0.5^2 + (0.2*noise)^2)
        self.trait_sd = np.sqrt(0.5**2 + (0.2*self.noise_level)**2)

        # Feature noise scales (match generator)
        self.sigma_check   = 0.3 + 0.3*self.noise_level
        self.sigma_chal    = 0.8 + 0.5*self.noise_level
        self.sigma_entropy = 0.3 + 0.3*self.noise_level
        self.sigma_click   = 0.4 + 0.3*self.noise_level
        self.sigma_hist    = 3.0
        self.sigma_dev     = 0.8 + 0.4*self.noise_level

        # Inner noise (on rates / logits)
        self.sigma_err_log = 0.3*self.noise_level                 # for Challenge_Errors rate log-noise
        self.sigma_scroll_log = 0.4*self.noise_level              # for Scroll_Count rate log-noise
        self.sigma_success_logit = 1.0 + 0.4*self.noise_level     # for History_Captcha_Success logit noise
        self.sigma_ip_logit = 1.0 + 0.5*self.noise_level          # for IP_Suspicious logit noise

        # Constant bounds / params used in simulator
        self.check_L = 0.05
        self.chal_L  = 0.3
        self.entropy_L = 0.1
        self.click_L = 0.2

        self.err_L, self.err_U = 0.05, 4.0                       # lambda clip for errors
        self.scroll_L, self.scroll_U = 0.2, 20.0                 # lambda clip for scrolls
        self.success_p_L, self.success_p_U = 0.05, 0.95          # clipping of success prob
        self.script_check_mu, self.script_check_sd = 0.15, 0.05  # min-with Normal for script bots

        # 1D Gauss–Hermite nodes/weights for inner noise
        self._x1, self._w1 = np.polynomial.hermite.hermgauss(self.gh_nodes_noise)
        # 1D nodes/weights for traits (reused per dimension)
        self._xt, self._wt = np.polynomial.hermite.hermgauss(self.gh_nodes_traits)

        # constants
        self._log_sqrt_pi  = 0.5*np.log(np.pi)
        self._log_pi_32    = 1.5*np.log(np.pi)  # (sqrt(pi))^3 in log

    # ------------------- sklearn API -------------------

    def fit(self, X, y=None):
        return self

    def predict_proba(self, X):
        X = np.asarray(X)
        out = np.zeros((X.shape[0], 2), dtype=float)
        for i, row in tqdm(enumerate(X)):
            log_px_h = self._log_p_x_given_class(row, is_human=1)
            log_px_b = self._log_p_x_given_class(row, is_human=0)
            log_h = np.log(self.prior_h) + log_px_h
            log_b = np.log(self.prior_b) + log_px_b
            m = max(log_h, log_b)
            ph = np.exp(log_h - m)
            pb = np.exp(log_b - m)
            s = ph + pb
            out[i, 0] = pb/s
            out[i, 1] = ph/s
        return out

    def predict(self, X):
        return (self.predict_proba(X)[:, 1] >= 0.5).astype(int)

    # ------------------- top-level class likelihood -------------------

    def _log_p_x_given_class(self, x, is_human):
        """log P(x | class) = log sum_t P(t|class) * ∫ p(x|t,traits)p(traits|t) dtraits"""
        types = self.human_types if is_human==1 else self.bot_types
        type_probs = self.type_prob_h if is_human==1 else self.type_prob_b

        log_terms = []
        for t in types:
            log_integral = self._log_integral_over_traits(x, is_human, t)
            log_terms.append(np.log(type_probs[t]) + log_integral)
        # log-sum over user types
        return logsumexp(log_terms)

    # ------------------- 3D traits integral via GH -------------------

    def _log_integral_over_traits(self, x, is_human, user_type):
        μm = self.motor_mean[user_type]
        μv = self.visual_mean[user_type]
        μp = self.patience_mean[user_type]
        sd = self.trait_sd

        # 3D GH quadrature: nodes/weights per dim
        xs, ws = self._xt, self._wt
        log_sum = -np.inf

        # Iterate over product grid (kept explicit for clarity)
        # value = (1/(sqrt(pi)^3)) * sum_{i,j,k} w_i w_j w_k * L(m,v,p)
        for i, (xi, wi) in enumerate(zip(xs, ws)):
            m  = μm + sd * np.sqrt(2) * xi
            log_wi = np.log(wi)
            for j, (xj, wj) in enumerate(zip(xs, ws)):
                v  = μv + sd * np.sqrt(2) * xj
                log_wij = log_wi + np.log(wj)
                for k, (xk, wk) in enumerate(zip(xs, ws)):
                    p  = μp + sd * np.sqrt(2) * xk
                    log_wijk = log_wij + np.log(wk)
                    logL = self._log_likelihood_given_traits(x, is_human, user_type, m, v, p)
                    log_sum = np.logaddexp(log_sum, log_wijk + logL)

        # subtract log((sqrt(pi))^3)
        return log_sum - self._log_pi_32

    # ------------------- per-sample, given traits -------------------

    def _log_likelihood_given_traits(self, x, is_human, t, motor, visual, patience):
        (check_time,
         challenge_time,
         challenge_errors,
         mouse_entropy,
         click_speed,
         scroll_count,
         hist_success_rate,
         hist_count,
         ip_susp,
         device_trust) = x

        # sanity for discrete features
        if hist_count < 0 or scroll_count < 0 or challenge_errors < 0:
            return -np.inf

        # 1) Check_Time
        # base_check = 1.5 -0.5*m -0.4*p +0.3*(t==2) +0.3*(t==3)
        mu_check = 1.5 - 0.5*motor - 0.4*patience + 0.3*(t==2) + 0.3*(t==3)
        if t == 3:  # script bots: Y = min(Y0, S), then clip at L
            log_like_check = self._log_pdf_min_normals_clipped(check_time,
                                                               mu1=mu_check, sd1=self.sigma_check,
                                                               mu2=self.script_check_mu, sd2=self.script_check_sd,
                                                               L=self.check_L)
        else:  # non-script: clip normal at L
            log_like_check = self._log_pdf_normal_with_point_at_L(check_time, mu_check,
                                                                  self.sigma_check, L=self.check_L)

        # 2) Challenge_Time: Normal + clip at 0.3
        mu_chal = 4.0 - 0.8*visual - 0.3*patience + 0.3*(t==2) + 0.2*(t==4)
        log_like_chal = self._log_pdf_normal_with_point_at_L(challenge_time, mu_chal,
                                                             self.sigma_chal, L=self.chal_L)

        # 3) Challenge_Errors: Poisson with lognormal rate and clipping on lambda
        # log_lambda = -0.7*visual + 0.4*(t==3) - 0.6*(t==5) + N(0, sigma_err_log)
        mu_err = -0.7*visual + 0.4*(t==3) - 0.6*(t==5)
        k_err = int(round(challenge_errors))
        if abs(k_err - challenge_errors) > self.tol:
            return -np.inf
        log_like_err = self._log_pois_lognorm_clipped(k_err, mu_err, self.sigma_err_log,
                                                      L=self.err_L, U=self.err_U)

        # 4) Mouse_Path_Entropy: Normal + clip at 0.1
        mu_ent = 1.0 + 0.6*motor + 0.4*patience - 0.8*(t==3) + 0.3*(t==4)
        log_like_ent = self._log_pdf_normal_with_point_at_L(mouse_entropy, mu_ent,
                                                            self.sigma_entropy, L=self.entropy_L)

        # 5) Click_Speed: Normal + clip at 0.2
        mu_clk = 1.5 + 0.4*motor - 0.3*patience + 0.5*(t==3) + 0.2*(t==5)
        log_like_clk = self._log_pdf_normal_with_point_at_L(click_speed, mu_clk,
                                                            self.sigma_click, L=self.click_L)

        # 6) Scroll_Count: Poisson with lognormal rate, clipped [0.2, 20]
        mu_scroll = 1.0 + 1.0*patience + 0.2*(t==2) - 0.7*(t==3)
        k_scroll = int(round(scroll_count))
        if abs(k_scroll - scroll_count) > self.tol:
            return -np.inf
        log_like_scroll = self._log_pois_lognorm_clipped(k_scroll, mu_scroll, self.sigma_scroll_log,
                                                         L=self.scroll_L, U=self.scroll_U)

        # 7) History_Captcha_Count: discretized Normal (clip<0 to 0; int via floor)
        mu_hist = 5.0 + 4.0*(t==0) + 3.0*(t==1) + 1.0*(t==2) + 5.0*(t==4) + 7.0*(t==5)
        log_like_hist = self._log_discrete_from_normal_intfloor(hist_count, mu_hist, self.sigma_hist)

        # 8) History_Captcha_Success: Binomial(H, p), p = clip(sigmoid(μ+σZ), [0.05,0.95])
        H = max(int(hist_count), 1)  # simulator uses max(hist_count,1)
        s_obs = int(round(hist_success_rate * H))
        s_obs = min(max(s_obs, 0), H)
        mu_succ = 0.4*visual + 0.3*motor + 0.1*patience - 0.4*(t==3) + 0.3*(t==0)
        log_like_succ = self._log_binom_logistic_normal(s_obs, H, mu_succ, self.sigma_success_logit,
                                                        pL=self.success_p_L, pU=self.success_p_U)

        # 9) IP_Suspicious: Bernoulli with p = clip(sigmoid(μ+σZ), [0.05,0.95])
        mu_ip = -0.7*is_human + 0.3*(t==3) + 0.2*(t==5) + 0.1*(t==4)
        y_ip = int(round(ip_susp))
        if y_ip not in (0,1):
            return -np.inf
        log_like_ip = self._log_bernoulli_logistic_normal(y_ip, mu_ip, self.sigma_ip_logit,
                                                          pL=0.05, pU=0.95)

        # 10) Device_Trust_Score: Normal
        mu_dev = 0.8*is_human - 0.5*(t==3) - 0.2*(t==5) + 0.3*(t==0)
        log_like_dev = norm.logpdf(device_trust, loc=mu_dev, scale=self.sigma_dev)

        return (log_like_check + log_like_chal + log_like_err + log_like_ent +
                log_like_clk + log_like_scroll + log_like_hist + log_like_succ +
                log_like_ip + log_like_dev)

    # ------------------- likelihood primitives -------------------

    def _log_pdf_normal_with_point_at_L(self, y, mu, sd, L):
        """Clipped-from-below normal: mass at L equals Φ((L-mu)/sd); density above L is Normal."""
        if y <= L + self.tol:
            mass = norm.cdf((L - mu)/sd)
            return np.log(max(mass, 1e-300))
        else:
            return norm.logpdf(y, loc=mu, scale=sd)

    def _log_pdf_min_normals_clipped(self, y, mu1, sd1, mu2, sd2, L):
        """
        Z = min(Y1, Y2), Yi ~ Normal; then clip Z below at L.
        For y > L: f_Z(y) = f1(y)*(1-F2(y)) + f2(y)*(1-F1(y))
        Mass at L: P(Z <= L) = 1 - (1-F1(L))*(1-F2(L))
        """
        F1L = norm.cdf((L - mu1)/sd1)
        F2L = norm.cdf((L - mu2)/sd2)
        if y <= L + self.tol:
            mass = 1.0 - (1.0 - F1L)*(1.0 - F2L)
            return np.log(max(mass, 1e-300))
        else:
            f1 = np.exp(norm.logpdf(y, loc=mu1, scale=sd1))
            f2 = np.exp(norm.logpdf(y, loc=mu2, scale=sd2))
            F1 = norm.cdf((y - mu1)/sd1)
            F2 = norm.cdf((y - mu2)/sd2)
            val = f1*(1.0 - F2) + f2*(1.0 - F1)
            return np.log(max(val, 1e-300))

    # ---- Poisson with clipped lognormal rate via GH expectation ----

    def _log_pois_lognorm_clipped(self, k, mu_log, sd_log, L, U):
        """
        log E_Z[ Pois(k | lambda(Z)) ] where lambda(Z)=clip(exp(mu_log + sd_log*Z), L, U),
        Z ~ N(0,1). Use Gauss-Hermite: 1/sqrt(pi) sum w_i * exp(log_pmf_i)
        """
        xs, ws = self._x1, self._w1
        logs = []
        for xi, wi in zip(xs, ws):
            z = np.sqrt(2)*xi
            lam = np.exp(mu_log + sd_log*z)
            lam = min(max(lam, L), U)
            logs.append(np.log(wi) + poisson.logpmf(k, mu=lam))
        return logsumexp(logs) - self._log_sqrt_pi

    # ---- Bernoulli / Binomial with logistic-normal via GH expectation ----

    def _log_bernoulli_logistic_normal(self, y, mu_logit, sd_logit, pL, pU):
        xs, ws = self._x1, self._w1
        logs = []
        for xi, wi in zip(xs, ws):
            z = np.sqrt(2)*xi
            p = 1.0/(1.0 + np.exp(-(mu_logit + sd_logit*z)))
            p = min(max(p, pL), pU)
            lp = np.log(p) if y==1 else np.log(1.0-p)
            logs.append(np.log(wi) + lp)
        return logsumexp(logs) - self._log_sqrt_pi

    def _log_binom_logistic_normal(self, s, n, mu_logit, sd_logit, pL, pU):
        xs, ws = self._x1, self._w1
        logs = []
        for xi, wi in zip(xs, ws):
            z = np.sqrt(2)*xi
            p = 1.0/(1.0 + np.exp(-(mu_logit + sd_logit*z)))
            p = min(max(p, pL), pU)
            logs.append(np.log(wi) + binom.logpmf(s, n=n, p=p))
        return logsumexp(logs) - self._log_sqrt_pi

    # ---- Discrete floor(int) from normal with clip-to-0 ----

    def _log_discrete_from_normal_intfloor(self, k, mu, sd):
        """P(K=k) where K = max(floor(Y),0), Y~N(mu, sd^2)."""
        if k < 0: 
            return -np.inf
        if k == 0:
            # P(Y < 1)
            c = (1 - mu)/sd
            p = norm.cdf(c)
        else:
            a = (k   - mu)/sd
            b = (k+1 - mu)/sd
            p = norm.cdf(b) - norm.cdf(a)
        return np.log(max(p, 1e-300))