Conditionally Dropping Columns
filepath = '../data/movieData.csv'
What We’re Used To
Dropping columns of data in pandas
is a pretty trivial task.
import pandas as pd
df = pd.read_csv(filepath)
df.head()
Rank | WeeklyGross | PctChangeWkGross | Theaters | DeltaTheaters | AvgRev | GrossToDate | Week | Thursday | name | year | Winner | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 17.0 | 967378 | NaN | 14.0 | NaN | 69098.0 | 967378 | 1 | 1990-11-18 | dances with wolves | 1990 | True |
1 | 9.0 | 3871641 | 300.0 | 14.0 | NaN | 276546.0 | 4839019 | 2 | 1990-11-25 | dances with wolves | 1990 | True |
2 | 3.0 | 12547813 | 224.0 | 1048.0 | 1034.0 | 11973.0 | 17386832 | 3 | 1990-12-02 | dances with wolves | 1990 | True |
3 | 4.0 | 9246632 | -26.3 | 1053.0 | 5.0 | 8781.0 | 26633464 | 4 | 1990-12-09 | dances with wolves | 1990 | True |
4 | 4.0 | 7272350 | -21.4 | 1051.0 | -2.0 | 6919.0 | 33905814 | 5 | 1990-12-16 | dances with wolves | 1990 | True |
We can either specify which columns we want to drop.
df.drop(['Rank', 'WeeklyGross'], axis=1).head()
PctChangeWkGross | Theaters | DeltaTheaters | AvgRev | GrossToDate | Week | Thursday | name | year | Winner | |
---|---|---|---|---|---|---|---|---|---|---|
0 | NaN | 14.0 | NaN | 69098.0 | 967378 | 1 | 1990-11-18 | dances with wolves | 1990 | True |
1 | 300.0 | 14.0 | NaN | 276546.0 | 4839019 | 2 | 1990-11-25 | dances with wolves | 1990 | True |
2 | 224.0 | 1048.0 | 1034.0 | 11973.0 | 17386832 | 3 | 1990-12-02 | dances with wolves | 1990 | True |
3 | -26.3 | 1053.0 | 5.0 | 8781.0 | 26633464 | 4 | 1990-12-09 | dances with wolves | 1990 | True |
4 | -21.4 | 1051.0 | -2.0 | 6919.0 | 33905814 | 5 | 1990-12-16 | dances with wolves | 1990 | True |
Or write some condition to filter on and pipe it into the DataFrame
selector.
Let’s imagine we only want columns that have a '.'
in them.
dotCounts = df.apply(lambda x: x.map(str)
.str.contains('\.')).sum()
dotCounts
Rank 3836
WeeklyGross 0
PctChangeWkGross 3625
Theaters 3836
DeltaTheaters 3389
AvgRev 3836
GrossToDate 0
Week 0
Thursday 0
name 13
year 0
Winner 0
dtype: int64
colsWithdots = dotCounts[dotCounts != 0].index
print(colsWithdots)
Index(['Rank', 'PctChangeWkGross', 'Theaters', 'DeltaTheaters', 'AvgRev',
'name'],
dtype='object')
df[colsWithdots].head()
Rank | PctChangeWkGross | Theaters | DeltaTheaters | AvgRev | name | |
---|---|---|---|---|---|---|
0 | 17.0 | NaN | 14.0 | NaN | 69098.0 | dances with wolves |
1 | 9.0 | 300.0 | 14.0 | NaN | 276546.0 | dances with wolves |
2 | 3.0 | 224.0 | 1048.0 | 1034.0 | 11973.0 | dances with wolves |
3 | 4.0 | -26.3 | 1053.0 | 5.0 | 8781.0 | dances with wolves |
4 | 4.0 | -21.4 | 1051.0 | -2.0 | 6919.0 | dances with wolves |
Ez pz
df['name'][df['name'].str.contains('\.')]
759 l.a. confidential
760 l.a. confidential
761 l.a. confidential
762 l.a. confidential
763 l.a. confidential
764 l.a. confidential
765 l.a. confidential
766 l.a. confidential
767 l.a. confidential
768 l.a. confidential
769 l.a. confidential
770 l.a. confidential
771 l.a. confidential
Name: name, dtype: object
I was curious, too.
Now Spark
Similarly, if we want to read this in as a Spark DataFrame
, we’d do the following.
import findspark
findspark.init()
import pyspark
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession(sc)
df = spark.read.csv(filepath, header=True)
df.show(5)
+----+-----------+----------------+--------+-------------+--------+-----------+----+----------+------------------+----+------+
|Rank|WeeklyGross|PctChangeWkGross|Theaters|DeltaTheaters| AvgRev|GrossToDate|Week| Thursday| name|year|Winner|
+----+-----------+----------------+--------+-------------+--------+-----------+----+----------+------------------+----+------+
|17.0| 967378| null| 14.0| null| 69098.0| 967378| 1|1990-11-18|dances with wolves|1990| True|
| 9.0| 3871641| 300.0| 14.0| null|276546.0| 4839019| 2|1990-11-25|dances with wolves|1990| True|
| 3.0| 12547813| 224.0| 1048.0| 1034.0| 11973.0| 17386832| 3|1990-12-02|dances with wolves|1990| True|
| 4.0| 9246632| -26.3| 1053.0| 5.0| 8781.0| 26633464| 4|1990-12-09|dances with wolves|1990| True|
| 4.0| 7272350| -21.4| 1051.0| -2.0| 6919.0| 33905814| 5|1990-12-16|dances with wolves|1990| True|
+----+-----------+----------------+--------+-------------+--------+-----------+----+----------+------------------+----+------+
only showing top 5 rows
But trying to drop columns is a little involved.
We can specify which column names we want to keep.
df.select('Rank', 'PctChangeWkGross', 'Theaters',
'DeltaTheaters', 'AvgRev', 'name').show(5)
+----+----------------+--------+-------------+--------+------------------+
|Rank|PctChangeWkGross|Theaters|DeltaTheaters| AvgRev| name|
+----+----------------+--------+-------------+--------+------------------+
|17.0| null| 14.0| null| 69098.0|dances with wolves|
| 9.0| 300.0| 14.0| null|276546.0|dances with wolves|
| 3.0| 224.0| 1048.0| 1034.0| 11973.0|dances with wolves|
| 4.0| -26.3| 1053.0| 5.0| 8781.0|dances with wolves|
| 4.0| -21.4| 1051.0| -2.0| 6919.0|dances with wolves|
+----+----------------+--------+-------------+--------+------------------+
only showing top 5 rows
But we won’t have the column names figured out, so we need to figure out how to get at them, procedurally.
Building Column Conditions
We’re going to need some handy functions to facilitate this
from pyspark.sql.functions import col, count, when
For simplicity, let’s just focus on just the PctChangeWkGross
column. We’ll scale up after.
So first, we’ll use the col
function on a string to make an instance of the Col
class.
pcwg = col('PctChangeWkGross')
So let’s check if each value contains a '.'
character
df.select(
pcwg.contains('.')
).count()
3845
Hmm, that doesn’t seem right. That’s the length of everything.
df.count()
3845
So we need to additionally add the count
column function
df.select(
count(pcwg.contains('.'))
).show()
+------------------------------------+
|count(contains(PctChangeWkGross, .))|
+------------------------------------+
| 3625|
+------------------------------------+
We shed like 200 records. That’s more like it. Let’s see what happens with column that we know doesn’t have any.
Thursday
is all yyyy-mm-dd
.
df.select(
count(col('Thursday').contains('.'))
).show()
+----------------------------+
|count(contains(Thursday, .))|
+----------------------------+
| 3845|
+----------------------------+
Still wrong. It’s count
ing a bunch of False
values. Last step, we need a way to make it not count these values.
This is where we’ll use the when
function. The top 5 values of Thursday
look like this.
df.select('Thursday').show(5)
+----------+
| Thursday|
+----------+
|1990-11-18|
|1990-11-25|
|1990-12-02|
|1990-12-09|
|1990-12-16|
+----------+
only showing top 5 rows
when
takes two arguments:
- A Column
of values with a broadcasted check
- A value for “if the check evaluates to True
“
Anything that evaluates to False
becomes NULL
df.select((when(col('Thursday').contains('11'), 0))).show(5)
+-------------------------------------------+
|CASE WHEN contains(Thursday, 11) THEN 0 END|
+-------------------------------------------+
| 0|
| 0|
| null|
| null|
| null|
+-------------------------------------------+
only showing top 5 rows
And count
will skip right over it.
df.select(count(when(col('Thursday').contains('11'), 0))).show(5)
+--------------------------------------------------+
|count(CASE WHEN contains(Thursday, 11) THEN 0 END)|
+--------------------------------------------------+
| 578|
+--------------------------------------------------+
At Scale
Now to do this for multiple columns, we need to do some clever list comprehension
df.select(
[count(when(col(x).contains('.'), 0)) for x in df.columns]
).collect()
[Row(count(CASE WHEN contains(Rank, .) THEN 0 END)=3836, count(CASE WHEN contains(WeeklyGross, .) THEN 0 END)=0, count(CASE WHEN contains(PctChangeWkGross, .) THEN 0 END)=3625, count(CASE WHEN contains(Theaters, .) THEN 0 END)=3836, count(CASE WHEN contains(DeltaTheaters, .) THEN 0 END)=3389, count(CASE WHEN contains(AvgRev, .) THEN 0 END)=3836, count(CASE WHEN contains(GrossToDate, .) THEN 0 END)=0, count(CASE WHEN contains(Week, .) THEN 0 END)=0, count(CASE WHEN contains(Thursday, .) THEN 0 END)=0, count(CASE WHEN contains(name, .) THEN 0 END)=13, count(CASE WHEN contains(year, .) THEN 0 END)=0, count(CASE WHEN contains(Winner, .) THEN 0 END)=0)]
But that looks crazy gross. alias
to the rescue.
dotCounts = df.select(
[count(when(col(x).contains('.'), 0)).alias(x) for x in df.columns]
)
dotCounts.show(5)
+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
|Rank|WeeklyGross|PctChangeWkGross|Theaters|DeltaTheaters|AvgRev|GrossToDate|Week|Thursday|name|year|Winner|
+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
|3836| 0| 3625| 3836| 3389| 3836| 0| 0| 0| 13| 0| 0|
+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
Let’s bring it home with some list comprehension
colsWithDots = [c for c in dotCounts.columns
if dotCounts[[c]].first()[c] == 0]
df.select(colsWithDots).show(5)
+-----------+-----------+----+----------+----+------+
|WeeklyGross|GrossToDate|Week| Thursday|year|Winner|
+-----------+-----------+----+----------+----+------+
| 967378| 967378| 1|1990-11-18|1990| True|
| 3871641| 4839019| 2|1990-11-25|1990| True|
| 12547813| 17386832| 3|1990-12-02|1990| True|
| 9246632| 26633464| 4|1990-12-09|1990| True|
| 7272350| 33905814| 5|1990-12-16|1990| True|
+-----------+-----------+----+----------+----+------+
only showing top 5 rows
Dropping Columns with NULL Data
Same general approach applies to a more practical application of dropping columns with NULL values
from pyspark.sql.functions import isnan, isnull
nullCounts = df.select([count(when(isnan(c)|isnull(c), c)).alias(c) for c in df.columns])
nullCounts.show()
+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
|Rank|WeeklyGross|PctChangeWkGross|Theaters|DeltaTheaters|AvgRev|GrossToDate|Week|Thursday|name|year|Winner|
+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
| 9| 0| 220| 9| 456| 9| 0| 0| 0| 0| 0| 0|
+----+-----------+----------------+--------+-------------+------+-----------+----+--------+----+----+------+
nonNull_cols = [c for c in nullCounts.columns if nullCounts[[c]].first()[c] == 0]
df.select(nonNull_cols).show(5)
+-----------+-----------+----+----------+------------------+----+------+
|WeeklyGross|GrossToDate|Week| Thursday| name|year|Winner|
+-----------+-----------+----+----------+------------------+----+------+
| 967378| 967378| 1|1990-11-18|dances with wolves|1990| True|
| 3871641| 4839019| 2|1990-11-25|dances with wolves|1990| True|
| 12547813| 17386832| 3|1990-12-02|dances with wolves|1990| True|
| 9246632| 26633464| 4|1990-12-09|dances with wolves|1990| True|
| 7272350| 33905814| 5|1990-12-16|dances with wolves|1990| True|
+-----------+-----------+----+----------+------------------+----+------+
only showing top 5 rows
This is probably handy enough to package into a function
def drop_null_cols(frame, threshold):
'''
Drop columns from a `PySpark.DataFrame` that
have more than `threshold` NULL values
'''
nullCounts = frame.select([count(when(isnan(c)|isnull(c), c))
.alias(c) for c in frame.columns])
nonNullCols = [c for c in nullCounts.columns
if nullCounts[[c]].first()[c] < threshold]
return frame.select(nonNullCols)
drop_null_cols(df, 100).show(5)
+----+-----------+--------+--------+-----------+----+----------+------------------+----+------+
|Rank|WeeklyGross|Theaters| AvgRev|GrossToDate|Week| Thursday| name|year|Winner|
+----+-----------+--------+--------+-----------+----+----------+------------------+----+------+
|17.0| 967378| 14.0| 69098.0| 967378| 1|1990-11-18|dances with wolves|1990| True|
| 9.0| 3871641| 14.0|276546.0| 4839019| 2|1990-11-25|dances with wolves|1990| True|
| 3.0| 12547813| 1048.0| 11973.0| 17386832| 3|1990-12-02|dances with wolves|1990| True|
| 4.0| 9246632| 1053.0| 8781.0| 26633464| 4|1990-12-09|dances with wolves|1990| True|
| 4.0| 7272350| 1051.0| 6919.0| 33905814| 5|1990-12-16|dances with wolves|1990| True|
+----+-----------+--------+--------+-----------+----+----------+------------------+----+------+
only showing top 5 rows