Executive summary

You can use Stan to write a statistical model and easily compile it to efficient C++ and native code, and cause it to be linked in to a running Python (or R) process while giving you access to not just the obvious HMC sampling methods but also the underying posterior density function $\pi(y \, | \theta)$ and its gradient under log_prob and grad_log_prob, respectively. This document shows how to do that in Python.

Context

On the Stan team, we get a lot of questions from people looking to use Stan for algorithm development. Stan is opinionated software and believes that models are mostly separate from inference mechanisms, so most of the people looking to do algorithm development with us buy into that; they're looking to perhaps add some mechanism for discrete parameters or maybe add another type of approximation alternative to ADVI for a given statistical model.

I claim that this version is actually pretty easy to do in practice, but we haven't written anything up about it because the Stan team values robustness and generality first and foremost, and we're eagerly awaiting something that beats the No U-Turn Sampler (which has since been improved in our source code from the original paper). There are a few of us who would like to support the broader research community in using Stan for algorithm development, even if we do not include any of these algorithms with the distribution of Stan. This document should illustrate how this can work.

In [6]:
import pystan
import numpy as np
In [85]:
# Construct an arbitrary model
model_code = """
data {
  int numObs;
  int numGroups;
  matrix[numObs, numGroups] group;
  vector[numObs] y;
}
parameters {
  vector[numGroups] beta;
  real<lower=0> sigma;
}
model {
  beta ~ normal(0, 20);
  sigma ~ normal(0, 5);
  y ~ normal(group * beta, sigma);
}
"""
model = pystan.StanModel(model_code=model_code)
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_390a0ec538f0cb4f0a4483ba147e0080 NOW.
In [108]:
# fake data
def gen_fake_data(N=100, K=3):
    group_ = np.random.choice(K, N)
    # one-hot encode in matrix
    group = np.zeros((N, K))
    group[np.arange(N), group_] = 1
    beta = np.random.normal(0, 20, K)
    sigma = np.abs(np.random.normal(0, 5))
    print("sigma: {}".format(sigma))
    y = np.random.normal(group.dot(beta), sigma)
    return dict(group=group, y=y, numObs=N, numGroups=K, beta=beta)
fd = gen_fake_data()
fd
sigma: 9.157885217081992
Out[108]:
{'group': array([[0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.]]),
 'y': array([-24.76301019, -35.97887663, -20.23274477, -41.49054707,
        -16.02999448, -34.09165568, -24.88020563, -24.05463857,
        -16.86881039, -18.8714874 , -30.42428944, -11.29308102,
        -24.34802981, -29.40753916, -28.59344693, -32.39550442,
        -36.37132741, -16.20657938, -24.92878376, -13.25484089,
        -13.17939153, -19.27899561, -35.26058827, -15.55430219,
        -12.2679381 ,  -7.88959863, -22.76132319, -25.95365739,
        -25.42946611, -18.30691545, -17.3644884 , -12.23038274,
        -26.18277742, -32.8749564 , -27.22624725, -26.41887204,
        -16.35451777, -20.67424802, -27.49359065, -24.22740679,
        -16.48492267, -14.92948987, -39.5280163 , -16.32757889,
        -18.03326048, -16.02600095, -24.86888698, -11.08624993,
        -39.63109918, -21.36258758, -37.90687437, -26.49305471,
        -19.65015011, -29.10433207, -22.13268255, -31.98791592,
          3.45796284, -26.93794208, -21.262007  , -26.21311793,
         -9.91721967, -28.45472649, -25.29096601, -15.85858002,
        -19.52721789, -22.61223839, -16.48710779, -17.49848262,
        -27.80772762,  -6.55352763, -38.28107976, -40.83528445,
        -25.64275568, -20.9426822 , -39.5144073 , -12.74040907,
        -13.68907224, -25.04254282, -16.91582672, -36.15031289,
        -25.97056788, -18.37441026, -26.20311141, -20.06767311,
        -33.2656917 , -33.07670545, -34.38271064, -23.44184368,
        -23.32873611,  -1.62293945, -27.85781366, -16.54050565,
         -8.79686481, -34.06324398, -21.78571493, -16.59197712,
        -22.35961064, -31.80027921, -30.64009453, -25.19952429]),
 'numObs': 100,
 'numGroups': 3,
 'beta': array([-25.87069529, -22.30629952, -19.12452159])}
