Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/pruna/algorithms/torch_compile/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_transformers_pipeline_with_causal_lm,
)
from pruna.engine.save import SAVE_FUNCTIONS
from pruna.engine.save_artifacts import SAVE_ARTIFACTS_FUNCTIONS
from pruna.logging.logger import pruna_logger

# This allows for torch compile to use more cache memory to compile the model
Expand Down Expand Up @@ -187,7 +188,7 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any:

# importantly, the torch artifacts saving need to be done *after* the before-compile-save
if smash_config["torch_compile_make_portable"]:
smash_config.save_fns.append(SAVE_FUNCTIONS.torch_artifacts.name)
smash_config.save_artifacts_fns.append(SAVE_ARTIFACTS_FUNCTIONS.torch_artifacts.name)
return output

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
Expand Down
6 changes: 6 additions & 0 deletions src/pruna/config/smash_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
"device_map",
"cache_dir",
"save_fns",
"save_artifacts_fns",
"load_fns",
"load_artifacts_fns",
"reapply_after_load",
]

Expand Down Expand Up @@ -86,6 +88,8 @@ def __init__(

self.save_fns: list[str] = []
self.load_fns: list[str] = []
self.save_artifacts_fns: list[str] = []
self.load_artifacts_fns: list[str] = []
self.reapply_after_load: dict[str, str | None] = {}
self.tokenizer: PreTrainedTokenizerBase | None = None
self.processor: ProcessorMixin | None = None
Expand Down Expand Up @@ -350,6 +354,8 @@ def flush_configuration(self) -> None:
# flush also saving / load functionality associated with a specific configuration
self.save_fns = []
self.load_fns = []
self.save_artifacts_fns = []
self.load_artifacts_fns = []
self.reapply_after_load = {}

# reset potentially previously used cache directory
Expand Down
58 changes: 35 additions & 23 deletions src/pruna/engine/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from transformers import pipeline

from pruna import SmashConfig
from pruna.engine.load_artifacts import load_artifacts
from pruna.engine.utils import load_json_config, move_to_device, set_to_best_available_device
from pruna.logging.logger import pruna_logger

Expand All @@ -56,6 +57,8 @@ def load_pruna_model(model_path: str | Path, **kwargs) -> tuple[Any, SmashConfig
"""
smash_config = SmashConfig()
smash_config.load_from_json(model_path)
if "torch_artifacts" in smash_config.load_fns:
_convert_to_artifact(smash_config, model_path)
Copy link
Collaborator Author

@gsprochette gsprochette Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simlang in response to your comment I removed the logic in save.py and added this function:

  • update both the save and load fns
  • save the smash config so that it's fixed for when we remove this logic in a future version after some deprecation time
  • log a warning saying it was updated and asking to warn any remote provider (because those would never receive the warning nor have the update automatically saved)

Let me know if that's to your liking

# since the model was just loaded from a file, we do not need to prepare saving anymore
smash_config._prepare_saving = False

Expand All @@ -64,11 +67,6 @@ def load_pruna_model(model_path: str | Path, **kwargs) -> tuple[Any, SmashConfig
if len(smash_config.load_fns) == 0:
raise ValueError("Load function has not been set.")

# load torch artifacts if they exist
if LOAD_FUNCTIONS.torch_artifacts.name in smash_config.load_fns:
load_torch_artifacts(model_path, **kwargs)
smash_config.load_fns.remove(LOAD_FUNCTIONS.torch_artifacts.name)

if len(smash_config.load_fns) > 1:
pruna_logger.error(f"Load functions not used: {smash_config.load_fns[1:]}")

Expand All @@ -78,9 +76,41 @@ def load_pruna_model(model_path: str | Path, **kwargs) -> tuple[Any, SmashConfig
if any(algorithm is not None for algorithm in smash_config.reapply_after_load.values()):
model = resmash_fn(model, smash_config)

# load artifacts (e.g. speed up the warmup or make it more consistent)
load_artifacts(model, model_path, smash_config)

return model, smash_config


def _convert_to_artifact(smash_config: SmashConfig, model_path: str | Path) -> None:
"""Convert legacy 'torch_artifacts' entries to the new artifact-based fields.

This handles older configs that still store 'torch_artifacts' under
'save_fns' or 'load_fns' instead of the corresponding '*_artifacts_fns'
fields.
"""
updated = False

if "torch_artifacts" in smash_config.save_fns:
smash_config.save_fns.remove("torch_artifacts")
smash_config.save_artifacts_fns.append("torch_artifacts")
updated = True

if "torch_artifacts" in smash_config.load_fns:
smash_config.load_fns.remove("torch_artifacts")
smash_config.load_artifacts_fns.append("torch_artifacts")
updated = True

if updated:
pruna_logger.warning(
"The legacy 'torch_artifacts' save/load function entry in your SmashConfig is deprecated; "
"your config file has been updated automatically. If you downloaded this smashed model, "
"please ask the provider to update their model by loading it once with a recent version "
"of Pruna and re-uploading it."
)
smash_config.save_to_json(model_path)


def load_pruna_model_from_pretrained(
repo_id: str,
revision: Optional[str] = None,
Expand Down Expand Up @@ -440,23 +470,6 @@ def load_quantized_model(quantized_path: str | Path) -> Any:
)


def load_torch_artifacts(model_path: str | Path, **kwargs) -> None:
"""
Load a torch artifacts from the given model path.

Parameters
----------
model_path : str | Path
The path to the model directory.
**kwargs : Any
Additional keyword arguments to pass to the model loading function.
"""
artifact_path = Path(model_path) / "artifact_bytes.bin"
artifact_bytes = artifact_path.read_bytes()

torch.compiler.load_cache_artifacts(artifact_bytes)


def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any:
"""
Load a diffusers model from the given model path.
Expand Down Expand Up @@ -580,7 +593,6 @@ class LOAD_FUNCTIONS(Enum): # noqa: N801
pickled = partial(load_pickled)
hqq = partial(load_hqq)
hqq_diffusers = partial(load_hqq_diffusers)
torch_artifacts = partial(load_torch_artifacts)

def __call__(self, *args, **kwargs) -> Any:
"""
Expand Down
123 changes: 123 additions & 0 deletions src/pruna/engine/load_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from enum import Enum
from functools import partial
from pathlib import Path
from typing import Any

import torch

from pruna.config.smash_config import SmashConfig
from pruna.logging.logger import pruna_logger


def load_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None:
"""
Load available artifacts.

This function is intended to be called after the main model load function.
It loads artifacts specific to different algorithms into the pre-loaded model.

Parameters
----------
model : Any
The model to load the artifacts for.
model_path : str | Path
The directory to load the artifacts from.
smash_config : SmashConfig
The SmashConfig object containing the load and save functions.

Returns
-------
None
The function does not return anything.
"""
artifact_fns = getattr(smash_config, "load_artifacts_fns", [])
if not artifact_fns:
return

for fn_name in artifact_fns:
# Only handle artifact loaders we explicitly know about here.
if fn_name not in LOAD_ARTIFACTS_FUNCTIONS.__members__:
continue

LOAD_ARTIFACTS_FUNCTIONS[fn_name](model, model_path, smash_config)


def load_torch_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None:
"""
Load torch artifacts from the given model path.

Parameters
----------
model : Any
The model to load the artifacts for.
model_path : str | Path
The directory to load the artifacts from.
smash_config : SmashConfig
The SmashConfig object containing the load and save functions.
"""
artifact_path = Path(model_path) / "artifact_bytes.bin"
if not artifact_path.exists():
pruna_logger.error(f"No torch artifacts found at {artifact_path}; skipping torch artifact loading.")
return

pruna_logger.info(f"Loading torch artifacts from {artifact_path}")
artifact_bytes = artifact_path.read_bytes()

torch.compiler.load_cache_artifacts(artifact_bytes)


class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801
"""
Enumeration of *artifact* load functions.

Artifact loaders are functions that are called after the main model load
has completed. They attach additional runtime state to the already-loaded
model (e.g. compilation cache).

This enum provides callable functions for loading such artifacts.

Parameters
----------
value : callable
The artifact load function to be called.
names : str
The name of the enum member.
module : str
The module where the enum is defined.
qualname : str
The qualified name of the enum.
type : type
The type of the enum.
start : int
The start index for auto-numbering enum values.
boundary : enum.FlagBoundary or None
Boundary handling mode used by the Enum functional API for Flag and
IntFlag enums.

Examples
--------
>>> LOAD_ARTIFACTS_FUNCTIONS.torch_artifacts(model, model_path, smash_config)
# Torch artifacts loaded into the current runtime
"""

torch_artifacts = partial(load_torch_artifacts)

def __call__(self, *args, **kwargs) -> None:
"""Call the underlying load function."""
if self.value is not None:
self.value(*args, **kwargs)
40 changes: 4 additions & 36 deletions src/pruna/engine/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
SAVE_BEFORE_SMASH_CACHE_DIR,
)
from pruna.engine.model_checks import get_helpers, is_janus_llamagen_ar
from pruna.engine.save_artifacts import save_artifacts
from pruna.engine.utils import determine_dtype, monkeypatch
from pruna.logging.logger import pruna_logger

Expand All @@ -61,10 +62,6 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf
if not model_path.exists():
model_path.mkdir(parents=True, exist_ok=True)

if SAVE_FUNCTIONS.torch_artifacts.name in smash_config.save_fns:
save_torch_artifacts(model, model_path, smash_config)
smash_config.save_fns.remove(SAVE_FUNCTIONS.torch_artifacts.name)

# in the case of no specialized save functions, we use the model's original save function
if len(smash_config.save_fns) == 0:
pruna_logger.debug("Using model's original save function...")
Expand All @@ -87,11 +84,13 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf
else:
pruna_logger.debug(f"Several save functions stacked: {smash_config.save_fns}, defaulting to pickled")
save_fn = SAVE_FUNCTIONS.pickled
smash_config.load_fns = [LOAD_FUNCTIONS.pickled.name]

# execute selected save function
save_fn(model, model_path, smash_config)

# save artifacts as well
save_artifacts(model, model_path, smash_config)

# save smash config (includes tokenizer and processor)
smash_config.save_to_json(model_path)

Expand Down Expand Up @@ -460,36 +459,6 @@ def save_component(attr_name: str | None, module: torch.nn.Module, subpaths: lis
smash_config.load_fns.append(LOAD_FUNCTIONS.hqq_diffusers.name)


def save_torch_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None:
"""
Save the model by saving the torch artifacts.

Parameters
----------
model : Any
The model to save.
model_path : str | Path
The directory to save the model to.
smash_config : SmashConfig
The SmashConfig object containing the save and load functions.
"""
artifacts = torch.compiler.save_cache_artifacts()

assert artifacts is not None
artifact_bytes, _ = artifacts

# check if the bytes are empty
if artifact_bytes == b"\x00\x00\x00\x00\x00\x00\x00\x01":
pruna_logger.error(
"Model has not been run before. Please run the model before saving to construct the compilation graph."
)

artifact_path = Path(model_path) / "artifact_bytes.bin"
artifact_path.write_bytes(artifact_bytes)

smash_config.load_fns.append(LOAD_FUNCTIONS.torch_artifacts.name)


def reapply(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None:
"""
Reapply the model.
Expand Down Expand Up @@ -543,7 +512,6 @@ class SAVE_FUNCTIONS(Enum): # noqa: N801
hqq_diffusers = partial(save_model_hqq_diffusers)
save_before_apply = partial(save_before_apply)
reapply = partial(reapply)
torch_artifacts = partial(save_torch_artifacts)

def __call__(self, *args, **kwargs) -> None:
"""
Expand Down
Loading