-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathmain_evaluate.py
More file actions
117 lines (89 loc) · 4.16 KB
/
main_evaluate.py
File metadata and controls
117 lines (89 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import fire
import json
import logging
from tqdm import tqdm
import torch
from dataset import Dataset
from model import Model
# = = = = = = = = = = = Logging Setup = = = = = = = = = = = = =
logger = logging.getLogger(__name__)
logging.basicConfig(
format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt = "%m/%d/%Y %H:%M:%S",
level = logging.INFO,
)
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
def do_model_prediction(input_data, model, batch_size):
if batch_size not in [1, -1]:
raise NotImplementedError("Batch size {} not implemented yet".format(batch_size))
if batch_size == -1:
model_predictions = model.generate(input_data)
else:
model_predictions = []
for inputs in tqdm(input_data, leave=False):
outputs = model.generate(inputs)
if isinstance(outputs, list):
model_predictions.extend(outputs)
else:
model_predictions.append(outputs)
return model_predictions
def main(
dataset_name : str = None,
model_name : str = None,
batch_size : int = 1, # it is now a dummy parameter
overwrite : bool = False,
metrics : str = None,
number_of_samples : int = -1,
):
logger.info("= = "*20)
logger.info("Dataset name: {}".format(dataset_name))
logger.info("Model name: {}".format(model_name))
logger.info("Batch size: {}".format(batch_size))
logger.info("Overwrite: {}".format(overwrite))
logger.info("Metrics: {}".format(metrics))
logger.info("Number of samples: {}".format(number_of_samples))
logger.info("= = "*20)
# If the final score log exists, skip the evaluation
if not overwrite and os.path.exists('log/{}/{}_{}_score.json'.format(model_name, dataset_name, metrics)):
logger.info("Evaluation has been done before. Skip the evaluation.")
logger.info("\n\n\n\n\n")
return
if model_name == 'WavLLM_fairseq':
batch_size = -1
logger.info("Batch size is set to -1 for WavLLM_fairseq model.")
dataset = Dataset(dataset_name, number_of_samples)
if overwrite or not os.path.exists('log/{}/{}.json'.format(model_name, dataset_name)):
logger.info("Overwrite is enabled or the results are not found. Try to infer with the model: {}.".format(model_name))
# Load model
model = Model(model_name)
# Specific current dataset name for evaluation
model.dataset_name = dataset.dataset_name
# Infer with model
model_predictions = do_model_prediction(dataset.input_data, model, batch_size=batch_size)
data_with_model_predictions = dataset.dataset_processor.format_model_predictions(dataset.input_data, model_predictions)
# Save the result with predictions
os.makedirs('log/{}'.format(model_name), exist_ok=True)
with open('log/{}/{}.json'.format(model_name, dataset_name), 'w') as f:
json.dump(data_with_model_predictions, f, indent=4, ensure_ascii=False)
data_with_model_predictions = json.load(open('log/{}/{}.json'.format(model_name, dataset_name)))
# Metric evaluation
try:
# Clear the cache to avoid memory leak
logger.info("Clear the cache to avoid memory leak")
del model
torch.cuda.empty_cache()
except:
pass
results = dataset.dataset_processor.compute_score(data_with_model_predictions, metrics=metrics)
# Print the result with metrics
logger.info('= = = = = = = = = = = = = = = = =')
logger.info('Dataset name: {}'.format(dataset_name.upper()))
logger.info('Model name: {}'.format(model_name.upper()))
logger.info(json.dumps({metrics: results[metrics]}, indent=4, ensure_ascii=False))
logger.info('= = = = = = = = = = = = = = = = =')
# Save the scores
with open('log/{}/{}_{}_score.json'.format(model_name, dataset_name, metrics), 'w') as f:
json.dump(results, f, indent=4, ensure_ascii=False)
if __name__ == "__main__":
fire.Fire(main)