Update documentation for processors module and commit test file

- Add comprehensive processors module documentation to cta_1d/README.md
  including usage examples for RobustZScoreNorm.from_version() and
  parameter extraction instructions

- Add test_processors.py test script for validating processor functionality

- Update metadata.json timestamp from parameter extraction

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
master
guofu 3 days ago
parent 89bd1a528e
commit ea011090f8

@ -34,3 +34,79 @@ Predefined weights (from qshare.config.research.cta.labels):
- `long_term`: [0.4, 0.2, 0.2, 0.2] - `long_term`: [0.4, 0.2, 0.2, 0.2]
Default: [0.2, 0.1, 0.3, 0.4] Default: [0.2, 0.1, 0.3, 0.4]
## Processors Module
The `cta_1d.src.processors` module provides Polars-based data processors that replicate Qlib's preprocessing pipeline:
### Available Processors
| Processor | Description |
|-----------|-------------|
| `DiffProcessor` | Adds diff features with configurable period |
| `FlagMarketInjector` | Adds market_0, market_1 columns based on instrument codes |
| `FlagSTInjector` | Creates IsST column from ST flags |
| `ColumnRemover` | Removes specified columns |
| `FlagToOnehot` | Converts one-hot industry flags to single index column |
| `IndusNtrlInjector` | Industry neutralization per datetime |
| `RobustZScoreNorm` | Robust z-score normalization using median/MAD |
| `Fillna` | Fills NaN values with specified value |
### RobustZScoreNorm with Pre-fitted Parameters
The `RobustZScoreNorm` processor supports loading pre-fitted parameters from Qlib's `proc_list.proc`:
```python
from cta_1d.src.processors import RobustZScoreNorm
# Method 1: Load from saved version (recommended)
processor = RobustZScoreNorm.from_version("csiallx_feature2_ntrla_flag_pnlnorm")
# Method 2: Load with direct parameters
processor = RobustZScoreNorm(
feature_cols=['KMID', 'KLEN', ...],
use_qlib_params=True,
qlib_mean=mean_array,
qlib_std=std_array
)
# Apply normalization
df = processor.process(df)
```
### Parameter Extraction
Extract parameters from Qlib's proc_list.proc:
```bash
python stock_1d/d033/alpha158_beta/scripts/extract_qlib_params.py \
--proc-list /path/to/proc_list.proc \
--version my_version
```
Output structure:
```
data/robust_zscore_params/{version}/
├── mean_train.npy # Pre-fitted mean (330,)
├── std_train.npy # Pre-fitted std (330,)
└── metadata.json # Feature columns and metadata
```
### Pipeline Helper Functions
```python
from cta_1d.src.processors import create_processor_pipeline, get_final_feature_columns
# Create pipeline from processor configs
pipeline = create_processor_pipeline([
{'type': 'Diff', 'columns': ['turnover', 'free_turnover']},
{'type': 'RobustZScoreNorm', 'feature_cols': feature_cols},
{'type': 'Fillna', 'value': 0},
])
# Get final feature columns after industry neutralization
final_cols = get_final_feature_columns(
alpha158_cols=ALPHA158_COLS,
market_ext_cols=MARKET_EXT_COLS,
)
```

