You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
146 lines
5.0 KiB
146 lines
5.0 KiB
# First, let me create a script to train a VAE model on the 0_7_beta data
|
|
# This would need to be done separately as it's a prerequisite for the prediction script above
|
|
|
|
"""
|
|
Workflow configuration to train a VAE model on alpha158 0_7_beta data.
|
|
This creates a VAE-encoded version of the 0_7_beta factors that can be used
|
|
for prediction comparison with the original 0_7 model.
|
|
"""
|
|
|
|
experiment_name: vae_alpha158_0_7_beta
|
|
|
|
qlib_init:
|
|
provider_uri: "~/.qlib/data_ops/target"
|
|
region: cn
|
|
|
|
load_start: &load_start 2013-01-01
|
|
load_end: &load_end 2023-09-30
|
|
|
|
train_start: &train_start 2013-01-01
|
|
train_end: &train_end 2018-12-31
|
|
|
|
benchmark_name: &benchmark_name SH000985
|
|
market: &market csiallx
|
|
|
|
dataset_cache_path: &dataset_cache_path tasks/artifacts/csiallx_dataset_alpha158_0_7_beta_vae.pkl
|
|
|
|
# DolphinDB configuration
|
|
ddb_config: &ddb_config
|
|
host: 192.168.1.146
|
|
port: 8848
|
|
username: "admin"
|
|
password: "123456"
|
|
|
|
data_handler_config: &data_handler_config
|
|
start_time: *load_start
|
|
end_time: *load_end
|
|
fit_start_time: *train_start
|
|
fit_end_time: *train_end
|
|
instruments: *market
|
|
ddb_config: *ddb_config
|
|
handler_list:
|
|
# Alpha158 0_7_beta features
|
|
- class: DDBZWindDataHandler
|
|
module_path: qlib.contrib.data.ddb_handlers.ddb_wind_handler
|
|
kwargs:
|
|
col_set: "feature"
|
|
query_config:
|
|
- db_path: "dfs://daily_stock_run"
|
|
dtype: "float32"
|
|
field_list: "alpha158" # All alpha158 factors
|
|
table_name: "stg_1day_wind_alpha158_0_7_beta" # Use the beta version
|
|
# Additional handlers as needed
|
|
- class: DDBZWindDataHandler
|
|
module_path: qlib.contrib.data.ddb_handlers.ddb_wind_handler
|
|
kwargs:
|
|
col_set: "risk_factor"
|
|
query_config:
|
|
- db_path: "dfs://daily_stock_run"
|
|
dtype: "float32"
|
|
field_list: ["MarketValue as total_size"]
|
|
table_name: "stg_1day_wind_kline_adjusted"
|
|
- class: DDBZWindDataHandler
|
|
module_path: qlib.contrib.data.ddb_handlers.ddb_indus_flag_handler
|
|
kwargs:
|
|
col_set: "indus_flag"
|
|
query_config:
|
|
- db_path: "dfs://daily_stock_run"
|
|
dtype: "bool"
|
|
field_list: "industry_code_cc.csv"
|
|
table_name: "stg_1day_gds_indus_flag_cc1"
|
|
- class: DDBZWindDataHandler
|
|
module_path: qlib.contrib.data.ddb_handlers.ddb_st_flag_handler
|
|
kwargs:
|
|
col_set: "st_flag"
|
|
query_config:
|
|
- db_path: "dfs://daily_stock_run"
|
|
dtype: "bool"
|
|
field_list: ["ST_Y", "ST_S", "ST_T", "ST_L", "ST_Z", "ST_X"]
|
|
table_name: "stg_1day_wind_st_flag"
|
|
infer_processors:
|
|
- class: FlagToOnehot
|
|
module_path: qlib.contrib.data.processor_flag
|
|
kwargs:
|
|
fields_group: indus_flag
|
|
onehot_group: indus_idx
|
|
- class: FactorNtrlInjector
|
|
module_path: qlib.contrib.data.processor_ntrl
|
|
kwargs:
|
|
fields_group: "feature"
|
|
factor_col: "risk_factor"
|
|
dummy_col: "indus_idx"
|
|
ntrl_type: "size_indus"
|
|
- class: RobustZScoreNorm
|
|
kwargs:
|
|
fields_group: ["feature"]
|
|
clip_outlier: true
|
|
- class: Fillna
|
|
kwargs:
|
|
fields_group: ["feature"]
|
|
|
|
task:
|
|
model:
|
|
class: VAEModel
|
|
module_path: qlib.contrib.model.task.task_vae_flat
|
|
kwargs:
|
|
model_config:
|
|
hidden_size: 32 # Same as the original model for consistency
|
|
nn_module:
|
|
class: VAE
|
|
module_path: qlib.contrib.model.module.module_vae
|
|
kwargs:
|
|
variational: true
|
|
optim_config:
|
|
seed: 1234567
|
|
bootstrap_config: 1.2
|
|
distort_config: 1e-3
|
|
beta: 1e-3 # KL divergence weight
|
|
n_epochs: 300
|
|
early_stop: 10
|
|
lr: 1e-3
|
|
optimizer: adamw
|
|
batch_size: 10000
|
|
n_jobs: 4
|
|
checkpoint:
|
|
save_path: tasks/artifacts/checkpoints/csiallx_alpha158_0_7_beta_vae32
|
|
dataset:
|
|
class: DatasetH
|
|
module_path: qlib.data.dataset
|
|
kwargs:
|
|
config_module: qlib.contrib.data.config
|
|
from_cache: *dataset_cache_path
|
|
require_setup: true
|
|
handler:
|
|
class: AggHandler
|
|
module_path: qlib.contrib.data.agg_handler
|
|
kwargs: *data_handler_config
|
|
segments:
|
|
train: [*train_start, *train_end]
|
|
test: [*load_start, *load_end]
|
|
record:
|
|
- class: SignalRecord
|
|
module_path: qlib.contrib.workflow.record_temp
|
|
kwargs:
|
|
model: <MODEL>
|
|
dataset: <DATASET>
|
|
col_set: "feature" |