Skip to content

mminn20/NumericalAnalysis

Repository files navigation

GNN Project

1) Message Passing Neural Network (MPNN)

https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html

Generalizing the convolution operator to irregular domains is typically expressed as a neighborhood aggregation or message passing scheme.


$x_i^{(k)} = \gamma^{(k)} \space (x_i^{(k-1)}, \space \square_{j\in N(i)}\phi^{(k)} \space (x_i^{(k-1)}, \space x_j{(k-1)}, \space e_{j,i}))$

  • $x_i^{(k-1)}$ : denotes node features of node $i$ in layer $(k-1)$
  • $e_{j, i} \in R^D$ : edge features from node $j$ to node $i$ (optional)
  • $\square$ : a differentiable, permutation invariant function, e.g. sum, mean, or max
  • $\gamma$ , $\phi$ : differentiable functions such as MLPs (Multiple Layer Perceptrons)

MPNN variant

$x_i^{(k)}$ $=$ $\gamma^{(k)}$ $(CONCAT[x_i^{(k-1)},$ $\Sigma_{j\in\boxtimes(i)}$ $\phi^{(k)}(e_{j, i}\cdot(x_j^{(k-1)}-x_i^{(k-1)}))])$




1. ๋ฌธ์ œ ์ •์˜

ํ”„๋กœ์ ํŠธ ์ฃผ์ œ๋Š” ๊ทธ๋ž˜ํ”„๋ฅผ ์ด์šฉํ•œ ๋จธ์‹ ๋Ÿฌ๋‹์œผ๋กœ ๋ถ„์ž ๊ตฌ์กฐ์— ๋Œ€ํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šต์‹œ์ผœ ์Œ๊ทน์ž ๋ชจ๋จผํŠธ(ฮผ)๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์‚ฌ์šฉ๋œ ์›์†Œ์ง‘ํ•ฉ์€ {C, H, N, O, F}์ด๊ณ , ๊ฐ ์›์ž๋ฒˆํ˜ธ {6, 1, 7, 8, 9}๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค. ์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ๋Š” ๋ถ„์ž๊ตฌ์กฐ๋ฅผ ๋ฌธ์ž์—ด๋กœ ๋‚˜ํƒ€๋‚ธ SMILES ํ‘œ๊ธฐ, ๋“ฑ๋ฐฉํ–ฅ ๋ถ„๊ทน๋ฅ , XYZ ์ขŒํ‘œ ๋“ฑ์˜ ์ •๋ณด๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•ด ๋ถ„์ž๊ตฌ์กฐ๋ฅผ ๊ทธ๋ž˜ํ”„ ๊ตฌ์กฐ๋กœ ๋งŒ๋“ค์–ด ํƒ€๊ฒŸ ๊ฐ’์ธ ์Œ๊ทน์ž ๋ชจ๋จผํŠธ๋ฅผ ํฌํ•จํ•˜์—ฌ ํ•™์Šตํ•˜๊ณ , ํ•™์Šต์ด ๋๋‚œ ๋ชจ๋ธ์„ ์ด์šฉํ•ด ๊ฐ ๋ถ„์ž์— ๋Œ€ํ•œ ์Œ๊ทน์ž ๋ชจ๋จผํŠธ๋ฅผ ์˜ˆ์ธกํ•˜๊ณ , ์ด ๊ฒฐ๊ณผ๊ฐ€ ์ž˜ ๋ฐ˜์˜๋˜์—ˆ๋Š”์ง€ ๊ฒ€์ฆํ•ฉ๋‹ˆ๋‹ค.


2. ์ดˆ๊ธฐ ๋ฌธ์ œ ์ ‘๊ทผ ๋ฐฉ๋ฒ•

์šฐ์„  ์ดˆ๊ธฐ ํ•™์Šต ๋ชจ๋ธ์€ GNN์˜ ๊ธฐ๋ณธ ํ˜•ํƒœ๋กœ ์ˆ˜์—…์‹œ๊ฐ„์— ๊ฐ€์žฅ ์ค‘์š”ํ•˜๊ฒŒ ๋‹ค๋ฃจ์–ด์กŒ๋˜ Message Passing์„ ์ด์šฉํ•œ Neural Network model์„ ์ƒ๊ฐํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋งˆ์นจ ์ˆ˜์—…์‹œ๊ฐ„ ์‹ค์Šต(Lab3)์—์„œ MPNN Layer๋ฅผ ๊ตฌํ˜„ํ•˜์˜€๊ณ  ๋ณด๋‹ค ์‰ฝ๊ฒŒ ํ”„๋กœ์ ํŠธ๋ฅผ ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ์กฐ๊ต๋‹˜๊ป˜์„œ ์ œ๊ณตํ•ด์ฃผ์‹  ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด mol ๋ฐ์ดํ„ฐ๋ฅผ csv๋กœ ๋ณ€ํ™”ํ•˜์—ฌ ๊ฐ train_list์™€ test_list๋ฅผ ์ƒ์„ฑํ•˜๊ณ , MPNN์„ ์ด์šฉํ•˜์—ฌ ๊ทธ๋ž˜ํ”„ ๊ตฌ์กฐ๋ฅผ ์„ค๊ณ„ํ•˜์—ฌ ํ•™์Šตํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ์ง„ํ–‰ํ–ˆ์Šต๋‹ˆ๋‹ค.

Pytorch geometric์ด ์ œ๊ณตํ•˜๋Š” MessagePassing ํด๋ž˜์Šค๋ฅผ ์ด์šฉํ•˜๊ณ , MyNet ํด๋ž˜์Šค๋ฅผ ํ˜•์„ฑํ•˜์—ฌ torch์—์„œ ์ œ๊ณตํ•˜๋Š” Module์„ ์ด์šฉํ•ด Neural Network ํ•™์Šต ํ™˜๊ฒฝ์„ ๊ตฌ์ถ•ํ–ˆ์Šต๋‹ˆ๋‹ค.

Screenshot 2025-03-17 at 4 20 45โ€ฏAM

PyG์˜ MessagePassing ํด๋ž˜์Šค๋ฅผ ์ด์šฉํ•˜๋ฉด, ์œ„์˜ ์ˆ˜์‹์—์„œ ๋ฉ”์‹œ์ง€ ์ „ํŒŒ๋ฅผ ์ž๋™์œผ๋กœ ์ฒ˜๋ฆฌํ•ด์ฃผ๊ธฐ ๋•Œ๋ฌธ์— ์œ„ ์‹์—์„œ ฮฆ message()์™€ ฮณ update() ํ•จ์ˆ˜๋งŒ ๋”ฐ๋กœ ๊ตฌํ˜„ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

Screenshot 2025-03-17 at 4 22 58โ€ฏAM

๋ณธ Project์˜ task๋Š” Regression task ์ด๋ฏ€๋กœ ๋งˆ์ง€๋ง‰ MPNN Layer์˜ ์ถœ๋ ฅ ์ฑ„๋„์„ 1๊ฐœ๋กœ ์„ค์ •ํ•˜๊ณ , ์Šค์นผ๋ผ ๊ฐ’์„ ์ถœ๋ ฅํ•˜๊ธฐ ์œ„ํ•ด flatten() ๋ฉ”์„œ๋“œ๋กœ ์ƒ์„ฑ๋œ one-dimensional vector์— torch.mean์„ ์ทจํ•˜๊ณ  reshape() ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์Šค์นผ๋ผ ๊ฐ’์„ ๋ฐ˜ํ™˜ํ–ˆ์Šต๋‹ˆ๋‹ค. Output vector ์ „์ฒด๋ฅผ ์ปค๋ฒ„ํ•˜๋Š” ์‚ฌ์ด์ฆˆ(100์œผ๋กœ ์„ค์ •)๋กœ Max pooling์„ ์ด์šฉํ•ด๋ณด๊ธฐ๋„ ํ•˜์˜€์œผ๋‚˜(์œ„ ์‚ฌ์ง„์˜ self.output layer), ์„ฑ๋Šฅ ํ–ฅ์ƒ์— ์–ด๋ ค์›€์ด ์žˆ์–ด torch.mean์„ ์ด์šฉํ•˜๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ ํ›„ ๋ชจ๋ธ ํ•™์Šต ์ตœ์ ํ™”๋ฅผ ์œ„ํ•ด epoch ๊ฐ’์„ ์ฆ๊ฐ€์‹œํ‚ค๋ฉฐ score๋ฅผ ๊พธ์ค€ํžˆ ํ™•์ธํ•˜์˜€์Šต๋‹ˆ๋‹ค. ์•ฝ 150ํšŒ ์ด์ƒ๋ถ€ํ„ฐ๋Š” overfitting์ด ๋ฐœ์ƒํ•˜์—ฌ ์„ฑ๋Šฅ์ด ๊ฐ์†Œํ•˜์˜€๊ณ , epoch๊ฐ’์„ ์ค„์ด๊ณ  aggregation() ํ•จ์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ์ตœ์ ํ™”ํ•˜๊ณ ์ž ํ•˜์˜€์œผ๋‚˜, aggregation() ์ถ”๊ฐ€๋Š” overfitting ์ •๋„๊ฐ€ ๋” ์‹ฌํ•ด์ง€๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.

