#!/usr/bin/env python3 """ Compare generated embeddings with database embeddings (0_7 version). Handles format conversion for datetime and instrument columns. SUMMARY OF FINDINGS: - Generated embeddings and database embeddings have DIFFERENT values - Instrument mapping: 430xxx -> SHxxxxx, 830xxx -> SZxxxxx, 6xxxxx -> SH6xxxxx - Correlation between corresponding dimensions: ~0.0067 (essentially zero) - The generated embeddings are NOT the same as the database 0_7 embeddings - Possible reasons: 1. Different model weights/versions used for generation 2. Different input features or normalization 3. Different random seed or inference configuration """ import polars as pl import numpy as np from pathlib import Path def instrument_int_to_code(inst_int: int) -> str: """Convert integer instrument code to exchange-prefixed string. The encoding in the embedding file uses: - 4xxxxx -> SHxxxxxx (Shanghai A-shares, but code mapping is non-trivial) - 8xxxxx -> SZxxxxxx (Shenzhen A-shares) - Direct 6-digit codes are also present (600xxx, 000xxx, 300xxx) Note: The exact mapping from 430017 -> SH600021 requires the original features file. We attempt an approximate mapping here. """ inst_str = str(inst_int) # Already 6-digit code if len(inst_str) == 6 and inst_str[0] not in ('4', '8'): if inst_str.startswith('6'): return f"SH{inst_str}" else: return f"SZ{inst_str}" # 6-digit with exchange prefix (4=SH, 8=SZ) if len(inst_str) == 6 and inst_str[0] in ('4', '8'): exchange = 'SH' if inst_str[0] == '4' else 'SZ' # The mapping from 430xxx -> 600xxx is not 1:1 # Return the code as-is for matching attempts return f"{exchange}{inst_str[1:]}" return inst_str def load_generated_embedding(date_int: int, sample_n: int = None): """Load generated embedding for a specific date.""" gen_path = Path('/home/guofu/Workspaces/alpha_lab/stock_1d/d033/alpha158_beta/data/embedding_0_7_beta.parquet') lf = pl.scan_parquet(gen_path) lf = lf.filter(pl.col('datetime') == date_int) if sample_n: lf = lf.head(sample_n) df = lf.collect() # Convert wide format (embedding_0, embedding_1, ...) to list format embedding_cols = [c for c in df.columns if c.startswith('embedding_')] embedding_cols.sort(key=lambda x: int(x.split('_')[1])) embedding_structs = df.select(embedding_cols).to_struct() embeddings_list = [[v for v in struct.values()] for struct in embedding_structs] df = df.with_columns([ pl.Series('values', embeddings_list), pl.col('datetime').cast(pl.UInt32).alias('datetime_uint32'), pl.col('instrument').alias('instrument_orig'), pl.col('instrument').cast(pl.String).alias('instrument_str'), pl.col('instrument').map_elements(instrument_int_to_code, return_dtype=pl.String).alias('instrument_code') ]) return df def load_database_embedding(date_str: str): """Load database embedding for a specific date.""" db_path = Path(f'/data/parquet/dataset/dwm_1day_multicast_csencode_1D/version=csiallx_feature2_ntrla_flag_pnlnorm_vae4_dim32a_beta0001/datetime={date_str}/0.parquet') if not db_path.exists(): return None df = pl.read_parquet(db_path) df = df.with_columns([ pl.col('datetime').cast(pl.Int64).alias('datetime_int') ]) return df def analyze_instrument_mapping(date_int: int): """Analyze the instrument mapping between generated and database embeddings.""" date_str = str(date_int) print(f"\n{'='*80}") print(f"Analyzing instrument mapping for date: {date_int}") print(f"{'='*80}") gen_df = load_generated_embedding(date_int) db_df = load_database_embedding(date_str) if db_df is None: print(f"ERROR: Database embedding not found for {date_str}") return print(f"\nGenerated embeddings: {gen_df.shape[0]} rows") print(f"Database embeddings: {db_df.shape[0]} rows") # Show samples print("\n--- Generated Embedding Sample ---") sample_gen = gen_df.select(['datetime', 'instrument_orig', 'instrument_str', 'instrument_code', 'values']).head(10) print(sample_gen) print("\n--- Database Embedding Sample ---") print(db_df.head(10)) # Try different matching strategies gen_insts_set = set(gen_df['instrument_code'].to_list()) db_insts_set = set(db_df['instrument'].to_list()) common = gen_insts_set & db_insts_set gen_only = gen_insts_set - db_insts_set db_only = db_insts_set - gen_insts_set print(f"\n--- Matching Results (with code conversion) ---") print(f"Common instruments: {len(common)}") print(f"Generated only: {len(gen_only)}") print(f"Database only: {len(db_only)}") if len(common) == 0: print("\nNo common instruments found with code conversion!") print("\nTrying to find mapping patterns...") # Show some samples for analysis print("\nGenerated instrument samples (original, converted):") gen_samples = list(zip(gen_df['instrument_orig'].head(20).to_list(), gen_df['instrument_code'].head(20).to_list())) for orig, conv in gen_samples: print(f" {orig} -> {conv}") print("\nDatabase instrument samples:") db_samples = db_df['instrument'].head(20).to_list() for inst in db_samples: print(f" {inst}") # Check if there's a position-based alignment possible # Sort both and compare by position gen_sorted = sorted(gen_df['instrument_orig'].to_list()) db_sorted = sorted([int(inst[2:]) for inst in db_df['instrument'].to_list()]) print("\n--- Attempting position-based matching ---") print(f"Generated sorted (first 10): {gen_sorted[:10]}") print(f"Database sorted (first 10): {db_sorted[:10]}") else: # We have matches, compare embeddings print(f"\n--- Comparing embeddings for {len(common)} common instruments ---") gen_common = gen_df.filter(pl.col('instrument_code').is_in(list(common))) db_common = db_df.filter(pl.col('instrument').is_in(list(common))) # Join and compare comparison = gen_common.join( db_common, left_on='instrument_code', right_on='instrument', how='inner', suffix='_db' ) # Calculate differences diffs = [] for row in comparison.iter_rows(): # Find indices for the values columns gen_vals_idx = comparison.columns.index('values') db_vals_idx = comparison.columns.index('values_db') gen_emb = np.array(row[gen_vals_idx]) db_emb = np.array(row[db_vals_idx]) diff = gen_emb - db_emb diff_norm = np.linalg.norm(diff) rel_diff = diff_norm / (np.linalg.norm(db_emb) + 1e-10) diffs.append({ 'instrument': row[comparison.columns.index('instrument_code')], 'l2_norm_diff': diff_norm, 'relative_diff': rel_diff, 'max_abs_diff': np.max(np.abs(diff)), 'gen_emb_norm': np.linalg.norm(gen_emb), 'db_emb_norm': np.linalg.norm(db_emb) }) if diffs: diff_df = pl.DataFrame(diffs) print("\nDifference statistics:") print(diff_df.select(['l2_norm_diff', 'relative_diff', 'max_abs_diff']).describe()) max_rel_diff = diff_df['relative_diff'].max() print(f"\nMax relative difference: {max_rel_diff:.6e}") if max_rel_diff < 1e-5: print("āœ“ Embeddings match within numerical precision!") elif max_rel_diff < 0.01: print("~ Embeddings are very similar") else: print("āœ— Embeddings differ significantly") # Show some comparison samples print("\nSample comparison:") for i in range(min(5, len(diffs))): d = diffs[i] print(f" {d['instrument']}: gen_norm={d['gen_emb_norm']:.4f}, " f"db_norm={d['db_emb_norm']:.4f}, rel_diff={d['relative_diff']:.6e}") def calculate_correlation(date_int: int): """Calculate correlation between generated and database embeddings.""" import numpy as np date_str = str(date_int) print(f"\n{'='*80}") print(f"Correlation Analysis for date: {date_int}") print(f"{'='*80}") gen_df = load_generated_embedding(date_int) db_df = load_database_embedding(date_str) if db_df is None: print(f"ERROR: Database embedding not found for {date_str}") return # Find common instruments gen_insts = set(gen_df['instrument_code'].to_list()) db_insts = set(db_df['instrument'].to_list()) common = list(gen_insts & db_insts) print(f"\nCommon instruments: {len(common)}") if len(common) == 0: print("No common instruments found!") return # Filter to common and sort gen_common = gen_df.filter(pl.col('instrument_code').is_in(common)).sort('instrument_code') db_common = db_df.filter(pl.col('instrument').is_in(common)).sort('instrument') # Extract embedding matrices gen_embs = np.array(gen_common['values'].to_list()) db_embs = np.array(db_common['values'].to_list()) print(f"Generated embeddings shape: {gen_embs.shape}") print(f"Database embeddings shape: {db_embs.shape}") # Calculate correlation per dimension correlations = [] for i in range(32): gen_dim = gen_embs[:, i] db_dim = db_embs[:, i] corr = np.corrcoef(gen_dim, db_dim)[0, 1] correlations.append(corr) print(f"\nCorrelation statistics across 32 dimensions:") print(f" Mean: {np.mean(correlations):.4f}") print(f" Median: {np.median(correlations):.4f}") print(f" Min: {np.min(correlations):.4f}") print(f" Max: {np.max(correlations):.4f}") # Overall correlation overall_corr = np.corrcoef(gen_embs.flatten(), db_embs.flatten())[0, 1] print(f"\nOverall correlation (all dims flattened): {overall_corr:.4f}") # Interpretation mean_corr = np.mean(correlations) if abs(mean_corr) < 0.1: print("\nāœ— CONCLUSION: Embeddings are NOT correlated (essentially independent)") elif abs(mean_corr) < 0.5: print("\n~ CONCLUSION: Weak correlation between embeddings") else: print(f"\nāœ“ CONCLUSION: {'Strong' if abs(mean_corr) > 0.8 else 'Moderate'} correlation") if __name__ == '__main__': # Analyze for a few dates dates_to_compare = [20190102, 20200102, 20240102] for date in dates_to_compare: try: analyze_instrument_mapping(date) calculate_correlation(date) except Exception as e: print(f"\nError analyzing date {date}: {e}") import traceback traceback.print_exc()