Major changes: - Fix FixedFlagMarketInjector to add market_0, market_1 columns based on instrument codes - Fix FixedFlagSTInjector to create IsST column from ST_S, ST_Y flags - Update generate_beta_embedding.py to handle IsST creation conditionally - Add dump_polars_dataset.py for generating raw and processed datasets - Add debug_data_divergence.py for comparing gold-standard vs polars output Documentation: - Update BUG_ANALYSIS_FINAL.md with IsST column issue discovery - Update README.md with polars dataset generation instructions Key discovery: - The FlagSTInjector in the gold-standard qlib code fails silently - The VAE was trained without IsST column (341 features, not 342) - The polars pipeline correctly skips FlagSTInjector to match gold-standard Generated dataset structure (2026-02-23 to 2026-02-27): - Raw data: 18,291 rows × 204 columns - Processed data: 18,291 rows × 342 columns (341 for VAE input) - market_0, market_1 columns correctly added to feature_flag groupmaster
parent
4d382dc6bd
commit
8bd36c1939
@ -0,0 +1,254 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Debug script to compare gold-standard qlib data vs polars-based pipeline.
|
||||||
|
|
||||||
|
This script helps identify where the data loading and processing pipeline
|
||||||
|
starts to diverge from the gold-standard qlib output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pickle as pkl
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import polars as pl
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
GOLD_RAW_PATH = "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/raw_data_20190101_20190131.pkl"
|
||||||
|
GOLD_PROC_PATH = "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/processed_data_20190101_20190131.pkl"
|
||||||
|
PROC_LIST_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/dataset/csiallx_feature2_ntrla_flag_pnlnorm/proc_list.proc"
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
|
||||||
|
|
||||||
|
def compare_raw_data():
|
||||||
|
"""Compare raw data from gold standard vs polars pipeline."""
|
||||||
|
print("=" * 80)
|
||||||
|
print("STEP 1: Compare RAW DATA (before proc_list)")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Load gold standard raw data
|
||||||
|
with open(GOLD_RAW_PATH, "rb") as f:
|
||||||
|
gold_raw = pkl.load(f)
|
||||||
|
|
||||||
|
print(f"\nGold standard raw data:")
|
||||||
|
print(f" Shape: {gold_raw.shape}")
|
||||||
|
print(f" Index: {gold_raw.index.names}")
|
||||||
|
print(f" Column groups: {gold_raw.columns.get_level_values(0).unique().tolist()}")
|
||||||
|
|
||||||
|
# Count columns per group
|
||||||
|
for grp in gold_raw.columns.get_level_values(0).unique().tolist():
|
||||||
|
count = (gold_raw.columns.get_level_values(0) == grp).sum()
|
||||||
|
print(f" {grp}: {count} columns")
|
||||||
|
|
||||||
|
# Show sample values for key columns
|
||||||
|
print("\n Sample values (first 3 rows):")
|
||||||
|
for col in [('feature', 'KMID'), ('feature_ext', 'turnover'), ('feature_ext', 'log_size')]:
|
||||||
|
if col in gold_raw.columns:
|
||||||
|
print(f" {col}: {gold_raw[col].iloc[:3].tolist()}")
|
||||||
|
|
||||||
|
return gold_raw
|
||||||
|
|
||||||
|
|
||||||
|
def compare_processed_data():
|
||||||
|
"""Compare processed data from gold standard vs polars pipeline."""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("STEP 2: Compare PROCESSED DATA (after proc_list)")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Load gold standard processed data
|
||||||
|
with open(GOLD_PROC_PATH, "rb") as f:
|
||||||
|
gold_proc = pkl.load(f)
|
||||||
|
|
||||||
|
print(f"\nGold standard processed data:")
|
||||||
|
print(f" Shape: {gold_proc.shape}")
|
||||||
|
print(f" Index: {gold_proc.index.names}")
|
||||||
|
print(f" Column groups: {gold_proc.columns.get_level_values(0).unique().tolist()}")
|
||||||
|
|
||||||
|
# Count columns per group
|
||||||
|
for grp in gold_proc.columns.get_level_values(0).unique().tolist():
|
||||||
|
count = (gold_proc.columns.get_level_values(0) == grp).sum()
|
||||||
|
print(f" {grp}: {count} columns")
|
||||||
|
|
||||||
|
# Show sample values for key columns
|
||||||
|
print("\n Sample values (first 3 rows):")
|
||||||
|
for col in [('feature', 'KMID'), ('feature', 'KMID_ntrl'),
|
||||||
|
('feature_ext', 'turnover'), ('feature_ext', 'turnover_ntrl')]:
|
||||||
|
if col in gold_proc.columns:
|
||||||
|
print(f" {col}: {gold_proc[col].iloc[:3].tolist()}")
|
||||||
|
|
||||||
|
return gold_proc
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_processor_pipeline(gold_raw, gold_proc):
|
||||||
|
"""Analyze what transformations happened in the proc_list."""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("STEP 3: Analyze Processor Transformations")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Load proc_list
|
||||||
|
with open(PROC_LIST_PATH, "rb") as f:
|
||||||
|
proc_list = pkl.load(f)
|
||||||
|
|
||||||
|
print(f"\nProcessor pipeline ({len(proc_list)} processors):")
|
||||||
|
for i, proc in enumerate(proc_list):
|
||||||
|
print(f" [{i}] {type(proc).__name__}")
|
||||||
|
|
||||||
|
# Analyze column changes
|
||||||
|
print("\nColumn count changes:")
|
||||||
|
print(f" Before: {gold_raw.shape[1]} columns")
|
||||||
|
print(f" After: {gold_proc.shape[1]} columns")
|
||||||
|
print(f" Change: +{gold_proc.shape[1] - gold_raw.shape[1]} columns")
|
||||||
|
|
||||||
|
# Check which columns were added/removed
|
||||||
|
gold_raw_cols = set(gold_raw.columns)
|
||||||
|
gold_proc_cols = set(gold_proc.columns)
|
||||||
|
|
||||||
|
added_cols = gold_proc_cols - gold_raw_cols
|
||||||
|
removed_cols = gold_raw_cols - gold_proc_cols
|
||||||
|
|
||||||
|
print(f"\n Added columns: {len(added_cols)}")
|
||||||
|
print(f" Removed columns: {len(removed_cols)}")
|
||||||
|
|
||||||
|
if removed_cols:
|
||||||
|
print(f" Removed: {list(removed_cols)[:10]}...")
|
||||||
|
|
||||||
|
# Check feature column patterns
|
||||||
|
print("\nFeature column patterns in processed data:")
|
||||||
|
feature_cols = [c for c in gold_proc.columns if c[0] == 'feature']
|
||||||
|
ntrl_cols = [c for c in feature_cols if c[1].endswith('_ntrl')]
|
||||||
|
raw_cols = [c for c in feature_cols if not c[1].endswith('_ntrl')]
|
||||||
|
print(f" Total feature columns: {len(feature_cols)}")
|
||||||
|
print(f" _ntrl columns: {len(ntrl_cols)}")
|
||||||
|
print(f" raw columns: {len(raw_cols)}")
|
||||||
|
|
||||||
|
|
||||||
|
def check_polars_pipeline():
|
||||||
|
"""Run the polars-based pipeline and compare."""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("STEP 4: Generate data using Polars pipeline")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from generate_beta_embedding import (
|
||||||
|
load_all_data, merge_data_sources, apply_feature_pipeline,
|
||||||
|
filter_stock_universe
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load data using polars pipeline
|
||||||
|
print("\nLoading data with polars pipeline...")
|
||||||
|
df_alpha, df_kline, df_flag, df_industry = load_all_data(
|
||||||
|
"2019-01-01", "2019-01-31"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nPolars data sources loaded:")
|
||||||
|
print(f" Alpha158: {df_alpha.shape}")
|
||||||
|
print(f" Kline (market_ext): {df_kline.shape}")
|
||||||
|
print(f" Flags: {df_flag.shape}")
|
||||||
|
print(f" Industry: {df_industry.shape}")
|
||||||
|
|
||||||
|
# Merge
|
||||||
|
df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
|
||||||
|
print(f"\nAfter merge: {df_merged.shape}")
|
||||||
|
|
||||||
|
# Convert to pandas for easier comparison
|
||||||
|
df_pandas = df_merged.to_pandas()
|
||||||
|
df_pandas = df_pandas.set_index(['datetime', 'instrument'])
|
||||||
|
|
||||||
|
print(f"\nAfter converting to pandas MultiIndex: {df_pandas.shape}")
|
||||||
|
|
||||||
|
# Compare column names
|
||||||
|
with open(GOLD_RAW_PATH, "rb") as f:
|
||||||
|
gold_raw = pkl.load(f)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("STEP 5: Compare Column Names (Gold vs Polars)")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
gold_cols = set(str(c) for c in gold_raw.columns)
|
||||||
|
polars_cols = set(str(c) for c in df_pandas.columns)
|
||||||
|
|
||||||
|
common_cols = gold_cols & polars_cols
|
||||||
|
only_in_gold = gold_cols - polars_cols
|
||||||
|
only_in_polars = polars_cols - gold_cols
|
||||||
|
|
||||||
|
print(f"\n Common columns: {len(common_cols)}")
|
||||||
|
print(f" Only in gold standard: {len(only_in_gold)}")
|
||||||
|
print(f" Only in polars: {len(only_in_polars)}")
|
||||||
|
|
||||||
|
if only_in_gold:
|
||||||
|
print(f"\n Columns only in gold standard (first 20):")
|
||||||
|
for col in list(only_in_gold)[:20]:
|
||||||
|
print(f" {col}")
|
||||||
|
|
||||||
|
if only_in_polars:
|
||||||
|
print(f"\n Columns only in polars (first 20):")
|
||||||
|
for col in list(only_in_polars)[:20]:
|
||||||
|
print(f" {col}")
|
||||||
|
|
||||||
|
# Check common columns values
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("STEP 6: Compare Values for Common Columns")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Get common columns as tuples
|
||||||
|
common_tuples = []
|
||||||
|
for gc in gold_raw.columns:
|
||||||
|
gc_str = str(gc)
|
||||||
|
for pc in df_pandas.columns:
|
||||||
|
if str(pc) == gc_str:
|
||||||
|
common_tuples.append((gc, pc))
|
||||||
|
break
|
||||||
|
|
||||||
|
print(f"\nComparing {len(common_tuples)} common columns...")
|
||||||
|
|
||||||
|
# Compare first few columns
|
||||||
|
matching_count = 0
|
||||||
|
diff_count = 0
|
||||||
|
for i, (gc, pc) in enumerate(common_tuples[:20]):
|
||||||
|
gold_vals = gold_raw[gc].dropna().values
|
||||||
|
polars_vals = df_pandas[pc].dropna().values
|
||||||
|
|
||||||
|
if len(gold_vals) > 0 and len(polars_vals) > 0:
|
||||||
|
# Compare min, max, mean
|
||||||
|
if np.allclose([gold_vals.min(), gold_vals.max(), gold_vals.mean()],
|
||||||
|
[polars_vals.min(), polars_vals.max(), polars_vals.mean()],
|
||||||
|
rtol=1e-5):
|
||||||
|
matching_count += 1
|
||||||
|
else:
|
||||||
|
diff_count += 1
|
||||||
|
if diff_count <= 3:
|
||||||
|
print(f" DIFF: {gc}")
|
||||||
|
print(f" Gold: min={gold_vals.min():.6f}, max={gold_vals.max():.6f}, mean={gold_vals.mean():.6f}")
|
||||||
|
print(f" Polars: min={polars_vals.min():.6f}, max={polars_vals.max():.6f}, mean={polars_vals.mean():.6f}")
|
||||||
|
|
||||||
|
print(f"\n Matching columns: {matching_count}")
|
||||||
|
print(f" Different columns: {diff_count}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError running polars pipeline: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=" * 80)
|
||||||
|
print("DATA DIVERGENCE DEBUG SCRIPT")
|
||||||
|
print("Comparing gold-standard qlib output vs polars-based pipeline")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Step 1: Check raw data
|
||||||
|
gold_raw = compare_raw_data()
|
||||||
|
|
||||||
|
# Step 2: Check processed data
|
||||||
|
gold_proc = compare_processed_data()
|
||||||
|
|
||||||
|
# Step 3: Analyze processor transformations
|
||||||
|
analyze_processor_pipeline(gold_raw, gold_proc)
|
||||||
|
|
||||||
|
# Step 4 & 5: Run polars pipeline and compare
|
||||||
|
check_polars_pipeline()
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("DEBUG COMPLETE")
|
||||||
|
print("=" * 80)
|
||||||
@ -0,0 +1,330 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Script to dump raw and processed datasets using the polars-based pipeline.
|
||||||
|
|
||||||
|
This generates:
|
||||||
|
1. Raw data (before applying processors) - equivalent to qlib's handler output
|
||||||
|
2. Processed data (after applying all processors) - ready for VAE encoding
|
||||||
|
|
||||||
|
Date range: 2026-02-23 to today (2026-02-27)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pickle as pkl
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from generate_beta_embedding import (
|
||||||
|
load_all_data,
|
||||||
|
merge_data_sources,
|
||||||
|
filter_stock_universe,
|
||||||
|
DiffProcessor,
|
||||||
|
FlagMarketInjector,
|
||||||
|
ColumnRemover,
|
||||||
|
FlagToOnehot,
|
||||||
|
IndusNtrlInjector,
|
||||||
|
RobustZScoreNorm,
|
||||||
|
Fillna,
|
||||||
|
load_qlib_processor_params,
|
||||||
|
ALPHA158_COLS,
|
||||||
|
INDUSTRY_FLAG_COLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Date range
|
||||||
|
START_DATE = "2026-02-23"
|
||||||
|
END_DATE = "2026-02-27"
|
||||||
|
|
||||||
|
# Output directory
|
||||||
|
OUTPUT_DIR = Path(__file__).parent.parent / "data_polars"
|
||||||
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Apply the full processor pipeline (equivalent to qlib's proc_list).
|
||||||
|
|
||||||
|
This mimics the qlib proc_list:
|
||||||
|
0. Diff: Adds diff features for market_ext columns
|
||||||
|
1. FlagMarketInjector: Adds market_0, market_1
|
||||||
|
2. FlagSTInjector: SKIPPED (fails even in gold-standard)
|
||||||
|
3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt
|
||||||
|
4. FlagToOnehot: Converts 29 industry flags to indus_idx
|
||||||
|
5. IndusNtrlInjector: Industry neutralization for feature
|
||||||
|
6. IndusNtrlInjector: Industry neutralization for feature_ext
|
||||||
|
7. RobustZScoreNorm: Normalization using pre-fitted qlib params
|
||||||
|
8. Fillna: Fill NaN with 0
|
||||||
|
"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Applying processor pipeline")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# market_ext columns (4 base)
|
||||||
|
market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
|
||||||
|
|
||||||
|
# market_flag columns (12 total before ColumnRemover)
|
||||||
|
market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
|
||||||
|
'open_limit', 'close_limit', 'low_limit',
|
||||||
|
'open_stop', 'close_stop', 'high_stop']
|
||||||
|
|
||||||
|
# Step 1: Diff Processor
|
||||||
|
print("\n[1] Applying Diff processor...")
|
||||||
|
diff_processor = DiffProcessor(market_ext_base)
|
||||||
|
df = diff_processor.process(df)
|
||||||
|
|
||||||
|
# After Diff: market_ext has 8 columns (4 base + 4 diff)
|
||||||
|
market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base]
|
||||||
|
|
||||||
|
# Step 2: FlagMarketInjector (adds market_0, market_1)
|
||||||
|
print("[2] Applying FlagMarketInjector...")
|
||||||
|
flag_injector = FlagMarketInjector()
|
||||||
|
df = flag_injector.process(df)
|
||||||
|
|
||||||
|
# Add market_0, market_1 to flag list
|
||||||
|
market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
|
||||||
|
|
||||||
|
# Step 3: FlagSTInjector - SKIPPED (fails even in gold-standard)
|
||||||
|
print("[3] Skipping FlagSTInjector (as per gold-standard behavior)...")
|
||||||
|
market_flag_with_st = market_flag_with_market # No IsST added
|
||||||
|
|
||||||
|
# Step 4: ColumnRemover
|
||||||
|
print("[4] Applying ColumnRemover...")
|
||||||
|
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
|
||||||
|
remover = ColumnRemover(columns_to_remove)
|
||||||
|
df = remover.process(df)
|
||||||
|
|
||||||
|
# Update column lists after removal
|
||||||
|
market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove]
|
||||||
|
market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove]
|
||||||
|
|
||||||
|
print(f" Removed columns: {columns_to_remove}")
|
||||||
|
print(f" Remaining market_ext: {len(market_ext_cols)} columns")
|
||||||
|
print(f" Remaining market_flag: {len(market_flag_with_st)} columns")
|
||||||
|
|
||||||
|
# Step 5: FlagToOnehot
|
||||||
|
print("[5] Applying FlagToOnehot...")
|
||||||
|
flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
|
||||||
|
df = flag_to_onehot.process(df)
|
||||||
|
|
||||||
|
# Step 6 & 7: IndusNtrlInjector
|
||||||
|
print("[6] Applying IndusNtrlInjector for alpha158...")
|
||||||
|
alpha158_cols = ALPHA158_COLS.copy()
|
||||||
|
indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl')
|
||||||
|
df = indus_ntrl_alpha.process(df)
|
||||||
|
|
||||||
|
print("[7] Applying IndusNtrlInjector for market_ext...")
|
||||||
|
indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl')
|
||||||
|
df = indus_ntrl_ext.process(df)
|
||||||
|
|
||||||
|
# Build column lists for normalization
|
||||||
|
alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols]
|
||||||
|
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols]
|
||||||
|
|
||||||
|
# Step 8: RobustZScoreNorm
|
||||||
|
print("[8] Applying RobustZScoreNorm...")
|
||||||
|
norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols
|
||||||
|
|
||||||
|
qlib_params = load_qlib_processor_params()
|
||||||
|
|
||||||
|
# Verify parameter shape
|
||||||
|
expected_features = len(norm_feature_cols)
|
||||||
|
if qlib_params['mean_train'].shape[0] != expected_features:
|
||||||
|
print(f" WARNING: Feature count mismatch! Expected {expected_features}, "
|
||||||
|
f"got {qlib_params['mean_train'].shape[0]}")
|
||||||
|
|
||||||
|
robust_norm = RobustZScoreNorm(
|
||||||
|
norm_feature_cols,
|
||||||
|
clip_range=(-3, 3),
|
||||||
|
use_qlib_params=True,
|
||||||
|
qlib_mean=qlib_params['mean_train'],
|
||||||
|
qlib_std=qlib_params['std_train']
|
||||||
|
)
|
||||||
|
df = robust_norm.process(df)
|
||||||
|
|
||||||
|
# Step 9: Fillna
|
||||||
|
print("[9] Applying Fillna...")
|
||||||
|
final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx']
|
||||||
|
fillna = Fillna()
|
||||||
|
df = fillna.process(df, final_feature_cols)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Processor pipeline complete!")
|
||||||
|
print(f" Normalized features: {len(norm_feature_cols)}")
|
||||||
|
print(f" Market flags: {len(market_flag_with_st)}")
|
||||||
|
print(f" Total features (with indus_idx): {len(final_feature_cols)}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
|
||||||
|
"""
|
||||||
|
Convert polars DataFrame to pandas DataFrame with MultiIndex columns.
|
||||||
|
This matches the format of qlib's output.
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
# Convert to pandas
|
||||||
|
df = df_polars.to_pandas()
|
||||||
|
|
||||||
|
# Check if datetime and instrument are columns
|
||||||
|
if 'datetime' in df.columns and 'instrument' in df.columns:
|
||||||
|
# Set MultiIndex
|
||||||
|
df = df.set_index(['datetime', 'instrument'])
|
||||||
|
# If they're already not in columns, assume they're already the index
|
||||||
|
|
||||||
|
# Drop raw columns that shouldn't be in processed data
|
||||||
|
raw_cols_to_drop = ['Turnover', 'FreeTurnover', 'MarketValue']
|
||||||
|
existing_raw_cols = [c for c in raw_cols_to_drop if c in df.columns]
|
||||||
|
if existing_raw_cols:
|
||||||
|
df = df.drop(columns=existing_raw_cols)
|
||||||
|
|
||||||
|
# Build MultiIndex columns based on column name patterns
|
||||||
|
columns_with_group = []
|
||||||
|
|
||||||
|
# Define column sets
|
||||||
|
alpha158_base = set(ALPHA158_COLS)
|
||||||
|
market_ext_base = {'turnover', 'free_turnover', 'log_size', 'con_rating_strength'}
|
||||||
|
market_ext_diff = {'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'}
|
||||||
|
market_ext_all = market_ext_base | market_ext_diff
|
||||||
|
feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit',
|
||||||
|
'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'}
|
||||||
|
|
||||||
|
for col in df.columns:
|
||||||
|
if col == 'indus_idx':
|
||||||
|
columns_with_group.append(('indus_idx', col))
|
||||||
|
elif col in feature_flag_cols:
|
||||||
|
columns_with_group.append(('feature_flag', col))
|
||||||
|
elif col.endswith('_ntrl'):
|
||||||
|
base_name = col[:-5] # Remove _ntrl suffix (5 characters)
|
||||||
|
if base_name in alpha158_base:
|
||||||
|
columns_with_group.append(('feature', col))
|
||||||
|
elif base_name in market_ext_all:
|
||||||
|
columns_with_group.append(('feature_ext', col))
|
||||||
|
else:
|
||||||
|
columns_with_group.append(('feature', col)) # Default to feature
|
||||||
|
elif col in alpha158_base:
|
||||||
|
columns_with_group.append(('feature', col))
|
||||||
|
elif col in market_ext_all:
|
||||||
|
columns_with_group.append(('feature_ext', col))
|
||||||
|
elif col in INDUSTRY_FLAG_COLS:
|
||||||
|
columns_with_group.append(('indus_flag', col))
|
||||||
|
elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}:
|
||||||
|
columns_with_group.append(('st_flag', col))
|
||||||
|
else:
|
||||||
|
# Unknown column - print warning
|
||||||
|
print(f" Warning: Unknown column '{col}', assigning to 'other' group")
|
||||||
|
columns_with_group.append(('other', col))
|
||||||
|
|
||||||
|
# Create MultiIndex columns
|
||||||
|
multi_cols = pd.MultiIndex.from_tuples(columns_with_group)
|
||||||
|
df.columns = multi_cols
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 80)
|
||||||
|
print("Dumping Polars Dataset")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"Date range: {START_DATE} to {END_DATE}")
|
||||||
|
print(f"Output directory: {OUTPUT_DIR}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Step 1: Load all data
|
||||||
|
print("Step 1: Loading data from parquet...")
|
||||||
|
df_alpha, df_kline, df_flag, df_industry = load_all_data(START_DATE, END_DATE)
|
||||||
|
print(f" Alpha158 shape: {df_alpha.shape}")
|
||||||
|
print(f" Kline (market_ext) shape: {df_kline.shape}")
|
||||||
|
print(f" Flags shape: {df_flag.shape}")
|
||||||
|
print(f" Industry shape: {df_industry.shape}")
|
||||||
|
|
||||||
|
# Step 2: Merge data sources
|
||||||
|
print("\nStep 2: Merging data sources...")
|
||||||
|
df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
|
||||||
|
print(f" Merged shape (after csiallx filter): {df_merged.shape}")
|
||||||
|
|
||||||
|
# Step 3: Save raw data (before processors)
|
||||||
|
print("\nStep 3: Saving raw data (before processors)...")
|
||||||
|
|
||||||
|
# Keep columns that match qlib's raw output format
|
||||||
|
# Include datetime and instrument for MultiIndex conversion
|
||||||
|
raw_columns = (
|
||||||
|
['datetime', 'instrument'] + # Index columns
|
||||||
|
ALPHA158_COLS + # feature group
|
||||||
|
['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + # feature_ext base
|
||||||
|
['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', # market_flag from kline
|
||||||
|
'open_limit', 'close_limit', 'low_limit',
|
||||||
|
'open_stop', 'close_stop', 'high_stop'] +
|
||||||
|
INDUSTRY_FLAG_COLS + # indus_flag
|
||||||
|
(['ST_S', 'ST_Y'] if 'ST_S' in df_merged.columns else []) # st_flag (if available)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter to available columns
|
||||||
|
available_raw_cols = [c for c in raw_columns if c in df_merged.columns]
|
||||||
|
print(f" Selecting {len(available_raw_cols)} columns for raw data...")
|
||||||
|
df_raw_polars = df_merged.select(available_raw_cols)
|
||||||
|
|
||||||
|
# Convert to pandas with MultiIndex
|
||||||
|
df_raw_pd = convert_to_multiindex_df(df_raw_polars)
|
||||||
|
|
||||||
|
raw_output_path = OUTPUT_DIR / f"raw_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
|
||||||
|
with open(raw_output_path, "wb") as f:
|
||||||
|
pkl.dump(df_raw_pd, f)
|
||||||
|
print(f" Saved raw data to: {raw_output_path}")
|
||||||
|
print(f" Raw data shape: {df_raw_pd.shape}")
|
||||||
|
print(f" Column groups: {df_raw_pd.columns.get_level_values(0).unique().tolist()}")
|
||||||
|
|
||||||
|
# Step 4: Apply processor pipeline
|
||||||
|
print("\nStep 4: Applying processor pipeline...")
|
||||||
|
df_processed = apply_processor_pipeline(df_merged)
|
||||||
|
|
||||||
|
# Step 5: Save processed data
|
||||||
|
print("\nStep 5: Saving processed data (after processors)...")
|
||||||
|
|
||||||
|
# Convert to pandas with MultiIndex
|
||||||
|
df_processed_pd = convert_to_multiindex_df(df_processed)
|
||||||
|
|
||||||
|
processed_output_path = OUTPUT_DIR / f"processed_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
|
||||||
|
with open(processed_output_path, "wb") as f:
|
||||||
|
pkl.dump(df_processed_pd, f)
|
||||||
|
print(f" Saved processed data to: {processed_output_path}")
|
||||||
|
print(f" Processed data shape: {df_processed_pd.shape}")
|
||||||
|
print(f" Column groups: {df_processed_pd.columns.get_level_values(0).unique().tolist()}")
|
||||||
|
|
||||||
|
# Count columns per group
|
||||||
|
print("\n Column counts by group:")
|
||||||
|
for grp in df_processed_pd.columns.get_level_values(0).unique().tolist():
|
||||||
|
count = (df_processed_pd.columns.get_level_values(0) == grp).sum()
|
||||||
|
print(f" {grp}: {count} columns")
|
||||||
|
|
||||||
|
# Step 6: Verify column counts
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Verification")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
feature_flag_cols = [c[1] for c in df_processed_pd.columns if c[0] == 'feature_flag']
|
||||||
|
has_market_0 = 'market_0' in feature_flag_cols
|
||||||
|
has_market_1 = 'market_1' in feature_flag_cols
|
||||||
|
|
||||||
|
print(f" feature_flag columns: {feature_flag_cols}")
|
||||||
|
print(f" Has market_0: {has_market_0}")
|
||||||
|
print(f" Has market_1: {has_market_1}")
|
||||||
|
|
||||||
|
if has_market_0 and has_market_1:
|
||||||
|
print("\n SUCCESS: market_0 and market_1 columns are present!")
|
||||||
|
else:
|
||||||
|
print("\n WARNING: market_0 or market_1 columns are missing!")
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Dataset dump complete!")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in new issue