import pandas as pd
import xarray as xr
import numpy as np
import glob
import cartopy.crs as ccrs

def read_dirlist(expset):
    dirlist = []
    if expset=='REF':
        dirlist.append('/scratch/usr/shkhwwey/output_others/FOCI1.20.0-CC104_RCP_ESM_spinup2089/nc/')
        dirlist.append('/scratch/usr/shkhwwey/output_others/FOCI1.20.0-CC105_RCP_ESM_spinup2099/nc/')
        dirlist.append('/scratch/usr/shkhwwey/output_others/FOCI1.20.0-CC106_RCP_ESM_spinup2079/nc/')
    if expset=='AFF':
        dirlist.append('/scratch/usr/shkhwwey/models/foci2.0/experiments/FOCI20.0-HW003_esm_ssp585_ssp126Lu/nc/')
        dirlist.append('/scratch/usr/shkhwwey/models/foci2.0/experiments/FOCI20.0-HW007_esm_ssp585_ssp126Lu/nc/')
        dirlist.append('/scratch/usr/shkhwwey/models/foci2.0/experiments/FOCI20.0-HW008_esm_ssp585_ssp126Lu/nc/')
    if expset=='OAE':
        dirlist.append('/scratch/usr/shkhwwey/models/foci2.0/experiments/FOCI20.0-HW002_esm_ssp585_ocn_alk/nc/')
        dirlist.append('/scratch/usr/shkhwwey/models/foci2.0/experiments/FOCI20.0-HW005_esm_ssp585_ocn_alk/nc/')
        dirlist.append('/scratch/usr/shkhwwey/models/foci2.0/experiments/FOCI20.0-HW006_esm_ssp585_ocn_alk/nc/')
    return dirlist

def get_file_list(expset,stream,yrstart=2090,yrend=2100,real=[0]):
    dirlist = read_dirlist(expset)
    for ireal in real:
#         print(f'{ireal}/2')
        echam_all_years = []
        jsbach_all_years = []
        veg_all_years = []
        dirname = dirlist[ireal]
        ptrc_T_all_years = []
        grid_T_all_years = []
        diad_T_all_years = []
        co2_all_years = []
        if stream=='ptrc_T':
            for year in range(yrstart,yrend):
                if expset=='REF':
#                     foo = glob.glob('/scratch/usr/shkchien/models/foci1.20.0/experiments/' + dirname.split('/')[-3] + '/outdata/nemo/ym/*' + str(year) + '0101*_ptrc_T*.nc')[0]
                    foo = glob.glob('/scratch/usr/shkhwwey/models/foci2.0/experiments/' + 'FOCI20.0-HW10' + str(ireal+1) + '_esm_ssp585/outdata/nemo/ym/*' + str(year) + '0101*_ptrc_T*.nc')[0]
                else:
                    foo = glob.glob(dirname + '../outdata/nemo/ym/*' + str(year)+'0101*_ptrc_T*.nc')[0]
                ptrc_T_all_years.append(foo)
            return ptrc_T_all_years
        if stream=='diad_T':
            for year in range(yrstart,yrend):
                if expset=='REF':
#                     foo = glob.glob('/scratch/usr/shkchien/models/foci1.20.0/experiments/' + dirname.split('/')[-3] + '/outdata/nemo/ym/*' + str(year) + '0101*_diad_T*.nc')[0]
                    foo = glob.glob('/scratch/usr/shkhwwey/models/foci2.0/experiments/' + 'FOCI20.0-HW10' + str(ireal+1) + '_esm_ssp585/outdata/nemo/ym/*' + str(year) + '0101*_diad_T*.nc')[0]
                else:
                    foo = glob.glob(dirname + '../outdata/nemo/ym/*' + str(year)+'0101*_diad_T*.nc')[0]
                diad_T_all_years.append(foo)
            return diad_T_all_years
        if stream=='grid_T':
            for year in range(yrstart,yrend):
                if expset=='REF':
#                     foo = glob.glob('/scratch/usr/shkchien/models/foci1.20.0/experiments/' + dirname.split('/')[-3] + '/outdata/nemo/ym/*' + str(year) + '0101*_grid_T*.nc')[0]
                    foo = glob.glob('/scratch/usr/shkhwwey/models/foci2.0/experiments/' + 'FOCI20.0-HW10' + str(ireal+1) + '_esm_ssp585/outdata/nemo/ym/*' + str(year) + '0101*_grid_T*.nc')[0]
                else:
                    foo = glob.glob(dirname + '../outdata/nemo/ym/*' + str(year)+'0101*_grid_T*.nc')[0]
                grid_T_all_years.append(foo)
            return grid_T_all_years
        if stream=='echam':
            for year in range(yrstart,yrend):
                foo = glob.glob(dirname + 'echam6_echam_' + str(year) + '.nc')[0]
                echam_all_years.append(foo)
            return echam_all_years
        if stream=='veg':
            for year in range(yrstart,yrend):
                foo = glob.glob(dirname + 'veg_' + str(year) + '.nc')[0]
                veg_all_years.append(foo)
            return veg_all_years
        if stream=='jsbach':
            for year in range(yrstart,yrend):
                foo = glob.glob(dirname + 'jsbach_' + str(year) + '.nc')[0]
                jsbach_all_years.append(foo)
            return jsbach_all_years
        if stream=='co2':
            for year in range(yrstart,yrend):
                foo = glob.glob(dirname + 'echam6_co2_' + str(year) + '.nc')[0]
                co2_all_years.append(foo)
            return co2_all_years

def ensemble_mean_considering_consistence(var_diff):
    '''
    calculate ensemble mean for multi-realizaiton ensemble when result is robust, otherwise filled with np.nan
    when ensemble number is three, considered robust when all three member agree on sign of change.
    when ensemble number is ten, considered robust when at least eight member agree on sign of change.
    '''
    if var_diff.shape[0]==3:
        var_diff_ens_mean = xr.where( ((var_diff[0]>0) & (var_diff[1]>0) & (var_diff[2]>0)) | ((var_diff[0]<0) & (var_diff[1]<0) & (var_diff[2]<0)),
                                     var_diff.mean(dim='ireal'),
                                     np.nan)
    elif var_diff.shape[0]==10:
        var_diff_ens_mean = xr.where( ((np.sum(var_diff>0,axis=0))>=8) | ((np.sum(var_diff<0,axis=0))>=8),
                                     var_diff.mean(dim='ireal'),
                                     np.nan)
    else:
        print('number of realizations not 3 or 10!!')
    return var_diff_ens_mean

def add_subplot_ens(foo,ax,vmin=None,vmax=None,unit='',title='',cmap='PiYG',fillna=True,cb=False,real=3):
    '''
    map plot for multi-realization models
    calls ensemble_mean_considering_consistence() for calculating ensemble mean
    '''
    for ireal in range(real+1):
        if ireal!=real:
            continue
        else:
            #bar = foo.mean(dim='ireal')
            if fillna is True:
                bar = ensemble_mean_considering_consistence(foo).fillna(0)
            else:
                bar = ensemble_mean_considering_consistence(foo)
        fontsize = 18
        try:
            im = bar.plot(y="nav_lat", x="nav_lon", ax=ax, transform=ccrs.PlateCarree(),
                             vmax=vmax,vmin=vmin,
                             add_labels=False,robust=True,add_colorbar=False,cmap=cmap)
        except:
            try:
                im = bar.plot(y="lat", x="lon", ax=ax, transform=ccrs.PlateCarree(),
                                 vmax=vmax,vmin=vmin,
                                 add_labels=False,robust=True,add_colorbar=False,cmap=cmap)
            except Exception as e:
                try:
                    print(e)
                    im = bar.plot(y="latitude", x="longitude", ax=ax, transform=ccrs.PlateCarree(),
                                     vmax=vmax,vmin=vmin,
                                     add_labels=False,robust=True,add_colorbar=False,cmap=cmap)
                except:
                    print('tried nav_lat/nav_lon, lat/lon, and latitude/longitude')
        if cb is True:
            cb = plt.colorbar(im, ax=ax,orientation='horizontal',pad=0.05,shrink=0.5)
            cb.set_label(unit,size=fontsize)
            cb.ax.tick_params(labelsize=fontsize)
        gl = ax.gridlines(draw_labels=False)
        gl.top_labels = False
        # gl.left_labels = False

        ax.set_global()
        ax.coastlines()
        if ireal!=real:
            dummy = ax.set_title(title + ' r' + str(ireal+1),fontsize=18)
        else:
            dummy = ax.set_title(title,fontsize=18)
    return im
