Skip to content

53_Matplotlib_And_Seaborn_For_Data_Visualization

Category: AI & Data Science Tools
Type: AI/ML Tool or Library
Generated on: 2025-08-26 11:08:14
For: Data Science, Machine Learning & Technical Interviews


Matplotlib & Seaborn Cheatsheet for Data Visualization

Section titled “Matplotlib & Seaborn Cheatsheet for Data Visualization”

This cheatsheet provides a comprehensive guide to Matplotlib and Seaborn, focusing on their use in AI/ML workflows.

1. Tool/Library Overview

  • Matplotlib: Python’s foundational plotting library. Provides a low-level interface for creating static, animated, and interactive visualizations. Highly customizable, but can be verbose.
    • Use Cases: Basic plotting, creating custom visualizations, building plots from scratch, embedding plots in applications.
  • Seaborn: A high-level statistical data visualization library built on top of Matplotlib. Simplifies creating informative and aesthetically pleasing statistical graphics.
    • Use Cases: Exploring relationships between variables, visualizing distributions, creating categorical plots, enhancing Matplotlib plots with statistical context.

2. Installation & Setup

Terminal window
# Install Matplotlib
pip install matplotlib
# Install Seaborn
pip install seaborn

Basic Import:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np # For example data

3. Core Features & API

3.1 Matplotlib

  • plt.figure(): Creates a new figure (the overall canvas).
    • figsize=(width, height): Sets the figure size in inches.
    • dpi: Sets dots per inch (resolution).
  • plt.subplot() / plt.subplots(): Adds one or more subplots to the figure.
    • plt.subplot(nrows, ncols, index)
    • fig, axes = plt.subplots(nrows, ncols): Returns the figure and an array of axes objects.
  • Plotting functions:
    • plt.plot(x, y, format_string, **kwargs): Creates a line plot.
      • x, y: Data for the x and y axes.
      • format_string: Controls line style, marker, and color (e.g., ‘r-’, ‘bo—’).
      • label: Label for the plot, used in legends.
      • linewidth: Line width.
      • alpha: Transparency.
    • plt.scatter(x, y, s=size, c=color, marker=marker, alpha=alpha): Creates a scatter plot.
      • s: Marker size.
      • c: Marker color (can be a single color or an array for a colormap).
      • marker: Marker style (e.g., ‘o’, ‘x’, ’^’).
    • plt.bar(x, height, width=0.8, bottom=None, align='center', **kwargs): Creates a bar chart.
      • x: X-coordinates of the bars.
      • height: Heights of the bars.
      • width: Width of the bars.
    • plt.hist(x, bins=num_bins, range=(xmin, xmax), density=True, **kwargs): Creates a histogram.
      • x: Data to be histogrammed.
      • bins: Number of bins.
      • density: If True, normalizes the histogram to a probability density.
    • plt.imshow(data, cmap='viridis', aspect='auto', interpolation='nearest'): Displays an image or 2D array.
      • data: 2D array of pixel values.
      • cmap: Colormap (e.g., ‘viridis’, ‘gray’, ‘jet’).
      • aspect: Aspect ratio (‘auto’, ‘equal’).
      • interpolation: Interpolation method (e.g., ‘nearest’, ‘bilinear’).
  • Axes customization:
    • plt.xlabel(label): Sets the x-axis label.
    • plt.ylabel(label): Sets the y-axis label.
    • plt.title(title): Sets the plot title.
    • plt.xlim(xmin, xmax): Sets the x-axis limits.
    • plt.ylim(ymin, ymax): Sets the y-axis limits.
    • plt.xticks(ticks, labels): Sets x-axis tick locations and labels.
    • plt.yticks(ticks, labels): Sets y-axis tick locations and labels.
    • ax.spines['top'].set_visible(False): Removes the top spine of an axis (similar for ‘right’, ‘bottom’, ‘left’).
    • ax.tick_params(axis='x', labelsize=10): Customize tick appearance.
  • Legends and annotations:
    • plt.legend(loc='best'): Adds a legend to the plot.
      • loc: Legend location (e.g., ‘best’, ‘upper right’, ‘lower left’).
    • plt.text(x, y, text, fontsize=12, ha='center', va='center'): Adds text to the plot.
      • x, y: Coordinates of the text.
      • ha: Horizontal alignment.
      • va: Vertical alignment.
    • plt.annotate(text, xy=(x, y), xytext=(x_offset, y_offset), arrowprops=dict(arrowstyle='->')): Adds an annotation with an arrow.
  • Saving plots:
    • plt.savefig('filename.png', dpi=300, bbox_inches='tight'): Saves the plot to a file.
      • dpi: Resolution.
      • bbox_inches='tight': Removes extra whitespace around the plot.
  • Displaying plots:
    • plt.show(): Displays the plot.

