From ea011090f8f8f529e8cab37ae6f06227d0941bf6 Mon Sep 17 00:00:00 2001
From: guofu
Date: Sun, 1 Mar 2026 13:16:06 +0800
Subject: [PATCH] Update documentation for processors module and commit test
file
- Add comprehensive processors module documentation to cta_1d/README.md
including usage examples for RobustZScoreNorm.from_version() and
parameter extraction instructions
- Add test_processors.py test script for validating processor functionality
- Update metadata.json timestamp from parameter extraction
Co-Authored-By: Claude Opus 4.6
---
cta_1d/README.md | 76 +++
cta_1d/src/processors/test_processors.py | 508 ++++++++++++++++++
.../metadata.json | 2 +-
3 files changed, 585 insertions(+), 1 deletion(-)
create mode 100644 cta_1d/src/processors/test_processors.py
diff --git a/cta_1d/README.md b/cta_1d/README.md
index 664105a..ab93c01 100644
--- a/cta_1d/README.md
+++ b/cta_1d/README.md
@@ -34,3 +34,79 @@ Predefined weights (from qshare.config.research.cta.labels):
- `long_term`: [0.4, 0.2, 0.2, 0.2]
Default: [0.2, 0.1, 0.3, 0.4]
+
+## Processors Module
+
+The `cta_1d.src.processors` module provides Polars-based data processors that replicate Qlib's preprocessing pipeline:
+
+### Available Processors
+
+| Processor | Description |
+|-----------|-------------|
+| `DiffProcessor` | Adds diff features with configurable period |
+| `FlagMarketInjector` | Adds market_0, market_1 columns based on instrument codes |
+| `FlagSTInjector` | Creates IsST column from ST flags |
+| `ColumnRemover` | Removes specified columns |
+| `FlagToOnehot` | Converts one-hot industry flags to single index column |
+| `IndusNtrlInjector` | Industry neutralization per datetime |
+| `RobustZScoreNorm` | Robust z-score normalization using median/MAD |
+| `Fillna` | Fills NaN values with specified value |
+
+### RobustZScoreNorm with Pre-fitted Parameters
+
+The `RobustZScoreNorm` processor supports loading pre-fitted parameters from Qlib's `proc_list.proc`:
+
+```python
+from cta_1d.src.processors import RobustZScoreNorm
+
+# Method 1: Load from saved version (recommended)
+processor = RobustZScoreNorm.from_version("csiallx_feature2_ntrla_flag_pnlnorm")
+
+# Method 2: Load with direct parameters
+processor = RobustZScoreNorm(
+ feature_cols=['KMID', 'KLEN', ...],
+ use_qlib_params=True,
+ qlib_mean=mean_array,
+ qlib_std=std_array
+)
+
+# Apply normalization
+df = processor.process(df)
+```
+
+### Parameter Extraction
+
+Extract parameters from Qlib's proc_list.proc:
+
+```bash
+python stock_1d/d033/alpha158_beta/scripts/extract_qlib_params.py \
+ --proc-list /path/to/proc_list.proc \
+ --version my_version
+```
+
+Output structure:
+```
+data/robust_zscore_params/{version}/
+├── mean_train.npy # Pre-fitted mean (330,)
+├── std_train.npy # Pre-fitted std (330,)
+└── metadata.json # Feature columns and metadata
+```
+
+### Pipeline Helper Functions
+
+```python
+from cta_1d.src.processors import create_processor_pipeline, get_final_feature_columns
+
+# Create pipeline from processor configs
+pipeline = create_processor_pipeline([
+ {'type': 'Diff', 'columns': ['turnover', 'free_turnover']},
+ {'type': 'RobustZScoreNorm', 'feature_cols': feature_cols},
+ {'type': 'Fillna', 'value': 0},
+])
+
+# Get final feature columns after industry neutralization
+final_cols = get_final_feature_columns(
+ alpha158_cols=ALPHA158_COLS,
+ market_ext_cols=MARKET_EXT_COLS,
+)
+```
diff --git a/cta_1d/src/processors/test_processors.py b/cta_1d/src/processors/test_processors.py
new file mode 100644
index 0000000..ada3c77
--- /dev/null
+++ b/cta_1d/src/processors/test_processors.py
@@ -0,0 +1,508 @@
+#!/usr/bin/env python
+"""
+Test script for the Polars processors module.
+
+This script verifies that all processors in cta_1d.src.processors work correctly
+and produce expected outputs.
+"""
+
+import sys
+import numpy as np
+import polars as pl
+from pathlib import Path
+
+# Add cta_1d to path
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from cta_1d.src.processors import (
+ DiffProcessor,
+ FlagMarketInjector,
+ FlagSTInjector,
+ ColumnRemover,
+ FlagToOnehot,
+ IndusNtrlInjector,
+ RobustZScoreNorm,
+ Fillna,
+ create_processor_pipeline,
+ get_final_feature_columns,
+)
+
+
+def create_test_data(
+ n_dates: int = 10,
+ n_instruments: int = 10,
+ include_st_flags: bool = True,
+ include_industry: bool = True,
+) -> pl.DataFrame:
+ """Create test DataFrame with realistic structure."""
+ np.random.seed(42)
+
+ # Use string dates to avoid polars issues
+ dates = [f'2024-01-{i+1:02d}' for i in range(n_dates)]
+ instruments = [f'SH60000{i}' for i in range(n_instruments)]
+
+ data = {
+ 'datetime': [],
+ 'instrument': [],
+ # Alpha158 features
+ 'KMID': [],
+ 'KLEN': [],
+ 'ROC5': [],
+ # Market ext
+ 'turnover': [],
+ 'free_turnover': [],
+ 'log_size': [],
+ 'con_rating_strength': [],
+ # Market flags
+ 'IsZt': [],
+ 'IsDt': [],
+ 'IsN': [],
+ 'IsXD': [],
+ 'IsXR': [],
+ 'IsDR': [],
+ 'open_limit': [],
+ 'close_limit': [],
+ 'low_limit': [],
+ 'open_stop': [],
+ 'close_stop': [],
+ 'high_stop': [],
+ }
+
+ if include_st_flags:
+ data['ST_S'] = []
+ data['ST_Y'] = []
+
+ if include_industry:
+ # Add 3 industry flags for testing
+ data['gds_CC10'] = []
+ data['gds_CC11'] = []
+ data['gds_CC20'] = []
+
+ for d in dates:
+ for inst in instruments:
+ data['datetime'].append(d)
+ data['instrument'].append(inst)
+ data['KMID'].append(np.random.randn() * 0.1)
+ data['KLEN'].append(np.random.randn())
+ data['ROC5'].append(np.random.randn() * 0.05)
+ data['turnover'].append(np.random.randn() * 1e6 + 1e7)
+ data['free_turnover'].append(np.random.randn() * 1e6 + 1e7)
+ data['log_size'].append(np.log(1e10 + np.random.randn() * 1e9))
+ data['con_rating_strength'].append(np.random.randn())
+ data['IsZt'].append(np.random.randint(0, 2))
+ data['IsDt'].append(np.random.randint(0, 2))
+ data['IsN'].append(np.random.randint(0, 2))
+ data['IsXD'].append(np.random.randint(0, 2))
+ data['IsXR'].append(np.random.randint(0, 2))
+ data['IsDR'].append(np.random.randint(0, 2))
+ data['open_limit'].append(np.random.randint(0, 2))
+ data['close_limit'].append(np.random.randint(0, 2))
+ data['low_limit'].append(np.random.randint(0, 2))
+ data['open_stop'].append(np.random.randint(0, 2))
+ data['close_stop'].append(np.random.randint(0, 2))
+ data['high_stop'].append(np.random.randint(0, 2))
+
+ if include_st_flags:
+ data['ST_S'].append(np.random.randint(0, 2))
+ data['ST_Y'].append(np.random.randint(0, 2))
+
+ if include_industry:
+ # Set exactly one industry to 1
+ industry_choice = np.random.randint(0, 3)
+ data['gds_CC10'].append(1 if industry_choice == 0 else 0)
+ data['gds_CC11'].append(1 if industry_choice == 1 else 0)
+ data['gds_CC20'].append(1 if industry_choice == 2 else 0)
+
+ return pl.DataFrame(data)
+
+
+def test_diff_processor():
+ """Test DiffProcessor."""
+ print("Testing DiffProcessor...")
+ df = create_test_data(n_dates=5, n_instruments=3)
+
+ processor = DiffProcessor(['turnover', 'log_size'])
+ df_out = processor.process(df)
+
+ assert 'turnover_diff' in df_out.columns, "turnover_diff not created"
+ assert 'log_size_diff' in df_out.columns, "log_size_diff not created"
+ assert df_out.shape[0] == df.shape[0], "Row count changed"
+
+ # Verify diff is calculated per instrument
+ inst_data = df_out.filter(pl.col('instrument') == 'SH600000')
+ turnover_diff = inst_data['turnover_diff'].to_list()
+ # First value should be null (no previous row to diff)
+ assert turnover_diff[0] is None or np.isnan(turnover_diff[0]), "First diff should be null"
+
+ print(" PASSED")
+
+
+def test_flag_market_injector():
+ """Test FlagMarketInjector."""
+ print("Testing FlagMarketInjector...")
+
+ # Test with different market types
+ df = pl.DataFrame({
+ 'datetime': ['2024-01-01'] * 6,
+ 'instrument': [
+ 'SH600000', # SH main -> market_0=1
+ 'SH688000', # SH STAR -> market_1=1
+ 'SZ000001', # SZ main -> market_0=1
+ 'SZ300001', # SZ ChiNext -> market_1=1
+ 'NE400001', # NE (New Third Board) -> both 0
+ 'SH601000', # SH main -> market_0=1
+ ],
+ 'value': [1, 2, 3, 4, 5, 6],
+ })
+
+ processor = FlagMarketInjector()
+ df_out = processor.process(df)
+
+ assert 'market_0' in df_out.columns, "market_0 not created"
+ assert 'market_1' in df_out.columns, "market_1 not created"
+
+ results = df_out.select(['instrument', 'market_0', 'market_1']).rows()
+
+ # Verify market classifications
+ assert results[0][1] == 1 and results[0][2] == 0, "SH600000 should be market_0"
+ assert results[1][1] == 0 and results[1][2] == 1, "SH688000 should be market_1"
+ assert results[2][1] == 1 and results[2][2] == 0, "SZ000001 should be market_0"
+ assert results[3][1] == 0 and results[3][2] == 1, "SZ300001 should be market_1"
+ assert results[4][1] == 0 and results[4][2] == 0, "NE400001 should be neither"
+
+ print(" PASSED")
+
+
+def test_flag_st_injector():
+ """Test FlagSTInjector."""
+ print("Testing FlagSTInjector...")
+
+ # Test with ST flags
+ df = pl.DataFrame({
+ 'datetime': ['2024-01-01'] * 4,
+ 'instrument': ['SH600000', 'SH600001', 'SH600002', 'SH600003'],
+ 'ST_S': [0, 1, 0, 1],
+ 'ST_Y': [0, 0, 1, 1],
+ })
+
+ processor = FlagSTInjector()
+ df_out = processor.process(df)
+
+ assert 'IsST' in df_out.columns, "IsST not created"
+
+ is_st = df_out['IsST'].to_list()
+ assert is_st[0] == 0, "ST_S=0, ST_Y=0 -> IsST=0"
+ assert is_st[1] == 1, "ST_S=1 -> IsST=1"
+ assert is_st[2] == 1, "ST_Y=1 -> IsST=1"
+ assert is_st[3] == 1, "ST_S=1, ST_Y=1 -> IsST=1"
+
+ # Test without ST flags (placeholder mode)
+ df_no_st = pl.DataFrame({
+ 'datetime': ['2024-01-01'] * 2,
+ 'instrument': ['SH600000', 'SH600001'],
+ 'value': [1, 2],
+ })
+
+ df_out_no_st = processor.process(df_no_st)
+ assert df_out_no_st['IsST'].sum() == 0, "Placeholder IsST should be all zeros"
+
+ print(" PASSED")
+
+
+def test_column_remover():
+ """Test ColumnRemover."""
+ print("Testing ColumnRemover...")
+ df = create_test_data(n_dates=3, n_instruments=2)
+
+ processor = ColumnRemover(['IsZt', 'IsDt', 'nonexistent_column'])
+ df_out = processor.process(df)
+
+ assert 'IsZt' not in df_out.columns, "IsZt not removed"
+ assert 'IsDt' not in df_out.columns, "IsDt not removed"
+ assert 'IsN' in df_out.columns, "IsN should remain"
+
+ print(" PASSED")
+
+
+def test_flag_to_onehot():
+ """Test FlagToOnehot."""
+ print("Testing FlagToOnehot...")
+
+ df = pl.DataFrame({
+ 'datetime': ['2024-01-01'] * 6,
+ 'instrument': [f'SH60000{i}' for i in range(6)],
+ 'gds_CC10': [1, 0, 0, 0, 0, 0],
+ 'gds_CC11': [0, 1, 0, 0, 0, 0],
+ 'gds_CC20': [0, 0, 1, 0, 0, 0],
+ 'value': [1, 2, 3, 4, 5, 6],
+ })
+
+ processor = FlagToOnehot(['gds_CC10', 'gds_CC11', 'gds_CC20'])
+ df_out = processor.process(df)
+
+ assert 'indus_idx' in df_out.columns, "indus_idx not created"
+ assert 'gds_CC10' not in df_out.columns, "gds_CC10 not removed"
+ assert 'gds_CC11' not in df_out.columns, "gds_CC11 not removed"
+ assert 'gds_CC20' not in df_out.columns, "gds_CC20 not removed"
+
+ indus_idx = df_out['indus_idx'].to_list()
+ assert indus_idx[0] == 0, "gds_CC10=1 -> indus_idx=0"
+ assert indus_idx[1] == 1, "gds_CC11=1 -> indus_idx=1"
+ assert indus_idx[2] == 2, "gds_CC20=1 -> indus_idx=2"
+
+ print(" PASSED")
+
+
+def test_indus_ntrl_injector():
+ """Test IndusNtrlInjector."""
+ print("Testing IndusNtrlInjector...")
+
+ # Create data with clear industry groups (use string dates to avoid polars issues)
+ df = pl.DataFrame({
+ 'datetime': ['2024-01-01'] * 6 + ['2024-01-02'] * 6,
+ 'instrument': [f'SH60000{i}' for i in range(6)] * 2,
+ 'indus_idx': [0, 0, 0, 1, 1, 1] * 2, # Two industries
+ 'KMID': [1.0, 2.0, 3.0, 10.0, 20.0, 30.0] * 2, # Clear industry means
+ })
+
+ processor = IndusNtrlInjector(['KMID'], suffix='_ntrl')
+ df_out = processor.process(df)
+
+ assert 'KMID_ntrl' in df_out.columns, "KMID_ntrl not created"
+
+ # Industry 0 mean = 2.0, Industry 1 mean = 20.0
+ # Neutralized values should be centered around 0 within each industry
+ ntrl_values = df_out['KMID_ntrl'].to_list()
+
+ # Check that industry means are approximately 0 after neutralization
+ df_indus_0 = df_out.filter(pl.col('indus_idx') == 0)
+ df_indus_1 = df_out.filter(pl.col('indus_idx') == 1)
+
+ mean_0 = df_indus_0['KMID_ntrl'].mean()
+ mean_1 = df_indus_1['KMID_ntrl'].mean()
+
+ assert abs(mean_0) < 1e-6, f"Industry 0 mean should be ~0, got {mean_0}"
+ assert abs(mean_1) < 1e-6, f"Industry 1 mean should be ~0, got {mean_1}"
+
+ print(" PASSED")
+
+
+def test_robust_zscore_norm():
+ """Test RobustZScoreNorm."""
+ print("Testing RobustZScoreNorm...")
+
+ # Test with pre-fitted parameters
+ df = pl.DataFrame({
+ 'datetime': ['2024-01-01'] * 10,
+ 'instrument': [f'SH60000{i}' for i in range(10)],
+ 'feat1': list(range(10)), # 0-9
+ })
+
+ # Pre-fitted params: mean=4.5, std=~2.87 for 0-9
+ mean_train = np.array([4.5])
+ std_train = np.array([2.87])
+
+ processor = RobustZScoreNorm(
+ ['feat1'],
+ clip_range=(-3, 3),
+ use_qlib_params=True,
+ qlib_mean=mean_train,
+ qlib_std=std_train
+ )
+ df_out = processor.process(df)
+
+ # Verify normalization
+ # feat1=0 -> z = (0-4.5)/2.87 = -1.57
+ # feat1=9 -> z = (9-4.5)/2.87 = 1.57
+ assert df_out['feat1'].min() < -1.5, "Min z-score should be around -1.57"
+ assert df_out['feat1'].max() > 1.5, "Max z-score should be around 1.57"
+
+ # Test per-datetime mode
+ processor2 = RobustZScoreNorm(['feat1'], use_qlib_params=False)
+ df_out2 = processor2.process(df.clone())
+
+ # After per-datetime normalization, median should be ~0
+ median = df_out2['feat1'].median()
+ assert abs(median) < 0.1, f"Median should be ~0, got {median}"
+
+ print(" PASSED")
+
+
+def test_fillna():
+ """Test Fillna."""
+ print("Testing Fillna...")
+
+ df = pl.DataFrame({
+ 'datetime': ['2024-01-01'] * 5,
+ 'instrument': [f'SH60000{i}' for i in range(5)],
+ 'feat1': [1.0, None, 3.0, float('nan'), 5.0],
+ })
+
+ processor = Fillna(fill_value=0.0)
+ df_out = processor.process(df, ['feat1'])
+
+ # Check no null/nan values remain
+ assert df_out['feat1'].null_count() == 0, "Null values remain"
+ assert not any(np.isnan(df_out['feat1'].to_numpy())), "NaN values remain"
+
+ print(" PASSED")
+
+
+def test_full_pipeline():
+ """Test complete processor pipeline."""
+ print("Testing full processor pipeline...")
+
+ df = create_test_data(n_dates=5, n_instruments=5, include_st_flags=True, include_industry=True)
+ original_shape = df.shape
+
+ # market_ext columns
+ market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
+
+ # Apply pipeline
+ # 1. Diff
+ diff_proc = DiffProcessor(market_ext_base)
+ df = diff_proc.process(df)
+
+ # 2. FlagMarketInjector
+ market_proc = FlagMarketInjector()
+ df = market_proc.process(df)
+
+ # 3. FlagSTInjector
+ st_proc = FlagSTInjector()
+ df = st_proc.process(df)
+
+ # 4. ColumnRemover
+ remove_proc = ColumnRemover(['log_size_diff', 'IsN', 'IsZt', 'IsDt'])
+ df = remove_proc.process(df)
+
+ # 5. FlagToOnehot
+ onehot_proc = FlagToOnehot(['gds_CC10', 'gds_CC11', 'gds_CC20'])
+ df = onehot_proc.process(df)
+
+ # 6. IndusNtrlInjector
+ ntrl_proc = IndusNtrlInjector(['KMID', 'KLEN', 'ROC5'], suffix='_ntrl')
+ df = ntrl_proc.process(df)
+
+ # 7. RobustZScoreNorm (per-datetime mode for testing)
+ norm_proc = RobustZScoreNorm(['KMID', 'KLEN', 'ROC5', 'KMID_ntrl', 'KLEN_ntrl', 'ROC5_ntrl'],
+ use_qlib_params=False)
+ df = norm_proc.process(df)
+
+ # 8. Fillna
+ fillna_proc = Fillna(fill_value=0.0)
+ df = fillna_proc.process(df, df.columns)
+
+ # Verify final structure
+ assert df.shape[0] == original_shape[0], "Row count changed"
+ assert 'indus_idx' in df.columns, "indus_idx missing"
+ assert 'market_0' in df.columns, "market_0 missing"
+ assert 'IsST' in df.columns, "IsST missing"
+ assert 'KMID_ntrl' in df.columns, "KMID_ntrl missing"
+
+ print(" PASSED")
+
+
+def test_create_processor_pipeline():
+ """Test create_processor_pipeline utility."""
+ print("Testing create_processor_pipeline...")
+
+ pipeline = create_processor_pipeline(
+ alpha158_cols=['KMID', 'KLEN'],
+ market_ext_base=['turnover', 'log_size'],
+ market_flag_cols=['IsZt', 'IsDt'],
+ industry_flag_cols=['gds_CC10', 'gds_CC11'],
+ )
+
+ assert len(pipeline) == 7, f"Expected 7 processors, got {len(pipeline)}"
+ assert isinstance(pipeline[0], DiffProcessor)
+ assert isinstance(pipeline[1], FlagMarketInjector)
+ assert isinstance(pipeline[2], FlagSTInjector)
+ assert isinstance(pipeline[3], ColumnRemover)
+ assert isinstance(pipeline[4], FlagToOnehot)
+ assert isinstance(pipeline[5], IndusNtrlInjector)
+ assert isinstance(pipeline[6], IndusNtrlInjector)
+
+ print(" PASSED")
+
+
+def test_get_final_feature_columns():
+ """Test get_final_feature_columns utility."""
+ print("Testing get_final_feature_columns...")
+
+ feature_struct = get_final_feature_columns(
+ alpha158_cols=['KMID', 'KLEN', 'ROC5'],
+ market_ext_base=['turnover', 'log_size'],
+ market_flag_cols=['IsZt', 'IsDt', 'IsN'],
+ )
+
+ # Verify structure
+ assert 'alpha158_cols' in feature_struct
+ assert 'alpha158_ntrl_cols' in feature_struct
+ assert 'market_ext_cols' in feature_struct
+ assert 'market_ext_ntrl_cols' in feature_struct
+ assert 'market_flag_cols' in feature_struct
+ assert 'norm_feature_cols' in feature_struct
+ assert 'all_feature_cols' in feature_struct
+
+ # Verify counts
+ assert len(feature_struct['alpha158_cols']) == 3
+ assert len(feature_struct['alpha158_ntrl_cols']) == 3
+ # market_ext: 2 base + 2 diff - 1 removed (log_size_diff) = 3
+ assert len(feature_struct['market_ext_cols']) == 3
+ # market_flag: IsZt, IsDt, IsN removed (3) + market_0, market_1, IsST added (3) = 3
+ assert len(feature_struct['market_flag_cols']) == 3
+
+ # Verify norm_feature_cols order (ntrl + raw for each group)
+ norm_cols = feature_struct['norm_feature_cols']
+ assert norm_cols[0] == 'KMID_ntrl', "First norm col should be KMID_ntrl"
+ assert norm_cols[3] == 'KMID', "Fourth norm col should be KMID (after 3 ntrl)"
+
+ print(" PASSED")
+
+
+def main():
+ """Run all tests."""
+ print("=" * 60)
+ print("Polars Processors Module Tests")
+ print("=" * 60)
+ print()
+
+ tests = [
+ test_diff_processor,
+ test_flag_market_injector,
+ test_flag_st_injector,
+ test_column_remover,
+ test_flag_to_onehot,
+ test_indus_ntrl_injector,
+ test_robust_zscore_norm,
+ test_fillna,
+ test_full_pipeline,
+ test_create_processor_pipeline,
+ test_get_final_feature_columns,
+ ]
+
+ passed = 0
+ failed = 0
+
+ for test_func in tests:
+ try:
+ test_func()
+ passed += 1
+ except AssertionError as e:
+ print(f" FAILED: {e}")
+ failed += 1
+ except Exception as e:
+ print(f" ERROR: {e}")
+ failed += 1
+
+ print()
+ print("=" * 60)
+ print(f"Results: {passed} passed, {failed} failed")
+ print("=" * 60)
+
+ return failed == 0
+
+
+if __name__ == "__main__":
+ success = main()
+ sys.exit(0 if success else 1)
diff --git a/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json b/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json
index 390a6ce..298f22d 100644
--- a/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json
+++ b/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json
@@ -1,6 +1,6 @@
{
"version": "csiallx_feature2_ntrla_flag_pnlnorm",
- "created_at": "2026-03-01T12:18:01.969109",
+ "created_at": "2026-03-01T13:11:57.144613",
"source_file": "2013-01-01",
"fit_start_time": "2013-01-01",
"fit_end_time": "2018-12-31",