This weekend, I decided to perform a short test on the Group Relative Policy Optimization (GRPO) technique, introduced in DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models.

Since the release of the GRPO trainer in trl and the fantastic Unsloth.AI project, there have been many tutorials available. But I decided to make something a bit different, meaning training a model to reason in another language—in this case, my native language—Bulgarian.

Due to the limited time span, training data, and resources, it will be more like a proof of concept, but I wanted to validate the approach. I hope it can be helpful for someone else looking to create a non-English reasoning model.

This article will describe the training process of the LLMBG-Llama-3.1-8B-BG-Reasoning-v0.1 – as far as I’m aware, the first publicly available reasoning model trained to “think” in Bulgarian. I’ll cover the data, training approach, and results.

The key highlight: the model starts “thinking” in Bulgarian after just 5 hours of training on a single L40S GPU. While I used Bulgarian, this approach can be adapted for training reasoning models in any language.

Introduction

In the opening sentences, I placed the word “think” in quotes. The same could be done for “reasoning” as there are fundamental differences in how humans and models think, particularly regarding whether models can truly think. However, that is a separate topic. In this article, I will use the terms reasoning and thinking without quotation marks from this point forward.

The release of DeepSeek-R1 has generated considerable excitement and progress in the field of model reasoning. If you’re new to the concept of reasoning in models, I recommend reading this article by Sebastian Raschka. It addresses whether reasoning is a solution for all problems.

💡
Spoiler alert: it’s not the best approach for resource-constrained or simple tasks, where standard models may perform better.

Below, you’ll find a quick comparison highlighting the strengths of reasoning models.

Are reasoning models the solution for all problems

The key strengths and limitations of reasoning models. Source: https://sebastianraschka.com/blog/2025/understanding-reasoning-llms.html

DeepSeek-R1 is the primary driver of this open discussion for the reasoning models and how to train them. However, the model mainly thinks in English and Chinese. Therefore, the purpose of this tutorial is to demonstrate how to express reasoning coherently in another language while consistently following the chain-of-thought (CoT) syntax.

If you are interested in learning more about these concepts, I could recommend this reading list, which I found in Andriy Burkov’s Artificial Intelligence #261 issue (which I also highly recommend).

Now, let’s start with the actual process.

The base model

Choosing the base model is a very important step because if the base is not pre-trained in a specific language, it will be impossible to generate any meaningful text in that language.  I needed a model that I could train for 4-5 hours on ~2K samples (more details about the dataset will follow later in the article) on a 48GB VRAM GPU. Larger models would require more powerful instances.

This led me to choose Llama 3.1 8B. It is multilingual, and I found through a small experiment that it understands Bulgarian out of the box. While there are notable differences when comparing it to other models of similar size, these variations are highly specific to my language and may not apply in your case.

For instance, Qwen 2.5 could perform better for other European languages. Another option I considered is Gemma 2 9B; in fact, the current state-of-the-art Bulgarian models are based on Gemma.

However, I ultimately decided to go with Llama.

Training

The training involved following the tutorial and blog article from Unsloth.AI, adjusting some parameters, and using an appropriate dataset. I trained the model using Lightning.AI, but for larger tasks, I prefer using GPU instances from OVH, mainly due to pricing.

For the dataset, I used machine translated GSM8k in Bulgarian dataset from here. INSAIT Institute has only released the test split publicly, which I used to quickly validate my approach.

