"""
White noise
===========

Random walk
-----------

Imagine that you repeatedly toss a coin. If it lands heads, you move one step
forward. If it lands tails you move one step back. This is a random walk. 
The function below simulates this process. 

"""

import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from pylab import rcParams
matplotlib.font_manager.FontProperties(family='Helvetica',size=11)
rcParams['figure.figsize'] = 14/2.54, 14/2.54


def random_walk(w0,n):
    # n step random walk
    w = w0
    positions = []
    for k in range(n):
        positions.append(w)
        e = random.choice([-1, 1])
        w += e
        
    return np.array(positions)


##############################################################################
# Let's simulate this nine times and plot the output. 
# 

def plotOverTime(ax, w):
    n=len(w)
    t=np.arange(n)
    ax.plot(t,w, '-',color='k')
    ax.set_xlabel('Time: t')

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks(np.arange(0,n,step=25))
    ax.set_yticks(np.arange(-n,n+1,step=10))
    ax.set_xlim(0,n)
    ax.set_ylim(-2*np.sqrt(n),2*np.sqrt(n)) 

fig,axs=plt.subplots(3,3)
n=100
w=random_walk(0,n)
for axsr in axs:
    for j,ax in enumerate(axsr):
        w=random_walk(0,n)
        plotOverTime(ax, w)
        if j==0:
            ax.set_ylabel('Position: w')

##############################################################################
# Notice that the random walk doesn't move very far. Over 100 time steps, it typically remains 
# within 20 steps of its starting position.
# 
#
# Normal noise
# ------------
# 
# In the above example the noise is generated by a coin toss. These outcomes 
# are distributed accordining to, what is known as, the Bernoulli distribution.
# Other distributions can be used to generate noise. A common choice is the Normal 
# (bell-shaped) distribution. 


def normal_walk(w0,n,std):
    # n step random walk with Normal distribution
    w = w0
    positions = []
    for k in range(n):
        positions.append(w)
        e = np.random.normal(0, std) 
        w += e
        
    return np.array(positions)

fig,axs=plt.subplots(3,3)
n=100
w=random_walk(0,n)
for axsr in axs:
    for j,ax in enumerate(axsr):
        w=normal_walk(0,n,1)
        plotOverTime(ax, w)
        if j==0:
            ax.set_ylabel('Position: w')

##############################################################################
# Notice that the step size changes each time step now, producing a less 
# jerky movement.
# 


