gm.evals.SamplerEvaluator

gm.evals.SamplerEvaluator#

class gemma.gm.evals.SamplerEvaluator(**kwargs)[source]

基类: kauldron.evals.evaluators.EvaluatorBase

采样评估器。

该评估器期望数据集包含 Seq2SeqTask 转换。

max_new_tokens

生成的最大新 token 数量。总共,模型将处理 input_length + max_new_tokens

类型:

int

num_examples

采样多少个示例。

类型:

int | None

ds

要评估的数据集。请注意,数据集必须是未批处理的,并且包含原始 str 字段。

类型:

kauldron.data.pipelines.Pipeline

model

要使用的模型。

类型:

flax.linen.module.Module

losses

要计算的损失。损失和指标可以通过键 preds.text 访问预测文本。

类型:

collections.abc.Mapping[str, kauldron.losses.base.Loss]

metrics

要计算的指标。损失和指标可以通过键 preds.text 访问预测文本。

类型:

collections.abc.Mapping[str, kauldron.metrics.base.Metric]

summaries

可选的要写入的摘要。

类型:

collections.abc.Mapping[str, kauldron.summaries.base.Summary]

max_new_tokens: int
num_examples: int | None = 1
ds: kd.data.Pipeline = _FakeRootCfg('cfg.eval_ds')
model: nn.Module = _FakeRootCfg('cfg.model')
losses: Mapping[str, kd.losses.Loss]
metrics: Mapping[str, kd.metrics.Metric]
summaries: Mapping[str, kd.summaries.Summary]
evaluate(
state: kauldron.train.train_step.TrainState,
step: int,
) Any[source]

运行此评估器,然后写入并可选地返回结果。

property examples: list[Any]

从数据集中提取 prompt。