Prediction Models

Prediction Models #

Overview #

The prediction models form the core of FunnelStory’s intelligence layer, enabling us to forecast customer churn and retention likelihood. This system is designed around an ensemble learning approach, combining multiple machine learning algorithms to produce a robust and accurate composite score.

The entire lifecycle—from data querying and preprocessing to model training, prediction, and result storage—is managed within the prediction package.

Core Concepts and Architecture #

1. Ensemble Modeling #

Instead of relying on a single algorithm, we use an ensemble of models to improve predictive accuracy and stability. Our primary ensemble consists of:

  • learn.LogisticRegression: A linear model that is effective at establishing baseline predictions and understanding the linear relationships between features.
  • learn.RandomForest: A tree-based model that excels at capturing complex, non-linear interactions between features.

The final prediction score is a weighted average of the outputs from each model in the ensemble. The weights (e.g., 70% for Logistic Regression, 30% for Random Forest) are configurable via EnsembleWeights, allowing us to tune the model’s behavior based on performance.

2. Model Training (Model.Train) #

The training process is orchestrated by the Model.Train method and the API handler PostModelsTrain.

  1. Data Acquisition: The process begins by querying the necessary data for either accounts (QueryAccountPredictionData) or users (QueryUserPredictionData). This function gathers raw data points that will be transformed into features.
  2. Feature Engineering & Preprocessing: The raw data is converted into a set of learn.Points. Each Point represents an entity (an account or user) and contains:
    • A set of numeric Features (e.g., activity:30_days:App Page Viewed).
    • A set of categorical Traits (e.g., trait:country=USA).
    • A Value of 1.0 if it’s a target entity (e.g., a known churned account) or 0.0 otherwise.
  3. Concurrent Training: The Model.Train method trains the Logistic Regression and Random Forest models in parallel using goroutines. This significantly speeds up the training process.
  4. Model Storage: Once trained, the resulting model—including its learned weights and parameters—is serialized to JSON and stored in the database.

3. Prediction Generation (Model.Predict & tick) #

Predictions are generated periodically by a background job managed by prediction.Accessor.Tick.

  1. Data Fetching: The tick function queries fresh data for all relevant entities to create up-to-date learn.Points.
  2. Ensemble Prediction (Model.predict): For each point, the predict method calculates the weighted average score from the trained Logistic Regression and Random Forest models.
  3. Composite Scoring: Two separate models are used: one for predicting churn and one for predicting retention. The CompositeScore function calculates the final score as retentionScore - churnScore. This gives a single, intuitive metric where positive values indicate health and negative values indicate risk.
  4. Conversational AI Analysis:
    • If recent conversational data (meetings, support tickets, notes) exists for an account, a separate ConversationalModel analyzes this unstructured data to produce its own prediction (conversationPrediction).
    • This conversational insight is then used to adjust the initial composite score, adding a crucial layer of qualitative analysis to our quantitative data.
  5. Factor Analysis & Recommendations:
    • To provide “explainable AI,” the system analyzes the contribution of each feature to the final score using a “what-if” approach. It calculates how the prediction would change if a feature’s value were different (e.g., set to zero or its median).
    • These PredictionFactor results are sorted by impact and returned with the prediction.
    • The system also generates actionable Recommendations (e.g., “increase” a certain activity) by testing how changes to a feature’s value would impact the score.

4. Data Structures #

  • prediction.Model: The main struct representing a trained prediction model. It contains the trained instances of learn.LogisticRegression and learn.RandomForest, along with configuration parameters like EnsembleWeights, target IDs, and factor configurations.
  • prediction.Prediction: The output of a prediction run for a single entity. It includes the final Value (score), the contributing Factors, actionable Recommendations, and the ConversationsPrediction if applicable.
  • learn.Point: The fundamental data structure for our ML models. It represents a single training or prediction instance with its features and traits.
  • api.PostModelsTrainPayload: The JSON payload used by the API to configure and trigger a model training run. It allows for the specification of target entities, feature inclusion/exclusion, and model hyperparameters.

Data Confidence (Prediction.EvaluateDataConfidence) #

To ensure users can trust our predictions, we calculate a DataConfidence score for each prediction. This score evaluates the quality and completeness of the underlying data based on:

  • Recency of Activity: Whether the account has had recent product usage data.
  • Presence of Conversations: Whether there is recent conversational data to analyze.
  • Activity Coverage: Whether the account is performing activities that a majority of other accounts typically perform.

The confidence level (High, Medium, Low) and the reasons for it are surfaced to the user, providing transparency into the prediction’s reliability.

API Endpoints #

  • GET /api/internal/prediction/models: Retrieves all configured prediction models.
  • GET /api/internal/prediction/models/{model_type}: Fetches a specific model (e.g., churn, retention).
  • POST /api/internal/prediction/models/{model_type}/train: The primary endpoint for initiating a model training process. It accepts a PostModelsTrainPayload to define the training set and parameters.
  • POST /api/internal/prediction/models/{model_type}/data: A utility endpoint to fetch the raw training data (learn.Points) for a given configuration without actually training a model.