Skip to content

Inference Engine

Numpyro inference engines for prophet models.

The classes in this module take a model, the data and perform inference using Numpyro.

InferenceEngine

Class representing an inference engine for a given model.

Parameters:

Name Type Description Default
model Callable

The model to be used for inference.

required
rng_key Optional[PRNGKey]

The random number generator key. If not provided, a default key with value 0 will be used.

None

Attributes:

Name Type Description
model Callable

The model used for inference.

rng_key PRNGKey

The random number generator key.

Source code in src/prophetverse/engine.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class InferenceEngine:
    """
    Class representing an inference engine for a given model.

    Parameters
    ----------
    model : Callable
        The model to be used for inference.
    rng_key : Optional[jax.random.PRNGKey]
        The random number generator key. If not provided, a default key with value 0
        will be used.

    Attributes
    ----------
    model : Callable
        The model used for inference.
    rng_key : jax.random.PRNGKey
        The random number generator key.
    """

    def __init__(self, model: Callable, rng_key=None):
        self.model = model
        if rng_key is None:
            rng_key = jax.random.PRNGKey(0)
        self.rng_key = rng_key

    # pragma: no cover
    def infer(self, **kwargs):
        """
        Perform inference using the specified model.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the model.

        Returns
        -------
        The result of the inference.
        """
        raise NotImplementedError("infer method must be implemented in subclass")

    # pragma: no cover
    def predict(self, **kwargs):
        """
        Generate predictions using the specified model.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the model.

        Returns
        -------
        The predictions generated by the model.
        """
        raise NotImplementedError("predict method must be implemented in subclass")

infer(**kwargs)

Perform inference using the specified model.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the model.

{}

Returns:

Type Description
The result of the inference.
Source code in src/prophetverse/engine.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def infer(self, **kwargs):
    """
    Perform inference using the specified model.

    Parameters
    ----------
    **kwargs
        Additional keyword arguments to be passed to the model.

    Returns
    -------
    The result of the inference.
    """
    raise NotImplementedError("infer method must be implemented in subclass")

predict(**kwargs)

Generate predictions using the specified model.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the model.

{}

Returns:

Type Description
The predictions generated by the model.
Source code in src/prophetverse/engine.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def predict(self, **kwargs):
    """
    Generate predictions using the specified model.

    Parameters
    ----------
    **kwargs
        Additional keyword arguments to be passed to the model.

    Returns
    -------
    The predictions generated by the model.
    """
    raise NotImplementedError("predict method must be implemented in subclass")

MAPInferenceEngine

Bases: InferenceEngine

Maximum a Posteriori (MAP) Inference Engine.

This class performs MAP inference using Stochastic Variational Inference (SVI) with AutoDelta guide. It provides methods for inference and prediction.

Parameters:

Name Type Description Default
model Callable

The probabilistic model to perform inference on.

required
optimizer_factory _NumPyroOptim

The optimizer to use for SVI. Defaults to None.

None
num_steps int

The number of optimization steps to perform. Defaults to 10000.

10000
rng_key PRNGKey

The random number generator key. Defaults to None.

None
Source code in src/prophetverse/engine.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class MAPInferenceEngine(InferenceEngine):
    """
    Maximum a Posteriori (MAP) Inference Engine.

    This class performs MAP inference using Stochastic Variational Inference (SVI)
    with AutoDelta guide. It provides methods for inference and prediction.

    Parameters
    ----------
    model : Callable
        The probabilistic model to perform inference on.
    optimizer_factory : numpyro.optim._NumPyroOptim, optional
        The optimizer to use for SVI. Defaults to None.
    num_steps : int, optional
        The number of optimization steps to perform. Defaults to 10000.
    rng_key : jax.random.PRNGKey, optional
        The random number generator key. Defaults to None.
    """

    def __init__(
        self,
        model: Callable,
        optimizer_factory: numpyro.optim._NumPyroOptim = None,
        num_steps=10000,
        num_samples=_DEFAULT_PREDICT_NUM_SAMPLES,
        rng_key=None,
    ):
        if optimizer_factory is None:
            optimizer_factory = self.default_optimizer_factory
        self.optimizer_factory = optimizer_factory
        self.num_steps = num_steps
        self.num_samples = num_samples
        super().__init__(model, rng_key)

    def default_optimizer_factory(self):
        """Create the default optimizer for SVI."""
        return numpyro.optim.Adam(step_size=0.001)

    def infer(self, **kwargs):
        """
        Perform MAP inference.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the model.

        Returns
        -------
        self
            The updated MAPInferenceEngine object.
        """
        self.guide_ = AutoDelta(self.model, init_loc_fn=init_to_mean())
        svi_ = SVI(self.model, self.guide_, self.optimizer_factory(), loss=Trace_ELBO())
        self.run_results_: SVIRunResult = svi_.run(
            rng_key=self.rng_key, num_steps=self.num_steps, **kwargs
        )

        self.raise_error_if_nan_loss(self.run_results_)

        self.posterior_samples_ = self.guide_.sample_posterior(
            self.rng_key, params=self.run_results_.params, **kwargs
        )
        return self

    def raise_error_if_nan_loss(self, run_results: SVIRunResult):
        """
        Raise an error if the loss is NaN.

        Parameters
        ----------
        run_results : SVIRunResult
            The result of the SVI run.

        Raises
        ------
        MAPInferenceEngineError
            If the last loss is NaN.
        """
        losses = run_results.losses
        if jnp.isnan(losses)[-1]:
            msg = "NaN losses in MAPInferenceEngine."
            msg += " Try decreasing the learning rate or changing the model specs."
            msg += " If the problem persists, please open an issue at"
            msg += " https://github.com/felipeangelimvieira/prophetverse"
            raise MAPInferenceEngineError(msg)

    def predict(self, **kwargs):
        """
        Generate predictions using the trained model.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the model.

        Returns
        -------
        numpyro.samples_
            The predicted samples generated by the model.
        """
        predictive = numpyro.infer.Predictive(
            self.model,
            params=self.run_results_.params,
            guide=self.guide_,
            # posterior_samples=self.posterior_samples_,
            num_samples=self.num_samples,
        )
        numpyro.samples_ = predictive(rng_key=self.rng_key, **kwargs)
        return numpyro.samples_

