From f37bd55d7c8c696a4f632734fcae0e11d23067bc Mon Sep 17 00:00:00 2001 From: Gaspar Rochette Date: Wed, 14 Jan 2026 15:50:31 +0000 Subject: [PATCH 1/8] refactor: formalize artifact saving --- src/pruna/config/smash_config.py | 4 + src/pruna/engine/load.py | 27 +----- src/pruna/engine/load_artifacts.py | 98 +++++++++++++++++++++ src/pruna/engine/save.py | 43 ++------- src/pruna/engine/save_artifacts.py | 134 +++++++++++++++++++++++++++++ 5 files changed, 248 insertions(+), 58 deletions(-) create mode 100644 src/pruna/engine/load_artifacts.py create mode 100644 src/pruna/engine/save_artifacts.py diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index ba277633..6f3aebfb 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -40,7 +40,9 @@ "device_map", "cache_dir", "save_fns", + "save_artifacts_fns", "load_fns", + "load_artifacts_fns", "reapply_after_load", ] @@ -86,6 +88,8 @@ def __init__( self.save_fns: list[str] = [] self.load_fns: list[str] = [] + self.save_artifacts_fns: list[str] = [] + self.load_artifacts_fns: list[str] = [] self.reapply_after_load: dict[str, str | None] = {} self.tokenizer: PreTrainedTokenizerBase | None = None self.processor: ProcessorMixin | None = None diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index b310b593..a5d853dd 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -30,6 +30,7 @@ from transformers import pipeline from pruna import SmashConfig +from pruna.engine.load_artifacts import load_artifacts from pruna.engine.utils import load_json_config, move_to_device, set_to_best_available_device from pruna.logging.logger import pruna_logger @@ -64,11 +65,6 @@ def load_pruna_model(model_path: str | Path, **kwargs) -> tuple[Any, SmashConfig if len(smash_config.load_fns) == 0: raise ValueError("Load function has not been set.") - # load torch artifacts if they exist - if LOAD_FUNCTIONS.torch_artifacts.name in smash_config.load_fns: - load_torch_artifacts(model_path, **kwargs) - smash_config.load_fns.remove(LOAD_FUNCTIONS.torch_artifacts.name) - if len(smash_config.load_fns) > 1: pruna_logger.error(f"Load functions not used: {smash_config.load_fns[1:]}") @@ -78,6 +74,9 @@ def load_pruna_model(model_path: str | Path, **kwargs) -> tuple[Any, SmashConfig if any(algorithm is not None for algorithm in smash_config.reapply_after_load.values()): model = resmash_fn(model, smash_config) + # load artifacts (e.g. speed up the warmup or make it more consistent) + load_artifacts(model, model_path, smash_config) + return model, smash_config @@ -440,23 +439,6 @@ def load_quantized_model(quantized_path: str | Path) -> Any: ) -def load_torch_artifacts(model_path: str | Path, **kwargs) -> None: - """ - Load a torch artifacts from the given model path. - - Parameters - ---------- - model_path : str | Path - The path to the model directory. - **kwargs : Any - Additional keyword arguments to pass to the model loading function. - """ - artifact_path = Path(model_path) / "artifact_bytes.bin" - artifact_bytes = artifact_path.read_bytes() - - torch.compiler.load_cache_artifacts(artifact_bytes) - - def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: """ Load a diffusers model from the given model path. @@ -580,7 +562,6 @@ class LOAD_FUNCTIONS(Enum): # noqa: N801 pickled = partial(load_pickled) hqq = partial(load_hqq) hqq_diffusers = partial(load_hqq_diffusers) - torch_artifacts = partial(load_torch_artifacts) def __call__(self, *args, **kwargs) -> Any: """ diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py new file mode 100644 index 00000000..1447e0e4 --- /dev/null +++ b/src/pruna/engine/load_artifacts.py @@ -0,0 +1,98 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Any + +import torch + +from pruna.config.smash_config import SmashConfig +from pruna.logging.logger import pruna_logger + + +def load_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Load available artifacts. + + This function is intended to be called after the main model load function. + It loads artifacts specific to different algorithms into the pre-loaded model. + + Parameters + ---------- + model : Any + The model to load the artifacts for. + model_path : str | Path + The directory to load the artifacts from. + smash_config : SmashConfig + The SmashConfig object containing the load and save functions. + + Returns + ------- + None + The function does not return anything. + """ + artifact_fns = getattr(smash_config, "load_artifacts_fns", []) + if not artifact_fns: + return + + for fn_name in artifact_fns: + # Only handle artifact loaders we explicitly know about here. + if fn_name not in LOAD_ARTIFACTS_FUNCTIONS.__members__: + continue + + LOAD_ARTIFACTS_FUNCTIONS[fn_name](model, model_path, smash_config) + + +def load_torch_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Load a torch artifacts from the given model path. + + Parameters + ---------- + model : Any + The model to load the artifacts for. + model_path : str | Path + The directory to load the artifacts from. + smash_config : SmashConfig + The SmashConfig object containing the load and save functions. + """ + artifact_path = Path(model_path) / "artifact_bytes.bin" + if not artifact_path.exists(): + pruna_logger.error(f"No torch artifacts found at {artifact_path}; skipping torch artifact loading.") + return + + pruna_logger.info(f"Loading torch artifacts from {artifact_path}") + artifact_bytes = artifact_path.read_bytes() + + torch.compiler.load_cache_artifacts(artifact_bytes) + + +class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 + """ + Enumeration of *artifact* load functions. + + Artifact loaders are functions that are called *after* the main model load + has completed. They attach additional runtime state to the already-loaded + model (e.g. FP8 scales) or restore global caches. + """ + + torch_artifacts = partial(load_torch_artifacts) + + def __call__(self, *args, **kwargs) -> None: + """Call the underlying load function.""" + if self.value is not None: + self.value(*args, **kwargs) diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index 0c741017..63643a37 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -36,6 +36,7 @@ SAVE_BEFORE_SMASH_CACHE_DIR, ) from pruna.engine.model_checks import get_helpers, is_janus_llamagen_ar +from pruna.engine.save_artifacts import save_artifacts from pruna.engine.utils import determine_dtype, monkeypatch from pruna.logging.logger import pruna_logger @@ -61,9 +62,10 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf if not model_path.exists(): model_path.mkdir(parents=True, exist_ok=True) - if SAVE_FUNCTIONS.torch_artifacts.name in smash_config.save_fns: - save_torch_artifacts(model, model_path, smash_config) - smash_config.save_fns.remove(SAVE_FUNCTIONS.torch_artifacts.name) + # Backward compatibility with torch artifacts save function + if "torch_artifacts" in smash_config.save_fns: + smash_config.save_fns.remove("torch_artifacts") + smash_config.save_artifacts_fns.append("torch_artifacts") # in the case of no specialized save functions, we use the model's original save function if len(smash_config.save_fns) == 0: @@ -87,11 +89,13 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf else: pruna_logger.debug(f"Several save functions stacked: {smash_config.save_fns}, defaulting to pickled") save_fn = SAVE_FUNCTIONS.pickled - smash_config.load_fns = [LOAD_FUNCTIONS.pickled.name] # execute selected save function save_fn(model, model_path, smash_config) + # save artifacts as well + save_artifacts(model, model_path, smash_config) + # save smash config (includes tokenizer and processor) smash_config.save_to_json(model_path) @@ -460,36 +464,6 @@ def save_component(attr_name: str | None, module: torch.nn.Module, subpaths: lis smash_config.load_fns.append(LOAD_FUNCTIONS.hqq_diffusers.name) -def save_torch_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: - """ - Save the model by saving the torch artifacts. - - Parameters - ---------- - model : Any - The model to save. - model_path : str | Path - The directory to save the model to. - smash_config : SmashConfig - The SmashConfig object containing the save and load functions. - """ - artifacts = torch.compiler.save_cache_artifacts() - - assert artifacts is not None - artifact_bytes, _ = artifacts - - # check if the bytes are empty - if artifact_bytes == b"\x00\x00\x00\x00\x00\x00\x00\x01": - pruna_logger.error( - "Model has not been run before. Please run the model before saving to construct the compilation graph." - ) - - artifact_path = Path(model_path) / "artifact_bytes.bin" - artifact_path.write_bytes(artifact_bytes) - - smash_config.load_fns.append(LOAD_FUNCTIONS.torch_artifacts.name) - - def reapply(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: """ Reapply the model. @@ -543,7 +517,6 @@ class SAVE_FUNCTIONS(Enum): # noqa: N801 hqq_diffusers = partial(save_model_hqq_diffusers) save_before_apply = partial(save_before_apply) reapply = partial(reapply) - torch_artifacts = partial(save_torch_artifacts) def __call__(self, *args, **kwargs) -> None: """ diff --git a/src/pruna/engine/save_artifacts.py b/src/pruna/engine/save_artifacts.py new file mode 100644 index 00000000..8907a0f6 --- /dev/null +++ b/src/pruna/engine/save_artifacts.py @@ -0,0 +1,134 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Any + +import torch + +from pruna.config.smash_config import SmashConfig +from pruna.engine.load import LOAD_FUNCTIONS +from pruna.logging.logger import pruna_logger + + +def save_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Save all configured artifacts for a model. + + This function is intended to be called *after* the main model save function + (e.g. :func:`save_pruna_model`). It iterates over + ``smash_config.save_artifacts_fns`` and invokes each corresponding + :class:`SAVE_ARTIFACTS_FUNCTIONS` member. Each artifact saver is independent + and is responsible for appending its own load function(s) to + ``smash_config.load_fns`` as needed. + + Parameters + ---------- + model : Any + The model to save artifacts for. + model_path : str | Path + The directory where the model and its artifacts are saved. + smash_config : SmashConfig + The SmashConfig object containing the artifact save function names in + ``save_artifacts_fns``. + """ + smash_config.load_artifacts_fns.clear() # accumulate as we run the save artifact functions + + artifact_fns = getattr(smash_config, "save_artifacts_fns", []) + for fn_name in artifact_fns: + try: + SAVE_ARTIFACTS_FUNCTIONS[fn_name](model, model_path, smash_config) + except KeyError: + pruna_logger.error( + "Unknown artifact save function '%s' in smash_config.save_artifacts_fns; skipping.", fn_name + ) + + +def save_torch_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Save the model by saving the torch artifacts. + + Parameters + ---------- + model : Any + The model to save. + model_path : str | Path + The directory to save the model to. + smash_config : SmashConfig + The SmashConfig object containing the save and load functions. + """ + artifacts = torch.compiler.save_cache_artifacts() + + assert artifacts is not None + artifact_bytes, _ = artifacts + + # check if the bytes are empty + if artifact_bytes == b"\x00\x00\x00\x00\x00\x00\x00\x01": + pruna_logger.error( + "Model has not been run before. Please run the model before saving to construct the compilation graph." + ) + + artifact_path = Path(model_path) / "artifact_bytes.bin" + artifact_path.write_bytes(artifact_bytes) + + smash_config.load_fns.append(LOAD_FUNCTIONS.torch_artifacts.name) + + +class SAVE_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 + """ + Enumeration of *artifact* save functions. + + Artifact savers are called after the main model save function has run. + They produce additional artifacts (e.g. compilation caches) to speed up + warmup or make the inference before and after loading consistent. + + This enum provides callable functions for saving such artifacts. + + Parameters + ---------- + value : callable + The artifact save function to be called. + names : str + The name of the enum member. + module : str + The module where the enum is defined. + qualname : str + The qualified name of the enum. + type : type + The type of the enum. + start : int + The start index for auto-numbering enum values. + + Examples + -------- + >>> SAVE_ARTIFACTS_FUNCTIONS.torch_artifacts(model, save_path, smash_config) + # Torch artifacts saved alongside the main model + """ + + torch_artifacts = partial(save_torch_artifacts) + + def __call__(self, *args, **kwargs) -> None: + """ + Call the underlying save function. + + Parameters + ---------- + args : Any + The arguments to pass to the save function. + kwargs : Any + The keyword arguments to pass to the save function. + """ + if self.value is not None: + self.value(*args, **kwargs) From 4b74ed21478ec66f1885fc0e04cd993c46e12a2f Mon Sep 17 00:00:00 2001 From: Gaspar Rochette Date: Wed, 14 Jan 2026 16:51:25 +0000 Subject: [PATCH 2/8] fix: load artifacts docstring --- src/pruna/engine/load_artifacts.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index 1447e0e4..de52e483 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -85,9 +85,31 @@ class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 """ Enumeration of *artifact* load functions. - Artifact loaders are functions that are called *after* the main model load + Artifact loaders are functions that are called after the main model load has completed. They attach additional runtime state to the already-loaded - model (e.g. FP8 scales) or restore global caches. + model (e.g. compilation cache). + + This enum provides callable functions for loading such artifacts. + + Parameters + ---------- + value : callable + The artifact load function to be called. + names : str + The name of the enum member. + module : str + The module where the enum is defined. + qualname : str + The qualified name of the enum. + type : type + The type of the enum. + start : int + The start index for auto-numbering enum values. + + Examples + -------- + >>> LOAD_ARTIFACTS_FUNCTIONS.torch_artifacts(model, model_path, smash_config) + # Torch artifacts loaded into the current runtime """ torch_artifacts = partial(load_torch_artifacts) From 86086b5f4c10e99e38907aaf9f6cf76201aba10e Mon Sep 17 00:00:00 2001 From: Gaspar Rochette Date: Wed, 14 Jan 2026 16:57:27 +0000 Subject: [PATCH 3/8] fix: register torch compile save as save artifacts --- src/pruna/algorithms/torch_compile/torch_compile.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pruna/algorithms/torch_compile/torch_compile.py b/src/pruna/algorithms/torch_compile/torch_compile.py index 36fa15a3..6149be7a 100644 --- a/src/pruna/algorithms/torch_compile/torch_compile.py +++ b/src/pruna/algorithms/torch_compile/torch_compile.py @@ -35,6 +35,7 @@ is_transformers_pipeline_with_causal_lm, ) from pruna.engine.save import SAVE_FUNCTIONS +from pruna.engine.save_artifacts import SAVE_ARTIFACTS_FUNCTIONS from pruna.logging.logger import pruna_logger # This allows for torch compile to use more cache memory to compile the model @@ -187,7 +188,7 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any: # importantly, the torch artifacts saving need to be done *after* the before-compile-save if smash_config["torch_compile_make_portable"]: - smash_config.save_fns.append(SAVE_FUNCTIONS.torch_artifacts.name) + smash_config.save_artifacts_fns.append(SAVE_ARTIFACTS_FUNCTIONS.torch_artifacts.name) return output def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: From 42ffaab8205054eafd446d9429eb290cf779b3bf Mon Sep 17 00:00:00 2001 From: Gaspar Rochette Date: Thu, 15 Jan 2026 15:33:07 +0000 Subject: [PATCH 4/8] fix: address bugbot comments --- src/pruna/config/smash_config.py | 2 ++ src/pruna/engine/load.py | 5 +++++ src/pruna/engine/save_artifacts.py | 4 ++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index 6f3aebfb..8b0f640d 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -354,6 +354,8 @@ def flush_configuration(self) -> None: # flush also saving / load functionality associated with a specific configuration self.save_fns = [] self.load_fns = [] + self.save_artifacts_fns = [] + self.load_artifacts_fns = [] self.reapply_after_load = {} # reset potentially previously used cache directory diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index a5d853dd..aa363cc3 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -60,6 +60,11 @@ def load_pruna_model(model_path: str | Path, **kwargs) -> tuple[Any, SmashConfig # since the model was just loaded from a file, we do not need to prepare saving anymore smash_config._prepare_saving = False + # Backward compatibility with torch artifacts save function + if "torch_artifacts" in smash_config.load_fns: + smash_config.load_fns.remove("torch_artifacts") + smash_config.load_artifacts_fns.append("torch_artifacts") + resmash_fn = kwargs.pop("resmash_fn", resmash) if len(smash_config.load_fns) == 0: diff --git a/src/pruna/engine/save_artifacts.py b/src/pruna/engine/save_artifacts.py index 8907a0f6..df899d68 100644 --- a/src/pruna/engine/save_artifacts.py +++ b/src/pruna/engine/save_artifacts.py @@ -19,7 +19,7 @@ import torch from pruna.config.smash_config import SmashConfig -from pruna.engine.load import LOAD_FUNCTIONS +from pruna.engine.load_artifacts import LOAD_ARTIFACTS_FUNCTIONS from pruna.logging.logger import pruna_logger @@ -83,7 +83,7 @@ def save_torch_artifacts(model: Any, model_path: str | Path, smash_config: Smash artifact_path = Path(model_path) / "artifact_bytes.bin" artifact_path.write_bytes(artifact_bytes) - smash_config.load_fns.append(LOAD_FUNCTIONS.torch_artifacts.name) + smash_config.load_artifacts_fns.append(LOAD_ARTIFACTS_FUNCTIONS.torch_artifacts.name) class SAVE_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 From 759a7247141dbbca6bce71fcf9ba2d9a65c37cbf Mon Sep 17 00:00:00 2001 From: Gaspar Rochette Date: Thu, 15 Jan 2026 15:39:43 +0000 Subject: [PATCH 5/8] fix: make typing tighter for artifact loading --- src/pruna/engine/load_artifacts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index de52e483..1d147e23 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -16,15 +16,15 @@ from enum import Enum from functools import partial from pathlib import Path -from typing import Any import torch from pruna.config.smash_config import SmashConfig +from pruna.engine.pruna_model import PrunaModel from pruna.logging.logger import pruna_logger -def load_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: +def load_artifacts(model: PrunaModel, model_path: str | Path, smash_config: SmashConfig) -> None: """ Load available artifacts. @@ -57,7 +57,7 @@ def load_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig LOAD_ARTIFACTS_FUNCTIONS[fn_name](model, model_path, smash_config) -def load_torch_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: +def load_torch_artifacts(model: PrunaModel, model_path: str | Path, smash_config: SmashConfig) -> None: """ Load a torch artifacts from the given model path. From 1ddd4400943389c753de02517026c7f528b63129 Mon Sep 17 00:00:00 2001 From: Gaspar Rochette Date: Wed, 21 Jan 2026 14:13:04 +0000 Subject: [PATCH 6/8] fix: circular import --- src/pruna/engine/load_artifacts.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index 1d147e23..6868493a 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -16,15 +16,15 @@ from enum import Enum from functools import partial from pathlib import Path +from typing import Any import torch from pruna.config.smash_config import SmashConfig -from pruna.engine.pruna_model import PrunaModel from pruna.logging.logger import pruna_logger -def load_artifacts(model: PrunaModel, model_path: str | Path, smash_config: SmashConfig) -> None: +def load_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: """ Load available artifacts. @@ -57,9 +57,9 @@ def load_artifacts(model: PrunaModel, model_path: str | Path, smash_config: Smas LOAD_ARTIFACTS_FUNCTIONS[fn_name](model, model_path, smash_config) -def load_torch_artifacts(model: PrunaModel, model_path: str | Path, smash_config: SmashConfig) -> None: +def load_torch_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: """ - Load a torch artifacts from the given model path. + Load torch artifacts from the given model path. Parameters ---------- From 5742ff2f0d81ed66e1f116ed919872636f5560fa Mon Sep 17 00:00:00 2001 From: Gaspar Rochette Date: Thu, 5 Feb 2026 12:50:55 +0000 Subject: [PATCH 7/8] docs: add deprecation warning to torch_artifact save/load and save fixed smash_config --- src/pruna/engine/load.py | 36 +++++++++++++++++++++++++++++++----- src/pruna/engine/save.py | 5 ----- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index aa363cc3..7679b09a 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -57,14 +57,11 @@ def load_pruna_model(model_path: str | Path, **kwargs) -> tuple[Any, SmashConfig """ smash_config = SmashConfig() smash_config.load_from_json(model_path) + if "torch_artifacts" in smash_config.load_fns: + _convert_to_artifact(smash_config, model_path) # since the model was just loaded from a file, we do not need to prepare saving anymore smash_config._prepare_saving = False - # Backward compatibility with torch artifacts save function - if "torch_artifacts" in smash_config.load_fns: - smash_config.load_fns.remove("torch_artifacts") - smash_config.load_artifacts_fns.append("torch_artifacts") - resmash_fn = kwargs.pop("resmash_fn", resmash) if len(smash_config.load_fns) == 0: @@ -85,6 +82,35 @@ def load_pruna_model(model_path: str | Path, **kwargs) -> tuple[Any, SmashConfig return model, smash_config +def _convert_to_artifact(smash_config: SmashConfig, model_path: str | Path) -> None: + """Convert legacy 'torch_artifacts' entries to the new artifact-based fields. + + This handles older configs that still store 'torch_artifacts' under + 'save_fns' or 'load_fns' instead of the corresponding '*_artifacts_fns' + fields. + """ + updated = False + + if "torch_artifacts" in smash_config.save_fns: + smash_config.save_fns.remove("torch_artifacts") + smash_config.save_artifacts_fns.append("torch_artifacts") + updated = True + + if "torch_artifacts" in smash_config.load_fns: + smash_config.load_fns.remove("torch_artifacts") + smash_config.load_artifacts_fns.append("torch_artifacts") + updated = True + + if updated: + pruna_logger.warning( + "The legacy 'torch_artifacts' save/load function entry in your SmashConfig is deprecated; " + "your config file has been updated automatically. If you downloaded this smashed model, " + "please ask the provider to update their model by loading it once with a recent version " + "of Pruna and re-uploading it." + ) + smash_config.save_to_json(model_path) + + def load_pruna_model_from_pretrained( repo_id: str, revision: Optional[str] = None, diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index 63643a37..10e40951 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -62,11 +62,6 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf if not model_path.exists(): model_path.mkdir(parents=True, exist_ok=True) - # Backward compatibility with torch artifacts save function - if "torch_artifacts" in smash_config.save_fns: - smash_config.save_fns.remove("torch_artifacts") - smash_config.save_artifacts_fns.append("torch_artifacts") - # in the case of no specialized save functions, we use the model's original save function if len(smash_config.save_fns) == 0: pruna_logger.debug("Using model's original save function...") From fc1d31d80f96dd9bab98e3f0a05f3d8327f68fc8 Mon Sep 17 00:00:00 2001 From: Gaspar Rochette Date: Thu, 5 Feb 2026 12:53:27 +0000 Subject: [PATCH 8/8] docs: fix enum docstring --- src/pruna/engine/load_artifacts.py | 3 +++ src/pruna/engine/save_artifacts.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index 6868493a..f79cf5b9 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -105,6 +105,9 @@ class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 The type of the enum. start : int The start index for auto-numbering enum values. + boundary : enum.FlagBoundary or None + Boundary handling mode used by the Enum functional API for Flag and + IntFlag enums. Examples -------- diff --git a/src/pruna/engine/save_artifacts.py b/src/pruna/engine/save_artifacts.py index df899d68..e9da3e08 100644 --- a/src/pruna/engine/save_artifacts.py +++ b/src/pruna/engine/save_artifacts.py @@ -110,6 +110,9 @@ class SAVE_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 The type of the enum. start : int The start index for auto-numbering enum values. + boundary : enum.FlagBoundary or None + Boundary handling mode used by the Enum functional API for Flag and + IntFlag enums. Examples --------