Dataset
- torch.util.data 참조 바로가기
Dataset 사용 전 참고사항
- CASE1. 훈련 데이터셋을 메모리에 다 들고 있을 정도로 여유가 있는 경우. (아래의 예제를 참조, Dataframe으로 데이터셋 전체를 들고 있는 방법)
- CASE2. 메모리에 다 들고 있을 수 없을 만큼 데이터의 양이 많은 경우.
- IterableDataset을 참조
Custom Dataset 생성 방법
- torch.util.data.Dataset을 상속
- len 함수 구현
- 전체 데이터셋의 크기를 반환하도록 구현
- init 함수 구현
- Datapoint의 변환메소드나 사전 / Tokenizer 등의 초기화 작업을 진행
- getitem 함수 구현
- 객체[i] 형태처럼 index를 통해 특정 원소를 참조하기 위해 구현하는 함수
- iter 함수 구현 (IterableDataset 에서 필요)
- for 문과 같은 iteration 에서 어떤 순서로 원소를 가져오는 지 구현하는 부분
Transformers 사용하기
일반적으로 Dataset 의 초기화 과정에서는 원본 형태의 Datapoint를 들고 있으며, 이 Datapoint를 torch 의 네트워크에 주입해 주기 위해 Tensor 형태로 바꾸기 위한 일련의 작업 과정이 필요하다. 보통 이러한 데이터 변환을 담당하는 Class 들을 구현한 후 이들의 객체를 리스트로 담아 Dataset의 초기화 인자로 넘겨주어 Dataset에서 리턴하는 모든 datapoint가 동일하게 변환과정을 거치도록 구현한다.
텍스트의 경우, tokenizing / padding / index변환 등의 변환 과정이 구현된 객체가 transformers 인자로 전달된다.
일반적으로 transformer의 마지막 step은 torch 의 Tensor로 변환하는 과정이다.
class ToIdxTransform(object): def __init__(self, vocab_path, max_seq_length = 30, unk_tok = '[UNK]', pad_tok = '[PAD]'): with codecs.open(vocab_path, 'r', 'utf-8') as fin: chars = [i.strip('\n') for i in fin.readlines() if i.strip()] self.chars = chars self.idx_map = dict(zip(self.chars, range(len(self.chars)))) self.max_seq_length = max_seq_length self.unk_tok = unk_tok self.pad_tok = pad_tok assert self.unk_tok in self.idx_map assert self.pad_tok in self.idx_map self.unk_idx = self.idx_map[self.unk_tok] self.pad_idx = self.idx_map[self.pad_tok] def __call__(self, dp): d = dp['document'] char_ids = [self.idx_map.get(c, self.unk_idx) for c in d][:self.max_seq_length] char_ids.extend([self.pad_idx] * (self.max_seq_length - len(char_ids))) dp['document'] = char_ids return dp
Custom Dataset 의 구현 예시
class CharacterSequenceDataset(Dataset):
def __init__(self, df, transformers=None):
self.df = df
self.labels = df['label'].map(lambda x : int(x)).to_list()
self.titles = df['title'].map(lambda x : str(x)).to_list()
self.transformers = transformers
assert len(self.labels) == len(self.titles)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
# Data point
dp = OrderedDict( {
'label' : self.labels[idx],
'title' : self.titles[idx]
})
if self.transformers is not None:
for transformer in self.transformers:
dp = transformer(dp)
return dp
'pytorch' 카테고리의 다른 글
02. model(network) 구현하기 (0) | 2020.10.03 |
---|---|
CrossEntropyLoss in pytorch (feat. log_softmax, nll_loss) (0) | 2020.09.30 |