Skip to content

Instantly share code, notes, and snippets.

@fblissjr
Created July 28, 2024 22:45
Show Gist options
  • Save fblissjr/c82ce6462c21a91faa12e0e9bff93b71 to your computer and use it in GitHub Desktop.
Save fblissjr/c82ce6462c21a91faa12e0e9bff93b71 to your computer and use it in GitHub Desktop.
loom + mlx lm design

Loom LLM Provider Integration Design Document

1. Overview

This document outlines the process of adding the MLX LLM provider to Loom and demonstrates how to extend this approach to easily add other providers in the future. The goal is to create a flexible and extensible system that maintains consistency with existing provider implementations.

2. Adding MLX Provider

2.1 Update Common Types (common.ts)

// Add MLX to the list of providers
export const PROVIDERS = [..., "mlx"];
export type Provider = (typeof PROVIDERS)[number];

// Add MLX-specific properties
type ProviderProps = {
  // ... existing providers
  "mlx": { 
    serverUrl: string;
  };
};

// Update ModelPreset type
export type ModelPreset<P extends Provider> = SharedPresetSettings & 
  (P extends keyof ProviderProps ? ProviderProps[P] : {}) & 
  { provider: P };

// Update LoomSettings interface
export interface LoomSettings {
  // ... existing settings
  mlxServerUrl?: string;
}

2.2 Create MLX Provider File (mlxProvider.ts)

import { requestUrl } from "obsidian";
import { CompletionResult } from "./types";

export async function completeMLX(
  serverUrl: string,
  prompt: string,
  maxTokens: number,
  temperature: number,
): Promise<CompletionResult> {
  try {
    const response = await requestUrl({
      url: `${serverUrl}/v1/completions`,
      method: "POST",
      headers: { "Content-Type": "application/json" },
      body: JSON.stringify({
        model: "mlx",
        prompt,
        max_tokens: maxTokens,
        temperature,
      }),
    });

    if (response.status !== 200) {
      throw new Error(`MLX server error: ${response.status}`);
    }

    const data = JSON.parse(response.text);
    const completions = data.choices.map((choice: any) => choice.text);
    return { ok: true, completions };
  } catch (e) {
    return { ok: false, status: e.status || 500, message: e.message };
  }
}

2.3 Update Main Plugin Class (main.ts)

import { completeMLX } from './mlxProvider';

export default class LoomPlugin extends Plugin {
  // ... existing code

  async generate(file: TFile, rootNode: string | null) {
    // ... existing code

    const completionMethods: Record<Provider, (prompt: string) => Promise<CompletionResult>> = {
      // ... existing providers
      mlx: (prompt: string) => completeMLX(
        this.settings.mlxServerUrl,
        prompt,
        this.settings.maxTokens,
        this.settings.temperature
      ),
    };

    // ... rest of the method
  }

  // ... rest of the class
}

2.4 Update Settings Tab (settings.ts)

class LoomSettingTab extends PluginSettingTab {
  // ... existing code

  display(): void {
    // ... existing code

    new Setting(containerEl)
      .setName("MLX Server URL")
      .setDesc("Enter the URL for the MLX server")
      .addText(text => text
        .setPlaceholder("http://localhost:8080")
        .setValue(this.plugin.settings.mlxServerUrl || "")
        .onChange(async (value) => {
          this.plugin.settings.mlxServerUrl = value;
          await this.plugin.saveSettings();
        }));

    // ... rest of the method
  }
}

2.5 Add MLX Model Presets (presets.ts)

const MLX_PRESETS: ModelPreset<"mlx">[] = [
  {
    name: "MLX Default",
    provider: "mlx",
    model: "mlx-default",
    contextLength: 4096,
    serverUrl: "http://localhost:8080",
  },
];

export const ALL_PRESETS = [
  ...EXISTING_PRESETS,
  ...MLX_PRESETS,
];

3. Adding Additional Providers

To add a new provider, follow these steps:

  1. Update common.ts:

    • Add the new provider to PROVIDERS array
    • Add provider-specific properties to ProviderProps
    • Update LoomSettings if needed
  2. Create a new provider file (e.g., newProvider.ts):

    • Implement the completion function for the new provider
    • Follow the same structure as mlxProvider.ts
  3. Update main.ts:

    • Import the new provider's completion function
    • Add the new provider to completionMethods in the generate method
  4. Update settings.ts:

    • Add any necessary settings for the new provider
  5. Update presets.ts:

    • Add presets for the new provider

