Simplify pipeline with struct columns

- Remove split_at_end parameter from pipeline.transform(), always return DataFrame
- Add pack_struct parameter to pack feature groups into struct columns
- Rename exporters: select_feature_groups_from_df -> get_groups, select_feature_groups -> get_groups_from_fg
- Add pack_structs() and unpack_struct() helper functions
- Remove split_from_merged() method from FeatureGroups (no longer needed)
- Rename dump_polars_dataset.py to dump_features.py with --pack-struct flag
- Update README with new CLI usage and struct column documentation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
master
guofu 2 days ago
parent 26a694298d
commit 5109ac4eb3

@ -115,17 +115,22 @@ All fixed processors preserve the trained parameters from the original proc_list
### Polars Dataset Generation
The `scripts/dump_polars_dataset.py` script generates datasets using a polars-based pipeline that replicates the qlib preprocessing:
The `scripts/dump_features.py` script generates datasets using a polars-based pipeline that replicates the qlib preprocessing:
```bash
# Generate raw and processed datasets
python scripts/dump_polars_dataset.py
# Generate merged features (flat columns)
python scripts/dump_features.py --start-date 2024-01-01 --end-date 2024-01-31 --groups merged
# Generate with struct columns (packed feature groups)
python scripts/dump_features.py --start-date 2024-01-01 --end-date 2024-01-31 --groups merged --pack-struct
# Generate specific feature groups
python scripts/dump_features.py --start-date 2024-01-01 --end-date 2024-01-31 --groups alpha158 market_ext
```
This script:
1. Loads data from Parquet files (alpha158, kline, market flags, industry flags)
2. Saves raw data (before processors) to `data_polars/raw_data_*.pkl`
3. Applies the full processor pipeline:
2. Applies the full processor pipeline:
- Diff processor (adds diff features)
- FlagMarketInjector (adds market_0, market_1)
- ColumnRemover (removes log_size_diff, IsN, IsZt, IsDt)
@ -133,13 +138,20 @@ This script:
- IndusNtrlInjector (industry neutralization)
- RobustZScoreNorm (using pre-fitted qlib parameters via `from_version()`)
- Fillna (fill NaN with 0)
4. Saves processed data to `data_polars/processed_data_*.pkl`
3. Saves to parquet/pickle format
**Output modes:**
- **Flat mode (default)**: All columns as separate fields (348 columns for merged)
- **Struct mode (`--pack-struct`)**: Feature groups packed into struct columns:
- `features_alpha158` (316 fields)
- `features_market_ext` (14 fields)
- `features_market_flag` (11 fields)
**Note**: The `FlagSTInjector` step is skipped because it fails silently even in the gold-standard qlib code (see `BUG_ANALYSIS_FINAL.md` for details).
Output structure:
- Raw data: ~204 columns (158 feature + 4 feature_ext + 12 feature_flag + 30 indus_flag)
- Processed data: 342 columns (316 feature + 14 feature_ext + 11 feature_flag + 1 indus_idx)
- Processed data: 348 columns (318 alpha158 + 14 market_ext + 14 market_flag + 2 index)
- VAE input dimension: 341 (excluding indus_idx)
### RobustZScoreNorm Parameter Extraction

@ -0,0 +1,265 @@
#!/usr/bin/env python
"""
Script to generate and dump transformed features from the alpha158_beta pipeline.
This script provides fine-grained control over the feature generation and dumping process:
- Select which feature groups to dump (alpha158, market_ext, market_flag, merged, vae_input)
- Choose output format (parquet, pickle, numpy)
- Control date range and universe filtering
- Save intermediate pipeline outputs
- Enable streaming mode for large datasets (>1 year)
Usage:
# Dump all features to parquet
python dump_features.py --start-date 2025-01-01 --end-date 2025-01-31
# Dump only alpha158 features to pickle
python dump_features.py --groups alpha158 --format pickle
# Dump with custom output path
python dump_features.py --output /path/to/output.parquet
# Dump merged features with all columns
python dump_features.py --groups merged --verbose
# Use streaming mode for large date ranges (>1 year)
python dump_features.py --start-date 2020-01-01 --end-date 2023-12-31 --streaming
"""
import os
import sys
import argparse
from pathlib import Path
from typing import Optional, List
# Add src to path for imports
SCRIPT_DIR = Path(__file__).parent
sys.path.insert(0, str(SCRIPT_DIR.parent / 'src'))
from processors import (
FeaturePipeline,
FeatureGroups,
VAE_INPUT_DIM,
ALPHA158_COLS,
MARKET_EXT_BASE_COLS,
COLUMNS_TO_REMOVE,
get_groups,
dump_to_parquet,
dump_to_pickle,
dump_to_numpy,
)
# Default output directory
DEFAULT_OUTPUT_DIR = SCRIPT_DIR.parent / "data"
def generate_and_dump(
start_date: str,
end_date: str,
output_path: str,
output_format: str = 'parquet',
groups: List[str] = None,
universe: str = 'csiallx',
filter_universe: bool = True,
robust_zscore_params_path: Optional[str] = None,
verbose: bool = True,
pack_struct: bool = False,
streaming: bool = False,
) -> None:
"""
Generate features and dump to file.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
output_path: Output file path
output_format: Output format ('parquet', 'pickle', 'numpy')
groups: Feature groups to dump (default: ['merged'])
universe: Stock universe name
filter_universe: Whether to filter to stock universe
robust_zscore_params_path: Path to robust zscore parameters
verbose: Whether to print progress
pack_struct: If True, pack each feature group into struct columns
(features_alpha158, features_market_ext, features_market_flag)
streaming: If True, use Polars streaming mode for large datasets (>1 year)
"""
if groups is None:
groups = ['merged']
print("=" * 60)
print("Feature Dump Script")
print("=" * 60)
print(f"Date range: {start_date} to {end_date}")
print(f"Output format: {output_format}")
print(f"Feature groups: {groups}")
print(f"Universe: {universe} (filter: {filter_universe})")
print(f"Pack struct: {pack_struct}")
print(f"Output path: {output_path}")
print("=" * 60)
# Initialize pipeline
pipeline = FeaturePipeline(
robust_zscore_params_path=robust_zscore_params_path
)
# Load data
feature_groups = pipeline.load_data(
start_date, end_date,
filter_universe=filter_universe,
universe_name=universe,
streaming=streaming
)
# Apply transformations - get merged DataFrame (pipeline always returns merged DataFrame now)
df_transformed = pipeline.transform(feature_groups, pack_struct=pack_struct)
# Select feature groups from merged DataFrame
outputs = get_groups(df_transformed, groups, verbose, use_struct=False)
# Ensure output directory exists
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Dump to file(s)
if output_format == 'numpy':
# For numpy, we save the merged features
dump_to_numpy(feature_groups, output_path, include_metadata=True, verbose=verbose)
elif output_format == 'pickle':
if 'merged' in outputs:
dump_to_pickle(outputs['merged'], output_path, verbose=verbose)
elif len(outputs) == 1:
# Single group output
key = list(outputs.keys())[0]
base_path = Path(output_path)
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
dump_to_pickle(outputs[key], dump_path, verbose=verbose)
else:
# Multiple groups - save each separately
base_path = Path(output_path)
for key, df_out in outputs.items():
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
dump_to_pickle(df_out, dump_path, verbose=verbose)
else: # parquet
if 'merged' in outputs:
dump_to_parquet(outputs['merged'], output_path, verbose=verbose)
elif len(outputs) == 1:
# Single group output
key = list(outputs.keys())[0]
base_path = Path(output_path)
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
dump_to_parquet(outputs[key], dump_path, verbose=verbose)
else:
# Multiple groups - save each separately
base_path = Path(output_path)
for key, df_out in outputs.items():
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
dump_to_parquet(df_out, dump_path, verbose=verbose)
print("=" * 60)
print("Feature dump complete!")
print("=" * 60)
def main():
parser = argparse.ArgumentParser(
description="Generate and dump transformed features from alpha158_beta pipeline"
)
# Date range
parser.add_argument(
"--start-date", type=str, required=True,
help="Start date in YYYY-MM-DD format"
)
parser.add_argument(
"--end-date", type=str, required=True,
help="End date in YYYY-MM-DD format"
)
# Output settings
parser.add_argument(
"--output", "-o", type=str, default=None,
help=f"Output file path (default: {DEFAULT_OUTPUT_DIR}/features.parquet)"
)
parser.add_argument(
"--format", "-f", type=str, default='parquet',
choices=['parquet', 'pickle', 'numpy'],
help="Output format (default: parquet)"
)
# Feature groups
parser.add_argument(
"--groups", "-g", type=str, nargs='+', default=['merged'],
choices=['merged', 'alpha158', 'market_ext', 'market_flag', 'vae_input'],
help="Feature groups to dump (default: merged)"
)
# Universe settings
parser.add_argument(
"--universe", type=str, default='csiallx',
help="Stock universe name (default: csiallx)"
)
parser.add_argument(
"--no-filter-universe", action="store_true",
help="Disable stock universe filtering"
)
# Robust zscore parameters
parser.add_argument(
"--robust-zscore-params", type=str, default=None,
help="Path to robust zscore parameters directory"
)
# Verbose mode
parser.add_argument(
"--verbose", "-v", action="store_true", default=True,
help="Enable verbose output (default: True)"
)
parser.add_argument(
"--quiet", "-q", action="store_true",
help="Disable verbose output"
)
# Struct option
parser.add_argument(
"--pack-struct", "-s", action="store_true",
help="Pack each feature group into separate struct columns (features_alpha158, features_market_ext, features_market_flag)"
)
# Streaming option
parser.add_argument(
"--streaming", action="store_true",
help="Use Polars streaming mode for large datasets (recommended for date ranges > 1 year)"
)
args = parser.parse_args()
# Handle verbose/quiet flags
verbose = args.verbose and not args.quiet
# Set default output path
if args.output is None:
# Build default output path: {data_dir}/features_{group}.parquet
# Note: generate_and_dump will add group suffix, so use base name "features"
output_path = str(DEFAULT_OUTPUT_DIR / "features.parquet")
else:
output_path = args.output
# Generate and dump
generate_and_dump(
start_date=args.start_date,
end_date=args.end_date,
output_path=output_path,
output_format=args.format,
groups=args.groups,
universe=args.universe,
filter_universe=not args.no_filter_universe,
robust_zscore_params_path=args.robust_zscore_params,
verbose=verbose,
pack_struct=args.pack_struct,
streaming=args.streaming,
)
if __name__ == "__main__":
main()

