Dataset

Dataset 사용 전 참고사항

  • CASE1. 훈련 데이터셋을 메모리에 다 들고 있을 정도로 여유가 있는 경우. (아래의 예제를 참조, Dataframe으로 데이터셋 전체를 들고 있는 방법)
  • CASE2. 메모리에 다 들고 있을 수 없을 만큼 데이터의 양이 많은 경우.
    • IterableDataset을 참조

Custom Dataset 생성 방법

  1. torch.util.data.Dataset을 상속
  2. len 함수 구현
    • 전체 데이터셋의 크기를 반환하도록 구현
  3. init 함수 구현
    • Datapoint의 변환메소드나 사전 / Tokenizer 등의 초기화 작업을 진행
  4. getitem 함수 구현
    • 객체[i] 형태처럼 index를 통해 특정 원소를 참조하기 위해 구현하는 함수
  5. iter 함수 구현 (IterableDataset 에서 필요)
    • for 문과 같은 iteration 에서 어떤 순서로 원소를 가져오는 지 구현하는 부분

Transformers 사용하기

  • 일반적으로 Dataset 의 초기화 과정에서는 원본 형태의 Datapoint를 들고 있으며, 이 Datapoint를 torch 의 네트워크에 주입해 주기 위해 Tensor 형태로 바꾸기 위한 일련의 작업 과정이 필요하다. 보통 이러한 데이터 변환을 담당하는 Class 들을 구현한 후 이들의 객체를 리스트로 담아 Dataset의 초기화 인자로 넘겨주어 Dataset에서 리턴하는 모든 datapoint가 동일하게 변환과정을 거치도록 구현한다.

  • 텍스트의 경우, tokenizing / padding / index변환 등의 변환 과정이 구현된 객체가 transformers 인자로 전달된다.

  • 일반적으로 transformer의 마지막 step은 torch 의 Tensor로 변환하는 과정이다.

    class ToIdxTransform(object):
      def __init__(self, vocab_path, max_seq_length = 30, unk_tok = '[UNK]', pad_tok = '[PAD]'):
          with codecs.open(vocab_path, 'r', 'utf-8') as fin:
              chars = [i.strip('\n') for i in fin.readlines() if i.strip()]
    
          self.chars = chars
          self.idx_map = dict(zip(self.chars, range(len(self.chars))))
          self.max_seq_length = max_seq_length
    
          self.unk_tok = unk_tok
          self.pad_tok = pad_tok
    
          assert self.unk_tok in self.idx_map
          assert self.pad_tok in self.idx_map
    
          self.unk_idx = self.idx_map[self.unk_tok]
          self.pad_idx = self.idx_map[self.pad_tok]
    
      def __call__(self, dp):
          d = dp['document']
          char_ids = [self.idx_map.get(c, self.unk_idx) for c in d][:self.max_seq_length]
          char_ids.extend([self.pad_idx] * (self.max_seq_length - len(char_ids)))
    
          dp['document'] = char_ids
          return dp

Custom Dataset 의 구현 예시

class CharacterSequenceDataset(Dataset):
    def __init__(self, df, transformers=None):
        self.df = df
        self.labels = df['label'].map(lambda x : int(x)).to_list()
        self.titles = df['title'].map(lambda x : str(x)).to_list()
        self.transformers = transformers

        assert len(self.labels) == len(self.titles)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Data point
        dp = OrderedDict( {
            'label' : self.labels[idx],
            'title' : self.titles[idx]
        })

        if self.transformers is not None:
            for transformer in self.transformers:
                dp = transformer(dp)

        return dp

'pytorch' 카테고리의 다른 글

02. model(network) 구현하기  (0) 2020.10.03
CrossEntropyLoss in pytorch (feat. log_softmax, nll_loss)  (0) 2020.09.30

+ Recent posts