import importlib
import gzip
import pickle
import functools

from pprint import pprint
from pathlib import Path
from tqdm import tqdm
#from tqdm.contrib.concurrent import process_map
from multiprocessing import Pool

import numpy as np
import pandas as pd

import dolphindb as ddb
import dolphindb.settings as keys

import sqlalchemy as sa

import ProtoBuffEntitys


def make_stock_daily_df(blob, type_name, stock_id):
    blob = gzip.decompress(blob)
    dataArray = eval(f"ProtoBuffEntitys.{type_name}Message_pb2.{type_name}Array()")
    dataArray.ParseFromString(blob)

    data_dict_list = [
        {field.name : val for field, val in entry.ListFields()}
        for entry in dataArray.dataArray
    ]

    array_type_list = [
        field.name 
        for field, val in dataArray.dataArray[0].ListFields()
        if isinstance(field.default_value, list)
    ]
    #pprint(array_type_list)

    df = pd.DataFrame(data_dict_list)
    #df['code'] = make_symbol(df['code'])
    df['code'] = stock_id
    df['m_nDate'] = make_date(df['m_nDate'])
    df['m_nTime'] = df['m_nDate'] + make_time(df['m_nTime'])
    for field_name in array_type_list:
        df[field_name] = make_nparray(df[field_name])

    #print(f"Did create ddb table for dataframe of shape {df.shape}")
    # self.make_table_skeleton(type_name, df.shape[0])
    return df


def dump_stock_daily_to_ddb(row, type_name, stock_id):
    df_table_name = type_name
    df = make_stock_daily_df(row[2], type_name, stock_id)

    ddb_sess = ddb.session(DDBLoader.ddb_config['host'], 8848)
    ddb_sess.login(DDBLoader.ddb_config['username'], DDBLoader.ddb_config['password'])

    ddb_sess.upload({df_table_name : df})
    ddb_sess.run("tableInsert(loadTable('{dbPath}', `{partitioned_table_name}), {df_table_name})".format(
        dbPath = DDBLoader.ddb_path,
        partitioned_table_name = type_name + DDBLoader.ddb_partition_table_suffix,
        df_table_name = df_table_name
    ))



def make_symbol(series):
    return series.astype('int32').astype('str')\
        .apply(str.zfill, args=(6,))\
        .apply(lambda code : \
            code + '.SH' if code[0] == '6' \
            else code + '.SZ')


def make_date(series):
    return pd.to_datetime(
        series.astype(str), format='%Y%m%d')


def make_nparray(series):
    return series.apply(lambda x : np.array(x))


def make_time(series):
    s_hr = series // 10000000 * 3600000
    s_min = series % 10000000 // 100000 * 60000
    s_sec = series % 100000 // 1000
    s_ms = series % 1000
    return pd.to_timedelta(s_hr + s_min + s_sec + s_ms, unit='ms')