@ -1,364 +0,0 @@
#!/usr/bin/env python
"""
Script to dump raw and processed datasets using the polars-based pipeline.
This generates:
1. Raw data (before applying processors) - equivalent to qlib's handler output
2. Processed data (after applying all processors) - ready for VAE encoding
Date range: 2026-02-23 to today (2026-02-27)
"""
import os
import sys
import pickle as pkl
import numpy as np
import polars as pl
from pathlib import Path
from datetime import datetime
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
# Import processors from the new shared module
from cta_1d.src.processors import (
DiffProcessor,
FlagMarketInjector,
FlagSTInjector,
ColumnRemover,
FlagToOnehot,
IndusNtrlInjector,
RobustZScoreNorm,
Fillna,
)
# Import constants from local module
from generate_beta_embedding import (
ALPHA158_COLS,
INDUSTRY_FLAG_COLS,
)
# Date range
START_DATE = "2026-02-23"
END_DATE = "2026-02-27"
# Output directory
OUTPUT_DIR = Path(__file__).parent.parent / "data_polars"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
"""
Apply the full processor pipeline (equivalent to qlib's proc_list).
This mimics the qlib proc_list:
0. Diff: Adds diff features for market_ext columns
1. FlagMarketInjector: Adds market_0, market_1
2. FlagSTInjector: Adds IsST (placeholder if ST flags not available)
3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt
4. FlagToOnehot: Converts 29 industry flags to indus_idx
5. IndusNtrlInjector: Industry neutralization for feature
6. IndusNtrlInjector: Industry neutralization for feature_ext
7. RobustZScoreNorm: Normalization using pre-fitted qlib params
8. Fillna: Fill NaN with 0
"""
print("=" * 60)
print("Applying processor pipeline")
print("=" * 60)
# market_ext columns (4 base)
market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
# market_flag columns (12 total before ColumnRemover)
market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
'open_limit', 'close_limit', 'low_limit',
'open_stop', 'close_stop', 'high_stop']
# Step 1: Diff Processor
print("\n[1] Applying Diff processor...")
diff_processor = DiffProcessor(market_ext_base)
df = diff_processor.process(df)
# After Diff: market_ext has 8 columns (4 base + 4 diff)
market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base]
# Step 2: FlagMarketInjector (adds market_0, market_1)
print("[2] Applying FlagMarketInjector...")
flag_injector = FlagMarketInjector()
df = flag_injector.process(df)
# Add market_0, market_1 to flag list
market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
# Step 3: FlagSTInjector - adds IsST (placeholder if ST flags not available)
print("[3] Applying FlagSTInjector...")
flag_st_injector = FlagSTInjector()
df = flag_st_injector.process(df)
# Add IsST to flag list
market_flag_with_st = market_flag_with_market + ['IsST']
# Step 4: ColumnRemover
print("[4] Applying ColumnRemover...")
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
remover = ColumnRemover(columns_to_remove)
df = remover.process(df)
# Update column lists after removal
market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove]
market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove]
print(f" Removed columns: {columns_to_remove}")
print(f" Remaining market_ext: {len(market_ext_cols)} columns")
print(f" Remaining market_flag: {len(market_flag_with_st)} columns")
# Step 5: FlagToOnehot
print("[5] Applying FlagToOnehot...")
flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
df = flag_to_onehot.process(df)
# Step 6 & 7: IndusNtrlInjector
print("[6] Applying IndusNtrlInjector for alpha158...")
alpha158_cols = ALPHA158_COLS.copy()
indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl')
df = indus_ntrl_alpha.process(df)
print("[7] Applying IndusNtrlInjector for market_ext...")
indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl')
df = indus_ntrl_ext.process(df)
# Build column lists for normalization
alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols]
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols]
# Step 8: RobustZScoreNorm
print("[8] Applying RobustZScoreNorm...")
norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols
# Load RobustZScoreNorm with pre-fitted parameters from version
robust_norm = RobustZScoreNorm.from_version(
"csiallx_feature2_ntrla_flag_pnlnorm",
feature_cols=norm_feature_cols
)
# Verify parameter shape matches expected features
expected_features = len(norm_feature_cols)
if robust_norm.qlib_mean.shape[0] != expected_features:
print(f" WARNING: Feature count mismatch! Expected {expected_features}, "
f"got {robust_norm.qlib_mean.shape[0]}")
df = robust_norm.process(df)
# Step 9: Fillna
print("[9] Applying Fillna...")
final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx']
fillna = Fillna()
df = fillna.process(df, final_feature_cols)
print("\n" + "=" * 60)
print("Processor pipeline complete!")
print(f" Normalized features: {len(norm_feature_cols)}")
print(f" Market flags: {len(market_flag_with_st)}")
print(f" Total features (with indus_idx): {len(final_feature_cols)}")
print("=" * 60)
return df
def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
"""
Convert polars DataFrame to pandas DataFrame with MultiIndex columns.
This matches the format of qlib's output.
IMPORTANT: Qlib's IndusNtrlInjector outputs columns in order [_ntrl] + [raw],
so we need to reorder columns to match this expected order.
"""
import pandas as pd
# Convert to pandas
df = df_polars.to_pandas()
# Check if datetime and instrument are columns
if 'datetime' in df.columns and 'instrument' in df.columns:
# Set MultiIndex
df = df.set_index(['datetime', 'instrument'])
# If they're already not in columns, assume they're already the index
# Drop raw columns that shouldn't be in processed data
raw_cols_to_drop = ['Turnover', 'FreeTurnover', 'MarketValue']
existing_raw_cols = [c for c in raw_cols_to_drop if c in df.columns]
if existing_raw_cols:
df = df.drop(columns=existing_raw_cols)
# Build MultiIndex columns based on column name patterns
# IMPORTANT: Qlib order is [_ntrl columns] + [raw columns] for each group
columns_with_group = []
# Define column sets
alpha158_base = set(ALPHA158_COLS)
market_ext_base = {'turnover', 'free_turnover', 'log_size', 'con_rating_strength'}
market_ext_diff = {'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'}
market_ext_all = market_ext_base | market_ext_diff
feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit',
'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'}
# First pass: collect _ntrl columns (these come first in qlib order)
ntrl_alpha158_cols = []
ntrl_market_ext_cols = []
raw_alpha158_cols = []
raw_market_ext_cols = []
flag_cols = []
indus_idx_col = None
for col in df.columns:
if col == 'indus_idx':
indus_idx_col = col
elif col in feature_flag_cols:
flag_cols.append(col)
elif col.endswith('_ntrl'):
base_name = col[:-5] # Remove _ntrl suffix (5 characters)
if base_name in alpha158_base:
ntrl_alpha158_cols.append(col)
elif base_name in market_ext_all:
ntrl_market_ext_cols.append(col)
elif col in alpha158_base:
raw_alpha158_cols.append(col)
elif col in market_ext_all:
raw_market_ext_cols.append(col)
elif col in INDUSTRY_FLAG_COLS:
columns_with_group.append(('indus_flag', col))
elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}:
columns_with_group.append(('st_flag', col))
else:
# Unknown column - print warning
print(f" Warning: Unknown column '{col}', assigning to 'other' group")
columns_with_group.append(('other', col))
# Build columns in qlib order: [_ntrl] + [raw] for each feature group
# Feature group: alpha158_ntrl + alpha158
for col in sorted(ntrl_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x.replace('_ntrl', '')) if x.replace('_ntrl', '') in ALPHA158_COLS else 999):
columns_with_group.append(('feature', col))
for col in sorted(raw_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x) if x in ALPHA158_COLS else 999):
columns_with_group.append(('feature', col))
# Feature_ext group: market_ext_ntrl + market_ext
for col in ntrl_market_ext_cols:
columns_with_group.append(('feature_ext', col))
for col in raw_market_ext_cols:
columns_with_group.append(('feature_ext', col))
# Feature_flag group
for col in flag_cols:
columns_with_group.append(('feature_flag', col))
# Indus_idx
if indus_idx_col:
columns_with_group.append(('indus_idx', indus_idx_col))
# Create MultiIndex columns
multi_cols = pd.MultiIndex.from_tuples(columns_with_group)
df.columns = multi_cols
return df
def main():
print("=" * 80)
print("Dumping Polars Dataset")
print("=" * 80)
print(f"Date range: {START_DATE} to {END_DATE}")
print(f"Output directory: {OUTPUT_DIR}")
print()
# Step 1: Load all data
print("Step 1: Loading data from parquet...")
df_alpha, df_kline, df_flag, df_industry = load_all_data(START_DATE, END_DATE)
print(f" Alpha158 shape: {df_alpha.shape}")
print(f" Kline (market_ext) shape: {df_kline.shape}")
print(f" Flags shape: {df_flag.shape}")
print(f" Industry shape: {df_industry.shape}")
# Step 2: Merge data sources
print("\nStep 2: Merging data sources...")
df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
print(f" Merged shape (after csiallx filter): {df_merged.shape}")
# Step 3: Save raw data (before processors)
print("\nStep 3: Saving raw data (before processors)...")
# Keep columns that match qlib's raw output format
# Include datetime and instrument for MultiIndex conversion
raw_columns = (
['datetime', 'instrument'] + # Index columns
ALPHA158_COLS + # feature group
['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + # feature_ext base
['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', # market_flag from kline
'open_limit', 'close_limit', 'low_limit',
'open_stop', 'close_stop', 'high_stop'] +
INDUSTRY_FLAG_COLS + # indus_flag
(['ST_S', 'ST_Y'] if 'ST_S' in df_merged.columns else []) # st_flag (if available)
)
# Filter to available columns
available_raw_cols = [c for c in raw_columns if c in df_merged.columns]
print(f" Selecting {len(available_raw_cols)} columns for raw data...")
df_raw_polars = df_merged.select(available_raw_cols)
# Convert to pandas with MultiIndex
df_raw_pd = convert_to_multiindex_df(df_raw_polars)
raw_output_path = OUTPUT_DIR / f"raw_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
with open(raw_output_path, "wb") as f:
pkl.dump(df_raw_pd, f)
print(f" Saved raw data to: {raw_output_path}")
print(f" Raw data shape: {df_raw_pd.shape}")
print(f" Column groups: {df_raw_pd.columns.get_level_values(0).unique().tolist()}")
# Step 4: Apply processor pipeline
print("\nStep 4: Applying processor pipeline...")
df_processed = apply_processor_pipeline(df_merged)
# Step 5: Save processed data
print("\nStep 5: Saving processed data (after processors)...")
# Convert to pandas with MultiIndex
df_processed_pd = convert_to_multiindex_df(df_processed)
processed_output_path = OUTPUT_DIR / f"processed_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
with open(processed_output_path, "wb") as f:
pkl.dump(df_processed_pd, f)
print(f" Saved processed data to: {processed_output_path}")
print(f" Processed data shape: {df_processed_pd.shape}")
print(f" Column groups: {df_processed_pd.columns.get_level_values(0).unique().tolist()}")
# Count columns per group
print("\n Column counts by group:")
for grp in df_processed_pd.columns.get_level_values(0).unique().tolist():
count = (df_processed_pd.columns.get_level_values(0) == grp).sum()
print(f" {grp}: {count} columns")
# Step 6: Verify column counts
print("\n" + "=" * 80)
print("Verification")
print("=" * 80)
feature_flag_cols = [c[1] for c in df_processed_pd.columns if c[0] == 'feature_flag']
has_market_0 = 'market_0' in feature_flag_cols
has_market_1 = 'market_1' in feature_flag_cols
print(f" feature_flag columns: {feature_flag_cols}")
print(f" Has market_0: {has_market_0}")
print(f" Has market_1: {has_market_1}")
if has_market_0 and has_market_1:
print("\n SUCCESS: market_0 and market_1 columns are present!")
else:
print("\n WARNING: market_0 or market_1 columns are missing!")
print("\n" + "=" * 80)
print("Dataset dump complete!")
print("=" * 80)
if __name__ == "__main__":
main()

@ -0,0 +1,53 @@
"""
Source package for alpha158_beta experiments.
This package provides modules for data loading, feature transformation,
and model training for alpha158 beta factor experiments.
"""
from .processors import (
FeaturePipeline,
FeatureGroups,
DiffProcessor,
FlagMarketInjector,
FlagSTInjector,
ColumnRemover,
FlagToOnehot,
IndusNtrlInjector,
RobustZScoreNorm,
Fillna,
load_alpha158,
load_market_ext,
load_market_flags,
load_industry_flags,
load_all_data,
load_robust_zscore_params,
filter_stock_universe,
)
__all__ = [
# Main pipeline
'FeaturePipeline',
'FeatureGroups',
# Processors
'DiffProcessor',
'FlagMarketInjector',
'FlagSTInjector',
'ColumnRemover',
'FlagToOnehot',
'IndusNtrlInjector',
'RobustZScoreNorm',
'Fillna',
# Loaders
'load_alpha158',
'load_market_ext',
'load_market_flags',
'load_industry_flags',
'load_all_data',
# Utilities
'load_robust_zscore_params',
'filter_stock_universe',
]

@ -0,0 +1,136 @@
"""
Processors package for the alpha158_beta feature pipeline.
This package provides a modular, polars-native data loading and transformation
pipeline for generating VAE input features from alpha158 beta factors.
Main components:
- FeatureGroups: Dataclass container for separate feature groups
- FeaturePipeline: Main orchestrator class for the full pipeline
- Processors: Individual transformation classes (DiffProcessor, IndusNtrlInjector, etc.)
- Loaders: Data loading functions for parquet sources
Usage:
from processors import FeaturePipeline, FeatureGroups
pipeline = FeaturePipeline()
feature_groups = pipeline.load_data(start_date, end_date)
transformed = pipeline.transform(feature_groups)
vae_input = pipeline.get_vae_input(transformed)
"""
from .dataclass import FeatureGroups
from .loaders import (
load_alpha158,
load_market_ext,
load_market_flags,
load_industry_flags,
load_all_data,
load_parquet_by_date_range,
get_date_partitions,
# Constants
PARQUET_ALPHA158_BETA_PATH,
PARQUET_KLINE_PATH,
PARQUET_MARKET_FLAG_PATH,
PARQUET_INDUSTRY_FLAG_PATH,
PARQUET_CON_RATING_PATH,
INDUSTRY_FLAG_COLS,
MARKET_FLAG_COLS_KLINE,
MARKET_FLAG_COLS_MARKET,
MARKET_EXT_RAW_COLS,
)
from .processors import (
DiffProcessor,
FlagMarketInjector,
FlagSTInjector,
ColumnRemover,
FlagToOnehot,
IndusNtrlInjector,
RobustZScoreNorm,
Fillna,
)
from .pipeline import (
FeaturePipeline,
load_robust_zscore_params,
filter_stock_universe,
# Constants
ALPHA158_COLS,
MARKET_EXT_BASE_COLS,
MARKET_FLAG_COLS,
COLUMNS_TO_REMOVE,
VAE_INPUT_DIM,
DEFAULT_ROBUST_ZSCORE_PARAMS_PATH,
)
from .exporters import (
get_groups,
get_groups_from_fg,
pack_structs,
unpack_struct,
dump_to_parquet,
dump_to_pickle,
dump_to_numpy,
dump_features,
# Also import old names for backward compatibility
get_groups as select_feature_groups_from_df,
get_groups_from_fg as select_feature_groups,
)
__all__ = [
# Main classes
'FeaturePipeline',
'FeatureGroups',
# Processors
'DiffProcessor',
'FlagMarketInjector',
'FlagSTInjector',
'ColumnRemover',
'FlagToOnehot',
'IndusNtrlInjector',
'RobustZScoreNorm',
'Fillna',
# Loaders
'load_alpha158',
'load_market_ext',
'load_market_flags',
'load_industry_flags',
'load_all_data',
'load_parquet_by_date_range',
'get_date_partitions',
# Utility functions
'load_robust_zscore_params',
'filter_stock_universe',
# Exporter functions
'get_groups',
'get_groups_from_fg',
'pack_structs',
'unpack_struct',
'dump_to_parquet',
'dump_to_pickle',
'dump_to_numpy',
'dump_features',
# Backward compatibility aliases
'select_feature_groups_from_df',
'select_feature_groups',
# Constants
'PARQUET_ALPHA158_BETA_PATH',
'PARQUET_KLINE_PATH',
'PARQUET_MARKET_FLAG_PATH',
'PARQUET_INDUSTRY_FLAG_PATH',
'PARQUET_CON_RATING_PATH',
'INDUSTRY_FLAG_COLS',
'MARKET_FLAG_COLS_KLINE',
'MARKET_FLAG_COLS_MARKET',
'MARKET_EXT_RAW_COLS',
'ALPHA158_COLS',
'MARKET_EXT_BASE_COLS',
'MARKET_FLAG_COLS',
'COLUMNS_TO_REMOVE',
'VAE_INPUT_DIM',
'DEFAULT_ROBUST_ZSCORE_PARAMS_PATH',
]

@ -0,0 +1,115 @@
"""Dataclass definitions for the feature pipeline."""
from dataclasses import dataclass, field
from typing import Optional, List
import polars as pl
# Import constants for column categorization
# Alpha158 base columns (158 features)
ALPHA158_BASE_COLS = [
'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
]
# Market extension base columns
MARKET_EXT_BASE_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
# Market flag base columns (before processors)
MARKET_FLAG_BASE_COLS = [
'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'
]
@dataclass
class FeatureGroups:
"""
Container for separate feature groups in the pipeline.
Keeps feature groups separate throughout the pipeline to avoid
unnecessary merging and complex column management.
Attributes:
alpha158: 158 alpha158 features (+ _ntrl after industry neutralization)
market_ext: Market extension features (Turnover, FreeTurnover, MarketValue, etc.)
market_flag: Market flag features (IsZt, IsDt, IsXD, etc.)
industry: 29 industry flags (converted to indus_idx by FlagToOnehot)
indus_idx: Single column industry index (after FlagToOnehot processing)
instruments: List of instrument IDs for metadata
dates: List of datetime values for metadata
"""
# Core feature groups
alpha158: pl.DataFrame # 158 alpha158 features (+ _ntrl after processing)
market_ext: pl.DataFrame # Market extension features (+ _ntrl after processing)
market_flag: pl.DataFrame # Market flags (12 cols initially, 11 after ColumnRemover)
industry: Optional[pl.DataFrame] = None # 29 industry flags -> indus_idx
# Processed industry index (separate after FlagToOnehot)
indus_idx: Optional[pl.DataFrame] = None # Single column after FlagToOnehot
# Metadata (extracted from dataframes for easy access)
instruments: List[str] = field(default_factory=list)
dates: List[int] = field(default_factory=list)
def extract_metadata(self) -> None:
"""Extract instrument and date lists from the alpha158 dataframe."""
if self.alpha158 is not None and len(self.alpha158) > 0:
self.instruments = self.alpha158['instrument'].to_list()
self.dates = self.alpha158['datetime'].to_list()
def merge_for_processors(self) -> pl.DataFrame:
"""
Merge all feature groups into a single DataFrame for processors.
This is used by processors that need access to multiple groups
(e.g., IndusNtrlInjector needs industry index).
Returns:
Merged DataFrame with all features
"""
df = self.alpha158
# Merge market_ext if not already merged
if self.market_ext is not None and self.market_ext is not self.alpha158:
df = df.join(self.market_ext, on=['instrument', 'datetime'], how='left')
# Merge market_flag if not already merged
if self.market_flag is not None and self.market_flag is not self.alpha158:
df = df.join(self.market_flag, on=['instrument', 'datetime'], how='left')
# Merge industry/indus_idx if available
if self.indus_idx is not None:
df = df.join(self.indus_idx, on=['instrument', 'datetime'], how='left')
elif self.industry is not None:
df = df.join(self.industry, on=['instrument', 'datetime'], how='left')
return df

@ -0,0 +1,523 @@
"""
Feature exporters for the alpha158_beta pipeline.
This module provides functions to select and export feature groups from the
transformed pipeline output. It can be used by both dump_features.py and
generate_beta_embedding.py.
Feature groups:
- merged: All columns after transformation
- alpha158: Alpha158 features + _ntrl versions
- market_ext: Market extended features + _ntrl + _diff
- market_flag: Market flag columns
- vae_input: 341 features specifically curated for VAE training
Struct mode:
- When pack_struct=True, each feature group is packed into a struct column:
- features_alpha158 (316 fields)
- features_market_ext (14 fields)
- features_market_flag (11 fields)
"""
import os
import pickle
from pathlib import Path
from typing import Dict, Any, List, Optional, Union
import numpy as np
import polars as pl
from .pipeline import (
ALPHA158_COLS,
MARKET_EXT_BASE_COLS,
MARKET_FLAG_COLS,
COLUMNS_TO_REMOVE,
)
from .dataclass import FeatureGroups
# =============================================================================
# Helper functions for struct packing/unpacking
# =============================================================================
def pack_structs(df: pl.DataFrame) -> pl.DataFrame:
"""
Pack feature columns into struct columns based on feature groups.
Creates:
- features_alpha158: struct with 316 fields (158 + 158 _ntrl)
- features_market_ext: struct with 14 fields (7 + 7 _ntrl)
- features_market_flag: struct with 11 fields
Args:
df: Input DataFrame with flat columns
Returns:
DataFrame with struct columns: instrument, datetime, indus_idx, features_*
"""
# Define column groups
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
alpha158_all_cols = alpha158_ntrl_cols + ALPHA158_COLS
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
market_ext_all_cols = market_ext_ntrl_cols + market_ext_with_diff
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
market_flag_cols += ['market_0', 'market_1', 'IsST']
market_flag_cols = list(dict.fromkeys(market_flag_cols))
# Build result with struct columns
result_cols = ['instrument', 'datetime']
# Check if indus_idx exists
if 'indus_idx' in df.columns:
result_cols.append('indus_idx')
# Pack alpha158
alpha158_cols_in_df = [c for c in alpha158_all_cols if c in df.columns]
if alpha158_cols_in_df:
result_cols.append(pl.struct(alpha158_cols_in_df).alias('features_alpha158'))
# Pack market_ext
ext_cols_in_df = [c for c in market_ext_all_cols if c in df.columns]
if ext_cols_in_df:
result_cols.append(pl.struct(ext_cols_in_df).alias('features_market_ext'))
# Pack market_flag
flag_cols_in_df = [c for c in market_flag_cols if c in df.columns]
if flag_cols_in_df:
result_cols.append(pl.struct(flag_cols_in_df).alias('features_market_flag'))
return df.select(result_cols)
def unpack_struct(df: pl.DataFrame, struct_name: str) -> pl.DataFrame:
"""
Unpack a struct column back into individual columns.
Args:
df: DataFrame containing struct column
struct_name: Name of the struct column to unpack
Returns:
DataFrame with struct fields as individual columns
"""
if struct_name not in df.columns:
raise ValueError(f"Struct column '{struct_name}' not found in DataFrame")
# Get the struct field names
struct_dtype = df.schema[struct_name]
if not isinstance(struct_dtype, pl.Struct):
raise ValueError(f"Column '{struct_name}' is not a struct type")
field_names = struct_dtype.fields
# Unpack using struct field access
unpacked_cols = []
for field in field_names:
col_name = field.name
unpacked_cols.append(
pl.col(struct_name).struct.field(col_name).alias(col_name)
)
# Select original columns + unpacked columns
other_cols = [c for c in df.columns if c != struct_name]
return df.select(other_cols + unpacked_cols)
def dump_to_parquet(df: pl.DataFrame, path: str, verbose: bool = True) -> None:
"""Save DataFrame to parquet file."""
if verbose:
print(f"Saving to parquet: {path}")
df.write_parquet(path)
if verbose:
print(f" Shape: {df.shape}")
def dump_to_pickle(df: pl.DataFrame, path: str, verbose: bool = True) -> None:
"""Save DataFrame to pickle file."""
if verbose:
print(f"Saving to pickle: {path}")
with open(path, 'wb') as f:
pickle.dump(df, f)
if verbose:
print(f" Shape: {df.shape}")
def dump_to_numpy(
feature_groups: FeatureGroups,
path: str,
include_metadata: bool = True,
verbose: bool = True
) -> None:
"""
Save features to numpy format.
Saves:
- features.npy: The feature matrix
- metadata.pkl: Column names and metadata (if include_metadata=True)
"""
if verbose:
print(f"Saving to numpy: {path}")
# Merge all groups for numpy array
df = feature_groups.merge_for_processors()
# Get all feature columns (exclude instrument, datetime, indus_idx)
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
# Extract features
features = df.select(feature_cols).to_numpy().astype(np.float32)
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
# Save features
base_path = Path(path)
if base_path.suffix == '':
base_path = Path(str(path) + '.npy')
np.save(str(base_path), features)
if verbose:
print(f" Features shape: {features.shape}")
# Save metadata if requested
if include_metadata:
metadata_path = str(base_path).replace('.npy', '_metadata.pkl')
metadata = {
'feature_cols': feature_cols,
'instruments': df['instrument'].to_list(),
'dates': df['datetime'].to_list(),
'n_features': len(feature_cols),
'n_samples': len(df),
}
with open(metadata_path, 'wb') as f:
pickle.dump(metadata, f)
if verbose:
print(f" Metadata saved to: {metadata_path}")
def get_groups(
df: pl.DataFrame,
groups_to_dump: List[str],
verbose: bool = True,
use_struct: bool = False,
) -> Dict[str, Any]:
"""
Select which feature groups to include in the output from a merged DataFrame.
Args:
df: Transformed merged DataFrame (flat or with struct columns)
groups_to_dump: List of groups to include ('alpha158', 'market_ext', 'market_flag', 'merged', 'vae_input')
verbose: Whether to print progress
use_struct: If True, pack feature columns into a single 'features' struct column.
If df already has struct columns (pack_struct=True was used in pipeline),
this function will automatically handle them.
Returns:
Dictionary with selected DataFrames and metadata
"""
# Check if df already has struct columns from pipeline.pack_struct
has_struct_cols = any(
isinstance(df.schema.get(c), pl.Struct)
for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']
)
result = {}
if 'merged' in groups_to_dump:
if has_struct_cols:
# Already has struct columns from pipeline.pack_struct
result['merged'] = df
elif use_struct:
# Keep instrument, datetime, and pack rest into struct
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
result['merged'] = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
else:
result['merged'] = df
if verbose:
print(f"Merged features: {result['merged'].shape}")
if 'alpha158' in groups_to_dump:
# Check if struct column already exists from pipeline.pack_struct
if has_struct_cols and 'features_alpha158' in df.columns:
result['alpha158'] = df.select(['instrument', 'datetime', 'features_alpha158'])
else:
# Select alpha158 columns from merged DataFrame
alpha_cols = ['instrument', 'datetime'] + [c for c in ALPHA158_COLS if c in df.columns]
# Also include _ntrl versions
alpha_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS if f"{c}_ntrl" in df.columns]
alpha_cols += alpha_ntrl_cols
df_alpha = df.select(alpha_cols)
if use_struct:
feature_cols = [c for c in df_alpha.columns if c not in ['instrument', 'datetime']]
df_alpha = df_alpha.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
result['alpha158'] = df_alpha
if verbose:
print(f"Alpha158 features: {result['alpha158'].shape}")
if 'market_ext' in groups_to_dump:
# Check if struct column already exists from pipeline.pack_struct
if has_struct_cols and 'features_market_ext' in df.columns:
result['market_ext'] = df.select(['instrument', 'datetime', 'features_market_ext'])
else:
# Select market_ext columns from merged DataFrame
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
ext_cols = ['instrument', 'datetime'] + market_ext_with_diff
# Also include _ntrl versions
ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff if f"{c}_ntrl" in df.columns]
ext_cols += ext_ntrl_cols
df_ext = df.select(ext_cols)
if use_struct:
feature_cols = [c for c in df_ext.columns if c not in ['instrument', 'datetime']]
df_ext = df_ext.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
result['market_ext'] = df_ext
if verbose:
print(f"Market ext features: {result['market_ext'].shape}")
if 'market_flag' in groups_to_dump:
# Check if struct column already exists from pipeline.pack_struct
if has_struct_cols and 'features_market_flag' in df.columns:
result['market_flag'] = df.select(['instrument', 'datetime', 'features_market_flag'])
else:
# Select market_flag columns from merged DataFrame
flag_cols = ['instrument', 'datetime']
flag_cols += [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE and c in df.columns]
flag_cols += ['market_0', 'market_1', 'IsST'] if all(c in df.columns for c in ['market_0', 'market_1', 'IsST']) else []
flag_cols = list(dict.fromkeys(flag_cols)) # Remove duplicates
df_flag = df.select(flag_cols)
if use_struct:
feature_cols = [c for c in df_flag.columns if c not in ['instrument', 'datetime']]
df_flag = df_flag.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
result['market_flag'] = df_flag
if verbose:
print(f"Market flag features: {result['market_flag'].shape}")
if 'vae_input' in groups_to_dump:
# Get VAE input columns from merged DataFrame
# VAE input = 330 normalized features + 11 market flags = 341 features
# Note: indus_idx is NOT included in VAE input
# Build alpha158 feature columns (158 original + 158 _ntrl = 316)
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
alpha158_cols = ALPHA158_COLS.copy()
# Build market_ext feature columns (7 original + 7 _ntrl = 14)
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
market_ext_cols = market_ext_with_diff.copy()
# VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
norm_feature_cols = (
alpha158_ntrl_cols + alpha158_cols +
market_ext_ntrl_cols + market_ext_cols
)
# Market flag columns (excluding IsST)
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
market_flag_cols += ['market_0', 'market_1']
market_flag_cols = list(dict.fromkeys(market_flag_cols))
# Combine all VAE input columns (341 total)
vae_cols = norm_feature_cols + market_flag_cols
if use_struct:
# Pack all features into a single struct column
result['vae_input'] = df.select([
'instrument',
'datetime',
pl.struct(vae_cols).alias('features')
])
if verbose:
print(f"VAE input features: {result['vae_input'].shape} (struct with {len(vae_cols)} fields)")
else:
# Always keep datetime and instrument as index columns
# Select features with index columns first
result['vae_input'] = df.select(['instrument', 'datetime'] + vae_cols)
if verbose:
print(f"VAE input features: {result['vae_input'].shape} (columns: {len(vae_cols)} + 2 index)")
return result
def get_groups_from_fg(
feature_groups: FeatureGroups,
groups_to_dump: List[str],
verbose: bool = True,
use_struct: bool = False,
) -> Dict[str, Any]:
"""
Select which feature groups to include in the output.
Args:
feature_groups: Transformed FeatureGroups container
groups_to_dump: List of groups to include ('alpha158', 'market_ext', 'market_flag', 'merged', 'vae_input')
verbose: Whether to print progress
use_struct: If True, pack feature columns into a single 'features' struct column
Returns:
Dictionary with selected DataFrames and metadata
"""
result = {}
if 'merged' in groups_to_dump:
df = feature_groups.merge_for_processors()
if use_struct:
# Keep instrument, datetime, and pack rest into struct
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
result['merged'] = df
if verbose:
print(f"Merged features: {result['merged'].shape}")
if 'alpha158' in groups_to_dump:
df = feature_groups.alpha158
if use_struct:
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
result['alpha158'] = df
if verbose:
print(f"Alpha158 features: {df.shape}")
if 'market_ext' in groups_to_dump:
df = feature_groups.market_ext
if use_struct:
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
result['market_ext'] = df
if verbose:
print(f"Market ext features: {df.shape}")
if 'market_flag' in groups_to_dump:
df = feature_groups.market_flag
if use_struct:
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
result['market_flag'] = df
if verbose:
print(f"Market flag features: {df.shape}")
if 'vae_input' in groups_to_dump:
# Get VAE input columns from already-transformed feature_groups
# VAE input = 330 normalized features + 11 market flags = 341 features
# Note: indus_idx is NOT included in VAE input
df = feature_groups.merge_for_processors()
# Build alpha158 feature columns (158 original + 158 _ntrl = 316)
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
alpha158_cols = ALPHA158_COLS.copy()
# Build market_ext feature columns (7 original + 7 _ntrl = 14)
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
market_ext_cols = market_ext_with_diff.copy()
# VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
norm_feature_cols = (
alpha158_ntrl_cols + alpha158_cols +
market_ext_ntrl_cols + market_ext_cols
)
# Market flag columns (excluding IsST)
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
market_flag_cols += ['market_0', 'market_1']
market_flag_cols = list(dict.fromkeys(market_flag_cols))
# Combine all VAE input columns (341 total)
vae_cols = norm_feature_cols + market_flag_cols
if use_struct:
# Pack all features into a single struct column
result['vae_input'] = df.select([
'instrument',
'datetime',
pl.struct(vae_cols).alias('features')
])
if verbose:
print(f"VAE input features: {result['vae_input'].shape} (struct with {len(vae_cols)} fields)")
else:
# Always keep datetime and instrument as index columns
# Select features with index columns first
result['vae_input'] = df.select(['instrument', 'datetime'] + vae_cols)
if verbose:
print(f"VAE input features: {result['vae_input'].shape} (columns: {len(vae_cols)} + 2 index)")
return result
def dump_features(
df: pl.DataFrame,
output_path: str,
output_format: str = 'parquet',
groups: List[str] = None,
verbose: bool = True,
use_struct: bool = False,
) -> None:
"""
Dump features to file.
Args:
df: Transformed merged DataFrame
output_path: Output file path
output_format: Output format ('parquet', 'pickle', 'numpy')
groups: Feature groups to dump (default: ['merged'])
verbose: Whether to print progress
use_struct: If True, pack feature columns into a single 'features' struct column
"""
if groups is None:
groups = ['merged']
# Select feature groups from merged DataFrame
outputs = get_groups(df, groups, verbose, use_struct)
# Ensure output directory exists
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# For numpy, we use FeatureGroups (need to convert df back)
# This is a simplified version - for numpy, use dump_to_numpy directly
if output_format == 'numpy':
raise NotImplementedError(
"For numpy output, use the FeaturePipeline directly. "
"This function handles parquet/pickle output only."
)
# Dump to file(s)
if output_format == 'pickle':
if 'merged' in outputs:
dump_to_pickle(outputs['merged'], output_path, verbose=verbose)
elif len(outputs) == 1:
# Single group output
key = list(outputs.keys())[0]
base_path = Path(output_path)
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
dump_to_pickle(outputs[key], dump_path, verbose=verbose)
else:
# Multiple groups - save each separately
base_path = Path(output_path)
for key, df_out in outputs.items():
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
dump_to_pickle(df_out, dump_path, verbose=verbose)
else: # parquet
if 'merged' in outputs:
dump_to_parquet(outputs['merged'], output_path, verbose=verbose)
elif len(outputs) == 1:
# Single group output
key = list(outputs.keys())[0]
base_path = Path(output_path)
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
dump_to_parquet(outputs[key], dump_path, verbose=verbose)
else:
# Multiple groups - save each separately
base_path = Path(output_path)
for key, df_out in outputs.items():
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
dump_to_parquet(df_out, dump_path, verbose=verbose)
if verbose:
print("Feature dump complete!")

