"""
Created on Fri May  5 15:00:12 2023

@author: Martin
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button

# no inline plotting
%matplotlib auto


# use complex numbers when n1 > n2, e.g. n1 = 1+0j
d = 10e-6
Lambda = 60e-6
lamb = 600e-9
z = 10
N = 7
x = np.linspace(-np.pi/3,np.pi/3,2000)

k = 2*np.pi/lamb

def sinc(d,k,x,z):
    return (np.sin(d/2*k*x/z)/(d/2*k*x/z))**2

def sinus(d,k,x,z,N,Lambda):
    return (np.sin(N*Lambda/2*k*x/z)/np.sin(Lambda/2*k*x/z)*np.sin(d/2*k*x/z)/(d/2*k*x/z))**2/N**2



# make figure
fig, ax = plt.subplots(figsize=(8,6))
plt.subplots_adjust(bottom=0.25)
plt.xlabel("$x$")
plt.ylabel("$|u(x,y,z)|^2$")
plt.ylim(0,1.2)
plt.tick_params(axis='x',direction="in")
plt.tick_params(axis='y',direction="in")

# make plots
line_sinc, = ax.plot(x,np.ones(len(x)), label="Einhüllende")
line_beug, = ax.plot(x,np.ones(len(x)), label="Beugungsmuster")
plt.legend(loc="upper left")


# update the plot
def update_wave(val=None):
    d = d_slider.val * 1e-6
    Lambda = Lambda_slider.val * 1e-6
    lamb = lamb_slider.val * 1e-9
    z = z_slider.val
    N = int(N_slider.val)
    k = 2*np.pi/lamb

    line_sinc.set_ydata(sinc(d,k,x,z))
    line_beug.set_ydata(sinus(d,k,x,z,N,Lambda))
    ax.set_ylabel("$|u(x,y,z)|^2$")

    fig.canvas.draw_idle()


    
# make Sliders
axwave_d = plt.axes([0.15, 0.15, 0.3, 0.03])
axwave_Lambda = plt.axes([0.15, 0.1, 0.3, 0.03])
axwave_lamb = plt.axes([0.15, 0.05, 0.3, 0.03])
axwave_z = plt.axes([0.60, 0.15, 0.3, 0.03])
axwave_N = plt.axes([0.60, 0.1, 0.3, 0.03])

d_slider = Slider(axwave_d, 'Spaltbreite', 0.1, 100, valinit=d*1e6, valfmt='%.1f µm')
Lambda_slider = Slider(axwave_Lambda, 'Spaltabstand', 0, 600, valinit=Lambda*1e6, valfmt='%.1f µm')
lamb_slider = Slider(axwave_lamb, 'Wellenlänge $\\lambda$', 300, 1000, valinit=lamb*1e9, valfmt='%.d nm')
z_slider = Slider(axwave_z, '$z$-Position', 1, 100, valinit=z, valfmt='%.1f m')
N_slider = Slider(axwave_N, 'Anzahl $N$', 1, 20, valinit=N, valfmt='%.d')

d_slider.on_changed(update_wave)
Lambda_slider.on_changed(update_wave)
lamb_slider.on_changed(update_wave)
z_slider.on_changed(update_wave)
N_slider.on_changed(update_wave)

update_wave()
plt.show()