You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
212 lines
7.1 KiB
212 lines
7.1 KiB
|
4 days ago
|
#!/usr/bin/env python
|
||
|
|
"""
|
||
|
|
Fetch original 0_7 predictions from DolphinDB and save to parquet.
|
||
|
|
|
||
|
|
This script:
|
||
|
|
1. Connects to DolphinDB
|
||
|
|
2. Queries the app_1day_multicast_longsignal_port table
|
||
|
|
3. Filters for version 'host140_exp20_d033'
|
||
|
|
4. Transforms columns (m_nDate -> datetime, code -> instrument)
|
||
|
|
5. Saves to local parquet file
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import polars as pl
|
||
|
|
import pandas as pd
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
# DolphinDB config (from CLAUDE.md)
|
||
|
|
DDB_CONFIG = {
|
||
|
|
"host": "192.168.1.146",
|
||
|
|
"port": 8848,
|
||
|
|
"username": "admin",
|
||
|
|
"password": "123456"
|
||
|
|
}
|
||
|
|
|
||
|
|
TABLE_PATH = "dfs://daily_stock_run_multicast/app_1day_multicast_longsignal_port"
|
||
|
|
VERSION = "host140_exp20_d033"
|
||
|
|
OUTPUT_FILE = "../data/original_predictions_0_7.parquet"
|
||
|
|
|
||
|
|
|
||
|
|
def datetime_to_uint32(dt) -> int:
|
||
|
|
"""Convert datetime to YYYYMMDD uint32 format."""
|
||
|
|
if isinstance(dt, (int, float)):
|
||
|
|
return int(dt)
|
||
|
|
if hasattr(dt, 'strftime'):
|
||
|
|
return int(dt.strftime('%Y%m%d'))
|
||
|
|
return int(dt)
|
||
|
|
|
||
|
|
|
||
|
|
def tscode_to_uint32(code) -> int:
|
||
|
|
"""Convert TS code (e.g., '000001.SZ') to uint32 instrument code."""
|
||
|
|
if isinstance(code, int):
|
||
|
|
return code
|
||
|
|
# Remove exchange suffix and leading zeros
|
||
|
|
code_str = str(code).split('.')[0]
|
||
|
|
return int(code_str)
|
||
|
|
|
||
|
|
|
||
|
|
def fetch_original_predictions(
|
||
|
|
start_date: Optional[str] = None,
|
||
|
|
end_date: Optional[str] = None,
|
||
|
|
output_file: str = OUTPUT_FILE
|
||
|
|
) -> pl.DataFrame:
|
||
|
|
"""
|
||
|
|
Fetch original 0_7 predictions from DolphinDB.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
start_date: Optional start date filter (YYYY-MM-DD)
|
||
|
|
end_date: Optional end date filter (YYYY-MM-DD)
|
||
|
|
output_file: Output parquet file path
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Polars DataFrame with columns: [datetime, instrument, prediction]
|
||
|
|
"""
|
||
|
|
print("Fetching original 0_7 predictions from DolphinDB...")
|
||
|
|
print(f"Table: {TABLE_PATH}")
|
||
|
|
print(f"Version: {VERSION}")
|
||
|
|
|
||
|
|
# Connect to DolphinDB
|
||
|
|
try:
|
||
|
|
from qshare.io.ddb import get_ddb_sess
|
||
|
|
sess = get_ddb_sess(host=DDB_CONFIG["host"], port=DDB_CONFIG["port"])
|
||
|
|
print(f"Connected to DolphinDB at {DDB_CONFIG['host']}:{DDB_CONFIG['port']}")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Error connecting to DolphinDB: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
# Build SQL query using DolphinDB syntax
|
||
|
|
# Need to load the table via database() first using dfs:// path
|
||
|
|
db_path, table_name = TABLE_PATH.replace("dfs://", "").split("/", 1)
|
||
|
|
|
||
|
|
# Use DolphinDB's SQL syntax with loadTable and dfs://
|
||
|
|
sql = f"""
|
||
|
|
select * from loadTable("dfs://{db_path}", "{table_name}")
|
||
|
|
"""
|
||
|
|
|
||
|
|
# We'll filter in Python after loading since DolphinDB's SQL syntax
|
||
|
|
# for partitioned tables can be tricky
|
||
|
|
print(f"Executing SQL: {sql.strip()}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Execute query and get pandas DataFrame
|
||
|
|
df_full = sess.run(sql)
|
||
|
|
print(f"Fetched {len(df_full)} total rows from DolphinDB")
|
||
|
|
print(f"Columns: {df_full.columns.tolist()}")
|
||
|
|
print(f"Sample:\n{df_full.head()}")
|
||
|
|
print(f"Version values: {df_full['version'].unique()[:10] if 'version' in df_full.columns else 'N/A'}")
|
||
|
|
|
||
|
|
# Filter for version in Python
|
||
|
|
# Version string contains additional parameters, use startswith
|
||
|
|
if 'version' in df_full.columns:
|
||
|
|
df_pd = df_full[df_full['version'].str.startswith(VERSION)]
|
||
|
|
print(f"Filtered to {len(df_pd)} rows for version '{VERSION}'")
|
||
|
|
if len(df_pd) > 0:
|
||
|
|
print(f"Matching versions: {df_pd['version'].unique()[:5]}")
|
||
|
|
else:
|
||
|
|
print("Warning: 'version' column not found, using all data")
|
||
|
|
df_pd = df_full
|
||
|
|
|
||
|
|
# Apply date filters if specified
|
||
|
|
# m_nDate is datetime64, convert to YYYYMMDD int for comparison
|
||
|
|
if start_date and 'm_nDate' in df_pd.columns:
|
||
|
|
start_dt = pd.to_datetime(start_date)
|
||
|
|
df_pd = df_pd[df_pd['m_nDate'] >= start_dt]
|
||
|
|
if end_date and 'm_nDate' in df_pd.columns:
|
||
|
|
end_dt = pd.to_datetime(end_date)
|
||
|
|
df_pd = df_pd[df_pd['m_nDate'] <= end_dt]
|
||
|
|
|
||
|
|
print(f"After date filter: {len(df_pd)} rows")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Error executing query: {e}")
|
||
|
|
raise
|
||
|
|
finally:
|
||
|
|
sess.close()
|
||
|
|
|
||
|
|
# Convert to Polars
|
||
|
|
df = pl.from_pandas(df_pd)
|
||
|
|
print(f"Columns in result: {df.columns}")
|
||
|
|
print(f"Sample data:\n{df.head()}")
|
||
|
|
|
||
|
|
# Transform columns
|
||
|
|
# Rename m_nDate -> datetime and convert to uint32
|
||
|
|
df = df.rename({"m_nDate": "datetime"})
|
||
|
|
|
||
|
|
# Handle datetime conversion from datetime[ns] to uint32 (YYYYMMDD)
|
||
|
|
if df["datetime"].dtype == pl.Datetime:
|
||
|
|
df = df.with_columns([
|
||
|
|
pl.col("datetime").dt.strftime("%Y%m%d").cast(pl.UInt32).alias("datetime")
|
||
|
|
])
|
||
|
|
elif df["datetime"].dtype == pl.Date:
|
||
|
|
df = df.with_columns([
|
||
|
|
pl.col("datetime").dt.strftime("%Y%m%d").cast(pl.UInt32).alias("datetime")
|
||
|
|
])
|
||
|
|
elif df["datetime"].dtype in [pl.Utf8, pl.String]:
|
||
|
|
df = df.with_columns([
|
||
|
|
pl.col("datetime").str.replace("-", "").cast(pl.UInt32).alias("datetime")
|
||
|
|
])
|
||
|
|
else:
|
||
|
|
# Already numeric, just cast
|
||
|
|
df = df.with_columns([pl.col("datetime").cast(pl.UInt32).alias("datetime")])
|
||
|
|
|
||
|
|
# Rename code -> instrument and convert to uint32
|
||
|
|
# The code is in format "SH600085" or "SZ000001"
|
||
|
|
df = df.rename({"code": "instrument"})
|
||
|
|
|
||
|
|
# Convert TS code (e.g., 'SH600085') to uint32 by removing prefix and casting
|
||
|
|
df = df.with_columns([
|
||
|
|
pl.col("instrument")
|
||
|
|
.str.replace("SH", "")
|
||
|
|
.str.replace("SZ", "")
|
||
|
|
.str.replace("BJ", "")
|
||
|
|
.cast(pl.UInt32)
|
||
|
|
.alias("instrument")
|
||
|
|
])
|
||
|
|
|
||
|
|
# The prediction column is 'weight' in this table
|
||
|
|
# Rename it to 'prediction' for consistency
|
||
|
|
if 'weight' in df.columns:
|
||
|
|
df = df.rename({'weight': 'prediction'})
|
||
|
|
else:
|
||
|
|
# Fallback: find any numeric column that's not datetime or instrument
|
||
|
|
for col in df.columns:
|
||
|
|
if col not in ['datetime', 'instrument'] and df[col].dtype in [pl.Float32, pl.Float64]:
|
||
|
|
df = df.rename({col: 'prediction'})
|
||
|
|
break
|
||
|
|
|
||
|
|
# Select only the columns we need
|
||
|
|
df = df.select(["datetime", "instrument", "prediction"])
|
||
|
|
|
||
|
|
print(f"\nTransformed data:")
|
||
|
|
print(f" Shape: {df.shape}")
|
||
|
|
print(f" Columns: {df.columns}")
|
||
|
|
print(f" Date range: {df['datetime'].min()} to {df['datetime'].max()}")
|
||
|
|
print(f" Sample:\n{df.head()}")
|
||
|
|
|
||
|
|
# Save to parquet
|
||
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||
|
|
df.write_parquet(output_file)
|
||
|
|
print(f"\nSaved to: {output_file}")
|
||
|
|
|
||
|
|
return df
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import argparse
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser(description="Fetch original 0_7 predictions from DolphinDB")
|
||
|
|
parser.add_argument("--start-date", type=str, default=None, help="Start date (YYYY-MM-DD)")
|
||
|
|
parser.add_argument("--end-date", type=str, default=None, help="End date (YYYY-MM-DD)")
|
||
|
|
parser.add_argument("--output", type=str, default=OUTPUT_FILE, help="Output parquet file")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
df = fetch_original_predictions(
|
||
|
|
start_date=args.start_date,
|
||
|
|
end_date=args.end_date,
|
||
|
|
output_file=args.output
|
||
|
|
)
|
||
|
|
|
||
|
|
print("\nDone!")
|