Skip to content

Custom/gradients dispatch#632

Open
nfeybesse wants to merge 11 commits intotensorflow:masterfrom
nfeybesse:custom/gradients-dispatch
Open

Custom/gradients dispatch#632
nfeybesse wants to merge 11 commits intotensorflow:masterfrom
nfeybesse:custom/gradients-dispatch

Conversation

@nfeybesse
Copy link
Contributor

  1. Initial Problem

You wanted to register multiple custom gradients in Java using
TensorFlow.registerCustomGradient(...).

Observed symptom:

After registering a few gradients (≈ 5–10),

TFJ_RegisterCustomGradient(opType, adapter) received adapter_ptr = 0 on the C++ side,

which resulted in:

either a refusal to register the gradient,

or a SIGSEGV later during backpropagation.

Key observation:

If the “important” gradient was registered first, it worked.

Subsequent ones failed → this was a cumulative issue, not related to the specific op.

  1. Actual Root Cause

It was not:

a JNI signature bug,

an InfoMap issue,

nor a casting or ABI problem.

👉 The real cause was a limitation in JavaCPP FunctionPointer callbacks:

each TFJ_GradFuncAdapter allocates a native thunk,

after a certain number of such allocations, JavaCPP silently passes a null pointer (0),

the TensorFlow C++ runtime then receives an invalid callback pointer.

👉 Conclusion:
Creating one native callback per gradient is not scalable.

  1. Principle of the Definitive Fix

Instead of:

1 gradient = 1 native callback

We switched to:

1 single native callback

with dispatching in Java based on opType

This is exactly how TensorFlow does it in Python on the C++ side.

  1. Final Architecture
    A. A Single Native Callback (Singleton)

A single TFJ_GradFuncAdapter instance

Registered with TensorFlow C++ for all ops

As a result:

no more adapter_ptr = 0

no practical limit on the number of custom gradients

B. Java-side Dispatch by opType

A Java dispatcher selects the correct gradient during backpropagation:

TensorFlow C++

CustomGradFunc (C++)

TFJ_GradFuncAdapter.call(...)

DispatchingGradientAdapter.apply(...)

CustomGradient / RawCustomGradient for the corresponding op

  1. Proper Handling of Visibility Constraints
    Problem

NativeScope and Ops have package-private constructors

They are only accessible from org.tensorflow.op

Solution

DispatchingGradientAdapter is package-private and lives in org.tensorflow.op

A public GradientDispatch class acts as a bridge

TensorFlow.java only sees the public TFJ_GradFuncAdapter type

➡️ This strictly respects TensorFlow Java’s internal design, with no hacks.

  1. Correct Support for “NoGradient”
    Problem

Returning null on the Java side caused a NullPointerException

The native code did not correctly support TF_Output.oper == nullptr

Fixes

Java side (AbstractGradientAdapter):

null is now translated into:

TF_Output { oper = nullptr, index = 0 }

C++ side (CustomGradFunc):

out.oper == nullptr is interpreted as NoGradient

No dangerous dereference

No crashes / no SIGSEGV

  1. Cleanup of the C++ Bridge (CustomGradFunc)

Applied corrections:

Removed a double loop that was adding gradients twice

Consistent handling of NoGradient

Single, safe memory deallocation (free(outputs))

Preserved defensive hardening:

checks on num_outputs

outputs == nullptr

etc.

  1. Final State
    What now works

✔ Registering dozens (or hundreds) of custom gradients

✔ Registration order no longer matters

✔ No more adapter_ptr = 0

✔ No JNI crashes / no SIGSEGV

✔ Proper support for partial gradients (NoGradient)

✔ Architecture aligned with native TensorFlow

What was avoided

❌ Fragile JavaCPP patches

❌ Dependency on internal allocation details

❌ Workarounds based on registration order

  1. In One Sentence

We replaced a non-scalable architecture (“N gradients = N native callbacks”) with a scalable one (“1 native callback + Java dispatch”), while properly fixing NoGradient handling and strictly respecting TensorFlow Java’s internal constraints.

unordered_map<string, TFJ_GradFuncAdapter> g_grad_func_adapters;

// Cast helper (inspired by TF C-API)
template <typename T, typename U>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix this diff to remove all the formatting changes so we can see just the functional changes to CustomGradFunc?

return false;
}

bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the formatting to reduce the diff.

@Craigacp
Copy link
Collaborator

This looks like a fairly complicated fix to work around a bug in JavaCPP? Is it not better to fix it there?

@nfeybesse
Copy link
Contributor Author

Thanks for the question — it’s a fair concern.

This change is indeed a workaround for a limitation in JavaCPP (bytedeco/javacpp#1205), where multiple native callbacks of the same kind cannot be reliably registered and invoked. In practice, only the last registered gradient adapter survives, which makes it impossible to support more than one Java custom gradient per process.

Fixing this directly in JavaCPP would be ideal in theory, but in practice it is not a viable short- or medium-term option for TensorFlow Java:

The issue is deep in JavaCPP’s native callback and lifetime management.

TensorFlow Java depends on JavaCPP as an external project, and cannot reasonably block feature development or correctness fixes on changes there.

Even with a JavaCPP fix, TensorFlow Java would still need a stable, deterministic way to manage gradient dispatch per op type.

For these reasons, this PR follows the same architectural pattern already used by TensorFlow itself.

TensorFlow Python does not register one native callback per op.
Instead, it registers a single C++ gradient hook and performs runtime dispatch based on the op type (via the gradient registry). In other words, Python also uses a centralized dispatcher rather than relying on multiple independent native callbacks.

This PR mirrors that design on the Java side:

A single native CustomGradFunc is registered with TensorFlow.

That function dispatches to the appropriate Java gradient implementation based on op_type.

This avoids the JavaCPP limitation entirely, while matching TensorFlow’s own gradient architecture.

As a result, the solution is:

robust and deterministic,

consistent with TensorFlow’s Python design,

backward-compatible,

and does not require changes to JavaCPP or TensorFlow C++.

In short: while the root cause is a JavaCPP limitation, centralizing gradient dispatch is not a hack — it is the same model TensorFlow already uses, adapted to the Java runtime constraints.

@nfeybesse
Copy link
Contributor Author

bytedeco/javacpp#648

@nfeybesse nfeybesse force-pushed the custom/gradients-dispatch branch from d7bc382 to 8d80312 Compare February 11, 2026 10:30

final String opType = operation.type();

RawCustomGradient rg = raw.get(opType);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic prefers raw gradients over typed ones, but there isn't anything documented about why it prefers them or if it makes sense to add both raw and typed gradients for the same op. It would be good to clarify this, and if it doesn't make sense to have both kinds of gradients the adapter should reject them in the puts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add javadoc to the top of this class noting the overall purpose of it (to provide Java side dispatching for gradients mirroring TF-Python), that it only accepts either raw or typed gradients for a given op, and that it rejects duplicate assignments.

…, remove inline ifs, add license headers and imports

- Prevent dual registration of raw and typed gradients for the same op type
- Use putIfAbsent and explicit exceptions to avoid silent overwrites
- Replace inline if statements in tfj_gradients_impl.cc with brace blocks
- Add Apache 2.0 headers to new files
- Replace fully-qualified GradientDispatch reference with import
@nfeybesse
Copy link
Contributor Author

Thanks for the review!

I’ve pushed an update that:

  • Enforces mutual exclusion between raw and typed gradient registrations
  • Prevents silent overwrites via putIfAbsent
  • Replaces inline if statements with brace blocks in tfj_gradients_impl.cc
  • Adds the standard Apache 2.0 headers
  • Uses an import for GradientDispatch

Let me know if anything else should be adjusted.

Document purpose as Java-side gradient dispatcher mirroring TF-Python,
clarify raw vs typed gradient registration contract, and note duplicate
registration rejection.
@Craigacp
Copy link
Collaborator

I'm trying to figure out why it needs to make a fresh anonymous subclass of TFJ_GradFuncAdapter rather than returning the subclass it already has. While I was doing that I noticed that there isn't actually a test that the original bug is fixed, by adding more than 10 gradients to a session and checking they all exist. Can you add such a test?

@nfeybesse
Copy link
Contributor Author

Good question.

The reason we return a fresh anonymous subclass of TFJ_GradFuncAdapter (instead of reusing the original subclass instance) is related to how FunctionPointer instances are managed on the JVM side and how their native address is captured.

TFJ_GradFuncAdapter (see

TFJ_GradFuncAdapter

) extends FunctionPointer. Each instance is associated with a native trampoline pointer allocated by JavaCPP. That native pointer is what TensorFlow C actually stores internally when we register the gradient.

If we were to reuse the same adapter instance:

We would share a single native function pointer across multiple gradient registrations.

The lifetime of that pointer would be tied to the original Java object.

In some scenarios (especially with multiple graphs or repeated registrations), this can lead to subtle issues:

pointer reuse across registrations,

unexpected deallocation / GC interaction,

or native-side bookkeeping assuming distinct callbacks per registration.

By returning a fresh anonymous subclass each time, we guarantee:

A distinct FunctionPointer instance.

A distinct native trampoline pointer.

No accidental sharing of callback state across gradient registrations.

In other words, the anonymous subclass is not about Java polymorphism — it's about forcing allocation of a fresh native callback binding, so the C side never sees a reused function pointer instance.

That said, if you think it's clearer, I can refactor the code to make that intent explicit (e.g. by adding a short comment explaining that we deliberately allocate a new FunctionPointer instance per registration to avoid native pointer reuse).

nfeybesse added a commit to nfeybesse/tensorflow that referenced this pull request Feb 24, 2026
…ence

This adds a regression test for PR tensorflow#632.

The test dynamically discovers op types with no registered gradient
(using TF_GetAllOpList + TensorFlow.hasGradient), registers 11 custom
gradients, and verifies that all are present in the native gradient
registry.

This directly validates that registering more than 10 gradients works
and that all entries are correctly stored in the native registry,
without relying on Graph.addGradients() execution.

Addresses reviewer comment about missing test for >10 gradients.
opType,
(tf, op, gradInputs) -> {
int n = op.numInputs();
java.util.ArrayList<org.tensorflow.Operand<?>> grads = new java.util.ArrayList<>(n);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayList is imported, don't fully qualify it.

@Craigacp
Copy link
Collaborator

I thought the point of this change was that there was only ever a single native function pointer, which points to the DispatchingGradientAdapter which then dispatches on the Java side to the gradient op?

@nfeybesse
Copy link
Contributor Author

Yes, that’s correct — the intent of this change is that there is a single native function pointer registered, which points to the DispatchingGradientAdapter, and all gradient dispatching happens on the Java side based on opType.

The anonymous subclass is not meant to introduce multiple native callbacks. It is only used to ensure we get a properly allocated and strongly reachable FunctionPointer instance on the JavaCPP side, avoiding any lifecycle or pointer ownership issues. Conceptually, there is still just one native entry point, and all routing is done in Java.

In parallel, I’m also working on stabilizing and refining the gradients for branching operations (e.g., If, Switch, etc.), which rely heavily on this dispatching mechanism.

…ence

This adds a regression test for PR tensorflow#632.

The test dynamically discovers op types with no registered gradient
(using TF_GetAllOpList + TensorFlow.hasGradient), registers 11 custom
gradients, and verifies that all are present in the native gradient
registry.

This directly validates that registering more than 10 gradients works
and that all entries are correctly stored in the native registry,
without relying on Graph.addGradients() execution.

Addresses reviewer comment about missing test for >10 gradients.
@nfeybesse nfeybesse force-pushed the custom/gradients-dispatch branch from b0dbea2 to 56539b9 Compare February 24, 2026 16:26
@Craigacp
Copy link
Collaborator

Yes, that’s correct — the intent of this change is that there is a single native function pointer registered, which points to the DispatchingGradientAdapter, and all gradient dispatching happens on the Java side based on opType.

The anonymous subclass is not meant to introduce multiple native callbacks. It is only used to ensure we get a properly allocated and strongly reachable FunctionPointer instance on the JavaCPP side, avoiding any lifecycle or pointer ownership issues. Conceptually, there is still just one native entry point, and all routing is done in Java.

In parallel, I’m also working on stabilizing and refining the gradients for branching operations (e.g., If, Switch, etc.), which rely heavily on this dispatching mechanism.

That doesn't make any sense to me. If the logic is all Java side, then we don't need to worry about JavaCPP. It'll be reachable because it's stored in the ConcurrentHashMap inside DispatchingGradientAdapter. The new gradient adapters don't interact with the native code at all as the native code calls up into them.

Consequently the native code in tfj_gradients_impl can be simplified so it doesn't contain a map for the gradient adapters because that logic lives in Java now, it only needs a reference to the DispatchingGradientAdapter which can be set on the first call.

Replace the native per-op unordered_map of TFJ_GradFuncAdapter with a
single global dispatch adapter.

The native layer now registers CustomGradFunc per op type in the
GradOpRegistry, but always calls the same TFJ_GradFuncAdapter instance.
All opType-based routing is handled on the Java side by
DispatchingGradientAdapter.

This aligns the native implementation with the intended design:
there is only one native function pointer registered, and dispatch
logic lives entirely in Java.

Also fixes unsafe casting of Scope* to TFJ_Scope* by constructing a
temporary TFJ_Scope wrapper instead.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants