How to create a scatter plot with categorical data in Pandas using matplotlib?

Susan Sarandon
Release: 2024-11-18 08:55:03
Original
842 people have browsed it

How to create a scatter plot with categorical data in Pandas using matplotlib?

Creating Scatter Plots Categorized by a Key in Pandas DataFrames

In data visualization, scatter plots are commonly used to discern relationships between numerical variables. However, when there are additional categorical variables that contribute to the analysis, it becomes necessary to represent them within the scatter plot. This question explores an efficient way of plotting two variables while conveying the third as discrete categories.

Initially, attempts were made using df.groupby, but they did not yield the desired results. The sample DataFrame provided serves to illustrate the issue:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10, 1, 30).reshape(10, 3),
                  index=pd.date_range('2010-01-01', freq='M', periods=10),
                  columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)

fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker='o', c=df['key1'], alpha=0.8)
plt.show()
Copy after login

This approach successfully colors the markers according to the 'key1' column, but it lacks a legend to distinguish the categories. To achieve both, a different method is required.

The solution is to employ plot instead of scatter, as plot is better suited for discrete categories:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
fig, ax = plt.subplots()
ax.margins(0.05)  # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()

plt.show()
Copy after login

This code generates a scatter plot with each category represented by a distinctive marker and a legend that clearly labels the categories.

For a more customized look, you can incorporate the Pandas style by updating rcParams and utilizing its color generator:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')

fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')

plt.show()
Copy after login

This modification will give the plot the classic Pandas style with a more visually appealing color scheme.

The above is the detailed content of How to create a scatter plot with categorical data in Pandas using matplotlib?. For more information, please follow other related articles on the PHP Chinese website!

source:php.cn
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Latest Articles by Author
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template