All files / src/providers/openai Models.ts

85.71% Statements 24/28
63.33% Branches 19/30
76.92% Functions 10/13
85.18% Lines 23/27

Press n or j to go to the next uncovered block, b, p or k for the previous block.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115              45x 45x       2799x       466x 466x 439x       466x               469x               469x 469x       466x 466x       469x 469x       4x 4x 4x               4x 4x       4x 469x 469x   469x                                                                                    
import { ModelInfo } from "../Provider.js";
import { Capabilities } from "./Capabilities.js";
import { ModelRegistry } from "../../models/ModelRegistry.js";
import { buildUrl } from "./utils.js";
 
export class OpenAIModels {
  constructor(
    protected readonly baseUrl: string,
    protected readonly apiKey: string
  ) {}
 
  protected getProviderName(): string {
    return "openai";
  }
 
  protected formatDisplayName(modelId: string): string {
    const model = ModelRegistry.find(modelId, this.getProviderName());
    if (model?.name && model.name !== modelId) return model.name;
    return Capabilities.formatDisplayName(modelId);
  }
 
  protected getContextWindow(modelId: string): number | null {
    return (
      ModelRegistry.getContextWindow(modelId, this.getProviderName()) ||
      Capabilities.getContextWindow(modelId) ||
      null
    );
  }
 
  protected getMaxOutputTokens(modelId: string): number | null {
    return (
      ModelRegistry.getMaxOutputTokens(modelId, this.getProviderName()) ||
      Capabilities.getMaxOutputTokens(modelId) ||
      null
    );
  }
 
  protected getModalities(modelId: string) {
    const model = ModelRegistry.find(modelId, this.getProviderName());
    return model?.modalities || Capabilities.getModalities(modelId);
  }
 
  protected getCapabilities(modelId: string): string[] {
    const model = ModelRegistry.find(modelId, this.getProviderName());
    return model?.capabilities || Capabilities.getCapabilities(modelId);
  }
 
  protected getPricing(modelId: string) {
    const model = ModelRegistry.find(modelId, this.getProviderName());
    return model?.pricing || Capabilities.getPricing(modelId);
  }
 
  async execute(): Promise<ModelInfo[]> {
    const provider = this.getProviderName();
    try {
      const response = await fetch(buildUrl(this.baseUrl, "/models"), {
        method: "GET",
        headers: {
          Authorization: `Bearer ${this.apiKey}`,
          "Content-Type": "application/json"
        }
      });
 
      Eif (response.ok) {
        const { data } = (await response.json()) as {
          data: { id: string; created: number; owned_by: string }[];
        };
 
        return data.map((m) => {
          const modelId = m.id;
          const registryModel = ModelRegistry.find(modelId, provider);
 
          return {
            id: modelId,
            name: this.formatDisplayName(modelId),
            provider: provider,
            family: registryModel?.family || modelId,
            context_window: this.getContextWindow(modelId),
            max_output_tokens: this.getMaxOutputTokens(modelId),
            modalities: this.getModalities(modelId),
            capabilities: this.getCapabilities(modelId),
            pricing: this.getPricing(modelId),
            metadata: {
              ...(registryModel?.metadata || {}),
              created: m.created,
              owned_by: m.owned_by
            }
          };
        });
      }
    } catch (_error) {
      // Fallback to registry if API call fails
    }
 
    // Fallback to registry data
    return ModelRegistry.all()
      .filter((m) => m.provider === provider)
      .map((m) => ({
        id: m.id,
        name: m.name,
        family: m.family || m.id,
        provider: provider,
        context_window: m.context_window ?? null,
        capabilities: m.capabilities,
        modalities: m.modalities,
        max_output_tokens: m.max_output_tokens ?? null,
        pricing: m.pricing || {}
      })) as ModelInfo[];
  }
 
  find(modelId: string) {
    return ModelRegistry.find(modelId, this.getProviderName());
  }
}