3.2 Seaborn

  • Distribution Plots:
    • sns.distplot(x, bins=None, kde=True, rug=False, hist=True): Combines a histogram, KDE (Kernel Density Estimate), and rug plot.
      • x: Data to be plotted.
      • bins: Number of bins in the histogram.
      • kde: Whether to plot the KDE.
      • rug: Whether to plot the rug plot.
      • hist: Whether to plot the histogram.
    • sns.kdeplot(x, shade=True): Creates a kernel density estimate plot.
    • sns.rugplot(x): Creates a rug plot (marks data points along the axis).
  • Categorical Plots:
    • sns.barplot(x, y, data, estimator=np.mean, ci='sd', order=None): Creates a bar plot showing the mean of a variable for different categories.
      • x: Column name for categorical variable.
      • y: Column name for numerical variable.
      • data: DataFrame.
      • estimator: Function to use for estimating the central tendency (default: np.mean).
      • ci: Confidence interval size (e.g., ‘sd’, 95). None to disable.
      • order: Order to plot the categories.
    • sns.countplot(x, data, order=None): Creates a count plot showing the number of occurrences of each category.
    • sns.boxplot(x, y, data, hue=None, orient='v'): Creates a box plot showing the distribution of a variable for different categories.
      • hue: Another categorical variable to split the data by.
      • orient: Orientation (‘v’ for vertical, ‘h’ for horizontal).
    • sns.violinplot(x, y, data, hue=None, split=False, inner='quartile'): Creates a violin plot, similar to a box plot but showing the KDE.
      • split: If True, splits the violins for each hue.
      • inner: Representation of the data within the violin (‘quartile’, ‘box’, ‘point’, ‘stick’).
    • sns.stripplot(x, y, data, jitter=True, dodge=False): Creates a scatter plot showing the distribution of data points for each category.
      • jitter: Adds random noise to the points to avoid overplotting.
      • dodge: If True, separates points for each hue.
    • sns.swarmplot(x, y, data, dodge=False): Creates a swarm plot, similar to a strip plot but positions points to avoid overlapping.
  • Relational Plots:
    • sns.scatterplot(x, y, data, hue=None, size=None, style=None): Creates a scatter plot to visualize the relationship between two variables.
      • hue: Another variable to color the points by.
      • size: Another variable to size the points by.
      • style: Another variable to change the marker style by.
    • sns.lineplot(x, y, data, hue=None, style=None): Creates a line plot to visualize the relationship between two variables.
  • Regression Plots:
    • sns.regplot(x, y, data, ci=95, scatter=True, fit_reg=True): Creates a scatter plot with a regression line.
      • ci: Confidence interval size.
      • scatter: Whether to plot the scatter points.
      • fit_reg: Whether to fit the regression line.
    • sns.lmplot(x, y, data, hue=None, col=None, row=None): Creates a regression plot with faceting (splitting the plot into multiple subplots).
      • col: Column to facet by.
      • row: Row to facet by.
  • Matrix Plots:
    • sns.heatmap(data, annot=True, cmap='viridis', fmt=".2f"): Creates a heatmap to visualize a matrix of values.
      • data: 2D array of data.
      • annot: Whether to display the values in each cell.
      • cmap: Colormap.
      • fmt: Format string for the annotations.
    • sns.clustermap(data, cmap='viridis', standard_scale=None, method='ward'): Creates a clustered heatmap.
      • standard_scale: Whether to standardize the data along rows (0) or columns (1).
      • method: Linkage method for clustering (e.g., ‘ward’, ‘average’, ‘complete’).
    • sns.pairplot(data, hue=None, diag_kind='hist'): Creates a matrix of scatter plots showing the relationships between all pairs of variables in a DataFrame.
      • hue: Variable to color the points by.
      • diag_kind: Kind of plot to show on the diagonal (‘hist’, ‘kde’).
  • Style and Aesthetics:
    • sns.set_style(style, rc=None): Sets the overall style of the plots.
      • style: ‘darkgrid’, ‘whitegrid’, ‘dark’, ‘white’, ‘ticks’.
      • rc: Dictionary of Matplotlib parameters to override.
    • sns.set_palette(palette, n_colors=None, desat=None): Sets the color palette.
      • palette: ‘deep’, ‘muted’, ‘pastel’, ‘bright’, ‘dark’, ‘colorblind’, or a list of colors.
    • sns.despine(fig=None, ax=None, top=True, right=True, left=False, bottom=False, offset=None, trim=False): Removes spines from the plot.
    • sns.color_palette(palette, n_colors=None, desat=None): Returns a list of colors from a palette.

