Skip to content
20 changes: 13 additions & 7 deletions bittensor_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5100,13 +5100,19 @@ def stake_remove(
default=self.config.get("wallet_name") or defaults.wallet.name,
)
if include_hotkeys:
if len(include_hotkeys) > 1:
print_error("Cannot unstake_all from multiple hotkeys at once.")
return False
elif is_valid_ss58_address(include_hotkeys[0]):
hotkey_ss58_address = include_hotkeys[0]
else:
print_error("Invalid hotkey ss58 address.")
# Multiple hotkeys are now supported with batching
# Initialize wallet as it's needed to resolve hotkey names and get coldkey
wallet = self.wallet_ask(
wallet_name,
wallet_path,
wallet_hotkey,
ask_for=[WO.NAME, WO.PATH],
)
if len(include_hotkeys) == 1:
# Single hotkey - use hotkey_ss58_address for backward compatibility
if is_valid_ss58_address(include_hotkeys[0]):
hotkey_ss58_address = include_hotkeys[0]
# If it's a hotkey name, it will be handled by the unstake_all function
return False
elif all_hotkeys:
wallet = self.wallet_ask(
Expand Down
100 changes: 100 additions & 0 deletions bittensor_cli/src/bittensor/subtensor_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,106 @@ async def create_signed(call_to_sign, n):
)
return False, err_msg, None

async def sign_and_send_batch_extrinsic(
self,
calls: list[GenericCall],
wallet: Wallet,
wait_for_inclusion: bool = True,
wait_for_finalization: bool = False,
era: Optional[dict[str, int]] = None,
proxy: Optional[str] = None,
nonce: Optional[int] = None,
sign_with: Literal["coldkey", "hotkey", "coldkeypub"] = "coldkey",
batch_type: Literal["batch", "batch_all"] = "batch",
mev_protection: bool = False,
) -> tuple[bool, str, Optional[AsyncExtrinsicReceipt], list[dict]]:
"""
:param calls: List of prepared Call objects to batch together
:param wallet: the wallet whose key will be used to sign the extrinsic
:param wait_for_inclusion: whether to wait until the extrinsic call is included on the chain
:param wait_for_finalization: whether to wait until the extrinsic call is finalized on the chain
:param era: The length (in blocks) for which a transaction should be valid.
:param proxy: The real account used to create the proxy. None if not using a proxy for this call.
:param nonce: The nonce used to submit this extrinsic call.
:param sign_with: Determine which of the wallet's keypairs to use to sign the extrinsic call.
:param batch_type: "batch" (stops on first error) or "batch_all" (executes all, fails if any fail)
:param mev_protection: If set, uses Mev Protection on the extrinsic, thus encrypting it.

:return: (success, error message, extrinsic receipt | None, list of individual call results)
"""
if not calls:
return False, "No calls provided for batching operation", None, []

if len(calls) == 1:
success, err_msg, receipt = await self.sign_and_send_extrinsic(
call=calls[0],
wallet=wallet,
wait_for_inclusion=wait_for_inclusion,
wait_for_finalization=wait_for_finalization,
era=era,
proxy=proxy,
nonce=nonce,
sign_with=sign_with,
mev_protection=mev_protection,
)
return success, err_msg, receipt, [{"success": success, "error": err_msg}]

batch_call = await self.substrate.compose_call(
call_module="Utility",
call_function=batch_type,
call_params={"calls": calls},
)

success, err_msg, receipt = await self.sign_and_send_extrinsic(
call=batch_call,
wallet=wallet,
wait_for_inclusion=wait_for_inclusion,
wait_for_finalization=wait_for_finalization,
era=era,
proxy=proxy,
nonce=nonce,
sign_with=sign_with,
mev_protection=mev_protection,
)

# Parse batch results if successful
call_results = []
if success and receipt:
try:
# Extract batch execution results from receipt
# The receipt should contain information about which calls succeeded/failed
for i, call in enumerate(calls):
call_results.append(
{
"index": i,
"call": call,
"success": True, # Will be updated if we can parse receipt
}
)
except Exception:
# If we can't parse results, assume all succeeded if batch succeeded
for i, call in enumerate(calls):
call_results.append(
{
"index": i,
"call": call,
"success": success,
}
)
Comment on lines +1372 to +1392
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this try/except block is pointless - the code in the try block will never raise an exception (a simple loop with append). The comment Will be updated if we can parse receipt indicates an incomplete implementation.

Pls either implement parsing of events from the receipt to determine the success of individual calls, or remove the dead code

else:
# If batch failed, mark all as failed
for i, call in enumerate(calls):
call_results.append(
{
"index": i,
"call": call,
"success": False,
"error": err_msg,
}
)

return success, err_msg, receipt, call_results

async def get_children(self, hotkey, netuid) -> tuple[bool, list, str]:
"""
This method retrieves the children of a given hotkey and netuid. It queries the SubtensorModule's ChildKeys
Expand Down
240 changes: 200 additions & 40 deletions bittensor_cli/src/commands/stake/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,56 +474,216 @@ async def stake_extrinsic(
successes = defaultdict(dict)
error_messages = defaultdict(dict)
extrinsic_ids = defaultdict(dict)
with console.status(f"\n:satellite: Staking on netuid(s): {netuids} ...") as status:
if safe_staking:
stake_coroutines = {}
for i, (ni, am, curr, price_with_tolerance) in enumerate(
zip(
netuids,
amounts_to_stake,
current_stake_balances,
prices_with_tolerance,

# Collect all calls for batching
calls_to_batch = []
call_metadata = [] # Track (netuid, staking_address, amount, current_stake, price_with_tolerance) for each call

with console.status(
f"\n:satellite: Preparing batch staking on netuid(s): {netuids} ..."
) as status:
# Get next nonce for batch
next_nonce = await subtensor.substrate.get_account_next_index(coldkey_ss58)
# Get block_hash at the beginning to speed up compose_call operations
block_hash = await subtensor.substrate.get_chain_head()

# Collect all calls - iterate through the same order as when building the lists
# The lists are built in order: for each hotkey, for each netuid
list_idx = 0
price_idx = 0
for hotkey in hotkeys_to_stake_to:
for netuid in netuids:
# Safety check: if we've processed all items from the first loop, stop
if list_idx >= len(amounts_to_stake):
break

# Verify subnet exists (same check as first loop)
# If subnet doesn't exist, it was skipped in first loop, so list_idx won't advance
# We need to skip it here too to stay in sync
subnet_info = all_subnets.get(netuid)
if not subnet_info:
# This netuid was skipped in first loop (doesn't exist)
# Don't advance list_idx, just continue to next netuid
continue

am = amounts_to_stake[list_idx]
curr = current_stake_balances[list_idx]
staking_address = hotkey[1]
price_with_tol = (
prices_with_tolerance[price_idx]
if safe_staking and price_idx < len(prices_with_tolerance)
else None
)
):
for _, staking_address in hotkeys_to_stake_to:
# Regular extrinsic for root subnet
if ni == 0:
stake_coroutines[(ni, staking_address)] = stake_extrinsic(
netuid_i=ni,
amount_=am,
current=curr,
staking_address_ss58=staking_address,
status_=status,
)
else:
stake_coroutines[(ni, staking_address)] = safe_stake_extrinsic(
netuid_=ni,
amount_=am,
current_stake=curr,
hotkey_ss58_=staking_address,
price_limit=price_with_tolerance,
status_=status,
if safe_staking:
price_idx += 1

call_metadata.append(
(netuid, staking_address, am, curr, price_with_tol)
)

if safe_staking and netuid != 0 and price_with_tol:
# Safe staking for non-root subnets
call = await subtensor.substrate.compose_call(
call_module="SubtensorModule",
call_function="add_stake_limit",
call_params={
"hotkey": staking_address,
"netuid": netuid,
"amount_staked": am.rao,
"limit_price": price_with_tol.rao,
"allow_partial": allow_partial_stake,
},
block_hash=block_hash,
)
else:
# Regular staking for root subnet or non-safe staking
call = await subtensor.substrate.compose_call(
call_module="SubtensorModule",
call_function="add_stake",
call_params={
"hotkey": staking_address,
"netuid": netuid,
"amount_staked": am.rao,
},
block_hash=block_hash,
)
calls_to_batch.append(call)
list_idx += 1

# If we have multiple calls, batch them; otherwise send single call
if len(calls_to_batch) > 1:
status.update(
f"\n:satellite: Batching {len(calls_to_batch)} stake operations..."
)
(
batch_success,
batch_err_msg,
batch_receipt,
call_results,
) = await subtensor.sign_and_send_batch_extrinsic(
calls=calls_to_batch,
wallet=wallet,
era={"period": era},
proxy=proxy,
nonce=next_nonce,
mev_protection=mev_protection,
batch_type="batch_all", # Use batch_all to execute all even if some fail
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a misunderstanding. batch_all means: all calls will be executed, but if at least one fails, all will be rolled back. it doesn't mean execute all even if some fail.

If you need to execute all regardless of errors you need force_batch (which is in the signature, but not used).
Pls, clarify the requirements and correct the comments/logic

)

if batch_success and batch_receipt:
if mev_protection:
Comment on lines +573 to +574
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend you to create something like handle_mev_protection function with proper parameters to handle the result. You broke DRY rule.
Currently, you have 3 exact the same code section in stake_add, unstake, and unstake_all calls.

inner_hash = batch_err_msg
mev_shield_id = await extract_mev_shield_id(batch_receipt)
(
mev_success,
mev_error,
batch_receipt,
) = await wait_for_extrinsic_by_hash(
subtensor=subtensor,
extrinsic_hash=inner_hash,
shield_id=mev_shield_id,
submit_block_hash=batch_receipt.block_hash,
status=status,
)
if not mev_success:
status.stop()
print_error(f"\n:cross_mark: [red]Failed[/red]: {mev_error}")
batch_success = False
batch_err_msg = mev_error

if batch_success:
if not json_output:
await print_extrinsic_id(batch_receipt)
batch_ext_id = await batch_receipt.get_extrinsic_identifier()

# Fetch updated balances for display
block_hash = await subtensor.substrate.get_chain_head()
current_balance = await subtensor.get_balance(
coldkey_ss58, block_hash
)

# Fetch all stake balances in parallel
if not json_output:
stake_fetch_tasks = [
subtensor.get_stake(
hotkey_ss58=staking_address,
coldkey_ss58=coldkey_ss58,
netuid=ni,
block_hash=block_hash,
)
for ni, staking_address, _, _, _ in call_metadata
]
new_stakes = await asyncio.gather(*stake_fetch_tasks)

# Process results for each call
for idx, (ni, staking_address, am, curr, _) in enumerate(
call_metadata
):
# For batch_all, we assume all succeeded if batch succeeded
# Individual call results would need to be parsed from receipt events
successes[ni][staking_address] = True
error_messages[ni][staking_address] = ""
extrinsic_ids[ni][staking_address] = batch_ext_id

if not json_output:
new_stake = new_stakes[idx]
console.print(
f":white_heavy_check_mark: [dark_sea_green3]Finalized. "
f"Stake added to netuid: {ni}, hotkey: {staking_address}[/dark_sea_green3]"
)
console.print(
f"Subnet: [{COLOR_PALETTE['GENERAL']['SUBHEADING']}]"
f"{ni}[/{COLOR_PALETTE['GENERAL']['SUBHEADING']}] "
f"Stake:\n"
f" [blue]{curr}[/blue] "
f":arrow_right: "
f"[{COLOR_PALETTE['STAKE']['STAKE_AMOUNT']}]{new_stake}\n"
)

# Show final coldkey balance
if not json_output:
console.print(
f"Coldkey Balance:\n "
f"[blue]{current_wallet_balance}[/blue] "
f":arrow_right: "
f"[{COLOR_PALETTE['STAKE']['STAKE_AMOUNT']}]{current_balance}"
)
else:
stake_coroutines = {
(ni, staking_address): stake_extrinsic(
else:
# Batch failed
for ni, staking_address, _, _, _ in call_metadata:
successes[ni][staking_address] = False
error_messages[ni][staking_address] = batch_err_msg
else:
# Batch submission failed
for ni, staking_address, _, _, _ in call_metadata:
successes[ni][staking_address] = False
error_messages[ni][staking_address] = (
batch_err_msg or "Batch submission failed"
)
elif len(calls_to_batch) == 1:
# Single call - use regular extrinsic
ni, staking_address, am, curr, price_with_tol = call_metadata[0]

if safe_staking and ni != 0 and price_with_tol:
success, er_msg, ext_receipt = await safe_stake_extrinsic(
netuid_=ni,
amount_=am,
current_stake=curr,
hotkey_ss58_=staking_address,
price_limit=price_with_tol,
status_=status,
)
else:
success, er_msg, ext_receipt = await stake_extrinsic(
netuid_i=ni,
amount_=am,
current=curr,
staking_address_ss58=staking_address,
status_=status,
)
for i, (ni, am, curr) in enumerate(
zip(netuids, amounts_to_stake, current_stake_balances)
)
for _, staking_address in hotkeys_to_stake_to
}
# We can gather them all at once but balance reporting will be in race-condition.
for (ni, staking_address), coroutine in stake_coroutines.items():
success, er_msg, ext_receipt = await coroutine
successes[ni][staking_address] = success
error_messages[ni][staking_address] = er_msg
if success:
if success and ext_receipt:
extrinsic_ids[ni][
staking_address
] = await ext_receipt.get_extrinsic_identifier()
Expand Down
Loading
Loading