@ -0,0 +1,508 @@
#!/usr/bin/env python
"""
Test script for the Polars processors module.
This script verifies that all processors in cta_1d.src.processors work correctly
and produce expected outputs.
"""
import sys
import numpy as np
import polars as pl
from pathlib import Path
# Add cta_1d to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from cta_1d.src.processors import (
DiffProcessor,
FlagMarketInjector,
FlagSTInjector,
ColumnRemover,
FlagToOnehot,
IndusNtrlInjector,
RobustZScoreNorm,
Fillna,
create_processor_pipeline,
get_final_feature_columns,
)
def create_test_data(
n_dates: int = 10,
n_instruments: int = 10,
include_st_flags: bool = True,
include_industry: bool = True,
) -> pl.DataFrame:
"""Create test DataFrame with realistic structure."""
np.random.seed(42)
# Use string dates to avoid polars issues
dates = [f'2024-01-{i+1:02d}' for i in range(n_dates)]
instruments = [f'SH60000{i}' for i in range(n_instruments)]
data = {
'datetime': [],
'instrument': [],
# Alpha158 features
'KMID': [],
'KLEN': [],
'ROC5': [],
# Market ext
'turnover': [],
'free_turnover': [],
'log_size': [],
'con_rating_strength': [],
# Market flags
'IsZt': [],
'IsDt': [],
'IsN': [],
'IsXD': [],
'IsXR': [],
'IsDR': [],
'open_limit': [],
'close_limit': [],
'low_limit': [],
'open_stop': [],
'close_stop': [],
'high_stop': [],
}
if include_st_flags:
data['ST_S'] = []
data['ST_Y'] = []
if include_industry:
# Add 3 industry flags for testing
data['gds_CC10'] = []
data['gds_CC11'] = []
data['gds_CC20'] = []
for d in dates:
for inst in instruments:
data['datetime'].append(d)
data['instrument'].append(inst)
data['KMID'].append(np.random.randn() * 0.1)
data['KLEN'].append(np.random.randn())
data['ROC5'].append(np.random.randn() * 0.05)
data['turnover'].append(np.random.randn() * 1e6 + 1e7)
data['free_turnover'].append(np.random.randn() * 1e6 + 1e7)
data['log_size'].append(np.log(1e10 + np.random.randn() * 1e9))
data['con_rating_strength'].append(np.random.randn())
data['IsZt'].append(np.random.randint(0, 2))
data['IsDt'].append(np.random.randint(0, 2))
data['IsN'].append(np.random.randint(0, 2))
data['IsXD'].append(np.random.randint(0, 2))
data['IsXR'].append(np.random.randint(0, 2))
data['IsDR'].append(np.random.randint(0, 2))
data['open_limit'].append(np.random.randint(0, 2))
data['close_limit'].append(np.random.randint(0, 2))
data['low_limit'].append(np.random.randint(0, 2))
data['open_stop'].append(np.random.randint(0, 2))
data['close_stop'].append(np.random.randint(0, 2))
data['high_stop'].append(np.random.randint(0, 2))
if include_st_flags:
data['ST_S'].append(np.random.randint(0, 2))
data['ST_Y'].append(np.random.randint(0, 2))
if include_industry:
# Set exactly one industry to 1
industry_choice = np.random.randint(0, 3)
data['gds_CC10'].append(1 if industry_choice == 0 else 0)
data['gds_CC11'].append(1 if industry_choice == 1 else 0)
data['gds_CC20'].append(1 if industry_choice == 2 else 0)
return pl.DataFrame(data)
def test_diff_processor():
"""Test DiffProcessor."""
print("Testing DiffProcessor...")
df = create_test_data(n_dates=5, n_instruments=3)
processor = DiffProcessor(['turnover', 'log_size'])
df_out = processor.process(df)
assert 'turnover_diff' in df_out.columns, "turnover_diff not created"
assert 'log_size_diff' in df_out.columns, "log_size_diff not created"
assert df_out.shape[0] == df.shape[0], "Row count changed"
# Verify diff is calculated per instrument
inst_data = df_out.filter(pl.col('instrument') == 'SH600000')
turnover_diff = inst_data['turnover_diff'].to_list()
# First value should be null (no previous row to diff)
assert turnover_diff[0] is None or np.isnan(turnover_diff[0]), "First diff should be null"
print(" PASSED")
def test_flag_market_injector():
"""Test FlagMarketInjector."""
print("Testing FlagMarketInjector...")
# Test with different market types
df = pl.DataFrame({
'datetime': ['2024-01-01'] * 6,
'instrument': [
'SH600000', # SH main -> market_0=1
'SH688000', # SH STAR -> market_1=1
'SZ000001', # SZ main -> market_0=1
'SZ300001', # SZ ChiNext -> market_1=1
'NE400001', # NE (New Third Board) -> both 0
'SH601000', # SH main -> market_0=1
],
'value': [1, 2, 3, 4, 5, 6],
})
processor = FlagMarketInjector()
df_out = processor.process(df)
assert 'market_0' in df_out.columns, "market_0 not created"
assert 'market_1' in df_out.columns, "market_1 not created"
results = df_out.select(['instrument', 'market_0', 'market_1']).rows()
# Verify market classifications
assert results[0][1] == 1 and results[0][2] == 0, "SH600000 should be market_0"
assert results[1][1] == 0 and results[1][2] == 1, "SH688000 should be market_1"
assert results[2][1] == 1 and results[2][2] == 0, "SZ000001 should be market_0"
assert results[3][1] == 0 and results[3][2] == 1, "SZ300001 should be market_1"
assert results[4][1] == 0 and results[4][2] == 0, "NE400001 should be neither"
print(" PASSED")
def test_flag_st_injector():
"""Test FlagSTInjector."""
print("Testing FlagSTInjector...")
# Test with ST flags
df = pl.DataFrame({
'datetime': ['2024-01-01'] * 4,
'instrument': ['SH600000', 'SH600001', 'SH600002', 'SH600003'],
'ST_S': [0, 1, 0, 1],
'ST_Y': [0, 0, 1, 1],
})
processor = FlagSTInjector()
df_out = processor.process(df)
assert 'IsST' in df_out.columns, "IsST not created"
is_st = df_out['IsST'].to_list()
assert is_st[0] == 0, "ST_S=0, ST_Y=0 -> IsST=0"
assert is_st[1] == 1, "ST_S=1 -> IsST=1"
assert is_st[2] == 1, "ST_Y=1 -> IsST=1"
assert is_st[3] == 1, "ST_S=1, ST_Y=1 -> IsST=1"
# Test without ST flags (placeholder mode)
df_no_st = pl.DataFrame({
'datetime': ['2024-01-01'] * 2,
'instrument': ['SH600000', 'SH600001'],
'value': [1, 2],
})
df_out_no_st = processor.process(df_no_st)
assert df_out_no_st['IsST'].sum() == 0, "Placeholder IsST should be all zeros"
print(" PASSED")
def test_column_remover():
"""Test ColumnRemover."""
print("Testing ColumnRemover...")
df = create_test_data(n_dates=3, n_instruments=2)
processor = ColumnRemover(['IsZt', 'IsDt', 'nonexistent_column'])
df_out = processor.process(df)
assert 'IsZt' not in df_out.columns, "IsZt not removed"
assert 'IsDt' not in df_out.columns, "IsDt not removed"
assert 'IsN' in df_out.columns, "IsN should remain"
print(" PASSED")
def test_flag_to_onehot():
"""Test FlagToOnehot."""
print("Testing FlagToOnehot...")
df = pl.DataFrame({
'datetime': ['2024-01-01'] * 6,
'instrument': [f'SH60000{i}' for i in range(6)],
'gds_CC10': [1, 0, 0, 0, 0, 0],
'gds_CC11': [0, 1, 0, 0, 0, 0],
'gds_CC20': [0, 0, 1, 0, 0, 0],
'value': [1, 2, 3, 4, 5, 6],
})
processor = FlagToOnehot(['gds_CC10', 'gds_CC11', 'gds_CC20'])
df_out = processor.process(df)
assert 'indus_idx' in df_out.columns, "indus_idx not created"
assert 'gds_CC10' not in df_out.columns, "gds_CC10 not removed"
assert 'gds_CC11' not in df_out.columns, "gds_CC11 not removed"
assert 'gds_CC20' not in df_out.columns, "gds_CC20 not removed"
indus_idx = df_out['indus_idx'].to_list()
assert indus_idx[0] == 0, "gds_CC10=1 -> indus_idx=0"
assert indus_idx[1] == 1, "gds_CC11=1 -> indus_idx=1"
assert indus_idx[2] == 2, "gds_CC20=1 -> indus_idx=2"
print(" PASSED")
def test_indus_ntrl_injector():
"""Test IndusNtrlInjector."""
print("Testing IndusNtrlInjector...")
# Create data with clear industry groups (use string dates to avoid polars issues)
df = pl.DataFrame({
'datetime': ['2024-01-01'] * 6 + ['2024-01-02'] * 6,
'instrument': [f'SH60000{i}' for i in range(6)] * 2,
'indus_idx': [0, 0, 0, 1, 1, 1] * 2, # Two industries
'KMID': [1.0, 2.0, 3.0, 10.0, 20.0, 30.0] * 2, # Clear industry means
})
processor = IndusNtrlInjector(['KMID'], suffix='_ntrl')
df_out = processor.process(df)
assert 'KMID_ntrl' in df_out.columns, "KMID_ntrl not created"
# Industry 0 mean = 2.0, Industry 1 mean = 20.0
# Neutralized values should be centered around 0 within each industry
ntrl_values = df_out['KMID_ntrl'].to_list()
# Check that industry means are approximately 0 after neutralization
df_indus_0 = df_out.filter(pl.col('indus_idx') == 0)
df_indus_1 = df_out.filter(pl.col('indus_idx') == 1)
mean_0 = df_indus_0['KMID_ntrl'].mean()
mean_1 = df_indus_1['KMID_ntrl'].mean()
assert abs(mean_0) < 1e-6, f"Industry 0 mean should be ~0, got {mean_0}"
assert abs(mean_1) < 1e-6, f"Industry 1 mean should be ~0, got {mean_1}"
print(" PASSED")
def test_robust_zscore_norm():
"""Test RobustZScoreNorm."""
print("Testing RobustZScoreNorm...")
# Test with pre-fitted parameters
df = pl.DataFrame({
'datetime': ['2024-01-01'] * 10,
'instrument': [f'SH60000{i}' for i in range(10)],
'feat1': list(range(10)), # 0-9
})
# Pre-fitted params: mean=4.5, std=~2.87 for 0-9
mean_train = np.array([4.5])
std_train = np.array([2.87])
processor = RobustZScoreNorm(
['feat1'],
clip_range=(-3, 3),
use_qlib_params=True,
qlib_mean=mean_train,
qlib_std=std_train
)
df_out = processor.process(df)
# Verify normalization
# feat1=0 -> z = (0-4.5)/2.87 = -1.57
# feat1=9 -> z = (9-4.5)/2.87 = 1.57
assert df_out['feat1'].min() < -1.5, "Min z-score should be around -1.57"
assert df_out['feat1'].max() > 1.5, "Max z-score should be around 1.57"
# Test per-datetime mode
processor2 = RobustZScoreNorm(['feat1'], use_qlib_params=False)
df_out2 = processor2.process(df.clone())
# After per-datetime normalization, median should be ~0
median = df_out2['feat1'].median()
assert abs(median) < 0.1, f"Median should be ~0, got {median}"
print(" PASSED")
def test_fillna():
"""Test Fillna."""
print("Testing Fillna...")
df = pl.DataFrame({
'datetime': ['2024-01-01'] * 5,
'instrument': [f'SH60000{i}' for i in range(5)],
'feat1': [1.0, None, 3.0, float('nan'), 5.0],
})
processor = Fillna(fill_value=0.0)
df_out = processor.process(df, ['feat1'])
# Check no null/nan values remain
assert df_out['feat1'].null_count() == 0, "Null values remain"
assert not any(np.isnan(df_out['feat1'].to_numpy())), "NaN values remain"
print(" PASSED")
def test_full_pipeline():
"""Test complete processor pipeline."""
print("Testing full processor pipeline...")
df = create_test_data(n_dates=5, n_instruments=5, include_st_flags=True, include_industry=True)
original_shape = df.shape
# market_ext columns
market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
# Apply pipeline
# 1. Diff
diff_proc = DiffProcessor(market_ext_base)
df = diff_proc.process(df)
# 2. FlagMarketInjector
market_proc = FlagMarketInjector()
df = market_proc.process(df)
# 3. FlagSTInjector
st_proc = FlagSTInjector()
df = st_proc.process(df)
# 4. ColumnRemover
remove_proc = ColumnRemover(['log_size_diff', 'IsN', 'IsZt', 'IsDt'])
df = remove_proc.process(df)
# 5. FlagToOnehot
onehot_proc = FlagToOnehot(['gds_CC10', 'gds_CC11', 'gds_CC20'])
df = onehot_proc.process(df)
# 6. IndusNtrlInjector
ntrl_proc = IndusNtrlInjector(['KMID', 'KLEN', 'ROC5'], suffix='_ntrl')
df = ntrl_proc.process(df)
# 7. RobustZScoreNorm (per-datetime mode for testing)
norm_proc = RobustZScoreNorm(['KMID', 'KLEN', 'ROC5', 'KMID_ntrl', 'KLEN_ntrl', 'ROC5_ntrl'],
use_qlib_params=False)
df = norm_proc.process(df)
# 8. Fillna
fillna_proc = Fillna(fill_value=0.0)
df = fillna_proc.process(df, df.columns)
# Verify final structure
assert df.shape[0] == original_shape[0], "Row count changed"
assert 'indus_idx' in df.columns, "indus_idx missing"
assert 'market_0' in df.columns, "market_0 missing"
assert 'IsST' in df.columns, "IsST missing"
assert 'KMID_ntrl' in df.columns, "KMID_ntrl missing"
print(" PASSED")
def test_create_processor_pipeline():
"""Test create_processor_pipeline utility."""
print("Testing create_processor_pipeline...")
pipeline = create_processor_pipeline(
alpha158_cols=['KMID', 'KLEN'],
market_ext_base=['turnover', 'log_size'],
market_flag_cols=['IsZt', 'IsDt'],
industry_flag_cols=['gds_CC10', 'gds_CC11'],
)
assert len(pipeline) == 7, f"Expected 7 processors, got {len(pipeline)}"
assert isinstance(pipeline[0], DiffProcessor)
assert isinstance(pipeline[1], FlagMarketInjector)
assert isinstance(pipeline[2], FlagSTInjector)
assert isinstance(pipeline[3], ColumnRemover)
assert isinstance(pipeline[4], FlagToOnehot)
assert isinstance(pipeline[5], IndusNtrlInjector)
assert isinstance(pipeline[6], IndusNtrlInjector)
print(" PASSED")
def test_get_final_feature_columns():
"""Test get_final_feature_columns utility."""
print("Testing get_final_feature_columns...")
feature_struct = get_final_feature_columns(
alpha158_cols=['KMID', 'KLEN', 'ROC5'],
market_ext_base=['turnover', 'log_size'],
market_flag_cols=['IsZt', 'IsDt', 'IsN'],
)
# Verify structure
assert 'alpha158_cols' in feature_struct
assert 'alpha158_ntrl_cols' in feature_struct
assert 'market_ext_cols' in feature_struct
assert 'market_ext_ntrl_cols' in feature_struct
assert 'market_flag_cols' in feature_struct
assert 'norm_feature_cols' in feature_struct
assert 'all_feature_cols' in feature_struct
# Verify counts
assert len(feature_struct['alpha158_cols']) == 3
assert len(feature_struct['alpha158_ntrl_cols']) == 3
# market_ext: 2 base + 2 diff - 1 removed (log_size_diff) = 3
assert len(feature_struct['market_ext_cols']) == 3
# market_flag: IsZt, IsDt, IsN removed (3) + market_0, market_1, IsST added (3) = 3
assert len(feature_struct['market_flag_cols']) == 3
# Verify norm_feature_cols order (ntrl + raw for each group)
norm_cols = feature_struct['norm_feature_cols']
assert norm_cols[0] == 'KMID_ntrl', "First norm col should be KMID_ntrl"
assert norm_cols[3] == 'KMID', "Fourth norm col should be KMID (after 3 ntrl)"
print(" PASSED")
def main():
"""Run all tests."""
print("=" * 60)
print("Polars Processors Module Tests")
print("=" * 60)
print()
tests = [
test_diff_processor,
test_flag_market_injector,
test_flag_st_injector,
test_column_remover,
test_flag_to_onehot,
test_indus_ntrl_injector,
test_robust_zscore_norm,
test_fillna,
test_full_pipeline,
test_create_processor_pipeline,
test_get_final_feature_columns,
]
passed = 0
failed = 0
for test_func in tests:
try:
test_func()
passed += 1
except AssertionError as e:
print(f" FAILED: {e}")
failed += 1
except Exception as e:
print(f" ERROR: {e}")
failed += 1
print()
print("=" * 60)
print(f"Results: {passed} passed, {failed} failed")
print("=" * 60)
return failed == 0
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

@ -1,6 +1,6 @@
{ {
"version": "csiallx_feature2_ntrla_flag_pnlnorm", "version": "csiallx_feature2_ntrla_flag_pnlnorm",
"created_at": "2026-03-01T12:18:01.969109", "created_at": "2026-03-01T13:11:57.144613",
"source_file": "2013-01-01", "source_file": "2013-01-01",
"fit_start_time": "2013-01-01", "fit_start_time": "2013-01-01",
"fit_end_time": "2018-12-31", "fit_end_time": "2018-12-31",

Loading…
Cancel
Save