NyayAI: Building an AI Legal Assistant for 1.4 Billion People — A Technical Deep Dive
I'm building a startup to make Indian law accessible to every lawyer, law student, and citizen in the country. Here's the technical story of how I went from zero to a working prototype — training a foundation model from scratch, fine-tuning on 4,000 instruction pairs, and building a production-ready RAG pipeline — all as a solo founder.
The Problem
India has 1.4 billion people and roughly 50 million active legal cases pending in its courts. Lawyers spend hours — sometimes days — digging through bare acts, constitutional articles, and decades of Supreme Court judgments just to find relevant precedents for a single case. The Indian legal system operates across 25+ High Courts, hundreds of tribunals, and a Supreme Court that has delivered judgments since 1950. The sheer volume is staggering.
And yet, the tooling available to lawyers is stuck in 2005. Paid databases like SCC Online and Manupatra charge thousands per month and still require manual keyword searches. Free resources like Indian Kanoon are search-only — no summaries, no analysis, no drafting. Generic AI tools like ChatGPT hallucinate case names, invent sections that don't exist, and have no depth in Indian law.
I wanted to change that.
NyayAI (न्याय = justice in Sanskrit) is an AI-powered legal assistant that understands Indian law — not superficially, but deeply. It can look up any section of any central act, summarize Supreme Court judgments, answer complex legal questions with grounded citations, and eventually draft legal documents. Think of it as ChatGPT, but one that actually passed the bar exam for Indian law.
This blog post is a technical deep dive into everything I've built so far — the data pipelines, the model architecture decisions, the training infrastructure, and the results. Every number, every line of code, every failed experiment is documented here.
Phase 0: The 103M Parameter Experiment (The Learning Phase)
Before touching any pretrained model, I wanted to understand transformers at the deepest level. Not "import transformers and call .fit()" — I mean implementing a GPT-style transformer from scratch in PyTorch.
Architecture
I built a decoder-only transformer with the following specifications:
| Parameter | Value |
|---|---|
| Total Parameters | 103,457,280 (~103M) |
| Layers | 9 |
| Attention Heads | 12 |
| Embedding Dimension | 768 |
| Context Window | 512 tokens |
| Vocabulary Size | 50,257 (GPT-2 tokenizer) |
| Output Head | Weight-tied with embedding layer |
The model was trained on 269 million tokens (1.25 GB) of Indian legal text — the same corpus I'd later use for the production pipeline. Training ran on NVIDIA A100 GPUs via Modal for 2 epochs across 59,000 gradient steps.
Results
| Metric | Value |
|---|---|
| Final Validation Loss | 2.46 |
| Perplexity | 11.7 |
| Training Time | ~8 hours |
A perplexity of 11.7 on legal text means the model learned the structure and vocabulary of Indian legal language reasonably well. It could generate coherent legal-sounding text, but it was not a useful model — it had no instruction-following capability and no factual grounding. It was a learning exercise, and it served its purpose brilliantly.
Key Takeaway: Building a transformer from scratch taught me more about attention mechanisms, positional encoding, loss landscapes, and gradient dynamics than any course or paper ever could. If you're serious about ML, I strongly recommend doing this at least once.
Phase 1: Data Acquisition — The Foundation of Everything
A model is only as good as its data. For NyayAI, I needed three categories of legal text:
- The Constitution of India — the supreme law, 395+ articles
- Central Acts (Bare Acts) — the 858 laws passed by Parliament
- Supreme Court Judgments — 75 years of case law (1950–2025)
1A. The Constitution of India
Source: A structured JSON file containing all articles with metadata (article number, title, description).
Pipeline: prepare_constitution.py — a straightforward JSON-to-text converter that:
- Parses each article from the JSON
- Cleans escaped newlines and normalizes whitespace
- Preserves repealed articles with notation
- Formats as structured text with
Article N — Titleheaders - Separates each article with
<|endoftext|>tokens for clean document boundaries
Output:
| Metric | Value |
|---|---|
| Articles Processed | 395+ (including amendments) |
| Output File | constitution_training.txt |
| File Size | 502 KB |
| Estimated Tokens | ~106,000 |
The Constitution is small but dense — every article matters. The Preamble alone is one of the most frequently cited legal texts in Indian jurisprudence.
1B. Central Acts (858 Bare Acts)
This was significantly more complex. India has 858 central acts in force, ranging from the Indian Penal Code (1860) to the Digital Personal Data Protection Act (2023). These were stored as deeply nested JSON files with a schema that included:
Act Title, Act ID, Enactment Date, Act Definition
├── Chapters/Parts
│ ├── Sections
│ │ └── Paragraphs (strings or nested dicts with text/contains)
│ └── Subheadings
│ └── Sections
├── Schedules, Annexures, Appendix, Forms
└── Footnotes
Pipeline: prepare_central_acts.py — a recursive JSON traversal engine that:
- Handles BOM encoding — many Indian government JSON files contain a byte-order mark
-
Recursively extracts paragraphs — the
extract_paragraphs()function handles arbitrarily nestedtext/containsstructures with proper indentation -
Cleans legislative artifacts — removes footnote reference numbers (
\d+\[→[), strips decorative markers (* * * * *) -
Sorts sections numerically — a custom
sort_key()function ensures Section 2 comes before Section 10 (not after, as string sorting would do) - Processes chapters, subheadings, schedules, annexures, and footnotes — preserving the full hierarchical structure
-
Outputs with
<|endoftext|>boundaries between each act
Output:
| Metric | Value |
|---|---|
| Acts Processed | 858 |
| Output File | central_acts_training.txt |
| File Size | 29.9 MB |
| Total Words | ~5,076,000 |
| Estimated Tokens | ~6,600,000 |
1C. Supreme Court Judgments (1950–2025)
This was the heavy lift — and the most valuable data. The Supreme Court of India has delivered tens of thousands of judgments over 75 years. I sourced these from the AWS Open Data Registry (s3://indian-supreme-court-judgments), a public bucket containing judgment PDFs and metadata JSONs organized by year.
Step 1: Download — download_sc_judgements.py
- Uses
boto3with unsigned requests (public bucket, no auth needed) - Downloads English judgment tar files and metadata tar files for each year (1950–2026)
- Implements resume support — skips files that already exist with correct size
- Downloads to
data/sc_judgments/anddata/sc_metadata/ - Progress logging with download speed tracking
Step 2: Extract & Process — prepare_sc_judgments.py
This script is the most complex in the entire pipeline. It:
- Extracts metadata tars — unpacks year-by-year JSON metadata files
-
Parses metadata HTML — each judgment's metadata is stored as raw HTML. The
parse_metadata_html()function uses regex to extract:- Case title (petitioner vs respondent)
- Judges/Coram
- Decision date
- Case number
- Bench size
- Citation
- Disposal nature
-
Extracts text from PDFs — uses PyMuPDF (fitz) to extract text from judgment PDFs, then
clean_judgment_text()removes:- Page headers/footers ("SUPREME COURT REPORTS", standalone page numbers)
- Excessive whitespace
- Year-only lines (standalone "1950", "2023", etc.)
- Matches PDFs to metadata — correlates each PDF with its extracted case metadata by path key
- Formats each judgment as a structured document with a header block (title, citation, case number, date, bench, disposal) followed by the full judgment text
- Processes year-by-year — streams output to avoid loading 1.5 GB of text into memory at once
Output:
| Metric | Value |
|---|---|
| Judgments Processed | 43,324 |
| Output File | sc_judgments_training.txt |
| File Size | 1.49 GB (1,588,861,395 bytes) |
| Total Words | ~261,000,000 |
| Estimated Tokens | ~339,300,000 |
| Time Span | 1950–2025 (75 years) |
Total Corpus Summary
| Source | File Size | Tokens (est.) |
|---|---|---|
| Constitution of India | 502 KB | ~106K |
| Central Acts (858 acts) | 29.9 MB | ~6.6M |
| SC Judgments (43,324 cases) | 1.49 GB | ~339.3M |
| Total | ~1.52 GB | ~346 Million |
This is a genuinely massive legal corpus — 346 million tokens of structured, cleaned Indian legal text spanning 75 years of Supreme Court jurisprudence, the entire Constitution, and every central act in force. For context, the original GPT-2 was trained on ~8 billion tokens. Our corpus alone represents a meaningful fraction of that, but hyper-focused on a single domain.
Phase 1.5: Synthetic Instruction Dataset Generation
A language model that can continue legal text is interesting but not useful. To make it follow instructions — answer questions, summarize cases, compare sections — I needed an instruction-response dataset.
Creating thousands of high-quality legal Q&A pairs by hand was not feasible. Instead, I built a synthetic data generation pipeline using Google's Gemini API.
Pipeline: generate_synthetic_qa.py
The approach:
-
Random chunk sampling — for each batch, randomly select a ~40,000 character chunk from one of the three source files, with a weighted distribution:
- 60% Supreme Court judgments (largest, most diverse)
- 30% Central Acts (statute-heavy, structured)
- 10% Constitution (fundamental, frequently referenced)
-
Structured prompting — each chunk is sent to
gemini-3.1-flash-litewith a carefully crafted prompt that enforces:- No hallucination — responses must be based strictly on the provided text excerpt
-
Diversity in length and complexity — each batch of 5 pairs follows a prescribed format:
- Task 1: Very Long (3-4 paragraph comprehensive summary/brief)
- Task 2: Medium (legal argument/analysis)
- Task 3: Medium (comparison of concepts)
- Task 4: Short (direct factual question)
- Task 5: Short (yes/no client question with explanation)
-
Structured output — uses Pydantic models (
QAPair,QAResponse) withresponse_mime_type: application/jsonfor reliable parsing
Incremental saving — pairs are appended to a JSONL file as they're generated, with a running count. Supports resume (checks existing pair count on startup).
Rate limiting — 4-second sleep between requests to respect the free tier (15 RPM).
Schema
Each output pair in legal_instruction_dataset.jsonl:
{
"instruction": "What are the key provisions of Section 14 of the Hindu Succession Act?",
"response": "Section 14 of the Hindu Succession Act, 1956, deals with...",
"source_type": "Bare Act"
}
Output
| Metric | Value |
|---|---|
| Target Pairs | 4,000 |
| Generated Pairs | ~4,000 |
| Output File | legal_instruction_dataset.jsonl |
| File Size | 2.09 MB |
| Source Distribution | 60% judgments, 30% acts, 10% constitution |
| Generation Model | Gemini 3.1 Flash Lite |
| Cost | $0 (free tier API) |
The critical insight here: the quality of your instruction data matters far more than quantity. The original Stanford Alpaca paper used only 52K pairs to teach instruction-following to LLaMA. For a domain-specific model, 2,000-4,000 high-quality, grounded pairs are more than enough — as long as they're diverse in task type and faithful to the source material.
Phase 2: Fine-Tuning — Teaching the Model Indian Law
With data in hand, it was time to take a state-of-the-art pretrained model and teach it to be an Indian legal expert.
Model Selection: Qwen-3 4B Instruct
After evaluating several sub-6B parameter models (Phi-4-mini, SmolLM3-3B, Gemma-3n-E2B), I chose Qwen-3 4B Instruct (2507 variant) for several reasons:
| Factor | Why Qwen-3 4B |
|---|---|
| Reasoning | Exceptional chain-of-thought and instruction following |
| Multilingual | Strong Hindi support (critical for Indian legal market) |
| Architecture | Modern optimizations, efficient attention |
| Ecosystem | Massive HuggingFace community, well-documented |
| License | Apache 2.0 — fully commercial use |
| Size | 4B parameters — fits in a single L4 GPU (24GB) in bfloat16 |
Training Infrastructure
Everything runs on Modal — a serverless GPU cloud that lets you define your entire training pipeline in a single Python file and run it with one command.
# From finetune_training.py
app = modal.App("NyayAI_FineTuning")
image = (
modal.Image.debian_slim()
.pip_install_from_requirements("finetune_requirements.txt")
.add_local_file("./data/legal_instruction_dataset.jsonl",
remote_path="/app/legal_instruction_dataset.jsonl")
)
@app.cls(
image=image,
gpu="L4", # NVIDIA L4, 24GB VRAM
volumes={"/checkpoints": CHECKPOINTS_VOL},
timeout=36000 # 10 hours
)
class FineTuneScript:
...
The entire training pipeline — from data loading to checkpoint saving — is defined in a single file (finetune_training.py, 677 lines) and executes remotely on Modal. Checkpoints are saved to a Modal Volume and automatically downloaded to my local machine after each epoch.
LoRA: Training Smart, Not Expensive
Fine-tuning all 4 billion parameters would require multiple GPUs and cost hundreds of dollars. Instead, I implemented LoRA (Low-Rank Adaptation) from scratch — no HuggingFace PEFT library, no Unsloth, no shortcuts.
How LoRA Works
Instead of updating the full weight matrix W (size d × d), LoRA decomposes the update into two small matrices:
W' = W + α(A × B)
where A is (d × r) and B is (r × d), and r << d
For rank r=16 and dimension d=768, instead of updating 589,824 parameters per layer, you're updating 16×768 + 16×768 = 24,576 parameters — a 24x reduction.
Implementation
class LORALayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha):
super().__init__()
self.A = torch.nn.Linear(in_dim, rank, bias=False)
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
self.alpha = alpha
def forward(self, x):
return self.alpha * (self.A(x) @ self.B)
class LinearWithLoRA(torch.nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LORALayer(linear.in_features, linear.out_features, rank, alpha)
def forward(self, x):
return self.linear(x) + self.lora(x)
The B matrix is initialized to zeros, so at the start of training, LoRA(x) = α × (A(x) @ 0) = 0. The model starts exactly where the pretrained model left off — no disruption. As training progresses, the LoRA layers learn domain-specific adaptations while the base model stays frozen.
Target Modules
LoRA adapters were injected into the attention layers only:
lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
The injection uses recursive module replacement:
def replace_linear_with_lora(model, rank, alpha):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear) and name in target_modules:
setattr(model, name, LinearWithLoRA(module, rank, alpha))
else:
replace_linear_with_lora(module, rank, alpha)
Hyperparameters
| Parameter | Value | Rationale |
|---|---|---|
| LoRA Rank | 16 | Sweet spot: enough capacity for domain adaptation without overfitting on ~4K pairs |
| LoRA Alpha | 32 | α/r = 2.0 scaling factor — standard choice |
| Peak Learning Rate | 2e-5 | Conservative — avoiding catastrophic forgetting of base model knowledge |
| Minimum Learning Rate | 2e-6 | 10x decay from peak |
| Warmup Steps | 50 | Quick ramp to prevent early instability |
| Batch Size | 4 | Fits in L4 VRAM with gradient checkpointing |
| Max Sequence Length | 8,192 | Full context window of Qwen-3 |
| Weight Decay | 0.1 | Standard regularization |
| Gradient Clipping | 1.0 (max norm) | Prevents exploding gradients on long legal sequences |
| Optimizer | AdamW | Only over LoRA parameters |
| Precision | bfloat16 | Native on L4, no precision loss for this scale |
| Epochs | 2 | Sufficient for convergence on this dataset size |
Parameter Efficiency
| Category | Count |
|---|---|
| Total Model Parameters | ~4,000,000,000 |
| Frozen (Base Model) | ~3,988,200,000 |
| Trainable (LoRA) | ~11,800,000 |
| Parameter Ratio | ~0.30% |
We're training less than 0.3% of the model's parameters. The LoRA adapter checkpoint is ~135 MB — compared to the full model's ~8 GB in bfloat16.
Data Formatting: ChatML
Every instruction-response pair is formatted in ChatML (the template Qwen expects):
<|im_start|>system
You are an expert Indian Legal Assistant.<|im_end|>
<|im_start|>user
What are the key provisions of Section 14 of the Hindu Succession Act?<|im_end|>
<|im_start|>assistant
Section 14 of the Hindu Succession Act, 1956, is a landmark provision...<|im_end|>
Custom Collation: Dynamic Batch Padding
Rather than padding all sequences to the maximum model length (8,192 tokens), I implemented dynamic batch padding — each batch is padded only to the length of its longest sequence:
def custom_collate_fn(batch, ignore_index=-100):
batch_max_length = max(len(item["input_ids"].squeeze()) + 1 for item in batch)
# ... pad each item to batch_max_length, not max_seq_length
This saves enormous amounts of compute. If a batch's longest sequence is 1,200 tokens, we're processing 1,200 × 4 = 4,800 tokens instead of 8,192 × 4 = 32,768 tokens. On average, this reduces compute by ~70-80%.
Padding tokens in the target sequence are masked with ignore_index=-100, so the loss function ignores them entirely.
Learning Rate Schedule: Cosine with Linear Warmup
def get_lr(self, step, total_steps):
if step < self.warmup_steps:
return self.peak_lr * (step + 1) / self.warmup_steps
progress = (step - self.warmup_steps) / max(1, total_steps - self.warmup_steps)
return self.min_lr + 0.5 * (self.peak_lr - self.min_lr) * (
1 + math.cos(math.pi * progress)
)
The schedule:
- Linear warmup (0 → 2e-5 over 50 steps) — prevents early training instability
- Cosine decay (2e-5 → 2e-6 over remaining steps) — smooth convergence without sharp drops
Memory Optimization: Gradient Checkpointing
With 4B parameters in bfloat16, the model alone takes ~8GB of VRAM. Add optimizer states, gradients, and activations for 8,192-token sequences, and you blow past 24GB easily.
Gradient checkpointing solves this by trading compute for memory:
self.model.gradient_checkpointing_enable()
Instead of storing all intermediate activations during the forward pass (for use in the backward pass), it recomputes them on-the-fly during backpropagation. This costs ~30% more compute time but saves ~40% VRAM — the difference between fitting and OOM.
Dataset Split
| Split | Size | Purpose |
|---|---|---|
| Train | 85% (~3,400 pairs) | Model training |
| Validation | 5% (~200 pairs) | Loss monitoring, overfitting detection |
| Test | 10% (~400 pairs) | Final evaluation (held out) |
Fault-Tolerant Training: The Generator Pattern
Training on cloud GPUs can fail for many reasons — preemption, network issues, timeouts. The training loop uses Python's generator pattern (yield) to stream results back to the local machine after each epoch:
@modal.method()
def train(self, ...):
for epoch in range(num_epochs):
# ... training loop ...
# Save checkpoint to Modal Volume
checkpoint_info = self.save_checkpoint(...)
# Yield results back to local machine
yield {
"type": "epoch_complete",
"epoch": epoch + 1,
"train_loss": train_losses[-1],
"val_loss": val_losses[-1],
"checkpoint_info": checkpoint_info,
"log": epoch_log,
}
The local entrypoint catches each yield and immediately downloads the checkpoint and training log:
for result in trainer.train.remote_gen(...):
if result_type == "epoch_complete":
# Download checkpoint from Modal Volume to local disk
with open(local_epoch_path, "wb") as f:
for chunk in CHECKPOINTS_VOL.read_file(rel_path):
f.write(chunk)
This means even if training crashes after epoch 1 completes, I already have the epoch 1 checkpoint and logs downloaded locally. Training can be resumed from any saved checkpoint using the --resume-from flag.
Checkpointing: LoRA Weights Only
Checkpoints save only the LoRA parameters plus optimizer state — not the full 8GB model:
lora_state = {
k: v for k, v in self.model.state_dict().items()
if "lora" in k.lower()
}
payload = {
"lora_state_dict": lora_state,
"optimizer_state_dict": self.optimizer.state_dict(),
"lora_config": {"rank": self.lora_rank, "alpha": self.lora_alpha},
}
Each checkpoint is ~135 MB instead of ~8 GB. Fast to save, fast to download, fast to resume.
Training Results
Training ran for 2 full epochs on an NVIDIA L4 GPU (24GB VRAM) via Modal.
Loss Curves
The training produced detailed per-step metrics logged every 5 steps:
| Metric | Epoch 1 End | Epoch 2 End (Final) |
|---|---|---|
| Training Loss | ~1.05 | ~0.69 |
| Validation Loss | ~1.00 | ~0.92 |
| Learning Rate | ~1.2e-5 (mid-decay) | ~2e-6 (minimum) |
| Tokens Processed | ~4.5M | ~9.0M |
| Global Steps | ~850 | ~1,700 |
Key Observations
- Smooth convergence — no loss spikes, no instability. The warmup + cosine schedule + gradient clipping combination worked perfectly.
- No overfitting — validation loss tracked training loss closely throughout. The gap widened slightly in epoch 2 (0.69 vs 0.92), which is expected and healthy.
- Rapid initial learning — the steepest loss drop happened in the first 200 steps of epoch 1, as the model quickly adapted to the legal domain's vocabulary and style.
- Diminishing returns in epoch 2 — most of the learning happened in epoch 1. Epoch 2 provided refinement but the marginal improvement was smaller.
Sample Generation
After each epoch, the model generates a response to a test prompt:
Prompt: "What are the key provisions of Section 14 of the Hindu Succession Act?"
The model produced a legally accurate, well-structured response citing the correct provisions — demonstrating genuine domain knowledge transfer from the training data.
Checkpoint Files
| File | Size | Contents |
|---|---|---|
epoch_1_lora_adapter.pth |
135.3 MB | LoRA weights after epoch 1 |
epoch_2_lora_adapter.pth |
135.3 MB | LoRA weights after epoch 2 |
lora_adapter.pth |
135.3 MB | Final LoRA weights |
finetune_log_epoch_1.json |
11 KB | Per-step metrics (epoch 1) |
finetune_log_epoch_2.json |
22 KB | Per-step metrics (both epochs) |
finetune_training_log.json |
21 KB | Complete training log |
Training Metric Visualization
I built a dedicated plotting script (plot_finetune.py) that generates four separate publication-quality plots:
- Loss vs. Steps — training and validation loss curves with epoch boundaries
- Loss vs. Tokens — same curves but x-axis is total tokens processed
- Learning Rate Schedule — visualizes the warmup + cosine decay
- Convergence Zoom — final 50% of training, zoomed in on the convergence behavior
Each plot is saved as a separate PNG with dark theme styling, grid lines, and proper annotations.
The Training Command
The entire fine-tuning run is launched with a single command:
modal run finetune_training.py --num-epochs 2 --eval-freq 5 --eval-iter 5
That's it. Modal provisions an L4 GPU, pulls the Docker image, loads the dataset, downloads Qwen-3 4B, injects LoRA adapters, trains for 2 epochs, saves checkpoints to a volume, and streams results back to my machine. When it's done, I have everything locally.
To resume from a checkpoint (if training was interrupted):
modal run finetune_training.py --resume-from "finetune_runs/20260518-041234/epoch_1_lora_adapter.pth"
Infrastructure Costs
| Phase | GPU | Time | Cost |
|---|---|---|---|
| Data Processing | CPU only | ~2 hours | $0 (local) |
| Synthetic QA Generation | None (API) | ~6 hours | $0 (Gemini free tier) |
| Fine-Tuning (2 epochs) | L4 (24GB) | ~2 hours | ~$3-5 (Modal) |
| Total | < $5 |
Read that again. The entire training pipeline — from raw legal text to a fine-tuned 4B parameter model — cost less than a cup of coffee.
Phase 3: The Production RAG Pipeline — Architecture, Sharding, & Serving
A fine-tuned model knows how to talk like a legal expert, but it doesn't remember specific facts. When a lawyer asks "What does Section 34 of the Indian Trusts Act say?", a model might generate something that sounds legally plausible but is entirely fabricated.
To solve this, I designed and built a production-grade, highly optimized RAG (Retrieval-Augmented Generation) pipeline. This lookup mechanism allows our fine-tuned Qwen model to query a massive vector database of Indian law, extract the exact legal provisions, and generate answers strictly grounded in the source material with pinpoint citations.
Here is the exact technical story of how the RAG pipeline was compiled, optimized, and served.
3A. LoRA Adapter Merging (merge_lora.py)
Running a model with active LoRA weights in production adds computational overhead and complicates serving. To achieve maximum inference speed and simplify deployment, I wrote merge_lora.py to mathematically blend the LoRA weights directly into the base Qwen-3 4B parameters.
-
How it works: The script loads the base model weights, recursively injects our trained LoRA layers (loading our
lora_adapter.pthcheckpoint), and fuses the weight updates back into the main matrices: $$W_{\text{merged}} = W_{\text{base}} + \frac{\alpha}{r} (A \times B)$$ -
Result: Fused 144 adapter projection layers in exactly 20.4 seconds. The final standalone model (~7.5 GB in
bfloat16precision) was saved directly to the/checkpoints/merged_modeldirectory on the persistent Modal Volume.
3B. Structure-Aware Legal Chunking (chunk_legal_docs.py)
Legal documents have natural, highly structured segmentations (articles, sections, subsections). Naive chunking (e.g., splitting every 500 characters blindly) splits legal clauses in half, completely ruining retrieval precision.
I wrote chunk_legal_docs.py to parse the three source document types into structured chunks while preserving critical legal metadata mappings:
- Constitution of India: Split by Article bounds $\rightarrow$ 468 chunks (average 1,025 characters).
- Central Acts: Split recursively by Section bounds $\rightarrow$ 23,152 chunks (average 1,364 characters).
- Supreme Court Judgments: Split by structured paragraphs, filtering out court headers $\rightarrow$ 330,673 chunks (average 4,756 characters).
-
Output: 354,293 chunks compiled into a single 1.6 GB JSONL file (
chunks.jsonl). Each chunk contains its text,chunk_id, and a metadata dictionary mapping its original source attributes (e.g.,article_number,act_title,section,case_title,year).
3C. Massively Parallel GPU Map-Reduce Sharding (embed_modal.py)
Generating vector embeddings for 354,293 documents using a state-of-the-art multi-lingual model (BGE-M3) would take days on a single machine. To solve this, I built a highly distributed Map-Reduce pipeline using Modal.
graph TD
A[chunks.jsonl <br> 354,293 chunks] --> B[Coordinator Function]
B -->|Split into 32 shards| C[Shard Inputs]
C -->|Shard 0| D1[L4 GPU Worker 1]
C -->|Shard 1| D2[L4 GPU Worker 2]
C -->|...| D3[L4 GPU Worker ...]
C -->|Shard 31| D4[L4 GPU Worker 32]
D1 -->|Embed FP16| E1[11,000 vectors]
D2 -->|Embed FP16| E2[11,000 vectors]
D3 -->|Embed FP16| E3[... vectors]
D4 -->|Embed FP16| E4[11,000 vectors]
E1 --> F[Reduce / Concatenate]
E2 --> F
E3 --> F
E4 --> F
F --> G[(FAISS Index FlatIP <br> 354,293 x 1024)]
F --> H[(SQLite chunk_lookup.db)]
Here is how the distributed pipeline operates:
- The Map Phase: The coordinator divides the 354K chunks into 32 shards (~11,000 chunks per shard). Modal automatically spins up 32 parallel L4 GPU containers in the cloud simultaneously.
- Pre-Caching & Instant Boot: The BGE-M3 model weights are baked directly into the Docker image layer, bypassing HuggingFace downloads and enabling the GPU servers to boot instantly.
-
FP16 Inference: Each worker runs native PyTorch
float16inference over its 11,000 texts, generating normalized dense embeddings in a fraction of the time. -
The Reduce Phase: The coordinator gathers the 32 output matrices, concatenating them in chronological order into a single dense matrix of shape
(354293, 1024). -
FAISS Index Compilation: The combined embeddings are fed into a FAISS
IndexFlatIP(Cosine similarity) index and saved asfaiss_index.bin. Simultaneously, achunk_lookup.jsondictionary is generated on the volume. - Compute Time: The entire parallel sharding execution finished in under 20-30 minutes of total wall time!
3D. Production FastAPI Serving & Optimizations (rag_inference.py)
To serve the RAG assistant, I built an extremely optimized FastAPI server hosted on Modal (rag_inference.py). It loads the merged Qwen model and BGE-M3 on a single cost-effective L4 GPU.
1. Zero-RAM SQLite Lookup Database (Startup Optimization)
-
The Problem: Reading the 1.6 GB
chunk_lookup.jsoninto container memory on boot takes almost 2 minutes and consumes 1.6 GB of RAM, causing bottlenecks and potential Out-Of-Memory crashes. -
The Solution: On first startup, the server streams the JSON file line-by-line and compiles a local SQLite database (
chunk_lookup.db) directly on the persistent volume (took 92.6s). On subsequent container boots, the JSON file is completely bypassed. - The Result: The server opens a thread-safe SQLite connection instantly on boot (0.001 seconds) and consumes 0 MB of startup RAM overhead.
2. VRAM Autocasting & Thread-Safe Real-time Streaming
-
Autocasting: Inside the generation thread, both token lookup and model generation are wrapped in
torch.inference_mode()andtorch.autocast(device_type="cuda", dtype=torch.bfloat16)to guarantee zero memory spikes and peak execution speed. -
ASGI Protection: Real-time token streaming is exposed via standard Server-Sent Events (SSE) at
/api/ask/stream. Because LLM token generation is CPU/GPU bound, running it synchronously inside an async FastAPI server freezes the async event loop, blocking all other users. I wrapped theTextIteratorStreamerinside a separate native OSThreadand fed tokens into a synchronous streaming generator, which FastAPI's thread-pool handles safely without freezing the server. -
Strict EOS Enforcement: The system dynamically extracts the
<|im_end|>and<|endoftext|>token IDs at tokenizer boot to strictly enforce early stopping and prevent model chat hallucinations.
3. Absolute Cost Safety
- The server class uses the
@app.cls(...)decorator withmin_containers=0. - When there are no active users, the server scales down to zero active GPU containers, costing you exactly $0.00 in hosting fees.
- When a user query arrives, a cold start boots the image and models in ~10 seconds. Subsequent user queries are served instantly in milliseconds.
3E. Verification & End-to-End Test Results
We verified both endpoints against our active server using test_endpoint.py. The results are spectacular and highly accurate:
1. Blocking API Endpoint (/api/ask)
- Query: "What does Article 21 of the Indian Constitution guarantee?"
-
Status:
200 OK - Total Latency: 5.34 seconds
- Generated Answer: > Article 21 guarantees the right to life and personal liberty. The Supreme Court has interpreted this right expansively, noting that it is not limited to mere survival but encompasses the right to live with dignity. This includes the right to privacy, which is viewed as an inalienable component of personal liberty. Additionally, Article 21 mandates that no person shall be deprived of life or personal liberty except according to procedure established by law, which must be fair, just, and reasonable.
-
Sources Used:
[SC_JUDGMENTS] Supreme Court: K.S. Puttaswamy v. Union of India (2017)[SC_JUDGMENTS] Supreme Court: Common Cause v. Union of India (2017)[SC_JUDGMENTS] Supreme Court: X v. Union of India (2023)
2. Streaming API Endpoint (/api/ask/stream)
- Query: "What are the grounds for divorce under the Hindu Marriage Act?"
-
Status:
200 OK -
Stream Event 1 (Metadata Block):
-
Source 1: [CENTRAL_ACTS] THE DIVORCE ACT, 1869 - Section 10 -
Source 2: [SC_JUDGMENTS] Supreme Court: Naveen Kohli v. Neelu Kohli (2006) -
Source 3: [CENTRAL_ACTS] THE HINDU MARRIAGE ACT, 1955 - Section 13
-
- Stream Event 2+ (Word-by-Word Tokens): > Under the Hindu Marriage Act, 1955, the grounds for divorce include: (i) adultery or voluntary sexual intercourse with another person; (ia) cruelty; (ib) desertion for a continuous period of not less than two years; (ii) conversion to another religion; (iii) incurable unsoundness of mind or mental disorder to the extent that cohabitation is unreasonable; (iv) virulent and incurable leprosy; (v) venereal disease in a communicable form; (vi) renunciation of the world by entering a religious order; and (vii) not being heard of as alive for seven years or more. Additionally, Section 13A provides grounds for divorce following a decree for judicial separation or restitution of conjugal rights where cohabitation or restitution has not been resumed for one year or upwards.
Project Structure
Nyay-ai/
├── data/
│ ├── constitution_of_india.json # Raw Constitution JSON
│ ├── prepare_constitution.py # Constitution → training text
│ ├── CentralActs/ # 858 act JSON files
│ ├── prepare_central_acts.py # Acts → training text
│ ├── sc_judgments/ # Downloaded SC judgment tars
│ ├── sc_metadata/ # Downloaded metadata tars
│ ├── download_sc_judgements.py # S3 downloader
│ ├── prepare_sc_judgments.py # PDFs → training text
│ ├── generate_synthetic_qa.py # Gemini-powered QA generation
│ ├── constitution_training.txt # 502 KB — processed Constitution
│ ├── central_acts_training.txt # 29.9 MB — processed Acts
│ ├── sc_judgments_training.txt # 1.49 GB — processed Judgments
│ └── legal_instruction_dataset.jsonl # 2.09 MB — 4K instruction pairs
├── rag_data/
│ ├── faiss_index.bin # 1.3 GB compiled FAISS vector index
│ ├── chunk_lookup.db # SQLite DB for 0ms startup seek
│ ├── chunk_lookup.json # 1.5 GB raw JSON backup
│ └── embedding_metadata.json # Embedding sharding stats
├── checkpoints/
│ ├── merged_model/ # Merged standing Qwen-3 BF16 weights (7.5GB)
│ ├── lora_adapter.pth # 135 MB — final adapter weights
│ ├── finetune_training_log.json # Complete training log
│ └── plot_*.png # Training metric visualizations
├── chunk_legal_docs.py # Structure-aware legal document chunker
├── embed_modal.py # Modal Map-Reduce sharding embedding pipeline
├── merge_lora.py # Weights merging utility
├── rag_inference.py # FastAPI server with SQLite and thread-safe streaming
├── test_endpoint.py # Local E2E verification test suite
├── finetune_training.py # 677-line training pipeline
├── plot_finetune.py # Metric visualization script
├── utils.py # Shared utilities
├── BLOG.md # Complete project write-up
└── readme.md
Lessons Learned
1. Data Quality > Data Quantity
4,000 carefully structured instruction pairs, generated from real legal text with strict anti-hallucination prompting, taught the model more than 50,000 sloppy pairs would have. The key was enforcing diversity in both task type (summaries, comparisons, Q&A, yes/no) and length (1 sentence to 4 paragraphs).
2. Standalone Merged Models are Faster and Cleaner
Merging the LoRA weights directly into the base parameters via merge_lora.py completely eliminated inference-time adapter overhead, trimmed memory footprints, and allowed the base model to load at peak native speeds.
3. Bypass JSON in Production with SQLite
Loading large JSON configuration or lookup files (1.6GB+) is a silent killer for cloud instances. Standardizing on local, thread-safe SQLite databases (chunk_lookup.db) mounted on persistent volumes is an absolute game-changer. It dropped boot overhead from 2 minutes to 0.001 seconds while consuming 0 MB of startup RAM.
4. GPU Sharding for Rapid Large-Scale Embeddings
Attempting to embed 354,000+ texts sequentially is a nightmare. Using Modal to orchestrate 32 parallel L4 GPUs concurrently allowed us to embed the entire dataset in ~20 minutes for under a few dollars in active cloud compute.
5. Always Scale to Zero when Idle
For bootstrapped startups, keeping GPU endpoints active is an unnecessary expense. Leveraging min_containers=0 on serverless providers like Modal allows you to host fully functional, complex RAG prototypes completely free of charge when idle.
Technical Specs Summary
| Component | Specification |
|---|---|
| Base Model | Qwen-3 4B Instruct (2507) |
| Merged Model | Standing bfloat16 standalone weights (~7.5 GB) |
| Embedding Model | BAAI/bge-m3 (dense vector, FP16 precision) |
| FAISS Vector Index | IndexFlatIP (Cosine Similarity, 1024 dimensions) |
| Total Database Chunks | 354,293 chunks (1.6 GB corpus) |
| Lookup Engine | Thread-safe local SQLite database (chunk_lookup.db) |
| Server Framework | FastAPI (with SSE Server-Sent Events token streaming) |
| Concurrence Model | Native multi-thread worker with TextIteratorStreamer |
| API Endpoints |
/api/ask (Blocking), /api/ask/stream (SSE Real-time Streaming) |
| Hosting Platform | Modal (Serverless GPU Cloud) |
| GPU Target | NVIDIA L4 (24GB VRAM) |
| Production Scale |
min_containers=0 (Scales to zero when idle for $0.00/hr) |
| E2E Average Latency | 5.34 seconds for full blocking answer / milliseconds for streaming |
| Fine-Tuning Cost | < $5.00 |
| Embedding Cost | < $4.00 |
Follow This Project
NyayAI has successfully moved from raw data to a fully working, highly optimized production-grade legal assistant. The prototype is complete, responsive, and ready to scale.
- GitHub: github.com/AshishRaj04/NyayAI-100M-Parameter-Legal-Foundation-Model
- Stack: PyTorch · Modal · Qwen-3 · FAISS · BGE-M3 · SQLite · FastAPI · Next.js
Built with obsession by a solo founder who believes every Indian deserves access to justice — and that the right AI can make that happen.





















