6-1. 평균 장 근사(Mean-field approximation)

2021. 8. 3. 18:37Bayesian

728x90

이번 글에서는 변 분추론의 방법 중 하나인 평균 장 근사(Mean-field approximation)에 대해 알아보겠습니다.

이번 글에서는 변분추론의 개념이 나오기 때문에, 이전 글들을 참고하시면 도움이 될 것입니다.

[관련 글]변분 베이즈 방법(Variational Bayes Method),범함수(Functional)


$\begin {align} \textrm {ln} p(\textbf {X})=&\displaystyle \int {q(\textbf {Z})~\textrm {ln} \left \{ \frac {p(\textbf {X},\textbf {Z})}{q(\textbf {Z})}\right \} d\textbf {Z}}-\displaystyle \int {q(\textbf {Z})~\textrm {ln} \left \{ \frac {p(\textbf {Z}|\textbf {X})}{q(\textbf {Z})}\right \} d\textbf {Z}} \\ =& \mathcal {L}(q)+\textrm {KL}(q||p) \end {align}$

변 분추론에서의 목표는 결합 분포의 로그-우도를 최대로 만드는 범함수(Variational)를 찾는 것 입니다. $\textrm{KL}(q||p)$텀은 0보다 큰 수 이기 때문에 결합분포의 로그-우도가 최대가 되려면 $\textrm {KL}(q||p)=0$이 되어야 합니다. 따라서 $p(\textbf {Z}|\textbf {X})$와 가장 가까운 $q(\textbf {Z})$를 찾아야 합니다.

$q(\textbf {Z})$의 분포를 정할 때, 다루기 쉬운 분포여야 하며, 충분히 복잡한 함수를 표현할 수 있어야 합니다. 이를 위해 지수족 분포 함수의 곱으로 이루어진 $q(\textbf {Z})$를 사용할 수 있는데 이 방법을 'mean-field approximation'라고 합니다.

즉, 다음과 같이 나타냅니다. 

$q(\textbf {Z})=\displaystyle \prod^{M}_{i=1}{q_i(\textbf {Z}_i)}$

하한에 해당하는 $\mathcal {L}(q)$만 생각해보겠습니다.

$\begin {align} \mathcal {L}(q) =& \displaystyle \int {\displaystyle \prod_i {q_i} \left \{ \textrm {ln} p(\textbf {X}, \textbf {Z})-\displaystyle \sum_i \textrm {ln} q_i \right \} d\textbf {Z} } \\ =& \displaystyle \int {q_{j} \left \{   \displaystyle \int {\textrm {ln} p(\textbf {X},\textbf {Z}) \displaystyle \prod_{i \neq j}{q_i}~d\textbf {Z}_i }\right \} d\textbf {Z}_j } - \displaystyle \int {q_j~\textrm {ln}{q_j}~d\textbf {Z}_j } + \textrm {const} \\ =& \displaystyle \int {q_{j}~\textrm{ln}~\tilde{p}(\textbf{X},\textbf{Z}) d\textbf{Z}_j } - \displaystyle \int{q_j~\textrm{ln}{q_j}~d\textbf{Z}_j } + \textrm{const} \end {align}$

여기서 $\begin {align} \textrm {ln}~\tilde {p}(\textbf {X},\textbf {Z}) = \mathbb {E}_{i \neq j}[\textrm {ln}~p(\textbf {X},\textbf {Z})]+\textrm {const} = \displaystyle \int {\textrm {ln}~p(\textbf {X},\textbf {Z}) \displaystyle \prod_{i \neq j}{q_i} d\textbf {Z}_i }\end {align}$

$\mathcal {L}(q)$식은 $q_j(\textbf {Z}_j)$와 $\tilde{p}(\textbf{X},\textbf{Z})$의 KL Divergence이기 때문에, 최적의 해는 $q_j(\textbf{Z}_j)$ 다음과 같습니다.

$\begin {align} &\textrm {ln}~{q_j^{*}(\textbf {Z}_j)}=\mathbb {E}[\textrm {ln}~p(\textbf {X},\textbf {Z})]+\textrm {const} \\ &q_j^{*}(\textbf {Z}_j)=\frac {\textrm {exp}{(\mathbb {E}_{i \neq j}[\textrm {ln}~p(\textbf {X},\textbf {Z})])}}{\displaystyle \int{\textrm{exp}{(\mathbb{E}_{i \neq j}[\textrm{ln}~p(\textbf{X},\textbf{Z})])}} d\textbf {Z}_j } \end {align} $

위의 식을 보면 각각의 요소(factors)들의 최적 값을 다른 요소들의 기댓값으로 나타낼 수 있다는 것을 알 수 있습니다.

따라서 'mean-field approximation'에서 $q(\textbf {Z})$의 최적 값은 각각의 요소의 최적 값을 반복적으로 갱신하여 구할 수 있습니다.