All files / src/models ModelRegistry.ts

93.54% Statements 29/31
84.61% Branches 33/39
100% Functions 10/10
93.54% Lines 29/31

Press n or j to go to the next uncovered block, b, p or k for the previous block.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114          85x     10839x 2633160x               6x   6x 527x 188081x   527x 59x   468x                 6x             962x 962x                             941x 941x                                 51x 51x 5x     46x 46x 51x         51x     51x   51x       51x 51x     51x   51x   51x                
import { Model } from "./types.js";
import { modelsData } from "./models.js";
import { PricingRegistry } from "./PricingRegistry.js";
 
export class ModelRegistry {
  private static models: Model[] = modelsData as unknown as Model[];
 
  static find(modelId: string, provider?: string): Model | undefined {
    return this.models.find(
      (m) => m.id.toLowerCase() === modelId.toLowerCase() && (!provider || m.provider === provider)
    );
  }
 
  /**
   * Add or update models in the registry.
   */
  static save(models: Model | Model[]): void {
    const toAdd = Array.isArray(models) ? models : [models];
 
    toAdd.forEach((newModel) => {
      const index = this.models.findIndex(
        (m) => m.id === newModel.id && m.provider === newModel.provider
      );
      if (index >= 0) {
        this.models[index] = newModel;
      } else {
        this.models.push(newModel);
      }
    });
  }
 
  /**
   * Get all available models.
   */
  static all(): Model[] {
    return this.models;
  }
 
  /**
   * Get output tokens limit for a model.
   */
  static getMaxOutputTokens(modelId: string, provider: string): number | undefined {
    const model = this.find(modelId, provider);
    return model?.max_output_tokens ?? undefined;
  }
 
  /**
   * Check if a model supports a capability.
   */
  static supports(modelId: string, capability: string, provider: string): boolean {
    const model = this.find(modelId, provider);
    return model?.capabilities.includes(capability) ?? false;
  }
 
  /**
   * Get context window size.
   */
  static getContextWindow(modelId: string, provider: string): number | undefined {
    const model = this.find(modelId, provider);
    return model?.context_window ?? undefined;
  }
 
  /**
   * Calculate cost for usage.
   */
  static calculateCost(
    usage: {
      input_tokens: number;
      output_tokens: number;
      total_tokens: number;
      cached_tokens?: number;
      reasoning_tokens?: number;
    },
    modelId: string,
    provider: string
  ) {
    const pricing = PricingRegistry.getPricing(modelId, provider);
    if (!pricing || !pricing.text_tokens?.standard) {
      return usage;
    }
 
    const prices = pricing.text_tokens.standard;
    const inputPrice = prices.input_per_million || 0;
    const outputPrice = prices.output_per_million || 0;
 
    // Fallback for reasoning: if not specified, default to 2.5x standard output price for specific reasoning models
    // or just standard output price for others.
    const reasoningPrice =
      prices.reasoning_output_per_million ??
      (modelId.includes("reasoner") || modelId.includes("3-7") ? outputPrice * 2.5 : outputPrice);
 
    const cachedPrice = prices.cached_input_per_million ?? inputPrice / 2;
 
    const inputCost =
      ((usage.input_tokens - (usage.cached_tokens || 0)) / 1_000_000) * inputPrice +
      ((usage.cached_tokens || 0) / 1_000_000) * cachedPrice;
 
    const outputTokens = usage.output_tokens - (usage.reasoning_tokens || 0);
    const reasoningTokens = usage.reasoning_tokens || 0;
 
    const outputCost =
      (outputTokens / 1_000_000) * outputPrice + (reasoningTokens / 1_000_000) * reasoningPrice;
 
    const totalCost = inputCost + outputCost;
 
    return {
      ...usage,
      input_cost: Number(inputCost.toFixed(6)),
      output_cost: Number(outputCost.toFixed(6)),
      cost: Number(totalCost.toFixed(6))
    };
  }
}