4. Practical Examples

4.1 Matplotlib Examples

# Example Data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
# Basic Line Plot
plt.figure(figsize=(8, 6)) # Set figure size
plt.plot(x, y1, label='sin(x)')
plt.plot(x, y2, label='cos(x)')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Sine and Cosine Waves')
plt.legend()
plt.grid(True) # Add grid lines
plt.savefig('sine_cosine.png')
plt.show()

Output: A plot showing sine and cosine waves.

# Scatter Plot
plt.figure(figsize=(8, 6))
plt.scatter(x, y1, c='red', marker='o', label='sin(x) points')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.title('Scatter Plot of Sine Wave')
plt.legend()
plt.show()

Output: A scatter plot of the sine wave.

# Bar Chart
categories = ['A', 'B', 'C', 'D']
values = [25, 40, 30, 50]
plt.figure(figsize=(8, 6))
plt.bar(categories, values, color='skyblue')
plt.xlabel('Category')
plt.ylabel('Value')
plt.title('Bar Chart of Values')
plt.show()

Output: A bar chart showing the values for each category.

# Histogram
data = np.random.randn(1000) # Generate random data
plt.figure(figsize=(8, 6))
plt.hist(data, bins=30, color='lightgreen', edgecolor='black')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Histogram of Random Data')
plt.show()

Output: A histogram of the random data.

# Image Plot
image_data = np.random.rand(100, 100) # Generate random image data
plt.figure(figsize=(8, 6))
plt.imshow(image_data, cmap='gray')
plt.colorbar()
plt.title('Random Image')
plt.show()

Output: A grayscale image of random noise.

4.2 Seaborn Examples

# Example DataFrame
data = pd.DataFrame({
'category': ['A', 'A', 'B', 'B', 'C', 'C'],
'value': [10, 15, 20, 25, 30, 35],
'group': ['X', 'Y', 'X', 'Y', 'X', 'Y']
})
# Bar Plot
sns.barplot(x='category', y='value', data=data, hue='group')
plt.title('Bar Plot with Seaborn')
plt.show()

Output: A bar plot showing the values for each category, separated by group.

# Scatter Plot
sns.scatterplot(x='value', y='category', data=data, hue='group', size='value')
plt.title('Scatter Plot with Seaborn')
plt.show()

Output: A scatter plot showing the relationship between value and category, colored by group and sized by value.

# Distribution Plot
sns.displot(data['value'], kde=True)
plt.title('Distribution Plot with Seaborn')
plt.show()

