Visualization Module
Comprehensive ML visualization toolkit with interactive dashboards, model analysis, and real-time monitoring capabilities.
Overview
The Visualization module provides comprehensive tools for data exploration, model analysis, and interactive dashboard creation. It supports both static plots and interactive visualizations.
MLVisualizer
class MLVisualizer(style: str = 'seaborn-v0_8', figsize: Tuple[int, int] = (12, 8))
Comprehensive ML visualization toolkit.
Key Features
- Data exploration plots
- Model comparison charts
- Performance metrics visualization
- Feature importance plots
- Learning curves
Methods
plot_data_overview(df: pd.DataFrame, save_path: str = None)
Create comprehensive data overview visualization.
Parameters:
df: Input DataFramesave_path: Path to save plot
plot_model_comparison(model_results: Dict[str, Dict[str, float]], metric: str = 'accuracy', save_path: str = None)
Compare multiple models performance.
Parameters:
model_results: Model results dictionarymetric: Metric to comparesave_path: Path to save plot
plot_feature_importance(model: Any, feature_names: List[str], save_path: str = None)
Plot feature importance.
Parameters:
model: Trained modelfeature_names: Feature namessave_path: Path to save plot
plot_learning_curve(model: Any, X: pd.DataFrame, y: pd.Series, save_path: str = None)
Plot learning curve.
Parameters:
model: Model to evaluateX: Feature matrixy: Target vectorsave_path: Path to save plot
DashboardBuilder
class DashboardBuilder(title: str = "ML Dashboard", layout: str = "wide")
Interactive dashboard builder for ML experiments.
Key Features
- Streamlit integration
- Real-time monitoring
- Interactive charts
- Multi-tab layouts
- Custom widgets
Methods
create_data_exploration_dashboard(df: pd.DataFrame)
Create data exploration dashboard.
Parameters:
df: Input DataFrame
create_model_evaluation_dashboard(models: Dict[str, Any], X: pd.DataFrame, y: pd.Series)
Create model evaluation dashboard.
Parameters:
models: Dictionary of trained modelsX: Feature matrixy: Target vector
run_dashboard(dashboard_type: str = "experiment", **kwargs)
Run the dashboard.
Parameters:
dashboard_type: Type of dashboard**kwargs: Additional arguments
Examples
Data Visualization
python
from ai_ml_framework.visualization import MLVisualizer
import pandas as pd
# Load data
df = pd.read_csv('your_data.csv')
# Create visualizer
visualizer = MLVisualizer(style='seaborn-v0_8', figsize=(12, 8))
# Data overview
visualizer.plot_data_overview(df, save_path='data_overview.png')
# Distribution plots
visualizer.plot_distributions(df, save_path='distributions.png')
# Correlation heatmap
visualizer.plot_correlation_matrix(df, save_path='correlation.png')
# Missing values analysis
visualizer.plot_missing_values(df, save_path='missing_values.png')
# Outlier detection
visualizer.plot_outliers(df, save_path='outliers.png')
print("Data visualization complete!")
Model Comparison
python
# Model results
model_results = {
'Random Forest': {
'accuracy': 0.92,
'precision': 0.91,
'recall': 0.93,
'f1': 0.92
},
'XGBoost': {
'accuracy': 0.94,
'precision': 0.93,
'recall': 0.95,
'f1': 0.94
},
'LightGBM': {
'accuracy': 0.93,
'precision': 0.92,
'recall': 0.94,
'f1': 0.93
}
}
# Compare models
visualizer.plot_model_comparison(model_results, metric='accuracy')
visualizer.plot_model_comparison(model_results, metric='precision')
visualizer.plot_model_comparison(model_results, metric='recall')
visualizer.plot_model_comparison(model_results, metric='f1')
# Create radar chart for all metrics
visualizer.plot_model_radar_chart(model_results, save_path='model_radar.png')
# Feature importance
best_model = models['XGBoost']
feature_names = df.columns.drop('target')
visualizer.plot_feature_importance(best_model, feature_names, save_path='feature_importance.png')
Learning Curves
python
# Learning curves for model analysis
from sklearn.model_selection import learning_curve
# Plot learning curve
visualizer.plot_learning_curve(
best_model,
X_train,
y_train,
save_path='learning_curve.png'
)
# Plot validation curve
visualizer.plot_validation_curve(
best_model,
X_train,
y_train,
param_name='n_estimators',
param_range=[50, 100, 150, 200],
save_path='validation_curve.png'
)
# Plot confusion matrix
visualizer.plot_confusion_matrix(
best_model,
X_test,
y_test,
save_path='confusion_matrix.png'
)
# Plot ROC curve
visualizer.plot_roc_curve(
best_model,
X_test,
y_test,
save_path='roc_curve.png'
)
Interactive Dashboard
python
from ai_ml_framework.visualization import DashboardBuilder
# Create dashboard builder
dashboard = DashboardBuilder(
title="ML Experiment Dashboard",
layout="wide"
)
# Create data exploration dashboard
dashboard.create_data_exploration_dashboard(df)
# Create model evaluation dashboard
dashboard.create_model_evaluation_dashboard(models, X_test, y_test)
# Create real-time monitoring dashboard
dashboard.create_real_time_monitoring_dashboard('metrics.db')
# Add custom widgets
dashboard.add_custom_metric_widget("Custom Metric", lambda: 0.95)
dashboard.add_alert_widget("Performance Alert", threshold=0.8)
# Run dashboard
dashboard.run_dashboard(
dashboard_type="experiment",
port=8501,
host="0.0.0.0"
)
print("Dashboard running at http://localhost:8501")
Advanced Visualizations
python
# Advanced visualization techniques
# SHAP values for model interpretability
visualizer.plot_shap_values(
best_model,
X_test,
feature_names,
save_path='shap_values.png'
)
# Partial dependence plots
visualizer.plot_partial_dependence(
best_model,
X_test,
feature_names,
save_path='partial_dependence.png'
)
# Residual plots for regression
visualizer.plot_residuals(
regression_model,
X_test,
y_test,
save_path='residuals.png'
)
# Time series visualization
if time_series_data:
visualizer.plot_time_series(
time_series_data,
save_path='time_series.png'
)
# Clustering visualization
if clustering_labels:
visualizer.plot_clusters(
X_pca,
clustering_labels,
save_path='clusters.png'
)
# Create comprehensive report
visualizer.create_visualization_report(
df=df,
models=models,
X_test=X_test,
y_test=y_test,
output_path='ml_report.html'
)
print("Advanced visualizations complete!")
Real-time Monitoring
python
# Real-time monitoring dashboard
import time
import sqlite3
# Create monitoring database
conn = sqlite3.connect('metrics.db')
cursor = conn.cursor()
# Create metrics table
cursor.execute('''
CREATE TABLE IF NOT EXISTS metrics (
timestamp TEXT,
model_name TEXT,
accuracy REAL,
latency REAL,
requests INTEGER
)
''')
# Simulate real-time data
def simulate_metrics():
import random
return {
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
'model_name': 'XGBoost',
'accuracy': random.uniform(0.85, 0.95),
'latency': random.uniform(10, 50),
'requests': random.randint(100, 1000)
}
# Insert sample data
for _ in range(100):
metrics = simulate_metrics()
cursor.execute('''
INSERT INTO metrics VALUES (?, ?, ?, ?, ?)
''', (metrics['timestamp'], metrics['model_name'],
metrics['accuracy'], metrics['latency'], metrics['requests']))
conn.commit()
# Create real-time dashboard
dashboard = DashboardBuilder()
dashboard.create_real_time_monitoring_dashboard('metrics.db')
# Add alerts
dashboard.add_alert_widget(
"Low Accuracy Alert",
lambda: cursor.execute('SELECT accuracy FROM metrics ORDER BY timestamp DESC LIMIT 1').fetchone()[0] < 0.90
)
dashboard.add_alert_widget(
"High Latency Alert",
lambda: cursor.execute('SELECT latency FROM metrics ORDER BY timestamp DESC LIMIT 1').fetchone()[0] > 40
)
# Run monitoring dashboard
dashboard.run_dashboard(
dashboard_type="monitoring",
port=8502,
refresh_interval=5
)
print("Monitoring dashboard running at http://localhost:8502")