自学气象人 发表于 2024-3-9 23:32:43

50个常用统计图表代码总结(1)


Setup
# !pip install brewer2mpl
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import warnings; warnings.filterwarnings(action='once')

large =22; med =16; small =12
params ={'axes.titlesize': large,
          'legend.fontsize': med,
          'figure.figsize':(16,10),
          'axes.labelsize': med,
          'axes.titlesize': med,
          'xtick.labelsize': med,
          'ytick.labelsize': med,
          'figure.titlesize': large}
plt.rcParams.update(params)
plt.style.use('seaborn-whitegrid')
sns.set_style("white")
%matplotlib inline

# Version
print(mpl.__version__)#> 3.0.0
print(sns.__version__)#> 0.9.0

Correlation
The plots under correlation is used to visualize the relationship between 2 or more variables. That is, how does one variable change with respect to another.


1. Scatter plot
Scatteplot is a classic and fundamental plot used to study the relationship between two variables. If you have multiple groups in your data you may want to visualise each group in a different color. In matplotlib, you can conveniently do this using plt.scatterplot().
# Import dataset
midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")

# Prepare Data
# Create as many colors as there are unique midwest['category']
categories = np.unique(midwest['category'])
colors =

# Draw Plot for Each Category
plt.figure(figsize=(16,10), dpi=80, facecolor='w', edgecolor='k')

for i, category in enumerate(categories):
    plt.scatter('area','poptotal',
                data=midwest.loc,
                s=20, c=colors<i>, label=str(category))</i>

# Decorations
plt.gca().set(xlim=(0.0,0.1), ylim=(0,90000),
            xlabel='Area', ylabel='Population')

plt.xticks(fontsize=12); plt.yticks(fontsize=12)
plt.title("Scatterplot of Midwest Area vs Population", fontsize=22)
plt.legend(fontsize=12)   
plt.show()



2. Bubble plot with Encircling
Sometimes you want to show a group of points within a boundary to emphasize their importance. In this example, you get the records from the dataframe that should be encircled and pass it to the encircle() described in the code below.
from matplotlib import patches
from scipy.spatial importConvexHull
import warnings; warnings.simplefilter('ignore')
sns.set_style("white")

# Step 1: Prepare Data
midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")

# As many colors as there are unique midwest['category']
categories = np.unique(midwest['category'])
colors =

# Step 2: Draw Scatterplot with unique color for each category
fig = plt.figure(figsize=(16,10), dpi=80, facecolor='w', edgecolor='k')   

for i, category in enumerate(categories):
    plt.scatter('area','poptotal', data=midwest.loc, s='dot_size', c=colors, label=str(category), edgecolors='black', linewidths=.5)

# Step 3: Encircling
# <a href="https://stackoverflow.com/questions/44575681/how-do-i-encircle-different-data-sets-in-scatter-plot" target="_blank"><a href="https://stackoverflow.com/questions/44575681/how-do-i-encircle-different-data-sets-in-scatter-plot</a>" target="_blank">https://stackoverflow.com/questions/44575681/how-do-i-encircle-different-data-sets-in-scatter-plot</a></a>
def encircle(x,y, ax=None,**kw):
    ifnot ax: ax=plt.gca()
    p = np.c_
    hull =ConvexHull(p)
    poly = plt.Polygon(p,**kw)
    ax.add_patch(poly)

# Select data to be encircled
midwest_encircle_data = midwest.loc                        

# Draw polygon surrounding vertices   
encircle(midwest_encircle_data.area, midwest_encircle_data.poptotal, ec="k", fc="gold", alpha=0.1)
encircle(midwest_encircle_data.area, midwest_encircle_data.poptotal, ec="firebrick", fc="none", linewidth=1.5)

# Step 4: Decorations
plt.gca().set(xlim=(0.0,0.1), ylim=(0,90000),
            xlabel='Area', ylabel='Population')

plt.xticks(fontsize=12); plt.yticks(fontsize=12)
plt.title("Bubble Plot with Encircling", fontsize=22)
plt.legend(fontsize=12)   
plt.show()


3. Scatter plot with linear regression line of best fit
If you want to understand how two variables change with respect to each other, the line of best fit is the way to go. The below plot shows how the line of best fit differs amongst various groups in the data. To disable the groupings and to just draw one line-of-best-fit for the entire dataset, remove the hue='cyl' parameter from the sns.lmplot() call below.
# Import Data
df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
df_select = df.loc),:]

# Plot
sns.set_style("white")
gridobj = sns.lmplot(x="displ", y="hwy", hue="cyl", data=df_select,
                     height=7, aspect=1.6, robust=True, palette='tab10',
                     scatter_kws=dict(s=60, linewidths=.7, edgecolors='black'))

# Decorations
gridobj.set(xlim=(0.5,7.5), ylim=(0,50))
plt.title("Scatterplot with line of best fit grouped by number of cylinders", fontsize=20)
plt.show()


Each regression line in its own column
Alternately, you can show the best fit line for each group in its own column. You cando this by setting the col=groupingcolumn parameter inside the sns.lmplot().
# Import Data
df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
df_select = df.loc),:]

# Each line in its own column
sns.set_style("white")
gridobj = sns.lmplot(x="displ", y="hwy",
                     data=df_select,
                     height=7,
                     robust=True,
                     palette='Set1',
                     col="cyl",
                     scatter_kws=dict(s=60, linewidths=.7, edgecolors='black'))

# Decorations
gridobj.set(xlim=(0.5,7.5), ylim=(0,50))
plt.show()


4. Jittering with stripplot
Often multiple datapoints have exactly the same X and Y values. As a result, multiple points get plotted over each other and hide. To avoid this, jitter the points slightly so you can visually see them. This is convenient to do using seaborn’s stripplot().
# Import Data
df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")

# Draw Stripplot
fig, ax = plt.subplots(figsize=(16,10), dpi=80)   
sns.stripplot(df.cty, df.hwy, jitter=0.25, size=8, ax=ax, linewidth=.5)

# Decorations
plt.title('Use jittered plots to avoid overlapping of points', fontsize=22)
plt.show()


5. Counts Plot
Another option to avoid the problem of points overlap is the increase the size of the dot depending on how many points lie in that spot. So, larger the size of the point more is the concentration of points around that.

# Import Data
df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
df_counts = df.groupby(['hwy','cty']).size().reset_index(name='counts')

# Draw Stripplot
fig, ax = plt.subplots(figsize=(16,10), dpi=80)   
sns.stripplot(df_counts.cty, df_counts.hwy, size=df_counts.counts*2, ax=ax)

# Decorations
plt.title('Counts Plot - Size of circle is bigger as more points overlap', fontsize=22)
plt.show()




文章来源于微信公众号:自学气象人



页: [1]
查看完整版本: 50个常用统计图表代码总结(1)