From 5506913ca044ada8d27aa4642c80764484424df8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carl=20Thom=C3=A9?= Date: Fri, 20 Jun 2025 15:24:45 +0200 Subject: [PATCH] Download models on first tokenizer call instead of module import --- src/laion_clap/training/data.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/laion_clap/training/data.py b/src/laion_clap/training/data.py index 765db7c..3fd9a62 100644 --- a/src/laion_clap/training/data.py +++ b/src/laion_clap/training/data.py @@ -41,9 +41,9 @@ except ImportError: torchaudio = None -bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") -roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") -bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") +bert_tokenizer = None +roberta_tokenizer = None +bart_tokenizer = None def tokenizer(text, tmodel="roberta", max_length=77): """tokenizer for different models @@ -51,10 +51,12 @@ def tokenizer(text, tmodel="roberta", max_length=77): max_length is default to 77 from the OpenAI CLIP parameters We assume text to be a single string, but it can also be a list of strings """ + global bert_tokenizer, roberta_tokenizer, bart_tokenizer if tmodel == "transformer": return clip_tokenizer(text).squeeze(0) - elif tmodel == "bert": + if bert_tokenizer is None: + bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") result = bert_tokenizer( text, padding="max_length", @@ -63,8 +65,9 @@ def tokenizer(text, tmodel="roberta", max_length=77): return_tensors="pt", ) return {k: v.squeeze(0) for k, v in result.items()} - elif tmodel == "roberta": + if roberta_tokenizer is None: + roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") result = roberta_tokenizer( text, padding="max_length", @@ -73,8 +76,9 @@ def tokenizer(text, tmodel="roberta", max_length=77): return_tensors="pt", ) return {k: v.squeeze(0) for k, v in result.items()} - elif tmodel == "bart": + if bart_tokenizer is None: + bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") result = bart_tokenizer( text, padding="max_length",