NestedLogit.approximate_fit#
- NestedLogit.approximate_fit(X, y=None, progressbar=None, random_seed=None, *, fit_kwargs=None, sample_kwargs=None)#
Fit a model using Variational Inference and return InferenceData.
This performs variational inference via
pymc.fit, then draws posterior samples from the fitted approximation viaApproximation.sample, returning anarviz.InferenceDatacompatible with the rest of the API (same structure as.fit).- Parameters:
- Xarray_like |
array,shape(n_obs,n_features) The training input samples. If scikit-learn is available, array-like, otherwise array.
- yarray_like |
array,shape(n_obs,) The target values (real numbers). If scikit-learn is available, array-like, otherwise array.
- progressbarbool, optional
Specifies whether the fitting/sample progress bar should be displayed. Defaults to True.
- random_seed
Optional[RandomState] Provides stochastic procedures with initial random seed for reproducibility.
- fit_kwargs
dict, optional Extra keyword arguments forwarded to
pymc.fit(e.g., {“n”: 10_000, “method”: “advi”}).- sample_kwargs
dict, optional Extra keyword arguments forwarded to
Approximation.sample(e.g., {“draws”: 1_000}).
- Xarray_like |
- Returns:
az.InferenceDataInference data of the variationally fitted model.