Visualization Module

Comprehensive ML visualization toolkit with interactive dashboards, model analysis, and real-time monitoring capabilities.

Overview

The Visualization module provides comprehensive tools for data exploration, model analysis, and interactive dashboard creation. It supports both static plots and interactive visualizations.

Basic Usage

from ai_ml_framework.visualization import MLVisualizer

visualizer = MLVisualizer()
visualizer.plot_data_overview(df)
visualizer.plot_model_comparison(results)

Interactive Dashboard

from ai_ml_framework.visualization import DashboardBuilder

dashboard = DashboardBuilder()
dashboard.create_data_exploration_dashboard(df)
dashboard.run_dashboard()

MLVisualizer

class MLVisualizer(style: str = 'seaborn-v0_8', figsize: Tuple[int, int] = (12, 8))

Comprehensive ML visualization toolkit.

Key Features

  • Data exploration plots
  • Model comparison charts
  • Performance metrics visualization
  • Feature importance plots
  • Learning curves

Methods

plot_data_overview(df: pd.DataFrame, save_path: str = None)

Create comprehensive data overview visualization.

Parameters:
  • df: Input DataFrame
  • save_path: Path to save plot
plot_model_comparison(model_results: Dict[str, Dict[str, float]], metric: str = 'accuracy', save_path: str = None)

Compare multiple models performance.

Parameters:
  • model_results: Model results dictionary
  • metric: Metric to compare
  • save_path: Path to save plot
plot_feature_importance(model: Any, feature_names: List[str], save_path: str = None)

Plot feature importance.

Parameters:
  • model: Trained model
  • feature_names: Feature names
  • save_path: Path to save plot
plot_learning_curve(model: Any, X: pd.DataFrame, y: pd.Series, save_path: str = None)

Plot learning curve.

Parameters:
  • model: Model to evaluate
  • X: Feature matrix
  • y: Target vector
  • save_path: Path to save plot

DashboardBuilder

class DashboardBuilder(title: str = "ML Dashboard", layout: str = "wide")

Interactive dashboard builder for ML experiments.

Key Features

  • Streamlit integration
  • Real-time monitoring
  • Interactive charts
  • Multi-tab layouts
  • Custom widgets

Methods

create_data_exploration_dashboard(df: pd.DataFrame)

Create data exploration dashboard.

Parameters:
  • df: Input DataFrame
create_model_evaluation_dashboard(models: Dict[str, Any], X: pd.DataFrame, y: pd.Series)

Create model evaluation dashboard.

Parameters:
  • models: Dictionary of trained models
  • X: Feature matrix
  • y: Target vector
run_dashboard(dashboard_type: str = "experiment", **kwargs)

Run the dashboard.

Parameters:
  • dashboard_type: Type of dashboard
  • **kwargs: Additional arguments

Examples

Data Visualization

python
from ai_ml_framework.visualization import MLVisualizer
import pandas as pd

# Load data
df = pd.read_csv('your_data.csv')

# Create visualizer
visualizer = MLVisualizer(style='seaborn-v0_8', figsize=(12, 8))

# Data overview
visualizer.plot_data_overview(df, save_path='data_overview.png')

# Distribution plots
visualizer.plot_distributions(df, save_path='distributions.png')

# Correlation heatmap
visualizer.plot_correlation_matrix(df, save_path='correlation.png')

# Missing values analysis
visualizer.plot_missing_values(df, save_path='missing_values.png')

# Outlier detection
visualizer.plot_outliers(df, save_path='outliers.png')

print("Data visualization complete!")

Model Comparison

python
# Model results
model_results = {
    'Random Forest': {
        'accuracy': 0.92,
        'precision': 0.91,
        'recall': 0.93,
        'f1': 0.92
    },
    'XGBoost': {
        'accuracy': 0.94,
        'precision': 0.93,
        'recall': 0.95,
        'f1': 0.94
    },
    'LightGBM': {
        'accuracy': 0.93,
        'precision': 0.92,
        'recall': 0.94,
        'f1': 0.93
    }
}

# Compare models
visualizer.plot_model_comparison(model_results, metric='accuracy')
visualizer.plot_model_comparison(model_results, metric='precision')
visualizer.plot_model_comparison(model_results, metric='recall')
visualizer.plot_model_comparison(model_results, metric='f1')

# Create radar chart for all metrics
visualizer.plot_model_radar_chart(model_results, save_path='model_radar.png')

# Feature importance
best_model = models['XGBoost']
feature_names = df.columns.drop('target')
visualizer.plot_feature_importance(best_model, feature_names, save_path='feature_importance.png')

