https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html
Generalizing the convolution operator to irregular domains is typically expressed as a neighborhood aggregation or message passing scheme.
-
$x_i^{(k-1)}$ : denotes node features of node$i$ in layer$(k-1)$ -
$e_{j, i} \in R^D$ : edge features from node$j$ to node$i$ (optional) -
$\square$ : a differentiable, permutation invariant function, e.g. sum, mean, or max -
$\gamma$ ,$\phi$ : differentiable functions such as MLPs (Multiple Layer Perceptrons)
ํ๋ก์ ํธ ์ฃผ์ ๋ ๊ทธ๋ํ๋ฅผ ์ด์ฉํ ๋จธ์ ๋ฌ๋์ผ๋ก ๋ถ์ ๊ตฌ์กฐ์ ๋ํ ๋ฐ์ดํฐ๋ฅผ ํ์ต์์ผ ์๊ทน์ ๋ชจ๋จผํธ(ฮผ)๋ฅผ ์์ธกํ๋ ๊ฒ์
๋๋ค. ์ฌ์ฉ๋ ์์์งํฉ์ {C, H, N, O, F}์ด๊ณ , ๊ฐ ์์๋ฒํธ {6, 1, 7, 8, 9}๋ฅผ ๊ฐ์ง๋๋ค. ์ฃผ์ด์ง ๋ฐ์ดํฐ๋ ๋ถ์๊ตฌ์กฐ๋ฅผ ๋ฌธ์์ด๋ก ๋ํ๋ธ SMILES ํ๊ธฐ, ๋ฑ๋ฐฉํฅ ๋ถ๊ทน๋ฅ , XYZ ์ขํ ๋ฑ์ ์ ๋ณด๋ฅผ ์ ๊ณตํฉ๋๋ค. ์ด๋ฌํ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํด ๋ถ์๊ตฌ์กฐ๋ฅผ ๊ทธ๋ํ ๊ตฌ์กฐ๋ก ๋ง๋ค์ด ํ๊ฒ ๊ฐ์ธ ์๊ทน์ ๋ชจ๋จผํธ๋ฅผ ํฌํจํ์ฌ ํ์ตํ๊ณ , ํ์ต์ด ๋๋ ๋ชจ๋ธ์ ์ด์ฉํด ๊ฐ ๋ถ์์ ๋ํ ์๊ทน์ ๋ชจ๋จผํธ๋ฅผ ์์ธกํ๊ณ , ์ด ๊ฒฐ๊ณผ๊ฐ ์ ๋ฐ์๋์๋์ง ๊ฒ์ฆํฉ๋๋ค.
์ฐ์ ์ด๊ธฐ ํ์ต ๋ชจ๋ธ์ GNN์ ๊ธฐ๋ณธ ํํ๋ก ์์
์๊ฐ์ ๊ฐ์ฅ ์ค์ํ๊ฒ ๋ค๋ฃจ์ด์ก๋ Message Passing์ ์ด์ฉํ Neural Network model์ ์๊ฐํ์์ต๋๋ค. ๋ง์นจ ์์
์๊ฐ ์ค์ต(Lab3)์์ MPNN Layer๋ฅผ ๊ตฌํํ์๊ณ ๋ณด๋ค ์ฝ๊ฒ ํ๋ก์ ํธ๋ฅผ ์งํํ ์ ์์์ต๋๋ค. ์กฐ๊ต๋๊ป์ ์ ๊ณตํด์ฃผ์ ์ฝ๋๋ฅผ ํตํด mol ๋ฐ์ดํฐ๋ฅผ csv๋ก ๋ณํํ์ฌ ๊ฐ train_list์ test_list๋ฅผ ์์ฑํ๊ณ , MPNN์ ์ด์ฉํ์ฌ ๊ทธ๋ํ ๊ตฌ์กฐ๋ฅผ ์ค๊ณํ์ฌ ํ์ตํ๋ ๋ฐฉํฅ์ผ๋ก ์งํํ์ต๋๋ค.
Pytorch geometric์ด ์ ๊ณตํ๋ MessagePassing ํด๋์ค๋ฅผ ์ด์ฉํ๊ณ , MyNet ํด๋์ค๋ฅผ ํ์ฑํ์ฌ torch์์ ์ ๊ณตํ๋ Module์ ์ด์ฉํด Neural Network ํ์ต ํ๊ฒฝ์ ๊ตฌ์ถํ์ต๋๋ค.
PyG์ MessagePassing ํด๋์ค๋ฅผ ์ด์ฉํ๋ฉด, ์์ ์์์์ ๋ฉ์์ง ์ ํ๋ฅผ ์๋์ผ๋ก ์ฒ๋ฆฌํด์ฃผ๊ธฐ ๋๋ฌธ์ ์ ์์์ ฮฆ message()์ ฮณ update() ํจ์๋ง ๋ฐ๋ก ๊ตฌํํ์์ต๋๋ค.
๋ณธ Project์ task๋ Regression task ์ด๋ฏ๋ก ๋ง์ง๋ง MPNN Layer์ ์ถ๋ ฅ ์ฑ๋์ 1๊ฐ๋ก ์ค์ ํ๊ณ , ์ค์นผ๋ผ ๊ฐ์ ์ถ๋ ฅํ๊ธฐ ์ํด flatten() ๋ฉ์๋๋ก ์์ฑ๋ one-dimensional vector์ torch.mean์ ์ทจํ๊ณ reshape() ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ์ค์นผ๋ผ ๊ฐ์ ๋ฐํํ์ต๋๋ค. Output vector ์ ์ฒด๋ฅผ ์ปค๋ฒํ๋ ์ฌ์ด์ฆ(100์ผ๋ก ์ค์ )๋ก Max pooling์ ์ด์ฉํด๋ณด๊ธฐ๋ ํ์์ผ๋(์ ์ฌ์ง์ self.output layer), ์ฑ๋ฅ ํฅ์์ ์ด๋ ค์์ด ์์ด torch.mean์ ์ด์ฉํ๊ฒ ๋์์ต๋๋ค. ๊ทธ ํ ๋ชจ๋ธ ํ์ต ์ต์ ํ๋ฅผ ์ํด epoch ๊ฐ์ ์ฆ๊ฐ์ํค๋ฉฐ score๋ฅผ ๊พธ์คํ ํ์ธํ์์ต๋๋ค. ์ฝ 150ํ ์ด์๋ถํฐ๋ overfitting์ด ๋ฐ์ํ์ฌ ์ฑ๋ฅ์ด ๊ฐ์ํ์๊ณ , epoch๊ฐ์ ์ค์ด๊ณ aggregation() ํจ์๋ฅผ ์ถ๊ฐํ์ฌ ์ต์ ํํ๊ณ ์ ํ์์ผ๋, aggregation() ์ถ๊ฐ๋ overfitting ์ ๋๊ฐ ๋ ์ฌํด์ง๋ ๊ฒ์ ํ์ธํ ์ ์์์ต๋๋ค.
(์ ๋ชจ๋ธ + epoch 150ํ Score: 0.91624)
์ดํ ์์
์๊ฐ์ ํ์ตํ์๋ GCN ๊ตฌํ์๋ ๋์ ํ์์ผ๋, pytorch ๋ฌธ๋ฒ ์ค๋ฅ ๋ฑ์ผ๋ก ์์ฑ์ ์ํค์ง ๋ชปํ์๊ณ , ์กฐ๊ต๋๊ป์ ์ฌ๋ ค์ฃผ์ DimeNet ๋
ผ๋ฌธ์ ์ฝ๊ณ ์๋ก์ด ๋ฐฉํฅ์ ์ก๊ฒ ๋์์ต๋๋ค.
์ฒ์ ์ฌ์ฉํ ํ์ต ๋ชจ๋ธ์ GNN์ ๋ํ์ ์ธ ํ๋ ์์ํฌ์ธ MPNN์
๋๋ค. GNN ๋ชจ๋ธ์ ์ด์ฉํ ๋ถ์ ์์ธก์์ ๊ทธ๋ํ๋ ํ๋์ ๋ถ์, ๋
ธ๋๋ ์์, ์ฃ์ง๋ ๋ฏธ๋ฆฌ ์ ์๋ ๋ถ์ ๊ตฌ์กฐ ๊ทธ๋ํ ๋๋ ์์ ๊ฐ์ ๊ฑฐ๋ฆฌ๋ก ๊ฒฐ์ ๋ฉ๋๋ค. GNN์ ๋ฉ์์ง๋ ๋ค์๊ณผ ๊ฐ์ด ์ ์๋ฉ๋๋ค.

