
# This file is part of a 1-day mintoring teaching program for highschool students.
# Experimental code! No warranty.
# Author: Tim Hempel, FU Berlin
# Date: 2022

import numpy as np
import matplotlib.pyplot as plt
from deeptime.data import triple_well_2d, custom_sde
import logging, signal
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from deeptime.data import triple_well_1d

# game options
goal = np.array([0., 1.5]) # goal state location
abstol = .1 # tolerance around goal state 
n_steps = 20 # recorded simulation steps
traj_start = [-1., 0] # trajectory starting point

# potential grid span
xymin, xymax = -3, 3

# for plotting
xmin, xmax = -2.5, 2.5
ymin, ymax = -1.8, 2.7
image = plt.imread('Gfp_and_fluorophore.png') 


def gaussian_energy(x, amplitude=15, origin=[0., 0.], sigma=1):
    """ Potential energy term of Gaussian as function of coordinate
    """
    e = amplitude * np.exp(- 1 / (2 * sigma**2) * np.linalg.norm(x-np.asarray(origin)[None], axis=-1)**2)
    return e

def gaussian_force(x, amplitude=15, origin=[0., 0.], sigma=1):
    """ Force of Gaussian acting on a particle, as function of coordinate
    """
    a = (amplitude/sigma**2) * np.exp(- 1 / (2 * sigma**2) * np.linalg.norm(x-np.asarray(origin)[None], axis=-1)**2)
    b = (x-np.asarray(origin))
    f = a * b
    return f


_triplewell = triple_well_2d(n_steps=100)
def triplewell_energy(x):
    """ Potential energy of triplewell; wrapper for deeptime
    """
    return _triplewell.potential(x)

def triplewell_force(x):
    """ Force of tripewell acting on particle, wrapper for deeptime
    """
    return _triplewell.f(0, x)
    

def harmonic_sphere_energy(x, origin=[0., 0.], radius=.1, k=1.):
    dist_to_origin = np.linalg.norm(x - np.asarray(origin)[None], axis=-1)
    dist_to_sphere = dist_to_origin - radius
    energy = np.zeros((len(x),))
    ixs = np.argwhere(dist_to_sphere > 0)[:, 0]
    energy[ixs] = 0.5 * k * dist_to_sphere[ixs] ** 2
    return energy


def harmonic_sphere_force(x, origin=[0., 0.], radius=.1, k=1.):
    dist_to_origin = np.linalg.norm(x - np.asarray(origin)[None])
    dist_to_sphere = dist_to_origin - radius
    if dist_to_sphere > 0:
        return -k * dist_to_sphere * (np.array(x) - np.array(origin)) / dist_to_origin
    else:
        return [0., 0.]
    
    
def total_energy(x, gaussians=[], **sphere_kw):
    #e = harmonic_sphere_energy(x, **sphere_kw)
    e = triplewell_energy(x)
    for g_kw in gaussians:
        e += gaussian_energy(x, **g_kw)
        
    return e

def total_force(x, gaussians=[], **sphere_kw):
    #f = harmonic_sphere_force(x, **sphere_kw)
    f = triplewell_force(x)
    #print('harmonic', f.shape)
    for g_kw in gaussians:
        #print('gauss')
        #print(gaussian_force(x, **g_kw).shape)
        f += gaussian_force(x, **g_kw)
    return f

_triplewell_1d = triple_well_1d(h=1e-3, n_steps=100)

def triplewell_1d_energy(x):
    return _triplewell_1d.potential(x)

def triplewell_1d_force(x):
    return _triplewell_1d._impl.rhs(1., x)

def total_energy_1d(x, gaussians=[], **sphere_kw):
    e = triplewell_1d_energy(x)
    for g_kw in gaussians:
        e += gaussian_energy(x[None].T, **g_kw)
        
    return e

def total_force_1d(x, gaussians=[], **sphere_kw):
    f = triplewell_1d_force(x)
    for g_kw in gaussians:
        f += gaussian_force(x, **g_kw)
    return f        

