[GRPO] Try returning hidden statex for GRPO#5142
[GRPO] Try returning hidden statex for GRPO#5142Datta0 wants to merge 2 commits intounslothai:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a fallback mechanism for GRPO to retrieve hidden states from models that do not natively support it by wrapping the forward method. Feedback suggests optimizing performance by pre-computing function signatures during installation and using them to robustly handle positional arguments. It was also recommended to log exceptions during argument binding to aid debugging.
|
|
||
| def _get_num_logits_to_keep(kwargs): | ||
| num_logits_to_keep = kwargs.get("num_logits_to_keep", 0) or 0 | ||
| logits_to_keep = kwargs.get("logits_to_keep", 0) or 0 |
There was a problem hiding this comment.
The _get_num_logits_to_keep function currently only checks kwargs. If logits_to_keep or num_logits_to_keep are passed as positional arguments, they will be missed. It's better to use the model's signature to bind the arguments and extract the values robustly. Per repository rules, ensure exceptions are logged rather than silently ignored.
| def _get_num_logits_to_keep(kwargs): | |
| num_logits_to_keep = kwargs.get("num_logits_to_keep", 0) or 0 | |
| logits_to_keep = kwargs.get("logits_to_keep", 0) or 0 | |
| def _get_num_logits_to_keep(sig, args, kwargs): | |
| try: | |
| bound = sig.bind_partial(*args, **kwargs) | |
| return max(bound.arguments.get("num_logits_to_keep", 0) or 0, | |
| bound.arguments.get("logits_to_keep", 0) or 0) | |
| except Exception as e: | |
| import logging | |
| logging.debug(f"Error binding signature: {e}") | |
| return max(kwargs.get("num_logits_to_keep", 0) or 0, | |
| kwargs.get("logits_to_keep", 0) or 0) |
References
- Avoid using broad, silent exception handlers like
except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.
| ) | ||
| return outputs | ||
|
|
||
| hidden_states = hidden_states[-1] |
| return True | ||
| return False | ||
|
|
||
|
|
There was a problem hiding this comment.
To improve performance, _drop_forward_kwargs_consumed_positionally should accept a pre-computed signature instead of calling inspect.signature() on every forward pass. This avoids redundant introspection overhead during training.
| def _drop_forward_kwargs_consumed_positionally(sig, args, kwargs): |
References
- To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.
| if len(args) == 0 or len(kwargs) == 0: | ||
| return kwargs | ||
|
|
||
| consumed_names = [] |
| return False | ||
|
|
There was a problem hiding this comment.
Pre-compute the signature once during installation to avoid the overhead of inspect.signature() on every forward pass.
| return False | |
| original_forward = target_model.forward | |
| sig = inspect.signature(original_forward) | |
| model_name = type(target_model).__name__ |
References
- To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.
| return original_forward(*args, **kwargs) | ||
|
|
||
| forward_kwargs = _drop_forward_kwargs_consumed_positionally( |
520de3b to
6be5ab1
Compare
for more information, see https://pre-commit.ci
unslothai/unsloth-zoo#602 fixed an important issue where when some models return logits we were failing with shape mismatch. This is because for GRPO we generally expect hidden states to be returned with our wrappers (for most models ofc) adn lm_head is applied chunk wise to avoid materialising full large logits which are much larger than hidden states (4K vs 256K ish for eg)
This is an effort to make more models return hidden states for efficiency reasons. Orthogonal to the above mentioned PR :)
Ref: unslothai/unsloth-zoo#609