from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
from .ant_link import AntLink, LossChannel
if TYPE_CHECKING:
from ..node import Node
[docs]
@dataclass
class AntLinkInternalObjectID(AntLink):
"""Refer to a different ObjectId on the same Node."""
[docs]
def backpropagate(self, node: "Node", gradient: np.ndarray, losstype: LossChannel) -> None:
"""Gradient w.r.t. the exchange vector."""
if self.id.object_id is None:
return
ref = node.find_object_reference(self.id.object_id)
if ref is not None:
ref.backpropagate(gradient, losstype)