You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

146 lines
5.0 KiB

# First, let me create a script to train a VAE model on the 0_7_beta data
# This would need to be done separately as it's a prerequisite for the prediction script above
"""
Workflow configuration to train a VAE model on alpha158 0_7_beta data.
This creates a VAE-encoded version of the 0_7_beta factors that can be used
for prediction comparison with the original 0_7 model.
"""
experiment_name: vae_alpha158_0_7_beta
qlib_init:
provider_uri: "~/.qlib/data_ops/target"
region: cn
load_start: &load_start 2013-01-01
load_end: &load_end 2023-09-30
train_start: &train_start 2013-01-01
train_end: &train_end 2018-12-31
benchmark_name: &benchmark_name SH000985
market: &market csiallx
dataset_cache_path: &dataset_cache_path tasks/artifacts/csiallx_dataset_alpha158_0_7_beta_vae.pkl
# DolphinDB configuration
ddb_config: &ddb_config
host: 192.168.1.146
port: 8848
username: "admin"
password: "123456"
data_handler_config: &data_handler_config
start_time: *load_start
end_time: *load_end
fit_start_time: *train_start
fit_end_time: *train_end
instruments: *market
ddb_config: *ddb_config
handler_list:
# Alpha158 0_7_beta features
- class: DDBZWindDataHandler
module_path: qlib.contrib.data.ddb_handlers.ddb_wind_handler
kwargs:
col_set: "feature"
query_config:
- db_path: "dfs://daily_stock_run"
dtype: "float32"
field_list: "alpha158" # All alpha158 factors
table_name: "stg_1day_wind_alpha158_0_7_beta" # Use the beta version
# Additional handlers as needed
- class: DDBZWindDataHandler
module_path: qlib.contrib.data.ddb_handlers.ddb_wind_handler
kwargs:
col_set: "risk_factor"
query_config:
- db_path: "dfs://daily_stock_run"
dtype: "float32"
field_list: ["MarketValue as total_size"]
table_name: "stg_1day_wind_kline_adjusted"
- class: DDBZWindDataHandler
module_path: qlib.contrib.data.ddb_handlers.ddb_indus_flag_handler
kwargs:
col_set: "indus_flag"
query_config:
- db_path: "dfs://daily_stock_run"
dtype: "bool"
field_list: "industry_code_cc.csv"
table_name: "stg_1day_gds_indus_flag_cc1"
- class: DDBZWindDataHandler
module_path: qlib.contrib.data.ddb_handlers.ddb_st_flag_handler
kwargs:
col_set: "st_flag"
query_config:
- db_path: "dfs://daily_stock_run"
dtype: "bool"
field_list: ["ST_Y", "ST_S", "ST_T", "ST_L", "ST_Z", "ST_X"]
table_name: "stg_1day_wind_st_flag"
infer_processors:
- class: FlagToOnehot
module_path: qlib.contrib.data.processor_flag
kwargs:
fields_group: indus_flag
onehot_group: indus_idx
- class: FactorNtrlInjector
module_path: qlib.contrib.data.processor_ntrl
kwargs:
fields_group: "feature"
factor_col: "risk_factor"
dummy_col: "indus_idx"
ntrl_type: "size_indus"
- class: RobustZScoreNorm
kwargs:
fields_group: ["feature"]
clip_outlier: true
- class: Fillna
kwargs:
fields_group: ["feature"]
task:
model:
class: VAEModel
module_path: qlib.contrib.model.task.task_vae_flat
kwargs:
model_config:
hidden_size: 32 # Same as the original model for consistency
nn_module:
class: VAE
module_path: qlib.contrib.model.module.module_vae
kwargs:
variational: true
optim_config:
seed: 1234567
bootstrap_config: 1.2
distort_config: 1e-3
beta: 1e-3 # KL divergence weight
n_epochs: 300
early_stop: 10
lr: 1e-3
optimizer: adamw
batch_size: 10000
n_jobs: 4
checkpoint:
save_path: tasks/artifacts/checkpoints/csiallx_alpha158_0_7_beta_vae32
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
config_module: qlib.contrib.data.config
from_cache: *dataset_cache_path
require_setup: true
handler:
class: AggHandler
module_path: qlib.contrib.data.agg_handler
kwargs: *data_handler_config
segments:
train: [*train_start, *train_end]
test: [*load_start, *load_end]
record:
- class: SignalRecord
module_path: qlib.contrib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
col_set: "feature"