Fix FlagMarketInjector and FlagSTInjector, add polars dataset dump script

Major changes:
- Fix FixedFlagMarketInjector to add market_0, market_1 columns based on instrument codes
- Fix FixedFlagSTInjector to create IsST column from ST_S, ST_Y flags
- Update generate_beta_embedding.py to handle IsST creation conditionally
- Add dump_polars_dataset.py for generating raw and processed datasets
- Add debug_data_divergence.py for comparing gold-standard vs polars output

Documentation:
- Update BUG_ANALYSIS_FINAL.md with IsST column issue discovery
- Update README.md with polars dataset generation instructions

Key discovery:
- The FlagSTInjector in the gold-standard qlib code fails silently
- The VAE was trained without IsST column (341 features, not 342)
- The polars pipeline correctly skips FlagSTInjector to match gold-standard

Generated dataset structure (2026-02-23 to 2026-02-27):
- Raw data: 18,291 rows × 204 columns
- Processed data: 18,291 rows × 342 columns (341 for VAE input)
- market_0, market_1 columns correctly added to feature_flag group
master
guofu 4 days ago
parent 4d382dc6bd
commit 8bd36c1939

@ -4,10 +4,11 @@
After fixing all identified bugs, the feature count now matches (341), but the embeddings remain uncorrelated with the database 0_7 version.
**Latest Version**: v5
**Latest Version**: v6
- Feature count: 341 ✓ (matches VAE input dim)
- Mean correlation with DB: 0.0050 (essentially zero)
- Status: All identified bugs fixed, but embeddings still differ
- Status: All identified bugs fixed, IsST issue documented
- **New**: Polars-based dataset generation script added (`scripts/dump_polars_dataset.py`)
---
@ -40,6 +41,79 @@ After fixing all identified bugs, the feature count now matches (341), but the e
- **Fix**: vocab_size=2 + 4 market_flag cols = 341 features
- **Impact**: VAE input dimension matches
### 6. Fixed* Processors Not Adding Required Columns ✓ FIXED
- **Bug**: `FixedFlagMarketInjector` only converted dtype but didn't add `market_0`, `market_1` columns
- **Bug**: `FixedFlagSTInjector` only converted dtype but didn't create `IsST` column from `ST_S`, `ST_Y`
- **Fix**:
- `FixedFlagMarketInjector`: Now adds `market_0` (SH60xxx, SZ00xxx) and `market_1` (SH688xxx, SH689xxx, SZ300xxx, SZ301xxx)
- `FixedFlagSTInjector`: Now creates `IsST = ST_S | ST_Y`
- **Impact**: Processed data now has 408 columns (was 405), matching original qlib output
---
## Important Discovery: IsST Column Issue in Gold-Standard Code
### Problem Description
The `FlagSTInjector` processor in the original qlib proc_list is supposed to create an `IsST` column in the `feature_flag` group from the `ST_S` and `ST_Y` columns in the `st_flag` group. However, this processor **fails silently** even in the gold-standard qlib code.
### Root Cause
The `FlagSTInjector` processor attempts to access columns using a format that doesn't match the actual column structure in the data:
1. **Expected format**: The processor expects columns like `st_flag::ST_S` and `st_flag::ST_Y` (string format with `::` separator)
2. **Actual format**: The qlib handler produces MultiIndex tuple columns like `('st_flag', 'ST_S')` and `('st_flag', 'ST_Y')`
This format mismatch causes the processor to fail to find the ST flag columns, and thus no `IsST` column is created.
### Evidence
```python
# Check proc_list
import pickle as pkl
with open('proc_list.proc', 'rb') as f:
proc_list = pkl.load(f)
# FlagSTInjector config
flag_st = proc_list[2]
print(f"fields_group: {flag_st.fields_group}") # 'feature_flag'
print(f"col_name: {flag_st.col_name}") # 'IsST'
print(f"st_group: {flag_st.st_group}") # 'st_flag'
# Check if IsST exists in processed data
with open('processed_data.pkl', 'rb') as f:
df = pkl.load(f)
feature_flag_cols = [c[1] for c in df.columns if c[0] == 'feature_flag']
print('IsST' in feature_flag_cols) # False!
```
### Impact
- **VAE training**: The VAE model was trained on data **without** the `IsST` column
- **VAE input dimension**: 341 features (excluding IsST), not 342
- **Polars pipeline**: Should also skip `IsST` to maintain compatibility
### Resolution
The polars-based pipeline (`dump_polars_dataset.py`) now correctly **skips** the `FlagSTInjector` step to match the gold-standard behavior:
```python
# Step 3: FlagSTInjector - SKIPPED (fails even in gold-standard)
print("[3] Skipping FlagSTInjector (as per gold-standard behavior)...")
market_flag_with_st = market_flag_with_market # No IsST added
```
### Lessons Learned
1. **Verify processor execution**: Don't assume all processors in the proc_list executed successfully. Check the output data to verify expected columns exist.
2. **Column format matters**: The qlib processors were designed for specific column formats (MultiIndex tuples vs `::` separator strings). Format mismatches can cause silent failures.
3. **Match the gold-standard bugs**: When replicating a pipeline, sometimes you need to replicate the bugs too. The VAE was trained on data without `IsST`, so our pipeline must also exclude it.
4. **Debug by comparing intermediate outputs**: Use scripts like `debug_data_divergence.py` to compare raw and processed data between the gold-standard and polars pipelines.
---
## Correlation Results (v5)