default_optimizer_factory()

Create the default optimizer for SVI.

Source code in src/prophetverse/engine.py
112
113
114
def default_optimizer_factory(self):
    """Create the default optimizer for SVI."""
    return numpyro.optim.Adam(step_size=0.001)

infer(**kwargs)

Perform MAP inference.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the model.

{}

Returns:

Type Description
self

The updated MAPInferenceEngine object.

Source code in src/prophetverse/engine.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def infer(self, **kwargs):
    """
    Perform MAP inference.

    Parameters
    ----------
    **kwargs
        Additional keyword arguments to be passed to the model.

    Returns
    -------
    self
        The updated MAPInferenceEngine object.
    """
    self.guide_ = AutoDelta(self.model, init_loc_fn=init_to_mean())
    svi_ = SVI(self.model, self.guide_, self.optimizer_factory(), loss=Trace_ELBO())
    self.run_results_: SVIRunResult = svi_.run(
        rng_key=self.rng_key, num_steps=self.num_steps, **kwargs
    )

    self.raise_error_if_nan_loss(self.run_results_)

    self.posterior_samples_ = self.guide_.sample_posterior(
        self.rng_key, params=self.run_results_.params, **kwargs
    )
    return self

predict(**kwargs)

Generate predictions using the trained model.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the model.

{}

Returns:

Type Description
samples_

The predicted samples generated by the model.

Source code in src/prophetverse/engine.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def predict(self, **kwargs):
    """
    Generate predictions using the trained model.

    Parameters
    ----------
    **kwargs
        Additional keyword arguments to be passed to the model.

    Returns
    -------
    numpyro.samples_
        The predicted samples generated by the model.
    """
    predictive = numpyro.infer.Predictive(
        self.model,
        params=self.run_results_.params,
        guide=self.guide_,
        # posterior_samples=self.posterior_samples_,
        num_samples=self.num_samples,
    )
    numpyro.samples_ = predictive(rng_key=self.rng_key, **kwargs)
    return numpyro.samples_

raise_error_if_nan_loss(run_results)

Raise an error if the loss is NaN.

Parameters:

Name Type Description Default
run_results SVIRunResult

The result of the SVI run.

required

Raises:

Type Description
MAPInferenceEngineError

If the last loss is NaN.

Source code in src/prophetverse/engine.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def raise_error_if_nan_loss(self, run_results: SVIRunResult):
    """
    Raise an error if the loss is NaN.

    Parameters
    ----------
    run_results : SVIRunResult
        The result of the SVI run.

    Raises
    ------
    MAPInferenceEngineError
        If the last loss is NaN.
    """
    losses = run_results.losses
    if jnp.isnan(losses)[-1]:
        msg = "NaN losses in MAPInferenceEngine."
        msg += " Try decreasing the learning rate or changing the model specs."
        msg += " If the problem persists, please open an issue at"
        msg += " https://github.com/felipeangelimvieira/prophetverse"
        raise MAPInferenceEngineError(msg)

MAPInferenceEngineError

Bases: Exception

Exception raised for NaN losses in MAPInferenceEngine.

Source code in src/prophetverse/engine.py
290
291
292
293
294
295
class MAPInferenceEngineError(Exception):
    """Exception raised for NaN losses in MAPInferenceEngine."""

    def __init__(self, message="NaN losses in MAPInferenceEngine"):
        self.message = message
        super().__init__(self.message)

MCMCInferenceEngine

Bases: InferenceEngine

Perform MCMC (Markov Chain Monte Carlo) inference for a given model.

Parameters:

Name Type Description Default
model Callable

The model function to perform inference on.

required
num_samples int

The number of MCMC samples to draw.

1000
num_warmup int

The number of warmup samples to discard.

