Otimização LoRA: Módulos Alvo para LLMs Eficientes
Introdução à Otimização de LLMs com LoRA
Ajustar um grande modelo de linguagem (LLM) para uma tarefa específica exige atualizações em bilhões de parâmetros ao longo de trilhões de tokens, com os custos inerentes em recursos de GPU e tempo. A Adaptação de Baixa Classificação (LoRA) surge como uma alternativa mais eficiente, que congela os pesos originais do modelo, mas introduz matrizes leves em submódulos específicos do modelo, ou "módulos".
Essas matrizes, comumente chamadas de "adaptadores", modificam os pesos dos módulos, possibilitando não apenas um ajuste fino eficiente, mas também:
- Serviço de modelo sob demanda, o que reduz drasticamente os custos de inferência.
- Compartilhamento do modelo base entre GPUs, diminuindo os requisitos de memória.
- Menor sobrecarga de download.
- Inferência paralela entre múltiplos adaptadores.
A questão central reside em onde inserir esses adaptadores no modelo. Empiricamente, direcionar mais módulos e módulos maiores tende a aumentar o desempenho, pois permite maior flexibilidade na personalização, mas também eleva os custos de treinamento e inferência. Utilizar um subconjunto menor e bem escolhido preserva a maioria dos ganhos com uma eficiência significativamente melhor.
Utilizando o LLM de raciocínio multimodal Nova 2.0 Lite da Amazon como nosso modelo base, estabelecemos o objetivo de identificar um subconjunto de configurações padronizadas de módulos alvo que funcionem eficazmente na vasta maioria dos casos de uso de clientes.
A Arquitetura Transformer
Modelos Transformer — os modelos responsáveis por todos os recentes e notáveis avanços da IA — consistem em grande parte de blocos repetidos múltiplas vezes. Cada bloco, por sua vez, possui dois componentes principais: um mecanismo de atenção, que determina a relevância dos tokens vistos anteriormente para o token atualmente processado, e uma rede feed-forward, uma rede neural convencional que realiza processamento adicional nas saídas do mecanismo de atenção.
O mecanismo de atenção envolve três matrizes diferentes, que recebem seus nomes do design de banco de dados:
- Matriz de consulta (query): Representa a relevância do token atual para os outros tokens na sequência de entrada.
- Matriz de chave (key): Representa a relevância de outros tokens entre si.
- Matriz de valor (value): Representa o conteúdo bruto desses outros tokens.
A multiplicação das três matrizes juntas cria, essencialmente, uma receita para a próxima saída do Transformer. Para reduzir a complexidade computacional, essas multiplicações ocorrem em um espaço com dimensões reduzidas. As próprias matrizes e os resultados de sua multiplicação devem então ser projetados de volta para as dimensões originais da entrada.
LoRA aproxima as atualizações de peso usando um produto de duas matrizes menores, reduzindo drasticamente o número de parâmetros treináveis. A técnica é tipicamente aplicada às camadas de projeção de atenção e às camadas de rede feed-forward. Esses módulos são candidatos ideais porque constituem a maior parte dos parâmetros do Transformer, governam diretamente o aprendizado de representação e exibem alinhamento natural com aproximações de baixa classificação. Evidências empíricas mostram que as mudanças de peso nessas camadas frequentemente se encontram dentro de um subespaço de baixa dimensão durante o ajuste fino.
Seleção de Módulos Alvo
A seleção dos módulos alvo corretos afeta diretamente a precisão, latência e eficiência computacional. A escolha ótima dos módulos alvo é principalmente uma função de (a) o modelo base que está sendo ajustado (ou seja, sua arquitetura, distribuições de dados pré e pós-treinamento, etc.) e (b) o domínio/modalidade de personalização.
Ao ajustar o Nova 2.0 Lite, equilibramos dois objetivos concorrentes:
- Maximizar a precisão em diversas tarefas e modalidades.
- Minimizar a latência para preservar os benefícios de eficiência do LoRA.
Investigamos a aplicação de LoRA em quatro módulos diferentes em cada bloco Transformer: as camadas de projeção de consulta, chave e valor (qkv); a camada o_proj; e duas camadas totalmente conectadas diferentes na rede feed-forward, gate_up_proj e gate_down_proj (referidas como fc1 e fc2). Abaixo estão as compensações para esses módulos, tanto individualmente quanto em combinação, com base em resultados publicados na literatura e estudos empíricos.
| Combination | Expected accuracy | Expected latency | Use case | | :------------------------- | :---------------------------------- | :----------------------- | :------- | | qkv only | Good (baseline) | Lowest | | | o_proj only | Moderate | Lowest | | | qkv + o_proj | High | Low to moderate (+5–10%) | | | qkv + fc1 / fc2 | Very high (close to full fine tuning) | Moderate (+10–15%) | | | o_proj + fc1 / fc2 | Good to high | Moderate (+5–10%) | | | qkv + o_proj + fc1 / fc2 | Highest (near-full fine tuning) | High (+15–20%) | | | All modules | Maximum | Highest (+20–25%) | |
Metodologia Experimental
Conduzimos um estudo de ablação abrangente, treinando múltiplas variantes de SFT (Supervised Fine-Tuning) LoRA em sete conjuntos de dados abrangendo dados de texto e visuais, em tarefas de raciocínio (ou seja, os próprios conjuntos de dados de treinamento incluem conteúdo de raciocínio) e não-raciocínio. Os conjuntos de dados cobriram diversos desafios, desde simples perguntas e respostas até sumarização de contexto longo e extração estruturada de JSON.
| Dataset | Modality | Reasoning traces | Domain | Tasks | Training size | Eval size | Eval metric | Source | | :---------- | :------- | :--------------- | :--------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------ | :-------- | :------------------- | :----- | | FinCOT | Txt | Yes | Finance | Financial-reasoning dataset. Samples consist of complex financial queries, along with reasoning traces obtained from GPT-4o. Predictions are typically complex tables or calculations based on the input. | 7436 | 1147 | Accuracy | | | GovReport | Txt | No | Goverment Doc | Large-context (30-40K tokens) summarization | 17457 | 837 | RougeLsum | | | MedMCQA | Txt | No | Medical | Dataset for multiple-choice QA — also used in Nova 1.0 | 20k | 3683 | Accuracy | | | MedReason | Txt | Yes | Medical | Medical-reasoning dataset that consists of questions and answers compiled from various medical benchmarks (MedQA, MedMCQA, etc.), along with synthetic, high-quality reasoning traces. (This uses the same eval set as MedMCQA.) | 31682 | 3683 | Accuracy | | | CoCoHD | Txt | No | Political Doc | A complex benchmark consisting of large-context (>20K tokens) transcripts of congressional hearings. The output is expected to be a summary in a specific JSON format, consisting of the members present, topic discussed, outcomes, etc. | 732 | 1053 | Averaged key and value match rate | | | Llava-COT | Image | Yes | Image understanding, General/Science | Multimodal, image benchmark consisting of Q&A reasoning questions. The dataset includes high-quality reasoning traces. | 10k | 270 | Exact match rate | | | Invoice OCR | Image | No | Image understanding | OCR benchmark that takes an input image and produces a JSON file with fields from the image. | 1400 | 447 | Accuracy | |
Todos os experimentos usaram o checkpoint de disponibilidade geral do Nova 2.0 Lite com hiperparâmetros consistentes em todos os módulos alvo, incluindo a taxa de aprendizado e os valores alfa.
| Target dataset | Setting | SFT LoRA target performance | Nova 2.0 Lite performance | | :------------- | :---------------- | :-------------------------- | :------------------------ | | Fin-COT | qkv | 67.09% | 72.12% | | | o_proj | 68.30% | | | | fc1 | 75.35% | | | | fc2 | 60.24% | | | | o_proj + fc1 | 61.38% | | | | qkv + fc2 | 60.31% | | | | o_proj + fc2 | 62.79% | | | | qkv + fc1 | 68.37% | | | | All target modules | 66.15% | | | CoCoHD | qkv | 19.64% | 45.14% | | | o_proj | 65.88% | | | | fc1 | 41.96% | | | | fc2 | 17.62% | | | | o_proj + fc1 | 76.83% | | | | qkv + fc2 | 66.47% | | | | o_proj + fc2 | 79.14% | | | | qkv + fc1 | 45.45% | | | | All target modules | 82.75% | | | GovReport | o_proj | 41.25% | 38.90% | | | fc1 | 39.69% | | | | o_proj + fc1 | 41.74% | | | | o_proj + fc2 | 42.16% | | | | qkv + fc1 | 41.66% | | | | qkv + fc2 | 39.02% | | | | All target modules | 41.95% | | | Llava-COT | qkv | 64.26% | 16.22% | | | o_proj | 64.26% | | | | fc1 | 65.92% | | | | fc2 | 65.02% | | | | o_proj + fc1 | 63.21% | | | | qkv + fc2 | 62.76% | | | | o_proj + fc2 | 66.37% | | | | qkv + fc1 | 66.52% | | | | All target modules | 63.96% | | | Invoice OCR | o_proj | 89.07% | 14.10% | | | o_proj + fc1 | 90.03% | | | | qkv + fc2 | 87.84% | | | | o_proj + fc2 | 89.47% | | | | qkv + fc1 | 88.55% | | | | All target modules | 90.11% | | | MedReason | o_proj | 24.55% | 1.68% | | | o_proj + fc1 | 20.88% | | | | qkv + fc2 | 8.39% | | | | o_proj + fc2 | 20.36% | | | | qkv + fc1 | 4.32% | | | | All target modules | 26.72% | | | MedMCQA | qkv | 62.18% | 1.68% | | | o_proj | 63.10% | | | | fc1 | 12.90% | | | | fc2 | 59.98% | | | | o_proj + fc1 | 61.39% | | | | qkv + fc2 | 65.63% | | | | o_proj + fc2 | 64.95% | | | | qkv + fc1 | 57.21% | | | | All target modules | 66.11% | |
Principais Descobertas
Nosso estudo de ablação revelou insights cruciais para a seleção de módulos alvo em LoRA:
- O
o_projé o alvo único mais robusto: A configuraçãoo_proj-only demonstrou notável consistência, nunca falhando abertamente em nenhuma tarefa e geralmente performando dentro de poucos pontos percentuais da melhor configuração (ou seja, usando todos os módulos alvo). Em MedMCQA, CoCoHD, GovReport, LLaVA-CoT e Invoice OCR,o_proj-only igualou ou chegou muito perto do desempenho ideal, tornando-o uma escolha padrão atraente que equilibra desempenho e simplicidade. Há evidências emergentes de que este módulo desempenha um papel fundamental no raciocínio, o que pode explicar sua eficácia aqui. qkv-only mostra instabilidade: Emboraqkv-only tenha se saído bem em MedMCQA, ele exibiu extrema variabilidade, performando abaixo da linha de base em CoCoHD e mostrando resultados pouco notáveis em outros lugares. Isso se alinha com a hipótese de que o LoRA apenas de atenção pode subajustar em tarefas que exigem recursos mais ricos da rede feed-forward, em vez de depender de roteamento de token modificado.- Combinações de módulos fornecem ganhos modestos: Combinações como
o_proj + fc2ou "todos os módulos alvo" frequentemente alcançaram as maiores pontuações por conjunto de dados (particularmente em CoCoHD, MedReason e Invoice OCR). No entanto, as melhorias em relação ao melhor módulo único foram tipicamente modestas, geralmente de 1-3 pontos percentuais. - A dificuldade da tarefa amplifica o impacto da configuração: Em benchmarks desafiadores onde o modelo base teve um desempenho ruim, a escolha dos módulos alvo teve um impacto maior. Por exemplo, em CoCoHD (contexto longo, geração complexa de JSON),
o_proj + fc2alcançou uma melhoria absoluta de +15% em relação ao modelo base, em comparação com apenas +3% como_projsozinho. - LoRA supera consistentemente os modelos base: Em quase todos os conjuntos de dados, qualquer configuração razoável de LoRA superou dramaticamente o modelo base. Por exemplo, MedReason, MedMCQA, LLaVA-CoT e Invoice OCR mostraram melhorias de uma precisão de linha de base de ~1-16% para 60-90%+ com LoRA. A exceção notável foi Fin-COT, onde apenas certas configurações (notavelmente fc1) excederam o desempenho da linha de base, sugerindo sensibilidade específica da tarefa à estratégia de adaptação.
Recomendações
Com base em nossas descobertas, oferecemos as seguintes recomendações para a seleção de módulos alvo LoRA:
- Para cenários priorizando a precisão: Recomendamos
o_proj + fc2como a configuração ideal para tarefas de texto e multimodais, mostrando melhorias de 2-12% sobreo_projsozinho em benchmarks. - Para eficiência e desempenho equilibrados:
o_proj-only oferece um excelente padrão, proporcionando desempenho robusto com sobrecarga mínima de latência — particularmente valioso ao servir múltiplos adaptadores ou operar sob restrições de recursos. - Para tarefas desafiadoras: Tais como benchmarks com contexto longo ou requisitos de geração complexa, ou outras tarefas onde os modelos base lutam, a precisão adicional de
o_proj + fc2justifica o modesto aumento de latência.
Direções Futuras
Nossa pesquisa abre várias avenidas promissoras para otimizações adicionais:
- Configurações específicas de modalidade e tarefa: Segmentar a seleção de módulos alvo por modalidade e dificuldade da tarefa (por exemplo, cenários de contexto longo) poderia render configurações especializadas com melhores compensações de precisão-latência.
- Otimização de hiperparâmetros por módulo: Uma extensa otimização de hiperparâmetros para cada configuração de módulo alvo poderia desbloquear ganhos de desempenho adicionais, embora os custos computacionais permaneçam uma consideração.
- LoRA de dois estágios para identificação precoce de candidatos: Alavancar abordagens de LoRA de dois estágios que usam dinâmicas de treinamento, gradientes, etc., para determinar a importância de diferentes módulos/camadas poderia ajudar a identificar configurações promissoras no início do treinamento, reduzindo o custo de buscas abrangentes de hiperparâmetros.
- Poda de camadas para redução de latência: Usar treinamento de dois estágios para identificar e podar camadas não utilizadas poderia reduzir ainda mais a latência de inferência, mantendo a precisão.
Neste estudo abrangente, demonstramos que a seleção criteriosa de módulos alvo no ajuste fino de LoRA pode aprimorar a precisão enquanto preserva as vantagens de eficiência que tornam o LoRA atraente para implantações em produção. A camada o_proj emerge como um alvo único notavelmente robusto, enquanto as combinações o_proj + fc2 oferecem a melhor precisão para tarefas desafiadoras. Em média, o LoRA com o_proj está dentro de 2% de o_proj + fc2 em termos de precisão, mas apresenta 22,6% menos latência (o TPOT p95 diminui de 10.085ms para 7.803ms). Essas descobertas fornecem uma base sólida para padronizar as configurações de LoRA em diversos casos de uso de clientes, equilibrando as demandas concorrentes de desempenho do modelo e eficiência computacional.
Aguardando Login...