Learning Curves

python
# Learning curves for model analysis
from sklearn.model_selection import learning_curve

# Plot learning curve
visualizer.plot_learning_curve(
    best_model, 
    X_train, 
    y_train,
    save_path='learning_curve.png'
)

# Plot validation curve
visualizer.plot_validation_curve(
    best_model,
    X_train,
    y_train,
    param_name='n_estimators',
    param_range=[50, 100, 150, 200],
    save_path='validation_curve.png'
)

# Plot confusion matrix
visualizer.plot_confusion_matrix(
    best_model,
    X_test,
    y_test,
    save_path='confusion_matrix.png'
)

# Plot ROC curve
visualizer.plot_roc_curve(
    best_model,
    X_test,
    y_test,
    save_path='roc_curve.png'
)

Interactive Dashboard

python
from ai_ml_framework.visualization import DashboardBuilder

# Create dashboard builder
dashboard = DashboardBuilder(
    title="ML Experiment Dashboard",
    layout="wide"
)

# Create data exploration dashboard
dashboard.create_data_exploration_dashboard(df)

# Create model evaluation dashboard
dashboard.create_model_evaluation_dashboard(models, X_test, y_test)

# Create real-time monitoring dashboard
dashboard.create_real_time_monitoring_dashboard('metrics.db')

# Add custom widgets
dashboard.add_custom_metric_widget("Custom Metric", lambda: 0.95)
dashboard.add_alert_widget("Performance Alert", threshold=0.8)

# Run dashboard
dashboard.run_dashboard(
    dashboard_type="experiment",
    port=8501,
    host="0.0.0.0"
)

print("Dashboard running at http://localhost:8501")

Advanced Visualizations

python
# Advanced visualization techniques

# SHAP values for model interpretability
visualizer.plot_shap_values(
    best_model, 
    X_test, 
    feature_names,
    save_path='shap_values.png'
)

# Partial dependence plots
visualizer.plot_partial_dependence(
    best_model,
    X_test,
    feature_names,
    save_path='partial_dependence.png'
)

# Residual plots for regression
visualizer.plot_residuals(
    regression_model,
    X_test,
    y_test,
    save_path='residuals.png'
)

# Time series visualization
if time_series_data:
    visualizer.plot_time_series(
        time_series_data,
        save_path='time_series.png'
    )

# Clustering visualization
if clustering_labels:
    visualizer.plot_clusters(
        X_pca,
        clustering_labels,
        save_path='clusters.png'
    )

# Create comprehensive report
visualizer.create_visualization_report(
    df=df,
    models=models,
    X_test=X_test,
    y_test=y_test,
    output_path='ml_report.html'
)

print("Advanced visualizations complete!")

Real-time Monitoring

python
# Real-time monitoring dashboard
import time
import sqlite3

# Create monitoring database
conn = sqlite3.connect('metrics.db')
cursor = conn.cursor()

# Create metrics table
cursor.execute('''
    CREATE TABLE IF NOT EXISTS metrics (
        timestamp TEXT,
        model_name TEXT,
        accuracy REAL,
        latency REAL,
        requests INTEGER
    )
''')

# Simulate real-time data
def simulate_metrics():
    import random
    return {
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        'model_name': 'XGBoost',
        'accuracy': random.uniform(0.85, 0.95),
        'latency': random.uniform(10, 50),
        'requests': random.randint(100, 1000)
    }

# Insert sample data
for _ in range(100):
    metrics = simulate_metrics()
    cursor.execute('''
        INSERT INTO metrics VALUES (?, ?, ?, ?, ?)
    ''', (metrics['timestamp'], metrics['model_name'], 
          metrics['accuracy'], metrics['latency'], metrics['requests']))

conn.commit()

# Create real-time dashboard
dashboard = DashboardBuilder()
dashboard.create_real_time_monitoring_dashboard('metrics.db')

# Add alerts
dashboard.add_alert_widget(
    "Low Accuracy Alert",
    lambda: cursor.execute('SELECT accuracy FROM metrics ORDER BY timestamp DESC LIMIT 1').fetchone()[0] < 0.90
)

dashboard.add_alert_widget(
    "High Latency Alert",
    lambda: cursor.execute('SELECT latency FROM metrics ORDER BY timestamp DESC LIMIT 1').fetchone()[0] > 40
)

# Run monitoring dashboard
dashboard.run_dashboard(
    dashboard_type="monitoring",
    port=8502,
    refresh_interval=5
)

print("Monitoring dashboard running at http://localhost:8502")