gm.data.pad

gm.data.pad#

gemma.gm.data.pad(
element: jaxtyping.Shaped[Array, 'sequence'] | jaxtyping.Shaped[ndarray, 'sequence'] | Sequence[PyTree[L]] | Mapping[str, PyTree[L]],
max_length: int,
*,
truncate: bool = False,
fill_value: int = 0,
axis: int = -1,
) jaxtyping.Shaped[Array, 'max_length'] | jaxtyping.Shaped[ndarray, 'max_length'] | Sequence[PyTree[L]] | Mapping[str, PyTree[L]][source]

在序列末尾添加零以达到最大长度。

支持一次填充多个数组。

参数:
  • element – 要填充的序列。

  • max_length – 序列的最大长度。

  • truncate – 是否将序列截断为最大长度。如果为 False,则长度超过 max_length 的序列将引发错误。

  • fill_value – 用于填充序列的值。

  • axis – 填充序列的轴(目前仅支持 -1)。

返回:

长度为 max_length 的填充后序列。