from turtle import pos
from scipy import stats
import math
import numpy as np
import matplotlib.pyplot as plt
import pprint

pp = pprint.PrettyPrinter(indent=4)


def main():
    belief = get_prior_belief()
    belief = update_belief(belief, {'size':0.7, 'correct':False})
    belief = update_belief(belief, {'size':0.8, 'correct':False})
    belief = update_belief(belief, {'size':0.2, 'correct':True})
    plot(belief)

def update_belief(prior, observation):
    """
    Come up with a new belief after having seen an observation
    """
    posterior = {}
    for ability_i, p_i in prior.items():
        p_observation = p_observation_given_ability(ability_i, observation)
        posterior[ability_i] = p_observation * p_i
    normalize(posterior)
    return posterior

def p_observation_given_ability(ability, observation):
    """
    Given: This function is given to you. It calculates how likely an observation
    (a patient getting a font size correct or incorrect) is
    """
    font_size = observation['size']
    p = p_correct_given_ability(ability, font_size)
    if observation['correct']:
        return p
    else:
        return 1-p

GUESS = 0.05
SLIP = 0.01
SCALING = 10
def p_correct_given_ability(ability, font_size):
    """
    This uses item response theory to model the chance that a 
    patient with a given ability will correctly identify a letter
    of a given size
    """
    difficulty = 1 - font_size
    p_no_slip = sigmoid(SCALING * (ability - difficulty))
    return SLIP * GUESS + (1-SLIP) * p_no_slip

def get_prior_belief():
    """
    This is the probability of ability to see in the general population.
    This chart comes from a historical study.
    """
    prior = {}
    for x in np.linspace(0, 1, num=101):
        prior[x]= stats.gumbel_r.pdf(1-x, 0.03, 0.3)
    normalize(prior)
    return prior

def normalize(belief):
    """
    Make sure a belief sums to one
    """
    total = 0
    for k, v in belief.items():
        total += v
    for k, v in belief.items():
        belief[k] /= total

EPSILON = 0.00001
def sigmoid(x):
    try:
        return 1 / (1 + math.exp(-x))
    except OverflowError as err:
        if x < 0:
            return EPSILON
        else:
            return 1 - EPSILON

def pretty_print(belief):
    for k in np.linspace(0, 1, num=101):
        print(f'{k:.2f}, {belief[k]:.5f}')
        
def plot(belief):
    xs = []
    ys = []
    for k, v in belief.items():
        xs.append(k)
        ys.append(v)
    plt.scatter(xs, ys, s=2) 
    plt.show()

if __name__ == '__main__':
    main()

