Skip to content

Inference Engine

Module for the inference engines.

BaseInferenceEngine

Bases: BaseObject

Class representing an inference engine for a given model.

Parameters:

Name Type Description Default
model Callable

The model to be used for inference.

required
rng_key Optional[PRNGKey]

The random number generator key. If not provided, a default key with value 0 will be used.

None

Attributes:

Name Type Description
model Callable

The model used for inference.

rng_key PRNGKey

The random number generator key.

Source code in src/prophetverse/engine/base.py
class BaseInferenceEngine(BaseObject):
    """
    Class representing an inference engine for a given model.

    Parameters
    ----------
    model : Callable
        The model to be used for inference.
    rng_key : Optional[jax.random.PRNGKey]
        The random number generator key. If not provided, a default key with value 0
        will be used.

    Attributes
    ----------
    model : Callable
        The model used for inference.
    rng_key : jax.random.PRNGKey
        The random number generator key.
    """

    _tags = {
        "object_type": "inference_engine",
    }

    def __init__(self, rng_key=None):
        self.rng_key = rng_key

        if rng_key is None:
            rng_key = jax.random.PRNGKey(0)
        self._rng_key = rng_key

    # pragma: no cover
    def infer(self, model, **kwargs):
        """
        Perform inference using the specified model.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the model.

        Returns
        -------
        The result of the inference.
        """
        self.model_ = model
        self._infer(**kwargs)

    # pragma: no cover
    def _infer(self, **kwargs):  # pragma: no cover
        """
        Perform inference using the specified model.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the model.

        Returns
        -------
        The result of the inference.
        """
        raise NotImplementedError("infer method must be implemented in subclass")

    # pragma: no cover
    def predict(self, **kwargs):
        """
        Generate predictions using the specified model.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the model.

        Returns
        -------
        The predictions generated by the model.
        """
        return self._predict(**kwargs)

    # pragma: no cover
    def _predict(self, **kwargs):  # pragma: no cover
        """
        Generate predictions using the specified model.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the model.

        Returns
        -------
        The predictions generated by the model.
        """
        raise NotImplementedError("predict method must be implemented in subclass")

infer(model, **kwargs)

Perform inference using the specified model.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the model.

{}

Returns:

Type Description
The result of the inference.
Source code in src/prophetverse/engine/base.py
def infer(self, model, **kwargs):
    """
    Perform inference using the specified model.

    Parameters
    ----------
    **kwargs
        Additional keyword arguments to be passed to the model.

    Returns
    -------
    The result of the inference.
    """
    self.model_ = model
    self._infer(**kwargs)

predict(**kwargs)

Generate predictions using the specified model.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the model.

{}

Returns:

Type Description
The predictions generated by the model.
Source code in src/prophetverse/engine/base.py
def predict(self, **kwargs):
    """
    Generate predictions using the specified model.

    Parameters
    ----------
    **kwargs
        Additional keyword arguments to be passed to the model.

    Returns
    -------
    The predictions generated by the model.
    """
    return self._predict(**kwargs)

MAPInferenceEngine

Bases: BaseInferenceEngine

Maximum a Posteriori (MAP) Inference Engine.

This class performs MAP inference using Stochastic Variational Inference (SVI) with AutoDelta guide. It provides methods for inference and prediction.

Parameters:

Name Type Description Default
model Callable

The probabilistic model to perform inference on.

required
optimizer_factory _NumPyroOptim

The optimizer to use for SVI. Defaults to None.

None
num_steps int

The number of optimization steps to perform. Defaults to 10000.

10000
rng_key PRNGKey

The random number generator key. Defaults to None.

