Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

Commit db5bc8e

Browse files
committed
Add loading code for ggjt
Now it can load the model, but it's not working
1 parent af5415f commit db5bc8e

File tree

5 files changed

+237
-131
lines changed

5 files changed

+237
-131
lines changed

Cargo.lock

+10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ggml/src/lib.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -322,13 +322,21 @@ impl Tensor {
322322
/// # Safety
323323
///
324324
/// The data must not be mutated while being read from.
325-
pub unsafe fn data(&self) -> *mut c_void {
325+
pub unsafe fn data(&self) -> *const c_void {
326326
self.with_alive_ctx(|| {
327327
// SAFETY: The with_alive_call guarantees the context is alive
328328
unsafe { *self.ptr.as_ptr() }.data
329329
})
330330
}
331331

332+
/// Set the tensor's data pointer (useful for mmap-ed data)
333+
pub unsafe fn set_data(&self, data_ptr: *mut c_void) {
334+
self.with_alive_ctx(|| {
335+
// SAFETY: The with_alive_call guarantees the context is alive
336+
unsafe { *self.ptr.as_ptr() }.data = data_ptr;
337+
})
338+
}
339+
332340
/// Number of elements in this tensor.
333341
pub fn nelements(&self) -> usize {
334342
self.with_alive_ctx(|| {

llama-rs/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ rand = { workspace = true }
1616
serde = { version = "1.0.156", features = ["derive"] }
1717
serde_bytes = "0.11"
1818
bincode = "1.3.3"
19+
memmap2 = "0.5.10"

llama-rs/src/lib.rs

+17-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use std::{
1414
time,
1515
};
1616

17+
use memmap2::Mmap;
1718
use thiserror::Error;
1819

1920
use partial_sort::PartialSort;
@@ -66,6 +67,8 @@ pub struct Model {
6667

6768
tensors: HashMap<String, ggml::Tensor>,
6869

70+
mmap: Option<Mmap>,
71+
6972
// Must be kept alive for the model
7073
_context: ggml::Context,
7174
}
@@ -502,7 +505,7 @@ pub enum LoadError {
502505
/// The name of the tensor.
503506
tensor_name: String,
504507
/// The format type that was encountered.
505-
ftype: u32,
508+
ftype: i32,
506509
/// The path that failed.
507510
path: PathBuf,
508511
},
@@ -585,12 +588,13 @@ impl Model {
585588

586589
let main_path = path.as_ref();
587590

588-
let mut reader =
589-
BufReader::new(
590-
File::open(main_path).map_err(|e| LoadError::OpenFileFailed {
591+
let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed {
591592
source: e,
592593
path: main_path.to_owned(),
593-
})?,
594+
})?;
595+
let mut reader =
596+
BufReader::new(
597+
&file,
594598
);
595599

596600
// Verify magic
@@ -732,7 +736,7 @@ impl Model {
732736
// Initialize the context
733737
let context = ggml::Context::init(ctx_size);
734738

735-
let model = {
739+
let mut model = {
736740
let mut tensors = HashMap::new();
737741

738742
let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab);
@@ -796,15 +800,20 @@ impl Model {
796800
layers,
797801
tensors,
798802
_context: context,
803+
mmap: None,
799804
}
800805
};
801806

802807
match model_type {
803808
ModelType::GGMF | ModelType::Unversioned => {
804-
load_weights_ggmf_or_unversioned(reader, main_path, load_progress_callback, &model)?
809+
let file_offset = reader.stream_position()?;
810+
drop(reader);
811+
load_weights_ggmf_or_unversioned(file_offset, main_path, load_progress_callback, &model)?
805812
}
806813
ModelType::GGJT => {
807-
load_weights_ggjt(reader, main_path, load_progress_callback, &model)?
814+
let mmap = unsafe { Mmap::map(&file)? };
815+
load_weights_ggjt(&mut reader, &mmap, main_path, load_progress_callback, &model)?;
816+
model.mmap = Some(mmap);
808817
}
809818
}
810819

0 commit comments

Comments
 (0)