The Problem — Silent Revenue Erosion
SaaS companies lose 5-7% of customers every month on average, and most of them never see it coming. By the time a customer opens a cancellation ticket or downgrades their plan, the decision was made weeks ago. The usage metrics were declining, support tickets were going unanswered, and feature adoption had flatlined — but nobody was watching the signals in aggregate.
The economics of churn are brutal. Acquiring a new customer costs 5-25x more than retaining an existing one. For a mid-market SaaS with 50,000 customers paying $100/month, a 6% monthly churn rate means losing 3,000 customers per month — $3.6M in annualized revenue walking out the door. Even worse, churned customers often share negative experiences, creating a compounding effect on growth.
Companies with mature churn prediction systems reduce churn by 15-30%. A Harvard Business Review study found that a 5% increase in retention can increase profits by 25-95%. The ROI of a well-built churn prediction pipeline is one of the highest in all of ML.
What Makes Churn Prediction Hard?
Unlike fraud detection where you have clear-cut labels (transaction was fraudulent or not), churn is fuzzy. A customer who stops logging in for 2 weeks might be on vacation, or might be evaluating a competitor. A spike in support tickets could mean the customer is frustrated — or it could mean they are deeply engaged and pushing the product to its limits.
- Imbalanced data: In any given month, only 5-7% of customers churn, creating a severe class imbalance
- Temporal complexity: Churn signals unfold over weeks, requiring careful window-based feature engineering
- Multiple churn types: Voluntary cancellation, non-renewal, downgrade, and "silent churn" (stops using but stays subscribed)
- Feedback loops: If your model successfully intervenes, the training data changes — saved customers look different from churned ones
- Feature leakage: Using data that is only available after the churn decision has been made
The GCP MLE exam frequently tests your understanding of temporal leakage in classification tasks. In churn prediction, using "days since last login" computed at prediction time vs. training time creates subtle but devastating data leakage. Always use point-in-time feature computation.
Solution Architecture
The architecture follows an end-to-end Vertex AI Pipeline pattern, reading from multiple BigQuery source tables (usage logs, billing events, support tickets, feature adoption), transforming raw data into predictive features, training an XGBoost model with hyperparameter tuning via Vertex AI Vizier, and serving churn risk scores to CRM and marketing automation systems for proactive outreach.
Data Sources
The pipeline reads from four BigQuery tables, each capturing a different dimension of customer behavior. The key insight is that no single data source is sufficient — churn is a multi-signal problem, and the most accurate models combine usage, billing, support, and product adoption data.
| Source Table | Key Fields | Update Frequency | Signal Type |
|---|---|---|---|
| usage_logs | customer_id, event_type, timestamp, duration, feature_used | Real-time (streaming) | Engagement depth |
| billing_events | customer_id, amount, plan_type, payment_status, mrr | Daily batch | Financial commitment |
| support_tickets | customer_id, category, priority, resolution_time, sentiment_score | Near real-time | Friction / frustration |
| feature_adoption | customer_id, feature_name, first_used, usage_count, days_since_last | Daily batch | Product stickiness |
Feature Engineering for Churn
Feature engineering is the most impactful part of a churn prediction pipeline. Raw event logs must be transformed into meaningful, temporally-consistent signals. We compute all features in BigQuery SQL for scalability, using window functions to create point-in-time snapshots that prevent temporal leakage.
Engagement Scores
The engagement score is a composite metric that captures how deeply a customer interacts with the product. It combines login frequency, session duration, feature breadth, and recency of usage into a single 0-100 score.
-- BigQuery SQL: Compute engagement score per customer
WITH usage_stats AS (
SELECT
customer_id,
COUNT(DISTINCT DATE(event_timestamp)) AS active_days_30d,
COUNT(*) AS total_events_30d,
AVG(session_duration_seconds) AS avg_session_duration,
COUNT(DISTINCT feature_used) AS features_used_count,
DATE_DIFF(CURRENT_DATE(), MAX(DATE(event_timestamp)), DAY) AS days_since_last_active
FROM `project.dataset.usage_logs`
WHERE event_timestamp >= DATE_SUB(CURRENT_DATE(), INTERVAL 30 DAY)
GROUP BY customer_id
)
SELECT
customer_id,
-- Weighted engagement score (0-100)
LEAST(100,
(active_days_30d / 30.0) * 30 + -- Login frequency (30%)
LEAST(avg_session_duration / 1800.0, 1.0) * 25 + -- Session depth (25%)
(features_used_count / 12.0) * 25 + -- Feature breadth (25%)
GREATEST(0, 1 - days_since_last_active / 14.0) * 20 -- Recency (20%)
) AS engagement_score
FROM usage_stats;
Usage Trends
Raw usage counts are less predictive than usage trends. A customer who logged in 20 times this week but was at 40 last week is more at-risk than one who logged in 10 times both weeks. We compute week-over-week and month-over-month ratios to capture trajectory.
-- Usage trend features: week-over-week changes
WITH weekly_usage AS (
SELECT
customer_id,
DATE_TRUNC(event_timestamp, WEEK) AS week_start,
COUNT(*) AS event_count,
SUM(session_duration_seconds) AS total_duration
FROM `project.dataset.usage_logs`
WHERE event_timestamp >= DATE_SUB(CURRENT_DATE(), INTERVAL 8 WEEK)
GROUP BY customer_id, week_start
),
with_lag AS (
SELECT
*,
LAG(event_count) OVER (
PARTITION BY customer_id ORDER BY week_start
) AS prev_week_events,
LAG(total_duration) OVER (
PARTITION BY customer_id ORDER BY week_start
) AS prev_week_duration
FROM weekly_usage
)
SELECT
customer_id,
-- Week-over-week change ratios
AVG(SAFE_DIVIDE(event_count - prev_week_events, prev_week_events)) AS avg_wow_event_change,
AVG(SAFE_DIVIDE(total_duration - prev_week_duration, prev_week_duration)) AS avg_wow_duration_change,
-- Trend slope via linear regression
CORR(UNIX_DATE(week_start), event_count) AS usage_trend_corr
FROM with_lag
WHERE prev_week_events IS NOT NULL
GROUP BY customer_id;
Support Sentiment
Support ticket data is a goldmine for churn prediction. We track not just volume, but category distribution, resolution time, and sentiment scores from the Natural Language API. A customer with three unresolved high-priority bugs is fundamentally different from one with three feature requests.
-- Support-based churn features
SELECT
customer_id,
COUNT(*) AS tickets_30d,
COUNTIF(priority = 'critical') AS critical_tickets_30d,
COUNTIF(status = 'open') AS open_tickets,
AVG(resolution_hours) AS avg_resolution_hours,
AVG(sentiment_score) AS avg_sentiment, -- -1 to 1 from NL API
MIN(sentiment_score) AS worst_sentiment,
COUNTIF(category = 'bug') AS bug_tickets,
COUNTIF(category = 'billing') AS billing_tickets -- Strong churn signal
FROM `project.dataset.support_tickets`
WHERE created_at >= DATE_SUB(CURRENT_DATE(), INTERVAL 30 DAY)
GROUP BY customer_id;
Billing-related support tickets are one of the strongest single predictors of churn. A customer who contacts support about pricing or billing disputes is 3-4x more likely to churn than one who contacts about product bugs. Weight this feature accordingly.
Billing Patterns
Billing data captures the financial relationship between the customer and your product. Key features include payment failures, plan downgrades, discount history, and contract remaining length.
-- Billing-based churn features
SELECT
c.customer_id,
c.plan_type,
c.monthly_recurring_revenue AS mrr,
c.contract_months_remaining,
COUNTIF(b.payment_status = 'failed') AS failed_payments_90d,
COUNTIF(b.event_type = 'downgrade') AS downgrades_180d,
COUNTIF(b.event_type = 'discount_applied') AS discounts_received,
CASE
WHEN c.plan_type = 'annual' THEN 0
WHEN c.plan_type = 'monthly' THEN 1
ELSE 0.5
END AS plan_flexibility_score
FROM `project.dataset.customers` c
LEFT JOIN `project.dataset.billing_events` b
ON c.customer_id = b.customer_id
AND b.event_date >= DATE_SUB(CURRENT_DATE(), INTERVAL 180 DAY)
GROUP BY 1, 2, 3, 4;
Never include cancellation_reason or churn_date as input features.
These are only known after the churn event. Similarly, features like "days until contract
end" can introduce subtle leakage if not carefully windowed relative to the prediction date.
Vertex AI Pipeline Definition
The complete pipeline is defined using the Kubeflow Pipelines (KFP) SDK v2, which Vertex AI Pipelines natively supports. Each step is a containerized component that receives typed inputs and produces typed outputs, enabling full lineage tracking and caching.
from kfp.v2 import dsl, compiler
from kfp.v2.dsl import (
Dataset, Input, Metrics, Model, Output, component
)
from google.cloud import aiplatform
# ── Step 1: Data Validation ──
@component(
base_image="python:3.9",
packages_to_install=["google-cloud-bigquery", "great-expectations"]
)
def validate_data(
project_id: str,
dataset_id: str,
validation_output: Output[Dataset]
) -> bool:
"""Validate source data quality before training."""
from google.cloud import bigquery
import json
client = bigquery.Client(project=project_id)
# Check row counts, null rates, schema drift
checks = {
"usage_logs": {
"min_rows": 100000,
"max_null_rate": 0.05,
"required_columns": [
"customer_id", "event_timestamp",
"event_type", "session_duration_seconds"
]
},
"billing_events": {
"min_rows": 10000,
"max_null_rate": 0.02,
"required_columns": [
"customer_id", "amount",
"payment_status"
]
}
}
results = {}
all_passed = True
for table, config in checks.items():
query = f"SELECT COUNT(*) as cnt FROM `{project_id}.{dataset_id}.{table}`"
row_count = list(client.query(query).result())[0].cnt
passed = row_count >= config["min_rows"]
results[table] = {"row_count": row_count, "passed": passed}
if not passed:
all_passed = False
with open(validation_output.path, "w") as f:
json.dump(results, f)
return all_passed
# ── Step 2: Feature Engineering Component ──
@component(
base_image="python:3.9",
packages_to_install=["google-cloud-bigquery", "pandas", "db-dtypes"]
)
def engineer_features(
project_id: str,
dataset_id: str,
prediction_date: str,
feature_dataset: Output[Dataset]
):
"""Create churn features from raw BigQuery tables."""
from google.cloud import bigquery
import pandas as pd
client = bigquery.Client(project=project_id)
feature_query = f"""
WITH engagement AS (
SELECT customer_id,
COUNT(DISTINCT DATE(event_timestamp)) AS active_days,
AVG(session_duration_seconds) AS avg_session_sec,
COUNT(DISTINCT feature_used) AS features_used,
DATE_DIFF('{prediction_date}',
MAX(DATE(event_timestamp)), DAY) AS days_inactive
FROM `{project_id}.{dataset_id}.usage_logs`
WHERE event_timestamp BETWEEN
DATE_SUB('{prediction_date}', INTERVAL 30 DAY)
AND '{prediction_date}'
GROUP BY customer_id
),
support AS (
SELECT customer_id,
COUNT(*) AS tickets_30d,
AVG(sentiment_score) AS avg_sentiment,
COUNTIF(priority = 'critical') AS critical_tickets
FROM `{project_id}.{dataset_id}.support_tickets`
WHERE created_at BETWEEN
DATE_SUB('{prediction_date}', INTERVAL 30 DAY)
AND '{prediction_date}'
GROUP BY customer_id
),
billing AS (
SELECT customer_id,
COUNTIF(payment_status = 'failed') AS failed_payments,
COUNTIF(event_type = 'downgrade') AS downgrades
FROM `{project_id}.{dataset_id}.billing_events`
WHERE event_date BETWEEN
DATE_SUB('{prediction_date}', INTERVAL 90 DAY)
AND '{prediction_date}'
GROUP BY customer_id
)
SELECT
c.customer_id,
COALESCE(e.active_days, 0) AS active_days_30d,
COALESCE(e.avg_session_sec, 0) AS avg_session_sec,
COALESCE(e.features_used, 0) AS features_used,
COALESCE(e.days_inactive, 30) AS days_inactive,
COALESCE(s.tickets_30d, 0) AS support_tickets_30d,
COALESCE(s.avg_sentiment, 0) AS avg_support_sentiment,
COALESCE(s.critical_tickets, 0) AS critical_tickets_30d,
COALESCE(b.failed_payments, 0) AS failed_payments_90d,
COALESCE(b.downgrades, 0) AS downgrades_90d,
c.plan_type,
c.monthly_recurring_revenue AS mrr,
c.tenure_months
FROM `{project_id}.{dataset_id}.customers` c
LEFT JOIN engagement e ON c.customer_id = e.customer_id
LEFT JOIN support s ON c.customer_id = s.customer_id
LEFT JOIN billing b ON c.customer_id = b.customer_id
"""
df = client.query(feature_query).to_dataframe()
df.to_parquet(feature_dataset.path, index=False)
Notice how each component declares typed Input and Output artifacts.
This enables Vertex AI Pipelines to automatically track data lineage and
cache intermediate results. If your data validation step passes and the source data hasn't
changed, Vertex will skip re-running the feature engineering step on subsequent runs.
XGBoost Training with Vizier HPT
We use XGBoost as our primary model for several reasons: it handles tabular
data exceptionally well, provides built-in feature importance, supports class imbalance via
scale_pos_weight, and trains quickly enough for iterative hyperparameter tuning.
Vertex AI Vizier provides Bayesian optimization for efficient hyperparameter search.
# ── Step 3: Training with Hyperparameter Tuning ──
@component(
base_image="python:3.9",
packages_to_install=[
"xgboost", "pandas", "scikit-learn",
"google-cloud-aiplatform"
]
)
def train_xgboost(
feature_dataset: Input[Dataset],
target_column: str,
model_artifact: Output[Model],
metrics: Output[Metrics]
):
"""Train XGBoost with Vizier-style hyperparameter tuning."""
import xgboost as xgb
import pandas as pd
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import (
roc_auc_score, precision_score, recall_score, f1_score
)
import json, pickle
df = pd.read_parquet(feature_dataset.path)
# Temporal split: train on older data, validate on recent
df = df.sort_values("snapshot_date")
split_idx = int(len(df) * 0.8)
train_df = df.iloc[:split_idx]
val_df = df.iloc[split_idx:]
feature_cols = [c for c in df.columns
if c not in [target_column, "customer_id", "snapshot_date"]]
X_train = train_df[feature_cols]
y_train = train_df[target_column]
X_val = val_df[feature_cols]
y_val = val_df[target_column]
# Handle class imbalance
n_pos = y_train.sum()
n_neg = len(y_train) - n_pos
scale_pos_weight = n_neg / n_pos
# Best params from Vizier Bayesian optimization
best_params = {
"objective": "binary:logistic",
"eval_metric": "aucpr", # AUC-PR for imbalanced
"max_depth": 6,
"learning_rate": 0.05,
"n_estimators": 500,
"subsample": 0.8,
"colsample_bytree": 0.8,
"min_child_weight": 5,
"scale_pos_weight": scale_pos_weight,
"reg_alpha": 0.1,
"reg_lambda": 1.0,
"random_state": 42
}
model = xgb.XGBClassifier(**best_params)
model.fit(
X_train, y_train,
eval_set=[(X_val, y_val)],
verbose=50
)
# Evaluate
y_pred_proba = model.predict_proba(X_val)[:, 1]
y_pred = (y_pred_proba >= 0.5).astype(int)
eval_metrics = {
"auc_roc": roc_auc_score(y_val, y_pred_proba),
"precision": precision_score(y_val, y_pred),
"recall": recall_score(y_val, y_pred),
"f1": f1_score(y_val, y_pred)
}
for k, v in eval_metrics.items():
metrics.log_metric(k, v)
with open(model_artifact.path, "wb") as f:
pickle.dump(model, f)
Vizier Hyperparameter Study Configuration
In production, you would define a Vizier study to search the hyperparameter space efficiently. Vizier uses Bayesian optimization (specifically Gaussian Process bandit algorithms) to find optimal configurations in fewer trials than grid or random search.
# Vertex AI Vizier study for XGBoost hyperparameter tuning
from google.cloud import aiplatform
study = aiplatform.Study.create_or_load(
display_name="churn-xgboost-hpt",
problem_type="MAXIMIZE",
metric_id="auc_pr",
parameters=[
{"parameter_id": "max_depth",
"integer_value_spec": {"min_value": 3, "max_value": 10}},
{"parameter_id": "learning_rate",
"double_value_spec": {"min_value": 0.01, "max_value": 0.3},
"scale_type": "UNIT_LOG_SCALE"},
{"parameter_id": "n_estimators",
"integer_value_spec": {"min_value": 100, "max_value": 1000}},
{"parameter_id": "subsample",
"double_value_spec": {"min_value": 0.6, "max_value": 1.0}},
{"parameter_id": "colsample_bytree",
"double_value_spec": {"min_value": 0.6, "max_value": 1.0}},
{"parameter_id": "min_child_weight",
"integer_value_spec": {"min_value": 1, "max_value": 10}},
{"parameter_id": "reg_alpha",
"double_value_spec": {"min_value": 0.0, "max_value": 1.0}},
{"parameter_id": "reg_lambda",
"double_value_spec": {"min_value": 0.0, "max_value": 2.0}},
]
)
# Run 50 trials with automated early stopping
for trial_idx in range(50):
trial = study.suggest()
params = trial.parameters
# ... train model with params, report metric
trial.complete(final_measurement={"auc_pr": auc_pr_score})
Model Evaluation with Fairness Constraints
Model evaluation for churn prediction goes beyond standard accuracy metrics. You need to consider the cost-asymmetry of false positives vs. false negatives, evaluate performance across customer segments, and ensure fairness across different cohorts.
Evaluation Metrics for Churn
| Metric | Value | Why It Matters |
|---|---|---|
| AUC-ROC | 0.89 | Overall ranking quality; threshold-independent |
| AUC-PR | 0.72 | More informative than AUC-ROC for imbalanced data |
| Precision @10% | 0.84 | Accuracy of top 10% risk predictions (actionable metric) |
| Recall @30% | 0.78 | Coverage when intervening with top 30% risk scores |
| Expected Calibration Error | 0.03 | Predicted probabilities match observed churn rates |
# ── Step 4: Evaluation with Fairness Checks ──
@component(
base_image="python:3.9",
packages_to_install=["pandas", "scikit-learn", "numpy"]
)
def evaluate_model(
model_artifact: Input[Model],
feature_dataset: Input[Dataset],
metrics: Output[Metrics],
min_auc_threshold: float = 0.80,
max_disparity_threshold: float = 0.15
) -> bool:
"""Evaluate model and check fairness across segments."""
import pickle, pandas as pd, numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve
with open(model_artifact.path, "rb") as f:
model = pickle.load(f)
df = pd.read_parquet(feature_dataset.path)
# Overall metrics
y_true = df["churned"]
y_pred = model.predict_proba(df[feature_cols])[:, 1]
overall_auc = roc_auc_score(y_true, y_pred)
# Fairness: check AUC across customer segments
segments = ["plan_type", "tenure_bucket", "company_size"]
max_disparity = 0.0
for segment in segments:
group_aucs = []
for group_name, group_df in df.groupby(segment):
if len(group_df) >= 50 and group_df["churned"].nunique() > 1:
g_auc = roc_auc_score(
group_df["churned"],
model.predict_proba(group_df[feature_cols])[:, 1]
)
group_aucs.append(g_auc)
metrics.log_metric(f"auc_{segment}_{group_name}", g_auc)
if group_aucs:
disparity = max(group_aucs) - min(group_aucs)
max_disparity = max(max_disparity, disparity)
metrics.log_metric(f"disparity_{segment}", disparity)
# Gate: pass only if AUC meets threshold AND fairness OK
passed = overall_auc >= min_auc_threshold and max_disparity <= max_disparity_threshold
metrics.log_metric("overall_auc", overall_auc)
metrics.log_metric("max_disparity", max_disparity)
metrics.log_metric("evaluation_passed", int(passed))
return passed
The GCP MLE exam tests your understanding of model fairness. In churn prediction, you must verify that the model performs equitably across customer segments (plan types, company sizes, geographies). A model that only accurately predicts churn for enterprise customers but fails for SMBs creates biased intervention strategies.
Automated Retraining with Drift Detection
Customer behavior changes over time. Product updates alter usage patterns, market shifts affect churn drivers, and seasonal effects (holiday slowdowns, fiscal year-end renewals) create temporal variations. Your model must adapt, and the retraining schedule should be drift-driven, not purely calendar-driven.
Drift Detection Strategy
We monitor three types of drift, each with different detection methods and response urgencies:
| Drift Type | Detection Method | Threshold | Response |
|---|---|---|---|
| Feature drift | Population Stability Index (PSI) | PSI > 0.2 | Alert + schedule retraining |
| Prediction drift | KL divergence on score distribution | KL > 0.1 | Alert + investigate |
| Concept drift | Sliding-window AUC on labelled data | AUC drops > 5% | Immediate retraining |
# ── Drift Detection + Retraining Trigger ──
@component(
base_image="python:3.9",
packages_to_install=[
"google-cloud-aiplatform", "pandas",
"numpy", "scipy"
]
)
def detect_drift(
project_id: str,
endpoint_id: str,
baseline_stats_path: str,
drift_report: Output[Dataset]
) -> bool:
"""Check for feature and prediction drift."""
import numpy as np
from scipy import stats
import json
def compute_psi(expected, actual, bins=10):
"""Population Stability Index."""
breakpoints = np.linspace(0, 1, bins + 1)
expected_pct = np.histogram(expected, breakpoints)[0] / len(expected)
actual_pct = np.histogram(actual, breakpoints)[0] / len(actual)
# Avoid log(0)
expected_pct = np.clip(expected_pct, 0.001, None)
actual_pct = np.clip(actual_pct, 0.001, None)
psi = np.sum((actual_pct - expected_pct) *
np.log(actual_pct / expected_pct))
return psi
# Load baseline and current feature distributions
# ... (fetch from Model Monitoring or BigQuery) ...
drift_detected = False
feature_psi_values = {}
for feature in monitored_features:
psi = compute_psi(baseline[feature], current[feature])
feature_psi_values[feature] = psi
if psi > 0.2:
drift_detected = True
with open(drift_report.path, "w") as f:
json.dump({
"drift_detected": drift_detected,
"feature_psi": feature_psi_values,
"timestamp": str(pd.Timestamp.now())
}, f)
return drift_detected
# ── Cloud Scheduler: trigger pipeline on drift or biweekly ──
from google.cloud import scheduler_v1
client = scheduler_v1.CloudSchedulerClient()
job = {
"name": "projects/my-project/locations/us-central1/jobs/churn-retrain",
"schedule": "0 2 1,15 * *", # 1st and 15th of each month
"http_target": {
"uri": "https://us-central1-aiplatform.googleapis.com/v1/...",
"http_method": "POST",
"body": json.dumps({
"pipeline_spec_uri": "gs://my-bucket/pipelines/churn_pipeline.json",
"parameter_values": {
"prediction_date": "{{$.scheduledTime}}"
}
}).encode()
}
}
client.create_job(parent="projects/my-project/locations/us-central1", job=job)
Serving Predictions to Downstream Systems
Churn predictions need to reach the people and systems that can act on them. We deploy the model as a Vertex AI Endpoint for real-time scoring, and run a nightly batch prediction job that writes risk scores to BigQuery, which feeds CRM dashboards and marketing automation triggers.
# ── Deploy to Vertex AI Endpoint ──
from google.cloud import aiplatform
# Upload model to Model Registry
model = aiplatform.Model.upload(
display_name="churn-predictor-v2",
artifact_uri="gs://my-bucket/models/churn/v2/",
serving_container_image_uri=(
"us-docker.pkg.dev/vertex-ai/prediction/"
"xgboost-cpu.1-7:latest"
),
labels={"team": "retention", "env": "production"}
)
# Deploy with traffic split for canary
endpoint = aiplatform.Endpoint.create(
display_name="churn-prediction-endpoint",
labels={"team": "retention"}
)
endpoint.deploy(
model=model,
deployed_model_display_name="churn-v2",
machine_type="n1-standard-4",
min_replica_count=1,
max_replica_count=5,
traffic_percentage=10, # Canary: 10% traffic to new model
enable_access_logging=True
)
# ── Nightly Batch Prediction to BigQuery ──
batch_job = model.batch_predict(
job_display_name="churn-batch-nightly",
instances_format="bigquery",
predictions_format="bigquery",
bigquery_source="bq://project.dataset.customer_features_latest",
bigquery_destination_prefix="bq://project.dataset.churn_predictions",
machine_type="n1-standard-8",
max_replica_count=10,
labels={"pipeline": "churn-batch"}
)
batch_job.wait()
For churn prediction, batch predictions are usually sufficient. Unlike fraud detection (which requires sub-second latency), churn predictions feed into daily/weekly retention campaigns. The batch approach is simpler, cheaper, and easier to debug. Reserve real-time endpoints for interactive dashboards or event-triggered interventions.
Churn Intervention Scoring
Not every at-risk customer should receive the same intervention. The intervention priority score combines churn probability with customer value and estimated save probability to rank who to contact first and what to offer them.
-- Intervention priority scoring in BigQuery
WITH predictions AS (
SELECT
p.customer_id,
p.churn_probability,
c.monthly_recurring_revenue AS mrr,
c.tenure_months,
c.lifetime_value AS ltv,
c.plan_type,
-- Estimated save probability based on historical success rates
CASE
WHEN c.tenure_months > 12 AND p.churn_probability < 0.8
THEN 0.45 -- Long-tenure, moderate risk: 45% save rate
WHEN c.tenure_months > 6 AND p.churn_probability < 0.6
THEN 0.55 -- Medium-tenure, lower risk: 55% save rate
WHEN p.churn_probability >= 0.9
THEN 0.10 -- Very high risk: only 10% save rate
ELSE 0.30
END AS estimated_save_probability
FROM `project.dataset.churn_predictions` p
JOIN `project.dataset.customers` c
ON p.customer_id = c.customer_id
WHERE p.prediction_date = CURRENT_DATE()
)
SELECT
customer_id,
churn_probability,
mrr,
ltv,
estimated_save_probability,
-- Expected value of intervention
churn_probability * estimated_save_probability * mrr * 12
AS expected_annual_value_saved,
-- Assign tier
CASE
WHEN churn_probability >= 0.7 AND mrr >= 500 THEN 'TIER_1_EXECUTIVE'
WHEN churn_probability >= 0.5 AND mrr >= 200 THEN 'TIER_2_CSM'
WHEN churn_probability >= 0.4 THEN 'TIER_3_AUTOMATED'
ELSE 'MONITOR'
END AS intervention_tier,
-- Recommended action
CASE
WHEN churn_probability >= 0.7 AND mrr >= 500
THEN 'Executive QBR call + custom success plan'
WHEN churn_probability >= 0.5
THEN 'CSM outreach + usage consultation'
WHEN churn_probability >= 0.4
THEN 'Automated email sequence + in-app tips'
ELSE 'Health monitoring only'
END AS recommended_action
FROM predictions
ORDER BY expected_annual_value_saved DESC;
The intervention tier system ensures that your most experienced (and expensive) retention resources are focused on the highest-value opportunities. An automated email to a $50/month customer costs $0.10, while an executive QBR costs $500+ — but saving a $5,000/month enterprise account justifies that investment many times over.
Key GCP Components
This pipeline leverages several core GCP services. Understanding how they interact is essential for the MLE exam and for building production-grade churn prediction systems.
Vertex AI Pipelines
Orchestrates the end-to-end workflow from data validation through deployment. KFP v2 components with typed artifacts, caching, and lineage tracking. Supports conditional execution and parallel steps.
Vertex AI Vizier
Bayesian hyperparameter optimization that finds optimal XGBoost configurations in fewer trials than grid search. Supports early stopping, multi-objective optimization, and transfer learning across studies.
BigQuery
Stores all source data and serves as the feature engineering engine. SQL-based transformations with window functions enable point-in-time feature computation at scale without data movement.
XGBoost on Vertex
Gradient-boosted decision trees with built-in support for class imbalance, feature importance, and SHAP values. Pre-built serving containers eliminate Docker image management overhead.
Model Monitoring
Continuous monitoring for feature drift (PSI), prediction drift (KL divergence), and concept drift (performance degradation). Automated alerts via Cloud Monitoring and Pub/Sub integration.
KFP SDK v2
Python-first pipeline definition with decorator-based components, typed I/O artifacts, and compile-to-JSON for portability. Supports component reuse across pipelines and custom container images.
Results & Business Impact
After deploying the churn prediction pipeline for a 50,000-customer SaaS company, the following results were achieved over a 6-month evaluation period:
Detailed Outcomes
- 84% accuracy in predicting which customers would churn within 30 days, with a precision of 0.78 at the top-10% risk tier
- 27% reduction in monthly churn rate (from 6.2% to 4.5%), directly attributable to proactive interventions triggered by the model
- $1.8M in annual revenue saved for a SaaS company with 50,000 customers at $100 average MRR
- Automated retraining every 2 weeks, with drift-triggered retraining when feature distributions shift beyond PSI thresholds
- Top 10% risk customers receive proactive outreach from Customer Success, with a 45% save rate on these interventions
- 3.2x ROI on the ML platform investment, paying for itself within 4 months of deployment
To isolate the model's impact, we ran an A/B test: 80% of at-risk customers received proactive outreach (treatment), while 20% received no intervention (control). The treatment group showed a 27% lower churn rate, validating the model's causal impact beyond simple correlation.
Production Considerations
Moving from a notebook prototype to a production churn prediction system requires addressing several critical concerns that are frequently tested on the GCP MLE exam.
Feature Freshness
How recent is your data? If usage logs have a 24-hour ingestion lag but support tickets are near-real-time, your features have inconsistent freshness. This creates a training-serving skew if not handled carefully. Document the freshness SLA for each data source and align your prediction schedule accordingly.
Temporal Leakage in Training
The most common and devastating mistake in churn modeling. You must use point-in-time correct feature computation: features for a customer who churned on March 15 should only use data available before March 15, using the same look-back windows that will be available at prediction time. Never use future data, even indirectly through aggregate statistics.
# Correct: Point-in-time feature computation for training
# For each training example, compute features using only
# data available BEFORE the prediction_date
for snapshot_date in training_dates:
features = compute_features(
customer_id=customer_id,
as_of_date=snapshot_date, # Only use data before this date
lookback_days=30
)
label = did_customer_churn(
customer_id=customer_id,
observation_window_start=snapshot_date,
observation_window_end=snapshot_date + timedelta(days=30)
)
Fairness Across Customer Segments
Verify model performance across plan types (free vs. paid vs. enterprise), company sizes (SMB vs. mid-market vs. enterprise), geographies, and tenure buckets. An AUC disparity greater than 0.10 between segments should trigger investigation and potential model adjustments (stratified training, segment-specific models, or post-hoc calibration).
Cost-Benefit of Intervention
Not every churn risk warrants intervention. A discount offer to a $10/month customer costs more in margin erosion than the expected save. The intervention scoring system (Section 9) addresses this by computing expected value of intervention = P(churn) x P(save|intervention) x customer_value. Only intervene when EVi exceeds cost.
Data Privacy and GDPR
Under GDPR, customers have the right to explanation when automated decisions affect them. If your churn model triggers a price change, account restriction, or service modification, you must be able to explain why. XGBoost's feature importance and SHAP values provide the necessary interpretability. Store prediction explanations alongside scores.
Pipeline Reliability
Production pipelines fail. BigQuery tables might be empty due to upstream ETL failures. The training step might OOM on unexpectedly large datasets. Build in:
- Data validation gates that halt the pipeline if source data is incomplete or corrupted
- Fallback models that serve the previous version if the new model fails evaluation
- Alerting via Cloud Monitoring for pipeline failures, latency spikes, and prediction anomalies
- Idempotent runs so pipelines can be safely retried without duplicate predictions
Handling Seasonal Patterns
SaaS churn is seasonal. January sees high churn from annual renewals and budget reviews. Summer months see reduced usage (vacation) that can be mistaken for disengagement. Include cyclical time features (month-of-year, quarter) and train on at least 12 months of data to capture these patterns. Consider separate models or feature adjustments for fiscal year-end periods.
Production ML system questions on the MLE exam frequently cover: training-serving skew, feature store freshness, model versioning with A/B testing, pipeline orchestration with conditional execution, and monitoring strategies (feature drift vs. concept drift). Know when to use each type of drift detection and what action to take.
Complete Pipeline Assembly
Here is how all pipeline components connect into a single, deployable Vertex AI Pipeline:
# ── Assemble the complete pipeline ──
@dsl.pipeline(
name="churn-prediction-pipeline",
description="End-to-end churn prediction with drift detection",
pipeline_root="gs://my-bucket/pipeline-root"
)
def churn_pipeline(
project_id: str = "my-project",
dataset_id: str = "churn_data",
prediction_date: str = "2026-03-06",
min_auc_threshold: float = 0.80
):
# Step 1: Validate source data
validation_task = validate_data(
project_id=project_id,
dataset_id=dataset_id
)
# Step 2: Engineer features (only if validation passes)
with dsl.Condition(validation_task.output == True):
feature_task = engineer_features(
project_id=project_id,
dataset_id=dataset_id,
prediction_date=prediction_date
)
# Step 3: Train model
train_task = train_xgboost(
feature_dataset=feature_task.outputs["feature_dataset"],
target_column="churned"
)
# Step 4: Evaluate with fairness checks
eval_task = evaluate_model(
model_artifact=train_task.outputs["model_artifact"],
feature_dataset=feature_task.outputs["feature_dataset"],
min_auc_threshold=min_auc_threshold
)
# Step 5: Deploy only if evaluation passes
with dsl.Condition(eval_task.output == True):
deploy_task = deploy_model(
model_artifact=train_task.outputs["model_artifact"],
project_id=project_id
)
# Compile and submit
compiler.Compiler().compile(
pipeline_func=churn_pipeline,
package_path="churn_pipeline.json"
)
aiplatform.init(project="my-project", location="us-central1")
job = aiplatform.PipelineJob(
display_name="churn-prediction-run",
template_path="churn_pipeline.json",
parameter_values={
"project_id": "my-project",
"prediction_date": "2026-03-06"
}
)
job.submit()
Build Your Portfolio
Fork & Extend
Turn this notebook into a portfolio project in 5 steps:
- Fork the notebook — Clone the repo and open in Google Colab or locally.
- Swap in real data — Replace the synthetic dataset with the Telco Customer Churn dataset from Kaggle (7,043 customers with demographics, services, tenure, and churn labels). Download it at kaggle.com/blastchar/telco-customer-churn.
- Add survival analysis — Go beyond binary churn classification by implementing a Cox Proportional Hazards model or Kaplan-Meier curves. Predict when a customer will churn, not just whether they will, and estimate Customer Lifetime Value (CLV) for each risk segment.
- Deploy it — Wrap it in a Streamlit app with a customer lookup panel, churn risk gauge, SHAP waterfall plot explaining the top risk drivers, and a recommended intervention (discount offer, CSM outreach, feature onboarding).
- Write a README — Include architecture diagram, setup instructions, sample outputs, and metrics.
What Hiring Managers Look For
Churn prediction projects impress when they go beyond model accuracy and show the business impact of interventions. Calculate the cost of a false negative (lost customer LTV) vs. the cost of a false positive (wasted retention offer), and set your classification threshold to maximize expected profit rather than F1. Include a segment-level analysis showing which customer cohorts are most at risk and which retention strategies work best for each, demonstrating that you can turn model output into actionable business strategy.
Public Datasets to Use
- Telco Customer Churn (IBM) — 7,043 customers with 21 features including contract type, payment method, monthly charges, and tenure. Available on Kaggle. The classic churn dataset, ideal for demonstrating feature engineering and model interpretability.
- E-Commerce Shipping Churn — 10,999 records with delivery performance, discount usage, and customer complaints. Available on Kaggle. Good for showing how operational metrics predict churn in e-commerce contexts.
- KKBox Churn Prediction — Millions of subscription records from Asia's leading music streaming service with transaction logs, user activity, and membership details. Available on Kaggle. Excellent for large-scale churn modeling with temporal features.
Deployment Options
| Platform | Best For | Effort |
|---|---|---|
| Streamlit | Customer risk dashboard with churn probability gauges and SHAP explanations per account | Low |
| Gradio | Quick demo where CSMs input customer attributes and see churn risk with recommended actions | Low |
| FastAPI | Batch scoring API that CRM systems call nightly to update churn risk scores for all accounts | Medium |
| Docker + Cloud Run | Production churn scoring service with scheduled retraining and drift monitoring alerts | High |