Skip to content

Commit 7381008

Browse files
authored
Merge pull request #89 from codelion/feat-add-local-inference
Feat add local inference
2 parents 476719c + ad90fd8 commit 7381008

File tree

4 files changed

+1586
-39
lines changed

4 files changed

+1586
-39
lines changed

README.md

+56-19
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,6 @@ python optillm.py
4848
* Running on http://192.168.10.48:8000
4949
2024-09-06 07:57:14,212 - INFO - Press CTRL+C to quit
5050
```
51-
52-
### Starting the optillm proxy for a local server (e.g. llama.cpp)
53-
54-
- Set the `OPENAI_API_KEY` env variable to a placeholder value
55-
- e.g. `export OPENAI_API_KEY="no_key"`
56-
- Run `./llama-server -c 4096 -m path_to_model` to start the server with the specified model and a context length of 4096 tokens
57-
- Run `python3 optillm.py --base_url base_url` to start the proxy
58-
- e.g. for llama.cpp, run `python3 optillm.py --base_url http://localhost:8080/v1`
59-
60-
> [!WARNING]
61-
> Note that llama-server currently does not support sampling multiple responses from a model, which limits the available approaches to the following:
62-
> `cot_reflection`, `leap`, `plansearch`, `rstar`, `rto`, `self_consistency`, `re2`, and `z3`.
63-
64-
> [!NOTE]
65-
> You'll later need to specify a model name in the OpenAI client configuration. Since llama-server was started with a single model, you can choose any name you want.
66-
6751
## Usage
6852