Example: Adding a Hypothetical "NewLLM" Provider

3.1 Update common.ts

export const PROVIDERS = [..., "newLLM"];
export type Provider = (typeof PROVIDERS)[number];

type ProviderProps = {
  // ... existing providers
  "newLLM": { 
    apiKey: string;
    modelVersion: string;
  };
};

export interface LoomSettings {
  // ... existing settings
  newLLMApiKey?: string;
}

3.2 Create newLLMProvider.ts

import { requestUrl } from "obsidian";
import { CompletionResult } from "./types";

export async function completeNewLLM(
  apiKey: string,
  modelVersion: string,
  prompt: string,
  maxTokens: number,
  temperature: number,
): Promise<CompletionResult> {
  try {
    const response = await requestUrl({
      url: "https://api.newllm.com/v1/generate",
      method: "POST",
      headers: {
        "Content-Type": "application/json",
        "Authorization": `Bearer ${apiKey}`,
      },
      body: JSON.stringify({
        model: modelVersion,
        prompt,
        max_tokens: maxTokens,
        temperature,
      }),
    });

    if (response.status !== 200) {
      throw new Error(`NewLLM API error: ${response.status}`);
    }

    const data = JSON.parse(response.text);
    const completions = data.generations.map((gen: any) => gen.text);
    return { ok: true, completions };
  } catch (e) {
    return { ok: false, status: e.status || 500, message: e.message };
  }
}

3.3 Update main.ts

import { completeNewLLM } from './newLLMProvider';

export default class LoomPlugin extends Plugin {
  // ... existing code

  async generate(file: TFile, rootNode: string | null) {
    // ... existing code

    const completionMethods: Record<Provider, (prompt: string) => Promise<CompletionResult>> = {
      // ... existing providers
      newLLM: (prompt: string) => completeNewLLM(
        this.settings.newLLMApiKey,
        getPreset(this.settings).modelVersion,
        prompt,
        this.settings.maxTokens,
        this.settings.temperature
      ),
    };

    // ... rest of the method
  }

  // ... rest of the class
}

3.4 Update settings.ts

class LoomSettingTab extends PluginSettingTab {
  // ... existing code

  display(): void {
    // ... existing code

    new Setting(containerEl)
      .setName("NewLLM API Key")
      .setDesc("Enter your NewLLM API key")
      .addText(text => text
        .setPlaceholder("Enter API key")
        .setValue(this.plugin.settings.newLLMApiKey || "")
        .onChange(async (value) => {
          this.plugin.settings.newLLMApiKey = value;
          await this.plugin.saveSettings();
        }));

    // ... rest of the method
  }
}

3.5 Update presets.ts

const NEWLLM_PRESETS: ModelPreset<"newLLM">[] = [
  {
    name: "NewLLM v1",
    provider: "newLLM",
    modelVersion: "v1",
    contextLength: 2048,
    apiKey: "",
  },
  {
    name: "NewLLM v2",
    provider: "newLLM",
    modelVersion: "v2",
    contextLength: 4096,
    apiKey: "",
  },
];

export const ALL_PRESETS = [
  ...EXISTING_PRESETS,
  ...MLX_PRESETS,
  ...NEWLLM_PRESETS,
];

4. Ensuring Flexibility and Extensibility

  1. Modular Structure: Each provider has its own file, making it easy to add or modify providers without affecting others.

  2. Consistent Interface: All provider completion functions follow the same structure, returning a CompletionResult.

  3. Type Safety: The use of TypeScript ensures type safety across the application.

  4. Dynamic Provider Selection: The completionMethods object in main.ts allows for dynamic selection of the appropriate provider based on settings.

  5. Extensible Settings: The settings system can easily accommodate new provider-specific settings.

  6. Flexible Presets: The preset system allows for easy addition of new provider-specific presets.

Loom Logit Sampling and Advanced Generation Features Analysis

After reviewing the provided Loom codebase, here's an analysis of the existing functionality related to logit sampling and advanced generation features:

1. Existing Generation Parameters

In the LoomSettings interface (likely in common.ts), we can see several parameters that affect text generation:

export interface LoomSettings {
  // ... other settings
  maxTokens: number;
  temperature: number;
  topP: number;
  frequencyPenalty: number;
  presencePenalty: number;
  repetitionPenalty: number;
  repetitionContextSize: number;
  logitBias: Record<string, number> | null;
}

These parameters indicate that Loom already supports various sampling and control methods:

  • temperature: Controls randomness in sampling.
  • topP: Implements nucleus sampling.
  • frequencyPenalty and presencePenalty: Adjust token probabilities based on their frequency.
  • repetitionPenalty and repetitionContextSize: Help prevent repetitive text.
  • logitBias: Allows manual adjustment of token probabilities.

2. Generation Logic

The core generation logic is likely implemented in the generate method of the LoomPlugin class. This method probably uses these parameters when calling the underlying LLM API.

3. Branching Feature

Loom's branching feature is one of its key aspects. This feature likely works by:

  1. Generating multiple completions for a given prompt.
  2. Creating separate "nodes" or "branches" for each completion.
  3. Allowing the user to explore different generated paths.

4. Logit Sampling Integration

While Loom doesn't seem to have explicit "logit sampling" functionality, it does have features that affect token probabilities during generation. Here's how we could leverage and extend these features:

4.1. Utilizing Existing Parameters

We can use the existing parameters to influence the logit distribution:

  • Adjust temperature to control overall randomness.
  • Use topP for nucleus sampling, which is a form of logit filtering.
  • Employ logitBias to manually adjust probabilities of specific tokens.

4.2. Extending Logit Manipulation

To add more advanced logit sampling:

  1. Add a new parameter to LoomSettings:
export interface LoomSettings {
  // ... existing settings
  logitSamplingMethod: 'top_k' | 'nucleus' | 'temperature' | 'custom';
  customLogitFunction?: string; // A user-defined function as a string
}
  1. Implement logit sampling logic in the generation process:
// In main.ts or a new sampling.ts file
function applyLogitSampling(logits: number[], method: string, params: any): number[] {
  switch (method) {
    case 'top_k':
      return topKSampling(logits, params.k);
    case 'nucleus':
      return nucleusSampling(logits, params.p);
    case 'temperature':
      return temperatureSampling(logits, params.temp);
    case 'custom':
      return customSampling(logits, params.customFunction);
    default:
      return logits;
  }
}

// Implement these functions:
function topKSampling(logits: number[], k: number): number[] { /* ... */ }
function nucleusSampling(logits: number[], p: number): number[] { /* ... */ }
function temperatureSampling(logits: number[], temp: number): number[] { /* ... */ }
function customSampling(logits: number[], customFunction: string): number[] { /* ... */ }
  1. Integrate with the branching feature:
// In the generate method of LoomPlugin
async generate(file: TFile, rootNode: string | null) {
  // ... existing code

  const sampledLogits = applyLogitSampling(
    rawLogits,
    this.settings.logitSamplingMethod,
    {
      k: this.settings.topK,
      p: this.settings.topP,
      temp: this.settings.temperature,
      customFunction: this.settings.customLogitFunction
    }
  );

  const completions = generateCompletions(sampledLogits, this.settings.n);
  
  // Create branches for each completion
  for (const completion of completions) {
    this.createBranch(file, rootNode, completion);
  }

  // ... rest of the method
}

4.3. UI for Logit Sampling

Add UI elements in the settings to allow users to choose and configure logit sampling methods:

// In settings.ts
new Setting(containerEl)
  .setName("Logit Sampling Method")
  .setDesc("Choose the method for sampling logits")
  .addDropdown(dropdown => dropdown
    .addOptions({
      'top_k': 'Top-K',
      'nucleus': 'Nucleus (Top-P)',
      'temperature': 'Temperature',
      'custom': 'Custom Function'
    })
    .setValue(this.plugin.settings.logitSamplingMethod)
    .onChange(async (value) => {
      this.plugin.settings.logitSamplingMethod = value;
      await this.plugin.saveSettings();
    }));

// Add additional settings for each method's parameters

5. Leveraging Branching with Logit Sampling

By integrating advanced logit sampling with Loom's branching feature, we can:

  1. Generate multiple, diverse completions using different sampling methods.
  2. Create branches that represent different sampling strategies.
  3. Allow users to compare outputs from various logit manipulation techniques.
  4. Provide a playground for experimenting with custom logit functions.

This integration would make Loom a powerful tool for exploring different text generation strategies and their outcomes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment