Custom PyTorch Collate Function (opens in new tab)
If your Dataset class looks something like class MyDataset(Dataset): # ... boilerplate ... def __getitem__(self, idx): item = self.data[idx] return item['anchor'], item['positive'], item['negative'] your collate function should be def collate_fn(data): anchors, pos, neg = zip(*data) anchors = tokenizer(anchors, return_tensors="pt", padding=True) pos = tokenizer(pos, return_tensors="pt", padding=True) neg = tokenizer(neg, return_tensors="pt", padding=True) return anchors, pos, neg and you can ...
Read the original article