class DelayedKeyboardInterrupt:
    """ Enables smooth execution in juyter notebook
    """

    def __enter__(self):
        self.signal_received = False
        self.old_handler = signal.signal(signal.SIGINT, self.handler)
                
    def handler(self, sig, frame):
        self.signal_received = (sig, frame)
        logging.debug('SIGINT received. Delaying KeyboardInterrupt.')
    
    def __exit__(self, type, value, traceback):
        signal.signal(signal.SIGINT, self.old_handler)
        if self.signal_received:
            self.old_handler(*self.signal_received)


def pltsin(ax, plot_arr, origin, image_box=None, **kw):
    """
    continouosly updating plot fct
    """
    
    if ax.lines:
        ax.lines[0].set_xdata(plot_arr[:, 0])
        ax.lines[0].set_ydata(plot_arr[:, 1])
        ax.patches[1].set_center(plot_arr[-1])

        if image_box is not None:
            ax.artists[0].remove()
    else:
        ax.plot(*plot_arr.T, 'k', **kw)
        ax.add_patch(plt.Circle(plot_arr[-1], .1))

    if image_box is not None:
        ab = AnnotationBbox(image_box, plot_arr[-1], frameon=False)
        ax.add_artist(ab)
        
    ax.figure.canvas.draw()
    
    

def plot_landscape(_potential_landscape, xy, ax=None, image_box=None, levels=10):
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
    potential_landscape = _potential_landscape.reshape((xy.shape[0], xy.shape[0]))

    
    ax.contourf(xy, xy, potential_landscape, levels=np.linspace(-4, 5, levels), cmap='Greys')
    goal_circl = plt.Circle(goal, np.sqrt(abstol), fc='none', ec='red', linewidth=3)
    ax.add_patch(goal_circl)
    if image_box:
        ab = AnnotationBbox(image_box, goal, frameon=False)
        ax.add_artist(ab)
    else:
        ax.annotate('ZIEL\n', goal)

    
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)

    ax.scatter(*traj_start, marker='x', s=100, linewidth=3)
    ax.annotate('START\n', traj_start)
    
def plot_triplewell_1d(x, gaussians, ax=None, show_constituents=False):
    if ax is None:
        fig, ax = plt.subplots()
    ax.plot(x, total_energy_1d(x, gaussians=gaussians), 'magenta', linewidth=3, label='sum')
    
    
    if show_constituents:
        ax.plot(x, triplewell_1d_energy(x), 'k--', label='original')

        for n, g_kw in enumerate(gaussians):
            plt.plot(x, gaussian_energy(x[None].T, **g_kw), 'r:', label=f'bias {n}')
        plt.legend()
    return ax

def pltsin1d(ax, plot_arr, origin, gaussians=[], image_box=None, **kw):
    """
    continouosly updating plot fct
    """
    
    if len(ax.lines) > 1:
        ax.lines[1].set_xdata(plot_arr)
        ax.lines[1].set_ydata(total_energy_1d(np.squeeze(plot_arr), gaussians=gaussians))
        ax.patches[0].set_center((plot_arr[-1][0], total_energy_1d(plot_arr[-1][0], gaussians=gaussians)))
        
        if image_box is not None:
            ax.artists[0].remove()
    else:
        ax.plot(plot_arr, total_energy_1d(np.squeeze(plot_arr), gaussians=gaussians), 'k', **kw)
        ax.add_patch(plt.Circle(
                (plot_arr[-1][0], total_energy_1d(plot_arr[-1][0], gaussians=gaussians)), .1))
        
    if image_box is not None:
        ab = AnnotationBbox(image_box, plot_arr[-1], frameon=False)
        ax.add_artist(ab)
        
    ax.figure.canvas.draw()
    

def check_found(x):
    # solution encoded
    # asdf = 'Antwort.'
    # print(asdf.encode().hex())
    if np.linalg.norm(x - goal) < abstol:
        print(bytes.fromhex('4a41574f484c212044696520416e74776f7274206973742034322e').decode('ascii'))
        return True


