Building Rl Grpo/building Rl Grpo

06 Sep 2025 - cohlem

This post is the continuation of this blog, where I experiment with the basics of distributed training. This post will explain how we apply those to RL training.

let’s familiarize ourself with GRPO algorithm.

model = initialize_model()
train_dataloader = initialize_train_dataloader(batch_size=per_rollout_size)

# Rollout step
for data in train_dataloader: # each len(data) will be per_rollout_size
	rollout_data = model.generate(data, responses_per_prompt=8) # we generate 8 responses per prompt for each entry inside data
	old_logprobs = calc_logprobs(rollout_data)
	entropy = calc_entropy(rollout_data)
	advantage = grpo_advantage(rollout_data)

	#update step
	update_dataloader = initialize_update_dataloader(batch_size=len(rollout_data)/update_per_rollout)
	for update_data in update_dataloader:
		logprobs = calc_logprobs(update_data)
		loss = grpo_loss(logprobs, old_logprobs)

		loss.backward()
		optimizer.step()
		optimizer.zero_grad()

The generic algorithm would look something like above. In our case, we generate rollouts using SGLang, so the model we are using to generate rollouts and model we’ll be using to do optimizer.step() will be different. So. we need to constantly update SGLang’s model with the model that we just updated.

First let’s initialize multiple processes and assign each process a GPU. The code below is straightforward and self-explanatory, cause it’s a boilerplate code that we would use everywhere in distributed run.


def setup():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    backend = 'nccl' if device == 'cuda' else 'gloo'
    rank = int(os.environ["RANK"])

    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    if device == 'cuda':
        torch.cuda.set_device(local_rank)
    os.environ["NCCL_CUMEM_ENABLE"] = "0"
    os.environ["NCCL_NVLS_ENABLE"] = "0"
    # Initialize with explicit parameters
    dist.init_process_group(
        backend=backend,
        world_size=world_size,
        rank=rank
    )

Now, we initialize a parent class for the actor (and also critic if ppo is used).

class Worker:
    """
    This is the policy that we will be updating with each gradient update, we rollout using this policy's
    parameters, and we use the logprobs from this policy, we will also copy it's weights to make it old policy
    """

    def __init__(self, config):

#         self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        self.config = config
        device= 'cuda' if torch.cuda.is_available() else 'cpu'
        # first make a device mesh
        fsdp_size = int(int(os.environ['WORLD_SIZE']) / (config.ddp_size * config.tp_size))
        # this mesh will only be used for model partition
        self.mesh = init_device_mesh(device,(config.ddp_size,fsdp_size, config.tp_size), mesh_dim_names=["DDP", "FSDP", "TP"])
        self.dp_size = int(int(os.environ['WORLD_SIZE']) / self.config.tp_size)
        # this mesh will be used for data parallelism
        self.device_mesh = init_device_mesh(device,(self.dp_size, config.tp_size), mesh_dim_names=["DP", "TP"])

    def prepare_optimizer(self):
        self.model.gradient_checkpointing_enable()
        if self.config.tp_size > 1:
            self.model = prepare_tp_model(self.model, self.mesh)

        self.model = prepare_dp_model(self.model, self.mesh)
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.lr)

        # offload the model to cpu
        load_model_to_device(self, "cpu")

Notice, why we initialize two different device mesh one is self.mesh and another self.device_mesh. They have different purpose. self.mesh will be used for model