In [109]:
# You must first run one of the methods that gives you a fit object:
fit = model.sampling(data=fd)
fit
Out[109]:
Inference for Stan model: anon_model_390a0ec538f0cb4f0a4483ba147e0080.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

          mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
beta[1]  -24.4    0.02   1.45 -27.25 -25.37 -24.42 -23.46 -21.52   5131    1.0
beta[2] -23.44    0.02   1.45 -26.26 -24.43 -23.46 -22.46 -20.59   4251    1.0
beta[3] -20.77    0.03   1.74 -24.19 -21.93 -20.79 -19.66 -17.22   4088    1.0
sigma     8.81    0.01   0.63   7.68   8.36   8.79   9.22  10.14   3649    1.0
lp__    -269.8    0.03   1.47 -273.5 -270.5 -269.4 -268.7 -268.0   1796    1.0

Samples were drawn using NUTS at Sat Oct 20 22:47:52 2018.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
In [39]:
# This then exposes a log prob and grad_log_prob method:
list(filter(lambda x: "prob" in x, dir(fit)))
Out[39]:
['grad_log_prob', 'log_prob']
In [193]:
help(fit.grad_log_prob)
Help on built-in function grad_log_prob:

grad_log_prob(...) method of stanfit4anon_model_390a0ec538f0cb4f0a4483ba147e0080_913984840570262847.StanFit4Model instance
    Expose the grad_log_prob of the model to stan_fit so user
    can call this function.
    
    Parameters
    ----------
    upar : array
        The real parameters on the unconstrained space.
    adjust_transform : bool
        Whether we add the term due to the transform from constrained
        space to unconstrained space implicitly done in Stan.

In [200]:
fit.log_prob(upar=[-24.4, -23.4, -20.7, np.log(8.8)]) # adjust_transform seems to default to True
Out[200]:
-267.78711575183934
In [145]:
# We can pass parameter values to `log_prob` in the order we defined them,
# which is shown in `flatnames`
print(fit.flatnames)
fit.log_prob(upar=[-50, 40, 32, 2.8])
['beta[1]', 'beta[2]', 'beta[3]', 'sigma']
Out[145]:
-764.6526305472523

Stochastic gradient descent with Stan

Stochastic gradient descent is basically just gradient descent with random data subsets at each iteration. Stan does not implement this, but we can easily do this using the log_prob and grad_log_prob methods of the fit object PyStan returns. To be concrete, the algorithm will: randomly initialize, evaluate grad_log_prob at that position in parameter space, adjust the parameters according to that gradient, and repeat the last two steps until the parameters change less than some tolerance.

In [148]:
# Generate a ton of data (or this would be less fun)
sgd_data = gen_fake_data(N=1000000)
sigma: 2.6148919343314736
In [268]:
# Assume the observations are always called 'y', and always 3 beta parameters,
# since this is just an illustration and not productionalizable code
def sgd(model, data, numParams, size=10000, alpha=0.0001, tol=1e-4):
    params = np.arange(numParams)
    new_params = np.random.normal(0, 20, 4)
    new_params[-1] = np.log(8.8) # hax - sigma should be positive, also SGD SUX
    print("new_params: {}".format(new_params))
    i = 0
    while np.abs(np.sum(new_params - params)) > tol:
        params = new_params.copy()
        subset = np.random.choice(np.arange(len(data['y'])), size)
        # Use this hack to get a model with data
        model_with_data = model.sampling(iter=1, algorithm='Fixed_param',
                                         data={**data,
                                               'y': data['y'][subset],
                                               'group': data['group'][subset],
                                               'numObs': size})
        grad = model_with_data.grad_log_prob(params, adjust_transform=True)
        #print("grad {}".format(grad))
        new_params += alpha * grad
        i += 1
        if i % 50 == 0:
            print("new params: {}".format(new_params))
    return params
