Extract RobustZScoreNorm parameters and add from_version() method

- Add extract_qlib_params.py script to extract pre-fitted mean/std parameters
  from Qlib's proc_list.proc and save as reusable .npy files with metadata.json

- Add RobustZScoreNorm.from_version() class method to load saved parameters
  by version name, supporting multiple parameter versions coexistence

- Update dump_polars_dataset.py to use from_version() instead of loading
  parameters directly from proc_list.proc

- Update generate_beta_embedding.py to use qshare's filter_instruments for
  stock universe filtering

- Save parameters to data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/
  with 330 features (158 alpha158_ntrl + 158 alpha158_raw + 7 market_ext_ntrl + 7 market_ext_raw)

- Update README.md with documentation for parameter extraction and usage

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
master
guofu 3 days ago
parent 8bd36c1939
commit 89bd1a528e

@ -0,0 +1,753 @@
"""
Polars-based data processors for financial feature transformation.
This module provides Polars implementations of Qlib-style data processors
used in the data_ops pipeline. Each processor follows a consistent interface:
- Takes a Polars DataFrame as input
- Returns a transformed Polars DataFrame
Processors are organized by category:
- Feature Engineering: DiffProcessor
- Flag Injection: FlagMarketInjector, FlagSTInjector
- Column Operations: ColumnRemover, FlagToOnehot
- Normalization: IndusNtrlInjector, RobustZScoreNorm
- Data Cleaning: Fillna
"""
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Dict
import numpy as np
import polars as pl
__all__ = [
# Feature Engineering
'DiffProcessor',
# Flag Injection
'FlagMarketInjector',
'FlagSTInjector',
# Column Operations
'ColumnRemover',
'FlagToOnehot',
# Normalization
'IndusNtrlInjector',
'RobustZScoreNorm',
# Data Cleaning
'Fillna',
]
# =============================================================================
# Feature Engineering Processors
# =============================================================================
class DiffProcessor:
"""
Calculate period-over-period differences within each instrument.
For each specified column, calculates the diff(periods) within each
instrument group, using forward-fill for NaN handling.
Attributes:
columns: List of columns to calculate diff for
suffix: Suffix to append to diff column names (default: 'diff')
periods: Number of periods to diff (default: 1)
Example:
>>> processor = DiffProcessor(['turnover', 'log_size'])
>>> df = processor.process(df)
>>> # Creates: turnover_diff, log_size_diff
"""
def __init__(
self,
columns: List[str],
suffix: str = 'diff',
periods: int = 1
):
self.columns = columns
self.suffix = suffix
self.periods = periods
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Add diff features for specified columns.
Args:
df: Input DataFrame with datetime and instrument columns
Returns:
DataFrame with original columns + diff columns with suffix
"""
# Sort by instrument and datetime for correct diff calculation
df = df.sort(['instrument', 'datetime'])
# Add diff for each column
for col in self.columns:
if col in df.columns:
diff_col = f"{col}_{self.suffix}"
df = df.with_columns([
pl.col(col)
.diff(self.periods)
.over('instrument')
.alias(diff_col)
])
return df
def __repr__(self) -> str:
return f"DiffProcessor(columns={self.columns}, suffix='{self.suffix}', periods={self.periods})"
# =============================================================================
# Flag Injection Processors
# =============================================================================
class FlagMarketInjector:
"""
Inject market sector flags based on instrument codes.
Classifies stocks into market segments based on their instrument codes.
Supports both formats:
- With exchange prefix: SH600000, SZ000001, SH688000, SZ300001
- Numeric only: 600000, 000001, 688000, 300001
Market classification:
- market_0: Main board (SH60xxx, SZ00xxx, or 6xxxxx, 0xxxxx) - 主板
- market_1: STAR/ChiNext (SH688xxx, SH689xxx, SZ300xxx, SZ301xxx,
or 688xxx, 689xxx, 300xxx, 301xxx) - 科创板/创业板
Note: Does NOT include 新三板/北交所 (NE40xxx, NE42xxx, NE43xxx, NE8xxx)
in the classification - these stocks will have 0 for both flags.
Example:
>>> processor = FlagMarketInjector()
>>> df = processor.process(df)
>>> # Adds: market_0, market_1 (both int8)
"""
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Add market classification columns.
Args:
df: Input DataFrame with instrument column
Returns:
DataFrame with market_0, market_1 columns added
"""
# Convert instrument to string
inst_col = pl.col('instrument').cast(pl.String)
# Remove exchange prefix if present (SH/SZ/NE -> numeric part)
# E.g., SH600000 -> 600000, SZ000001 -> 000001
inst_numeric = inst_col.str.replace_all("^(SH|SZ|NE)", "")
# Get first digit(s) for market classification
first_digit = inst_numeric.str.slice(0, 1)
first_three = inst_numeric.str.slice(0, 3)
# market_0: Main board (60xxx, 601xxx, 603xxx, 000xxx, 001xxx, 002xxx)
# Excludes STAR market (688xxx, 689xxx) which start with '6' but are not main board
is_sh_main = (first_digit == '6') & ~(first_three == '688') & ~(first_three == '689')
is_sz_main = first_digit == '0'
# market_1: STAR/ChiNext (688xxx, 689xxx, 300xxx, 301xxx)
is_sh_star = (first_three == '688') | (first_three == '689')
is_sz_chi = (first_three == '300') | (first_three == '301')
df = df.with_columns([
# market_0 = 主板
(is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'),
# market_1 = 科创板 + 创业板
(is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1')
])
return df
def __repr__(self) -> str:
return "FlagMarketInjector()"
class FlagSTInjector:
"""
Inject ST (Special Treatment) flag.
Creates IsST column from ST_S and ST_Y flags:
- IsST = 1 if ST_S or ST_Y is True (stock is ST or *ST)
- IsST = 0 otherwise
If ST flags are not available, creates a placeholder column (all zeros).
Attributes:
mark_st_as: Value to mark ST stocks as (default: 1)
Example:
>>> processor = FlagSTInjector()
>>> df = processor.process(df)
>>> # Adds: IsST (int8)
"""
def __init__(self, mark_st_as: int = 1):
self.mark_st_as = mark_st_as
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Add IsST column from ST flags.
Args:
df: Input DataFrame with ST_S, ST_Y columns (or without for placeholder)
Returns:
DataFrame with IsST column added
"""
# Check if ST flags are available
if 'ST_S' in df.columns and 'ST_Y' in df.columns:
# Create IsST from actual ST flags
df = df.with_columns([
((pl.col('ST_S').cast(pl.Boolean, strict=False) |
pl.col('ST_Y').cast(pl.Boolean, strict=False))
.cast(pl.Int8).alias('IsST'))
])
else:
# Create placeholder (all zeros) if ST flags not available
df = df.with_columns([
pl.lit(0).cast(pl.Int8).alias('IsST')
])
return df
def __repr__(self) -> str:
return f"FlagSTInjector(mark_st_as={self.mark_st_as})"
# =============================================================================
# Column Operation Processors
# =============================================================================
class ColumnRemover:
"""
Remove specified columns from the DataFrame.
Attributes:
columns_to_remove: List of column names to drop
Example:
>>> processor = ColumnRemover(['log_size_diff', 'IsN', 'IsZt', 'IsDt'])
>>> df = processor.process(df)
>>> # Removes specified columns
"""
def __init__(self, columns_to_remove: List[str]):
self.columns_to_remove = columns_to_remove
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Remove specified columns.
Args:
df: Input DataFrame
Returns:
DataFrame without specified columns
"""
# Only remove columns that exist
cols_to_drop = [c for c in self.columns_to_remove if c in df.columns]
if cols_to_drop:
df = df.drop(cols_to_drop)
return df
def __repr__(self) -> str:
return f"ColumnRemover(columns_to_remove={self.columns_to_remove})"
class FlagToOnehot:
"""
Convert flag columns to one-hot encoded index.
For multiple one-hot encoded columns, finds which flag is set and
returns the corresponding index. Uses -1 as default for rows with
no flags set.
Attributes:
flag_columns: List of boolean flag column names
Example:
>>> processor = FlagToOnehot(['gds_CC10', 'gds_CC11', ...])
>>> df = processor.process(df)
>>> # Adds: indus_idx (index of first True flag, or -1)
"""
def __init__(self, flag_columns: List[str]):
self.flag_columns = flag_columns
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Convert flag columns to single index column.
Args:
df: Input DataFrame with flag columns
Returns:
DataFrame with indus_idx column added, original flags dropped
"""
# Build a when/then chain to find the industry index
# Start with -1 (no industry) as default
indus_expr = pl.lit(-1)
for idx, col in enumerate(self.flag_columns):
if col in df.columns:
indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr)
df = df.with_columns([indus_expr.alias('indus_idx')])
# Drop the original one-hot columns
cols_to_drop = [c for c in self.flag_columns if c in df.columns]
if cols_to_drop:
df = df.drop(cols_to_drop)
return df
def __repr__(self) -> str:
return f"FlagToOnehot(flag_columns={len(self.flag_columns)} columns)"
# =============================================================================
# Normalization Processors
# =============================================================================
class IndusNtrlInjector:
"""
Industry neutralization for features.
For each feature, subtracts the industry mean (grouped by indus_idx
within each datetime) from the feature value. Creates new columns
with the specified suffix while keeping original columns.
This performs cross-sectional neutralization per datetime, matching
Qlib's cal_indus_ntrl behavior.
Attributes:
feature_cols: List of feature columns to neutralize
suffix: Suffix for neutralized column names (default: '_ntrl')
Example:
>>> processor = IndusNtrlInjector(['KMID', 'KLEN'], suffix='_ntrl')
>>> df = processor.process(df)
>>> # Adds: KMID_ntrl, KLEN_ntrl
"""
def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'):
self.feature_cols = feature_cols
self.suffix = suffix
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Apply industry neutralization to specified features.
Args:
df: Input DataFrame with feature columns and indus_idx
Returns:
DataFrame with neutralized columns added (_ntrl suffix)
"""
# Filter to only columns that exist
existing_cols = [c for c in self.feature_cols if c in df.columns]
for col in existing_cols:
ntrl_col = f"{col}{self.suffix}"
# Calculate industry mean PER DATETIME and subtract from feature
# Use group_by().transform() for proper group-wise operation
df = df.with_columns([
(pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'])).alias(ntrl_col)
])
return df
def __repr__(self) -> str:
return f"IndusNtrlInjector(feature_cols={len(self.feature_cols)} columns, suffix='{self.suffix}')"
@dataclass
class RobustZScoreNorm:
"""
Robust z-score normalization using median/MAD.
Applies the transformation: (x - median) / (1.4826 * MAD)
where MAD = median(|x - median|)
Supports two modes:
1. Per-datetime computation (default): Calculates median/MAD for each datetime
2. Pre-fitted parameters: Uses provided mean/std arrays (from Qlib processor)
Attributes:
feature_cols: List of feature columns to normalize
clip_range: Clip normalized values to this range (default: (-3, 3))
use_qlib_params: Use pre-fitted parameters (default: False)
qlib_mean: Pre-fitted mean array (required if use_qlib_params=True)
qlib_std: Pre-fitted std array (required if use_qlib_params=True)
Example:
# Using pre-fitted Qlib parameters
>>> processor = RobustZScoreNorm(
... feature_cols=['KMID', 'KLEN'],
... use_qlib_params=True,
... qlib_mean=mean_array,
... qlib_std=std_array
... )
>>> df = processor.process(df)
# Loading parameters from saved version
>>> processor = RobustZScoreNorm.from_version("csiallx_feature2_ntrla_flag_pnlnorm")
>>> df = processor.process(df)
"""
feature_cols: List[str]
clip_range: Tuple[float, float] = (-3.0, 3.0)
use_qlib_params: bool = False
qlib_mean: Optional[np.ndarray] = None
qlib_std: Optional[np.ndarray] = None
@classmethod
def from_version(
cls,
version: str,
feature_cols: Optional[List[str]] = None,
clip_range: Tuple[float, float] = (-3.0, 3.0),
params_dir: str = None
) -> "RobustZScoreNorm":
"""
Create a RobustZScoreNorm instance from saved parameters by version name.
This loads pre-extracted mean_train and std_train from the parameter
directory structure:
{params_dir}/{version}/
mean_train.npy
std_train.npy
metadata.json
Args:
version: Version name (e.g., "csiallx_feature2_ntrla_flag_pnlnorm")
feature_cols: Optional list of feature columns. If None, uses the
order from metadata.json (alpha158_ntrl + alpha158_raw +
market_ext_ntrl + market_ext_raw)
clip_range: Clip range for normalized values (default: (-3, 3))
params_dir: Base directory for parameter versions. If None, uses:
stock_1d/d033/alpha158_beta/data/robust_zscore_params/
Returns:
RobustZScoreNorm instance with pre-fitted parameters loaded
Raises:
FileNotFoundError: If version directory or parameter files not found
ValueError: If feature column count doesn't match parameter shape
Example:
>>> processor = RobustZScoreNorm.from_version(
... "csiallx_feature2_ntrla_flag_pnlnorm"
... )
>>> df = processor.process(df)
"""
import json
from pathlib import Path
# Set default params_dir
if params_dir is None:
# Default to the standard location
# Go from cta_1d/src/processors/ to alpha_lab/stock_1d/d033/alpha158_beta/data/robust_zscore_params/
params_dir = Path(__file__).parent.parent.parent.parent / \
"stock_1d" / "d033" / "alpha158_beta" / "data" / "robust_zscore_params"
else:
params_dir = Path(params_dir)
version_dir = params_dir / version
# Check version directory exists
if not version_dir.exists():
raise FileNotFoundError(
f"Version directory not found: {version_dir}\n"
f"Available versions should be in: {params_dir}"
)
# Load mean_train.npy
mean_path = version_dir / "mean_train.npy"
if not mean_path.exists():
raise FileNotFoundError(f"mean_train.npy not found: {mean_path}")
mean_train = np.load(mean_path)
# Load std_train.npy
std_path = version_dir / "std_train.npy"
if not std_path.exists():
raise FileNotFoundError(f"std_train.npy not found: {std_path}")
std_train = np.load(std_path)
# Load metadata.json for feature column names
metadata_path = version_dir / "metadata.json"
if metadata_path.exists():
with open(metadata_path, 'r') as f:
metadata = json.load(f)
# Build feature columns from metadata if not provided
if feature_cols is None:
feature_columns = metadata.get('feature_columns', {})
alpha158_ntrl = [f"{c}_ntrl" for c in feature_columns.get('alpha158_ntrl', [])]
alpha158_raw = feature_columns.get('alpha158_raw', [])
market_ext_ntrl = [f"{c}_ntrl" for c in feature_columns.get('market_ext_ntrl', [])]
market_ext_raw = feature_columns.get('market_ext_raw', [])
feature_cols = alpha158_ntrl + alpha158_raw + market_ext_ntrl + market_ext_raw
# Validate feature column count matches parameter shape
expected_count = len(mean_train)
if feature_cols and len(feature_cols) != expected_count:
raise ValueError(
f"Feature column count ({len(feature_cols)}) does not match "
f"parameter shape ({expected_count})"
)
print(f"Loaded RobustZScoreNorm parameters from version '{version}':")
print(f" mean_train shape: {mean_train.shape}")
print(f" std_train shape: {std_train.shape}")
print(f" feature_cols: {len(feature_cols) if feature_cols else 'not specified'}")
return cls(
feature_cols=feature_cols,
clip_range=clip_range,
use_qlib_params=True,
qlib_mean=mean_train,
qlib_std=std_train
)
def __post_init__(self):
"""Validate parameters after initialization."""
if self.use_qlib_params:
if self.qlib_mean is None or self.qlib_std is None:
raise ValueError(
"Must provide qlib_mean and qlib_std when use_qlib_params=True"
)
# Convert to numpy arrays if not already
self.qlib_mean = np.asarray(self.qlib_mean)
self.qlib_std = np.asarray(self.qlib_std)
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Apply robust z-score normalization.
Args:
df: Input DataFrame with feature columns
Returns:
DataFrame with normalized columns (in-place modification)
"""
# Filter to only columns that exist
existing_cols = [c for c in self.feature_cols if c in df.columns]
if self.use_qlib_params:
# Use pre-fitted parameters (fit once, apply to all dates)
for i, col in enumerate(existing_cols):
if i < len(self.qlib_mean):
mean_val = float(self.qlib_mean[i])
std_val = float(self.qlib_std[i])
# Apply z-score normalization using pre-fitted params
df = df.with_columns([
((pl.col(col) - mean_val) / (std_val + 1e-8))
.clip(self.clip_range[0], self.clip_range[1])
.alias(col)
])
else:
# Compute per-datetime robust z-score
for col in existing_cols:
# Compute median per datetime
median_col = f"__median_{col}"
df = df.with_columns([
pl.col(col).median().over('datetime').alias(median_col)
])
# Compute absolute deviation
abs_dev_col = f"__absdev_{col}"
df = df.with_columns([
(pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col)
])
# Compute MAD (median of absolute deviations)
mad_col = f"__mad_{col}"
df = df.with_columns([
pl.col(abs_dev_col).median().over('datetime').alias(mad_col)
])
# Compute robust z-score and clip
df = df.with_columns([
((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8))
.clip(self.clip_range[0], self.clip_range[1])
.alias(col)
])
# Clean up temporary columns
df = df.drop([median_col, abs_dev_col, mad_col])
return df
# =============================================================================
# Data Cleaning Processors
# =============================================================================
class Fillna:
"""
Fill NaN values with specified value.
Fills NaN/None values in specified columns with the fill_value.
Only processes numeric columns (Float32, Float64, Int32, Int64, UInt32, UInt64).
Attributes:
fill_value: Value to fill NaN with (default: 0.0)
Example:
>>> processor = Fillna(fill_value=0.0)
>>> df = processor.process(df, ['KMID', 'KLEN'])
"""
def __init__(self, fill_value: float = 0.0):
self.fill_value = fill_value
def process(self, df: pl.DataFrame, columns: List[str]) -> pl.DataFrame:
"""
Fill NaN values in specified columns.
Args:
df: Input DataFrame
columns: List of columns to fill NaN in
Returns:
DataFrame with NaN values filled
"""
# Filter to only columns that exist and are numeric
existing_cols = [c for c in columns if c in df.columns]
for col in existing_cols:
# Check column dtype
dtype = df[col].dtype
if dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64,
pl.UInt32, pl.UInt64, pl.UInt16, pl.UInt8]:
df = df.with_columns([
pl.col(col).fill_null(self.fill_value).fill_nan(self.fill_value)
])
return df
def __repr__(self) -> str:
return f"Fillna(fill_value={self.fill_value})"
# =============================================================================
# Processor Pipeline Utilities
# =============================================================================
def create_processor_pipeline(
alpha158_cols: List[str],
market_ext_base: List[str],
market_flag_cols: List[str],
industry_flag_cols: List[str],
columns_to_remove: Optional[List[str]] = None,
ntrl_suffix: str = '_ntrl'
) -> List:
"""
Create a complete processor pipeline configuration.
This factory function creates processors in the correct order:
1. DiffProcessor - adds diff features
2. FlagMarketInjector - adds market_0, market_1
3. FlagSTInjector - adds IsST
4. ColumnRemover - removes specified columns
5. FlagToOnehot - converts industry flags to index
6. IndusNtrlInjector (x2) - neutralizes alpha158 and market_ext
7. RobustZScoreNorm - normalizes features
8. Fillna - fills NaN values
Args:
alpha158_cols: List of alpha158 feature names
market_ext_base: List of market extension base columns
market_flag_cols: List of market flag columns
industry_flag_cols: List of industry flag columns
columns_to_remove: Columns to remove (default: ['log_size_diff', 'IsN', 'IsZt', 'IsDt'])
ntrl_suffix: Suffix for neutralized columns
Returns:
List of processor instances in execution order
"""
if columns_to_remove is None:
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
return [
DiffProcessor(market_ext_base),
FlagMarketInjector(),
FlagSTInjector(),
ColumnRemover(columns_to_remove),
FlagToOnehot(industry_flag_cols),
IndusNtrlInjector(alpha158_cols, suffix=ntrl_suffix),
IndusNtrlInjector(market_ext_base, suffix=ntrl_suffix),
# RobustZScoreNorm and Fillna require fitted parameters
# and should be added separately
]
def get_final_feature_columns(
alpha158_cols: List[str],
market_ext_base: List[str],
market_flag_cols: List[str],
columns_to_remove: Optional[List[str]] = None,
ntrl_suffix: str = '_ntrl'
) -> Dict[str, List[str]]:
"""
Get the final feature column structure after processing.
This is useful for determining the expected VAE input dimensions
and verifying feature order.
Args:
alpha158_cols: List of alpha158 feature names
market_ext_base: List of market extension base columns (before Diff)
market_flag_cols: List of market flag columns (before ColumnRemover)
columns_to_remove: Columns to remove
ntrl_suffix: Suffix for neutralized columns
Returns:
Dictionary with feature groups:
- 'norm_feature_cols': Features to normalize (in order)
- 'market_flag_cols': Market flag columns after processing
- 'all_feature_cols': All feature columns including indus_idx
"""
if columns_to_remove is None:
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
# After Diff: market_ext becomes base + diff
market_ext_after_diff = market_ext_base + [f"{c}_diff" for c in market_ext_base]
# After ColumnRemover: remove specified columns
market_ext_final = [c for c in market_ext_after_diff if c not in columns_to_remove]
market_flag_final = [c for c in market_flag_cols if c not in columns_to_remove]
# After FlagMarketInjector: add market_0, market_1
market_flag_final.extend(['market_0', 'market_1'])
# After FlagSTInjector: add IsST
market_flag_final.append('IsST')
# Build normalized feature columns (in Qlib order: ntrl + raw for each group)
alpha158_ntrl = [f"{c}{ntrl_suffix}" for c in alpha158_cols]
market_ext_ntrl = [f"{c}{ntrl_suffix}" for c in market_ext_final]
# Normalization order: alpha158_ntrl + alpha158 + market_ext_ntrl + market_ext
norm_feature_cols = alpha158_ntrl + alpha158_cols + market_ext_ntrl + market_ext_final
return {
'alpha158_cols': alpha158_cols,
'alpha158_ntrl_cols': alpha158_ntrl,
'market_ext_cols': market_ext_final,
'market_ext_ntrl_cols': market_ext_ntrl,
'market_flag_cols': market_flag_final,
'norm_feature_cols': norm_feature_cols,
'all_feature_cols': norm_feature_cols + market_flag_final + ['indus_idx'],
}

@ -19,18 +19,27 @@ stock_1d/d033/alpha158_beta/
│ ├── fetch_predictions.py # Fetch original predictions from DolphinDB
│ ├── predict_with_embedding.py # Generate predictions using beta embeddings
│ ├── compare_predictions.py # Compare 0_7 vs 0_7_beta predictions
│ └── dump_polars_dataset.py # Dump raw and processed datasets using polars pipeline
│ ├── dump_polars_dataset.py # Dump raw and processed datasets using polars pipeline
│ └── extract_qlib_params.py # Extract RobustZScoreNorm parameters from Qlib proc_list
├── src/ # Source modules
│ └── qlib_loader.py # Qlib data loader with configurable date range
├── config/ # Configuration files
│ └── handler.yaml # Modified handler with configurable end date
└── data/ # Data files (gitignored)
├── embedding_0_7_beta.parquet
├── predictions_beta_embedding.parquet
├── original_predictions_0_7.parquet
├── actual_returns.parquet
├── raw_data_*.pkl # Raw data before preprocessing
└── processed_data_*.pkl # Processed data after preprocessing
├── data/ # Data files (gitignored)
│ ├── robust_zscore_params/ # Pre-fitted normalization parameters
│ │ └── csiallx_feature2_ntrla_flag_pnlnorm/
│ │ ├── mean_train.npy
│ │ ├── std_train.npy
│ │ └── metadata.json
│ ├── embedding_0_7_beta.parquet
│ ├── predictions_beta_embedding.parquet
│ ├── original_predictions_0_7.parquet
│ ├── actual_returns.parquet
│ ├── raw_data_*.pkl # Raw data before preprocessing
│ └── processed_data_*.pkl # Processed data after preprocessing
└── data_polars/ # Polars-generated datasets (gitignored)
├── raw_data_*.pkl
└── processed_data_*.pkl
```
## Data Loading with Configurable Date Range
@ -122,7 +131,7 @@ This script:
- ColumnRemover (removes log_size_diff, IsN, IsZt, IsDt)
- FlagToOnehot (converts 29 industry flags to indus_idx)
- IndusNtrlInjector (industry neutralization)
- RobustZScoreNorm (using pre-fitted qlib parameters)
- RobustZScoreNorm (using pre-fitted qlib parameters via `from_version()`)
- Fillna (fill NaN with 0)
4. Saves processed data to `data_polars/processed_data_*.pkl`
@ -133,6 +142,38 @@ Output structure:
- Processed data: 342 columns (316 feature + 14 feature_ext + 11 feature_flag + 1 indus_idx)
- VAE input dimension: 341 (excluding indus_idx)
### RobustZScoreNorm Parameter Extraction
The pipeline uses pre-fitted normalization parameters extracted from Qlib's `proc_list.proc` file. These parameters are stored in `data/robust_zscore_params/` and can be loaded using the `RobustZScoreNorm.from_version()` method.
**Extract parameters from Qlib proc_list:**
```bash
python scripts/extract_qlib_params.py --version csiallx_feature2_ntrla_flag_pnlnorm
```
This creates:
- `data/robust_zscore_params/{version}/mean_train.npy` - Pre-fitted mean parameters (330,)
- `data/robust_zscore_params/{version}/std_train.npy` - Pre-fitted std parameters (330,)
- `data/robust_zscore_params/{version}/metadata.json` - Feature column names and metadata
**Use in Polars processors:**
```python
from cta_1d.src.processors import RobustZScoreNorm
# Load pre-fitted parameters by version name
processor = RobustZScoreNorm.from_version("csiallx_feature2_ntrla_flag_pnlnorm")
# Apply normalization to DataFrame
df = processor.process(df)
```
**Parameter details:**
- Fit period: 2013-01-01 to 2018-12-31
- Feature count: 330 (158 alpha158_ntrl + 158 alpha158_raw + 7 market_ext_ntrl + 7 market_ext_raw)
- Fields: ['feature', 'feature_ext']
## Workflow
### 1. Generate Beta Embeddings

@ -0,0 +1,366 @@
{
"version": "csiallx_feature2_ntrla_flag_pnlnorm",
"created_at": "2026-03-01T12:18:01.969109",
"source_file": "2013-01-01",
"fit_start_time": "2013-01-01",
"fit_end_time": "2018-12-31",
"fields_group": [
"feature",
"feature_ext"
],
"feature_columns": {
"alpha158_ntrl": [
"KMID",
"KLEN",
"KMID2",
"KUP",
"KUP2",
"KLOW",
"KLOW2",
"KSFT",
"KSFT2",
"OPEN0",
"HIGH0",
"LOW0",
"VWAP0",
"ROC5",
"ROC10",
"ROC20",
"ROC30",
"ROC60",
"MA5",
"MA10",
"MA20",
"MA30",
"MA60",
"STD5",
"STD10",
"STD20",
"STD30",
"STD60",
"BETA5",
"BETA10",
"BETA20",
"BETA30",
"BETA60",
"RSQR5",
"RSQR10",
"RSQR20",
"RSQR30",
"RSQR60",
"RESI5",
"RESI10",
"RESI20",
"RESI30",
"RESI60",
"MAX5",
"MAX10",
"MAX20",
"MAX30",
"MAX60",
"MIN5",
"MIN10",
"MIN20",
"MIN30",
"MIN60",
"QTLU5",
"QTLU10",
"QTLU20",
"QTLU30",
"QTLU60",
"QTLD5",
"QTLD10",
"QTLD20",
"QTLD30",
"QTLD60",
"RANK5",
"RANK10",
"RANK20",
"RANK30",
"RANK60",
"RSV5",
"RSV10",
"RSV20",
"RSV30",
"RSV60",
"IMAX5",
"IMAX10",
"IMAX20",
"IMAX30",
"IMAX60",
"IMIN5",
"IMIN10",
"IMIN20",
"IMIN30",
"IMIN60",
"IMXD5",
"IMXD10",
"IMXD20",
"IMXD30",
"IMXD60",
"CORR5",
"CORR10",
"CORR20",
"CORR30",
"CORR60",
"CORD5",
"CORD10",
"CORD20",
"CORD30",
"CORD60",
"CNTP5",
"CNTP10",
"CNTP20",
"CNTP30",
"CNTP60",
"CNTN5",
"CNTN10",
"CNTN20",
"CNTN30",
"CNTN60",
"CNTD5",
"CNTD10",
"CNTD20",
"CNTD30",
"CNTD60",
"SUMP5",
"SUMP10",
"SUMP20",
"SUMP30",
"SUMP60",
"SUMN5",
"SUMN10",
"SUMN20",
"SUMN30",
"SUMN60",
"SUMD5",
"SUMD10",
"SUMD20",
"SUMD30",
"SUMD60",
"VMA5",
"VMA10",
"VMA20",
"VMA30",
"VMA60",
"VSTD5",
"VSTD10",
"VSTD20",
"VSTD30",
"VSTD60",
"WVMA5",
"WVMA10",
"WVMA20",
"WVMA30",
"WVMA60",
"VSUMP5",
"VSUMP10",
"VSUMP20",
"VSUMP30",
"VSUMP60",
"VSUMN5",
"VSUMN10",
"VSUMN20",
"VSUMN30",
"VSUMN60",
"VSUMD5",
"VSUMD10",
"VSUMD20",
"VSUMD30",
"VSUMD60"
],
"alpha158_raw": [
"KMID",
"KLEN",
"KMID2",
"KUP",
"KUP2",
"KLOW",
"KLOW2",
"KSFT",
"KSFT2",
"OPEN0",
"HIGH0",
"LOW0",
"VWAP0",
"ROC5",
"ROC10",
"ROC20",
"ROC30",
"ROC60",
"MA5",
"MA10",
"MA20",
"MA30",
"MA60",
"STD5",
"STD10",
"STD20",
"STD30",
"STD60",
"BETA5",
"BETA10",
"BETA20",
"BETA30",
"BETA60",
"RSQR5",
"RSQR10",
"RSQR20",
"RSQR30",
"RSQR60",
"RESI5",
"RESI10",
"RESI20",
"RESI30",
"RESI60",
"MAX5",
"MAX10",
"MAX20",
"MAX30",
"MAX60",
"MIN5",
"MIN10",
"MIN20",
"MIN30",
"MIN60",
"QTLU5",
"QTLU10",
"QTLU20",
"QTLU30",
"QTLU60",
"QTLD5",
"QTLD10",
"QTLD20",
"QTLD30",
"QTLD60",
"RANK5",
"RANK10",
"RANK20",
"RANK30",
"RANK60",
"RSV5",
"RSV10",
"RSV20",
"RSV30",
"RSV60",
"IMAX5",
"IMAX10",
"IMAX20",
"IMAX30",
"IMAX60",
"IMIN5",
"IMIN10",
"IMIN20",
"IMIN30",
"IMIN60",
"IMXD5",
"IMXD10",
"IMXD20",
"IMXD30",
"IMXD60",
"CORR5",
"CORR10",
"CORR20",
"CORR30",
"CORR60",
"CORD5",
"CORD10",
"CORD20",
"CORD30",
"CORD60",
"CNTP5",
"CNTP10",
"CNTP20",
"CNTP30",
"CNTP60",
"CNTN5",
"CNTN10",
"CNTN20",
"CNTN30",
"CNTN60",
"CNTD5",
"CNTD10",
"CNTD20",
"CNTD30",
"CNTD60",
"SUMP5",
"SUMP10",
"SUMP20",
"SUMP30",
"SUMP60",
"SUMN5",
"SUMN10",
"SUMN20",
"SUMN30",
"SUMN60",
"SUMD5",
"SUMD10",
"SUMD20",
"SUMD30",
"SUMD60",
"VMA5",
"VMA10",
"VMA20",
"VMA30",
"VMA60",
"VSTD5",
"VSTD10",
"VSTD20",
"VSTD30",
"VSTD60",
"WVMA5",
"WVMA10",
"WVMA20",
"WVMA30",
"WVMA60",
"VSUMP5",
"VSUMP10",
"VSUMP20",
"VSUMP30",
"VSUMP60",
"VSUMN5",
"VSUMN10",
"VSUMN20",
"VSUMN30",
"VSUMN60",
"VSUMD5",
"VSUMD10",
"VSUMD20",
"VSUMD30",
"VSUMD60"
],
"market_ext_ntrl": [
"turnover",
"free_turnover",
"log_size",
"con_rating_strength",
"turnover_diff",
"free_turnover_diff",
"con_rating_strength_diff"
],
"market_ext_raw": [
"turnover",
"free_turnover",
"log_size",
"con_rating_strength",
"turnover_diff",
"free_turnover_diff",
"con_rating_strength_diff"
]
},
"feature_count": {
"alpha158_ntrl": 158,
"alpha158_raw": 158,
"market_ext_ntrl": 7,
"market_ext_raw": 7,
"total": 330
},
"parameter_shapes": {
"mean_train": [
330
],
"std_train": [
330
]
}
}

@ -20,18 +20,20 @@ from datetime import datetime
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from generate_beta_embedding import (
load_all_data,
merge_data_sources,
filter_stock_universe,
# Import processors from the new shared module
from cta_1d.src.processors import (
DiffProcessor,
FlagMarketInjector,
FlagSTInjector,
ColumnRemover,
FlagToOnehot,
IndusNtrlInjector,
RobustZScoreNorm,
Fillna,
load_qlib_processor_params,
)
# Import constants from local module
from generate_beta_embedding import (
ALPHA158_COLS,
INDUSTRY_FLAG_COLS,
)
@ -52,7 +54,7 @@ def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
This mimics the qlib proc_list:
0. Diff: Adds diff features for market_ext columns
1. FlagMarketInjector: Adds market_0, market_1
2. FlagSTInjector: SKIPPED (fails even in gold-standard)
2. FlagSTInjector: Adds IsST (placeholder if ST flags not available)
3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt
4. FlagToOnehot: Converts 29 industry flags to indus_idx
5. IndusNtrlInjector: Industry neutralization for feature
@ -88,9 +90,13 @@ def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
# Add market_0, market_1 to flag list
market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
# Step 3: FlagSTInjector - SKIPPED (fails even in gold-standard)
print("[3] Skipping FlagSTInjector (as per gold-standard behavior)...")
market_flag_with_st = market_flag_with_market # No IsST added
# Step 3: FlagSTInjector - adds IsST (placeholder if ST flags not available)
print("[3] Applying FlagSTInjector...")
flag_st_injector = FlagSTInjector()
df = flag_st_injector.process(df)
# Add IsST to flag list
market_flag_with_st = market_flag_with_market + ['IsST']
# Step 4: ColumnRemover
print("[4] Applying ColumnRemover...")
@ -129,21 +135,18 @@ def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
print("[8] Applying RobustZScoreNorm...")
norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols
qlib_params = load_qlib_processor_params()
# Load RobustZScoreNorm with pre-fitted parameters from version
robust_norm = RobustZScoreNorm.from_version(
"csiallx_feature2_ntrla_flag_pnlnorm",
feature_cols=norm_feature_cols
)
# Verify parameter shape
# Verify parameter shape matches expected features
expected_features = len(norm_feature_cols)
if qlib_params['mean_train'].shape[0] != expected_features:
if robust_norm.qlib_mean.shape[0] != expected_features:
print(f" WARNING: Feature count mismatch! Expected {expected_features}, "
f"got {qlib_params['mean_train'].shape[0]}")
robust_norm = RobustZScoreNorm(
norm_feature_cols,
clip_range=(-3, 3),
use_qlib_params=True,
qlib_mean=qlib_params['mean_train'],
qlib_std=qlib_params['std_train']
)
f"got {robust_norm.qlib_mean.shape[0]}")
df = robust_norm.process(df)
# Step 9: Fillna
@ -166,6 +169,9 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
"""
Convert polars DataFrame to pandas DataFrame with MultiIndex columns.
This matches the format of qlib's output.
IMPORTANT: Qlib's IndusNtrlInjector outputs columns in order [_ntrl] + [raw],
so we need to reorder columns to match this expected order.
"""
import pandas as pd
@ -185,6 +191,7 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
df = df.drop(columns=existing_raw_cols)
# Build MultiIndex columns based on column name patterns
# IMPORTANT: Qlib order is [_ntrl columns] + [raw columns] for each group
columns_with_group = []
# Define column sets
@ -195,23 +202,29 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit',
'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'}
# First pass: collect _ntrl columns (these come first in qlib order)
ntrl_alpha158_cols = []
ntrl_market_ext_cols = []
raw_alpha158_cols = []
raw_market_ext_cols = []
flag_cols = []
indus_idx_col = None
for col in df.columns:
if col == 'indus_idx':
columns_with_group.append(('indus_idx', col))
indus_idx_col = col
elif col in feature_flag_cols:
columns_with_group.append(('feature_flag', col))
flag_cols.append(col)
elif col.endswith('_ntrl'):
base_name = col[:-5] # Remove _ntrl suffix (5 characters)
if base_name in alpha158_base:
columns_with_group.append(('feature', col))
ntrl_alpha158_cols.append(col)
elif base_name in market_ext_all:
columns_with_group.append(('feature_ext', col))
else:
columns_with_group.append(('feature', col)) # Default to feature
ntrl_market_ext_cols.append(col)
elif col in alpha158_base:
columns_with_group.append(('feature', col))
raw_alpha158_cols.append(col)
elif col in market_ext_all:
columns_with_group.append(('feature_ext', col))
raw_market_ext_cols.append(col)
elif col in INDUSTRY_FLAG_COLS:
columns_with_group.append(('indus_flag', col))
elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}:
@ -221,6 +234,27 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
print(f" Warning: Unknown column '{col}', assigning to 'other' group")
columns_with_group.append(('other', col))
# Build columns in qlib order: [_ntrl] + [raw] for each feature group
# Feature group: alpha158_ntrl + alpha158
for col in sorted(ntrl_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x.replace('_ntrl', '')) if x.replace('_ntrl', '') in ALPHA158_COLS else 999):
columns_with_group.append(('feature', col))
for col in sorted(raw_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x) if x in ALPHA158_COLS else 999):
columns_with_group.append(('feature', col))
# Feature_ext group: market_ext_ntrl + market_ext
for col in ntrl_market_ext_cols:
columns_with_group.append(('feature_ext', col))
for col in raw_market_ext_cols:
columns_with_group.append(('feature_ext', col))
# Feature_flag group
for col in flag_cols:
columns_with_group.append(('feature_flag', col))
# Indus_idx
if indus_idx_col:
columns_with_group.append(('indus_idx', indus_idx_col))
# Create MultiIndex columns
multi_cols = pd.MultiIndex.from_tuples(columns_with_group)
df.columns = multi_cols

@ -0,0 +1,305 @@
#!/usr/bin/env python
"""
Extract RobustZScoreNorm parameters from Qlib's proc_list.proc file.
This script extracts the pre-fitted mean_train and std_train parameters from
Qlib's RobustZScoreNorm processor and saves them as reusable .npy files.
The extracted parameters can be used by the Polars RobustZScoreNorm processor
via the from_version() class method.
Output structure:
stock_1d/d033/alpha158_beta/data/robust_zscore_params/
{version}/
mean_train.npy
std_train.npy
metadata.json
"""
import os
import sys
import json
import pickle as pkl
from pathlib import Path
from datetime import datetime
import numpy as np
# Default paths
DEFAULT_PROC_LIST_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/dataset/csiallx_feature2_ntrla_flag_pnlnorm/proc_list.proc"
DEFAULT_VERSION = "csiallx_feature2_ntrla_flag_pnlnorm"
# Alpha158 columns in order (158 features)
ALPHA158_COLS = [
'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
]
assert len(ALPHA158_COLS) == 158, f"Expected 158 alpha158 cols, got {len(ALPHA158_COLS)}"
def extract_robust_zscore_params(proc_list_path: str) -> dict:
"""
Extract RobustZScoreNorm parameters from Qlib's proc_list.proc file.
Args:
proc_list_path: Path to the proc_list.proc pickle file
Returns:
Dictionary containing:
- mean_train: numpy array of shape (330,)
- std_train: numpy array of shape (330,)
- fit_start_time: datetime string
- fit_end_time: datetime string
- fields_group: list of field groups
"""
print(f"Loading proc_list.proc from: {proc_list_path}")
with open(proc_list_path, 'rb') as f:
proc_list = pkl.load(f)
print(f"Loaded {len(proc_list)} processors from proc_list")
# Find RobustZScoreNorm processor (typically at index 7)
zscore_proc = None
for i, proc in enumerate(proc_list):
proc_name = type(proc).__name__
print(f" [{i}] {proc_name}")
if proc_name == "RobustZScoreNorm":
zscore_proc = proc
if zscore_proc is None:
raise ValueError("RobustZScoreNorm processor not found in proc_list")
# Extract parameters
params = {
'mean_train': zscore_proc.mean_train,
'std_train': zscore_proc.std_train,
'fit_start_time': getattr(zscore_proc, 'fit_start_time', None),
'fit_end_time': getattr(zscore_proc, 'fit_end_time', None),
'fields_group': getattr(zscore_proc, 'fields_group', None),
}
print(f"\nExtracted RobustZScoreNorm parameters:")
print(f" mean_train shape: {params['mean_train'].shape}, dtype: {params['mean_train'].dtype}")
print(f" std_train shape: {params['std_train'].shape}, dtype: {params['std_train'].dtype}")
print(f" fit_start_time: {params['fit_start_time']}")
print(f" fit_end_time: {params['fit_end_time']}")
print(f" fields_group: {params['fields_group']}")
return params
def build_feature_column_names() -> list:
"""
Build the complete list of 330 feature column names in order.
Feature order (330 total):
1. alpha158_ntrl (158 features)
2. alpha158_raw (158 features)
3. market_ext_ntrl (7 features)
4. market_ext_raw (7 features)
market_ext columns (after processing):
- Base: turnover, free_turnover, log_size, con_rating_strength
- Diff: turnover_diff, free_turnover_diff, con_rating_strength_diff
- Note: log_size_diff is removed by ColumnRemover
"""
# Alpha158 neutralized columns (158)
alpha158_ntrl = [f"{c}_ntrl" for c in ALPHA158_COLS]
# Alpha158 raw columns (158)
alpha158_raw = ALPHA158_COLS.copy()
# market_ext columns (7 after ColumnRemover)
# After Diff: 4 base + 4 diff = 8
# After ColumnRemover (removes log_size_diff): 7 remain
market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
market_ext_diff = ['turnover_diff', 'free_turnover_diff', 'log_size_diff', 'con_rating_strength_diff']
market_ext_all = market_ext_base + market_ext_diff
market_ext_final = [c for c in market_ext_all if c != 'log_size_diff']
# market_ext neutralized columns (7)
market_ext_ntrl = [f"{c}_ntrl" for c in market_ext_final]
# market_ext raw columns (7)
market_ext_raw = market_ext_final.copy()
# Combine all feature columns in Qlib order
feature_cols = alpha158_ntrl + alpha158_raw + market_ext_ntrl + market_ext_raw
print(f"\nBuilt feature column names:")
print(f" alpha158_ntrl: {len(alpha158_ntrl)} features")
print(f" alpha158_raw: {len(alpha158_raw)} features")
print(f" market_ext_ntrl: {len(market_ext_ntrl)} features")
print(f" market_ext_raw: {len(market_ext_raw)} features")
print(f" Total: {len(feature_cols)} features")
return feature_cols
def save_parameters(
params: dict,
feature_cols: list,
output_dir: str,
version: str
):
"""
Save extracted parameters to output directory.
Creates:
- mean_train.npy
- std_train.npy
- metadata.json
"""
output_path = Path(output_dir) / version
output_path.mkdir(parents=True, exist_ok=True)
print(f"\nSaving parameters to: {output_path}")
# Save mean_train.npy
mean_path = output_path / "mean_train.npy"
np.save(mean_path, params['mean_train'])
print(f" Saved mean_train to: {mean_path}")
# Save std_train.npy
std_path = output_path / "std_train.npy"
np.save(std_path, params['std_train'])
print(f" Saved std_train to: {std_path}")
# Build metadata
metadata = {
'version': version,
'created_at': datetime.now().isoformat(),
'source_file': str(params.get('fit_start_time', 'unknown')),
'fit_start_time': str(params['fit_start_time']),
'fit_end_time': str(params['fit_end_time']),
'fields_group': list(params['fields_group']) if params['fields_group'] else None,
'feature_columns': {
'alpha158_ntrl': ALPHA158_COLS, # Store base names, _ntrl is implied
'alpha158_raw': ALPHA158_COLS,
'market_ext_ntrl': [
'turnover', 'free_turnover', 'log_size', 'con_rating_strength',
'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'
],
'market_ext_raw': [
'turnover', 'free_turnover', 'log_size', 'con_rating_strength',
'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'
],
},
'feature_count': {
'alpha158_ntrl': 158,
'alpha158_raw': 158,
'market_ext_ntrl': 7,
'market_ext_raw': 7,
'total': 330
},
'parameter_shapes': {
'mean_train': list(params['mean_train'].shape),
'std_train': list(params['std_train'].shape)
}
}
# Save metadata.json
metadata_path = output_path / "metadata.json"
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
print(f" Saved metadata to: {metadata_path}")
return output_path
def main():
import argparse
parser = argparse.ArgumentParser(
description="Extract RobustZScoreNorm parameters from Qlib's proc_list.proc"
)
parser.add_argument(
'--proc-list',
type=str,
default=DEFAULT_PROC_LIST_PATH,
help=f"Path to proc_list.proc (default: {DEFAULT_PROC_LIST_PATH})"
)
parser.add_argument(
'--version',
type=str,
default=DEFAULT_VERSION,
help=f"Version name for output directory (default: {DEFAULT_VERSION})"
)
parser.add_argument(
'--output-dir',
type=str,
default=str(Path(__file__).parent.parent / "data" / "robust_zscore_params"),
help="Output directory for parameter files"
)
args = parser.parse_args()
print("=" * 80)
print("Extract Qlib RobustZScoreNorm Parameters")
print("=" * 80)
print(f"Source: {args.proc_list}")
print(f"Version: {args.version}")
print(f"Output: {args.output_dir}")
print()
# Step 1: Extract parameters from proc_list.proc
params = extract_robust_zscore_params(args.proc_list)
# Step 2: Build feature column names
feature_cols = build_feature_column_names()
# Step 3: Verify parameter shape matches feature count
if params['mean_train'].shape[0] != len(feature_cols):
print(f"\nWARNING: Parameter shape mismatch!")
print(f" Expected: {len(feature_cols)} features")
print(f" Got: {params['mean_train'].shape[0]} parameters")
else:
print(f"\n✓ Parameter shape matches feature count ({len(feature_cols)})")
# Step 4: Save parameters
output_path = save_parameters(params, feature_cols, args.output_dir, args.version)
print("\n" + "=" * 80)
print("Extraction complete!")
print("=" * 80)
print(f"Output files:")
print(f" {output_path}/mean_train.npy")
print(f" {output_path}/std_train.npy")
print(f" {output_path}/metadata.json")
print()
print("Usage in Polars RobustZScoreNorm:")
print(f' norm = RobustZScoreNorm.from_version("{args.version}")')
print("=" * 80)
if __name__ == "__main__":
main()

@ -74,38 +74,25 @@ INDUSTRY_FLAG_COLS = [
'gds_CC50', 'gds_CC60', 'gds_CC61', 'gds_CC62', 'gds_CC63', 'gds_CC70'
]
# Stock universe filter: csiallx = All A-shares excluding BSE/NEEQ and STAR market
# This matches the original qlib handler configuration
# - Include: SH600xxx, SH601xxx, SH603xxx, SH605xxx (Shanghai Main Board)
# - Include: SZ000xxx, SZ001xxx, SZ002xxx, SZ003xxx (Shenzhen Main Board)
# - Include: SZ300xxx, SZ301xxx (ChiNext)
# - Exclude: SH688xxx, SH689xxx (STAR Market/科创板)
# - Exclude: 4xxxxx, 8xxxxx (BSE/NEEQ/北交所/新三板)
def filter_stock_universe(df: pl.DataFrame) -> pl.DataFrame:
"""
Filter dataframe to csiallx stock universe (A-shares only).
This filter matches the original qlib handler configuration which excludes:
- BSE/NEEQ stocks (4xxxxx, 8xxxxx)
- STAR Market stocks (688xxx, 689xxx)
def filter_stock_universe(df: pl.DataFrame, instruments: str = 'csiallx') -> pl.DataFrame:
"""
inst_str = pl.col('instrument').cast(pl.String).str.zfill(6)
Filter dataframe to csiallx stock universe (A-shares excluding STAR/BSE) using qshare spine functions.
# Define inclusion patterns
is_sh_main = inst_str.str.starts_with('60') | inst_str.str.starts_with('61')
is_sz_main = inst_str.str.starts_with('0')
is_chi_next = inst_str.str.starts_with('300') | inst_str.str.starts_with('301')
This uses qshare's filter_instruments which loads the instrument list from:
/data/qlib/default/data_ops/target/instruments/csiallx.txt
# Define exclusion patterns (explicitly exclude these)
is_star = inst_str.str.starts_with('688') | inst_str.str.starts_with('689')
is_bseeq = inst_str.str.starts_with('4') | inst_str.str.starts_with('8')
Args:
df: Input DataFrame with datetime and instrument columns
instruments: Market name for spine creation (default: 'csiallx')
# Filter: include main boards and ChiNext, exclude STAR and BSE/NEEQ
df = df.filter(
(is_sh_main | is_sz_main | is_chi_next) &
(~is_star) &
(~is_bseeq)
)
Returns:
Filtered DataFrame with only instruments in the specified universe
"""
from qshare.algo.polars.spine import filter_instruments
# Use qshare's filter_instruments with csiallx market name
df = filter_instruments(df, instruments=instruments)
return df

Loading…
Cancel
Save