None
Source code in src/prophetverse/engine/map.py
class MAPInferenceEngine(BaseInferenceEngine):
    """
    Maximum a Posteriori (MAP) Inference Engine.

    This class performs MAP inference using Stochastic Variational Inference (SVI)
    with AutoDelta guide. It provides methods for inference and prediction.

    Parameters
    ----------
    model : Callable
        The probabilistic model to perform inference on.
    optimizer_factory : numpyro.optim._NumPyroOptim, optional
        The optimizer to use for SVI. Defaults to None.
    num_steps : int, optional
        The number of optimization steps to perform. Defaults to 10000.
    rng_key : jax.random.PRNGKey, optional
        The random number generator key. Defaults to None.
    """

    _tags = {
        "inference_method": "map",
    }

    def __init__(
        self,
        optimizer_factory: numpyro.optim._NumPyroOptim = None,
        optimizer: Optional[BaseOptimizer] = None,
        num_steps=10_000,
        num_samples=_DEFAULT_PREDICT_NUM_SAMPLES,
        rng_key=None,
        progress_bar: bool = DEFAULT_PROGRESS_BAR,
        stable_update=False,
        forward_mode_differentiation=False,
        init_loc_fn=None,
    ):

        self.optimizer_factory = optimizer_factory
        self.optimizer = optimizer
        self.num_steps = num_steps
        self.num_samples = num_samples
        self.progress_bar = progress_bar
        self.stable_update = stable_update
        self.forward_mode_differentiation = forward_mode_differentiation
        self.init_loc_fn = init_loc_fn
        super().__init__(rng_key)

        deprecation_warning(
            "optimizer_factory",
            "0.5.0",
            "Please use the `optimizer` parameter instead.",
        )

        if optimizer_factory is None and optimizer is None:
            optimizer = LBFGSSolver()

        if self.optimizer is None and optimizer_factory is not None:
            optimizer = _OptimizerFromCallable(optimizer_factory)
        self._optimizer = optimizer

        self._init_loc_fn = init_loc_fn
        if init_loc_fn is None:
            self._init_loc_fn = init_to_mean()

        self._num_steps = num_steps

        if self._optimizer.get_tag("is_solver", False):  # type: ignore[union-attr]
            self._optimizer = self._optimizer.set_max_iter(  # type: ignore[union-attr]
                self._num_steps
            )
            self._num_steps = 1

    def _infer(self, **kwargs):
        """
        Perform MAP inference.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the model.

        Returns
        -------
        self
            The updated MAPInferenceEngine object.
        """
        self.guide_ = AutoDelta(self.model_, init_loc_fn=self._init_loc_fn)

        def get_result(
            rng_key,
            model,
            guide,
            optimizer,
            num_steps,
            progress_bar,
            stable_update,
            forward_mode_differentiation,
            **kwargs,
        ) -> SVIRunResult:
            svi_ = SVI(
                model,
                guide,
                optimizer,
                loss=Trace_ELBO(),
            )
            return svi_.run(
                rng_key=rng_key,
                progress_bar=progress_bar,
                stable_update=stable_update,
                num_steps=num_steps,
                forward_mode_differentiation=forward_mode_differentiation,
                **kwargs,
            )

        self.run_results_: SVIRunResult = get_result(
            self._rng_key,
            self.model_,
            self.guide_,
            self._optimizer.create_optimizer(),
            self._num_steps,
            stable_update=self.stable_update,
            progress_bar=self.progress_bar,
            forward_mode_differentiation=self.forward_mode_differentiation,
            **kwargs,
        )

        self.raise_error_if_nan_loss(self.run_results_)

        self.posterior_samples_ = self.guide_.sample_posterior(
            self._rng_key, params=self.run_results_.params, **kwargs
        )
        return self

    def raise_error_if_nan_loss(self, run_results: SVIRunResult):
        """
        Raise an error if the loss is NaN.

        Parameters
        ----------
        run_results : SVIRunResult
            The result of the SVI run.

        Raises
        ------
        MAPInferenceEngineError
            If the last loss is NaN.
        """
        losses = run_results.losses
        if jnp.isnan(losses)[-1]:
            msg = "NaN losses in MAPInferenceEngine."
            msg += " Try decreasing the learning rate or changing the model specs."
            msg += " If the problem persists, please open an issue at"
            msg += " https://github.com/felipeangelimvieira/prophetverse"
            raise MAPInferenceEngineError(msg)

    def _predict(self, **kwargs):
        """
        Generate predictions using the trained model.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the model.

        Returns
        -------
        dict
            The predicted samples generated by the model.
        """
        predictive = numpyro.infer.Predictive(
            self.model_,
            params=self.run_results_.params,
            guide=self.guide_,
            # posterior_samples=self.posterior_samples_,
            num_samples=self.num_samples,
        )
        self.samples_ = predictive(rng_key=self._rng_key, **kwargs)
        return self.samples_

    @classmethod
    def get_test_params(*args, **kwargs):
        """Return test params for unit testing."""
        return [
            {
                "optimizer": LBFGSSolver(),
                "num_steps": 100,
            },
            {
                "optimizer": AdamOptimizer(),
                "num_steps": 100,
            },
        ]

get_test_params(*args, **kwargs) classmethod

Return test params for unit testing.

Source code in src/prophetverse/engine/map.py
@classmethod
def get_test_params(*args, **kwargs):
    """Return test params for unit testing."""
    return [
        {
            "optimizer": LBFGSSolver(),
            "num_steps": 100,
        },
        {
            "optimizer": AdamOptimizer(),
            "num_steps": 100,
        },
    ]

raise_error_if_nan_loss(run_results)

Raise an error if the loss is NaN.

Parameters:

Name Type Description Default
run_results SVIRunResult

The result of the SVI run.

required

Raises:

Type Description
MAPInferenceEngineError

If the last loss is NaN.

