This post extends the post on Scatter Plot in matplotlib and seaborn. Scatter plot is an important type of visualization used to show relationship between two continuous variables and categorical variable. The relationship can be either positive, negative, or neutral. In this post we will look at Scatter plot, when to use it and how to use scatter plot in plotly. The data for this post can be downloaded here for iris dataset and gdp per capita dataset.
When to Use Scatter Plot
Scatter plot is used to visualize the relationship between two continuous variables and categorical variable.
How to Use Scatter Plot
- Add line of best fit (trendline) to clearly show the direction of the relationship.
- Avoid overplotting. Reduce cluttering of data point as much as possible to ensure that the relationship between variables is clearly seen.
- When having more than two variables use colour to distinguish between the variables.
- When interpreting the relationship between variables note that correlation does not imply causation.
Scatter Plot in Plotly
Load Required Libraries
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
Load Data
iris_df=pd.read_csv('iris.csv')
iris_df.head()
Simple Scatter Plot
iris_simple_scatterplot_df=iris_df[iris_df['class'].isin(['Iris-setosa'])]
fig = px.scatter(iris_simple_scatterplot_df, x='sepal_width', y='sepal_length')
fig.update_layout(title={'text': 'Sepal Width vs Length','y':0.95,'x':0.5, 'xanchor': 'center','yanchor': 'top'},
legend=dict(yanchor="top",y=0.95,xanchor="right",x=0.95),
autosize=True,margin=dict(t=70,b=0,l=0,r=0), xaxis_title='Sepal Length', yaxis_title='Sepal Width',
font=dict(size=20, family='Times New Romans', color='brown') )
fig.update_xaxes(showline=True, linewidth=1, linecolor='white', gridwidth=3, gridcolor='white', mirror=True)
fig.update_yaxes(showline=True, linewidth=1, linecolor='white', gridwidth=3, gridcolor='white', mirror=True)
fig.show()
Scatter plot with Grouped Data
fig = px.scatter(iris_df, x='sepal_width', y='sepal_length',color='class')
fig.update_layout(title={'text': 'Sepal Width vs Length','y':0.95,'x':0.5, 'xanchor': 'center','yanchor': 'top'},
legend=dict(yanchor="top",y=0.95,xanchor="right",x=0.95),
autosize=True,margin=dict(t=70,b=0,l=0,r=0), xaxis_title='Sepal Length', yaxis_title='Sepal Width',
font=dict(size=20, family='Times New Romans', color='brown') )
fig.update_xaxes(showline=True, linewidth=1, linecolor='white', gridwidth=3, gridcolor='white', mirror=True)
fig.update_yaxes(showline=True, linewidth=1, linecolor='white', gridwidth=3, gridcolor='white', mirror=True)
fig.show()
Scatter Plot With Trendline
Plotly allows us to plot scatter graphs and the distribution of the data points. Before we start we need to install statsmodel library which comes with ols. Run below command in terminal
pip install statsmodels --upgrade
Now let’s plot the scatter plot and a trendline.
iris_simple_scatterplot_df=iris_df[iris_df['class'].isin(['Iris-setosa'])]
fig = px.scatter(iris_simple_scatterplot_df, x='sepal_width', y='sepal_length', trendline="ols")
fig.update_layout(title={'text': 'Sepal Width vs Length','y':0.95,'x':0.5, 'xanchor': 'center','yanchor': 'top'},
legend=dict(yanchor="top",y=0.95,xanchor="right",x=0.95),
autosize=True,margin=dict(t=70,b=0,l=0,r=0), xaxis_title='Sepal Length', yaxis_title='Sepal Width',
font=dict(size=20, family='Times New Romans', color='brown') )
fig.update_xaxes(showline=True, linewidth=1, linecolor='white', gridwidth=3, gridcolor='white', mirror=True)
fig.update_yaxes(showline=True, linewidth=1, linecolor='white', gridwidth=3, gridcolor='white', mirror=True)
fig.show()
Scatter plot with Trendline and grouped data
fig = px.scatter(iris_df, x='sepal_width', y='sepal_length',color='class',trendline='ols')
fig.update_layout(title={'text': 'Sepal Width vs Length','y':0.95,'x':0.5, 'xanchor': 'center','yanchor': 'top'},
legend=dict(yanchor="top",y=0.95,xanchor="right",x=0.95),
autosize=True,margin=dict(t=70,b=0,l=0,r=0), xaxis_title='Sepal Length', yaxis_title='Sepal Width',
font=dict(size=20, family='Times New Romans', color='brown') )
fig.update_xaxes(showline=True, linewidth=1, linecolor='white', gridwidth=3, gridcolor='white', mirror=True)
fig.update_yaxes(showline=True, linewidth=1, linecolor='white', gridwidth=3, gridcolor='white', mirror=True)
fig.show()
Buble Chart
gdpPercap_df=pd.read_csv('gdpPercap.csv')
gdpPercap_df.head()
fig = px.scatter(gdpPercap_df[gdpPercap_df['year'].isin(['2007'])], x='gdpPercap',
y='lifeExp',color='continent',size='pop',log_x=True)
fig.update_layout(title={'text': 'GDP Per Capita vs Life Expectancy','y':0.95,'x':0.5, 'xanchor': 'center','yanchor': 'top'},
legend=dict(yanchor="bottom",y=0.095,xanchor="right",x=0.95),
autosize=True,margin=dict(t=70,b=0,l=0,r=0), xaxis_title='GDP Per Capita', yaxis_title='Life Expectancy',
font=dict(size=20, family='Times New Romans', color='brown') )
fig.update_xaxes(showline=True, linewidth=1, linecolor='white', gridwidth=3, gridcolor='white', mirror=True)
fig.update_yaxes(showline=True, linewidth=1, linecolor='white', gridwidth=3, gridcolor='white', mirror=True)
fig.show()
For complete code check the jupyter notebook here.
Conclusion
In this post we have looked at scatter plot and how to use it. Scatter plot are useful in depicting the relationship between two continuous variables and categorical variables. In the next post we will look at Area Charts and how to use them in plotly. To learn about Pie chart in plotly check our previous post here. To learn about Scatter plot in seaborn check our post here.