Scalable Time-Series Forecasting in Python

Overview

Whether predicting daily demand for thousands of products or the number of workers to staff across many distribution centers, generating operational forecasts in parallel is a common task for data scientists. Accordingly, the goal of this post is to outline an approach for creating many forecasts via PySpark. We’ll cover some common data-cleaning steps that often precede forecasting, and then generate several thousand week-level demand predictions for a variety consumer products. Note that we will not cover how to implement this workflow in a cloud computing environment (which, in a real-world setting, would typically be the case). Nor will we delve into model tuning or selection. The goal is to provide a straightforward workflow for quickly generating many time-series forecasts in parallel.

Getting Started

We’ll use data originally provided by Walmart that represents weekly demand for products at the store-department level. All code for this post is stored in the Codeforest Repository. Before diving into the details, let’s briefly review the key modules and files.

conf.json - A configuration file that defines various parameters for our job. It’s a good practice to keep these parameters outside of your actual code, as it makes it easier for others (or future you!) to adapt and extend to other use cases. pyspark_fcast.py - Our main module, or where the forecasting gets done. We’ll cover this in detail below.

fcast_data_frame.py - A class responsible for common pre-forecasting data transformations. These include filling in missing values, filtering time-series with only a few observations, or log transforming our outcome variable. Visit here for access to all methods.

You’ll also need to import the following packages to follow along with the tutorial.

import argparse
import json
import logging
import os
import re
from datetime import datetime
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd

from fbprophet import Prophet # fbprophet==0.7.1 & pystan==2.18.0
from pyspark.sql import SparkSession # pyspark==3.0.1
from pyspark.sql.functions import lit
from pyspark.sql.types import (DateType, FloatType, IntegerType, StructField,
                               StructType)
                               
from pyspark_ts_fcast.fcast_data_frame import FcastDataFrame

Assuming the imports were successful, we’ll peak at a few rows in our data to get a feel for the format.

Sample Data
storedeptdateweekly_sales
112010-02-0524924
112010-02-1246039
112010-02-1941596
112010-02-2619404
112010-03-0521828
Let’s now discuss the process of passing and documenting the forecasting parameters. We’ll execute the following from the command line to generate our forecasts:
python3 pyspark_fcast.py --forecast-config-file 'config/conf.json'
Here we are passing in the location of our configuration file and extracting the parameters. Don’t worry if the individual parameters don’t make sense now. I’ll explain each in greater detail below.

args = read_args()

with open(args.forecast_config_file) as f:
    config = json.load(f)

log_input_params(config=config)

# forecasting parameters
input_data_path = config["input_data_path"]
fcast_params = config["fcast_parameters"]
group_fields = fcast_params["group_fields"]
date_field = fcast_params["date_field"]
yvar_field = fcast_params["yvar_field"]
ts_frequency = fcast_params["ts_frequency"]
fcast_floor = fcast_params["forecast_floor"]
fcast_cap = fcast_params["forecast_cap"]
min_obs_threshold = fcast_params['min_obs_count']

# spark parameters
spark_n_threads = str(config['spark_n_threads'])
java_home = config["java_home"]

Note the two helper functions: read_args and log_input_params.

