- Add comprehensive processors module documentation to cta_1d/README.md including usage examples for RobustZScoreNorm.from_version() and parameter extraction instructions - Add test_processors.py test script for validating processor functionality - Update metadata.json timestamp from parameter extraction Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>master
parent
89bd1a528e
commit
ea011090f8
@ -0,0 +1,508 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script for the Polars processors module.
|
||||
|
||||
This script verifies that all processors in cta_1d.src.processors work correctly
|
||||
and produce expected outputs.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from pathlib import Path
|
||||
|
||||
# Add cta_1d to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cta_1d.src.processors import (
|
||||
DiffProcessor,
|
||||
FlagMarketInjector,
|
||||
FlagSTInjector,
|
||||
ColumnRemover,
|
||||
FlagToOnehot,
|
||||
IndusNtrlInjector,
|
||||
RobustZScoreNorm,
|
||||
Fillna,
|
||||
create_processor_pipeline,
|
||||
get_final_feature_columns,
|
||||
)
|
||||
|
||||
|
||||
def create_test_data(
|
||||
n_dates: int = 10,
|
||||
n_instruments: int = 10,
|
||||
include_st_flags: bool = True,
|
||||
include_industry: bool = True,
|
||||
) -> pl.DataFrame:
|
||||
"""Create test DataFrame with realistic structure."""
|
||||
np.random.seed(42)
|
||||
|
||||
# Use string dates to avoid polars issues
|
||||
dates = [f'2024-01-{i+1:02d}' for i in range(n_dates)]
|
||||
instruments = [f'SH60000{i}' for i in range(n_instruments)]
|
||||
|
||||
data = {
|
||||
'datetime': [],
|
||||
'instrument': [],
|
||||
# Alpha158 features
|
||||
'KMID': [],
|
||||
'KLEN': [],
|
||||
'ROC5': [],
|
||||
# Market ext
|
||||
'turnover': [],
|
||||
'free_turnover': [],
|
||||
'log_size': [],
|
||||
'con_rating_strength': [],
|
||||
# Market flags
|
||||
'IsZt': [],
|
||||
'IsDt': [],
|
||||
'IsN': [],
|
||||
'IsXD': [],
|
||||
'IsXR': [],
|
||||
'IsDR': [],
|
||||
'open_limit': [],
|
||||
'close_limit': [],
|
||||
'low_limit': [],
|
||||
'open_stop': [],
|
||||
'close_stop': [],
|
||||
'high_stop': [],
|
||||
}
|
||||
|
||||
if include_st_flags:
|
||||
data['ST_S'] = []
|
||||
data['ST_Y'] = []
|
||||
|
||||
if include_industry:
|
||||
# Add 3 industry flags for testing
|
||||
data['gds_CC10'] = []
|
||||
data['gds_CC11'] = []
|
||||
data['gds_CC20'] = []
|
||||
|
||||
for d in dates:
|
||||
for inst in instruments:
|
||||
data['datetime'].append(d)
|
||||
data['instrument'].append(inst)
|
||||
data['KMID'].append(np.random.randn() * 0.1)
|
||||
data['KLEN'].append(np.random.randn())
|
||||
data['ROC5'].append(np.random.randn() * 0.05)
|
||||
data['turnover'].append(np.random.randn() * 1e6 + 1e7)
|
||||
data['free_turnover'].append(np.random.randn() * 1e6 + 1e7)
|
||||
data['log_size'].append(np.log(1e10 + np.random.randn() * 1e9))
|
||||
data['con_rating_strength'].append(np.random.randn())
|
||||
data['IsZt'].append(np.random.randint(0, 2))
|
||||
data['IsDt'].append(np.random.randint(0, 2))
|
||||
data['IsN'].append(np.random.randint(0, 2))
|
||||
data['IsXD'].append(np.random.randint(0, 2))
|
||||
data['IsXR'].append(np.random.randint(0, 2))
|
||||
data['IsDR'].append(np.random.randint(0, 2))
|
||||
data['open_limit'].append(np.random.randint(0, 2))
|
||||
data['close_limit'].append(np.random.randint(0, 2))
|
||||
data['low_limit'].append(np.random.randint(0, 2))
|
||||
data['open_stop'].append(np.random.randint(0, 2))
|
||||
data['close_stop'].append(np.random.randint(0, 2))
|
||||
data['high_stop'].append(np.random.randint(0, 2))
|
||||
|
||||
if include_st_flags:
|
||||
data['ST_S'].append(np.random.randint(0, 2))
|
||||
data['ST_Y'].append(np.random.randint(0, 2))
|
||||
|
||||
if include_industry:
|
||||
# Set exactly one industry to 1
|
||||
industry_choice = np.random.randint(0, 3)
|
||||
data['gds_CC10'].append(1 if industry_choice == 0 else 0)
|
||||
data['gds_CC11'].append(1 if industry_choice == 1 else 0)
|
||||
data['gds_CC20'].append(1 if industry_choice == 2 else 0)
|
||||
|
||||
return pl.DataFrame(data)
|
||||
|
||||
|
||||
def test_diff_processor():
|
||||
"""Test DiffProcessor."""
|
||||
print("Testing DiffProcessor...")
|
||||
df = create_test_data(n_dates=5, n_instruments=3)
|
||||
|
||||
processor = DiffProcessor(['turnover', 'log_size'])
|
||||
df_out = processor.process(df)
|
||||
|
||||
assert 'turnover_diff' in df_out.columns, "turnover_diff not created"
|
||||
assert 'log_size_diff' in df_out.columns, "log_size_diff not created"
|
||||
assert df_out.shape[0] == df.shape[0], "Row count changed"
|
||||
|
||||
# Verify diff is calculated per instrument
|
||||
inst_data = df_out.filter(pl.col('instrument') == 'SH600000')
|
||||
turnover_diff = inst_data['turnover_diff'].to_list()
|
||||
# First value should be null (no previous row to diff)
|
||||
assert turnover_diff[0] is None or np.isnan(turnover_diff[0]), "First diff should be null"
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
def test_flag_market_injector():
|
||||
"""Test FlagMarketInjector."""
|
||||
print("Testing FlagMarketInjector...")
|
||||
|
||||
# Test with different market types
|
||||
df = pl.DataFrame({
|
||||
'datetime': ['2024-01-01'] * 6,
|
||||
'instrument': [
|
||||
'SH600000', # SH main -> market_0=1
|
||||
'SH688000', # SH STAR -> market_1=1
|
||||
'SZ000001', # SZ main -> market_0=1
|
||||
'SZ300001', # SZ ChiNext -> market_1=1
|
||||
'NE400001', # NE (New Third Board) -> both 0
|
||||
'SH601000', # SH main -> market_0=1
|
||||
],
|
||||
'value': [1, 2, 3, 4, 5, 6],
|
||||
})
|
||||
|
||||
processor = FlagMarketInjector()
|
||||
df_out = processor.process(df)
|
||||
|
||||
assert 'market_0' in df_out.columns, "market_0 not created"
|
||||
assert 'market_1' in df_out.columns, "market_1 not created"
|
||||
|
||||
results = df_out.select(['instrument', 'market_0', 'market_1']).rows()
|
||||
|
||||
# Verify market classifications
|
||||
assert results[0][1] == 1 and results[0][2] == 0, "SH600000 should be market_0"
|
||||
assert results[1][1] == 0 and results[1][2] == 1, "SH688000 should be market_1"
|
||||
assert results[2][1] == 1 and results[2][2] == 0, "SZ000001 should be market_0"
|
||||
assert results[3][1] == 0 and results[3][2] == 1, "SZ300001 should be market_1"
|
||||
assert results[4][1] == 0 and results[4][2] == 0, "NE400001 should be neither"
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
def test_flag_st_injector():
|
||||
"""Test FlagSTInjector."""
|
||||
print("Testing FlagSTInjector...")
|
||||
|
||||
# Test with ST flags
|
||||
df = pl.DataFrame({
|
||||
'datetime': ['2024-01-01'] * 4,
|
||||
'instrument': ['SH600000', 'SH600001', 'SH600002', 'SH600003'],
|
||||
'ST_S': [0, 1, 0, 1],
|
||||
'ST_Y': [0, 0, 1, 1],
|
||||
})
|
||||
|
||||
processor = FlagSTInjector()
|
||||
df_out = processor.process(df)
|
||||
|
||||
assert 'IsST' in df_out.columns, "IsST not created"
|
||||
|
||||
is_st = df_out['IsST'].to_list()
|
||||
assert is_st[0] == 0, "ST_S=0, ST_Y=0 -> IsST=0"
|
||||
assert is_st[1] == 1, "ST_S=1 -> IsST=1"
|
||||
assert is_st[2] == 1, "ST_Y=1 -> IsST=1"
|
||||
assert is_st[3] == 1, "ST_S=1, ST_Y=1 -> IsST=1"
|
||||
|
||||
# Test without ST flags (placeholder mode)
|
||||
df_no_st = pl.DataFrame({
|
||||
'datetime': ['2024-01-01'] * 2,
|
||||
'instrument': ['SH600000', 'SH600001'],
|
||||
'value': [1, 2],
|
||||
})
|
||||
|
||||
df_out_no_st = processor.process(df_no_st)
|
||||
assert df_out_no_st['IsST'].sum() == 0, "Placeholder IsST should be all zeros"
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
def test_column_remover():
|
||||
"""Test ColumnRemover."""
|
||||
print("Testing ColumnRemover...")
|
||||
df = create_test_data(n_dates=3, n_instruments=2)
|
||||
|
||||
processor = ColumnRemover(['IsZt', 'IsDt', 'nonexistent_column'])
|
||||
df_out = processor.process(df)
|
||||
|
||||
assert 'IsZt' not in df_out.columns, "IsZt not removed"
|
||||
assert 'IsDt' not in df_out.columns, "IsDt not removed"
|
||||
assert 'IsN' in df_out.columns, "IsN should remain"
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
def test_flag_to_onehot():
|
||||
"""Test FlagToOnehot."""
|
||||
print("Testing FlagToOnehot...")
|
||||
|
||||
df = pl.DataFrame({
|
||||
'datetime': ['2024-01-01'] * 6,
|
||||
'instrument': [f'SH60000{i}' for i in range(6)],
|
||||
'gds_CC10': [1, 0, 0, 0, 0, 0],
|
||||
'gds_CC11': [0, 1, 0, 0, 0, 0],
|
||||
'gds_CC20': [0, 0, 1, 0, 0, 0],
|
||||
'value': [1, 2, 3, 4, 5, 6],
|
||||
})
|
||||
|
||||
processor = FlagToOnehot(['gds_CC10', 'gds_CC11', 'gds_CC20'])
|
||||
df_out = processor.process(df)
|
||||
|
||||
assert 'indus_idx' in df_out.columns, "indus_idx not created"
|
||||
assert 'gds_CC10' not in df_out.columns, "gds_CC10 not removed"
|
||||
assert 'gds_CC11' not in df_out.columns, "gds_CC11 not removed"
|
||||
assert 'gds_CC20' not in df_out.columns, "gds_CC20 not removed"
|
||||
|
||||
indus_idx = df_out['indus_idx'].to_list()
|
||||
assert indus_idx[0] == 0, "gds_CC10=1 -> indus_idx=0"
|
||||
assert indus_idx[1] == 1, "gds_CC11=1 -> indus_idx=1"
|
||||
assert indus_idx[2] == 2, "gds_CC20=1 -> indus_idx=2"
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
def test_indus_ntrl_injector():
|
||||
"""Test IndusNtrlInjector."""
|
||||
print("Testing IndusNtrlInjector...")
|
||||
|
||||
# Create data with clear industry groups (use string dates to avoid polars issues)
|
||||
df = pl.DataFrame({
|
||||
'datetime': ['2024-01-01'] * 6 + ['2024-01-02'] * 6,
|
||||
'instrument': [f'SH60000{i}' for i in range(6)] * 2,
|
||||
'indus_idx': [0, 0, 0, 1, 1, 1] * 2, # Two industries
|
||||
'KMID': [1.0, 2.0, 3.0, 10.0, 20.0, 30.0] * 2, # Clear industry means
|
||||
})
|
||||
|
||||
processor = IndusNtrlInjector(['KMID'], suffix='_ntrl')
|
||||
df_out = processor.process(df)
|
||||
|
||||
assert 'KMID_ntrl' in df_out.columns, "KMID_ntrl not created"
|
||||
|
||||
# Industry 0 mean = 2.0, Industry 1 mean = 20.0
|
||||
# Neutralized values should be centered around 0 within each industry
|
||||
ntrl_values = df_out['KMID_ntrl'].to_list()
|
||||
|
||||
# Check that industry means are approximately 0 after neutralization
|
||||
df_indus_0 = df_out.filter(pl.col('indus_idx') == 0)
|
||||
df_indus_1 = df_out.filter(pl.col('indus_idx') == 1)
|
||||
|
||||
mean_0 = df_indus_0['KMID_ntrl'].mean()
|
||||
mean_1 = df_indus_1['KMID_ntrl'].mean()
|
||||
|
||||
assert abs(mean_0) < 1e-6, f"Industry 0 mean should be ~0, got {mean_0}"
|
||||
assert abs(mean_1) < 1e-6, f"Industry 1 mean should be ~0, got {mean_1}"
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
def test_robust_zscore_norm():
|
||||
"""Test RobustZScoreNorm."""
|
||||
print("Testing RobustZScoreNorm...")
|
||||
|
||||
# Test with pre-fitted parameters
|
||||
df = pl.DataFrame({
|
||||
'datetime': ['2024-01-01'] * 10,
|
||||
'instrument': [f'SH60000{i}' for i in range(10)],
|
||||
'feat1': list(range(10)), # 0-9
|
||||
})
|
||||
|
||||
# Pre-fitted params: mean=4.5, std=~2.87 for 0-9
|
||||
mean_train = np.array([4.5])
|
||||
std_train = np.array([2.87])
|
||||
|
||||
processor = RobustZScoreNorm(
|
||||
['feat1'],
|
||||
clip_range=(-3, 3),
|
||||
use_qlib_params=True,
|
||||
qlib_mean=mean_train,
|
||||
qlib_std=std_train
|
||||
)
|
||||
df_out = processor.process(df)
|
||||
|
||||
# Verify normalization
|
||||
# feat1=0 -> z = (0-4.5)/2.87 = -1.57
|
||||
# feat1=9 -> z = (9-4.5)/2.87 = 1.57
|
||||
assert df_out['feat1'].min() < -1.5, "Min z-score should be around -1.57"
|
||||
assert df_out['feat1'].max() > 1.5, "Max z-score should be around 1.57"
|
||||
|
||||
# Test per-datetime mode
|
||||
processor2 = RobustZScoreNorm(['feat1'], use_qlib_params=False)
|
||||
df_out2 = processor2.process(df.clone())
|
||||
|
||||
# After per-datetime normalization, median should be ~0
|
||||
median = df_out2['feat1'].median()
|
||||
assert abs(median) < 0.1, f"Median should be ~0, got {median}"
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
def test_fillna():
|
||||
"""Test Fillna."""
|
||||
print("Testing Fillna...")
|
||||
|
||||
df = pl.DataFrame({
|
||||
'datetime': ['2024-01-01'] * 5,
|
||||
'instrument': [f'SH60000{i}' for i in range(5)],
|
||||
'feat1': [1.0, None, 3.0, float('nan'), 5.0],
|
||||
})
|
||||
|
||||
processor = Fillna(fill_value=0.0)
|
||||
df_out = processor.process(df, ['feat1'])
|
||||
|
||||
# Check no null/nan values remain
|
||||
assert df_out['feat1'].null_count() == 0, "Null values remain"
|
||||
assert not any(np.isnan(df_out['feat1'].to_numpy())), "NaN values remain"
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
def test_full_pipeline():
|
||||
"""Test complete processor pipeline."""
|
||||
print("Testing full processor pipeline...")
|
||||
|
||||
df = create_test_data(n_dates=5, n_instruments=5, include_st_flags=True, include_industry=True)
|
||||
original_shape = df.shape
|
||||
|
||||
# market_ext columns
|
||||
market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
|
||||
|
||||
# Apply pipeline
|
||||
# 1. Diff
|
||||
diff_proc = DiffProcessor(market_ext_base)
|
||||
df = diff_proc.process(df)
|
||||
|
||||
# 2. FlagMarketInjector
|
||||
market_proc = FlagMarketInjector()
|
||||
df = market_proc.process(df)
|
||||
|
||||
# 3. FlagSTInjector
|
||||
st_proc = FlagSTInjector()
|
||||
df = st_proc.process(df)
|
||||
|
||||
# 4. ColumnRemover
|
||||
remove_proc = ColumnRemover(['log_size_diff', 'IsN', 'IsZt', 'IsDt'])
|
||||
df = remove_proc.process(df)
|
||||
|
||||
# 5. FlagToOnehot
|
||||
onehot_proc = FlagToOnehot(['gds_CC10', 'gds_CC11', 'gds_CC20'])
|
||||
df = onehot_proc.process(df)
|
||||
|
||||
# 6. IndusNtrlInjector
|
||||
ntrl_proc = IndusNtrlInjector(['KMID', 'KLEN', 'ROC5'], suffix='_ntrl')
|
||||
df = ntrl_proc.process(df)
|
||||
|
||||
# 7. RobustZScoreNorm (per-datetime mode for testing)
|
||||
norm_proc = RobustZScoreNorm(['KMID', 'KLEN', 'ROC5', 'KMID_ntrl', 'KLEN_ntrl', 'ROC5_ntrl'],
|
||||
use_qlib_params=False)
|
||||
df = norm_proc.process(df)
|
||||
|
||||
# 8. Fillna
|
||||
fillna_proc = Fillna(fill_value=0.0)
|
||||
df = fillna_proc.process(df, df.columns)
|
||||
|
||||
# Verify final structure
|
||||
assert df.shape[0] == original_shape[0], "Row count changed"
|
||||
assert 'indus_idx' in df.columns, "indus_idx missing"
|
||||
assert 'market_0' in df.columns, "market_0 missing"
|
||||
assert 'IsST' in df.columns, "IsST missing"
|
||||
assert 'KMID_ntrl' in df.columns, "KMID_ntrl missing"
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
def test_create_processor_pipeline():
|
||||
"""Test create_processor_pipeline utility."""
|
||||
print("Testing create_processor_pipeline...")
|
||||
|
||||
pipeline = create_processor_pipeline(
|
||||
alpha158_cols=['KMID', 'KLEN'],
|
||||
market_ext_base=['turnover', 'log_size'],
|
||||
market_flag_cols=['IsZt', 'IsDt'],
|
||||
industry_flag_cols=['gds_CC10', 'gds_CC11'],
|
||||
)
|
||||
|
||||
assert len(pipeline) == 7, f"Expected 7 processors, got {len(pipeline)}"
|
||||
assert isinstance(pipeline[0], DiffProcessor)
|
||||
assert isinstance(pipeline[1], FlagMarketInjector)
|
||||
assert isinstance(pipeline[2], FlagSTInjector)
|
||||
assert isinstance(pipeline[3], ColumnRemover)
|
||||
assert isinstance(pipeline[4], FlagToOnehot)
|
||||
assert isinstance(pipeline[5], IndusNtrlInjector)
|
||||
assert isinstance(pipeline[6], IndusNtrlInjector)
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
def test_get_final_feature_columns():
|
||||
"""Test get_final_feature_columns utility."""
|
||||
print("Testing get_final_feature_columns...")
|
||||
|
||||
feature_struct = get_final_feature_columns(
|
||||
alpha158_cols=['KMID', 'KLEN', 'ROC5'],
|
||||
market_ext_base=['turnover', 'log_size'],
|
||||
market_flag_cols=['IsZt', 'IsDt', 'IsN'],
|
||||
)
|
||||
|
||||
# Verify structure
|
||||
assert 'alpha158_cols' in feature_struct
|
||||
assert 'alpha158_ntrl_cols' in feature_struct
|
||||
assert 'market_ext_cols' in feature_struct
|
||||
assert 'market_ext_ntrl_cols' in feature_struct
|
||||
assert 'market_flag_cols' in feature_struct
|
||||
assert 'norm_feature_cols' in feature_struct
|
||||
assert 'all_feature_cols' in feature_struct
|
||||
|
||||
# Verify counts
|
||||
assert len(feature_struct['alpha158_cols']) == 3
|
||||
assert len(feature_struct['alpha158_ntrl_cols']) == 3
|
||||
# market_ext: 2 base + 2 diff - 1 removed (log_size_diff) = 3
|
||||
assert len(feature_struct['market_ext_cols']) == 3
|
||||
# market_flag: IsZt, IsDt, IsN removed (3) + market_0, market_1, IsST added (3) = 3
|
||||
assert len(feature_struct['market_flag_cols']) == 3
|
||||
|
||||
# Verify norm_feature_cols order (ntrl + raw for each group)
|
||||
norm_cols = feature_struct['norm_feature_cols']
|
||||
assert norm_cols[0] == 'KMID_ntrl', "First norm col should be KMID_ntrl"
|
||||
assert norm_cols[3] == 'KMID', "Fourth norm col should be KMID (after 3 ntrl)"
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("=" * 60)
|
||||
print("Polars Processors Module Tests")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
tests = [
|
||||
test_diff_processor,
|
||||
test_flag_market_injector,
|
||||
test_flag_st_injector,
|
||||
test_column_remover,
|
||||
test_flag_to_onehot,
|
||||
test_indus_ntrl_injector,
|
||||
test_robust_zscore_norm,
|
||||
test_fillna,
|
||||
test_full_pipeline,
|
||||
test_create_processor_pipeline,
|
||||
test_get_final_feature_columns,
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test_func in tests:
|
||||
try:
|
||||
test_func()
|
||||
passed += 1
|
||||
except AssertionError as e:
|
||||
print(f" FAILED: {e}")
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
failed += 1
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print(f"Results: {passed} passed, {failed} failed")
|
||||
print("=" * 60)
|
||||
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
Loading…
Reference in new issue