There is a full version here petkopetkov/gsm8k-bg, but the answer and their original separators (with ####) were removed, which would make the processing on my end messier and not correct.

The goal was to identify a dataset that contains fixed, well-known answers, allowing the language model to handle the reasoning aspect by providing completions. Many other datasets could be modified for this purpose, and it is also relatively straightforward to create a synthetic dataset using a larger model.

The notebooks are on the links here: Stage 1 and Stage 2.

They are too separate because I interrupted the process at the 200th step and then started again from the checkpoints. I wanted to create an additional notebook and ensure that I would keep my previous logging data.

Here you can find the full training log, including Training loss, Reward, Reward std, Completion length, and Kullback-Leibler divergence.

Alternatively, this repository contains all the notebooks and training logs in one place: https://github.com/s-emanuilov/LLMBG-Llama-3.1-8B-BG-Reasoning-v0.1

About the process in more detail. The first step involves installing the core libraries:


!pip install unsloth vllm
!pip install --upgrade pillow
!pip install diffusers
# Temporarily install a specific TRL nightly version
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b

It is essential to notice that TRL is installed from a specific version, as the GRPO is still so new. Then:

from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

Where the GRPO patch is applied to the FastLanguageModel class to make Unsloth work with GRPO. Next, we need to pick the base model, which is meta-llama/meta-Llama-3.1-8B-Instruct in this case. I choose to set max_seq_length to 1024 to provide more room for thinking and lora_rank to 64 to prioritize output quality over generation speed.


from unsloth import is_bfloat16_supported
import torch
max_seq_length = 1024
lora_rank = 64
model, tokenizer = FastLanguageModel.from_pretrained(
 model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
 max_seq_length = max_seq_length,
 load_in_4bit = True,
 fast_inference = True,
 max_lora_rank = lora_rank,
 gpu_memory_utilization = 0.65,
)
model = FastLanguageModel.get_peft_model(
 model,
 r = lora_rank,
 target_modules = [
 "q_proj", "k_proj", "v_proj", "o_proj",
 "gate_proj", "up_proj", "down_proj",
 ],
 lora_alpha = lora_rank,
 use_gradient_checkpointing = "unsloth",
 random_state = 3407,
)

I set the GPU memory utilization to 0.65 to use the GPU at a specific level. A higher ratio leads to out-of-memory (OOM) errors, while a lower one does not fully utilize the available resources.

For gradient checkpointing, I chose the “unsloth” option, which is a new algorithm that allows for fine-tuning language models with long context windows. You can find more information here.

An important aspect of this experiment is the reward functions. I focused primarily on functions that verify and reward the structure of the output.

However, I believe there is significant room for optimization in this area. I will discuss potential additional reward functions in the future improvements section. For now, let’s take a look at the current process:

import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
 answer = text.split("")[-1]
 answer = answer.split("")[0]
 return answer.strip()

def extract_hash_answer(text: str) -> str | None:
 if "####" not in text:
 return None
 return text.split("####")[1].strip()

def get_gsm8k_questions(split = "test") -> Dataset:
 data = load_dataset('INSAIT-Institute/GSM8k-bgeval', 'default')[split] # type: ignore
 data = data.map(lambda x: { # type: ignore
 'prompt': [
 {'role': 'system', 'content': SYSTEM_PROMPT},
 {'role': 'user', 'content': x['question']}
 ],
 'answer': extract_hash_answer(x['answer'])
 }) # type: ignore
 return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
 responses = [completion[0]['content'] for completion in completions]
 q = prompts[0][-1]['content']
 extracted_responses = [extract_xml_answer(r) for r in responses]
 print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
 return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
 responses = [completion[0]['content'] for completion in completions]
 extracted_responses = [extract_xml_answer(r) for r in responses]
 return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
 """Reward function that checks if the completion has a specific format."""
 pattern = r"^\n.*?\n\n\n.*?\n\n$"
 responses = [completion[0]["content"] for completion in completions]
 matches = [re.match(pattern, r) for r in responses]
 return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
 """Reward function that checks if the completion has a specific format."""
 pattern = r".*?\s*.*?"
 responses = [completion[0]["content"] for completion in completions]
 matches = [re.match(pattern, r) for r in responses]
 return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
 count = 0.0
 if text.count("\n") == 1:
 count += 0.125
 if text.count("\n\n") == 1:
 count += 0.125
 if text.count("\n\n") == 1:
 count += 0.125
 count -= len(text.split("\n\n")[-1])*0.001
 if text.count("\n") == 1:
 count += 0.125
 count -= (len(text.split("\n")[-1]) - 1)*0.001
 return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
 contents = [completion[0]["content"] for completion in completions]
 return [count_xml(c) for c in contents]

In short, they are closely following the Unsloth notebook, mainly rewarding good output syntax. But this is not bad, as it can be considered aligning the model to respond in this specific format, and further on, we can reward other aspects of the generation. Then, the GRPO configurations:

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
 use_vllm = True,
 learning_rate = 5e-6,
 adam_beta1 = 0.9,
 adam_beta2 = 0.99,
 weight_decay = 0.1,
 warmup_ratio = 0.1,
 lr_scheduler_type = "cosine",
 optim = "paged_adamw_8bit",
 logging_steps = 1,
 bf16 = is_bfloat16_supported(),
 fp16 = not is_bfloat16_supported(),
 per_device_train_batch_size = 1,
 gradient_accumulation_steps = 1,
 num_generations = 6,
 max_prompt_length = 256,
 max_completion_length = 512,
 num_train_epochs = 1,
 # max_steps = 250,
 save_steps = 100,
 max_grad_norm = 0.1,
 report_to = "none",
 output_dir = "outputs",
)

