Created
October 7, 2024 20:27
-
-
Save anna-geller/8c37a868939ea94a6c91f069dc4c215c to your computer and use it in GitHub Desktop.
Example showing how to create interactive workflows that dynamically adapt to user inputs using Kestra’s open-source orchestration platform and Modal’s serverless infrastructure.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
id: modal_forecast | |
namespace: company.team | |
inputs: | |
- id: s3_bucket | |
displayName: S3 bucket name | |
description: Name of an S3 bucket to upload final ML artifacts | |
type: STRING | |
defaults: kestra-us | |
- id: run_modal | |
displayName: Run forecasts on Modal | |
description: Whether to run the forecast on Modal | |
type: BOOLEAN | |
defaults: true | |
- id: dataset_url | |
displayName: Data lake URL to the dataset used for the ML model | |
description: Swap the `small` with `large` in the URL for a larger dataset and a more accurate forecast | |
type: STRING | |
defaults: "https://huggingface.co/datasets/kestra/datasets/resolve/main/modal/raw_orders_small.parquet" | |
dependsOn: | |
inputs: | |
- run_modal | |
condition: "{{ inputs.run_modal equals true }}" | |
- id: cpu | |
type: SELECT | |
displayName: CPU request | |
description: The number of CPU cores to allocate to the job | |
defaults: "0.25" | |
values: ["0.25", "0.5", "0.75", "1.0", "1.5", "2.0", "4.0", "8.0", "16.0", "32.0"] | |
dependsOn: | |
inputs: | |
- run_modal | |
condition: "{{ inputs.run_modal equals true }}" | |
- id: memory | |
type: SELECT | |
displayName: Memory request | |
description: Amount of memory in MiB | |
defaults: "512" | |
values: ["512", "1024", "2048", "4096", "8192", "16384", "32768"] | |
dependsOn: | |
inputs: | |
- run_modal | |
condition: "{{ inputs.run_modal }}" | |
- id: customize_forecast | |
displayName: Customize forecast parameters | |
description: Whether to customize the visualization and final artifact names | |
type: BOOLEAN | |
defaults: false | |
- id: forecast_file | |
displayName: Forecast file name | |
description: Name of the forecast output file that will be stored in S3 | |
type: STRING | |
defaults: forecast.parquet | |
dependsOn: | |
inputs: | |
- run_modal | |
- customize_forecast | |
condition: "{{ inputs.run_modal equals true and inputs.customize_forecast equals true }}" | |
- id: html_report | |
displayName: HTML report file name | |
description: Name of the HTML report that you can download from the Outputs tab in the Kestra UI | |
type: STRING | |
defaults: forecast.html | |
dependsOn: | |
inputs: | |
- run_modal | |
- customize_forecast | |
condition: "{{ inputs.run_modal equals true and inputs.customize_forecast equals true }}" | |
- id: color_history | |
displayName: Time series color for historical data | |
description: Color for historical data in the plot | |
type: STRING | |
defaults: blue | |
dependsOn: | |
inputs: | |
- run_modal | |
- customize_forecast | |
condition: "{{ inputs.run_modal equals true and inputs.customize_forecast equals true }}" | |
- id: color_prediction | |
displayName: Time series color for forecasted data | |
description: Color for predicted data in the plot | |
type: STRING | |
defaults: red | |
dependsOn: | |
inputs: | |
- run_modal | |
- customize_forecast | |
condition: "{{ inputs.run_modal equals true and inputs.customize_forecast equals true }}" | |
- id: nr_days_fcst | |
displayName: Forecast days | |
description: The number of days in the future to generate forecast for | |
type: INT | |
defaults: 180 | |
dependsOn: | |
inputs: | |
- run_modal | |
- customize_forecast | |
condition: "{{ inputs.run_modal equals true and inputs.customize_forecast equals true }}" | |
tasks: | |
- id: check_whether_to_run_modal | |
type: io.kestra.plugin.core.flow.If | |
condition: "{{ inputs.run_modal equals true }}" | |
then: | |
- id: modal | |
type: io.kestra.plugin.modal.cli.ModalCLI | |
description: The script.py is added inlined for portability. In production, it should be stored in a Namespace File. | |
inputFiles: | |
script.py: | | |
import os | |
from typing import Tuple | |
import modal | |
import warnings | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
app = modal.App( | |
"order-forecast", | |
secrets=[ | |
modal.Secret.from_local_environ( | |
env_keys=[ | |
"CPU", | |
"MEMORY", | |
"AWS_ACCESS_KEY_ID", | |
"AWS_SECRET_ACCESS_KEY", | |
"AWS_DEFAULT_REGION", | |
] | |
) | |
], | |
) | |
image = modal.Image.debian_slim().pip_install( | |
"pandas", | |
"boto3", | |
"kestra", | |
"pyarrow", | |
"plotly", | |
"statsmodels", | |
) | |
@app.function(image=image, cpu=float(os.getenv("CPU")), memory=int(os.getenv("MEMORY"))) | |
def predict_order_volume( | |
forecast_file: str, | |
html_report: str, | |
s3_bucket: str, | |
nr_days_fcst: int, | |
dataset_url: str, | |
color_history: str, | |
color_prediction: str, | |
) -> Tuple[str, str]: | |
import datetime | |
import boto3 | |
from kestra import Kestra | |
import pandas as pd | |
import plotly.graph_objs as go | |
from statsmodels.tsa.statespace.sarimax import SARIMAX | |
# ==================== EXTRACT ================= | |
df = pd.read_parquet(dataset_url) | |
initial_nr_rows = len(df) | |
print(f"Number of rows in the dataset: {initial_nr_rows}") | |
# ==================== TRANSFORM ==================== | |
# Extract 'ds' (date) from 'ordered_at' and use 'order_total' as 'y' | |
df["ds"] = pd.to_datetime(df["ordered_at"]).dt.date | |
df = df.groupby("ds").agg({"order_total": "sum"}).reset_index() | |
df.rename(columns={"order_total": "y"}, inplace=True) | |
# Ensure daily frequency | |
df["ds"] = pd.to_datetime(df["ds"]) | |
df.set_index("ds", inplace=True) | |
df = df.asfreq("D", fill_value=0) # Fill missing days with 0 order totals | |
nr_rows_daily = len(df) | |
# ==================== TRAIN SARIMA MODEL ==================== | |
model = SARIMAX(df["y"], order=(1, 1, 1), seasonal_order=(1, 1, 1, 7)) | |
sarima_fit = model.fit(disp=False) | |
# ==================== PREDICT ==================== | |
future = sarima_fit.get_forecast(steps=nr_days_fcst) | |
forecast = future.summary_frame() | |
# Create future dates | |
future_dates = pd.date_range( | |
df.index.max() + datetime.timedelta(days=1), periods=nr_days_fcst | |
) | |
forecast_df = pd.DataFrame({"ds": future_dates, "yhat": forecast["mean"]}) | |
forecast_df.to_parquet(forecast_file) | |
# ==================== VISUALIZE WITH PLOTLY ==================== | |
forecast_fig = go.Figure() | |
forecast_fig.add_trace( | |
go.Scatter( | |
x=df.index, | |
y=df["y"], | |
mode="lines", | |
name="Historical Order Volume", | |
line=dict(color=color_history), | |
) | |
) | |
forecast_fig.add_trace( | |
go.Scatter( | |
x=forecast_df["ds"], | |
y=forecast_df["yhat"], | |
mode="lines", | |
name="Predicted Order Volume", | |
line=dict(color=color_prediction), | |
) | |
) | |
forecast_fig.update_layout( | |
title=f"Order Volume Prediction for the Next {nr_days_fcst} Days", | |
xaxis_title="Date", | |
yaxis_title="Order Total", | |
legend_title="Legend", | |
xaxis=dict(showgrid=True), | |
yaxis=dict(showgrid=True), | |
) | |
forecast_fig.write_html(html_report) | |
# ==================== UPLOAD TO S3 ==================== | |
s3 = boto3.client( | |
"s3", | |
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), | |
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), | |
region_name=os.getenv("AWS_DEFAULT_REGION"), | |
) | |
files_to_upload = [html_report, forecast_file] | |
for file_name in files_to_upload: | |
s3.upload_file(file_name, s3_bucket, file_name) | |
print(f"File {file_name} uploaded to {s3_bucket}.") | |
Kestra.outputs( | |
dict( | |
initial_nr_rows=initial_nr_rows, | |
nr_rows_daily=nr_rows_daily, | |
forecast_file=forecast_file, | |
html_report=html_report, | |
) | |
) | |
return forecast_file, html_report | |
@app.local_entrypoint() | |
def generate_and_predict( | |
forecast_file: str, | |
html_report: str, | |
s3_bucket: str, | |
nr_days_fcst: int, | |
dataset_url: str, | |
color_history: str, | |
color_prediction: str, | |
) -> None: | |
results = predict_order_volume.remote( | |
forecast_file, | |
html_report, | |
s3_bucket, | |
nr_days_fcst, | |
dataset_url, | |
color_history, | |
color_prediction, | |
) | |
print(f"Forecast file: {results[0]}, HTML report: {results[1]}") | |
namespaceFiles: | |
enabled: true | |
commands: | |
- modal run script.py --forecast-file {{inputs.forecast_file}} --html-report {{inputs.html_report}} --s3-bucket {{inputs.s3_bucket}} --nr-days-fcst {{inputs.nr_days_fcst}} --dataset-url {{inputs.dataset_url}} --color-history {{inputs.color_history}} --color-prediction {{inputs.color_prediction}} | |
containerImage: ghcr.io/kestra-io/modal:latest | |
env: | |
MODAL_TOKEN_ID: "{{ kv('MODAL_TOKEN_ID') }}" | |
MODAL_TOKEN_SECRET: "{{ kv('MODAL_TOKEN_SECRET') }}" | |
AWS_ACCESS_KEY_ID: "{{ kv('AWS_ACCESS_KEY_ID') }}" | |
AWS_SECRET_ACCESS_KEY: "{{ kv('AWS_SECRET_ACCESS_KEY') }}" | |
AWS_DEFAULT_REGION: "{{ kv('AWS_DEFAULT_REGION') }}" | |
CPU: "{{ inputs.cpu }}" | |
MEMORY: "{{ inputs.memory }}" | |
- id: download_from_s3 | |
type: io.kestra.plugin.aws.s3.Download | |
accessKeyId: "{{ kv('AWS_ACCESS_KEY_ID') }}" | |
secretKeyId: "{{ kv('AWS_SECRET_ACCESS_KEY') }}" | |
region: "{{ kv('AWS_DEFAULT_REGION') }}" | |
bucket: "{{ inputs.s3_bucket }}" | |
key: "{{ inputs.html_report }}" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment