Conditionally Dropping Columns
filepath = '../data/movieData.csv'What We’re Used To
Dropping columns of data in pandas is a pretty trivial task.
import pandas as pd
df = pd.read_csv(filepath)
df.head()| Rank | WeeklyGross | PctChangeWkGross | Theaters | DeltaTheaters | AvgRev | GrossToDate | Week | Thursday | name | year | Winner | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 17.0 | 967378 | NaN | 14.0 | NaN | 69098.0 | 967378 | 1 | 1990-11-18 | dances with wolves | 1990 | True |
| 1 | 9.0 | 3871641 | 300.0 | 14.0 | NaN | 276546.0 | 4839019 | 2 | 1990-11-25 | dances with wolves | 1990 | True |
| 2 | 3.0 | 12547813 | 224.0 | 1048.0 | 1034.0 | 11973.0 | 17386832 | 3 | 1990-12-02 | dances with wolves | 1990 | True |
| 3 | 4.0 | 9246632 | -26.3 | 1053.0 | 5.0 | 8781.0 | 26633464 | 4 | 1990-12-09 | dances with wolves | 1990 | True |
| 4 | 4.0 | 7272350 | -21.4 | 1051.0 | -2.0 | 6919.0 | 33905814 | 5 | 1990-12-16 | dances with wolves | 1990 | True |
We can either specify which columns we want to drop.
df.drop(['Rank', 'WeeklyGross'], axis=1).head()| PctChangeWkGross | Theaters | DeltaTheaters | AvgRev | GrossToDate | Week | Thursday | name | year | Winner | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | NaN | 14.0 | NaN | 69098.0 | 967378 | 1 | 1990-11-18 | dances with wolves | 1990 | True |
| 1 | 300.0 | 14.0 | NaN | 276546.0 | 4839019 | 2 | 1990-11-25 | dances with wolves | 1990 | True |
| 2 | 224.0 | 1048.0 | 1034.0 | 11973.0 | 17386832 | 3 | 1990-12-02 | dances with wolves | 1990 | True |
| 3 | -26.3 | 1053.0 | 5.0 | 8781.0 | 26633464 | 4 | 1990-12-09 | dances with wolves | 1990 | True |
| 4 | -21.4 | 1051.0 | -2.0 | 6919.0 | 33905814 | 5 | 1990-12-16 | dances with wolves | 1990 | True |
Or write some condition to filter on and pipe it into the DataFrame selector.
Let’s imagine we only want columns that have a '.' in them.
dotCounts = df.apply(lambda x: x.map(str)
.str.contains('\.')).sum()
dotCountsRank 3836
WeeklyGross 0
PctChangeWkGross 3625
Theaters 3836
DeltaTheaters 3389
AvgRev 3836
GrossToDate 0
Week 0
Thursday 0
name 13
year 0
Winner 0
dtype: int64
colsWithdots = dotCounts[dotCounts != 0].index
print(colsWithdots)Index(['Rank', 'PctChangeWkGross', 'Theaters', 'DeltaTheaters', 'AvgRev',
'name'],
dtype='object')
df[colsWithdots].head()| Rank | PctChangeWkGross | Theaters | DeltaTheaters | AvgRev | name | |
|---|---|---|---|---|---|---|
| 0 | 17.0 | NaN | 14.0 | NaN | 69098.0 | dances with wolves |
| 1 | 9.0 | 300.0 | 14.0 | NaN | 276546.0 | dances with wolves |
| 2 | 3.0 | 224.0 | 1048.0 | 1034.0 | 11973.0 | dances with wolves |
| 3 | 4.0 | -26.3 | 1053.0 | 5.0 | 8781.0 | dances with wolves |
| 4 | 4.0 | -21.4 | 1051.0 | -2.0 | 6919.0 | dances with wolves |
Ez pz
df['name'][df['name'].str.contains('\.')]759 l.a. confidential
760 l.a. confidential
761 l.a. confidential
762 l.a. confidential
763 l.a. confidential
764 l.a. confidential
765 l.a. confidential
766 l.a. confidential
767 l.a. confidential
768 l.a. confidential
769 l.a. confidential
770 l.a. confidential
771 l.a. confidential
Name: name, dtype: object
I was curious, too.
Now Spark
Similarly, if we want to read this in as a Spark DataFrame, we’d do the following.
import findspark
findspark.init()
import pyspark
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession(sc)df = spark.read.csv(filepath, header=True)
df.show(5)+----+-----------+----------------+--------+-------------+--------+-----------+----+----------+------------------+----+------+
|Rank|WeeklyGross|PctChangeWkGross|Theaters|DeltaTheaters| AvgRev|GrossToDate|Week| Thursday| name|year|Winner|
+----+-----------+----------------+--------+-------------+--------+-----------+----+----------+------------------+----+------+
|17.0| 967378| null| 14.0| null| 69098.0| 967378| 1|1990-11-18|dances with wolves|1990| True|
| 9.0| 3871641| 300.0| 14.0| null|276546.0| 4839019| 2|1990-11-25|dances with wolves|1990| True|
| 3.0| 12547813| 224.0| 1048.0| 1034.0| 11973.0| 17386832| 3|1990-12-02|dances with wolves|1990| True|
| 4.0| 9246632| -26.3| 1053.0| 5.0| 8781.0| 26633464| 4|1990-12-09|dances with wolves|1990| True|
| 4.0| 7272350| -21.4| 1051.0| -2.0| 6919.0| 33905814| 5|1990-12-16|dances with wolves|1990| True|
+----+-----------+----------------+--------+-------------+--------+-----------+----+----------+------------------+----+------+
only showing top 5 rows
But trying to drop columns is a little involved.
We can specify which column names we want to keep.
df.select('Rank', 'PctChangeWkGross', 'Theaters',
'DeltaTheaters', 'AvgRev', 'name').show(5)+----+----------------+--------+-------------+--------+------------------+
|Rank|PctChangeWkGross|Theaters|DeltaTheaters| AvgRev| name|
+----+----------------+--------+-------------+--------+------------------+
|17.0| null| 14.0| null| 69098.0|dances with wolves|
| 9.0| 300.0| 14.0| null|276546.0|dances with wolves|
| 3.0| 224.0| 1048.0| 1034.0| 11973.0|dances with wolves|
| 4.0| -26.3| 1053.0| 5.0| 8781.0|dances with wolves|
| 4.0| -21.4| 1051.0| -2.0| 6919.0|dances with wolves|
+----+----------------+--------+-------------+--------+------------------+
only showing top 5 rows
But we won’t have the column names figured out, so we need to figure out how to get at them, procedurally.
Building Column Conditions
We’re going to need some handy functions to facilitate this
from pyspark.sql.functions import col, count, whenFor simplicity, let’s just focus on just the PctChangeWkGross column. We’ll scale up after.
So first, we’ll use the col function on a string to make an instance of the Col class.
pcwg = col('PctChangeWkGross')So let’s check if each value contains a '.' character
df.select(
pcwg.contains('.')
).count()3845
Hmm, that doesn’t seem right. That’s the length of everything.
df.count()3845
So we need to additionally add the count column function
df.select(
count(pcwg.contains('.'))
).show()+------------------------------------+
|count(contains(PctChangeWkGross, .))|
+------------------------------------+
| 3625|
+------------------------------------+
We shed like 200 records. That’s more like it. Let’s see what happens with column that we know doesn’t have any.
Thursday is all yyyy-mm-dd.
df.select(
count(col('Thursday').contains('.'))
).show()+----------------------------+
|count(contains(Thursday, .))|
+----------------------------+
| 3845|
+----------------------------+
Still wrong. It’s counting a bunch of False values. Last step, we need a way to make it not count these values.
This is where we’ll use the when function. The top 5 values of Thursday look like this.
df.select('Thursday').show(5)+----------+
| Thursday|
+----------+
|1990-11-18|
|1990-11-25|
|1990-12-02|
|1990-12-09|
|1990-12-16|
+----------+
only showing top 5 rows
when takes two arguments:
- A Column of values with a broadcasted check
- A value for “if the check evaluates to True“
Anything that evaluates to False becomes NULL
df.select((when(col('Thursday').contains('11'), 0))).show(5)+-------------------------------------------+
|CASE WHEN contains(Thursday, 11) THEN 0 END|
+-------------------------------------------+
| 0|
| 0|
| null|
| null|
| null|
+-------------------------------------------+
only showing top 5 rows
And count will skip right over it.
df.select(count(when(col('Thursday').contains('11'), 0))).show(5)+--------------------------------------------------+
|count(CASE WHEN contains(Thursday, 11) THEN 0 END)|
+--------------------------------------------------+
| 578|
+--------------------------------------------------+
At Scale
Now to do this for multiple columns, we need to do some clever list comprehension
df.select(
[count(when(col(x).contains('.'), 0)) for x in df.columns]
).collect()[Row(count(CASE WHEN contains(Rank, .) THEN 0 END)=3836, count(CASE WHEN contains(WeeklyGross, .) THEN 0 END)=0, count(CASE WHEN contains(PctChangeWkGross, .) THEN 0 END)=3625, count(CASE WHEN contains(Theaters, .) THEN 0 END)=3836, count(CASE WHEN contains(DeltaTheaters, .) THEN 0 END)=3389, count(CASE WHEN contains(AvgRev, .) THEN 0 END)=3836, count(CASE WHEN contains(GrossToDate, .) THEN 0 END)=0, count(CASE WHEN contains(Week, .) THEN 0 END)=0, count(CASE WHEN contains(Thursday, .) THEN 0 END)=0, count(CASE WHEN contains(name, .) THEN 0 END)=13, count(CASE WHEN contains(year, .) THEN 0 END)=0, count(CASE WHEN contains(Winner, .) THEN 0 END)=0)]
But that looks crazy gross. alias to the rescue.
dotCounts = df.select(
[count(when(col(x).contains('.'), 0)).alias(x) for x in df.columns]
)
dotCounts.show(5)+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
|Rank|WeeklyGross|PctChangeWkGross|Theaters|DeltaTheaters|AvgRev|GrossToDate|Week|Thursday|name|year|Winner|
+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
|3836| 0| 3625| 3836| 3389| 3836| 0| 0| 0| 13| 0| 0|
+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
Let’s bring it home with some list comprehension
colsWithDots = [c for c in dotCounts.columns
if dotCounts[[c]].first()[c] == 0]df.select(colsWithDots).show(5)+-----------+-----------+----+----------+----+------+
|WeeklyGross|GrossToDate|Week| Thursday|year|Winner|
+-----------+-----------+----+----------+----+------+
| 967378| 967378| 1|1990-11-18|1990| True|
| 3871641| 4839019| 2|1990-11-25|1990| True|
| 12547813| 17386832| 3|1990-12-02|1990| True|
| 9246632| 26633464| 4|1990-12-09|1990| True|
| 7272350| 33905814| 5|1990-12-16|1990| True|
+-----------+-----------+----+----------+----+------+
only showing top 5 rows
Dropping Columns with NULL Data
Same general approach applies to a more practical application of dropping columns with NULL values
from pyspark.sql.functions import isnan, isnullnullCounts = df.select([count(when(isnan(c)|isnull(c), c)).alias(c) for c in df.columns])
nullCounts.show()+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
|Rank|WeeklyGross|PctChangeWkGross|Theaters|DeltaTheaters|AvgRev|GrossToDate|Week|Thursday|name|year|Winner|
+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
| 9| 0| 220| 9| 456| 9| 0| 0| 0| 0| 0| 0|
+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
nonNull_cols = [c for c in nullCounts.columns if nullCounts[[c]].first()[c] == 0]df.select(nonNull_cols).show(5)+-----------+-----------+----+----------+------------------+----+------+
|WeeklyGross|GrossToDate|Week| Thursday| name|year|Winner|
+-----------+-----------+----+----------+------------------+----+------+
| 967378| 967378| 1|1990-11-18|dances with wolves|1990| True|
| 3871641| 4839019| 2|1990-11-25|dances with wolves|1990| True|
| 12547813| 17386832| 3|1990-12-02|dances with wolves|1990| True|
| 9246632| 26633464| 4|1990-12-09|dances with wolves|1990| True|
| 7272350| 33905814| 5|1990-12-16|dances with wolves|1990| True|
+-----------+-----------+----+----------+------------------+----+------+
only showing top 5 rows
This is probably handy enough to package into a function
def drop_null_cols(frame, threshold):
'''
Drop columns from a `PySpark.DataFrame` that
have more than `threshold` NULL values
'''
nullCounts = frame.select([count(when(isnan(c)|isnull(c), c))
.alias(c) for c in frame.columns])
nonNullCols = [c for c in nullCounts.columns
if nullCounts[[c]].first()[c] < threshold]
return frame.select(nonNullCols)drop_null_cols(df, 100).show(5)+----+-----------+--------+--------+-----------+----+----------+------------------+----+------+
|Rank|WeeklyGross|Theaters| AvgRev|GrossToDate|Week| Thursday| name|year|Winner|
+----+-----------+--------+--------+-----------+----+----------+------------------+----+------+
|17.0| 967378| 14.0| 69098.0| 967378| 1|1990-11-18|dances with wolves|1990| True|
| 9.0| 3871641| 14.0|276546.0| 4839019| 2|1990-11-25|dances with wolves|1990| True|
| 3.0| 12547813| 1048.0| 11973.0| 17386832| 3|1990-12-02|dances with wolves|1990| True|
| 4.0| 9246632| 1053.0| 8781.0| 26633464| 4|1990-12-09|dances with wolves|1990| True|
| 4.0| 7272350| 1051.0| 6919.0| 33905814| 5|1990-12-16|dances with wolves|1990| True|
+----+-----------+--------+--------+-----------+----+----------+------------------+----+------+
only showing top 5 rows