{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f15b133a-12f3-4803-830d-820a471add0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import qlib\n",
    "from qlib.data import D\n",
    "from qlib.constant import REG_CN\n",
    "\n",
    "from qlib.data.pit import P, PRef"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b6e1089e-e51d-42ac-aee8-53d1a886a355",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint \n",
    "\n",
    "class PRelRef(P):\n",
    "    \n",
    "    def __init__(self, feature, rel_period):\n",
    "        super().__init__(feature)\n",
    "        self.rel_period = rel_period\n",
    "        self.unit = unit\n",
    "        \n",
    "    def __str__(self):\n",
    "        return f\"{super().__str__()}[{self.rel_period, self.unit}]\"\n",
    "    \n",
    "    def _load_feature(self, instrucument, start_index, end_index, cur_time):\n",
    "        #pprint(f\"{start_index}, {end_index}\")\n",
    "        #pprint(f\"{self.feature.get_longest_back_rolling()}, {self.feature.get_extended_window_size()}\")\n",
    "        #pprint(f\"{cur_time}, {self.rel_period}, {self.unit}\")\n",
    "        return self.feature.load(instrucument, start_index, end_index, cur_time, self.rel_period)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8730a9fb-9356-4847-b33d-370a3095df04",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from qlib.data.ops import ElemOperator\n",
    "from qlib.data.data import Cal\n",
    "\n",
    "def is_of_quarter(period:int, quarter:int) -> bool:\n",
    "    return (period - quarter) % 100 == 0\n",
    "\n",
    "\n",
    "class PDiff(P):\n",
    "    \"\"\"\n",
    "    还是继承P,而不是EleOperator,以减少麻烦。\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, feature, **kwargs):\n",
    "        super().__init__(feature)\n",
    "        self.rel_period = 1 if 'rel_period' not in kwargs else kwargs['rel_period']\n",
    "        self.skip_q1 = False if 'skip_q1' not in kwargs else kwargs['skip_q1']\n",
    "        \n",
    "    def _load_internal(self, instrument, start_index, end_index, freq):\n",
    "        _calendar = Cal.calendar(freq=freq)\n",
    "        resample_data = np.empty(end_index - start_index + 1, dtype=\"float32\")\n",
    "        \n",
    "        # 对日期区间逐一循环,考虑到使用PIT数据的模型一般最多到日频,单个股票序列长度最多到千级\n",
    "        for cur_index in range(start_index, end_index + 1):\n",
    "            cur_time = _calendar[cur_index]\n",
    "            # To load expression accurately, more historical data are required\n",
    "            start_ws, end_ws = self.get_extended_window_size()\n",
    "            if end_ws > 0:\n",
    "                raise ValueError(\n",
    "                    \"PIT database does not support referring to future period (e.g. expressions like `Ref('$$roewa_q', -1)` are not supported\"\n",
    "                )\n",
    "\n",
    "            # The calculated value will always the last element, so the end_offset is zero.\n",
    "            try:\n",
    "                s = self._load_feature(instrument, -start_ws, 0, cur_time)\n",
    "                pprint(s)\n",
    "                # 满足不需要做diff的条件:在需要跳过一季度的前提下,当前引用的财报期确实为一季度\n",
    "                if self.skip_q1 or is_of_quarter(s.index[-1], 1):\n",
    "                    resample_data[cur_index - start_index] = s.iloc[-1] if len(s) > 0 else np.nan\n",
    "                else:\n",
    "                    resample_data[cur_index - start_index] = (s.iloc[-1] - s.iloc[-2]) if len(s) > 1 else np.nan\n",
    "            except FileNotFoundError:\n",
    "                get_module_logger(\"base\").warning(f\"WARN: period data not found for {str(self)}\")\n",
    "                return pd.Series(dtype=\"float32\", name=str(self))\n",
    "\n",
    "        resample_series = pd.Series(\n",
    "            resample_data, index=pd.RangeIndex(start_index, end_index + 1), dtype=\"float32\", name=str(self)\n",
    "        )\n",
    "        return resample_series\n",
    "\n",
    "    def get_longest_back_rolling(self):\n",
    "        return self.feature.get_longest_back_rolling() + self.rel_period\n",
    "    \n",
    "    def get_extended_window_size(self):\n",
    "        # 这里需要考虑的是feature的windows size,而不仅仅是自身的windows size\n",
    "        lft_etd, rght_etd = self.feature.get_extended_window_size()\n",
    "        return lft_etd + self.rel_period, rght_etd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "abe001d9-ccb4-48a8-b5d0-91ef8862b14f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class PPairDiff(PairOperator):\n",
    "    \n",
    "    def __init__(self, feature_left, feature_right, **kwargs):\n",
    "        super().__init__(feature_left, feature_right)\n",
    "        self.rel_period = 1 if 'rel_period' not in kwargs else kwargs['rel_period']\n",
    "    \n",
    "\n",
    "    def _load_internal(self, instrument, start_index, end_index, *args):\n",
    "        assert any(\n",
    "            [isinstance(self.feature_left, Expression), self.feature_right, Expression]\n",
    "        ), \"at least one of two inputs is Expression instance\"\n",
    "\n",
    "        if isinstance(self.feature_left, Expression):\n",
    "            series_left = self.feature_left.load(instrument, start_index, end_index, *args)\n",
    "        else:\n",
    "            series_left = self.feature_left  # numeric value\n",
    "        if isinstance(self.feature_right, Expression):\n",
    "            series_right = self.feature_right.load(instrument, start_index, end_index, *args)\n",
    "        else:\n",
    "            series_right = self.feature_right\n",
    "\n",
    "        if self.N == 0:\n",
    "            series = getattr(series_left.expanding(min_periods=1), self.func)(series_right)\n",
    "        else:\n",
    "            series = getattr(series_left.rolling(self.N, min_periods=1), self.func)(series_right)\n",
    "        return series\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c7272be7-8df1-47b0-b18e-b631aca6e3cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[41422:MainThread](2022-07-14 15:01:12,046) INFO - qlib.Initialization - [config.py:413] - default_conf: client.\n",
      "[41422:MainThread](2022-07-14 15:01:12,053) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.\n",
      "[41422:MainThread](2022-07-14 15:01:12,057) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': PosixPath('/home/guofu/Workspaces/guofu/TslDataFeed/_data/test/target')}\n"
     ]
    }
   ],
   "source": [
    "qlib.init(provider_uri='_data/test/target/', region=REG_CN, custom_ops=[PDiff, PPairDiff])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "3939574b-80f4-48db-bd16-1636f92b2e02",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>PDiff($$净利润_q, skip_q1=True)</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>instrument</th>\n",
       "      <th>datetime</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"6\" valign=\"top\">sh600000</th>\n",
       "      <th>2021-03-26</th>\n",
       "      <td>1.593600e+10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2021-03-29</th>\n",
       "      <td>1.380300e+10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2021-03-30</th>\n",
       "      <td>1.380300e+10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2021-03-31</th>\n",
       "      <td>1.380300e+10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2021-04-01</th>\n",
       "      <td>1.380300e+10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2021-04-02</th>\n",
       "      <td>1.380300e+10</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       PDiff($$净利润_q, skip_q1=True)\n",
       "instrument datetime                                \n",
       "sh600000   2021-03-26                  1.593600e+10\n",
       "           2021-03-29                  1.380300e+10\n",
       "           2021-03-30                  1.380300e+10\n",
       "           2021-03-31                  1.380300e+10\n",
       "           2021-04-01                  1.380300e+10\n",
       "           2021-04-02                  1.380300e+10"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D.features(['sh600000'], ['PDiff($$净利润_q, skip_q1=True)'], start_time='2021-03-26', end_time='2021-04-02', freq=\"day\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "679d0bcd-6975-43f3-b394-5e2e775a9b8f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "c1ef5dbb-930f-4d2d-ac3a-5287ddac6d6c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(202003 - 2) % 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c01d2ab-eafa-4492-bbe5-6e3c8382bd3c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}