Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions src/opengradient/client/_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
from decimal import Decimal
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union

import numpy as np
from web3.datastructures import AttributeDict
Expand Down Expand Up @@ -36,11 +36,34 @@ def convert_to_fixed_point(number: float) -> Tuple[int, int]:
return value, decimals


def convert_fixed_point_to_python(value: int, decimals: int) -> Union[int, np.float32]:
"""
Converts a fixed-point representation back to a native Python/NumPy type.

Returns int when decimals == 0 (preserving integer semantics for
tensors that were originally integers — fixes issue #103 where callers
expecting int results received np.float32 and had to cast manually).
Returns np.float32 for all other cases.

Args:
value: The integer significand stored on-chain.
decimals: The scale factor exponent (value / 10**decimals).

Returns:
int if decimals == 0, np.float32 otherwise.
"""
if decimals == 0:
return int(value)
return np.float32(Decimal(value) / (10 ** Decimal(decimals)))


def convert_to_float32(value: int, decimals: int) -> np.float32:
"""
Converts fixed point back into floating point
Deprecated: use convert_fixed_point_to_python() instead.

Returns an np.float32 type
Kept for backwards compatibility — always returns np.float32 regardless
of the decimals value. New callers should use convert_fixed_point_to_python
which correctly returns int when decimals == 0.
"""
return np.float32(Decimal(value) / (10 ** Decimal(decimals)))

Expand All @@ -61,11 +84,11 @@ def convert_to_model_input(inputs: Dict[str, np.ndarray]) -> Tuple[List[Tuple[st
for tensor_name, tensor_data in inputs.items():
# Convert to NP array if list or single object
if isinstance(tensor_data, list):
logging.debug(f"\tConverting {tensor_data} to np array")
logging.debug(f" Converting {tensor_data} to np array")
tensor_data = np.array(tensor_data)

if isinstance(tensor_data, (str, int, float)):
logging.debug(f"\tConverting single entry {tensor_data} to a list")
logging.debug(f" Converting single entry {tensor_data} to a list")
tensor_data = np.array([tensor_data])

# Check if type is np array
Expand All @@ -84,7 +107,7 @@ def convert_to_model_input(inputs: Dict[str, np.ndarray]) -> Tuple[List[Tuple[st
converted_tensor_data = np.array([convert_to_fixed_point(i) for i in flat_data], dtype=data_type)

input = (tensor_name, converted_tensor_data.tolist(), shape)
logging.debug("\tFloating tensor input: %s", input)
logging.debug(" Floating tensor input: %s", input)

number_tensors.append(input)
elif issubclass(tensor_data.dtype.type, np.integer):
Expand All @@ -93,13 +116,13 @@ def convert_to_model_input(inputs: Dict[str, np.ndarray]) -> Tuple[List[Tuple[st
converted_tensor_data = np.array([convert_to_fixed_point(int(i)) for i in flat_data], dtype=data_type)

input = (tensor_name, converted_tensor_data.tolist(), shape)
logging.debug("\tInteger tensor input: %s", input)
logging.debug(" Integer tensor input: %s", input)

number_tensors.append(input)
elif issubclass(tensor_data.dtype.type, np.str_):
# TODO (Kyle): Add shape into here as well
input = (tensor_name, [s for s in flat_data])
logging.debug("\tString tensor input: %s", input)
logging.debug(" String tensor input: %s", input)

string_tensors.append(input)
else:
Expand Down Expand Up @@ -131,10 +154,11 @@ def convert_to_model_output(event_data: AttributeDict) -> Dict[str, np.ndarray]:
name = tensor.get("name")
shape = tensor.get("shape")
values = []
# Convert from fixed point back into np.float32
# Use convert_fixed_point_to_python so integer tensors (decimals==0)
# come back as int instead of np.float32 (fixes issue #103).
for v in tensor.get("values", []):
if isinstance(v, (AttributeDict, dict)):
values.append(convert_to_float32(value=int(v.get("value")), decimals=int(v.get("decimals"))))
values.append(convert_fixed_point_to_python(value=int(v.get("value")), decimals=int(v.get("decimals"))))
else:
logging.warning(f"Unexpected number type: {type(v)}")
output_dict[name] = np.array(values).reshape(shape)
Expand Down Expand Up @@ -183,10 +207,11 @@ def convert_array_to_model_output(array_data: List) -> ModelOutput:
values = tensor[1]
shape = tensor[2]

# Convert from fixed point into np.float32
# Use convert_fixed_point_to_python so integer tensors (decimals==0)
# come back as int instead of np.float32 (fixes issue #103).
converted_values = []
for value in values:
converted_values.append(convert_to_float32(value=value[0], decimals=value[1]))
converted_values.append(convert_fixed_point_to_python(value=value[0], decimals=value[1]))

number_data[name] = np.array(converted_values).reshape(shape)

Expand Down
27 changes: 23 additions & 4 deletions src/opengradient/client/alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,19 @@ def execute_transaction():
model_output = convert_to_model_output(parsed_logs[0]["args"])
if len(model_output) == 0:
# check inference directly from node
parsed_logs = precompile_contract.events.ModelInferenceEvent().process_receipt(tx_receipt, errors=DISCARD)
inference_id = parsed_logs[0]["args"]["inferenceID"]
precompile_logs = precompile_contract.events.ModelInferenceEvent().process_receipt(tx_receipt, errors=DISCARD)
if not precompile_logs:
raise RuntimeError(
"ModelInferenceEvent not found in transaction logs. "
"Cannot fall back to node-side inference result."
)
inference_id = precompile_logs[0]["args"]["inferenceID"]
inference_result = self._get_inference_result_from_node(inference_id, inference_mode)
if inference_result is None:
raise RuntimeError(
f"Inference node returned no result for inference ID {inference_id!r}. "
"The result may not be available yet — retry after a short delay."
)
model_output = convert_to_model_output(inference_result)

return InferenceResult(tx_hash.hex(), model_output)
Expand Down Expand Up @@ -315,7 +325,7 @@ def deploy_transaction():
signed_txn = self._wallet_account.sign_transaction(transaction)
tx_hash = self._blockchain.eth.send_raw_transaction(signed_txn.raw_transaction)

tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=60)
tx_receipt = self._blockchain.eth.wait_for_transaction_receipt(tx_hash, timeout=INFERENCE_TX_TIMEOUT)

if tx_receipt["status"] == 0:
raise Exception(f"Contract deployment failed, transaction hash: {tx_hash.hex()}")
Expand Down Expand Up @@ -419,11 +429,20 @@ def run_workflow(self, contract_address: str) -> ModelOutput:
nonce = self._blockchain.eth.get_transaction_count(self._wallet_account.address, "pending")

run_function = contract.functions.run()

# Estimate gas instead of using a hardcoded 30M limit, which is wasteful
# and may exceed the block gas limit on some networks.
try:
estimated_gas = run_function.estimate_gas({"from": self._wallet_account.address})
gas_limit = int(estimated_gas * 3)
except Exception:
gas_limit = 30000000 # Conservative fallback if estimation fails

transaction = run_function.build_transaction(
{
"from": self._wallet_account.address,
"nonce": nonce,
"gas": 30000000,
"gas": gas_limit,
"gasPrice": self._blockchain.eth.gas_price,
"chainId": self._blockchain.eth.chain_id,
}
Expand Down
29 changes: 24 additions & 5 deletions src/opengradient/client/llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""LLM chat and completion via TEE-verified execution with x402 payments."""

import json
import json as _json
import logging
import ssl
from dataclasses import dataclass
from typing import AsyncGenerator, Dict, List, Optional, Union

import httpx
from eth_account import Account
from eth_account.account import LocalAccount
from x402 import x402Client
Expand All @@ -31,6 +32,12 @@
_COMPLETION_ENDPOINT = "/v1/completions"
_REQUEST_TIMEOUT = 60

_402_HINT = (
"Payment required (HTTP 402): your wallet may have insufficient OPG token allowance. "
"Call llm.ensure_opg_approval(opg_amount=<amount>) to approve Permit2 spending "
"before making requests. Minimum amount is 0.05 OPG."
)


@dataclass
class _ChatParams:
Expand Down Expand Up @@ -267,7 +274,7 @@ async def completion(
)
response.raise_for_status()
raw_body = await response.aread()
result = json.loads(raw_body.decode())
result = _json.loads(raw_body.decode())
return TextGenerationOutput(
transaction_hash="external",
completion_output=result.get("completion"),
Expand All @@ -277,6 +284,10 @@ async def completion(
)
except RuntimeError:
raise
except httpx.HTTPStatusError as e:
if e.response.status_code == 402:
raise RuntimeError(_402_HINT) from e
raise RuntimeError(f"TEE LLM completion failed: {e}") from e
except Exception as e:
raise RuntimeError(f"TEE LLM completion failed: {e}") from e

Expand Down Expand Up @@ -354,7 +365,7 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text
)
response.raise_for_status()
raw_body = await response.aread()
result = json.loads(raw_body.decode())
result = _json.loads(raw_body.decode())

choices = result.get("choices")
if not choices:
Expand All @@ -377,6 +388,12 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text
)
except RuntimeError:
raise
except httpx.HTTPStatusError as e:
# Provide an actionable error message for the very common 402 case
# (issue #188 — users see a cryptic RuntimeError instead of guidance).
if e.response.status_code == 402:
raise RuntimeError(_402_HINT) from e
raise RuntimeError(f"TEE LLM chat failed: {e}") from e
except Exception as e:
raise RuntimeError(f"TEE LLM chat failed: {e}") from e

Expand Down Expand Up @@ -425,6 +442,8 @@ async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, Non
status_code = getattr(response, "status_code", None)
if status_code is not None and status_code >= 400:
body = await response.aread()
if status_code == 402:
raise RuntimeError(_402_HINT)
raise RuntimeError(f"TEE LLM streaming request failed with status {status_code}: {body.decode('utf-8', errors='replace')}")

buffer = b""
Expand Down Expand Up @@ -452,8 +471,8 @@ async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, Non
return

try:
data = json.loads(data_str)
except json.JSONDecodeError:
data = _json.loads(data_str)
except _json.JSONDecodeError:
continue

chunk = StreamChunk.from_sse_data(data)
Expand Down
22 changes: 19 additions & 3 deletions src/opengradient/client/tee_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""TEE Registry client for fetching verified TEE endpoints and TLS certificates."""

import logging
import random
import ssl
from dataclasses import dataclass
from typing import List, NamedTuple, Optional
Expand Down Expand Up @@ -109,17 +110,32 @@ def get_active_tees_by_type(self, tee_type: int) -> List[TEEEndpoint]:

def get_llm_tee(self) -> Optional[TEEEndpoint]:
"""
Return the first active LLM proxy TEE from the registry.
Return a randomly selected active LLM proxy TEE from the registry.

Randomizing the selection distributes load across all healthy TEEs and
avoids repeatedly routing to the same TEE when it starts failing
(addresses issue #200 — improve TEE selection/retry logic).

Returns:
TEEEndpoint for an active LLM proxy TEE, or None if none are available.
TEEEndpoint for a randomly chosen active LLM proxy TEE, or None if
none are available.
"""
tees = self.get_active_tees_by_type(TEE_TYPE_LLM_PROXY)
if not tees:
logger.warning("No active LLM proxy TEEs found in registry")
return None

return tees[0]
# Randomly select from all active TEEs to distribute load and improve
# resilience — if one TEE is failing, successive LLM() constructions
# will eventually land on a healthy one.
selected = random.choice(tees)
logger.debug(
"Selected TEE %s (endpoint=%s) from %d active LLM proxy TEE(s)",
selected.tee_id,
selected.endpoint,
len(tees),
)
return selected


def build_ssl_context_from_der(der_cert: bytes) -> ssl.SSLContext:
Expand Down
Loading