Essential Data Quality Checks for Data Pipelines: A Comprehensive Guide with PySpark Code Examples
Data quality is of paramount importance for any organization that relies on data-driven decision making. Ensuring the quality of data in a data pipeline is a critical aspect of data engineering, as it helps maintain trust in the data and prevents inaccurate insights or erroneous decisions.
In this blog post, we will delve into 20 essential data quality checks that you should implement in your data pipeline. For each check, we will provide a real-time example and the corresponding PySpark code to help you understand the concept and apply it in your own data engineering projects.
1. Completeness Check
2. Accuracy Check
3. Consistency Check
4. Validity Check
5. Timeliness Check
6. Uniqueness Check
7. Range Check
8. Format Check
9. Statistical Check
10. Referential Integrity Check
11. Duplicated Columns Check
12. Domain-Specific Checks
13. Outlier Check
14. Data Type Check
15. Null Value Check
16. Duplication Check
17. Redundancy Check
18. Integrity Check
19. Precision Check
20. Consistency with Business Rules
1. Completeness Check
Example: In a retail dataset, we need to ensure that all transactions have the necessary information such as transaction_id, product_id, customer_id, and transaction_amount.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# Initialize Spark session
spark = SparkSession.builder \
.appName("Completeness Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Check for missing values in each column
missing_values = retail_data.select(*(col(c).isNull().alias(c) for c in retail_data.columns))
# Count missing values
missing_count = missing_values.agg(*((col(c).cast("int")).sum().alias(c) for c in retail_data.columns))
# Display missing value count for each column
missing_count.show()
# Filter out rows with missing values
complete_data = retail_data.dropna()
# Display the number of rows before and after filteringprint(f"Number of rows before filtering: {retail_data.count()}")
print(f"Number of rows after filtering: {complete_data.count()}")
# Stop Spark session
spark.stop()
Explanation:
First, we import the necessary libraries and functions from PySpark.
We initialize a SparkSession, which is the entry point for any Spark functionality.
We load the retail dataset using the spark.read.csv method, specifying the path to the dataset, and setting `header=True` and `inferSchema=True` to read the header and infer the schema.
We create a DataFrame called `missing_values` by selecting each column and checking if the value is null, using the `isNull()` function.
We count the number of missing values in each column by aggregating the `missing_values` DataFrame and casting the boolean values to integers (True=1, False=0) before summing them.
We display the count of missing values for each column using the show() method on the `missing_count` DataFrame.
We filter out rows with missing values using the `dropna()` method on the retail_data DataFrame and store the result in a new DataFrame called `complete_data`.
We display the number of rows before and after filtering to show the impact of removing rows with missing values.
Finally, we stop the Spark session.
2. Accuracy Check
Example: In a retail dataset, we want to ensure that the transaction_amount column has accurate values, i.e., it should be greater than or equal to zero.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# Initialize Spark session
spark = SparkSession.builder \
.appName("Accuracy Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Check for inaccurate values in transaction_amount
inaccurate_values = retail_data.filter(col("transaction_amount") < 0)
# Display the count of inaccurate valuesprint(f"Number of inaccurate transaction_amount values: {inaccurate_values.count()}")
# Filter out rows with inaccurate values
accurate_data = retail_data.filter(col("transaction_amount") >= 0)
# Display the number of rows before and after filteringprint(f"Number of rows before filtering: {retail_data.count()}")
print(f"Number of rows after filtering: {accurate_data.count()}")
# Stop Spark session
spark.stop()
Explanation:
We create a DataFrame called inaccurate_values by filtering the retail_data DataFrame using the `filter()` method, and specifying a condition that checks if the transaction_amount column has values less than zero.
We display the count of inaccurate values by calling the count() method on the `inaccurate_values` DataFrame.
We create a new DataFrame called `accurate_data` by filtering the retail_data DataFrame again, this time using a condition that checks if the transaction_amount column has values greater than or equal to zero.
We display the number of rows before and after filtering to show the impact of removing rows with inaccurate values.
Finally, we stop the Spark session.
3. Consistency Check
Example: In a retail dataset, we want to ensure consistency between two columns: transaction_id and transaction_date. If a transaction_id is duplicated, the corresponding transaction_date values should be the same.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, first
# Initialize Spark session
spark = SparkSession.builder \
.appName("Consistency Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Group by transaction_id and aggregate the count and first occurrence of transaction_date
grouped_data = retail_data.groupBy("transaction_id") \
.agg(count("transaction_date").alias("date_count"),
first("transaction_date").alias("first_date"))
# Check for inconsistency: transaction_id with more than 1 transaction_date
inconsistent_data = grouped_data.filter(col("date_count") > 1)
# Display the count of inconsistent rowsprint(f"Number of inconsistent transaction_id values: {inconsistent_data.count()}")
# Filter out rows with inconsistent data by joining with original dataset
consistent_data = retail_data.join(inconsistent_data, on="transaction_id", how="left_anti")
# Display the number of rows before and after filteringprint(f"Number of rows before filtering: {retail_data.count()}")
print(f"Number of rows after filtering: {consistent_data.count()}")
# Stop Spark session
spark.stop()
Explanation:
We group the retail_data DataFrame by transaction_id and aggregate the count of transaction_date and the first occurrence of transaction_date for each group using the `groupBy()` and `agg()` methods.
We create a DataFrame called inconsistent_data by filtering the grouped_data DataFrame with a condition that checks if the count of transaction_date values for a transaction_id is greater than 1.
We display the count of inconsistent rows by calling the `count()` method on the inconsistent_data DataFrame.
We create a new DataFrame called consistent_data by performing a left anti join between the retail_data DataFrame and the inconsistent_data DataFrame on the transaction_id column. This operation filters out rows with inconsistent data.
We display the number of rows before and after filtering to show the impact of removing rows with inconsistent data.
Finally, we stop the Spark session.
4. Validity Check
Example: In a retail dataset, we want to ensure that the email addresses in the customer_email column are in a valid format.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import re
# Initialize Spark session
spark = SparkSession.builder \
.appName("Validity Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Define a function to check if an email address is validdef is_valid_email(email):
email_regex = r'^[\w\.-]+@[\w\.-]+\.\w+$'return bool(re.match(email_regex, email))
# Register the function as a UDF (User-Defined Function)from pyspark.sql.types import BooleanType
from pyspark.sql.functions import udf
is_valid_email_udf = udf(is_valid_email, BooleanType())
# Check for invalid email addresses
invalid_emails = retail_data.filter(~is_valid_email_udf(col("customer_email")))
# Display the count of invalid email addressesprint(f"Number of invalid email addresses: {invalid_emails.count()}")
# Filter out rows with invalid email addresses
valid_data = retail_data.filter(is_valid_email_udf(col("customer_email")))
# Display the number of rows before and after filteringprint(f"Number of rows before filtering: {retail_data.count()}")
print(f"Number of rows after filtering: {valid_data.count()}")
# Stop Spark session
spark.stop()
Explanation:
We define a function is_valid_email that takes an email address as input and returns True if the email address is valid, and False otherwise. We use the re.match() function with a regular expression to validate the email format.
We register the is_valid_email function as a UDF (User-Defined Function) in PySpark, specifying the return data type as BooleanType.
We create a DataFrame called invalid_emails by filtering the retail_data DataFrame using the `filter()` method and the `is_valid_email_udf()` function to check for invalid email addresses.
We display the count of invalid email addresses by calling the count() method on the invalid_emails DataFrame.
We create a new DataFrame called valid_data by filtering the retail_data DataFrame again, this time using the `is_valid_email_udf()` function to check for valid email addresses.
We display the number of rows before and after filtering to show the impact of removing rows with invalid email addresses.
Finally, we stop the Spark session.
5. Timeliness Check
Example: In a retail dataset, we want to ensure that the transaction_date column only contains records from the current year.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, year
from datetime import datetime
# Initialize Spark session
spark = SparkSession.builder \
.appName("Timeliness Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Get the current year
current_year = datetime.now().year
# Check for outdated transaction_date values
outdated_transactions = retail_data.filter(year(col("transaction_date")) != current_year)
# Display the count of outdated transactionsprint(f"Number of outdated transactions: {outdated_transactions.count()}")
# Filter out rows with outdated transaction_date values
timely_data = retail_data.filter(year(col("transaction_date")) == current_year)
# Display the number of rows before and after filteringprint(f"Number of rows before filtering: {retail_data.count()}")
print(f"Number of rows after filtering: {timely_data.count()}")
# Stop Spark session
spark.stop()
Explanation:
We get the current year using the `datetime.now().year` method from the datetime module.
We create a DataFrame called `outdated_transactions` by filtering the retail_data DataFrame using the filter() method and a condition that checks if the year of the transaction_date column is not equal to the current year.
We display the count of outdated transactions by calling the count() method on the `outdated_transactions` DataFrame.
We create a new DataFrame called `timely_data` by filtering the retail_data DataFrame again, this time using a condition that checks if the year of the transaction_date column is equal to the current year.
We display the number of rows before and after filtering to show the impact of removing rows with outdated transaction_date values.
Finally, we stop the Spark session.
6. Uniqueness Check
Example: In a retail dataset, we want to ensure that the transaction_id column contains unique values and no duplicates.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count
# Initialize Spark session
spark = SparkSession.builder \
.appName("Uniqueness Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Group by transaction_id and count the occurrences
grouped_data = retail_data.groupBy("transaction_id").agg(count("*").alias("count"))
# Check for duplicate transaction_id values
duplicate_transactions = grouped_data.filter(col("count") > 1)
# Display the count of duplicate transactionsprint(f"Number of duplicate transactions: {duplicate_transactions.count()}")
# Filter out rows with duplicate transaction_id values by joining with original dataset
unique_data = retail_data.join(duplicate_transactions, on="transaction_id", how="left_anti")
# Display the number of rows before and after filteringprint(f"Number of rows before filtering: {retail_data.count()}")
print(f"Number of rows after filtering: {unique_data.count()}")
# Stop Spark session
spark.stop()
Explanation:
We group the retail_data DataFrame by transaction_id and count the occurrences of each transaction_id using the `groupBy()` and `agg()` methods.
We create a DataFrame called `duplicate_transactions` by filtering the grouped_data DataFrame using the filter() method and a condition that checks if the count of transaction_id occurrences is greater than 1.
We display the count of duplicate transactions by calling the count() method on the `duplicate_transactions` DataFrame.
We create a new DataFrame called `unique_data` by performing a left anti join between the retail_data DataFrame and the duplicate_transactions DataFrame on the transaction_id column. This operation filters out rows with duplicate transaction_id values.
We display the number of rows before and after filtering to show the impact of removing rows with duplicate `transaction_id` values.
Finally, we stop the Spark session.
7. Range Check
Example: In a retail dataset, we want to ensure that the prices in the product_price column fall within a specific range, such as between $1 and $10,000.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# Initialize Spark session
spark = SparkSession.builder \
.appName("Range Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Define the range for product_price
min_price = 1
max_price = 10000# Check for product_price values outside the defined range
out_of_range_prices = retail_data.filter((col("product_price") < min_price) | (col("product_price") > max_price))
# Display the count of out of range product_price valuesprint(f"Number of out of range product prices: {out_of_range_prices.count()}")
# Filter out rows with product_price values outside the defined range
in_range_data = retail_data.filter((col("product_price") >= min_price) & (col("product_price") <= max_price))
# Display the number of rows before and after filteringprint(f"Number of rows before filtering: {retail_data.count()}")
print(f"Number of rows after filtering: {in_range_data.count()}")
# Stop Spark session
spark.stop()
Explanation:
We define the range for the `product_price` column by setting the `min_price` and `max_price` variables.
We create a DataFrame called `out_of_range_prices` by filtering the retail_data DataFrame using the `filter()` method and a condition that checks if the product_price column values are outside the defined range.
We display the count of out of range product prices by calling the count() method on the `out_of_range_prices DataFrame`.
We create a new DataFrame called `in_range_data` by filtering the retail_data DataFrame again, this time using a condition that checks if the product_price column values are within the defined range.
We display the number of rows before and after filtering to show the impact of removing rows with out of range product prices.
Finally, we stop the Spark session.
8. Format Check
Example: In a retail dataset, we want to ensure that the dates in the transaction_date column are in a specific format, such as "yyyy-MM-dd".
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_date
from pyspark.sql.types import StringType
# Initialize Spark session
spark = SparkSession.builder \
.appName("Format Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Define the date format
date_format = "yyyy-MM-dd"# Check if the transaction_date column is in the correct format by converting it to the expected format and comparing the result
retail_data = retail_data.withColumn("formatted_date", to_date(col("transaction_date"), date_format).cast(StringType()))
incorrect_format = retail_data.filter(col("transaction_date") != col("formatted_date"))
# Display the count of transaction_date values with incorrect formatprint(f"Number of transaction dates with incorrect format: {incorrect_format.count()}")
# Filter out rows with incorrect date format
correct_format_data = retail_data.filter(col("transaction_date") == col("formatted_date"))
# Display the number of rows before and after filteringprint(f"Number of rows before filtering: {retail_data.count()}")
print(f"Number of rows after filtering: {correct_format_data.count()}")
# Stop Spark session
spark.stop()
Explanation:
We define the expected date format by setting the date_format variable.
We create a new column called "formatted_date" in the retail_data DataFrame by converting the transaction_date column to the expected date format using the `to_date()` function and casting the result to StringType.
We create a DataFrame called `incorrect_format` by filtering the retail_data DataFrame using the `filter()` method and a condition that checks if the transaction_date column values are not equal to the formatted_date column values.
We display the count of transaction_date values with incorrect format by calling the `count()` method on the incorrect_format DataFrame.
We create a new DataFrame called correct_format_data by filtering the retail_data DataFrame again, this time using a condition that checks if the transaction_date column values are equal to the formatted_date column values.
We display the number of rows before and after filtering to show the impact of removing rows with incorrect date format.
Finally, we stop the Spark session.
9. Statistical Check
Example: In a retail dataset, we want to check if the product_price column follows a normal distribution.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, skewness, kurtosis
import scipy.stats as stats
# Initialize Spark session
spark = SparkSession.builder \
.appName("Statistical Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Calculate skewness and kurtosis for product_price
price_summary = retail_data.select(skewness("product_price").alias("skewness"), kurtosis("product_price").alias("kurtosis")).collect()
price_skewness = price_summary[0]["skewness"]
price_kurtosis = price_summary[0]["kurtosis"]
# Display skewness and kurtosisprint(f"Skewness of product prices: {price_skewness}")
print(f"Kurtosis of product prices: {price_kurtosis}")
# Check if product_price follows a normal distribution using skewness and kurtosis
alpha = 0.05
skewness_test_statistic, skewness_p_value = stats.skewtest(retail_data.select("product_price").rdd.flatMap(lambda x: x).collect())
kurtosis_test_statistic, kurtosis_p_value = stats.kurtosistest(retail_data.select("product_price").rdd.flatMap(lambda x: x).collect())
is_normal_distribution = skewness_p_value > alpha and kurtosis_p_value > alpha
print(f"Does product_price follow a normal distribution? {is_normal_distribution}")
# Stop Spark session
spark.stop()
Explanation:
We calculate the skewness and kurtosis of the product_price column using the skewness() and kurtosis() functions from PySpark, and collect the results in a list called `price_summary`.
We extract the skewness and kurtosis values from the `price_summary` list and store them in the `price_skewness` and `price_kurtosis` variables.
We display the skewness and kurtosis values to give an idea of the distribution of the product_price column.
We use the scipy library's `skewtest()` and `kurtosistest()` functions to test if the product_price column follows a normal distribution. We pass the product_price values to these functions by converting the DataFrame column to an RDD and then to a list.
We check if the product_price follows a normal distribution by comparing the p-values from the skewness and kurtosis tests with the significance level (alpha), which we set to 0.05.
We display the result of the normal distribution check.
Finally, we stop the Spark session.
10. Referential Integrity Check
Example: In a retail dataset with two tables - one containing product information and the other containing sales transactions - we want to ensure that all product IDs in the sales table have a corresponding entry in the product table.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# Initialize Spark session
spark = SparkSession.builder \
.appName("Referential Integrity Check") \
.getOrCreate()
# Load product and sales datasets
product_data = spark.read.csv("path/to/product_data.csv", header=True, inferSchema=True)
sales_data = spark.read.csv("path/to/sales_data.csv", header=True, inferSchema=True)
# Check referential integrity between product_data and sales_data
missing_products = sales_data.join(product_data, sales_data["product_id"] == product_data["id"], "left_anti")
# Display the count of sales records with missing product dataprint(f"Number of sales records with missing product data: {missing_products.count()}")
# Filter out sales records with missing product data
valid_sales_data = sales_data.join(product_data, sales_data["product_id"] == product_data["id"], "inner")
# Display the number of rows before and after filteringprint(f"Number of sales records before filtering: {sales_data.count()}")
print(f"Number of sales records after filtering: {valid_sales_data.count()}")
# Stop Spark session
spark.stop()
Explanation:
We perform a referential integrity check between the product_data and sales_data DataFrames by joining them using the "left_anti" join type. This type of join returns only the rows from the sales_data DataFrame that do not have a corresponding match in the product_data DataFrame based on the product_id column.
We store the result of the join in a DataFrame called `missing_products`.
We display the count of sales records with missing product data by calling the count() method on the `missing_products` DataFrame.
We create a new DataFrame called `valid_sales_data` by filtering the sales_data DataFrame using an "inner" join with the product_data DataFrame. This join type returns only the rows with matching product IDs in both DataFrames.
We display the number of sales records before and after filtering to show the impact of removing rows with missing product data.
Finally, we stop the Spark session.
11. Duplicated Columns Check
Example: In a retail dataset, we want to ensure that there are no duplicated columns with the same name and content.
PySpark code:
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder \
.appName("Duplicated Columns Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Get column names
column_names = retail_data.columns
# Find duplicated columns
duplicated_columns = []
for i in range(len(column_names)):
for j in range(i + 1, len(column_names)):
column_i = retail_data.select(column_names[i])
column_j = retail_data.select(column_names[j])
if column_i.distinct().count() == column_j.distinct().count() and \
column_i.distinct().intersect(column_j.distinct()).count() == column_i.distinct().count():
duplicated_columns.append((column_names[i], column_names[j]))
# Display duplicated columnsprint(f"Duplicated columns: {duplicated_columns}")
# Stop Spark session
spark.stop()
Explanation:
We get the column names of the retail_data DataFrame using the columns attribute.
We initialize an empty list called `duplicated_columns` to store the names of duplicated columns.
We iterate through all pairs of columns using two nested loops. For each pair of columns, we select the columns separately using the `select()` method and store them in the `column_i` and `column_j` variables.
We compare the distinct values of both columns by counting their distinct values and checking if the intersection of their distinct values has the same count as their individual distinct values. If these conditions are true, we consider the columns to be duplicated and append their names as a tuple to the duplicated_columns list.
We display the duplicated columns by printing the duplicated_columns list.
Finally, we stop the Spark session.
12. Domain-Specific Checks
Example: In a retail dataset, we want to ensure that all product categories follow a predefined set of allowed categories.
PySpark code:
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder \
.appName("Domain-Specific Checks") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Define allowed product categories
allowed_categories = {"Electronics", "Clothing", "Home & Kitchen", "Sports & Outdoors", "Health & Beauty"}
# Check if all product categories are in the allowed categories set
invalid_categories = retail_data.select("category").distinct().filter(~retail_data["category"].isin(allowed_categories))
# Display invalid categoriesprint("Invalid categories:")
invalid_categories.show()
# Stop Spark session
spark.stop()
Explanation:
We define the allowed product categories in a set called `allowed_categories`.
We perform a domain-specific check by selecting the distinct values of the "category" column and filtering them using the `filter()` method and the `isin()` function. We use the `~` operator to negate the condition, so we only keep the categories that are not in the `allowed_categories` set.
We store the result in a DataFrame called `invalid_categories`.
We display the invalid categories by calling the `show()` method on the `invalid_categories` DataFrame.
Finally, we stop the Spark session.
13. Outlier Check
Example: In a retail dataset, we want to identify and filter out sales records with outlier sales amounts, which may skew our analysis.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, stddev_pop
# Initialize Spark session
spark = SparkSession.builder \
.appName("Outlier Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Calculate mean and standard deviation of sales amounts
stats = retail_data.agg({"sales_amount": "mean", "sales_amount": "stddev_pop"}).collect()[0]
mean_sales_amount = stats["avg(sales_amount)"]
stddev_sales_amount = stats["stddev_pop(sales_amount)"]
# Define outlier threshold (e.g., 3 standard deviations from the mean)
outlier_threshold = 3 * stddev_sales_amount
# Identify and filter out sales records with outlier sales amounts
outliers = retail_data.filter(col("sales_amount") > mean_sales_amount + outlier_threshold)
filtered_data = retail_data.filter(col("sales_amount") <= mean_sales_amount + outlier_threshold)
# Display outliersprint("Outliers:")
outliers.show()
# Display the number of rows before and after filteringprint(f"Number of sales records before filtering: {retail_data.count()}")
print(f"Number of sales records after filtering: {filtered_data.count()}")
# Stop Spark session
spark.stop()
Explanation:
We calculate the mean and standard deviation of the sales_amount column using the `agg()` method and store the result in a `stats` variable. We then extract the mean and standard deviation values from the `stats` variable.
We define an outlier threshold, which is a multiple of the standard deviation (e.g., 3 standard deviations) away from the mean sales amount.
We identify and filter out sales records with outlier sales amounts using the filter() method and the `col()` function. We create two DataFrames: outliers containing the outlier records and `filtered_data` containing the remaining records.
We display the outliers by calling the `show()` method on the `outliers` DataFrame.
We display the number of sales records before and after filtering to show the impact of removing outlier rows.
Finally, we stop the Spark session.
14. Data Type Check
Example: In a retail dataset, we want to ensure that the data types of columns are correct and consistent across the dataset.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, DoubleType, IntegerType, TimestampType
# Initialize Spark session
spark = SparkSession.builder \
.appName("Data Type Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Define expected data types for columns
expected_data_types = {
"transaction_id": StringType(),
"customer_id": StringType(),
"product_id": StringType(),
"quantity": IntegerType(),
"sales_amount": DoubleType(),
"transaction_date": TimestampType()
}
# Check data types and collect incorrect data types
incorrect_data_types = {}
for column_name, expected_data_type in expected_data_types.items():
actual_data_type = retail_data.schema[column_name].dataType
if actual_data_type != expected_data_type:
incorrect_data_types[column_name] = (actual_data_type, expected_data_type)
# Display incorrect data typesprint("Incorrect data types:")
for column_name, (actual_data_type, expected_data_type) in incorrect_data_types.items():
print(f"{column_name}: actual={actual_data_type}, expected={expected_data_type}")
# Stop Spark session
spark.stop()
Explanation:
We define the expected data types for each column in the dataset using a dictionary called `expected_data_types`.
We iterate through the `expected_data_types` dictionary and check the actual data type of each column against the expected data type. If the actual data type is different from the expected data type, we store the column name and the actual and expected data types in a dictionary called `incorrect_data_types`.
We display the incorrect data types by iterating through the `incorrect_data_types` dictionary and printing the column name, actual data type, and expected data type.
Finally, we stop the Spark session.
15. Null Value Check
Example: In a retail dataset, we want to ensure that there are no null values in the critical columns, such as the transaction_id and sales_amount columns.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, when
# Initialize Spark session
spark = SparkSession.builder \
.appName("Null Value Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Define critical columns for null value check
critical_columns = ["transaction_id", "sales_amount"]
# Check for null values in critical columns and collect the counts
null_counts = {}
for column_name in critical_columns:
null_count = retail_data.where(col(column_name).isNull()).count()
null_counts[column_name] = null_count
# Display null counts for critical columnsprint("Null value counts:")
for column_name, null_count in null_counts.items():
print(f"{column_name}: {null_count}")
# Stop Spark session
spark.stop()
Explanation:
We define the critical columns for the null value check in a list called `critical_columns`.
We iterate through the `critical_columns` list and use the `where()` method and the `isNull()` function to filter rows with null values in the specified column. We then count the number of rows with null values and store the count in a dictionary called `null_counts`.
We display the null counts for each critical column by iterating through the `null_counts` dictionary and printing the column name and the null count.
Finally, we stop the Spark session.
16. Duplication Check
Example: In a retail dataset, we want to ensure that there are no duplicate records or rows, based on the transaction_id column.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# Initialize Spark session
spark = SparkSession.builder \
.appName("Duplication Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Find and display duplicate records based on transaction_id
duplicates = retail_data.groupBy("transaction_id").count().where(col("count") > 1)
print("Duplicate records:")
duplicates.show()
# Remove duplicate records and keep the first record of each duplicate group
unique_data = retail_data.dropDuplicates(subset=["transaction_id"])
# Display the number of rows before and after removing duplicatesprint(f"Number of records before removing duplicates: {retail_data.count()}")
print(f"Number of records after removing duplicates: {unique_data.count()}")
# Stop Spark session
spark.stop()
Explanation:
We find duplicate records based on the "transaction_id" column by using the `groupBy()` method and the `count()` function. We filter the groups with more than one record using the `where()` method and the `col()` function.
We store the result in a DataFrame called duplicates and display the duplicate records by calling the `show()` method on the duplicates DataFrame.
We remove duplicate records and keep the first record of each duplicate group using the `dropDuplicates()` method and specifying the subset of columns to consider for duplication detection (in this case, "transaction_id").
We store the result in a DataFrame called `unique_data` and display the number of rows before and after removing duplicates to show the impact of the deduplication process.
Finally, we stop the Spark session.
17. Redundancy Check
Example: In a retail dataset, we want to ensure that there are no redundant or unnecessary columns, such as duplicate columns or columns derived from other columns in the dataset.
PySpark code:
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder \
.appName("Redundancy Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Define redundant columns
redundant_columns = ["sales_amount_in_cents"]
# Remove redundant columns from the dataset
cleaned_data = retail_data.drop(*redundant_columns)
# Display the schema before and after removing redundant columnsprint("Schema before removing redundant columns:")
retail_data.printSchema()
print("Schema after removing redundant columns:")
cleaned_data.printSchema()
# Stop Spark session
spark.stop()
Explanation:
We import the necessary libraries from PySpark and initialize a SparkSession.
We load the retail dataset using the spark.read.csv method, specifying the path to the dataset, and setting header=True and inferSchema=True to read the header and infer the schema.
We define the redundant columns in a list called `redundant_columns`. In this example, we assume that the "sales_amount_in_cents" column is redundant, as it can be derived from the "sales_amount" column.
We remove the redundant columns from the dataset using the `drop()` method and passing the list of redundant columns, unpacked using the * operator.
We store the result in a DataFrame called `cleaned_data` and display the schema before and after removing the redundant columns by calling the `printSchema()` method on the `retail_data` and `cleaned_data` DataFrames.
Finally, we stop the Spark session.
18. Integrity Check
Example: In a retail dataset, we want to ensure that the data is not tampered with or corrupted by comparing the checksum of the dataset with the expected checksum.
PySpark code:
import hashlib
from pyspark.sql import SparkSession
def file_checksum(file_path, algorithm="sha256"):
hasher = hashlib.new(algorithm)
with open(file_path, "rb") as file:
for chunk in iter(lambda: file.read(4096), b""):
hasher.update(chunk)
return hasher.hexdigest()
# Initialize Spark session
spark = SparkSession.builder \
.appName("Integrity Check") \
.getOrCreate()
# Define the path to the retail dataset and the expected checksum
retail_data_path = "path/to/retail_data.csv"
expected_checksum = "your_expected_checksum_here"# Calculate the actual checksum of the dataset
actual_checksum = file_checksum(retail_data_path)
# Compare the actual checksum with the expected checksumif actual_checksum == expected_checksum:
print("Data integrity check passed.")
retail_data = spark.read.csv(retail_data_path, header=True, inferSchema=True)
else:
print("Data integrity check failed.")
print(f"Expected checksum: {expected_checksum}")
print(f"Actual checksum: {actual_checksum}")
# Stop Spark session
spark.stop()
Explanation:
We import the necessary libraries and functions from PySpark and hashlib, and define a helper function called `file_checksum` that takes a file path and a hashing algorithm as input and returns the file's checksum.
We initialize a SparkSession.
We define the path to the retail dataset and the expected checksum.
We calculate the actual checksum of the dataset using the `file_checksum` function.
We compare the actual checksum with the expected checksum. If they match, we proceed with loading the dataset using the `spark.read.csv` method, specifying the path to the dataset, and setting `header=True` and `inferSchema=True` to read the header and infer the schema. If the checksums do not match, we display an error message with the expected and actual checksums.
Finally, we stop the Spark session.
19. Precision Check
Example: In a retail dataset, we want to ensure that the data is precise and accurate to the desired level of detail. In this case, we will check that the decimal precision of the sales_amount column is correct (e.g., up to two decimal places).
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf
from pyspark.sql.types import BooleanType
import math
# User-defined function to check the decimal precision of a valuedef check_decimal_precision(value, precision):
if value is None:
return False
decimal_part = value - math.floor(value)
decimal_precision = len(str(decimal_part)) - 2return decimal_precision <= precision
# Register the user-defined function in PySpark
check_precision_udf = udf(check_decimal_precision, BooleanType())
# Initialize Spark session
spark = SparkSession.builder \
.appName("Precision Check") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Define the desired precision for the sales_amount column
desired_precision = 2# Check the precision of the sales_amount column using the user-defined function
invalid_precision_data = retail_data.where(~check_precision_udf(col("sales_amount"), desired_precision))
# Display the records with invalid precisionprint("Records with invalid precision:")
invalid_precision_data.show()
# Stop Spark session
spark.stop()
Explanation:
We import the necessary libraries and functions from PySpark and Python's math module.
We define a user-defined function (UDF) called `check_decimal_precision` that takes a value and a desired precision as input and returns True if the value has the correct decimal precision, and False otherwise.
We register the UDF in PySpark using the `udf()` function.
We initialize a SparkSession.
We load the retail dataset using the `spark.read.csv` method, specifying the path to the dataset, and setting `header=True` and `inferSchema=True` to read the header and infer the schema.
We define the desired precision for the sales_amount column (e.g., 2 decimal places).
We use the `where()` method and the registered UDF to filter rows with sales_amount values that do not have the correct decimal precision. We store the result in a DataFrame called `invalid_precision_data`.
We display the records with invalid precision by calling the `show()` method on the invalid_precision_data DataFrame.
Finally, we stop the Spark session.
20. Consistency with Business Rules
Example: In a retail dataset, we want to ensure that the data is consistent with the business rules and logic. For instance, we will check that the total sales amount in the dataset matches the sum of individual sales amounts.
PySpark code:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum# Initialize Spark session
spark = SparkSession.builder \
.appName("Consistency with Business Rules") \
.getOrCreate()
# Load retail dataset
retail_data = spark.read.csv("path/to/retail_data.csv", header=True, inferSchema=True)
# Calculate the sum of individual sales amounts
sum_individual_sales = retail_data.select(sum(col("sales_amount"))).first()[0]
# Define the total sales amount in the dataset (from another source or calculation)
total_sales_amount = 123456.78# Check if the sum of individual sales amounts matches the total sales amountif round(sum_individual_sales, 2) == round(total_sales_amount, 2):
print("The data is consistent with business rules.")
else:
print("The data is not consistent with business rules.")
print(f"Sum of individual sales amounts: {sum_individual_sales:.2f}")
print(f"Total sales amount: {total_sales_amount:.2f}")
# Stop Spark session
spark.stop()
Explanation:
We calculate the sum of individual sales amounts by selecting the "sales_amount" column, applying the `sum()` function, and using `first()` to retrieve the first row of the result.
We define the total sales amount in the dataset (from another source or calculation).
We check if the sum of individual sales amounts matches the total sales amount. If they match, we print a message stating that the data is consistent with business rules. If they do not match, we display an error message with the sum of individual sales amounts and the total sales amount.
Finally, we stop the Spark session.
Comments