(์œ„ ๋ชจ๋ธ + epoch 150ํšŒ Score: 0.91624)

์ดํ›„ ์ˆ˜์—…์‹œ๊ฐ„์— ํ•™์Šตํ•˜์˜€๋˜ GCN ๊ตฌํ˜„์—๋„ ๋„์ „ํ•˜์˜€์œผ๋‚˜, pytorch ๋ฌธ๋ฒ• ์˜ค๋ฅ˜ ๋“ฑ์œผ๋กœ ์™„์„ฑ์„ ์‹œํ‚ค์ง€ ๋ชปํ•˜์˜€๊ณ , ์กฐ๊ต๋‹˜๊ป˜์„œ ์˜ฌ๋ ค์ฃผ์‹  DimeNet ๋…ผ๋ฌธ์„ ์ฝ๊ณ  ์ƒˆ๋กœ์šด ๋ฐฉํ–ฅ์„ ์žก๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.


3. ํ•™์Šต GNN ๋ชจ๋ธ: DimeNet

์ฒ˜์Œ ์‚ฌ์šฉํ•œ ํ•™์Šต ๋ชจ๋ธ์€ GNN์˜ ๋Œ€ํ‘œ์ ์ธ ํ”„๋ ˆ์ž„์›Œํฌ์ธ MPNN์ž…๋‹ˆ๋‹ค. GNN ๋ชจ๋ธ์„ ์ด์šฉํ•œ ๋ถ„์ž ์˜ˆ์ธก์—์„œ ๊ทธ๋ž˜ํ”„๋Š” ํ•˜๋‚˜์˜ ๋ถ„์ž, ๋…ธ๋“œ๋Š” ์›์ž, ์—ฃ์ง€๋Š” ๋ฏธ๋ฆฌ ์ •์˜๋œ ๋ถ„์ž ๊ตฌ์กฐ ๊ทธ๋ž˜ํ”„ ๋˜๋Š” ์›์ž ๊ฐ„์˜ ๊ฑฐ๋ฆฌ๋กœ ๊ฒฐ์ •๋ฉ๋‹ˆ๋‹ค. GNN์˜ ๋ฉ”์‹œ์ง€๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ •์˜๋ฉ๋‹ˆ๋‹ค.
Screenshot 2025-03-17 at 4 24 42โ€ฏAM
์ฆ‰, ๋ฉ”์‹œ์ง€ ํ•จ์ˆ˜๋Š” ํƒ€๊ฒŸ ๋…ธ๋“œ์˜ ํ˜„์žฌ ์ƒํƒœ๊ฐ’, ํƒ€๊ฒŸ ๋…ธ๋“œ์˜ ์ด์›ƒ์˜ ํ˜„์žฌ ์ƒํƒœ, ๊ทธ๋ฆฌ๊ณ  ํ•ด๋‹น ๋…ธ๋“œ์™€ ์ด์›ƒ์„ ์—ฐ๊ฒฐํ•˜๋Š” ์—ฃ์ง€์˜ ์ •๋ณด๋ฅผ ํ•ฉํ•˜์—ฌ ํƒ€๊ฒŸ ๋…ธ๋“œ์˜ ๋‹ค์Œ ๋ฉ”์‹œ์ง€๋กœ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

MPNN์—์„œ ๋ฉ”์„ธ์ง€ ํ•จ์ˆ˜์™€ ์—…๋ฐ์ดํŠธ ํ•จ์ˆ˜๋Š” ์•„๋ž˜์˜ ์ˆ˜์‹์œผ๋กœ ์ •์˜๋ฉ๋‹ˆ๋‹ค. ๐‘ฅ!(#%&)๋Š” (๐‘˜ โˆ’ 1) ๋ ˆ์ด์–ด์˜ ๋…ธ๋“œ ๐‘–์˜ ํŠน์„ฑ์„ ๋‚˜ํƒ€๋‚ด๊ณ , โ–ก๋Š” aggregation ํ•จ์ˆ˜, ๐‘’ โˆˆR0 ๋…ธ๋“œ ๐‘—์—์„œ ๋…ธ๋“œ ๐‘–์˜ ๋…ธ๋“œ ํŠน์„ฑ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. MPNN์€ ํšŒ๊ท€(Regression)๋ฅผ ์ด์šฉํ•ด ์—…๋ฐ์ดํŠธ ํ•ฉ๋‹ˆ๋‹ค.

Screenshot 2025-03-17 at 4 25 39โ€ฏAM

๋”ฐ๋ผ์„œ ์ผ๋ฐ˜์ ์ธ non-directional GNN์€ ๋ถ„์ž ์˜ˆ์ธกํ•  ๋•Œ ํšจ๊ณผ์ ์ธ ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง€์ง€๋งŒ, GNN์˜ ์—…๋ฐ์ดํŠธ ๋ฉ”์„ธ์ง€๋Š” ์ด์ „ ์ž„๋ฒ ๋”ฉ ์›์ž์™€ ์›์ž์Œ ๊ฑฐ๋ฆฌ ์ •๋ณด๋กœ๋งŒ ๊ตฌ์„ฑ๋˜๊ณ , ๋ถ„์ž ์˜ˆ์ธก์— ์ค‘์š”ํ•œ ๋ฐฉํ–ฅ ์ •๋ณด์ธ ๊ฒฐํ•ฉ๊ฐ, ํšŒ์ „ ๋“ฑ์—๋Š” ๋…๋ฆฝ๋œ ํŠน์„ฑ์„ ๊ฐ€์ง‘๋‹ˆ๋‹ค.

๋˜ํ•œ, GNN์€ ์ „์ฒด ์›์ž ๊ฐ„ ๊ฑฐ๋ฆฌ์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ๋‹ด์€ ํ–‰๋ ฌ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  cutoff ๊ฑฐ๋ฆฌ ๐‘๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ, ๊ฒฐํ•ฉ ๊ธธ์ด์™€ ์ด์›ƒ ์›์†Œ๊ฐ€ ๊ฐ™์€ ๋ถ„์ž๊ฐ„์˜ cutoff ๊ฑฐ๋ฆฌ๊ฐ€ 2.5ร…๋ณด๋‹ค ์ž‘๊ฑฐ๋‚˜ ๊ฐ™๋‹ค๋ฉด, GNN์€ ์ด ๋ถ„์ž๊ฐ€ ์–ด๋–ค ๊ฒƒ์ธ์ง€ ๊ตฌ๋ถ„ํ•˜์ง€ ๋ชปํ•ฉ๋‹ˆ๋‹ค. GNN(MPNN) ๋ชจ๋ธ์˜ ์ด๋Ÿฌํ•œ ํ•œ๊ณ„์ ์„ ๊ทน๋ณตํ•˜๊ณ ์ž, molecular data์˜ regression task์—์„œ ๋†’์€ ์„ฑ๋Šฅ์„ ๊ฐ€์ง€๋Š” ๋ชจ๋ธ์„ ์ฐพ๊ฒŒ ๋˜์—ˆ๊ณ , ๊ทธ์ค‘์—์„œ๋„ ์กฐ๊ต๋‹˜๊ป˜์„œ ์˜ฌ๋ ค์ฃผ์‹  ๋…ผ๋ฌธ ์ค‘ ๊ฐ€์žฅ ์„ฑ๋Šฅ์ด ์ข‹์•˜๋˜ DimeNet์„ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.

DimeNet์€ ๊ธฐ์กด GNN์˜ ๋ฌธ์ œ์ ์„ ๋ณด์™„ํ•˜๊ธฐ ์œ„ํ•ด ๋ฐฉํ–ฅ ์ •๋ณด, ๊ฒฐํ•ฉ ์‚ฌ์ด์˜ ๊ฐ๋„๋ฅผ ๋‹ด์€ ๋ฉ”์„ธ์ง€๋ฅผ ์›์ž ์‚ฌ์ด์— ์ž„๋ฒ ๋”ฉํ•˜๊ณ , ์ด ๋ฉ”์„ธ์ง€๋ฅผ ์ด์šฉํ•˜์—ฌ ์‹ ๋ขฐ์ „ํŒŒ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๊ฐ™์ด ๊ด€์ธก๋œ ๋…ธ๋“œ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๊ด€์ธก๋˜์ง€ ์•Š์€ ๋…ธ๋“œ์˜ ๋ถ„ํฌ๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค.

DimeNet ๋ฉ”์„ธ์ง€์˜ ํ•ต์‹ฌ์„ ๋ฐฉํ–ฅ ์ž„๋ฒ ๋”ฉ, ๋ฉ”์‹œ์ง€ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ๋‚˜๋ˆ„์—ˆ์Šต๋‹ˆ๋‹ค. ๋จผ์ €, ๋ฐฉํ–ฅ์„ฑ์„ ๊ฐ€์ง„ ์ž„๋ฒ ๋”ฉ์ด ๊ฐ€๋Šฅํ•œ ๊ฒƒ์€ ์›์†Œ์˜ ๋ฌผ๋ฆฌ์  ํŠน์„ฑ๊ฐ’์€ ํšŒ์ „์— ๋ถˆ๋ณ€ํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ์›์ž๊ฐ€ ํšŒ์ „์„ ํ•ด๋„ ์ด์›ƒ ๊ฐ„ ๊ฑฐ๋ฆฌ์™€ ๊ฒฐํ•ฉ๊ฐ์ด ๋ณ€ํ•˜์ง€ ์•Š๋Š” ๊ฒƒ์„ ์ด์šฉํ•˜์—ฌ ์ฐจ๋‹จ๊ฑฐ๋ฆฌ ๋‚ด์˜ ์›์ž๋“ค์ด ์„œ๋กœ ๋ฐ˜์‘ํ•  ๋•Œ๋งŒ ๋ถˆ๋ณ€์„ฑ์ด ๊นจ์ง€๋Š” ๊ฒฝ์šฐ๋ฅผ ๋ชจ๋ธ์— ์ถ”๊ฐ€์ ์œผ๋กœ ๋ฐ˜์˜ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. DimeNet์€ ์›์ž i์™€ ์ด์›ƒ ๋…ธ๋“œ j์— ๋Œ€ํ•ด ๊ฐ ์ธ์ ‘ํ•œ ์›์ž์˜ ๋ฐฉํ–ฅ์œผ๋กœ ๋™์ผํ•˜๊ฒŒ ํ•™์Šตํ•˜๋Š” ๋ณ„๋„์˜ ์ž„๋ฒ ๋”ฉ $$๐‘š_{ji}$$ ๋ฅผ ๋งŒ๋“ค์–ด ๊ตฌํ˜„ํ–ˆ์Šต๋‹ˆ๋‹ค. ์ž„๋ฒ ๋”ฉ $$๐‘š_{ji}$$๋Š” ๋ถ„์ž์™€ ํ•จ๊ป˜ ํšŒ์ „ํ•˜๋ฏ€๋กœ ์ด์›ƒ ๋…ธ๋“œ์™€์˜ ์ƒ๋Œ€์ ์ธ ๋ฐฉํ–ฅ ์ •๋ณด๊ฐ€ ๋ณด์กด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

๋ฉ”์‹œ์ง€ ์ž„๋ฒ ๋”ฉ์€ ์›์ž ์Œ ๐‘—๐‘–์— ๋Œ€ํ•œ ๋ฐฉํ–ฅ ์ž„๋ฒ ๋”ฉ $$๐‘š_{ji}$$๋ฅผ ์›์ž ๐‘—์—์„œ ์›์ž ๐‘–๋กœ ๋ณด๋‚ธ ๋ฉ”์„ธ์ง€๋กœ ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž„์˜์˜ ์›์ž ๐‘–์— ๋Œ€ํ•ด $$๐‘š_{ji}$$ ๋ฉ”์„ธ์ง€ ์ง‘ํ•ฉ์„ ์ด์šฉํ•ด ์ž„๋ฒ  ๋”ฉํ•˜๊ณ  ๋ฉ”์„ธ์ง€ $$๐‘š_{ji}$$๋Š” ์ด์›ƒ ๋…ธ๋“œ๋“ค์—์„œ ๋“ค์–ด์˜ค๋Š” ๋ฐฉํ–ฅ์˜ ๋ฉ”์„ธ์ง€๋กœ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๋ฉ”์„ธ์ง€๋Š” ์—…๋ฐ์ดํŠธ ํ•จ์ˆ˜์™€ aggregation์„ ์ด์šฉํ•ด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ •์˜๋ฉ๋‹ˆ๋‹ค.

Screenshot 2025-03-17 at 4 31 46โ€ฏAM

Screenshot 2025-03-17 at 4 31 20โ€ฏAM

$$๐‘Ž_{SBF}(๐‘†๐‘h๐‘’๐‘Ÿ๐‘–๐‘๐‘Ž๐‘™ ๐ต๐‘’๐‘ ๐‘ ๐‘’๐‘™ ๐น๐‘ข๐‘›๐‘๐‘ก๐‘–๐‘œ๐‘›)$$๋Š” ์›์ž ๊ฒฐํ•ฉ๊ฐ โˆ ๐‘˜๐‘—๐‘–์™€ ๐‘˜์™€ ๐‘—์‚ฌ์ด์˜ ์›์ž๊ฐ„ ๊ฑฐ๋ฆฌ $$๐‘‘_{kj}$$๋ฅผ SBF๋กœ ์ž„๋ฒ ๋”ฉํ•œ ๊ฐ’์ž…๋‹ˆ๋‹ค. ๋˜ํ•œ, $$๐‘’^{(ji)}_{RBF} (๐‘…๐‘Ž๐‘‘๐‘–๐‘Ž๐‘™ ๐ต๐‘’๐‘ ๐‘ ๐‘’๐‘™ ๐น๐‘ข๐‘›๐‘๐‘ก๐‘–๐‘œ๐‘›) $$๋Š” ์›์ž๊ฐ„ ๊ฑฐ๋ฆฌ ๐‘‘ ์˜ ๋ฐฉ์‚ฌํ˜• ๊ธฐ์ € ํ•จ์ˆ˜ ์ž„๋ฒ ๋”ฉ ๊ฐ’์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.

Screenshot 2025-03-17 at 4 35 25โ€ฏAM

DimeNet์˜ ์ „๋ฐ˜์ ์ธ ํ•™์Šต๊ณผ์ •์€ ์œ„์˜ ์•„ํ‚คํ…์ฒ˜ ๊ทธ๋ฆผ์„ ํ†ตํ•ด ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํฌ๊ฒŒ๋Š” ์œ„์—์„œ ์–ธ๊ธ‰ํ•œ RBF์™€ SBF๋ฅผ ํ†ตํ•ด ์›์ž๊ฐ„ ๊ฑฐ๋ฆฌ์™€ ๊ฐ๋„ ์ •๋ณด๋ฅผ ์ž„๋ฒ ๋”ฉํ•˜๊ณ , ์ž„๋ฒ ๋”ฉ ๋œ ๊ฐ’๊ณผ ๋ฉ”์„ธ์ง€ $$๐‘š_{kj}$$๋ฅผ Embedding, Interaction block์— ๋„ฃ์–ด ๋‚˜์˜ค๋Š” ๊ฐ’์„ Aggregationํ•˜์—ฌ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค. ์ด ๋•Œ Embedding block์—์„œ๋Š” ์ค‘์‹ฌ ์›์ž์™€ ์ฃผ๋ณ€ ์›์ž๊ฐ„ ๊ฑฐ๋ฆฌ์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ๋‹ด์€ RBF ์ž„๋ฒ ๋”ฉ ๊ฐ’๊ณผ ์›์ž ์ข…๋ฅ˜๋ฅผ ์ด์šฉํ•˜์—ฌ ์ถœ๋ ฅ๊ฐ’์„ ํ˜•์„ฑํ•˜๊ณ , Interaction block์—์„œ๋Š” RBF์™€ SBF ์ž„๋ฒ ๋”ฉ ๊ฐ’ ๋ชจ๋‘์™€ ์ด์ „ layer๋กœ๋ถ€ํ„ฐ ์ „๋‹ฌ๋œ message๊ฐ’์„ ์ด์šฉํ•˜์—ฌ ์ถœ๋ ฅ๊ฐ’์„ ํ˜•์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•œ ๊ฐœ์˜ Embedding block๊ณผ ์—ฌ๋Ÿฌ ๊ฐœ์˜ Interaction block์—์„œ ๋‚˜์˜จ ์ถœ๋ ฅ๊ฐ’๋“ค์˜ aggregation์ด ์ตœ์ข… Output์ด ๋ฉ๋‹ˆ๋‹ค.


์‹คํ—˜ ๊ฒฐ๊ณผ ๋น„๊ต ๋ถ„์„

์‹คํ—˜์— ์“ฐ์ธ ๋ชจ๋ธ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์„ค์ •ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

  1. Model baseline: Python Geometric class
  2. Model initial parameter:
    a. hidden_channels=128 -> ๋งŽ์ด ์‚ฌ์šฉํ•˜๋Š” ์ž„๋ฒ ๋”ฉ ์‚ฌ์ด์ฆˆ ์ด์šฉ
    b. out_channels=1 -> ์˜ˆ์ธก๊ฐ’ 1๊ฐœ
    c. num_blocks=6 -> Building block ์ˆ˜
    d. num_bilinear=8 -> Bilinear tensor ์ˆ˜
    e. num_spherical=7 -> spherical harmonics ์ˆ˜
    f. num_radial=6 -> radial basis function ์ˆ˜
  3. Loss function: torch.nn.L1Loss() -> MAE
  4. Optimizer: ADAM (learning rate: 1e-5, amsgrad ์ด์šฉ)
  5. Scheduler: Exponential learning rate decay (3000 step ๋‹น 0.98)
  6. Stochastic moving average (step ๋‹น 0.001)
  7. ํ•™์Šต Epoch ์ˆ˜: 10

ํ•ด๋‹น ์„ค์ • ๊ฐ’๋“ค์€ $$๋…ผ๋ฌธ^1$$ Appendix B (Experimental setup)์— ์ž‘์„ฑ๋œ QM9 dataset ์„ฑ๋Šฅ ๋ถ„์„์— ํ™œ์šฉํ•œ hyperparameter๊ฐ’์„ ๊ทธ๋Œ€๋กœ ์ด์šฉํ•˜๋˜, dataset size์˜ ์ฐจ์ด๋กœ ์ธํ•ด ์ผ๋ถ€๋งŒ ์ˆ˜์ •ํ•œ ๊ฐ’์ž…๋‹ˆ๋‹ค.

๋˜ํ•œ DimeNet ํด๋ž˜์Šค์— ํฌํ•จ๋œ from_qm9_pretrained ๋ฉ”์„œ๋“œ๋ฅผ ์ด์šฉํ•˜์—ฌ QM9 dataset์—์„œ ์šฐ์ˆ˜ํ•œ ์„ฑ๋Šฅ์„ ๋ณด์ธ ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๋ฐ›์•„ ํ•™์Šต์„ ์‹œ์ž‘ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

์ถ”๊ฐ€๋กœ, data ์ค‘ x๊ฐ’ (์›์ž ๋„˜๋ฒ„๋ง)์— H๋ฅผ 0์œผ๋กœ ํ‘œ์‹œํ•˜์˜€๋‹ค๋Š” ์ ์ด ํ•™์Šต์— ๋ฐฉํ•ด๊ฐ€ ๋œ๋‹ค๊ณ  ํŒ๋‹จํ•˜์—ฌ ๋ชจ๋“  ๊ฐ’์— 1์„ ๋”ํ•˜์—ฌ ํ•™์Šตํ•˜์˜€์Šต๋‹ˆ๋‹ค. ์ด๋Š” ์‹ค์ œ๋กœ ์ ์ˆ˜ ์ƒ์Šน์— ํšจ๊ณผ๊ฐ€ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.

์ตœ์ข… ํ•™์Šต์ด ์™„๋ฃŒ๋œ ๋ชจ๋ธ์˜ MAE score๋Š” 0,00712์ด์—ˆ์Šต๋‹ˆ๋‹ค.


5. ํ† ์˜

๋ณธ ํ”„๋กœ์ ํŠธ์—์„œ ๋‹ค๋ฅธ ํŒ€๋“ค์— ๋น„ํ•ด reportํ•œ Loss๊ฐ’์ด ๋งค์šฐ ๋‚ฎ์•˜์Šต๋‹ˆ๋‹ค. ๊ฐ€์žฅ ํฐ ์˜ํ–ฅ์„ ์ค€ ๋ถ€๋ถ„์€ ๋ณธ ํ”„๋กœ์ ํŠธ์˜ ํ‰๊ฐ€ data๊ฐ€ qm9 dataset์˜ subset์ด์—ˆ๋Š”๋ฐ, ์ด๋ฅผ ์ด๋ฏธ ํ•™์Šต์„ ์™„๋ฃŒํ•œ pretrained model์„ ์‚ฌ์šฉํ•˜์˜€๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. Qm9๊ฐ€ ํ‰๊ฐ€ data์™€ ๋‹ค๋ฅธ data๋ผ๋„ molecular๋ฅผ ํ•™ ํ•˜๊ณ  ๋ชจ๋จผํŠธ๊ฐ’์„ ์˜ˆ์ธกํ•˜๋Š” task์— ๋Œ€ํ•ด ํ•™์Šตํ•œ ๊ฒฝํ—˜์ด ์žˆ๋Š” pretrained model์„ ์ด์šฉํ•˜๋ฉด ์„ฑ๋Šฅ์ด ์ข‹์„ ๊ฒƒ์œผ๋กœ ์˜ˆ์ธกํ•˜๊ณ  ํ•ด๋‹น ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ data ๊ตฌ์„ฑ๊ณผ ๊ทธ ํ˜•์‹์ด ๋‹ฌ๋ผ ํ‰๊ฐ€ data๊ฐ€ qm9์˜ subset data์ธ ๊ฒƒ์€ ์ธ์ง€ํ•˜์ง€ ๋ชปํ•˜๊ณ  test data๋กœ pretrain ๋œ model์„ ์‚ฌ์šฉํ•˜๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ ์—์„œ leaderboard์— report ๋œ ์ ์ˆ˜๋Š” ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ํ•™์Šต ๋ฐ ํ‰๊ฐ€๋œ ๋ชจ๋ธ์˜ ์ ์ˆ˜๋ผ๊ณ  ๋ณผ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.

๊ทธ๋Ÿฌ๋‚˜ ๋ณธ ํ”„๋กœ์ ํŠธ๋ฅผ ์ˆ˜ํ–‰ํ•˜๋ฉด์„œ ์—ฌ๋Ÿฌ ๋…ผ๋ฌธ์„ ์ฝ๊ณ , GNN์—์„œ Message passing์— ๋‹ค์–‘ํ•œ ์‚ฌ๋žŒ๋“ค์˜ ์•„์ด๋””์–ด๊ฐ€ ๋“ค์–ด๊ฐ€์žˆ๋Š” ๊ฒƒ์„ ํ™•์ธํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋”ํ•˜์—ฌ, DimeNet๊ณผ ๊ฐ™์ด domain knowledge๊ฐ€ ๋งŽ์ด ์“ฐ์ธ ์•„์ด๋””์–ด๋Š” ๊ณง ํ•ด๋‹น data์—์„œ ๋งค์šฐ ๋†’์€ ์„ฑ๋Šฅ์˜ ์‹คํ—˜ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ๊ฒŒ ๋จ์„ ์•Œ ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ์ถ”๊ฐ€๋กœ DimeNet ๋…ผ๋ฌธ์— ์ž‘์„ฑ๋œ ํ•™์Šต ์„ธํŒ… ๋ฐฉ๋ฒ•์„ ์ด์šฉํ•˜๊ธฐ ์œ„ํ•ด EMA, SWA ๋“ฑ์˜ ๋‹ค์–‘ํ•œ scheduling ๊ธฐ๋ฒ•์— ๋Œ€ํ•ด ํ•™์Šตํ•  ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ํ›„์— ๊ผญ GNN์ด ์•„๋‹ˆ๋”๋ผ๋„ learning model์„ ์ œ์ž‘ํ•  ๋•Œ ์œ ์šฉํ•˜๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

Data ๊ฐ€๊ณต๋ถ€ํ„ฐ ํ•™์Šต ๋ชจ๋ธ ์„ ์ •, loss function ๋ฐ optimizer ์„ ํƒ, hyperparameter optimization, validation ๋ฐ test๊นŒ์ง€ ์ „ ๊ณผ์ •์„ ์ง์ ‘ ๊ตฌํ˜„ํ•ด๋ณด๊ณ , ๋ ˆํผ๋Ÿฐ์Šค๋ฅผ ์ฐพ์•„๋ณด๋ฉฐ ๊ฒฐ๊ณผ๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด ๋…ธ๋ ฅํ•œ ๊ณผ์ •์„ ํ†ตํ•ด ๋งค์šฐ ๋งŽ์€ ๊ฒƒ์„ ์–ป์„ ์ˆ˜ ์žˆ์—ˆ๋˜ ํ”„๋กœ์ ํŠธ์˜€์Šต๋‹ˆ๋‹ค.



References

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors