-
Notifications
You must be signed in to change notification settings - Fork 525
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
874716b
commit 1a26110
Showing
8 changed files
with
188 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
use candle::Device; | ||
|
||
pub fn get_device() -> Device { | ||
Device::new_metal(0).unwrap_or(Device::new_cuda(0).unwrap_or(Device::Cpu)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,4 @@ pub mod capture_screenshot_by_window; | |
#[cfg(target_os = "windows")] | ||
pub use microsoft::perform_ocr_windows; | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
SAIVEERENDER24
|
||
pub use tesseract::perform_ocr_tesseract; | ||
pub mod multimodal_embeddings; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
use std::ops::Mul; | ||
|
||
use anyhow::Result; | ||
use candle::{DType, Device, Tensor}; | ||
use candle_nn::{ops::softmax, VarBuilder}; | ||
use candle_transformers::models::siglip::{Config, Model as SiglipModel}; | ||
use image::DynamicImage; | ||
use tokenizers::Tokenizer; | ||
|
||
pub struct MultimodalEmbedder { | ||
model: SiglipModel, | ||
tokenizer: Tokenizer, | ||
device: Device, | ||
config: Config, | ||
} | ||
|
||
impl MultimodalEmbedder { | ||
pub fn new(device: &Device) -> Result<Self> { | ||
let config = Config::base_patch16_224(); | ||
|
||
// Load the model weights from safetensors file | ||
let model_file = { | ||
let api = hf_hub::api::sync::Api::new()?; | ||
let api = api.model("google/siglip-base-patch16-224".to_string()); | ||
api.get("model.safetensors")? | ||
}; | ||
|
||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, device)? }; | ||
|
||
let model = SiglipModel::new(&config, vb)?; | ||
let tokenizer = Self::get_tokenizer(None)?; | ||
|
||
Ok(Self { | ||
model, | ||
tokenizer, | ||
device: device.clone(), | ||
config, | ||
}) | ||
} | ||
|
||
fn get_tokenizer(tokenizer_path: Option<String>) -> Result<Tokenizer> { | ||
let tokenizer_path = match tokenizer_path { | ||
None => { | ||
let api = hf_hub::api::sync::Api::new()?; | ||
let api = api.model("google/siglip-base-patch16-224".to_string()); | ||
api.get("tokenizer.json")? | ||
} | ||
Some(path) => path.into(), | ||
}; | ||
|
||
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) | ||
} | ||
|
||
pub fn compute_embeddings( | ||
&self, | ||
image: &DynamicImage, | ||
ocr_text: &str, | ||
) -> Result<(Tensor, Tensor)> { | ||
let image_tensor = self.preprocess_image(image)?; | ||
let text_tensor = self.tokenize_text(ocr_text)?; | ||
|
||
let (text_embeddings, image_embeddings) = | ||
self.model.forward(&image_tensor, &text_tensor)?; | ||
Ok((text_embeddings, image_embeddings)) | ||
} | ||
|
||
pub fn compute_similarity( | ||
&self, | ||
text_embeddings: &Tensor, | ||
image_embeddings: &Tensor, | ||
) -> anyhow::Result<Tensor> { | ||
// compute dot product between text and image embeddings | ||
let similarity = text_embeddings.matmul(&image_embeddings.transpose(0, 1)?)?; | ||
|
||
// apply softmax to get probabilities | ||
let similarity = softmax(&similarity, 1)?; | ||
|
||
Ok(similarity) | ||
} | ||
|
||
fn preprocess_image(&self, image: &DynamicImage) -> Result<Tensor> { | ||
let image_size = self.config.vision_config.image_size; | ||
let img = image.resize_to_fill( | ||
image_size as u32, | ||
image_size as u32, | ||
image::imageops::FilterType::Triangle, | ||
); | ||
let img = img.to_rgb8(); | ||
let img = img.into_raw(); | ||
let img = Tensor::from_vec(img, (image_size, image_size, 3), &self.device)? | ||
.permute((2, 0, 1))? | ||
.to_dtype(DType::F32)? | ||
.affine(2. / 255., -1.)? | ||
.unsqueeze(0)?; | ||
Ok(img) | ||
} | ||
|
||
fn tokenize_text(&self, text: &str) -> anyhow::Result<Tensor> { | ||
let encoding = self | ||
.tokenizer | ||
.encode(text, true) | ||
.map_err(|e| anyhow::anyhow!(e))?; | ||
let mut tokens = encoding.get_ids().to_vec(); | ||
let max_len = self.config.text_config.max_position_embeddings; | ||
let pad_id = self.config.text_config.pad_token_id; | ||
|
||
// Pad the sequence to have the correct length | ||
let len_diff = max_len - tokens.len(); | ||
if len_diff > 0 { | ||
tokens.extend(vec![pad_id; len_diff]); | ||
} | ||
|
||
let input_ids = Tensor::new(vec![tokens], &self.device)?; | ||
Ok(input_ids) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
use anyhow::Result; | ||
use candle::Device; | ||
use image::DynamicImage; | ||
use screenpipe_core::get_device; | ||
use screenpipe_vision::multimodal_embeddings::MultimodalEmbedder; | ||
use std::time::Instant; | ||
|
||
// Mock function to simulate screenshot capture | ||
fn capture_screenshot() -> Result<DynamicImage> { | ||
// For this test, we'll create a dummy image | ||
let img = DynamicImage::new_rgb8(224, 224); | ||
Ok(img) | ||
} | ||
|
||
#[test] | ||
fn test_screenshot_and_embedding_speed() -> Result<()> { | ||
let device = get_device(); | ||
let embedder = MultimodalEmbedder::new(&device).unwrap(); | ||
|
||
let start = Instant::now(); | ||
|
||
// Capture screenshot | ||
let screenshot = capture_screenshot()?; | ||
let screenshot_time = start.elapsed(); | ||
|
||
// Perform OCR (mocked for this test) | ||
let ocr_text = "This is a test OCR text"; | ||
|
||
// Compute embeddings | ||
let embedding_start = Instant::now(); | ||
let (text_embeddings, image_embeddings) = embedder.compute_embeddings(&screenshot, ocr_text)?; | ||
let embedding_time = embedding_start.elapsed(); | ||
|
||
// Compute similarity | ||
let similarity = embedder.compute_similarity(&text_embeddings, &image_embeddings)?; | ||
|
||
let total_time = start.elapsed(); | ||
|
||
println!("Screenshot capture time: {:?}", screenshot_time); | ||
println!("Embedding computation time: {:?}", embedding_time); | ||
println!("Total processing time: {:?}", total_time); | ||
println!("Similarity shape: {:?}", similarity.shape()); | ||
|
||
Ok(()) | ||
} |
wow create with calm positive vibe