All files / src/llm-orchestration/action-handlers discard-messages.handler.ts

70.47% Statements 74/105
61.76% Branches 21/34
86.66% Functions 13/15
72.16% Lines 70/97

Press n or j to go to the next uncovered block, b, p or k for the previous block.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 3577x 7x 7x           7x 7x 7x 7x             7x     7x 31x     31x   31x 31x       1x                                               3x                                                                             13x 13x                                                                                                         22x 22x     22x         4x                           18x 2x                           16x   24x 24x   16x                         16x     16x 16x 47x       16x 16x 16x 22x 22x 19x   3x       16x 3x                         13x     38x     6x   4x 4x     38x 13x 13x 13x 38x       13x                   13x     13x 13x       13x 13x 21x 21x         13x     13x               12x 12x 12x   2x             2x           12x 12x     12x       12x                                             1x                        
import { Injectable } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { ActionHandler } from './action-handler.interface';
import {
  ActionExecutionResult,
  PlanExecutionContext,
} from '../llm-orchestration.interfaces';
import { SessionInputsService } from '../../session-inputs/session-inputs.service';
import { SessionInput } from '../../core-entities';
import { HistoryCompressionService } from '../history-compression.service';
import {
  buildHistory,
  computeTurnCharSize,
  DEFAULT_CHARS_PER_TOKEN,
  CONTEXT_TARGET_TOKENS,
  CONTEXT_FLOOR_TOKENS,
} from './get-messages.handler';
import { toShortId } from '../../utils';
 
@Injectable()
export class DiscardMessagesHandler implements ActionHandler {
  readonly toolName = 'discard_messages';
 
  constructor(
    private readonly sessionInputsService: SessionInputsService,
    @InjectRepository(SessionInput)
    private readonly sessionInputsRepository: Repository<SessionInput>,
    private readonly historyCompressionService: HistoryCompressionService,
  ) {}
 
  getMetadata() {
    return {
      name: this.toolName,
      description:
        'Discard messages from the LLM context window to extend session life. Discarded messages are excluded from future LLM calls but preserved in history. After discarding, a follow-up turn starts automatically so you can continue working. Call get_messages first to identify short_ids, then call discard_messages with IDs to remove.',
      arguments: [
        {
          name: 'input_ids',
          type: 'string' as const,
          description:
            'Comma-separated list of short_ids to discard (e.g., "a1b90,cd3ef"). Use get_messages to find short_ids.',
          required: true,
        },
        {
          name: 'reason',
          type: 'string' as const,
          description:
            'Brief explanation (2-3 sentences max) of what is being discarded, what is kept, and why.',
          required: true,
        },
      ],
    };
  }
 
  getDefinition(): string {
    return `## discard_messages
 
Discard messages from the LLM context window to extend session life. Discarded messages are excluded from future LLM calls but preserved in history. After discarding, a follow-up turn starts automatically with reduced context so you can continue your task. Call \`get_messages\` first to identify short_ids, then call \`discard_messages\` with the IDs to remove.
 
**Important:** Check \`total_est_tokens\`, \`target_tokens\`, and \`floor_tokens\` from \`get_messages\` output. Cherry-pick individual messages until \`total_est_tokens\` is near \`target_tokens\`. **Never discard below \`floor_tokens\`.**
 
### Parameters:
- \`input_ids\` (string, required): Comma-separated list of short_ids to discard (e.g., "a1b90,cd3ef,ghi45"). Use \`get_messages\` to find short_ids.
- \`reason\` (string, required): Brief explanation (2-3 sentences max) of what is being discarded, what is kept, and why.
 
### Returns:
JSON object with:
- \`affected\`: Number of messages discarded
- \`short_ids\`: List of short_ids that were discarded
- \`reason\`: Your reason for discarding
- \`discarded_tokens\`: Estimated tokens freed by discarding
- \`new_total_tokens\`: Estimated total tokens remaining after discarding
- \`target_tokens\`: Target token count (${CONTEXT_TARGET_TOKENS})
- \`floor_tokens\`: Minimum token count (${CONTEXT_FLOOR_TOKENS})
- \`auto_discarded_get_messages\`: Number of get_messages calls that were auto-discarded
 
### Example:
\`\`\`typescript
tool: discard_messages
args: {
  input_ids: "a1b90,cd3ef,ghi45",
  reason: "Discarding old request_context calls and completed run_command outputs to free up context window. Keeping recent user prompts and code changes."
}
\`\`\``;
  }
 
  /**
   * Estimate total tokens for a set of session inputs using the same
   * algorithm as GetMessagesHandler (buildHistory + compress + derive ratio).
   * Returns the total estimated tokens and per-input est_tokens.
   */
  private async estimateTokens(
    inputs: SessionInput[],
  ): Promise<{ totalTokens: number; perInputTokens: number[] }> {
    if (inputs.length === 0) {
      return { totalTokens: 0, perInputTokens: [] };
    }
 
    const { history, turnToInputIndex } = buildHistory(inputs);
    const compressedHistory =
      await this.historyCompressionService.compress(history);
 
    // Compute compressed char_size per SessionInput
    const compressedCharSizes = new Array<number>(inputs.length).fill(0);
    for (let i = 0; i < compressedHistory.length; i++) {
      const turnCharSize = computeTurnCharSize(compressedHistory[i]);
      const inputIdx = turnToInputIndex[i];
      Iif (inputIdx >= 0 && inputIdx < inputs.length) {
        compressedCharSizes[inputIdx] += turnCharSize;
      }
    }
 
    // Derive char/token ratio from last model input with input_token_count
    let ratio = DEFAULT_CHARS_PER_TOKEN;
    let ratioInputIndex = -1;
    for (let i = inputs.length - 1; i >= 0; i--) {
      Iif (
        inputs[i].role === 'model' &&
        inputs[i].input_token_count != null &&
        inputs[i].input_token_count > 0
      ) {
        ratioInputIndex = i;
        break;
      }
    }
 
    Iif (ratioInputIndex >= 0) {
      let cumulativeChars = 0;
      for (let i = 0; i <= ratioInputIndex; i++) {
        cumulativeChars += compressedCharSizes[i];
      }
      Iif (cumulativeChars > 0) {
        ratio = cumulativeChars / inputs[ratioInputIndex].input_token_count!;
      }
    }
 
    const perInputTokens = compressedCharSizes.map((cs) =>
      cs > 0 ? Math.round(cs / ratio) : 0,
    );
 
    const totalTokens = perInputTokens.reduce((sum, t) => sum + t, 0);
    return { totalTokens, perInputTokens };
  }
 
  async execute(
    args: Record<string, any>,
    context: PlanExecutionContext,
  ): Promise<ActionExecutionResult> {
    try {
      const { input_ids, reason } = args;
 
      // Validate input_ids
      if (
        !input_ids ||
        typeof input_ids !== 'string' ||
        input_ids.trim() === ''
      ) {
        return {
          status: 'FAILURE',
          summary: 'Failed to discard messages',
          error_message:
            'input_ids is required and must be a non-empty comma-separated string of short_ids',
          execution_log: {
            output: '',
            error_message:
              'input_ids is required and must be a non-empty comma-separated string of short_ids',
          },
        };
      }
 
      // Validate reason
      if (!reason || typeof reason !== 'string' || reason.trim() === '') {
        return {
          status: 'FAILURE',
          summary: 'Failed to discard messages',
          error_message:
            'reason is required and must be a non-empty string explaining what is discarded, what is kept, and why',
          execution_log: {
            output: '',
            error_message:
              'reason is required and must be a non-empty string explaining what is discarded, what is kept, and why',
          },
        };
      }
 
      // Parse and clean short IDs
      const shortIds = input_ids
        .split(',')
        .map((id: string) => id.trim())
        .filter((id: string) => id.length > 0);
 
      Iif (shortIds.length === 0) {
        return {
          status: 'FAILURE',
          summary: 'Failed to discard messages',
          error_message: 'No valid IDs provided after parsing input_ids',
          execution_log: {
            output: '',
            error_message: 'No valid IDs provided after parsing input_ids',
          },
        };
      }
 
      // Fetch all session inputs to resolve short IDs to full UUIDs
      const allInputs = await this.sessionInputsService.findAllBySessionId(
        context.session_id,
      );
      const shortToFull = new Map<string, string>();
      for (const inp of allInputs) {
        shortToFull.set(toShortId(inp.id), inp.id);
      }
 
      // Resolve short IDs to full UUIDs
      const resolvedIds: string[] = [];
      const unresolved: string[] = [];
      for (const shortId of shortIds) {
        const fullId = shortToFull.get(shortId);
        if (fullId) {
          resolvedIds.push(fullId);
        } else {
          unresolved.push(shortId);
        }
      }
 
      if (unresolved.length > 0) {
        return {
          status: 'FAILURE',
          summary: 'Failed to discard messages',
          error_message: `Could not resolve short_ids: ${unresolved.join(', ')}. Use get_messages to find valid short_ids.`,
          execution_log: {
            output: '',
            error_message: `Could not resolve short_ids: ${unresolved.join(', ')}`,
          },
        };
      }
 
      // Auto-discard any inputs that contain get_messages actions
      // (the get_messages call itself is now redundant — it served its purpose)
      const getMessagesInputIds = allInputs
        .filter(
          (inp) =>
            inp.role === 'model' &&
            inp.aiActions &&
            inp.aiActions.length > 0 &&
            inp.aiActions.every((a) => a.action_type === 'get_messages'),
        )
        .map((inp) => inp.id)
        .filter((id) => !resolvedIds.includes(id)); // don't double-count
 
      // Compute token estimates BEFORE discarding
      const nonDiscardedInputs = allInputs.filter((inp) => !inp.is_discarded);
      const resolvedIdSet = new Set(resolvedIds);
      const autoDiscardIdSet = new Set(getMessagesInputIds);
      const inputsBeingDiscarded = nonDiscardedInputs.filter(
        (inp) => resolvedIdSet.has(inp.id) || autoDiscardIdSet.has(inp.id),
      );
 
      // Query with relations for token estimation
      const inputsForEstimation = await this.sessionInputsRepository.find({
        where: {
          session: { id: context.session_id },
          is_discarded: false,
        },
        relations: ['aiActions', 'aiActions.executionLogs'],
        order: { sequence_number: 'ASC' },
      });
 
      const { totalTokens: oldTotalTokens, perInputTokens } =
        await this.estimateTokens(inputsForEstimation);
 
      // Compute discarded_tokens: sum of est_tokens for inputs being discarded
      const idToIndex = new Map<string, number>();
      for (let i = 0; i < inputsForEstimation.length; i++) {
        idToIndex.set(inputsForEstimation[i].id, i);
      }
 
      let discardedTokens = 0;
      for (const inp of inputsBeingDiscarded) {
        const idx = idToIndex.get(inp.id);
        Iif (idx !== undefined) {
          discardedTokens += perInputTokens[idx];
        }
      }
 
      const newTotalTokens = oldTotalTokens - discardedTokens;
 
      // Perform the discard
      const result = await this.sessionInputsService.toggleDiscardedBatch(
        context.session_id,
        {
          input_ids: resolvedIds,
          is_discarded: true,
        },
      );
 
      let autoDiscardedCount = 0;
      const _autoDiscardedTokens = 0;
      if (getMessagesInputIds.length > 0) {
        // Tokens for auto-discarded get_messages inputs are already counted in discardedTokens
        const autoResult = await this.sessionInputsService.toggleDiscardedBatch(
          context.session_id,
          {
            input_ids: getMessagesInputIds,
            is_discarded: true,
          },
        );
        autoDiscardedCount = autoResult.affected;
      }
 
      // No is_final — a follow-up turn starts automatically after discarding
      // so the AI can continue working with reduced context
 
      const totalDiscardedTokens = discardedTokens;
      const finalNewTotal = newTotalTokens;
 
      const summary =
        autoDiscardedCount > 0
          ? `Discarded ${result.affected} message(s) + ${autoDiscardedCount} get_messages call(s). ~${totalDiscardedTokens} tokens freed. New total: ~${finalNewTotal} tokens.`
          : `Discarded ${result.affected} message(s). ~${totalDiscardedTokens} tokens freed. New total: ~${finalNewTotal} tokens.`;
 
      return {
        status: 'SUCCESS',
        summary,
        execution_log: {
          output: JSON.stringify({
            affected: result.affected,
            short_ids: shortIds,
            reason: reason.trim(),
            discarded_tokens: totalDiscardedTokens,
            new_total_tokens: finalNewTotal,
            target_tokens: CONTEXT_TARGET_TOKENS,
            floor_tokens: CONTEXT_FLOOR_TOKENS,
            auto_discarded_get_messages: autoDiscardedCount,
          }),
        },
        persisted_args: {
          arguments: JSON.stringify({
            input_ids: shortIds,
            reason: reason.trim(),
          }),
        },
      };
    } catch (error) {
      return {
        status: 'FAILURE',
        summary: 'Failed to discard messages',
        error_message: error.message,
        execution_log: {
          output: '',
          error_message: error.message,
        },
      };
    }
  }
}