I ran a complete training epoch with num_train_epochs set to 1 and increased the completion length to a maximum of 512 tokens by setting max_completion_length.

The learning rate was configured to 5e-6, and I set save_steps to 100 for the checkpointing.

The other parameters were very similar to the standard Unsloth configurations.

Next came the trainer and the training process itself. A key point to note is the inclusion of the reward functions, where you can add your own:


trainer = GRPOTrainer(
 model = model,
 processing_class = tokenizer,
 reward_funcs = [
 xmlcount_reward_func,
 soft_format_reward_func,
 strict_format_reward_func,
 int_reward_func,
 correctness_reward_func,
 ],
 args = training_args,
 train_dataset = dataset,
)
trainer.train()

Results

In the first few steps, we can see obviously wrong answers. Example:

</отговор>
Extracted:
<причината>
На първата улица има 5 деца.
На втората улица към тези 5 деца се присъединят още 5 деца (1 на повишение), така че има 5 + 5 = 10 деца.
На третата улица към тези 10 деца се присъединяват още 10 деца (2 на повишение), така че сега има 10 + 10 = 20 деца.
На втората улица първите 5 деца се отказват, така че сега има 20 - 5 = 15 деца, които следват камиона.
</причина>
<отговор>
15
</отговор>

It is clear that the CoT prompt syntax is not being followed. While we require the use of English tags, this text employs unique tags that are not properly closed.

However, this issue improves with training.

Below, you will find a raw chart of the rewards at a given step.

I did not group the data by step, allowing you to see how it fluctuates at each stage. Nonetheless, the trend line shows a clear upward movement: rewards are increasing, and the model is learning.

GPRO training

GRPOtraining rewards on a specific step

A few conclusions could be taken on that basis:

  • Around step 190, there’s a clear breakthrough – the model’s reward scores jump significantly and become more consistent;
  • The reward functions are far from optimal; even though the trend is positive, there is not much room for further learning. The maximum reward is 3 by design, and the model improves with each step, enhancing its performance until we reach the ceiling.

Results in the final steps looks better:


Question: Лори печели $10 на час. Карън печели два пъти повече от Лори. Колко печели Карън за два дни, ако работи по 3 часа на ден? 

Answer:
120 

Response:
<reasoning>
Ако Лори печели 10 долара на час, а Карън печели два пъти повече, Карън печели 10 * 2 = 20 долара на час. Ако Карън работи 3 часа на ден, в два дни тя печели 3 * 20 * 2 = 120 долара.
</reasoning>
<answer>
120
</answer> 

Extracted (from the dataset):
120

=====

Question: Тоби чете книга с дължина 45 страници. Средно 200 думи на страница. Тоби може да чете със скорост 300 думи в минута. Той трябва да бъде на летището след 60 минути и планира да тръгне веднага щом завърши книгата. Отнема 10 минути, за да стигнете до летището. Колко минути ще подрани Тоби? 

Answer:
20 

Response:
<reasoning>
Дължината на книгата е 45 страници * 200 думи на страница = 9000 думи. Тоби може да чете 9000 думи / 300 думи в минута = 30 минути. Трябва да се добавят 10 минути за пътуването до летището, така че в общият период ще бъде 30 + 10 = 40 минути.
</reasoning>
<answer>
40
</answer> 

Extracted (from the dataset):
40

=====

Question: Ян има три пъти повече домашни любимци от Марсия. Марсия има с два повече домашни любимци от Синди. Ако Синди има четири домашни любимеца, колко общо домашни любимци имат тримата? 

Answer:
28 

Response:
<reasoning>
Ако Синди има четири домашни любимци, а Марсия има два повече домакински любимци от Синди, то Марсия има 4 + 2 = 6 домашни любимци. Ако Ян има три пъти повече домашни любимци от Марсия, то Ян има 3 * 6 = 18 домашни любимци. Общо повече има 4 + 6 + 18 = 28 домашни любимци.
</reasoning>
<answer>
28
</answer> 

Extracted (from the dataset):
28

Further improvements

As this is a proof of concept weekend project with poor data, there are many options for improvement.

First, the reward functions. The previous example illustrated how reward functions can assign points for structured output. However, many other functions could also be added future alignments. For instance, we could implement a function that rewards the generation of Cyrillic text or one that provides rewards for longer reasoning text. The possibilities are endless, as the cliché reminds us.

Second, the data. The training was based on machine-translated data using Google Translate API (source), which isn’t extensive. For scaling purposes, there are many open-source datasets available that can be utilized. Additionally, I believe the Google Translate API is not the best option for this purpose. I have achieved much better translation quality using other methods, such as DeepL, followed by verification from a larger language model. However, generating such datasets can be costly, but for GRPO, we don’t need a large number of samples. We could also incorporate synthetic data; I have obtained high-quality datasets from Claude 3.5 Sonnet.

Third, the base model. Llama 3.1 is a multilingual model trained on a large portion of the internet, which includes various languages. However, I’ve noticed that much of the generated text is not perfectly aligned with the Bulgarian language. At times, it confuses Bulgarian with Russian or English. This issue could potentially be addressed through continual pre-training, similar to the method used for the state-of-the-art Bulgarian model (BgGPT 1.0: Extending English-centric LLMs to Other Languages). I attempted to do this with LLMBG-Llama-3.1-8B-bnb-4bit, but my data is limited, and I have not achieved satisfactory results (so far). The output, characterized by messy symbols and poorly constructed sentences sometimes, remind me for DeepSeek-R1-Zero, from the original paper for DeepSeek R1.

Finally, modifying the CoT syntax to yield shorter text answers instead of strict (and therefore verifiable) responses can also be a good direction. Reward functions could use similarity measures (such as BERT score, Edit distance, etc.).

However, this approach may result in a slower training process. Quick example:


from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from nltk.translate.bleu_score import sentence_bleu
from difflib import SequenceMatcher

def text_similarity_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
   """
   Combines multiple text similarity metrics to evaluate Bulgarian answers
   """
   responses = [extract_xml_answer(completion[0]['content']) for completion in completions]
   
   def calculate_similarities(response: str, target: str) -> float:
       # Convert to lowercase for better comparison
       response = response.lower()
       target = target.lower()
       
       # Calculate similarity using multiple metrics
       try:
           # TF-IDF Cosine Similarity 
           vectorizer = TfidfVectorizer()
           tfidf_matrix = vectorizer.fit_transform([response, target])
           cosine_sim = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
           
           # Sequence Matcher
           sequence_sim = SequenceMatcher(None, response, target).ratio()
           
           # BLEU Score for n-gram overlap
           bleu_score = sentence_bleu([target.split()], response.split())
           
       except:
           return 0.0
           
       # Combine metrics with weights
       combined_score = (
           0.4 * cosine_sim +      # TF-IDF 
           0.4 * sequence_sim +     # Sequence similarity
           0.2 * bleu_score        # BLEU score
       )
       
       return min(combined_score * 2.0, 2.0)  # Scale to max 2.0
       
   return [calculate_similarities(r, a) for r, a in zip(responses, answer)]

# Example usage
test_completions = [[{
   "content": """
   
   Магазинът продава ябълки по 2 лева.
   
   
   Една ябълка струва 2 лева.
   
   """
}]]
test_answer = ["Цената на една ябълка е 2 лева."]

rewards = text_similarity_reward_func(None, test_completions, test_answer)
print(f"Similarity reward: {rewards[0]:.2f}")

Conclusion

It was an interesting journey to make the model. It was possible in a weekend, standing on the shoulders of giants like DeepSeek, HuggingFace TRL, UnslothAI, Meta and Llama, the OpenAI original dataset, and many others.

The results are not perfect, but the training budget, time spent, and data quality are negligible compared to what current providers invest (e.g., BgGPT was trained on 64 NVIDIA H100s).

The experiment shows that training a basic reasoning model in a non-English language is feasible with limited resources. While my reward functions focused on format and basic accuracy, there’s room for more complex approaches – like the text similarity metrics I outlined that could better evaluate reasoning quality in any language.

A key takeaway is that you don’t need massive computational resources to start experimenting with reasoning models. With a single GPU, some translated data, and the right open-source tools, you can build a foundation for more complex work. The GRPO technique proved particularly effective at teaching the model to maintain consistent reasoning patterns, even in a different language.

I hope this article serves as a practical guide for anyone wanting to train a reasoning model in their language. The code and approach can be adapted for other languages, and while the results won’t match commercial models, they provide a solid starting point for further research and experimentation.

Categorized in:

Deep Learning, LLMs, Tutorials,

Last Update: 09/02/2025