class DDBLoader(object):
    """
    0. 从sql-server中读取calendar数据,并创建成员变量df_calendar,df_calendar可以保存在本地pickle作为缓存
        |- `def make_calendar_df(self) -> df_calendar`

    1. 创建ddb中的数据库,分区性质从calendar数据中获取
        |- `def create_ddb_database(self, df_calendar) -> void`
        |- `def load_ddb_database(self) -> void`

    2. 在ddb数据库中创建calendar表
        |- `def create_ddb_calendar(self, df_calendar) -> void`

    3. 创建ddb的分布式表结构
        |- `create_ddb_partition_table(self, hft_type_name)`
            |- `_make_table_skeleton(self, hft_type_name, capacity) -> memory_table_name`
    
    4. 从sql server的高频数据转录到dolpindb数据库中
        |- `dump_hft_to_ddb(self, type_name, stock_id, trade_date=None)`
    """

    hft_type_list = ['KLine', 'Order', 'Tick', 'TickQueue', 'Transe']
    
    protobuff_name_dict = {
        name : f"{name}Message_pb2" for name in hft_type_list
    }
    
    protobuff_module_dict = {
        type_name : importlib.import_module(f".{module_name}", package='ProtoBuffEntitys') 
        for type_name, module_name in protobuff_name_dict.items()
    }

    protobuff_desc_dict = {
        type_name : eval(f"ProtoBuffEntitys.{module_name}.{type_name}Array.{type_name}Data.DESCRIPTOR")
        for type_name, module_name in protobuff_name_dict.items()
    }

    mssql_name_dict = {
        type_name : (
            f"{type_name}" if type_name != 'TickQueue' \
            else f"TickQue"
        ) for type_name in hft_type_list
    }
    
    # 数据库路径和数据库名可以不一致
    ddb_path = "dfs://hft_stock_ts"
    ddb_dbname = "db_stock_ts"
    ddb_memory_table_suffix = "Memroy"
    ddb_partition_table_suffix = "Partitioned"

    # calendar表不需要分区,因此需要创建一个新的数据库
    # 该数据库可以是一个简单的csv,现在还不清楚两者的差别
    #ddb_calendar_path = "dfs://daily_calendar"
    #ddb_calendar_dbname = "db_calendar"
    ddb_calendar_table_name = "Calendar"

    col_type_mapping = {
        'code' : 'SYMBOL',
        'm_nDate' : 'DATE',
        'm_nTime' : 'TIME',
        1 : 'FLOAT',
        3 : 'INT',
        5 : 'INT',
        13 : 'INT',
    }

    mssql_config = {
        'host' : '192.168.1.7',
        'username' : 'sa',
        'password' : 'passw0rd!'
    }

    ddb_config = {
        'host' : '192.168.1.7',
        'username' : 'admin',
        'password' : '123456'
    }

    num_workers = 8
    default_table_capacity = 10000
    ddb_dump_journal_fname = 'ddb_dump_journal.csv'


    def __init__(self):
        self.mssql_engine = sa.create_engine(
            "mssql+pyodbc://{username}:{password}@{host}/master?driver=ODBC+Driver+18+for+SQL+Server".format(**self.mssql_config),
            connect_args = {
                "TrustServerCertificate": "yes"
            }, echo=False
        )

        self.ddb_sess = ddb.session(self.ddb_config['host'], 8848)
        self.ddb_sess.login(self.ddb_config['username'], self.ddb_config['password'])


    def init_ddb_database(self, df_calendar):
        """
        1. 创建ddb_database
        2. 创建calendar表
        3. 创建数据分区表
        """
        # df_calendar还是由外部输入比较方便
        #df_calendar = self.make_calendar_df()
        self.create_ddb_database(df_calendar)
        self.create_ddb_calendar(df_calendar)
        for hft_type_name in self.hft_type_list:
            self.create_ddb_partition_table(hft_type_name)


    def init_ddb_table_data(self, df_calendar, num_workers=None):
        """
        对每个股票进行循环,转录数据到分区表
        """
        stock_list = df_calendar['code'].unique().astype('str')

        # 不能重复创建Pool对象,因此需要在循环的最外侧创建好Pool对象,然后传参进去
        with Pool(self.num_workers if num_workers is None else num_workers) as pool:
            for hft_type_name in self.hft_type_list:
                print('Will work on hft type:', hft_type_name)
                with tqdm(stock_list) as pbar:
                    for stock_id in pbar:
                        pbar.set_description(f"Working on stock {stock_id}")
                        self.dump_hft_to_ddb(hft_type_name, stock_id, pbar=pbar, pool=pool)


    def _get_stock_date_list(self, cache=False):
        """
        Deprecated: This function is replaced by `create_ddb_calendar()`. 
        """
        if cache:
            with open('tmp.pkl', 'rb') as fin:
                stock_list, date_list = pickle.load(fin)
        else:
            with self.mssql_engine.connect() as conn:
                # 从KLine表查询,主要是因为KLine表最小
                stat = "select distinct S_INFO_WINDCODE, TRADE_DT from Level2BytesKline.dbo.KLine"
                rs = conn.execute(stat)
                stock_date_list = [(stock_name, date) for stock_name, date in rs.fetchall()]
            stock_list, date_list = zip(*stock_date_list)

        # cache
        #with open('tmp.pkl', 'wb') as fout:
        #    pickle.dump((stock_list, date_list), fout)

        return pd.Series(stock_list, dtype='str').unique(), \
                pd.Series(date_list, dtype='datetime64[D]').unique()


    def create_ddb_database(self, pd_calendar):
        # 从`pd_calendar`中创建`stock_list`和`date_list`
        stock_list = pd_calendar['code'].unique().astype('str')
        date_list = pd_calendar['m_nDate'].unique().astype('datetime64[D]')

        # 可以把所有股票高频数据放在一个数据库中不同的表
        # 分区策略是跟数据库绑定的,因此需要保证同一个数据库中的表都使用同样的分区额策略
        # 对于股票高频数据,我们可以使用COMPO的分区策略,并且两个子db的分区策略都是VALUE类型的code和m_nDate字段
        if self.ddb_sess.existsDatabase(self.ddb_path):
            print('Wiil drop database:', self.ddb_path)
            self.ddb_sess.dropDatabase(self.ddb_path)

        # 要创建一个COMPO分区的数据库,需要首先创建两个简单分区的子数据库
        # 这里我们使用先按日期,然后按股票分区的子数据库
        # Please note that when creating a DFS database with COMPO domain, 
        # the parameter dbPath for each partition level must be either an empty string or unspecified.
        db_date = self.ddb_sess.database('db_date', partitionType=keys.VALUE, partitions=date_list, dbPath='')

        # 这里看起来直接使用dolphindb的脚本语句更方便一些
        self.ddb_sess.run("""
            db_stock = database("", 5, [SYMBOL, 50])
        """)
        #self.ddb_sess.run("""
        #    db_stock = database("", 1, symbol({partitions}))
        #""".format(
        #    partitions = '`' + '`'.join(stock_list)
        #))

        self.ddb_sess.run("""
            {dbName} = database(
                directory = '{dbPath}', 
                partitionType = COMPO, 
                partitionScheme = [db_date, db_stock], 
                engine = "TSDB")
        """.format(
            dbName = self.ddb_dbname,
            dbPath = self.ddb_path
        ))

        self._load_ddb_dump_journal(recreate=True)


    def load_ddb_database(self):
        db_date = self.ddb_sess.database('db_date', dbPath='')
        db_stock = self.ddb_sess.database('db_stock', dbPath='')
        
        self.ddb_sess.run("{dbName} = database(directory='{dbPath}')".format(
            dbName = self.ddb_dbname,
            dbPath = self.ddb_path
        ))

        self._load_ddb_dump_journal()

        
    def _load_ddb_dump_journal(self, recreate=False):
        if recreate or not Path(self.ddb_dump_journal_fname).exists():
            print('Will create new dump journal.')
            self.dump_journal_writer =  open(self.ddb_dump_journal_fname, 'w')
            self.dump_journal_writer.write("type_name,stock_id,status\n")
            self.dump_journal_writer.flush()
        else:
            print('Will load previous dump journal.')
            self.dump_journal_writer = open(self.ddb_dump_journal_fname, 'a')

        self.dump_journal_df = pd.read_csv(self.ddb_dump_journal_fname)
        self.dump_journal_df.set_index(['type_name', 'stock_id', 'status'], inplace=True)
        # 因为dump_journal_df只会在创建的时候载入一次数据,之后不会在写入,因此可以在此时对index进行排序
        self.dump_journal_df.sort_index(inplace=True)
        print('Did load the dump journal, shape', self.dump_journal_df.shape)
        #pprint(self.dump_journal_df.head())


    def create_ddb_calendar(self, df_calendar):
        mem_table = self.ddb_calendar_table_name + self.ddb_memory_table_suffix
        per_table = self.ddb_calendar_table_name
        # 1. 创建临时内存表
        # calendar的行数大概是股票数量 * 交易日数量
        self.ddb_sess.run("""
            {table_name} = table({capacity}:0, {col_names}, [{col_types}]);
        """.format(
            table_name = mem_table,
            capacity = 5000 * 1000,
            col_names = '`code`m_nDate',
            col_types = "SYMBOL, DATE"
        ))
        print('Did create the memory table')

        # 2. 向内存表中插入所有(code, date)数据
        appender = ddb.tableAppender(tableName=mem_table, ddbSession=self.ddb_sess)
        num = appender.append(df_calendar)
        print('Did append calendar data into ddb memory table, return code', num)

        # 3. 创建持久化表格之前需要先根据路径创建一个database对象
        # 但研究了一下,发现好像一个database里面可以同时存在分区表和非分区表,
        # 所以在这里暂时就不创建新的database了
        # 但因为原database设置成了TSDB,所以必须在createTable的时候指定sortKey
        #self.ddb_sess.run("""
        #    {db_name} = 
        #""")

        # 4. 直接从内存表创建一个持久化表格
        if self.ddb_sess.existsTable(self.ddb_path, per_table):
            self.ddb_sess.dropTable(self.ddb_path, per_table)
        self.ddb_sess.run("""
            tableInsert(createTable(
                dbHandle={ddb_dbname},
                table={mem_table}, 
                tableName=`{per_table},
                sortCOlumns=`code`m_nDate,
                compressMethods={{"m_nDate":"delta"}}
            ), {mem_table})
        """.format(
            ddb_dbname = self.ddb_dbname,
            mem_table = mem_table,
            per_table = per_table
        ))
        print('Did create the persistent table with the memory table')


    def make_calendar_df(self):
        print('Will create calendar dataframe from SQL Server')
        # 从KLine表查询,主要是因为KLine表最小
        with self.mssql_engine.connect() as conn:
            stat = "select distinct S_INFO_WINDCODE, TRADE_DT from Level2BytesKline.dbo.KLine"
            rs = conn.execute(stat)
            stock_date_list = [(stock_name, date) for stock_name, date in rs.fetchall()]

        df_calendar = pd.DataFrame(stock_date_list, columns=['code', 'm_nDate'])
        df_calendar['m_nDate'] = make_date(df_calendar['m_nDate'])
        print('Did make the DataFrame for calendar')
        return df_calendar


    def _make_table_skeleton(self, hft_type_name, table_capacity=default_table_capacity):

        def _make_tbl_config(field_list):
            """
            根据ProtoBuffEntity对象的Descriptor.fields,创建ddb标准的列名列表和列类型列表。
            """
            col_name_list, col_type_list = [], []
            for desc in field_list:
                col_name_list.append(desc.name)
                # 如果对列明有特殊设定,目前仅包括`code`m_nDate和`m_nTime三个字段
                if desc.name in self.col_type_mapping:
                    col_type_list.append(self.col_type_mapping[desc.name])
                # 通过对ProtoBuffEntity的类型编号,映射到ddb的类型编号
                # 如果默认值是一个数组,那么ddb类型要额外增加说明是数组
                # ProtoBuffEntity的类型编号只针对基本类型,数组需要通过`default_value`来判断
                else:
                    col_type = self.col_type_mapping[desc.type]
                    if isinstance(desc.default_value, list):
                        col_type += '[]'
                    col_type_list.append(col_type)
            return col_name_list, col_type_list

        desc_obj = self.protobuff_desc_dict[hft_type_name]
        col_name_list, col_type_list = _make_tbl_config(desc_obj.fields)

        table_name = hft_type_name + self.ddb_memory_table_suffix
        print('-' * 80)
        print('Will create table structure:', table_name)

        self.ddb_sess.run("""
            {table_name} = table({capacity}:0, {col_names}, [{col_types}]);
        """.format(
            table_name = table_name,
            capacity = table_capacity,
            col_names = '`' + '`'.join(col_name_list),
            col_types = ', '.join([f"'{type_name}'" for type_name in col_type_list])
        ))
        res = self.ddb_sess.run(f"schema({table_name}).colDefs")
        pprint(res)
        print('-' * 80)
        return table_name


    def create_ddb_partition_table(self, hft_type_name):
        memory_table_name = self._make_table_skeleton(hft_type_name, 10)
        partition_table_name = hft_type_name + self.ddb_partition_table_suffix
        
        print('-' * 80)
        print('Will create partitioned table:', partition_table_name)

        self.ddb_sess.run("""
            {ddb_dbname}.createPartitionedTable(
                table = {memory_table_name}, 
                tableName = `{partition_table_name}, 
                partitionColumns = `m_nDate`code, 
                sortColumns = `code`m_nDate`m_nTime,
                compressMethods = {{m_nDate:"delta", m_nTime:"delta"}}
            )
        """.format(
            ddb_dbname = self.ddb_dbname,
            memory_table_name = memory_table_name,
            partition_table_name = partition_table_name
        ))

        res = self.ddb_sess.run(f"schema(loadTable('{self.ddb_path}', '{partition_table_name}')).colDefs")
        pprint(res)
        print('-' * 80)

    
    def dump_hft_to_ddb(self, type_name, stock_id, trade_date=None, pbar=None, pool=None):
        if (type_name, stock_id, 'OK') in self.dump_journal_df.index:
            message = f"Wiil skip ({type_name}, {stock_id}) as it appears in the dump journal."
            if pbar is None:
                print(message)
            else:
                pbar.set_description(message)
            return
        
        self.dump_journal_writer.write(f"{type_name},{stock_id},START\n")
        self.dump_journal_writer.flush()
        
        # 经过尝试,按个股来做batch查询效率还是可以接受的
        # mssql中,索引字段是(S_INFO_WINDCODE, TRADE_DT)
        with self.mssql_engine.connect() as conn:
            stat = """
                select * from [Level2Bytes{mssql_type_name}].dbo.[{mssql_type_name}] 
                where S_INFO_WINDCODE='{stock_id}'
            """.format(
                mssql_type_name = self.mssql_name_dict[type_name],
                stock_id = stock_id
            )
            row_list = list(conn.execute(stat).fetchall())
            num_rows = len(row_list)

            if pbar:
                #pbar.set_description(f"Did get the result set for stock {stock_id} from mssql")
                pbar.set_description(f"Will work in paralle on dumping job on {stock_id} of len {num_rows}")
            else:
                print(f"Did get the result set for stock {stock_id} from mssql")

            # 每一行是当个个股某一日的所有高频交易信息
            # 使用多进程来加快速度
            
            #with Pool(self.num_workers if num_workers is None else num_workers) as pool:
            if pool is None:
                print("Will create new Pool object, but this is not encourage for large batch work.")
                pool = Pool(self.num_worker)

            with tqdm(total=num_rows, leave=False) as sub_pbar:
                for _ in pool.imap_unordered(
                    functools.partial(
                        dump_stock_daily_to_ddb, 
                        type_name = type_name,
                        stock_id = stock_id
                    ),
                    row_list
                ):
                    sub_pbar.update()

        self.dump_journal_writer.write(f"{type_name},{stock_id},OK\n")
        self.dump_journal_writer.flush()


def main():
    loader = DDBLoader()
    df_calendar = loader.make_calendar_df()

    loader.init_ddb_database(df_calendar)
    print('Did finish init_ddb_database')

    #loader.load_ddb_database()
    #print('Did load ddb database')

    loader.init_ddb_table_data(df_calendar)
    print('Did finish init_table_data')


if __name__ == '__main__':
    main()