53_Matplotlib_And_Seaborn_For_Data_Visualization
Category: AI & Data Science Tools
Type: AI/ML Tool or Library
Generated on: 2025-08-26 11:08:14
For: Data Science, Machine Learning & Technical Interviews
Matplotlib & Seaborn Cheatsheet for Data Visualization
Section titled “Matplotlib & Seaborn Cheatsheet for Data Visualization”This cheatsheet provides a comprehensive guide to Matplotlib and Seaborn, focusing on their use in AI/ML workflows.
1. Tool/Library Overview
- Matplotlib: Python’s foundational plotting library. Provides a low-level interface for creating static, animated, and interactive visualizations. Highly customizable, but can be verbose.
- Use Cases: Basic plotting, creating custom visualizations, building plots from scratch, embedding plots in applications.
- Seaborn: A high-level statistical data visualization library built on top of Matplotlib. Simplifies creating informative and aesthetically pleasing statistical graphics.
- Use Cases: Exploring relationships between variables, visualizing distributions, creating categorical plots, enhancing Matplotlib plots with statistical context.
2. Installation & Setup
# Install Matplotlibpip install matplotlib
# Install Seabornpip install seabornBasic Import:
import matplotlib.pyplot as pltimport seaborn as snsimport pandas as pdimport numpy as np # For example data3. Core Features & API
3.1 Matplotlib
plt.figure(): Creates a new figure (the overall canvas).figsize=(width, height): Sets the figure size in inches.dpi: Sets dots per inch (resolution).
plt.subplot()/plt.subplots(): Adds one or more subplots to the figure.plt.subplot(nrows, ncols, index)fig, axes = plt.subplots(nrows, ncols): Returns the figure and an array of axes objects.
- Plotting functions:
plt.plot(x, y, format_string, **kwargs): Creates a line plot.x,y: Data for the x and y axes.format_string: Controls line style, marker, and color (e.g., ‘r-’, ‘bo—’).label: Label for the plot, used in legends.linewidth: Line width.alpha: Transparency.
plt.scatter(x, y, s=size, c=color, marker=marker, alpha=alpha): Creates a scatter plot.s: Marker size.c: Marker color (can be a single color or an array for a colormap).marker: Marker style (e.g., ‘o’, ‘x’, ’^’).
plt.bar(x, height, width=0.8, bottom=None, align='center', **kwargs): Creates a bar chart.x: X-coordinates of the bars.height: Heights of the bars.width: Width of the bars.
plt.hist(x, bins=num_bins, range=(xmin, xmax), density=True, **kwargs): Creates a histogram.x: Data to be histogrammed.bins: Number of bins.density: If True, normalizes the histogram to a probability density.
plt.imshow(data, cmap='viridis', aspect='auto', interpolation='nearest'): Displays an image or 2D array.data: 2D array of pixel values.cmap: Colormap (e.g., ‘viridis’, ‘gray’, ‘jet’).aspect: Aspect ratio (‘auto’, ‘equal’).interpolation: Interpolation method (e.g., ‘nearest’, ‘bilinear’).
- Axes customization:
plt.xlabel(label): Sets the x-axis label.plt.ylabel(label): Sets the y-axis label.plt.title(title): Sets the plot title.plt.xlim(xmin, xmax): Sets the x-axis limits.plt.ylim(ymin, ymax): Sets the y-axis limits.plt.xticks(ticks, labels): Sets x-axis tick locations and labels.plt.yticks(ticks, labels): Sets y-axis tick locations and labels.ax.spines['top'].set_visible(False): Removes the top spine of an axis (similar for ‘right’, ‘bottom’, ‘left’).ax.tick_params(axis='x', labelsize=10): Customize tick appearance.
- Legends and annotations:
plt.legend(loc='best'): Adds a legend to the plot.loc: Legend location (e.g., ‘best’, ‘upper right’, ‘lower left’).
plt.text(x, y, text, fontsize=12, ha='center', va='center'): Adds text to the plot.x,y: Coordinates of the text.ha: Horizontal alignment.va: Vertical alignment.
plt.annotate(text, xy=(x, y), xytext=(x_offset, y_offset), arrowprops=dict(arrowstyle='->')): Adds an annotation with an arrow.
- Saving plots:
plt.savefig('filename.png', dpi=300, bbox_inches='tight'): Saves the plot to a file.dpi: Resolution.bbox_inches='tight': Removes extra whitespace around the plot.
- Displaying plots:
plt.show(): Displays the plot.
3.2 Seaborn
- Distribution Plots:
sns.distplot(x, bins=None, kde=True, rug=False, hist=True): Combines a histogram, KDE (Kernel Density Estimate), and rug plot.x: Data to be plotted.bins: Number of bins in the histogram.kde: Whether to plot the KDE.rug: Whether to plot the rug plot.hist: Whether to plot the histogram.
sns.kdeplot(x, shade=True): Creates a kernel density estimate plot.sns.rugplot(x): Creates a rug plot (marks data points along the axis).
- Categorical Plots:
sns.barplot(x, y, data, estimator=np.mean, ci='sd', order=None): Creates a bar plot showing the mean of a variable for different categories.x: Column name for categorical variable.y: Column name for numerical variable.data: DataFrame.estimator: Function to use for estimating the central tendency (default:np.mean).ci: Confidence interval size (e.g., ‘sd’, 95).Noneto disable.order: Order to plot the categories.
sns.countplot(x, data, order=None): Creates a count plot showing the number of occurrences of each category.sns.boxplot(x, y, data, hue=None, orient='v'): Creates a box plot showing the distribution of a variable for different categories.hue: Another categorical variable to split the data by.orient: Orientation (‘v’ for vertical, ‘h’ for horizontal).
sns.violinplot(x, y, data, hue=None, split=False, inner='quartile'): Creates a violin plot, similar to a box plot but showing the KDE.split: If True, splits the violins for each hue.inner: Representation of the data within the violin (‘quartile’, ‘box’, ‘point’, ‘stick’).
sns.stripplot(x, y, data, jitter=True, dodge=False): Creates a scatter plot showing the distribution of data points for each category.jitter: Adds random noise to the points to avoid overplotting.dodge: If True, separates points for each hue.
sns.swarmplot(x, y, data, dodge=False): Creates a swarm plot, similar to a strip plot but positions points to avoid overlapping.
- Relational Plots:
sns.scatterplot(x, y, data, hue=None, size=None, style=None): Creates a scatter plot to visualize the relationship between two variables.hue: Another variable to color the points by.size: Another variable to size the points by.style: Another variable to change the marker style by.
sns.lineplot(x, y, data, hue=None, style=None): Creates a line plot to visualize the relationship between two variables.
- Regression Plots:
sns.regplot(x, y, data, ci=95, scatter=True, fit_reg=True): Creates a scatter plot with a regression line.ci: Confidence interval size.scatter: Whether to plot the scatter points.fit_reg: Whether to fit the regression line.
sns.lmplot(x, y, data, hue=None, col=None, row=None): Creates a regression plot with faceting (splitting the plot into multiple subplots).col: Column to facet by.row: Row to facet by.
- Matrix Plots:
sns.heatmap(data, annot=True, cmap='viridis', fmt=".2f"): Creates a heatmap to visualize a matrix of values.data: 2D array of data.annot: Whether to display the values in each cell.cmap: Colormap.fmt: Format string for the annotations.
sns.clustermap(data, cmap='viridis', standard_scale=None, method='ward'): Creates a clustered heatmap.standard_scale: Whether to standardize the data along rows (0) or columns (1).method: Linkage method for clustering (e.g., ‘ward’, ‘average’, ‘complete’).
sns.pairplot(data, hue=None, diag_kind='hist'): Creates a matrix of scatter plots showing the relationships between all pairs of variables in a DataFrame.hue: Variable to color the points by.diag_kind: Kind of plot to show on the diagonal (‘hist’, ‘kde’).
- Style and Aesthetics:
sns.set_style(style, rc=None): Sets the overall style of the plots.style: ‘darkgrid’, ‘whitegrid’, ‘dark’, ‘white’, ‘ticks’.rc: Dictionary of Matplotlib parameters to override.
sns.set_palette(palette, n_colors=None, desat=None): Sets the color palette.palette: ‘deep’, ‘muted’, ‘pastel’, ‘bright’, ‘dark’, ‘colorblind’, or a list of colors.
sns.despine(fig=None, ax=None, top=True, right=True, left=False, bottom=False, offset=None, trim=False): Removes spines from the plot.sns.color_palette(palette, n_colors=None, desat=None): Returns a list of colors from a palette.
4. Practical Examples
4.1 Matplotlib Examples
# Example Datax = np.linspace(0, 10, 100)y1 = np.sin(x)y2 = np.cos(x)
# Basic Line Plotplt.figure(figsize=(8, 6)) # Set figure sizeplt.plot(x, y1, label='sin(x)')plt.plot(x, y2, label='cos(x)')plt.xlabel('x')plt.ylabel('y')plt.title('Sine and Cosine Waves')plt.legend()plt.grid(True) # Add grid linesplt.savefig('sine_cosine.png')plt.show()Output: A plot showing sine and cosine waves.
# Scatter Plotplt.figure(figsize=(8, 6))plt.scatter(x, y1, c='red', marker='o', label='sin(x) points')plt.xlabel('x')plt.ylabel('sin(x)')plt.title('Scatter Plot of Sine Wave')plt.legend()plt.show()Output: A scatter plot of the sine wave.
# Bar Chartcategories = ['A', 'B', 'C', 'D']values = [25, 40, 30, 50]
plt.figure(figsize=(8, 6))plt.bar(categories, values, color='skyblue')plt.xlabel('Category')plt.ylabel('Value')plt.title('Bar Chart of Values')plt.show()Output: A bar chart showing the values for each category.
# Histogramdata = np.random.randn(1000) # Generate random data
plt.figure(figsize=(8, 6))plt.hist(data, bins=30, color='lightgreen', edgecolor='black')plt.xlabel('Value')plt.ylabel('Frequency')plt.title('Histogram of Random Data')plt.show()Output: A histogram of the random data.
# Image Plotimage_data = np.random.rand(100, 100) # Generate random image data
plt.figure(figsize=(8, 6))plt.imshow(image_data, cmap='gray')plt.colorbar()plt.title('Random Image')plt.show()Output: A grayscale image of random noise.
4.2 Seaborn Examples
# Example DataFramedata = pd.DataFrame({ 'category': ['A', 'A', 'B', 'B', 'C', 'C'], 'value': [10, 15, 20, 25, 30, 35], 'group': ['X', 'Y', 'X', 'Y', 'X', 'Y']})
# Bar Plotsns.barplot(x='category', y='value', data=data, hue='group')plt.title('Bar Plot with Seaborn')plt.show()Output: A bar plot showing the values for each category, separated by group.
# Scatter Plotsns.scatterplot(x='value', y='category', data=data, hue='group', size='value')plt.title('Scatter Plot with Seaborn')plt.show()Output: A scatter plot showing the relationship between value and category, colored by group and sized by value.
# Distribution Plotsns.displot(data['value'], kde=True)plt.title('Distribution Plot with Seaborn')plt.show()Output: A distribution plot showing the distribution of the ‘value’ column.
# Heatmap (Correlation Matrix)correlation_matrix = data[['value']].corr() # Example: correlation of a single columnsns.heatmap(correlation_matrix, annot=True, cmap='coolwarm')plt.title('Correlation Heatmap')plt.show()Output: A heatmap of the correlation matrix.
# Pair Plotiris = sns.load_dataset('iris') # load iris datasetsns.pairplot(iris, hue='species')plt.show()Output: A matrix of scatter plots showing the relationships between all pairs of variables in the Iris dataset, colored by species.
5. Advanced Usage
- Customizing Matplotlib Plots with Seaborn Styles:
sns.set_style('whitegrid') # Set seaborn stylefig, ax = plt.subplots(figsize=(8, 6)) # Create a matplotlib figure and axes
ax.plot(x, y1, label='sin(x)') # Plot with matplotlibax.set_xlabel('x')ax.set_ylabel('y')ax.set_title('Sine Wave with Seaborn Style')ax.legend()plt.show()- Creating Facetted Plots with Seaborn:
tips = sns.load_dataset('tips')sns.catplot(x='day', y='total_bill', hue='sex', col='smoker', data=tips, kind='bar')plt.show()- Using Custom Color Palettes:
custom_palette = ['#e41a1c', '#377eb8', '#4daf4a'] # Define a custom color palettesns.set_palette(custom_palette) # Set the palettesns.barplot(x='category', y='value', data=data, hue='group')plt.show()- Working with Subplots:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 8))
# Plotting on each subplotsns.histplot(data['value'], ax=axes[0, 0], kde=True)sns.scatterplot(x='value', y='category', data=data, hue='group', ax=axes[0, 1])sns.barplot(x='category', y='value', data=data, hue='group', ax=axes[1, 0])sns.boxplot(x='category', y='value', data=data, hue='group', ax=axes[1, 1])
plt.tight_layout() # Adjust subplot parameters for a tight layout.plt.show()-
Performance Considerations:
- For large datasets, consider using vectorized operations and avoid looping when possible.
- Limit the number of data points plotted in scatter plots to avoid overplotting. Consider using
hexbinplots for large datasets. - Use appropriate bin sizes for histograms.
- When saving plots, choose a suitable file format and resolution (DPI). Vector formats (e.g. SVG, PDF) are preferred for scalability.
6. Tips & Tricks
- Tab Completion: Use tab completion in your IDE to explore the available functions and parameters.
- Help Documentation: Use
help(plt.plot)orhelp(sns.barplot)to access the documentation for a specific function. - Online Examples: Search online for examples of specific plot types to get inspiration and code snippets.
- Customizing Styles: Create custom style sheets to apply consistent styling to your plots. You can load these with
plt.style.use('my_custom_style.mplstyle'). - Interactive Plots: For interactive plots, consider using libraries like Plotly or Bokeh, which integrate well with Matplotlib and Seaborn data.
- Use
plt.tight_layout()to prevent labels from overlapping. - Use
plt.clf()orplt.cla()to clear figure or axes.
7. Integration
- Pandas: Seaborn is designed to work seamlessly with Pandas DataFrames. Most Seaborn plotting functions accept a DataFrame as input and column names as arguments.
- NumPy: Matplotlib and Seaborn use NumPy arrays for data storage and manipulation.
- Scikit-learn: Visualize model performance and results using Matplotlib and Seaborn.
# Example: Visualizing Scikit-learn model resultsfrom sklearn.linear_model import LogisticRegressionfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import confusion_matrix
# Sample dataX, y = iris.drop('species', axis=1), iris['species']X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Train a logistic regression modelmodel = LogisticRegression(max_iter=1000)model.fit(X_train, y_train)
# Predict on the test sety_pred = model.predict(X_test)
# Create a confusion matrixcm = confusion_matrix(y_test, y_pred, labels=iris['species'].unique())
# Visualize the confusion matrix using Seabornsns.heatmap(cm, annot=True, cmap='Blues', xticklabels=iris['species'].unique(), yticklabels=iris['species'].unique())plt.xlabel('Predicted')plt.ylabel('Actual')plt.title('Confusion Matrix')plt.show()8. Further Resources
- Matplotlib Official Documentation: https://matplotlib.org/stable/contents.html
- Seaborn Official Documentation: https://seaborn.pydata.org/
- Matplotlib Gallery: https://matplotlib.org/stable/gallery/index.html
- Seaborn Gallery: https://seaborn.pydata.org/examples/index.html
- Real Python Matplotlib Tutorials: https://realpython.com/tutorials/matplotlib/
- Kaggle Data Visualization Courses: https://www.kaggle.com/learn/data-visualization