From 89bd1a528edfe5bc855dce2a25f65ffdf52af018 Mon Sep 17 00:00:00 2001
From: guofu
Date: Sun, 1 Mar 2026 12:56:44 +0800
Subject: [PATCH] Extract RobustZScoreNorm parameters and add from_version()
method
- Add extract_qlib_params.py script to extract pre-fitted mean/std parameters
from Qlib's proc_list.proc and save as reusable .npy files with metadata.json
- Add RobustZScoreNorm.from_version() class method to load saved parameters
by version name, supporting multiple parameter versions coexistence
- Update dump_polars_dataset.py to use from_version() instead of loading
parameters directly from proc_list.proc
- Update generate_beta_embedding.py to use qshare's filter_instruments for
stock universe filtering
- Save parameters to data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/
with 330 features (158 alpha158_ntrl + 158 alpha158_raw + 7 market_ext_ntrl + 7 market_ext_raw)
- Update README.md with documentation for parameter extraction and usage
Co-Authored-By: Claude Opus 4.6
---
cta_1d/src/processors/__init__.py | 753 ++++++++++++++++++
stock_1d/d033/alpha158_beta/README.md | 59 +-
.../metadata.json | 366 +++++++++
.../scripts/dump_polars_dataset.py | 92 ++-
.../scripts/extract_qlib_params.py | 305 +++++++
.../scripts/generate_beta_embedding.py | 45 +-
6 files changed, 1553 insertions(+), 67 deletions(-)
create mode 100644 cta_1d/src/processors/__init__.py
create mode 100644 stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json
create mode 100644 stock_1d/d033/alpha158_beta/scripts/extract_qlib_params.py
diff --git a/cta_1d/src/processors/__init__.py b/cta_1d/src/processors/__init__.py
new file mode 100644
index 0000000..da8476e
--- /dev/null
+++ b/cta_1d/src/processors/__init__.py
@@ -0,0 +1,753 @@
+"""
+Polars-based data processors for financial feature transformation.
+
+This module provides Polars implementations of Qlib-style data processors
+used in the data_ops pipeline. Each processor follows a consistent interface:
+- Takes a Polars DataFrame as input
+- Returns a transformed Polars DataFrame
+
+Processors are organized by category:
+- Feature Engineering: DiffProcessor
+- Flag Injection: FlagMarketInjector, FlagSTInjector
+- Column Operations: ColumnRemover, FlagToOnehot
+- Normalization: IndusNtrlInjector, RobustZScoreNorm
+- Data Cleaning: Fillna
+"""
+
+from dataclasses import dataclass, field
+from typing import List, Tuple, Optional, Dict
+import numpy as np
+import polars as pl
+
+
+__all__ = [
+ # Feature Engineering
+ 'DiffProcessor',
+ # Flag Injection
+ 'FlagMarketInjector',
+ 'FlagSTInjector',
+ # Column Operations
+ 'ColumnRemover',
+ 'FlagToOnehot',
+ # Normalization
+ 'IndusNtrlInjector',
+ 'RobustZScoreNorm',
+ # Data Cleaning
+ 'Fillna',
+]
+
+
+# =============================================================================
+# Feature Engineering Processors
+# =============================================================================
+
+class DiffProcessor:
+ """
+ Calculate period-over-period differences within each instrument.
+
+ For each specified column, calculates the diff(periods) within each
+ instrument group, using forward-fill for NaN handling.
+
+ Attributes:
+ columns: List of columns to calculate diff for
+ suffix: Suffix to append to diff column names (default: 'diff')
+ periods: Number of periods to diff (default: 1)
+
+ Example:
+ >>> processor = DiffProcessor(['turnover', 'log_size'])
+ >>> df = processor.process(df)
+ >>> # Creates: turnover_diff, log_size_diff
+ """
+
+ def __init__(
+ self,
+ columns: List[str],
+ suffix: str = 'diff',
+ periods: int = 1
+ ):
+ self.columns = columns
+ self.suffix = suffix
+ self.periods = periods
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Add diff features for specified columns.
+
+ Args:
+ df: Input DataFrame with datetime and instrument columns
+
+ Returns:
+ DataFrame with original columns + diff columns with suffix
+ """
+ # Sort by instrument and datetime for correct diff calculation
+ df = df.sort(['instrument', 'datetime'])
+
+ # Add diff for each column
+ for col in self.columns:
+ if col in df.columns:
+ diff_col = f"{col}_{self.suffix}"
+ df = df.with_columns([
+ pl.col(col)
+ .diff(self.periods)
+ .over('instrument')
+ .alias(diff_col)
+ ])
+
+ return df
+
+ def __repr__(self) -> str:
+ return f"DiffProcessor(columns={self.columns}, suffix='{self.suffix}', periods={self.periods})"
+
+
+# =============================================================================
+# Flag Injection Processors
+# =============================================================================
+
+class FlagMarketInjector:
+ """
+ Inject market sector flags based on instrument codes.
+
+ Classifies stocks into market segments based on their instrument codes.
+ Supports both formats:
+ - With exchange prefix: SH600000, SZ000001, SH688000, SZ300001
+ - Numeric only: 600000, 000001, 688000, 300001
+
+ Market classification:
+ - market_0: Main board (SH60xxx, SZ00xxx, or 6xxxxx, 0xxxxx) - 主板
+ - market_1: STAR/ChiNext (SH688xxx, SH689xxx, SZ300xxx, SZ301xxx,
+ or 688xxx, 689xxx, 300xxx, 301xxx) - 科创板/创业板
+
+ Note: Does NOT include 新三板/北交所 (NE40xxx, NE42xxx, NE43xxx, NE8xxx)
+ in the classification - these stocks will have 0 for both flags.
+
+ Example:
+ >>> processor = FlagMarketInjector()
+ >>> df = processor.process(df)
+ >>> # Adds: market_0, market_1 (both int8)
+ """
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Add market classification columns.
+
+ Args:
+ df: Input DataFrame with instrument column
+
+ Returns:
+ DataFrame with market_0, market_1 columns added
+ """
+ # Convert instrument to string
+ inst_col = pl.col('instrument').cast(pl.String)
+
+ # Remove exchange prefix if present (SH/SZ/NE -> numeric part)
+ # E.g., SH600000 -> 600000, SZ000001 -> 000001
+ inst_numeric = inst_col.str.replace_all("^(SH|SZ|NE)", "")
+
+ # Get first digit(s) for market classification
+ first_digit = inst_numeric.str.slice(0, 1)
+ first_three = inst_numeric.str.slice(0, 3)
+
+ # market_0: Main board (60xxx, 601xxx, 603xxx, 000xxx, 001xxx, 002xxx)
+ # Excludes STAR market (688xxx, 689xxx) which start with '6' but are not main board
+ is_sh_main = (first_digit == '6') & ~(first_three == '688') & ~(first_three == '689')
+ is_sz_main = first_digit == '0'
+
+ # market_1: STAR/ChiNext (688xxx, 689xxx, 300xxx, 301xxx)
+ is_sh_star = (first_three == '688') | (first_three == '689')
+ is_sz_chi = (first_three == '300') | (first_three == '301')
+
+ df = df.with_columns([
+ # market_0 = 主板
+ (is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'),
+ # market_1 = 科创板 + 创业板
+ (is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1')
+ ])
+
+ return df
+
+ def __repr__(self) -> str:
+ return "FlagMarketInjector()"
+
+
+class FlagSTInjector:
+ """
+ Inject ST (Special Treatment) flag.
+
+ Creates IsST column from ST_S and ST_Y flags:
+ - IsST = 1 if ST_S or ST_Y is True (stock is ST or *ST)
+ - IsST = 0 otherwise
+
+ If ST flags are not available, creates a placeholder column (all zeros).
+
+ Attributes:
+ mark_st_as: Value to mark ST stocks as (default: 1)
+
+ Example:
+ >>> processor = FlagSTInjector()
+ >>> df = processor.process(df)
+ >>> # Adds: IsST (int8)
+ """
+
+ def __init__(self, mark_st_as: int = 1):
+ self.mark_st_as = mark_st_as
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Add IsST column from ST flags.
+
+ Args:
+ df: Input DataFrame with ST_S, ST_Y columns (or without for placeholder)
+
+ Returns:
+ DataFrame with IsST column added
+ """
+ # Check if ST flags are available
+ if 'ST_S' in df.columns and 'ST_Y' in df.columns:
+ # Create IsST from actual ST flags
+ df = df.with_columns([
+ ((pl.col('ST_S').cast(pl.Boolean, strict=False) |
+ pl.col('ST_Y').cast(pl.Boolean, strict=False))
+ .cast(pl.Int8).alias('IsST'))
+ ])
+ else:
+ # Create placeholder (all zeros) if ST flags not available
+ df = df.with_columns([
+ pl.lit(0).cast(pl.Int8).alias('IsST')
+ ])
+
+ return df
+
+ def __repr__(self) -> str:
+ return f"FlagSTInjector(mark_st_as={self.mark_st_as})"
+
+
+# =============================================================================
+# Column Operation Processors
+# =============================================================================
+
+class ColumnRemover:
+ """
+ Remove specified columns from the DataFrame.
+
+ Attributes:
+ columns_to_remove: List of column names to drop
+
+ Example:
+ >>> processor = ColumnRemover(['log_size_diff', 'IsN', 'IsZt', 'IsDt'])
+ >>> df = processor.process(df)
+ >>> # Removes specified columns
+ """
+
+ def __init__(self, columns_to_remove: List[str]):
+ self.columns_to_remove = columns_to_remove
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Remove specified columns.
+
+ Args:
+ df: Input DataFrame
+
+ Returns:
+ DataFrame without specified columns
+ """
+ # Only remove columns that exist
+ cols_to_drop = [c for c in self.columns_to_remove if c in df.columns]
+ if cols_to_drop:
+ df = df.drop(cols_to_drop)
+
+ return df
+
+ def __repr__(self) -> str:
+ return f"ColumnRemover(columns_to_remove={self.columns_to_remove})"
+
+
+class FlagToOnehot:
+ """
+ Convert flag columns to one-hot encoded index.
+
+ For multiple one-hot encoded columns, finds which flag is set and
+ returns the corresponding index. Uses -1 as default for rows with
+ no flags set.
+
+ Attributes:
+ flag_columns: List of boolean flag column names
+
+ Example:
+ >>> processor = FlagToOnehot(['gds_CC10', 'gds_CC11', ...])
+ >>> df = processor.process(df)
+ >>> # Adds: indus_idx (index of first True flag, or -1)
+ """
+
+ def __init__(self, flag_columns: List[str]):
+ self.flag_columns = flag_columns
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Convert flag columns to single index column.
+
+ Args:
+ df: Input DataFrame with flag columns
+
+ Returns:
+ DataFrame with indus_idx column added, original flags dropped
+ """
+ # Build a when/then chain to find the industry index
+ # Start with -1 (no industry) as default
+ indus_expr = pl.lit(-1)
+
+ for idx, col in enumerate(self.flag_columns):
+ if col in df.columns:
+ indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr)
+
+ df = df.with_columns([indus_expr.alias('indus_idx')])
+
+ # Drop the original one-hot columns
+ cols_to_drop = [c for c in self.flag_columns if c in df.columns]
+ if cols_to_drop:
+ df = df.drop(cols_to_drop)
+
+ return df
+
+ def __repr__(self) -> str:
+ return f"FlagToOnehot(flag_columns={len(self.flag_columns)} columns)"
+
+
+# =============================================================================
+# Normalization Processors
+# =============================================================================
+
+class IndusNtrlInjector:
+ """
+ Industry neutralization for features.
+
+ For each feature, subtracts the industry mean (grouped by indus_idx
+ within each datetime) from the feature value. Creates new columns
+ with the specified suffix while keeping original columns.
+
+ This performs cross-sectional neutralization per datetime, matching
+ Qlib's cal_indus_ntrl behavior.
+
+ Attributes:
+ feature_cols: List of feature columns to neutralize
+ suffix: Suffix for neutralized column names (default: '_ntrl')
+
+ Example:
+ >>> processor = IndusNtrlInjector(['KMID', 'KLEN'], suffix='_ntrl')
+ >>> df = processor.process(df)
+ >>> # Adds: KMID_ntrl, KLEN_ntrl
+ """
+
+ def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'):
+ self.feature_cols = feature_cols
+ self.suffix = suffix
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Apply industry neutralization to specified features.
+
+ Args:
+ df: Input DataFrame with feature columns and indus_idx
+
+ Returns:
+ DataFrame with neutralized columns added (_ntrl suffix)
+ """
+ # Filter to only columns that exist
+ existing_cols = [c for c in self.feature_cols if c in df.columns]
+
+ for col in existing_cols:
+ ntrl_col = f"{col}{self.suffix}"
+ # Calculate industry mean PER DATETIME and subtract from feature
+ # Use group_by().transform() for proper group-wise operation
+ df = df.with_columns([
+ (pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'])).alias(ntrl_col)
+ ])
+
+ return df
+
+ def __repr__(self) -> str:
+ return f"IndusNtrlInjector(feature_cols={len(self.feature_cols)} columns, suffix='{self.suffix}')"
+
+
+@dataclass
+class RobustZScoreNorm:
+ """
+ Robust z-score normalization using median/MAD.
+
+ Applies the transformation: (x - median) / (1.4826 * MAD)
+ where MAD = median(|x - median|)
+
+ Supports two modes:
+ 1. Per-datetime computation (default): Calculates median/MAD for each datetime
+ 2. Pre-fitted parameters: Uses provided mean/std arrays (from Qlib processor)
+
+ Attributes:
+ feature_cols: List of feature columns to normalize
+ clip_range: Clip normalized values to this range (default: (-3, 3))
+ use_qlib_params: Use pre-fitted parameters (default: False)
+ qlib_mean: Pre-fitted mean array (required if use_qlib_params=True)
+ qlib_std: Pre-fitted std array (required if use_qlib_params=True)
+
+ Example:
+ # Using pre-fitted Qlib parameters
+ >>> processor = RobustZScoreNorm(
+ ... feature_cols=['KMID', 'KLEN'],
+ ... use_qlib_params=True,
+ ... qlib_mean=mean_array,
+ ... qlib_std=std_array
+ ... )
+ >>> df = processor.process(df)
+
+ # Loading parameters from saved version
+ >>> processor = RobustZScoreNorm.from_version("csiallx_feature2_ntrla_flag_pnlnorm")
+ >>> df = processor.process(df)
+ """
+
+ feature_cols: List[str]
+ clip_range: Tuple[float, float] = (-3.0, 3.0)
+ use_qlib_params: bool = False
+ qlib_mean: Optional[np.ndarray] = None
+ qlib_std: Optional[np.ndarray] = None
+
+ @classmethod
+ def from_version(
+ cls,
+ version: str,
+ feature_cols: Optional[List[str]] = None,
+ clip_range: Tuple[float, float] = (-3.0, 3.0),
+ params_dir: str = None
+ ) -> "RobustZScoreNorm":
+ """
+ Create a RobustZScoreNorm instance from saved parameters by version name.
+
+ This loads pre-extracted mean_train and std_train from the parameter
+ directory structure:
+ {params_dir}/{version}/
+ ├── mean_train.npy
+ ├── std_train.npy
+ └── metadata.json
+
+ Args:
+ version: Version name (e.g., "csiallx_feature2_ntrla_flag_pnlnorm")
+ feature_cols: Optional list of feature columns. If None, uses the
+ order from metadata.json (alpha158_ntrl + alpha158_raw +
+ market_ext_ntrl + market_ext_raw)
+ clip_range: Clip range for normalized values (default: (-3, 3))
+ params_dir: Base directory for parameter versions. If None, uses:
+ stock_1d/d033/alpha158_beta/data/robust_zscore_params/
+
+ Returns:
+ RobustZScoreNorm instance with pre-fitted parameters loaded
+
+ Raises:
+ FileNotFoundError: If version directory or parameter files not found
+ ValueError: If feature column count doesn't match parameter shape
+
+ Example:
+ >>> processor = RobustZScoreNorm.from_version(
+ ... "csiallx_feature2_ntrla_flag_pnlnorm"
+ ... )
+ >>> df = processor.process(df)
+ """
+ import json
+ from pathlib import Path
+
+ # Set default params_dir
+ if params_dir is None:
+ # Default to the standard location
+ # Go from cta_1d/src/processors/ to alpha_lab/stock_1d/d033/alpha158_beta/data/robust_zscore_params/
+ params_dir = Path(__file__).parent.parent.parent.parent / \
+ "stock_1d" / "d033" / "alpha158_beta" / "data" / "robust_zscore_params"
+ else:
+ params_dir = Path(params_dir)
+
+ version_dir = params_dir / version
+
+ # Check version directory exists
+ if not version_dir.exists():
+ raise FileNotFoundError(
+ f"Version directory not found: {version_dir}\n"
+ f"Available versions should be in: {params_dir}"
+ )
+
+ # Load mean_train.npy
+ mean_path = version_dir / "mean_train.npy"
+ if not mean_path.exists():
+ raise FileNotFoundError(f"mean_train.npy not found: {mean_path}")
+ mean_train = np.load(mean_path)
+
+ # Load std_train.npy
+ std_path = version_dir / "std_train.npy"
+ if not std_path.exists():
+ raise FileNotFoundError(f"std_train.npy not found: {std_path}")
+ std_train = np.load(std_path)
+
+ # Load metadata.json for feature column names
+ metadata_path = version_dir / "metadata.json"
+ if metadata_path.exists():
+ with open(metadata_path, 'r') as f:
+ metadata = json.load(f)
+
+ # Build feature columns from metadata if not provided
+ if feature_cols is None:
+ feature_columns = metadata.get('feature_columns', {})
+ alpha158_ntrl = [f"{c}_ntrl" for c in feature_columns.get('alpha158_ntrl', [])]
+ alpha158_raw = feature_columns.get('alpha158_raw', [])
+ market_ext_ntrl = [f"{c}_ntrl" for c in feature_columns.get('market_ext_ntrl', [])]
+ market_ext_raw = feature_columns.get('market_ext_raw', [])
+
+ feature_cols = alpha158_ntrl + alpha158_raw + market_ext_ntrl + market_ext_raw
+
+ # Validate feature column count matches parameter shape
+ expected_count = len(mean_train)
+ if feature_cols and len(feature_cols) != expected_count:
+ raise ValueError(
+ f"Feature column count ({len(feature_cols)}) does not match "
+ f"parameter shape ({expected_count})"
+ )
+
+ print(f"Loaded RobustZScoreNorm parameters from version '{version}':")
+ print(f" mean_train shape: {mean_train.shape}")
+ print(f" std_train shape: {std_train.shape}")
+ print(f" feature_cols: {len(feature_cols) if feature_cols else 'not specified'}")
+
+ return cls(
+ feature_cols=feature_cols,
+ clip_range=clip_range,
+ use_qlib_params=True,
+ qlib_mean=mean_train,
+ qlib_std=std_train
+ )
+
+ def __post_init__(self):
+ """Validate parameters after initialization."""
+ if self.use_qlib_params:
+ if self.qlib_mean is None or self.qlib_std is None:
+ raise ValueError(
+ "Must provide qlib_mean and qlib_std when use_qlib_params=True"
+ )
+ # Convert to numpy arrays if not already
+ self.qlib_mean = np.asarray(self.qlib_mean)
+ self.qlib_std = np.asarray(self.qlib_std)
+
+ def process(self, df: pl.DataFrame) -> pl.DataFrame:
+ """
+ Apply robust z-score normalization.
+
+ Args:
+ df: Input DataFrame with feature columns
+
+ Returns:
+ DataFrame with normalized columns (in-place modification)
+ """
+ # Filter to only columns that exist
+ existing_cols = [c for c in self.feature_cols if c in df.columns]
+
+ if self.use_qlib_params:
+ # Use pre-fitted parameters (fit once, apply to all dates)
+ for i, col in enumerate(existing_cols):
+ if i < len(self.qlib_mean):
+ mean_val = float(self.qlib_mean[i])
+ std_val = float(self.qlib_std[i])
+
+ # Apply z-score normalization using pre-fitted params
+ df = df.with_columns([
+ ((pl.col(col) - mean_val) / (std_val + 1e-8))
+ .clip(self.clip_range[0], self.clip_range[1])
+ .alias(col)
+ ])
+ else:
+ # Compute per-datetime robust z-score
+ for col in existing_cols:
+ # Compute median per datetime
+ median_col = f"__median_{col}"
+ df = df.with_columns([
+ pl.col(col).median().over('datetime').alias(median_col)
+ ])
+
+ # Compute absolute deviation
+ abs_dev_col = f"__absdev_{col}"
+ df = df.with_columns([
+ (pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col)
+ ])
+
+ # Compute MAD (median of absolute deviations)
+ mad_col = f"__mad_{col}"
+ df = df.with_columns([
+ pl.col(abs_dev_col).median().over('datetime').alias(mad_col)
+ ])
+
+ # Compute robust z-score and clip
+ df = df.with_columns([
+ ((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8))
+ .clip(self.clip_range[0], self.clip_range[1])
+ .alias(col)
+ ])
+
+ # Clean up temporary columns
+ df = df.drop([median_col, abs_dev_col, mad_col])
+
+ return df
+
+
+# =============================================================================
+# Data Cleaning Processors
+# =============================================================================
+
+class Fillna:
+ """
+ Fill NaN values with specified value.
+
+ Fills NaN/None values in specified columns with the fill_value.
+ Only processes numeric columns (Float32, Float64, Int32, Int64, UInt32, UInt64).
+
+ Attributes:
+ fill_value: Value to fill NaN with (default: 0.0)
+
+ Example:
+ >>> processor = Fillna(fill_value=0.0)
+ >>> df = processor.process(df, ['KMID', 'KLEN'])
+ """
+
+ def __init__(self, fill_value: float = 0.0):
+ self.fill_value = fill_value
+
+ def process(self, df: pl.DataFrame, columns: List[str]) -> pl.DataFrame:
+ """
+ Fill NaN values in specified columns.
+
+ Args:
+ df: Input DataFrame
+ columns: List of columns to fill NaN in
+
+ Returns:
+ DataFrame with NaN values filled
+ """
+ # Filter to only columns that exist and are numeric
+ existing_cols = [c for c in columns if c in df.columns]
+
+ for col in existing_cols:
+ # Check column dtype
+ dtype = df[col].dtype
+ if dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64,
+ pl.UInt32, pl.UInt64, pl.UInt16, pl.UInt8]:
+ df = df.with_columns([
+ pl.col(col).fill_null(self.fill_value).fill_nan(self.fill_value)
+ ])
+
+ return df
+
+ def __repr__(self) -> str:
+ return f"Fillna(fill_value={self.fill_value})"
+
+
+# =============================================================================
+# Processor Pipeline Utilities
+# =============================================================================
+
+def create_processor_pipeline(
+ alpha158_cols: List[str],
+ market_ext_base: List[str],
+ market_flag_cols: List[str],
+ industry_flag_cols: List[str],
+ columns_to_remove: Optional[List[str]] = None,
+ ntrl_suffix: str = '_ntrl'
+) -> List:
+ """
+ Create a complete processor pipeline configuration.
+
+ This factory function creates processors in the correct order:
+ 1. DiffProcessor - adds diff features
+ 2. FlagMarketInjector - adds market_0, market_1
+ 3. FlagSTInjector - adds IsST
+ 4. ColumnRemover - removes specified columns
+ 5. FlagToOnehot - converts industry flags to index
+ 6. IndusNtrlInjector (x2) - neutralizes alpha158 and market_ext
+ 7. RobustZScoreNorm - normalizes features
+ 8. Fillna - fills NaN values
+
+ Args:
+ alpha158_cols: List of alpha158 feature names
+ market_ext_base: List of market extension base columns
+ market_flag_cols: List of market flag columns
+ industry_flag_cols: List of industry flag columns
+ columns_to_remove: Columns to remove (default: ['log_size_diff', 'IsN', 'IsZt', 'IsDt'])
+ ntrl_suffix: Suffix for neutralized columns
+
+ Returns:
+ List of processor instances in execution order
+ """
+ if columns_to_remove is None:
+ columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
+
+ return [
+ DiffProcessor(market_ext_base),
+ FlagMarketInjector(),
+ FlagSTInjector(),
+ ColumnRemover(columns_to_remove),
+ FlagToOnehot(industry_flag_cols),
+ IndusNtrlInjector(alpha158_cols, suffix=ntrl_suffix),
+ IndusNtrlInjector(market_ext_base, suffix=ntrl_suffix),
+ # RobustZScoreNorm and Fillna require fitted parameters
+ # and should be added separately
+ ]
+
+
+def get_final_feature_columns(
+ alpha158_cols: List[str],
+ market_ext_base: List[str],
+ market_flag_cols: List[str],
+ columns_to_remove: Optional[List[str]] = None,
+ ntrl_suffix: str = '_ntrl'
+) -> Dict[str, List[str]]:
+ """
+ Get the final feature column structure after processing.
+
+ This is useful for determining the expected VAE input dimensions
+ and verifying feature order.
+
+ Args:
+ alpha158_cols: List of alpha158 feature names
+ market_ext_base: List of market extension base columns (before Diff)
+ market_flag_cols: List of market flag columns (before ColumnRemover)
+ columns_to_remove: Columns to remove
+ ntrl_suffix: Suffix for neutralized columns
+
+ Returns:
+ Dictionary with feature groups:
+ - 'norm_feature_cols': Features to normalize (in order)
+ - 'market_flag_cols': Market flag columns after processing
+ - 'all_feature_cols': All feature columns including indus_idx
+ """
+ if columns_to_remove is None:
+ columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
+
+ # After Diff: market_ext becomes base + diff
+ market_ext_after_diff = market_ext_base + [f"{c}_diff" for c in market_ext_base]
+
+ # After ColumnRemover: remove specified columns
+ market_ext_final = [c for c in market_ext_after_diff if c not in columns_to_remove]
+ market_flag_final = [c for c in market_flag_cols if c not in columns_to_remove]
+
+ # After FlagMarketInjector: add market_0, market_1
+ market_flag_final.extend(['market_0', 'market_1'])
+
+ # After FlagSTInjector: add IsST
+ market_flag_final.append('IsST')
+
+ # Build normalized feature columns (in Qlib order: ntrl + raw for each group)
+ alpha158_ntrl = [f"{c}{ntrl_suffix}" for c in alpha158_cols]
+ market_ext_ntrl = [f"{c}{ntrl_suffix}" for c in market_ext_final]
+
+ # Normalization order: alpha158_ntrl + alpha158 + market_ext_ntrl + market_ext
+ norm_feature_cols = alpha158_ntrl + alpha158_cols + market_ext_ntrl + market_ext_final
+
+ return {
+ 'alpha158_cols': alpha158_cols,
+ 'alpha158_ntrl_cols': alpha158_ntrl,
+ 'market_ext_cols': market_ext_final,
+ 'market_ext_ntrl_cols': market_ext_ntrl,
+ 'market_flag_cols': market_flag_final,
+ 'norm_feature_cols': norm_feature_cols,
+ 'all_feature_cols': norm_feature_cols + market_flag_final + ['indus_idx'],
+ }
diff --git a/stock_1d/d033/alpha158_beta/README.md b/stock_1d/d033/alpha158_beta/README.md
index bf6fdf6..c1b678a 100644
--- a/stock_1d/d033/alpha158_beta/README.md
+++ b/stock_1d/d033/alpha158_beta/README.md
@@ -19,18 +19,27 @@ stock_1d/d033/alpha158_beta/
│ ├── fetch_predictions.py # Fetch original predictions from DolphinDB
│ ├── predict_with_embedding.py # Generate predictions using beta embeddings
│ ├── compare_predictions.py # Compare 0_7 vs 0_7_beta predictions
-│ └── dump_polars_dataset.py # Dump raw and processed datasets using polars pipeline
+│ ├── dump_polars_dataset.py # Dump raw and processed datasets using polars pipeline
+│ └── extract_qlib_params.py # Extract RobustZScoreNorm parameters from Qlib proc_list
├── src/ # Source modules
│ └── qlib_loader.py # Qlib data loader with configurable date range
├── config/ # Configuration files
│ └── handler.yaml # Modified handler with configurable end date
-└── data/ # Data files (gitignored)
- ├── embedding_0_7_beta.parquet
- ├── predictions_beta_embedding.parquet
- ├── original_predictions_0_7.parquet
- ├── actual_returns.parquet
- ├── raw_data_*.pkl # Raw data before preprocessing
- └── processed_data_*.pkl # Processed data after preprocessing
+├── data/ # Data files (gitignored)
+│ ├── robust_zscore_params/ # Pre-fitted normalization parameters
+│ │ └── csiallx_feature2_ntrla_flag_pnlnorm/
+│ │ ├── mean_train.npy
+│ │ ├── std_train.npy
+│ │ └── metadata.json
+│ ├── embedding_0_7_beta.parquet
+│ ├── predictions_beta_embedding.parquet
+│ ├── original_predictions_0_7.parquet
+│ ├── actual_returns.parquet
+│ ├── raw_data_*.pkl # Raw data before preprocessing
+│ └── processed_data_*.pkl # Processed data after preprocessing
+└── data_polars/ # Polars-generated datasets (gitignored)
+ ├── raw_data_*.pkl
+ └── processed_data_*.pkl
```
## Data Loading with Configurable Date Range
@@ -122,7 +131,7 @@ This script:
- ColumnRemover (removes log_size_diff, IsN, IsZt, IsDt)
- FlagToOnehot (converts 29 industry flags to indus_idx)
- IndusNtrlInjector (industry neutralization)
- - RobustZScoreNorm (using pre-fitted qlib parameters)
+ - RobustZScoreNorm (using pre-fitted qlib parameters via `from_version()`)
- Fillna (fill NaN with 0)
4. Saves processed data to `data_polars/processed_data_*.pkl`
@@ -133,6 +142,38 @@ Output structure:
- Processed data: 342 columns (316 feature + 14 feature_ext + 11 feature_flag + 1 indus_idx)
- VAE input dimension: 341 (excluding indus_idx)
+### RobustZScoreNorm Parameter Extraction
+
+The pipeline uses pre-fitted normalization parameters extracted from Qlib's `proc_list.proc` file. These parameters are stored in `data/robust_zscore_params/` and can be loaded using the `RobustZScoreNorm.from_version()` method.
+
+**Extract parameters from Qlib proc_list:**
+
+```bash
+python scripts/extract_qlib_params.py --version csiallx_feature2_ntrla_flag_pnlnorm
+```
+
+This creates:
+- `data/robust_zscore_params/{version}/mean_train.npy` - Pre-fitted mean parameters (330,)
+- `data/robust_zscore_params/{version}/std_train.npy` - Pre-fitted std parameters (330,)
+- `data/robust_zscore_params/{version}/metadata.json` - Feature column names and metadata
+
+**Use in Polars processors:**
+
+```python
+from cta_1d.src.processors import RobustZScoreNorm
+
+# Load pre-fitted parameters by version name
+processor = RobustZScoreNorm.from_version("csiallx_feature2_ntrla_flag_pnlnorm")
+
+# Apply normalization to DataFrame
+df = processor.process(df)
+```
+
+**Parameter details:**
+- Fit period: 2013-01-01 to 2018-12-31
+- Feature count: 330 (158 alpha158_ntrl + 158 alpha158_raw + 7 market_ext_ntrl + 7 market_ext_raw)
+- Fields: ['feature', 'feature_ext']
+
## Workflow
### 1. Generate Beta Embeddings
diff --git a/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json b/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json
new file mode 100644
index 0000000..390a6ce
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/metadata.json
@@ -0,0 +1,366 @@
+{
+ "version": "csiallx_feature2_ntrla_flag_pnlnorm",
+ "created_at": "2026-03-01T12:18:01.969109",
+ "source_file": "2013-01-01",
+ "fit_start_time": "2013-01-01",
+ "fit_end_time": "2018-12-31",
+ "fields_group": [
+ "feature",
+ "feature_ext"
+ ],
+ "feature_columns": {
+ "alpha158_ntrl": [
+ "KMID",
+ "KLEN",
+ "KMID2",
+ "KUP",
+ "KUP2",
+ "KLOW",
+ "KLOW2",
+ "KSFT",
+ "KSFT2",
+ "OPEN0",
+ "HIGH0",
+ "LOW0",
+ "VWAP0",
+ "ROC5",
+ "ROC10",
+ "ROC20",
+ "ROC30",
+ "ROC60",
+ "MA5",
+ "MA10",
+ "MA20",
+ "MA30",
+ "MA60",
+ "STD5",
+ "STD10",
+ "STD20",
+ "STD30",
+ "STD60",
+ "BETA5",
+ "BETA10",
+ "BETA20",
+ "BETA30",
+ "BETA60",
+ "RSQR5",
+ "RSQR10",
+ "RSQR20",
+ "RSQR30",
+ "RSQR60",
+ "RESI5",
+ "RESI10",
+ "RESI20",
+ "RESI30",
+ "RESI60",
+ "MAX5",
+ "MAX10",
+ "MAX20",
+ "MAX30",
+ "MAX60",
+ "MIN5",
+ "MIN10",
+ "MIN20",
+ "MIN30",
+ "MIN60",
+ "QTLU5",
+ "QTLU10",
+ "QTLU20",
+ "QTLU30",
+ "QTLU60",
+ "QTLD5",
+ "QTLD10",
+ "QTLD20",
+ "QTLD30",
+ "QTLD60",
+ "RANK5",
+ "RANK10",
+ "RANK20",
+ "RANK30",
+ "RANK60",
+ "RSV5",
+ "RSV10",
+ "RSV20",
+ "RSV30",
+ "RSV60",
+ "IMAX5",
+ "IMAX10",
+ "IMAX20",
+ "IMAX30",
+ "IMAX60",
+ "IMIN5",
+ "IMIN10",
+ "IMIN20",
+ "IMIN30",
+ "IMIN60",
+ "IMXD5",
+ "IMXD10",
+ "IMXD20",
+ "IMXD30",
+ "IMXD60",
+ "CORR5",
+ "CORR10",
+ "CORR20",
+ "CORR30",
+ "CORR60",
+ "CORD5",
+ "CORD10",
+ "CORD20",
+ "CORD30",
+ "CORD60",
+ "CNTP5",
+ "CNTP10",
+ "CNTP20",
+ "CNTP30",
+ "CNTP60",
+ "CNTN5",
+ "CNTN10",
+ "CNTN20",
+ "CNTN30",
+ "CNTN60",
+ "CNTD5",
+ "CNTD10",
+ "CNTD20",
+ "CNTD30",
+ "CNTD60",
+ "SUMP5",
+ "SUMP10",
+ "SUMP20",
+ "SUMP30",
+ "SUMP60",
+ "SUMN5",
+ "SUMN10",
+ "SUMN20",
+ "SUMN30",
+ "SUMN60",
+ "SUMD5",
+ "SUMD10",
+ "SUMD20",
+ "SUMD30",
+ "SUMD60",
+ "VMA5",
+ "VMA10",
+ "VMA20",
+ "VMA30",
+ "VMA60",
+ "VSTD5",
+ "VSTD10",
+ "VSTD20",
+ "VSTD30",
+ "VSTD60",
+ "WVMA5",
+ "WVMA10",
+ "WVMA20",
+ "WVMA30",
+ "WVMA60",
+ "VSUMP5",
+ "VSUMP10",
+ "VSUMP20",
+ "VSUMP30",
+ "VSUMP60",
+ "VSUMN5",
+ "VSUMN10",
+ "VSUMN20",
+ "VSUMN30",
+ "VSUMN60",
+ "VSUMD5",
+ "VSUMD10",
+ "VSUMD20",
+ "VSUMD30",
+ "VSUMD60"
+ ],
+ "alpha158_raw": [
+ "KMID",
+ "KLEN",
+ "KMID2",
+ "KUP",
+ "KUP2",
+ "KLOW",
+ "KLOW2",
+ "KSFT",
+ "KSFT2",
+ "OPEN0",
+ "HIGH0",
+ "LOW0",
+ "VWAP0",
+ "ROC5",
+ "ROC10",
+ "ROC20",
+ "ROC30",
+ "ROC60",
+ "MA5",
+ "MA10",
+ "MA20",
+ "MA30",
+ "MA60",
+ "STD5",
+ "STD10",
+ "STD20",
+ "STD30",
+ "STD60",
+ "BETA5",
+ "BETA10",
+ "BETA20",
+ "BETA30",
+ "BETA60",
+ "RSQR5",
+ "RSQR10",
+ "RSQR20",
+ "RSQR30",
+ "RSQR60",
+ "RESI5",
+ "RESI10",
+ "RESI20",
+ "RESI30",
+ "RESI60",
+ "MAX5",
+ "MAX10",
+ "MAX20",
+ "MAX30",
+ "MAX60",
+ "MIN5",
+ "MIN10",
+ "MIN20",
+ "MIN30",
+ "MIN60",
+ "QTLU5",
+ "QTLU10",
+ "QTLU20",
+ "QTLU30",
+ "QTLU60",
+ "QTLD5",
+ "QTLD10",
+ "QTLD20",
+ "QTLD30",
+ "QTLD60",
+ "RANK5",
+ "RANK10",
+ "RANK20",
+ "RANK30",
+ "RANK60",
+ "RSV5",
+ "RSV10",
+ "RSV20",
+ "RSV30",
+ "RSV60",
+ "IMAX5",
+ "IMAX10",
+ "IMAX20",
+ "IMAX30",
+ "IMAX60",
+ "IMIN5",
+ "IMIN10",
+ "IMIN20",
+ "IMIN30",
+ "IMIN60",
+ "IMXD5",
+ "IMXD10",
+ "IMXD20",
+ "IMXD30",
+ "IMXD60",
+ "CORR5",
+ "CORR10",
+ "CORR20",
+ "CORR30",
+ "CORR60",
+ "CORD5",
+ "CORD10",
+ "CORD20",
+ "CORD30",
+ "CORD60",
+ "CNTP5",
+ "CNTP10",
+ "CNTP20",
+ "CNTP30",
+ "CNTP60",
+ "CNTN5",
+ "CNTN10",
+ "CNTN20",
+ "CNTN30",
+ "CNTN60",
+ "CNTD5",
+ "CNTD10",
+ "CNTD20",
+ "CNTD30",
+ "CNTD60",
+ "SUMP5",
+ "SUMP10",
+ "SUMP20",
+ "SUMP30",
+ "SUMP60",
+ "SUMN5",
+ "SUMN10",
+ "SUMN20",
+ "SUMN30",
+ "SUMN60",
+ "SUMD5",
+ "SUMD10",
+ "SUMD20",
+ "SUMD30",
+ "SUMD60",
+ "VMA5",
+ "VMA10",
+ "VMA20",
+ "VMA30",
+ "VMA60",
+ "VSTD5",
+ "VSTD10",
+ "VSTD20",
+ "VSTD30",
+ "VSTD60",
+ "WVMA5",
+ "WVMA10",
+ "WVMA20",
+ "WVMA30",
+ "WVMA60",
+ "VSUMP5",
+ "VSUMP10",
+ "VSUMP20",
+ "VSUMP30",
+ "VSUMP60",
+ "VSUMN5",
+ "VSUMN10",
+ "VSUMN20",
+ "VSUMN30",
+ "VSUMN60",
+ "VSUMD5",
+ "VSUMD10",
+ "VSUMD20",
+ "VSUMD30",
+ "VSUMD60"
+ ],
+ "market_ext_ntrl": [
+ "turnover",
+ "free_turnover",
+ "log_size",
+ "con_rating_strength",
+ "turnover_diff",
+ "free_turnover_diff",
+ "con_rating_strength_diff"
+ ],
+ "market_ext_raw": [
+ "turnover",
+ "free_turnover",
+ "log_size",
+ "con_rating_strength",
+ "turnover_diff",
+ "free_turnover_diff",
+ "con_rating_strength_diff"
+ ]
+ },
+ "feature_count": {
+ "alpha158_ntrl": 158,
+ "alpha158_raw": 158,
+ "market_ext_ntrl": 7,
+ "market_ext_raw": 7,
+ "total": 330
+ },
+ "parameter_shapes": {
+ "mean_train": [
+ 330
+ ],
+ "std_train": [
+ 330
+ ]
+ }
+}
\ No newline at end of file
diff --git a/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py b/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py
index cc928df..7b961d0 100644
--- a/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py
+++ b/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py
@@ -20,18 +20,20 @@ from datetime import datetime
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
-from generate_beta_embedding import (
- load_all_data,
- merge_data_sources,
- filter_stock_universe,
+# Import processors from the new shared module
+from cta_1d.src.processors import (
DiffProcessor,
FlagMarketInjector,
+ FlagSTInjector,
ColumnRemover,
FlagToOnehot,
IndusNtrlInjector,
RobustZScoreNorm,
Fillna,
- load_qlib_processor_params,
+)
+
+# Import constants from local module
+from generate_beta_embedding import (
ALPHA158_COLS,
INDUSTRY_FLAG_COLS,
)
@@ -52,7 +54,7 @@ def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
This mimics the qlib proc_list:
0. Diff: Adds diff features for market_ext columns
1. FlagMarketInjector: Adds market_0, market_1
- 2. FlagSTInjector: SKIPPED (fails even in gold-standard)
+ 2. FlagSTInjector: Adds IsST (placeholder if ST flags not available)
3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt
4. FlagToOnehot: Converts 29 industry flags to indus_idx
5. IndusNtrlInjector: Industry neutralization for feature
@@ -88,9 +90,13 @@ def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
# Add market_0, market_1 to flag list
market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
- # Step 3: FlagSTInjector - SKIPPED (fails even in gold-standard)
- print("[3] Skipping FlagSTInjector (as per gold-standard behavior)...")
- market_flag_with_st = market_flag_with_market # No IsST added
+ # Step 3: FlagSTInjector - adds IsST (placeholder if ST flags not available)
+ print("[3] Applying FlagSTInjector...")
+ flag_st_injector = FlagSTInjector()
+ df = flag_st_injector.process(df)
+
+ # Add IsST to flag list
+ market_flag_with_st = market_flag_with_market + ['IsST']
# Step 4: ColumnRemover
print("[4] Applying ColumnRemover...")
@@ -129,21 +135,18 @@ def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
print("[8] Applying RobustZScoreNorm...")
norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols
- qlib_params = load_qlib_processor_params()
+ # Load RobustZScoreNorm with pre-fitted parameters from version
+ robust_norm = RobustZScoreNorm.from_version(
+ "csiallx_feature2_ntrla_flag_pnlnorm",
+ feature_cols=norm_feature_cols
+ )
- # Verify parameter shape
+ # Verify parameter shape matches expected features
expected_features = len(norm_feature_cols)
- if qlib_params['mean_train'].shape[0] != expected_features:
+ if robust_norm.qlib_mean.shape[0] != expected_features:
print(f" WARNING: Feature count mismatch! Expected {expected_features}, "
- f"got {qlib_params['mean_train'].shape[0]}")
-
- robust_norm = RobustZScoreNorm(
- norm_feature_cols,
- clip_range=(-3, 3),
- use_qlib_params=True,
- qlib_mean=qlib_params['mean_train'],
- qlib_std=qlib_params['std_train']
- )
+ f"got {robust_norm.qlib_mean.shape[0]}")
+
df = robust_norm.process(df)
# Step 9: Fillna
@@ -166,6 +169,9 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
"""
Convert polars DataFrame to pandas DataFrame with MultiIndex columns.
This matches the format of qlib's output.
+
+ IMPORTANT: Qlib's IndusNtrlInjector outputs columns in order [_ntrl] + [raw],
+ so we need to reorder columns to match this expected order.
"""
import pandas as pd
@@ -185,6 +191,7 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
df = df.drop(columns=existing_raw_cols)
# Build MultiIndex columns based on column name patterns
+ # IMPORTANT: Qlib order is [_ntrl columns] + [raw columns] for each group
columns_with_group = []
# Define column sets
@@ -195,23 +202,29 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit',
'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'}
+ # First pass: collect _ntrl columns (these come first in qlib order)
+ ntrl_alpha158_cols = []
+ ntrl_market_ext_cols = []
+ raw_alpha158_cols = []
+ raw_market_ext_cols = []
+ flag_cols = []
+ indus_idx_col = None
+
for col in df.columns:
if col == 'indus_idx':
- columns_with_group.append(('indus_idx', col))
+ indus_idx_col = col
elif col in feature_flag_cols:
- columns_with_group.append(('feature_flag', col))
+ flag_cols.append(col)
elif col.endswith('_ntrl'):
base_name = col[:-5] # Remove _ntrl suffix (5 characters)
if base_name in alpha158_base:
- columns_with_group.append(('feature', col))
+ ntrl_alpha158_cols.append(col)
elif base_name in market_ext_all:
- columns_with_group.append(('feature_ext', col))
- else:
- columns_with_group.append(('feature', col)) # Default to feature
+ ntrl_market_ext_cols.append(col)
elif col in alpha158_base:
- columns_with_group.append(('feature', col))
+ raw_alpha158_cols.append(col)
elif col in market_ext_all:
- columns_with_group.append(('feature_ext', col))
+ raw_market_ext_cols.append(col)
elif col in INDUSTRY_FLAG_COLS:
columns_with_group.append(('indus_flag', col))
elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}:
@@ -221,6 +234,27 @@ def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
print(f" Warning: Unknown column '{col}', assigning to 'other' group")
columns_with_group.append(('other', col))
+ # Build columns in qlib order: [_ntrl] + [raw] for each feature group
+ # Feature group: alpha158_ntrl + alpha158
+ for col in sorted(ntrl_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x.replace('_ntrl', '')) if x.replace('_ntrl', '') in ALPHA158_COLS else 999):
+ columns_with_group.append(('feature', col))
+ for col in sorted(raw_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x) if x in ALPHA158_COLS else 999):
+ columns_with_group.append(('feature', col))
+
+ # Feature_ext group: market_ext_ntrl + market_ext
+ for col in ntrl_market_ext_cols:
+ columns_with_group.append(('feature_ext', col))
+ for col in raw_market_ext_cols:
+ columns_with_group.append(('feature_ext', col))
+
+ # Feature_flag group
+ for col in flag_cols:
+ columns_with_group.append(('feature_flag', col))
+
+ # Indus_idx
+ if indus_idx_col:
+ columns_with_group.append(('indus_idx', indus_idx_col))
+
# Create MultiIndex columns
multi_cols = pd.MultiIndex.from_tuples(columns_with_group)
df.columns = multi_cols
diff --git a/stock_1d/d033/alpha158_beta/scripts/extract_qlib_params.py b/stock_1d/d033/alpha158_beta/scripts/extract_qlib_params.py
new file mode 100644
index 0000000..1d73846
--- /dev/null
+++ b/stock_1d/d033/alpha158_beta/scripts/extract_qlib_params.py
@@ -0,0 +1,305 @@
+#!/usr/bin/env python
+"""
+Extract RobustZScoreNorm parameters from Qlib's proc_list.proc file.
+
+This script extracts the pre-fitted mean_train and std_train parameters from
+Qlib's RobustZScoreNorm processor and saves them as reusable .npy files.
+
+The extracted parameters can be used by the Polars RobustZScoreNorm processor
+via the from_version() class method.
+
+Output structure:
+ stock_1d/d033/alpha158_beta/data/robust_zscore_params/
+ └── {version}/
+ ├── mean_train.npy
+ ├── std_train.npy
+ └── metadata.json
+"""
+
+import os
+import sys
+import json
+import pickle as pkl
+from pathlib import Path
+from datetime import datetime
+import numpy as np
+
+# Default paths
+DEFAULT_PROC_LIST_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/dataset/csiallx_feature2_ntrla_flag_pnlnorm/proc_list.proc"
+DEFAULT_VERSION = "csiallx_feature2_ntrla_flag_pnlnorm"
+
+# Alpha158 columns in order (158 features)
+ALPHA158_COLS = [
+ 'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
+ 'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
+ 'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
+ 'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
+ 'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
+ 'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
+ 'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
+ 'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
+ 'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
+ 'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
+ 'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
+ 'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
+ 'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
+ 'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
+ 'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
+ 'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
+ 'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
+ 'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
+ 'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
+ 'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
+ 'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
+ 'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
+ 'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
+ 'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
+ 'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
+ 'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
+ 'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
+ 'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
+ 'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
+ 'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
+ 'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
+]
+assert len(ALPHA158_COLS) == 158, f"Expected 158 alpha158 cols, got {len(ALPHA158_COLS)}"
+
+
+def extract_robust_zscore_params(proc_list_path: str) -> dict:
+ """
+ Extract RobustZScoreNorm parameters from Qlib's proc_list.proc file.
+
+ Args:
+ proc_list_path: Path to the proc_list.proc pickle file
+
+ Returns:
+ Dictionary containing:
+ - mean_train: numpy array of shape (330,)
+ - std_train: numpy array of shape (330,)
+ - fit_start_time: datetime string
+ - fit_end_time: datetime string
+ - fields_group: list of field groups
+ """
+ print(f"Loading proc_list.proc from: {proc_list_path}")
+
+ with open(proc_list_path, 'rb') as f:
+ proc_list = pkl.load(f)
+
+ print(f"Loaded {len(proc_list)} processors from proc_list")
+
+ # Find RobustZScoreNorm processor (typically at index 7)
+ zscore_proc = None
+ for i, proc in enumerate(proc_list):
+ proc_name = type(proc).__name__
+ print(f" [{i}] {proc_name}")
+ if proc_name == "RobustZScoreNorm":
+ zscore_proc = proc
+
+ if zscore_proc is None:
+ raise ValueError("RobustZScoreNorm processor not found in proc_list")
+
+ # Extract parameters
+ params = {
+ 'mean_train': zscore_proc.mean_train,
+ 'std_train': zscore_proc.std_train,
+ 'fit_start_time': getattr(zscore_proc, 'fit_start_time', None),
+ 'fit_end_time': getattr(zscore_proc, 'fit_end_time', None),
+ 'fields_group': getattr(zscore_proc, 'fields_group', None),
+ }
+
+ print(f"\nExtracted RobustZScoreNorm parameters:")
+ print(f" mean_train shape: {params['mean_train'].shape}, dtype: {params['mean_train'].dtype}")
+ print(f" std_train shape: {params['std_train'].shape}, dtype: {params['std_train'].dtype}")
+ print(f" fit_start_time: {params['fit_start_time']}")
+ print(f" fit_end_time: {params['fit_end_time']}")
+ print(f" fields_group: {params['fields_group']}")
+
+ return params
+
+
+def build_feature_column_names() -> list:
+ """
+ Build the complete list of 330 feature column names in order.
+
+ Feature order (330 total):
+ 1. alpha158_ntrl (158 features)
+ 2. alpha158_raw (158 features)
+ 3. market_ext_ntrl (7 features)
+ 4. market_ext_raw (7 features)
+
+ market_ext columns (after processing):
+ - Base: turnover, free_turnover, log_size, con_rating_strength
+ - Diff: turnover_diff, free_turnover_diff, con_rating_strength_diff
+ - Note: log_size_diff is removed by ColumnRemover
+ """
+ # Alpha158 neutralized columns (158)
+ alpha158_ntrl = [f"{c}_ntrl" for c in ALPHA158_COLS]
+
+ # Alpha158 raw columns (158)
+ alpha158_raw = ALPHA158_COLS.copy()
+
+ # market_ext columns (7 after ColumnRemover)
+ # After Diff: 4 base + 4 diff = 8
+ # After ColumnRemover (removes log_size_diff): 7 remain
+ market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
+ market_ext_diff = ['turnover_diff', 'free_turnover_diff', 'log_size_diff', 'con_rating_strength_diff']
+ market_ext_all = market_ext_base + market_ext_diff
+ market_ext_final = [c for c in market_ext_all if c != 'log_size_diff']
+
+ # market_ext neutralized columns (7)
+ market_ext_ntrl = [f"{c}_ntrl" for c in market_ext_final]
+
+ # market_ext raw columns (7)
+ market_ext_raw = market_ext_final.copy()
+
+ # Combine all feature columns in Qlib order
+ feature_cols = alpha158_ntrl + alpha158_raw + market_ext_ntrl + market_ext_raw
+
+ print(f"\nBuilt feature column names:")
+ print(f" alpha158_ntrl: {len(alpha158_ntrl)} features")
+ print(f" alpha158_raw: {len(alpha158_raw)} features")
+ print(f" market_ext_ntrl: {len(market_ext_ntrl)} features")
+ print(f" market_ext_raw: {len(market_ext_raw)} features")
+ print(f" Total: {len(feature_cols)} features")
+
+ return feature_cols
+
+
+def save_parameters(
+ params: dict,
+ feature_cols: list,
+ output_dir: str,
+ version: str
+):
+ """
+ Save extracted parameters to output directory.
+
+ Creates:
+ - mean_train.npy
+ - std_train.npy
+ - metadata.json
+ """
+ output_path = Path(output_dir) / version
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ print(f"\nSaving parameters to: {output_path}")
+
+ # Save mean_train.npy
+ mean_path = output_path / "mean_train.npy"
+ np.save(mean_path, params['mean_train'])
+ print(f" Saved mean_train to: {mean_path}")
+
+ # Save std_train.npy
+ std_path = output_path / "std_train.npy"
+ np.save(std_path, params['std_train'])
+ print(f" Saved std_train to: {std_path}")
+
+ # Build metadata
+ metadata = {
+ 'version': version,
+ 'created_at': datetime.now().isoformat(),
+ 'source_file': str(params.get('fit_start_time', 'unknown')),
+ 'fit_start_time': str(params['fit_start_time']),
+ 'fit_end_time': str(params['fit_end_time']),
+ 'fields_group': list(params['fields_group']) if params['fields_group'] else None,
+ 'feature_columns': {
+ 'alpha158_ntrl': ALPHA158_COLS, # Store base names, _ntrl is implied
+ 'alpha158_raw': ALPHA158_COLS,
+ 'market_ext_ntrl': [
+ 'turnover', 'free_turnover', 'log_size', 'con_rating_strength',
+ 'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'
+ ],
+ 'market_ext_raw': [
+ 'turnover', 'free_turnover', 'log_size', 'con_rating_strength',
+ 'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'
+ ],
+ },
+ 'feature_count': {
+ 'alpha158_ntrl': 158,
+ 'alpha158_raw': 158,
+ 'market_ext_ntrl': 7,
+ 'market_ext_raw': 7,
+ 'total': 330
+ },
+ 'parameter_shapes': {
+ 'mean_train': list(params['mean_train'].shape),
+ 'std_train': list(params['std_train'].shape)
+ }
+ }
+
+ # Save metadata.json
+ metadata_path = output_path / "metadata.json"
+ with open(metadata_path, 'w') as f:
+ json.dump(metadata, f, indent=2)
+ print(f" Saved metadata to: {metadata_path}")
+
+ return output_path
+
+
+def main():
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description="Extract RobustZScoreNorm parameters from Qlib's proc_list.proc"
+ )
+ parser.add_argument(
+ '--proc-list',
+ type=str,
+ default=DEFAULT_PROC_LIST_PATH,
+ help=f"Path to proc_list.proc (default: {DEFAULT_PROC_LIST_PATH})"
+ )
+ parser.add_argument(
+ '--version',
+ type=str,
+ default=DEFAULT_VERSION,
+ help=f"Version name for output directory (default: {DEFAULT_VERSION})"
+ )
+ parser.add_argument(
+ '--output-dir',
+ type=str,
+ default=str(Path(__file__).parent.parent / "data" / "robust_zscore_params"),
+ help="Output directory for parameter files"
+ )
+
+ args = parser.parse_args()
+
+ print("=" * 80)
+ print("Extract Qlib RobustZScoreNorm Parameters")
+ print("=" * 80)
+ print(f"Source: {args.proc_list}")
+ print(f"Version: {args.version}")
+ print(f"Output: {args.output_dir}")
+ print()
+
+ # Step 1: Extract parameters from proc_list.proc
+ params = extract_robust_zscore_params(args.proc_list)
+
+ # Step 2: Build feature column names
+ feature_cols = build_feature_column_names()
+
+ # Step 3: Verify parameter shape matches feature count
+ if params['mean_train'].shape[0] != len(feature_cols):
+ print(f"\nWARNING: Parameter shape mismatch!")
+ print(f" Expected: {len(feature_cols)} features")
+ print(f" Got: {params['mean_train'].shape[0]} parameters")
+ else:
+ print(f"\n✓ Parameter shape matches feature count ({len(feature_cols)})")
+
+ # Step 4: Save parameters
+ output_path = save_parameters(params, feature_cols, args.output_dir, args.version)
+
+ print("\n" + "=" * 80)
+ print("Extraction complete!")
+ print("=" * 80)
+ print(f"Output files:")
+ print(f" {output_path}/mean_train.npy")
+ print(f" {output_path}/std_train.npy")
+ print(f" {output_path}/metadata.json")
+ print()
+ print("Usage in Polars RobustZScoreNorm:")
+ print(f' norm = RobustZScoreNorm.from_version("{args.version}")')
+ print("=" * 80)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py
index 6902d31..d8316c3 100644
--- a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py
+++ b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py
@@ -74,38 +74,25 @@ INDUSTRY_FLAG_COLS = [
'gds_CC50', 'gds_CC60', 'gds_CC61', 'gds_CC62', 'gds_CC63', 'gds_CC70'
]
-# Stock universe filter: csiallx = All A-shares excluding BSE/NEEQ and STAR market
-# This matches the original qlib handler configuration
-# - Include: SH600xxx, SH601xxx, SH603xxx, SH605xxx (Shanghai Main Board)
-# - Include: SZ000xxx, SZ001xxx, SZ002xxx, SZ003xxx (Shenzhen Main Board)
-# - Include: SZ300xxx, SZ301xxx (ChiNext)
-# - Exclude: SH688xxx, SH689xxx (STAR Market/科创板)
-# - Exclude: 4xxxxx, 8xxxxx (BSE/NEEQ/北交所/新三板)
-def filter_stock_universe(df: pl.DataFrame) -> pl.DataFrame:
+
+def filter_stock_universe(df: pl.DataFrame, instruments: str = 'csiallx') -> pl.DataFrame:
"""
- Filter dataframe to csiallx stock universe (A-shares only).
+ Filter dataframe to csiallx stock universe (A-shares excluding STAR/BSE) using qshare spine functions.
+
+ This uses qshare's filter_instruments which loads the instrument list from:
+ /data/qlib/default/data_ops/target/instruments/csiallx.txt
+
+ Args:
+ df: Input DataFrame with datetime and instrument columns
+ instruments: Market name for spine creation (default: 'csiallx')
- This filter matches the original qlib handler configuration which excludes:
- - BSE/NEEQ stocks (4xxxxx, 8xxxxx)
- - STAR Market stocks (688xxx, 689xxx)
+ Returns:
+ Filtered DataFrame with only instruments in the specified universe
"""
- inst_str = pl.col('instrument').cast(pl.String).str.zfill(6)
-
- # Define inclusion patterns
- is_sh_main = inst_str.str.starts_with('60') | inst_str.str.starts_with('61')
- is_sz_main = inst_str.str.starts_with('0')
- is_chi_next = inst_str.str.starts_with('300') | inst_str.str.starts_with('301')
-
- # Define exclusion patterns (explicitly exclude these)
- is_star = inst_str.str.starts_with('688') | inst_str.str.starts_with('689')
- is_bseeq = inst_str.str.starts_with('4') | inst_str.str.starts_with('8')
-
- # Filter: include main boards and ChiNext, exclude STAR and BSE/NEEQ
- df = df.filter(
- (is_sh_main | is_sz_main | is_chi_next) &
- (~is_star) &
- (~is_bseeq)
- )
+ from qshare.algo.polars.spine import filter_instruments
+
+ # Use qshare's filter_instruments with csiallx market name
+ df = filter_instruments(df, instruments=instruments)
return df