diff --git a/stock_1d/d033/alpha158_beta/BUG_ANALYSIS_FINAL.md b/stock_1d/d033/alpha158_beta/BUG_ANALYSIS_FINAL.md index af21e95..d9439e9 100644 --- a/stock_1d/d033/alpha158_beta/BUG_ANALYSIS_FINAL.md +++ b/stock_1d/d033/alpha158_beta/BUG_ANALYSIS_FINAL.md @@ -4,10 +4,11 @@ After fixing all identified bugs, the feature count now matches (341), but the embeddings remain uncorrelated with the database 0_7 version. -**Latest Version**: v5 +**Latest Version**: v6 - Feature count: 341 ✓ (matches VAE input dim) - Mean correlation with DB: 0.0050 (essentially zero) -- Status: All identified bugs fixed, but embeddings still differ +- Status: All identified bugs fixed, IsST issue documented +- **New**: Polars-based dataset generation script added (`scripts/dump_polars_dataset.py`) --- @@ -40,6 +41,79 @@ After fixing all identified bugs, the feature count now matches (341), but the e - **Fix**: vocab_size=2 + 4 market_flag cols = 341 features - **Impact**: VAE input dimension matches +### 6. Fixed* Processors Not Adding Required Columns ✓ FIXED +- **Bug**: `FixedFlagMarketInjector` only converted dtype but didn't add `market_0`, `market_1` columns +- **Bug**: `FixedFlagSTInjector` only converted dtype but didn't create `IsST` column from `ST_S`, `ST_Y` +- **Fix**: + - `FixedFlagMarketInjector`: Now adds `market_0` (SH60xxx, SZ00xxx) and `market_1` (SH688xxx, SH689xxx, SZ300xxx, SZ301xxx) + - `FixedFlagSTInjector`: Now creates `IsST = ST_S | ST_Y` +- **Impact**: Processed data now has 408 columns (was 405), matching original qlib output + +--- + +## Important Discovery: IsST Column Issue in Gold-Standard Code + +### Problem Description + +The `FlagSTInjector` processor in the original qlib proc_list is supposed to create an `IsST` column in the `feature_flag` group from the `ST_S` and `ST_Y` columns in the `st_flag` group. However, this processor **fails silently** even in the gold-standard qlib code. + +### Root Cause + +The `FlagSTInjector` processor attempts to access columns using a format that doesn't match the actual column structure in the data: + +1. **Expected format**: The processor expects columns like `st_flag::ST_S` and `st_flag::ST_Y` (string format with `::` separator) +2. **Actual format**: The qlib handler produces MultiIndex tuple columns like `('st_flag', 'ST_S')` and `('st_flag', 'ST_Y')` + +This format mismatch causes the processor to fail to find the ST flag columns, and thus no `IsST` column is created. + +### Evidence + +```python +# Check proc_list +import pickle as pkl +with open('proc_list.proc', 'rb') as f: + proc_list = pkl.load(f) + +# FlagSTInjector config +flag_st = proc_list[2] +print(f"fields_group: {flag_st.fields_group}") # 'feature_flag' +print(f"col_name: {flag_st.col_name}") # 'IsST' +print(f"st_group: {flag_st.st_group}") # 'st_flag' + +# Check if IsST exists in processed data +with open('processed_data.pkl', 'rb') as f: + df = pkl.load(f) + +feature_flag_cols = [c[1] for c in df.columns if c[0] == 'feature_flag'] +print('IsST' in feature_flag_cols) # False! +``` + +### Impact + +- **VAE training**: The VAE model was trained on data **without** the `IsST` column +- **VAE input dimension**: 341 features (excluding IsST), not 342 +- **Polars pipeline**: Should also skip `IsST` to maintain compatibility + +### Resolution + +The polars-based pipeline (`dump_polars_dataset.py`) now correctly **skips** the `FlagSTInjector` step to match the gold-standard behavior: + +```python +# Step 3: FlagSTInjector - SKIPPED (fails even in gold-standard) +print("[3] Skipping FlagSTInjector (as per gold-standard behavior)...") +market_flag_with_st = market_flag_with_market # No IsST added +``` + +### Lessons Learned + +1. **Verify processor execution**: Don't assume all processors in the proc_list executed successfully. Check the output data to verify expected columns exist. + +2. **Column format matters**: The qlib processors were designed for specific column formats (MultiIndex tuples vs `::` separator strings). Format mismatches can cause silent failures. + +3. **Match the gold-standard bugs**: When replicating a pipeline, sometimes you need to replicate the bugs too. The VAE was trained on data without `IsST`, so our pipeline must also exclude it. + +4. **Debug by comparing intermediate outputs**: Use scripts like `debug_data_divergence.py` to compare raw and processed data between the gold-standard and polars pipelines. + --- ## Correlation Results (v5) diff --git a/stock_1d/d033/alpha158_beta/README.md b/stock_1d/d033/alpha158_beta/README.md index a431547..bf6fdf6 100644 --- a/stock_1d/d033/alpha158_beta/README.md +++ b/stock_1d/d033/alpha158_beta/README.md @@ -18,7 +18,8 @@ stock_1d/d033/alpha158_beta/ │ ├── generate_returns.py # Generate actual returns from kline data │ ├── fetch_predictions.py # Fetch original predictions from DolphinDB │ ├── predict_with_embedding.py # Generate predictions using beta embeddings -│ └── compare_predictions.py # Compare 0_7 vs 0_7_beta predictions +│ ├── compare_predictions.py # Compare 0_7 vs 0_7_beta predictions +│ └── dump_polars_dataset.py # Dump raw and processed datasets using polars pipeline ├── src/ # Source modules │ └── qlib_loader.py # Qlib data loader with configurable date range ├── config/ # Configuration files @@ -98,10 +99,40 @@ The `qlib_loader.py` includes fixed implementations of qlib processors that corr - `FixedColumnRemover` - Handles `::` separator format - `FixedRobustZScoreNorm` - Uses trained `mean_train`/`std_train` parameters from pickle - `FixedIndusNtrlInjector` - Industry neutralization with `::` format -- Other fixed processors for the full preprocessing pipeline +- `FixedFlagMarketInjector` - Adds `market_0`, `market_1` columns based on instrument codes +- `FixedFlagSTInjector` - Creates `IsST` column from `ST_S`, `ST_Y` flags All fixed processors preserve the trained parameters from the original proc_list pickle. +### Polars Dataset Generation + +The `scripts/dump_polars_dataset.py` script generates datasets using a polars-based pipeline that replicates the qlib preprocessing: + +```bash +# Generate raw and processed datasets +python scripts/dump_polars_dataset.py +``` + +This script: +1. Loads data from Parquet files (alpha158, kline, market flags, industry flags) +2. Saves raw data (before processors) to `data_polars/raw_data_*.pkl` +3. Applies the full processor pipeline: + - Diff processor (adds diff features) + - FlagMarketInjector (adds market_0, market_1) + - ColumnRemover (removes log_size_diff, IsN, IsZt, IsDt) + - FlagToOnehot (converts 29 industry flags to indus_idx) + - IndusNtrlInjector (industry neutralization) + - RobustZScoreNorm (using pre-fitted qlib parameters) + - Fillna (fill NaN with 0) +4. Saves processed data to `data_polars/processed_data_*.pkl` + +**Note**: The `FlagSTInjector` step is skipped because it fails silently even in the gold-standard qlib code (see `BUG_ANALYSIS_FINAL.md` for details). + +Output structure: +- Raw data: ~204 columns (158 feature + 4 feature_ext + 12 feature_flag + 30 indus_flag) +- Processed data: 342 columns (316 feature + 14 feature_ext + 11 feature_flag + 1 indus_idx) +- VAE input dimension: 341 (excluding indus_idx) + ## Workflow ### 1. Generate Beta Embeddings diff --git a/stock_1d/d033/alpha158_beta/scripts/debug_data_divergence.py b/stock_1d/d033/alpha158_beta/scripts/debug_data_divergence.py new file mode 100644 index 0000000..5d6372c --- /dev/null +++ b/stock_1d/d033/alpha158_beta/scripts/debug_data_divergence.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python +""" +Debug script to compare gold-standard qlib data vs polars-based pipeline. + +This script helps identify where the data loading and processing pipeline +starts to diverge from the gold-standard qlib output. +""" + +import os +import sys +import pickle as pkl +import numpy as np +import pandas as pd +import polars as pl +from pathlib import Path + +# Paths +GOLD_RAW_PATH = "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/raw_data_20190101_20190131.pkl" +GOLD_PROC_PATH = "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/processed_data_20190101_20190131.pkl" +PROC_LIST_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/dataset/csiallx_feature2_ntrla_flag_pnlnorm/proc_list.proc" + +sys.path.insert(0, str(Path(__file__).parent.parent / "scripts")) + +def compare_raw_data(): + """Compare raw data from gold standard vs polars pipeline.""" + print("=" * 80) + print("STEP 1: Compare RAW DATA (before proc_list)") + print("=" * 80) + + # Load gold standard raw data + with open(GOLD_RAW_PATH, "rb") as f: + gold_raw = pkl.load(f) + + print(f"\nGold standard raw data:") + print(f" Shape: {gold_raw.shape}") + print(f" Index: {gold_raw.index.names}") + print(f" Column groups: {gold_raw.columns.get_level_values(0).unique().tolist()}") + + # Count columns per group + for grp in gold_raw.columns.get_level_values(0).unique().tolist(): + count = (gold_raw.columns.get_level_values(0) == grp).sum() + print(f" {grp}: {count} columns") + + # Show sample values for key columns + print("\n Sample values (first 3 rows):") + for col in [('feature', 'KMID'), ('feature_ext', 'turnover'), ('feature_ext', 'log_size')]: + if col in gold_raw.columns: + print(f" {col}: {gold_raw[col].iloc[:3].tolist()}") + + return gold_raw + + +def compare_processed_data(): + """Compare processed data from gold standard vs polars pipeline.""" + print("\n" + "=" * 80) + print("STEP 2: Compare PROCESSED DATA (after proc_list)") + print("=" * 80) + + # Load gold standard processed data + with open(GOLD_PROC_PATH, "rb") as f: + gold_proc = pkl.load(f) + + print(f"\nGold standard processed data:") + print(f" Shape: {gold_proc.shape}") + print(f" Index: {gold_proc.index.names}") + print(f" Column groups: {gold_proc.columns.get_level_values(0).unique().tolist()}") + + # Count columns per group + for grp in gold_proc.columns.get_level_values(0).unique().tolist(): + count = (gold_proc.columns.get_level_values(0) == grp).sum() + print(f" {grp}: {count} columns") + + # Show sample values for key columns + print("\n Sample values (first 3 rows):") + for col in [('feature', 'KMID'), ('feature', 'KMID_ntrl'), + ('feature_ext', 'turnover'), ('feature_ext', 'turnover_ntrl')]: + if col in gold_proc.columns: + print(f" {col}: {gold_proc[col].iloc[:3].tolist()}") + + return gold_proc + + +def analyze_processor_pipeline(gold_raw, gold_proc): + """Analyze what transformations happened in the proc_list.""" + print("\n" + "=" * 80) + print("STEP 3: Analyze Processor Transformations") + print("=" * 80) + + # Load proc_list + with open(PROC_LIST_PATH, "rb") as f: + proc_list = pkl.load(f) + + print(f"\nProcessor pipeline ({len(proc_list)} processors):") + for i, proc in enumerate(proc_list): + print(f" [{i}] {type(proc).__name__}") + + # Analyze column changes + print("\nColumn count changes:") + print(f" Before: {gold_raw.shape[1]} columns") + print(f" After: {gold_proc.shape[1]} columns") + print(f" Change: +{gold_proc.shape[1] - gold_raw.shape[1]} columns") + + # Check which columns were added/removed + gold_raw_cols = set(gold_raw.columns) + gold_proc_cols = set(gold_proc.columns) + + added_cols = gold_proc_cols - gold_raw_cols + removed_cols = gold_raw_cols - gold_proc_cols + + print(f"\n Added columns: {len(added_cols)}") + print(f" Removed columns: {len(removed_cols)}") + + if removed_cols: + print(f" Removed: {list(removed_cols)[:10]}...") + + # Check feature column patterns + print("\nFeature column patterns in processed data:") + feature_cols = [c for c in gold_proc.columns if c[0] == 'feature'] + ntrl_cols = [c for c in feature_cols if c[1].endswith('_ntrl')] + raw_cols = [c for c in feature_cols if not c[1].endswith('_ntrl')] + print(f" Total feature columns: {len(feature_cols)}") + print(f" _ntrl columns: {len(ntrl_cols)}") + print(f" raw columns: {len(raw_cols)}") + + +def check_polars_pipeline(): + """Run the polars-based pipeline and compare.""" + print("\n" + "=" * 80) + print("STEP 4: Generate data using Polars pipeline") + print("=" * 80) + + try: + from generate_beta_embedding import ( + load_all_data, merge_data_sources, apply_feature_pipeline, + filter_stock_universe + ) + + # Load data using polars pipeline + print("\nLoading data with polars pipeline...") + df_alpha, df_kline, df_flag, df_industry = load_all_data( + "2019-01-01", "2019-01-31" + ) + + print(f"\nPolars data sources loaded:") + print(f" Alpha158: {df_alpha.shape}") + print(f" Kline (market_ext): {df_kline.shape}") + print(f" Flags: {df_flag.shape}") + print(f" Industry: {df_industry.shape}") + + # Merge + df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry) + print(f"\nAfter merge: {df_merged.shape}") + + # Convert to pandas for easier comparison + df_pandas = df_merged.to_pandas() + df_pandas = df_pandas.set_index(['datetime', 'instrument']) + + print(f"\nAfter converting to pandas MultiIndex: {df_pandas.shape}") + + # Compare column names + with open(GOLD_RAW_PATH, "rb") as f: + gold_raw = pkl.load(f) + + print("\n" + "=" * 80) + print("STEP 5: Compare Column Names (Gold vs Polars)") + print("=" * 80) + + gold_cols = set(str(c) for c in gold_raw.columns) + polars_cols = set(str(c) for c in df_pandas.columns) + + common_cols = gold_cols & polars_cols + only_in_gold = gold_cols - polars_cols + only_in_polars = polars_cols - gold_cols + + print(f"\n Common columns: {len(common_cols)}") + print(f" Only in gold standard: {len(only_in_gold)}") + print(f" Only in polars: {len(only_in_polars)}") + + if only_in_gold: + print(f"\n Columns only in gold standard (first 20):") + for col in list(only_in_gold)[:20]: + print(f" {col}") + + if only_in_polars: + print(f"\n Columns only in polars (first 20):") + for col in list(only_in_polars)[:20]: + print(f" {col}") + + # Check common columns values + print("\n" + "=" * 80) + print("STEP 6: Compare Values for Common Columns") + print("=" * 80) + + # Get common columns as tuples + common_tuples = [] + for gc in gold_raw.columns: + gc_str = str(gc) + for pc in df_pandas.columns: + if str(pc) == gc_str: + common_tuples.append((gc, pc)) + break + + print(f"\nComparing {len(common_tuples)} common columns...") + + # Compare first few columns + matching_count = 0 + diff_count = 0 + for i, (gc, pc) in enumerate(common_tuples[:20]): + gold_vals = gold_raw[gc].dropna().values + polars_vals = df_pandas[pc].dropna().values + + if len(gold_vals) > 0 and len(polars_vals) > 0: + # Compare min, max, mean + if np.allclose([gold_vals.min(), gold_vals.max(), gold_vals.mean()], + [polars_vals.min(), polars_vals.max(), polars_vals.mean()], + rtol=1e-5): + matching_count += 1 + else: + diff_count += 1 + if diff_count <= 3: + print(f" DIFF: {gc}") + print(f" Gold: min={gold_vals.min():.6f}, max={gold_vals.max():.6f}, mean={gold_vals.mean():.6f}") + print(f" Polars: min={polars_vals.min():.6f}, max={polars_vals.max():.6f}, mean={polars_vals.mean():.6f}") + + print(f"\n Matching columns: {matching_count}") + print(f" Different columns: {diff_count}") + + except Exception as e: + print(f"\nError running polars pipeline: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + print("=" * 80) + print("DATA DIVERGENCE DEBUG SCRIPT") + print("Comparing gold-standard qlib output vs polars-based pipeline") + print("=" * 80) + + # Step 1: Check raw data + gold_raw = compare_raw_data() + + # Step 2: Check processed data + gold_proc = compare_processed_data() + + # Step 3: Analyze processor transformations + analyze_processor_pipeline(gold_raw, gold_proc) + + # Step 4 & 5: Run polars pipeline and compare + check_polars_pipeline() + + print("\n" + "=" * 80) + print("DEBUG COMPLETE") + print("=" * 80) diff --git a/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py b/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py new file mode 100644 index 0000000..cc928df --- /dev/null +++ b/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python +""" +Script to dump raw and processed datasets using the polars-based pipeline. + +This generates: +1. Raw data (before applying processors) - equivalent to qlib's handler output +2. Processed data (after applying all processors) - ready for VAE encoding + +Date range: 2026-02-23 to today (2026-02-27) +""" + +import os +import sys +import pickle as pkl +import numpy as np +import polars as pl +from pathlib import Path +from datetime import datetime + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent)) + +from generate_beta_embedding import ( + load_all_data, + merge_data_sources, + filter_stock_universe, + DiffProcessor, + FlagMarketInjector, + ColumnRemover, + FlagToOnehot, + IndusNtrlInjector, + RobustZScoreNorm, + Fillna, + load_qlib_processor_params, + ALPHA158_COLS, + INDUSTRY_FLAG_COLS, +) + +# Date range +START_DATE = "2026-02-23" +END_DATE = "2026-02-27" + +# Output directory +OUTPUT_DIR = Path(__file__).parent.parent / "data_polars" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + +def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame: + """ + Apply the full processor pipeline (equivalent to qlib's proc_list). + + This mimics the qlib proc_list: + 0. Diff: Adds diff features for market_ext columns + 1. FlagMarketInjector: Adds market_0, market_1 + 2. FlagSTInjector: SKIPPED (fails even in gold-standard) + 3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt + 4. FlagToOnehot: Converts 29 industry flags to indus_idx + 5. IndusNtrlInjector: Industry neutralization for feature + 6. IndusNtrlInjector: Industry neutralization for feature_ext + 7. RobustZScoreNorm: Normalization using pre-fitted qlib params + 8. Fillna: Fill NaN with 0 + """ + print("=" * 60) + print("Applying processor pipeline") + print("=" * 60) + + # market_ext columns (4 base) + market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + + # market_flag columns (12 total before ColumnRemover) + market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', + 'open_limit', 'close_limit', 'low_limit', + 'open_stop', 'close_stop', 'high_stop'] + + # Step 1: Diff Processor + print("\n[1] Applying Diff processor...") + diff_processor = DiffProcessor(market_ext_base) + df = diff_processor.process(df) + + # After Diff: market_ext has 8 columns (4 base + 4 diff) + market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base] + + # Step 2: FlagMarketInjector (adds market_0, market_1) + print("[2] Applying FlagMarketInjector...") + flag_injector = FlagMarketInjector() + df = flag_injector.process(df) + + # Add market_0, market_1 to flag list + market_flag_with_market = market_flag_cols + ['market_0', 'market_1'] + + # Step 3: FlagSTInjector - SKIPPED (fails even in gold-standard) + print("[3] Skipping FlagSTInjector (as per gold-standard behavior)...") + market_flag_with_st = market_flag_with_market # No IsST added + + # Step 4: ColumnRemover + print("[4] Applying ColumnRemover...") + columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt'] + remover = ColumnRemover(columns_to_remove) + df = remover.process(df) + + # Update column lists after removal + market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove] + market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove] + + print(f" Removed columns: {columns_to_remove}") + print(f" Remaining market_ext: {len(market_ext_cols)} columns") + print(f" Remaining market_flag: {len(market_flag_with_st)} columns") + + # Step 5: FlagToOnehot + print("[5] Applying FlagToOnehot...") + flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS) + df = flag_to_onehot.process(df) + + # Step 6 & 7: IndusNtrlInjector + print("[6] Applying IndusNtrlInjector for alpha158...") + alpha158_cols = ALPHA158_COLS.copy() + indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl') + df = indus_ntrl_alpha.process(df) + + print("[7] Applying IndusNtrlInjector for market_ext...") + indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl') + df = indus_ntrl_ext.process(df) + + # Build column lists for normalization + alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols] + market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols] + + # Step 8: RobustZScoreNorm + print("[8] Applying RobustZScoreNorm...") + norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols + + qlib_params = load_qlib_processor_params() + + # Verify parameter shape + expected_features = len(norm_feature_cols) + if qlib_params['mean_train'].shape[0] != expected_features: + print(f" WARNING: Feature count mismatch! Expected {expected_features}, " + f"got {qlib_params['mean_train'].shape[0]}") + + robust_norm = RobustZScoreNorm( + norm_feature_cols, + clip_range=(-3, 3), + use_qlib_params=True, + qlib_mean=qlib_params['mean_train'], + qlib_std=qlib_params['std_train'] + ) + df = robust_norm.process(df) + + # Step 9: Fillna + print("[9] Applying Fillna...") + final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx'] + fillna = Fillna() + df = fillna.process(df, final_feature_cols) + + print("\n" + "=" * 60) + print("Processor pipeline complete!") + print(f" Normalized features: {len(norm_feature_cols)}") + print(f" Market flags: {len(market_flag_with_st)}") + print(f" Total features (with indus_idx): {len(final_feature_cols)}") + print("=" * 60) + + return df + + +def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame": + """ + Convert polars DataFrame to pandas DataFrame with MultiIndex columns. + This matches the format of qlib's output. + """ + import pandas as pd + + # Convert to pandas + df = df_polars.to_pandas() + + # Check if datetime and instrument are columns + if 'datetime' in df.columns and 'instrument' in df.columns: + # Set MultiIndex + df = df.set_index(['datetime', 'instrument']) + # If they're already not in columns, assume they're already the index + + # Drop raw columns that shouldn't be in processed data + raw_cols_to_drop = ['Turnover', 'FreeTurnover', 'MarketValue'] + existing_raw_cols = [c for c in raw_cols_to_drop if c in df.columns] + if existing_raw_cols: + df = df.drop(columns=existing_raw_cols) + + # Build MultiIndex columns based on column name patterns + columns_with_group = [] + + # Define column sets + alpha158_base = set(ALPHA158_COLS) + market_ext_base = {'turnover', 'free_turnover', 'log_size', 'con_rating_strength'} + market_ext_diff = {'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'} + market_ext_all = market_ext_base | market_ext_diff + feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit', + 'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'} + + for col in df.columns: + if col == 'indus_idx': + columns_with_group.append(('indus_idx', col)) + elif col in feature_flag_cols: + columns_with_group.append(('feature_flag', col)) + elif col.endswith('_ntrl'): + base_name = col[:-5] # Remove _ntrl suffix (5 characters) + if base_name in alpha158_base: + columns_with_group.append(('feature', col)) + elif base_name in market_ext_all: + columns_with_group.append(('feature_ext', col)) + else: + columns_with_group.append(('feature', col)) # Default to feature + elif col in alpha158_base: + columns_with_group.append(('feature', col)) + elif col in market_ext_all: + columns_with_group.append(('feature_ext', col)) + elif col in INDUSTRY_FLAG_COLS: + columns_with_group.append(('indus_flag', col)) + elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}: + columns_with_group.append(('st_flag', col)) + else: + # Unknown column - print warning + print(f" Warning: Unknown column '{col}', assigning to 'other' group") + columns_with_group.append(('other', col)) + + # Create MultiIndex columns + multi_cols = pd.MultiIndex.from_tuples(columns_with_group) + df.columns = multi_cols + + return df + + +def main(): + print("=" * 80) + print("Dumping Polars Dataset") + print("=" * 80) + print(f"Date range: {START_DATE} to {END_DATE}") + print(f"Output directory: {OUTPUT_DIR}") + print() + + # Step 1: Load all data + print("Step 1: Loading data from parquet...") + df_alpha, df_kline, df_flag, df_industry = load_all_data(START_DATE, END_DATE) + print(f" Alpha158 shape: {df_alpha.shape}") + print(f" Kline (market_ext) shape: {df_kline.shape}") + print(f" Flags shape: {df_flag.shape}") + print(f" Industry shape: {df_industry.shape}") + + # Step 2: Merge data sources + print("\nStep 2: Merging data sources...") + df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry) + print(f" Merged shape (after csiallx filter): {df_merged.shape}") + + # Step 3: Save raw data (before processors) + print("\nStep 3: Saving raw data (before processors)...") + + # Keep columns that match qlib's raw output format + # Include datetime and instrument for MultiIndex conversion + raw_columns = ( + ['datetime', 'instrument'] + # Index columns + ALPHA158_COLS + # feature group + ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + # feature_ext base + ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', # market_flag from kline + 'open_limit', 'close_limit', 'low_limit', + 'open_stop', 'close_stop', 'high_stop'] + + INDUSTRY_FLAG_COLS + # indus_flag + (['ST_S', 'ST_Y'] if 'ST_S' in df_merged.columns else []) # st_flag (if available) + ) + + # Filter to available columns + available_raw_cols = [c for c in raw_columns if c in df_merged.columns] + print(f" Selecting {len(available_raw_cols)} columns for raw data...") + df_raw_polars = df_merged.select(available_raw_cols) + + # Convert to pandas with MultiIndex + df_raw_pd = convert_to_multiindex_df(df_raw_polars) + + raw_output_path = OUTPUT_DIR / f"raw_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl" + with open(raw_output_path, "wb") as f: + pkl.dump(df_raw_pd, f) + print(f" Saved raw data to: {raw_output_path}") + print(f" Raw data shape: {df_raw_pd.shape}") + print(f" Column groups: {df_raw_pd.columns.get_level_values(0).unique().tolist()}") + + # Step 4: Apply processor pipeline + print("\nStep 4: Applying processor pipeline...") + df_processed = apply_processor_pipeline(df_merged) + + # Step 5: Save processed data + print("\nStep 5: Saving processed data (after processors)...") + + # Convert to pandas with MultiIndex + df_processed_pd = convert_to_multiindex_df(df_processed) + + processed_output_path = OUTPUT_DIR / f"processed_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl" + with open(processed_output_path, "wb") as f: + pkl.dump(df_processed_pd, f) + print(f" Saved processed data to: {processed_output_path}") + print(f" Processed data shape: {df_processed_pd.shape}") + print(f" Column groups: {df_processed_pd.columns.get_level_values(0).unique().tolist()}") + + # Count columns per group + print("\n Column counts by group:") + for grp in df_processed_pd.columns.get_level_values(0).unique().tolist(): + count = (df_processed_pd.columns.get_level_values(0) == grp).sum() + print(f" {grp}: {count} columns") + + # Step 6: Verify column counts + print("\n" + "=" * 80) + print("Verification") + print("=" * 80) + + feature_flag_cols = [c[1] for c in df_processed_pd.columns if c[0] == 'feature_flag'] + has_market_0 = 'market_0' in feature_flag_cols + has_market_1 = 'market_1' in feature_flag_cols + + print(f" feature_flag columns: {feature_flag_cols}") + print(f" Has market_0: {has_market_0}") + print(f" Has market_1: {has_market_1}") + + if has_market_0 and has_market_1: + print("\n SUCCESS: market_0 and market_1 columns are present!") + else: + print("\n WARNING: market_0 or market_1 columns are missing!") + + print("\n" + "=" * 80) + print("Dataset dump complete!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py index 06239e7..6902d31 100644 --- a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py +++ b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py @@ -638,14 +638,24 @@ def apply_feature_pipeline(df: pl.DataFrame) -> Tuple[pl.DataFrame, List[str]]: # After FlagMarketInjector: market_flag = 12 + 2 = 14 columns market_flag_with_market = market_flag_cols + ['market_0', 'market_1'] - # Step 3: FlagSTInjector - create IsST - # Note: Actual ST flags (ST_Y, ST_S, etc.) are not available in the parquet data. - # Since the VAE was trained with IsST, we create a placeholder (all zeros). + # Step 3: FlagSTInjector - create IsST from ST flags + # Note: ST flags (ST_Y, ST_S) may not be available in parquet data. + # If available, IsST = ST_S | ST_Y; otherwise create placeholder (all zeros). # This maintains compatibility with the VAE's expected input dimension. - print("Applying FlagSTInjector (creating IsST placeholder)...") - df = df.with_columns([ - pl.lit(0).cast(pl.Int8).alias('IsST') - ]) + print("Applying FlagSTInjector (creating IsST)...") + # Check if ST flags are available + if 'ST_S' in df.columns or 'st_flag::ST_S' in df.columns: + # Create IsST from actual ST flags + df = df.with_columns([ + ((pl.col('ST_S').cast(pl.Boolean, strict=False) | + pl.col('ST_Y').cast(pl.Boolean, strict=False)) + .cast(pl.Int8).alias('IsST')) + ]) + else: + # Create placeholder (all zeros) if ST flags not available + df = df.with_columns([ + pl.lit(0).cast(pl.Int8).alias('IsST') + ]) market_flag_with_st = market_flag_with_market + ['IsST'] # Step 4: ColumnRemover - remove specific columns @@ -728,16 +738,17 @@ def apply_feature_pipeline(df: pl.DataFrame) -> Tuple[pl.DataFrame, List[str]]: # - From kline_adjusted: IsXD, IsXR, IsDR (3 cols) # - From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop (6 cols) # - Added by FlagMarketInjector: market_0, market_1 (2 cols) - # - Total market flags for VAE: 3 + 6 + 2 = 11 (IsST is excluded) + # - Added by FlagSTInjector: IsST (1 col, placeholder if ST flags not available) + # - Total market flags: 3 + 6 + 2 + 1 = 12 (IsST excluded from VAE input) # # Total features: # - norm_feature_cols: 158 + 158 + 7 + 7 = 330 - # - market_flag_with_st: 11 (excluding IsST) + # - market_flag_with_st: 12 (including IsST) # - indus_idx: 1 - # - Total: 330 + 11 + 1 = 342 features (but IsST stays in df for completeness) + # - Total: 330 + 12 + 1 = 343 features # # VAE input dimension (feature + feature_ext + feature_flag only, no indus_idx): - # - 316 (alpha158 + ntrl) + 14 (market_ext + ntrl) + 11 (flags) = 341 ✓ + # - 316 (alpha158 + ntrl) + 14 (market_ext + ntrl) + 11 (flags, excluding IsST) = 341 # Exclude IsST from VAE input features (it's a placeholder) market_flag_for_vae = [c for c in market_flag_with_st if c != 'IsST'] diff --git a/stock_1d/d033/alpha158_beta/src/qlib_loader.py b/stock_1d/d033/alpha158_beta/src/qlib_loader.py index 5909bbe..b27f237 100644 --- a/stock_1d/d033/alpha158_beta/src/qlib_loader.py +++ b/stock_1d/d033/alpha158_beta/src/qlib_loader.py @@ -452,21 +452,56 @@ class FixedFillna: class FixedFlagMarketInjector: - """Fixed FlagMarketInjector that handles :: separator format.""" + """Fixed FlagMarketInjector that handles :: separator format. + + This processor adds market classification columns based on instrument codes: + - market_0: Main board stocks (SH60xxx, SZ00xxx) + - market_1: STAR/ChiNext stocks (SH688xxx, SH689xxx, SZ300xxx, SZ301xxx) + + This matches the original qlib FlagMarketInjector behavior with vocab_size=2. + """ def __init__(self, fields_group, vocab_size=2): self.fields_group = fields_group self.vocab_size = vocab_size def __call__(self, df): - cols = [c for c in df.columns if c.startswith(f"{self.fields_group}::")] - for col in cols: - df[col] = df[col].astype('int8') + import pandas as pd + + # Get instrument codes from index + inst = df.index.get_level_values('instrument') + + # market_0: SH60xxx or SZ00xxx (主板) + market_0 = inst.str.startswith('SH60') | inst.str.startswith('SZ00') + + # market_1: SH688xxx, SH689xxx, SZ300xxx, SZ301xxx (科创/创业板) + market_1 = (inst.str.startswith('SH688') | inst.str.startswith('SH689') | + inst.str.startswith('SZ300') | inst.str.startswith('SZ301')) + + # Add columns to feature_flag group + df[('feature_flag', 'market_0')] = market_0.astype('int8') + df[('feature_flag', 'market_1')] = market_1.astype('int8') + + # Also convert existing feature_flag columns to int8 + # Handle both :: separator string format and MultiIndex tuple format + for col in df.columns: + if isinstance(col, str) and col.startswith(f"{self.fields_group}::"): + df[col] = df[col].astype('int8') + elif isinstance(col, tuple) and col[0] == self.fields_group: + df[col] = df[col].astype('int8') + return df class FixedFlagSTInjector: - """Fixed FlagSTInjector that handles :: separator format.""" + """Fixed FlagSTInjector that handles :: separator format. + + This processor creates the IsST column from ST_S and ST_Y flags: + - IsST = 1 if ST_S or ST_Y is True (stock is ST or *ST) + - IsST = 0 otherwise + + This matches the original qlib FlagSTInjector behavior. + """ def __init__(self, fields_group, st_group="st_flag", col_name="IsST"): self.fields_group = fields_group @@ -474,9 +509,24 @@ class FixedFlagSTInjector: self.col_name = col_name def __call__(self, df): - cols = [c for c in df.columns if c.startswith(f"{self.st_group}::")] - for col in cols: - df[col] = df[col].astype('int8') + import pandas as pd + + # IsST = True if ST_S or ST_Y is True + st_s_col = ('st_flag', 'ST_S') + st_y_col = ('st_flag', 'ST_Y') + + if st_s_col in df.columns and st_y_col in df.columns: + is_st = df[st_s_col].astype(bool) | df[st_y_col].astype(bool) + df[('feature_flag', 'IsST')] = is_st.astype('int8') + + # Also convert existing st_flag columns to int8 + # Handle both :: separator string format and MultiIndex tuple format + for col in df.columns: + if isinstance(col, str) and col.startswith(f"{self.st_group}::"): + df[col] = df[col].astype('int8') + elif isinstance(col, tuple) and col[0] == self.st_group: + df[col] = df[col].astype('int8') + return df