def read_args() -> argparse.Namespace:
    """Read Forecasting arguments 

    Returns:
        argparse.Namespace: argparse Namespace
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--forecast-config-file", type=str)
    return parser.parse_args()

read_args takes arguments in our configuration file, then we document which parameters we’re using with log_input_params.

logging.basicConfig(
    format="%(levelname)s - %(asctime)s - %(filename)s - %(message)s",
    level=logging.INFO,
    filename="run_{start_time}.log".format(
        start_time=datetime.now().strftime("%Y-%m-%d %H-%M-%S")
    ),
)

def log_input_params(config: dict) -> None:
    """Logs all parameters in configuration file

    Args:
        config (dict): Configuration parameters for forecast and data
    """
    params = pd.json_normalize(config).transpose()
    [
        logging.info("input params:" + x[0] + "-" + str(x[1]))
        for x in zip(params.index, params.iloc[:, 0])
    ]
    return None

There are several benefits to documenting our inputs. First, we can validate if the correct parameters have been passed to our forecasting process. Having a record of these values facilitates debugging. Second, it is useful for experimentation. We can try out different parameters to see which combination provides the best results. Logging does not receive a lot of attention in the data science world, but it is incredibly useful and will save you time as your project matures.

We have our parameters and have set up logging. Next, we’ll read in the data stored here and execute some basic field formatting with clean_names.

def clean_names(df: pd.DataFrame) -> pd.DataFrame:
    """Applies the following transformations to column names:
        - Removes camel case
        - Replaces any double underscore with single underscore
        - Removes spaces in the middle of a string name
        - Replaces periods with underscores

    Args:
        df (pd.DataFrame): Dataframe with untransformed column names

    Returns:
        pd.DataFrame: Dataframe with transformed column names
    """
    cols = df.columns
    cols = [re.sub(r"(?<!^)(?=[A-Z])", "_", x).lower() for x in cols]
    cols = [re.sub(r"_{2,}", "_", x) for x in cols]
    cols = [re.sub(r"\s", "", x) for x in cols]
    cols = [re.sub(r"\.", "_", x) for x in cols]
    df.columns = cols
    return df
sales_df = pd.read_csv(input_data_path)
sales_df = clean_names(sales_df)

If you don’t have a clean_names-type function as part of your codebase, I’d highly recommend creating one. It’s a function that I use frequently when reading data from various sources and encourages a standardized way of formatting field names.

Now that we have our data, we’ll do some pre-forecasting data cleaning. The main steps are outlined below:

Filter groups with limited observations - It’s a good idea to put predictions against items where you have some historical data. While the space of cold-start forecasting is very interesting, it is outside the scope of this post. Thus, we are putting a minimum threshold on the number of data points per group. This is also a good idea because some forecasting algorithms will not fit a model against a few observations, causing your program to crash.

Replace negative values with zero - I’m assuming a negative value represents a returned product. Our goal is to forecast demand not demand - returns. This is an assumption that would need to be validated with domain knowledge.

Pad missing values - Accounting for missing data is an easy step to overlook for one simple reason: Missing values in time-series data are not usually flagged as “missing”. For example, a store may shut down for six weeks of renovations. As a result, there will be a series of dates that have no sales data. Identifying these gaps is pivotal for generating reliable forecasts. I’ve provided a brief example below to illustrate what this looks like from a data perspective.

Incomplete Data
storedeptdateweekly
112010-02-0524924
112010-02-1941596
112010-02-2619404
112010-03-1922137
112010-03-2626229
Padded Data
storedeptdateweekly
112010-02-0524924
NANA2010-02-12NA
112010-02-1941596
112010-02-2619404
NANA2010-03-05NA
NANA2010-03-12NA
112010-03-1922137
112010-03-2626229

We’ll go back and fill or “interpolate” those missing values in the weekly_sales field in a later step.

Filter groups with long ‘streaks’ of missing observations - Building on the previous example, let’s say the store closes for six months instead of six weeks. Thus, half of the year will not have any sales information. We could fill it in with a reasonable value, such as the average, but this won’t capture the overall trend, seasonality, or potential holiday/event effects that help to explain variation in our outcome variable. I’ll often initially exclude these time-series, and then try to understand why/how long streaks of values are missing. In this case, we’ll set a limit of four weeks (i.e., if any time-series has more than four consecutive dates missing, exclude from the final forecasting step).

Interpolate missing values - Fills in missing data with “reasonable” values. We’ll use the overall mean of each series, which is a very simple and easy to understand technique. There are better approaches that account for seasonality or local trends. However, the goal here isn’t to generate the best forecast but instead to create a good starting point from which to iterate.

Add forecasting bounds - This function is specific to the Prophet API and is not required to generate a forecast via PySpark. However, when you cannot inspect the quality of each forecast, adding in some “guardrails” can prevent errant predictions that erode trust with your stakeholders. The floor and cap fields provide bounds that a forecast cannot go above or below. For example, if the minimum value in a time-series is 10 and the maximum is 100, a floor of 0.5 and a cap of 1.5 ensures all forecasted values are not above 150 (100 * 1.5) or less than 5 (10 * 0.5). Again, these decisions are often driven by domain knowledge of the forecaster. We’ll go a bit deeper on this field below as well.

Log transform outcome variable - Log transforming our outcome variable is an effective approach to reduce the influence of outliers and stabilize variance that increases over time. A separate approach is to use a box-cox transformation (see here for more details), which can yield better results than a log-transformation. However, I often start with a log-transformation because it does require us to keep track of the transformation parameters, which is something you’ll need to do with a box-cox transformation. Are we seeing a theme here? Start simple.

Whew - that was a lot of information, but we can finally implement all of these data-cleaning steps via the FcastDataFrame class. The format was inspired by the sklearn.pipeline class to prepare and clean grouped time-series data prior to generating forecasts.

class FcastDataFrame:
    """Use for pre-processing data prior to forecasting"""
    def __init__(
        self,
        df: pd.DataFrame,
        group_fields: List[str],
        date_field: str,
        yvar_field: str,
        ts_frequency: str,
    ):
        """
        Args:
            df (pd.DataFrame): dataframe with to be forecasted data
            group_fields (List[str]): grouping fields. These are often re
                represented by attributes of each unit 
                (e.g., store id, product id, etc.).
            date_field (str): date field
            yvar_field (str): outcome ("y") field
            ts_frequency (str): granularity of the data. For example, 
                data that is recorded on a weekly basis, every Friday will 
                be "W-FRI". Note that sub-day level (e.g, hourly, minute) 
                data is not supported. 
        """
        self.df = df
        self.group_fields = group_fields
        self.date_field = date_field
        self.yvar_field = yvar_field
        self.ts_frequency = ts_frequency
        
fcast_df = FcastDataFrame(
        df=sales_df,
        group_fields=group_fields,
        date_field=date_field,
        yvar_field=yvar_field,
        ts_frequency=ts_frequency,
    )        

While we won’t cover all methods in this class, I’ll briefly review one of the methods – filter_groups_min_obs – to illustrate the structure of the class.

def filter_groups_min_obs(self, min_obs_threshold: int):
    """Filters groups based on some minimum number of observations 
       required for forecasting

    Args:
        min_obs_threshold (int): removes all groups with less obsevations than 
                                 this threshold
    """
    n_unique_groups = self.df[self.group_fields].drop_duplicates().shape[0]
    min_obs_filter_df = (
        self.df.groupby(self.group_fields)[self.yvar_field]
        .count()
        .reset_index()
        .rename(columns={self.yvar_field: "obs_count"})
        .query(f"obs_count > {str(min_obs_threshold)}")
        .drop(columns="obs_count")
    )
    n_remaining_groups = min_obs_filter_df.shape[0]
    df = pd.merge(self.df, min_obs_filter_df, how="inner", on=self.group_fields)
    self.df = df
    logger.info("N groups dropped: {}".format(n_unique_groups - n_remaining_groups))

Each data transformation takes in our data, applies some filtering, cleaning, or formatting, logs the changes, and then replaces the existing DataFrame with the updated DataFrame. This pattern is applied at each step until we are satisfied with the changes. Let’s apply these filtering and cleaning steps below.

# filter out groups with less than min number of observations
fcast_df.filter_groups_min_obs(min_obs_threshold=min_obs_threshold)  
# replace any negative value with a zero
fcast_df.replace_negative_value_with_zero()
# replace missing dates between start and end of time-series by group
fcast_df.pad_missing_values()
# filter groups with consecutive missing streak longer than 4
fcast_df.filter_groups_max_missing_streak(max_streak=4)
# impute missing values
fcast_df.fill_missing_values()
# add upper and lower bounds for forecasting
fcast_df.add_forecast_bounds(
    floor_multiplier=fcast_floor, 
    cap_multiplier=fcast_cap
)
# log transform outcome, floor, and cap values
fcast_df.log_transform_values(yvar_field, "floor_value", "cap_value")
# return transformed data
fcast_df_trans = fcast_df.return_transformed_df()

Now we are ready to do some forecasting. In the next section, we’ll produce our forecasts from the cleaned and prepared data.

Pyspark Forecasting

Let’s start by translating the field names to those that Prophet understands. For example, our date variable should be named ds and our outcome variable y. We’ll use the prep_for_prophet function to make the transition.

def prep_for_prophet(
    df: pd.DataFrame, yvar_field: str, date_field: str, group_fields: List[str]
) -> pd.DataFrame:
    """Renames key field names to be compatible with Prophet Forecasting API

    Args:
        df (pd.DataFrame): Contains data that will be used to generate forecasting
        yvar_field (str): outcome ("y") field name
        date_field (str): date field name
        group_fields (List[str]): grouping fields. These are often
                represented by attributes of each unit
                (e.g., store id, product id, etc.).

    Returns:
        pd.DataFrame: Data with compatible field names
    """
    fields = df.columns.tolist()
    cap_value_index = [
        index 
        for index, value in enumerate(["cap_value" in x for x in fields]) 
        if value
    ]
    floor_value_index = [
        index
        for index, value in enumerate(["floor_value" in x for x in fields])
        if value
    ]
    if cap_value_index and floor_value_index:
        df = df.rename(
            columns={
                fields[cap_value_index[0]]: "cap",
                fields[floor_value_index[0]]: "floor",
            }
        )
        group_fields = group_fields + ["cap", "floor"]
    df = df[group_fields + [date_field] + [yvar_field]]
    df = df.rename(columns={date_field: "ds", yvar_field: "y"})
    df["ds"] = pd.to_datetime(df["ds"])
    return df
fcast_df_prophet_input = prep_for_prophet(
        df=fcast_df_trans,
        yvar_field="weekly_sales_prep_log1p",
        date_field=date_field,
        group_fields=group_fields,
    )    

With our data prepared, we’ll shift over to creating a Spark Session and indicate where our Java version is located. Note this step will vary depending on your computing environment.

os.environ["JAVA_HOME"] = java_home

SPARK = (
    SparkSession.builder.master(f"local[{spark_n_threads}]")
    .appName(config["app_name"])
    .config("spark.sql.execution.arrow.pyspark.enabled", "true")
    .getOrCreate()
)

Next, we’ll define the schema (or format) of our input and output data.

INPUT_SCHEMA = StructType(
        [
            StructField("store", IntegerType(), True),
            StructField("dept", IntegerType(), True),
            StructField("cap", FloatType(), True),
            StructField("floor", FloatType(), True),
            StructField("ds", DateType(), True),
            StructField("y", FloatType(), True),
        ]
    )
    
OUTPUT_SCHEMA = StructType(
        [
            StructField("ds", DateType(), True),
            StructField("store", IntegerType(), True),
            StructField("dept", IntegerType(), True),
            StructField("yhat_lower", FloatType(), True), 
            StructField("yhat_upper", FloatType(), True),
            StructField("yhat", FloatType(), True),
        ]
)    

We’ll now translate our Pandas DataFrame to a Spark DataFrame and pass in the schema we defined above.

fcast_spark_prophet_input = SPARK.createDataFrame(
        fcast_df_prophet_input, schema=INPUT_SCHEMA
)

The function below does the actual forecasting and we’ll spend some time unpacking what’s happening here.

def run_forecast(keys, df):
    """Generate time-series forecast 

    Args:
        keys: Grouping keys
        df: Spark Dataframe 
    """
    fields = ["ds", "store", "dept", "yhat_lower", "yhat_upper","yhat"]
    store, dept = keys
    cap = df["cap"][0]
    floor = df["floor"][0]
    model = Prophet(
        interval_width=0.95,
        growth="logistic",
        yearly_seasonality=True,
        seasonality_mode="additive",
    )
    model.add_country_holidays(country_name="US")
    model.fit(df)
    future_df = model.make_future_dataframe(
        periods=13, freq="W-FRI", include_history=False
    )
    future_df["cap"] = cap
    future_df["floor"] = floor
    results_df = model.predict(future_df)
    results_df["store"] = store
    results_df["dept"] = dept
    results_df = results_df[fields]
    return results_df

Let’s start by discussing the Prophet model, which automates the selection of many forecasting settings, like seasonality, determined during the model fitting process. Below is a brief summary of some of the key settings:

interval_width - Interval width quantifies uncertainty in our forecast. Wider intervals indicate greater uncertainty. Here, we are indicating that the actual values should fall outside of the interval ~5% of the time. By default, Prophet is set to 80%, which is less conservative than our setting here. Providing a measure of uncertainty is perhaps even more important than the forecast itself, as it allows a business to hedge against the risk of being wrong. For example, imagine a product has a high margin and a low inventory holding cost. In this instance, you would want to plan to a high percentile, as you rarely want to stock out of this product.

yearly_seasonality - Setting this to True indicates my belief that there is week-over-week variation that repeats itself over the course of a year. For example, sales for items like sandals or sunscreen are likely higher in Summer weeks and lower in the Winter weeks. There are two other seasonality options not included above - daily and hourly. Daily captures hourly changes within a day, while hourly captures minute-by-minute changes within an hour. Our data is at the week level, so we can ignore these two settings.

growth - Growth is a way to encode our beliefs regarding if a forecast should reach a “saturation” point across your prediction horizon (see here for official documentation). For example, customer acquisition slows as a market matures and will eventually reach a saturation point (i.e., the total addressable market has been acquired). This is typically used for “long-range” forecasting on the scale of several years. Our forecasting horizon is much shorter at only 13 weeks. However, I like to codify what I consider to be reasonable amount of growth, via the “cap” parameter, as well as contraction, via the “floor” parameter, in my forecasts, especially when I cannot inspect each result.

seasonality_mode - I’ve selected “additive” for this parameter based on my belief that the magnitude of seasonal changes do not vary across time. Recall that our outcome variable has already been log-transformed, thus we are actually using an additive decomposition of the log-transformed values.

add_country_holidays - Holidays tend to drive increases in consumption of particular products. And some holidays, like Easter, are not consistent year-over-year. Thus, you can improve forecasting accuracy if you anticipate how demand shifts when generating forecasts based on when holidays occur. One thing to note that is not included in the current post (but is incredibly useful) is the ability to apply a lower_window and upper_window to each holiday date. Continuing with our Easter example, you can imagine egg sales increase in the days leading up to Easter. Sales on the holiday date may not be that high, unless you are doing some last minute shopping. By extending the lower_window parameter for this holiday to something like -5, you can capture the elevated demand during the five days that precede Easter.

Now that we are familiar with how the model is being tuned, let’s generate the forecasts. This may take a few minutes depending on how many threads you are using. I am using four, and it took about 20 minutes to complete.

fcast_df_prophet_output = (
    fcast_spark_prophet_input.groupBy(group_fields)
    .applyInPandas(func=run_forecast, schema=OUTPUT_SCHEMA)
    .withColumn("part", lit("forecast"))
    .withColumn("fcast_date", lit(datetime.now().strftime("%Y-%m-%d")))
    .toPandas()
    .rename(
        columns={
            "yhat": yvar_field,
            "yhat_lower": f"{yvar_field}_lb",
            "yhat_upper": f"{yvar_field}_ub",
            "ds": date_field,
        }
    )
)

We should have 13-week forecasts for all store-department combinations. Our next steps are to combine the forecasts with the historical data and invert our log-transformation of the outcome variable to get back to our original scale. Note that np.log1p and np.expm1 are inverses of one another, and elegantly deal with zero values by adding/subtracting a value of “1” to avoid taking the log of zero, which is undefined and will make your code go 💥. Lastly, we’ll write the results out to our root directory.

fcast_df_prophet_input["part"] = "actuals"
fcast_df_prophet_input = fcast_df_prophet_input.rename(
    columns={"y": yvar_field, "ds": date_field}
)
del fcast_df_prophet_input["cap"]
del fcast_df_prophet_input["floor"]

ret_df = pd.concat([fcast_df_prophet_input, fcast_df_prophet_output])
ret_df = ret_df.apply(lambda x: round(np.expm1(x)) if yvar_field in x.name else x)

ret_df.to_csv(Path.cwd() / "sales_data_forecast.csv", index=False)

Quality Assurance

We’ll transition back to the world of R for some quick quality-assurance work. Let’s read in our forecasts and examine a few store-department combinations. Note there are much more formal ways to validate the performance of our models, but our objective is to do a quick sanity check (i.e., “do the forecasts look reasonable for a few randomly sampled grouped?”). The raw output is stored in Github. Let’s start by examining the first and last five rows for a single Store-Dept combination.

library(tidyverse)
library(timetk)
library(lubridate)

fcast_df_url = "https://raw.githubusercontent.com/thecodeforest/codeforest_datasets/main/pyspark_forecasting_data/sales_data_forecast.csv"
fcast_df = read_csv(fcast_df_url)
df_store_dept_sample <- fcast_df %>% 
  filter(store == 1, dept == 1) %>% 
  mutate(date = as_date(date))
Top 5 Rows of Forecasting Data
storedeptdateweekly_salespartweekly_sales_lbweekly_sales_ubfcast_date
112010-02-0524924actualsNANANA
112010-02-1246039actualsNANANA
112010-02-1941596actualsNANANA
112010-02-2619404actualsNANANA
112010-03-0521828actualsNANANA
Bottom 5 Rows of Forecasting Data
storedeptdateweekly_salespartweekly_sales_lbweekly_sales_ubfcast_date
112012-12-2830948forecast21883428392021-04-05
112013-01-0421138forecast14793300242021-04-05
112013-01-1116149forecast11384228322021-04-05
112013-01-1815553forecast10712216622021-04-05
112013-01-2518954forecast13475272822021-04-05

Let’s sample a few forecasts and plot them out.

set.seed(2021)
fcast_df %>% 
  filter(store < 3,
         dept %in% c(df %>% distinct(dept) %>% sample_n(2) %>% pull())
         ) %>% 
  mutate(store = paste0('Store: ', store),
         dept = paste0('Dept: ', dept),
         store_id = paste(store, dept, sep=' ')) %>% 
  select(date, store_id, contains('weekly')) %>% 
  pivot_longer(contains('weekly')) %>%  
  mutate(name = str_to_title(str_replace_all(name, '_', ' '))) %>% 
  ggplot(aes(date, value, color = name)) + 
  geom_line(size = 1.5, alpha = 0.8) + 
  facet_grid(store_id ~ ., scales = 'free') + 
  theme_bw() + 
  scale_y_continuous(labels = scales::comma_format()) + 
  labs(x = 'Date',
       y = 'Weekly Sales',
       color = NULL,
       title = 'Sample Forecasts'
       ) + 
  theme(legend.position = "top",
        legend.text = element_text(size = 12),
        strip.text.y = element_text(size = 12),
        plot.title = element_text(size = 14)
        )

Overall, the forecasts appear to capture changes in the trend and seasonal variation. A more formal approach to this problem is to do back-testing by holding out some historical data and generating forecasts against it. However, this is a great starting point from which to build more advanced models and incorporate external variables to further improve our forecasts. Hopefully this is enough to get you started on your way to forecasting at an enterprise scale. Until next time, happy forecasting!

Mark LeBoeuf
Mark LeBoeuf
Data Scientist