Skip to content

Instantly share code, notes, and snippets.

@anna-geller
Created October 7, 2024 20:27
Show Gist options
  • Save anna-geller/8c37a868939ea94a6c91f069dc4c215c to your computer and use it in GitHub Desktop.
Save anna-geller/8c37a868939ea94a6c91f069dc4c215c to your computer and use it in GitHub Desktop.
Example showing how to create interactive workflows that dynamically adapt to user inputs using Kestra’s open-source orchestration platform and Modal’s serverless infrastructure.
id: modal_forecast
namespace: company.team
inputs:
- id: s3_bucket
displayName: S3 bucket name
description: Name of an S3 bucket to upload final ML artifacts
type: STRING
defaults: kestra-us
- id: run_modal
displayName: Run forecasts on Modal
description: Whether to run the forecast on Modal
type: BOOLEAN
defaults: true
- id: dataset_url
displayName: Data lake URL to the dataset used for the ML model
description: Swap the `small` with `large` in the URL for a larger dataset and a more accurate forecast
type: STRING
defaults: "https://huggingface.co/datasets/kestra/datasets/resolve/main/modal/raw_orders_small.parquet"
dependsOn:
inputs:
- run_modal
condition: "{{ inputs.run_modal equals true }}"
- id: cpu
type: SELECT
displayName: CPU request
description: The number of CPU cores to allocate to the job
defaults: "0.25"
values: ["0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0", "8.0", "16.0", "32.0"]
dependsOn:
inputs:
- run_modal
condition: "{{ inputs.run_modal equals true }}"
- id: memory
type: SELECT
displayName: Memory request
description: Amount of memory in MiB
defaults: "512"
values: ["512", "1024", "2048", "4096", "8192", "16384", "32768"]
dependsOn:
inputs:
- run_modal
condition: "{{ inputs.run_modal }}"
- id: customize_forecast
displayName: Customize forecast parameters
description: Whether to customize the visualization and final artifact names
type: BOOLEAN
defaults: false
- id: forecast_file
displayName: Forecast file name
description: Name of the forecast output file that will be stored in S3
type: STRING
defaults: forecast.parquet
dependsOn:
inputs:
- run_modal
- customize_forecast
condition: "{{ inputs.run_modal equals true and inputs.customize_forecast equals true }}"
- id: html_report
displayName: HTML report file name
description: Name of the HTML report that you can download from the Outputs tab in the Kestra UI
type: STRING
defaults: forecast.html
dependsOn:
inputs:
- run_modal
- customize_forecast
condition: "{{ inputs.run_modal equals true and inputs.customize_forecast equals true }}"
- id: color_history
displayName: Time series color for historical data
description: Color for historical data in the plot
type: STRING
defaults: blue
dependsOn:
inputs:
- run_modal
- customize_forecast
condition: "{{ inputs.run_modal equals true and inputs.customize_forecast equals true }}"
- id: color_prediction
displayName: Time series color for forecasted data
description: Color for predicted data in the plot
type: STRING
defaults: red
dependsOn:
inputs:
- run_modal
- customize_forecast
condition: "{{ inputs.run_modal equals true and inputs.customize_forecast equals true }}"
- id: nr_days_fcst
displayName: Forecast days
description: The number of days in the future to generate forecast for
type: INT
defaults: 180
dependsOn:
inputs:
- run_modal
- customize_forecast
condition: "{{ inputs.run_modal equals true and inputs.customize_forecast equals true }}"
tasks:
- id: check_whether_to_run_modal
type: io.kestra.plugin.core.flow.If
condition: "{{ inputs.run_modal equals true }}"
then:
- id: modal
type: io.kestra.plugin.modal.cli.ModalCLI
description: The script.py is added inlined for portability. In production, it should be stored in a Namespace File.
inputFiles:
script.py: |
import os
from typing import Tuple
import modal
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
app = modal.App(
"order-forecast",
secrets=[
modal.Secret.from_local_environ(
env_keys=[
"CPU",
"MEMORY",
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_DEFAULT_REGION",
]
)
],
)
image = modal.Image.debian_slim().pip_install(
"pandas",
"boto3",
"kestra",
"pyarrow",
"plotly",
"statsmodels",
)
@app.function(image=image, cpu=float(os.getenv("CPU")), memory=int(os.getenv("MEMORY")))
def predict_order_volume(
forecast_file: str,
html_report: str,
s3_bucket: str,
nr_days_fcst: int,
dataset_url: str,
color_history: str,
color_prediction: str,
) -> Tuple[str, str]:
import datetime
import boto3
from kestra import Kestra
import pandas as pd
import plotly.graph_objs as go
from statsmodels.tsa.statespace.sarimax import SARIMAX
# ==================== EXTRACT =================
df = pd.read_parquet(dataset_url)
initial_nr_rows = len(df)
print(f"Number of rows in the dataset: {initial_nr_rows}")
# ==================== TRANSFORM ====================
# Extract 'ds' (date) from 'ordered_at' and use 'order_total' as 'y'
df["ds"] = pd.to_datetime(df["ordered_at"]).dt.date
df = df.groupby("ds").agg({"order_total": "sum"}).reset_index()
df.rename(columns={"order_total": "y"}, inplace=True)
# Ensure daily frequency
df["ds"] = pd.to_datetime(df["ds"])
df.set_index("ds", inplace=True)
df = df.asfreq("D", fill_value=0) # Fill missing days with 0 order totals
nr_rows_daily = len(df)
# ==================== TRAIN SARIMA MODEL ====================
model = SARIMAX(df["y"], order=(1, 1, 1), seasonal_order=(1, 1, 1, 7))
sarima_fit = model.fit(disp=False)
# ==================== PREDICT ====================
future = sarima_fit.get_forecast(steps=nr_days_fcst)
forecast = future.summary_frame()
# Create future dates
future_dates = pd.date_range(
df.index.max() + datetime.timedelta(days=1), periods=nr_days_fcst
)
forecast_df = pd.DataFrame({"ds": future_dates, "yhat": forecast["mean"]})
forecast_df.to_parquet(forecast_file)
# ==================== VISUALIZE WITH PLOTLY ====================
forecast_fig = go.Figure()
forecast_fig.add_trace(
go.Scatter(
x=df.index,
y=df["y"],
mode="lines",
name="Historical Order Volume",
line=dict(color=color_history),
)
)
forecast_fig.add_trace(
go.Scatter(
x=forecast_df["ds"],
y=forecast_df["yhat"],
mode="lines",
name="Predicted Order Volume",
line=dict(color=color_prediction),
)
)
forecast_fig.update_layout(
title=f"Order Volume Prediction for the Next {nr_days_fcst} Days",
xaxis_title="Date",
yaxis_title="Order Total",
legend_title="Legend",
xaxis=dict(showgrid=True),
yaxis=dict(showgrid=True),
)
forecast_fig.write_html(html_report)
# ==================== UPLOAD TO S3 ====================
s3 = boto3.client(
"s3",
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
region_name=os.getenv("AWS_DEFAULT_REGION"),
)
files_to_upload = [html_report, forecast_file]
for file_name in files_to_upload:
s3.upload_file(file_name, s3_bucket, file_name)
print(f"File {file_name} uploaded to {s3_bucket}.")
Kestra.outputs(
dict(
initial_nr_rows=initial_nr_rows,
nr_rows_daily=nr_rows_daily,
forecast_file=forecast_file,
html_report=html_report,
)
)
return forecast_file, html_report
@app.local_entrypoint()
def generate_and_predict(
forecast_file: str,
html_report: str,
s3_bucket: str,
nr_days_fcst: int,
dataset_url: str,
color_history: str,
color_prediction: str,
) -> None:
results = predict_order_volume.remote(
forecast_file,
html_report,
s3_bucket,
nr_days_fcst,
dataset_url,
color_history,
color_prediction,
)
print(f"Forecast file: {results[0]}, HTML report: {results[1]}")
namespaceFiles:
enabled: true
commands:
- modal run script.py --forecast-file {{inputs.forecast_file}} --html-report {{inputs.html_report}} --s3-bucket {{inputs.s3_bucket}} --nr-days-fcst {{inputs.nr_days_fcst}} --dataset-url {{inputs.dataset_url}} --color-history {{inputs.color_history}} --color-prediction {{inputs.color_prediction}}
containerImage: ghcr.io/kestra-io/modal:latest
env:
MODAL_TOKEN_ID: "{{ kv('MODAL_TOKEN_ID') }}"
MODAL_TOKEN_SECRET: "{{ kv('MODAL_TOKEN_SECRET') }}"
AWS_ACCESS_KEY_ID: "{{ kv('AWS_ACCESS_KEY_ID') }}"
AWS_SECRET_ACCESS_KEY: "{{ kv('AWS_SECRET_ACCESS_KEY') }}"
AWS_DEFAULT_REGION: "{{ kv('AWS_DEFAULT_REGION') }}"
CPU: "{{ inputs.cpu }}"
MEMORY: "{{ inputs.memory }}"
- id: download_from_s3
type: io.kestra.plugin.aws.s3.Download
accessKeyId: "{{ kv('AWS_ACCESS_KEY_ID') }}"
secretKeyId: "{{ kv('AWS_SECRET_ACCESS_KEY') }}"
region: "{{ kv('AWS_DEFAULT_REGION') }}"
bucket: "{{ inputs.s3_bucket }}"
key: "{{ inputs.html_report }}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment