Skip to content

Commit 7272ca0

Browse files
authored
Multilingual Domain Classifier (#363)
* initial commit Signed-off-by: Sarah Yurick <[email protected]> * run black Signed-off-by: Sarah Yurick <[email protected]> * combine with DomainClassifier Signed-off-by: Sarah Yurick <[email protected]> * isort Signed-off-by: Sarah Yurick <[email protected]> * add links Signed-off-by: Sarah Yurick <[email protected]> * add praateek's suggestion Signed-off-by: Sarah Yurick <[email protected]> * add ryan's suggestion Signed-off-by: Sarah Yurick <[email protected]> * update readmes Signed-off-by: Sarah Yurick <[email protected]> * create MultilingualDomainClassifier Signed-off-by: Sarah Yurick <[email protected]> * add api Signed-off-by: Sarah Yurick <[email protected]> --------- Signed-off-by: Sarah Yurick <[email protected]>
1 parent edd6262 commit 7272ca0

File tree

11 files changed

+370
-34
lines changed

11 files changed

+370
-34
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ All of our text pipelines have great multilingual support.
2828
- [Heuristic Filtering](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html)
2929
- Classifier Filtering
3030
- [fastText](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html)
31-
- GPU-Accelerated models: [Domain, Quality, and Safety Classification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html)
31+
- GPU-Accelerated models: [Domain (English and multilingual), Quality, and Safety Classification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html)
3232
- **GPU-Accelerated Deduplication**
3333
- [Exact Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html)
3434
- [Fuzzy Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) via MinHash Locality Sensitive Hashing

docs/user-guide/api/classifiers.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@ Classifiers
55
.. autoclass:: nemo_curator.classifiers.DomainClassifier
66
:members:
77

8+
.. autoclass:: nemo_curator.classifiers.MultilingualDomainClassifier
9+
:members:
10+
811
.. autoclass:: nemo_curator.classifiers.QualityClassifier
912
:members:
1013

1114
.. autoclass:: nemo_curator.classifiers.FineWebEduClassifier
1215
:members:
1316

1417
.. autoclass:: nemo_curator.classifiers.AegisClassifier
15-
:members:
18+
:members:

docs/user-guide/cpuvsgpu.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ The following NeMo Curator modules are GPU based.
6767
* Semantic Deduplication
6868
* Distributed Data Classification
6969

70-
* Domain Classification
70+
* Domain Classification (English and multilingual)
7171
* Quality Classification
7272

7373
GPU modules store the ``DocumentDataset`` using a ``cudf`` backend instead of a ``pandas`` one.

docs/user-guide/distributeddataclassification.rst

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ NeMo Curator provides a module to help users run inference with pre-trained mode
1515
This is achieved by chunking the datasets across multiple computing nodes, each equipped with multiple GPUs, to accelerate the classification task in a distributed manner.
1616
Since the classification of a single text document is independent of other documents within the dataset, we can distribute the workload across multiple nodes and GPUs to perform parallel processing.
1717

18-
Domain, quality, content safety, and educational content models are tasks we include as examples within our module.
18+
Domain (English and multilingual), quality, content safety, and educational content models are tasks we include as examples within our module.
1919

2020
Here, we summarize why each is useful for training an LLM:
2121

2222
- The **Domain Classifier** is useful because it helps the LLM understand the context and specific domain of the input text. Because different domains have different linguistic characteristics and terminologies, an LLM's ability to generate contextually relevant responses can be improved by tailoring training data to a specific domain. Overall, this helps provide more accurate and specialized information.
2323

24+
- The **Multilingual Domain Classifier** is the same as the domain classifier, but has been trained to classify text in 52 languages, including English.
25+
2426
- The **Quality Classifier** is useful for filtering out noisy or low quality data. This allows the model to focus on learning from high quality and informative examples, which contributes to the LLM's robustness and enhances its ability to generate reliable and meaningful outputs. Additionally, quality classification helps mitigate biases and inaccuracies that may arise from poorly curated training data.
2527

2628
- The **AEGIS Safety Models** are essential for filtering harmful or risky content, which is critical for training models that should avoid learning from unsafe data. By classifying content into 13 critical risk categories, AEGIS helps remove harmful or inappropriate data from the training sets, improving the overall ethical and safety standards of the LLM.
@@ -45,7 +47,7 @@ Check out ``nemo_curator.classifiers.base.py`` for reference.
4547
Domain Classifier
4648
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4749

48-
The Domain Classifier is used to categorize text documents into specific domains or subject areas. This is particularly useful for organizing large datasets and tailoring the training data for domain-specific LLMs.
50+
The Domain Classifier is used to categorize English text documents into specific domains or subject areas. This is particularly useful for organizing large datasets and tailoring the training data for domain-specific LLMs.
4951

5052
Let's see how ``DomainClassifier`` works in a small excerpt taken from ``examples/classifiers/domain_example.py``:
5153

@@ -64,6 +66,29 @@ Let's see how ``DomainClassifier`` works in a small excerpt taken from ``example
6466
In this example, the domain classifier is obtained directly from `Hugging Face <https://huggingface.co/nvidia/domain-classifier>`_.
6567
It filters the input dataset to include only documents classified as "Games" or "Sports".
6668

69+
Multilingual Domain Classifier
70+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
71+
72+
The Multilingual Domain Classifier is used to categorize text documents across 52 languages into specific domains or subject areas.
73+
74+
Using the ``MultilingualDomainClassifier`` is very similar to using the ``DomainClassifier`` as described above. Here is an example:
75+
76+
.. code-block:: python
77+
78+
from nemo_curator.classifiers import MultilingualDomainClassifier
79+
80+
files = get_all_files_paths_under("japanese_books_dataset/")
81+
input_dataset = DocumentDataset.read_json(files, backend="cudf")
82+
83+
multilingual_domain_classifier = MultilingualDomainClassifier(
84+
filter_by=["Games", "Sports"],
85+
)
86+
result_dataset = multilingual_domain_classifier(dataset=input_dataset)
87+
88+
result_dataset.to_json("games_and_sports/")
89+
90+
For more information about the multilingual domain classifier, including its supported languages, please see the `nvidia/multilingual-domain-classifier <https://huggingface.co/nvidia/multilingual-domain-classifier>`_ on Hugging Face.
91+
6792
Quality Classifier
6893
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6994

examples/classifiers/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
## Text Classification
22

3-
The Python scripts in this directory demonstrate how to run classification on your text data with each of these 4 classifiers:
3+
The Python scripts in this directory demonstrate how to run classification on your text data with each of these 5 classifiers:
44

55
- Domain Classifier
6+
- Multilingual Domain Classifier
67
- Quality Classifier
78
- AEGIS Safety Models
89
- FineWeb Educational Content Classifier
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import time
17+
18+
from nemo_curator.classifiers import MultilingualDomainClassifier
19+
from nemo_curator.datasets import DocumentDataset
20+
from nemo_curator.utils.distributed_utils import get_client
21+
from nemo_curator.utils.script_utils import ArgumentHelper
22+
23+
24+
def main(args):
25+
global_st = time.time()
26+
27+
# Input can be a string or list
28+
input_file_path = "/path/to/data"
29+
output_file_path = "./"
30+
31+
client_args = ArgumentHelper.parse_client_args(args)
32+
client_args["cluster_type"] = "gpu"
33+
client = get_client(**client_args)
34+
35+
input_dataset = DocumentDataset.read_json(
36+
input_file_path, backend="cudf", add_filename=True
37+
)
38+
39+
multilingual_domain_classifier = MultilingualDomainClassifier(
40+
filter_by=["Games", "Sports"]
41+
)
42+
result_dataset = multilingual_domain_classifier(dataset=input_dataset)
43+
44+
result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)
45+
46+
global_et = time.time()
47+
print(
48+
f"Total time taken for multilingual domain classifier inference: {global_et-global_st} s",
49+
flush=True,
50+
)
51+
52+
client.close()
53+
54+
55+
def attach_args(
56+
parser=argparse.ArgumentParser(
57+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
58+
),
59+
):
60+
argumentHelper = ArgumentHelper(parser)
61+
argumentHelper.add_distributed_classifier_cluster_args()
62+
63+
return argumentHelper.parser
64+
65+
66+
if __name__ == "__main__":
67+
main(attach_args().parse_args())

