Remove deprecated crates and introduce OpenCL matcher integration

This commit is contained in:
2025-04-09 13:41:50 +02:00
parent 2a584e878f
commit 47bbf25ac7
17 changed files with 202 additions and 187 deletions

View File

@@ -0,0 +1,9 @@
[package]
name = "schemsearch-ocl-matcher"
version = "0.1.0"
edition = "2021"
[dependencies]
schemsearch-common = { path = "../schemsearch-common" }
ocl = "0.19.7"
libmath = "0.2.1"

View File

@@ -0,0 +1,51 @@
__kernel void add(__global int* result,
__global uint* schem,
__global uint* pattern,
const int p_width,
const int p_height,
const int p_depth,
const uint air_id,
const int ignore_air,
const int air_as_any,
const int skipamount) {
int x = get_global_id(0);
int y = get_global_id(1);
int z = get_global_id(2);
int width = get_global_size(0);
int height = get_global_size(1);
int depth = get_global_size(2);
if (x > width - p_width || y > height - p_height || z > depth - p_depth) {
return;
}
int wrong_blocks = 0;
for (int py = 0; py < p_height; py++) {
for (int pz = 0; pz < p_depth; pz++) {
for (int px = 0; px < p_width; px++) {
int s_idx = (x + px) + width * ((z + pz) + (y + py) * depth);
int p_idx = px + p_width * (pz + py * p_depth);
uint schem_block = schem[s_idx];
uint pattern_block = pattern[p_idx];
if ((ignore_air && schem_block != air_id) || (air_as_any && pattern_block != air_id)) {
continue;
}
if (schem_block != pattern_block) {
wrong_blocks++;
if (wrong_blocks > skipamount) {
int idx = x + z * width + y * width * depth;
result[idx] = wrong_blocks;
return;
}
}
}
}
}
int idx = x + z * width + y * width * depth;
result[idx] = wrong_blocks;
}

View File

@@ -0,0 +1,95 @@
use ocl::{Buffer, MemFlags, ProQue, Platform};
use ocl::SpatialDims::Three;
use schemsearch_common::{Match, SearchBehavior};
use math::round::ceil;
const KERNEL: &str = include_str!("kernel.cl");
pub fn ocl_available() -> bool {
!Platform::list().is_empty()
}
pub fn ocl_search(
schem: &[i32],
schem_size: [usize; 3],
pattern: &[i32],
pattern_size: [usize; 3],
air_id: i32,
search_behavior: SearchBehavior,
) -> Result<Vec<Match>, String> {
search_ocl(schem, schem_size, pattern, pattern_size, air_id, search_behavior).map_err(|e| e.to_string())
}
fn search_ocl(
schem: &[i32],
schem_size: [usize; 3],
pattern: &[i32],
pattern_size: [usize; 3],
air_id: i32,
search_behavior: SearchBehavior,
) -> ocl::Result<Vec<Match>> {
let pattern_width = pattern_size[0];
let pattern_height = pattern_size[1];
let pattern_length = pattern_size[2];
let schem_width = schem_size[0];
let schem_height = schem_size[1];
let schem_length = schem_size[2];
let pattern_blocks = (pattern_width * pattern_height * pattern_length) as f32;
let skip_amount = ceil((pattern_blocks * (1.0 - search_behavior.threshold)) as f64, 0) as i32;
let pro_que = ProQue::builder()
.src(KERNEL)
.dims(Three(schem_width, schem_height, schem_length))
.build()?;
let buffer = Buffer::builder()
.queue(pro_que.queue().clone())
.flags(MemFlags::new().read_write())
.fill_val(-1)
.len(schem.len())
.build()?;
let schem_buffer = create_schem_buffer(schem, &pro_que)?;
let pattern_buffer = create_schem_buffer(pattern, &pro_que)?;
let kernel = pro_que.kernel_builder("add")
.arg(&buffer)
.arg(&schem_buffer)
.arg(&pattern_buffer)
.arg(pattern_width as i32)
.arg(pattern_height as i32)
.arg(pattern_length as i32)
.arg(air_id) // air_id
.arg(search_behavior.ignore_air as u32) // ignore_air
.arg(search_behavior.air_as_any as u32) // air_as_any
.arg(skip_amount)
.build()?;
unsafe { kernel.enq()?; }
let mut vec = vec![0i32; buffer.len()];
buffer.read(&mut vec).enq()?;
Ok(vec.into_iter().enumerate().filter(|(_, v)| *v < skip_amount && *v != -1).map(|(i, v)| {
Match {
x: (i % schem_width) as u16,
y: ((i / (schem_width * schem_length)) % schem_height) as u16,
z: ((i / schem_width) % schem_length) as u16,
percent: (pattern_blocks - v as f32) / pattern_blocks,
}
}).collect())
}
fn create_schem_buffer(pattern: &[i32], pro_que: &ProQue) -> ocl::Result<Buffer<i32>> {
Buffer::builder()
.queue(pro_que.queue().clone())
.flags(MemFlags::new().read_only())
.len(pattern.len())
.copy_host_slice(pattern)
.build()
}