Output: A distribution plot showing the distribution of the ‘value’ column.

# Heatmap (Correlation Matrix)
correlation_matrix = data[['value']].corr() # Example: correlation of a single column
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm')
plt.title('Correlation Heatmap')
plt.show()

Output: A heatmap of the correlation matrix.

# Pair Plot
iris = sns.load_dataset('iris') # load iris dataset
sns.pairplot(iris, hue='species')
plt.show()

Output: A matrix of scatter plots showing the relationships between all pairs of variables in the Iris dataset, colored by species.

5. Advanced Usage

  • Customizing Matplotlib Plots with Seaborn Styles:
sns.set_style('whitegrid') # Set seaborn style
fig, ax = plt.subplots(figsize=(8, 6)) # Create a matplotlib figure and axes
ax.plot(x, y1, label='sin(x)') # Plot with matplotlib
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('Sine Wave with Seaborn Style')
ax.legend()
plt.show()
  • Creating Facetted Plots with Seaborn:
tips = sns.load_dataset('tips')
sns.catplot(x='day', y='total_bill', hue='sex', col='smoker', data=tips, kind='bar')
plt.show()
  • Using Custom Color Palettes:
custom_palette = ['#e41a1c', '#377eb8', '#4daf4a'] # Define a custom color palette
sns.set_palette(custom_palette) # Set the palette
sns.barplot(x='category', y='value', data=data, hue='group')
plt.show()
  • Working with Subplots:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 8))
# Plotting on each subplot
sns.histplot(data['value'], ax=axes[0, 0], kde=True)
sns.scatterplot(x='value', y='category', data=data, hue='group', ax=axes[0, 1])
sns.barplot(x='category', y='value', data=data, hue='group', ax=axes[1, 0])
sns.boxplot(x='category', y='value', data=data, hue='group', ax=axes[1, 1])
plt.tight_layout() # Adjust subplot parameters for a tight layout.
plt.show()
  • Performance Considerations:

    • For large datasets, consider using vectorized operations and avoid looping when possible.
    • Limit the number of data points plotted in scatter plots to avoid overplotting. Consider using hexbin plots for large datasets.
    • Use appropriate bin sizes for histograms.
    • When saving plots, choose a suitable file format and resolution (DPI). Vector formats (e.g. SVG, PDF) are preferred for scalability.

6. Tips & Tricks

  • Tab Completion: Use tab completion in your IDE to explore the available functions and parameters.
  • Help Documentation: Use help(plt.plot) or help(sns.barplot) to access the documentation for a specific function.
  • Online Examples: Search online for examples of specific plot types to get inspiration and code snippets.
  • Customizing Styles: Create custom style sheets to apply consistent styling to your plots. You can load these with plt.style.use('my_custom_style.mplstyle').
  • Interactive Plots: For interactive plots, consider using libraries like Plotly or Bokeh, which integrate well with Matplotlib and Seaborn data.
  • Use plt.tight_layout() to prevent labels from overlapping.
  • Use plt.clf() or plt.cla() to clear figure or axes.

7. Integration

  • Pandas: Seaborn is designed to work seamlessly with Pandas DataFrames. Most Seaborn plotting functions accept a DataFrame as input and column names as arguments.
  • NumPy: Matplotlib and Seaborn use NumPy arrays for data storage and manipulation.
  • Scikit-learn: Visualize model performance and results using Matplotlib and Seaborn.
# Example: Visualizing Scikit-learn model results
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
# Sample data
X, y = iris.drop('species', axis=1), iris['species']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Train a logistic regression model
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)
# Predict on the test set
y_pred = model.predict(X_test)
# Create a confusion matrix
cm = confusion_matrix(y_test, y_pred, labels=iris['species'].unique())
# Visualize the confusion matrix using Seaborn
sns.heatmap(cm, annot=True, cmap='Blues', xticklabels=iris['species'].unique(), yticklabels=iris['species'].unique())
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

8. Further Resources