import pandas as pd
import xarray as xr
import numpy as np
import glob
def get_ens_ACCESS(exp_name,var_name,yrstart=2040,yrend=2060,concat=True):
    '''
    read all 10 realizations of ACCESS
    return concatenated DataArray
    '''
    var_concat = []
    for ireal in range(1,11):
        dir_name = '/scratch/projects/shk00056/CDRMIP/ACCESS-ESM1-5/' + exp_name + '/' + var_name + '/real/'
        file_list = glob.glob(dir_name + '/*' + 'r' + str(ireal) + 'i*.nc')
        file_list.sort()
        print(file_list)
        ds = xr.open_mfdataset(file_list)
        yrini = 2015
        istart = (yrstart-yrini) * 12
        iend = (yrend-yrini)*12
        var = ds[var_name][istart:iend+12]
        var_concat.append(var)
    if concat is True:
        return xr.concat(var_concat,dim='ireal')
    if concat is False:
        return var_concat


def ens_values(df_list,var_name):
    """
    df_list: a list of DataFrame
    """
    yrstart = 2015
    diff_realizaions = [] # (nreal,nyear)
    for ireal in range(len(df_list)):
        diff_realizaions.append([])
    for iyear in range(yrstart,2100):
        for ireal in range(len(df_list)):
            diff_realizaions[ireal].append(df_list[ireal][var_name].loc[iyear])           
    diff_realizaions = np.array(diff_realizaions)
    ens_mean = np.average(diff_realizaions,axis=0)
    ens_min = np.min(diff_realizaions,axis=0)
    ens_max = np.max(diff_realizaions,axis=0)
    ens_stddev = np.std(diff_realizaions,axis=0)
    yrindex = np.array(range(yrstart,2100)) 
    df_ens = pd.DataFrame(data=ens_max - ens_min, index=yrindex, columns=['ens_range'])
    df_ens['ens_mean'] = ens_mean
    df_ens['ens_min'] = ens_min
    df_ens['ens_max'] = ens_max
    return df_ens

def cOcean_from_fgco2(df,yrstart):
    cOcean = []
    for year in range(yrstart,2100):
        cOcean.append(df['fgco2'].loc[yrstart:year].sum())
    cOcean = np.array(cOcean)
    cOcean = pd.DataFrame(data=cOcean, index=range(yrstart,2100), columns=['cOcean'])
    cOcean.index.name = 'year'
    return cOcean

class Figure1():
    def __init__(self, var_name, exp):
        self.var_name = var_name
        self.exp = exp
        self.expnamelist = []
        self.dflist = []
    def read_FOCI(self,debug=False):
        self.model = 'FOCI'
        var_name = self.var_name
        model, exp = self.model, self.exp
        diff = True
        self.expnamelist = []
        self.expnamelist.append(['CC104','CC105','CC106']) # 0 REF
        self.expnamelist.append(['HW002','HW005','HW006']) # 1 OAE
        self.expnamelist.append(['HW003','HW007','HW008']) # 2 AFF
        self.expnamelist.append(['HW013','HW012','HW011']) # 3 SynTraREF
        self.expnamelist.append(['HW022','HW023','HW024']) # 4 SynTraAR
        self.expnamelist.append(['HW028','HW032','HW033']) # 5 SynTraOAE
        self.expnamelist.append(['HW029','HW030','HW031']) # 6 SynTraBOTH
        self.expnamelist.append(['HW034','HW035','HW036']) # 7 SynTraHalfOAE
        self.expnamelist.append(['HW037','HW038','HW039']) # 8 SynTraHalfBOTH
        self.expnamelist.append(['HW040','HW041','HW042']) # 9 SynTraHalfAR
        if exp=='OAE':
            save_name = var_name + '_OAE'
            self.yrstart = 2015
        if exp=='AFF':
            save_name = var_name + '_AFF'
            self.yrstart = 2015
        if exp=='BOTH':
            save_name = var_name + '_BOTH'
        self.dflist = []
        for iexpset in range(len(self.expnamelist)):
            self.dflist.append([])
            for expname in self.expnamelist[iexpset]:
                if var_name=='cLand':
                    try:
                        self.dflist[iexpset].append( pd.read_csv('../data/csv/' + expname + '_combined.csv',index_col='year'))
                    except:
                        try:
                            self.dflist[iexpset].append( pd.read_csv('../data/csv/' + expname + '_cLand.csv',index_col='year'))
                        except:
                            if debug is True:
                                print(f'Error reading {expname} {var_name}!')
                else:
                    try:
                        self.dflist[iexpset].append( pd.read_csv('../data/csv/' + expname + '_' + var_name + '.csv',index_col='year'))
                    except:
                        if debug is True:
                            print(f'Error reading {expname} {var_name}!')
                
    def read_MPIESM(self,debug=False):
        self.model = 'MPI-ESM'
        var_name = self.var_name
        model, exp = self.model, self.exp
        diff = True
        self.expnamelist = []
        if exp=='SynTra':
            self.expnamelist.append(['ymo1001','ymo1002','ymo1003']) # 0 REF
            self.expnamelist.append(['ymo1021','ymo1022','ymo1045']) # 1 AR
            self.expnamelist.append(['ymo1089','ymo1090','ymo1091']) # 2 OAE
            self.expnamelist.append(['ymo1092','ymo1093','ymo1094']) # 3 BOTH
        if exp=='OAE':
            self.expnamelist.append(['essp585_0003','essp585_0004','essp585_0005'])
            self.expnamelist.append(['esm-ssp585-ocn-alk_EXP1','esm-ssp585-ocn-alk_EXP2','esm-ssp585-ocn-alk_EXP3'])
            self.expnamelist.append([])
            self.yrstart = 2020
        if exp=='AFF':
            self.expnamelist.append(['essp585_0004','essp585_0005','esm-ssp585'])
            self.expnamelist.append([])
            self.expnamelist.append(['ymo_esm_ssp585_ssp126Lu_2','ymo_esm_ssp585_ssp126Lu_3','esm-ssp585-ssp126Lu'])
            self.yrstart = 2015
        if exp=='BOTH':
            self.expnamelist.append(['essp585_0003','essp585_0004','essp585_0005'])
            self.expnamelist.append(['esm-ssp585-ocn-alk_EXP1','esm-ssp585-ocn-alk_EXP2','esm-ssp585-ocn-alk_EXP3'])
            self.expnamelist.append(['essp585_0004','essp585_0005','esm-ssp585'])
            self.expnamelist.append(['ymo_esm_ssp585_ssp126Lu_2','ymo_esm_ssp585_ssp126Lu_3','esm-ssp585-ssp126Lu'])
            save_name = var_name + '_BOTH'
        if exp!='SynTra':