nemo_curator/classifiers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616

1717
os.environ["RAPIDS_NO_INITIALIZE"] = "1"
1818
from .aegis import AegisClassifier, InstructionDataGuardClassifier
19-
from .domain import DomainClassifier
19+
from .domain import DomainClassifier, MultilingualDomainClassifier
2020
from .fineweb_edu import FineWebEduClassifier
2121
from .quality import QualityClassifier
2222

2323
__all__ = [
2424
"DomainClassifier",
25+
"MultilingualDomainClassifier",
2526
"QualityClassifier",
2627
"AegisClassifier",
2728
"InstructionDataGuardClassifier",

nemo_curator/classifiers/domain.py

Lines changed: 128 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@
2828
from nemo_curator.datasets import DocumentDataset
2929

3030
DOMAIN_IDENTIFIER = "nvidia/domain-classifier"
31+
DOMAIN_BASE_MODEL = "microsoft/deberta-v3-base"
32+
MULTILINGUAL_DOMAIN_IDENTIFIER = "nvidia/multilingual-domain-classifier"
33+
MULTILINGUAL_DOMAIN_BASE_MODEL = "microsoft/mdeberta-v3-base"
3134

3235

3336
@dataclass
3437
class DomainModelConfig:
35-
model: str = "microsoft/deberta-v3-base"
38+
identifier: str = DOMAIN_IDENTIFIER
39+
base_model: str = DOMAIN_BASE_MODEL
3640
fc_dropout: float = 0.2
3741
max_len: int = 512
3842

@@ -49,44 +53,30 @@ def __init__(
4953
if max_mem_gb is None:
5054
max_mem_gb = _get_suggest_memory_for_classifier()
5155

52-
super().__init__(self.config.model, max_mem_gb=max_mem_gb)
56+
super().__init__(self.config.base_model, max_mem_gb=max_mem_gb)
5357

5458
def load_model(self, device: str = "cuda"):
55-
model = HFDeberta.from_pretrained(DOMAIN_IDENTIFIER)
59+
model = HFDeberta.from_pretrained(self.config.identifier)
5660
model.set_autocast(self.autocast)
5761
model = model.to(device)
5862
return model.eval()
5963

6064
def load_tokenizer(self):
61-
return AutoTokenizer.from_pretrained(DOMAIN_IDENTIFIER)
65+
return AutoTokenizer.from_pretrained(self.config.identifier)
6266

6367
def load_config(self):
64-
return AutoConfig.from_pretrained(DOMAIN_IDENTIFIER)
68+
return AutoConfig.from_pretrained(self.config.identifier)
6569

6670

67-
class DomainClassifier(DistributedDataClassifier):
71+
class _DomainClassifier(DistributedDataClassifier):
6872
"""
69-
DomainClassifier is a specialized classifier designed for domain classification tasks, utilizing the
70-
NVIDIA Domain Classifier model (https://huggingface.co/nvidia/domain-classifier). This class is optimized
71-
for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets.
72-
73-
Attributes:
74-
filter_by (list[str], optional): The classes to filter the dataset by.
75-
If None, all classes will be included. Defaults to None.
76-
batch_size (int): The number of samples per batch for inference. Defaults to 256.
77-
text_field (str): The field in the dataset that should be classified.
78-
pred_column (str): The column name where predictions will be stored. Defaults to "domain_pred".
79-
prob_column (str, optional): The column name where prediction probabilities will be stored. Defaults to None.
80-
max_chars (int): The maximum number of characters in each document to consider for classification. Defaults to 2000.
81-
device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda".
82-
autocast (bool): Whether to use mixed precision for faster inference. Defaults to True.
83-
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
84-
it defaults to the available GPU memory minus 4 GB.
85-
73+
Parent class for DomainClassifier and MultilingualDomainClassifier,
74+
since their implementations are almost identical.
8675
"""
8776

8877
def __init__(
8978
self,
79+
multilingual: bool = False,
9080
filter_by: Optional[List[str]] = None,
9181
batch_size: int = 256,
9282
text_field: str = "text",
@@ -97,7 +87,20 @@ def __init__(
9787
autocast: bool = True,
9888
max_mem_gb: Optional[int] = None,
9989
):
100-
config = AutoConfig.from_pretrained(DOMAIN_IDENTIFIER)
90+
self.multilingual = multilingual
91+
92+
if multilingual:
93+
config = AutoConfig.from_pretrained(MULTILINGUAL_DOMAIN_IDENTIFIER)
94+
model_config = DomainModelConfig(
95+
identifier=MULTILINGUAL_DOMAIN_IDENTIFIER,
96+
base_model=MULTILINGUAL_DOMAIN_BASE_MODEL,
97+
)
98+
else:
99+
config = AutoConfig.from_pretrained(DOMAIN_IDENTIFIER)
100+
model_config = DomainModelConfig(
101+
identifier=DOMAIN_IDENTIFIER,
102+
base_model=DOMAIN_BASE_MODEL,
103+
)
101104

102105
self.text_field = text_field
103106
self.prob_column = prob_column
@@ -106,7 +109,7 @@ def __init__(
106109
self.out_dim = len(self.labels)
107110

108111
model = DomainModel(
109-
config=DomainModelConfig, autocast=autocast, max_mem_gb=max_mem_gb
112+
config=model_config, autocast=autocast, max_mem_gb=max_mem_gb
110113
)
111114

112115
super().__init__(
@@ -122,7 +125,11 @@ def __init__(
122125
)
123126

124127
def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:
125-
print("Starting domain classifier inference", flush=True)
128+
if self.multilingual:
129+
print("Starting multilingual domain classifier inference", flush=True)
130+
else:
131+
print("Starting domain classifier inference", flush=True)
132+
126133
df = dataset.df
127134
df = _run_classifier_helper(
128135
df=df,
@@ -135,3 +142,98 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:
135142
prob_col=self.prob_column,
136143
)
137144
return DocumentDataset(df)
145+
146+
147+
class DomainClassifier(_DomainClassifier):
148+
"""
149+
DomainClassifier is a specialized classifier designed for English text domain classification tasks,
150+
utilizing the NVIDIA Domain Classifier (https://huggingface.co/nvidia/domain-classifier) model.
151+
This class is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets.
152+
153+
Attributes:
154+
filter_by (list[str], optional): The classes to filter the dataset by.
155+
If None, all classes will be included. Defaults to None.
156+
batch_size (int): The number of samples per batch for inference. Defaults to 256.
157+
text_field (str): The field in the dataset that should be classified.
158+
pred_column (str): The column name where predictions will be stored. Defaults to "domain_pred".
159+
prob_column (str, optional): The column name where prediction probabilities will be stored. Defaults to None.
160+
max_chars (int): The maximum number of characters in each document to consider for classification. Defaults to 2000.
161+
device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda".
162+
autocast (bool): Whether to use mixed precision for faster inference. Defaults to True.
163+
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
164+
it defaults to the available GPU memory minus 4 GB.
165+
166+
"""
167+
168+
def __init__(
169+
self,
170+
filter_by: Optional[List[str]] = None,
171+
batch_size: int = 256,
172+
text_field: str = "text",
173+
pred_column: str = "domain_pred",
174+
prob_column: Optional[str] = None,
175+
max_chars: int = 2000,
176+
device_type: str = "cuda",
177+
autocast: bool = True,
178+
max_mem_gb: Optional[int] = None,
179+
):
180+
super().__init__(
181+
multilingual=False,
182+
filter_by=filter_by,
183+
batch_size=batch_size,
184+
text_field=text_field,
185+
pred_column=pred_column,
186+
prob_column=prob_column,
187+
max_chars=max_chars,
188+
device_type=device_type,
189+
autocast=autocast,
190+
max_mem_gb=max_mem_gb,
191+
)
192+
193+
194+
class MultilingualDomainClassifier(_DomainClassifier):
195+
"""
196+
MultilingualDomainClassifier is a specialized classifier designed for domain classification tasks,
197+
utilizing the NVIDIA Multilingual Domain Classifier (https://huggingface.co/nvidia/multilingual-domain-classifier) model.
198+
It supports domain classification across 52 languages.
199+
This class is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets.
200+
201+
Attributes:
202+
filter_by (list[str], optional): The classes to filter the dataset by.
203+
If None, all classes will be included. Defaults to None.
204+
batch_size (int): The number of samples per batch for inference. Defaults to 256.
205+
text_field (str): The field in the dataset that should be classified.
206+
pred_column (str): The column name where predictions will be stored. Defaults to "domain_pred".
207+
prob_column (str, optional): The column name where prediction probabilities will be stored. Defaults to None.
208+
max_chars (int): The maximum number of characters in each document to consider for classification. Defaults to 2000.
209+
device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda".
210+
autocast (bool): Whether to use mixed precision for faster inference. Defaults to True.
211+
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
212+
it defaults to the available GPU memory minus 4 GB.
213+
214+
"""
215+
216+
def __init__(
217+
self,
218+
filter_by: Optional[List[str]] = None,
219+
batch_size: int = 256,
220+
text_field: str = "text",
221+
pred_column: str = "domain_pred",
222+
prob_column: Optional[str] = None,
223+
max_chars: int = 2000,
224+
device_type: str = "cuda",
225+
autocast: bool = True,
226+
max_mem_gb: Optional[int] = None,
227+
):
228+
super().__init__(
229+
multilingual=True,
230+
filter_by=filter_by,
231+
batch_size=batch_size,
232+
text_field=text_field,
233+
pred_column=pred_column,
234+
prob_column=prob_column,
235+
max_chars=max_chars,
236+
device_type=device_type,
237+
autocast=autocast,
238+
max_mem_gb=max_mem_gb,
239+
)

0 commit comments

Comments
 (0)