JAGS の「カウントプロセス」形式でパラメトリック生存モデルを表現する 質問する

JAGS の「カウントプロセス」形式でパラメトリック生存モデルを表現する 質問する

私は、JAGS で時間とともに変化する共変量を考慮した生存モデルを構築しようとしています。パラメトリック モデルにしたいのですが、たとえば生存がワイブル分布に従うと仮定します (ただし、ハザードが変化できるようにしたいので、指数関数では単純すぎます)。つまり、これは基本的に、flexsurvパラメトリック モデルで時間とともに変化する共変量を考慮したパッケージで実行できることのベイジアン バージョンです。

したがって、私はデータを「カウントプロセス」形式で入力できるようにしたいと考えています。各被験者には複数の行があり、それぞれが共変量が一定であった時間間隔に対応しています(このPDFでまたはこここれは、またはパッケージ(start, stop]で許可される定式化です。survivalflexurv

残念ながら、JAGS で生存分析を実行する方法の説明では、被験者ごとに 1 行を想定しているようです。

このより単純なアプローチを採用し、カウント プロセス形式に拡張しようとしましたが、モデルは分布を正しく推定しません。

失敗した試み:

ここに例があります。まず、データを生成します。

library('dplyr')
library('survival')

## Make the Data: -----
set.seed(3)
n_sub <- 1000
current_date <- 365*2

true_shape <- 2
true_scale <- 365

dat <- data_frame(person = 1:n_sub,
                  true_duration = rweibull(n = n_sub, shape = true_shape, scale = true_scale),
                  person_start_time = runif(n_sub, min= 0, max= true_scale*2),
                  person_censored = (person_start_time + true_duration) > current_date,
                  person_duration = ifelse(person_censored, current_date - person_start_time, true_duration)
)

  person person_start_time person_censored person_duration
   (int)             (dbl)           (lgl)           (dbl)
1      1          11.81416           FALSE        487.4553
2      2         114.20900           FALSE        168.7674
3      3          75.34220           FALSE        356.6298
4      4         339.98225           FALSE        385.5119
5      5         389.23357           FALSE        259.9791
6      6         253.71067           FALSE        259.0032
7      7         419.52305            TRUE        310.4770

次に、データを被験者ごとに 2 つの観察に分割します。各被験者を時間 = 300 で分割します (被験者が時間 = 300 に到達しなかった場合、その被験者には 1 つの観察しか得られません)。

## Split into multiple observations per person: --------
cens_point <- 300 # <----- try changing to 0 for no split; if so, model correctly estimates
dat_split <- dat %>%
  group_by(person) %>%
  do(data_frame(
    split = ifelse(.$person_duration > cens_point, cens_point, .$person_duration),
    START = c(0, split[1]),
    END = c(split[1], .$person_duration),
    TINTERVAL = c(split[1], .$person_duration - split[1]), 
    CENS = c(ifelse(.$person_duration > cens_point, 1, .$person_censored), .$person_censored), # <— edited original post here due to bug; but problem still present when fixing bug
    TINTERVAL_CENS = ifelse(CENS, NA, TINTERVAL),
    END_CENS = ifelse(CENS, NA, END)
  )) %>%
  filter(TINTERVAL != 0)

  person    split START      END TINTERVAL CENS TINTERVAL_CENS
   (int)    (dbl) (dbl)    (dbl)     (dbl) (dbl)        (dbl)
1      1 300.0000     0 300.0000 300.00000     1           NA
2      1 300.0000   300 487.4553 187.45530     0    187.45530
3      2 168.7674     0 168.7674 168.76738     1           NA
4      3 300.0000     0 300.0000 300.00000     1           NA
5      3 300.0000   300 356.6298  56.62979     0     56.62979
6      4 300.0000     0 300.0000 300.00000     1           NA

これで、JAGS モデルを設定できます。

## Set-Up JAGS Model -------
dat_jags <- as.list(dat_split)
dat_jags$N <- length(dat_jags$TINTERVAL)
inits <- replicate(n = 2, simplify = FALSE, expr = {
       list(TINTERVAL_CENS = with(dat_jags, ifelse(CENS, TINTERVAL + 1, NA)),
            END_CENS = with(dat_jags, ifelse(CENS, END + 1, NA)) )
})

model_string <- 
"
  model {
    # set priors on reparameterized version, as suggested
    # here: https://sourceforge.net/p/mcmc-jags/discussion/610036/thread/d5249e71/?limit=25#8c3b
    log_a ~ dnorm(0, .001) 
    log(a) <- log_a
    log_b ~ dnorm(0, .001)
    log(b) <- log_b
    nu <- a
    lambda <- (1/b)^a
    
    for (i in 1:N) {
      # Estimate Subject-Durations:
      CENS[i] ~ dinterval(TINTERVAL_CENS[i], TINTERVAL[i])
      TINTERVAL_CENS[i] ~ dweibull( nu, lambda )
    }
  }
"

library('runjags')
param_monitors <- c('a', 'b', 'nu', 'lambda') 
fit_jags <- run.jags(model = model_string,
                     burnin = 1000, sample = 1000, 
                     monitor = param_monitors,
                     n.chains = 2, data = dat_jags, inits = inits)
# estimates:
fit_jags
# actual:
c(a=true_shape, b=true_scale)

分割点がどこにあるかによって、モデルは基礎となる分布の非常に異なるパラメータを推定します。データがカウント プロセス形式に分割されていない場合にのみ、パラメータが適切に取得されます。この種の問題では、この方法でデータをフォーマットするのは適切ではないようです。

もし私が仮定を見落としていて、私の問題が JAGS に関係するものではなく、むしろ私が問題を定式化する方法に関係するものである場合、提案は大歓迎です。時間によって変化する共変量をパラメトリック生存モデルで使用できない (そして、一定のハザードを仮定し、実際には基礎となる分布を推定しない Cox モデルなどのモデルでのみ使用できる) ことに絶望しているかもしれませんが、上で述べたように、flexsurvregR のパッケージはパラメトリック モデルでの定式化に対応しています(start, stop]

このようなモデルを別の言語(たとえば、JAGS ではなく STAN)で構築する方法を知っている方がいらっしゃいましたら、それも教えていただけると幸いです。

編集:

Chris Jackson は電子メールで役立つアドバイスを提供しています。

JAGS の切り捨てには T() 構造が必要だと思います。基本的に、人が生存しているが共変量が一定である各期間 (t[i], t[i+1]) について、生存時間は期間の開始時に左切り捨てされ、終了時には右打ち切りされる可能性があります。したがって、次のように記述します。y[i] ~ dweib(shape, scale[i])T(t[i], )

私はこの提案を次のように実装してみました:

model {
  # same as before
  log_a ~ dnorm(0, .01) 
  log(a) <- log_a
  log_b ~ dnorm(0, .01)
  log(b) <- log_b
  nu <- a
  lambda <- (1/b)^a
  
  for (i in 1:N) {
    # modified to include left-truncation
    CENS[i] ~ dinterval(END_CENS[i], END[i])
    END_CENS[i] ~ dweibull( nu, lambda )T(START[i],)
  }
}

残念ながら、これではうまくいきません。古いコードでは、モデルはスケール パラメータをほぼ正しく取得していましたが、形状パラメータについてはひどい結果でした。この新しいコードでは、正しい形状パラメータに非常に近づきますが、スケール パラメータを常に過大評価しています。過大評価の度合いは、分割点がどれだけ遅れるかと相関していることに気づきました。分割点が早い場合 ( cens_point = 50)、実際には過大評価はありません。遅い場合 ( cens_point = 350)、過大評価が大きくなります。

問題は、観測値の「二重カウント」に関連しているのではないかと思いました。t=300 で検閲された観測値を見た後、同じ人物から t=400 で検閲されていない観測値を見ると、この人物はワイブル パラメータに関する推論に 2 つのデータ ポイントを提供しているように見えますが、実際には 1 つのポイントしか提供していないはずです。そのため、各人物にランダム効果を組み込むことを試みましたが、これは完全に失敗し、パラメータの推定値が非常に大きく (50 ~ 90 の範囲) なりましたnu。その理由はわかりませんが、おそらく別の投稿で質問することになると思います。問題が関連しているかどうかはわかりませんので、この投稿全体のコード (そのモデルの JAGS コードを含む) を参照してください。ここ

ベストアンサー1

rstanarmSTAN のラッパーであるパッケージを使用できます。これにより、標準の R 式表記を使用して生存モデルを記述できます。stan_surv関数は、「カウント プロセス」形式で引数を受け入れます。ワイブルを含むさまざまな基本ハザード関数を使用してモデルを適合できます。

rstanarm-関数の存続部分はstan_survCRANではまだ利用できないため、パッケージを直接インストールする必要があります。mc-stan.org

install.packages("rstanarm", repos = c("https://mc-stan.org/r-packages/", getOption("repos")))

以下のコードをご覧ください:

library(dplyr)
library(survival)
library(rstanarm)

## Make the Data: -----
set.seed(3)
n_sub <- 1000
current_date <- 365*2

true_shape <- 2
true_scale <- 365

dat <- data_frame(person = 1:n_sub,
                  true_duration = rweibull(n = n_sub, shape = true_shape, scale = true_scale),
                  person_start_time = runif(n_sub, min= 0, max= true_scale*2),
                  person_censored = (person_start_time + true_duration) > current_date,
                  person_duration = ifelse(person_censored, current_date - person_start_time, true_duration)
)

## Split into multiple observations per person: --------
cens_point <- 300 # <----- try changing to 0 for no split; if so, model correctly estimates
dat_split <- dat %>%
  group_by(person) %>%
  do(data_frame(
    split = ifelse(.$person_duration > cens_point, cens_point, .$person_duration),
    START = c(0, split[1]),
    END = c(split[1], .$person_duration),
    TINTERVAL = c(split[1], .$person_duration - split[1]), 
    CENS = c(ifelse(.$person_duration > cens_point, 1, .$person_censored), .$person_censored), # <— edited original post here due to bug; but problem still present when fixing bug
    TINTERVAL_CENS = ifelse(CENS, NA, TINTERVAL),
    END_CENS = ifelse(CENS, NA, END)
  )) %>%
  filter(TINTERVAL != 0)
dat_split$CENS <- as.integer(!(dat_split$CENS))


# Fit STAN survival model
mod_tvc <- stan_surv(
  formula = Surv(START, END, CENS) ~ 1,
  data = dat_split,
  iter = 1000,
  chains = 2,
  basehaz = "weibull-aft")

# Print fit coefficients
mod_tvc$coefficients[2]
unname(exp(mod_tvc$coefficients[1]))

出力は真の値()と一致するtrue_shape <- 2; true_scale <- 365

> mod_tvc$coefficients[2]
weibull-shape 
     1.943157 
> unname(exp(mod_tvc$coefficients[1]))
[1] 360.6058

また、を使用して STAN ソースを調べ、rstan::get_stanmodel(mod_tvc$stanfit)STAN コードを JAGS で行った試行と比較することもできます。

おすすめ記事