print(sgd_data['beta'])
sgd(model, sgd_data, 4)
[-16.02126103  15.29848408   7.19441698]
new_params: [ 1.96796133 22.80975954  9.85273041  2.17475172]
new params: [-0.37768434 21.8305395   9.50682107  2.25083747]
new params: [-3.06085882 20.7139119   9.1058192   2.06121228]
new params: [-6.28408818 19.36283677  8.62658011  1.83093959]
new params: [-10.43155098  17.62031151   8.01021932   1.48260734]
new params: [-14.83405766  15.78608479   7.36925782   0.90130219]
new params: [-15.92467355  15.33906932   7.19869735   0.98572609]
new params: [-16.02137515  15.2997305    7.18322427   1.02415639]
new params: [-16.02514756  15.3059415    7.17593135   0.92693517]
new params: [-16.01917244  15.30327473   7.18611837   0.97838085]
new params: [-16.02410886  15.31047737   7.18688107   1.04860497]
new params: [-16.03099437  15.28555268   7.20274812   1.02933739]
new params: [-16.02934058  15.30490861   7.17599538   0.9481738 ]
new params: [-16.01501387  15.2982072    7.18548136   1.0183951 ]
new params: [-16.01690784  15.29332343   7.19543123   0.85908236]
new params: [-16.02178972  15.28882103   7.17410878   1.12567647]
new params: [-16.02469626  15.29361907   7.17094974   1.03576549]
new params: [-16.01071213  15.28540296   7.20041093   0.87728606]
new params: [-16.01045924  15.29718304   7.18037604   0.86822714]
new params: [-16.01283494  15.30441838   7.17886016   0.86468173]
new params: [-15.9988999   15.29369511   7.16942873   0.91479659]
Out[268]:
array([-16.00822695,  15.29228693,   7.1787844 ,   0.96846103])
In [269]:
np.exp(0.96846103)
Out[269]:
2.633887864259263

As you can see, SGD takes forever and seems to severely underestimate the variance. Oh well, at least you can farm it out to millions of computers at once.

Expectation Maximization

Expectation maximization can be thought of as a generalization of k-means with discrete cluster assignments to models that have additional parameters (that might be dependent on the cluster assigments). Bob Carpenter has a good write-up; I'll try to summarize the algorithm briefly here.

The algorithm is traditionally explained as broken up into two steps, one for calculating the expectation of the cluster assignments (or other latent parameters) and using those to generate the assignments, and another for maximizing the values of the other parameters given those assignments. One way to accomplish this with Stan is to provide two Stan models, one for each step. I'll call these expectation.stan and maximization.stan and use Markov Chain Monte Carlo to calculate the integral in the expectation step, and our BFGS optimizer for the maximization step. The model we already defined can be used as is for the maximization step

In [105]:
maxi_model = model
In [116]:
expectation_model = pystan.StanModel(model_code="""
data {
  int numObs;
  int numGroups;
  row_vector[numGroups] beta;
  vector[numObs] y;
}
parameters {
  simplex[numGroups] group[numObs];
  real<lower=0> sigma;
}
model {
  vector[numObs] mu;
  for (n in 1:numObs)
    mu[n] = beta * group[n];
  beta ~ normal(0, 20);
  sigma ~ normal(0, 5);
  y ~ normal(mu, sigma);
}
""")
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_5fb00b74056ca004b9d6262d4b3f4b40 NOW.
In [270]:
efit = expectation_model.sampling(fd)
efit
WARNING:pystan:Rhat for parameter sigma is 1.1157325303401602!
WARNING:pystan:Rhat for parameter lp__ is 1.1221980792747852!
WARNING:pystan:Rhat above 1.1 or below 0.9 indicates that the chains very likely have not mixed
WARNING:pystan:Chain 1: E-BFMI = 0.079956287503388
WARNING:pystan:Chain 2: E-BFMI = 0.10182956992244153
WARNING:pystan:Chain 3: E-BFMI = 0.06689687643557556
WARNING:pystan:Chain 4: E-BFMI = 0.06306580770061987
WARNING:pystan:E-BFMI below 0.2 indicates you may need to reparameterize your model
WARNING:pystan:Truncated summary with the 'fit.__repr__' method. For the full summary use 'print(fit)'
Out[270]:
Warning: Shown data is truncated to 100 parameters
For the full summary use 'print(fit)'

