You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

212 lines
7.1 KiB

#!/usr/bin/env python
"""
Fetch original 0_7 predictions from DolphinDB and save to parquet.
This script:
1. Connects to DolphinDB
2. Queries the app_1day_multicast_longsignal_port table
3. Filters for version 'host140_exp20_d033'
4. Transforms columns (m_nDate -> datetime, code -> instrument)
5. Saves to local parquet file
"""
import os
import polars as pl
import pandas as pd
from datetime import datetime
from typing import Optional
# DolphinDB config (from CLAUDE.md)
DDB_CONFIG = {
"host": "192.168.1.146",
"port": 8848,
"username": "admin",
"password": "123456"
}
TABLE_PATH = "dfs://daily_stock_run_multicast/app_1day_multicast_longsignal_port"
VERSION = "host140_exp20_d033"
OUTPUT_FILE = "../data/original_predictions_0_7.parquet"
def datetime_to_uint32(dt) -> int:
"""Convert datetime to YYYYMMDD uint32 format."""
if isinstance(dt, (int, float)):
return int(dt)
if hasattr(dt, 'strftime'):
return int(dt.strftime('%Y%m%d'))
return int(dt)
def tscode_to_uint32(code) -> int:
"""Convert TS code (e.g., '000001.SZ') to uint32 instrument code."""
if isinstance(code, int):
return code
# Remove exchange suffix and leading zeros
code_str = str(code).split('.')[0]
return int(code_str)
def fetch_original_predictions(
start_date: Optional[str] = None,
end_date: Optional[str] = None,
output_file: str = OUTPUT_FILE
) -> pl.DataFrame:
"""
Fetch original 0_7 predictions from DolphinDB.
Args:
start_date: Optional start date filter (YYYY-MM-DD)
end_date: Optional end date filter (YYYY-MM-DD)
output_file: Output parquet file path
Returns:
Polars DataFrame with columns: [datetime, instrument, prediction]
"""
print("Fetching original 0_7 predictions from DolphinDB...")
print(f"Table: {TABLE_PATH}")
print(f"Version: {VERSION}")
# Connect to DolphinDB
try:
from qshare.io.ddb import get_ddb_sess
sess = get_ddb_sess(host=DDB_CONFIG["host"], port=DDB_CONFIG["port"])
print(f"Connected to DolphinDB at {DDB_CONFIG['host']}:{DDB_CONFIG['port']}")
except Exception as e:
print(f"Error connecting to DolphinDB: {e}")
raise
# Build SQL query using DolphinDB syntax
# Need to load the table via database() first using dfs:// path
db_path, table_name = TABLE_PATH.replace("dfs://", "").split("/", 1)
# Use DolphinDB's SQL syntax with loadTable and dfs://
sql = f"""
select * from loadTable("dfs://{db_path}", "{table_name}")
"""
# We'll filter in Python after loading since DolphinDB's SQL syntax
# for partitioned tables can be tricky
print(f"Executing SQL: {sql.strip()}")
try:
# Execute query and get pandas DataFrame
df_full = sess.run(sql)
print(f"Fetched {len(df_full)} total rows from DolphinDB")
print(f"Columns: {df_full.columns.tolist()}")
print(f"Sample:\n{df_full.head()}")
print(f"Version values: {df_full['version'].unique()[:10] if 'version' in df_full.columns else 'N/A'}")
# Filter for version in Python
# Version string contains additional parameters, use startswith
if 'version' in df_full.columns:
df_pd = df_full[df_full['version'].str.startswith(VERSION)]
print(f"Filtered to {len(df_pd)} rows for version '{VERSION}'")
if len(df_pd) > 0:
print(f"Matching versions: {df_pd['version'].unique()[:5]}")
else:
print("Warning: 'version' column not found, using all data")
df_pd = df_full
# Apply date filters if specified
# m_nDate is datetime64, convert to YYYYMMDD int for comparison
if start_date and 'm_nDate' in df_pd.columns:
start_dt = pd.to_datetime(start_date)
df_pd = df_pd[df_pd['m_nDate'] >= start_dt]
if end_date and 'm_nDate' in df_pd.columns:
end_dt = pd.to_datetime(end_date)
df_pd = df_pd[df_pd['m_nDate'] <= end_dt]
print(f"After date filter: {len(df_pd)} rows")
except Exception as e:
print(f"Error executing query: {e}")
raise
finally:
sess.close()
# Convert to Polars
df = pl.from_pandas(df_pd)
print(f"Columns in result: {df.columns}")
print(f"Sample data:\n{df.head()}")
# Transform columns
# Rename m_nDate -> datetime and convert to uint32
df = df.rename({"m_nDate": "datetime"})
# Handle datetime conversion from datetime[ns] to uint32 (YYYYMMDD)
if df["datetime"].dtype == pl.Datetime:
df = df.with_columns([
pl.col("datetime").dt.strftime("%Y%m%d").cast(pl.UInt32).alias("datetime")
])
elif df["datetime"].dtype == pl.Date:
df = df.with_columns([
pl.col("datetime").dt.strftime("%Y%m%d").cast(pl.UInt32).alias("datetime")
])
elif df["datetime"].dtype in [pl.Utf8, pl.String]:
df = df.with_columns([
pl.col("datetime").str.replace("-", "").cast(pl.UInt32).alias("datetime")
])
else:
# Already numeric, just cast
df = df.with_columns([pl.col("datetime").cast(pl.UInt32).alias("datetime")])
# Rename code -> instrument and convert to uint32
# The code is in format "SH600085" or "SZ000001"
df = df.rename({"code": "instrument"})
# Convert TS code (e.g., 'SH600085') to uint32 by removing prefix and casting
df = df.with_columns([
pl.col("instrument")
.str.replace("SH", "")
.str.replace("SZ", "")
.str.replace("BJ", "")
.cast(pl.UInt32)
.alias("instrument")
])
# The prediction column is 'weight' in this table
# Rename it to 'prediction' for consistency
if 'weight' in df.columns:
df = df.rename({'weight': 'prediction'})
else:
# Fallback: find any numeric column that's not datetime or instrument
for col in df.columns:
if col not in ['datetime', 'instrument'] and df[col].dtype in [pl.Float32, pl.Float64]:
df = df.rename({col: 'prediction'})
break
# Select only the columns we need
df = df.select(["datetime", "instrument", "prediction"])
print(f"\nTransformed data:")
print(f" Shape: {df.shape}")
print(f" Columns: {df.columns}")
print(f" Date range: {df['datetime'].min()} to {df['datetime'].max()}")
print(f" Sample:\n{df.head()}")
# Save to parquet
os.makedirs(os.path.dirname(output_file), exist_ok=True)
df.write_parquet(output_file)
print(f"\nSaved to: {output_file}")
return df
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Fetch original 0_7 predictions from DolphinDB")
parser.add_argument("--start-date", type=str, default=None, help="Start date (YYYY-MM-DD)")
parser.add_argument("--end-date", type=str, default=None, help="End date (YYYY-MM-DD)")
parser.add_argument("--output", type=str, default=OUTPUT_FILE, help="Output parquet file")
args = parser.parse_args()
df = fetch_original_predictions(
start_date=args.start_date,
end_date=args.end_date,
output_file=args.output
)
print("\nDone!")