All files / src/model-router index.ts

100% Statements 55/55
100% Branches 23/23
100% Functions 6/6
100% Lines 55/55

Press n or j to go to the next uncovered block, b, p or k for the previous block.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128                                                                                                                    1x 1x 1x 1x 1x 1x 1x 1x 1x 1x 1x 1x       53x 53x 53x 53x 53x 53x 53x 53x       1x 18x 18x 18x   18x 18x   13x 5x 1x 1x 5x 3x 3x 5x     9x 9x 9x   9x 9x 9x 9x 9x 9x 13x   18x 2x 2x   18x 3x 3x   18x 5x 5x 18x 18x  
/**
 * AgentKits — Model Router Module
 *
 * Auto-select the best model based on cost/speed/quality preference.
 * Configurable routing rules.
 *
 * Usage:
 *   import { createModelRouter } from 'agentkits/model-router';
 *   const router = createModelRouter({ preference: 'balanced' });
 *   const model = router.select('Translate this sentence');
 */
 
// ── Types ──────────────────────────────────────────────────────────
 
export type RoutingPreference = 'cost' | 'speed' | 'quality' | 'balanced';
 
export interface ModelProfile {
  id: string;
  provider: string;
  /** Relative cost score 1-10 (1=cheapest) */
  costScore: number;
  /** Relative speed score 1-10 (1=fastest) */
  speedScore: number;
  /** Relative quality score 1-10 (10=best) */
  qualityScore: number;
  /** Max context length */
  maxContext?: number;
}
 
export interface RoutingRule {
  /** If prompt length exceeds this, use longContext model */
  maxPromptLength?: number;
  /** Keywords that trigger a specific model */
  keywords?: string[];
  /** Model to use when rule matches */
  model: string;
  provider: string;
}
 
export interface ModelRouterConfig {
  preference?: RoutingPreference;
  models?: ModelProfile[];
  rules?: RoutingRule[];
}
 
export interface ModelRouter {
  /** Select best model for given prompt */
  select(prompt: string): { model: string; provider: string; reason: string };
  /** Add a model profile */
  addModel(profile: ModelProfile): void;
  /** Add a routing rule */
  addRule(rule: RoutingRule): void;
  /** List all known models */
  listModels(): ModelProfile[];
}
 
// ── Default model profiles ─────────────────────────────────────────
 
const DEFAULT_MODELS: ModelProfile[] = [
  { id: 'gpt-4o',           provider: 'openai',    costScore: 6, speedScore: 7, qualityScore: 9, maxContext: 128000 },
  { id: 'gpt-4o-mini',      provider: 'openai',    costScore: 2, speedScore: 9, qualityScore: 7, maxContext: 128000 },
  { id: 'gemini-2.5-flash', provider: 'gemini',    costScore: 2, speedScore: 9, qualityScore: 8, maxContext: 1000000 },
  { id: 'deepseek-chat',    provider: 'deepseek',  costScore: 1, speedScore: 7, qualityScore: 8, maxContext: 64000 },
  { id: 'qwen-max',         provider: 'dashscope', costScore: 3, speedScore: 7, qualityScore: 8, maxContext: 32000 },
  { id: 'glm-4-plus',       provider: 'zhipu',     costScore: 3, speedScore: 7, qualityScore: 7, maxContext: 128000 },
  { id: 'grok-3',           provider: 'grok',      costScore: 5, speedScore: 7, qualityScore: 8, maxContext: 131072 },
  { id: 'command-r-plus',   provider: 'cohere',    costScore: 4, speedScore: 6, qualityScore: 8, maxContext: 128000 },
  { id: 'moonshot-v1-auto', provider: 'moonshot',  costScore: 2, speedScore: 7, qualityScore: 7, maxContext: 128000 },
  { id: 'MiniMax-Text-01',  provider: 'minimax',   costScore: 2, speedScore: 7, qualityScore: 7, maxContext: 1000000 },
];
 
// ── Scoring ────────────────────────────────────────────────────────
 
function scoreModel(m: ModelProfile, pref: RoutingPreference): number {
  switch (pref) {
    case 'cost':     return (10 - m.costScore) * 3 + m.qualityScore;
    case 'speed':    return (10 - m.speedScore) * 3 + m.qualityScore;
    case 'quality':  return m.qualityScore * 3 + (10 - m.costScore);
    case 'balanced': return m.qualityScore * 2 + (10 - m.costScore) + (10 - m.speedScore);
  }
}
 
// ── Factory ────────────────────────────────────────────────────────
 
export function createModelRouter(config: ModelRouterConfig = {}): ModelRouter {
  const preference = config.preference ?? 'balanced';
  const models = [...(config.models ?? DEFAULT_MODELS)];
  const rules = [...(config.rules ?? [])];
 
  return {
    select(prompt: string): { model: string; provider: string; reason: string } {
      // Check rules first
      for (const rule of rules) {
        if (rule.maxPromptLength && prompt.length > rule.maxPromptLength) {
          return { model: rule.model, provider: rule.provider, reason: `Prompt exceeds ${rule.maxPromptLength} chars` };
        }
        if (rule.keywords?.some(kw => prompt.toLowerCase().includes(kw.toLowerCase()))) {
          return { model: rule.model, provider: rule.provider, reason: `Matched keyword rule` };
        }
      }
 
      // Score and sort
      const scored = models
        .map(m => ({ ...m, score: scoreModel(m, preference) }))
        .sort((a, b) => b.score - a.score);
 
      const best = scored[0];
      return {
        model: best.id,
        provider: best.provider,
        reason: `Best ${preference} score (${best.score.toFixed(1)})`,
      };
    },
 
    addModel(profile: ModelProfile) {
      models.push(profile);
    },
 
    addRule(rule: RoutingRule) {
      rules.push(rule);
    },
 
    listModels(): ModelProfile[] {
      return [...models];
    },
  };
}