Source code in src/prophetverse/engine/map.py
def raise_error_if_nan_loss(self, run_results: SVIRunResult):
    """
    Raise an error if the loss is NaN.

    Parameters
    ----------
    run_results : SVIRunResult
        The result of the SVI run.

    Raises
    ------
    MAPInferenceEngineError
        If the last loss is NaN.
    """
    losses = run_results.losses
    if jnp.isnan(losses)[-1]:
        msg = "NaN losses in MAPInferenceEngine."
        msg += " Try decreasing the learning rate or changing the model specs."
        msg += " If the problem persists, please open an issue at"
        msg += " https://github.com/felipeangelimvieira/prophetverse"
        raise MAPInferenceEngineError(msg)

MAPInferenceEngineError

Bases: Exception

Exception raised for NaN losses in MAPInferenceEngine.

Source code in src/prophetverse/engine/map.py
class MAPInferenceEngineError(Exception):
    """Exception raised for NaN losses in MAPInferenceEngine."""

    def __init__(self, message="NaN losses in MAPInferenceEngine"):
        self.message = message
        super().__init__(self.message)

MCMCInferenceEngine

Bases: BaseInferenceEngine

Perform MCMC (Markov Chain Monte Carlo) inference for a given model.

Parameters:

Name Type Description Default
model Callable

The model function to perform inference on.

required
num_samples int

The number of MCMC samples to draw.

1000
num_warmup int

The number of warmup samples to discard.

200
num_chains int

The number of MCMC chains to run in parallel.

1
dense_mass bool

Whether to use dense mass matrix for NUTS sampler.

False
rng_key Optional

The random number generator key.

None

Attributes:

Name Type Description
num_samples int

The number of MCMC samples to draw.

num_warmup int

The number of warmup samples to discard.

num_chains int

The number of MCMC chains to run in parallel.

dense_mass bool

Whether to use dense mass matrix for NUTS sampler.

mcmc_ MCMC

The MCMC object used for inference.

posterior_samples_ Dict[str, ndarray]

The posterior samples obtained from MCMC.

samples_predictive_ Dict[str, ndarray]

The predictive samples obtained from MCMC.

samples_ Dict[str, ndarray]

The MCMC samples obtained from MCMC.

Source code in src/prophetverse/engine/mcmc.py
class MCMCInferenceEngine(BaseInferenceEngine):
    """
    Perform MCMC (Markov Chain Monte Carlo) inference for a given model.

    Parameters
    ----------
    model : Callable
        The model function to perform inference on.
    num_samples : int
        The number of MCMC samples to draw.
    num_warmup : int
        The number of warmup samples to discard.
    num_chains : int
        The number of MCMC chains to run in parallel.
    dense_mass : bool
        Whether to use dense mass matrix for NUTS sampler.
    rng_key : Optional
        The random number generator key.

    Attributes
    ----------
    num_samples : int
        The number of MCMC samples to draw.
    num_warmup : int
        The number of warmup samples to discard.
    num_chains : int
        The number of MCMC chains to run in parallel.
    dense_mass : bool
        Whether to use dense mass matrix for NUTS sampler.
    mcmc_ : MCMC
        The MCMC object used for inference.
    posterior_samples_ : Dict[str, np.ndarray]
        The posterior samples obtained from MCMC.
    samples_predictive_ : Dict[str, np.ndarray]
        The predictive samples obtained from MCMC.
    samples_ : Dict[str, np.ndarray]
        The MCMC samples obtained from MCMC.
    """

    _tags = {
        "inference_method": "mcmc",
    }

    def __init__(
        self,
        num_samples=1000,
        num_warmup=200,
        num_chains=1,
        dense_mass=False,
        rng_key=None,
    ):
        self.num_samples = num_samples
        self.num_warmup = num_warmup
        self.num_chains = num_chains
        self.dense_mass = dense_mass
        super().__init__(rng_key)

    def _infer(self, **kwargs):
        """
        Run MCMC inference.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the MCMC run method.

        Returns
        -------
        self
            The MCMCInferenceEngine object.
        """

        def get_posterior_samples(
            rng_key,
            model,
            dense_mass,
            init_strategy,
            num_samples,
            num_warmup,
            num_chains,
            **kwargs
        ) -> MCMC:
            mcmc_ = MCMC(
                NUTS(model, dense_mass=dense_mass, init_strategy=init_strategy),
                num_samples=num_samples,
                num_warmup=num_warmup,
                num_chains=num_chains,
            )
            mcmc_.run(rng_key, **kwargs)
            return mcmc_.get_samples()

        self.posterior_samples_ = get_posterior_samples(
            self._rng_key,
            self.model_,
            self.dense_mass,
            init_strategy=init_to_mean,
            num_samples=self.num_samples,
            num_warmup=self.num_warmup,
            num_chains=self.num_chains,
            **kwargs
        )
        return self

    def _predict(self, **kwargs):
        """
        Generate predictive samples.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the Predictive method.

        Returns
        -------
        Dict[str, np.ndarray]
            The predictive samples.
        """
        predictive = Predictive(
            self.model_, self.posterior_samples_, num_samples=self.num_samples
        )

        self.samples_predictive_ = predictive(self._rng_key, **kwargs)
        return self.samples_predictive_