Skip to content

Batch

from tempora.utils.batch import Batch

Class for batches generated by a Sampler.

Batch(
    data: BatchData,
    seq_lens: SeqLens,
    feature_names: list[str],
    dtypes: list[Dtype],
    targets: BatchData | None = None,
    target_seq_lens: SeqLens | None = None,
    target_names: list[str] | None = None,
    target_dtypes: list[Dtype] | None = None,
    metadata: BatchMetadata | None = None
)

Parameters

Name Description
data Batch data either in tensor or (packed) matrix form.
seq_lens Length of each data sequence in the batch.
feature_names Names of the batch features/columns.
dtypes Batch data dtypes.
targets Optional batch targets.
target_seq_lens Length of each target sequence.
target_names Names of the target features/columns.
target_dtypes Target dtypes.
metadata Batch metadata.