@ -18,7 +18,8 @@ stock_1d/d033/alpha158_beta/
│ ├── generate_returns.py # Generate actual returns from kline data
│ ├── fetch_predictions.py # Fetch original predictions from DolphinDB
│ ├── predict_with_embedding.py # Generate predictions using beta embeddings
│ └── compare_predictions.py # Compare 0_7 vs 0_7_beta predictions
│ ├── compare_predictions.py # Compare 0_7 vs 0_7_beta predictions
│ └── dump_polars_dataset.py # Dump raw and processed datasets using polars pipeline
├── src/ # Source modules
│ └── qlib_loader.py # Qlib data loader with configurable date range
├── config/ # Configuration files
@ -98,10 +99,40 @@ The `qlib_loader.py` includes fixed implementations of qlib processors that corr
- `FixedColumnRemover` - Handles `::` separator format
- `FixedRobustZScoreNorm` - Uses trained `mean_train`/`std_train` parameters from pickle
- `FixedIndusNtrlInjector` - Industry neutralization with `::` format
- Other fixed processors for the full preprocessing pipeline
- `FixedFlagMarketInjector` - Adds `market_0`, `market_1` columns based on instrument codes
- `FixedFlagSTInjector` - Creates `IsST` column from `ST_S`, `ST_Y` flags
All fixed processors preserve the trained parameters from the original proc_list pickle.
### Polars Dataset Generation
The `scripts/dump_polars_dataset.py` script generates datasets using a polars-based pipeline that replicates the qlib preprocessing:
```bash
# Generate raw and processed datasets
python scripts/dump_polars_dataset.py
```
This script:
1. Loads data from Parquet files (alpha158, kline, market flags, industry flags)
2. Saves raw data (before processors) to `data_polars/raw_data_*.pkl`
3. Applies the full processor pipeline:
- Diff processor (adds diff features)
- FlagMarketInjector (adds market_0, market_1)
- ColumnRemover (removes log_size_diff, IsN, IsZt, IsDt)
- FlagToOnehot (converts 29 industry flags to indus_idx)
- IndusNtrlInjector (industry neutralization)
- RobustZScoreNorm (using pre-fitted qlib parameters)
- Fillna (fill NaN with 0)
4. Saves processed data to `data_polars/processed_data_*.pkl`
**Note**: The `FlagSTInjector` step is skipped because it fails silently even in the gold-standard qlib code (see `BUG_ANALYSIS_FINAL.md` for details).
Output structure:
- Raw data: ~204 columns (158 feature + 4 feature_ext + 12 feature_flag + 30 indus_flag)
- Processed data: 342 columns (316 feature + 14 feature_ext + 11 feature_flag + 1 indus_idx)
- VAE input dimension: 341 (excluding indus_idx)
## Workflow
### 1. Generate Beta Embeddings

@ -0,0 +1,254 @@
#!/usr/bin/env python
"""
Debug script to compare gold-standard qlib data vs polars-based pipeline.
This script helps identify where the data loading and processing pipeline
starts to diverge from the gold-standard qlib output.
"""
import os
import sys
import pickle as pkl
import numpy as np
import pandas as pd
import polars as pl
from pathlib import Path
# Paths
GOLD_RAW_PATH = "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/raw_data_20190101_20190131.pkl"
GOLD_PROC_PATH = "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/processed_data_20190101_20190131.pkl"
PROC_LIST_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/dataset/csiallx_feature2_ntrla_flag_pnlnorm/proc_list.proc"
sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
def compare_raw_data():
"""Compare raw data from gold standard vs polars pipeline."""
print("=" * 80)
print("STEP 1: Compare RAW DATA (before proc_list)")
print("=" * 80)
# Load gold standard raw data
with open(GOLD_RAW_PATH, "rb") as f:
gold_raw = pkl.load(f)
print(f"\nGold standard raw data:")
print(f" Shape: {gold_raw.shape}")
print(f" Index: {gold_raw.index.names}")
print(f" Column groups: {gold_raw.columns.get_level_values(0).unique().tolist()}")
# Count columns per group
for grp in gold_raw.columns.get_level_values(0).unique().tolist():
count = (gold_raw.columns.get_level_values(0) == grp).sum()
print(f" {grp}: {count} columns")
# Show sample values for key columns
print("\n Sample values (first 3 rows):")
for col in [('feature', 'KMID'), ('feature_ext', 'turnover'), ('feature_ext', 'log_size')]:
if col in gold_raw.columns:
print(f" {col}: {gold_raw[col].iloc[:3].tolist()}")
return gold_raw
def compare_processed_data():
"""Compare processed data from gold standard vs polars pipeline."""
print("\n" + "=" * 80)
print("STEP 2: Compare PROCESSED DATA (after proc_list)")
print("=" * 80)
# Load gold standard processed data
with open(GOLD_PROC_PATH, "rb") as f:
gold_proc = pkl.load(f)
print(f"\nGold standard processed data:")
print(f" Shape: {gold_proc.shape}")
print(f" Index: {gold_proc.index.names}")
print(f" Column groups: {gold_proc.columns.get_level_values(0).unique().tolist()}")
# Count columns per group
for grp in gold_proc.columns.get_level_values(0).unique().tolist():
count = (gold_proc.columns.get_level_values(0) == grp).sum()
print(f" {grp}: {count} columns")
# Show sample values for key columns
print("\n Sample values (first 3 rows):")
for col in [('feature', 'KMID'), ('feature', 'KMID_ntrl'),
('feature_ext', 'turnover'), ('feature_ext', 'turnover_ntrl')]:
if col in gold_proc.columns:
print(f" {col}: {gold_proc[col].iloc[:3].tolist()}")
return gold_proc
def analyze_processor_pipeline(gold_raw, gold_proc):
"""Analyze what transformations happened in the proc_list."""
print("\n" + "=" * 80)
print("STEP 3: Analyze Processor Transformations")
print("=" * 80)
# Load proc_list
with open(PROC_LIST_PATH, "rb") as f:
proc_list = pkl.load(f)
print(f"\nProcessor pipeline ({len(proc_list)} processors):")
for i, proc in enumerate(proc_list):
print(f" [{i}] {type(proc).__name__}")
# Analyze column changes
print("\nColumn count changes:")
print(f" Before: {gold_raw.shape[1]} columns")
print(f" After: {gold_proc.shape[1]} columns")
print(f" Change: +{gold_proc.shape[1] - gold_raw.shape[1]} columns")
# Check which columns were added/removed
gold_raw_cols = set(gold_raw.columns)
gold_proc_cols = set(gold_proc.columns)
added_cols = gold_proc_cols - gold_raw_cols
removed_cols = gold_raw_cols - gold_proc_cols
print(f"\n Added columns: {len(added_cols)}")
print(f" Removed columns: {len(removed_cols)}")
if removed_cols:
print(f" Removed: {list(removed_cols)[:10]}...")
# Check feature column patterns
print("\nFeature column patterns in processed data:")
feature_cols = [c for c in gold_proc.columns if c[0] == 'feature']
ntrl_cols = [c for c in feature_cols if c[1].endswith('_ntrl')]
raw_cols = [c for c in feature_cols if not c[1].endswith('_ntrl')]
print(f" Total feature columns: {len(feature_cols)}")
print(f" _ntrl columns: {len(ntrl_cols)}")
print(f" raw columns: {len(raw_cols)}")
def check_polars_pipeline():
"""Run the polars-based pipeline and compare."""
print("\n" + "=" * 80)
print("STEP 4: Generate data using Polars pipeline")
print("=" * 80)
try:
from generate_beta_embedding import (
load_all_data, merge_data_sources, apply_feature_pipeline,
filter_stock_universe
)
# Load data using polars pipeline
print("\nLoading data with polars pipeline...")
df_alpha, df_kline, df_flag, df_industry = load_all_data(
"2019-01-01", "2019-01-31"
)
print(f"\nPolars data sources loaded:")
print(f" Alpha158: {df_alpha.shape}")
print(f" Kline (market_ext): {df_kline.shape}")
print(f" Flags: {df_flag.shape}")
print(f" Industry: {df_industry.shape}")
# Merge
df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
print(f"\nAfter merge: {df_merged.shape}")
# Convert to pandas for easier comparison
df_pandas = df_merged.to_pandas()
df_pandas = df_pandas.set_index(['datetime', 'instrument'])
print(f"\nAfter converting to pandas MultiIndex: {df_pandas.shape}")
# Compare column names
with open(GOLD_RAW_PATH, "rb") as f:
gold_raw = pkl.load(f)
print("\n" + "=" * 80)
print("STEP 5: Compare Column Names (Gold vs Polars)")
print("=" * 80)
gold_cols = set(str(c) for c in gold_raw.columns)
polars_cols = set(str(c) for c in df_pandas.columns)
common_cols = gold_cols & polars_cols
only_in_gold = gold_cols - polars_cols
only_in_polars = polars_cols - gold_cols
print(f"\n Common columns: {len(common_cols)}")
print(f" Only in gold standard: {len(only_in_gold)}")
print(f" Only in polars: {len(only_in_polars)}")
if only_in_gold:
print(f"\n Columns only in gold standard (first 20):")
for col in list(only_in_gold)[:20]:
print(f" {col}")
if only_in_polars:
print(f"\n Columns only in polars (first 20):")
for col in list(only_in_polars)[:20]:
print(f" {col}")
# Check common columns values
print("\n" + "=" * 80)
print("STEP 6: Compare Values for Common Columns")
print("=" * 80)
# Get common columns as tuples
common_tuples = []
for gc in gold_raw.columns:
gc_str = str(gc)
for pc in df_pandas.columns:
if str(pc) == gc_str:
common_tuples.append((gc, pc))
break
print(f"\nComparing {len(common_tuples)} common columns...")
# Compare first few columns
matching_count = 0
diff_count = 0
for i, (gc, pc) in enumerate(common_tuples[:20]):
gold_vals = gold_raw[gc].dropna().values
polars_vals = df_pandas[pc].dropna().values
if len(gold_vals) > 0 and len(polars_vals) > 0:
# Compare min, max, mean
if np.allclose([gold_vals.min(), gold_vals.max(), gold_vals.mean()],
[polars_vals.min(), polars_vals.max(), polars_vals.mean()],
rtol=1e-5):
matching_count += 1
else:
diff_count += 1
if diff_count <= 3:
print(f" DIFF: {gc}")
print(f" Gold: min={gold_vals.min():.6f}, max={gold_vals.max():.6f}, mean={gold_vals.mean():.6f}")
print(f" Polars: min={polars_vals.min():.6f}, max={polars_vals.max():.6f}, mean={polars_vals.mean():.6f}")
print(f"\n Matching columns: {matching_count}")
print(f" Different columns: {diff_count}")
except Exception as e:
print(f"\nError running polars pipeline: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
print("=" * 80)
print("DATA DIVERGENCE DEBUG SCRIPT")
print("Comparing gold-standard qlib output vs polars-based pipeline")
print("=" * 80)
# Step 1: Check raw data
gold_raw = compare_raw_data()
# Step 2: Check processed data
gold_proc = compare_processed_data()
# Step 3: Analyze processor transformations
analyze_processor_pipeline(gold_raw, gold_proc)
# Step 4 & 5: Run polars pipeline and compare
check_polars_pipeline()
print("\n" + "=" * 80)
print("DEBUG COMPLETE")
print("=" * 80)

@ -0,0 +1,330 @@
#!/usr/bin/env python
"""
Script to dump raw and processed datasets using the polars-based pipeline.
This generates:
1. Raw data (before applying processors) - equivalent to qlib's handler output
2. Processed data (after applying all processors) - ready for VAE encoding
Date range: 2026-02-23 to today (2026-02-27)
"""
import os
import sys
import pickle as pkl
import numpy as np
import polars as pl
from pathlib import Path
from datetime import datetime
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from generate_beta_embedding import (
load_all_data,
merge_data_sources,
filter_stock_universe,
DiffProcessor,
FlagMarketInjector,
ColumnRemover,
FlagToOnehot,
IndusNtrlInjector,
RobustZScoreNorm,
Fillna,
load_qlib_processor_params,
ALPHA158_COLS,
INDUSTRY_FLAG_COLS,
)
# Date range
START_DATE = "2026-02-23"
END_DATE = "2026-02-27"
# Output directory
OUTPUT_DIR = Path(__file__).parent.parent / "data_polars"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
"""
Apply the full processor pipeline (equivalent to qlib's proc_list).
This mimics the qlib proc_list:
0. Diff: Adds diff features for market_ext columns
1. FlagMarketInjector: Adds market_0, market_1
2. FlagSTInjector: SKIPPED (fails even in gold-standard)
3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt
4. FlagToOnehot: Converts 29 industry flags to indus_idx
5. IndusNtrlInjector: Industry neutralization for feature
6. IndusNtrlInjector: Industry neutralization for feature_ext
7. RobustZScoreNorm: Normalization using pre-fitted qlib params
8. Fillna: Fill NaN with 0
"""
print("=" * 60)
print("Applying processor pipeline")
print("=" * 60)
# market_ext columns (4 base)
market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
# market_flag columns (12 total before ColumnRemover)
market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
'open_limit', 'close_limit', 'low_limit',
'open_stop', 'close_stop', 'high_stop']
# Step 1: Diff Processor
print("\n[1] Applying Diff processor...")
diff_processor = DiffProcessor(market_ext_base)
df = diff_processor.process(df)
# After Diff: market_ext has 8 columns (4 base + 4 diff)
market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base]
# Step 2: FlagMarketInjector (adds market_0, market_1)
print("[2] Applying FlagMarketInjector...")
flag_injector = FlagMarketInjector()
df = flag_injector.process(df)
# Add market_0, market_1 to flag list
market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
# Step 3: FlagSTInjector - SKIPPED (fails even in gold-standard)
print("[3] Skipping FlagSTInjector (as per gold-standard behavior)...")
market_flag_with_st = market_flag_with_market # No IsST added
# Step 4: ColumnRemover
print("[4] Applying ColumnRemover...")
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
remover = ColumnRemover(columns_to_remove)
df = remover.process(df)
# Update column lists after removal
market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove]
market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove]
print(f" Removed columns: {columns_to_remove}")
print(f" Remaining market_ext: {len(market_ext_cols)} columns")
print(f" Remaining market_flag: {len(market_flag_with_st)} columns")
# Step 5: FlagToOnehot
print("[5] Applying FlagToOnehot...")
flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
df = flag_to_onehot.process(df)
# Step 6 & 7: IndusNtrlInjector
print("[6] Applying IndusNtrlInjector for alpha158...")
alpha158_cols = ALPHA158_COLS.copy()
indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl')
df = indus_ntrl_alpha.process(df)
print("[7] Applying IndusNtrlInjector for market_ext...")
indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl')
df = indus_ntrl_ext.process(df)
# Build column lists for normalization
alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols]
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols]
# Step 8: RobustZScoreNorm
print("[8] Applying RobustZScoreNorm...")
norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols
qlib_params = load_qlib_processor_params()
# Verify parameter shape
expected_features = len(norm_feature_cols)
if qlib_params['mean_train'].shape[0] != expected_features:
print(f" WARNING: Feature count mismatch! Expected {expected_features}, "
f"got {qlib_params['mean_train'].shape[0]}")
robust_norm = RobustZScoreNorm(
norm_feature_cols,
clip_range=(-3, 3),
use_qlib_params=True,
qlib_mean=qlib_params['mean_train'],
qlib_std=qlib_params['std_train']
)
df = robust_norm.process(df)
# Step 9: Fillna
print("[9] Applying Fillna...")
final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx']
fillna = Fillna()
df = fillna.process(df, final_feature_cols)
print("\n" + "=" * 60)
print("Processor pipeline complete!")
print(f" Normalized features: {len(norm_feature_cols)}")
print(f" Market flags: {len(market_flag_with_st)}")
print(f" Total features (with indus_idx): {len(final_feature_cols)}")
print("=" * 60)
return df
def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
"""
Convert polars DataFrame to pandas DataFrame with MultiIndex columns.
This matches the format of qlib's output.
"""
import pandas as pd
# Convert to pandas
df = df_polars.to_pandas()
# Check if datetime and instrument are columns
if 'datetime' in df.columns and 'instrument' in df.columns:
# Set MultiIndex
df = df.set_index(['datetime', 'instrument'])
# If they're already not in columns, assume they're already the index
# Drop raw columns that shouldn't be in processed data
raw_cols_to_drop = ['Turnover', 'FreeTurnover', 'MarketValue']
existing_raw_cols = [c for c in raw_cols_to_drop if c in df.columns]
if existing_raw_cols:
df = df.drop(columns=existing_raw_cols)
# Build MultiIndex columns based on column name patterns
columns_with_group = []
# Define column sets
alpha158_base = set(ALPHA158_COLS)
market_ext_base = {'turnover', 'free_turnover', 'log_size', 'con_rating_strength'}
market_ext_diff = {'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'}
market_ext_all = market_ext_base | market_ext_diff
feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit',
'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'}
for col in df.columns:
if col == 'indus_idx':
columns_with_group.append(('indus_idx', col))
elif col in feature_flag_cols:
columns_with_group.append(('feature_flag', col))
elif col.endswith('_ntrl'):
base_name = col[:-5] # Remove _ntrl suffix (5 characters)
if base_name in alpha158_base:
columns_with_group.append(('feature', col))
elif base_name in market_ext_all:
columns_with_group.append(('feature_ext', col))
else:
columns_with_group.append(('feature', col)) # Default to feature
elif col in alpha158_base:
columns_with_group.append(('feature', col))
elif col in market_ext_all:
columns_with_group.append(('feature_ext', col))
elif col in INDUSTRY_FLAG_COLS:
columns_with_group.append(('indus_flag', col))
elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}:
columns_with_group.append(('st_flag', col))
else:
# Unknown column - print warning
print(f" Warning: Unknown column '{col}', assigning to 'other' group")
columns_with_group.append(('other', col))
# Create MultiIndex columns
multi_cols = pd.MultiIndex.from_tuples(columns_with_group)
df.columns = multi_cols
return df
def main():
print("=" * 80)
print("Dumping Polars Dataset")
print("=" * 80)
print(f"Date range: {START_DATE} to {END_DATE}")
print(f"Output directory: {OUTPUT_DIR}")
print()
# Step 1: Load all data
print("Step 1: Loading data from parquet...")
df_alpha, df_kline, df_flag, df_industry = load_all_data(START_DATE, END_DATE)
print(f" Alpha158 shape: {df_alpha.shape}")
print(f" Kline (market_ext) shape: {df_kline.shape}")
print(f" Flags shape: {df_flag.shape}")
print(f" Industry shape: {df_industry.shape}")
# Step 2: Merge data sources
print("\nStep 2: Merging data sources...")
df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
print(f" Merged shape (after csiallx filter): {df_merged.shape}")
# Step 3: Save raw data (before processors)
print("\nStep 3: Saving raw data (before processors)...")
# Keep columns that match qlib's raw output format
# Include datetime and instrument for MultiIndex conversion
raw_columns = (
['datetime', 'instrument'] + # Index columns
ALPHA158_COLS + # feature group
['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + # feature_ext base
['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', # market_flag from kline
'open_limit', 'close_limit', 'low_limit',
'open_stop', 'close_stop', 'high_stop'] +
INDUSTRY_FLAG_COLS + # indus_flag
(['ST_S', 'ST_Y'] if 'ST_S' in df_merged.columns else []) # st_flag (if available)
)
# Filter to available columns
available_raw_cols = [c for c in raw_columns if c in df_merged.columns]
print(f" Selecting {len(available_raw_cols)} columns for raw data...")
df_raw_polars = df_merged.select(available_raw_cols)
# Convert to pandas with MultiIndex
df_raw_pd = convert_to_multiindex_df(df_raw_polars)
raw_output_path = OUTPUT_DIR / f"raw_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
with open(raw_output_path, "wb") as f:
pkl.dump(df_raw_pd, f)
print(f" Saved raw data to: {raw_output_path}")
print(f" Raw data shape: {df_raw_pd.shape}")
print(f" Column groups: {df_raw_pd.columns.get_level_values(0).unique().tolist()}")
# Step 4: Apply processor pipeline
print("\nStep 4: Applying processor pipeline...")
df_processed = apply_processor_pipeline(df_merged)
# Step 5: Save processed data
print("\nStep 5: Saving processed data (after processors)...")
# Convert to pandas with MultiIndex
df_processed_pd = convert_to_multiindex_df(df_processed)
processed_output_path = OUTPUT_DIR / f"processed_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
with open(processed_output_path, "wb") as f:
pkl.dump(df_processed_pd, f)
print(f" Saved processed data to: {processed_output_path}")
print(f" Processed data shape: {df_processed_pd.shape}")
print(f" Column groups: {df_processed_pd.columns.get_level_values(0).unique().tolist()}")
# Count columns per group
print("\n Column counts by group:")
for grp in df_processed_pd.columns.get_level_values(0).unique().tolist():
count = (df_processed_pd.columns.get_level_values(0) == grp).sum()
print(f" {grp}: {count} columns")
# Step 6: Verify column counts
print("\n" + "=" * 80)
print("Verification")
print("=" * 80)
feature_flag_cols = [c[1] for c in df_processed_pd.columns if c[0] == 'feature_flag']
has_market_0 = 'market_0' in feature_flag_cols
has_market_1 = 'market_1' in feature_flag_cols
print(f" feature_flag columns: {feature_flag_cols}")
print(f" Has market_0: {has_market_0}")
print(f" Has market_1: {has_market_1}")
if has_market_0 and has_market_1:
print("\n SUCCESS: market_0 and market_1 columns are present!")
else:
print("\n WARNING: market_0 or market_1 columns are missing!")
print("\n" + "=" * 80)
print("Dataset dump complete!")
print("=" * 80)
if __name__ == "__main__":
main()

@ -638,14 +638,24 @@ def apply_feature_pipeline(df: pl.DataFrame) -> Tuple[pl.DataFrame, List[str]]:
# After FlagMarketInjector: market_flag = 12 + 2 = 14 columns
market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
# Step 3: FlagSTInjector - create IsST
# Note: Actual ST flags (ST_Y, ST_S, etc.) are not available in the parquet data.
# Since the VAE was trained with IsST, we create a placeholder (all zeros).
# Step 3: FlagSTInjector - create IsST from ST flags
# Note: ST flags (ST_Y, ST_S) may not be available in parquet data.
# If available, IsST = ST_S | ST_Y; otherwise create placeholder (all zeros).
# This maintains compatibility with the VAE's expected input dimension.
print("Applying FlagSTInjector (creating IsST placeholder)...")
df = df.with_columns([
pl.lit(0).cast(pl.Int8).alias('IsST')
])
print("Applying FlagSTInjector (creating IsST)...")
# Check if ST flags are available
if 'ST_S' in df.columns or 'st_flag::ST_S' in df.columns:
# Create IsST from actual ST flags
df = df.with_columns([
((pl.col('ST_S').cast(pl.Boolean, strict=False) |
pl.col('ST_Y').cast(pl.Boolean, strict=False))
.cast(pl.Int8).alias('IsST'))
])
else:
# Create placeholder (all zeros) if ST flags not available
df = df.with_columns([
pl.lit(0).cast(pl.Int8).alias('IsST')
])
market_flag_with_st = market_flag_with_market + ['IsST']
# Step 4: ColumnRemover - remove specific columns
@ -728,16 +738,17 @@ def apply_feature_pipeline(df: pl.DataFrame) -> Tuple[pl.DataFrame, List[str]]:
# - From kline_adjusted: IsXD, IsXR, IsDR (3 cols)
# - From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop (6 cols)
# - Added by FlagMarketInjector: market_0, market_1 (2 cols)
# - Total market flags for VAE: 3 + 6 + 2 = 11 (IsST is excluded)
# - Added by FlagSTInjector: IsST (1 col, placeholder if ST flags not available)
# - Total market flags: 3 + 6 + 2 + 1 = 12 (IsST excluded from VAE input)
#
# Total features:
# - norm_feature_cols: 158 + 158 + 7 + 7 = 330
# - market_flag_with_st: 11 (excluding IsST)
# - market_flag_with_st: 12 (including IsST)
# - indus_idx: 1
# - Total: 330 + 11 + 1 = 342 features (but IsST stays in df for completeness)
# - Total: 330 + 12 + 1 = 343 features
#
# VAE input dimension (feature + feature_ext + feature_flag only, no indus_idx):
# - 316 (alpha158 + ntrl) + 14 (market_ext + ntrl) + 11 (flags) = 341
# - 316 (alpha158 + ntrl) + 14 (market_ext + ntrl) + 11 (flags, excluding IsST) = 341
# Exclude IsST from VAE input features (it's a placeholder)
market_flag_for_vae = [c for c in market_flag_with_st if c != 'IsST']

@ -452,21 +452,56 @@ class FixedFillna:
class FixedFlagMarketInjector:
"""Fixed FlagMarketInjector that handles :: separator format."""
"""Fixed FlagMarketInjector that handles :: separator format.
This processor adds market classification columns based on instrument codes:
- market_0: Main board stocks (SH60xxx, SZ00xxx)
- market_1: STAR/ChiNext stocks (SH688xxx, SH689xxx, SZ300xxx, SZ301xxx)
This matches the original qlib FlagMarketInjector behavior with vocab_size=2.
"""
def __init__(self, fields_group, vocab_size=2):
self.fields_group = fields_group
self.vocab_size = vocab_size
def __call__(self, df):
cols = [c for c in df.columns if c.startswith(f"{self.fields_group}::")]
for col in cols:
df[col] = df[col].astype('int8')
import pandas as pd
# Get instrument codes from index
inst = df.index.get_level_values('instrument')
# market_0: SH60xxx or SZ00xxx (主板)
market_0 = inst.str.startswith('SH60') | inst.str.startswith('SZ00')
# market_1: SH688xxx, SH689xxx, SZ300xxx, SZ301xxx (科创/创业板)
market_1 = (inst.str.startswith('SH688') | inst.str.startswith('SH689') |
inst.str.startswith('SZ300') | inst.str.startswith('SZ301'))
# Add columns to feature_flag group
df[('feature_flag', 'market_0')] = market_0.astype('int8')
df[('feature_flag', 'market_1')] = market_1.astype('int8')
# Also convert existing feature_flag columns to int8
# Handle both :: separator string format and MultiIndex tuple format
for col in df.columns:
if isinstance(col, str) and col.startswith(f"{self.fields_group}::"):
df[col] = df[col].astype('int8')
elif isinstance(col, tuple) and col[0] == self.fields_group:
df[col] = df[col].astype('int8')
return df
class FixedFlagSTInjector:
"""Fixed FlagSTInjector that handles :: separator format."""
"""Fixed FlagSTInjector that handles :: separator format.
This processor creates the IsST column from ST_S and ST_Y flags:
- IsST = 1 if ST_S or ST_Y is True (stock is ST or *ST)
- IsST = 0 otherwise
This matches the original qlib FlagSTInjector behavior.
"""
def __init__(self, fields_group, st_group="st_flag", col_name="IsST"):
self.fields_group = fields_group
@ -474,9 +509,24 @@ class FixedFlagSTInjector:
self.col_name = col_name
def __call__(self, df):
cols = [c for c in df.columns if c.startswith(f"{self.st_group}::")]
for col in cols:
df[col] = df[col].astype('int8')
import pandas as pd
# IsST = True if ST_S or ST_Y is True
st_s_col = ('st_flag', 'ST_S')
st_y_col = ('st_flag', 'ST_Y')
if st_s_col in df.columns and st_y_col in df.columns:
is_st = df[st_s_col].astype(bool) | df[st_y_col].astype(bool)
df[('feature_flag', 'IsST')] = is_st.astype('int8')
# Also convert existing st_flag columns to int8
# Handle both :: separator string format and MultiIndex tuple format
for col in df.columns:
if isinstance(col, str) and col.startswith(f"{self.st_group}::"):
df[col] = df[col].astype('int8')
elif isinstance(col, tuple) and col[0] == self.st_group:
df[col] = df[col].astype('int8')
return df

Loading…
Cancel
Save