"""
    Copyright (c) 2019 Lumerical Inc. """

######## IMPORTS ########
# General purpose imports
import os
import scipy as sp

## We try to import autograd (and use it's numpy wrapper) but fall back to numpy if autograd is not available
## This is for later functionality, currently not used
try:
    import autograd.numpy as np
    from autograd import jacobian
except ImportError:
    import numpy as np

# Optimization specific imports
from lumopt.utilities.load_lumerical_scripts import load_from_lsf
from lumopt.utilities.wavelengths import Wavelengths
from lumopt.geometries.polygon import FunctionDefinedPolygon
from lumopt.figures_of_merit.modematch import ModeMatch
from lumopt.optimizers.generic_optimizers import ScipyOptimizers
from lumopt.optimization import Optimization
from lumopt.utilities.materials import Material

from numpy.random import rand

def runGratingOptimization(bandwidth_in_nm, etch_depth, n_grates, params, working_dir):

    bounds = [(-4,3),      #< Starting position (in um)
              (0,0.05),    #< Scaling parameter R
              (1.5,3),     #< Parameter a
              (0,2)]       #< Parameter b

    def grating_params_pos(params):
        lambda_c = 1.55e-6
        F0 = 0.95

        height  = 220e-9
        y0      = 0
        x_begin =-5.1e-6
        x_end   = 22e-6

        y3 = y0+height
        y1 = y3-etch_depth

        x_start = params[0]*1e-6  #< First parameter is the starting position
        R  = params[1]*1e6        #< second parameter (unit is 1/um)
        a  = params[2]            #< Third parameter (dim-less)
        b  = params[3]            #< Fourth parameter (dim-less)

        x0 = x_start
  
        verts = np.array( [[x_begin,y0],[x_begin,y3],[x0,y3],[x0,y1]] )       

        ## Iterate over all but the last tooth
        for i in range(n_grates-1):
            F = F0-R*(x0-x_start)
            Lambda = lambda_c / (a+F*b)
            x1 = x0 + (1-F)*Lambda    #< Width of the etched region
            x2 = x0 + Lambda          #< Rest of cell
            verts = np.concatenate((verts,np.array([[x1,y1],[x1,y3],[x2,y3],[x2,y1]])),axis=0)
            x0 = x2

        ## Last tooth is special
        F = F0-R*(x0-x_start)
        Lambda = lambda_c / (a+F*b)
        x1 = x0 + (1-F)*Lambda        #< Width of the etched region
        verts = np.concatenate((verts,np.array([[x1,y1],[x1,y3],[x_end,y3],[x_end,y0]])),axis=0) 

        return verts

    geometry = FunctionDefinedPolygon(func = grating_params_pos, initial_params = params, bounds = bounds, z = 0.0, depth = 110e-9, eps_out = 1.44 ** 2, eps_in = 3.47668 ** 2, edge_precision = 5, dx = 1e-5)

    ######## DEFINE FIGURE OF MERIT ########
    fom = ModeMatch(monitor_name = 'fom', mode_number = 1, direction = 'Backward', target_T_fwd = lambda wl: np.ones(wl.size), norm_p = 1)

    ######## DEFINE OPTIMIZATION ALGORITHM ########
    optimizer = ScipyOptimizers(max_iter = 25, method = 'L-BFGS-B', scaling_factor = 1, pgtol = 1e-6, ftol = 1e-7)

    ######## DEFINE BASE SIMULATION ########
    base_script = load_from_lsf(os.path.join(os.path.dirname(__file__), 'grating_coupler_2D_TE_base.lsf'))

    ######## PUT EVERYTHING TOGETHER ########
    lambda_start = 1550 - bandwidth_in_nm/2
    lambda_end   = 1550 + bandwidth_in_nm/2
    lambda_pts   = int(bandwidth_in_nm/10)+1
    wavelengths = Wavelengths(start = lambda_start*1e-9, stop = lambda_end*1e-9, points = lambda_pts)
    opt = Optimization(base_script = base_script, wavelengths = wavelengths, fom = fom, geometry = geometry, optimizer = optimizer, hide_fdtd_cad = False, use_deps = True)

    ######## RUN THE OPTIMIZER ########
    opt.run(working_dir)


if __name__ == "__main__":
    bandwidth_in_nm = 0     #< Only optimize for center frequency of 1550nm
    etch_depth=80           #< Etch depth in nm

    initial_params = [-2.5, 0.03, 2.4, 0.5369]
 
    cur_path = os.path.dirname(os.path.realpath(__file__))
    working_dir = os.path.join(cur_path,'ApodizedGrating')

    runGratingOptimization( bandwidth_in_nm=bandwidth_in_nm,
                            etch_depth=etch_depth*1e-9,
                            n_grates = 25,
                            params=initial_params,
                            working_dir=working_dir)
