# -*- coding: utf-8 -*-
"""
Created on Sun May 28 15:44:00 2023

@author: Martin

This is a script to plot the xy-electric field vector of a plane wave propagating in z-direction
for a given Jones vector j = (Ex, Ey * e^(i*delta)), where delta is the phase shift between both electric field components. Then you can apply a quarter wave late and a half wave plate to this polarization state.

phi:    angle of the QWP
alpha:  angle of the HWP
M_rot:  2D-rotation matrix
Ex,Ey:  electric field components
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
%matplotlib auto

def Lambda_viertel(phi):
    phi=phi*np.pi/180
    M           = np.matrix([[1,0],[0,1j]])
    M_rot       = np.matrix([[np.cos(phi),np.sin(phi)],[-np.sin(phi),np.cos(phi)]])
    M_rot_inv   = np.matrix([[np.cos(phi),-np.sin(phi)],[np.sin(phi),np.cos(phi)]])
    return np.matmul(M_rot_inv,np.matmul(M,M_rot))

def Lambda_halbe(phi):
    phi=phi*np.pi/180
    M           = np.matrix([[1,0],[0,-1]],dtype=complex)
    M_rot       = np.matrix([[np.cos(phi),np.sin(phi)],[-np.sin(phi),np.cos(phi)]])
    M_rot_inv   = np.matrix([[np.cos(phi),-np.sin(phi)],[np.sin(phi),np.cos(phi)]])
    return np.matmul(M_rot_inv,np.matmul(M,M_rot))

def x(Ex,delta):
    angle = np.linspace(0,2*np.pi,100)
    return Ex*np.cos(angle-delta)

def y(Ey,delta):
    angle = np.linspace(0,2*np.pi,100)
    return Ey*np.cos(angle-delta)

def plot_Jones(j):
    Ex_out = np.abs(j[0]).item()
    Ey_out = np.abs(j[1]).item()
    deltax_out = np.angle(j[0]).item()
    deltay_out = np.angle(j[1]).item()
    # print(Ex_out,Ey_out,deltax_out,deltay_out)
    return x(Ex_out,deltax_out),y(Ey_out,deltay_out)

use_lamb_halbe = False
use_lamb_viertel = False

Ex = 1
Ey = 1
delta = 0
phi = 0
alpha = 0

# make figure
fig = plt.figure()
plt.subplots_adjust(bottom=0.3)


# initialize the (empty) plots 
Jones_in,     = plt.plot([],[], label="Jones in")
Jones_out,    = plt.plot([],[], label='Jones ($\\lambda/4$)')
Jones_out2,   = plt.plot([],[], label="Jones ($\\lambda/2$)")
phi_indicator,= plt.plot([],[], "--", c="tab:orange")



plt.xlabel("Polarisationsachse $x$")
plt.xlabel("Polarisationsachse $y$")
plt.xlim(-1.5,1.5)
plt.ylim(-1.5,1.5)
plt.legend(loc="upper left")
plt.gca().set_aspect("equal")


# update the plot
def update_values(val=None):
     
    global use_lamb_halbe
    Ex = Ex_slider.val
    Ey = Ey_slider.val
    delta = delta_slider.val
    phi = phi_slider.val
    alpha = alpha_slider.val
    
    j_in = np.matrix([[Ex],[Ey*np.exp(1j*delta*np.pi/180)]])
    j_out = np.matmul(Lambda_viertel(phi),j_in)
    j_out2 = np.matmul(Lambda_halbe(alpha),j_out)
    
    if use_lamb_halbe and not use_lamb_viertel:
        j_out2 = np.matmul(Lambda_halbe(alpha),j_in)
    
    xj_in, yj_in = plot_Jones(j_in)
    xj_out, yj_out = plot_Jones(j_out)
    xj_out2, yj_out2 = plot_Jones(j_out2)
    
    # set Data values
    Jones_in.set_xdata(xj_in)
    Jones_in.set_ydata(yj_in)
    
    
    if use_lamb_halbe:
        Jones_out2.set_xdata(xj_out2)
        Jones_out2.set_ydata(yj_out2)
    
    phi *= np.pi/180
    if use_lamb_viertel:
        Jones_out.set_xdata(xj_out)
        Jones_out.set_ydata(yj_out)
        radius = np.sqrt(Ex**2+Ey**2)
        phi_indicator.set_xdata([radius*np.cos(phi),radius*np.cos(phi+np.pi)])
        phi_indicator.set_ydata([radius*np.sin(phi),radius*np.sin(phi+np.pi)])
    
    fig.canvas.draw_idle()
    
def use_lambda_halbe(val=None):
    global use_lamb_halbe 
    
    if use_lamb_halbe:
        use_lamb_halbe = False
        Jones_out2.set_xdata([])
        Jones_out2.set_ydata([])
        button_l2.color = "0.85"
    
    else:
        use_lamb_halbe = True
        button_l2.color = "green"
        update_values()

def use_lambda_viertel(val=None):
    global use_lamb_viertel 
    
    if use_lamb_viertel:
        use_lamb_viertel = False
        Jones_out.set_xdata([])
        Jones_out.set_ydata([])
        phi_indicator.set_xdata([])
        phi_indicator.set_ydata([])
        button_l4.color = "0.85"
    
    else:
        use_lamb_viertel = True
        button_l4.color = "green"
        print("l/4")
        update_values()
    
# make Slider locations
axwave1 = plt.axes([0.1, 0.15, 0.3, 0.03])
axwave2 = plt.axes([0.1, 0.1, 0.3, 0.03])
axwave3 = plt.axes([0.1, 0.05, 0.3, 0.03])
axwave4 = plt.axes([0.6, 0.15, 0.3, 0.03])
axwave5 = plt.axes([0.6, 0.1, 0.3, 0.03])
axwave6 = plt.axes([0.625, 0.05, 0.1, 0.03])
axwave7 = plt.axes([0.775, 0.05, 0.1, 0.03])

# make Sliders
Ex_slider = Slider(axwave1, "$E_x$", 1,2, valinit=Ex, valfmt='%.1f')
Ey_slider = Slider(axwave2, "$E_y$", 1,2, valinit=Ey, valfmt='%.1f')
delta_slider = Slider(axwave3, '$\\delta$', -90,90, valinit=delta, valfmt='%.1f')
phi_slider = Slider(axwave4, "$\\lambda/4$-Winkel", -90, 90, valinit=phi, valfmt='%.1f')
alpha_slider = Slider(axwave5, "$\\lambda/2$-Winkel", -90,90, valinit=alpha, valfmt="%.1f")


button_l4 = Button(axwave6, 'Use $\\lambda/4$', hovercolor='0.975')
button_l2 = Button(axwave7, 'Use $\\lambda/2$', hovercolor='0.975')
button_l2.on_clicked(use_lambda_halbe)
button_l4.on_clicked(use_lambda_viertel)

Ex_slider.on_changed(update_values)
Ey_slider.on_changed(update_values)
delta_slider.on_changed(update_values)
phi_slider.on_changed(update_values)
alpha_slider.on_changed(update_values)

update_values()
