From d86ef7c3a5eefb91ab7174dc36164904b43e6f93 Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Mon, 22 Aug 2022 17:22:01 +0530 Subject: [PATCH] Parameterize fluency, diversity and adequacy model tags --- parrot/parrot.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/parrot/parrot.py b/parrot/parrot.py index f8ec738..e232f56 100644 --- a/parrot/parrot.py +++ b/parrot/parrot.py @@ -1,6 +1,12 @@ class Parrot(): - def __init__(self, model_tag="prithivida/parrot_paraphraser_on_T5", use_gpu=False): + def __init__( + self, + model_tag="prithivida/parrot_paraphraser_on_T5", + adequacy_model="prithivida/parrot_adequacy_model", + fluency_model="prithivida/parrot_fluency_model", + diversity_model="paraphrase-distilroberta-base-v2", + use_gpu=False): from transformers import AutoTokenizer from transformers import AutoModelForSeq2SeqLM import pandas as pd @@ -9,9 +15,9 @@ def __init__(self, model_tag="prithivida/parrot_paraphraser_on_T5", use_gpu=Fals from parrot.filters import Diversity self.tokenizer = AutoTokenizer.from_pretrained(model_tag) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_tag) - self.adequacy_score = Adequacy() - self.fluency_score = Fluency() - self.diversity_score= Diversity() + self.adequacy_score = Adequacy(model_tag=adequacy_model) + self.fluency_score = Fluency(model_tag=fluency_model) + self.diversity_score= Diversity(model_tag=diversity_model) def rephrase(self, input_phrase, use_gpu=False, diversity_ranker="levenshtein", do_diverse=False, style=1, max_length=32, adequacy_threshold = 0.90, fluency_threshold = 0.90): if use_gpu: