- Remove split_at_end parameter from pipeline.transform(), always return DataFrame - Add pack_struct parameter to pack feature groups into struct columns - Rename exporters: select_feature_groups_from_df -> get_groups, select_feature_groups -> get_groups_from_fg - Add pack_structs() and unpack_struct() helper functions - Remove split_from_merged() method from FeatureGroups (no longer needed) - Rename dump_polars_dataset.py to dump_features.py with --pack-struct flag - Update README with new CLI usage and struct column documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>master
parent
26a694298d
commit
5109ac4eb3
@ -0,0 +1,265 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Script to generate and dump transformed features from the alpha158_beta pipeline.
|
||||
|
||||
This script provides fine-grained control over the feature generation and dumping process:
|
||||
- Select which feature groups to dump (alpha158, market_ext, market_flag, merged, vae_input)
|
||||
- Choose output format (parquet, pickle, numpy)
|
||||
- Control date range and universe filtering
|
||||
- Save intermediate pipeline outputs
|
||||
- Enable streaming mode for large datasets (>1 year)
|
||||
|
||||
Usage:
|
||||
# Dump all features to parquet
|
||||
python dump_features.py --start-date 2025-01-01 --end-date 2025-01-31
|
||||
|
||||
# Dump only alpha158 features to pickle
|
||||
python dump_features.py --groups alpha158 --format pickle
|
||||
|
||||
# Dump with custom output path
|
||||
python dump_features.py --output /path/to/output.parquet
|
||||
|
||||
# Dump merged features with all columns
|
||||
python dump_features.py --groups merged --verbose
|
||||
|
||||
# Use streaming mode for large date ranges (>1 year)
|
||||
python dump_features.py --start-date 2020-01-01 --end-date 2023-12-31 --streaming
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
# Add src to path for imports
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
sys.path.insert(0, str(SCRIPT_DIR.parent / 'src'))
|
||||
|
||||
from processors import (
|
||||
FeaturePipeline,
|
||||
FeatureGroups,
|
||||
VAE_INPUT_DIM,
|
||||
ALPHA158_COLS,
|
||||
MARKET_EXT_BASE_COLS,
|
||||
COLUMNS_TO_REMOVE,
|
||||
get_groups,
|
||||
dump_to_parquet,
|
||||
dump_to_pickle,
|
||||
dump_to_numpy,
|
||||
)
|
||||
|
||||
# Default output directory
|
||||
DEFAULT_OUTPUT_DIR = SCRIPT_DIR.parent / "data"
|
||||
|
||||
|
||||
def generate_and_dump(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
output_path: str,
|
||||
output_format: str = 'parquet',
|
||||
groups: List[str] = None,
|
||||
universe: str = 'csiallx',
|
||||
filter_universe: bool = True,
|
||||
robust_zscore_params_path: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
pack_struct: bool = False,
|
||||
streaming: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Generate features and dump to file.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
output_path: Output file path
|
||||
output_format: Output format ('parquet', 'pickle', 'numpy')
|
||||
groups: Feature groups to dump (default: ['merged'])
|
||||
universe: Stock universe name
|
||||
filter_universe: Whether to filter to stock universe
|
||||
robust_zscore_params_path: Path to robust zscore parameters
|
||||
verbose: Whether to print progress
|
||||
pack_struct: If True, pack each feature group into struct columns
|
||||
(features_alpha158, features_market_ext, features_market_flag)
|
||||
streaming: If True, use Polars streaming mode for large datasets (>1 year)
|
||||
"""
|
||||
if groups is None:
|
||||
groups = ['merged']
|
||||
|
||||
print("=" * 60)
|
||||
print("Feature Dump Script")
|
||||
print("=" * 60)
|
||||
print(f"Date range: {start_date} to {end_date}")
|
||||
print(f"Output format: {output_format}")
|
||||
print(f"Feature groups: {groups}")
|
||||
print(f"Universe: {universe} (filter: {filter_universe})")
|
||||
print(f"Pack struct: {pack_struct}")
|
||||
print(f"Output path: {output_path}")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize pipeline
|
||||
pipeline = FeaturePipeline(
|
||||
robust_zscore_params_path=robust_zscore_params_path
|
||||
)
|
||||
|
||||
# Load data
|
||||
feature_groups = pipeline.load_data(
|
||||
start_date, end_date,
|
||||
filter_universe=filter_universe,
|
||||
universe_name=universe,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
# Apply transformations - get merged DataFrame (pipeline always returns merged DataFrame now)
|
||||
df_transformed = pipeline.transform(feature_groups, pack_struct=pack_struct)
|
||||
|
||||
# Select feature groups from merged DataFrame
|
||||
outputs = get_groups(df_transformed, groups, verbose, use_struct=False)
|
||||
|
||||
# Ensure output directory exists
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Dump to file(s)
|
||||
if output_format == 'numpy':
|
||||
# For numpy, we save the merged features
|
||||
dump_to_numpy(feature_groups, output_path, include_metadata=True, verbose=verbose)
|
||||
elif output_format == 'pickle':
|
||||
if 'merged' in outputs:
|
||||
dump_to_pickle(outputs['merged'], output_path, verbose=verbose)
|
||||
elif len(outputs) == 1:
|
||||
# Single group output
|
||||
key = list(outputs.keys())[0]
|
||||
base_path = Path(output_path)
|
||||
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||
dump_to_pickle(outputs[key], dump_path, verbose=verbose)
|
||||
else:
|
||||
# Multiple groups - save each separately
|
||||
base_path = Path(output_path)
|
||||
for key, df_out in outputs.items():
|
||||
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||
dump_to_pickle(df_out, dump_path, verbose=verbose)
|
||||
else: # parquet
|
||||
if 'merged' in outputs:
|
||||
dump_to_parquet(outputs['merged'], output_path, verbose=verbose)
|
||||
elif len(outputs) == 1:
|
||||
# Single group output
|
||||
key = list(outputs.keys())[0]
|
||||
base_path = Path(output_path)
|
||||
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||
dump_to_parquet(outputs[key], dump_path, verbose=verbose)
|
||||
else:
|
||||
# Multiple groups - save each separately
|
||||
base_path = Path(output_path)
|
||||
for key, df_out in outputs.items():
|
||||
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||
dump_to_parquet(df_out, dump_path, verbose=verbose)
|
||||
|
||||
print("=" * 60)
|
||||
print("Feature dump complete!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate and dump transformed features from alpha158_beta pipeline"
|
||||
)
|
||||
|
||||
# Date range
|
||||
parser.add_argument(
|
||||
"--start-date", type=str, required=True,
|
||||
help="Start date in YYYY-MM-DD format"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--end-date", type=str, required=True,
|
||||
help="End date in YYYY-MM-DD format"
|
||||
)
|
||||
|
||||
# Output settings
|
||||
parser.add_argument(
|
||||
"--output", "-o", type=str, default=None,
|
||||
help=f"Output file path (default: {DEFAULT_OUTPUT_DIR}/features.parquet)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format", "-f", type=str, default='parquet',
|
||||
choices=['parquet', 'pickle', 'numpy'],
|
||||
help="Output format (default: parquet)"
|
||||
)
|
||||
|
||||
# Feature groups
|
||||
parser.add_argument(
|
||||
"--groups", "-g", type=str, nargs='+', default=['merged'],
|
||||
choices=['merged', 'alpha158', 'market_ext', 'market_flag', 'vae_input'],
|
||||
help="Feature groups to dump (default: merged)"
|
||||
)
|
||||
|
||||
# Universe settings
|
||||
parser.add_argument(
|
||||
"--universe", type=str, default='csiallx',
|
||||
help="Stock universe name (default: csiallx)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-filter-universe", action="store_true",
|
||||
help="Disable stock universe filtering"
|
||||
)
|
||||
|
||||
# Robust zscore parameters
|
||||
parser.add_argument(
|
||||
"--robust-zscore-params", type=str, default=None,
|
||||
help="Path to robust zscore parameters directory"
|
||||
)
|
||||
|
||||
# Verbose mode
|
||||
parser.add_argument(
|
||||
"--verbose", "-v", action="store_true", default=True,
|
||||
help="Enable verbose output (default: True)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quiet", "-q", action="store_true",
|
||||
help="Disable verbose output"
|
||||
)
|
||||
|
||||
# Struct option
|
||||
parser.add_argument(
|
||||
"--pack-struct", "-s", action="store_true",
|
||||
help="Pack each feature group into separate struct columns (features_alpha158, features_market_ext, features_market_flag)"
|
||||
)
|
||||
|
||||
# Streaming option
|
||||
parser.add_argument(
|
||||
"--streaming", action="store_true",
|
||||
help="Use Polars streaming mode for large datasets (recommended for date ranges > 1 year)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle verbose/quiet flags
|
||||
verbose = args.verbose and not args.quiet
|
||||
|
||||
# Set default output path
|
||||
if args.output is None:
|
||||
# Build default output path: {data_dir}/features_{group}.parquet
|
||||
# Note: generate_and_dump will add group suffix, so use base name "features"
|
||||
output_path = str(DEFAULT_OUTPUT_DIR / "features.parquet")
|
||||
else:
|
||||
output_path = args.output
|
||||
|
||||
# Generate and dump
|
||||
generate_and_dump(
|
||||
start_date=args.start_date,
|
||||
end_date=args.end_date,
|
||||
output_path=output_path,
|
||||
output_format=args.format,
|
||||
groups=args.groups,
|
||||
universe=args.universe,
|
||||
filter_universe=not args.no_filter_universe,
|
||||
robust_zscore_params_path=args.robust_zscore_params,
|
||||
verbose=verbose,
|
||||
pack_struct=args.pack_struct,
|
||||
streaming=args.streaming,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,364 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Script to dump raw and processed datasets using the polars-based pipeline.
|
||||
|
||||
This generates:
|
||||
1. Raw data (before applying processors) - equivalent to qlib's handler output
|
||||
2. Processed data (after applying all processors) - ready for VAE encoding
|
||||
|
||||
Date range: 2026-02-23 to today (2026-02-27)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pickle as pkl
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
# Import processors from the new shared module
|
||||
from cta_1d.src.processors import (
|
||||
DiffProcessor,
|
||||
FlagMarketInjector,
|
||||
FlagSTInjector,
|
||||
ColumnRemover,
|
||||
FlagToOnehot,
|
||||
IndusNtrlInjector,
|
||||
RobustZScoreNorm,
|
||||
Fillna,
|
||||
)
|
||||
|
||||
# Import constants from local module
|
||||
from generate_beta_embedding import (
|
||||
ALPHA158_COLS,
|
||||
INDUSTRY_FLAG_COLS,
|
||||
)
|
||||
|
||||
# Date range
|
||||
START_DATE = "2026-02-23"
|
||||
END_DATE = "2026-02-27"
|
||||
|
||||
# Output directory
|
||||
OUTPUT_DIR = Path(__file__).parent.parent / "data_polars"
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Apply the full processor pipeline (equivalent to qlib's proc_list).
|
||||
|
||||
This mimics the qlib proc_list:
|
||||
0. Diff: Adds diff features for market_ext columns
|
||||
1. FlagMarketInjector: Adds market_0, market_1
|
||||
2. FlagSTInjector: Adds IsST (placeholder if ST flags not available)
|
||||
3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt
|
||||
4. FlagToOnehot: Converts 29 industry flags to indus_idx
|
||||
5. IndusNtrlInjector: Industry neutralization for feature
|
||||
6. IndusNtrlInjector: Industry neutralization for feature_ext
|
||||
7. RobustZScoreNorm: Normalization using pre-fitted qlib params
|
||||
8. Fillna: Fill NaN with 0
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("Applying processor pipeline")
|
||||
print("=" * 60)
|
||||
|
||||
# market_ext columns (4 base)
|
||||
market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
|
||||
|
||||
# market_flag columns (12 total before ColumnRemover)
|
||||
market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
|
||||
'open_limit', 'close_limit', 'low_limit',
|
||||
'open_stop', 'close_stop', 'high_stop']
|
||||
|
||||
# Step 1: Diff Processor
|
||||
print("\n[1] Applying Diff processor...")
|
||||
diff_processor = DiffProcessor(market_ext_base)
|
||||
df = diff_processor.process(df)
|
||||
|
||||
# After Diff: market_ext has 8 columns (4 base + 4 diff)
|
||||
market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base]
|
||||
|
||||
# Step 2: FlagMarketInjector (adds market_0, market_1)
|
||||
print("[2] Applying FlagMarketInjector...")
|
||||
flag_injector = FlagMarketInjector()
|
||||
df = flag_injector.process(df)
|
||||
|
||||
# Add market_0, market_1 to flag list
|
||||
market_flag_with_market = market_flag_cols + ['market_0', 'market_1']
|
||||
|
||||
# Step 3: FlagSTInjector - adds IsST (placeholder if ST flags not available)
|
||||
print("[3] Applying FlagSTInjector...")
|
||||
flag_st_injector = FlagSTInjector()
|
||||
df = flag_st_injector.process(df)
|
||||
|
||||
# Add IsST to flag list
|
||||
market_flag_with_st = market_flag_with_market + ['IsST']
|
||||
|
||||
# Step 4: ColumnRemover
|
||||
print("[4] Applying ColumnRemover...")
|
||||
columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
|
||||
remover = ColumnRemover(columns_to_remove)
|
||||
df = remover.process(df)
|
||||
|
||||
# Update column lists after removal
|
||||
market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove]
|
||||
market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove]
|
||||
|
||||
print(f" Removed columns: {columns_to_remove}")
|
||||
print(f" Remaining market_ext: {len(market_ext_cols)} columns")
|
||||
print(f" Remaining market_flag: {len(market_flag_with_st)} columns")
|
||||
|
||||
# Step 5: FlagToOnehot
|
||||
print("[5] Applying FlagToOnehot...")
|
||||
flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
|
||||
df = flag_to_onehot.process(df)
|
||||
|
||||
# Step 6 & 7: IndusNtrlInjector
|
||||
print("[6] Applying IndusNtrlInjector for alpha158...")
|
||||
alpha158_cols = ALPHA158_COLS.copy()
|
||||
indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl')
|
||||
df = indus_ntrl_alpha.process(df)
|
||||
|
||||
print("[7] Applying IndusNtrlInjector for market_ext...")
|
||||
indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl')
|
||||
df = indus_ntrl_ext.process(df)
|
||||
|
||||
# Build column lists for normalization
|
||||
alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols]
|
||||
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols]
|
||||
|
||||
# Step 8: RobustZScoreNorm
|
||||
print("[8] Applying RobustZScoreNorm...")
|
||||
norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols
|
||||
|
||||
# Load RobustZScoreNorm with pre-fitted parameters from version
|
||||
robust_norm = RobustZScoreNorm.from_version(
|
||||
"csiallx_feature2_ntrla_flag_pnlnorm",
|
||||
feature_cols=norm_feature_cols
|
||||
)
|
||||
|
||||
# Verify parameter shape matches expected features
|
||||
expected_features = len(norm_feature_cols)
|
||||
if robust_norm.qlib_mean.shape[0] != expected_features:
|
||||
print(f" WARNING: Feature count mismatch! Expected {expected_features}, "
|
||||
f"got {robust_norm.qlib_mean.shape[0]}")
|
||||
|
||||
df = robust_norm.process(df)
|
||||
|
||||
# Step 9: Fillna
|
||||
print("[9] Applying Fillna...")
|
||||
final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx']
|
||||
fillna = Fillna()
|
||||
df = fillna.process(df, final_feature_cols)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Processor pipeline complete!")
|
||||
print(f" Normalized features: {len(norm_feature_cols)}")
|
||||
print(f" Market flags: {len(market_flag_with_st)}")
|
||||
print(f" Total features (with indus_idx): {len(final_feature_cols)}")
|
||||
print("=" * 60)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame":
|
||||
"""
|
||||
Convert polars DataFrame to pandas DataFrame with MultiIndex columns.
|
||||
This matches the format of qlib's output.
|
||||
|
||||
IMPORTANT: Qlib's IndusNtrlInjector outputs columns in order [_ntrl] + [raw],
|
||||
so we need to reorder columns to match this expected order.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
# Convert to pandas
|
||||
df = df_polars.to_pandas()
|
||||
|
||||
# Check if datetime and instrument are columns
|
||||
if 'datetime' in df.columns and 'instrument' in df.columns:
|
||||
# Set MultiIndex
|
||||
df = df.set_index(['datetime', 'instrument'])
|
||||
# If they're already not in columns, assume they're already the index
|
||||
|
||||
# Drop raw columns that shouldn't be in processed data
|
||||
raw_cols_to_drop = ['Turnover', 'FreeTurnover', 'MarketValue']
|
||||
existing_raw_cols = [c for c in raw_cols_to_drop if c in df.columns]
|
||||
if existing_raw_cols:
|
||||
df = df.drop(columns=existing_raw_cols)
|
||||
|
||||
# Build MultiIndex columns based on column name patterns
|
||||
# IMPORTANT: Qlib order is [_ntrl columns] + [raw columns] for each group
|
||||
columns_with_group = []
|
||||
|
||||
# Define column sets
|
||||
alpha158_base = set(ALPHA158_COLS)
|
||||
market_ext_base = {'turnover', 'free_turnover', 'log_size', 'con_rating_strength'}
|
||||
market_ext_diff = {'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'}
|
||||
market_ext_all = market_ext_base | market_ext_diff
|
||||
feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit',
|
||||
'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'}
|
||||
|
||||
# First pass: collect _ntrl columns (these come first in qlib order)
|
||||
ntrl_alpha158_cols = []
|
||||
ntrl_market_ext_cols = []
|
||||
raw_alpha158_cols = []
|
||||
raw_market_ext_cols = []
|
||||
flag_cols = []
|
||||
indus_idx_col = None
|
||||
|
||||
for col in df.columns:
|
||||
if col == 'indus_idx':
|
||||
indus_idx_col = col
|
||||
elif col in feature_flag_cols:
|
||||
flag_cols.append(col)
|
||||
elif col.endswith('_ntrl'):
|
||||
base_name = col[:-5] # Remove _ntrl suffix (5 characters)
|
||||
if base_name in alpha158_base:
|
||||
ntrl_alpha158_cols.append(col)
|
||||
elif base_name in market_ext_all:
|
||||
ntrl_market_ext_cols.append(col)
|
||||
elif col in alpha158_base:
|
||||
raw_alpha158_cols.append(col)
|
||||
elif col in market_ext_all:
|
||||
raw_market_ext_cols.append(col)
|
||||
elif col in INDUSTRY_FLAG_COLS:
|
||||
columns_with_group.append(('indus_flag', col))
|
||||
elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}:
|
||||
columns_with_group.append(('st_flag', col))
|
||||
else:
|
||||
# Unknown column - print warning
|
||||
print(f" Warning: Unknown column '{col}', assigning to 'other' group")
|
||||
columns_with_group.append(('other', col))
|
||||
|
||||
# Build columns in qlib order: [_ntrl] + [raw] for each feature group
|
||||
# Feature group: alpha158_ntrl + alpha158
|
||||
for col in sorted(ntrl_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x.replace('_ntrl', '')) if x.replace('_ntrl', '') in ALPHA158_COLS else 999):
|
||||
columns_with_group.append(('feature', col))
|
||||
for col in sorted(raw_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x) if x in ALPHA158_COLS else 999):
|
||||
columns_with_group.append(('feature', col))
|
||||
|
||||
# Feature_ext group: market_ext_ntrl + market_ext
|
||||
for col in ntrl_market_ext_cols:
|
||||
columns_with_group.append(('feature_ext', col))
|
||||
for col in raw_market_ext_cols:
|
||||
columns_with_group.append(('feature_ext', col))
|
||||
|
||||
# Feature_flag group
|
||||
for col in flag_cols:
|
||||
columns_with_group.append(('feature_flag', col))
|
||||
|
||||
# Indus_idx
|
||||
if indus_idx_col:
|
||||
columns_with_group.append(('indus_idx', indus_idx_col))
|
||||
|
||||
# Create MultiIndex columns
|
||||
multi_cols = pd.MultiIndex.from_tuples(columns_with_group)
|
||||
df.columns = multi_cols
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 80)
|
||||
print("Dumping Polars Dataset")
|
||||
print("=" * 80)
|
||||
print(f"Date range: {START_DATE} to {END_DATE}")
|
||||
print(f"Output directory: {OUTPUT_DIR}")
|
||||
print()
|
||||
|
||||
# Step 1: Load all data
|
||||
print("Step 1: Loading data from parquet...")
|
||||
df_alpha, df_kline, df_flag, df_industry = load_all_data(START_DATE, END_DATE)
|
||||
print(f" Alpha158 shape: {df_alpha.shape}")
|
||||
print(f" Kline (market_ext) shape: {df_kline.shape}")
|
||||
print(f" Flags shape: {df_flag.shape}")
|
||||
print(f" Industry shape: {df_industry.shape}")
|
||||
|
||||
# Step 2: Merge data sources
|
||||
print("\nStep 2: Merging data sources...")
|
||||
df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry)
|
||||
print(f" Merged shape (after csiallx filter): {df_merged.shape}")
|
||||
|
||||
# Step 3: Save raw data (before processors)
|
||||
print("\nStep 3: Saving raw data (before processors)...")
|
||||
|
||||
# Keep columns that match qlib's raw output format
|
||||
# Include datetime and instrument for MultiIndex conversion
|
||||
raw_columns = (
|
||||
['datetime', 'instrument'] + # Index columns
|
||||
ALPHA158_COLS + # feature group
|
||||
['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + # feature_ext base
|
||||
['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', # market_flag from kline
|
||||
'open_limit', 'close_limit', 'low_limit',
|
||||
'open_stop', 'close_stop', 'high_stop'] +
|
||||
INDUSTRY_FLAG_COLS + # indus_flag
|
||||
(['ST_S', 'ST_Y'] if 'ST_S' in df_merged.columns else []) # st_flag (if available)
|
||||
)
|
||||
|
||||
# Filter to available columns
|
||||
available_raw_cols = [c for c in raw_columns if c in df_merged.columns]
|
||||
print(f" Selecting {len(available_raw_cols)} columns for raw data...")
|
||||
df_raw_polars = df_merged.select(available_raw_cols)
|
||||
|
||||
# Convert to pandas with MultiIndex
|
||||
df_raw_pd = convert_to_multiindex_df(df_raw_polars)
|
||||
|
||||
raw_output_path = OUTPUT_DIR / f"raw_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
|
||||
with open(raw_output_path, "wb") as f:
|
||||
pkl.dump(df_raw_pd, f)
|
||||
print(f" Saved raw data to: {raw_output_path}")
|
||||
print(f" Raw data shape: {df_raw_pd.shape}")
|
||||
print(f" Column groups: {df_raw_pd.columns.get_level_values(0).unique().tolist()}")
|
||||
|
||||
# Step 4: Apply processor pipeline
|
||||
print("\nStep 4: Applying processor pipeline...")
|
||||
df_processed = apply_processor_pipeline(df_merged)
|
||||
|
||||
# Step 5: Save processed data
|
||||
print("\nStep 5: Saving processed data (after processors)...")
|
||||
|
||||
# Convert to pandas with MultiIndex
|
||||
df_processed_pd = convert_to_multiindex_df(df_processed)
|
||||
|
||||
processed_output_path = OUTPUT_DIR / f"processed_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl"
|
||||
with open(processed_output_path, "wb") as f:
|
||||
pkl.dump(df_processed_pd, f)
|
||||
print(f" Saved processed data to: {processed_output_path}")
|
||||
print(f" Processed data shape: {df_processed_pd.shape}")
|
||||
print(f" Column groups: {df_processed_pd.columns.get_level_values(0).unique().tolist()}")
|
||||
|
||||
# Count columns per group
|
||||
print("\n Column counts by group:")
|
||||
for grp in df_processed_pd.columns.get_level_values(0).unique().tolist():
|
||||
count = (df_processed_pd.columns.get_level_values(0) == grp).sum()
|
||||
print(f" {grp}: {count} columns")
|
||||
|
||||
# Step 6: Verify column counts
|
||||
print("\n" + "=" * 80)
|
||||
print("Verification")
|
||||
print("=" * 80)
|
||||
|
||||
feature_flag_cols = [c[1] for c in df_processed_pd.columns if c[0] == 'feature_flag']
|
||||
has_market_0 = 'market_0' in feature_flag_cols
|
||||
has_market_1 = 'market_1' in feature_flag_cols
|
||||
|
||||
print(f" feature_flag columns: {feature_flag_cols}")
|
||||
print(f" Has market_0: {has_market_0}")
|
||||
print(f" Has market_1: {has_market_1}")
|
||||
|
||||
if has_market_0 and has_market_1:
|
||||
print("\n SUCCESS: market_0 and market_1 columns are present!")
|
||||
else:
|
||||
print("\n WARNING: market_0 or market_1 columns are missing!")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Dataset dump complete!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,53 @@
|
||||
"""
|
||||
Source package for alpha158_beta experiments.
|
||||
|
||||
This package provides modules for data loading, feature transformation,
|
||||
and model training for alpha158 beta factor experiments.
|
||||
"""
|
||||
|
||||
from .processors import (
|
||||
FeaturePipeline,
|
||||
FeatureGroups,
|
||||
DiffProcessor,
|
||||
FlagMarketInjector,
|
||||
FlagSTInjector,
|
||||
ColumnRemover,
|
||||
FlagToOnehot,
|
||||
IndusNtrlInjector,
|
||||
RobustZScoreNorm,
|
||||
Fillna,
|
||||
load_alpha158,
|
||||
load_market_ext,
|
||||
load_market_flags,
|
||||
load_industry_flags,
|
||||
load_all_data,
|
||||
load_robust_zscore_params,
|
||||
filter_stock_universe,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Main pipeline
|
||||
'FeaturePipeline',
|
||||
'FeatureGroups',
|
||||
|
||||
# Processors
|
||||
'DiffProcessor',
|
||||
'FlagMarketInjector',
|
||||
'FlagSTInjector',
|
||||
'ColumnRemover',
|
||||
'FlagToOnehot',
|
||||
'IndusNtrlInjector',
|
||||
'RobustZScoreNorm',
|
||||
'Fillna',
|
||||
|
||||
# Loaders
|
||||
'load_alpha158',
|
||||
'load_market_ext',
|
||||
'load_market_flags',
|
||||
'load_industry_flags',
|
||||
'load_all_data',
|
||||
|
||||
# Utilities
|
||||
'load_robust_zscore_params',
|
||||
'filter_stock_universe',
|
||||
]
|
||||
@ -0,0 +1,136 @@
|
||||
"""
|
||||
Processors package for the alpha158_beta feature pipeline.
|
||||
|
||||
This package provides a modular, polars-native data loading and transformation
|
||||
pipeline for generating VAE input features from alpha158 beta factors.
|
||||
|
||||
Main components:
|
||||
- FeatureGroups: Dataclass container for separate feature groups
|
||||
- FeaturePipeline: Main orchestrator class for the full pipeline
|
||||
- Processors: Individual transformation classes (DiffProcessor, IndusNtrlInjector, etc.)
|
||||
- Loaders: Data loading functions for parquet sources
|
||||
|
||||
Usage:
|
||||
from processors import FeaturePipeline, FeatureGroups
|
||||
|
||||
pipeline = FeaturePipeline()
|
||||
feature_groups = pipeline.load_data(start_date, end_date)
|
||||
transformed = pipeline.transform(feature_groups)
|
||||
vae_input = pipeline.get_vae_input(transformed)
|
||||
"""
|
||||
|
||||
from .dataclass import FeatureGroups
|
||||
from .loaders import (
|
||||
load_alpha158,
|
||||
load_market_ext,
|
||||
load_market_flags,
|
||||
load_industry_flags,
|
||||
load_all_data,
|
||||
load_parquet_by_date_range,
|
||||
get_date_partitions,
|
||||
# Constants
|
||||
PARQUET_ALPHA158_BETA_PATH,
|
||||
PARQUET_KLINE_PATH,
|
||||
PARQUET_MARKET_FLAG_PATH,
|
||||
PARQUET_INDUSTRY_FLAG_PATH,
|
||||
PARQUET_CON_RATING_PATH,
|
||||
INDUSTRY_FLAG_COLS,
|
||||
MARKET_FLAG_COLS_KLINE,
|
||||
MARKET_FLAG_COLS_MARKET,
|
||||
MARKET_EXT_RAW_COLS,
|
||||
)
|
||||
from .processors import (
|
||||
DiffProcessor,
|
||||
FlagMarketInjector,
|
||||
FlagSTInjector,
|
||||
ColumnRemover,
|
||||
FlagToOnehot,
|
||||
IndusNtrlInjector,
|
||||
RobustZScoreNorm,
|
||||
Fillna,
|
||||
)
|
||||
from .pipeline import (
|
||||
FeaturePipeline,
|
||||
load_robust_zscore_params,
|
||||
filter_stock_universe,
|
||||
# Constants
|
||||
ALPHA158_COLS,
|
||||
MARKET_EXT_BASE_COLS,
|
||||
MARKET_FLAG_COLS,
|
||||
COLUMNS_TO_REMOVE,
|
||||
VAE_INPUT_DIM,
|
||||
DEFAULT_ROBUST_ZSCORE_PARAMS_PATH,
|
||||
)
|
||||
from .exporters import (
|
||||
get_groups,
|
||||
get_groups_from_fg,
|
||||
pack_structs,
|
||||
unpack_struct,
|
||||
dump_to_parquet,
|
||||
dump_to_pickle,
|
||||
dump_to_numpy,
|
||||
dump_features,
|
||||
# Also import old names for backward compatibility
|
||||
get_groups as select_feature_groups_from_df,
|
||||
get_groups_from_fg as select_feature_groups,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Main classes
|
||||
'FeaturePipeline',
|
||||
'FeatureGroups',
|
||||
|
||||
# Processors
|
||||
'DiffProcessor',
|
||||
'FlagMarketInjector',
|
||||
'FlagSTInjector',
|
||||
'ColumnRemover',
|
||||
'FlagToOnehot',
|
||||
'IndusNtrlInjector',
|
||||
'RobustZScoreNorm',
|
||||
'Fillna',
|
||||
|
||||
# Loaders
|
||||
'load_alpha158',
|
||||
'load_market_ext',
|
||||
'load_market_flags',
|
||||
'load_industry_flags',
|
||||
'load_all_data',
|
||||
'load_parquet_by_date_range',
|
||||
'get_date_partitions',
|
||||
|
||||
# Utility functions
|
||||
'load_robust_zscore_params',
|
||||
'filter_stock_universe',
|
||||
|
||||
# Exporter functions
|
||||
'get_groups',
|
||||
'get_groups_from_fg',
|
||||
'pack_structs',
|
||||
'unpack_struct',
|
||||
'dump_to_parquet',
|
||||
'dump_to_pickle',
|
||||
'dump_to_numpy',
|
||||
'dump_features',
|
||||
|
||||
# Backward compatibility aliases
|
||||
'select_feature_groups_from_df',
|
||||
'select_feature_groups',
|
||||
|
||||
# Constants
|
||||
'PARQUET_ALPHA158_BETA_PATH',
|
||||
'PARQUET_KLINE_PATH',
|
||||
'PARQUET_MARKET_FLAG_PATH',
|
||||
'PARQUET_INDUSTRY_FLAG_PATH',
|
||||
'PARQUET_CON_RATING_PATH',
|
||||
'INDUSTRY_FLAG_COLS',
|
||||
'MARKET_FLAG_COLS_KLINE',
|
||||
'MARKET_FLAG_COLS_MARKET',
|
||||
'MARKET_EXT_RAW_COLS',
|
||||
'ALPHA158_COLS',
|
||||
'MARKET_EXT_BASE_COLS',
|
||||
'MARKET_FLAG_COLS',
|
||||
'COLUMNS_TO_REMOVE',
|
||||
'VAE_INPUT_DIM',
|
||||
'DEFAULT_ROBUST_ZSCORE_PARAMS_PATH',
|
||||
]
|
||||
@ -0,0 +1,115 @@
|
||||
"""Dataclass definitions for the feature pipeline."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
import polars as pl
|
||||
|
||||
# Import constants for column categorization
|
||||
# Alpha158 base columns (158 features)
|
||||
ALPHA158_BASE_COLS = [
|
||||
'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
|
||||
'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
|
||||
'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
|
||||
'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
|
||||
'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
|
||||
'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
|
||||
'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
|
||||
'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
|
||||
'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
|
||||
'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
|
||||
'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
|
||||
'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
|
||||
'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
|
||||
'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
|
||||
'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
|
||||
'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
|
||||
'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
|
||||
'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
|
||||
'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
|
||||
'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
|
||||
'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
|
||||
'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
|
||||
'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
|
||||
'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
|
||||
'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
|
||||
'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
|
||||
'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
|
||||
'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
|
||||
'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
|
||||
'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
|
||||
'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
|
||||
]
|
||||
|
||||
# Market extension base columns
|
||||
MARKET_EXT_BASE_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
|
||||
|
||||
# Market flag base columns (before processors)
|
||||
MARKET_FLAG_BASE_COLS = [
|
||||
'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
|
||||
'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureGroups:
|
||||
"""
|
||||
Container for separate feature groups in the pipeline.
|
||||
|
||||
Keeps feature groups separate throughout the pipeline to avoid
|
||||
unnecessary merging and complex column management.
|
||||
|
||||
Attributes:
|
||||
alpha158: 158 alpha158 features (+ _ntrl after industry neutralization)
|
||||
market_ext: Market extension features (Turnover, FreeTurnover, MarketValue, etc.)
|
||||
market_flag: Market flag features (IsZt, IsDt, IsXD, etc.)
|
||||
industry: 29 industry flags (converted to indus_idx by FlagToOnehot)
|
||||
indus_idx: Single column industry index (after FlagToOnehot processing)
|
||||
instruments: List of instrument IDs for metadata
|
||||
dates: List of datetime values for metadata
|
||||
"""
|
||||
# Core feature groups
|
||||
alpha158: pl.DataFrame # 158 alpha158 features (+ _ntrl after processing)
|
||||
market_ext: pl.DataFrame # Market extension features (+ _ntrl after processing)
|
||||
market_flag: pl.DataFrame # Market flags (12 cols initially, 11 after ColumnRemover)
|
||||
industry: Optional[pl.DataFrame] = None # 29 industry flags -> indus_idx
|
||||
|
||||
# Processed industry index (separate after FlagToOnehot)
|
||||
indus_idx: Optional[pl.DataFrame] = None # Single column after FlagToOnehot
|
||||
|
||||
# Metadata (extracted from dataframes for easy access)
|
||||
instruments: List[str] = field(default_factory=list)
|
||||
dates: List[int] = field(default_factory=list)
|
||||
|
||||
def extract_metadata(self) -> None:
|
||||
"""Extract instrument and date lists from the alpha158 dataframe."""
|
||||
if self.alpha158 is not None and len(self.alpha158) > 0:
|
||||
self.instruments = self.alpha158['instrument'].to_list()
|
||||
self.dates = self.alpha158['datetime'].to_list()
|
||||
|
||||
def merge_for_processors(self) -> pl.DataFrame:
|
||||
"""
|
||||
Merge all feature groups into a single DataFrame for processors.
|
||||
|
||||
This is used by processors that need access to multiple groups
|
||||
(e.g., IndusNtrlInjector needs industry index).
|
||||
|
||||
Returns:
|
||||
Merged DataFrame with all features
|
||||
"""
|
||||
df = self.alpha158
|
||||
|
||||
# Merge market_ext if not already merged
|
||||
if self.market_ext is not None and self.market_ext is not self.alpha158:
|
||||
df = df.join(self.market_ext, on=['instrument', 'datetime'], how='left')
|
||||
|
||||
# Merge market_flag if not already merged
|
||||
if self.market_flag is not None and self.market_flag is not self.alpha158:
|
||||
df = df.join(self.market_flag, on=['instrument', 'datetime'], how='left')
|
||||
|
||||
# Merge industry/indus_idx if available
|
||||
if self.indus_idx is not None:
|
||||
df = df.join(self.indus_idx, on=['instrument', 'datetime'], how='left')
|
||||
elif self.industry is not None:
|
||||
df = df.join(self.industry, on=['instrument', 'datetime'], how='left')
|
||||
|
||||
return df
|
||||
@ -0,0 +1,523 @@
|
||||
"""
|
||||
Feature exporters for the alpha158_beta pipeline.
|
||||
|
||||
This module provides functions to select and export feature groups from the
|
||||
transformed pipeline output. It can be used by both dump_features.py and
|
||||
generate_beta_embedding.py.
|
||||
|
||||
Feature groups:
|
||||
- merged: All columns after transformation
|
||||
- alpha158: Alpha158 features + _ntrl versions
|
||||
- market_ext: Market extended features + _ntrl + _diff
|
||||
- market_flag: Market flag columns
|
||||
- vae_input: 341 features specifically curated for VAE training
|
||||
|
||||
Struct mode:
|
||||
- When pack_struct=True, each feature group is packed into a struct column:
|
||||
- features_alpha158 (316 fields)
|
||||
- features_market_ext (14 fields)
|
||||
- features_market_flag (11 fields)
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
from .pipeline import (
|
||||
ALPHA158_COLS,
|
||||
MARKET_EXT_BASE_COLS,
|
||||
MARKET_FLAG_COLS,
|
||||
COLUMNS_TO_REMOVE,
|
||||
)
|
||||
from .dataclass import FeatureGroups
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper functions for struct packing/unpacking
|
||||
# =============================================================================
|
||||
|
||||
def pack_structs(df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Pack feature columns into struct columns based on feature groups.
|
||||
|
||||
Creates:
|
||||
- features_alpha158: struct with 316 fields (158 + 158 _ntrl)
|
||||
- features_market_ext: struct with 14 fields (7 + 7 _ntrl)
|
||||
- features_market_flag: struct with 11 fields
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with flat columns
|
||||
|
||||
Returns:
|
||||
DataFrame with struct columns: instrument, datetime, indus_idx, features_*
|
||||
"""
|
||||
# Define column groups
|
||||
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||
alpha158_all_cols = alpha158_ntrl_cols + ALPHA158_COLS
|
||||
|
||||
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||
market_ext_all_cols = market_ext_ntrl_cols + market_ext_with_diff
|
||||
|
||||
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||
market_flag_cols += ['market_0', 'market_1', 'IsST']
|
||||
market_flag_cols = list(dict.fromkeys(market_flag_cols))
|
||||
|
||||
# Build result with struct columns
|
||||
result_cols = ['instrument', 'datetime']
|
||||
|
||||
# Check if indus_idx exists
|
||||
if 'indus_idx' in df.columns:
|
||||
result_cols.append('indus_idx')
|
||||
|
||||
# Pack alpha158
|
||||
alpha158_cols_in_df = [c for c in alpha158_all_cols if c in df.columns]
|
||||
if alpha158_cols_in_df:
|
||||
result_cols.append(pl.struct(alpha158_cols_in_df).alias('features_alpha158'))
|
||||
|
||||
# Pack market_ext
|
||||
ext_cols_in_df = [c for c in market_ext_all_cols if c in df.columns]
|
||||
if ext_cols_in_df:
|
||||
result_cols.append(pl.struct(ext_cols_in_df).alias('features_market_ext'))
|
||||
|
||||
# Pack market_flag
|
||||
flag_cols_in_df = [c for c in market_flag_cols if c in df.columns]
|
||||
if flag_cols_in_df:
|
||||
result_cols.append(pl.struct(flag_cols_in_df).alias('features_market_flag'))
|
||||
|
||||
return df.select(result_cols)
|
||||
|
||||
|
||||
def unpack_struct(df: pl.DataFrame, struct_name: str) -> pl.DataFrame:
|
||||
"""
|
||||
Unpack a struct column back into individual columns.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing struct column
|
||||
struct_name: Name of the struct column to unpack
|
||||
|
||||
Returns:
|
||||
DataFrame with struct fields as individual columns
|
||||
"""
|
||||
if struct_name not in df.columns:
|
||||
raise ValueError(f"Struct column '{struct_name}' not found in DataFrame")
|
||||
|
||||
# Get the struct field names
|
||||
struct_dtype = df.schema[struct_name]
|
||||
if not isinstance(struct_dtype, pl.Struct):
|
||||
raise ValueError(f"Column '{struct_name}' is not a struct type")
|
||||
|
||||
field_names = struct_dtype.fields
|
||||
|
||||
# Unpack using struct field access
|
||||
unpacked_cols = []
|
||||
for field in field_names:
|
||||
col_name = field.name
|
||||
unpacked_cols.append(
|
||||
pl.col(struct_name).struct.field(col_name).alias(col_name)
|
||||
)
|
||||
|
||||
# Select original columns + unpacked columns
|
||||
other_cols = [c for c in df.columns if c != struct_name]
|
||||
return df.select(other_cols + unpacked_cols)
|
||||
|
||||
|
||||
def dump_to_parquet(df: pl.DataFrame, path: str, verbose: bool = True) -> None:
|
||||
"""Save DataFrame to parquet file."""
|
||||
if verbose:
|
||||
print(f"Saving to parquet: {path}")
|
||||
df.write_parquet(path)
|
||||
if verbose:
|
||||
print(f" Shape: {df.shape}")
|
||||
|
||||
|
||||
def dump_to_pickle(df: pl.DataFrame, path: str, verbose: bool = True) -> None:
|
||||
"""Save DataFrame to pickle file."""
|
||||
if verbose:
|
||||
print(f"Saving to pickle: {path}")
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(df, f)
|
||||
if verbose:
|
||||
print(f" Shape: {df.shape}")
|
||||
|
||||
|
||||
def dump_to_numpy(
|
||||
feature_groups: FeatureGroups,
|
||||
path: str,
|
||||
include_metadata: bool = True,
|
||||
verbose: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Save features to numpy format.
|
||||
|
||||
Saves:
|
||||
- features.npy: The feature matrix
|
||||
- metadata.pkl: Column names and metadata (if include_metadata=True)
|
||||
"""
|
||||
if verbose:
|
||||
print(f"Saving to numpy: {path}")
|
||||
|
||||
# Merge all groups for numpy array
|
||||
df = feature_groups.merge_for_processors()
|
||||
|
||||
# Get all feature columns (exclude instrument, datetime, indus_idx)
|
||||
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
|
||||
|
||||
# Extract features
|
||||
features = df.select(feature_cols).to_numpy().astype(np.float32)
|
||||
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# Save features
|
||||
base_path = Path(path)
|
||||
if base_path.suffix == '':
|
||||
base_path = Path(str(path) + '.npy')
|
||||
|
||||
np.save(str(base_path), features)
|
||||
|
||||
if verbose:
|
||||
print(f" Features shape: {features.shape}")
|
||||
|
||||
# Save metadata if requested
|
||||
if include_metadata:
|
||||
metadata_path = str(base_path).replace('.npy', '_metadata.pkl')
|
||||
metadata = {
|
||||
'feature_cols': feature_cols,
|
||||
'instruments': df['instrument'].to_list(),
|
||||
'dates': df['datetime'].to_list(),
|
||||
'n_features': len(feature_cols),
|
||||
'n_samples': len(df),
|
||||
}
|
||||
with open(metadata_path, 'wb') as f:
|
||||
pickle.dump(metadata, f)
|
||||
if verbose:
|
||||
print(f" Metadata saved to: {metadata_path}")
|
||||
|
||||
|
||||
def get_groups(
|
||||
df: pl.DataFrame,
|
||||
groups_to_dump: List[str],
|
||||
verbose: bool = True,
|
||||
use_struct: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Select which feature groups to include in the output from a merged DataFrame.
|
||||
|
||||
Args:
|
||||
df: Transformed merged DataFrame (flat or with struct columns)
|
||||
groups_to_dump: List of groups to include ('alpha158', 'market_ext', 'market_flag', 'merged', 'vae_input')
|
||||
verbose: Whether to print progress
|
||||
use_struct: If True, pack feature columns into a single 'features' struct column.
|
||||
If df already has struct columns (pack_struct=True was used in pipeline),
|
||||
this function will automatically handle them.
|
||||
|
||||
Returns:
|
||||
Dictionary with selected DataFrames and metadata
|
||||
"""
|
||||
# Check if df already has struct columns from pipeline.pack_struct
|
||||
has_struct_cols = any(
|
||||
isinstance(df.schema.get(c), pl.Struct)
|
||||
for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']
|
||||
)
|
||||
|
||||
result = {}
|
||||
|
||||
if 'merged' in groups_to_dump:
|
||||
if has_struct_cols:
|
||||
# Already has struct columns from pipeline.pack_struct
|
||||
result['merged'] = df
|
||||
elif use_struct:
|
||||
# Keep instrument, datetime, and pack rest into struct
|
||||
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
|
||||
result['merged'] = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||
else:
|
||||
result['merged'] = df
|
||||
if verbose:
|
||||
print(f"Merged features: {result['merged'].shape}")
|
||||
|
||||
if 'alpha158' in groups_to_dump:
|
||||
# Check if struct column already exists from pipeline.pack_struct
|
||||
if has_struct_cols and 'features_alpha158' in df.columns:
|
||||
result['alpha158'] = df.select(['instrument', 'datetime', 'features_alpha158'])
|
||||
else:
|
||||
# Select alpha158 columns from merged DataFrame
|
||||
alpha_cols = ['instrument', 'datetime'] + [c for c in ALPHA158_COLS if c in df.columns]
|
||||
# Also include _ntrl versions
|
||||
alpha_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS if f"{c}_ntrl" in df.columns]
|
||||
alpha_cols += alpha_ntrl_cols
|
||||
df_alpha = df.select(alpha_cols)
|
||||
if use_struct:
|
||||
feature_cols = [c for c in df_alpha.columns if c not in ['instrument', 'datetime']]
|
||||
df_alpha = df_alpha.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||
result['alpha158'] = df_alpha
|
||||
if verbose:
|
||||
print(f"Alpha158 features: {result['alpha158'].shape}")
|
||||
|
||||
if 'market_ext' in groups_to_dump:
|
||||
# Check if struct column already exists from pipeline.pack_struct
|
||||
if has_struct_cols and 'features_market_ext' in df.columns:
|
||||
result['market_ext'] = df.select(['instrument', 'datetime', 'features_market_ext'])
|
||||
else:
|
||||
# Select market_ext columns from merged DataFrame
|
||||
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||
ext_cols = ['instrument', 'datetime'] + market_ext_with_diff
|
||||
# Also include _ntrl versions
|
||||
ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff if f"{c}_ntrl" in df.columns]
|
||||
ext_cols += ext_ntrl_cols
|
||||
df_ext = df.select(ext_cols)
|
||||
if use_struct:
|
||||
feature_cols = [c for c in df_ext.columns if c not in ['instrument', 'datetime']]
|
||||
df_ext = df_ext.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||
result['market_ext'] = df_ext
|
||||
if verbose:
|
||||
print(f"Market ext features: {result['market_ext'].shape}")
|
||||
|
||||
if 'market_flag' in groups_to_dump:
|
||||
# Check if struct column already exists from pipeline.pack_struct
|
||||
if has_struct_cols and 'features_market_flag' in df.columns:
|
||||
result['market_flag'] = df.select(['instrument', 'datetime', 'features_market_flag'])
|
||||
else:
|
||||
# Select market_flag columns from merged DataFrame
|
||||
flag_cols = ['instrument', 'datetime']
|
||||
flag_cols += [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE and c in df.columns]
|
||||
flag_cols += ['market_0', 'market_1', 'IsST'] if all(c in df.columns for c in ['market_0', 'market_1', 'IsST']) else []
|
||||
flag_cols = list(dict.fromkeys(flag_cols)) # Remove duplicates
|
||||
df_flag = df.select(flag_cols)
|
||||
if use_struct:
|
||||
feature_cols = [c for c in df_flag.columns if c not in ['instrument', 'datetime']]
|
||||
df_flag = df_flag.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||
result['market_flag'] = df_flag
|
||||
if verbose:
|
||||
print(f"Market flag features: {result['market_flag'].shape}")
|
||||
|
||||
if 'vae_input' in groups_to_dump:
|
||||
# Get VAE input columns from merged DataFrame
|
||||
# VAE input = 330 normalized features + 11 market flags = 341 features
|
||||
# Note: indus_idx is NOT included in VAE input
|
||||
|
||||
# Build alpha158 feature columns (158 original + 158 _ntrl = 316)
|
||||
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||
alpha158_cols = ALPHA158_COLS.copy()
|
||||
|
||||
# Build market_ext feature columns (7 original + 7 _ntrl = 14)
|
||||
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||
market_ext_cols = market_ext_with_diff.copy()
|
||||
|
||||
# VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
|
||||
norm_feature_cols = (
|
||||
alpha158_ntrl_cols + alpha158_cols +
|
||||
market_ext_ntrl_cols + market_ext_cols
|
||||
)
|
||||
|
||||
# Market flag columns (excluding IsST)
|
||||
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||
market_flag_cols += ['market_0', 'market_1']
|
||||
market_flag_cols = list(dict.fromkeys(market_flag_cols))
|
||||
|
||||
# Combine all VAE input columns (341 total)
|
||||
vae_cols = norm_feature_cols + market_flag_cols
|
||||
|
||||
if use_struct:
|
||||
# Pack all features into a single struct column
|
||||
result['vae_input'] = df.select([
|
||||
'instrument',
|
||||
'datetime',
|
||||
pl.struct(vae_cols).alias('features')
|
||||
])
|
||||
if verbose:
|
||||
print(f"VAE input features: {result['vae_input'].shape} (struct with {len(vae_cols)} fields)")
|
||||
else:
|
||||
# Always keep datetime and instrument as index columns
|
||||
# Select features with index columns first
|
||||
result['vae_input'] = df.select(['instrument', 'datetime'] + vae_cols)
|
||||
if verbose:
|
||||
print(f"VAE input features: {result['vae_input'].shape} (columns: {len(vae_cols)} + 2 index)")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_groups_from_fg(
|
||||
feature_groups: FeatureGroups,
|
||||
groups_to_dump: List[str],
|
||||
verbose: bool = True,
|
||||
use_struct: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Select which feature groups to include in the output.
|
||||
|
||||
Args:
|
||||
feature_groups: Transformed FeatureGroups container
|
||||
groups_to_dump: List of groups to include ('alpha158', 'market_ext', 'market_flag', 'merged', 'vae_input')
|
||||
verbose: Whether to print progress
|
||||
use_struct: If True, pack feature columns into a single 'features' struct column
|
||||
|
||||
Returns:
|
||||
Dictionary with selected DataFrames and metadata
|
||||
"""
|
||||
result = {}
|
||||
|
||||
if 'merged' in groups_to_dump:
|
||||
df = feature_groups.merge_for_processors()
|
||||
if use_struct:
|
||||
# Keep instrument, datetime, and pack rest into struct
|
||||
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']]
|
||||
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||
result['merged'] = df
|
||||
if verbose:
|
||||
print(f"Merged features: {result['merged'].shape}")
|
||||
|
||||
if 'alpha158' in groups_to_dump:
|
||||
df = feature_groups.alpha158
|
||||
if use_struct:
|
||||
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
|
||||
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||
result['alpha158'] = df
|
||||
if verbose:
|
||||
print(f"Alpha158 features: {df.shape}")
|
||||
|
||||
if 'market_ext' in groups_to_dump:
|
||||
df = feature_groups.market_ext
|
||||
if use_struct:
|
||||
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
|
||||
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||
result['market_ext'] = df
|
||||
if verbose:
|
||||
print(f"Market ext features: {df.shape}")
|
||||
|
||||
if 'market_flag' in groups_to_dump:
|
||||
df = feature_groups.market_flag
|
||||
if use_struct:
|
||||
feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']]
|
||||
df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')])
|
||||
result['market_flag'] = df
|
||||
if verbose:
|
||||
print(f"Market flag features: {df.shape}")
|
||||
|
||||
if 'vae_input' in groups_to_dump:
|
||||
# Get VAE input columns from already-transformed feature_groups
|
||||
# VAE input = 330 normalized features + 11 market flags = 341 features
|
||||
# Note: indus_idx is NOT included in VAE input
|
||||
df = feature_groups.merge_for_processors()
|
||||
|
||||
# Build alpha158 feature columns (158 original + 158 _ntrl = 316)
|
||||
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||
alpha158_cols = ALPHA158_COLS.copy()
|
||||
|
||||
# Build market_ext feature columns (7 original + 7 _ntrl = 14)
|
||||
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||
market_ext_cols = market_ext_with_diff.copy()
|
||||
|
||||
# VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
|
||||
norm_feature_cols = (
|
||||
alpha158_ntrl_cols + alpha158_cols +
|
||||
market_ext_ntrl_cols + market_ext_cols
|
||||
)
|
||||
|
||||
# Market flag columns (excluding IsST)
|
||||
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||
market_flag_cols += ['market_0', 'market_1']
|
||||
market_flag_cols = list(dict.fromkeys(market_flag_cols))
|
||||
|
||||
# Combine all VAE input columns (341 total)
|
||||
vae_cols = norm_feature_cols + market_flag_cols
|
||||
|
||||
if use_struct:
|
||||
# Pack all features into a single struct column
|
||||
result['vae_input'] = df.select([
|
||||
'instrument',
|
||||
'datetime',
|
||||
pl.struct(vae_cols).alias('features')
|
||||
])
|
||||
if verbose:
|
||||
print(f"VAE input features: {result['vae_input'].shape} (struct with {len(vae_cols)} fields)")
|
||||
else:
|
||||
# Always keep datetime and instrument as index columns
|
||||
# Select features with index columns first
|
||||
result['vae_input'] = df.select(['instrument', 'datetime'] + vae_cols)
|
||||
if verbose:
|
||||
print(f"VAE input features: {result['vae_input'].shape} (columns: {len(vae_cols)} + 2 index)")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def dump_features(
|
||||
df: pl.DataFrame,
|
||||
output_path: str,
|
||||
output_format: str = 'parquet',
|
||||
groups: List[str] = None,
|
||||
verbose: bool = True,
|
||||
use_struct: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Dump features to file.
|
||||
|
||||
Args:
|
||||
df: Transformed merged DataFrame
|
||||
output_path: Output file path
|
||||
output_format: Output format ('parquet', 'pickle', 'numpy')
|
||||
groups: Feature groups to dump (default: ['merged'])
|
||||
verbose: Whether to print progress
|
||||
use_struct: If True, pack feature columns into a single 'features' struct column
|
||||
"""
|
||||
if groups is None:
|
||||
groups = ['merged']
|
||||
|
||||
# Select feature groups from merged DataFrame
|
||||
outputs = get_groups(df, groups, verbose, use_struct)
|
||||
|
||||
# Ensure output directory exists
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# For numpy, we use FeatureGroups (need to convert df back)
|
||||
# This is a simplified version - for numpy, use dump_to_numpy directly
|
||||
if output_format == 'numpy':
|
||||
raise NotImplementedError(
|
||||
"For numpy output, use the FeaturePipeline directly. "
|
||||
"This function handles parquet/pickle output only."
|
||||
)
|
||||
|
||||
# Dump to file(s)
|
||||
if output_format == 'pickle':
|
||||
if 'merged' in outputs:
|
||||
dump_to_pickle(outputs['merged'], output_path, verbose=verbose)
|
||||
elif len(outputs) == 1:
|
||||
# Single group output
|
||||
key = list(outputs.keys())[0]
|
||||
base_path = Path(output_path)
|
||||
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||
dump_to_pickle(outputs[key], dump_path, verbose=verbose)
|
||||
else:
|
||||
# Multiple groups - save each separately
|
||||
base_path = Path(output_path)
|
||||
for key, df_out in outputs.items():
|
||||
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||
dump_to_pickle(df_out, dump_path, verbose=verbose)
|
||||
else: # parquet
|
||||
if 'merged' in outputs:
|
||||
dump_to_parquet(outputs['merged'], output_path, verbose=verbose)
|
||||
elif len(outputs) == 1:
|
||||
# Single group output
|
||||
key = list(outputs.keys())[0]
|
||||
base_path = Path(output_path)
|
||||
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||
dump_to_parquet(outputs[key], dump_path, verbose=verbose)
|
||||
else:
|
||||
# Multiple groups - save each separately
|
||||
base_path = Path(output_path)
|
||||
for key, df_out in outputs.items():
|
||||
dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}"))
|
||||
dump_to_parquet(df_out, dump_path, verbose=verbose)
|
||||
|
||||
if verbose:
|
||||
print("Feature dump complete!")
|
||||
@ -0,0 +1,279 @@
|
||||
"""Data loading functions for the feature pipeline.
|
||||
|
||||
Memory Best Practices:
|
||||
- Always filter on partition keys (datetime) BEFORE collecting
|
||||
- Use streaming for large date ranges (>1 year)
|
||||
- Never collect full parquet tables - filter on datetime first
|
||||
"""
|
||||
|
||||
import polars as pl
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Data paths
|
||||
PARQUET_ALPHA158_BETA_PATH = "/data/parquet/dataset/stg_1day_wind_alpha158_0_7_beta_1D/"
|
||||
PARQUET_KLINE_PATH = "/data/parquet/dataset/stg_1day_wind_kline_adjusted_1D/"
|
||||
PARQUET_MARKET_FLAG_PATH = "/data/parquet/dataset/stg_1day_wind_market_flag_1D/"
|
||||
PARQUET_INDUSTRY_FLAG_PATH = "/data/parquet/dataset/stg_1day_gds_indus_flag_cc1_1D/"
|
||||
PARQUET_CON_RATING_PATH = "/data/parquet/dataset/stg_1day_gds_con_rating_1D/"
|
||||
|
||||
# Industry flag columns (30 one-hot columns - note: gds_CC29 is not present in the data)
|
||||
INDUSTRY_FLAG_COLS = [
|
||||
'gds_CC10', 'gds_CC11', 'gds_CC12', 'gds_CC20', 'gds_CC21', 'gds_CC22',
|
||||
'gds_CC23', 'gds_CC24', 'gds_CC25', 'gds_CC26', 'gds_CC27', 'gds_CC28',
|
||||
'gds_CC30', 'gds_CC31', 'gds_CC32', 'gds_CC33', 'gds_CC34', 'gds_CC35',
|
||||
'gds_CC36', 'gds_CC37', 'gds_CC40', 'gds_CC41', 'gds_CC42', 'gds_CC43',
|
||||
'gds_CC50', 'gds_CC60', 'gds_CC61', 'gds_CC62', 'gds_CC63', 'gds_CC70'
|
||||
]
|
||||
|
||||
# Market flag columns
|
||||
MARKET_FLAG_COLS_KLINE = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR']
|
||||
MARKET_FLAG_COLS_MARKET = ['open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop']
|
||||
|
||||
# Market extension raw columns
|
||||
MARKET_EXT_RAW_COLS = ['Turnover', 'FreeTurnover', 'MarketValue']
|
||||
|
||||
|
||||
def get_date_partitions(start_date: str, end_date: str) -> List[str]:
|
||||
"""
|
||||
Generate a list of date partitions to load from Parquet.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
|
||||
Returns:
|
||||
List of datetime=YYYYMMDD partition strings for weekdays only
|
||||
"""
|
||||
start = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
partitions = []
|
||||
current = start
|
||||
while current <= end:
|
||||
if current.weekday() < 5: # Monday = 0, Friday = 4
|
||||
partitions.append(f"datetime={current.strftime('%Y%m%d')}")
|
||||
current = datetime(current.year, current.month, current.day + 1)
|
||||
|
||||
return partitions
|
||||
|
||||
|
||||
def load_parquet_by_date_range(
|
||||
base_path: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
columns: Optional[List[str]] = None,
|
||||
collect: bool = True,
|
||||
streaming: bool = False
|
||||
) -> pl.LazyFrame | pl.DataFrame:
|
||||
"""
|
||||
Load parquet data filtered by date range.
|
||||
|
||||
CRITICAL: This function filters on the datetime partition key BEFORE collecting.
|
||||
This is essential for memory efficiency with large parquet datasets.
|
||||
|
||||
Args:
|
||||
base_path: Base path to the parquet dataset
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
columns: Optional list of columns to select (excluding instrument/datetime)
|
||||
collect: If True, return DataFrame (default); if False, return LazyFrame
|
||||
streaming: If True, use streaming mode for large datasets (recommended for >1 year)
|
||||
|
||||
Returns:
|
||||
Polars DataFrame with the loaded data (LazyFrame if collect=False)
|
||||
"""
|
||||
start_int = int(start_date.replace("-", ""))
|
||||
end_int = int(end_date.replace("-", ""))
|
||||
|
||||
try:
|
||||
# Start with lazy scan - DO NOT COLLECT YET
|
||||
lf = pl.scan_parquet(base_path)
|
||||
|
||||
# CRITICAL: Filter on partition key FIRST, before any other operations
|
||||
# This ensures partition pruning happens at the scan level
|
||||
lf = lf.filter(pl.col('datetime') >= start_int)
|
||||
lf = lf.filter(pl.col('datetime') <= end_int)
|
||||
|
||||
# Select specific columns if provided (column pruning)
|
||||
if columns:
|
||||
# Get schema to check which columns exist
|
||||
schema = lf.collect_schema()
|
||||
available_cols = ['instrument', 'datetime'] + [c for c in columns if c in schema.names()]
|
||||
lf = lf.select(available_cols)
|
||||
|
||||
# Collect with optional streaming mode
|
||||
if collect:
|
||||
if streaming:
|
||||
return lf.collect(streaming=True)
|
||||
return lf.collect()
|
||||
return lf
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading from {base_path}: {e}")
|
||||
# Return empty DataFrame with expected schema
|
||||
if columns:
|
||||
return pl.DataFrame({
|
||||
'instrument': pl.Series([], dtype=pl.String),
|
||||
'datetime': pl.Series([], dtype=pl.Int32),
|
||||
**{col: pl.Series([], dtype=pl.Float64) for col in columns if c not in ['instrument', 'datetime']}
|
||||
})
|
||||
return pl.DataFrame({
|
||||
'instrument': pl.Series([], dtype=pl.String),
|
||||
'datetime': pl.Series([], dtype=pl.Int32)
|
||||
})
|
||||
|
||||
|
||||
def load_alpha158(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
|
||||
"""
|
||||
Load alpha158 beta factors from parquet.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
streaming: If True, use streaming mode for large datasets
|
||||
|
||||
Returns:
|
||||
DataFrame with instrument, datetime, and 158 alpha158 features
|
||||
"""
|
||||
print("Loading alpha158_0_7_beta factors...")
|
||||
df = load_parquet_by_date_range(PARQUET_ALPHA158_BETA_PATH, start_date, end_date, streaming=streaming)
|
||||
print(f" Alpha158 shape: {df.shape}")
|
||||
return df
|
||||
|
||||
|
||||
def load_market_ext(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
|
||||
"""
|
||||
Load market extension features from parquet.
|
||||
|
||||
Loads Turnover, FreeTurnover, MarketValue from kline data and transforms:
|
||||
- Turnover -> turnover (rename)
|
||||
- FreeTurnover -> free_turnover (rename)
|
||||
- MarketValue -> log_size = log(MarketValue)
|
||||
- con_rating_strength: loaded from parquet (or zeros if not available)
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
streaming: If True, use streaming mode for large datasets
|
||||
|
||||
Returns:
|
||||
DataFrame with instrument, datetime, turnover, free_turnover, log_size, con_rating_strength
|
||||
"""
|
||||
print("Loading kline data (market ext columns)...")
|
||||
|
||||
# Load raw kline columns
|
||||
df_kline = load_parquet_by_date_range(
|
||||
PARQUET_KLINE_PATH, start_date, end_date, MARKET_EXT_RAW_COLS, streaming=streaming
|
||||
)
|
||||
print(f" Kline (market ext raw) shape: {df_kline.shape}")
|
||||
|
||||
# Load con_rating_strength from parquet
|
||||
print("Loading con_rating_strength from parquet...")
|
||||
df_con_rating = load_parquet_by_date_range(
|
||||
PARQUET_CON_RATING_PATH, start_date, end_date, ['con_rating_strength'], streaming=streaming
|
||||
)
|
||||
print(f" Con rating shape: {df_con_rating.shape}")
|
||||
|
||||
# Transform columns
|
||||
df_kline = df_kline.with_columns([
|
||||
pl.col('Turnover').alias('turnover'),
|
||||
pl.col('FreeTurnover').alias('free_turnover'),
|
||||
pl.col('MarketValue').log().alias('log_size'),
|
||||
])
|
||||
print(f" Kline (market ext transformed) shape: {df_kline.shape}")
|
||||
|
||||
# Merge con_rating_strength
|
||||
df_kline = df_kline.join(df_con_rating, on=['instrument', 'datetime'], how='left')
|
||||
df_kline = df_kline.with_columns([
|
||||
pl.col('con_rating_strength').fill_null(0.0)
|
||||
])
|
||||
print(f" Kline (with con_rating) shape: {df_kline.shape}")
|
||||
|
||||
return df_kline
|
||||
|
||||
|
||||
def load_market_flags(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
|
||||
"""
|
||||
Load market flag features from parquet.
|
||||
|
||||
Combines two sources:
|
||||
- From kline_adjusted: IsZt, IsDt, IsN, IsXD, IsXR, IsDR
|
||||
- From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
streaming: If True, use streaming mode for large datasets
|
||||
|
||||
Returns:
|
||||
DataFrame with instrument, datetime, and 12 market flag columns
|
||||
"""
|
||||
# Load kline flags
|
||||
print("Loading market flags from kline_adjusted...")
|
||||
df_kline_flag = load_parquet_by_date_range(
|
||||
PARQUET_KLINE_PATH, start_date, end_date, MARKET_FLAG_COLS_KLINE, streaming=streaming
|
||||
)
|
||||
print(f" Kline flags shape: {df_kline_flag.shape}")
|
||||
|
||||
# Load market flags
|
||||
print("Loading market flags from market_flag table...")
|
||||
df_market_flag = load_parquet_by_date_range(
|
||||
PARQUET_MARKET_FLAG_PATH, start_date, end_date, MARKET_FLAG_COLS_MARKET, streaming=streaming
|
||||
)
|
||||
print(f" Market flag shape: {df_market_flag.shape}")
|
||||
|
||||
# Merge both flag sources
|
||||
df_flag = df_kline_flag.join(df_market_flag, on=['instrument', 'datetime'], how='inner')
|
||||
print(f" Combined flags shape: {df_flag.shape}")
|
||||
|
||||
return df_flag
|
||||
|
||||
|
||||
def load_industry_flags(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame:
|
||||
"""
|
||||
Load industry flag features from parquet.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
streaming: If True, use streaming mode for large datasets
|
||||
|
||||
Returns:
|
||||
DataFrame with instrument, datetime, and 29 industry flag columns
|
||||
"""
|
||||
print("Loading industry flags...")
|
||||
df = load_parquet_by_date_range(
|
||||
PARQUET_INDUSTRY_FLAG_PATH, start_date, end_date, INDUSTRY_FLAG_COLS, streaming=streaming
|
||||
)
|
||||
print(f" Industry shape: {df.shape}")
|
||||
return df
|
||||
|
||||
|
||||
def load_all_data(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
streaming: bool = False
|
||||
) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame]:
|
||||
"""
|
||||
Load all data sources from Parquet.
|
||||
|
||||
This is a convenience function that loads all four data sources
|
||||
and returns them as separate DataFrames.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
streaming: If True, use streaming mode for large datasets
|
||||
|
||||
Returns:
|
||||
Tuple of (alpha158_df, market_ext_df, market_flag_df, industry_df)
|
||||
"""
|
||||
print(f"Loading data from {start_date} to {end_date}...")
|
||||
|
||||
df_alpha = load_alpha158(start_date, end_date, streaming=streaming)
|
||||
df_market_ext = load_market_ext(start_date, end_date, streaming=streaming)
|
||||
df_market_flag = load_market_flags(start_date, end_date, streaming=streaming)
|
||||
df_industry = load_industry_flags(start_date, end_date, streaming=streaming)
|
||||
|
||||
return df_alpha, df_market_ext, df_market_flag, df_industry
|
||||
@ -0,0 +1,539 @@
|
||||
"""FeaturePipeline orchestrator for the data loading and transformation pipeline."""
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
from .dataclass import FeatureGroups
|
||||
from .loaders import (
|
||||
load_alpha158,
|
||||
load_market_ext,
|
||||
load_market_flags,
|
||||
load_industry_flags,
|
||||
load_all_data,
|
||||
INDUSTRY_FLAG_COLS,
|
||||
)
|
||||
from .processors import (
|
||||
DiffProcessor,
|
||||
FlagMarketInjector,
|
||||
FlagSTInjector,
|
||||
ColumnRemover,
|
||||
FlagToOnehot,
|
||||
IndusNtrlInjector,
|
||||
RobustZScoreNorm,
|
||||
Fillna
|
||||
)
|
||||
|
||||
# Constants - Import from loaders module (single source of truth)
|
||||
# INDUSTRY_FLAG_COLS is now imported from .loaders above
|
||||
|
||||
# Alpha158 feature columns in explicit order
|
||||
ALPHA158_COLS = [
|
||||
'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2',
|
||||
'OPEN0', 'HIGH0', 'LOW0', 'VWAP0',
|
||||
'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60',
|
||||
'MA5', 'MA10', 'MA20', 'MA30', 'MA60',
|
||||
'STD5', 'STD10', 'STD20', 'STD30', 'STD60',
|
||||
'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60',
|
||||
'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60',
|
||||
'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60',
|
||||
'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60',
|
||||
'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60',
|
||||
'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60',
|
||||
'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60',
|
||||
'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60',
|
||||
'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60',
|
||||
'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60',
|
||||
'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60',
|
||||
'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60',
|
||||
'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60',
|
||||
'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60',
|
||||
'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60',
|
||||
'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60',
|
||||
'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60',
|
||||
'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60',
|
||||
'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60',
|
||||
'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60',
|
||||
'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60',
|
||||
'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60',
|
||||
'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60',
|
||||
'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60',
|
||||
'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60',
|
||||
'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'
|
||||
]
|
||||
|
||||
# Market extension base columns
|
||||
MARKET_EXT_BASE_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength']
|
||||
|
||||
# Market flag columns (before processors)
|
||||
MARKET_FLAG_COLS = [
|
||||
'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR',
|
||||
'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'
|
||||
]
|
||||
|
||||
# Columns to remove after FlagMarketInjector and FlagSTInjector
|
||||
COLUMNS_TO_REMOVE = ['log_size_diff', 'IsN', 'IsZt', 'IsDt']
|
||||
|
||||
# Expected VAE input dimension
|
||||
VAE_INPUT_DIM = 341
|
||||
|
||||
# Default robust zscore parameters path
|
||||
DEFAULT_ROBUST_ZSCORE_PARAMS_PATH = (
|
||||
"/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/"
|
||||
"data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/"
|
||||
)
|
||||
|
||||
|
||||
def filter_stock_universe(df: pl.DataFrame, instruments: str = 'csiallx') -> pl.DataFrame:
|
||||
"""
|
||||
Filter dataframe to csiallx stock universe (A-shares excluding STAR/BSE).
|
||||
|
||||
This uses qshare's filter_instruments which loads the instrument list from:
|
||||
/data/qlib/default/data_ops/target/instruments/csiallx.txt
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with datetime and instrument columns
|
||||
instruments: Market name for spine creation (default: 'csiallx')
|
||||
|
||||
Returns:
|
||||
Filtered DataFrame with only instruments in the specified universe
|
||||
"""
|
||||
from qshare.algo.polars.spine import filter_instruments
|
||||
return filter_instruments(df, instruments=instruments)
|
||||
|
||||
|
||||
def load_robust_zscore_params(
|
||||
params_path: str = None
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Load pre-fitted RobustZScoreNorm parameters from numpy files.
|
||||
|
||||
Loads mean_train.npy and std_train.npy from the specified directory.
|
||||
Parameters are cached to avoid repeated file I/O.
|
||||
|
||||
Args:
|
||||
params_path: Path to the directory containing mean_train.npy and std_train.npy.
|
||||
If None, uses the default path.
|
||||
|
||||
Returns:
|
||||
Dictionary with 'mean_train' and 'std_train' numpy arrays
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If parameter files are not found
|
||||
"""
|
||||
if params_path is None:
|
||||
params_path = DEFAULT_ROBUST_ZSCORE_PARAMS_PATH
|
||||
|
||||
# Check for cached params in the class (module-level cache)
|
||||
if not hasattr(load_robust_zscore_params, '_cached_params'):
|
||||
load_robust_zscore_params._cached_params = {}
|
||||
|
||||
if params_path in load_robust_zscore_params._cached_params:
|
||||
return load_robust_zscore_params._cached_params[params_path]
|
||||
|
||||
print(f"Loading robust zscore parameters from: {params_path}")
|
||||
|
||||
mean_path = os.path.join(params_path, 'mean_train.npy')
|
||||
std_path = os.path.join(params_path, 'std_train.npy')
|
||||
|
||||
if not os.path.exists(mean_path):
|
||||
raise FileNotFoundError(f"mean_train.npy not found at {mean_path}")
|
||||
if not os.path.exists(std_path):
|
||||
raise FileNotFoundError(f"std_train.npy not found at {std_path}")
|
||||
|
||||
mean_train = np.load(mean_path)
|
||||
std_train = np.load(std_path)
|
||||
|
||||
print(f"Loaded parameters:")
|
||||
print(f" mean_train shape: {mean_train.shape}")
|
||||
print(f" std_train shape: {std_train.shape}")
|
||||
|
||||
# Try to load metadata if available
|
||||
metadata_path = os.path.join(params_path, 'metadata.json')
|
||||
if os.path.exists(metadata_path):
|
||||
with open(metadata_path, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
print(f" fitted on: {metadata.get('fit_start_time', 'N/A')} to {metadata.get('fit_end_time', 'N/A')}")
|
||||
|
||||
params = {
|
||||
'mean_train': mean_train,
|
||||
'std_train': std_train
|
||||
}
|
||||
|
||||
# Cache the loaded parameters
|
||||
load_robust_zscore_params._cached_params[params_path] = params
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class FeaturePipeline:
|
||||
"""
|
||||
Feature pipeline orchestrator for loading and transforming data.
|
||||
|
||||
The pipeline manages:
|
||||
1. Data loading from parquet sources
|
||||
2. Feature transformations via processors
|
||||
3. Output preparation for VAE input
|
||||
|
||||
Usage:
|
||||
pipeline = FeaturePipeline(config)
|
||||
feature_groups = pipeline.load_data(start_date, end_date)
|
||||
transformed = pipeline.transform(feature_groups)
|
||||
vae_input = pipeline.get_vae_input(transformed)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[Dict] = None,
|
||||
robust_zscore_params_path: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the FeaturePipeline.
|
||||
|
||||
Args:
|
||||
config: Optional configuration dictionary with pipeline settings.
|
||||
If None, uses default configuration.
|
||||
robust_zscore_params_path: Path to robust zscore parameters.
|
||||
If None, uses default path.
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.robust_zscore_params_path = robust_zscore_params_path or DEFAULT_ROBUST_ZSCORE_PARAMS_PATH
|
||||
|
||||
# Cache for loaded robust zscore params
|
||||
self._robust_zscore_params = None
|
||||
|
||||
# Initialize processors
|
||||
self._init_processors()
|
||||
|
||||
def _init_processors(self):
|
||||
"""Initialize all processors with default configuration."""
|
||||
# Diff processor for market_ext columns
|
||||
self.diff_processor = DiffProcessor(MARKET_EXT_BASE_COLS)
|
||||
|
||||
# Market flag injector
|
||||
self.flag_market_injector = FlagMarketInjector()
|
||||
|
||||
# ST flag injector
|
||||
self.flag_st_injector = FlagSTInjector()
|
||||
|
||||
# Column remover
|
||||
self.column_remover = ColumnRemover(COLUMNS_TO_REMOVE)
|
||||
|
||||
# Industry flag to index converter
|
||||
self.flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS)
|
||||
|
||||
# Industry neutralization injectors (created on-demand with specific feature lists)
|
||||
# self.indus_ntrl_injector = None # Created per-feature-group
|
||||
|
||||
# Robust zscore normalizer (created on-demand with loaded params)
|
||||
# self.robust_norm = None # Created with loaded params
|
||||
|
||||
# Fillna processor
|
||||
self.fillna = Fillna()
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
filter_universe: bool = True,
|
||||
universe_name: str = 'csiallx',
|
||||
streaming: bool = False
|
||||
) -> FeatureGroups:
|
||||
"""
|
||||
Load data from parquet sources into FeatureGroups container.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
filter_universe: Whether to filter to a specific stock universe
|
||||
universe_name: Name of the stock universe to filter to
|
||||
streaming: If True, use Polars streaming mode for large datasets (>1 year)
|
||||
|
||||
Returns:
|
||||
FeatureGroups container with loaded data
|
||||
"""
|
||||
print("=" * 60)
|
||||
print(f"Loading data from {start_date} to {end_date}")
|
||||
print("=" * 60)
|
||||
|
||||
# Load all data sources
|
||||
df_alpha, df_market_ext, df_market_flag, df_industry = load_all_data(
|
||||
start_date, end_date, streaming=streaming
|
||||
)
|
||||
|
||||
# Apply stock universe filter if requested
|
||||
if filter_universe:
|
||||
print(f"Filtering to {universe_name} universe...")
|
||||
df_alpha = filter_stock_universe(df_alpha, instruments=universe_name)
|
||||
df_market_ext = filter_stock_universe(df_market_ext, instruments=universe_name)
|
||||
df_market_flag = filter_stock_universe(df_market_flag, instruments=universe_name)
|
||||
df_industry = filter_stock_universe(df_industry, instruments=universe_name)
|
||||
print(f" After filter - Alpha158 shape: {df_alpha.shape}")
|
||||
|
||||
# Create FeatureGroups container
|
||||
feature_groups = FeatureGroups(
|
||||
alpha158=df_alpha,
|
||||
market_ext=df_market_ext,
|
||||
market_flag=df_market_flag,
|
||||
industry=df_industry
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
feature_groups.extract_metadata()
|
||||
|
||||
print(f"Loaded {len(feature_groups.instruments)} samples")
|
||||
print("=" * 60)
|
||||
|
||||
return feature_groups
|
||||
|
||||
def transform(
|
||||
self,
|
||||
feature_groups: FeatureGroups,
|
||||
pack_struct: bool = False
|
||||
) -> pl.DataFrame:
|
||||
"""
|
||||
Apply feature transformation pipeline to FeatureGroups.
|
||||
|
||||
The pipeline applies processors in the following order:
|
||||
1. DiffProcessor - adds diff features to market_ext
|
||||
2. FlagMarketInjector - adds market_0, market_1 to market_flag
|
||||
3. FlagSTInjector - adds IsST to market_flag
|
||||
4. ColumnRemover - removes log_size_diff, IsN, IsZt, IsDt
|
||||
5. FlagToOnehot - converts industry flags to indus_idx
|
||||
6. IndusNtrlInjector - industry neutralization for alpha158 and market_ext
|
||||
7. RobustZScoreNorm - robust z-score normalization
|
||||
8. Fillna - fill NaN values with 0
|
||||
|
||||
Args:
|
||||
feature_groups: FeatureGroups container with loaded data
|
||||
pack_struct: If True, pack each feature group into a struct column
|
||||
(features_alpha158, features_market_ext, features_market_flag).
|
||||
If False (default), return flat DataFrame with all columns merged.
|
||||
|
||||
Returns:
|
||||
Merged DataFrame with all transformed features (pl.DataFrame)
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("Starting feature transformation pipeline")
|
||||
print("=" * 60)
|
||||
|
||||
# Merge all groups for processing
|
||||
df = feature_groups.merge_for_processors()
|
||||
|
||||
# Step 1: Diff Processor - adds diff features for market_ext
|
||||
df = self.diff_processor.process(df)
|
||||
|
||||
# Step 2: FlagMarketInjector - adds market_0, market_1
|
||||
df = self.flag_market_injector.process(df)
|
||||
|
||||
# Step 3: FlagSTInjector - adds IsST
|
||||
df = self.flag_st_injector.process(df)
|
||||
|
||||
# Step 4: ColumnRemover - removes specific columns
|
||||
df = self.column_remover.process(df)
|
||||
|
||||
# Step 5: FlagToOnehot - converts industry flags to indus_idx
|
||||
df = self.flag_to_onehot.process(df)
|
||||
|
||||
# Step 6: IndusNtrlInjector - industry neutralization
|
||||
# For alpha158 features
|
||||
indus_ntrl_alpha = IndusNtrlInjector(ALPHA158_COLS, suffix='_ntrl')
|
||||
df = indus_ntrl_alpha.process(df, df) # Pass df as both feature and industry source
|
||||
|
||||
# For market_ext features (with diff columns)
|
||||
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||
# Remove columns that were dropped
|
||||
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||
indus_ntrl_ext = IndusNtrlInjector(market_ext_with_diff, suffix='_ntrl')
|
||||
df = indus_ntrl_ext.process(df, df)
|
||||
|
||||
# Step 7: RobustZScoreNorm - robust z-score normalization
|
||||
# Load parameters and create normalizer
|
||||
if self._robust_zscore_params is None:
|
||||
self._robust_zscore_params = load_robust_zscore_params(self.robust_zscore_params_path)
|
||||
|
||||
# Build the list of features to normalize
|
||||
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||
|
||||
# Feature order for VAE: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
|
||||
norm_feature_cols = (
|
||||
alpha158_ntrl_cols + ALPHA158_COLS +
|
||||
market_ext_ntrl_cols + market_ext_with_diff
|
||||
)
|
||||
|
||||
print(f"Applying RobustZScoreNorm to {len(norm_feature_cols)} features...")
|
||||
|
||||
robust_norm = RobustZScoreNorm(
|
||||
norm_feature_cols,
|
||||
clip_range=(-3, 3),
|
||||
use_qlib_params=True,
|
||||
qlib_mean=self._robust_zscore_params['mean_train'],
|
||||
qlib_std=self._robust_zscore_params['std_train']
|
||||
)
|
||||
df = robust_norm.process(df)
|
||||
|
||||
# Step 8: Fillna - fill NaN with 0
|
||||
# Get all feature columns
|
||||
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||
market_flag_cols += ['market_0', 'market_1', 'IsST']
|
||||
market_flag_cols = list(dict.fromkeys(market_flag_cols)) # Remove duplicates
|
||||
|
||||
final_feature_cols = norm_feature_cols + market_flag_cols + ['indus_idx']
|
||||
df = self.fillna.process(df, final_feature_cols)
|
||||
|
||||
print("=" * 60)
|
||||
print("Pipeline complete")
|
||||
print(f" Total columns: {len(df.columns)}")
|
||||
print(f" Rows: {len(df)}")
|
||||
|
||||
# Optionally pack features into struct columns
|
||||
if pack_struct:
|
||||
df = self._pack_into_structs(df)
|
||||
print(f" Packed into struct columns")
|
||||
|
||||
print("=" * 60)
|
||||
return df
|
||||
|
||||
def _pack_into_structs(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Pack feature groups into struct columns.
|
||||
|
||||
Creates:
|
||||
- features_alpha158: struct with 316 fields (158 + 158 _ntrl)
|
||||
- features_market_ext: struct with 14 fields (7 + 7 _ntrl)
|
||||
- features_market_flag: struct with 11 fields
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: instrument, datetime, indus_idx, features_*
|
||||
"""
|
||||
# Define column groups
|
||||
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||
alpha158_all_cols = alpha158_ntrl_cols + ALPHA158_COLS
|
||||
|
||||
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||
market_ext_all_cols = market_ext_ntrl_cols + market_ext_with_diff
|
||||
|
||||
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||
market_flag_cols += ['market_0', 'market_1', 'IsST']
|
||||
market_flag_cols = list(dict.fromkeys(market_flag_cols))
|
||||
|
||||
# Build result with struct columns
|
||||
result_cols = ['instrument', 'datetime']
|
||||
|
||||
# Check if indus_idx exists
|
||||
if 'indus_idx' in df.columns:
|
||||
result_cols.append('indus_idx')
|
||||
|
||||
# Pack alpha158
|
||||
alpha158_cols_in_df = [c for c in alpha158_all_cols if c in df.columns]
|
||||
if alpha158_cols_in_df:
|
||||
result_cols.append(pl.struct(alpha158_cols_in_df).alias('features_alpha158'))
|
||||
|
||||
# Pack market_ext
|
||||
ext_cols_in_df = [c for c in market_ext_all_cols if c in df.columns]
|
||||
if ext_cols_in_df:
|
||||
result_cols.append(pl.struct(ext_cols_in_df).alias('features_market_ext'))
|
||||
|
||||
# Pack market_flag
|
||||
flag_cols_in_df = [c for c in market_flag_cols if c in df.columns]
|
||||
if flag_cols_in_df:
|
||||
result_cols.append(pl.struct(flag_cols_in_df).alias('features_market_flag'))
|
||||
|
||||
return df.select(result_cols)
|
||||
|
||||
def get_vae_input(
|
||||
self,
|
||||
df: pl.DataFrame | FeatureGroups,
|
||||
exclude_isst: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Prepare VAE input features from transformed DataFrame or FeatureGroups.
|
||||
|
||||
VAE input structure (341 features):
|
||||
- feature group (316): 158 alpha158 + 158 alpha158_ntrl
|
||||
- feature_ext group (14): 7 market_ext + 7 market_ext_ntrl
|
||||
- feature_flag group (11): market flags (excluding IsST)
|
||||
|
||||
NOTE: indus_idx is NOT included in VAE input.
|
||||
|
||||
Args:
|
||||
df: Transformed DataFrame or FeatureGroups container
|
||||
exclude_isst: Whether to exclude IsST from VAE input (default: True)
|
||||
|
||||
Returns:
|
||||
Numpy array of shape (n_samples, VAE_INPUT_DIM)
|
||||
"""
|
||||
print("Preparing features for VAE...")
|
||||
|
||||
# Accept either DataFrame or FeatureGroups
|
||||
if isinstance(df, FeatureGroups):
|
||||
df = df.merge_for_processors()
|
||||
|
||||
# Build alpha158 feature columns
|
||||
alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS]
|
||||
alpha158_cols = ALPHA158_COLS.copy()
|
||||
|
||||
# Build market_ext feature columns (with diff, minus removed columns)
|
||||
market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS]
|
||||
market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE]
|
||||
market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff]
|
||||
market_ext_cols = market_ext_with_diff.copy()
|
||||
|
||||
# VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext]
|
||||
norm_feature_cols = (
|
||||
alpha158_ntrl_cols + alpha158_cols +
|
||||
market_ext_ntrl_cols + market_ext_cols
|
||||
)
|
||||
|
||||
# Market flag columns (excluding IsST if requested)
|
||||
market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE]
|
||||
market_flag_cols += ['market_0', 'market_1']
|
||||
if not exclude_isst:
|
||||
market_flag_cols.append('IsST')
|
||||
market_flag_cols = list(dict.fromkeys(market_flag_cols))
|
||||
|
||||
# Combine all VAE input columns
|
||||
vae_cols = norm_feature_cols + market_flag_cols
|
||||
|
||||
print(f" norm_feature_cols: {len(norm_feature_cols)}")
|
||||
print(f" market_flag_cols: {len(market_flag_cols)}")
|
||||
print(f" Total VAE input columns: {len(vae_cols)}")
|
||||
|
||||
# Verify all columns exist
|
||||
missing_cols = [c for c in vae_cols if c not in df.columns]
|
||||
if missing_cols:
|
||||
print(f"WARNING: Missing columns: {missing_cols}")
|
||||
|
||||
# Select features and convert to numpy
|
||||
features_df = df.select(vae_cols)
|
||||
features = features_df.to_numpy().astype(np.float32)
|
||||
|
||||
# Handle any remaining NaN/Inf values
|
||||
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
print(f"Feature matrix shape: {features.shape}")
|
||||
|
||||
# Verify dimensions
|
||||
if features.shape[1] != VAE_INPUT_DIM:
|
||||
print(f"WARNING: Expected {VAE_INPUT_DIM} features, got {features.shape[1]}")
|
||||
|
||||
if features.shape[1] < VAE_INPUT_DIM:
|
||||
# Pad with zeros
|
||||
padding = np.zeros(
|
||||
(features.shape[0], VAE_INPUT_DIM - features.shape[1]),
|
||||
dtype=np.float32
|
||||
)
|
||||
features = np.concatenate([features, padding], axis=1)
|
||||
print(f"Padded to shape: {features.shape}")
|
||||
else:
|
||||
# Truncate
|
||||
features = features[:, :VAE_INPUT_DIM]
|
||||
print(f"Truncated to shape: {features.shape}")
|
||||
|
||||
return features
|
||||
@ -0,0 +1,447 @@
|
||||
"""Processor classes for feature transformation.
|
||||
|
||||
Processors are stateless, functional components that operate on
|
||||
specific feature groups or merged DataFrames.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
|
||||
|
||||
class DiffProcessor:
|
||||
"""
|
||||
Diff Processor: Calculate diff features for market_ext columns.
|
||||
|
||||
For each column, calculates diff with period=1 within each instrument group.
|
||||
"""
|
||||
|
||||
def __init__(self, columns: List[str]):
|
||||
"""
|
||||
Initialize the DiffProcessor.
|
||||
|
||||
Args:
|
||||
columns: List of column names to compute diffs for
|
||||
"""
|
||||
self.columns = columns
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Add diff features for specified columns.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with market_ext columns
|
||||
|
||||
Returns:
|
||||
DataFrame with added {col}_diff columns
|
||||
"""
|
||||
print("Applying Diff processor...")
|
||||
|
||||
# CRITICAL: Build all expressions FIRST, then apply in single with_columns()
|
||||
# Use order_by='datetime' to ensure proper time-series ordering within each instrument
|
||||
diff_exprs = [
|
||||
pl.col(col).diff().over('instrument', order_by='datetime').alias(f"{col}_diff")
|
||||
for col in self.columns
|
||||
if col in df.columns
|
||||
]
|
||||
|
||||
if diff_exprs:
|
||||
df = df.with_columns(diff_exprs)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class FlagMarketInjector:
|
||||
"""
|
||||
Flag Market Injector: Create market_0, market_1 columns based on instrument code.
|
||||
|
||||
Maps to Qlib's map_market_sec logic with vocab_size=2:
|
||||
- market_0 (主板): SH60xxx, SZ00xxx
|
||||
- market_1 (科创板/创业板): SH688xxx, SH689xxx, SZ300xxx, SZ301xxx
|
||||
|
||||
NOTE: vocab_size=2 (not 3!) - 新三板/北交所 (NE4xxxx, NE8xxxx) are NOT included.
|
||||
"""
|
||||
|
||||
def process(self, df: pl.DataFrame, instrument_col: str = 'instrument') -> pl.DataFrame:
|
||||
"""
|
||||
Add market_0, market_1 columns.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with instrument column
|
||||
instrument_col: Name of the instrument column
|
||||
|
||||
Returns:
|
||||
DataFrame with added market_0 and market_1 columns
|
||||
"""
|
||||
print("Applying FlagMarketInjector (vocab_size=2)...")
|
||||
|
||||
# Convert instrument to string and pad to 6 digits
|
||||
inst_str = pl.col(instrument_col).cast(pl.String).str.zfill(6)
|
||||
|
||||
# Determine market type based on first digit
|
||||
is_sh_main = inst_str.str.starts_with('6') # SH600xxx, SH601xxx, etc.
|
||||
is_sz_main = inst_str.str.starts_with('0') # SZ000xxx
|
||||
is_sh_star = inst_str.str.starts_with('688') | inst_str.str.starts_with('689') # SH688xxx, SH689xxx
|
||||
is_sz_chi = inst_str.str.starts_with('300') | inst_str.str.starts_with('301') # SZ300xxx, SZ301xxx
|
||||
|
||||
df = df.with_columns([
|
||||
# market_0 = 主板 (SH main + SZ main)
|
||||
(is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'),
|
||||
# market_1 = 科创板 + 创业板 (SH star + SZ ChiNext)
|
||||
(is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1')
|
||||
])
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class FlagSTInjector:
|
||||
"""
|
||||
Flag ST Injector: Create IsST column from ST flags.
|
||||
|
||||
Creates IsST = ST_S | ST_Y if ST flags are available,
|
||||
otherwise creates a placeholder column of all zeros.
|
||||
"""
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Add IsST column.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame
|
||||
|
||||
Returns:
|
||||
DataFrame with added IsST column
|
||||
"""
|
||||
print("Applying FlagSTInjector (creating IsST)...")
|
||||
|
||||
# Check if ST flags are available
|
||||
if 'ST_S' in df.columns or 'st_flag::ST_S' in df.columns:
|
||||
# Create IsST from actual ST flags
|
||||
df = df.with_columns([
|
||||
((pl.col('ST_S').cast(pl.Boolean, strict=False) |
|
||||
pl.col('ST_Y').cast(pl.Boolean, strict=False))
|
||||
.cast(pl.Int8).alias('IsST'))
|
||||
])
|
||||
else:
|
||||
# Create placeholder (all zeros) if ST flags not available
|
||||
df = df.with_columns([
|
||||
pl.lit(0).cast(pl.Int8).alias('IsST')
|
||||
])
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class ColumnRemover:
|
||||
"""
|
||||
Column Remover: Drop specific columns.
|
||||
|
||||
Removes columns that are not needed for the VAE input.
|
||||
"""
|
||||
|
||||
def __init__(self, columns_to_remove: List[str]):
|
||||
"""
|
||||
Initialize the ColumnRemover.
|
||||
|
||||
Args:
|
||||
columns_to_remove: List of column names to remove
|
||||
"""
|
||||
self.columns_to_remove = columns_to_remove
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Remove specified columns.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame
|
||||
|
||||
Returns:
|
||||
DataFrame with specified columns removed
|
||||
"""
|
||||
print(f"Applying ColumnRemover (removing {len(self.columns_to_remove)} columns)...")
|
||||
|
||||
# Only remove columns that exist
|
||||
cols_to_drop = [c for c in self.columns_to_remove if c in df.columns]
|
||||
if cols_to_drop:
|
||||
df = df.drop(cols_to_drop)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class FlagToOnehot:
|
||||
"""
|
||||
Flag To Onehot: Convert 29 one-hot industry columns to single indus_idx.
|
||||
|
||||
For each row, finds which industry column is True/1 and sets indus_idx to that index.
|
||||
"""
|
||||
|
||||
def __init__(self, industry_cols: List[str]):
|
||||
"""
|
||||
Initialize the FlagToOnehot.
|
||||
|
||||
Args:
|
||||
industry_cols: List of 29 industry flag column names
|
||||
"""
|
||||
self.industry_cols = industry_cols
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Convert industry flags to single indus_idx column.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with industry flag columns
|
||||
|
||||
Returns:
|
||||
DataFrame with indus_idx column (original industry columns removed)
|
||||
"""
|
||||
print("Applying FlagToOnehot (converting 29 industry flags to indus_idx)...")
|
||||
|
||||
# Build a when/then chain to find the industry index
|
||||
# Start with -1 (no industry) as default
|
||||
indus_expr = pl.lit(-1)
|
||||
|
||||
for idx, col in enumerate(self.industry_cols):
|
||||
if col in df.columns:
|
||||
indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr)
|
||||
|
||||
df = df.with_columns([indus_expr.alias('indus_idx')])
|
||||
|
||||
# Drop the original one-hot columns
|
||||
cols_to_drop = [c for c in self.industry_cols if c in df.columns]
|
||||
if cols_to_drop:
|
||||
df = df.drop(cols_to_drop)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class IndusNtrlInjector:
|
||||
"""
|
||||
Industry Neutralization Injector: Industry neutralization for features.
|
||||
|
||||
For each feature, subtracts the industry mean (grouped by indus_idx)
|
||||
from the feature value. Creates new columns with "_ntrl" suffix while
|
||||
keeping original columns.
|
||||
|
||||
IMPORTANT: Industry neutralization is done PER DATETIME (cross-sectional),
|
||||
not across the entire dataset. This matches qlib's cal_indus_ntrl behavior.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'):
|
||||
"""
|
||||
Initialize the IndusNtrlInjector.
|
||||
|
||||
Args:
|
||||
feature_cols: List of feature column names to neutralize
|
||||
suffix: Suffix to append to neutralized column names
|
||||
"""
|
||||
self.feature_cols = feature_cols
|
||||
self.suffix = suffix
|
||||
|
||||
def process(
|
||||
self,
|
||||
feature_df: pl.DataFrame,
|
||||
industry_df: pl.DataFrame
|
||||
) -> pl.DataFrame:
|
||||
"""
|
||||
Apply industry neutralization to specified features.
|
||||
|
||||
Args:
|
||||
feature_df: DataFrame with feature columns to neutralize
|
||||
industry_df: DataFrame with indus_idx column (must have instrument, datetime)
|
||||
|
||||
Returns:
|
||||
DataFrame with added {col}_ntrl columns
|
||||
"""
|
||||
print(f"Applying IndusNtrlInjector to {len(self.feature_cols)} features...")
|
||||
|
||||
# Check if indus_idx already exists in feature_df
|
||||
if 'indus_idx' in feature_df.columns:
|
||||
df = feature_df
|
||||
else:
|
||||
# Merge industry index into feature dataframe
|
||||
df = feature_df.join(
|
||||
industry_df.select(['instrument', 'datetime', 'indus_idx']),
|
||||
on=['instrument', 'datetime'],
|
||||
how='left',
|
||||
suffix='_indus'
|
||||
)
|
||||
|
||||
# Filter to only columns that exist
|
||||
existing_cols = [c for c in self.feature_cols if c in df.columns]
|
||||
|
||||
# CRITICAL: Build all neutralization expressions FIRST, then apply in single with_columns()
|
||||
# Use order_by='datetime' to ensure proper time-series ordering within each group
|
||||
# The neutralization is done per datetime (cross-sectional), so order_by='datetime'
|
||||
# ensures values are processed in chronological order
|
||||
ntrl_exprs = [
|
||||
(pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'], order_by='datetime')).alias(f"{col}{self.suffix}")
|
||||
for col in existing_cols
|
||||
]
|
||||
|
||||
if ntrl_exprs:
|
||||
df = df.with_columns(ntrl_exprs)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class RobustZScoreNorm:
|
||||
"""
|
||||
Robust Z-Score Normalization: Per datetime normalization.
|
||||
|
||||
(x - median) / (1.4826 * MAD) where MAD = median(|x - median|)
|
||||
Clip outliers at [-3, 3].
|
||||
|
||||
Supports pre-fitted parameters from qlib's pickled processor:
|
||||
normalizer = RobustZScoreNorm(
|
||||
feature_cols=feature_cols,
|
||||
use_qlib_params=True,
|
||||
qlib_mean=zscore_proc.mean_train,
|
||||
qlib_std=zscore_proc.std_train
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_cols: List[str],
|
||||
clip_range: Tuple[float, float] = (-3, 3),
|
||||
use_qlib_params: bool = False,
|
||||
qlib_mean: Optional[np.ndarray] = None,
|
||||
qlib_std: Optional[np.ndarray] = None
|
||||
):
|
||||
"""
|
||||
Initialize the RobustZScoreNorm.
|
||||
|
||||
Args:
|
||||
feature_cols: List of feature column names to normalize
|
||||
clip_range: Tuple of (min, max) for clipping normalized values
|
||||
use_qlib_params: Whether to use pre-fitted parameters from qlib
|
||||
qlib_mean: Pre-fitted mean array from qlib (required if use_qlib_params=True)
|
||||
qlib_std: Pre-fitted std array from qlib (required if use_qlib_params=True)
|
||||
"""
|
||||
self.feature_cols = feature_cols
|
||||
self.clip_range = clip_range
|
||||
self.use_qlib_params = use_qlib_params
|
||||
self.mean_train = qlib_mean
|
||||
self.std_train = qlib_std
|
||||
|
||||
if use_qlib_params:
|
||||
if qlib_mean is None or qlib_std is None:
|
||||
raise ValueError("Must provide qlib_mean and qlib_std when use_qlib_params=True")
|
||||
print(f"Using pre-fitted qlib parameters (mean shape: {qlib_mean.shape}, std shape: {qlib_std.shape})")
|
||||
|
||||
def process(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Apply robust z-score normalization.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame with feature columns
|
||||
|
||||
Returns:
|
||||
DataFrame with normalized feature columns (in-place modification)
|
||||
"""
|
||||
print(f"Applying RobustZScoreNorm to {len(self.feature_cols)} features...")
|
||||
|
||||
# Filter to only columns that exist
|
||||
existing_cols = [c for c in self.feature_cols if c in df.columns]
|
||||
|
||||
if self.use_qlib_params:
|
||||
# Use pre-fitted parameters from qlib (fit once, apply to all dates)
|
||||
# CRITICAL: Build all normalization expressions FIRST, then apply in single with_columns()
|
||||
# This avoids creating a copy per column (330 columns = 330 copies if done in loop)
|
||||
norm_exprs = []
|
||||
for i, col in enumerate(existing_cols):
|
||||
if i < len(self.mean_train):
|
||||
mean_val = float(self.mean_train[i])
|
||||
std_val = float(self.std_train[i])
|
||||
norm_exprs.append(
|
||||
((pl.col(col) - mean_val) / (std_val + 1e-8))
|
||||
.clip(self.clip_range[0], self.clip_range[1])
|
||||
.alias(col)
|
||||
)
|
||||
|
||||
if norm_exprs:
|
||||
df = df.with_columns(norm_exprs)
|
||||
else:
|
||||
# Compute per-datetime robust z-score (original behavior)
|
||||
# CRITICAL: Build all expressions with temp columns first, then clean up in single drop()
|
||||
all_exprs = []
|
||||
temp_cols = []
|
||||
|
||||
for col in existing_cols:
|
||||
# Compute median per datetime
|
||||
median_col = f"__median_{col}"
|
||||
temp_cols.append(median_col)
|
||||
all_exprs.append(
|
||||
pl.col(col).median().over('datetime').alias(median_col)
|
||||
)
|
||||
|
||||
# Compute absolute deviation
|
||||
abs_dev_col = f"__absdev_{col}"
|
||||
temp_cols.append(abs_dev_col)
|
||||
all_exprs.append(
|
||||
(pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col)
|
||||
)
|
||||
|
||||
# Compute MAD (median of absolute deviations)
|
||||
mad_col = f"__mad_{col}"
|
||||
temp_cols.append(mad_col)
|
||||
all_exprs.append(
|
||||
pl.col(abs_dev_col).median().over('datetime').alias(mad_col)
|
||||
)
|
||||
|
||||
# Compute robust z-score and clip (modifies original column)
|
||||
all_exprs.append(
|
||||
((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8))
|
||||
.clip(self.clip_range[0], self.clip_range[1])
|
||||
.alias(col)
|
||||
)
|
||||
|
||||
# Apply all expressions in single with_columns()
|
||||
if all_exprs:
|
||||
df = df.with_columns(all_exprs)
|
||||
|
||||
# Clean up all temporary columns in single drop()
|
||||
if temp_cols:
|
||||
df = df.drop(temp_cols)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class Fillna:
|
||||
"""
|
||||
Fill NaN: Fill all NaN values with 0 for numeric columns.
|
||||
"""
|
||||
|
||||
def process(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
feature_cols: List[str]
|
||||
) -> pl.DataFrame:
|
||||
"""
|
||||
Fill NaN values with 0 for specified columns.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame
|
||||
feature_cols: List of column names to fill NaN values for
|
||||
|
||||
Returns:
|
||||
DataFrame with NaN values filled with 0
|
||||
"""
|
||||
print("Applying Fillna processor...")
|
||||
|
||||
# Filter to only columns that exist and are numeric (not boolean)
|
||||
existing_cols = [
|
||||
c for c in feature_cols
|
||||
if c in df.columns and df[c].dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64, pl.UInt32, pl.UInt64]
|
||||
]
|
||||
|
||||
# CRITICAL: Build all fill expressions FIRST, then apply in single with_columns()
|
||||
# This avoids creating a copy per column (~345 columns = ~345 copies if done in loop)
|
||||
fill_exprs = [
|
||||
pl.col(col).fill_null(0.0).fill_nan(0.0)
|
||||
for col in existing_cols
|
||||
]
|
||||
|
||||
if fill_exprs:
|
||||
df = df.with_columns(fill_exprs)
|
||||
|
||||
return df
|
||||
Loading…
Reference in new issue