Inference for Stan model: anon_model_5fb00b74056ca004b9d6262d4b3f4b40.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

              mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
group[1,1]     0.3  2.4e-3   0.18   0.01   0.15    0.3   0.45    0.6   5246    1.0
group[2,1]    0.27  2.2e-3   0.15   0.02   0.14   0.27   0.39   0.53   4675    1.0
group[3,1]    0.32  2.7e-3   0.18   0.02   0.16   0.32   0.47   0.63   4678    1.0
group[4,1]    0.24  2.1e-3   0.14   0.01   0.12   0.24   0.36   0.49   4847    1.0
group[5,1]    0.33  2.6e-3   0.19   0.01   0.17   0.33   0.49   0.66   5419    1.0
group[6,1]    0.26  2.2e-3   0.16   0.01   0.13   0.26    0.4   0.53   5185    1.0
group[7,1]     0.3  2.4e-3   0.17   0.02   0.16   0.29   0.43   0.59   4874    1.0
group[8,1]     0.3  2.6e-3   0.17   0.02   0.16    0.3   0.45    0.6   4354    1.0
group[9,1]    0.33  3.0e-3   0.19   0.02   0.17   0.33   0.49   0.65   4080    1.0
group[10,1]   0.31  2.6e-3   0.19   0.01   0.15   0.31   0.47   0.63   5168    1.0
group[11,1]   0.28  2.5e-3   0.17   0.01   0.13   0.28   0.42   0.56   4516    1.0
group[12,1]   0.34  3.1e-3    0.2   0.02   0.16   0.34   0.51   0.68   4236    1.0
group[13,1]    0.3  2.4e-3   0.18   0.01   0.14    0.3   0.45    0.6   5590    1.0
group[14,1]   0.28  2.6e-3   0.16   0.02   0.15   0.28   0.42   0.57   4085    1.0
group[15,1]   0.28  2.5e-3   0.17   0.02   0.14   0.28   0.42   0.57   4520    1.0
group[16,1]   0.27  2.4e-3   0.16   0.01   0.14   0.27   0.41   0.55   4424    1.0
group[17,1]   0.26  2.4e-3   0.15   0.01   0.13   0.26   0.38   0.52   4083    1.0
group[18,1]   0.33  2.8e-3   0.19   0.02   0.16   0.33    0.5   0.65   4677    1.0
group[19,1]    0.3  2.4e-3   0.18   0.02   0.15   0.29   0.45   0.59   5281    1.0
group[20,1]   0.33  2.9e-3   0.19   0.01   0.16   0.33    0.5   0.67   4387    1.0
group[21,1]   0.33  2.9e-3    0.2   0.02   0.16   0.32    0.5   0.67   4706    1.0
group[22,1]   0.32  2.7e-3   0.18   0.02   0.16   0.32   0.47   0.62   4504    1.0
group[23,1]   0.27  2.3e-3   0.15   0.01   0.14   0.27   0.39   0.52   4440    1.0
group[24,1]   0.32  3.1e-3   0.19   0.02   0.16   0.32    0.5   0.65   3866    1.0
group[25,1]   0.34  3.0e-3    0.2   0.02   0.17   0.34   0.51   0.67   4269    1.0
group[26,1]   0.35  2.9e-3    0.2   0.02   0.18   0.35   0.52    0.7   4932    1.0
group[27,1]   0.31  2.7e-3   0.18   0.01   0.15   0.31   0.46   0.61   4457    1.0
group[28,1]   0.29  2.5e-3   0.17   0.01   0.14   0.29   0.44   0.59   4729    1.0
group[29,1]   0.29  2.5e-3   0.18   0.01   0.14   0.29   0.44   0.59   4811    1.0
group[30,1]   0.32  2.9e-3   0.19   0.01   0.15   0.32   0.48   0.65   4322    1.0
group[31,1]   0.32  2.7e-3   0.19   0.02   0.16   0.32   0.48   0.64   4944    1.0
group[32,1]   0.34  2.9e-3    0.2   0.02   0.16   0.34   0.51   0.68   4884    1.0
group[33,1]   0.29  2.4e-3   0.17   0.01   0.15   0.29   0.44   0.58   4944    1.0
group[34,1]   0.27  2.3e-3   0.16   0.02   0.14   0.28    0.4   0.54   4586    1.0
group[35,1]   0.29  2.7e-3   0.17   0.01   0.14   0.29   0.44   0.58   4135    1.0
group[36,1]    0.3  2.6e-3   0.18   0.01   0.14    0.3   0.45   0.59   4502    1.0
group[37,1]   0.33  2.4e-3   0.19   0.01   0.16   0.32   0.49   0.65   6215    1.0
group[38,1]   0.31  2.7e-3   0.18   0.01   0.16   0.31   0.47   0.62   4543    1.0
group[39,1]   0.29  2.3e-3   0.17   0.02   0.14   0.29   0.43   0.58   5457    1.0
group[40,1]    0.3  2.5e-3   0.18   0.01   0.15   0.31   0.46    0.6   5111    1.0
group[41,1]   0.33  2.9e-3   0.19   0.01   0.16   0.32    0.5   0.66   4473    1.0
group[42,1]   0.33  2.8e-3   0.19   0.01   0.16   0.34    0.5   0.66   4930    1.0
group[43,1]   0.25  2.2e-3   0.15   0.01   0.12   0.25   0.37    0.5   4738    1.0
group[44,1]   0.33  2.7e-3   0.19   0.02   0.16   0.33   0.49   0.65   4896    1.0
group[45,1]   0.32  2.9e-3   0.19   0.01   0.16   0.32   0.48   0.63   4255    1.0
group[46,1]   0.33  3.0e-3   0.19   0.02   0.17   0.32   0.49   0.65   4007    1.0
group[47,1]    0.3  2.6e-3   0.17   0.02   0.15   0.29   0.45    0.6   4622    1.0
group[48,1]   0.35  2.9e-3    0.2   0.02   0.17   0.35   0.52   0.69   4940    1.0
group[49,1]   0.25  2.2e-3   0.15   0.01   0.12   0.24   0.37    0.5   4587    1.0
group[50,1]   0.31  2.8e-3   0.18   0.01   0.15   0.31   0.47   0.62   4331    1.0
group[51,1]   0.25  2.2e-3   0.15   0.01   0.13   0.25   0.38   0.51   4546    1.0
group[52,1]    0.3  2.6e-3   0.17   0.02   0.16    0.3   0.44   0.59   4195    1.0
group[53,1]   0.31  2.7e-3   0.18   0.02   0.15   0.31   0.47   0.62   4555    1.0
group[54,1]   0.28  2.5e-3   0.17   0.02   0.14   0.27   0.42   0.57   4427    1.0
group[55,1]   0.31  2.8e-3   0.18   0.01   0.16   0.31   0.46   0.61   4002    1.0
group[56,1]   0.28  2.5e-3   0.16   0.01   0.14   0.27   0.41   0.55   4006    1.0
group[57,1]   0.39  3.4e-3   0.23   0.02   0.19   0.39   0.58   0.77   4461    1.0
group[58,1]   0.29  2.5e-3   0.17   0.02   0.15   0.29   0.44   0.58   4700    1.0
group[59,1]   0.31  2.6e-3   0.18   0.02   0.16   0.31   0.46   0.62   4811    1.0
group[60,1]    0.3  2.5e-3   0.17   0.01   0.15   0.29   0.44   0.59   4581    1.0
group[61,1]   0.35  3.2e-3   0.21   0.02   0.17   0.35   0.53   0.69   4253    1.0
group[62,1]   0.29  2.4e-3   0.17   0.01   0.14   0.28   0.43   0.58   5007    1.0
group[63,1]    0.3  2.3e-3   0.17   0.01   0.15   0.29   0.44    0.6   5757    1.0
group[64,1]   0.33  2.8e-3   0.19   0.02   0.16   0.32   0.48   0.65   4508    1.0
group[65,1]   0.32  3.1e-3   0.19   0.01   0.15   0.32   0.47   0.63   3528    1.0
group[66,1]   0.31  2.7e-3   0.18   0.01   0.16   0.31   0.46   0.61   4477    1.0
group[67,1]   0.33  2.9e-3   0.19   0.02   0.16   0.33   0.49   0.65   4293    1.0
group[68,1]   0.32  2.9e-3   0.19   0.02   0.16   0.32   0.48   0.64   4092    1.0
group[69,1]   0.29  2.5e-3   0.17   0.02   0.14   0.29   0.43   0.57   4485    1.0
group[70,1]   0.36  3.0e-3   0.21   0.02   0.18   0.36   0.54   0.72   5046    1.0
group[71,1]   0.25  2.2e-3   0.15   0.01   0.12   0.25   0.37    0.5   4485    1.0
group[72,1]   0.25  1.9e-3   0.15   0.01   0.12   0.25   0.38    0.5   5822    1.0
group[73,1]    0.3  2.5e-3   0.17   0.02   0.15   0.29   0.44   0.59   4565    1.0
group[74,1]   0.32  2.6e-3   0.18   0.02   0.16   0.32   0.47   0.62   4746    1.0
group[75,1]   0.25  2.2e-3   0.15   0.01   0.12   0.24   0.37    0.5   4387    1.0
group[76,1]   0.34  2.9e-3    0.2   0.02   0.16   0.33   0.51   0.68   4885    1.0
group[77,1]   0.33  2.7e-3    0.2   0.01   0.16   0.33    0.5   0.67   5608    1.0
group[78,1]    0.3  2.6e-3   0.18   0.01   0.15    0.3   0.45   0.59   4512    1.0
group[79,1]   0.32  2.8e-3   0.19   0.01   0.16   0.32   0.48   0.64   4559    1.0
group[80,1]   0.26  2.2e-3   0.15   0.01   0.13   0.26   0.39   0.52   4488    1.0
group[81,1]    0.3  2.6e-3   0.17   0.02   0.15   0.29   0.44   0.59   4442    1.0
group[82,1]   0.32  2.7e-3   0.19   0.02   0.16   0.32   0.48   0.63   4677    1.0
group[83,1]   0.29  2.6e-3   0.17   0.01   0.15   0.29   0.44   0.58   4451    1.0
group[84,1]   0.32  2.5e-3   0.18   0.02   0.16   0.32   0.47   0.62   5309    1.0
group[85,1]   0.27  2.2e-3   0.16   0.02   0.14   0.27    0.4   0.54   5271    1.0
group[86,1]   0.27  2.1e-3   0.16   0.02   0.14   0.27    0.4   0.54   5271    1.0
group[87,1]   0.27  2.3e-3   0.15   0.02   0.13   0.27    0.4   0.53   4713    1.0
group[88,1]   0.31  2.6e-3   0.18   0.02   0.15   0.31   0.46    0.6   4535    1.0
group[89,1]   0.31  2.7e-3   0.18   0.02   0.16   0.31   0.46    0.6   4314    1.0
group[90,1]   0.38  3.2e-3   0.22   0.02   0.19   0.38   0.57   0.74   4540    1.0
group[91,1]   0.29  2.4e-3   0.17   0.01   0.14   0.28   0.43   0.58   4958    1.0
group[92,1]   0.33  2.8e-3   0.19   0.01   0.17   0.34    0.5   0.65   4815    1.0
group[93,1]   0.36  2.9e-3   0.21   0.02   0.17   0.36   0.53    0.7   5141    1.0
group[94,1]   0.27  2.2e-3   0.15   0.02   0.14   0.27    0.4   0.53   5093    1.0
group[95,1]   0.32  2.8e-3   0.18   0.01   0.16   0.32   0.47   0.62   4335    1.0
group[96,1]   0.33  2.8e-3   0.19   0.02   0.16   0.33    0.5   0.65   4617    1.0
group[97,1]   0.31  2.5e-3   0.18   0.02   0.15   0.31   0.46   0.61   5255    1.0
group[98,1]   0.27  2.4e-3   0.16   0.01   0.13   0.27   0.41   0.55   4504    1.0
group[99,1]   0.28  2.5e-3   0.16   0.01   0.14   0.28   0.42   0.56   4318    1.0
lp__        -602.8    8.82  54.14 -701.7 -644.1 -601.7 -562.1 -505.6     38   1.12

