You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
3.5 KiB

"""Common plotting utilities for experiments."""
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
def setup_plot_style():
"""Set up default plotting style."""
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10
def plot_ic_series(ic_by_date: pd.Series, title: str = "IC Over Time",
figsize: tuple = (14, 4)) -> plt.Figure:
"""Plot IC time series with rolling mean.
Args:
ic_by_date: Series with datetime index and IC values
title: Plot title
figsize: Figure size
Returns:
Matplotlib figure
"""
fig, ax = plt.subplots(figsize=figsize)
# Plot raw IC
ax.plot(ic_by_date.index, ic_by_date.values, alpha=0.5, color='gray', label='Daily IC')
# Plot rolling mean
rolling = ic_by_date.rolling(20, min_periods=5).mean()
ax.plot(rolling.index, rolling.values, color='blue', linewidth=2, label='20-day MA')
# Add mean line
mean_ic = ic_by_date.mean()
ax.axhline(y=mean_ic, color='red', linestyle='--',
label=f'Mean IC: {mean_ic:.4f}')
ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
ax.set_title(title)
ax.set_xlabel('Date')
ax.set_ylabel('Information Coefficient')
ax.legend(loc='upper right')
plt.tight_layout()
return fig
def plot_cumulative_returns(returns: pd.Series, title: str = "Cumulative Returns",
figsize: tuple = (12, 6)) -> plt.Figure:
"""Plot cumulative returns.
Args:
returns: Series with datetime index and daily returns
title: Plot title
figsize: Figure size
Returns:
Matplotlib figure
"""
fig, ax = plt.subplots(figsize=figsize)
cumulative = (1 + returns).cumprod()
ax.plot(cumulative.index, cumulative.values, linewidth=1.5)
ax.set_title(title)
ax.set_xlabel('Date')
ax.set_ylabel('Cumulative Return')
ax.set_yscale('log')
# Add final return annotation
final_return = cumulative.iloc[-1] - 1
ax.annotate(f'{final_return:.2%}',
xy=(cumulative.index[-1], cumulative.iloc[-1]),
xytext=(10, 0), textcoords='offset points',
fontsize=10, color='green' if final_return > 0 else 'red')
plt.tight_layout()
return fig
def plot_factor_distribution(factor: pd.Series, title: str = "Factor Distribution",
figsize: tuple = (10, 6)) -> plt.Figure:
"""Plot factor distribution with statistics.
Args:
factor: Series of factor values
title: Plot title
figsize: Figure size
Returns:
Matplotlib figure
"""
fig, axes = plt.subplots(1, 2, figsize=figsize)
# Histogram
axes[0].hist(factor.dropna(), bins=100, alpha=0.7, edgecolor='black')
axes[0].set_title(f'{title} - Distribution')
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Frequency')
# Q-Q plot
from scipy import stats
stats.probplot(factor.dropna(), dist="norm", plot=axes[1])
axes[1].set_title(f'{title} - Q-Q Plot')
# Add statistics text
stats_text = f"Mean: {factor.mean():.4f}\nStd: {factor.std():.4f}\n"
stats_text += f"Skew: {factor.skew():.4f}\nKurt: {factor.kurtosis():.4f}"
axes[0].text(0.95, 0.95, stats_text, transform=axes[0].transAxes,
verticalalignment='top', horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
plt.tight_layout()
return fig