Creating Scatter Plots Categorized by a Key in Pandas DataFrames
In data visualization, scatter plots are commonly used to discern relationships between numerical variables. However, when there are additional categorical variables that contribute to the analysis, it becomes necessary to represent them within the scatter plot. This question explores an efficient way of plotting two variables while conveying the third as discrete categories.
Initially, attempts were made using df.groupby, but they did not yield the desired results. The sample DataFrame provided serves to illustrate the issue:
import numpy as np import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame(np.random.normal(10, 1, 30).reshape(10, 3), index=pd.date_range('2010-01-01', freq='M', periods=10), columns=('one', 'two', 'three')) df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8) fig1 = plt.figure(1) ax1 = fig1.add_subplot(111) ax1.scatter(df['one'], df['two'], marker='o', c=df['key1'], alpha=0.8) plt.show()
This approach successfully colors the markers according to the 'key1' column, but it lacks a legend to distinguish the categories. To achieve both, a different method is required.
The solution is to employ plot instead of scatter, as plot is better suited for discrete categories:
import matplotlib.pyplot as plt import numpy as np import pandas as pd np.random.seed(1974) # Generate Data num = 20 x, y = np.random.random((2, num)) labels = np.random.choice(['a', 'b', 'c'], num) df = pd.DataFrame(dict(x=x, y=y, label=labels)) groups = df.groupby('label') # Plot fig, ax = plt.subplots() ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling for name, group in groups: ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name) ax.legend() plt.show()
This code generates a scatter plot with each category represented by a distinctive marker and a legend that clearly labels the categories.
For a more customized look, you can incorporate the Pandas style by updating rcParams and utilizing its color generator:
import matplotlib.pyplot as plt import numpy as np import pandas as pd np.random.seed(1974) # Generate Data num = 20 x, y = np.random.random((2, num)) labels = np.random.choice(['a', 'b', 'c'], num) df = pd.DataFrame(dict(x=x, y=y, label=labels)) groups = df.groupby('label') # Plot plt.rcParams.update(pd.tools.plotting.mpl_stylesheet) colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random') fig, ax = plt.subplots() ax.set_color_cycle(colors) ax.margins(0.05) for name, group in groups: ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name) ax.legend(numpoints=1, loc='upper left') plt.show()
This modification will give the plot the classic Pandas style with a more visually appealing color scheme.
The above is the detailed content of How to create a scatter plot with categorical data in Pandas using matplotlib?. For more information, please follow other related articles on the PHP Chinese website!