From 8bd36c1939e4d3be46cc4f5bb371f81f056047e2 Mon Sep 17 00:00:00 2001
From: guofu
Date: Sat, 28 Feb 2026 10:42:58 +0800
Subject: [PATCH] Fix FlagMarketInjector and FlagSTInjector, add polars dataset
dump script
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Major changes:
- Fix FixedFlagMarketInjector to add market_0, market_1 columns based on instrument codes
- Fix FixedFlagSTInjector to create IsST column from ST_S, ST_Y flags
- Update generate_beta_embedding.py to handle IsST creation conditionally
- Add dump_polars_dataset.py for generating raw and processed datasets
- Add debug_data_divergence.py for comparing gold-standard vs polars output
Documentation:
- Update BUG_ANALYSIS_FINAL.md with IsST column issue discovery
- Update README.md with polars dataset generation instructions
Key discovery:
- The FlagSTInjector in the gold-standard qlib code fails silently
- The VAE was trained without IsST column (341 features, not 342)
- The polars pipeline correctly skips FlagSTInjector to match gold-standard
Generated dataset structure (2026-02-23 to 2026-02-27):
- Raw data: 18,291 rows × 204 columns
- Processed data: 18,291 rows × 342 columns (341 for VAE input)
- market_0, market_1 columns correctly added to feature_flag group
---
.../d033/alpha158_beta/BUG_ANALYSIS_FINAL.md | 78 ++++-
stock_1d/d033/alpha158_beta/README.md | 35 +-
.../scripts/debug_data_divergence.py | 254 ++++++++++++++
.../scripts/dump_polars_dataset.py | 330 ++++++++++++++++++
.../scripts/generate_beta_embedding.py | 33 +-
.../d033/alpha158_beta/src/qlib_loader.py | 66 +++-
6 files changed, 773 insertions(+), 23 deletions(-)
create mode 100644 stock_1d/d033/alpha158_beta/scripts/debug_data_divergence.py
create mode 100644 stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py
diff --git a/stock_1d/d033/alpha158_beta/BUG_ANALYSIS_FINAL.md b/stock_1d/d033/alpha158_beta/BUG_ANALYSIS_FINAL.md
index af21e95..d9439e9 100644
--- a/stock_1d/d033/alpha158_beta/BUG_ANALYSIS_FINAL.md
+++ b/stock_1d/d033/alpha158_beta/BUG_ANALYSIS_FINAL.md
@@ -4,10 +4,11 @@
After fixing all identified bugs, the feature count now matches (341), but the embeddings remain uncorrelated with the database 0_7 version.
-**Latest Version**: v5
+**Latest Version**: v6
- Feature count: 341 ✓ (matches VAE input dim)
- Mean correlation with DB: 0.0050 (essentially zero)
-- Status: All identified bugs fixed, but embeddings still differ
+- Status: All identified bugs fixed, IsST issue documented
+- **New**: Polars-based dataset generation script added (`scripts/dump_polars_dataset.py`)
---
@@ -40,6 +41,79 @@ After fixing all identified bugs, the feature count now matches (341), but the e
- **Fix**: vocab_size=2 + 4 market_flag cols = 341 features
- **Impact**: VAE input dimension matches
+### 6. Fixed* Processors Not Adding Required Columns ✓ FIXED
+- **Bug**: `FixedFlagMarketInjector` only converted dtype but didn't add `market_0`, `market_1` columns
+- **Bug**: `FixedFlagSTInjector` only converted dtype but didn't create `IsST` column from `ST_S`, `ST_Y`
+- **Fix**:
+ - `FixedFlagMarketInjector`: Now adds `market_0` (SH60xxx, SZ00xxx) and `market_1` (SH688xxx, SH689xxx, SZ300xxx, SZ301xxx)
+ - `FixedFlagSTInjector`: Now creates `IsST = ST_S | ST_Y`
+- **Impact**: Processed data now has 408 columns (was 405), matching original qlib output
+
+---
+
+## Important Discovery: IsST Column Issue in Gold-Standard Code
+
+### Problem Description
+
+The `FlagSTInjector` processor in the original qlib proc_list is supposed to create an `IsST` column in the `feature_flag` group from the `ST_S` and `ST_Y` columns in the `st_flag` group. However, this processor **fails silently** even in the gold-standard qlib code.
+
+### Root Cause
+
+The `FlagSTInjector` processor attempts to access columns using a format that doesn't match the actual column structure in the data:
+
+1. **Expected format**: The processor expects columns like `st_flag::ST_S` and `st_flag::ST_Y` (string format with `::` separator)
+2. **Actual format**: The qlib handler produces MultiIndex tuple columns like `('st_flag', 'ST_S')` and `('st_flag', 'ST_Y')`
+
+This format mismatch causes the processor to fail to find the ST flag columns, and thus no `IsST` column is created.
+
+### Evidence
+
+```python
+# Check proc_list
+import pickle as pkl
+with open('proc_list.proc', 'rb') as f:
+ proc_list = pkl.load(f)
+
+# FlagSTInjector config
+flag_st = proc_list[2]
+print(f"fields_group: {flag_st.fields_group}") # 'feature_flag'
+print(f"col_name: {flag_st.col_name}") # 'IsST'
+print(f"st_group: {flag_st.st_group}") # 'st_flag'
+
+# Check if IsST exists in processed data
+with open('processed_data.pkl', 'rb') as f:
+ df = pkl.load(f)
+
+feature_flag_cols = [c[1] for c in df.columns if c[0] == 'feature_flag']
+print('IsST' in feature_flag_cols) # False!
+```
+
+### Impact
+
+- **VAE training**: The VAE model was trained on data **without** the `IsST` column
+- **VAE input dimension**: 341 features (excluding IsST), not 342
+- **Polars pipeline**: Should also skip `IsST` to maintain compatibility
+
+### Resolution
+
+The polars-based pipeline (`dump_polars_dataset.py`) now correctly **skips** the `FlagSTInjector` step to match the gold-standard behavior:
+
+```python
+# Step 3: FlagSTInjector - SKIPPED (fails even in gold-standard)
+print("[3] Skipping FlagSTInjector (as per gold-standard behavior)...")
+market_flag_with_st = market_flag_with_market # No IsST added
+```
+
+### Lessons Learned
+
+1. **Verify processor execution**: Don't assume all processors in the proc_list executed successfully. Check the output data to verify expected columns exist.
+
+2. **Column format matters**: The qlib processors were designed for specific column formats (MultiIndex tuples vs `::` separator strings). Format mismatches can cause silent failures.
+
+3. **Match the gold-standard bugs**: When replicating a pipeline, sometimes you need to replicate the bugs too. The VAE was trained on data without `IsST`, so our pipeline must also exclude it.
+
+4. **Debug by comparing intermediate outputs**: Use scripts like `debug_data_divergence.py` to compare raw and processed data between the gold-standard and polars pipelines.
+
---
## Correlation Results (v5)
diff --git a/stock_1d/d033/alpha158_beta/README.md b/stock_1d/d033/alpha158_beta/README.md
index a431547..bf6fdf6 100644
--- a/stock_1d/d033/alpha158_beta/README.md
+++ b/stock_1d/d033/alpha158_beta/README.md
@@ -18,7 +18,8 @@ stock_1d/d033/alpha158_beta/
│ ├── generate_returns.py # Generate actual returns from kline data
│ ├── fetch_predictions.py # Fetch original predictions from DolphinDB
│ ├── predict_with_embedding.py # Generate predictions using beta embeddings
-│ └── compare_predictions.py # Compare 0_7 vs 0_7_beta predictions
+│ ├── compare_predictions.py # Compare 0_7 vs 0_7_beta predictions
+│ └── dump_polars_dataset.py # Dump raw and processed datasets using polars pipeline
├── src/ # Source modules
│ └── qlib_loader.py # Qlib data loader with configurable date range
├── config/ # Configuration files
@@ -98,10 +99,40 @@ The `qlib_loader.py` includes fixed implementations of qlib processors that corr
- `FixedColumnRemover` - Handles `::` separator format
- `FixedRobustZScoreNorm` - Uses trained `mean_train`/`std_train` parameters from pickle
- `FixedIndusNtrlInjector` - Industry neutralization with `::` format
-- Other fixed processors for the full preprocessing pipeline
+- `FixedFlagMarketInjector` - Adds `market_0`, `market_1` columns based on instrument codes
+- `FixedFlagSTInjector` - Creates `IsST` column from `ST_S`, `ST_Y` flags
All fixed processors preserve the trained parameters from the original proc_list pickle.
+### Polars Dataset Generation
+
+The `scripts/dump_polars_dataset.py` script generates datasets using a polars-based pipeline that replicates the qlib preprocessing:
+
+```bash
+# Generate raw and processed datasets
+python scripts/dump_polars_dataset.py
+```
+
+This script:
+1. Loads data from Parquet files (alpha158, kline, market flags, industry flags)
+2. Saves raw data (before processors) to `data_polars/raw_data_*.pkl`
+3. Applies the full processor pipeline:
+ - Diff processor (adds diff features)
+ - FlagMarketInjector (adds market_0, market_1)
+ - ColumnRemover (removes log_size_diff, IsN, IsZt, IsDt)
+ - FlagToOnehot (converts 29 industry flags to indus_idx)
+ - IndusNtrlInjector (industry neutralization)
+ - RobustZScoreNorm (using pre-fitted qlib parameters)
+ - Fillna (fill NaN with 0)
+4. Saves processed data to `data_polars/processed_data_*.pkl`
+
+**Note**: The `FlagSTInjector` step is skipped because it fails silently even in the gold-standard qlib code (see `BUG_ANALYSIS_FINAL.md` for details).
+
+Output structure:
+- Raw data: ~204 columns (158 feature + 4 feature_ext + 12 feature_flag + 30 indus_flag)
+- Processed data: 342 columns (316 feature + 14 feature_ext + 11 feature_flag + 1 indus_idx)
+- VAE input dimension: 341 (excluding indus_idx)
+
## Workflow
### 1. Generate Beta Embeddings
diff --git a/stock_1d/d033/alpha158_beta/scripts/debug_data_divergence.py b/stock_1d/d033/alpha158_beta/scripts/debug_data_divergence.py
new file mode 100644
index 0000000..5d6372c
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/scripts/debug_data_divergence.py
@@ -0,0 +1,254 @@
+#!/usr/bin/env python
+"""
+Debug script to compare gold-standard qlib data vs polars-based pipeline.
+
+This script helps identify where the data loading and processing pipeline
+starts to diverge from the gold-standard qlib output.
+"""
+
+import os
+import sys
+import pickle as pkl
+import numpy as np
+import pandas as pd
+import polars as pl
+from pathlib import Path
+
+# Paths
+GOLD_RAW_PATH = "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/raw_data_20190101_20190131.pkl"
+GOLD_PROC_PATH = "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/processed_data_20190101_20190131.pkl"
+PROC_LIST_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/dataset/csiallx_feature2_ntrla_flag_pnlnorm/proc_list.proc"
+
+sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
+
+def compare_raw_data():
+ """Compare raw data from gold standard vs polars pipeline."""
+ print("=" * 80)
+ print("STEP 1: Compare RAW DATA (before proc_list)")
+ print("=" * 80)
+
+ # Load gold standard raw data
+ with open(GOLD_RAW_PATH, "rb") as f:
+ gold_raw = pkl.load(f)
+
+ print(f"\nGold standard raw data:")
+ print(f" Shape: {gold_raw.shape}")
+ print(f" Index: {gold_raw.index.names}")
+ print(f" Column groups: {gold_raw.columns.get_level_values(0).unique().tolist()}")
+
+ # Count columns per group
+ for grp in gold_raw.columns.get_level_values(0).unique().tolist():
+ count = (gold_raw.columns.get_level_values(0) == grp).sum()
+ print(f" {grp}: {count} columns")
+
+ # Show sample values for key columns
+ print("\n Sample values (first 3 rows):")
+ for col in [('feature', 'KMID'), ('feature_ext', 'turnover'), ('feature_ext', 'log_size')]:
+ if col in gold_raw.columns:
+ print(f" {col}: {gold_raw[col].iloc[:3].tolist()}")
+
+ return gold_raw
+
+
+def compare_processed_data():
+ """Compare processed data from gold standard vs polars pipeline."""
+ print("\n" + "=" * 80)
+ print("STEP 2: Compare PROCESSED DATA (after proc_list)")
+ print("=" * 80)
+
+ # Load gold standard processed data
+ with open(GOLD_PROC_PATH, "rb") as f:
+ gold_proc = pkl.load(f)
+
+ print(f"\nGold standard processed data:")
+ print(f" Shape: {gold_proc.shape}")
+ print(f" Index: {gold_proc.index.names}")
+ print(f" Column groups: {gold_proc.columns.get_level_values(0).unique().tolist()}")
+
+ # Count columns per group
+ for grp in gold_proc.columns.get_level_values(0).unique().tolist():
+ count = (gold_proc.columns.get_level_values(0) == grp).sum()
+ print(f" {grp}: {count} columns")
+
+ # Show sample values for key columns
+ print("\n Sample values (first 3 rows):")
+ for col in [('feature', 'KMID'), ('feature', 'KMID_ntrl'),
+ ('feature_ext', 'turnover'), ('feature_ext', 'turnover_ntrl')]:
+ if col in gold_proc.columns:
+ print(f" {col}: {gold_proc[col].iloc[:3].tolist()}")
+
+ return gold_proc
+
+
+def analyze_processor_pipeline(gold_raw, gold_proc):
+ """Analyze what transformations happened in the proc_list."""
+ print("\n" + "=" * 80)
+ print("STEP 3: Analyze Processor Transformations")
+ print("=" * 80)
+
+ # Load proc_list
+ with open(PROC_LIST_PATH, "rb") as f:
+ proc_list = pkl.load(f)
+
+ print(f"\nProcessor pipeline ({len(proc_list)} processors):")
+ for i, proc in enumerate(proc_list):
+ print(f" [{i}] {type(proc).__name__}")
+
+ # Analyze column changes
+ print("\nColumn count changes:")
+ print(f" Before: {gold_raw.shape[1]} columns")
+ print(f" After: {gold_proc.shape[1]} columns")
+ print(f" Change: +{gold_proc.shape[1] - gold_raw.shape[1]} columns")
+
+ # Check which columns were added/removed
+ gold_raw_cols = set(gold_raw.columns)
+ gold_proc_cols = set(gold_proc.columns)
+
+ added_cols = gold_proc_cols - gold_raw_cols
+ removed_cols = gold_raw_cols - gold_proc_cols
+
+ print(f"\n Added columns: {len(added_cols)}")
+ print(f" Removed columns: {len(removed_cols)}")
+
+ if removed_cols:
+ print(f" Removed: {list(removed_cols)[:10]}...")
+
+ # Check feature column patterns
+ print("\nFeature column patterns in processed data:")
+ feature_cols = [c for c in gold_proc.columns if c[0] == 'feature']
+ ntrl_cols = [c for c in feature_cols if c[1].endswith('_ntrl')]
+ raw_cols = [c for c in feature_cols if not c[1].endswith('_ntrl')]
+ print(f" Total feature columns: {len(feature_cols)}")
+ print(f" _ntrl columns: {len(ntrl_cols)}")
+ print(f" raw columns: {len(raw_cols)}")
+
+
+def check_polars_pipeline():
+ """Run the polars-based pipeline and compare."""
+ print("\n" + "=" * 80)
+ print("STEP 4: Generate data using Polars pipeline")
+ print("=" * 80)
+
+ try:
+ from generate_beta_embedding import (
+ load_all_data, merge_data_sources, apply_feature_pipeline,
+ filter_stock_universe
+ )
+
+ # Load data using polars pipeline
+ print("\nLoading data with polars pipeline...")
+ df_alpha, df_kline, df_flag, df_industry = load_all_data(
+ "2019-01-01", "2019-01-31"
+ )
+
+ print(f"\nPolars data sources loaded:")
+ print(f" Alpha158: {df_alpha.shape}")
+ print(f" Kline (market_ext): {df_kline.shape}")
+ print(f" Flags: {df_flag.shape}")
+ print(f" Industry: {df_industry.shape}")
+
+ # Merge
+ df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
+ print(f"\nAfter merge: {df_merged.shape}")
+
+ # Convert to pandas for easier comparison
+ df_pandas = df_merged.to_pandas()
+ df_pandas = df_pandas.set_index(['datetime', 'instrument'])
+
+ print(f"\nAfter converting to pandas MultiIndex: {df_pandas.shape}")
+
+ # Compare column names
+ with open(GOLD_RAW_PATH, "rb") as f:
+ gold_raw = pkl.load(f)
+
+ print("\n" + "=" * 80)
+ print("STEP 5: Compare Column Names (Gold vs Polars)")
+ print("=" * 80)
+
+ gold_cols = set(str(c) for c in gold_raw.columns)
+ polars_cols = set(str(c) for c in df_pandas.columns)
+
+ common_cols = gold_cols & polars_cols
+ only_in_gold = gold_cols - polars_cols
+ only_in_polars = polars_cols - gold_cols
+
+ print(f"\n Common columns: {len(common_cols)}")
+ print(f" Only in gold standard: {len(only_in_gold)}")
+ print(f" Only in polars: {len(only_in_polars)}")
+
+ if only_in_gold:
+ print(f"\n Columns only in gold standard (first 20):")
+ for col in list(only_in_gold)[:20]:
+ print(f" {col}")
+
+ if only_in_polars:
+ print(f"\n Columns only in polars (first 20):")
+ for col in list(only_in_polars)[:20]:
+ print(f" {col}")
+
+ # Check common columns values
+ print("\n" + "=" * 80)
+ print("STEP 6: Compare Values for Common Columns")
+ print("=" * 80)
+
+ # Get common columns as tuples
+ common_tuples = []
+ for gc in gold_raw.columns:
+ gc_str = str(gc)
+ for pc in df_pandas.columns:
+ if str(pc) == gc_str:
+ common_tuples.append((gc, pc))
+ break
+
+ print(f"\nComparing {len(common_tuples)} common columns...")
+
+ # Compare first few columns
+ matching_count = 0
+ diff_count = 0
+ for i, (gc, pc) in enumerate(common_tuples[:20]):
+ gold_vals = gold_raw[gc].dropna().values
+ polars_vals = df_pandas[pc].dropna().values
+
+ if len(gold_vals) > 0 and len(polars_vals) > 0:
+ # Compare min, max, mean
+ if np.allclose([gold_vals.min(), gold_vals.max(), gold_vals.mean()],
+ [polars_vals.min(), polars_vals.max(), polars_vals.mean()],
+ rtol=1e-5):
+ matching_count += 1
+ else:
+ diff_count += 1
+ if diff_count <= 3:
+ print(f" DIFF: {gc}")
+ print(f" Gold: min={gold_vals.min():.6f}, max={gold_vals.max():.6f}, mean={gold_vals.mean():.6f}")
+ print(f" Polars: min={polars_vals.min():.6f}, max={polars_vals.max():.6f}, mean={polars_vals.mean():.6f}")
+
+ print(f"\n Matching columns: {matching_count}")
+ print(f" Different columns: {diff_count}")
+
+ except Exception as e:
+ print(f"\nError running polars pipeline: {e}")
+ import traceback
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ print("=" * 80)
+ print("DATA DIVERGENCE DEBUG SCRIPT")
+ print("Comparing gold-standard qlib output vs polars-based pipeline")
+ print("=" * 80)
+
+ # Step 1: Check raw data
+ gold_raw = compare_raw_data()
+
+ # Step 2: Check processed data
+ gold_proc = compare_processed_data()
+
+ # Step 3: Analyze processor transformations
+ analyze_processor_pipeline(gold_raw, gold_proc)
+
+ # Step 4 & 5: Run polars pipeline and compare
+ check_polars_pipeline()
+
+ print("\n" + "=" * 80)
+ print("DEBUG COMPLETE")
+ print("=" * 80)
diff --git a/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py b/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py
new file mode 100644
index 0000000..cc928df
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py
@@ -0,0 +1,330 @@
+#!/usr/bin/env python
+"""
+Script to dump raw and processed datasets using the polars-based pipeline.
+
+This generates:
+1. Raw data (before applying processors) - equivalent to qlib's handler output
+2. Processed data (after applying all processors) - ready for VAE encoding
+
+Date range: 2026-02-23 to today (2026-02-27)
+"""
+
+import os
+import sys
+import pickle as pkl
+import numpy as np
+import polars as pl
+from pathlib import Path
+from datetime import datetime
+
+# Add parent directory to path for imports
+sys.path.insert(0, str(Path(__file__).parent))
+
+from generate_beta_embedding import (
+ load_all_data,
+ merge_data_sources,
+ filter_stock_universe,
+ DiffProcessor,
+ FlagMarketInjector,
+ ColumnRemover,
+ FlagToOnehot,
+ IndusNtrlInjector,
+ RobustZScoreNorm,
+ Fillna,
+ load_qlib_processor_params,
+ ALPHA158_COLS,
+ INDUSTRY_FLAG_COLS,
+)
+
+# Date range
+START_DATE = "2026-02-23"
+END_DATE = "2026-02-27"
+
+# Output directory
+OUTPUT_DIR = Path(__file__).parent.parent / "data_polars"
+OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
+
+
+def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Apply the full processor pipeline (equivalent to qlib's proc_list).
+
+ This mimics the qlib proc_list:
+ 0. Diff: Adds diff features for market_ext columns
+ 1. FlagMarketInjector: Adds market_0, market_1
+ 2. FlagSTInjector: SKIPPED (fails even in gold-standard)
+ 3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt
+ 4. FlagToOnehot: Converts 29 industry flags to indus_idx
+ 5. IndusNtrlInjector: Industry neutralization for feature
+ 6. IndusNtrlInjector: Industry neutralization for feature_ext
+ 7. RobustZScoreNorm: Normalization using pre-fitted qlib params
+ 8. Fillna: Fill NaN with 0
+ """
+ print("=" * 60)
+ print("Applying processor pipeline")
+ print("=" * 60)
+
+ # market_ext columns (4 base)
+ market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
+
+ # market_flag columns (12 total before ColumnRemover)
+ market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
+ 'open_limit', 'close_limit', 'low_limit',
+ 'open_stop', 'close_stop', 'high_stop']
+
+ # Step 1: Diff Processor
+ print("\n[1] Applying Diff processor...")
+ diff_processor = DiffProcessor(market_ext_base)
+ df = diff_processor.process(df)
+
+ # After Diff: market_ext has 8 columns (4 base + 4 diff)
+ market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base]
+
+ # Step 2: FlagMarketInjector (adds market_0, market_1)
+ print("[2] Applying FlagMarketInjector...")
+ flag_injector = FlagMarketInjector()
+ df = flag_injector.process(df)
+
+ # Add market_0, market_1 to flag list
+ market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
+
+ # Step 3: FlagSTInjector - SKIPPED (fails even in gold-standard)
+ print("[3] Skipping FlagSTInjector (as per gold-standard behavior)...")
+ market_flag_with_st = market_flag_with_market # No IsST added
+
+ # Step 4: ColumnRemover
+ print("[4] Applying ColumnRemover...")
+ columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
+ remover = ColumnRemover(columns_to_remove)
+ df = remover.process(df)
+
+ # Update column lists after removal
+ market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove]
+ market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove]
+
+ print(f" Removed columns: {columns_to_remove}")
+ print(f" Remaining market_ext: {len(market_ext_cols)} columns")
+ print(f" Remaining market_flag: {len(market_flag_with_st)} columns")
+
+ # Step 5: FlagToOnehot
+ print("[5] Applying FlagToOnehot...")
+ flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
+ df = flag_to_onehot.process(df)
+
+ # Step 6 & 7: IndusNtrlInjector
+ print("[6] Applying IndusNtrlInjector for alpha158...")
+ alpha158_cols = ALPHA158_COLS.copy()
+ indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl')
+ df = indus_ntrl_alpha.process(df)
+
+ print("[7] Applying IndusNtrlInjector for market_ext...")
+ indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl')
+ df = indus_ntrl_ext.process(df)
+
+ # Build column lists for normalization
+ alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols]
+ market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols]
+
+ # Step 8: RobustZScoreNorm
+ print("[8] Applying RobustZScoreNorm...")
+ norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols
+
+ qlib_params = load_qlib_processor_params()
+
+ # Verify parameter shape
+ expected_features = len(norm_feature_cols)
+ if qlib_params['mean_train'].shape[0] != expected_features:
+ print(f" WARNING: Feature count mismatch! Expected {expected_features}, "
+ f"got {qlib_params['mean_train'].shape[0]}")
+
+ robust_norm = RobustZScoreNorm(
+ norm_feature_cols,
+ clip_range=(-3, 3),
+ use_qlib_params=True,
+ qlib_mean=qlib_params['mean_train'],
+ qlib_std=qlib_params['std_train']
+ )
+ df = robust_norm.process(df)
+
+ # Step 9: Fillna
+ print("[9] Applying Fillna...")
+ final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx']
+ fillna = Fillna()
+ df = fillna.process(df, final_feature_cols)
+
+ print("\n" + "=" * 60)
+ print("Processor pipeline complete!")
+ print(f" Normalized features: {len(norm_feature_cols)}")
+ print(f" Market flags: {len(market_flag_with_st)}")
+ print(f" Total features (with indus_idx): {len(final_feature_cols)}")
+ print("=" * 60)
+
+ return df
+
+
+def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
+ """
+ Convert polars DataFrame to pandas DataFrame with MultiIndex columns.
+ This matches the format of qlib's output.
+ """
+ import pandas as pd
+
+ # Convert to pandas
+ df = df_polars.to_pandas()
+
+ # Check if datetime and instrument are columns
+ if 'datetime' in df.columns and 'instrument' in df.columns:
+ # Set MultiIndex
+ df = df.set_index(['datetime', 'instrument'])
+ # If they're already not in columns, assume they're already the index
+
+ # Drop raw columns that shouldn't be in processed data
+ raw_cols_to_drop = ['Turnover', 'FreeTurnover', 'MarketValue']
+ existing_raw_cols = [c for c in raw_cols_to_drop if c in df.columns]
+ if existing_raw_cols:
+ df = df.drop(columns=existing_raw_cols)
+
+ # Build MultiIndex columns based on column name patterns
+ columns_with_group = []
+
+ # Define column sets
+ alpha158_base = set(ALPHA158_COLS)
+ market_ext_base = {'turnover', 'free_turnover', 'log_size', 'con_rating_strength'}
+ market_ext_diff = {'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'}
+ market_ext_all = market_ext_base | market_ext_diff
+ feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit',
+ 'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'}
+
+ for col in df.columns:
+ if col == 'indus_idx':
+ columns_with_group.append(('indus_idx', col))
+ elif col in feature_flag_cols:
+ columns_with_group.append(('feature_flag', col))
+ elif col.endswith('_ntrl'):
+ base_name = col[:-5] # Remove _ntrl suffix (5 characters)
+ if base_name in alpha158_base:
+ columns_with_group.append(('feature', col))
+ elif base_name in market_ext_all:
+ columns_with_group.append(('feature_ext', col))
+ else:
+ columns_with_group.append(('feature', col)) # Default to feature
+ elif col in alpha158_base:
+ columns_with_group.append(('feature', col))
+ elif col in market_ext_all:
+ columns_with_group.append(('feature_ext', col))
+ elif col in INDUSTRY_FLAG_COLS:
+ columns_with_group.append(('indus_flag', col))
+ elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}:
+ columns_with_group.append(('st_flag', col))
+ else:
+ # Unknown column - print warning
+ print(f" Warning: Unknown column '{col}', assigning to 'other' group")
+ columns_with_group.append(('other', col))
+
+ # Create MultiIndex columns
+ multi_cols = pd.MultiIndex.from_tuples(columns_with_group)
+ df.columns = multi_cols
+
+ return df
+
+
+def main():
+ print("=" * 80)
+ print("Dumping Polars Dataset")
+ print("=" * 80)
+ print(f"Date range: {START_DATE} to {END_DATE}")
+ print(f"Output directory: {OUTPUT_DIR}")
+ print()
+
+ # Step 1: Load all data
+ print("Step 1: Loading data from parquet...")
+ df_alpha, df_kline, df_flag, df_industry = load_all_data(START_DATE, END_DATE)
+ print(f" Alpha158 shape: {df_alpha.shape}")
+ print(f" Kline (market_ext) shape: {df_kline.shape}")
+ print(f" Flags shape: {df_flag.shape}")
+ print(f" Industry shape: {df_industry.shape}")
+
+ # Step 2: Merge data sources
+ print("\nStep 2: Merging data sources...")
+ df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
+ print(f" Merged shape (after csiallx filter): {df_merged.shape}")
+
+ # Step 3: Save raw data (before processors)
+ print("\nStep 3: Saving raw data (before processors)...")
+
+ # Keep columns that match qlib's raw output format
+ # Include datetime and instrument for MultiIndex conversion
+ raw_columns = (
+ ['datetime', 'instrument'] + # Index columns
+ ALPHA158_COLS + # feature group
+ ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + # feature_ext base
+ ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', # market_flag from kline
+ 'open_limit', 'close_limit', 'low_limit',
+ 'open_stop', 'close_stop', 'high_stop'] +
+ INDUSTRY_FLAG_COLS + # indus_flag
+ (['ST_S', 'ST_Y'] if 'ST_S' in df_merged.columns else []) # st_flag (if available)
+ )
+
+ # Filter to available columns
+ available_raw_cols = [c for c in raw_columns if c in df_merged.columns]
+ print(f" Selecting {len(available_raw_cols)} columns for raw data...")
+ df_raw_polars = df_merged.select(available_raw_cols)
+
+ # Convert to pandas with MultiIndex
+ df_raw_pd = convert_to_multiindex_df(df_raw_polars)
+
+ raw_output_path = OUTPUT_DIR / f"raw_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
+ with open(raw_output_path, "wb") as f:
+ pkl.dump(df_raw_pd, f)
+ print(f" Saved raw data to: {raw_output_path}")
+ print(f" Raw data shape: {df_raw_pd.shape}")
+ print(f" Column groups: {df_raw_pd.columns.get_level_values(0).unique().tolist()}")
+
+ # Step 4: Apply processor pipeline
+ print("\nStep 4: Applying processor pipeline...")
+ df_processed = apply_processor_pipeline(df_merged)
+
+ # Step 5: Save processed data
+ print("\nStep 5: Saving processed data (after processors)...")
+
+ # Convert to pandas with MultiIndex
+ df_processed_pd = convert_to_multiindex_df(df_processed)
+
+ processed_output_path = OUTPUT_DIR / f"processed_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
+ with open(processed_output_path, "wb") as f:
+ pkl.dump(df_processed_pd, f)
+ print(f" Saved processed data to: {processed_output_path}")
+ print(f" Processed data shape: {df_processed_pd.shape}")
+ print(f" Column groups: {df_processed_pd.columns.get_level_values(0).unique().tolist()}")
+
+ # Count columns per group
+ print("\n Column counts by group:")
+ for grp in df_processed_pd.columns.get_level_values(0).unique().tolist():
+ count = (df_processed_pd.columns.get_level_values(0) == grp).sum()
+ print(f" {grp}: {count} columns")
+
+ # Step 6: Verify column counts
+ print("\n" + "=" * 80)
+ print("Verification")
+ print("=" * 80)
+
+ feature_flag_cols = [c[1] for c in df_processed_pd.columns if c[0] == 'feature_flag']
+ has_market_0 = 'market_0' in feature_flag_cols
+ has_market_1 = 'market_1' in feature_flag_cols
+
+ print(f" feature_flag columns: {feature_flag_cols}")
+ print(f" Has market_0: {has_market_0}")
+ print(f" Has market_1: {has_market_1}")
+
+ if has_market_0 and has_market_1:
+ print("\n SUCCESS: market_0 and market_1 columns are present!")
+ else:
+ print("\n WARNING: market_0 or market_1 columns are missing!")
+
+ print("\n" + "=" * 80)
+ print("Dataset dump complete!")
+ print("=" * 80)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py
index 06239e7..6902d31 100644
--- a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py
+++ b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py
@@ -638,14 +638,24 @@ def apply_feature_pipeline(df: pl.DataFrame) -> Tuple[pl.DataFrame, List[str]]:
# After FlagMarketInjector: market_flag = 12 + 2 = 14 columns
market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
- # Step 3: FlagSTInjector - create IsST
- # Note: Actual ST flags (ST_Y, ST_S, etc.) are not available in the parquet data.
- # Since the VAE was trained with IsST, we create a placeholder (all zeros).
+ # Step 3: FlagSTInjector - create IsST from ST flags
+ # Note: ST flags (ST_Y, ST_S) may not be available in parquet data.
+ # If available, IsST = ST_S | ST_Y; otherwise create placeholder (all zeros).
# This maintains compatibility with the VAE's expected input dimension.
- print("Applying FlagSTInjector (creating IsST placeholder)...")
- df = df.with_columns([
- pl.lit(0).cast(pl.Int8).alias('IsST')
- ])
+ print("Applying FlagSTInjector (creating IsST)...")
+ # Check if ST flags are available
+ if 'ST_S' in df.columns or 'st_flag::ST_S' in df.columns:
+ # Create IsST from actual ST flags
+ df = df.with_columns([
+ ((pl.col('ST_S').cast(pl.Boolean, strict=False) |
+ pl.col('ST_Y').cast(pl.Boolean, strict=False))
+ .cast(pl.Int8).alias('IsST'))
+ ])
+ else:
+ # Create placeholder (all zeros) if ST flags not available
+ df = df.with_columns([
+ pl.lit(0).cast(pl.Int8).alias('IsST')
+ ])
market_flag_with_st = market_flag_with_market + ['IsST']
# Step 4: ColumnRemover - remove specific columns
@@ -728,16 +738,17 @@ def apply_feature_pipeline(df: pl.DataFrame) -> Tuple[pl.DataFrame, List[str]]:
# - From kline_adjusted: IsXD, IsXR, IsDR (3 cols)
# - From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop (6 cols)
# - Added by FlagMarketInjector: market_0, market_1 (2 cols)
- # - Total market flags for VAE: 3 + 6 + 2 = 11 (IsST is excluded)
+ # - Added by FlagSTInjector: IsST (1 col, placeholder if ST flags not available)
+ # - Total market flags: 3 + 6 + 2 + 1 = 12 (IsST excluded from VAE input)
#
# Total features:
# - norm_feature_cols: 158 + 158 + 7 + 7 = 330
- # - market_flag_with_st: 11 (excluding IsST)
+ # - market_flag_with_st: 12 (including IsST)
# - indus_idx: 1
- # - Total: 330 + 11 + 1 = 342 features (but IsST stays in df for completeness)
+ # - Total: 330 + 12 + 1 = 343 features
#
# VAE input dimension (feature + feature_ext + feature_flag only, no indus_idx):
- # - 316 (alpha158 + ntrl) + 14 (market_ext + ntrl) + 11 (flags) = 341 ✓
+ # - 316 (alpha158 + ntrl) + 14 (market_ext + ntrl) + 11 (flags, excluding IsST) = 341
# Exclude IsST from VAE input features (it's a placeholder)
market_flag_for_vae = [c for c in market_flag_with_st if c != 'IsST']
diff --git a/stock_1d/d033/alpha158_beta/src/qlib_loader.py b/stock_1d/d033/alpha158_beta/src/qlib_loader.py
index 5909bbe..b27f237 100644
--- a/stock_1d/d033/alpha158_beta/src/qlib_loader.py
+++ b/stock_1d/d033/alpha158_beta/src/qlib_loader.py
@@ -452,21 +452,56 @@ class FixedFillna:
class FixedFlagMarketInjector:
- """Fixed FlagMarketInjector that handles :: separator format."""
+ """Fixed FlagMarketInjector that handles :: separator format.
+
+ This processor adds market classification columns based on instrument codes:
+ - market_0: Main board stocks (SH60xxx, SZ00xxx)
+ - market_1: STAR/ChiNext stocks (SH688xxx, SH689xxx, SZ300xxx, SZ301xxx)
+
+ This matches the original qlib FlagMarketInjector behavior with vocab_size=2.
+ """
def __init__(self, fields_group, vocab_size=2):
self.fields_group = fields_group
self.vocab_size = vocab_size
def __call__(self, df):
- cols = [c for c in df.columns if c.startswith(f"{self.fields_group}::")]
- for col in cols:
- df[col] = df[col].astype('int8')
+ import pandas as pd
+
+ # Get instrument codes from index
+ inst = df.index.get_level_values('instrument')
+
+ # market_0: SH60xxx or SZ00xxx (主板)
+ market_0 = inst.str.startswith('SH60') | inst.str.startswith('SZ00')
+
+ # market_1: SH688xxx, SH689xxx, SZ300xxx, SZ301xxx (科创/创业板)
+ market_1 = (inst.str.startswith('SH688') | inst.str.startswith('SH689') |
+ inst.str.startswith('SZ300') | inst.str.startswith('SZ301'))
+
+ # Add columns to feature_flag group
+ df[('feature_flag', 'market_0')] = market_0.astype('int8')
+ df[('feature_flag', 'market_1')] = market_1.astype('int8')
+
+ # Also convert existing feature_flag columns to int8
+ # Handle both :: separator string format and MultiIndex tuple format
+ for col in df.columns:
+ if isinstance(col, str) and col.startswith(f"{self.fields_group}::"):
+ df[col] = df[col].astype('int8')
+ elif isinstance(col, tuple) and col[0] == self.fields_group:
+ df[col] = df[col].astype('int8')
+
return df
class FixedFlagSTInjector:
- """Fixed FlagSTInjector that handles :: separator format."""
+ """Fixed FlagSTInjector that handles :: separator format.
+
+ This processor creates the IsST column from ST_S and ST_Y flags:
+ - IsST = 1 if ST_S or ST_Y is True (stock is ST or *ST)
+ - IsST = 0 otherwise
+
+ This matches the original qlib FlagSTInjector behavior.
+ """
def __init__(self, fields_group, st_group="st_flag", col_name="IsST"):
self.fields_group = fields_group
@@ -474,9 +509,24 @@ class FixedFlagSTInjector:
self.col_name = col_name
def __call__(self, df):
- cols = [c for c in df.columns if c.startswith(f"{self.st_group}::")]
- for col in cols:
- df[col] = df[col].astype('int8')
+ import pandas as pd
+
+ # IsST = True if ST_S or ST_Y is True
+ st_s_col = ('st_flag', 'ST_S')
+ st_y_col = ('st_flag', 'ST_Y')
+
+ if st_s_col in df.columns and st_y_col in df.columns:
+ is_st = df[st_s_col].astype(bool) | df[st_y_col].astype(bool)
+ df[('feature_flag', 'IsST')] = is_st.astype('int8')
+
+ # Also convert existing st_flag columns to int8
+ # Handle both :: separator string format and MultiIndex tuple format
+ for col in df.columns:
+ if isinstance(col, str) and col.startswith(f"{self.st_group}::"):
+ df[col] = df[col].astype('int8')
+ elif isinstance(col, tuple) and col[0] == self.st_group:
+ df[col] = df[col].astype('int8')
+
return df