diff --git a/stock_1d/d033/alpha158_beta/README.md b/stock_1d/d033/alpha158_beta/README.md index c1b678a..2b84307 100644 --- a/stock_1d/d033/alpha158_beta/README.md +++ b/stock_1d/d033/alpha158_beta/README.md @@ -115,17 +115,22 @@ All fixed processors preserve the trained parameters from the original proc_list ### Polars Dataset Generation -The `scripts/dump_polars_dataset.py` script generates datasets using a polars-based pipeline that replicates the qlib preprocessing: +The `scripts/dump_features.py` script generates datasets using a polars-based pipeline that replicates the qlib preprocessing: ```bash -# Generate raw and processed datasets -python scripts/dump_polars_dataset.py +# Generate merged features (flat columns) +python scripts/dump_features.py --start-date 2024-01-01 --end-date 2024-01-31 --groups merged + +# Generate with struct columns (packed feature groups) +python scripts/dump_features.py --start-date 2024-01-01 --end-date 2024-01-31 --groups merged --pack-struct + +# Generate specific feature groups +python scripts/dump_features.py --start-date 2024-01-01 --end-date 2024-01-31 --groups alpha158 market_ext ``` This script: 1. Loads data from Parquet files (alpha158, kline, market flags, industry flags) -2. Saves raw data (before processors) to `data_polars/raw_data_*.pkl` -3. Applies the full processor pipeline: +2. Applies the full processor pipeline: - Diff processor (adds diff features) - FlagMarketInjector (adds market_0, market_1) - ColumnRemover (removes log_size_diff, IsN, IsZt, IsDt) @@ -133,13 +138,20 @@ This script: - IndusNtrlInjector (industry neutralization) - RobustZScoreNorm (using pre-fitted qlib parameters via `from_version()`) - Fillna (fill NaN with 0) -4. Saves processed data to `data_polars/processed_data_*.pkl` +3. Saves to parquet/pickle format + +**Output modes:** +- **Flat mode (default)**: All columns as separate fields (348 columns for merged) +- **Struct mode (`--pack-struct`)**: Feature groups packed into struct columns: + - `features_alpha158` (316 fields) + - `features_market_ext` (14 fields) + - `features_market_flag` (11 fields) **Note**: The `FlagSTInjector` step is skipped because it fails silently even in the gold-standard qlib code (see `BUG_ANALYSIS_FINAL.md` for details). Output structure: - Raw data: ~204 columns (158 feature + 4 feature_ext + 12 feature_flag + 30 indus_flag) -- Processed data: 342 columns (316 feature + 14 feature_ext + 11 feature_flag + 1 indus_idx) +- Processed data: 348 columns (318 alpha158 + 14 market_ext + 14 market_flag + 2 index) - VAE input dimension: 341 (excluding indus_idx) ### RobustZScoreNorm Parameter Extraction diff --git a/stock_1d/d033/alpha158_beta/scripts/dump_features.py b/stock_1d/d033/alpha158_beta/scripts/dump_features.py new file mode 100644 index 0000000..3a5d10d --- /dev/null +++ b/stock_1d/d033/alpha158_beta/scripts/dump_features.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python +""" +Script to generate and dump transformed features from the alpha158_beta pipeline. + +This script provides fine-grained control over the feature generation and dumping process: +- Select which feature groups to dump (alpha158, market_ext, market_flag, merged, vae_input) +- Choose output format (parquet, pickle, numpy) +- Control date range and universe filtering +- Save intermediate pipeline outputs +- Enable streaming mode for large datasets (>1 year) + +Usage: + # Dump all features to parquet + python dump_features.py --start-date 2025-01-01 --end-date 2025-01-31 + + # Dump only alpha158 features to pickle + python dump_features.py --groups alpha158 --format pickle + + # Dump with custom output path + python dump_features.py --output /path/to/output.parquet + + # Dump merged features with all columns + python dump_features.py --groups merged --verbose + + # Use streaming mode for large date ranges (>1 year) + python dump_features.py --start-date 2020-01-01 --end-date 2023-12-31 --streaming +""" + +import os +import sys +import argparse +from pathlib import Path +from typing import Optional, List + +# Add src to path for imports +SCRIPT_DIR = Path(__file__).parent +sys.path.insert(0, str(SCRIPT_DIR.parent / 'src')) + +from processors import ( + FeaturePipeline, + FeatureGroups, + VAE_INPUT_DIM, + ALPHA158_COLS, + MARKET_EXT_BASE_COLS, + COLUMNS_TO_REMOVE, + get_groups, + dump_to_parquet, + dump_to_pickle, + dump_to_numpy, +) + +# Default output directory +DEFAULT_OUTPUT_DIR = SCRIPT_DIR.parent / "data" + + +def generate_and_dump( + start_date: str, + end_date: str, + output_path: str, + output_format: str = 'parquet', + groups: List[str] = None, + universe: str = 'csiallx', + filter_universe: bool = True, + robust_zscore_params_path: Optional[str] = None, + verbose: bool = True, + pack_struct: bool = False, + streaming: bool = False, +) -> None: + """ + Generate features and dump to file. + + Args: + start_date: Start date in YYYY-MM-DD format + end_date: End date in YYYY-MM-DD format + output_path: Output file path + output_format: Output format ('parquet', 'pickle', 'numpy') + groups: Feature groups to dump (default: ['merged']) + universe: Stock universe name + filter_universe: Whether to filter to stock universe + robust_zscore_params_path: Path to robust zscore parameters + verbose: Whether to print progress + pack_struct: If True, pack each feature group into struct columns + (features_alpha158, features_market_ext, features_market_flag) + streaming: If True, use Polars streaming mode for large datasets (>1 year) + """ + if groups is None: + groups = ['merged'] + + print("=" * 60) + print("Feature Dump Script") + print("=" * 60) + print(f"Date range: {start_date} to {end_date}") + print(f"Output format: {output_format}") + print(f"Feature groups: {groups}") + print(f"Universe: {universe} (filter: {filter_universe})") + print(f"Pack struct: {pack_struct}") + print(f"Output path: {output_path}") + print("=" * 60) + + # Initialize pipeline + pipeline = FeaturePipeline( + robust_zscore_params_path=robust_zscore_params_path + ) + + # Load data + feature_groups = pipeline.load_data( + start_date, end_date, + filter_universe=filter_universe, + universe_name=universe, + streaming=streaming + ) + + # Apply transformations - get merged DataFrame (pipeline always returns merged DataFrame now) + df_transformed = pipeline.transform(feature_groups, pack_struct=pack_struct) + + # Select feature groups from merged DataFrame + outputs = get_groups(df_transformed, groups, verbose, use_struct=False) + + # Ensure output directory exists + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + # Dump to file(s) + if output_format == 'numpy': + # For numpy, we save the merged features + dump_to_numpy(feature_groups, output_path, include_metadata=True, verbose=verbose) + elif output_format == 'pickle': + if 'merged' in outputs: + dump_to_pickle(outputs['merged'], output_path, verbose=verbose) + elif len(outputs) == 1: + # Single group output + key = list(outputs.keys())[0] + base_path = Path(output_path) + dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}")) + dump_to_pickle(outputs[key], dump_path, verbose=verbose) + else: + # Multiple groups - save each separately + base_path = Path(output_path) + for key, df_out in outputs.items(): + dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}")) + dump_to_pickle(df_out, dump_path, verbose=verbose) + else: # parquet + if 'merged' in outputs: + dump_to_parquet(outputs['merged'], output_path, verbose=verbose) + elif len(outputs) == 1: + # Single group output + key = list(outputs.keys())[0] + base_path = Path(output_path) + dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}")) + dump_to_parquet(outputs[key], dump_path, verbose=verbose) + else: + # Multiple groups - save each separately + base_path = Path(output_path) + for key, df_out in outputs.items(): + dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}")) + dump_to_parquet(df_out, dump_path, verbose=verbose) + + print("=" * 60) + print("Feature dump complete!") + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser( + description="Generate and dump transformed features from alpha158_beta pipeline" + ) + + # Date range + parser.add_argument( + "--start-date", type=str, required=True, + help="Start date in YYYY-MM-DD format" + ) + parser.add_argument( + "--end-date", type=str, required=True, + help="End date in YYYY-MM-DD format" + ) + + # Output settings + parser.add_argument( + "--output", "-o", type=str, default=None, + help=f"Output file path (default: {DEFAULT_OUTPUT_DIR}/features.parquet)" + ) + parser.add_argument( + "--format", "-f", type=str, default='parquet', + choices=['parquet', 'pickle', 'numpy'], + help="Output format (default: parquet)" + ) + + # Feature groups + parser.add_argument( + "--groups", "-g", type=str, nargs='+', default=['merged'], + choices=['merged', 'alpha158', 'market_ext', 'market_flag', 'vae_input'], + help="Feature groups to dump (default: merged)" + ) + + # Universe settings + parser.add_argument( + "--universe", type=str, default='csiallx', + help="Stock universe name (default: csiallx)" + ) + parser.add_argument( + "--no-filter-universe", action="store_true", + help="Disable stock universe filtering" + ) + + # Robust zscore parameters + parser.add_argument( + "--robust-zscore-params", type=str, default=None, + help="Path to robust zscore parameters directory" + ) + + # Verbose mode + parser.add_argument( + "--verbose", "-v", action="store_true", default=True, + help="Enable verbose output (default: True)" + ) + parser.add_argument( + "--quiet", "-q", action="store_true", + help="Disable verbose output" + ) + + # Struct option + parser.add_argument( + "--pack-struct", "-s", action="store_true", + help="Pack each feature group into separate struct columns (features_alpha158, features_market_ext, features_market_flag)" + ) + + # Streaming option + parser.add_argument( + "--streaming", action="store_true", + help="Use Polars streaming mode for large datasets (recommended for date ranges > 1 year)" + ) + + args = parser.parse_args() + + # Handle verbose/quiet flags + verbose = args.verbose and not args.quiet + + # Set default output path + if args.output is None: + # Build default output path: {data_dir}/features_{group}.parquet + # Note: generate_and_dump will add group suffix, so use base name "features" + output_path = str(DEFAULT_OUTPUT_DIR / "features.parquet") + else: + output_path = args.output + + # Generate and dump + generate_and_dump( + start_date=args.start_date, + end_date=args.end_date, + output_path=output_path, + output_format=args.format, + groups=args.groups, + universe=args.universe, + filter_universe=not args.no_filter_universe, + robust_zscore_params_path=args.robust_zscore_params, + verbose=verbose, + pack_struct=args.pack_struct, + streaming=args.streaming, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py b/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py deleted file mode 100644 index 7b961d0..0000000 --- a/stock_1d/d033/alpha158_beta/scripts/dump_polars_dataset.py +++ /dev/null @@ -1,364 +0,0 @@ -#!/usr/bin/env python -""" -Script to dump raw and processed datasets using the polars-based pipeline. - -This generates: -1. Raw data (before applying processors) - equivalent to qlib's handler output -2. Processed data (after applying all processors) - ready for VAE encoding - -Date range: 2026-02-23 to today (2026-02-27) -""" - -import os -import sys -import pickle as pkl -import numpy as np -import polars as pl -from pathlib import Path -from datetime import datetime - -# Add parent directory to path for imports -sys.path.insert(0, str(Path(__file__).parent)) - -# Import processors from the new shared module -from cta_1d.src.processors import ( - DiffProcessor, - FlagMarketInjector, - FlagSTInjector, - ColumnRemover, - FlagToOnehot, - IndusNtrlInjector, - RobustZScoreNorm, - Fillna, -) - -# Import constants from local module -from generate_beta_embedding import ( - ALPHA158_COLS, - INDUSTRY_FLAG_COLS, -) - -# Date range -START_DATE = "2026-02-23" -END_DATE = "2026-02-27" - -# Output directory -OUTPUT_DIR = Path(__file__).parent.parent / "data_polars" -OUTPUT_DIR.mkdir(parents=True, exist_ok=True) - - -def apply_processor_pipeline(df: pl.DataFrame) -> pl.DataFrame: - """ - Apply the full processor pipeline (equivalent to qlib's proc_list). - - This mimics the qlib proc_list: - 0. Diff: Adds diff features for market_ext columns - 1. FlagMarketInjector: Adds market_0, market_1 - 2. FlagSTInjector: Adds IsST (placeholder if ST flags not available) - 3. ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt - 4. FlagToOnehot: Converts 29 industry flags to indus_idx - 5. IndusNtrlInjector: Industry neutralization for feature - 6. IndusNtrlInjector: Industry neutralization for feature_ext - 7. RobustZScoreNorm: Normalization using pre-fitted qlib params - 8. Fillna: Fill NaN with 0 - """ - print("=" * 60) - print("Applying processor pipeline") - print("=" * 60) - - # market_ext columns (4 base) - market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] - - # market_flag columns (12 total before ColumnRemover) - market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', - 'open_limit', 'close_limit', 'low_limit', - 'open_stop', 'close_stop', 'high_stop'] - - # Step 1: Diff Processor - print("\n[1] Applying Diff processor...") - diff_processor = DiffProcessor(market_ext_base) - df = diff_processor.process(df) - - # After Diff: market_ext has 8 columns (4 base + 4 diff) - market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base] - - # Step 2: FlagMarketInjector (adds market_0, market_1) - print("[2] Applying FlagMarketInjector...") - flag_injector = FlagMarketInjector() - df = flag_injector.process(df) - - # Add market_0, market_1 to flag list - market_flag_with_market = market_flag_cols + ['market_0', 'market_1'] - - # Step 3: FlagSTInjector - adds IsST (placeholder if ST flags not available) - print("[3] Applying FlagSTInjector...") - flag_st_injector = FlagSTInjector() - df = flag_st_injector.process(df) - - # Add IsST to flag list - market_flag_with_st = market_flag_with_market + ['IsST'] - - # Step 4: ColumnRemover - print("[4] Applying ColumnRemover...") - columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt'] - remover = ColumnRemover(columns_to_remove) - df = remover.process(df) - - # Update column lists after removal - market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove] - market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove] - - print(f" Removed columns: {columns_to_remove}") - print(f" Remaining market_ext: {len(market_ext_cols)} columns") - print(f" Remaining market_flag: {len(market_flag_with_st)} columns") - - # Step 5: FlagToOnehot - print("[5] Applying FlagToOnehot...") - flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS) - df = flag_to_onehot.process(df) - - # Step 6 & 7: IndusNtrlInjector - print("[6] Applying IndusNtrlInjector for alpha158...") - alpha158_cols = ALPHA158_COLS.copy() - indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl') - df = indus_ntrl_alpha.process(df) - - print("[7] Applying IndusNtrlInjector for market_ext...") - indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl') - df = indus_ntrl_ext.process(df) - - # Build column lists for normalization - alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols] - market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols] - - # Step 8: RobustZScoreNorm - print("[8] Applying RobustZScoreNorm...") - norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols - - # Load RobustZScoreNorm with pre-fitted parameters from version - robust_norm = RobustZScoreNorm.from_version( - "csiallx_feature2_ntrla_flag_pnlnorm", - feature_cols=norm_feature_cols - ) - - # Verify parameter shape matches expected features - expected_features = len(norm_feature_cols) - if robust_norm.qlib_mean.shape[0] != expected_features: - print(f" WARNING: Feature count mismatch! Expected {expected_features}, " - f"got {robust_norm.qlib_mean.shape[0]}") - - df = robust_norm.process(df) - - # Step 9: Fillna - print("[9] Applying Fillna...") - final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx'] - fillna = Fillna() - df = fillna.process(df, final_feature_cols) - - print("\n" + "=" * 60) - print("Processor pipeline complete!") - print(f" Normalized features: {len(norm_feature_cols)}") - print(f" Market flags: {len(market_flag_with_st)}") - print(f" Total features (with indus_idx): {len(final_feature_cols)}") - print("=" * 60) - - return df - - -def convert_to_multiindex_df(df_polars: pl.DataFrame) -> "pd.DataFrame": - """ - Convert polars DataFrame to pandas DataFrame with MultiIndex columns. - This matches the format of qlib's output. - - IMPORTANT: Qlib's IndusNtrlInjector outputs columns in order [_ntrl] + [raw], - so we need to reorder columns to match this expected order. - """ - import pandas as pd - - # Convert to pandas - df = df_polars.to_pandas() - - # Check if datetime and instrument are columns - if 'datetime' in df.columns and 'instrument' in df.columns: - # Set MultiIndex - df = df.set_index(['datetime', 'instrument']) - # If they're already not in columns, assume they're already the index - - # Drop raw columns that shouldn't be in processed data - raw_cols_to_drop = ['Turnover', 'FreeTurnover', 'MarketValue'] - existing_raw_cols = [c for c in raw_cols_to_drop if c in df.columns] - if existing_raw_cols: - df = df.drop(columns=existing_raw_cols) - - # Build MultiIndex columns based on column name patterns - # IMPORTANT: Qlib order is [_ntrl columns] + [raw columns] for each group - columns_with_group = [] - - # Define column sets - alpha158_base = set(ALPHA158_COLS) - market_ext_base = {'turnover', 'free_turnover', 'log_size', 'con_rating_strength'} - market_ext_diff = {'turnover_diff', 'free_turnover_diff', 'con_rating_strength_diff'} - market_ext_all = market_ext_base | market_ext_diff - feature_flag_cols = {'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', 'open_limit', 'close_limit', 'low_limit', - 'open_stop', 'close_stop', 'high_stop', 'market_0', 'market_1', 'IsST'} - - # First pass: collect _ntrl columns (these come first in qlib order) - ntrl_alpha158_cols = [] - ntrl_market_ext_cols = [] - raw_alpha158_cols = [] - raw_market_ext_cols = [] - flag_cols = [] - indus_idx_col = None - - for col in df.columns: - if col == 'indus_idx': - indus_idx_col = col - elif col in feature_flag_cols: - flag_cols.append(col) - elif col.endswith('_ntrl'): - base_name = col[:-5] # Remove _ntrl suffix (5 characters) - if base_name in alpha158_base: - ntrl_alpha158_cols.append(col) - elif base_name in market_ext_all: - ntrl_market_ext_cols.append(col) - elif col in alpha158_base: - raw_alpha158_cols.append(col) - elif col in market_ext_all: - raw_market_ext_cols.append(col) - elif col in INDUSTRY_FLAG_COLS: - columns_with_group.append(('indus_flag', col)) - elif col in {'ST_S', 'ST_Y', 'ST_T', 'ST_L', 'ST_Z', 'ST_X'}: - columns_with_group.append(('st_flag', col)) - else: - # Unknown column - print warning - print(f" Warning: Unknown column '{col}', assigning to 'other' group") - columns_with_group.append(('other', col)) - - # Build columns in qlib order: [_ntrl] + [raw] for each feature group - # Feature group: alpha158_ntrl + alpha158 - for col in sorted(ntrl_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x.replace('_ntrl', '')) if x.replace('_ntrl', '') in ALPHA158_COLS else 999): - columns_with_group.append(('feature', col)) - for col in sorted(raw_alpha158_cols, key=lambda x: ALPHA158_COLS.index(x) if x in ALPHA158_COLS else 999): - columns_with_group.append(('feature', col)) - - # Feature_ext group: market_ext_ntrl + market_ext - for col in ntrl_market_ext_cols: - columns_with_group.append(('feature_ext', col)) - for col in raw_market_ext_cols: - columns_with_group.append(('feature_ext', col)) - - # Feature_flag group - for col in flag_cols: - columns_with_group.append(('feature_flag', col)) - - # Indus_idx - if indus_idx_col: - columns_with_group.append(('indus_idx', indus_idx_col)) - - # Create MultiIndex columns - multi_cols = pd.MultiIndex.from_tuples(columns_with_group) - df.columns = multi_cols - - return df - - -def main(): - print("=" * 80) - print("Dumping Polars Dataset") - print("=" * 80) - print(f"Date range: {START_DATE} to {END_DATE}") - print(f"Output directory: {OUTPUT_DIR}") - print() - - # Step 1: Load all data - print("Step 1: Loading data from parquet...") - df_alpha, df_kline, df_flag, df_industry = load_all_data(START_DATE, END_DATE) - print(f" Alpha158 shape: {df_alpha.shape}") - print(f" Kline (market_ext) shape: {df_kline.shape}") - print(f" Flags shape: {df_flag.shape}") - print(f" Industry shape: {df_industry.shape}") - - # Step 2: Merge data sources - print("\nStep 2: Merging data sources...") - df_merged = merge_data_sources(df_alpha, df_kline, df_flag, df_industry) - print(f" Merged shape (after csiallx filter): {df_merged.shape}") - - # Step 3: Save raw data (before processors) - print("\nStep 3: Saving raw data (before processors)...") - - # Keep columns that match qlib's raw output format - # Include datetime and instrument for MultiIndex conversion - raw_columns = ( - ['datetime', 'instrument'] + # Index columns - ALPHA158_COLS + # feature group - ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + # feature_ext base - ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', # market_flag from kline - 'open_limit', 'close_limit', 'low_limit', - 'open_stop', 'close_stop', 'high_stop'] + - INDUSTRY_FLAG_COLS + # indus_flag - (['ST_S', 'ST_Y'] if 'ST_S' in df_merged.columns else []) # st_flag (if available) - ) - - # Filter to available columns - available_raw_cols = [c for c in raw_columns if c in df_merged.columns] - print(f" Selecting {len(available_raw_cols)} columns for raw data...") - df_raw_polars = df_merged.select(available_raw_cols) - - # Convert to pandas with MultiIndex - df_raw_pd = convert_to_multiindex_df(df_raw_polars) - - raw_output_path = OUTPUT_DIR / f"raw_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl" - with open(raw_output_path, "wb") as f: - pkl.dump(df_raw_pd, f) - print(f" Saved raw data to: {raw_output_path}") - print(f" Raw data shape: {df_raw_pd.shape}") - print(f" Column groups: {df_raw_pd.columns.get_level_values(0).unique().tolist()}") - - # Step 4: Apply processor pipeline - print("\nStep 4: Applying processor pipeline...") - df_processed = apply_processor_pipeline(df_merged) - - # Step 5: Save processed data - print("\nStep 5: Saving processed data (after processors)...") - - # Convert to pandas with MultiIndex - df_processed_pd = convert_to_multiindex_df(df_processed) - - processed_output_path = OUTPUT_DIR / f"processed_data_{START_DATE.replace('-', '')}_{END_DATE.replace('-', '')}.pkl" - with open(processed_output_path, "wb") as f: - pkl.dump(df_processed_pd, f) - print(f" Saved processed data to: {processed_output_path}") - print(f" Processed data shape: {df_processed_pd.shape}") - print(f" Column groups: {df_processed_pd.columns.get_level_values(0).unique().tolist()}") - - # Count columns per group - print("\n Column counts by group:") - for grp in df_processed_pd.columns.get_level_values(0).unique().tolist(): - count = (df_processed_pd.columns.get_level_values(0) == grp).sum() - print(f" {grp}: {count} columns") - - # Step 6: Verify column counts - print("\n" + "=" * 80) - print("Verification") - print("=" * 80) - - feature_flag_cols = [c[1] for c in df_processed_pd.columns if c[0] == 'feature_flag'] - has_market_0 = 'market_0' in feature_flag_cols - has_market_1 = 'market_1' in feature_flag_cols - - print(f" feature_flag columns: {feature_flag_cols}") - print(f" Has market_0: {has_market_0}") - print(f" Has market_1: {has_market_1}") - - if has_market_0 and has_market_1: - print("\n SUCCESS: market_0 and market_1 columns are present!") - else: - print("\n WARNING: market_0 or market_1 columns are missing!") - - print("\n" + "=" * 80) - print("Dataset dump complete!") - print("=" * 80) - - -if __name__ == "__main__": - main() diff --git a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py index d8316c3..f7bf05e 100644 --- a/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py +++ b/stock_1d/d033/alpha158_beta/scripts/generate_beta_embedding.py @@ -2,51 +2,44 @@ """ Standalone script to generate embeddings from alpha158_0_7_beta factors using the VAE encoder. -This script implements the full feature transformation pipeline: -1. Load all 6 data sources from Parquet: - - Alpha158: stg_1day_wind_alpha158_0_7_beta_1D/ (158 features) - - Market Ext: stg_1day_wind_kline_adjusted_1D/ (Turnover, FreeTurnover, MarketValue -> log_size) - - Con Rating: stg_1day_gds_con_rating_1D/ (con_rating_strength) - - Market Flag: stg_1day_wind_kline_adjusted_1D/ (IsZt, IsDt, IsN, IsXD, IsXR, IsDR) - - Market Flag: stg_1day_wind_market_flag_1D/ (open_limit, close_limit, low_limit, open_stop, close_stop, high_stop) - - Industry Flag: stg_1day_gds_indus_flag_cc1_1D/ (29 one-hot industries) - -2. Apply 9 processors in sequence: - - Diff: Adds 4 diff features to feature_ext - - FlagMarketInjector: Adds market_0, market_1 to feature_flag - - FlagSTInjector: Adds IsST (placeholder, all zeros) - - ColumnRemover: Removes log_size_diff, IsN, IsZt, IsDt - - FlagToOnehot: Converts 29 industry flags to single indus_idx - - IndusNtrlInjector (x2): Industry neutralization for feature and feature_ext - - RobustZScoreNorm: Robust z-score normalization using pre-fitted qlib parameters - - Fillna: Fill NaN values with 0 - -3. Encode with VAE: - - Load VAE model from alpha/data_ops/tasks/dwm_feature_vae/model/ - - Run inference to generate 32-dim embeddings - - Save to parquet - -Note: Feature order is critical - alpha158 columns are in explicit order matching the VAE training. +This script uses the new modular processors package for data loading and feature +transformation, and focuses on VAE encoding and output generation. + +Workflow: +1. Load data using FeaturePipeline (loads from 6 parquet sources) +2. Transform features using the modular processor pipeline +3. Encode with VAE to generate 32-dim embeddings +4. Save embeddings to parquet + +Note: The data loading and transformation logic is now in the processors module: + stock_1d/d033/alpha158_beta/src/processors/ """ import os import sys import pickle as pkl -import io import numpy as np import polars as pl import torch import torch.nn as nn from pathlib import Path -from datetime import datetime -from typing import Optional, List, Tuple, Dict, Set - -# Constants -PARQUET_ALPHA158_BETA_PATH = "/data/parquet/dataset/stg_1day_wind_alpha158_0_7_beta_1D/" -PARQUET_KLINE_PATH = "/data/parquet/dataset/stg_1day_wind_kline_adjusted_1D/" -PARQUET_MARKET_FLAG_PATH = "/data/parquet/dataset/stg_1day_wind_market_flag_1D/" -PARQUET_INDUSTRY_FLAG_PATH = "/data/parquet/dataset/stg_1day_gds_indus_flag_cc1_1D/" -PARQUET_CON_RATING_PATH = "/data/parquet/dataset/stg_1day_gds_con_rating_1D/" +from typing import Optional, List, Tuple + +# Import from the new processors module +sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) +from processors import ( + FeaturePipeline, + FeatureGroups, + filter_stock_universe, + ALPHA158_COLS, + MARKET_EXT_BASE_COLS, + MARKET_FLAG_COLS, + COLUMNS_TO_REMOVE, + VAE_INPUT_DIM, + DEFAULT_ROBUST_ZSCORE_PARAMS_PATH, +) + +# Constants for VAE and output VAE_MODEL_PATH = "/home/guofu/Workspaces/alpha/data_ops/tasks/dwm_feature_vae/model/csiallx_feature2_ntrla_flag_pnlnorm_vae4_dim32a_beta0001/module.pt" OUTPUT_DIR = "../data" @@ -54,752 +47,239 @@ OUTPUT_DIR = "../data" DEFAULT_START_DATE = "2019-01-01" DEFAULT_END_DATE = "2025-12-31" -# Expected VAE input dimension -# Based on original pipeline: -# - feature: 158 alpha158 + 158 alpha158_ntrl = 316 -# - feature_ext: 7 market_ext + 7 market_ext_ntrl = 14 -# - feature_flag: 11 columns (after ColumnRemover, FlagMarketInjector; excluding IsST) -# Total: 316 + 14 + 11 = 341 -# -# NOTE: The VAE model encode() function takes feature + feature_ext + feature_flag groups -# (indus_idx is NOT included in VAE input) -VAE_INPUT_DIM = 341 - -# Industry flag columns (29 one-hot columns) -INDUSTRY_FLAG_COLS = [ - 'gds_CC10', 'gds_CC11', 'gds_CC12', 'gds_CC20', 'gds_CC21', 'gds_CC22', - 'gds_CC23', 'gds_CC24', 'gds_CC25', 'gds_CC26', 'gds_CC27', 'gds_CC28', - 'gds_CC30', 'gds_CC31', 'gds_CC32', 'gds_CC33', 'gds_CC34', 'gds_CC35', - 'gds_CC36', 'gds_CC37', 'gds_CC40', 'gds_CC41', 'gds_CC42', 'gds_CC43', - 'gds_CC50', 'gds_CC60', 'gds_CC61', 'gds_CC62', 'gds_CC63', 'gds_CC70' -] - - -def filter_stock_universe(df: pl.DataFrame, instruments: str = 'csiallx') -> pl.DataFrame: - """ - Filter dataframe to csiallx stock universe (A-shares excluding STAR/BSE) using qshare spine functions. - This uses qshare's filter_instruments which loads the instrument list from: - /data/qlib/default/data_ops/target/instruments/csiallx.txt +def load_vae_model(model_path: str) -> nn.Module: + """ + Load the VAE model from file. Args: - df: Input DataFrame with datetime and instrument columns - instruments: Market name for spine creation (default: 'csiallx') + model_path: Path to the pickled VAE model Returns: - Filtered DataFrame with only instruments in the specified universe + Loaded VAE model in eval mode on CPU """ - from qshare.algo.polars.spine import filter_instruments - - # Use qshare's filter_instruments with csiallx market name - df = filter_instruments(df, instruments=instruments) - - return df - -# Alpha158 feature columns in EXPLICIT ORDER -# These are the 158 alpha158 features in the order they appear in the parquet file -# This order MUST match the order used when training the VAE model -ALPHA158_COLS = [ - 'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2', - 'OPEN0', 'HIGH0', 'LOW0', 'VWAP0', - 'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60', - 'MA5', 'MA10', 'MA20', 'MA30', 'MA60', - 'STD5', 'STD10', 'STD20', 'STD30', 'STD60', - 'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60', - 'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60', - 'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60', - 'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60', - 'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60', - 'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60', - 'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60', - 'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60', - 'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60', - 'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60', - 'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60', - 'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60', - 'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60', - 'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60', - 'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60', - 'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60', - 'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60', - 'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60', - 'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60', - 'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60', - 'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60', - 'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60', - 'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60', - 'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60', - 'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60', - 'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60' -] -# Verify we have 158 features -assert len(ALPHA158_COLS) == 158, f"Expected 158 alpha158 cols, got {len(ALPHA158_COLS)}" - -# Market extension columns - MUST match original qlib HANDLER_MARKET_EXT config -# Original config loads: -# 'Turnover as turnover', 'FreeTurnover as free_turnover', -# 'log(MarketValue) as log_size', 'con_rating_strength' -# -# We use lowercase names to match the original pipeline exactly. -# NOTE: con_rating_strength is not available in parquet, so we'll create it as zeros. -MARKET_EXT_RAW_COLS = ['Turnover', 'FreeTurnover', 'MarketValue'] # Raw columns from parquet -MARKET_EXT_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] # Final names - -# Market flag columns (before processors) -# According to HANDLER_MARKET_FLAG in qlib config: -# From stg_1day_wind_kline_adjusted: IsZt, IsDt, IsN, IsXD, IsXR, IsDR (boolean) -# From stg_1day_wind_market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop (boolean) -MARKET_FLAG_COLS_KLINE = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR'] -MARKET_FLAG_COLS_MARKET = ['open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'] - - -def get_date_partitions(start_date: str, end_date: str) -> List[str]: - """Generate a list of date partitions to load from Parquet.""" - start = datetime.strptime(start_date, "%Y-%m-%d") - end = datetime.strptime(end_date, "%Y-%m-%d") - - partitions = [] - current = start - while current <= end: - if current.weekday() < 5: - partitions.append(f"datetime={current.strftime('%Y%m%d')}") - current = datetime(current.year, current.month, current.day + 1) - - return partitions - - -def load_parquet_by_date_range( - base_path: str, - start_date: str, - end_date: str, - columns: Optional[List[str]] = None -) -> pl.DataFrame: - """Load parquet data filtered by date range.""" - start_int = int(start_date.replace("-", "")) - end_int = int(end_date.replace("-", "")) + print(f"Loading VAE model from {model_path}...") + + # Patch torch.load to use CPU + original_torch_load = torch.load + def cpu_torch_load(*args, **kwargs): + kwargs['map_location'] = 'cpu' + return original_torch_load(*args, **kwargs) + torch.load = cpu_torch_load try: - df = pl.scan_parquet(base_path) + with open(model_path, "rb") as fin: + model = pkl.load(fin) - # Filter by date range - df = df.filter( - (pl.col('datetime') >= start_int) & - (pl.col('datetime') <= end_int) - ) + model.eval() + print(f"Loaded VAE model: {model.__class__.__name__}") + print(f" Input size: {model.input_size}") + print(f" Hidden size: {model.hidden_size}") - # Select specific columns if provided - if columns: - available_cols = ['instrument', 'datetime'] + [c for c in columns if c not in ['instrument', 'datetime']] - df = df.select(available_cols) + return model - return df.collect() - except Exception as e: - print(f"Error loading from {base_path}: {e}") - return pl.DataFrame() + finally: + torch.load = original_torch_load -def load_all_data( - start_date: str, - end_date: str -) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame]: +def encode_with_vae(features: np.ndarray, model: nn.Module, batch_size: int = 5000) -> np.ndarray: """ - Load all data sources from Parquet. - - According to original HANDLER_MARKET_EXT and HANDLER_MARKET_FLAG configs: - - alpha158: 158 features - - market_ext: turnover, free_turnover, log_size (=log(MarketValue)), con_rating_strength - - market_flag: IsZt, IsDt, IsN, IsXD, IsXR, IsDR + open_limit, close_limit, low_limit, open_stop, close_stop, high_stop - - indus_flag: 29 industry flags + Encode features using the VAE model. - NOTE: con_rating_strength is not available in parquet, so we create it as zeros (placeholder). + Args: + features: Input features of shape (n_samples, VAE_INPUT_DIM) + model: VAE model with encode() method + batch_size: Batch size for inference Returns: - Tuple of (alpha158_df, market_ext_df, market_flag_df, industry_df) + Embeddings of shape (n_samples, 32) """ - print(f"Loading data from {start_date} to {end_date}...") - - # 1. Load Alpha158 beta factors (158 features) - print("Loading alpha158_0_7_beta factors...") - df_alpha = load_parquet_by_date_range(PARQUET_ALPHA158_BETA_PATH, start_date, end_date) - print(f" Alpha158 shape: {df_alpha.shape}") - - # 2. Load Kline data for market_ext columns - # Original config: 'Turnover as turnover', 'FreeTurnover as free_turnover', - # 'log(MarketValue) as log_size', 'con_rating_strength' - # We load raw columns and transform them - print("Loading kline data (market ext columns)...") - kline_cols = ['Turnover', 'FreeTurnover', 'MarketValue'] - df_kline = load_parquet_by_date_range(PARQUET_KLINE_PATH, start_date, end_date, kline_cols) - print(f" Kline (market ext raw) shape: {df_kline.shape}") - - # 3. Load con_rating_strength from parquet - print("Loading con_rating_strength from parquet...") - df_con_rating = load_parquet_by_date_range( - PARQUET_CON_RATING_PATH, start_date, end_date, ['con_rating_strength'] - ) - print(f" Con rating shape: {df_con_rating.shape}") - - # Transform market_ext columns to match original pipeline: - # - Turnover -> turnover (rename) - # - FreeTurnover -> free_turnover (rename) - # - MarketValue -> log_size = log(MarketValue) - # - con_rating_strength: loaded from parquet (will merge below) - print("Transforming market_ext columns...") - df_kline = df_kline.with_columns([ - pl.col('Turnover').alias('turnover'), - pl.col('FreeTurnover').alias('free_turnover'), - pl.col('MarketValue').log().alias('log_size'), - ]) - print(f" Kline (market ext transformed) shape: {df_kline.shape}") - - # Merge con_rating_strength into kline dataframe - df_kline = df_kline.join(df_con_rating, on=['instrument', 'datetime'], how='left') - # Fill NaN with 0 for instruments/dates without con_rating data - df_kline = df_kline.with_columns([ - pl.col('con_rating_strength').fill_null(0.0) - ]) - print(f" Kline (with con_rating) shape: {df_kline.shape}") - - # 4. Load Market Flag data from kline_adjusted (all 6 columns) - print("Loading market flags from kline_adjusted...") - kline_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR'] - df_kline_flag = load_parquet_by_date_range(PARQUET_KLINE_PATH, start_date, end_date, kline_flag_cols) - print(f" Kline flags shape: {df_kline_flag.shape}") - - # 5. Load Market Flag data from market_flag table (ALL 6 columns as per original config) - print("Loading market flags from market_flag table (6 cols)...") - market_flag_cols = ['open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'] - df_market_flag = load_parquet_by_date_range(PARQUET_MARKET_FLAG_PATH, start_date, end_date, market_flag_cols) - print(f" Market flag shape: {df_market_flag.shape}") - - # 6. Load Industry flags - print("Loading industry flags...") - df_industry = load_parquet_by_date_range(PARQUET_INDUSTRY_FLAG_PATH, start_date, end_date, INDUSTRY_FLAG_COLS) - print(f" Industry shape: {df_industry.shape}") - - # Merge kline flag and market flag - df_flag = df_kline_flag.join(df_market_flag, on=['instrument', 'datetime'], how='inner') - print(f" Combined flags shape: {df_flag.shape}") - - return df_alpha, df_kline, df_flag, df_industry - - -def merge_data_sources( - df_alpha: pl.DataFrame, - df_kline: pl.DataFrame, - df_flag: pl.DataFrame, - df_industry: pl.DataFrame -) -> pl.DataFrame: - """Merge all data sources on instrument and datetime.""" - print("Merging data sources...") - - # Start with alpha158 - df = df_alpha - - # Merge kline data (market_ext with transformed columns) - # df_kline now has: turnover, free_turnover, log_size, con_rating_strength - df = df.join(df_kline, on=['instrument', 'datetime'], how='inner') - - # Merge flags (kline_flag + market_flag) - df = df.join(df_flag, on=['instrument', 'datetime'], how='inner') - - # Merge industry flags - df = df.join(df_industry, on=['instrument', 'datetime'], how='inner') - - print(f"Merged data shape (before filter): {df.shape}") - - # Apply stock universe filter to match csiallx universe - # This is CRITICAL for correct industry neutralization: - # - Must use the same stock universe as the original pipeline - # - Industry means are calculated per datetime across this universe - df = filter_stock_universe(df) - - print(f"Merged data shape (after csiallx filter): {df.shape}") - return df - + print(f"Encoding {features.shape[0]} samples with VAE...") -class DiffProcessor: - """ - Diff Processor: Calculate diff features for market_ext columns. - For each column in feature_ext, calculate diff with period=1 within each instrument group. - """ - def __init__(self, columns: List[str]): - self.columns = columns + device = torch.device('cpu') + model = model.to(device) + model.eval() - def process(self, df: pl.DataFrame) -> pl.DataFrame: - """Add diff features for specified columns.""" - print("Applying Diff processor...") + all_embeddings = [] - # Sort by instrument and datetime - df = df.sort(['instrument', 'datetime']) + with torch.no_grad(): + for i in range(0, len(features), batch_size): + batch = features[i:i + batch_size] + batch_tensor = torch.tensor(batch, dtype=torch.float32, device=device) - # Add diff for each column - for col in self.columns: - if col in df.columns: - diff_col = f"{col}_diff" - df = df.with_columns([ - pl.col(col) - .diff() - .over('instrument') - .alias(diff_col) - ]) + # Use model.encode() to get mu (the embedding) + mu, _ = model.encode(batch_tensor) - return df + # Convert to numpy + embeddings_np = mu.cpu().numpy() + all_embeddings.append(embeddings_np) + if (i // batch_size + 1) % 10 == 0: + print(f" Processed {min(i + batch_size, len(features))}/{len(features)} samples...") -class FlagMarketInjector: - """ - Flag Market Injector: Create market_0, market_1 columns based on instrument code. + embeddings = np.concatenate(all_embeddings, axis=0) + print(f"Generated embeddings shape: {embeddings.shape}") - Maps to Qlib's map_market_sec logic with vocab_size=2: - - market_0 (主板): SH60xxx, SZ00xxx - - market_1 (科创板/创业板): SH688xxx, SH689xxx, SZ300xxx, SZ301xxx + return embeddings - NOTE: vocab_size=2 (not 3!) - the original qlib pipeline does NOT include - 新三板/北交所 (NE4xxxx, NE8xxxx) in the market classification. - This uses the gds encoding where: - - 6xxxxx -> SH main board - - 0xxxxx, 3xxxxx -> SZ (main/ChiNext) - - 4xxxxx, 8xxxxx -> NE (新三板/北交所) - NOT included in vocab_size=2 +def prepare_vae_features( + feature_groups: FeatureGroups, + exclude_isst: bool = True +) -> Tuple[np.ndarray, List[str]]: """ - def process(self, df: pl.DataFrame) -> pl.DataFrame: - """Add market_0, market_1 columns.""" - print("Applying FlagMarketInjector (vocab_size=2)...") + Prepare features for VAE encoding from FeatureGroups. - # Convert instrument to string and pad to 6 digits - inst_str = pl.col('instrument').cast(pl.String).str.zfill(6) - - # Determine market type based on first digit - # vocab_size=2: only market_0 (主板) and market_1 (科创/创业) - is_sh_main = inst_str.str.starts_with('6') # SH600xxx, SH601xxx, etc. - is_sz_main = inst_str.str.starts_with('0') | inst_str.str.starts_with('00') # SZ000xxx - is_sh_star = inst_str.str.starts_with('688') | inst_str.str.starts_with('689') # SH688xxx, SH689xxx - is_sz_chi = inst_str.str.starts_with('300') | inst_str.str.starts_with('301') # SZ300xxx, SZ301xxx - - df = df.with_columns([ - # market_0 = 主板 (SH main + SZ main) - (is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'), - # market_1 = 科创板 + 创业板 (SH star + SZ ChiNext) - (is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1') - ]) + VAE input structure (341 features): + - feature group (316): 158 alpha158 + 158 alpha158_ntrl + - feature_ext group (14): 7 market_ext + 7 market_ext_ntrl + - feature_flag group (11): market flags (excluding IsST) - return df + NOTE: indus_idx is NOT included in VAE input. + Args: + feature_groups: Transformed FeatureGroups container + exclude_isst: Whether to exclude IsST from VAE input -class ColumnRemover: - """ - Column Remover: Drop specific columns. - Removes: log_size_diff (TotalValue_diff), IsN, IsZt, IsDt + Returns: + Tuple of (features numpy array, list of embedding column names) """ - def __init__(self, columns_to_remove: List[str]): - self.columns_to_remove = columns_to_remove - - def process(self, df: pl.DataFrame) -> pl.DataFrame: - """Remove specified columns.""" - print(f"Applying ColumnRemover (removing {len(self.columns_to_remove)} columns)...") - - # Only remove columns that exist - cols_to_drop = [c for c in self.columns_to_remove if c in df.columns] - if cols_to_drop: - df = df.drop(cols_to_drop) - - return df + print("Preparing features for VAE...") + # Merge all groups for final feature extraction + df = feature_groups.merge_for_processors() -class FlagToOnehot: - """ - Flag To Onehot: Convert 29 one-hot industry columns to single indus_idx. - For each row, find which industry column is True/1 and set indus_idx to that index. - """ - def __init__(self, industry_cols: List[str]): - self.industry_cols = industry_cols - - def process(self, df: pl.DataFrame) -> pl.DataFrame: - """Convert industry flags to single indus_idx column.""" - print("Applying FlagToOnehot (converting 29 industry flags to indus_idx)...") + # Build alpha158 feature columns + alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS] + alpha158_cols = ALPHA158_COLS.copy() - # Build a when/then chain to find the industry index - # Start with -1 (no industry) as default - indus_expr = pl.lit(-1) + # Build market_ext feature columns (with diff, minus removed columns) + market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS] + market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE] + market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff] + market_ext_cols = market_ext_with_diff.copy() - for idx, col in enumerate(self.industry_cols): - if col in df.columns: - indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr) + # VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext] + norm_feature_cols = ( + alpha158_ntrl_cols + alpha158_cols + + market_ext_ntrl_cols + market_ext_cols + ) - df = df.with_columns([indus_expr.alias('indus_idx')]) + # Market flag columns (excluding IsST if requested) + # After ColumnRemover removes IsN, IsZt, IsDt: + # - From kline_adjusted: IsXD, IsXR, IsDR (3 cols) + # - From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop (6 cols) + # - Added by FlagMarketInjector: market_0, market_1 (2 cols) + # - Added by FlagSTInjector: IsST (1 col, excluded from VAE) + # - Total: 3 + 6 + 2 = 11 flags (excluding IsST) + market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE] + market_flag_cols += ['market_0', 'market_1'] + if not exclude_isst: + market_flag_cols.append('IsST') + market_flag_cols = list(dict.fromkeys(market_flag_cols)) - # Drop the original one-hot columns - cols_to_drop = [c for c in self.industry_cols if c in df.columns] - if cols_to_drop: - df = df.drop(cols_to_drop) + # Combine all VAE input columns + vae_cols = norm_feature_cols + market_flag_cols - return df + print(f" norm_feature_cols: {len(norm_feature_cols)}") + print(f" market_flag_cols: {len(market_flag_cols)}") + print(f" Total VAE input columns: {len(vae_cols)}") + # Verify all columns exist + missing_cols = [c for c in vae_cols if c not in df.columns] + if missing_cols: + print(f"WARNING: Missing columns: {missing_cols}") + # Add missing columns as zeros + for col in missing_cols: + df = df.with_columns(pl.lit(0).alias(col)) -class IndusNtrlInjector: - """ - Industry Neutralization Injector: Industry neutralization for features. - For each feature, subtract the industry mean (grouped by indus_idx) from the feature value. - Creates new columns with "_ntrl" suffix while keeping original columns. + # Select features and convert to numpy + features_df = df.select(vae_cols) + features = features_df.to_numpy().astype(np.float32) - IMPORTANT: Industry neutralization must be done PER DATETIME (cross-sectional), - not across the entire dataset. This matches qlib's cal_indus_ntrl behavior. - """ - def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'): - self.feature_cols = feature_cols - self.suffix = suffix + # Handle any remaining NaN/Inf values + features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) - def process(self, df: pl.DataFrame) -> pl.DataFrame: - """Apply industry neutralization to specified features.""" - print(f"Applying IndusNtrlInjector to {len(self.feature_cols)} features...") + print(f"Feature matrix shape: {features.shape}") - # Filter to only columns that exist - existing_cols = [c for c in self.feature_cols if c in df.columns] + # Verify dimensions + if features.shape[1] != VAE_INPUT_DIM: + print(f"WARNING: Expected {VAE_INPUT_DIM} features, got {features.shape[1]}") + diff = VAE_INPUT_DIM - features.shape[1] + if diff > 0: + print(f" Difference: {diff} columns missing") + else: + print(f" Difference: {-diff} extra columns") - for col in existing_cols: - ntrl_col = f"{col}{self.suffix}" - # Calculate industry mean PER DATETIME and subtract from feature - # This is the CORRECT cross-sectional neutralization - df = df.with_columns([ - (pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'])).alias(ntrl_col) - ]) + # Generate embedding column names + embedding_cols = [f"embedding_{i}" for i in range(32)] - return df + return features, embedding_cols -class RobustZScoreNorm: +def prepare_vae_features_from_df( + df: pl.DataFrame, + exclude_isst: bool = True +) -> Tuple[np.ndarray, List[str]]: """ - Robust Z-Score Normalization: Per datetime normalization. - (x - median) / (1.4826 * MAD) where MAD = median(|x - median|) - Clip outliers at [-3, 3]. - - Can use pre-fitted parameters from qlib's pickled processor: - # Load from qlib pickle - with open('proc_list.proc', 'rb') as f: - proc_list = pickle.load(f) - zscore_proc = proc_list[7] # RobustZScoreNorm is 8th processor - - # Create with pre-fitted parameters - normalizer = RobustZScoreNorm( - feature_cols=feature_cols, - use_qlib_params=True, - qlib_mean=zscore_proc.mean_train, - qlib_std=zscore_proc.std_train - ) - """ - def __init__(self, feature_cols: List[str], - clip_range: Tuple[float, float] = (-3, 3), - use_qlib_params: bool = False, - qlib_mean: Optional[np.ndarray] = None, - qlib_std: Optional[np.ndarray] = None): - self.feature_cols = feature_cols - self.clip_range = clip_range - self.use_qlib_params = use_qlib_params - self.mean_train = qlib_mean - self.std_train = qlib_std - - if use_qlib_params: - if qlib_mean is None or qlib_std is None: - raise ValueError("Must provide qlib_mean and qlib_std when use_qlib_params=True") - print(f"Using pre-fitted qlib parameters (mean shape: {qlib_mean.shape}, std shape: {qlib_std.shape})") - - def process(self, df: pl.DataFrame) -> pl.DataFrame: - """Apply robust z-score normalization.""" - print(f"Applying RobustZScoreNorm to {len(self.feature_cols)} features...") - - # Filter to only columns that exist - existing_cols = [c for c in self.feature_cols if c in df.columns] - - if self.use_qlib_params: - # Use pre-fitted parameters from qlib (fit once, apply to all dates) - for i, col in enumerate(existing_cols): - if i < len(self.mean_train): - mean_val = float(self.mean_train[i]) - std_val = float(self.std_train[i]) - - # Apply z-score normalization using pre-fitted params - df = df.with_columns([ - ((pl.col(col) - mean_val) / (std_val + 1e-8)) - .clip(self.clip_range[0], self.clip_range[1]) - .alias(col) - ]) - else: - # Compute per-datetime robust z-score (original behavior) - for col in existing_cols: - # First compute median per datetime as a new column - median_col = f"__median_{col}" - df = df.with_columns([ - pl.col(col).median().over('datetime').alias(median_col) - ]) - - # Then compute absolute deviation - abs_dev_col = f"__absdev_{col}" - df = df.with_columns([ - (pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col) - ]) - - # Compute MAD (median of absolute deviations) - mad_col = f"__mad_{col}" - df = df.with_columns([ - pl.col(abs_dev_col).median().over('datetime').alias(mad_col) - ]) - - # Compute robust z-score and clip - df = df.with_columns([ - ((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8)) - .clip(self.clip_range[0], self.clip_range[1]) - .alias(col) - ]) - - # Clean up temporary columns - df = df.drop([median_col, abs_dev_col, mad_col]) - - return df - - -class Fillna: - """ - Fill NaN: Fill all NaN values with 0 for numeric columns. - """ - def process(self, df: pl.DataFrame, feature_cols: List[str]) -> pl.DataFrame: - """Fill NaN values with 0 for specified columns.""" - print("Applying Fillna processor...") + Prepare features for VAE encoding from a merged DataFrame. - # Filter to only columns that exist and are numeric (not boolean) - existing_cols = [c for c in feature_cols if c in df.columns] - - for col in existing_cols: - # Check column dtype - dtype = df[col].dtype - if dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64, pl.UInt32, pl.UInt64]: - df = df.with_columns([pl.col(col).fill_null(0.0).fill_nan(0.0)]) - - return df + VAE input structure (341 features): + - feature group (316): 158 alpha158 + 158 alpha158_ntrl + - feature_ext group (14): 7 market_ext + 7 market_ext_ntrl + - feature_flag group (11): market flags (excluding IsST) + NOTE: indus_idx is NOT included in VAE input. -def apply_feature_pipeline(df: pl.DataFrame) -> Tuple[pl.DataFrame, List[str]]: - """ - Apply the full feature transformation pipeline. + Args: + df: Transformed merged DataFrame + exclude_isst: Whether to exclude IsST from VAE input Returns: - Tuple of (processed DataFrame, list of final feature columns) + Tuple of (features numpy array, list of embedding column names) """ - print("=" * 60) - print("Starting feature transformation pipeline") - print("=" * 60) + print("Preparing features for VAE...") - # Use EXPLICIT alpha158 column order (158 features) - # This order MUST match what the VAE was trained with + # Build alpha158 feature columns + alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS] alpha158_cols = ALPHA158_COLS.copy() - # market_ext: 4 features - MUST match original HANDLER_MARKET_EXT config - # Original: 'Turnover as turnover', 'FreeTurnover as free_turnover', - # 'log(MarketValue) as log_size', 'con_rating_strength' - # We already transformed these in load_all_data(), so use lowercase names - market_ext_base = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] - - # market_flag: ALL 12 columns as per original HANDLER_MARKET_FLAG config - # From kline_adjusted: IsZt, IsDt, IsN, IsXD, IsXR, IsDR (6 cols) - # From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop (6 cols) - market_flag_cols = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', - 'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'] - - print(f"Initial column counts:") - print(f" Alpha158 features: {len(alpha158_cols)}") - print(f" Market ext base: {len(market_ext_base)}") - print(f" Market flag: {len(market_flag_cols)}") - print(f" Industry flags: {len(INDUSTRY_FLAG_COLS)}") - - # Step 1: Diff Processor - adds diff features for market_ext - diff_processor = DiffProcessor(market_ext_base) - df = diff_processor.process(df) - - # After Diff: market_ext becomes 8 columns (4 base + 4 diff) - market_ext_cols = market_ext_base + [f"{c}_diff" for c in market_ext_base] - - # Step 2: FlagMarketInjector - adds market_0, market_1 (vocab_size=2) - flag_injector = FlagMarketInjector() - df = flag_injector.process(df) - - # After FlagMarketInjector: market_flag = 12 + 2 = 14 columns - market_flag_with_market = market_flag_cols + ['market_0', 'market_1'] - - # Step 3: FlagSTInjector - create IsST from ST flags - # Note: ST flags (ST_Y, ST_S) may not be available in parquet data. - # If available, IsST = ST_S | ST_Y; otherwise create placeholder (all zeros). - # This maintains compatibility with the VAE's expected input dimension. - print("Applying FlagSTInjector (creating IsST)...") - # Check if ST flags are available - if 'ST_S' in df.columns or 'st_flag::ST_S' in df.columns: - # Create IsST from actual ST flags - df = df.with_columns([ - ((pl.col('ST_S').cast(pl.Boolean, strict=False) | - pl.col('ST_Y').cast(pl.Boolean, strict=False)) - .cast(pl.Int8).alias('IsST')) - ]) - else: - # Create placeholder (all zeros) if ST flags not available - df = df.with_columns([ - pl.lit(0).cast(pl.Int8).alias('IsST') - ]) - market_flag_with_st = market_flag_with_market + ['IsST'] - - # Step 4: ColumnRemover - remove specific columns - # Qlib ColumnRemover removes: ['log_size_diff', 'IsN', 'IsZt', 'IsDt'] - columns_to_remove = ['log_size_diff', 'IsN', 'IsZt', 'IsDt'] - remover = ColumnRemover(columns_to_remove) - df = remover.process(df) - - # Update column lists after removal - market_ext_cols = [c for c in market_ext_cols if c not in columns_to_remove] - market_flag_with_st = [c for c in market_flag_with_st if c not in columns_to_remove] - - # Step 5: FlagToOnehot - convert 29 industry flags to indus_idx - flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS) - df = flag_to_onehot.process(df) - - print(f"After FlagToOnehot: industry flags -> indus_idx") - - # Step 6 & 7: IndusNtrlInjector - industry neutralization for alpha158 and market_ext - indus_ntrl_alpha = IndusNtrlInjector(alpha158_cols, suffix='_ntrl') - df = indus_ntrl_alpha.process(df) - - indus_ntrl_ext = IndusNtrlInjector(market_ext_cols, suffix='_ntrl') - df = indus_ntrl_ext.process(df) - - # After IndusNtrlInjector: each feature gets a _ntrl version - # IMPORTANT: qlib's IndusNtrlInjector with keep_origin=True produces columns in order - # [all _ntrl] + [all raw] for EACH feature group, NOT [all raw] + [all _ntrl] - # This is critical for matching the VAE training feature order! - alpha158_ntrl_cols = [f"{c}_ntrl" for c in alpha158_cols] - market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_cols] - - # Step 8: RobustZScoreNorm - robust z-score normalization - # Qlib applies RobustZScoreNorm ONLY to ['feature', 'feature_ext'] groups - # NOT to feature_flag columns (binary flags should not be normalized) - # NOT to indus_idx (single column industry index) - # - # Feature order MUST match what the VAE was trained with: - # [alpha158_ntrl (158), alpha158 (158), market_ext_ntrl (7), market_ext (7)] = 330 features - # This order comes from qlib's IndusNtrlInjector which outputs [ntrl] + [raw] for each group - norm_feature_cols = alpha158_ntrl_cols + alpha158_cols + market_ext_ntrl_cols + market_ext_cols - - print(f"Applying RobustZScoreNorm to {len(norm_feature_cols)} features...") - print(f" (Excluding {len(market_flag_with_st)} market flags and indus_idx)") - - # Load pre-fitted qlib parameters for consistent normalization - qlib_params = load_qlib_processor_params() - - # Verify parameter shape matches expected features - expected_features = len(norm_feature_cols) - if qlib_params['mean_train'].shape[0] != expected_features: - print(f"WARNING: Feature count mismatch! Expected {expected_features}, " - f"but qlib params have {qlib_params['mean_train'].shape[0]}") - print(f" This means the feature order/columns may not match what the VAE was trained with.") - - robust_norm = RobustZScoreNorm( - norm_feature_cols, - clip_range=(-3, 3), - use_qlib_params=True, - qlib_mean=qlib_params['mean_train'], - qlib_std=qlib_params['std_train'] - ) - df = robust_norm.process(df) - - # Step 9: Fillna - fill NaN with 0 for ALL feature columns - # This includes normalized features, market flags, and indus_idx - # - # IMPORTANT: IsST is a placeholder (all zeros) and should NOT be included in VAE input. - # The VAE was trained with 11 market flags (excluding IsST). - # - # Define final feature list first - final_feature_cols = norm_feature_cols + market_flag_with_st + ['indus_idx'] - - fillna = Fillna() - df = fillna.process(df, final_feature_cols) - - # Final feature list breakdown for VAE input: - # The VAE model takes feature, feature_ext, feature_flag groups (indus_idx is separate) - # After ColumnRemover removes IsN, IsZt, IsDt: - # - From kline_adjusted: IsXD, IsXR, IsDR (3 cols) - # - From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop (6 cols) - # - Added by FlagMarketInjector: market_0, market_1 (2 cols) - # - Added by FlagSTInjector: IsST (1 col, placeholder if ST flags not available) - # - Total market flags: 3 + 6 + 2 + 1 = 12 (IsST excluded from VAE input) - # - # Total features: - # - norm_feature_cols: 158 + 158 + 7 + 7 = 330 - # - market_flag_with_st: 12 (including IsST) - # - indus_idx: 1 - # - Total: 330 + 12 + 1 = 343 features - # - # VAE input dimension (feature + feature_ext + feature_flag only, no indus_idx): - # - 316 (alpha158 + ntrl) + 14 (market_ext + ntrl) + 11 (flags, excluding IsST) = 341 - - # Exclude IsST from VAE input features (it's a placeholder) - market_flag_for_vae = [c for c in market_flag_with_st if c != 'IsST'] - - print("=" * 60) - print(f"Pipeline complete. Final feature count: {len(final_feature_cols)}") - print(f"Expected VAE input dim: {VAE_INPUT_DIM}") - print(f" norm_feature_cols: {len(norm_feature_cols)}") - print(f" market_flag_for_vae (excluding IsST): {len(market_flag_for_vae)}") - print(f" indus_idx: 1") - print("=" * 60) - - # Verify we have the expected number of features - vae_feature_count = len(norm_feature_cols) + len(market_flag_for_vae) - if vae_feature_count != VAE_INPUT_DIM: - print(f"WARNING: Feature count mismatch! Expected {VAE_INPUT_DIM}, got {vae_feature_count}") - print(f"Difference: {vae_feature_count - VAE_INPUT_DIM} columns") - print(f"Market flag columns for VAE ({len(market_flag_for_vae)}): {market_flag_for_vae}") - else: - print(f"✓ Feature count matches VAE input dimension!") - - # Return additional lists needed for VAE feature preparation - return df, final_feature_cols, norm_feature_cols, market_flag_for_vae - + # Build market_ext feature columns (with diff, minus removed columns) + market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS] + market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE] + market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff] + market_ext_cols = market_ext_with_diff.copy() -def prepare_vae_features(df: pl.DataFrame, feature_cols: List[str], - norm_feature_cols: List[str], - market_flag_for_vae: List[str]) -> np.ndarray: - """ - Prepare features for VAE encoding. - Ensure we have exactly VAE_INPUT_DIM features in the correct order. - - VAE input structure (341 features): - - feature group (316): 158 alpha158 + 158 alpha158_ntrl - - feature_ext group (14): 7 market_ext + 7 market_ext_ntrl - - feature_flag group (11): market flags (excluding IsST which is a placeholder) - - NOTE: indus_idx is NOT included in VAE input (it's used separately by the model). + # VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext] + norm_feature_cols = ( + alpha158_ntrl_cols + alpha158_cols + + market_ext_ntrl_cols + market_ext_cols + ) - Args: - df: Processed DataFrame - feature_cols: All feature columns (including indus_idx and IsST) - norm_feature_cols: Normalized feature columns (330 features) - market_flag_for_vae: Market flag columns for VAE (11 features, excluding IsST) - """ - print("Preparing features for VAE...") + # Market flag columns (excluding IsST if requested) + market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE] + market_flag_cols += ['market_0', 'market_1'] + if not exclude_isst: + market_flag_cols.append('IsST') + market_flag_cols = list(dict.fromkeys(market_flag_cols)) - # Construct VAE input columns explicitly in correct order: - # [norm_feature_cols (330), market_flag_for_vae (11)] = 341 total - vae_cols = norm_feature_cols + market_flag_for_vae + # Combine all VAE input columns + vae_cols = norm_feature_cols + market_flag_cols print(f" norm_feature_cols: {len(norm_feature_cols)}") - print(f" market_flag_for_vae: {len(market_flag_for_vae)}") + print(f" market_flag_cols: {len(market_flag_cols)}") print(f" Total VAE input columns: {len(vae_cols)}") # Verify all columns exist missing_cols = [c for c in vae_cols if c not in df.columns] if missing_cols: print(f"WARNING: Missing columns: {missing_cols}") + # Add missing columns as zeros + for col in missing_cols: + df = df.with_columns(pl.lit(0).alias(col)) - # Select features + # Select features and convert to numpy features_df = df.select(vae_cols) - - # Convert to numpy features = features_df.to_numpy().astype(np.float32) # Handle any remaining NaN/Inf values @@ -810,89 +290,37 @@ def prepare_vae_features(df: pl.DataFrame, feature_cols: List[str], # Verify dimensions if features.shape[1] != VAE_INPUT_DIM: print(f"WARNING: Expected {VAE_INPUT_DIM} features, got {features.shape[1]}") - - if features.shape[1] < VAE_INPUT_DIM: - # Pad with zeros - padding = np.zeros((features.shape[0], VAE_INPUT_DIM - features.shape[1]), dtype=np.float32) - features = np.concatenate([features, padding], axis=1) - print(f"Padded to shape: {features.shape}") + diff = VAE_INPUT_DIM - features.shape[1] + if diff > 0: + print(f" Difference: {diff} columns missing") else: - # Truncate - features = features[:, :VAE_INPUT_DIM] - print(f"Truncated to shape: {features.shape}") - - return features - - -def load_vae_model(model_path: str) -> nn.Module: - """ - Load the VAE model from file. - """ - print(f"Loading VAE model from {model_path}...") - - # Patch torch.load to use CPU - original_torch_load = torch.load - def cpu_torch_load(*args, **kwargs): - kwargs['map_location'] = 'cpu' - return original_torch_load(*args, **kwargs) - torch.load = cpu_torch_load - - try: - with open(model_path, "rb") as fin: - model = pkl.load(fin) - - model.eval() - print(f"Loaded VAE model: {model.__class__.__name__}") - print(f" Input size: {model.input_size}") - print(f" Hidden size: {model.hidden_size}") - - return model - - finally: - torch.load = original_torch_load - + print(f" Difference: {-diff} extra columns") -def encode_with_vae(features: np.ndarray, model: nn.Module, batch_size: int = 5000) -> np.ndarray: - """ - Encode features using the VAE model. - """ - print(f"Encoding {features.shape[0]} samples with VAE...") - - device = torch.device('cpu') - model = model.to(device) - model.eval() - - all_embeddings = [] - - with torch.no_grad(): - for i in range(0, len(features), batch_size): - batch = features[i:i + batch_size] - batch_tensor = torch.tensor(batch, dtype=torch.float32, device=device) - - # Use model.encode() to get mu (the embedding) - mu, _ = model.encode(batch_tensor) - - # Convert to numpy - embeddings_np = mu.cpu().numpy() - all_embeddings.append(embeddings_np) - - if (i // batch_size + 1) % 10 == 0: - print(f" Processed {min(i + batch_size, len(features))}/{len(features)} samples...") - - embeddings = np.concatenate(all_embeddings, axis=0) - print(f"Generated embeddings shape: {embeddings.shape}") + # Generate embedding column names + embedding_cols = [f"embedding_{i}" for i in range(32)] - return embeddings + return features, embedding_cols def generate_embeddings( start_date: str = DEFAULT_START_DATE, end_date: str = DEFAULT_END_DATE, output_file: Optional[str] = None, - use_vae: bool = True + use_vae: bool = True, + robust_zscore_params_path: Optional[str] = None ) -> pl.DataFrame: """ Main function to generate embeddings from alpha158_0_7_beta factors. + + Args: + start_date: Start date in YYYY-MM-DD format + end_date: End date in YYYY-MM-DD format + output_file: Optional output parquet file path + use_vae: Whether to use VAE encoding (or random embeddings) + robust_zscore_params_path: Optional path to robust zscore parameters + + Returns: + DataFrame with datetime, instrument, and embedding columns """ print("=" * 60) print(f"Generating Alpha158 0_7 Beta Embeddings") @@ -900,25 +328,23 @@ def generate_embeddings( print(f"Use VAE: {use_vae}") print("=" * 60) - # Load all data sources - df_alpha, df_kline, df_flag, df_industry = load_all_data(start_date, end_date) + # Initialize pipeline + pipeline = FeaturePipeline( + robust_zscore_params_path=robust_zscore_params_path + ) - # Merge data sources - df = merge_data_sources(df_alpha, df_kline, df_flag, df_industry) + # Load data + feature_groups = pipeline.load_data(start_date, end_date) - # Get datetime and instrument columns before processing - datetime_col = df['datetime'].clone() - instrument_col = df['instrument'].clone() + # Apply transformations - get merged DataFrame + df_transformed = pipeline.transform(feature_groups) - # Apply feature transformation pipeline - df_processed, feature_cols, norm_feature_cols, market_flag_for_vae = apply_feature_pipeline(df) + # Get datetime and instrument columns from merged DataFrame + datetime_col = df_transformed['datetime'].to_list() + instrument_col = df_transformed['instrument'].to_list() - # Prepare features for VAE - features = prepare_vae_features( - df_processed, feature_cols, - norm_feature_cols=norm_feature_cols, - market_flag_for_vae=market_flag_for_vae - ) + # Prepare VAE input features from DataFrame + features, embedding_cols = prepare_vae_features_from_df(df_transformed) # Encode with VAE if use_vae: @@ -938,11 +364,9 @@ def generate_embeddings( embeddings = np.random.randn(features.shape[0], 32).astype(np.float32) # Create output DataFrame - embedding_cols = [f"embedding_{i}" for i in range(embeddings.shape[1])] - result_data = { - 'datetime': datetime_col.to_list(), - 'instrument': instrument_col.to_list() + 'datetime': datetime_col, + 'instrument': instrument_col } for i, col_name in enumerate(embedding_cols): result_data[col_name] = embeddings[:, i].tolist() @@ -961,16 +385,17 @@ def generate_embeddings( return df_result -def load_qlib_processor_params(proc_path: str = None) -> Dict[str, np.ndarray]: +def load_qlib_processor_params( + proc_path: str = None +) -> dict: """ Load pre-fitted processor parameters from qlib's pickle file. - This demonstrates how to extract the fitted mean/std from qlib's - RobustZScoreNorm processor for use in standalone code. + This is kept for backwards compatibility and reference. + The new pipeline uses load_robust_zscore_params() instead. Args: - proc_path: Path to qlib's proc_list.proc file. - If None, uses the path from the original VAE model. + proc_path: Path to qlib's proc_list.proc file Returns: Dictionary with 'mean_train' and 'std_train' numpy arrays @@ -983,7 +408,7 @@ def load_qlib_processor_params(proc_path: str = None) -> Dict[str, np.ndarray]: with open(proc_path, "rb") as fin: proc_list = pkl.load(fin) - # Find RobustZScoreNorm processor (index 7 in the list) + # Find RobustZScoreNorm processor zscore_proc = None for proc in proc_list: if type(proc).__name__ == "RobustZScoreNorm": @@ -1008,50 +433,32 @@ def load_qlib_processor_params(proc_path: str = None) -> Dict[str, np.ndarray]: return params -# Example usage function -def generate_embeddings_with_qlib_params( - start_date: str = DEFAULT_START_DATE, - end_date: str = DEFAULT_END_DATE, - output_file: Optional[str] = None -) -> pl.DataFrame: - """ - Example of how to use pre-fitted qlib parameters for normalization. - - This is an alternative to generate_embeddings() that uses the exact - same normalization parameters as the original qlib pipeline. - """ - # Load the pre-fitted parameters - qlib_params = load_qlib_processor_params() - - # Load data (same as regular pipeline) - df_alpha, df_kline, df_industry = load_all_data(start_date, end_date) - df = merge_data_sources(df_alpha, df_kline, df_industry) - - datetime_col = df['datetime'].clone() - instrument_col = df['instrument'].clone() - - # Process through pipeline, but use qlib params for normalization - # (This would require modifying apply_feature_pipeline to accept params) - # For now, this is a demonstration of the pattern - - print("\nNote: To use qlib params, modify apply_feature_pipeline() to accept") - print("qlib_mean and qlib_std arguments and pass them to RobustZScoreNorm") - - return df - - if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description="Generate embeddings from alpha158_0_7_beta factors") - parser.add_argument("--start-date", type=str, default=DEFAULT_START_DATE, - help="Start date (YYYY-MM-DD)") - parser.add_argument("--end-date", type=str, default=DEFAULT_END_DATE, - help="End date (YYYY-MM-DD)") - parser.add_argument("--output", type=str, default=None, - help="Output parquet file path") - parser.add_argument("--no-vae", action="store_true", - help="Skip VAE encoding (use random embeddings for testing)") + parser = argparse.ArgumentParser( + description="Generate embeddings from alpha158_0_7_beta factors" + ) + parser.add_argument( + "--start-date", type=str, default=DEFAULT_START_DATE, + help="Start date (YYYY-MM-DD)" + ) + parser.add_argument( + "--end-date", type=str, default=DEFAULT_END_DATE, + help="End date (YYYY-MM-DD)" + ) + parser.add_argument( + "--output", type=str, default=None, + help="Output parquet file path" + ) + parser.add_argument( + "--no-vae", action="store_true", + help="Skip VAE encoding (use random embeddings for testing)" + ) + parser.add_argument( + "--robust-zscore-params", type=str, default=None, + help="Path to robust zscore parameters directory" + ) args = parser.parse_args() @@ -1059,7 +466,8 @@ if __name__ == "__main__": start_date=args.start_date, end_date=args.end_date, output_file=args.output, - use_vae=not args.no_vae + use_vae=not args.no_vae, + robust_zscore_params_path=args.robust_zscore_params ) print("\nDone!") diff --git a/stock_1d/d033/alpha158_beta/src/__init__.py b/stock_1d/d033/alpha158_beta/src/__init__.py new file mode 100644 index 0000000..e599eca --- /dev/null +++ b/stock_1d/d033/alpha158_beta/src/__init__.py @@ -0,0 +1,53 @@ +""" +Source package for alpha158_beta experiments. + +This package provides modules for data loading, feature transformation, +and model training for alpha158 beta factor experiments. +""" + +from .processors import ( + FeaturePipeline, + FeatureGroups, + DiffProcessor, + FlagMarketInjector, + FlagSTInjector, + ColumnRemover, + FlagToOnehot, + IndusNtrlInjector, + RobustZScoreNorm, + Fillna, + load_alpha158, + load_market_ext, + load_market_flags, + load_industry_flags, + load_all_data, + load_robust_zscore_params, + filter_stock_universe, +) + +__all__ = [ + # Main pipeline + 'FeaturePipeline', + 'FeatureGroups', + + # Processors + 'DiffProcessor', + 'FlagMarketInjector', + 'FlagSTInjector', + 'ColumnRemover', + 'FlagToOnehot', + 'IndusNtrlInjector', + 'RobustZScoreNorm', + 'Fillna', + + # Loaders + 'load_alpha158', + 'load_market_ext', + 'load_market_flags', + 'load_industry_flags', + 'load_all_data', + + # Utilities + 'load_robust_zscore_params', + 'filter_stock_universe', +] diff --git a/stock_1d/d033/alpha158_beta/src/processors/__init__.py b/stock_1d/d033/alpha158_beta/src/processors/__init__.py new file mode 100644 index 0000000..3d94ddc --- /dev/null +++ b/stock_1d/d033/alpha158_beta/src/processors/__init__.py @@ -0,0 +1,136 @@ +""" +Processors package for the alpha158_beta feature pipeline. + +This package provides a modular, polars-native data loading and transformation +pipeline for generating VAE input features from alpha158 beta factors. + +Main components: +- FeatureGroups: Dataclass container for separate feature groups +- FeaturePipeline: Main orchestrator class for the full pipeline +- Processors: Individual transformation classes (DiffProcessor, IndusNtrlInjector, etc.) +- Loaders: Data loading functions for parquet sources + +Usage: + from processors import FeaturePipeline, FeatureGroups + + pipeline = FeaturePipeline() + feature_groups = pipeline.load_data(start_date, end_date) + transformed = pipeline.transform(feature_groups) + vae_input = pipeline.get_vae_input(transformed) +""" + +from .dataclass import FeatureGroups +from .loaders import ( + load_alpha158, + load_market_ext, + load_market_flags, + load_industry_flags, + load_all_data, + load_parquet_by_date_range, + get_date_partitions, + # Constants + PARQUET_ALPHA158_BETA_PATH, + PARQUET_KLINE_PATH, + PARQUET_MARKET_FLAG_PATH, + PARQUET_INDUSTRY_FLAG_PATH, + PARQUET_CON_RATING_PATH, + INDUSTRY_FLAG_COLS, + MARKET_FLAG_COLS_KLINE, + MARKET_FLAG_COLS_MARKET, + MARKET_EXT_RAW_COLS, +) +from .processors import ( + DiffProcessor, + FlagMarketInjector, + FlagSTInjector, + ColumnRemover, + FlagToOnehot, + IndusNtrlInjector, + RobustZScoreNorm, + Fillna, +) +from .pipeline import ( + FeaturePipeline, + load_robust_zscore_params, + filter_stock_universe, + # Constants + ALPHA158_COLS, + MARKET_EXT_BASE_COLS, + MARKET_FLAG_COLS, + COLUMNS_TO_REMOVE, + VAE_INPUT_DIM, + DEFAULT_ROBUST_ZSCORE_PARAMS_PATH, +) +from .exporters import ( + get_groups, + get_groups_from_fg, + pack_structs, + unpack_struct, + dump_to_parquet, + dump_to_pickle, + dump_to_numpy, + dump_features, + # Also import old names for backward compatibility + get_groups as select_feature_groups_from_df, + get_groups_from_fg as select_feature_groups, +) + +__all__ = [ + # Main classes + 'FeaturePipeline', + 'FeatureGroups', + + # Processors + 'DiffProcessor', + 'FlagMarketInjector', + 'FlagSTInjector', + 'ColumnRemover', + 'FlagToOnehot', + 'IndusNtrlInjector', + 'RobustZScoreNorm', + 'Fillna', + + # Loaders + 'load_alpha158', + 'load_market_ext', + 'load_market_flags', + 'load_industry_flags', + 'load_all_data', + 'load_parquet_by_date_range', + 'get_date_partitions', + + # Utility functions + 'load_robust_zscore_params', + 'filter_stock_universe', + + # Exporter functions + 'get_groups', + 'get_groups_from_fg', + 'pack_structs', + 'unpack_struct', + 'dump_to_parquet', + 'dump_to_pickle', + 'dump_to_numpy', + 'dump_features', + + # Backward compatibility aliases + 'select_feature_groups_from_df', + 'select_feature_groups', + + # Constants + 'PARQUET_ALPHA158_BETA_PATH', + 'PARQUET_KLINE_PATH', + 'PARQUET_MARKET_FLAG_PATH', + 'PARQUET_INDUSTRY_FLAG_PATH', + 'PARQUET_CON_RATING_PATH', + 'INDUSTRY_FLAG_COLS', + 'MARKET_FLAG_COLS_KLINE', + 'MARKET_FLAG_COLS_MARKET', + 'MARKET_EXT_RAW_COLS', + 'ALPHA158_COLS', + 'MARKET_EXT_BASE_COLS', + 'MARKET_FLAG_COLS', + 'COLUMNS_TO_REMOVE', + 'VAE_INPUT_DIM', + 'DEFAULT_ROBUST_ZSCORE_PARAMS_PATH', +] diff --git a/stock_1d/d033/alpha158_beta/src/processors/dataclass.py b/stock_1d/d033/alpha158_beta/src/processors/dataclass.py new file mode 100644 index 0000000..043f45c --- /dev/null +++ b/stock_1d/d033/alpha158_beta/src/processors/dataclass.py @@ -0,0 +1,115 @@ +"""Dataclass definitions for the feature pipeline.""" + +from dataclasses import dataclass, field +from typing import Optional, List +import polars as pl + +# Import constants for column categorization +# Alpha158 base columns (158 features) +ALPHA158_BASE_COLS = [ + 'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2', + 'OPEN0', 'HIGH0', 'LOW0', 'VWAP0', + 'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60', + 'MA5', 'MA10', 'MA20', 'MA30', 'MA60', + 'STD5', 'STD10', 'STD20', 'STD30', 'STD60', + 'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60', + 'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60', + 'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60', + 'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60', + 'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60', + 'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60', + 'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60', + 'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60', + 'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60', + 'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60', + 'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60', + 'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60', + 'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60', + 'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60', + 'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60', + 'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60', + 'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60', + 'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60', + 'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60', + 'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60', + 'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60', + 'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60', + 'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60', + 'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60', + 'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60', + 'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60' +] + +# Market extension base columns +MARKET_EXT_BASE_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + +# Market flag base columns (before processors) +MARKET_FLAG_BASE_COLS = [ + 'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', + 'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop' +] + + +@dataclass +class FeatureGroups: + """ + Container for separate feature groups in the pipeline. + + Keeps feature groups separate throughout the pipeline to avoid + unnecessary merging and complex column management. + + Attributes: + alpha158: 158 alpha158 features (+ _ntrl after industry neutralization) + market_ext: Market extension features (Turnover, FreeTurnover, MarketValue, etc.) + market_flag: Market flag features (IsZt, IsDt, IsXD, etc.) + industry: 29 industry flags (converted to indus_idx by FlagToOnehot) + indus_idx: Single column industry index (after FlagToOnehot processing) + instruments: List of instrument IDs for metadata + dates: List of datetime values for metadata + """ + # Core feature groups + alpha158: pl.DataFrame # 158 alpha158 features (+ _ntrl after processing) + market_ext: pl.DataFrame # Market extension features (+ _ntrl after processing) + market_flag: pl.DataFrame # Market flags (12 cols initially, 11 after ColumnRemover) + industry: Optional[pl.DataFrame] = None # 29 industry flags -> indus_idx + + # Processed industry index (separate after FlagToOnehot) + indus_idx: Optional[pl.DataFrame] = None # Single column after FlagToOnehot + + # Metadata (extracted from dataframes for easy access) + instruments: List[str] = field(default_factory=list) + dates: List[int] = field(default_factory=list) + + def extract_metadata(self) -> None: + """Extract instrument and date lists from the alpha158 dataframe.""" + if self.alpha158 is not None and len(self.alpha158) > 0: + self.instruments = self.alpha158['instrument'].to_list() + self.dates = self.alpha158['datetime'].to_list() + + def merge_for_processors(self) -> pl.DataFrame: + """ + Merge all feature groups into a single DataFrame for processors. + + This is used by processors that need access to multiple groups + (e.g., IndusNtrlInjector needs industry index). + + Returns: + Merged DataFrame with all features + """ + df = self.alpha158 + + # Merge market_ext if not already merged + if self.market_ext is not None and self.market_ext is not self.alpha158: + df = df.join(self.market_ext, on=['instrument', 'datetime'], how='left') + + # Merge market_flag if not already merged + if self.market_flag is not None and self.market_flag is not self.alpha158: + df = df.join(self.market_flag, on=['instrument', 'datetime'], how='left') + + # Merge industry/indus_idx if available + if self.indus_idx is not None: + df = df.join(self.indus_idx, on=['instrument', 'datetime'], how='left') + elif self.industry is not None: + df = df.join(self.industry, on=['instrument', 'datetime'], how='left') + + return df diff --git a/stock_1d/d033/alpha158_beta/src/processors/exporters.py b/stock_1d/d033/alpha158_beta/src/processors/exporters.py new file mode 100644 index 0000000..a656442 --- /dev/null +++ b/stock_1d/d033/alpha158_beta/src/processors/exporters.py @@ -0,0 +1,523 @@ +""" +Feature exporters for the alpha158_beta pipeline. + +This module provides functions to select and export feature groups from the +transformed pipeline output. It can be used by both dump_features.py and +generate_beta_embedding.py. + +Feature groups: +- merged: All columns after transformation +- alpha158: Alpha158 features + _ntrl versions +- market_ext: Market extended features + _ntrl + _diff +- market_flag: Market flag columns +- vae_input: 341 features specifically curated for VAE training + +Struct mode: +- When pack_struct=True, each feature group is packed into a struct column: + - features_alpha158 (316 fields) + - features_market_ext (14 fields) + - features_market_flag (11 fields) +""" + +import os +import pickle +from pathlib import Path +from typing import Dict, Any, List, Optional, Union + +import numpy as np +import polars as pl + +from .pipeline import ( + ALPHA158_COLS, + MARKET_EXT_BASE_COLS, + MARKET_FLAG_COLS, + COLUMNS_TO_REMOVE, +) +from .dataclass import FeatureGroups + + +# ============================================================================= +# Helper functions for struct packing/unpacking +# ============================================================================= + +def pack_structs(df: pl.DataFrame) -> pl.DataFrame: + """ + Pack feature columns into struct columns based on feature groups. + + Creates: + - features_alpha158: struct with 316 fields (158 + 158 _ntrl) + - features_market_ext: struct with 14 fields (7 + 7 _ntrl) + - features_market_flag: struct with 11 fields + + Args: + df: Input DataFrame with flat columns + + Returns: + DataFrame with struct columns: instrument, datetime, indus_idx, features_* + """ + # Define column groups + alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS] + alpha158_all_cols = alpha158_ntrl_cols + ALPHA158_COLS + + market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS] + market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE] + market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff] + market_ext_all_cols = market_ext_ntrl_cols + market_ext_with_diff + + market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE] + market_flag_cols += ['market_0', 'market_1', 'IsST'] + market_flag_cols = list(dict.fromkeys(market_flag_cols)) + + # Build result with struct columns + result_cols = ['instrument', 'datetime'] + + # Check if indus_idx exists + if 'indus_idx' in df.columns: + result_cols.append('indus_idx') + + # Pack alpha158 + alpha158_cols_in_df = [c for c in alpha158_all_cols if c in df.columns] + if alpha158_cols_in_df: + result_cols.append(pl.struct(alpha158_cols_in_df).alias('features_alpha158')) + + # Pack market_ext + ext_cols_in_df = [c for c in market_ext_all_cols if c in df.columns] + if ext_cols_in_df: + result_cols.append(pl.struct(ext_cols_in_df).alias('features_market_ext')) + + # Pack market_flag + flag_cols_in_df = [c for c in market_flag_cols if c in df.columns] + if flag_cols_in_df: + result_cols.append(pl.struct(flag_cols_in_df).alias('features_market_flag')) + + return df.select(result_cols) + + +def unpack_struct(df: pl.DataFrame, struct_name: str) -> pl.DataFrame: + """ + Unpack a struct column back into individual columns. + + Args: + df: DataFrame containing struct column + struct_name: Name of the struct column to unpack + + Returns: + DataFrame with struct fields as individual columns + """ + if struct_name not in df.columns: + raise ValueError(f"Struct column '{struct_name}' not found in DataFrame") + + # Get the struct field names + struct_dtype = df.schema[struct_name] + if not isinstance(struct_dtype, pl.Struct): + raise ValueError(f"Column '{struct_name}' is not a struct type") + + field_names = struct_dtype.fields + + # Unpack using struct field access + unpacked_cols = [] + for field in field_names: + col_name = field.name + unpacked_cols.append( + pl.col(struct_name).struct.field(col_name).alias(col_name) + ) + + # Select original columns + unpacked columns + other_cols = [c for c in df.columns if c != struct_name] + return df.select(other_cols + unpacked_cols) + + +def dump_to_parquet(df: pl.DataFrame, path: str, verbose: bool = True) -> None: + """Save DataFrame to parquet file.""" + if verbose: + print(f"Saving to parquet: {path}") + df.write_parquet(path) + if verbose: + print(f" Shape: {df.shape}") + + +def dump_to_pickle(df: pl.DataFrame, path: str, verbose: bool = True) -> None: + """Save DataFrame to pickle file.""" + if verbose: + print(f"Saving to pickle: {path}") + with open(path, 'wb') as f: + pickle.dump(df, f) + if verbose: + print(f" Shape: {df.shape}") + + +def dump_to_numpy( + feature_groups: FeatureGroups, + path: str, + include_metadata: bool = True, + verbose: bool = True +) -> None: + """ + Save features to numpy format. + + Saves: + - features.npy: The feature matrix + - metadata.pkl: Column names and metadata (if include_metadata=True) + """ + if verbose: + print(f"Saving to numpy: {path}") + + # Merge all groups for numpy array + df = feature_groups.merge_for_processors() + + # Get all feature columns (exclude instrument, datetime, indus_idx) + feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']] + + # Extract features + features = df.select(feature_cols).to_numpy().astype(np.float32) + features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) + + # Save features + base_path = Path(path) + if base_path.suffix == '': + base_path = Path(str(path) + '.npy') + + np.save(str(base_path), features) + + if verbose: + print(f" Features shape: {features.shape}") + + # Save metadata if requested + if include_metadata: + metadata_path = str(base_path).replace('.npy', '_metadata.pkl') + metadata = { + 'feature_cols': feature_cols, + 'instruments': df['instrument'].to_list(), + 'dates': df['datetime'].to_list(), + 'n_features': len(feature_cols), + 'n_samples': len(df), + } + with open(metadata_path, 'wb') as f: + pickle.dump(metadata, f) + if verbose: + print(f" Metadata saved to: {metadata_path}") + + +def get_groups( + df: pl.DataFrame, + groups_to_dump: List[str], + verbose: bool = True, + use_struct: bool = False, +) -> Dict[str, Any]: + """ + Select which feature groups to include in the output from a merged DataFrame. + + Args: + df: Transformed merged DataFrame (flat or with struct columns) + groups_to_dump: List of groups to include ('alpha158', 'market_ext', 'market_flag', 'merged', 'vae_input') + verbose: Whether to print progress + use_struct: If True, pack feature columns into a single 'features' struct column. + If df already has struct columns (pack_struct=True was used in pipeline), + this function will automatically handle them. + + Returns: + Dictionary with selected DataFrames and metadata + """ + # Check if df already has struct columns from pipeline.pack_struct + has_struct_cols = any( + isinstance(df.schema.get(c), pl.Struct) + for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx'] + ) + + result = {} + + if 'merged' in groups_to_dump: + if has_struct_cols: + # Already has struct columns from pipeline.pack_struct + result['merged'] = df + elif use_struct: + # Keep instrument, datetime, and pack rest into struct + feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']] + result['merged'] = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')]) + else: + result['merged'] = df + if verbose: + print(f"Merged features: {result['merged'].shape}") + + if 'alpha158' in groups_to_dump: + # Check if struct column already exists from pipeline.pack_struct + if has_struct_cols and 'features_alpha158' in df.columns: + result['alpha158'] = df.select(['instrument', 'datetime', 'features_alpha158']) + else: + # Select alpha158 columns from merged DataFrame + alpha_cols = ['instrument', 'datetime'] + [c for c in ALPHA158_COLS if c in df.columns] + # Also include _ntrl versions + alpha_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS if f"{c}_ntrl" in df.columns] + alpha_cols += alpha_ntrl_cols + df_alpha = df.select(alpha_cols) + if use_struct: + feature_cols = [c for c in df_alpha.columns if c not in ['instrument', 'datetime']] + df_alpha = df_alpha.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')]) + result['alpha158'] = df_alpha + if verbose: + print(f"Alpha158 features: {result['alpha158'].shape}") + + if 'market_ext' in groups_to_dump: + # Check if struct column already exists from pipeline.pack_struct + if has_struct_cols and 'features_market_ext' in df.columns: + result['market_ext'] = df.select(['instrument', 'datetime', 'features_market_ext']) + else: + # Select market_ext columns from merged DataFrame + market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS] + market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE] + ext_cols = ['instrument', 'datetime'] + market_ext_with_diff + # Also include _ntrl versions + ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff if f"{c}_ntrl" in df.columns] + ext_cols += ext_ntrl_cols + df_ext = df.select(ext_cols) + if use_struct: + feature_cols = [c for c in df_ext.columns if c not in ['instrument', 'datetime']] + df_ext = df_ext.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')]) + result['market_ext'] = df_ext + if verbose: + print(f"Market ext features: {result['market_ext'].shape}") + + if 'market_flag' in groups_to_dump: + # Check if struct column already exists from pipeline.pack_struct + if has_struct_cols and 'features_market_flag' in df.columns: + result['market_flag'] = df.select(['instrument', 'datetime', 'features_market_flag']) + else: + # Select market_flag columns from merged DataFrame + flag_cols = ['instrument', 'datetime'] + flag_cols += [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE and c in df.columns] + flag_cols += ['market_0', 'market_1', 'IsST'] if all(c in df.columns for c in ['market_0', 'market_1', 'IsST']) else [] + flag_cols = list(dict.fromkeys(flag_cols)) # Remove duplicates + df_flag = df.select(flag_cols) + if use_struct: + feature_cols = [c for c in df_flag.columns if c not in ['instrument', 'datetime']] + df_flag = df_flag.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')]) + result['market_flag'] = df_flag + if verbose: + print(f"Market flag features: {result['market_flag'].shape}") + + if 'vae_input' in groups_to_dump: + # Get VAE input columns from merged DataFrame + # VAE input = 330 normalized features + 11 market flags = 341 features + # Note: indus_idx is NOT included in VAE input + + # Build alpha158 feature columns (158 original + 158 _ntrl = 316) + alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS] + alpha158_cols = ALPHA158_COLS.copy() + + # Build market_ext feature columns (7 original + 7 _ntrl = 14) + market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS] + market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE] + market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff] + market_ext_cols = market_ext_with_diff.copy() + + # VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext] + norm_feature_cols = ( + alpha158_ntrl_cols + alpha158_cols + + market_ext_ntrl_cols + market_ext_cols + ) + + # Market flag columns (excluding IsST) + market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE] + market_flag_cols += ['market_0', 'market_1'] + market_flag_cols = list(dict.fromkeys(market_flag_cols)) + + # Combine all VAE input columns (341 total) + vae_cols = norm_feature_cols + market_flag_cols + + if use_struct: + # Pack all features into a single struct column + result['vae_input'] = df.select([ + 'instrument', + 'datetime', + pl.struct(vae_cols).alias('features') + ]) + if verbose: + print(f"VAE input features: {result['vae_input'].shape} (struct with {len(vae_cols)} fields)") + else: + # Always keep datetime and instrument as index columns + # Select features with index columns first + result['vae_input'] = df.select(['instrument', 'datetime'] + vae_cols) + if verbose: + print(f"VAE input features: {result['vae_input'].shape} (columns: {len(vae_cols)} + 2 index)") + + return result + + +def get_groups_from_fg( + feature_groups: FeatureGroups, + groups_to_dump: List[str], + verbose: bool = True, + use_struct: bool = False, +) -> Dict[str, Any]: + """ + Select which feature groups to include in the output. + + Args: + feature_groups: Transformed FeatureGroups container + groups_to_dump: List of groups to include ('alpha158', 'market_ext', 'market_flag', 'merged', 'vae_input') + verbose: Whether to print progress + use_struct: If True, pack feature columns into a single 'features' struct column + + Returns: + Dictionary with selected DataFrames and metadata + """ + result = {} + + if 'merged' in groups_to_dump: + df = feature_groups.merge_for_processors() + if use_struct: + # Keep instrument, datetime, and pack rest into struct + feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime', 'indus_idx']] + df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')]) + result['merged'] = df + if verbose: + print(f"Merged features: {result['merged'].shape}") + + if 'alpha158' in groups_to_dump: + df = feature_groups.alpha158 + if use_struct: + feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']] + df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')]) + result['alpha158'] = df + if verbose: + print(f"Alpha158 features: {df.shape}") + + if 'market_ext' in groups_to_dump: + df = feature_groups.market_ext + if use_struct: + feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']] + df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')]) + result['market_ext'] = df + if verbose: + print(f"Market ext features: {df.shape}") + + if 'market_flag' in groups_to_dump: + df = feature_groups.market_flag + if use_struct: + feature_cols = [c for c in df.columns if c not in ['instrument', 'datetime']] + df = df.select(['instrument', 'datetime', pl.struct(feature_cols).alias('features')]) + result['market_flag'] = df + if verbose: + print(f"Market flag features: {df.shape}") + + if 'vae_input' in groups_to_dump: + # Get VAE input columns from already-transformed feature_groups + # VAE input = 330 normalized features + 11 market flags = 341 features + # Note: indus_idx is NOT included in VAE input + df = feature_groups.merge_for_processors() + + # Build alpha158 feature columns (158 original + 158 _ntrl = 316) + alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS] + alpha158_cols = ALPHA158_COLS.copy() + + # Build market_ext feature columns (7 original + 7 _ntrl = 14) + market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS] + market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE] + market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff] + market_ext_cols = market_ext_with_diff.copy() + + # VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext] + norm_feature_cols = ( + alpha158_ntrl_cols + alpha158_cols + + market_ext_ntrl_cols + market_ext_cols + ) + + # Market flag columns (excluding IsST) + market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE] + market_flag_cols += ['market_0', 'market_1'] + market_flag_cols = list(dict.fromkeys(market_flag_cols)) + + # Combine all VAE input columns (341 total) + vae_cols = norm_feature_cols + market_flag_cols + + if use_struct: + # Pack all features into a single struct column + result['vae_input'] = df.select([ + 'instrument', + 'datetime', + pl.struct(vae_cols).alias('features') + ]) + if verbose: + print(f"VAE input features: {result['vae_input'].shape} (struct with {len(vae_cols)} fields)") + else: + # Always keep datetime and instrument as index columns + # Select features with index columns first + result['vae_input'] = df.select(['instrument', 'datetime'] + vae_cols) + if verbose: + print(f"VAE input features: {result['vae_input'].shape} (columns: {len(vae_cols)} + 2 index)") + + return result + + +def dump_features( + df: pl.DataFrame, + output_path: str, + output_format: str = 'parquet', + groups: List[str] = None, + verbose: bool = True, + use_struct: bool = False, +) -> None: + """ + Dump features to file. + + Args: + df: Transformed merged DataFrame + output_path: Output file path + output_format: Output format ('parquet', 'pickle', 'numpy') + groups: Feature groups to dump (default: ['merged']) + verbose: Whether to print progress + use_struct: If True, pack feature columns into a single 'features' struct column + """ + if groups is None: + groups = ['merged'] + + # Select feature groups from merged DataFrame + outputs = get_groups(df, groups, verbose, use_struct) + + # Ensure output directory exists + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + # For numpy, we use FeatureGroups (need to convert df back) + # This is a simplified version - for numpy, use dump_to_numpy directly + if output_format == 'numpy': + raise NotImplementedError( + "For numpy output, use the FeaturePipeline directly. " + "This function handles parquet/pickle output only." + ) + + # Dump to file(s) + if output_format == 'pickle': + if 'merged' in outputs: + dump_to_pickle(outputs['merged'], output_path, verbose=verbose) + elif len(outputs) == 1: + # Single group output + key = list(outputs.keys())[0] + base_path = Path(output_path) + dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}")) + dump_to_pickle(outputs[key], dump_path, verbose=verbose) + else: + # Multiple groups - save each separately + base_path = Path(output_path) + for key, df_out in outputs.items(): + dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}")) + dump_to_pickle(df_out, dump_path, verbose=verbose) + else: # parquet + if 'merged' in outputs: + dump_to_parquet(outputs['merged'], output_path, verbose=verbose) + elif len(outputs) == 1: + # Single group output + key = list(outputs.keys())[0] + base_path = Path(output_path) + dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}")) + dump_to_parquet(outputs[key], dump_path, verbose=verbose) + else: + # Multiple groups - save each separately + base_path = Path(output_path) + for key, df_out in outputs.items(): + dump_path = str(base_path.with_name(f"{base_path.stem}_{key}{base_path.suffix}")) + dump_to_parquet(df_out, dump_path, verbose=verbose) + + if verbose: + print("Feature dump complete!") \ No newline at end of file diff --git a/stock_1d/d033/alpha158_beta/src/processors/loaders.py b/stock_1d/d033/alpha158_beta/src/processors/loaders.py new file mode 100644 index 0000000..e1a22fc --- /dev/null +++ b/stock_1d/d033/alpha158_beta/src/processors/loaders.py @@ -0,0 +1,279 @@ +"""Data loading functions for the feature pipeline. + +Memory Best Practices: +- Always filter on partition keys (datetime) BEFORE collecting +- Use streaming for large date ranges (>1 year) +- Never collect full parquet tables - filter on datetime first +""" + +import polars as pl +from datetime import datetime +from typing import Optional, List, Tuple + +# Data paths +PARQUET_ALPHA158_BETA_PATH = "/data/parquet/dataset/stg_1day_wind_alpha158_0_7_beta_1D/" +PARQUET_KLINE_PATH = "/data/parquet/dataset/stg_1day_wind_kline_adjusted_1D/" +PARQUET_MARKET_FLAG_PATH = "/data/parquet/dataset/stg_1day_wind_market_flag_1D/" +PARQUET_INDUSTRY_FLAG_PATH = "/data/parquet/dataset/stg_1day_gds_indus_flag_cc1_1D/" +PARQUET_CON_RATING_PATH = "/data/parquet/dataset/stg_1day_gds_con_rating_1D/" + +# Industry flag columns (30 one-hot columns - note: gds_CC29 is not present in the data) +INDUSTRY_FLAG_COLS = [ + 'gds_CC10', 'gds_CC11', 'gds_CC12', 'gds_CC20', 'gds_CC21', 'gds_CC22', + 'gds_CC23', 'gds_CC24', 'gds_CC25', 'gds_CC26', 'gds_CC27', 'gds_CC28', + 'gds_CC30', 'gds_CC31', 'gds_CC32', 'gds_CC33', 'gds_CC34', 'gds_CC35', + 'gds_CC36', 'gds_CC37', 'gds_CC40', 'gds_CC41', 'gds_CC42', 'gds_CC43', + 'gds_CC50', 'gds_CC60', 'gds_CC61', 'gds_CC62', 'gds_CC63', 'gds_CC70' +] + +# Market flag columns +MARKET_FLAG_COLS_KLINE = ['IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR'] +MARKET_FLAG_COLS_MARKET = ['open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop'] + +# Market extension raw columns +MARKET_EXT_RAW_COLS = ['Turnover', 'FreeTurnover', 'MarketValue'] + + +def get_date_partitions(start_date: str, end_date: str) -> List[str]: + """ + Generate a list of date partitions to load from Parquet. + + Args: + start_date: Start date in YYYY-MM-DD format + end_date: End date in YYYY-MM-DD format + + Returns: + List of datetime=YYYYMMDD partition strings for weekdays only + """ + start = datetime.strptime(start_date, "%Y-%m-%d") + end = datetime.strptime(end_date, "%Y-%m-%d") + + partitions = [] + current = start + while current <= end: + if current.weekday() < 5: # Monday = 0, Friday = 4 + partitions.append(f"datetime={current.strftime('%Y%m%d')}") + current = datetime(current.year, current.month, current.day + 1) + + return partitions + + +def load_parquet_by_date_range( + base_path: str, + start_date: str, + end_date: str, + columns: Optional[List[str]] = None, + collect: bool = True, + streaming: bool = False +) -> pl.LazyFrame | pl.DataFrame: + """ + Load parquet data filtered by date range. + + CRITICAL: This function filters on the datetime partition key BEFORE collecting. + This is essential for memory efficiency with large parquet datasets. + + Args: + base_path: Base path to the parquet dataset + start_date: Start date in YYYY-MM-DD format + end_date: End date in YYYY-MM-DD format + columns: Optional list of columns to select (excluding instrument/datetime) + collect: If True, return DataFrame (default); if False, return LazyFrame + streaming: If True, use streaming mode for large datasets (recommended for >1 year) + + Returns: + Polars DataFrame with the loaded data (LazyFrame if collect=False) + """ + start_int = int(start_date.replace("-", "")) + end_int = int(end_date.replace("-", "")) + + try: + # Start with lazy scan - DO NOT COLLECT YET + lf = pl.scan_parquet(base_path) + + # CRITICAL: Filter on partition key FIRST, before any other operations + # This ensures partition pruning happens at the scan level + lf = lf.filter(pl.col('datetime') >= start_int) + lf = lf.filter(pl.col('datetime') <= end_int) + + # Select specific columns if provided (column pruning) + if columns: + # Get schema to check which columns exist + schema = lf.collect_schema() + available_cols = ['instrument', 'datetime'] + [c for c in columns if c in schema.names()] + lf = lf.select(available_cols) + + # Collect with optional streaming mode + if collect: + if streaming: + return lf.collect(streaming=True) + return lf.collect() + return lf + + except Exception as e: + print(f"Error loading from {base_path}: {e}") + # Return empty DataFrame with expected schema + if columns: + return pl.DataFrame({ + 'instrument': pl.Series([], dtype=pl.String), + 'datetime': pl.Series([], dtype=pl.Int32), + **{col: pl.Series([], dtype=pl.Float64) for col in columns if c not in ['instrument', 'datetime']} + }) + return pl.DataFrame({ + 'instrument': pl.Series([], dtype=pl.String), + 'datetime': pl.Series([], dtype=pl.Int32) + }) + + +def load_alpha158(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame: + """ + Load alpha158 beta factors from parquet. + + Args: + start_date: Start date in YYYY-MM-DD format + end_date: End date in YYYY-MM-DD format + streaming: If True, use streaming mode for large datasets + + Returns: + DataFrame with instrument, datetime, and 158 alpha158 features + """ + print("Loading alpha158_0_7_beta factors...") + df = load_parquet_by_date_range(PARQUET_ALPHA158_BETA_PATH, start_date, end_date, streaming=streaming) + print(f" Alpha158 shape: {df.shape}") + return df + + +def load_market_ext(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame: + """ + Load market extension features from parquet. + + Loads Turnover, FreeTurnover, MarketValue from kline data and transforms: + - Turnover -> turnover (rename) + - FreeTurnover -> free_turnover (rename) + - MarketValue -> log_size = log(MarketValue) + - con_rating_strength: loaded from parquet (or zeros if not available) + + Args: + start_date: Start date in YYYY-MM-DD format + end_date: End date in YYYY-MM-DD format + streaming: If True, use streaming mode for large datasets + + Returns: + DataFrame with instrument, datetime, turnover, free_turnover, log_size, con_rating_strength + """ + print("Loading kline data (market ext columns)...") + + # Load raw kline columns + df_kline = load_parquet_by_date_range( + PARQUET_KLINE_PATH, start_date, end_date, MARKET_EXT_RAW_COLS, streaming=streaming + ) + print(f" Kline (market ext raw) shape: {df_kline.shape}") + + # Load con_rating_strength from parquet + print("Loading con_rating_strength from parquet...") + df_con_rating = load_parquet_by_date_range( + PARQUET_CON_RATING_PATH, start_date, end_date, ['con_rating_strength'], streaming=streaming + ) + print(f" Con rating shape: {df_con_rating.shape}") + + # Transform columns + df_kline = df_kline.with_columns([ + pl.col('Turnover').alias('turnover'), + pl.col('FreeTurnover').alias('free_turnover'), + pl.col('MarketValue').log().alias('log_size'), + ]) + print(f" Kline (market ext transformed) shape: {df_kline.shape}") + + # Merge con_rating_strength + df_kline = df_kline.join(df_con_rating, on=['instrument', 'datetime'], how='left') + df_kline = df_kline.with_columns([ + pl.col('con_rating_strength').fill_null(0.0) + ]) + print(f" Kline (with con_rating) shape: {df_kline.shape}") + + return df_kline + + +def load_market_flags(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame: + """ + Load market flag features from parquet. + + Combines two sources: + - From kline_adjusted: IsZt, IsDt, IsN, IsXD, IsXR, IsDR + - From market_flag: open_limit, close_limit, low_limit, open_stop, close_stop, high_stop + + Args: + start_date: Start date in YYYY-MM-DD format + end_date: End date in YYYY-MM-DD format + streaming: If True, use streaming mode for large datasets + + Returns: + DataFrame with instrument, datetime, and 12 market flag columns + """ + # Load kline flags + print("Loading market flags from kline_adjusted...") + df_kline_flag = load_parquet_by_date_range( + PARQUET_KLINE_PATH, start_date, end_date, MARKET_FLAG_COLS_KLINE, streaming=streaming + ) + print(f" Kline flags shape: {df_kline_flag.shape}") + + # Load market flags + print("Loading market flags from market_flag table...") + df_market_flag = load_parquet_by_date_range( + PARQUET_MARKET_FLAG_PATH, start_date, end_date, MARKET_FLAG_COLS_MARKET, streaming=streaming + ) + print(f" Market flag shape: {df_market_flag.shape}") + + # Merge both flag sources + df_flag = df_kline_flag.join(df_market_flag, on=['instrument', 'datetime'], how='inner') + print(f" Combined flags shape: {df_flag.shape}") + + return df_flag + + +def load_industry_flags(start_date: str, end_date: str, streaming: bool = False) -> pl.DataFrame: + """ + Load industry flag features from parquet. + + Args: + start_date: Start date in YYYY-MM-DD format + end_date: End date in YYYY-MM-DD format + streaming: If True, use streaming mode for large datasets + + Returns: + DataFrame with instrument, datetime, and 29 industry flag columns + """ + print("Loading industry flags...") + df = load_parquet_by_date_range( + PARQUET_INDUSTRY_FLAG_PATH, start_date, end_date, INDUSTRY_FLAG_COLS, streaming=streaming + ) + print(f" Industry shape: {df.shape}") + return df + + +def load_all_data( + start_date: str, + end_date: str, + streaming: bool = False +) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame]: + """ + Load all data sources from Parquet. + + This is a convenience function that loads all four data sources + and returns them as separate DataFrames. + + Args: + start_date: Start date in YYYY-MM-DD format + end_date: End date in YYYY-MM-DD format + streaming: If True, use streaming mode for large datasets + + Returns: + Tuple of (alpha158_df, market_ext_df, market_flag_df, industry_df) + """ + print(f"Loading data from {start_date} to {end_date}...") + + df_alpha = load_alpha158(start_date, end_date, streaming=streaming) + df_market_ext = load_market_ext(start_date, end_date, streaming=streaming) + df_market_flag = load_market_flags(start_date, end_date, streaming=streaming) + df_industry = load_industry_flags(start_date, end_date, streaming=streaming) + + return df_alpha, df_market_ext, df_market_flag, df_industry diff --git a/stock_1d/d033/alpha158_beta/src/processors/pipeline.py b/stock_1d/d033/alpha158_beta/src/processors/pipeline.py new file mode 100644 index 0000000..ecf75a9 --- /dev/null +++ b/stock_1d/d033/alpha158_beta/src/processors/pipeline.py @@ -0,0 +1,539 @@ +"""FeaturePipeline orchestrator for the data loading and transformation pipeline.""" + +import os +import json +import numpy as np +import polars as pl +from typing import List, Dict, Optional, Tuple +from pathlib import Path + +from .dataclass import FeatureGroups +from .loaders import ( + load_alpha158, + load_market_ext, + load_market_flags, + load_industry_flags, + load_all_data, + INDUSTRY_FLAG_COLS, +) +from .processors import ( + DiffProcessor, + FlagMarketInjector, + FlagSTInjector, + ColumnRemover, + FlagToOnehot, + IndusNtrlInjector, + RobustZScoreNorm, + Fillna +) + +# Constants - Import from loaders module (single source of truth) +# INDUSTRY_FLAG_COLS is now imported from .loaders above + +# Alpha158 feature columns in explicit order +ALPHA158_COLS = [ + 'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2', + 'OPEN0', 'HIGH0', 'LOW0', 'VWAP0', + 'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60', + 'MA5', 'MA10', 'MA20', 'MA30', 'MA60', + 'STD5', 'STD10', 'STD20', 'STD30', 'STD60', + 'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60', + 'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60', + 'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60', + 'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60', + 'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60', + 'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60', + 'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60', + 'RANK5', 'RANK10', 'RANK20', 'RANK30', 'RANK60', + 'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60', + 'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60', + 'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60', + 'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60', + 'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60', + 'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60', + 'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60', + 'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60', + 'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60', + 'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60', + 'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60', + 'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60', + 'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60', + 'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60', + 'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60', + 'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60', + 'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60', + 'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60' +] + +# Market extension base columns +MARKET_EXT_BASE_COLS = ['turnover', 'free_turnover', 'log_size', 'con_rating_strength'] + +# Market flag columns (before processors) +MARKET_FLAG_COLS = [ + 'IsZt', 'IsDt', 'IsN', 'IsXD', 'IsXR', 'IsDR', + 'open_limit', 'close_limit', 'low_limit', 'open_stop', 'close_stop', 'high_stop' +] + +# Columns to remove after FlagMarketInjector and FlagSTInjector +COLUMNS_TO_REMOVE = ['log_size_diff', 'IsN', 'IsZt', 'IsDt'] + +# Expected VAE input dimension +VAE_INPUT_DIM = 341 + +# Default robust zscore parameters path +DEFAULT_ROBUST_ZSCORE_PARAMS_PATH = ( + "/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/" + "data/robust_zscore_params/csiallx_feature2_ntrla_flag_pnlnorm/" +) + + +def filter_stock_universe(df: pl.DataFrame, instruments: str = 'csiallx') -> pl.DataFrame: + """ + Filter dataframe to csiallx stock universe (A-shares excluding STAR/BSE). + + This uses qshare's filter_instruments which loads the instrument list from: + /data/qlib/default/data_ops/target/instruments/csiallx.txt + + Args: + df: Input DataFrame with datetime and instrument columns + instruments: Market name for spine creation (default: 'csiallx') + + Returns: + Filtered DataFrame with only instruments in the specified universe + """ + from qshare.algo.polars.spine import filter_instruments + return filter_instruments(df, instruments=instruments) + + +def load_robust_zscore_params( + params_path: str = None +) -> Dict[str, np.ndarray]: + """ + Load pre-fitted RobustZScoreNorm parameters from numpy files. + + Loads mean_train.npy and std_train.npy from the specified directory. + Parameters are cached to avoid repeated file I/O. + + Args: + params_path: Path to the directory containing mean_train.npy and std_train.npy. + If None, uses the default path. + + Returns: + Dictionary with 'mean_train' and 'std_train' numpy arrays + + Raises: + FileNotFoundError: If parameter files are not found + """ + if params_path is None: + params_path = DEFAULT_ROBUST_ZSCORE_PARAMS_PATH + + # Check for cached params in the class (module-level cache) + if not hasattr(load_robust_zscore_params, '_cached_params'): + load_robust_zscore_params._cached_params = {} + + if params_path in load_robust_zscore_params._cached_params: + return load_robust_zscore_params._cached_params[params_path] + + print(f"Loading robust zscore parameters from: {params_path}") + + mean_path = os.path.join(params_path, 'mean_train.npy') + std_path = os.path.join(params_path, 'std_train.npy') + + if not os.path.exists(mean_path): + raise FileNotFoundError(f"mean_train.npy not found at {mean_path}") + if not os.path.exists(std_path): + raise FileNotFoundError(f"std_train.npy not found at {std_path}") + + mean_train = np.load(mean_path) + std_train = np.load(std_path) + + print(f"Loaded parameters:") + print(f" mean_train shape: {mean_train.shape}") + print(f" std_train shape: {std_train.shape}") + + # Try to load metadata if available + metadata_path = os.path.join(params_path, 'metadata.json') + if os.path.exists(metadata_path): + with open(metadata_path, 'r') as f: + metadata = json.load(f) + print(f" fitted on: {metadata.get('fit_start_time', 'N/A')} to {metadata.get('fit_end_time', 'N/A')}") + + params = { + 'mean_train': mean_train, + 'std_train': std_train + } + + # Cache the loaded parameters + load_robust_zscore_params._cached_params[params_path] = params + + return params + + +class FeaturePipeline: + """ + Feature pipeline orchestrator for loading and transforming data. + + The pipeline manages: + 1. Data loading from parquet sources + 2. Feature transformations via processors + 3. Output preparation for VAE input + + Usage: + pipeline = FeaturePipeline(config) + feature_groups = pipeline.load_data(start_date, end_date) + transformed = pipeline.transform(feature_groups) + vae_input = pipeline.get_vae_input(transformed) + """ + + def __init__( + self, + config: Optional[Dict] = None, + robust_zscore_params_path: Optional[str] = None + ): + """ + Initialize the FeaturePipeline. + + Args: + config: Optional configuration dictionary with pipeline settings. + If None, uses default configuration. + robust_zscore_params_path: Path to robust zscore parameters. + If None, uses default path. + """ + self.config = config or {} + self.robust_zscore_params_path = robust_zscore_params_path or DEFAULT_ROBUST_ZSCORE_PARAMS_PATH + + # Cache for loaded robust zscore params + self._robust_zscore_params = None + + # Initialize processors + self._init_processors() + + def _init_processors(self): + """Initialize all processors with default configuration.""" + # Diff processor for market_ext columns + self.diff_processor = DiffProcessor(MARKET_EXT_BASE_COLS) + + # Market flag injector + self.flag_market_injector = FlagMarketInjector() + + # ST flag injector + self.flag_st_injector = FlagSTInjector() + + # Column remover + self.column_remover = ColumnRemover(COLUMNS_TO_REMOVE) + + # Industry flag to index converter + self.flag_to_onehot = FlagToOnehot(INDUSTRY_FLAG_COLS) + + # Industry neutralization injectors (created on-demand with specific feature lists) + # self.indus_ntrl_injector = None # Created per-feature-group + + # Robust zscore normalizer (created on-demand with loaded params) + # self.robust_norm = None # Created with loaded params + + # Fillna processor + self.fillna = Fillna() + + def load_data( + self, + start_date: str, + end_date: str, + filter_universe: bool = True, + universe_name: str = 'csiallx', + streaming: bool = False + ) -> FeatureGroups: + """ + Load data from parquet sources into FeatureGroups container. + + Args: + start_date: Start date in YYYY-MM-DD format + end_date: End date in YYYY-MM-DD format + filter_universe: Whether to filter to a specific stock universe + universe_name: Name of the stock universe to filter to + streaming: If True, use Polars streaming mode for large datasets (>1 year) + + Returns: + FeatureGroups container with loaded data + """ + print("=" * 60) + print(f"Loading data from {start_date} to {end_date}") + print("=" * 60) + + # Load all data sources + df_alpha, df_market_ext, df_market_flag, df_industry = load_all_data( + start_date, end_date, streaming=streaming + ) + + # Apply stock universe filter if requested + if filter_universe: + print(f"Filtering to {universe_name} universe...") + df_alpha = filter_stock_universe(df_alpha, instruments=universe_name) + df_market_ext = filter_stock_universe(df_market_ext, instruments=universe_name) + df_market_flag = filter_stock_universe(df_market_flag, instruments=universe_name) + df_industry = filter_stock_universe(df_industry, instruments=universe_name) + print(f" After filter - Alpha158 shape: {df_alpha.shape}") + + # Create FeatureGroups container + feature_groups = FeatureGroups( + alpha158=df_alpha, + market_ext=df_market_ext, + market_flag=df_market_flag, + industry=df_industry + ) + + # Extract metadata + feature_groups.extract_metadata() + + print(f"Loaded {len(feature_groups.instruments)} samples") + print("=" * 60) + + return feature_groups + + def transform( + self, + feature_groups: FeatureGroups, + pack_struct: bool = False + ) -> pl.DataFrame: + """ + Apply feature transformation pipeline to FeatureGroups. + + The pipeline applies processors in the following order: + 1. DiffProcessor - adds diff features to market_ext + 2. FlagMarketInjector - adds market_0, market_1 to market_flag + 3. FlagSTInjector - adds IsST to market_flag + 4. ColumnRemover - removes log_size_diff, IsN, IsZt, IsDt + 5. FlagToOnehot - converts industry flags to indus_idx + 6. IndusNtrlInjector - industry neutralization for alpha158 and market_ext + 7. RobustZScoreNorm - robust z-score normalization + 8. Fillna - fill NaN values with 0 + + Args: + feature_groups: FeatureGroups container with loaded data + pack_struct: If True, pack each feature group into a struct column + (features_alpha158, features_market_ext, features_market_flag). + If False (default), return flat DataFrame with all columns merged. + + Returns: + Merged DataFrame with all transformed features (pl.DataFrame) + """ + print("=" * 60) + print("Starting feature transformation pipeline") + print("=" * 60) + + # Merge all groups for processing + df = feature_groups.merge_for_processors() + + # Step 1: Diff Processor - adds diff features for market_ext + df = self.diff_processor.process(df) + + # Step 2: FlagMarketInjector - adds market_0, market_1 + df = self.flag_market_injector.process(df) + + # Step 3: FlagSTInjector - adds IsST + df = self.flag_st_injector.process(df) + + # Step 4: ColumnRemover - removes specific columns + df = self.column_remover.process(df) + + # Step 5: FlagToOnehot - converts industry flags to indus_idx + df = self.flag_to_onehot.process(df) + + # Step 6: IndusNtrlInjector - industry neutralization + # For alpha158 features + indus_ntrl_alpha = IndusNtrlInjector(ALPHA158_COLS, suffix='_ntrl') + df = indus_ntrl_alpha.process(df, df) # Pass df as both feature and industry source + + # For market_ext features (with diff columns) + market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS] + # Remove columns that were dropped + market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE] + indus_ntrl_ext = IndusNtrlInjector(market_ext_with_diff, suffix='_ntrl') + df = indus_ntrl_ext.process(df, df) + + # Step 7: RobustZScoreNorm - robust z-score normalization + # Load parameters and create normalizer + if self._robust_zscore_params is None: + self._robust_zscore_params = load_robust_zscore_params(self.robust_zscore_params_path) + + # Build the list of features to normalize + alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS] + market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff] + + # Feature order for VAE: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext] + norm_feature_cols = ( + alpha158_ntrl_cols + ALPHA158_COLS + + market_ext_ntrl_cols + market_ext_with_diff + ) + + print(f"Applying RobustZScoreNorm to {len(norm_feature_cols)} features...") + + robust_norm = RobustZScoreNorm( + norm_feature_cols, + clip_range=(-3, 3), + use_qlib_params=True, + qlib_mean=self._robust_zscore_params['mean_train'], + qlib_std=self._robust_zscore_params['std_train'] + ) + df = robust_norm.process(df) + + # Step 8: Fillna - fill NaN with 0 + # Get all feature columns + market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE] + market_flag_cols += ['market_0', 'market_1', 'IsST'] + market_flag_cols = list(dict.fromkeys(market_flag_cols)) # Remove duplicates + + final_feature_cols = norm_feature_cols + market_flag_cols + ['indus_idx'] + df = self.fillna.process(df, final_feature_cols) + + print("=" * 60) + print("Pipeline complete") + print(f" Total columns: {len(df.columns)}") + print(f" Rows: {len(df)}") + + # Optionally pack features into struct columns + if pack_struct: + df = self._pack_into_structs(df) + print(f" Packed into struct columns") + + print("=" * 60) + return df + + def _pack_into_structs(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Pack feature groups into struct columns. + + Creates: + - features_alpha158: struct with 316 fields (158 + 158 _ntrl) + - features_market_ext: struct with 14 fields (7 + 7 _ntrl) + - features_market_flag: struct with 11 fields + + Returns: + DataFrame with columns: instrument, datetime, indus_idx, features_* + """ + # Define column groups + alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS] + alpha158_all_cols = alpha158_ntrl_cols + ALPHA158_COLS + + market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS] + market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE] + market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff] + market_ext_all_cols = market_ext_ntrl_cols + market_ext_with_diff + + market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE] + market_flag_cols += ['market_0', 'market_1', 'IsST'] + market_flag_cols = list(dict.fromkeys(market_flag_cols)) + + # Build result with struct columns + result_cols = ['instrument', 'datetime'] + + # Check if indus_idx exists + if 'indus_idx' in df.columns: + result_cols.append('indus_idx') + + # Pack alpha158 + alpha158_cols_in_df = [c for c in alpha158_all_cols if c in df.columns] + if alpha158_cols_in_df: + result_cols.append(pl.struct(alpha158_cols_in_df).alias('features_alpha158')) + + # Pack market_ext + ext_cols_in_df = [c for c in market_ext_all_cols if c in df.columns] + if ext_cols_in_df: + result_cols.append(pl.struct(ext_cols_in_df).alias('features_market_ext')) + + # Pack market_flag + flag_cols_in_df = [c for c in market_flag_cols if c in df.columns] + if flag_cols_in_df: + result_cols.append(pl.struct(flag_cols_in_df).alias('features_market_flag')) + + return df.select(result_cols) + + def get_vae_input( + self, + df: pl.DataFrame | FeatureGroups, + exclude_isst: bool = True + ) -> np.ndarray: + """ + Prepare VAE input features from transformed DataFrame or FeatureGroups. + + VAE input structure (341 features): + - feature group (316): 158 alpha158 + 158 alpha158_ntrl + - feature_ext group (14): 7 market_ext + 7 market_ext_ntrl + - feature_flag group (11): market flags (excluding IsST) + + NOTE: indus_idx is NOT included in VAE input. + + Args: + df: Transformed DataFrame or FeatureGroups container + exclude_isst: Whether to exclude IsST from VAE input (default: True) + + Returns: + Numpy array of shape (n_samples, VAE_INPUT_DIM) + """ + print("Preparing features for VAE...") + + # Accept either DataFrame or FeatureGroups + if isinstance(df, FeatureGroups): + df = df.merge_for_processors() + + # Build alpha158 feature columns + alpha158_ntrl_cols = [f"{c}_ntrl" for c in ALPHA158_COLS] + alpha158_cols = ALPHA158_COLS.copy() + + # Build market_ext feature columns (with diff, minus removed columns) + market_ext_with_diff = MARKET_EXT_BASE_COLS + [f"{c}_diff" for c in MARKET_EXT_BASE_COLS] + market_ext_with_diff = [c for c in market_ext_with_diff if c not in COLUMNS_TO_REMOVE] + market_ext_ntrl_cols = [f"{c}_ntrl" for c in market_ext_with_diff] + market_ext_cols = market_ext_with_diff.copy() + + # VAE feature order: [alpha158_ntrl, alpha158, market_ext_ntrl, market_ext] + norm_feature_cols = ( + alpha158_ntrl_cols + alpha158_cols + + market_ext_ntrl_cols + market_ext_cols + ) + + # Market flag columns (excluding IsST if requested) + market_flag_cols = [c for c in MARKET_FLAG_COLS if c not in COLUMNS_TO_REMOVE] + market_flag_cols += ['market_0', 'market_1'] + if not exclude_isst: + market_flag_cols.append('IsST') + market_flag_cols = list(dict.fromkeys(market_flag_cols)) + + # Combine all VAE input columns + vae_cols = norm_feature_cols + market_flag_cols + + print(f" norm_feature_cols: {len(norm_feature_cols)}") + print(f" market_flag_cols: {len(market_flag_cols)}") + print(f" Total VAE input columns: {len(vae_cols)}") + + # Verify all columns exist + missing_cols = [c for c in vae_cols if c not in df.columns] + if missing_cols: + print(f"WARNING: Missing columns: {missing_cols}") + + # Select features and convert to numpy + features_df = df.select(vae_cols) + features = features_df.to_numpy().astype(np.float32) + + # Handle any remaining NaN/Inf values + features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) + + print(f"Feature matrix shape: {features.shape}") + + # Verify dimensions + if features.shape[1] != VAE_INPUT_DIM: + print(f"WARNING: Expected {VAE_INPUT_DIM} features, got {features.shape[1]}") + + if features.shape[1] < VAE_INPUT_DIM: + # Pad with zeros + padding = np.zeros( + (features.shape[0], VAE_INPUT_DIM - features.shape[1]), + dtype=np.float32 + ) + features = np.concatenate([features, padding], axis=1) + print(f"Padded to shape: {features.shape}") + else: + # Truncate + features = features[:, :VAE_INPUT_DIM] + print(f"Truncated to shape: {features.shape}") + + return features diff --git a/stock_1d/d033/alpha158_beta/src/processors/processors.py b/stock_1d/d033/alpha158_beta/src/processors/processors.py new file mode 100644 index 0000000..c0ee302 --- /dev/null +++ b/stock_1d/d033/alpha158_beta/src/processors/processors.py @@ -0,0 +1,447 @@ +"""Processor classes for feature transformation. + +Processors are stateless, functional components that operate on +specific feature groups or merged DataFrames. +""" + +import numpy as np +import polars as pl +from typing import List, Tuple, Optional, Dict + + +class DiffProcessor: + """ + Diff Processor: Calculate diff features for market_ext columns. + + For each column, calculates diff with period=1 within each instrument group. + """ + + def __init__(self, columns: List[str]): + """ + Initialize the DiffProcessor. + + Args: + columns: List of column names to compute diffs for + """ + self.columns = columns + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Add diff features for specified columns. + + Args: + df: Input DataFrame with market_ext columns + + Returns: + DataFrame with added {col}_diff columns + """ + print("Applying Diff processor...") + + # CRITICAL: Build all expressions FIRST, then apply in single with_columns() + # Use order_by='datetime' to ensure proper time-series ordering within each instrument + diff_exprs = [ + pl.col(col).diff().over('instrument', order_by='datetime').alias(f"{col}_diff") + for col in self.columns + if col in df.columns + ] + + if diff_exprs: + df = df.with_columns(diff_exprs) + + return df + + +class FlagMarketInjector: + """ + Flag Market Injector: Create market_0, market_1 columns based on instrument code. + + Maps to Qlib's map_market_sec logic with vocab_size=2: + - market_0 (主板): SH60xxx, SZ00xxx + - market_1 (科创板/创业板): SH688xxx, SH689xxx, SZ300xxx, SZ301xxx + + NOTE: vocab_size=2 (not 3!) - 新三板/北交所 (NE4xxxx, NE8xxxx) are NOT included. + """ + + def process(self, df: pl.DataFrame, instrument_col: str = 'instrument') -> pl.DataFrame: + """ + Add market_0, market_1 columns. + + Args: + df: Input DataFrame with instrument column + instrument_col: Name of the instrument column + + Returns: + DataFrame with added market_0 and market_1 columns + """ + print("Applying FlagMarketInjector (vocab_size=2)...") + + # Convert instrument to string and pad to 6 digits + inst_str = pl.col(instrument_col).cast(pl.String).str.zfill(6) + + # Determine market type based on first digit + is_sh_main = inst_str.str.starts_with('6') # SH600xxx, SH601xxx, etc. + is_sz_main = inst_str.str.starts_with('0') # SZ000xxx + is_sh_star = inst_str.str.starts_with('688') | inst_str.str.starts_with('689') # SH688xxx, SH689xxx + is_sz_chi = inst_str.str.starts_with('300') | inst_str.str.starts_with('301') # SZ300xxx, SZ301xxx + + df = df.with_columns([ + # market_0 = 主板 (SH main + SZ main) + (is_sh_main | is_sz_main).cast(pl.Int8).alias('market_0'), + # market_1 = 科创板 + 创业板 (SH star + SZ ChiNext) + (is_sh_star | is_sz_chi).cast(pl.Int8).alias('market_1') + ]) + + return df + + +class FlagSTInjector: + """ + Flag ST Injector: Create IsST column from ST flags. + + Creates IsST = ST_S | ST_Y if ST flags are available, + otherwise creates a placeholder column of all zeros. + """ + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Add IsST column. + + Args: + df: Input DataFrame + + Returns: + DataFrame with added IsST column + """ + print("Applying FlagSTInjector (creating IsST)...") + + # Check if ST flags are available + if 'ST_S' in df.columns or 'st_flag::ST_S' in df.columns: + # Create IsST from actual ST flags + df = df.with_columns([ + ((pl.col('ST_S').cast(pl.Boolean, strict=False) | + pl.col('ST_Y').cast(pl.Boolean, strict=False)) + .cast(pl.Int8).alias('IsST')) + ]) + else: + # Create placeholder (all zeros) if ST flags not available + df = df.with_columns([ + pl.lit(0).cast(pl.Int8).alias('IsST') + ]) + + return df + + +class ColumnRemover: + """ + Column Remover: Drop specific columns. + + Removes columns that are not needed for the VAE input. + """ + + def __init__(self, columns_to_remove: List[str]): + """ + Initialize the ColumnRemover. + + Args: + columns_to_remove: List of column names to remove + """ + self.columns_to_remove = columns_to_remove + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Remove specified columns. + + Args: + df: Input DataFrame + + Returns: + DataFrame with specified columns removed + """ + print(f"Applying ColumnRemover (removing {len(self.columns_to_remove)} columns)...") + + # Only remove columns that exist + cols_to_drop = [c for c in self.columns_to_remove if c in df.columns] + if cols_to_drop: + df = df.drop(cols_to_drop) + + return df + + +class FlagToOnehot: + """ + Flag To Onehot: Convert 29 one-hot industry columns to single indus_idx. + + For each row, finds which industry column is True/1 and sets indus_idx to that index. + """ + + def __init__(self, industry_cols: List[str]): + """ + Initialize the FlagToOnehot. + + Args: + industry_cols: List of 29 industry flag column names + """ + self.industry_cols = industry_cols + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Convert industry flags to single indus_idx column. + + Args: + df: Input DataFrame with industry flag columns + + Returns: + DataFrame with indus_idx column (original industry columns removed) + """ + print("Applying FlagToOnehot (converting 29 industry flags to indus_idx)...") + + # Build a when/then chain to find the industry index + # Start with -1 (no industry) as default + indus_expr = pl.lit(-1) + + for idx, col in enumerate(self.industry_cols): + if col in df.columns: + indus_expr = pl.when(pl.col(col) == 1).then(idx).otherwise(indus_expr) + + df = df.with_columns([indus_expr.alias('indus_idx')]) + + # Drop the original one-hot columns + cols_to_drop = [c for c in self.industry_cols if c in df.columns] + if cols_to_drop: + df = df.drop(cols_to_drop) + + return df + + +class IndusNtrlInjector: + """ + Industry Neutralization Injector: Industry neutralization for features. + + For each feature, subtracts the industry mean (grouped by indus_idx) + from the feature value. Creates new columns with "_ntrl" suffix while + keeping original columns. + + IMPORTANT: Industry neutralization is done PER DATETIME (cross-sectional), + not across the entire dataset. This matches qlib's cal_indus_ntrl behavior. + """ + + def __init__(self, feature_cols: List[str], suffix: str = '_ntrl'): + """ + Initialize the IndusNtrlInjector. + + Args: + feature_cols: List of feature column names to neutralize + suffix: Suffix to append to neutralized column names + """ + self.feature_cols = feature_cols + self.suffix = suffix + + def process( + self, + feature_df: pl.DataFrame, + industry_df: pl.DataFrame + ) -> pl.DataFrame: + """ + Apply industry neutralization to specified features. + + Args: + feature_df: DataFrame with feature columns to neutralize + industry_df: DataFrame with indus_idx column (must have instrument, datetime) + + Returns: + DataFrame with added {col}_ntrl columns + """ + print(f"Applying IndusNtrlInjector to {len(self.feature_cols)} features...") + + # Check if indus_idx already exists in feature_df + if 'indus_idx' in feature_df.columns: + df = feature_df + else: + # Merge industry index into feature dataframe + df = feature_df.join( + industry_df.select(['instrument', 'datetime', 'indus_idx']), + on=['instrument', 'datetime'], + how='left', + suffix='_indus' + ) + + # Filter to only columns that exist + existing_cols = [c for c in self.feature_cols if c in df.columns] + + # CRITICAL: Build all neutralization expressions FIRST, then apply in single with_columns() + # Use order_by='datetime' to ensure proper time-series ordering within each group + # The neutralization is done per datetime (cross-sectional), so order_by='datetime' + # ensures values are processed in chronological order + ntrl_exprs = [ + (pl.col(col) - pl.col(col).mean().over(['datetime', 'indus_idx'], order_by='datetime')).alias(f"{col}{self.suffix}") + for col in existing_cols + ] + + if ntrl_exprs: + df = df.with_columns(ntrl_exprs) + + return df + + +class RobustZScoreNorm: + """ + Robust Z-Score Normalization: Per datetime normalization. + + (x - median) / (1.4826 * MAD) where MAD = median(|x - median|) + Clip outliers at [-3, 3]. + + Supports pre-fitted parameters from qlib's pickled processor: + normalizer = RobustZScoreNorm( + feature_cols=feature_cols, + use_qlib_params=True, + qlib_mean=zscore_proc.mean_train, + qlib_std=zscore_proc.std_train + ) + """ + + def __init__( + self, + feature_cols: List[str], + clip_range: Tuple[float, float] = (-3, 3), + use_qlib_params: bool = False, + qlib_mean: Optional[np.ndarray] = None, + qlib_std: Optional[np.ndarray] = None + ): + """ + Initialize the RobustZScoreNorm. + + Args: + feature_cols: List of feature column names to normalize + clip_range: Tuple of (min, max) for clipping normalized values + use_qlib_params: Whether to use pre-fitted parameters from qlib + qlib_mean: Pre-fitted mean array from qlib (required if use_qlib_params=True) + qlib_std: Pre-fitted std array from qlib (required if use_qlib_params=True) + """ + self.feature_cols = feature_cols + self.clip_range = clip_range + self.use_qlib_params = use_qlib_params + self.mean_train = qlib_mean + self.std_train = qlib_std + + if use_qlib_params: + if qlib_mean is None or qlib_std is None: + raise ValueError("Must provide qlib_mean and qlib_std when use_qlib_params=True") + print(f"Using pre-fitted qlib parameters (mean shape: {qlib_mean.shape}, std shape: {qlib_std.shape})") + + def process(self, df: pl.DataFrame) -> pl.DataFrame: + """ + Apply robust z-score normalization. + + Args: + df: Input DataFrame with feature columns + + Returns: + DataFrame with normalized feature columns (in-place modification) + """ + print(f"Applying RobustZScoreNorm to {len(self.feature_cols)} features...") + + # Filter to only columns that exist + existing_cols = [c for c in self.feature_cols if c in df.columns] + + if self.use_qlib_params: + # Use pre-fitted parameters from qlib (fit once, apply to all dates) + # CRITICAL: Build all normalization expressions FIRST, then apply in single with_columns() + # This avoids creating a copy per column (330 columns = 330 copies if done in loop) + norm_exprs = [] + for i, col in enumerate(existing_cols): + if i < len(self.mean_train): + mean_val = float(self.mean_train[i]) + std_val = float(self.std_train[i]) + norm_exprs.append( + ((pl.col(col) - mean_val) / (std_val + 1e-8)) + .clip(self.clip_range[0], self.clip_range[1]) + .alias(col) + ) + + if norm_exprs: + df = df.with_columns(norm_exprs) + else: + # Compute per-datetime robust z-score (original behavior) + # CRITICAL: Build all expressions with temp columns first, then clean up in single drop() + all_exprs = [] + temp_cols = [] + + for col in existing_cols: + # Compute median per datetime + median_col = f"__median_{col}" + temp_cols.append(median_col) + all_exprs.append( + pl.col(col).median().over('datetime').alias(median_col) + ) + + # Compute absolute deviation + abs_dev_col = f"__absdev_{col}" + temp_cols.append(abs_dev_col) + all_exprs.append( + (pl.col(col) - pl.col(median_col)).abs().alias(abs_dev_col) + ) + + # Compute MAD (median of absolute deviations) + mad_col = f"__mad_{col}" + temp_cols.append(mad_col) + all_exprs.append( + pl.col(abs_dev_col).median().over('datetime').alias(mad_col) + ) + + # Compute robust z-score and clip (modifies original column) + all_exprs.append( + ((pl.col(col) - pl.col(median_col)) / (1.4826 * pl.col(mad_col) + 1e-8)) + .clip(self.clip_range[0], self.clip_range[1]) + .alias(col) + ) + + # Apply all expressions in single with_columns() + if all_exprs: + df = df.with_columns(all_exprs) + + # Clean up all temporary columns in single drop() + if temp_cols: + df = df.drop(temp_cols) + + return df + + +class Fillna: + """ + Fill NaN: Fill all NaN values with 0 for numeric columns. + """ + + def process( + self, + df: pl.DataFrame, + feature_cols: List[str] + ) -> pl.DataFrame: + """ + Fill NaN values with 0 for specified columns. + + Args: + df: Input DataFrame + feature_cols: List of column names to fill NaN values for + + Returns: + DataFrame with NaN values filled with 0 + """ + print("Applying Fillna processor...") + + # Filter to only columns that exist and are numeric (not boolean) + existing_cols = [ + c for c in feature_cols + if c in df.columns and df[c].dtype in [pl.Float32, pl.Float64, pl.Int32, pl.Int64, pl.UInt32, pl.UInt64] + ] + + # CRITICAL: Build all fill expressions FIRST, then apply in single with_columns() + # This avoids creating a copy per column (~345 columns = ~345 copies if done in loop) + fill_exprs = [ + pl.col(col).fill_null(0.0).fill_nan(0.0) + for col in existing_cols + ] + + if fill_exprs: + df = df.with_columns(fill_exprs) + + return df