6953
Once the proxy is running, you can use it as a drop in replacement for an OpenAI client by setting the `base_url` as `http://localhost:8000/v1`.
@@ -155,7 +139,60 @@ In the diagram:
155139
- `A` is an existing tool (like [oobabooga](https://github.com/oobabooga/text-generation-webui/)), framework (like [patchwork](https://github.com/patched-codes/patchwork))
156140
or your own code where you want to use the results from optillm. You can use it directly using any OpenAI client sdk.
157141
- `B` is the optillm service (running directly or in a docker container) that will send requests to the `base_url`.
158-
- `C` is any service providing an OpenAI API compatible chat completions endpoint.
142+
- `C` is any service providing an OpenAI API compatible chat completions endpoint.
143+
144+
### Local inference server
145+
146+
We support loading any HuggingFace model or LoRA directly in optillm. To use the built-in inference server set the `OPTILLM_API_KEY` to any value (e.g. `export OPTILLM_API_KEY="optillm"`)
147+
and then use the same in your OpenAI client. You can pass any HuggingFace model in model field. If it is a private model make sure you set the `HF_TOKEN` environment variable
148+
with your HuggingFace key. We also support adding any number of LoRAs on top of the model by using the `+` separator.
149+
150+
E.g. The following code loads the base model `meta-llama/Llama-3.2-1B-Instruct` and then adds two LoRAs on top - `patched-codes/Llama-3.2-1B-FixVulns` and `patched-codes/Llama-3.2-1B-FastApply`.
151+
You can specify which LoRA to use using the `active_adapter` param in `extra_args` field of OpenAI SDK client. By default we will load the last specified adapter.
152+
153+
```python
154+
OPENAI_BASE_URL = "http://localhost:8000/v1"
155+
OPENAI_KEY = "optillm"
156+
response = client.chat.completions.create(
157+
model="meta-llama/Llama-3.2-1B-Instruct+patched-codes/Llama-3.2-1B-FastApply+patched-codes/Llama-3.2-1B-FixVulns",
158+
messages=messages,
159+
temperature=0.2,
160+
logprobs = True,
161+
top_logprobs = 3,
162+
extra_body={"active_adapter": "patched-codes/Llama-3.2-1B-FastApply"},
163+
)
164+
```
165+
166+
You can also use the alternate decoding techniques like `cot_decoding` and `entropy_decoding` directly with the local inference server.
167+
168+
```python
169+
response = client.chat.completions.create(
170+
model="meta-llama/Llama-3.2-1B-Instruct",
171+
messages=messages,
172+
temperature=0.2,
173+
extra_body={
174+
"decoding": "cot_decoding", # or "entropy_decoding"
175+
# CoT specific params
176+
"k": 10,
177+
"aggregate_paths": True,
178+
# OR Entropy specific params
179+
"top_k": 27,
180+
"min_p": 0.03,
181+
}
182+
)
183+
```
184+
185+
### Starting the optillm proxy with an external server (e.g. llama.cpp or ollama)
186+
187+
- Set the `OPENAI_API_KEY` env variable to a placeholder value
188+
- e.g. `export OPENAI_API_KEY="sk-no-key"`
189+
- Run `./llama-server -c 4096 -m path_to_model` to start the server with the specified model and a context length of 4096 tokens
190+
- Run `python3 optillm.py --base_url base_url` to start the proxy
191+
- e.g. for llama.cpp, run `python3 optillm.py --base_url http://localhost:8080/v1`
192+
193+
> [!WARNING]
194+
> Note that llama-server (and ollama) currently does not support sampling multiple responses from a model, which limits the available approaches to the following:
195+
> `cot_reflection`, `leap`, `plansearch`, `rstar`, `rto`, `self_consistency`, `re2`, and `z3`. Use the built-in local inference server to use these approaches.
159196
160197
## Implemented techniques
161198

@@ -256,9 +293,9 @@ Authorization: Bearer your_secret_api_key
256293
### readurls&memory-gpt-4o-mini on Google FRAMES Benchmark (Oct 2024)
257294
| Model | Accuracy |
258295
| ----- | -------- |
259-
| readlurls&memory-gpt-4o-mini | 65.66 |
296+
| readurls&memory-gpt-4o-mini | 65.66 |
260297
| gpt-4o-mini | 50.0 |
261-
| readlurls&memory-Gemma2-9b | 30.1 |
298+
| readurls&memory-Gemma2-9b | 30.1 |
262299
| Gemma2-9b | 5.1 |
263300
| Gemma2-27b | 30.8 |
264301
| Gemini Flash 1.5 | 66.5 |

optillm.py

+93-19
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import asyncio
1212
import re
1313
from concurrent.futures import ThreadPoolExecutor
14+
from typing import Tuple, Optional, Union, Dict, Any, List
1415

1516
# Import approach modules
1617
from optillm.mcts import chat_with_mcts
@@ -43,8 +44,13 @@
4344

4445
def get_config():
4546
API_KEY = None
47+
if os.environ.get("OPTILLM_API_KEY"):
48+
# Use local inference engine
49+
from optillm.inference import create_inference_client
50+
API_KEY = os.environ.get("OPTILLM_API_KEY")
51+
default_client = create_inference_client()
4652
# OpenAI, Azure, or LiteLLM API configuration
47-
if os.environ.get("OPENAI_API_KEY"):
53+
elif os.environ.get("OPENAI_API_KEY"):
4854
API_KEY = os.environ.get("OPENAI_API_KEY")
4955
base_url = server_config['base_url']
5056
if base_url != "":
@@ -78,7 +84,7 @@ def get_config():
7884

7985
# Server configuration
8086
server_config = {
81-
'approach': 'bon',
87+
'approach': 'none',
8288
'mcts_simulations': 2,
8389
'mcts_exploration': 0.2,
8490
'mcts_depth': 1,
@@ -96,11 +102,52 @@ def get_config():
96102
}
97103

98104
# List of known approaches
99-
known_approaches = ["mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar",
100-
"cot_reflection", "plansearch", "leap", "re2"]
105+
known_approaches = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency",
106+
"pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]
101107

102108
plugin_approaches = {}
103109

110+
def none_approach(
111+
client: Any,
112+
model: str,
113+
original_messages: List[Dict[str, str]],
114+
**kwargs
115+
) -> Dict[str, Any]:
116+
"""
117+
Direct proxy approach that passes through all parameters to the underlying endpoint.
118+
119+
Args:
120+
system_prompt: System prompt text (unused)
121+
initial_query: Initial query/conversation (unused)
122+
client: OpenAI client instance
123+
model: Model identifier
124+
original_messages: Original messages from the request
125+
**kwargs: Additional parameters to pass through
126+
127+
Returns:
128+
Dict[str, Any]: Full OpenAI API response
129+
"""
130+
# Strip 'none-' prefix from model if present
131+
if model.startswith('none-'):
132+
model = model[5:]
133+
134+
try:
135+
# Make the direct completion call with original messages and parameters
136+
response = client.chat.completions.create(
137+
model=model,
138+
messages=original_messages,
139+
**kwargs
140+
)
141+
142+
# Convert to dict if it's not already
143+
if hasattr(response, 'model_dump'):
144+
return response.model_dump()
145+
return response
146+
147+
except Exception as e:
148+
logger.error(f"Error in none approach: {str(e)}")
149+
raise
150+
104151
def load_plugins():
105152
# Clear existing plugins first but modify the global dict in place
106153
plugin_approaches.clear()
@@ -158,7 +205,7 @@ def load_plugins():
158205

159206
def parse_combined_approach(model: str, known_approaches: list, plugin_approaches: dict):
160207
if model == 'auto':
161-
return 'SINGLE', ['bon'], model
208+
return 'SINGLE', ['none'], model
162209

163210
parts = model.split('-')
164211
approaches = []
@@ -183,7 +230,7 @@ def parse_combined_approach(model: str, known_approaches: list, plugin_approache
183230
model_parts.append(part)
184231

185232
if not approaches:
186-
approaches = ['bon']
233+
approaches = ['none']
187234
operation = 'SINGLE'
188235

189236
actual_model = '-'.join(model_parts)
@@ -192,8 +239,21 @@ def parse_combined_approach(model: str, known_approaches: list, plugin_approache
192239

193240
def execute_single_approach(approach, system_prompt, initial_query, client, model):
194241
if approach in known_approaches:
195-
# Execute known approaches
196-
if approach == 'mcts':
242+
if approach == 'none':
243+
# Extract kwargs from the request data
244+
kwargs = {}
245+
if hasattr(request, 'json'):
246+
data = request.get_json()
247+
messages = data.get('messages', [])
248+
# Copy all parameters except 'model' and 'messages'
249+
kwargs = {k: v for k, v in data.items()
250+
if k not in ['model', 'messages', 'optillm_approach']}
251+
response = none_approach(original_messages=messages, client=client, model=model, **kwargs)
252+
253+
# For none approach, we return the response and a token count of 0
254+
# since the full token count is already in the response
255+
return response, 0
256+
elif approach == 'mcts':
197257
return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
198258
server_config['mcts_exploration'], server_config['mcts_depth'])
199259
elif approach == 'bon':
@@ -324,7 +384,6 @@ def proxy():
324384
bearer_token = ""
325385

326386
if auth_header and auth_header.startswith("Bearer "):
327-
# Extract the bearer token
328387
bearer_token = auth_header.split("Bearer ")[1].strip()
329388
logger.debug(f"Intercepted Bearer Token: {bearer_token}")
330389

@@ -360,22 +419,37 @@ def proxy():
360419
client = default_client
361420

362421
try:
422+
# Check if any of the approaches is 'none'
423+
contains_none = any(approach == 'none' for approach in approaches)
424+
425+
if operation == 'SINGLE' and approaches[0] == 'none':
426+
# For none approach, return the response directly
427+
result, _ = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
428+
logger.debug(f'Direct proxy response: {result}')
429+
return jsonify(result), 200
430+
431+
elif operation == 'AND' or operation == 'OR':
432+
if contains_none:
433+
raise ValueError("'none' approach cannot be combined with other approaches")
434+
435+
# Handle non-none approaches
363436
if operation == 'SINGLE':
364-
final_response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
437+
response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
365438
elif operation == 'AND':
366-
final_response, completion_tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model)
439+
response, completion_tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model)
367440
elif operation == 'OR':
368441
loop = asyncio.new_event_loop()
369442
asyncio.set_event_loop(loop)
370-
final_response, completion_tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model))
443+
response, completion_tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model))
371444
else:
372445
raise ValueError(f"Unknown operation: {operation}")
446+
373447
except Exception as e:
374448
logger.error(f"Error processing request: {str(e)}")
375449
return jsonify({"error": str(e)}), 500
376450

