Luke Salamone's Blog

Custom PyTorch Collate Function (opens in new tab)

If your Dataset class looks something like class MyDataset(Dataset): # ... boilerplate ... def __getitem__(self, idx): item = self.data[idx] return item['anchor'], item['positive'], item['negative'] your collate function should be def collate_fn(data): anchors, pos, neg = zip(*data) anchors = tokenizer(anchors, return_tensors="pt", padding=True) pos = tokenizer(pos, return_tensors="pt", padding=True) neg = tokenizer(neg, return_tensors="pt", padding=True) return anchors, pos, neg and you can ...

Read the original article
Sign in to keep reading the full article.

Keyboard Shortcuts

Navigation

Next / previous post
j/k
Open post
oorEnter
Preview post
v

Post Actions

Love post
a
Like post
l
Dislike post
d
Undo reaction
u
Save / unsave
s

Recommendations

Add interest / feed
Enter
Not interested
x

Go to

Home
gh
Interests
gi
Feeds
gf
Likes
gl
History
gy
Changelog
gc
Settings
gs
Discover
gb
Search
/

General

Show this help
?
Submit feedback
!
Close modal / unfocus
Esc

Press ? anytime to show this help