#!/usr/bin/env python """ Fetch embedding data from DolphinDB and save to parquet. This script: 1. Connects to DolphinDB 2. Queries the dwm_1day_multicast_csencode table 3. Filters by version (default: 'csiallx_feature2_ntrla_flag_pnlnorm') 4. Filters by date range 5. Transforms columns (m_nDate -> datetime, code -> instrument) 6. Saves to local parquet file """ import os import polars as pl import pandas as pd from datetime import datetime from typing import Optional # DolphinDB config (from CLAUDE.md) DDB_CONFIG = { "host": "192.168.1.146", "port": 8848, "username": "admin", "password": "123456" } DB_PATH = "dfs://daily_stock_run_multicast" TABLE_NAME = "dwm_1day_multicast_csencode" DEFAULT_VERSION = "csix_alpha158b_ext2_zscore_vae4" DEFAULT_START_DATE = "2019-01-01" DEFAULT_END_DATE = "2025-12-31" OUTPUT_FILE = "../data/embeddings_from_ddb.parquet" def fetch_embeddings( start_date: str = DEFAULT_START_DATE, end_date: str = DEFAULT_END_DATE, version: str = DEFAULT_VERSION, output_file: str = OUTPUT_FILE ) -> pl.DataFrame: """ Fetch embedding data from DolphinDB. Args: start_date: Start date filter (YYYY-MM-DD) end_date: End date filter (YYYY-MM-DD) version: Version string to filter by output_file: Output parquet file path Returns: Polars DataFrame with columns: [datetime, instrument, embedding_0, embedding_1, ...] """ print("=" * 60) print("Fetching embedding data from DolphinDB") print("=" * 60) print(f"Database: {DB_PATH}") print(f"Table: {TABLE_NAME}") print(f"Version: {version}") print(f"Date range: {start_date} to {end_date}") # Connect to DolphinDB try: from qshare.io.ddb import get_ddb_sess sess = get_ddb_sess(host=DDB_CONFIG["host"], port=DDB_CONFIG["port"]) print(f"Connected to DolphinDB at {DDB_CONFIG['host']}:{DDB_CONFIG['port']}") except Exception as e: print(f"Error connecting to DolphinDB: {e}") raise # Convert date strings to DolphinDB date format (YYYY.MM.DD) start_ddb = start_date.replace("-", ".") end_ddb = end_date.replace("-", ".") # Build SQL query with filters in the WHERE clause # Note: DolphinDB requires date() function for date literals # Use single-line SQL to avoid parsing issues sql = f'select * from loadTable("{DB_PATH}", "{TABLE_NAME}") where version = "{version}" and m_nDate >= date({start_ddb}) and m_nDate <= date({end_ddb})' print(f"Executing SQL: {sql.strip()}") try: # Execute query and get pandas DataFrame df_pd = sess.run(sql) print(f"Fetched {len(df_pd)} rows from DolphinDB") print(f"Columns: {df_pd.columns.tolist()}") if len(df_pd) > 0: print(f"Sample:\n{df_pd.head()}") except Exception as e: print(f"Error executing query: {e}") raise finally: sess.close() # Convert to Polars df = pl.from_pandas(df_pd) print(f"Columns in result: {df.columns}") # Transform columns # Rename m_nDate -> datetime and convert to uint32 (YYYYMMDD) if 'm_nDate' in df.columns: df = df.rename({"m_nDate": "datetime"}) if df["datetime"].dtype == pl.Datetime: df = df.with_columns([ pl.col("datetime").dt.strftime("%Y%m%d").cast(pl.UInt32).alias("datetime") ]) elif df["datetime"].dtype == pl.Date: df = df.with_columns([ pl.col("datetime").dt.strftime("%Y%m%d").cast(pl.UInt32).alias("datetime") ]) elif df["datetime"].dtype in [pl.Utf8, pl.String]: df = df.with_columns([ pl.col("datetime").str.replace("-", "").cast(pl.UInt32).alias("datetime") ]) else: df = df.with_columns([pl.col("datetime").cast(pl.UInt32).alias("datetime")]) # Rename code -> instrument and convert to uint32 if 'code' in df.columns: df = df.rename({"code": "instrument"}) # Convert TS code (e.g., 'SH600085') to uint32 by removing prefix and casting df = df.with_columns([ pl.col("instrument") .str.replace("SH", "") .str.replace("SZ", "") .str.replace("BJ", "") .cast(pl.UInt32) .alias("instrument") ]) # Drop version column if present (no longer needed) if 'version' in df.columns: df = df.drop('version') # Check if 'values' column contains lists (embedding vectors) if 'values' in df.columns and df['values'].dtype == pl.List: # Get the embedding dimension from the first row first_val = df['values'][0] if first_val is not None: emb_dim = len(first_val) print(f"Detected embedding dimension: {emb_dim}") # Expand the list column to separate embedding columns embedding_cols = [] for i in range(emb_dim): col_name = f"embedding_{i}" embedding_cols.append(col_name) df = df.with_columns([ pl.col('values').list.get(i).alias(col_name) ]) # Drop the original values column df = df.drop('values') # Reorder columns: datetime, instrument, embedding_0, embedding_1, ... core_cols = ['datetime', 'instrument'] final_cols = core_cols + embedding_cols df = df.select(final_cols) print(f"Expanded embeddings into {emb_dim} columns") else: # Identify embedding columns (typically named 'feature_0', 'feature_1', etc. or 'emb_0', 'emb_1', etc.) # Keep datetime, instrument, and any embedding/feature columns core_cols = ['datetime', 'instrument'] embedding_cols = [c for c in df.columns if c not in core_cols + ['version']] # Select and order columns final_cols = core_cols + sorted(embedding_cols) df = df.select(final_cols) print(f"\nTransformed data:") print(f" Shape: {df.shape}") print(f" Columns: {df.columns[:10]}..." if len(df.columns) > 10 else f" Columns: {df.columns}") print(f" Date range: {df['datetime'].min()} to {df['datetime'].max()}") print(f" Instrument count: {df['instrument'].n_unique()}") print(f" Sample:\n{df.head()}") # Save to parquet os.makedirs(os.path.dirname(output_file), exist_ok=True) df.write_parquet(output_file) print(f"\nSaved to: {output_file}") return df if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Fetch embedding data from DolphinDB") parser.add_argument("--start-date", type=str, default=DEFAULT_START_DATE, help="Start date (YYYY-MM-DD)") parser.add_argument("--end-date", type=str, default=DEFAULT_END_DATE, help="End date (YYYY-MM-DD)") parser.add_argument("--version", type=str, default=DEFAULT_VERSION, help="Version string to filter by") parser.add_argument("--output", type=str, default=OUTPUT_FILE, help="Output parquet file") args = parser.parse_args() df = fetch_embeddings( start_date=args.start_date, end_date=args.end_date, version=args.version, output_file=args.output ) print("\nDone!")