200
num_chains int

The number of MCMC chains to run in parallel.

1
dense_mass bool

Whether to use dense mass matrix for NUTS sampler.

False
rng_key Optional

The random number generator key.

None

Attributes:

Name Type Description
num_samples int

The number of MCMC samples to draw.

num_warmup int

The number of warmup samples to discard.

num_chains int

The number of MCMC chains to run in parallel.

dense_mass bool

Whether to use dense mass matrix for NUTS sampler.

mcmc_ MCMC

The MCMC object used for inference.

posterior_samples_ Dict[str, ndarray]

The posterior samples obtained from MCMC.

samples_predictive_ Dict[str, ndarray]

The predictive samples obtained from MCMC.

samples_ Dict[str, ndarray]

The MCMC samples obtained from MCMC.

Source code in src/prophetverse/engine.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
class MCMCInferenceEngine(InferenceEngine):
    """
    Perform MCMC (Markov Chain Monte Carlo) inference for a given model.

    Parameters
    ----------
    model : Callable
        The model function to perform inference on.
    num_samples : int
        The number of MCMC samples to draw.
    num_warmup : int
        The number of warmup samples to discard.
    num_chains : int
        The number of MCMC chains to run in parallel.
    dense_mass : bool
        Whether to use dense mass matrix for NUTS sampler.
    rng_key : Optional
        The random number generator key.

    Attributes
    ----------
    num_samples : int
        The number of MCMC samples to draw.
    num_warmup : int
        The number of warmup samples to discard.
    num_chains : int
        The number of MCMC chains to run in parallel.
    dense_mass : bool
        Whether to use dense mass matrix for NUTS sampler.
    mcmc_ : MCMC
        The MCMC object used for inference.
    posterior_samples_ : Dict[str, np.ndarray]
        The posterior samples obtained from MCMC.
    samples_predictive_ : Dict[str, np.ndarray]
        The predictive samples obtained from MCMC.
    samples_ : Dict[str, np.ndarray]
        The MCMC samples obtained from MCMC.
    """

    def __init__(
        self,
        model: Callable,
        num_samples=1000,
        num_warmup=200,
        num_chains=1,
        dense_mass=False,
        rng_key=None,
    ):
        self.num_samples = num_samples
        self.num_warmup = num_warmup
        self.num_chains = num_chains
        self.dense_mass = dense_mass
        super().__init__(model, rng_key)

    def infer(self, **kwargs):
        """
        Run MCMC inference.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the MCMC run method.

        Returns
        -------
        self
            The MCMCInferenceEngine object.
        """
        self.mcmc_ = MCMC(
            NUTS(self.model, dense_mass=self.dense_mass, init_strategy=init_to_mean()),
            num_samples=self.num_samples,
            num_warmup=self.num_warmup,
        )
        self.mcmc_.run(self.rng_key, **kwargs)
        self.posterior_samples_ = self.mcmc_.get_samples()
        return self

    def predict(self, **kwargs):
        """
        Generate predictive samples.

        Parameters
        ----------
        **kwargs
            Additional keyword arguments to be passed to the Predictive method.

        Returns
        -------
        Dict[str, np.ndarray]
            The predictive samples.
        """
        predictive = Predictive(
            self.model, self.posterior_samples_, num_samples=self.num_samples
        )

        numpyro.samples_predictive_ = predictive(self.rng_key, **kwargs)
        numpyro.samples_ = self.mcmc_.get_samples()
        return numpyro.samples_predictive_

infer(**kwargs)

Run MCMC inference.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the MCMC run method.

{}

Returns:

Type Description
self

The MCMCInferenceEngine object.

Source code in src/prophetverse/engine.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def infer(self, **kwargs):
    """
    Run MCMC inference.

    Parameters
    ----------
    **kwargs
        Additional keyword arguments to be passed to the MCMC run method.

    Returns
    -------
    self
        The MCMCInferenceEngine object.
    """
    self.mcmc_ = MCMC(
        NUTS(self.model, dense_mass=self.dense_mass, init_strategy=init_to_mean()),
        num_samples=self.num_samples,
        num_warmup=self.num_warmup,
    )
    self.mcmc_.run(self.rng_key, **kwargs)
    self.posterior_samples_ = self.mcmc_.get_samples()
    return self

predict(**kwargs)

Generate predictive samples.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to be passed to the Predictive method.

{}

Returns:

Type Description
Dict[str, ndarray]

The predictive samples.

Source code in src/prophetverse/engine.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def predict(self, **kwargs):
    """
    Generate predictive samples.

    Parameters
    ----------
    **kwargs
        Additional keyword arguments to be passed to the Predictive method.

    Returns
    -------
    Dict[str, np.ndarray]
        The predictive samples.
    """
    predictive = Predictive(
        self.model, self.posterior_samples_, num_samples=self.num_samples
    )

    numpyro.samples_predictive_ = predictive(self.rng_key, **kwargs)
    numpyro.samples_ = self.mcmc_.get_samples()
    return numpyro.samples_predictive_