Skip to content

Monitoring Training

This guide explains how to monitor and visualize training progress for DMS models.

Basic Usage

The simplest way to visualize training progress is using the provided plotting utilities in use-examples/visualize_train_curves.py

Run Configuration

Use RunConfig to specify which training runs to visualize:

from directmultistep.analysis.training import RunConfig

run = RunConfig(
    run_name="flash_10M",      # Folder name of the run
    trace_name="Flash Model",  # Display name for the traces
    include_val=True          # Whether to include validation curve
)

Training Curves

The plot_training_curves function creates a figure showing:

  • Training loss curves (solid lines)
  • Validation loss curves (dotted lines with markers)
  • X-axis shows number of processed tokens
  • Hovering over validation points shows epoch information

Learning Rate Curves

The plot_learning_rates function visualizes the learning rate schedule:

  • Shows learning rate vs. training step
  • Useful for verifying learning rate schedules
  • Multiple runs can be compared on the same plot

Advanced Usage

For more control over visualization, you can load the training data directly:

from directmultistep.analysis.training import load_training_df

# Load training data
df = load_training_df(train_path, "flash_10M")

# Ignore specific training runs by ID
df = load_training_df(train_path, "flash_10M", ignore_ids=[0, 1])

The returned DataFrame contains columns:

  • processed_tokens: Number of tokens processed
  • train_loss: Training loss
  • val_loss: Validation loss (if available)
  • train_lr: Learning rate
  • epoch: Current epoch
  • Additional metrics depending on the training configuration

Source Code

directmultistep.analysis.training

RunConfig dataclass

Configuration for a training run visualization.

Source code in src/directmultistep/analysis/training.py
@dataclass
class RunConfig:
    """Configuration for a training run visualization."""

    run_name: str  # Folder name of the run
    trace_name: str  # Display name for the traces
    include_val: bool = True  # Whether to include validation curve
    ignore_ids: list[int] | None = None  # Version IDs to ignore when loading data

plot_training_curves(train_path, runs, x_axis='processed_tokens', log_x=False, log_y=False)

makes a graph. you get it.

Source code in src/directmultistep/analysis/training.py
def plot_training_curves(
    train_path: Path,
    runs: list[RunConfig],
    x_axis: str = "processed_tokens",
    log_x: bool = False,
    log_y: bool = False,
) -> go.Figure:
    """makes a graph. you get it."""
    traces = []
    for i, run in enumerate(runs):
        data = load_training_data(train_path, run.run_name, run.ignore_ids)
        if not data:
            logger.debug(f"no data found for {run.run_name}, skipping.")
            continue

        color_idx = i % len(style.colors_light)
        traces.append(create_train_trace(data, run.trace_name, style.colors_light[color_idx], x_axis))
        if run.include_val:
            traces.append(create_val_trace(data, run.trace_name, style.colors_dark[color_idx], x_axis))

    fig = go.Figure(data=traces)
    fig.update_layout(
        title="Training Loss",
        xaxis_title=x_axis,
        yaxis_title="Loss",
        xaxis_type="log" if log_x else "linear",
        yaxis_type="log" if log_y else "linear",
    )
    style.apply_development_style(fig)
    return fig

plot_learning_rates(train_path, runs)

makes another graph. also obvious.

Source code in src/directmultistep/analysis/training.py
def plot_learning_rates(train_path: Path, runs: list[RunConfig]) -> go.Figure:
    """makes another graph. also obvious."""
    traces = []
    for run in runs:
        data = load_training_data(train_path, run.run_name, run.ignore_ids)
        if data:
            traces.append(get_lr_trace(data, run.trace_name))

    fig = go.Figure(data=traces)
    fig.update_layout(title="Learning Rate", xaxis_title="Step", yaxis_title="Learning Rate", width=800)
    style.apply_development_style(fig)
    return fig