Skip to content

[GRPO] Try returning hidden statex for GRPO#5142

Draft
Datta0 wants to merge 2 commits intounslothai:mainfrom
Datta0:grpo-hidden-fallback
Draft

[GRPO] Try returning hidden statex for GRPO#5142
Datta0 wants to merge 2 commits intounslothai:mainfrom
Datta0:grpo-hidden-fallback

Conversation

@Datta0
Copy link
Copy Markdown
Collaborator

@Datta0 Datta0 commented Apr 23, 2026

unslothai/unsloth-zoo#602 fixed an important issue where when some models return logits we were failing with shape mismatch. This is because for GRPO we generally expect hidden states to be returned with our wrappers (for most models ofc) adn lm_head is applied chunk wise to avoid materialising full large logits which are much larger than hidden states (4K vs 256K ish for eg)
This is an effort to make more models return hidden states for efficiency reasons. Orthogonal to the above mentioned PR :)

Ref: unslothai/unsloth-zoo#609

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a fallback mechanism for GRPO to retrieve hidden states from models that do not natively support it by wrapping the forward method. Feedback suggests optimizing performance by pre-computing function signatures during installation and using them to robustly handle positional arguments. It was also recommended to log exceptions during argument binding to aid debugging.

Comment thread unsloth/models/rl.py
Comment on lines +593 to +596

def _get_num_logits_to_keep(kwargs):
num_logits_to_keep = kwargs.get("num_logits_to_keep", 0) or 0
logits_to_keep = kwargs.get("logits_to_keep", 0) or 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _get_num_logits_to_keep function currently only checks kwargs. If logits_to_keep or num_logits_to_keep are passed as positional arguments, they will be missed. It's better to use the model's signature to bind the arguments and extract the values robustly. Per repository rules, ensure exceptions are logged rather than silently ignored.

Suggested change
def _get_num_logits_to_keep(kwargs):
num_logits_to_keep = kwargs.get("num_logits_to_keep", 0) or 0
logits_to_keep = kwargs.get("logits_to_keep", 0) or 0
def _get_num_logits_to_keep(sig, args, kwargs):
try:
bound = sig.bind_partial(*args, **kwargs)
return max(bound.arguments.get("num_logits_to_keep", 0) or 0,
bound.arguments.get("logits_to_keep", 0) or 0)
except Exception as e:
import logging
logging.debug(f"Error binding signature: {e}")
return max(kwargs.get("num_logits_to_keep", 0) or 0,
kwargs.get("logits_to_keep", 0) or 0)
References
  1. Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.

Comment thread unsloth/models/rl.py
)
return outputs

hidden_states = hidden_states[-1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Use the pre-computed signature and original arguments to correctly extract num_logits_to_keep, accounting for positional arguments.

Suggested change
hidden_states = hidden_states[-1]
num_logits_to_keep = _get_num_logits_to_keep(sig, args, kwargs)

Comment thread unsloth/models/rl.py
return True
return False


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve performance, _drop_forward_kwargs_consumed_positionally should accept a pre-computed signature instead of calling inspect.signature() on every forward pass. This avoids redundant introspection overhead during training.

Suggested change
def _drop_forward_kwargs_consumed_positionally(sig, args, kwargs):
References
  1. To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.

Comment thread unsloth/models/rl.py
if len(args) == 0 or len(kwargs) == 0:
return kwargs

consumed_names = []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use the passed signature instead of re-computing it.

Suggested change
consumed_names = []
for parameter in sig.parameters.values():

Comment thread unsloth/models/rl.py
Comment on lines +632 to +633
return False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pre-compute the signature once during installation to avoid the overhead of inspect.signature() on every forward pass.

Suggested change
return False
original_forward = target_model.forward
sig = inspect.signature(original_forward)
model_name = type(target_model).__name__
References
  1. To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.

Comment thread unsloth/models/rl.py
Comment on lines +639 to +641
return original_forward(*args, **kwargs)

forward_kwargs = _drop_forward_kwargs_consumed_positionally(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pass the pre-computed signature to the helper function.

Suggested change
return original_forward(*args, **kwargs)
forward_kwargs = _drop_forward_kwargs_consumed_positionally(
forward_kwargs = _drop_forward_kwargs_consumed_positionally(
sig, args, kwargs
)

@Datta0 Datta0 force-pushed the grpo-hidden-fallback branch from 520de3b to 6be5ab1 Compare April 27, 2026 06:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant