8

[Space-S x KaKR] 그래프 러닝 및 해커톤

SPACE-S

torch_geometric 사용자를 위한 Data 생성 코드 snippet입니다!

근근
2022.09.27 01:51
452
import os
import torch
from torch_geometric.data import Data
import pathlib

# ============= Training Data ============== #
data = torch.load("qm9_train_data.pt") # training data 불러오기

y = data['mu']  # target 값: dipole moment value
num_nodes = data['num_atoms']  # 분자 내 원자 수(==그래프의 노드 갯수)
num_edges = data['num_bonds']  # 분자 내 결합 수(==그래프의 엣지 갯수)
coords = data['x']  # 각 원자의 3d 좌표 값
atomic_numbers = data['atomic_numbers']  # 각 원자의 원자번호
edge = data['edge']  # 엣지 인덱스와 결합 종류

pathlib.Path('processed_train').mkdir(parents=True, exist_ok=True) 

for i in range(len(y)):
    y_s = torch.tensor(y[i], dtype=torch.float)
    num_node_s = num_nodes[i]
    num_edge_s = num_edges[i]
    coord = torch.tensor(coords[i][:num_node_s])  
    atomic_num = torch.tensor(atomic_numbers[i][:num_node_s, :], dtype=torch.long)
    edge_index = torch.tensor(edge[i][:num_edge_s, :2], dtype=torch.long).t()
    edge_attr = torch.tensor(edge[i][:num_edge_s, 2], dtype=torch.long).unsqueeze(1)
    sample = Data(pos=coord, z=atomic_num, y=y_s, edge_index=edge_index, edge_attr=edge_attr)
    torch.save(sample, f'processed_train/data_{i}_train.pt')

# ============= Test Data ============== #
data_test = torch.load("qm9_test_data.pt") # test data 불러오기 

num_nodes = data_test['num_atoms']
num_edges = data_test['num_bonds']
coords = data_test['x']
atomic_numbers = data_test['atomic_numbers']
edge_test = data_test['edge']

pathlib.Path('processed_test').mkdir(parents=True, exist_ok=True) 

for i in range(len(num_nodes)):
    num_node_s = num_nodes[i]
    num_edge_s = num_edges[i]
    coord = torch.tensor(coords[i][:num_node_s])
    edge_index = torch.tensor(edge_test[i][:num_edge_s, :2], dtype=torch.long).t()
    edge_attr = torch.tensor(edge_test[i][:num_edge_s, 2], dtype=torch.long).unsqueeze(1)
    atomic_num = torch.tensor(atomic_numbers[i][:num_node_s, :], dtype=torch.long)
    sample = Data(pos=coord, z=atomic_num, edge_index=edge_index, edge_attr=edge_attr)
    torch.save(sample, f'processed_test/data_{i}_test.pt')   

torch_geometric 라이브러리를 활용하여 학습 및 예측을 진행하고자 하시는 분들을 위해서, 각각의 분자구조를 서로 다른 파일로 저장할 수 있도록 코드를 작성해보았습니다. 궁금하신 사항은 댓글로 달아주시면 제가 아는 선에서 최대한 답변드리도록 노력하겠습니다. 

감사합니다.

8
2개의 댓글
로그인 후 이용해주세요!