diff --git a/src/pruna/algorithms/torch_compile/torch_compile.py b/src/pruna/algorithms/torch_compile/torch_compile.py index 36fa15a3..6149be7a 100644 --- a/src/pruna/algorithms/torch_compile/torch_compile.py +++ b/src/pruna/algorithms/torch_compile/torch_compile.py @@ -35,6 +35,7 @@ is_transformers_pipeline_with_causal_lm, ) from pruna.engine.save import SAVE_FUNCTIONS +from pruna.engine.save_artifacts import SAVE_ARTIFACTS_FUNCTIONS from pruna.logging.logger import pruna_logger # This allows for torch compile to use more cache memory to compile the model @@ -187,7 +188,7 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any: # importantly, the torch artifacts saving need to be done *after* the before-compile-save if smash_config["torch_compile_make_portable"]: - smash_config.save_fns.append(SAVE_FUNCTIONS.torch_artifacts.name) + smash_config.save_artifacts_fns.append(SAVE_ARTIFACTS_FUNCTIONS.torch_artifacts.name) return output def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index ba277633..8b0f640d 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -40,7 +40,9 @@ "device_map", "cache_dir", "save_fns", + "save_artifacts_fns", "load_fns", + "load_artifacts_fns", "reapply_after_load", ] @@ -86,6 +88,8 @@ def __init__( self.save_fns: list[str] = [] self.load_fns: list[str] = [] + self.save_artifacts_fns: list[str] = [] + self.load_artifacts_fns: list[str] = [] self.reapply_after_load: dict[str, str | None] = {} self.tokenizer: PreTrainedTokenizerBase | None = None self.processor: ProcessorMixin | None = None @@ -350,6 +354,8 @@ def flush_configuration(self) -> None: # flush also saving / load functionality associated with a specific configuration self.save_fns = [] self.load_fns = [] + self.save_artifacts_fns = [] + self.load_artifacts_fns = [] self.reapply_after_load = {} # reset potentially previously used cache directory diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index b310b593..7679b09a 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -30,6 +30,7 @@ from transformers import pipeline from pruna import SmashConfig +from pruna.engine.load_artifacts import load_artifacts from pruna.engine.utils import load_json_config, move_to_device, set_to_best_available_device from pruna.logging.logger import pruna_logger @@ -56,6 +57,8 @@ def load_pruna_model(model_path: str | Path, **kwargs) -> tuple[Any, SmashConfig """ smash_config = SmashConfig() smash_config.load_from_json(model_path) + if "torch_artifacts" in smash_config.load_fns: + _convert_to_artifact(smash_config, model_path) # since the model was just loaded from a file, we do not need to prepare saving anymore smash_config._prepare_saving = False @@ -64,11 +67,6 @@ def load_pruna_model(model_path: str | Path, **kwargs) -> tuple[Any, SmashConfig if len(smash_config.load_fns) == 0: raise ValueError("Load function has not been set.") - # load torch artifacts if they exist - if LOAD_FUNCTIONS.torch_artifacts.name in smash_config.load_fns: - load_torch_artifacts(model_path, **kwargs) - smash_config.load_fns.remove(LOAD_FUNCTIONS.torch_artifacts.name) - if len(smash_config.load_fns) > 1: pruna_logger.error(f"Load functions not used: {smash_config.load_fns[1:]}") @@ -78,9 +76,41 @@ def load_pruna_model(model_path: str | Path, **kwargs) -> tuple[Any, SmashConfig if any(algorithm is not None for algorithm in smash_config.reapply_after_load.values()): model = resmash_fn(model, smash_config) + # load artifacts (e.g. speed up the warmup or make it more consistent) + load_artifacts(model, model_path, smash_config) + return model, smash_config +def _convert_to_artifact(smash_config: SmashConfig, model_path: str | Path) -> None: + """Convert legacy 'torch_artifacts' entries to the new artifact-based fields. + + This handles older configs that still store 'torch_artifacts' under + 'save_fns' or 'load_fns' instead of the corresponding '*_artifacts_fns' + fields. + """ + updated = False + + if "torch_artifacts" in smash_config.save_fns: + smash_config.save_fns.remove("torch_artifacts") + smash_config.save_artifacts_fns.append("torch_artifacts") + updated = True + + if "torch_artifacts" in smash_config.load_fns: + smash_config.load_fns.remove("torch_artifacts") + smash_config.load_artifacts_fns.append("torch_artifacts") + updated = True + + if updated: + pruna_logger.warning( + "The legacy 'torch_artifacts' save/load function entry in your SmashConfig is deprecated; " + "your config file has been updated automatically. If you downloaded this smashed model, " + "please ask the provider to update their model by loading it once with a recent version " + "of Pruna and re-uploading it." + ) + smash_config.save_to_json(model_path) + + def load_pruna_model_from_pretrained( repo_id: str, revision: Optional[str] = None, @@ -440,23 +470,6 @@ def load_quantized_model(quantized_path: str | Path) -> Any: ) -def load_torch_artifacts(model_path: str | Path, **kwargs) -> None: - """ - Load a torch artifacts from the given model path. - - Parameters - ---------- - model_path : str | Path - The path to the model directory. - **kwargs : Any - Additional keyword arguments to pass to the model loading function. - """ - artifact_path = Path(model_path) / "artifact_bytes.bin" - artifact_bytes = artifact_path.read_bytes() - - torch.compiler.load_cache_artifacts(artifact_bytes) - - def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: """ Load a diffusers model from the given model path. @@ -580,7 +593,6 @@ class LOAD_FUNCTIONS(Enum): # noqa: N801 pickled = partial(load_pickled) hqq = partial(load_hqq) hqq_diffusers = partial(load_hqq_diffusers) - torch_artifacts = partial(load_torch_artifacts) def __call__(self, *args, **kwargs) -> Any: """ diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py new file mode 100644 index 00000000..f79cf5b9 --- /dev/null +++ b/src/pruna/engine/load_artifacts.py @@ -0,0 +1,123 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Any + +import torch + +from pruna.config.smash_config import SmashConfig +from pruna.logging.logger import pruna_logger + + +def load_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Load available artifacts. + + This function is intended to be called after the main model load function. + It loads artifacts specific to different algorithms into the pre-loaded model. + + Parameters + ---------- + model : Any + The model to load the artifacts for. + model_path : str | Path + The directory to load the artifacts from. + smash_config : SmashConfig + The SmashConfig object containing the load and save functions. + + Returns + ------- + None + The function does not return anything. + """ + artifact_fns = getattr(smash_config, "load_artifacts_fns", []) + if not artifact_fns: + return + + for fn_name in artifact_fns: + # Only handle artifact loaders we explicitly know about here. + if fn_name not in LOAD_ARTIFACTS_FUNCTIONS.__members__: + continue + + LOAD_ARTIFACTS_FUNCTIONS[fn_name](model, model_path, smash_config) + + +def load_torch_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Load torch artifacts from the given model path. + + Parameters + ---------- + model : Any + The model to load the artifacts for. + model_path : str | Path + The directory to load the artifacts from. + smash_config : SmashConfig + The SmashConfig object containing the load and save functions. + """ + artifact_path = Path(model_path) / "artifact_bytes.bin" + if not artifact_path.exists(): + pruna_logger.error(f"No torch artifacts found at {artifact_path}; skipping torch artifact loading.") + return + + pruna_logger.info(f"Loading torch artifacts from {artifact_path}") + artifact_bytes = artifact_path.read_bytes() + + torch.compiler.load_cache_artifacts(artifact_bytes) + + +class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 + """ + Enumeration of *artifact* load functions. + + Artifact loaders are functions that are called after the main model load + has completed. They attach additional runtime state to the already-loaded + model (e.g. compilation cache). + + This enum provides callable functions for loading such artifacts. + + Parameters + ---------- + value : callable + The artifact load function to be called. + names : str + The name of the enum member. + module : str + The module where the enum is defined. + qualname : str + The qualified name of the enum. + type : type + The type of the enum. + start : int + The start index for auto-numbering enum values. + boundary : enum.FlagBoundary or None + Boundary handling mode used by the Enum functional API for Flag and + IntFlag enums. + + Examples + -------- + >>> LOAD_ARTIFACTS_FUNCTIONS.torch_artifacts(model, model_path, smash_config) + # Torch artifacts loaded into the current runtime + """ + + torch_artifacts = partial(load_torch_artifacts) + + def __call__(self, *args, **kwargs) -> None: + """Call the underlying load function.""" + if self.value is not None: + self.value(*args, **kwargs) diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index 0c741017..10e40951 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -36,6 +36,7 @@ SAVE_BEFORE_SMASH_CACHE_DIR, ) from pruna.engine.model_checks import get_helpers, is_janus_llamagen_ar +from pruna.engine.save_artifacts import save_artifacts from pruna.engine.utils import determine_dtype, monkeypatch from pruna.logging.logger import pruna_logger @@ -61,10 +62,6 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf if not model_path.exists(): model_path.mkdir(parents=True, exist_ok=True) - if SAVE_FUNCTIONS.torch_artifacts.name in smash_config.save_fns: - save_torch_artifacts(model, model_path, smash_config) - smash_config.save_fns.remove(SAVE_FUNCTIONS.torch_artifacts.name) - # in the case of no specialized save functions, we use the model's original save function if len(smash_config.save_fns) == 0: pruna_logger.debug("Using model's original save function...") @@ -87,11 +84,13 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf else: pruna_logger.debug(f"Several save functions stacked: {smash_config.save_fns}, defaulting to pickled") save_fn = SAVE_FUNCTIONS.pickled - smash_config.load_fns = [LOAD_FUNCTIONS.pickled.name] # execute selected save function save_fn(model, model_path, smash_config) + # save artifacts as well + save_artifacts(model, model_path, smash_config) + # save smash config (includes tokenizer and processor) smash_config.save_to_json(model_path) @@ -460,36 +459,6 @@ def save_component(attr_name: str | None, module: torch.nn.Module, subpaths: lis smash_config.load_fns.append(LOAD_FUNCTIONS.hqq_diffusers.name) -def save_torch_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: - """ - Save the model by saving the torch artifacts. - - Parameters - ---------- - model : Any - The model to save. - model_path : str | Path - The directory to save the model to. - smash_config : SmashConfig - The SmashConfig object containing the save and load functions. - """ - artifacts = torch.compiler.save_cache_artifacts() - - assert artifacts is not None - artifact_bytes, _ = artifacts - - # check if the bytes are empty - if artifact_bytes == b"\x00\x00\x00\x00\x00\x00\x00\x01": - pruna_logger.error( - "Model has not been run before. Please run the model before saving to construct the compilation graph." - ) - - artifact_path = Path(model_path) / "artifact_bytes.bin" - artifact_path.write_bytes(artifact_bytes) - - smash_config.load_fns.append(LOAD_FUNCTIONS.torch_artifacts.name) - - def reapply(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: """ Reapply the model. @@ -543,7 +512,6 @@ class SAVE_FUNCTIONS(Enum): # noqa: N801 hqq_diffusers = partial(save_model_hqq_diffusers) save_before_apply = partial(save_before_apply) reapply = partial(reapply) - torch_artifacts = partial(save_torch_artifacts) def __call__(self, *args, **kwargs) -> None: """ diff --git a/src/pruna/engine/save_artifacts.py b/src/pruna/engine/save_artifacts.py new file mode 100644 index 00000000..e9da3e08 --- /dev/null +++ b/src/pruna/engine/save_artifacts.py @@ -0,0 +1,137 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Any + +import torch + +from pruna.config.smash_config import SmashConfig +from pruna.engine.load_artifacts import LOAD_ARTIFACTS_FUNCTIONS +from pruna.logging.logger import pruna_logger + + +def save_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Save all configured artifacts for a model. + + This function is intended to be called *after* the main model save function + (e.g. :func:`save_pruna_model`). It iterates over + ``smash_config.save_artifacts_fns`` and invokes each corresponding + :class:`SAVE_ARTIFACTS_FUNCTIONS` member. Each artifact saver is independent + and is responsible for appending its own load function(s) to + ``smash_config.load_fns`` as needed. + + Parameters + ---------- + model : Any + The model to save artifacts for. + model_path : str | Path + The directory where the model and its artifacts are saved. + smash_config : SmashConfig + The SmashConfig object containing the artifact save function names in + ``save_artifacts_fns``. + """ + smash_config.load_artifacts_fns.clear() # accumulate as we run the save artifact functions + + artifact_fns = getattr(smash_config, "save_artifacts_fns", []) + for fn_name in artifact_fns: + try: + SAVE_ARTIFACTS_FUNCTIONS[fn_name](model, model_path, smash_config) + except KeyError: + pruna_logger.error( + "Unknown artifact save function '%s' in smash_config.save_artifacts_fns; skipping.", fn_name + ) + + +def save_torch_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Save the model by saving the torch artifacts. + + Parameters + ---------- + model : Any + The model to save. + model_path : str | Path + The directory to save the model to. + smash_config : SmashConfig + The SmashConfig object containing the save and load functions. + """ + artifacts = torch.compiler.save_cache_artifacts() + + assert artifacts is not None + artifact_bytes, _ = artifacts + + # check if the bytes are empty + if artifact_bytes == b"\x00\x00\x00\x00\x00\x00\x00\x01": + pruna_logger.error( + "Model has not been run before. Please run the model before saving to construct the compilation graph." + ) + + artifact_path = Path(model_path) / "artifact_bytes.bin" + artifact_path.write_bytes(artifact_bytes) + + smash_config.load_artifacts_fns.append(LOAD_ARTIFACTS_FUNCTIONS.torch_artifacts.name) + + +class SAVE_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 + """ + Enumeration of *artifact* save functions. + + Artifact savers are called after the main model save function has run. + They produce additional artifacts (e.g. compilation caches) to speed up + warmup or make the inference before and after loading consistent. + + This enum provides callable functions for saving such artifacts. + + Parameters + ---------- + value : callable + The artifact save function to be called. + names : str + The name of the enum member. + module : str + The module where the enum is defined. + qualname : str + The qualified name of the enum. + type : type + The type of the enum. + start : int + The start index for auto-numbering enum values. + boundary : enum.FlagBoundary or None + Boundary handling mode used by the Enum functional API for Flag and + IntFlag enums. + + Examples + -------- + >>> SAVE_ARTIFACTS_FUNCTIONS.torch_artifacts(model, save_path, smash_config) + # Torch artifacts saved alongside the main model + """ + + torch_artifacts = partial(save_torch_artifacts) + + def __call__(self, *args, **kwargs) -> None: + """ + Call the underlying save function. + + Parameters + ---------- + args : Any + The arguments to pass to the save function. + kwargs : Any + The keyword arguments to pass to the save function. + """ + if self.value is not None: + self.value(*args, **kwargs)