#!/usr/bin/env python

import argparse
import os
import glob

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

E1 = 4.0
E2 = 4.5

def v(x, y):
    return 3.0*np.e**(-x**2 + -(y-1.0/3)**2) - E1*np.e**(-x**2-(y-5.0/3)**2)\
         - E2*np.e**(-(x-1.0)**2 + -y**2) - E2*np.e**(-(x+1.0)**2 + -y**2)\
         + 2.0/10.0 * (x**4 + (y-1.0/3)**4)         

def f(v):
    x = v[0]
    y = v[1]
    return 3.0*np.e**(-x**2 + -(y-1.0/3)**2) - E1*np.e**(-x**2-(y-5.0/3)**2)\
         - E2*np.e**(-(x-1.0)**2 + -y**2) - E2*np.e**(-(x+1.0)**2 + -y**2)\
         + 2.0/10.0 * (x**4 + (y-1.0/3)**4)    

def F(v):
    x = v[0]
    y = v[1]
    Vx = -3.0*2.0*x * np.e**(-x**2 + -(y-1.0/3)**2) + E1*2.0*x*np.e**(-x**2-(y-5.0/3)**2)\
        + E2 * 2.0 * ((x-1.0)*np.e**(-(x-1.0)**2 + -y**2) + (x+1.0)*np.e**(-(x+1.0)**2 + -y**2))\
        + 8.0/10 * x**3

    Vy = -3.0 * 2.0 * (y-1.0/3)*np.e**(-x**2-(y-1.0/3)**2) + E1 * 2.0 * (y-5.0/3)*np.e**(-x**2-(y-5.0/3)**2)\
        + E2*2.0*y * (np.e**(-(x-1.0)**2 + -y**2) + np.e**(-(x+1.0)**2 + -y**2))\
        + 8.0/10 * (y-1.0/3)**3

    return np.array([Vx, Vy])                   

def euler_propagation(v0, N, dt, beta):
    dim = v0.shape[0]
    sigma  = np.sqrt(2.0*dt/beta)
    dW = np.random.randn(N, dim)
    v = np.zeros((N, dim))
    v[0, :] = v0
    for i in range(N-1):
        v[i+1, :] = v[i, :] + -F(v[i, :])*dt + sigma * dW[i, :]
    return v       

if __name__=="__main__":
    descr_str = 'Brownian dynamics simulation in 2d potential'
    parser = argparse.ArgumentParser(description=descr_str)
    help_str = 'length of trajectory'
    parser.add_argument('N', type=int, help=help_str)
    help_str = 'inverse temperature'
    parser.add_argument('beta', type=float, help=help_str)
    help_str = 'filename'
    parser.add_argument('-outfile', type=str, help=help_str)
    args=parser.parse_args()   

    beta = args.beta
    N = args.N
    
    dt = 1e-3
    S0 = np.array([-1.0, 0.0])
    S = euler_propagation(S0, N, dt, beta)

    if args.outfile:
        outfile = args.outfile
        if os.path.splitext(outfile)[1]=='.npy':
            np.save(outfile, S)

    nx = 100
    lx = -1.8
    Lx = 1.8
    x = np.linspace(lx, Lx, nx)
    ly = -1.2
    Ly = 2.2
    y = np.linspace(ly, Ly, nx)
    X, Y = np.meshgrid(x, y)
    Z = v(X, Y)
    levels = np.array([-3.5, -3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0])
    fig = plt.figure(0)
    ax = fig.add_subplot(111)
    ax.set_xlim(lx, Lx)
    ax.set_ylim(ly, Ly)
    ax.set_xlabel(r"$x$")
    ax.set_ylabel(r"$y$")
    cs = ax.contour(X, Y, Z, levels)
    ax.plot(S[:, 0], S[:, 1], c='k')
    plt.clabel(cs, inline=1, fontsize=10)
    plt.show()