์ฆ, ๋ฉ์์ง ํจ์๋ ํ๊ฒ ๋
ธ๋์ ํ์ฌ ์ํ๊ฐ, ํ๊ฒ ๋
ธ๋์ ์ด์์ ํ์ฌ ์ํ, ๊ทธ๋ฆฌ๊ณ ํด๋น ๋
ธ๋์ ์ด์์ ์ฐ๊ฒฐํ๋ ์ฃ์ง์ ์ ๋ณด๋ฅผ ํฉํ์ฌ ํ๊ฒ ๋
ธ๋์ ๋ค์ ๋ฉ์์ง๋ก ์ ๋ฌํฉ๋๋ค.
MPNN์์ ๋ฉ์ธ์ง ํจ์์ ์
๋ฐ์ดํธ ํจ์๋ ์๋์ ์์์ผ๋ก ์ ์๋ฉ๋๋ค. ๐ฅ!(#%&)๋ (๐ โ 1) ๋ ์ด์ด์ ๋
ธ๋ ๐์ ํน์ฑ์ ๋ํ๋ด๊ณ , โก๋ aggregation ํจ์, ๐ โR0 ๋
ธ๋ ๐์์ ๋
ธ๋ ๐์ ๋
ธ๋ ํน์ฑ์
๋ํ๋
๋๋ค. MPNN์ ํ๊ท(Regression)๋ฅผ ์ด์ฉํด ์
๋ฐ์ดํธ ํฉ๋๋ค.
๋ฐ๋ผ์ ์ผ๋ฐ์ ์ธ non-directional GNN์ ๋ถ์ ์์ธกํ ๋ ํจ๊ณผ์ ์ธ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง์ง๋ง, GNN์ ์
๋ฐ์ดํธ ๋ฉ์ธ์ง๋ ์ด์ ์๋ฒ ๋ฉ ์์์ ์์์ ๊ฑฐ๋ฆฌ ์ ๋ณด๋ก๋ง ๊ตฌ์ฑ๋๊ณ , ๋ถ์ ์์ธก์ ์ค์ํ ๋ฐฉํฅ ์ ๋ณด์ธ ๊ฒฐํฉ๊ฐ, ํ์ ๋ฑ์๋ ๋
๋ฆฝ๋ ํน์ฑ์ ๊ฐ์ง๋๋ค.
๋ํ, GNN์ ์ ์ฒด ์์ ๊ฐ ๊ฑฐ๋ฆฌ์ ๋ํ ์ ๋ณด๋ฅผ ๋ด์ ํ๋ ฌ์ ์ฌ์ฉํ์ง ์๊ณ cutoff ๊ฑฐ๋ฆฌ ๐๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ด ๊ฒฝ์ฐ, ๊ฒฐํฉ ๊ธธ์ด์ ์ด์ ์์๊ฐ ๊ฐ์ ๋ถ์๊ฐ์ cutoff ๊ฑฐ๋ฆฌ๊ฐ 2.5ร
๋ณด๋ค ์๊ฑฐ๋ ๊ฐ๋ค๋ฉด, GNN์ ์ด ๋ถ์๊ฐ ์ด๋ค ๊ฒ์ธ์ง ๊ตฌ๋ถํ์ง ๋ชปํฉ๋๋ค. GNN(MPNN) ๋ชจ๋ธ์ ์ด๋ฌํ ํ๊ณ์ ์ ๊ทน๋ณตํ๊ณ ์, molecular data์ regression task์์ ๋์ ์ฑ๋ฅ์ ๊ฐ์ง๋ ๋ชจ๋ธ์ ์ฐพ๊ฒ ๋์๊ณ , ๊ทธ์ค์์๋ ์กฐ๊ต๋๊ป์ ์ฌ๋ ค์ฃผ์ ๋
ผ๋ฌธ ์ค ๊ฐ์ฅ ์ฑ๋ฅ์ด ์ข์๋ DimeNet์ ์ฌ์ฉํ์์ต๋๋ค.
DimeNet์ ๊ธฐ์กด GNN์ ๋ฌธ์ ์ ์ ๋ณด์ํ๊ธฐ ์ํด ๋ฐฉํฅ ์ ๋ณด, ๊ฒฐํฉ ์ฌ์ด์ ๊ฐ๋๋ฅผ ๋ด์ ๋ฉ์ธ์ง๋ฅผ ์์ ์ฌ์ด์ ์๋ฒ ๋ฉํ๊ณ , ์ด ๋ฉ์ธ์ง๋ฅผ ์ด์ฉํ์ฌ ์ ๋ขฐ์ ํ ์๊ณ ๋ฆฌ์ฆ ๊ฐ์ด ๊ด์ธก๋ ๋
ธ๋๋ฅผ ๋ฐํ์ผ๋ก ๊ด์ธก๋์ง ์์ ๋
ธ๋์ ๋ถํฌ๋ฅผ ๊ณ์ฐํ์ฌ ์
๋ฐ์ดํธํฉ๋๋ค.
DimeNet ๋ฉ์ธ์ง์ ํต์ฌ์ ๋ฐฉํฅ ์๋ฒ ๋ฉ, ๋ฉ์์ง ์๋ฒ ๋ฉ์ผ๋ก ๋๋์์ต๋๋ค. ๋จผ์ , ๋ฐฉํฅ์ฑ์ ๊ฐ์ง ์๋ฒ ๋ฉ์ด ๊ฐ๋ฅํ ๊ฒ์ ์์์ ๋ฌผ๋ฆฌ์ ํน์ฑ๊ฐ์ ํ์ ์ ๋ถ๋ณํ๊ธฐ ๋๋ฌธ์
๋๋ค. ์์๊ฐ ํ์ ์ ํด๋ ์ด์ ๊ฐ ๊ฑฐ๋ฆฌ์ ๊ฒฐํฉ๊ฐ์ด ๋ณํ์ง ์๋ ๊ฒ์ ์ด์ฉํ์ฌ ์ฐจ๋จ๊ฑฐ๋ฆฌ ๋ด์ ์์๋ค์ด ์๋ก ๋ฐ์ํ ๋๋ง ๋ถ๋ณ์ฑ์ด ๊นจ์ง๋ ๊ฒฝ์ฐ๋ฅผ ๋ชจ๋ธ์ ์ถ๊ฐ์ ์ผ๋ก ๋ฐ์ํ๋ฉด ๋ฉ๋๋ค. DimeNet์ ์์ i์ ์ด์ ๋
ธ๋ j์ ๋ํด ๊ฐ ์ธ์ ํ ์์์ ๋ฐฉํฅ์ผ๋ก ๋์ผํ๊ฒ ํ์ตํ๋ ๋ณ๋์ ์๋ฒ ๋ฉ
๋ฉ์์ง ์๋ฒ ๋ฉ์ ์์ ์ ๐๐์ ๋ํ ๋ฐฉํฅ ์๋ฒ ๋ฉ
DimeNet์ ์ ๋ฐ์ ์ธ ํ์ต๊ณผ์ ์ ์์ ์ํคํ
์ฒ ๊ทธ๋ฆผ์ ํตํด ํ์ธํ ์ ์์ต๋๋ค. ํฌ๊ฒ๋ ์์์ ์ธ๊ธํ RBF์ SBF๋ฅผ ํตํด ์์๊ฐ ๊ฑฐ๋ฆฌ์ ๊ฐ๋ ์ ๋ณด๋ฅผ ์๋ฒ ๋ฉํ๊ณ , ์๋ฒ ๋ฉ ๋ ๊ฐ๊ณผ ๋ฉ์ธ์ง
์คํ์ ์ฐ์ธ ๋ชจ๋ธ์ ๋ค์๊ณผ ๊ฐ์ด ์ค์ ํ์์ต๋๋ค.
- Model baseline: Python Geometric class
- Model initial parameter:
a. hidden_channels=128 -> ๋ง์ด ์ฌ์ฉํ๋ ์๋ฒ ๋ฉ ์ฌ์ด์ฆ ์ด์ฉ
b. out_channels=1 -> ์์ธก๊ฐ 1๊ฐ
c. num_blocks=6 -> Building block ์
d. num_bilinear=8 -> Bilinear tensor ์
e. num_spherical=7 -> spherical harmonics ์
f. num_radial=6 -> radial basis function ์ - Loss function: torch.nn.L1Loss() -> MAE
- Optimizer: ADAM (learning rate: 1e-5, amsgrad ์ด์ฉ)
- Scheduler: Exponential learning rate decay (3000 step ๋น 0.98)
- Stochastic moving average (step ๋น 0.001)
- ํ์ต Epoch ์: 10
ํด๋น ์ค์ ๊ฐ๋ค์
๋ํ DimeNet ํด๋์ค์ ํฌํจ๋ from_qm9_pretrained ๋ฉ์๋๋ฅผ ์ด์ฉํ์ฌ QM9 dataset์์ ์ฐ์ํ ์ฑ๋ฅ์ ๋ณด์ธ ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ฐ์ ํ์ต์ ์์ํ์์ต๋๋ค.
์ถ๊ฐ๋ก, data ์ค x๊ฐ (์์ ๋๋ฒ๋ง)์ H๋ฅผ 0์ผ๋ก ํ์ํ์๋ค๋ ์ ์ด ํ์ต์ ๋ฐฉํด๊ฐ ๋๋ค๊ณ ํ๋จํ์ฌ ๋ชจ๋ ๊ฐ์ 1์ ๋ํ์ฌ ํ์ตํ์์ต๋๋ค. ์ด๋ ์ค์ ๋ก ์ ์ ์์น์ ํจ๊ณผ๊ฐ ์์์ต๋๋ค.
์ต์ข
ํ์ต์ด ์๋ฃ๋ ๋ชจ๋ธ์ MAE score๋ 0,00712์ด์์ต๋๋ค.
๋ณธ ํ๋ก์ ํธ์์ ๋ค๋ฅธ ํ๋ค์ ๋นํด reportํ Loss๊ฐ์ด ๋งค์ฐ ๋ฎ์์ต๋๋ค. ๊ฐ์ฅ ํฐ ์ํฅ์ ์ค ๋ถ๋ถ์ ๋ณธ ํ๋ก์ ํธ์ ํ๊ฐ data๊ฐ qm9 dataset์ subset์ด์๋๋ฐ, ์ด๋ฅผ ์ด๋ฏธ ํ์ต์ ์๋ฃํ pretrained model์ ์ฌ์ฉํ์๊ธฐ ๋๋ฌธ์
๋๋ค. Qm9๊ฐ ํ๊ฐ data์ ๋ค๋ฅธ data๋ผ๋ molecular๋ฅผ ํ ํ๊ณ ๋ชจ๋จผํธ๊ฐ์ ์์ธกํ๋ task์ ๋ํด ํ์ตํ ๊ฒฝํ์ด ์๋ pretrained model์ ์ด์ฉํ๋ฉด ์ฑ๋ฅ์ด ์ข์ ๊ฒ์ผ๋ก ์์ธกํ๊ณ ํด๋น ๋ฉ์๋๋ฅผ ์ฌ์ฉํ๊ฒ ๋์์ต๋๋ค. ๊ทธ๋ฌ๋ data ๊ตฌ์ฑ๊ณผ ๊ทธ ํ์์ด ๋ฌ๋ผ ํ๊ฐ data๊ฐ qm9์ subset data์ธ ๊ฒ์ ์ธ์งํ์ง ๋ชปํ๊ณ test data๋ก pretrain ๋ model์ ์ฌ์ฉํ๊ฒ ๋์์ต๋๋ค. ์ด๋ฌํ ์ ์์ leaderboard์ report ๋ ์ ์๋ ์ฌ๋ฐ๋ฅด๊ฒ ํ์ต ๋ฐ ํ๊ฐ๋ ๋ชจ๋ธ์ ์ ์๋ผ๊ณ ๋ณผ ์ ์์ต๋๋ค.
๊ทธ๋ฌ๋ ๋ณธ ํ๋ก์ ํธ๋ฅผ ์ํํ๋ฉด์ ์ฌ๋ฌ ๋
ผ๋ฌธ์ ์ฝ๊ณ , GNN์์ Message passing์ ๋ค์ํ ์ฌ๋๋ค์ ์์ด๋์ด๊ฐ ๋ค์ด๊ฐ์๋ ๊ฒ์ ํ์ธํ์์ต๋๋ค. ๋ํ์ฌ, DimeNet๊ณผ ๊ฐ์ด domain knowledge๊ฐ ๋ง์ด ์ฐ์ธ ์์ด๋์ด๋ ๊ณง ํด๋น data์์ ๋งค์ฐ ๋์ ์ฑ๋ฅ์ ์คํ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์ ์๊ฒ ๋จ์ ์ ์ ์์์ต๋๋ค. ์ถ๊ฐ๋ก DimeNet ๋
ผ๋ฌธ์ ์์ฑ๋ ํ์ต ์ธํ
๋ฐฉ๋ฒ์ ์ด์ฉํ๊ธฐ ์ํด EMA, SWA ๋ฑ์ ๋ค์ํ scheduling ๊ธฐ๋ฒ์ ๋ํด ํ์ตํ ์ ์์์ต๋๋ค. ์ด๋ ํ์ ๊ผญ GNN์ด ์๋๋๋ผ๋ learning model์ ์ ์ํ ๋ ์ ์ฉํ๊ฒ ์ฌ์ฉํ ์ ์์ ๊ฒ์
๋๋ค.
Data ๊ฐ๊ณต๋ถํฐ ํ์ต ๋ชจ๋ธ ์ ์ , loss function ๋ฐ optimizer ์ ํ, hyperparameter optimization, validation ๋ฐ test๊น์ง ์ ๊ณผ์ ์ ์ง์ ๊ตฌํํด๋ณด๊ณ , ๋ ํผ๋ฐ์ค๋ฅผ ์ฐพ์๋ณด๋ฉฐ ๊ฒฐ๊ณผ๋ฅผ ์ป๊ธฐ ์ํด ๋
ธ๋ ฅํ ๊ณผ์ ์ ํตํด ๋งค์ฐ ๋ง์ ๊ฒ์ ์ป์ ์ ์์๋ ํ๋ก์ ํธ์์ต๋๋ค.
-
Creating Message Passing Networks
https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html
-
Source code for
torch_geometric.nn.conv.message_passing -
J.Gasteiger et al. โDirectional Message Passing for Molecular Graphsโ, ICLR, 2020