Samples were drawn using NUTS at Sun Oct 21 00:35:14 2018.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
In [271]:
def em_algo(maxi_model, expectation_model, data, num_iter=10):
    print("Original beta: {}".format(data['beta']))
    #Random initialization for parameters
    data['beta'] = np.random.normal(0, 20, len(data['beta']))
    for i in range(num_iter):
        print("Iteration {}; beta: {}".format(i, data['beta']))
        # E step
        efit = expectation_model.sampling(data)
        # Use the sample mean as the group assignment simplex
        data['group'] = np.mean(efit.extract()['group'], axis=0)
        
        # M step
        mfit = maxi_model.optimizing(data)
        data['beta'] = mfit['beta']
em_algo(maxi_model, expectation_model, fd)
Original beta: [  34.81878147 -112.64493292   46.28279639]
Iteration 0; beta: [ -3.59344058 -32.86852646  21.8326248 ]
Iteration 1; beta: [ 20.9338156  -86.47885418  46.93093845]
WARNING:pystan:n_eff / iter for parameter lp__ is 0.0006772846691974476!
WARNING:pystan:n_eff / iter below 0.001 indicates that the effective sample size has likely been overestimated
WARNING:pystan:Rhat for parameter sigma is 1.5772541462411396!
WARNING:pystan:Rhat for parameter lp__ is 2.2929696460277094!
WARNING:pystan:Rhat above 1.1 or below 0.9 indicates that the chains very likely have not mixed
WARNING:pystan:33 of 4000 iterations ended with a divergence (0.825%).
WARNING:pystan:Try running with adapt_delta larger than 0.8 to remove the divergences.
WARNING:pystan:954 of 4000 iterations saturated the maximum tree depth of 10 (23.85%)
WARNING:pystan:Run again with max_treedepth larger than 10 to avoid saturation
WARNING:pystan:Chain 1: E-BFMI = 0.08636177088090548
WARNING:pystan:Chain 2: E-BFMI = 0.1183776457445831
WARNING:pystan:Chain 3: E-BFMI = 0.024184527416314404
WARNING:pystan:Chain 4: E-BFMI = 0.0377874880074901
WARNING:pystan:E-BFMI below 0.2 indicates you may need to reparameterize your model
Iteration 2; beta: [ 20.91314216 -87.35267788  46.98427761]
WARNING:pystan:Rhat for parameter sigma is 1.1662902113775195!
WARNING:pystan:Rhat for parameter lp__ is 1.2119837156258804!
WARNING:pystan:Rhat above 1.1 or below 0.9 indicates that the chains very likely have not mixed
WARNING:pystan:2 of 4000 iterations ended with a divergence (0.05%).
WARNING:pystan:Try running with adapt_delta larger than 0.8 to remove the divergences.
WARNING:pystan:Chain 1: E-BFMI = 0.04755913631432413
WARNING:pystan:Chain 2: E-BFMI = 0.03675889809161817
WARNING:pystan:Chain 3: E-BFMI = 0.06411132073869108
WARNING:pystan:Chain 4: E-BFMI = 0.03561881104416588
WARNING:pystan:E-BFMI below 0.2 indicates you may need to reparameterize your model
Iteration 3; beta: [ 21.49493122 -88.13704701  46.3259473 ]
WARNING:pystan:Rhat for parameter lp__ is 1.1213335016555586!
WARNING:pystan:Rhat above 1.1 or below 0.9 indicates that the chains very likely have not mixed
WARNING:pystan:Chain 1: E-BFMI = 0.023793654865949664
WARNING:pystan:Chain 2: E-BFMI = 0.12081905558602336
WARNING:pystan:Chain 3: E-BFMI = 0.06737230575837842
WARNING:pystan:Chain 4: E-BFMI = 0.04131102555268011
WARNING:pystan:E-BFMI below 0.2 indicates you may need to reparameterize your model
Iteration 4; beta: [ 26.75281521 -89.2801291   39.89030341]
WARNING:pystan:Chain 1: E-BFMI = 0.07256835650813424
WARNING:pystan:Chain 2: E-BFMI = 0.08232131628476169
WARNING:pystan:Chain 3: E-BFMI = 0.0956485940248879
WARNING:pystan:Chain 4: E-BFMI = 0.07396086232764862
WARNING:pystan:E-BFMI below 0.2 indicates you may need to reparameterize your model
Iteration 5; beta: [ 21.45265578 -90.37360984  45.76797409]
WARNING:pystan:Rhat for parameter sigma is 1.2106669582831677!
WARNING:pystan:Rhat for parameter lp__ is 1.3455533516161826!
WARNING:pystan:Rhat above 1.1 or below 0.9 indicates that the chains very likely have not mixed
WARNING:pystan:1 of 4000 iterations ended with a divergence (0.025%).
WARNING:pystan:Try running with adapt_delta larger than 0.8 to remove the divergences.
WARNING:pystan:Chain 1: E-BFMI = 0.08093989430990671
WARNING:pystan:Chain 2: E-BFMI = 0.04485334291887987
WARNING:pystan:Chain 3: E-BFMI = 0.0402337144623449
WARNING:pystan:Chain 4: E-BFMI = 0.09438939951333641
WARNING:pystan:E-BFMI below 0.2 indicates you may need to reparameterize your model
Iteration 6; beta: [ 22.3658989  -91.13397713  44.56188515]
WARNING:pystan:Rhat for parameter sigma is 1.1687867446346574!
WARNING:pystan:Rhat for parameter lp__ is 1.1924936689298336!
WARNING:pystan:Rhat above 1.1 or below 0.9 indicates that the chains very likely have not mixed
WARNING:pystan:Chain 1: E-BFMI = 0.0779069583773524
WARNING:pystan:Chain 2: E-BFMI = 0.05118534236948631
WARNING:pystan:Chain 3: E-BFMI = 0.05005323025048814
WARNING:pystan:Chain 4: E-BFMI = 0.056202162274556336
WARNING:pystan:E-BFMI below 0.2 indicates you may need to reparameterize your model
Iteration 7; beta: [ 22.18616077 -91.89531043  44.62643986]
WARNING:pystan:Chain 1: E-BFMI = 0.11152796900714382
WARNING:pystan:Chain 2: E-BFMI = 0.04707671104563218
WARNING:pystan:Chain 3: E-BFMI = 0.11882304910974623
WARNING:pystan:Chain 4: E-BFMI = 0.03660767458667997
WARNING:pystan:E-BFMI below 0.2 indicates you may need to reparameterize your model
Iteration 8; beta: [ 22.25029701 -92.60420504  44.40234528]
WARNING:pystan:Rhat for parameter sigma is 1.1193038531183725!
WARNING:pystan:Rhat for parameter lp__ is 1.2071429213007652!
WARNING:pystan:Rhat above 1.1 or below 0.9 indicates that the chains very likely have not mixed
WARNING:pystan:1000 of 4000 iterations saturated the maximum tree depth of 10 (25.0%)
WARNING:pystan:Run again with max_treedepth larger than 10 to avoid saturation
WARNING:pystan:Chain 1: E-BFMI = 0.07468723493614429
WARNING:pystan:Chain 2: E-BFMI = 0.028783557701503534
WARNING:pystan:Chain 3: E-BFMI = 0.01687177261674604
WARNING:pystan:Chain 4: E-BFMI = 0.09237032602968896
WARNING:pystan:E-BFMI below 0.2 indicates you may need to reparameterize your model
Iteration 9; beta: [ 22.46913085 -93.43458416  44.09826419]
WARNING:pystan:Rhat for parameter sigma is 1.195771019355573!
WARNING:pystan:Rhat for parameter lp__ is 1.2822946971097566!
WARNING:pystan:Rhat above 1.1 or below 0.9 indicates that the chains very likely have not mixed
WARNING:pystan:Chain 1: E-BFMI = 0.06588655593571956
WARNING:pystan:Chain 2: E-BFMI = 0.08392865446456031
WARNING:pystan:Chain 3: E-BFMI = 0.10500837409569222
WARNING:pystan:Chain 4: E-BFMI = 0.04798908203005176
WARNING:pystan:E-BFMI below 0.2 indicates you may need to reparameterize your model
In [ ]: