- Add extract_qlib_params.py script to extract pre-fitted mean/std parameters from Qlib's proc_list.proc and save as reusable .npy files with metadata.json - Add RobustZScoreNorm.from_version() class method to load saved parameters by version name, supporting multiple parameter versions coexistence - Update dump_polars_dataset.py to use from_version() instead of loading parameters directly from proc_list.proc - Update generate_beta_embedding.py to use qshare's filter_instruments for stock universe filtering - Save parameters to data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/ with 330 features (158 alpha158_ntrl + 158 alpha158_raw + 7 market_ext_ntrl + 7 market_ext_raw) - Update README.md with documentation for parameter extraction and usage Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>master
parent
8bd36c1939
commit
89bd1a528e
@ -0,0 +1,753 @@
|
||||
"""
|
||||
Polars-based data processors for financial feature transformation.
|
||||
|
||||
This module provides Polars implementations of Qlib-style data processors
|
||||
used in the data_ops pipeline. Each processor follows a consistent interface:
|
||||
- Takes a Polars DataFrame as input
|
||||
- Returns a transformed Polars DataFrame
|
||||
|
||||
Processors are organized by category:
|
||||
- Feature Engineering: DiffProcessor
|
||||
- Flag Injection: FlagMarketInjector, FlagSTInjector
|
||||
- Column Operations: ColumnRemover, FlagToOnehot
|
||||
- Normalization: IndusNtrlInjector, RobustZScoreNorm
|
||||
- Data Cleaning: Fillna
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Feature Engineering
|
||||
'DiffProcessor',
|
||||
# Flag Injection
|
||||
'FlagMarketInjector',
|
||||
'FlagSTInjector',
|
||||
# Column Operations
|
||||
'ColumnRemover',
|
||||
'FlagToOnehot',
|
||||
# Normalization
|
||||
'IndusNtrlInjector',
|
||||
'RobustZScoreNorm',
|
||||
# Data Cleaning
|
||||
'Fillna',
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Feature Engineering Processors
|
||||
# =============================================================================
|
||||
|
||||
class DiffProcessor:
|
||||
"""
|
||||
Calculate period-over-period differences within each instrument.
|
||||
|
||||
For each specified column, calculates the diff(periods) within each
|
||||
instrument group, using forward-fill for NaN handling.
|
||||
|
||||
Attributes:
|
||||
columns: List of columns to calculate diff for
|
||||
suffix: Suffix to append to diff column names (default: 'diff')
|
||||
periods: Number of periods to diff (default: 1)
|
||||
|
||||
Example:
|
||||
>>> processor = DiffProcessor(['turnover', 'log_size'])
|
||||
>>> df = processor.process(df)
|
||||
>>> # Creates: turnover_diff, log_size_diff
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
columns: List[str],
|
||||
suffix: str = 'diff',
|
||||
periods: int = 1
|
||||
):
|
||||
self.columns = columns
|
||||
self.suffix = suffix
|
||||
self.periods = periods
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Add diff features for specified columns.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with datetime and instrument columns
|
||||
|
||||
Returns:
|
||||
DataFrame with original columns + diff columns with suffix
|
||||
"""
|
||||
# Sort by instrument and datetime for correct diff calculation
|
||||
df = df.sort(['instrument', 'datetime'])
|
||||
|
||||
# Add diff for each column
|
||||
for col in self.columns:
|
||||
if col in df.columns:
|
||||
diff_col = f"{col}_{self.suffix}"
|
||||
df = df.with_columns([
|
||||
pl.col(col)
|
||||
.diff(self.periods)
|
||||
.over('instrument')
|
||||
.alias(diff_col)
|
||||
])
|
||||
|
||||
return df
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DiffProcessor(columns={self.columns}, suffix='{self.suffix}', periods={self.periods})"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Flag Injection Processors
|
||||
# =============================================================================
|
||||
|
||||
class FlagMarketInjector:
|
||||
"""
|
||||
Inject market sector flags based on instrument codes.
|
||||
|
||||
Classifies stocks into market segments based on their instrument codes.
|
||||
Supports both formats:
|
||||
- With exchange prefix: SH600000, SZ000001, SH688000, SZ300001
|
||||
- Numeric only: 600000, 000001, 688000, 300001
|
||||
|
||||
Market classification:
|
||||
- market_0: Main board (SH60xxx, SZ00xxx, or 6xxxxx, 0xxxxx) - 主板
|
||||
- market_1: STAR/ChiNext (SH688xxx, SH689xxx, SZ300xxx, SZ301xxx,
|
||||
or 688xxx, 689xxx, 300xxx, 301xxx) - 科创板/创业板
|
||||
|
||||
Note: Does NOT include 新三板/北交所 (NE40xxx, NE42xxx, NE43xxx, NE8xxx)
|
||||
in the classification - these stocks will have 0 for both flags.
|
||||
|
||||
Example:
|
||||
>>> processor = FlagMarketInjector()
|
||||
>>> df = processor.process(df)
|
||||
>>> # Adds: market_0, market_1 (both int8)
|
||||
"""
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Add market classification columns.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with instrument column
|
||||
|
||||
Returns:
|
||||
DataFrame with market_0, market_1 columns added
|
||||
"""
|
||||
# Convert instrument to string
|
||||
inst_col = pl.col('instrument').cast(pl.String)
|
||||
|
||||
# Remove exchange prefix if present (SH/SZ/NE -> numeric part)
|
||||
# E.g., SH600000 -> 600000, SZ000001 -> 000001
|
||||
inst_numeric = inst_col.str.replace_all("^(SH|SZ|NE)", "")
|
||||
|
||||
# Get first digit(s) for market classification
|
||||
first_digit = inst_numeric.str.slice(0, 1)
|
||||
first_three = inst_numeric.str.slice(0, 3)
|
||||
|
||||
# market_0: Main board (60xxx, 601xxx, 603xxx, 000xxx, 001xxx, 002xxx)
|
||||
# Excludes STAR market (688xxx, 689xxx) which start with '6' but are not main board
|
||||
is_sh_main = (first_digit == '6') & ~(first_three == '688') & ~(first_three == '689')
|
||||
is_sz_main = first_digit == '0'
|
||||
|
||||
# market_1: STAR/ChiNext (688xxx, 689xxx, 300xxx, 301xxx)
|
||||
is_sh_star = (first_three == '688') | (first_three == '689')
|
||||
is_sz_chi = (first_three == '300') | (first_three == '301')
|
||||
|
||||
df = df.with_columns([
|
||||
# market_0 = 主板
|
||||
(is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'),
|
||||
# market_1 = 科创板 + 创业板
|
||||
(is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1')
|
||||
])
|
||||
|
||||
return df
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "FlagMarketInjector()"
|
||||
|
||||
|
||||
class FlagSTInjector:
|
||||
"""
|
||||
Inject ST (Special Treatment) flag.
|
||||
|
||||
Creates IsST column from ST_S and ST_Y flags:
|
||||
- IsST = 1 if ST_S or ST_Y is True (stock is ST or *ST)
|
||||
- IsST = 0 otherwise
|
||||
|
||||
If ST flags are not available, creates a placeholder column (all zeros).
|
||||
|
||||
Attributes:
|
||||
mark_st_as: Value to mark ST stocks as (default: 1)
|
||||
|
||||
Example:
|
||||
>>> processor = FlagSTInjector()
|
||||
>>> df = processor.process(df)
|
||||
>>> # Adds: IsST (int8)
|
||||
"""
|
||||
|
||||
def __init__(self, mark_st_as: int = 1):
|
||||
self.mark_st_as = mark_st_as
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Add IsST column from ST flags.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with ST_S, ST_Y columns (or without for placeholder)
|
||||
|
||||
Returns:
|
||||
DataFrame with IsST column added
|
||||
"""
|
||||
# Check if ST flags are available
|
||||
if 'ST_S' in df.columns and 'ST_Y' in df.columns:
|
||||
# Create IsST from actual ST flags
|
||||
df = df.with_columns([
|
||||
((pl.col('ST_S').cast(pl.Boolean, strict=False) |
|
||||
pl.col('ST_Y').cast(pl.Boolean, strict=False))
|
||||
.cast(pl.Int8).alias('IsST'))
|
||||
])
|
||||
else:
|
||||
# Create placeholder (all zeros) if ST flags not available
|
||||
df = df.with_columns([
|
||||
pl.lit(0).cast(pl.Int8).alias('IsST')
|
||||
])
|
||||
|
||||
return df
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"FlagSTInjector(mark_st_as={self.mark_st_as})"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Column Operation Processors
|
||||
# =============================================================================
|
||||
|
||||
class ColumnRemover:
|
||||
"""
|
||||
Remove specified columns from the DataFrame.
|
||||
|
||||
Attributes:
|
||||
columns_to_remove: List of column names to drop
|
||||
|
||||
Example:
|
||||
>>> processor = ColumnRemover(['log_size_diff', 'IsN', 'IsZt', 'IsDt'])
|
||||
>>> df = processor.process(df)
|
||||
>>> # Removes specified columns
|
||||
"""
|
||||
|
||||
def __init__(self, columns_to_remove: List[str]):
|
||||
self.columns_to_remove = columns_to_remove
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Remove specified columns.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame
|
||||
|
||||
Returns:
|
||||
DataFrame without specified columns
|
||||
"""
|
||||
# Only remove columns that exist
|
||||
cols_to_drop = [c for c in self.columns_to_remove if c in df.columns]
|
||||
if cols_to_drop:
|
||||
df = df.drop(cols_to_drop)
|
||||
|
||||
return df
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ColumnRemover(columns_to_remove={self.columns_to_remove})"
|
||||
|
||||
|
||||
class FlagToOnehot:
|
||||
"""
|
||||
Convert flag columns to one-hot encoded index.
|
||||
|
||||
For multiple one-hot encoded columns, finds which flag is set and
|
||||
returns the corresponding index. Uses -1 as default for rows with
|
||||
no flags set.
|
||||
|
||||
Attributes:
|
||||
flag_columns: List of boolean flag column names
|
||||
|
||||
Example:
|
||||
>>> processor = FlagToOnehot(['gds_CC10', 'gds_CC11', ...])
|
||||
>>> df = processor.process(df)
|
||||
>>> # Adds: indus_idx (index of first True flag, or -1)
|
||||
"""
|
||||
|
||||
def __init__(self, flag_columns: List[str]):
|
||||
self.flag_columns = flag_columns
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Convert flag columns to single index column.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with flag columns
|
||||
|
||||
Returns:
|
||||
DataFrame with indus_idx column added, original flags dropped
|
||||
"""
|
||||
# Build a when/then chain to find the industry index
|
||||
# Start with -1 (no industry) as default
|
||||
indus_expr = pl.lit(-1)
|
||||
|
||||
for idx, col in enumerate(self.flag_columns):
|
||||
if col in df.columns:
|
||||
indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr)
|
||||
|
||||
df = df.with_columns([indus_expr.alias('indus_idx')])
|
||||
|
||||
# Drop the original one-hot columns
|
||||
cols_to_drop = [c for c in self.flag_columns if c in df.columns]
|
||||
if cols_to_drop:
|
||||
df = df.drop(cols_to_drop)
|
||||
|
||||
return df
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"FlagToOnehot(flag_columns={len(self.flag_columns)} columns)"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Normalization Processors
|
||||
# =============================================================================
|
||||
|
||||
class IndusNtrlInjector:
|
||||
"""
|
||||
Industry neutralization for features.
|
||||
|
||||
For each feature, subtracts the industry mean (grouped by indus_idx
|
||||
within each datetime) from the feature value. Creates new columns
|
||||
with the specified suffix while keeping original columns.
|
||||
|
||||
This performs cross-sectional neutralization per datetime, matching
|
||||
Qlib's cal_indus_ntrl behavior.
|
||||
|
||||
Attributes:
|
||||
feature_cols: List of feature columns to neutralize
|
||||
suffix: Suffix for neutralized column names (default: '_ntrl')
|
||||
|
||||
Example:
|
||||
>>> processor = IndusNtrlInjector(['KMID', 'KLEN'], suffix='_ntrl')
|
||||
>>> df = processor.process(df)
|
||||
>>> # Adds: KMID_ntrl, KLEN_ntrl
|
||||
"""
|
||||
|
||||
def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'):
|
||||
self.feature_cols = feature_cols
|
||||
self.suffix = suffix
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Apply industry neutralization to specified features.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with feature columns and indus_idx
|
||||
|
||||
Returns:
|
||||
DataFrame with neutralized columns added (_ntrl suffix)
|
||||
"""
|
||||
# Filter to only columns that exist
|
||||
existing_cols = [c for c in self.feature_cols if c in df.columns]
|
||||
|
||||
for col in existing_cols:
|
||||
ntrl_col = f"{col}{self.suffix}"
|
||||
# Calculate industry mean PER DATETIME and subtract from feature
|
||||
# Use group_by().transform() for proper group-wise operation
|
||||
df = df.with_columns([
|
||||
(pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'])).alias(ntrl_col)
|
||||
])
|
||||
|
||||
return df
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"IndusNtrlInjector(feature_cols={len(self.feature_cols)} columns, suffix='{self.suffix}')"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobustZScoreNorm:
|
||||
"""
|
||||
Robust z-score normalization using median/MAD.
|
||||
|
||||
Applies the transformation: (x - median) / (1.4826 * MAD)
|
||||
where MAD = median(|x - median|)
|
||||
|
||||
Supports two modes:
|
||||
1. Per-datetime computation (default): Calculates median/MAD for each datetime
|
||||
2. Pre-fitted parameters: Uses provided mean/std arrays (from Qlib processor)
|
||||
|
||||
Attributes:
|
||||
feature_cols: List of feature columns to normalize
|
||||
clip_range: Clip normalized values to this range (default: (-3, 3))
|
||||
use_qlib_params: Use pre-fitted parameters (default: False)
|
||||
qlib_mean: Pre-fitted mean array (required if use_qlib_params=True)
|
||||
qlib_std: Pre-fitted std array (required if use_qlib_params=True)
|
||||
|
||||
Example:
|
||||
# Using pre-fitted Qlib parameters
|
||||
>>> processor = RobustZScoreNorm(
|
||||
... feature_cols=['KMID', 'KLEN'],
|
||||
... use_qlib_params=True,
|
||||
... qlib_mean=mean_array,
|
||||
... qlib_std=std_array
|
||||
... )
|
||||
>>> df = processor.process(df)
|
||||
|
||||
# Loading parameters from saved version
|
||||
>>> processor = RobustZScoreNorm.from_version("csiallx_feature2_ntrla_flag_pnlnorm")
|
||||
>>> df = processor.process(df)
|
||||
"""
|
||||
|
||||
feature_cols: List[str]
|
||||
clip_range: Tuple[float, float] = (-3.0, 3.0)
|
||||
use_qlib_params: bool = False
|
||||
qlib_mean: Optional[np.ndarray] = None
|
||||
qlib_std: Optional[np.ndarray] = None
|
||||
|
||||
@classmethod
|
||||
def from_version(
|
||||
cls,
|
||||
version: str,
|
||||
feature_cols: Optional[List[str]] = None,
|
||||
clip_range: Tuple[float, float] = (-3.0, 3.0),
|
||||
params_dir: str = None
|
||||
) -> "RobustZScoreNorm":
|
||||
"""
|
||||
Create a RobustZScoreNorm instance from saved parameters by version name.
|
||||
|
||||
This loads pre-extracted mean_train and std_train from the parameter
|
||||
directory structure:
|
||||
{params_dir}/{version}/
|
||||
├── mean_train.npy
|
||||
├── std_train.npy
|
||||
└── metadata.json
|
||||
|
||||
Args:
|
||||
version: Version name (e.g., "csiallx_feature2_ntrla_flag_pnlnorm")
|
||||
feature_cols: Optional list of feature columns. If None, uses the
|
||||
order from metadata.json (alpha158_ntrl + alpha158_raw +
|
||||
market_ext_ntrl + market_ext_raw)
|
||||
clip_range: Clip range for normalized values (default: (-3, 3))
|
||||
params_dir: Base directory for parameter versions. If None, uses:
|
||||
stock_1d/d033/alpha158_beta/data/robust_zscore_params/
|
||||
|
||||
Returns:
|
||||
RobustZScoreNorm instance with pre-fitted parameters loaded
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If version directory or parameter files not found
|
||||
ValueError: If feature column count doesn't match parameter shape
|
||||
|
||||
Example:
|
||||
>>> processor = RobustZScoreNorm.from_version(
|
||||
... "csiallx_feature2_ntrla_flag_pnlnorm"
|
||||
... )
|
||||
>>> df = processor.process(df)
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Set default params_dir
|
||||
if params_dir is None:
|
||||
# Default to the standard location
|
||||
# Go from cta_1d/src/processors/ to alpha_lab/stock_1d/d033/alpha158_beta/data/robust_zscore_params/
|
||||
params_dir = Path(__file__).parent.parent.parent.parent / \
|
||||
"stock_1d" / "d033" / "alpha158_beta" / "data" / "robust_zscore_params"
|
||||
else:
|
||||
params_dir = Path(params_dir)
|
||||
|
||||
version_dir = params_dir / version
|
||||
|
||||
# Check version directory exists
|
||||
if not version_dir.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Version directory not found: {version_dir}\n"
|
||||
f"Available versions should be in: {params_dir}"
|
||||
)
|
||||
|
||||
# Load mean_train.npy
|
||||
mean_path = version_dir / "mean_train.npy"
|
||||
if not mean_path.exists():
|
||||
raise FileNotFoundError(f"mean_train.npy not found: {mean_path}")
|
||||
mean_train = np.load(mean_path)
|
||||
|
||||
# Load std_train.npy
|
||||
std_path = version_dir / "std_train.npy"
|
||||
if not std_path.exists():
|
||||
raise FileNotFoundError(f"std_train.npy not found: {std_path}")
|
||||
std_train = np.load(std_path)
|
||||
|
||||
# Load metadata.json for feature column names
|
||||
metadata_path = version_dir / "metadata.json"
|
||||
if metadata_path.exists():
|
||||
with open(metadata_path, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Build feature columns from metadata if not provided
|
||||
if feature_cols is None:
|
||||
feature_columns = metadata.get('feature_columns', {})
|
||||
alpha158_ntrl = [f"{c}_ntrl" for c in feature_columns.get('alpha158_ntrl', [])]
|
||||
alpha158_raw = feature_columns.get('alpha158_raw', [])
|
||||
market_ext_ntrl = [f"{c}_ntrl" for c in feature_columns.get('market_ext_ntrl', [])]
|
||||
market_ext_raw = feature_columns.get('market_ext_raw', [])
|
||||
|
||||
feature_cols = alpha158_ntrl + alpha158_raw + market_ext_ntrl + market_ext_raw
|
||||
|
||||
# Validate feature column count matches parameter shape
|
||||
expected_count = len(mean_train)
|
||||
if feature_cols and len(feature_cols) != expected_count:
|
||||
raise ValueError(
|
||||
f"Feature column count ({len(feature_cols)}) does not match "
|
||||
f"parameter shape ({expected_count})"
|
||||
)
|
||||
|
||||
print(f"Loaded RobustZScoreNorm parameters from version '{version}':")
|
||||
print(f" mean_train shape: {mean_train.shape}")
|
||||
print(f" std_train shape: {std_train.shape}")
|
||||
print(f" feature_cols: {len(feature_cols) if feature_cols else 'not specified'}")
|
||||
|
||||
return cls(
|
||||
feature_cols=feature_cols,
|
||||
clip_range=clip_range,
|
||||
use_qlib_params=True,
|
||||
qlib_mean=mean_train,
|
||||
qlib_std=std_train
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate parameters after initialization."""
|
||||
if self.use_qlib_params:
|
||||
if self.qlib_mean is None or self.qlib_std is None:
|
||||
raise ValueError(
|
||||
"Must provide qlib_mean and qlib_std when use_qlib_params=True"
|
||||
)
|
||||
# Convert to numpy arrays if not already
|
||||
self.qlib_mean = np.asarray(self.qlib_mean)
|
||||
self.qlib_std = np.asarray(self.qlib_std)
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Apply robust z-score normalization.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with feature columns
|
||||
|
||||
Returns:
|
||||
DataFrame with normalized columns (in-place modification)
|
||||
"""
|
||||
# Filter to only columns that exist
|
||||
existing_cols = [c for c in self.feature_cols if c in df.columns]
|
||||
|
||||
if self.use_qlib_params:
|
||||
# Use pre-fitted parameters (fit once, apply to all dates)
|
||||
for i, col in enumerate(existing_cols):
|
||||
if i < len(self.qlib_mean):
|
||||
mean_val = float(self.qlib_mean[i])
|
||||
std_val = float(self.qlib_std[i])
|
||||
|
||||
# Apply z-score normalization using pre-fitted params
|
||||
df = df.with_columns([
|
||||
((pl.col(col) - mean_val) / (std_val + 1e-8))
|
||||
.clip(self.clip_range[0], self.clip_range[1])
|
||||
.alias(col)
|
||||
])
|
||||
else:
|
||||
# Compute per-datetime robust z-score
|
||||
for col in existing_cols:
|
||||
# Compute median per datetime
|
||||
median_col = f"__median_{col}"
|
||||
df = df.with_columns([
|
||||
pl.col(col).median().over('datetime').alias(median_col)
|
||||
])
|
||||
|
||||
# Compute absolute deviation
|
||||
abs_dev_col = f"__absdev_{col}"
|
||||
df = df.with_columns([
|
||||
(pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col)
|
||||
])
|
||||
|
||||
# Compute MAD (median of absolute deviations)
|
||||
mad_col = f"__mad_{col}"
|
||||
df = df.with_columns([
|
||||
pl.col(abs_dev_col).median().over('datetime').alias(mad_col)
|
||||
])
|
||||
|
||||
# Compute robust z-score and clip
|
||||
df = df.with_columns([
|
||||
((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8))
|
||||
.clip(self.clip_range[0], self.clip_range[1])
|
||||
.alias(col)
|
||||
])
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop([median_col, abs_dev_col, mad_col])
|
||||
|
||||
return df
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Cleaning Processors
|
||||
# =============================================================================
|
||||
|
||||
class Fillna:
|
||||
"""
|
||||
Fill NaN values with specified value.
|
||||
|
||||
Fills NaN/None values in specified columns with the fill_value.
|
||||
Only processes numeric columns (Float32, Float64, Int32, Int64, UInt32, UInt64).
|
||||
|
||||
Attributes:
|
||||
fill_value: Value to fill NaN with (default: 0.0)
|
||||
|
||||
Example:
|
||||
>>> processor = Fillna(fill_value=0.0)
|
||||
>>> df = processor.process(df, ['KMID', 'KLEN'])
|
||||
"""
|
||||
|
||||
def __init__(self, fill_value: float = 0.0):
|
||||
self.fill_value = fill_value
|
||||
|
||||
def process(self, df: pl.DataFrame, columns: List[str]) -> pl.DataFrame:
|
||||
"""
|
||||
Fill NaN values in specified columns.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame
|
||||
columns: List of columns to fill NaN in
|
||||
|
||||
Returns:
|
||||
DataFrame with NaN values filled
|
||||
"""
|
||||
# Filter to only columns that exist and are numeric
|
||||
existing_cols = [c for c in columns if c in df.columns]
|
||||
|
||||
for col in existing_cols:
|
||||
# Check column dtype
|
||||
dtype = df[col].dtype
|
||||
if dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64,
|
||||
pl.UInt32, pl.UInt64, pl.UInt16, pl.UInt8]:
|
||||
df = df.with_columns([
|
||||
pl.col(col).fill_null(self.fill_value).fill_nan(self.fill_value)
|
||||
])
|
||||
|
||||
return df
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Fillna(fill_value={self.fill_value})"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Processor Pipeline Utilities
|
||||
# =============================================================================
|
||||
|
||||
def create_processor_pipeline(
|
||||
alpha158_cols: List[str],
|
||||
market_ext_base: List[str],
|
||||
market_flag_cols: List[str],
|
||||
industry_flag_cols: List[str],
|
||||
columns_to_remove: Optional[List[str]] = None,
|
||||
ntrl_suffix: str = '_ntrl'
|
||||
) -> List:
|
||||
"""
|
||||
Create a complete processor pipeline configuration.
|
||||
|
||||
This factory function creates processors in the correct order:
|
||||
1. DiffProcessor - adds diff features
|
||||
2. FlagMarketInjector - adds market_0, market_1
|
||||
3. FlagSTInjector - adds IsST
|
||||
4. ColumnRemover - removes specified columns
|
||||
5. FlagToOnehot - converts industry flags to index
|
||||
6. IndusNtrlInjector (x2) - neutralizes alpha158 and market_ext
|
||||
7. RobustZScoreNorm - normalizes features
|
||||
8. Fillna - fills NaN values
|
||||
|
||||
Args:
|
||||
alpha158_cols: List of alpha158 feature names
|
||||
market_ext_base: List of market extension base columns
|
||||
market_flag_cols: List of market flag columns
|
||||
industry_flag_cols: List of industry flag columns
|
||||
columns_to_remove: Columns to remove (default: ['log_size_diff', 'IsN', 'IsZt', 'IsDt'])
|
||||
ntrl_suffix: Suffix for neutralized columns
|
||||
|
||||
Returns:
|
||||
List of processor instances in execution order
|
||||
"""
|
||||
if columns_to_remove is None:
|
||||
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
|
||||
|
||||
return [
|
||||
DiffProcessor(market_ext_base),
|
||||
FlagMarketInjector(),
|
||||
FlagSTInjector(),
|
||||
ColumnRemover(columns_to_remove),
|
||||
FlagToOnehot(industry_flag_cols),
|
||||
IndusNtrlInjector(alpha158_cols, suffix=ntrl_suffix),
|
||||
IndusNtrlInjector(market_ext_base, suffix=ntrl_suffix),
|
||||
# RobustZScoreNorm and Fillna require fitted parameters
|
||||
# and should be added separately
|
||||
]
|
||||
|
||||
|
||||
def get_final_feature_columns(
|
||||
alpha158_cols: List[str],
|
||||
market_ext_base: List[str],
|
||||
market_flag_cols: List[str],
|
||||
columns_to_remove: Optional[List[str]] = None,
|
||||
ntrl_suffix: str = '_ntrl'
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get the final feature column structure after processing.
|
||||
|
||||
This is useful for determining the expected VAE input dimensions
|
||||
and verifying feature order.
|
||||
|
||||
Args:
|
||||
alpha158_cols: List of alpha158 feature names
|
||||
market_ext_base: List of market extension base columns (before Diff)
|
||||
market_flag_cols: List of market flag columns (before ColumnRemover)
|
||||
columns_to_remove: Columns to remove
|
||||
ntrl_suffix: Suffix for neutralized columns
|
||||
|
||||
Returns:
|
||||
Dictionary with feature groups:
|
||||
- 'norm_feature_cols': Features to normalize (in order)
|
||||
- 'market_flag_cols': Market flag columns after processing
|
||||
- 'all_feature_cols': All feature columns including indus_idx
|
||||
"""
|
||||
if columns_to_remove is None:
|
||||
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
|
||||
|
||||
# After Diff: market_ext becomes base + diff
|
||||
market_ext_after_diff = market_ext_base + [f"{c}_diff" for c in market_ext_base]
|
||||
|
||||
# After ColumnRemover: remove specified columns
|
||||
market_ext_final = [c for c in market_ext_after_diff if c not in columns_to_remove]
|
||||
market_flag_final = [c for c in market_flag_cols if c not in columns_to_remove]
|
||||
|
||||
# After FlagMarketInjector: add market_0, market_1
|
||||
market_flag_final.extend(['market_0', 'market_1'])
|
||||
|
||||
# After FlagSTInjector: add IsST
|
||||
market_flag_final.append('IsST')
|
||||
|
||||
# Build normalized feature columns (in Qlib order: ntrl + raw for each group)
|
||||
alpha158_ntrl = [f"{c}{ntrl_suffix}" for c in alpha158_cols]
|
||||
market_ext_ntrl = [f"{c}{ntrl_suffix}" for c in market_ext_final]
|
||||
|
||||
# Normalization order: alpha158_ntrl + alpha158 + market_ext_ntrl + market_ext
|
||||
norm_feature_cols = alpha158_ntrl + alpha158_cols + market_ext_ntrl + market_ext_final
|
||||
|
||||
return {
|
||||
'alpha158_cols': alpha158_cols,
|
||||
'alpha158_ntrl_cols': alpha158_ntrl,
|
||||
'market_ext_cols': market_ext_final,
|
||||
'market_ext_ntrl_cols': market_ext_ntrl,
|
||||
'market_flag_cols': market_flag_final,
|
||||
'norm_feature_cols': norm_feature_cols,
|
||||
'all_feature_cols': norm_feature_cols + market_flag_final + ['indus_idx'],
|
||||
}
|
||||
@ -0,0 +1,366 @@
|
||||
{
|
||||
"version": "csiallx_feature2_ntrla_flag_pnlnorm",
|
||||
"created_at": "2026-03-01T12:18:01.969109",
|
||||
"source_file": "2013-01-01",
|
||||
"fit_start_time": "2013-01-01",
|
||||
"fit_end_time": "2018-12-31",
|
||||
"fields_group": [
|
||||
"feature",
|
||||
"feature_ext"
|
||||
],
|
||||
"feature_columns": {
|
||||
"alpha158_ntrl": [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
"OPEN0",
|
||||
"HIGH0",
|
||||
"LOW0",
|
||||
"VWAP0",
|
||||
"ROC5",
|
||||
"ROC10",
|
||||
"ROC20",
|
||||
"ROC30",
|
||||
"ROC60",
|
||||
"MA5",
|
||||
"MA10",
|
||||
"MA20",
|
||||
"MA30",
|
||||
"MA60",
|
||||
"STD5",
|
||||
"STD10",
|
||||
"STD20",
|
||||
"STD30",
|
||||
"STD60",
|
||||
"BETA5",
|
||||
"BETA10",
|
||||
"BETA20",
|
||||
"BETA30",
|
||||
"BETA60",
|
||||
"RSQR5",
|
||||
"RSQR10",
|
||||
"RSQR20",
|
||||
"RSQR30",
|
||||
"RSQR60",
|
||||
"RESI5",
|
||||
"RESI10",
|
||||
"RESI20",
|
||||
"RESI30",
|
||||
"RESI60",
|
||||
"MAX5",
|
||||
"MAX10",
|
||||
"MAX20",
|
||||
"MAX30",
|
||||
"MAX60",
|
||||
"MIN5",
|
||||
"MIN10",
|
||||
"MIN20",
|
||||
"MIN30",
|
||||
"MIN60",
|
||||
"QTLU5",
|
||||
"QTLU10",
|
||||
"QTLU20",
|
||||
"QTLU30",
|
||||
"QTLU60",
|
||||
"QTLD5",
|
||||
"QTLD10",
|
||||
"QTLD20",
|
||||
"QTLD30",
|
||||
"QTLD60",
|
||||
"RANK5",
|
||||
"RANK10",
|
||||
"RANK20",
|
||||
"RANK30",
|
||||
"RANK60",
|
||||
"RSV5",
|
||||
"RSV10",
|
||||
"RSV20",
|
||||
"RSV30",
|
||||
"RSV60",
|
||||
"IMAX5",
|
||||
"IMAX10",
|
||||
"IMAX20",
|
||||
"IMAX30",
|
||||
"IMAX60",
|
||||
"IMIN5",
|
||||
"IMIN10",
|
||||
"IMIN20",
|
||||
"IMIN30",
|
||||
"IMIN60",
|
||||
"IMXD5",
|
||||
"IMXD10",
|
||||
"IMXD20",
|
||||
"IMXD30",
|
||||
"IMXD60",
|
||||
"CORR5",
|
||||
"CORR10",
|
||||
"CORR20",
|
||||
"CORR30",
|
||||
"CORR60",
|
||||
"CORD5",
|
||||
"CORD10",
|
||||
"CORD20",
|
||||
"CORD30",
|
||||
"CORD60",
|
||||
"CNTP5",
|
||||
"CNTP10",
|
||||
"CNTP20",
|
||||
"CNTP30",
|
||||
"CNTP60",
|
||||
"CNTN5",
|
||||
"CNTN10",
|
||||
"CNTN20",
|
||||
"CNTN30",
|
||||
"CNTN60",
|
||||
"CNTD5",
|
||||
"CNTD10",
|
||||
"CNTD20",
|
||||
"CNTD30",
|
||||
"CNTD60",
|
||||
"SUMP5",
|
||||
"SUMP10",
|
||||
"SUMP20",
|
||||
"SUMP30",
|
||||
"SUMP60",
|
||||
"SUMN5",
|
||||
"SUMN10",
|
||||
"SUMN20",
|
||||
"SUMN30",
|
||||
"SUMN60",
|
||||
"SUMD5",
|
||||
"SUMD10",
|
||||
"SUMD20",
|
||||
"SUMD30",
|
||||
"SUMD60",
|
||||
"VMA5",
|
||||
"VMA10",
|
||||
"VMA20",
|
||||
"VMA30",
|
||||
"VMA60",
|
||||
"VSTD5",
|
||||
"VSTD10",
|
||||
"VSTD20",
|
||||
"VSTD30",
|
||||
"VSTD60",
|
||||
"WVMA5",
|
||||
"WVMA10",
|
||||
"WVMA20",
|
||||
"WVMA30",
|
||||
"WVMA60",
|
||||
"VSUMP5",
|
||||
"VSUMP10",
|
||||
"VSUMP20",
|
||||
"VSUMP30",
|
||||
"VSUMP60",
|
||||
"VSUMN5",
|
||||
"VSUMN10",
|
||||
"VSUMN20",
|
||||
"VSUMN30",
|
||||
"VSUMN60",
|
||||
"VSUMD5",
|
||||
"VSUMD10",
|
||||
"VSUMD20",
|
||||
"VSUMD30",
|
||||
"VSUMD60"
|
||||
],
|
||||
"alpha158_raw": [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
"OPEN0",
|
||||
"HIGH0",
|
||||
"LOW0",
|
||||
"VWAP0",
|
||||
"ROC5",
|
||||
"ROC10",
|
||||
"ROC20",
|
||||
"ROC30",
|
||||
"ROC60",
|
||||
"MA5",
|
||||
"MA10",
|
||||
"MA20",
|
||||
"MA30",
|
||||
"MA60",
|
||||
"STD5",
|
||||
"STD10",
|
||||
"STD20",
|
||||
"STD30",
|
||||
"STD60",
|
||||
"BETA5",
|
||||
"BETA10",
|
||||
"BETA20",
|
||||
"BETA30",
|
||||
"BETA60",
|
||||
"RSQR5",
|
||||
"RSQR10",
|
||||
"RSQR20",
|
||||
"RSQR30",
|
||||
"RSQR60",
|
||||
"RESI5",
|
||||
"RESI10",
|
||||
"RESI20",
|
||||
"RESI30",
|
||||
"RESI60",
|
||||
"MAX5",
|
||||
"MAX10",
|
||||
"MAX20",
|
||||
"MAX30",
|
||||
"MAX60",
|
||||
"MIN5",
|
||||
"MIN10",
|
||||
"MIN20",
|
||||
"MIN30",
|
||||
"MIN60",
|
||||
"QTLU5",
|
||||
"QTLU10",
|
||||
"QTLU20",
|
||||
"QTLU30",
|
||||
"QTLU60",
|
||||
"QTLD5",
|
||||
"QTLD10",
|
||||
"QTLD20",
|
||||
"QTLD30",
|
||||
"QTLD60",
|
||||
"RANK5",
|
||||
"RANK10",
|
||||
"RANK20",
|
||||
"RANK30",
|
||||
"RANK60",
|
||||
"RSV5",
|
||||
"RSV10",
|
||||
"RSV20",
|
||||
"RSV30",
|
||||
"RSV60",
|
||||
"IMAX5",
|
||||
"IMAX10",
|
||||
"IMAX20",
|
||||
"IMAX30",
|
||||
"IMAX60",
|
||||
"IMIN5",
|
||||
"IMIN10",
|
||||
"IMIN20",
|
||||
"IMIN30",
|
||||
"IMIN60",
|
||||
"IMXD5",
|
||||
"IMXD10",
|
||||
"IMXD20",
|
||||
"IMXD30",
|
||||
"IMXD60",
|
||||
"CORR5",
|
||||
"CORR10",
|
||||
"CORR20",
|
||||
"CORR30",
|
||||
"CORR60",
|
||||
"CORD5",
|
||||
"CORD10",
|
||||
"CORD20",
|
||||
"CORD30",
|
||||
"CORD60",
|
||||
"CNTP5",
|
||||
"CNTP10",
|
||||
"CNTP20",
|
||||
"CNTP30",
|
||||
"CNTP60",
|
||||
"CNTN5",
|
||||
"CNTN10",
|
||||
"CNTN20",
|
||||
"CNTN30",
|
||||
"CNTN60",
|
||||
"CNTD5",
|
||||
"CNTD10",
|
||||
"CNTD20",
|
||||
"CNTD30",
|
||||
"CNTD60",
|
||||
"SUMP5",
|
||||
"SUMP10",
|
||||
"SUMP20",
|
||||
"SUMP30",
|
||||
"SUMP60",
|
||||
"SUMN5",
|
||||
"SUMN10",
|
||||
"SUMN20",
|
||||
"SUMN30",
|
||||
"SUMN60",
|
||||
"SUMD5",
|
||||
"SUMD10",
|
||||
"SUMD20",
|
||||
"SUMD30",
|
||||
"SUMD60",
|
||||
"VMA5",
|
||||
"VMA10",
|
||||
"VMA20",
|
||||
"VMA30",
|
||||
"VMA60",
|
||||
"VSTD5",
|
||||
"VSTD10",
|
||||
"VSTD20",
|
||||
"VSTD30",
|
||||
"VSTD60",
|
||||
"WVMA5",
|
||||
"WVMA10",
|
||||
"WVMA20",
|
||||
"WVMA30",
|
||||
"WVMA60",
|
||||
"VSUMP5",
|
||||
"VSUMP10",
|
||||
"VSUMP20",
|
||||
"VSUMP30",
|
||||
"VSUMP60",
|
||||
"VSUMN5",
|
||||
"VSUMN10",
|
||||
"VSUMN20",
|
||||
"VSUMN30",
|
||||
"VSUMN60",
|
||||
"VSUMD5",
|
||||
"VSUMD10",
|
||||
"VSUMD20",
|
||||
"VSUMD30",
|
||||
"VSUMD60"
|
||||
],
|
||||
"market_ext_ntrl": [
|
||||
"turnover",
|
||||
"free_turnover",
|
||||
"log_size",
|
||||
"con_rating_strength",
|
||||
"turnover_diff",
|
||||
"free_turnover_diff",
|
||||
"con_rating_strength_diff"
|
||||
],
|
||||
"market_ext_raw": [
|
||||
"turnover",
|
||||
"free_turnover",
|
||||
"log_size",
|
||||
"con_rating_strength",
|
||||
"turnover_diff",
|
||||
"free_turnover_diff",
|
||||
"con_rating_strength_diff"
|
||||
]
|
||||
},
|
||||
"feature_count": {
|
||||
"alpha158_ntrl": 158,
|
||||
"alpha158_raw": 158,
|
||||
"market_ext_ntrl": 7,
|
||||
"market_ext_raw": 7,
|
||||
"total": 330
|
||||
},
|
||||
"parameter_shapes": {
|
||||
"mean_train": [
|
||||
330
|
||||
],
|
||||
"std_train": [
|
||||
330
|
||||
]
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,305 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Extract RobustZScoreNorm parameters from Qlib's proc_list.proc file.
|
||||
|
||||
This script extracts the pre-fitted mean_train and std_train parameters from
|
||||
Qlib's RobustZScoreNorm processor and saves them as reusable .npy files.
|
||||
|
||||
The extracted parameters can be used by the Polars RobustZScoreNorm processor
|
||||
via the from_version() class method.
|
||||
|
||||
Output structure:
|
||||
stock_1d/d033/alpha158_beta/data/robust_zscore_params/
|
||||
└── {version}/
|
||||
├── mean_train.npy
|
||||
├── std_train.npy
|
||||
└── metadata.json
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle as pkl
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
# Default paths
|
||||
DEFAULT_PROC_LIST_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/dataset/csiallx_feature2_ntrla_flag_pnlnorm/proc_list.proc"
|
||||
DEFAULT_VERSION = "csiallx_feature2_ntrla_flag_pnlnorm"
|
||||
|
||||
# Alpha158 columns in order (158 features)
|
||||
ALPHA158_COLS = [
|
||||
'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
|
||||
'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
|
||||
'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
|
||||
'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
|
||||
'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
|
||||
'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
|
||||
'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
|
||||
'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
|
||||
'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
|
||||
'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
|
||||
'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
|
||||
'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
|
||||
'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
|
||||
'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
|
||||
'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
|
||||
'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
|
||||
'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
|
||||
'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
|
||||
'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
|
||||
'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
|
||||
'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
|
||||
'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
|
||||
'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
|
||||
'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
|
||||
'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
|
||||
'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
|
||||
'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
|
||||
'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
|
||||
'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
|
||||
'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
|
||||
'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
|
||||
]
|
||||
assert len(ALPHA158_COLS) == 158, f"Expected 158 alpha158 cols, got {len(ALPHA158_COLS)}"
|
||||
|
||||
|
||||
def extract_robust_zscore_params(proc_list_path: str) -> dict:
|
||||
"""
|
||||
Extract RobustZScoreNorm parameters from Qlib's proc_list.proc file.
|
||||
|
||||
Args:
|
||||
proc_list_path: Path to the proc_list.proc pickle file
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- mean_train: numpy array of shape (330,)
|
||||
- std_train: numpy array of shape (330,)
|
||||
- fit_start_time: datetime string
|
||||
- fit_end_time: datetime string
|
||||
- fields_group: list of field groups
|
||||
"""
|
||||
print(f"Loading proc_list.proc from: {proc_list_path}")
|
||||
|
||||
with open(proc_list_path, 'rb') as f:
|
||||
proc_list = pkl.load(f)
|
||||
|
||||
print(f"Loaded {len(proc_list)} processors from proc_list")
|
||||
|
||||
# Find RobustZScoreNorm processor (typically at index 7)
|
||||
zscore_proc = None
|
||||
for i, proc in enumerate(proc_list):
|
||||
proc_name = type(proc).__name__
|
||||
print(f" [{i}] {proc_name}")
|
||||
if proc_name == "RobustZScoreNorm":
|
||||
zscore_proc = proc
|
||||
|
||||
if zscore_proc is None:
|
||||
raise ValueError("RobustZScoreNorm processor not found in proc_list")
|
||||
|
||||
# Extract parameters
|
||||
params = {
|
||||
'mean_train': zscore_proc.mean_train,
|
||||
'std_train': zscore_proc.std_train,
|
||||
'fit_start_time': getattr(zscore_proc, 'fit_start_time', None),
|
||||
'fit_end_time': getattr(zscore_proc, 'fit_end_time', None),
|
||||
'fields_group': getattr(zscore_proc, 'fields_group', None),
|
||||
}
|
||||
|
||||
print(f"\nExtracted RobustZScoreNorm parameters:")
|
||||
print(f" mean_train shape: {params['mean_train'].shape}, dtype: {params['mean_train'].dtype}")
|
||||
print(f" std_train shape: {params['std_train'].shape}, dtype: {params['std_train'].dtype}")
|
||||
print(f" fit_start_time: {params['fit_start_time']}")
|
||||
print(f" fit_end_time: {params['fit_end_time']}")
|
||||
print(f" fields_group: {params['fields_group']}")
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def build_feature_column_names() -> list:
|
||||
"""
|
||||
Build the complete list of 330 feature column names in order.
|
||||
|
||||
Feature order (330 total):
|
||||
1. alpha158_ntrl (158 features)
|
||||
2. alpha158_raw (158 features)
|
||||
3. market_ext_ntrl (7 features)
|
||||
4. market_ext_raw (7 features)
|
||||
|
||||
market_ext columns (after processing):
|
||||
- Base: turnover, free_turnover, log_size, con_rating_strength
|
||||
- Diff: turnover_diff, free_turnover_diff, con_rating_strength_diff
|
||||
- Note: log_size_diff is removed by ColumnRemover
|
||||
"""
|
||||
# Alpha158 neutralized columns (158)
|
||||
alpha158_ntrl = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||
|
||||
# Alpha158 raw columns (158)
|
||||
alpha158_raw = ALPHA158_COLS.copy()
|
||||
|
||||
# market_ext columns (7 after ColumnRemover)
|
||||
# After Diff: 4 base + 4 diff = 8
|
||||
# After ColumnRemover (removes log_size_diff): 7 remain
|
||||
market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
|
||||
market_ext_diff = ['turnover_diff', 'free_turnover_diff', 'log_size_diff', 'con_rating_strength_diff']
|
||||
market_ext_all = market_ext_base + market_ext_diff
|
||||
market_ext_final = [c for c in market_ext_all if c != 'log_size_diff']
|
||||
|
||||
# market_ext neutralized columns (7)
|
||||
market_ext_ntrl = [f"{c}_ntrl" for c in market_ext_final]
|
||||
|
||||
# market_ext raw columns (7)
|
||||
market_ext_raw = market_ext_final.copy()
|
||||
|
||||
# Combine all feature columns in Qlib order
|
||||
feature_cols = alpha158_ntrl + alpha158_raw + market_ext_ntrl + market_ext_raw
|
||||
|
||||
print(f"\nBuilt feature column names:")
|
||||
print(f" alpha158_ntrl: {len(alpha158_ntrl)} features")
|
||||
print(f" alpha158_raw: {len(alpha158_raw)} features")
|
||||
print(f" market_ext_ntrl: {len(market_ext_ntrl)} features")
|
||||
print(f" market_ext_raw: {len(market_ext_raw)} features")
|
||||
print(f" Total: {len(feature_cols)} features")
|
||||
|
||||
return feature_cols
|
||||
|
||||
|
||||
def save_parameters(
|
||||
params: dict,
|
||||
feature_cols: list,
|
||||
output_dir: str,
|
||||
version: str
|
||||
):
|
||||
"""
|
||||
Save extracted parameters to output directory.
|
||||
|
||||
Creates:
|
||||
- mean_train.npy
|
||||
- std_train.npy
|
||||
- metadata.json
|
||||
"""
|
||||
output_path = Path(output_dir) / version
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\nSaving parameters to: {output_path}")
|
||||
|
||||
# Save mean_train.npy
|
||||
mean_path = output_path / "mean_train.npy"
|
||||
np.save(mean_path, params['mean_train'])
|
||||
print(f" Saved mean_train to: {mean_path}")
|
||||
|
||||
# Save std_train.npy
|
||||
std_path = output_path / "std_train.npy"
|
||||
np.save(std_path, params['std_train'])
|
||||
print(f" Saved std_train to: {std_path}")
|
||||
|
||||
# Build metadata
|
||||
metadata = {
|
||||
'version': version,
|
||||
'created_at': datetime.now().isoformat(),
|
||||
'source_file': str(params.get('fit_start_time', 'unknown')),
|
||||
'fit_start_time': str(params['fit_start_time']),
|
||||
'fit_end_time': str(params['fit_end_time']),
|
||||
'fields_group': list(params['fields_group']) if params['fields_group'] else None,
|
||||
'feature_columns': {
|
||||
'alpha158_ntrl': ALPHA158_COLS, # Store base names, _ntrl is implied
|
||||
'alpha158_raw': ALPHA158_COLS,
|
||||
'market_ext_ntrl': [
|
||||
'turnover', 'free_turnover', 'log_size', 'con_rating_strength',
|
||||
'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'
|
||||
],
|
||||
'market_ext_raw': [
|
||||
'turnover', 'free_turnover', 'log_size', 'con_rating_strength',
|
||||
'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'
|
||||
],
|
||||
},
|
||||
'feature_count': {
|
||||
'alpha158_ntrl': 158,
|
||||
'alpha158_raw': 158,
|
||||
'market_ext_ntrl': 7,
|
||||
'market_ext_raw': 7,
|
||||
'total': 330
|
||||
},
|
||||
'parameter_shapes': {
|
||||
'mean_train': list(params['mean_train'].shape),
|
||||
'std_train': list(params['std_train'].shape)
|
||||
}
|
||||
}
|
||||
|
||||
# Save metadata.json
|
||||
metadata_path = output_path / "metadata.json"
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print(f" Saved metadata to: {metadata_path}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract RobustZScoreNorm parameters from Qlib's proc_list.proc"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--proc-list',
|
||||
type=str,
|
||||
default=DEFAULT_PROC_LIST_PATH,
|
||||
help=f"Path to proc_list.proc (default: {DEFAULT_PROC_LIST_PATH})"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--version',
|
||||
type=str,
|
||||
default=DEFAULT_VERSION,
|
||||
help=f"Version name for output directory (default: {DEFAULT_VERSION})"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output-dir',
|
||||
type=str,
|
||||
default=str(Path(__file__).parent.parent / "data" / "robust_zscore_params"),
|
||||
help="Output directory for parameter files"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 80)
|
||||
print("Extract Qlib RobustZScoreNorm Parameters")
|
||||
print("=" * 80)
|
||||
print(f"Source: {args.proc_list}")
|
||||
print(f"Version: {args.version}")
|
||||
print(f"Output: {args.output_dir}")
|
||||
print()
|
||||
|
||||
# Step 1: Extract parameters from proc_list.proc
|
||||
params = extract_robust_zscore_params(args.proc_list)
|
||||
|
||||
# Step 2: Build feature column names
|
||||
feature_cols = build_feature_column_names()
|
||||
|
||||
# Step 3: Verify parameter shape matches feature count
|
||||
if params['mean_train'].shape[0] != len(feature_cols):
|
||||
print(f"\nWARNING: Parameter shape mismatch!")
|
||||
print(f" Expected: {len(feature_cols)} features")
|
||||
print(f" Got: {params['mean_train'].shape[0]} parameters")
|
||||
else:
|
||||
print(f"\n✓ Parameter shape matches feature count ({len(feature_cols)})")
|
||||
|
||||
# Step 4: Save parameters
|
||||
output_path = save_parameters(params, feature_cols, args.output_dir, args.version)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Extraction complete!")
|
||||
print("=" * 80)
|
||||
print(f"Output files:")
|
||||
print(f" {output_path}/mean_train.npy")
|
||||
print(f" {output_path}/std_train.npy")
|
||||
print(f" {output_path}/metadata.json")
|
||||
print()
|
||||
print("Usage in Polars RobustZScoreNorm:")
|
||||
print(f' norm = RobustZScoreNorm.from_version("{args.version}")')
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in new issue