gm.evals.SamplerEvaluator#
- class gemma.gm.evals.SamplerEvaluator(**kwargs)[source]
基类:
kauldron.evals.evaluators.EvaluatorBase
采样评估器。
该评估器期望数据集包含
Seq2SeqTask
转换。- max_new_tokens
生成的最大新 token 数量。总共,模型将处理 input_length + max_new_tokens。
- 类型:
int
- num_examples
采样多少个示例。
- 类型:
int | None
- ds
要评估的数据集。请注意,数据集必须是未批处理的,并且包含原始 str 字段。
- 类型:
kauldron.data.pipelines.Pipeline
- model
要使用的模型。
- 类型:
flax.linen.module.Module
- losses
要计算的损失。损失和指标可以通过键 preds.text 访问预测文本。
- 类型:
collections.abc.Mapping[str, kauldron.losses.base.Loss]
- metrics
要计算的指标。损失和指标可以通过键 preds.text 访问预测文本。
- 类型:
collections.abc.Mapping[str, kauldron.metrics.base.Metric]
- summaries
可选的要写入的摘要。
- 类型:
collections.abc.Mapping[str, kauldron.summaries.base.Summary]
- max_new_tokens: int
- num_examples: int | None = 1
- ds: kd.data.Pipeline = _FakeRootCfg('cfg.eval_ds')
- model: nn.Module = _FakeRootCfg('cfg.model')
- losses: Mapping[str, kd.losses.Loss]
- metrics: Mapping[str, kd.metrics.Metric]
- summaries: Mapping[str, kd.summaries.Summary]
- evaluate(
- state: kauldron.train.train_step.TrainState,
- step: int,
运行此评估器,然后写入并可选地返回结果。
- property examples: list[Any]
从数据集中提取 prompt。