#             base_dir = '/home/shkhwwey/OAE/csv/MPI-ESM1-2-LR/'
            base_dir = '../data/csv/'
        if exp=='SynTra':
            base_dir = '/home/shkhwwey/OAE/SynTra_analysis/csv/MPIESM_csv/'
        self.dflist = []
        for iexpset in range(len(self.expnamelist)):
            self.dflist.append([])
            for expname in self.expnamelist[iexpset]:
                if var_name=='cLand':
                    self.dflist[iexpset].append( pd.read_csv(base_dir + 'MPI-ESM1-2-LR_' + expname + '_' + var_name + '.csv',index_col='year'))
                else:
                    try:
                        self.dflist[iexpset].append( pd.read_csv(base_dir + 'MPI-ESM1-2-LR_' + expname + '_' + var_name + '.csv',index_col='year'))
                    except:
                        if debug is True:
                            print(f'Error reading {expname} {var_name}!')
                            

    def read_ACCESS_ensemble(self,debug=False):
        base_dir = '../data/csv/'
        self.model = 'ACCESS-ESM1-5'
        var_name = self.var_name
        self.expnamelist = ['esm-ssp585','esm-ssp585-ssp126Lu']
        for iexpset in range(len(self.expnamelist)):
            self.dflist.append([])
            expname = self.expnamelist[iexpset]
            for ireal in range(1,11):
                try:
                    self.dflist[iexpset].append(pd.read_csv(base_dir + 'ACCESS-ESM1-5_' + expname + '_' + var_name + '_r' + str(ireal) + '.csv',index_col='year'))
                except:
                    if debug is True:
                        print(f'Error reading {expname} {var_name} r{ireal}!')
    def read_model(self,model,debug=False):
        base_dir = '../data/csv/'
        self.model = model
        exp = self.exp
        var_name = self.var_name
        diff = True
        if exp=='OAE':
            save_name = var_name + '_OAE'
            self.expnamelist = [['esm-ssp585'],['esm-ssp585-ocn-alk'],['esm-ssp585-ssp126Lu']]
            #self.yrstart = 2015
        if exp=='AFF':
            save_name = var_name + '_AFF'
            self.expnamelist = [['esm-ssp585'],['esm-ssp585-ocn-alk'],['esm-ssp585-ssp126Lu']]
            #self.yrstart = 2015
        if exp=='BOTH':
            save_name = var_name + '_BOTH'
        self.dflist = []
        for iexpset in range(len(self.expnamelist)):
            self.dflist.append([])
            for expname in self.expnamelist[iexpset]:
                try:
                    self.dflist[iexpset].append(pd.read_csv(base_dir + model + '_' + expname + '_' + var_name + '.csv',index_col='year'))
                except:
                        if debug is True:
                            print(f' Error reading {expname} {var_name} for {model}!')
                            
