{ "cells": [ { "cell_type": "markdown", "id": "4432747a-903e-450d-a9c5-3173aff8c4ac", "metadata": {}, "source": [ "# Add arbitrary Stan code to an HBOMS model\n", "\n", "One goal of HBOMS is to produce human-readable Stan models that can be further developed and customized by the user. However, modifying an HBOMS-generated model makes it hard to change the basic parameter structure, as the modifications would have to be re-applied after each change. Therefore, the user can also include \"plugin\" Stan code in the HBOMS model definition. This plugin code is added as-is at the end of the code blocks (or beginning for the `functions` block). \n", "\n", "To use this functionality, the user must provide a Python dictionary with keys equal to the names of the code blocks. For example\n", "\n", "```python\n", "plugin_code = {\n", " \"data\" : \"real SomeDataPoint;\",\n", " \"model\" : \"SomeDataPoint ~ exponential(lambda);\"\n", "}\n", "```\n", "\n", "will add the declaration for `SomeDataPoint` to the `data` block, and a \"sampling statement\" to the model block. This sampling statement uses a parameter `lambda` and so it is important that this is defined in the HBOMS model.\n", "\n", "The `plugin_code` dictionary is then passed as a keyword argument in the `HbomsModel` initiation\n", "\n", "```python\n", "model = hboms.HbomsModel(\n", " name = \"my_model\",\n", " # other arguments...\n", " plugin_code = plugin_code\n", ")\n", "```\n", "\n", "\n", "## Example: SEIR model with serial interval data\n", "\n", "As an example, we look at an epidemiological model of an infectious disease in which we have time series data, but also additional observations in the form of observed serial intervals (the time between symptom onset of a primary and secondary patient). The epidemiological model will be the SEIR model (susceptible, exposed, infected, recovered), and we will pretend that the serial interval is the same as the generation interval (the time between infections). \n", "\n", "The ODEs for the SEIR model are given by \n", "\n", "\\begin{equation}\n", "\\begin{split}\n", "\\frac{dS}{dt} &= -\\beta SI \\\\\n", "\\frac{dE}{dt} &= \\beta SI - \\alpha E \\\\\n", "\\frac{dI}{dt} &= \\alpha E - \\gamma I\n", "\\end{split}\n", "\\end{equation}\n", "\n", "and initial condition $S(0) = 1-\\epsilon$, $E(0) = 0$, $I(0) = \\epsilon$.\n", "We'll assument that the observed data at time `t` follows a Poisson distribution with mean `SampleSize * I(t)`.\n", "Let's start with creating the nessecary ingredients of an HBOMS model." ] }, { "cell_type": "code", "execution_count": 11, "id": "c7056445-e6e5-4a60-a6c4-b966cbe844f1", "metadata": {}, "outputs": [], "source": [ "import hboms\n", "\n", "params = [\n", " hboms.Parameter(\"beta\", 1.0, \"random\", scale=0.1), \n", " hboms.Parameter(\"alpha\", 1.0, \"random\", scale=0.1),\n", " hboms.Parameter(\"gamma\", 0.5, \"random\", scale=0.1),\n", " hboms.Parameter(\"SampleSize\", 1e2, \"const\")\n", "]\n", "\n", "init = \"\"\"\n", "S_0 = 0.998;\n", "E_0 = 0.0;\n", "I_0 = 0.002; // we're lazy and use a hard-coded epsilon\n", "\"\"\"\n", "\n", "odes = \"\"\"\n", "ddt_S = -beta * S * I;\n", "ddt_E = beta * S * I - alpha * E;\n", "ddt_I = alpha * E - gamma * I;\n", "\"\"\"\n", "\n", "state = [hboms.Variable(\"S\"), hboms.Variable(\"E\"), hboms.Variable(\"I\")]\n", "obs = [hboms.Observation(\"Counts\", data_type=\"int\")]\n", "dists = [hboms.StanDist(\"poisson\", \"Counts\", [\"I * SampleSize\"])]" ] }, { "cell_type": "markdown", "id": "637962a3-d236-4ed3-97f1-c7ece398974c", "metadata": {}, "source": [ "As additional data, we have a number of observed generation intervals $T_G$ which in the SEIR model are hypoexponentially distributed with rate parameters $\\alpha$ and $\\beta$. A hypoexponential random variable is the sum of two exponentially distributed random variables. Hence it is like the Erlang distribution, but the rates can differ between the summands.\n", "In e.g. [Roberts and Heesterbeek](https://doi.org/10.1007/s00285-007-0112-8), we find an explicit formula for the PDF\n", "\n", "$$ f(t) = \\frac{\\alpha\\gamma}{\\gamma-\\alpha} (e^{-\\alpha t} - e^{-\\gamma t}) $$\n", "\n", "This equation is a bit annoying for estimating $\\alpha$ and $\\gamma$, though, because it has a removable singularity at $\\alpha = \\gamma$. Therefore, we use a somewhat less explicit, but equivalent, equation\n", "\n", "$$ f(t) = -(1,0) \\exp(\\Theta t) \\Theta (1,1)'\\quad \\mbox{where}\\quad \\Theta = \\left(\\begin{array}{cc}-\\alpha & \\alpha \\\\ 0 &-\\gamma \\end{array}\\right) $$\n", "and $\\exp$ is the matrix exponential and we trust that the implementation of the matrix exponential in Stan is efficient and accurate.\n", "\n", "To implement this, we will add some plugin code to the `functions`, `data` and `model` blocks.\n", "We define the `hypoexp` distribution in the `functions` block, declare a data array (and it's size) in the `data` block and write a \"sampling statement\" for the `model` block. For defining the data array, we can make use of the variable `R` which is the number of units.\n", "For the sampling statements, we'll make use of the parameters $\\alpha$ and $\\gamma$.\n", "\n", ":::{caution}\n", "As these are random parameters, we will have to index them. If you were to change $\\alpha$ or $\\gamma$ to fixed parameters, you would have to make sure the indexing is removed in the plugin code.\n", ":::\n", "\n", "With this plugin code in place, we finally initiate the HBOMS model." ] }, { "cell_type": "code", "execution_count": 2, "id": "9df0ecf3-65a8-4096-ab60-1d6458808f0e", "metadata": {}, "outputs": [], "source": [ "plugin_code_functions = \"\"\"\n", "real hypoexp_lpdf(real t, real a, real b) {\n", " matrix[2, 2] Theta = [[-a, a], [0,-b]];\n", " return log(sum((-matrix_exp(Theta .* t) * Theta)[1,:]));\n", "}\n", "\"\"\"\n", "\n", "plugin_code_data = \"\"\"\n", "int NumIntervalSamples;\n", "array[R, NumIntervalSamples] real IntervalSamples;\n", "\"\"\"\n", "\n", "plugin_code_model = \"\"\"\n", "for ( r in 1:R ) {\n", " for ( n in 1:NumIntervalSamples ) {\n", " IntervalSamples[r,n] ~ hypoexp(alpha[r], gamma[r]);\n", " }\n", "}\n", "\"\"\"\n", "\n", "plugin_code = {\n", " \"functions\" : plugin_code_functions,\n", " \"data\" : plugin_code_data,\n", " \"model\" : plugin_code_model\n", "}\n", "\n", "model = hboms.HbomsModel(\n", " \"plugin_model\",\n", " state,\n", " odes,\n", " init,\n", " params,\n", " obs,\n", " dists,\n", " plugin_code=plugin_code\n", ")" ] }, { "cell_type": "markdown", "id": "b9fd9796-8dc3-4e00-8f96-4fbc6d723293", "metadata": {}, "source": [ "Using the `show_stan_model` utility, we can have a look at what HBOMS has generated.\n", "Notice that the plugin function is defined before the other functions. The motivation is that we might want to define functions that are used in the ODE model definition. The other plugin code is added at the end of the code blocks." ] }, { "cell_type": "code", "execution_count": 3, "id": "9a157227-3a47-4ce7-a35d-ec1e89bd46ed", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "
\n", "functions {\n",
" /* User-provided plugin code */\n",
" real hypoexp_lpdf(real t, real a, real b) {\n",
" matrix[2, 2] Theta = [[-a, a], [0,-b]];\n",
" return log(sum((-matrix_exp(Theta .* t) * Theta)[1,:]));\n",
" }\n",
" /* vector field */\n",
" vector ode_fun(real t, vector state, real beta, real alpha, real gamma) {\n",
" /* unpack the state variables */\n",
" real S = state[1];\n",
" real E = state[2];\n",
" real I = state[3];\n",
" /* declare derivatives of state variables */\n",
" real ddt_S, ddt_E, ddt_I;\n",
" /* user-defined ODEs */\n",
" ddt_S = -beta * S * I;\n",
" ddt_E = beta * S * I - alpha * E;\n",
" ddt_I = alpha * E - gamma * I;\n",
" /* return literal vector with derivatives */\n",
" return [ddt_S, ddt_E, ddt_I]';\n",
" }\n",
" /* initial condition */\n",
" vector gen_init() {\n",
" /* declare initial variables */\n",
" real S_0, E_0, I_0;\n",
" /* user-defined initial condition */\n",
" S_0 = 0.998;\n",
" E_0 = 0.0;\n",
" I_0 = 0.002;\n",
" /* return literal vector with initival state variables */\n",
" return [S_0, E_0, I_0]';\n",
" }\n",
" /* IVP solver function */\n",
" array[] vector solve_ivp(data array[] real Time, real beta, real alpha, real gamma) {\n",
" int N = num_elements(Time); /* number of time points */\n",
" vector[3] init_state = gen_init(); /* generate initial state */\n",
" array[N] vector[3] sol; /* allocate space for solution */\n",
" sol[1:N, 1:3] = ode_rk45(ode_fun, init_state, 0.0, Time[1:N], beta, alpha, gamma); /* solve initial value problem */\n",
" return sol; /* return the solution */\n",
" }\n",
" /* auxiliary function for map_rect */\n",
" vector map_rect_helper_fun(vector ppar, vector upar, data array[] real rdat, data array[] int idat) {\n",
" int N = idat[1]; /* number of time points */\n",
" /* solve initial value problem */\n",
" array[N] vector[3] sol = solve_ivp(rdat[1:N + 0], upar[1], upar[2], upar[3]);\n",
" /* concatenate solution into a vector */\n",
" vector[3 * N] res;\n",
" for ( n in 1:N ) res[(n - 1) * 3 + 1:n * 3] = sol[n];\n",
" return res;\n",
" }\n",
" /* log-likelihood function */\n",
" vector loglik_fun(int Counts, vector state, data real SampleSize) {\n",
" /* unpack required state variables */\n",
" real I = state[3];\n",
" /* declare log-lik variables */\n",
" real ll_Counts = 0.0;\n",
" /* user-defined log-likelihood */\n",
" ll_Counts = poisson_lpmf(Counts | I * SampleSize); /* poisson log-likelihood */\n",
" return [ll_Counts]';\n",
" }\n",
" /* rng function */\n",
" int Counts_rng(vector state, data real SampleSize) {\n",
" /* unpack required state variables */\n",
" real I = state[3];\n",
" /* declare variable to-be-returned */\n",
" int Counts;\n",
" /* user-defined sampler */\n",
" Counts = poisson_rng(I * SampleSize); /* random poisson sample */\n",
" return Counts;\n",
" }\n",
"}\n",
"\n",
"data {\n",
" int<lower=0> R; /* number of units */\n",
" array[R] int<lower=0> N; /* number of observations per unit */\n",
" array[R, max(N)] real<lower=0.0> Time; /* observation times */\n",
" /* observations */\n",
" array[R, max(N)] int Counts;\n",
" int<lower=0> NSim; /* number of simulation time points */\n",
" array[R, NSim] real<lower=0.0> TimeSim; /* simulation time points */\n",
" /* constants */\n",
" real<lower=0.0> SampleSize;\n",
" /* User-provided plugin code */\n",
" int NumIntervalSamples;\n",
" array[R, NumIntervalSamples] real IntervalSamples;\n",
"}\n",
"\n",
"transformed data {\n",
" /* declarations */\n",
" array[R, max(N)] real rdats;\n",
" array[R, 1] int idats; /* integer data */\n",
" /* definitions */\n",
" rdats[:, 1:max(N)] = Time;\n",
" idats[:, 1] = N;\n",
"}\n",
"\n",
"parameters {\n",
" /* individual parameters (and their hyper-parameters) */\n",
" array[R] real<lower=0.0> beta;\n",
" real loc_beta;\n",
" real<lower=0.0> scale_beta;\n",
" array[R] real<lower=0.0> alpha;\n",
" real loc_alpha;\n",
" real<lower=0.0> scale_alpha;\n",
" array[R] real<lower=0.0> gamma;\n",
" real loc_gamma;\n",
" real<lower=0.0> scale_gamma;\n",
"}\n",
"\n",
"transformed parameters {\n",
" array[R] vector[3] upars; /* prepare data structure for map_rect */\n",
" vector[0] ppar; /* no population parameters required for map_rect */\n",
" /* assign unit-parameters to array of vectors */\n",
" upars[:, 1] = beta;\n",
" upars[:, 2] = alpha;\n",
" upars[:, 3] = gamma;\n",
"}\n",
"\n",
"model {\n",
" /* solve ODEs in parallel */\n",
" vector[sum(N) * 3] concat_res = map_rect(map_rect_helper_fun, ppar, upars, rdats, idats);\n",
" /* compute log-likelihood of observations */\n",
" for ( r in 1:R ) {\n",
" for ( n in 1:N[r] ) {\n",
" /* extract state */\n",
" int idx = 3 * (sum(N[:r - 1]) + n - 1) + 1;\n",
" vector[3] state = concat_res[idx:idx + 2];\n",
" /* compute likelihood of observation given state */\n",
" target += loglik_fun(Counts[r, n], state, SampleSize);\n",
" }\n",
" }\n",
" /* prior */\n",
" beta ~ lognormal(loc_beta, scale_beta);\n",
" loc_beta ~ student_t(3.0, 0.0, 2.5);\n",
" scale_beta ~ student_t(3.0, 0.0, 2.5);\n",
" alpha ~ lognormal(loc_alpha, scale_alpha);\n",
" loc_alpha ~ student_t(3.0, 0.0, 2.5);\n",
" scale_alpha ~ student_t(3.0, 0.0, 2.5);\n",
" gamma ~ lognormal(loc_gamma, scale_gamma);\n",
" loc_gamma ~ student_t(3.0, 0.0, 2.5);\n",
" scale_gamma ~ student_t(3.0, 0.0, 2.5);\n",
" /* User-provided plugin code */\n",
" for ( r in 1:R ) {\n",
" for ( n in 1:NumIntervalSamples ) {\n",
" IntervalSamples[r,n] ~ hypoexp(alpha[r], gamma[r]);\n",
" }\n",
" }\n",
"}\n",
"\n",
"generated quantities {\n",
" array[R, NSim] real S_sim, E_sim, I_sim;\n",
" array[R, max(N)] int Counts_sim;\n",
" vector[sum(N) * 1] log_lik; /* vector of log-likelihoods for model comparison */\n",
" for ( r in 1:R ) {\n",
" array[NSim] vector[3] u_sim = solve_ivp(TimeSim[r], beta[r], alpha[r], gamma[r]); /* solve ODEs at simulation times */\n",
" array[N[r]] vector[3] u_sim_obs = solve_ivp(Time[r, 1:N[r]], beta[r], alpha[r], gamma[r]); /* solve ODEs at observation times */\n",
" for ( n in 1:NSim ) {\n",
" S_sim[r, n] = u_sim[n, 1];\n",
" E_sim[r, n] = u_sim[n, 2];\n",
" I_sim[r, n] = u_sim[n, 3];\n",
" }\n",
" for ( n in 1:N[r] ) {\n",
" int idx = 1 * (sum(N[:r - 1]) + n - 1) + 1;\n",
" log_lik[idx:idx + 0] = loglik_fun(Counts[r, n], u_sim_obs[n], SampleSize); /* record log-likelihood of each observation */\n",
" Counts_sim[r, n] = Counts_rng(u_sim_obs[n], SampleSize); /* simulate data at observation times */\n",
" }\n",
" }\n",
"}\n",
"