Skip to content

Commit c5f08db

Browse files
authored
Add support for STS-B dataset with unit tests (#1714)
* Add support for STS-B dataset _ unit test * Modify tests + docstring * Add dataset documentation * Add shuffle and sharding
1 parent e631624 commit c5f08db

File tree

4 files changed

+186
-0
lines changed

4 files changed

+186
-0
lines changed

docs/source/datasets.rst

+5
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ SST2
8282

8383
.. autofunction:: SST2
8484

85+
STSB
86+
~~~~
87+
88+
.. autofunction:: STSB
89+
8590
YahooAnswers
8691
~~~~~~~~~~~~
8792

test/datasets/test_stsb.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import os
2+
import random
3+
import tarfile
4+
from collections import defaultdict
5+
from unittest.mock import patch
6+
7+
from parameterized import parameterized
8+
from torchtext.datasets.stsb import STSB
9+
10+
from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode
11+
from ..common.torchtext_test_case import TorchtextTestCase
12+
13+
14+
def _get_mock_dataset(root_dir):
15+
"""
16+
root_dir: directory to the mocked dataset
17+
"""
18+
base_dir = os.path.join(root_dir, "STSB")
19+
temp_dataset_dir = os.path.join(base_dir, "stsbenchmark")
20+
os.makedirs(temp_dataset_dir, exist_ok=True)
21+
22+
seed = 1
23+
mocked_data = defaultdict(list)
24+
for file_name, name in zip(["sts-train.csv", "sts-dev.csv" "sts-test.csv"], ["train", "dev", "test"]):
25+
txt_file = os.path.join(temp_dataset_dir, file_name)
26+
with open(txt_file, "w", encoding="utf-8") as f:
27+
for i in range(5):
28+
label = random.uniform(0, 5)
29+
rand_string_1 = get_random_unicode(seed)
30+
rand_string_2 = get_random_unicode(seed + 1)
31+
rand_string_3 = get_random_unicode(seed + 2)
32+
rand_string_4 = get_random_unicode(seed + 3)
33+
rand_string_5 = get_random_unicode(seed + 4)
34+
dataset_line = (i, label, rand_string_4, rand_string_5)
35+
# append line to correct dataset split
36+
mocked_data[name].append(dataset_line)
37+
f.write(
38+
f"{rand_string_1}\t{rand_string_2}\t{rand_string_3}\t{i}\t{label}\t{rand_string_4}\t{rand_string_5}\n"
39+
)
40+
seed += 1
41+
# case with quotes to test arg `quoting=csv.QUOTE_NONE`
42+
dataset_line = (i, label, rand_string_4, rand_string_5)
43+
# append line to correct dataset split
44+
mocked_data[name].append(dataset_line)
45+
f.write(
46+
f'{rand_string_1}"\t"{rand_string_2}\t{rand_string_3}\t{i}\t{label}\t{rand_string_4}\t{rand_string_5}\n'
47+
)
48+
49+
compressed_dataset_path = os.path.join(base_dir, "Stsbenchmark.tar.gz")
50+
# create tar file from dataset folder
51+
with tarfile.open(compressed_dataset_path, "w:gz") as tar:
52+
tar.add(temp_dataset_dir, arcname="stsbenchmark")
53+
54+
return mocked_data
55+
56+
57+
class TestSTSB(TempDirMixin, TorchtextTestCase):
58+
root_dir = None
59+
samples = []
60+
61+
@classmethod
62+
def setUpClass(cls):
63+
super().setUpClass()
64+
cls.root_dir = cls.get_base_temp_dir()
65+
cls.samples = _get_mock_dataset(cls.root_dir)
66+
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True)
67+
cls.patcher.start()
68+
69+
@classmethod
70+
def tearDownClass(cls):
71+
cls.patcher.stop()
72+
super().tearDownClass()
73+
74+
@parameterized.expand(["train", "dev", "test"])
75+
def test_stsb(self, split):
76+
dataset = STSB(root=self.root_dir, split=split)
77+
78+
samples = list(dataset)
79+
expected_samples = self.samples[split]
80+
for sample, expected_sample in zip_equal(samples, expected_samples):
81+
self.assertEqual(sample, expected_sample)
82+
83+
@parameterized.expand(["train", "dev", "test"])
84+
def test_stsb_split_argument(self, split):
85+
dataset1 = STSB(root=self.root_dir, split=split)
86+
(dataset2,) = STSB(root=self.root_dir, split=(split,))
87+
88+
for d1, d2 in zip_equal(dataset1, dataset2):
89+
self.assertEqual(d1, d2)

torchtext/datasets/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .squad1 import SQuAD1
2121
from .squad2 import SQuAD2
2222
from .sst2 import SST2
23+
from .stsb import STSB
2324
from .udpos import UDPOS
2425
from .wikitext103 import WikiText103
2526
from .wikitext2 import WikiText2
@@ -48,6 +49,7 @@
4849
"SQuAD2": SQuAD2,
4950
"SogouNews": SogouNews,
5051
"SST2": SST2,
52+
"STSB": STSB,
5153
"UDPOS": UDPOS,
5254
"WikiText103": WikiText103,
5355
"WikiText2": WikiText2,

torchtext/datasets/stsb.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import csv
2+
import os
3+
4+
from torchtext._internal.module_utils import is_module_available
5+
from torchtext.data.datasets_utils import (
6+
_create_dataset_directory,
7+
_wrap_split_argument,
8+
)
9+
10+
if is_module_available("torchdata"):
11+
from torchdata.datapipes.iter import FileOpener, IterableWrapper
12+
13+
# we import HttpReader from _download_hooks so we can swap out public URLs
14+
# with interal URLs when the dataset is used within Facebook
15+
from torchtext._download_hooks import HttpReader
16+
17+
18+
URL = "http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz"
19+
20+
MD5 = "4eb0065aba063ef77873d3a9c8088811"
21+
22+
NUM_LINES = {
23+
"train": 5749,
24+
"dev": 1500,
25+
"test": 1379,
26+
}
27+
28+
_PATH = "Stsbenchmark.tar.gz"
29+
30+
DATASET_NAME = "STSB"
31+
32+
_EXTRACTED_FILES = {
33+
"train": os.path.join("stsbenchmark", "sts-train.csv"),
34+
"dev": os.path.join("stsbenchmark", "sts-dev.csv"),
35+
"test": os.path.join("stsbenchmark", "sts-test.csv"),
36+
}
37+
38+
39+
@_create_dataset_directory(dataset_name=DATASET_NAME)
40+
@_wrap_split_argument(("train", "dev", "test"))
41+
def STSB(root, split):
42+
"""STSB Dataset
43+
44+
For additional details refer to https://ixa2.si.ehu.eus/stswiki/index.php/STSbenchmark
45+
46+
Number of lines per split:
47+
- train: 5749
48+
- dev: 1500
49+
- test: 1379
50+
51+
Args:
52+
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache')
53+
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`)
54+
55+
:returns: DataPipe that yields tuple of (index (int), label (float), sentence1 (str), sentence2 (str))
56+
:rtype: (int, float, str, str)
57+
"""
58+
# TODO Remove this after removing conditional dependency
59+
if not is_module_available("torchdata"):
60+
raise ModuleNotFoundError(
61+
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
62+
)
63+
64+
def _filepath_fn(x=_PATH):
65+
return os.path.join(root, os.path.basename(x))
66+
67+
def _extracted_filepath_fn(_=None):
68+
return _filepath_fn(_EXTRACTED_FILES[split])
69+
70+
def _filter_fn(x):
71+
return _EXTRACTED_FILES[split] in x[0]
72+
73+
def _modify_res(x):
74+
return (int(x[3]), float(x[4]), x[5], x[6])
75+
76+
url_dp = IterableWrapper([URL])
77+
cache_compressed_dp = url_dp.on_disk_cache(
78+
filepath_fn=_filepath_fn,
79+
hash_dict={_filepath_fn(URL): MD5},
80+
hash_type="md5",
81+
)
82+
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)
83+
84+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
85+
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(_filter_fn)
86+
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
87+
88+
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
89+
parsed_data = data_dp.parse_csv(delimiter="\t", quoting=csv.QUOTE_NONE).map(_modify_res)
90+
return parsed_data.shuffle().set_shuffle(False).sharding_filter()

0 commit comments

Comments
 (0)