Engenharia

Otimização LoRA: Módulos Alvo para LLMs Eficientes

Introdução à Otimização de LLMs com LoRA

Ajustar um grande modelo de linguagem (LLM) para uma tarefa específica exige atualizações em bilhões de parâmetros ao longo de trilhões de tokens, com os custos inerentes em recursos de GPU e tempo. A Adaptação de Baixa Classificação (LoRA) surge como uma alternativa mais eficiente, que congela os pesos originais do modelo, mas introduz matrizes leves em submódulos específicos do modelo, ou "módulos".

Essas matrizes, comumente chamadas de "adaptadores", modificam os pesos dos módulos, possibilitando não apenas um ajuste fino eficiente, mas também:

A questão central reside em onde inserir esses adaptadores no modelo. Empiricamente, direcionar mais módulos e módulos maiores tende a aumentar o desempenho, pois permite maior flexibilidade na personalização, mas também eleva os custos de treinamento e inferência. Utilizar um subconjunto menor e bem escolhido preserva a maioria dos ganhos com uma eficiência significativamente melhor.

Utilizando o LLM de raciocínio multimodal Nova 2.0 Lite da Amazon como nosso modelo base, estabelecemos o objetivo de identificar um subconjunto de configurações padronizadas de módulos alvo que funcionem eficazmente na vasta maioria dos casos de uso de clientes.

A Arquitetura Transformer

Modelos Transformer — os modelos responsáveis por todos os recentes e notáveis avanços da IA — consistem em grande parte de blocos repetidos múltiplas vezes. Cada bloco, por sua vez, possui dois componentes principais: um mecanismo de atenção, que determina a relevância dos tokens vistos anteriormente para o token atualmente processado, e uma rede feed-forward, uma rede neural convencional que realiza processamento adicional nas saídas do mecanismo de atenção.

O mecanismo de atenção envolve três matrizes diferentes, que recebem seus nomes do design de banco de dados:

A multiplicação das três matrizes juntas cria, essencialmente, uma receita para a próxima saída do Transformer. Para reduzir a complexidade computacional, essas multiplicações ocorrem em um espaço com dimensões reduzidas. As próprias matrizes e os resultados de sua multiplicação devem então ser projetados de volta para as dimensões originais da entrada.

LoRA aproxima as atualizações de peso usando um produto de duas matrizes menores, reduzindo drasticamente o número de parâmetros treináveis. A técnica é tipicamente aplicada às camadas de projeção de atenção e às camadas de rede feed-forward. Esses módulos são candidatos ideais porque constituem a maior parte dos parâmetros do Transformer, governam diretamente o aprendizado de representação e exibem alinhamento natural com aproximações de baixa classificação. Evidências empíricas mostram que as mudanças de peso nessas camadas frequentemente se encontram dentro de um subespaço de baixa dimensão durante o ajuste fino.

Seleção de Módulos Alvo

A seleção dos módulos alvo corretos afeta diretamente a precisão, latência e eficiência computacional. A escolha ótima dos módulos alvo é principalmente uma função de (a) o modelo base que está sendo ajustado (ou seja, sua arquitetura, distribuições de dados pré e pós-treinamento, etc.) e (b) o domínio/modalidade de personalização.

Ao ajustar o Nova 2.0 Lite, equilibramos dois objetivos concorrentes:

Investigamos a aplicação de LoRA em quatro módulos diferentes em cada bloco Transformer: as camadas de projeção de consulta, chave e valor (qkv); a camada o_proj; e duas camadas totalmente conectadas diferentes na rede feed-forward, gate_up_proj e gate_down_proj (referidas como fc1 e fc2). Abaixo estão as compensações para esses módulos, tanto individualmente quanto em combinação, com base em resultados publicados na literatura e estudos empíricos.

| Combination | Expected accuracy | Expected latency | Use case | | :------------------------- | :---------------------------------- | :----------------------- | :------- | | qkv only | Good (baseline) | Lowest | | | o_proj only | Moderate | Lowest | | | qkv + o_proj | High | Low to moderate (+5–10%) | | | qkv + fc1 / fc2 | Very high (close to full fine tuning) | Moderate (+10–15%) | | | o_proj + fc1 / fc2 | Good to high | Moderate (+5–10%) | | | qkv + o_proj + fc1 / fc2 | Highest (near-full fine tuning) | High (+15–20%) | | | All modules | Maximum | Highest (+20–25%) | |

Metodologia Experimental

Conduzimos um estudo de ablação abrangente, treinando múltiplas variantes de SFT (Supervised Fine-Tuning) LoRA em sete conjuntos de dados abrangendo dados de texto e visuais, em tarefas de raciocínio (ou seja, os próprios conjuntos de dados de treinamento incluem conteúdo de raciocínio) e não-raciocínio. Os conjuntos de dados cobriram diversos desafios, desde simples perguntas e respostas até sumarização de contexto longo e extração estruturada de JSON.

