You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
3.5 KiB
120 lines
3.5 KiB
|
3 weeks ago
|
"""Common plotting utilities for experiments."""
|
||
|
|
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
import seaborn as sns
|
||
|
|
import numpy as np
|
||
|
|
import pandas as pd
|
||
|
|
|
||
|
|
|
||
|
|
def setup_plot_style():
|
||
|
|
"""Set up default plotting style."""
|
||
|
|
plt.style.use('seaborn-v0_8-whitegrid')
|
||
|
|
sns.set_palette("husl")
|
||
|
|
plt.rcParams['figure.figsize'] = (12, 6)
|
||
|
|
plt.rcParams['font.size'] = 10
|
||
|
|
|
||
|
|
|
||
|
|
def plot_ic_series(ic_by_date: pd.Series, title: str = "IC Over Time",
|
||
|
|
figsize: tuple = (14, 4)) -> plt.Figure:
|
||
|
|
"""Plot IC time series with rolling mean.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
ic_by_date: Series with datetime index and IC values
|
||
|
|
title: Plot title
|
||
|
|
figsize: Figure size
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Matplotlib figure
|
||
|
|
"""
|
||
|
|
fig, ax = plt.subplots(figsize=figsize)
|
||
|
|
|
||
|
|
# Plot raw IC
|
||
|
|
ax.plot(ic_by_date.index, ic_by_date.values, alpha=0.5, color='gray', label='Daily IC')
|
||
|
|
|
||
|
|
# Plot rolling mean
|
||
|
|
rolling = ic_by_date.rolling(20, min_periods=5).mean()
|
||
|
|
ax.plot(rolling.index, rolling.values, color='blue', linewidth=2, label='20-day MA')
|
||
|
|
|
||
|
|
# Add mean line
|
||
|
|
mean_ic = ic_by_date.mean()
|
||
|
|
ax.axhline(y=mean_ic, color='red', linestyle='--',
|
||
|
|
label=f'Mean IC: {mean_ic:.4f}')
|
||
|
|
ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
|
||
|
|
|
||
|
|
ax.set_title(title)
|
||
|
|
ax.set_xlabel('Date')
|
||
|
|
ax.set_ylabel('Information Coefficient')
|
||
|
|
ax.legend(loc='upper right')
|
||
|
|
|
||
|
|
plt.tight_layout()
|
||
|
|
return fig
|
||
|
|
|
||
|
|
|
||
|
|
def plot_cumulative_returns(returns: pd.Series, title: str = "Cumulative Returns",
|
||
|
|
figsize: tuple = (12, 6)) -> plt.Figure:
|
||
|
|
"""Plot cumulative returns.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
returns: Series with datetime index and daily returns
|
||
|
|
title: Plot title
|
||
|
|
figsize: Figure size
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Matplotlib figure
|
||
|
|
"""
|
||
|
|
fig, ax = plt.subplots(figsize=figsize)
|
||
|
|
|
||
|
|
cumulative = (1 + returns).cumprod()
|
||
|
|
ax.plot(cumulative.index, cumulative.values, linewidth=1.5)
|
||
|
|
|
||
|
|
ax.set_title(title)
|
||
|
|
ax.set_xlabel('Date')
|
||
|
|
ax.set_ylabel('Cumulative Return')
|
||
|
|
ax.set_yscale('log')
|
||
|
|
|
||
|
|
# Add final return annotation
|
||
|
|
final_return = cumulative.iloc[-1] - 1
|
||
|
|
ax.annotate(f'{final_return:.2%}',
|
||
|
|
xy=(cumulative.index[-1], cumulative.iloc[-1]),
|
||
|
|
xytext=(10, 0), textcoords='offset points',
|
||
|
|
fontsize=10, color='green' if final_return > 0 else 'red')
|
||
|
|
|
||
|
|
plt.tight_layout()
|
||
|
|
return fig
|
||
|
|
|
||
|
|
|
||
|
|
def plot_factor_distribution(factor: pd.Series, title: str = "Factor Distribution",
|
||
|
|
figsize: tuple = (10, 6)) -> plt.Figure:
|
||
|
|
"""Plot factor distribution with statistics.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
factor: Series of factor values
|
||
|
|
title: Plot title
|
||
|
|
figsize: Figure size
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Matplotlib figure
|
||
|
|
"""
|
||
|
|
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
||
|
|
|
||
|
|
# Histogram
|
||
|
|
axes[0].hist(factor.dropna(), bins=100, alpha=0.7, edgecolor='black')
|
||
|
|
axes[0].set_title(f'{title} - Distribution')
|
||
|
|
axes[0].set_xlabel('Value')
|
||
|
|
axes[0].set_ylabel('Frequency')
|
||
|
|
|
||
|
|
# Q-Q plot
|
||
|
|
from scipy import stats
|
||
|
|
stats.probplot(factor.dropna(), dist="norm", plot=axes[1])
|
||
|
|
axes[1].set_title(f'{title} - Q-Q Plot')
|
||
|
|
|
||
|
|
# Add statistics text
|
||
|
|
stats_text = f"Mean: {factor.mean():.4f}\nStd: {factor.std():.4f}\n"
|
||
|
|
stats_text += f"Skew: {factor.skew():.4f}\nKurt: {factor.kurtosis():.4f}"
|
||
|
|
axes[0].text(0.95, 0.95, stats_text, transform=axes[0].transAxes,
|
||
|
|
verticalalignment='top', horizontalalignment='right',
|
||
|
|
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||
|
|
|
||
|
|
plt.tight_layout()
|
||
|
|
return fig
|