Batched Route Generation
The BatchedBeamSearch class provides efficient batched route generation for multiple target molecules simultaneously, with support for variable batch sizes and lengths.
Features
- Variable Batch Sizes: Process any number of targets in a single batch
- Variable Path Start Lengths: Each target can have different starting material lengths
- Variable Target Lengths: Different maximum output lengths per target
- Early Termination: Each batch item can finish independently
- GPU Efficient: Optimized batching for maximum GPU utilization
Basic Usage
Single Target (Compatible with BeamSearchOptimized)
from pathlib import Path
from directmultistep import generate_routes
target = "CNCc1ccccc1"
starting_material = "CN"
n_steps = 1
routes = generate_routes(
target=target,
n_steps=n_steps,
starting_material=starting_material,
beam_size=5,
model="flash",
config_path=Path("data/configs/dms_dictionary.yaml"),
ckpt_dir=Path("data/checkpoints"),
)
for route in routes:
print(route)
Multiple Targets (Batched)
from pathlib import Path
from directmultistep import generate_routes_batched
targets = [
"CNCc1ccccc1",
"CCOc1ccccc1",
"c1ccccc1",
]
n_steps_list = [1, 2, 1]
starting_materials = [
"CN",
None,
None,
]
routes = generate_routes_batched(
targets=targets,
n_steps_list=n_steps_list,
starting_materials=starting_materials,
beam_size=5,
model="flash",
config_path=Path("data/configs/dms_dictionary.yaml"),
ckpt_dir=Path("data/checkpoints"),
)
for i, (target, routes_for_target) in enumerate(zip(targets, routes)):
print(f"Target {i+1}: {target}")
print(f"Routes: {len(routes_for_target)}")
for route in routes_for_target[:3]:
print(f" {route}")
Advanced Usage
Using the Low-Level API
For more control, you can use the lower-level APIs:
from pathlib import Path
import torch
from directmultistep import (
load_published_model,
create_batched_beam_search,
prepare_batched_input_tensors,
)
from directmultistep.utils.dataset import RoutesProcessing
# Load model
model = load_published_model("flash", Path("data/checkpoints"))
rds = RoutesProcessing(metadata_path=Path("data/configs/dms_dictionary.yaml"))
# Create batched beam search
beam_search = create_batched_beam_search(model, beam_size=5, rds=rds)
# Prepare batched inputs
targets = ["CNCc1ccccc1", "CCOc1ccccc1"]
n_steps_list = [1, 2]
starting_materials = ["CN", None]
encoder_batch, steps_batch, path_starts, target_lengths = prepare_batched_input_tensors(
targets=targets,
n_steps_list=n_steps_list,
starting_materials=starting_materials,
rds=rds,
product_max_length=rds.product_max_length,
sm_max_length=rds.sm_max_length,
)
# Run batched beam search
device = next(model.parameters()).device
results = beam_search.decode(
src_BC=encoder_batch.to(device),
steps_B1=steps_batch.to(device) if steps_batch is not None else None,
path_starts=[ps.to(device) for ps in path_starts],
target_lengths=target_lengths,
progress_bar=True,
)
# Results is a list of lists: results[batch_idx][beam_idx] = (sequence, log_prob)
for batch_idx, beam_results in enumerate(results):
print(f"\nTarget {batch_idx}: {targets[batch_idx]}")
for beam_idx, (sequence, log_prob) in enumerate(beam_results):
print(f" Beam {beam_idx}: score={log_prob:.2f}, seq={sequence[:50]}...")
Custom Batch Processing
You can also directly use BatchedBeamSearch for custom processing:
from directmultistep.generation.tensor_gen import BatchedBeamSearch
# Create custom beam search with specific parameters
beam_search = BatchedBeamSearch(
model=model,
beam_size=10,
start_idx=0,
pad_idx=52,
end_idx=22,
max_length=1074,
idx_to_token=rds.idx_to_token,
device=device,
)
# Use with custom target lengths per batch item
results = beam_search.decode(
src_BC=encoder_batch,
steps_B1=steps_batch,
path_starts=path_starts,
target_lengths=[500, 1000, 1500], # Different max length per target
progress_bar=True,
)
API Reference
High-Level Functions
generate_routes_batched
def generate_routes_batched(
targets: Sequence[str],
n_steps_list: Sequence[int] | None,
starting_materials: Sequence[str | None],
beam_size: int,
model: ModelName | torch.nn.Module,
config_path: Path,
ckpt_dir: Path | None = None,
commercial_stock: set[str] | None = None,
use_fp16: bool = False,
) -> list[list[str]]:
Generate synthesis routes for multiple targets using batched beam search.
Arguments:
- targets: List of SMILES strings of target molecules
- n_steps_list: List of number of synthesis steps for each target or None (for explorer)
- starting_materials: List of starting materials for each target (can contain None)
- beam_size: Beam size for the beam search
- model: Either a model name or a torch.nn.Module
- config_path: Path to the model configuration file
- ckpt_dir: Directory containing model checkpoints (required if model is a string)
- commercial_stock: Set of commercially available starting materials (SMILES)
- use_fp16: Whether to use half precision (FP16)
Returns: - List of lists, where each inner list contains valid routes for the corresponding target
Utility Functions
prepare_batched_input_tensors
def prepare_batched_input_tensors(
targets: Sequence[str],
n_steps_list: Sequence[int] | None,
starting_materials: Sequence[str | None],
rds: RoutesProcessing,
product_max_length: int,
sm_max_length: int,
use_fp16: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor], list[int]]:
Prepare batched input tensors for the model.
Returns:
- encoder_batch: Batched input tensor for the encoder [B, C]
- steps_batch: Batched tensor of steps [B, 1], or None if all n_steps are None
- path_starts: List of initial path tensors for decoder (variable lengths)
- target_lengths: List of target max lengths per batch item
create_batched_beam_search
def create_batched_beam_search(
model: torch.nn.Module,
beam_size: int,
rds: RoutesProcessing
) -> BatchedBeamSearch:
Create a batched beam search object that supports variable batch sizes and lengths.
BatchedBeamSearch Class
class BatchedBeamSearch:
def __init__(
self,
model: nn.Module,
beam_size: int,
start_idx: int,
pad_idx: int,
end_idx: int,
max_length: int,
idx_to_token: dict[int, str],
device: torch.device,
):
...
def decode(
self,
src_BC: Tensor,
steps_B1: Tensor | None,
path_starts: list[Tensor | None] | None = None,
target_lengths: list[int] | None = None,
progress_bar: bool = True,
token_processor: Callable[[list[str]], str] | None = None,
) -> list[list[tuple[str, float]]]:
...
Performance Considerations
- Batch Size: Larger batches improve GPU utilization but increase memory usage
- Variable Lengths: The implementation handles variable lengths efficiently by grouping active beams
- Early Termination: Batch items that finish early are removed from computation
- Memory Usage: Peak memory scales with
batch_size * beam_size * max_sequence_length
Comparison with BeamSearchOptimized
| Feature | BeamSearchOptimized | BatchedBeamSearch |
|---|---|---|
| Batch Size | Only 1 | Any positive integer |
| Variable Path Starts | No | Yes |
| Variable Target Lengths | No | Yes |
| Early Termination | All beams together | Per batch item |
| API Compatibility | Single target | Multiple targets |
For single-target generation, both implementations produce identical results. For multiple targets, use BatchedBeamSearch for better efficiency.