377451
if stream:
378-
return Response(generate_streaming_response(final_response, model), content_type='text/event-stream')
452+
return Response(generate_streaming_response(response, model), content_type='text/event-stream')
379453
else:
380454
response_data = {
381455
'model': model,
@@ -385,13 +459,13 @@ def proxy():
385459
}
386460
}
387461

388-
if isinstance(final_response, list):
389-
for index, response in enumerate(final_response):
462+
if isinstance(response, list):
463+
for index, resp in enumerate(response):
390464
response_data['choices'].append({
391465
'index': index,
392466
'message': {
393467
'role': 'assistant',
394-
'content': response,
468+
'content': resp,
395469
},
396470
'finish_reason': 'stop'
397471
})
@@ -400,13 +474,13 @@ def proxy():
400474
'index': 0,
401475
'message': {
402476
'role': 'assistant',
403-
'content': final_response,
477+
'content': response,
404478
},
405479
'finish_reason': 'stop'
406480
})
407481

408-
logger.debug(f'API response: {response_data}')
409-
return jsonify(response_data), 200
482+
logger.debug(f'API response: {response_data}')
483+
return jsonify(response_data), 200
410484

411485
@app.route('/v1/models', methods=['GET'])
412486
def proxy_models():

0 commit comments

Comments
 (0)