Skip to content

Module function reference

This is a concise reference for exported workflow functions in timeshap_app.py.

Data preparation

generate_demo_data(n_entities=180, time_steps=70, n_features=6, seed=SEED) -> pd.DataFrame

  • Generates synthetic long-format sequence data with lag structure.

detect_feature_cols(df, entity_col, time_col, target_col) -> list[str]

  • Returns all columns except selected entity/time/target columns.

build_sequences(df, feature_cols, entity_col, time_col, target_col, seq_len) -> (x, y, windows_raw)

  • Builds sliding windows and labels from sorted grouped sequences.

fit_scaler(train_windows) -> (mean, std)

  • Computes per-feature normalization statistics.

train_val_split_by_entity(windows_raw, val_frac, entity_col, seed=SEED) -> (train_idx, val_idx)

  • Entity-level split into train and validation index arrays.

Model training and inference

train_model(x_train, y_train, x_val, y_val, input_dim, hidden_dim, epochs, lr, batch_size) -> (model, history)

  • Trains built-in GRUClassifier with BCE-with-logits loss.

predict_prob_and_hidden(model, x_np, hs_np=None) -> (prob, seq_hidden, hidden_state)

  • Runs model inference and returns probability plus hidden outputs.

load_sequence_model_from_checkpoint(checkpoint_file, input_dim, seq_len, hidden_dim_fallback, lightning_class_path="", model_attr_name="model") -> (model, notes)

  • Loads model via Lightning class path or direct checkpoint/state dict fallback.
  • Validates loaded model output contract before returning.

TimeSHAP integration

resolve_timeshap_functions() -> (local_pruning, local_event, local_feat)

  • Resolves TimeSHAP local APIs across supported module paths.
  • Includes SHAP compatibility shim for Kernel vs KernelExplainer.

run_local_timeshap(model, sequence_scaled, feature_cols, entity_col, time_col, tol, nsamples) -> (pruning_out, event_out, feat_out)

  • Executes local TimeSHAP flow for one sequence window.

UI and rendering helpers

render_history(history) -> None

  • Renders training and validation plots with Plotly in Streamlit.

app() -> None

  • Main Streamlit application entrypoint.