-
Notifications
You must be signed in to change notification settings - Fork 5
Add tests/fixes for inplace rules #158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I'd like to do a bit of tidying up with this one but it can be merged if tests pass and people feel like it, I can just make a separate PR. |
Codecov Report✅ All modified and coverable lines are covered by tests.
🚀 New features to boost your workflow:
|
8a1283e to
772b61b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me!
I have two remaining questions:
- This also bumps the Mooncake version, should this come before or after #164 ?
I might be getting this wrong, but I am kind of confused now about why we arezero!ing thedargs. Aren't we supposed to accumulate the contributions of the gradient into that, not setting them to? I.e., should this not just bedarg1 += zero(darg1), or effectively we don't have to do that?
[EDIT] ignore this last part, I'm stupid.
facdda6 to
4528a73
Compare
| function copy_tangent(var::Mooncake.CoDual, Δargs) | ||
| dargs = make_mooncake_fdata(deepcopy(Δargs)) | ||
| copyto!(Mooncake.tangent(var), dargs) | ||
| return | ||
| end | ||
|
|
||
| function copy_tangent(var::Mooncake.CoDual, Δargs::Tuple) | ||
| dargs = make_mooncake_fdata.(deepcopy(Δargs)) | ||
| for (var_tangent, darg) in zip(Mooncake.tangent(var), dargs) | ||
| if var_tangent isa Mooncake.FData | ||
| for (var_f, darg_f) in zip(Mooncake._fields(var_tangent), Mooncake._fields(darg)) | ||
| copyto!(var_f, darg_f) | ||
| end | ||
| else | ||
| copyto!(var_tangent, darg) | ||
| end | ||
| end | ||
| return | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these functions have been called copy_tangent! for consistency with Julia naming guidelines?
| end | ||
|
|
||
| function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) | ||
| function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata; ȳ = Δargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the role of Δargs and ȳ. Is Δargs is just the allocated space for the shadow variables of args, whereas ȳ contains the actual cotangents? But in most cases they are the same, and inplace_out contains Δargs as tangents and copy_tangent(inplace_out, ȳ) ends up copying ȳ into Δargs, i.e. into itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarity, this is just a question to make sure I understand. I don't take an issue with copying into itself in these test methods.
| dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] | ||
| @test dA_inplace_ ≈ dA_copy_ | ||
| @test copy_args == inplace_args | ||
| if dargs_copy isa Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do the dargs_... variables contain at the end of the _get_..._derivative ? They are all set to zero in the pullback, right? Does it make sense to just test that both dargs_copy and dargs_inplace are both zero, rather than simply testing that they are equal (which is true if they are both zero)?
Added more tests to
test_pullbacks_matchto make sure the state of the arguments is restored, and the final argument derivatives match between inplace and non in place methods.Unfortunately, the Mooncake FD tester doesn't work well for our functions, because
Abecomes a scratch space, and the inputs are also the outputs (so get incremented twice under the FD scheme).