| Dataset | Modality | Reasoning traces | Domain | Tasks | Training size | Eval size | Eval metric | Source | | :---------- | :------- | :--------------- | :--------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------ | :-------- | :------------------- | :----- | | FinCOT | Txt | Yes | Finance | Financial-reasoning dataset. Samples consist of complex financial queries, along with reasoning traces obtained from GPT-4o. Predictions are typically complex tables or calculations based on the input. | 7436 | 1147 | Accuracy | | | GovReport | Txt | No | Goverment Doc | Large-context (30-40K tokens) summarization | 17457 | 837 | RougeLsum | | | MedMCQA | Txt | No | Medical | Dataset for multiple-choice QA — also used in Nova 1.0 | 20k | 3683 | Accuracy | | | MedReason | Txt | Yes | Medical | Medical-reasoning dataset that consists of questions and answers compiled from various medical benchmarks (MedQA, MedMCQA, etc.), along with synthetic, high-quality reasoning traces. (This uses the same eval set as MedMCQA.) | 31682 | 3683 | Accuracy | | | CoCoHD | Txt | No | Political Doc | A complex benchmark consisting of large-context (>20K tokens) transcripts of congressional hearings. The output is expected to be a summary in a specific JSON format, consisting of the members present, topic discussed, outcomes, etc. | 732 | 1053 | Averaged key and value match rate | | | Llava-COT | Image | Yes | Image understanding, General/Science | Multimodal, image benchmark consisting of Q&A reasoning questions. The dataset includes high-quality reasoning traces. | 10k | 270 | Exact match rate | | | Invoice OCR | Image | No | Image understanding | OCR benchmark that takes an input image and produces a JSON file with fields from the image. | 1400 | 447 | Accuracy | |

Todos os experimentos usaram o checkpoint de disponibilidade geral do Nova 2.0 Lite com hiperparâmetros consistentes em todos os módulos alvo, incluindo a taxa de aprendizado e os valores alfa.

| Target dataset | Setting | SFT LoRA target performance | Nova 2.0 Lite performance | | :------------- | :---------------- | :-------------------------- | :------------------------ | | Fin-COT | qkv | 67.09% | 72.12% | | | o_proj | 68.30% | | | | fc1 | 75.35% | | | | fc2 | 60.24% | | | | o_proj + fc1 | 61.38% | | | | qkv + fc2 | 60.31% | | | | o_proj + fc2 | 62.79% | | | | qkv + fc1 | 68.37% | | | | All target modules | 66.15% | | | CoCoHD | qkv | 19.64% | 45.14% | | | o_proj | 65.88% | | | | fc1 | 41.96% | | | | fc2 | 17.62% | | | | o_proj + fc1 | 76.83% | | | | qkv + fc2 | 66.47% | | | | o_proj + fc2 | 79.14% | | | | qkv + fc1 | 45.45% | | | | All target modules | 82.75% | | | GovReport | o_proj | 41.25% | 38.90% | | | fc1 | 39.69% | | | | o_proj + fc1 | 41.74% | | | | o_proj + fc2 | 42.16% | | | | qkv + fc1 | 41.66% | | | | qkv + fc2 | 39.02% | | | | All target modules | 41.95% | | | Llava-COT | qkv | 64.26% | 16.22% | | | o_proj | 64.26% | | | | fc1 | 65.92% | | | | fc2 | 65.02% | | | | o_proj + fc1 | 63.21% | | | | qkv + fc2 | 62.76% | | | | o_proj + fc2 | 66.37% | | | | qkv + fc1 | 66.52% | | | | All target modules | 63.96% | | | Invoice OCR | o_proj | 89.07% | 14.10% | | | o_proj + fc1 | 90.03% | | | | qkv + fc2 | 87.84% | | | | o_proj + fc2 | 89.47% | | | | qkv + fc1 | 88.55% | | | | All target modules | 90.11% | | | MedReason | o_proj | 24.55% | 1.68% | | | o_proj + fc1 | 20.88% | | | | qkv + fc2 | 8.39% | | | | o_proj + fc2 | 20.36% | | | | qkv + fc1 | 4.32% | | | | All target modules | 26.72% | | | MedMCQA | qkv | 62.18% | 1.68% | | | o_proj | 63.10% | | | | fc1 | 12.90% | | | | fc2 | 59.98% | | | | o_proj + fc1 | 61.39% | | | | qkv + fc2 | 65.63% | | | | o_proj + fc2 | 64.95% | | | | qkv + fc1 | 57.21% | | | | All target modules | 66.11% | |

Principais Descobertas

Nosso estudo de ablação revelou insights cruciais para a seleção de módulos alvo em LoRA:

Recomendações

Com base em nossas descobertas, oferecemos as seguintes recomendações para a seleção de módulos alvo LoRA:

Direções Futuras

Nossa pesquisa abre várias avenidas promissoras para otimizações adicionais:

Neste estudo abrangente, demonstramos que a seleção criteriosa de módulos alvo no ajuste fino de LoRA pode aprimorar a precisão enquanto preserva as vantagens de eficiência que tornam o LoRA atraente para implantações em produção. A camada o_proj emerge como um alvo único notavelmente robusto, enquanto as combinações o_proj + fc2 oferecem a melhor precisão para tarefas desafiadoras. Em média, o LoRA com o_proj está dentro de 2% de o_proj + fc2 em termos de precisão, mas apresenta 22,6% menos latência (o TPOT p95 diminui de 10.085ms para 7.803ms). Essas descobertas fornecem uma base sólida para padronizar as configurações de LoRA em diversos casos de uso de clientes, equilibrando as demandas concorrentes de desempenho do modelo e eficiência computacional.

Comentários

Interações
Seu Perfil

Aguardando Login...