diff --git a/src/laion_clap/training/data.py b/src/laion_clap/training/data.py index 765db7c..3fd9a62 100644 --- a/src/laion_clap/training/data.py +++ b/src/laion_clap/training/data.py @@ -41,9 +41,9 @@ except ImportError: torchaudio = None -bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") -roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") -bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") +bert_tokenizer = None +roberta_tokenizer = None +bart_tokenizer = None def tokenizer(text, tmodel="roberta", max_length=77): """tokenizer for different models @@ -51,10 +51,12 @@ def tokenizer(text, tmodel="roberta", max_length=77): max_length is default to 77 from the OpenAI CLIP parameters We assume text to be a single string, but it can also be a list of strings """ + global bert_tokenizer, roberta_tokenizer, bart_tokenizer if tmodel == "transformer": return clip_tokenizer(text).squeeze(0) - elif tmodel == "bert": + if bert_tokenizer is None: + bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") result = bert_tokenizer( text, padding="max_length", @@ -63,8 +65,9 @@ def tokenizer(text, tmodel="roberta", max_length=77): return_tensors="pt", ) return {k: v.squeeze(0) for k, v in result.items()} - elif tmodel == "roberta": + if roberta_tokenizer is None: + roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") result = roberta_tokenizer( text, padding="max_length", @@ -73,8 +76,9 @@ def tokenizer(text, tmodel="roberta", max_length=77): return_tensors="pt", ) return {k: v.squeeze(0) for k, v in result.items()} - elif tmodel == "bart": + if bart_tokenizer is None: + bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") result = bart_tokenizer( text, padding="max_length",