Cohort Analysis with Python
Using Python to understand user retention is much easier than you might have imagined. All of this can be done with a few libraries. For this tutorial, we will be using Pandas, Matplotlib and Seaborn and Datetime.
You can follow the along in the video to save time if you are in a rush.
You can also follow the code below to ensure that you know how to create the full cohort using Python. Make you change the file path to fit where your file is saved. The full Jupyter notebook is hosted on Github here.
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# load in the data and take a look
data = pd.read_excel("C:/Users/User/Downloads/Online Retail.xlsx")
#drop rows with no customer ID
data = data.dropna(subset=['CustomerID'])
#create an invoice month
import datetime as dt
#function for month
def get_month(x):
return dt.datetime(x.year, x.month,1)
#apply the function
data['InvoiceMonth'] = data['InvoiceDate'].apply(get_month)
#create a column index with the minimum invoice date aka first time customer was acquired
data['Cohort Month'] = data.groupby('CustomerID')['InvoiceMonth'].transform('min')
data.head(30)
# create a date element function to get a series for subtraction
def get_date_elements(df, column):
day = df[column].dt.day
month = df[column].dt.month
year = df[column].dt.year
return day, month, year
# get date elements for our cohort and invoice columns
_,Invoice_month,Invoice_year = get_date_elements(data,'InvoiceMonth')
_,Cohort_month,Cohort_year = get_date_elements(data,'Cohort Month')
#create a cohort index
year_diff = Invoice_year -Cohort_year
month_diff = Invoice_month - Cohort_month
data['CohortIndex'] = year_diff*12+month_diff+1
#count the customer ID by grouping by Cohort Month and Cohort Index
cohort_data = data.groupby(['Cohort Month','CohortIndex'])['CustomerID'].apply(pd.Series.nunique).reset_index()
# create a pivot table
cohort_table = cohort_data.pivot(index='Cohort Month', columns=['CohortIndex'],values='CustomerID')
# change index
cohort_table.index = cohort_table.index.strftime('%B %Y')
#visualize our results in heatmap
plt.figure(figsize=(21,10))
sns.heatmap(cohort_table,annot=True,cmap='Blues')
#cohort table for percentage
new_cohort_table = cohort_table.divide(cohort_table.iloc[:,0],axis=0)
#create a percentages visual
plt.figure(figsize=(21,10))
sns.heatmap(new_cohort_table,annot=True,fmt='.0%')
