diff --git a/cta_1d/README.md b/cta_1d/README.md index 664105a..ab93c01 100644 --- a/cta_1d/README.md +++ b/cta_1d/README.md @@ -34,3 +34,79 @@ Predefined weights (from qshare.config.research.cta.labels): - `long_term`: [0.4, 0.2, 0.2, 0.2] Default: [0.2, 0.1, 0.3, 0.4] + +## Processors Module + +The `cta_1d.src.processors` module provides Polars-based data processors that replicate Qlib's preprocessing pipeline: + +### Available Processors + +| Processor | Description | +|-----------|-------------| +| `DiffProcessor` | Adds diff features with configurable period | +| `FlagMarketInjector` | Adds market_0, market_1 columns based on instrument codes | +| `FlagSTInjector` | Creates IsST column from ST flags | +| `ColumnRemover` | Removes specified columns | +| `FlagToOnehot` | Converts one-hot industry flags to single index column | +| `IndusNtrlInjector` | Industry neutralization per datetime | +| `RobustZScoreNorm` | Robust z-score normalization using median/MAD | +| `Fillna` | Fills NaN values with specified value | + +### RobustZScoreNorm with Pre-fitted Parameters + +The `RobustZScoreNorm` processor supports loading pre-fitted parameters from Qlib's `proc_list.proc`: + +```python +from cta_1d.src.processors import RobustZScoreNorm + +# Method 1: Load from saved version (recommended) +processor = RobustZScoreNorm.from_version("csiallx_feature2_ntrla_flag_pnlnorm") + +# Method 2: Load with direct parameters +processor = RobustZScoreNorm( + feature_cols=['KMID', 'KLEN', ...], + use_qlib_params=True, + qlib_mean=mean_array, + qlib_std=std_array +) + +# Apply normalization +df = processor.process(df) +``` + +### Parameter Extraction + +Extract parameters from Qlib's proc_list.proc: + +```bash +python stock_1d/d033/alpha158_beta/scripts/extract_qlib_params.py \ + --proc-list /path/to/proc_list.proc \ + --version my_version +``` + +Output structure: +``` +data/robust_zscore_params/{version}/ +├── mean_train.npy # Pre-fitted mean (330,) +├── std_train.npy # Pre-fitted std (330,) +└── metadata.json # Feature columns and metadata +``` + +### Pipeline Helper Functions + +```python +from cta_1d.src.processors import create_processor_pipeline, get_final_feature_columns + +# Create pipeline from processor configs +pipeline = create_processor_pipeline([ + {'type': 'Diff', 'columns': ['turnover', 'free_turnover']}, + {'type': 'RobustZScoreNorm', 'feature_cols': feature_cols}, + {'type': 'Fillna', 'value': 0}, +]) + +# Get final feature columns after industry neutralization +final_cols = get_final_feature_columns( + alpha158_cols=ALPHA158_COLS, + market_ext_cols=MARKET_EXT_COLS, +) +``` diff --git a/cta_1d/src/processors/test_processors.py b/cta_1d/src/processors/test_processors.py new file mode 100644 index 0000000..ada3c77 --- /dev/null +++ b/cta_1d/src/processors/test_processors.py @@ -0,0 +1,508 @@ +#!/usr/bin/env python +""" +Test script for the Polars processors module. + +This script verifies that all processors in cta_1d.src.processors work correctly +and produce expected outputs. +""" + +import sys +import numpy as np +import polars as pl +from pathlib import Path + +# Add cta_1d to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from cta_1d.src.processors import ( + DiffProcessor, + FlagMarketInjector, + FlagSTInjector, + ColumnRemover, + FlagToOnehot, + IndusNtrlInjector, + RobustZScoreNorm, + Fillna, + create_processor_pipeline, + get_final_feature_columns, +) + + +def create_test_data( + n_dates: int = 10, + n_instruments: int = 10, + include_st_flags: bool = True, + include_industry: bool = True, +) -> pl.DataFrame: + """Create test DataFrame with realistic structure.""" + np.random.seed(42) + + # Use string dates to avoid polars issues + dates = [f'2024-01-{i+1:02d}' for i in range(n_dates)] + instruments = [f'SH60000{i}' for i in range(n_instruments)] + + data = { + 'datetime': [], + 'instrument': [], + # Alpha158 features + 'KMID': [], + 'KLEN': [], + 'ROC5': [], + # Market ext + 'turnover': [], + 'free_turnover': [], + 'log_size': [], + 'con_rating_strength': [], + # Market flags + 'IsZt': [], + 'IsDt': [], + 'IsN': [], + 'IsXD': [], + 'IsXR': [], + 'IsDR': [], + 'open_limit': [], + 'close_limit': [], + 'low_limit': [], + 'open_stop': [], + 'close_stop': [], + 'high_stop': [], + } + + if include_st_flags: + data['ST_S'] = [] + data['ST_Y'] = [] + + if include_industry: + # Add 3 industry flags for testing + data['gds_CC10'] = [] + data['gds_CC11'] = [] + data['gds_CC20'] = [] + + for d in dates: + for inst in instruments: + data['datetime'].append(d) + data['instrument'].append(inst) + data['KMID'].append(np.random.randn() * 0.1) + data['KLEN'].append(np.random.randn()) + data['ROC5'].append(np.random.randn() * 0.05) + data['turnover'].append(np.random.randn() * 1e6 + 1e7) + data['free_turnover'].append(np.random.randn() * 1e6 + 1e7) + data['log_size'].append(np.log(1e10 + np.random.randn() * 1e9)) + data['con_rating_strength'].append(np.random.randn()) + data['IsZt'].append(np.random.randint(0, 2)) + data['IsDt'].append(np.random.randint(0, 2)) + data['IsN'].append(np.random.randint(0, 2)) + data['IsXD'].append(np.random.randint(0, 2)) + data['IsXR'].append(np.random.randint(0, 2)) + data['IsDR'].append(np.random.randint(0, 2)) + data['open_limit'].append(np.random.randint(0, 2)) + data['close_limit'].append(np.random.randint(0, 2)) + data['low_limit'].append(np.random.randint(0, 2)) + data['open_stop'].append(np.random.randint(0, 2)) + data['close_stop'].append(np.random.randint(0, 2)) + data['high_stop'].append(np.random.randint(0, 2)) + + if include_st_flags: + data['ST_S'].append(np.random.randint(0, 2)) + data['ST_Y'].append(np.random.randint(0, 2)) + + if include_industry: + # Set exactly one industry to 1 + industry_choice = np.random.randint(0, 3) + data['gds_CC10'].append(1 if industry_choice == 0 else 0) + data['gds_CC11'].append(1 if industry_choice == 1 else 0) + data['gds_CC20'].append(1 if industry_choice == 2 else 0) + + return pl.DataFrame(data) + + +def test_diff_processor(): + """Test DiffProcessor.""" + print("Testing DiffProcessor...") + df = create_test_data(n_dates=5, n_instruments=3) + + processor = DiffProcessor(['turnover', 'log_size']) + df_out = processor.process(df) + + assert 'turnover_diff' in df_out.columns, "turnover_diff not created" + assert 'log_size_diff' in df_out.columns, "log_size_diff not created" + assert df_out.shape[0] == df.shape[0], "Row count changed" + + # Verify diff is calculated per instrument + inst_data = df_out.filter(pl.col('instrument') == 'SH600000') + turnover_diff = inst_data['turnover_diff'].to_list() + # First value should be null (no previous row to diff) + assert turnover_diff[0] is None or np.isnan(turnover_diff[0]), "First diff should be null" + + print(" PASSED") + + +def test_flag_market_injector(): + """Test FlagMarketInjector.""" + print("Testing FlagMarketInjector...") + + # Test with different market types + df = pl.DataFrame({ + 'datetime': ['2024-01-01'] * 6, + 'instrument': [ + 'SH600000', # SH main -> market_0=1 + 'SH688000', # SH STAR -> market_1=1 + 'SZ000001', # SZ main -> market_0=1 + 'SZ300001', # SZ ChiNext -> market_1=1 + 'NE400001', # NE (New Third Board) -> both 0 + 'SH601000', # SH main -> market_0=1 + ], + 'value': [1, 2, 3, 4, 5, 6], + }) + + processor = FlagMarketInjector() + df_out = processor.process(df) + + assert 'market_0' in df_out.columns, "market_0 not created" + assert 'market_1' in df_out.columns, "market_1 not created" + + results = df_out.select(['instrument', 'market_0', 'market_1']).rows() + + # Verify market classifications + assert results[0][1] == 1 and results[0][2] == 0, "SH600000 should be market_0" + assert results[1][1] == 0 and results[1][2] == 1, "SH688000 should be market_1" + assert results[2][1] == 1 and results[2][2] == 0, "SZ000001 should be market_0" + assert results[3][1] == 0 and results[3][2] == 1, "SZ300001 should be market_1" + assert results[4][1] == 0 and results[4][2] == 0, "NE400001 should be neither" + + print(" PASSED") + + +def test_flag_st_injector(): + """Test FlagSTInjector.""" + print("Testing FlagSTInjector...") + + # Test with ST flags + df = pl.DataFrame({ + 'datetime': ['2024-01-01'] * 4, + 'instrument': ['SH600000', 'SH600001', 'SH600002', 'SH600003'], + 'ST_S': [0, 1, 0, 1], + 'ST_Y': [0, 0, 1, 1], + }) + + processor = FlagSTInjector() + df_out = processor.process(df) + + assert 'IsST' in df_out.columns, "IsST not created" + + is_st = df_out['IsST'].to_list() + assert is_st[0] == 0, "ST_S=0, ST_Y=0 -> IsST=0" + assert is_st[1] == 1, "ST_S=1 -> IsST=1" + assert is_st[2] == 1, "ST_Y=1 -> IsST=1" + assert is_st[3] == 1, "ST_S=1, ST_Y=1 -> IsST=1" + + # Test without ST flags (placeholder mode) + df_no_st = pl.DataFrame({ + 'datetime': ['2024-01-01'] * 2, + 'instrument': ['SH600000', 'SH600001'], + 'value': [1, 2], + }) + + df_out_no_st = processor.process(df_no_st) + assert df_out_no_st['IsST'].sum() == 0, "Placeholder IsST should be all zeros" + + print(" PASSED") + + +def test_column_remover(): + """Test ColumnRemover.""" + print("Testing ColumnRemover...") + df = create_test_data(n_dates=3, n_instruments=2) + + processor = ColumnRemover(['IsZt', 'IsDt', 'nonexistent_column']) + df_out = processor.process(df) + + assert 'IsZt' not in df_out.columns, "IsZt not removed" + assert 'IsDt' not in df_out.columns, "IsDt not removed" + assert 'IsN' in df_out.columns, "IsN should remain" + + print(" PASSED") + + +def test_flag_to_onehot(): + """Test FlagToOnehot.""" + print("Testing FlagToOnehot...") + + df = pl.DataFrame({ + 'datetime': ['2024-01-01'] * 6, + 'instrument': [f'SH60000{i}' for i in range(6)], + 'gds_CC10': [1, 0, 0, 0, 0, 0], + 'gds_CC11': [0, 1, 0, 0, 0, 0], + 'gds_CC20': [0, 0, 1, 0, 0, 0], + 'value': [1, 2, 3, 4, 5, 6], + }) + + processor = FlagToOnehot(['gds_CC10', 'gds_CC11', 'gds_CC20']) + df_out = processor.process(df) + + assert 'indus_idx' in df_out.columns, "indus_idx not created" + assert 'gds_CC10' not in df_out.columns, "gds_CC10 not removed" + assert 'gds_CC11' not in df_out.columns, "gds_CC11 not removed" + assert 'gds_CC20' not in df_out.columns, "gds_CC20 not removed" + + indus_idx = df_out['indus_idx'].to_list() + assert indus_idx[0] == 0, "gds_CC10=1 -> indus_idx=0" + assert indus_idx[1] == 1, "gds_CC11=1 -> indus_idx=1" + assert indus_idx[2] == 2, "gds_CC20=1 -> indus_idx=2" + + print(" PASSED") + + +def test_indus_ntrl_injector(): + """Test IndusNtrlInjector.""" + print("Testing IndusNtrlInjector...") + + # Create data with clear industry groups (use string dates to avoid polars issues) + df = pl.DataFrame({ + 'datetime': ['2024-01-01'] * 6 + ['2024-01-02'] * 6, + 'instrument': [f'SH60000{i}' for i in range(6)] * 2, + 'indus_idx': [0, 0, 0, 1, 1, 1] * 2, # Two industries + 'KMID': [1.0, 2.0, 3.0, 10.0, 20.0, 30.0] * 2, # Clear industry means + }) + + processor = IndusNtrlInjector(['KMID'], suffix='_ntrl') + df_out = processor.process(df) + + assert 'KMID_ntrl' in df_out.columns, "KMID_ntrl not created" + + # Industry 0 mean = 2.0, Industry 1 mean = 20.0 + # Neutralized values should be centered around 0 within each industry + ntrl_values = df_out['KMID_ntrl'].to_list() + + # Check that industry means are approximately 0 after neutralization + df_indus_0 = df_out.filter(pl.col('indus_idx') == 0) + df_indus_1 = df_out.filter(pl.col('indus_idx') == 1) + + mean_0 = df_indus_0['KMID_ntrl'].mean() + mean_1 = df_indus_1['KMID_ntrl'].mean() + + assert abs(mean_0) < 1e-6, f"Industry 0 mean should be ~0, got {mean_0}" + assert abs(mean_1) < 1e-6, f"Industry 1 mean should be ~0, got {mean_1}" + + print(" PASSED") + + +def test_robust_zscore_norm(): + """Test RobustZScoreNorm.""" + print("Testing RobustZScoreNorm...") + + # Test with pre-fitted parameters + df = pl.DataFrame({ + 'datetime': ['2024-01-01'] * 10, + 'instrument': [f'SH60000{i}' for i in range(10)], + 'feat1': list(range(10)), # 0-9 + }) + + # Pre-fitted params: mean=4.5, std=~2.87 for 0-9 + mean_train = np.array([4.5]) + std_train = np.array([2.87]) + + processor = RobustZScoreNorm( + ['feat1'], + clip_range=(-3, 3), + use_qlib_params=True, + qlib_mean=mean_train, + qlib_std=std_train + ) + df_out = processor.process(df) + + # Verify normalization + # feat1=0 -> z = (0-4.5)/2.87 = -1.57 + # feat1=9 -> z = (9-4.5)/2.87 = 1.57 + assert df_out['feat1'].min() < -1.5, "Min z-score should be around -1.57" + assert df_out['feat1'].max() > 1.5, "Max z-score should be around 1.57" + + # Test per-datetime mode + processor2 = RobustZScoreNorm(['feat1'], use_qlib_params=False) + df_out2 = processor2.process(df.clone()) + + # After per-datetime normalization, median should be ~0 + median = df_out2['feat1'].median() + assert abs(median) < 0.1, f"Median should be ~0, got {median}" + + print(" PASSED") + + +def test_fillna(): + """Test Fillna.""" + print("Testing Fillna...") + + df = pl.DataFrame({ + 'datetime': ['2024-01-01'] * 5, + 'instrument': [f'SH60000{i}' for i in range(5)], + 'feat1': [1.0, None, 3.0, float('nan'), 5.0], + }) + + processor = Fillna(fill_value=0.0) + df_out = processor.process(df, ['feat1']) + + # Check no null/nan values remain + assert df_out['feat1'].null_count() == 0, "Null values remain" + assert not any(np.isnan(df_out['feat1'].to_numpy())), "NaN values remain" + + print(" PASSED") + + +def test_full_pipeline(): + """Test complete processor pipeline.""" + print("Testing full processor pipeline...") + + df = create_test_data(n_dates=5, n_instruments=5, include_st_flags=True, include_industry=True) + original_shape = df.shape + + # market_ext columns + market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + + # Apply pipeline + # 1. Diff + diff_proc = DiffProcessor(market_ext_base) + df = diff_proc.process(df) + + # 2. FlagMarketInjector + market_proc = FlagMarketInjector() + df = market_proc.process(df) + + # 3. FlagSTInjector + st_proc = FlagSTInjector() + df = st_proc.process(df) + + # 4. ColumnRemover + remove_proc = ColumnRemover(['log_size_diff', 'IsN', 'IsZt', 'IsDt']) + df = remove_proc.process(df) + + # 5. FlagToOnehot + onehot_proc = FlagToOnehot(['gds_CC10', 'gds_CC11', 'gds_CC20']) + df = onehot_proc.process(df) + + # 6. IndusNtrlInjector + ntrl_proc = IndusNtrlInjector(['KMID', 'KLEN', 'ROC5'], suffix='_ntrl') + df = ntrl_proc.process(df) + + # 7. RobustZScoreNorm (per-datetime mode for testing) + norm_proc = RobustZScoreNorm(['KMID', 'KLEN', 'ROC5', 'KMID_ntrl', 'KLEN_ntrl', 'ROC5_ntrl'], + use_qlib_params=False) + df = norm_proc.process(df) + + # 8. Fillna + fillna_proc = Fillna(fill_value=0.0) + df = fillna_proc.process(df, df.columns) + + # Verify final structure + assert df.shape[0] == original_shape[0], "Row count changed" + assert 'indus_idx' in df.columns, "indus_idx missing" + assert 'market_0' in df.columns, "market_0 missing" + assert 'IsST' in df.columns, "IsST missing" + assert 'KMID_ntrl' in df.columns, "KMID_ntrl missing" + + print(" PASSED") + + +def test_create_processor_pipeline(): + """Test create_processor_pipeline utility.""" + print("Testing create_processor_pipeline...") + + pipeline = create_processor_pipeline( + alpha158_cols=['KMID', 'KLEN'], + market_ext_base=['turnover', 'log_size'], + market_flag_cols=['IsZt', 'IsDt'], + industry_flag_cols=['gds_CC10', 'gds_CC11'], + ) + + assert len(pipeline) == 7, f"Expected 7 processors, got {len(pipeline)}" + assert isinstance(pipeline[0], DiffProcessor) + assert isinstance(pipeline[1], FlagMarketInjector) + assert isinstance(pipeline[2], FlagSTInjector) + assert isinstance(pipeline[3], ColumnRemover) + assert isinstance(pipeline[4], FlagToOnehot) + assert isinstance(pipeline[5], IndusNtrlInjector) + assert isinstance(pipeline[6], IndusNtrlInjector) + + print(" PASSED") + + +def test_get_final_feature_columns(): + """Test get_final_feature_columns utility.""" + print("Testing get_final_feature_columns...") + + feature_struct = get_final_feature_columns( + alpha158_cols=['KMID', 'KLEN', 'ROC5'], + market_ext_base=['turnover', 'log_size'], + market_flag_cols=['IsZt', 'IsDt', 'IsN'], + ) + + # Verify structure + assert 'alpha158_cols' in feature_struct + assert 'alpha158_ntrl_cols' in feature_struct + assert 'market_ext_cols' in feature_struct + assert 'market_ext_ntrl_cols' in feature_struct + assert 'market_flag_cols' in feature_struct + assert 'norm_feature_cols' in feature_struct + assert 'all_feature_cols' in feature_struct + + # Verify counts + assert len(feature_struct['alpha158_cols']) == 3 + assert len(feature_struct['alpha158_ntrl_cols']) == 3 + # market_ext: 2 base + 2 diff - 1 removed (log_size_diff) = 3 + assert len(feature_struct['market_ext_cols']) == 3 + # market_flag: IsZt, IsDt, IsN removed (3) + market_0, market_1, IsST added (3) = 3 + assert len(feature_struct['market_flag_cols']) == 3 + + # Verify norm_feature_cols order (ntrl + raw for each group) + norm_cols = feature_struct['norm_feature_cols'] + assert norm_cols[0] == 'KMID_ntrl', "First norm col should be KMID_ntrl" + assert norm_cols[3] == 'KMID', "Fourth norm col should be KMID (after 3 ntrl)" + + print(" PASSED") + + +def main(): + """Run all tests.""" + print("=" * 60) + print("Polars Processors Module Tests") + print("=" * 60) + print() + + tests = [ + test_diff_processor, + test_flag_market_injector, + test_flag_st_injector, + test_column_remover, + test_flag_to_onehot, + test_indus_ntrl_injector, + test_robust_zscore_norm, + test_fillna, + test_full_pipeline, + test_create_processor_pipeline, + test_get_final_feature_columns, + ] + + passed = 0 + failed = 0 + + for test_func in tests: + try: + test_func() + passed += 1 + except AssertionError as e: + print(f" FAILED: {e}") + failed += 1 + except Exception as e: + print(f" ERROR: {e}") + failed += 1 + + print() + print("=" * 60) + print(f"Results: {passed} passed, {failed} failed") + print("=" * 60) + + return failed == 0 + + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1) diff --git a/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json b/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json index 390a6ce..298f22d 100644 --- a/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json +++ b/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json @@ -1,6 +1,6 @@ { "version": "csiallx_feature2_ntrla_flag_pnlnorm", - "created_at": "2026-03-01T12:18:01.969109", + "created_at": "2026-03-01T13:11:57.144613", "source_file": "2013-01-01", "fit_start_time": "2013-01-01", "fit_end_time": "2018-12-31",