@ -0,0 +1,279 @@
"""Data loading functions for the feature pipeline.
Memory Best Practices:
- Always filter on partition keys (datetime) BEFORE collecting
- Use streaming for large date ranges (>1 year)
- Never collect full parquet tables - filter on datetime first
"""
import polars as pl
from datetime import datetime
from typing import Optional, List, Tuple
# Data paths
PARQUET_ALPHA158_BETA_PATH = "/data/parquet/dataset/stg_1day_wind_alpha158_0_7_beta_1D/"
PARQUET_KLINE_PATH = "/data/parquet/dataset/stg_1day_wind_kline_adjusted_1D/"
PARQUET_MARKET_FLAG_PATH = "/data/parquet/dataset/stg_1day_wind_market_flag_1D/"
PARQUET_INDUSTRY_FLAG_PATH = "/data/parquet/dataset/stg_1day_gds_indus_flag_cc1_1D/"
PARQUET_CON_RATING_PATH = "/data/parquet/dataset/stg_1day_gds_con_rating_1D/"
# Industry flag columns (30 one-hot columns - note: gds_CC29 is not present in the data)
INDUSTRY_FLAG_COLS = [
'gds_CC10', 'gds_CC11', 'gds_CC12', 'gds_CC20', 'gds_CC21', 'gds_CC22',
'gds_CC23', 'gds_CC24', 'gds_CC25', 'gds_CC26', 'gds_CC27', 'gds_CC28',
'gds_CC30', 'gds_CC31', 'gds_CC32', 'gds_CC33', 'gds_CC34', 'gds_CC35',
'gds_CC36', 'gds_CC37', 'gds_CC40', 'gds_CC41', 'gds_CC42', 'gds_CC43',
'gds_CC50', 'gds_CC60', 'gds_CC61', 'gds_CC62', 'gds_CC63', 'gds_CC70'
]
# Market flag columns
MARKET_FLAG_COLS_KLINE = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR']
MARKET_FLAG_COLS_MARKET = ['open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop']
# Market extension raw columns
MARKET_EXT_RAW_COLS = ['Turnover', 'FreeTurnover', 'MarketValue']
def get_date_partitions(start_date: str, end_date: str) -> List[str]:
"""
Generate a list of date partitions to load from Parquet.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
Returns:
List of datetime=YYYYMMDD partition strings for weekdays only
"""
start = datetime.strptime(start_date, "%Y-%m-%d")
end = datetime.strptime(end_date, "%Y-%m-%d")
partitions = []
current = start
while current <= end:
if current.weekday() < 5: # Monday = 0, Friday = 4
partitions.append(f"datetime={current.strftime('%Y%m%d')}")
current = datetime(current.year, current.month, current.day + 1)
return partitions
def load_parquet_by_date_range(
base_path: str,
start_date: str,
end_date: str,
columns: Optional[List[str]] = None,
collect: bool = True,
streaming: bool = False
) -> pl.LazyFrame | pl.DataFrame:
"""
Load parquet data filtered by date range.
CRITICAL: This function filters on the datetime partition key BEFORE collecting.
This is essential for memory efficiency with large parquet datasets.
Args:
base_path: Base path to the parquet dataset
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
columns: Optional list of columns to select (excluding instrument/datetime)
collect: If True, return DataFrame (default); if False, return LazyFrame
streaming: If True, use streaming mode for large datasets (recommended for >1 year)
Returns:
Polars DataFrame with the loaded data (LazyFrame if collect=False)
"""
start_int = int(start_date.replace("-", ""))
end_int = int(end_date.replace("-", ""))
try:
# Start with lazy scan - DO NOT COLLECT YET
lf = pl.scan_parquet(base_path)
# CRITICAL: Filter on partition key FIRST, before any other operations
# This ensures partition pruning happens at the scan level
lf = lf.filter(pl.col('datetime') >= start_int)
lf = lf.filter(pl.col('datetime') <= end_int)
# Select specific columns if provided (column pruning)
if columns:
# Get schema to check which columns exist
schema = lf.collect_schema()
available_cols = ['instrument', 'datetime'] + [c for c in columns if c in schema.names()]
lf = lf.select(available_cols)
# Collect with optional streaming mode
if collect:
if streaming:
return lf.collect(streaming=True)
return lf.collect()
return lf
except Exception as e:
print(f"Error loading from {base_path}: {e}")
# Return empty DataFrame with expected schema
if columns:
return pl.DataFrame({
'instrument': pl.Series([], dtype=pl.String),
'datetime': pl.Series([], dtype=pl.Int32),
**{col: pl.Series([], dtype=pl.Float64) for col in columns if c not in ['instrument', 'datetime']}
})
return pl.DataFrame({
'instrument': pl.Series([], dtype=pl.String),
'datetime': pl.Series([], dtype=pl.Int32)
})
def load_alpha158(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
"""
Load alpha158 beta factors from parquet.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
streaming: If True, use streaming mode for large datasets
Returns:
DataFrame with instrument, datetime, and 158 alpha158 features
"""
print("Loading alpha158_0_7_beta factors...")
df = load_parquet_by_date_range(PARQUET_ALPHA158_BETA_PATH, start_date, end_date, streaming=streaming)
print(f" Alpha158 shape: {df.shape}")
return df
def load_market_ext(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
"""
Load market extension features from parquet.
Loads Turnover, FreeTurnover, MarketValue from kline data and transforms:
- Turnover -> turnover (rename)
- FreeTurnover -> free_turnover (rename)
- MarketValue -> log_size = log(MarketValue)
- con_rating_strength: loaded from parquet (or zeros if not available)
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
streaming: If True, use streaming mode for large datasets
Returns:
DataFrame with instrument, datetime, turnover, free_turnover, log_size, con_rating_strength
"""
print("Loading kline data (market ext columns)...")
# Load raw kline columns
df_kline = load_parquet_by_date_range(
PARQUET_KLINE_PATH, start_date, end_date, MARKET_EXT_RAW_COLS, streaming=streaming
)
print(f" Kline (market ext raw) shape: {df_kline.shape}")
# Load con_rating_strength from parquet
print("Loading con_rating_strength from parquet...")
df_con_rating = load_parquet_by_date_range(
PARQUET_CON_RATING_PATH, start_date, end_date, ['con_rating_strength'], streaming=streaming
)
print(f" Con rating shape: {df_con_rating.shape}")
# Transform columns
df_kline = df_kline.with_columns([
pl.col('Turnover').alias('turnover'),
pl.col('FreeTurnover').alias('free_turnover'),
pl.col('MarketValue').log().alias('log_size'),
])
print(f" Kline (market ext transformed) shape: {df_kline.shape}")
# Merge con_rating_strength
df_kline = df_kline.join(df_con_rating, on=['instrument', 'datetime'], how='left')
df_kline = df_kline.with_columns([
pl.col('con_rating_strength').fill_null(0.0)
])
print(f" Kline (with con_rating) shape: {df_kline.shape}")
return df_kline
def load_market_flags(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
"""
Load market flag features from parquet.
Combines two sources:
- From kline_adjusted: IsZt, IsDt, IsN, IsXD, IsXR, IsDR
- From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
streaming: If True, use streaming mode for large datasets
Returns:
DataFrame with instrument, datetime, and 12 market flag columns
"""
# Load kline flags
print("Loading market flags from kline_adjusted...")
df_kline_flag = load_parquet_by_date_range(
PARQUET_KLINE_PATH, start_date, end_date, MARKET_FLAG_COLS_KLINE, streaming=streaming
)
print(f" Kline flags shape: {df_kline_flag.shape}")
# Load market flags
print("Loading market flags from market_flag table...")
df_market_flag = load_parquet_by_date_range(
PARQUET_MARKET_FLAG_PATH, start_date, end_date, MARKET_FLAG_COLS_MARKET, streaming=streaming
)
print(f" Market flag shape: {df_market_flag.shape}")
# Merge both flag sources
df_flag = df_kline_flag.join(df_market_flag, on=['instrument', 'datetime'], how='inner')
print(f" Combined flags shape: {df_flag.shape}")
return df_flag
def load_industry_flags(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
"""
Load industry flag features from parquet.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
streaming: If True, use streaming mode for large datasets
Returns:
DataFrame with instrument, datetime, and 29 industry flag columns
"""
print("Loading industry flags...")
df = load_parquet_by_date_range(
PARQUET_INDUSTRY_FLAG_PATH, start_date, end_date, INDUSTRY_FLAG_COLS, streaming=streaming
)
print(f" Industry shape: {df.shape}")
return df
def load_all_data(
start_date: str,
end_date: str,
streaming: bool = False
) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame]:
"""
Load all data sources from Parquet.
This is a convenience function that loads all four data sources
and returns them as separate DataFrames.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
streaming: If True, use streaming mode for large datasets
Returns:
Tuple of (alpha158_df, market_ext_df, market_flag_df, industry_df)
"""
print(f"Loading data from {start_date} to {end_date}...")
df_alpha = load_alpha158(start_date, end_date, streaming=streaming)
df_market_ext = load_market_ext(start_date, end_date, streaming=streaming)
df_market_flag = load_market_flags(start_date, end_date, streaming=streaming)
df_industry = load_industry_flags(start_date, end_date, streaming=streaming)
return df_alpha, df_market_ext, df_market_flag, df_industry

@ -0,0 +1,539 @@
"""FeaturePipeline orchestrator for the data loading and transformation pipeline."""
import os
import json
import numpy as np
import polars as pl
from typing import List, Dict, Optional, Tuple
from pathlib import Path
from .dataclass import FeatureGroups
from .loaders import (
load_alpha158,
load_market_ext,
load_market_flags,
load_industry_flags,
load_all_data,
INDUSTRY_FLAG_COLS,
)
from .processors import (
DiffProcessor,
FlagMarketInjector,
FlagSTInjector,
ColumnRemover,
FlagToOnehot,
IndusNtrlInjector,
RobustZScoreNorm,
Fillna
)
# Constants - Import from loaders module (single source of truth)
# INDUSTRY_FLAG_COLS is now imported from .loaders above
# Alpha158 feature columns in explicit order
ALPHA158_COLS = [
'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
]
# Market extension base columns
MARKET_EXT_BASE_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
# Market flag columns (before processors)
MARKET_FLAG_COLS = [
'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'
]
# Columns to remove after FlagMarketInjector and FlagSTInjector
COLUMNS_TO_REMOVE = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
# Expected VAE input dimension
VAE_INPUT_DIM = 341
# Default robust zscore parameters path
DEFAULT_ROBUST_ZSCORE_PARAMS_PATH = (
"/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/"
"data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/"
)
def filter_stock_universe(df: pl.DataFrame, instruments: str = 'csiallx') -> pl.DataFrame:
"""
Filter dataframe to csiallx stock universe (A-shares excluding STAR/BSE).
This uses qshare's filter_instruments which loads the instrument list from:
/data/qlib/default/data_ops/target/instruments/csiallx.txt
Args:
df: Input DataFrame with datetime and instrument columns
instruments: Market name for spine creation (default: 'csiallx')
Returns:
Filtered DataFrame with only instruments in the specified universe
"""
from qshare.algo.polars.spine import filter_instruments
return filter_instruments(df, instruments=instruments)
def load_robust_zscore_params(
params_path: str = None
) -> Dict[str, np.ndarray]:
"""
Load pre-fitted RobustZScoreNorm parameters from numpy files.
Loads mean_train.npy and std_train.npy from the specified directory.
Parameters are cached to avoid repeated file I/O.
Args:
params_path: Path to the directory containing mean_train.npy and std_train.npy.
If None, uses the default path.
Returns:
Dictionary with 'mean_train' and 'std_train' numpy arrays
Raises:
FileNotFoundError: If parameter files are not found
"""
if params_path is None:
params_path = DEFAULT_ROBUST_ZSCORE_PARAMS_PATH
# Check for cached params in the class (module-level cache)
if not hasattr(load_robust_zscore_params, '_cached_params'):
load_robust_zscore_params._cached_params = {}
if params_path in load_robust_zscore_params._cached_params:
return load_robust_zscore_params._cached_params[params_path]
print(f"Loading robust zscore parameters from: {params_path}")
mean_path = os.path.join(params_path, 'mean_train.npy')
std_path = os.path.join(params_path, 'std_train.npy')
if not os.path.exists(mean_path):
raise FileNotFoundError(f"mean_train.npy not found at {mean_path}")
if not os.path.exists(std_path):
raise FileNotFoundError(f"std_train.npy not found at {std_path}")
mean_train = np.load(mean_path)
std_train = np.load(std_path)
print(f"Loaded parameters:")
print(f" mean_train shape: {mean_train.shape}")
print(f" std_train shape: {std_train.shape}")
# Try to load metadata if available
metadata_path = os.path.join(params_path, 'metadata.json')
if os.path.exists(metadata_path):
with open(metadata_path, 'r') as f:
metadata = json.load(f)
print(f" fitted on: {metadata.get('fit_start_time', 'N/A')} to {metadata.get('fit_end_time', 'N/A')}")
params = {
'mean_train': mean_train,
'std_train': std_train
}
# Cache the loaded parameters
load_robust_zscore_params._cached_params[params_path] = params
return params
class FeaturePipeline:
"""
Feature pipeline orchestrator for loading and transforming data.
The pipeline manages:
1. Data loading from parquet sources
2. Feature transformations via processors
3. Output preparation for VAE input
Usage:
pipeline = FeaturePipeline(config)
feature_groups = pipeline.load_data(start_date, end_date)
transformed = pipeline.transform(feature_groups)
vae_input = pipeline.get_vae_input(transformed)
"""
def __init__(
self,
config: Optional[Dict] = None,
robust_zscore_params_path: Optional[str] = None
):
"""
Initialize the FeaturePipeline.
Args:
config: Optional configuration dictionary with pipeline settings.
If None, uses default configuration.
robust_zscore_params_path: Path to robust zscore parameters.
If None, uses default path.
"""
self.config = config or {}
self.robust_zscore_params_path = robust_zscore_params_path or DEFAULT_ROBUST_ZSCORE_PARAMS_PATH
# Cache for loaded robust zscore params
self._robust_zscore_params = None
# Initialize processors
self._init_processors()
def _init_processors(self):
"""Initialize all processors with default configuration."""
# Diff processor for market_ext columns
self.diff_processor = DiffProcessor(MARKET_EXT_BASE_COLS)
# Market flag injector
self.flag_market_injector = FlagMarketInjector()
# ST flag injector
self.flag_st_injector = FlagSTInjector()
# Column remover
self.column_remover = ColumnRemover(COLUMNS_TO_REMOVE)
# Industry flag to index converter
self.flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
# Industry neutralization injectors (created on-demand with specific feature lists)
# self.indus_ntrl_injector = None # Created per-feature-group
# Robust zscore normalizer (created on-demand with loaded params)
# self.robust_norm = None # Created with loaded params
# Fillna processor
self.fillna = Fillna()
def load_data(
self,
start_date: str,
end_date: str,
filter_universe: bool = True,
universe_name: str = 'csiallx',
streaming: bool = False
) -> FeatureGroups:
"""
Load data from parquet sources into FeatureGroups container.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
filter_universe: Whether to filter to a specific stock universe
universe_name: Name of the stock universe to filter to
streaming: If True, use Polars streaming mode for large datasets (>1 year)
Returns:
FeatureGroups container with loaded data
"""
print("=" * 60)
print(f"Loading data from {start_date} to {end_date}")
print("=" * 60)
# Load all data sources
df_alpha, df_market_ext, df_market_flag, df_industry = load_all_data(
start_date, end_date, streaming=streaming
)
# Apply stock universe filter if requested
if filter_universe:
print(f"Filtering to {universe_name} universe...")
df_alpha = filter_stock_universe(df_alpha, instruments=universe_name)
df_market_ext = filter_stock_universe(df_market_ext, instruments=universe_name)
df_market_flag = filter_stock_universe(df_market_flag, instruments=universe_name)
df_industry = filter_stock_universe(df_industry, instruments=universe_name)
print(f" After filter - Alpha158 shape: {df_alpha.shape}")
# Create FeatureGroups container
feature_groups = FeatureGroups(
alpha158=df_alpha,
market_ext=df_market_ext,
market_flag=df_market_flag,
industry=df_industry
)
# Extract metadata
feature_groups.extract_metadata()
print(f"Loaded {len(feature_groups.instruments)} samples")
print("=" * 60)
return feature_groups
def transform(
self,
feature_groups: FeatureGroups,
pack_struct: bool = False
) -> pl.DataFrame:
"""
Apply feature transformation pipeline to FeatureGroups.
The pipeline applies processors in the following order:
1. DiffProcessor - adds diff features to market_ext
2. FlagMarketInjector - adds market_0, market_1 to market_flag
3. FlagSTInjector - adds IsST to market_flag
4. ColumnRemover - removes log_size_diff, IsN, IsZt, IsDt
5. FlagToOnehot - converts industry flags to indus_idx
6. IndusNtrlInjector - industry neutralization for alpha158 and market_ext
7. RobustZScoreNorm - robust z-score normalization
8. Fillna - fill NaN values with 0
Args:
feature_groups: FeatureGroups container with loaded data
pack_struct: If True, pack each feature group into a struct column
(features_alpha158, features_market_ext, features_market_flag).
If False (default), return flat DataFrame with all columns merged.
Returns:
Merged DataFrame with all transformed features (pl.DataFrame)
"""
print("=" * 60)
print("Starting feature transformation pipeline")
print("=" * 60)
# Merge all groups for processing
df = feature_groups.merge_for_processors()
# Step 1: Diff Processor - adds diff features for market_ext
df = self.diff_processor.process(df)
# Step 2: FlagMarketInjector - adds market_0, market_1
df = self.flag_market_injector.process(df)
# Step 3: FlagSTInjector - adds IsST
df = self.flag_st_injector.process(df)
# Step 4: ColumnRemover - removes specific columns
df = self.column_remover.process(df)
# Step 5: FlagToOnehot - converts industry flags to indus_idx
df = self.flag_to_onehot.process(df)
# Step 6: IndusNtrlInjector - industry neutralization
# For alpha158 features
indus_ntrl_alpha = IndusNtrlInjector(ALPHA158_COLS, suffix='_ntrl')
df = indus_ntrl_alpha.process(df, df) # Pass df as both feature and industry source
# For market_ext features (with diff columns)
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
# Remove columns that were dropped
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
indus_ntrl_ext = IndusNtrlInjector(market_ext_with_diff, suffix='_ntrl')
df = indus_ntrl_ext.process(df, df)
# Step 7: RobustZScoreNorm - robust z-score normalization
# Load parameters and create normalizer
if self._robust_zscore_params is None:
self._robust_zscore_params = load_robust_zscore_params(self.robust_zscore_params_path)
# Build the list of features to normalize
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
# Feature order for VAE: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
norm_feature_cols = (
alpha158_ntrl_cols + ALPHA158_COLS +
market_ext_ntrl_cols + market_ext_with_diff
)
print(f"Applying RobustZScoreNorm to {len(norm_feature_cols)} features...")
robust_norm = RobustZScoreNorm(
norm_feature_cols,
clip_range=(-3, 3),
use_qlib_params=True,
qlib_mean=self._robust_zscore_params['mean_train'],
qlib_std=self._robust_zscore_params['std_train']
)
df = robust_norm.process(df)
# Step 8: Fillna - fill NaN with 0
# Get all feature columns
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
market_flag_cols += ['market_0', 'market_1', 'IsST']
market_flag_cols = list(dict.fromkeys(market_flag_cols)) # Remove duplicates
final_feature_cols = norm_feature_cols + market_flag_cols + ['indus_idx']
df = self.fillna.process(df, final_feature_cols)
print("=" * 60)
print("Pipeline complete")
print(f" Total columns: {len(df.columns)}")
print(f" Rows: {len(df)}")
# Optionally pack features into struct columns
if pack_struct:
df = self._pack_into_structs(df)
print(f" Packed into struct columns")
print("=" * 60)
return df
def _pack_into_structs(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Pack feature groups into struct columns.
Creates:
- features_alpha158: struct with 316 fields (158 + 158 _ntrl)
- features_market_ext: struct with 14 fields (7 + 7 _ntrl)
- features_market_flag: struct with 11 fields
Returns:
DataFrame with columns: instrument, datetime, indus_idx, features_*
"""
# Define column groups
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
alpha158_all_cols = alpha158_ntrl_cols + ALPHA158_COLS
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
market_ext_all_cols = market_ext_ntrl_cols + market_ext_with_diff
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
market_flag_cols += ['market_0', 'market_1', 'IsST']
market_flag_cols = list(dict.fromkeys(market_flag_cols))
# Build result with struct columns
result_cols = ['instrument', 'datetime']
# Check if indus_idx exists
if 'indus_idx' in df.columns:
result_cols.append('indus_idx')
# Pack alpha158
alpha158_cols_in_df = [c for c in alpha158_all_cols if c in df.columns]
if alpha158_cols_in_df:
result_cols.append(pl.struct(alpha158_cols_in_df).alias('features_alpha158'))
# Pack market_ext
ext_cols_in_df = [c for c in market_ext_all_cols if c in df.columns]
if ext_cols_in_df:
result_cols.append(pl.struct(ext_cols_in_df).alias('features_market_ext'))
# Pack market_flag
flag_cols_in_df = [c for c in market_flag_cols if c in df.columns]
if flag_cols_in_df:
result_cols.append(pl.struct(flag_cols_in_df).alias('features_market_flag'))
return df.select(result_cols)
def get_vae_input(
self,
df: pl.DataFrame | FeatureGroups,
exclude_isst: bool = True
) -> np.ndarray:
"""
Prepare VAE input features from transformed DataFrame or FeatureGroups.
VAE input structure (341 features):
- feature group (316): 158 alpha158 + 158 alpha158_ntrl
- feature_ext group (14): 7 market_ext + 7 market_ext_ntrl
- feature_flag group (11): market flags (excluding IsST)
NOTE: indus_idx is NOT included in VAE input.
Args:
df: Transformed DataFrame or FeatureGroups container
exclude_isst: Whether to exclude IsST from VAE input (default: True)
Returns:
Numpy array of shape (n_samples, VAE_INPUT_DIM)
"""
print("Preparing features for VAE...")
# Accept either DataFrame or FeatureGroups
if isinstance(df, FeatureGroups):
df = df.merge_for_processors()
# Build alpha158 feature columns
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
alpha158_cols = ALPHA158_COLS.copy()
# Build market_ext feature columns (with diff, minus removed columns)
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
market_ext_cols = market_ext_with_diff.copy()
# VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
norm_feature_cols = (
alpha158_ntrl_cols + alpha158_cols +
market_ext_ntrl_cols + market_ext_cols
)
# Market flag columns (excluding IsST if requested)
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
market_flag_cols += ['market_0', 'market_1']
if not exclude_isst:
market_flag_cols.append('IsST')
market_flag_cols = list(dict.fromkeys(market_flag_cols))
# Combine all VAE input columns
vae_cols = norm_feature_cols + market_flag_cols
print(f" norm_feature_cols: {len(norm_feature_cols)}")
print(f" market_flag_cols: {len(market_flag_cols)}")
print(f" Total VAE input columns: {len(vae_cols)}")
# Verify all columns exist
missing_cols = [c for c in vae_cols if c not in df.columns]
if missing_cols:
print(f"WARNING: Missing columns: {missing_cols}")
# Select features and convert to numpy
features_df = df.select(vae_cols)
features = features_df.to_numpy().astype(np.float32)
# Handle any remaining NaN/Inf values
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
print(f"Feature matrix shape: {features.shape}")
# Verify dimensions
if features.shape[1] != VAE_INPUT_DIM:
print(f"WARNING: Expected {VAE_INPUT_DIM} features, got {features.shape[1]}")
if features.shape[1] < VAE_INPUT_DIM:
# Pad with zeros
padding = np.zeros(
(features.shape[0], VAE_INPUT_DIM - features.shape[1]),
dtype=np.float32
)
features = np.concatenate([features, padding], axis=1)
print(f"Padded to shape: {features.shape}")
else:
# Truncate
features = features[:, :VAE_INPUT_DIM]
print(f"Truncated to shape: {features.shape}")
return features

@ -0,0 +1,447 @@
"""Processor classes for feature transformation.
Processors are stateless, functional components that operate on
specific feature groups or merged DataFrames.
"""
import numpy as np
import polars as pl
from typing import List, Tuple, Optional, Dict
class DiffProcessor:
"""
Diff Processor: Calculate diff features for market_ext columns.
For each column, calculates diff with period=1 within each instrument group.
"""
def __init__(self, columns: List[str]):
"""
Initialize the DiffProcessor.
Args:
columns: List of column names to compute diffs for
"""
self.columns = columns
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Add diff features for specified columns.
Args:
df: Input DataFrame with market_ext columns
Returns:
DataFrame with added {col}_diff columns
"""
print("Applying Diff processor...")
# CRITICAL: Build all expressions FIRST, then apply in single with_columns()
# Use order_by='datetime' to ensure proper time-series ordering within each instrument
diff_exprs = [
pl.col(col).diff().over('instrument', order_by='datetime').alias(f"{col}_diff")
for col in self.columns
if col in df.columns
]
if diff_exprs:
df = df.with_columns(diff_exprs)
return df
class FlagMarketInjector:
"""
Flag Market Injector: Create market_0, market_1 columns based on instrument code.
Maps to Qlib's map_market_sec logic with vocab_size=2:
- market_0 (主板): SH60xxx, SZ00xxx
- market_1 (科创板/创业板): SH688xxx, SH689xxx, SZ300xxx, SZ301xxx
NOTE: vocab_size=2 (not 3!) - 新三板/北交所 (NE4xxxx, NE8xxxx) are NOT included.
"""
def process(self, df: pl.DataFrame, instrument_col: str = 'instrument') -> pl.DataFrame:
"""
Add market_0, market_1 columns.
Args:
df: Input DataFrame with instrument column
instrument_col: Name of the instrument column
Returns:
DataFrame with added market_0 and market_1 columns
"""
print("Applying FlagMarketInjector (vocab_size=2)...")
# Convert instrument to string and pad to 6 digits
inst_str = pl.col(instrument_col).cast(pl.String).str.zfill(6)
# Determine market type based on first digit
is_sh_main = inst_str.str.starts_with('6') # SH600xxx, SH601xxx, etc.
is_sz_main = inst_str.str.starts_with('0') # SZ000xxx
is_sh_star = inst_str.str.starts_with('688') | inst_str.str.starts_with('689') # SH688xxx, SH689xxx
is_sz_chi = inst_str.str.starts_with('300') | inst_str.str.starts_with('301') # SZ300xxx, SZ301xxx
df = df.with_columns([
# market_0 = 主板 (SH main + SZ main)
(is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'),
# market_1 = 科创板 + 创业板 (SH star + SZ ChiNext)
(is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1')
])
return df
class FlagSTInjector:
"""
Flag ST Injector: Create IsST column from ST flags.
Creates IsST = ST_S | ST_Y if ST flags are available,
otherwise creates a placeholder column of all zeros.
"""
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Add IsST column.
Args:
df: Input DataFrame
Returns:
DataFrame with added IsST column
"""
print("Applying FlagSTInjector (creating IsST)...")
# Check if ST flags are available
if 'ST_S' in df.columns or 'st_flag::ST_S' in df.columns:
# Create IsST from actual ST flags
df = df.with_columns([
((pl.col('ST_S').cast(pl.Boolean, strict=False) |
pl.col('ST_Y').cast(pl.Boolean, strict=False))
.cast(pl.Int8).alias('IsST'))
])
else:
# Create placeholder (all zeros) if ST flags not available
df = df.with_columns([
pl.lit(0).cast(pl.Int8).alias('IsST')
])
return df
class ColumnRemover:
"""
Column Remover: Drop specific columns.
Removes columns that are not needed for the VAE input.
"""
def __init__(self, columns_to_remove: List[str]):
"""
Initialize the ColumnRemover.
Args:
columns_to_remove: List of column names to remove
"""
self.columns_to_remove = columns_to_remove
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Remove specified columns.
Args:
df: Input DataFrame
Returns:
DataFrame with specified columns removed
"""
print(f"Applying ColumnRemover (removing {len(self.columns_to_remove)} columns)...")
# Only remove columns that exist
cols_to_drop = [c for c in self.columns_to_remove if c in df.columns]
if cols_to_drop:
df = df.drop(cols_to_drop)
return df
class FlagToOnehot:
"""
Flag To Onehot: Convert 29 one-hot industry columns to single indus_idx.
For each row, finds which industry column is True/1 and sets indus_idx to that index.
"""
def __init__(self, industry_cols: List[str]):
"""
Initialize the FlagToOnehot.
Args:
industry_cols: List of 29 industry flag column names
"""
self.industry_cols = industry_cols
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Convert industry flags to single indus_idx column.
Args:
df: Input DataFrame with industry flag columns
Returns:
DataFrame with indus_idx column (original industry columns removed)
"""
print("Applying FlagToOnehot (converting 29 industry flags to indus_idx)...")
# Build a when/then chain to find the industry index
# Start with -1 (no industry) as default
indus_expr = pl.lit(-1)
for idx, col in enumerate(self.industry_cols):
if col in df.columns:
indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr)
df = df.with_columns([indus_expr.alias('indus_idx')])
# Drop the original one-hot columns
cols_to_drop = [c for c in self.industry_cols if c in df.columns]
if cols_to_drop:
df = df.drop(cols_to_drop)
return df
class IndusNtrlInjector:
"""
Industry Neutralization Injector: Industry neutralization for features.
For each feature, subtracts the industry mean (grouped by indus_idx)
from the feature value. Creates new columns with "_ntrl" suffix while
keeping original columns.
IMPORTANT: Industry neutralization is done PER DATETIME (cross-sectional),
not across the entire dataset. This matches qlib's cal_indus_ntrl behavior.
"""
def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'):
"""
Initialize the IndusNtrlInjector.
Args:
feature_cols: List of feature column names to neutralize
suffix: Suffix to append to neutralized column names
"""
self.feature_cols = feature_cols
self.suffix = suffix
def process(
self,
feature_df: pl.DataFrame,
industry_df: pl.DataFrame
) -> pl.DataFrame:
"""
Apply industry neutralization to specified features.
Args:
feature_df: DataFrame with feature columns to neutralize
industry_df: DataFrame with indus_idx column (must have instrument, datetime)
Returns:
DataFrame with added {col}_ntrl columns
"""
print(f"Applying IndusNtrlInjector to {len(self.feature_cols)} features...")
# Check if indus_idx already exists in feature_df
if 'indus_idx' in feature_df.columns:
df = feature_df
else:
# Merge industry index into feature dataframe
df = feature_df.join(
industry_df.select(['instrument', 'datetime', 'indus_idx']),
on=['instrument', 'datetime'],
how='left',
suffix='_indus'
)
# Filter to only columns that exist
existing_cols = [c for c in self.feature_cols if c in df.columns]
# CRITICAL: Build all neutralization expressions FIRST, then apply in single with_columns()
# Use order_by='datetime' to ensure proper time-series ordering within each group
# The neutralization is done per datetime (cross-sectional), so order_by='datetime'
# ensures values are processed in chronological order
ntrl_exprs = [
(pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'], order_by='datetime')).alias(f"{col}{self.suffix}")
for col in existing_cols
]
if ntrl_exprs:
df = df.with_columns(ntrl_exprs)
return df
class RobustZScoreNorm:
"""
Robust Z-Score Normalization: Per datetime normalization.
(x - median) / (1.4826 * MAD) where MAD = median(|x - median|)
Clip outliers at [-3, 3].
Supports pre-fitted parameters from qlib's pickled processor:
normalizer = RobustZScoreNorm(
feature_cols=feature_cols,
use_qlib_params=True,
qlib_mean=zscore_proc.mean_train,
qlib_std=zscore_proc.std_train
)
"""
def __init__(
self,
feature_cols: List[str],
clip_range: Tuple[float, float] = (-3, 3),
use_qlib_params: bool = False,
qlib_mean: Optional[np.ndarray] = None,
qlib_std: Optional[np.ndarray] = None
):
"""
Initialize the RobustZScoreNorm.
Args:
feature_cols: List of feature column names to normalize
clip_range: Tuple of (min, max) for clipping normalized values
use_qlib_params: Whether to use pre-fitted parameters from qlib
qlib_mean: Pre-fitted mean array from qlib (required if use_qlib_params=True)
qlib_std: Pre-fitted std array from qlib (required if use_qlib_params=True)
"""
self.feature_cols = feature_cols
self.clip_range = clip_range
self.use_qlib_params = use_qlib_params
self.mean_train = qlib_mean
self.std_train = qlib_std
if use_qlib_params:
if qlib_mean is None or qlib_std is None:
raise ValueError("Must provide qlib_mean and qlib_std when use_qlib_params=True")
print(f"Using pre-fitted qlib parameters (mean shape: {qlib_mean.shape}, std shape: {qlib_std.shape})")
def process(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Apply robust z-score normalization.
Args:
df: Input DataFrame with feature columns
Returns:
DataFrame with normalized feature columns (in-place modification)
"""
print(f"Applying RobustZScoreNorm to {len(self.feature_cols)} features...")
# Filter to only columns that exist
existing_cols = [c for c in self.feature_cols if c in df.columns]
if self.use_qlib_params:
# Use pre-fitted parameters from qlib (fit once, apply to all dates)
# CRITICAL: Build all normalization expressions FIRST, then apply in single with_columns()
# This avoids creating a copy per column (330 columns = 330 copies if done in loop)
norm_exprs = []
for i, col in enumerate(existing_cols):
if i < len(self.mean_train):
mean_val = float(self.mean_train[i])
std_val = float(self.std_train[i])
norm_exprs.append(
((pl.col(col) - mean_val) / (std_val + 1e-8))
.clip(self.clip_range[0], self.clip_range[1])
.alias(col)
)
if norm_exprs:
df = df.with_columns(norm_exprs)
else:
# Compute per-datetime robust z-score (original behavior)
# CRITICAL: Build all expressions with temp columns first, then clean up in single drop()
all_exprs = []
temp_cols = []
for col in existing_cols:
# Compute median per datetime
median_col = f"__median_{col}"
temp_cols.append(median_col)
all_exprs.append(
pl.col(col).median().over('datetime').alias(median_col)
)
# Compute absolute deviation
abs_dev_col = f"__absdev_{col}"
temp_cols.append(abs_dev_col)
all_exprs.append(
(pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col)
)
# Compute MAD (median of absolute deviations)
mad_col = f"__mad_{col}"
temp_cols.append(mad_col)
all_exprs.append(
pl.col(abs_dev_col).median().over('datetime').alias(mad_col)
)
# Compute robust z-score and clip (modifies original column)
all_exprs.append(
((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8))
.clip(self.clip_range[0], self.clip_range[1])
.alias(col)
)
# Apply all expressions in single with_columns()
if all_exprs:
df = df.with_columns(all_exprs)
# Clean up all temporary columns in single drop()
if temp_cols:
df = df.drop(temp_cols)
return df
class Fillna:
"""
Fill NaN: Fill all NaN values with 0 for numeric columns.
"""
def process(
self,
df: pl.DataFrame,
feature_cols: List[str]
) -> pl.DataFrame:
"""
Fill NaN values with 0 for specified columns.
Args:
df: Input DataFrame
feature_cols: List of column names to fill NaN values for
Returns:
DataFrame with NaN values filled with 0
"""
print("Applying Fillna processor...")
# Filter to only columns that exist and are numeric (not boolean)
existing_cols = [
c for c in feature_cols
if c in df.columns and df[c].dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64, pl.UInt32, pl.UInt64]
]
# CRITICAL: Build all fill expressions FIRST, then apply in single with_columns()
# This avoids creating a copy per column (~345 columns = ~345 copies if done in loop)
fill_exprs = [
pl.col(col).fill_null(0.0).fill_nan(0.0)
for col in existing_cols
]
if fill_exprs:
df = df.with_columns(fill_exprs)
return df
Loading…
Cancel
Save