Source code for malva.combine

import h5py
import os
import shutil
import logging
import json
import uuid
import tempfile

from rich.progress import track
from concurrent.futures import ProcessPoolExecutor, as_completed
from threading import current_thread, main_thread

from malva.index import MalvaIndex
from malva.utils import check_directory_exists, check_file_exists

[docs] def combine_indices(combine_dir, project_uuids=None, project_id_offset=0): output_file = os.path.join(combine_dir, "malva_index.h5") project_mapping = {} with h5py.File(output_file, 'w', driver='split') as f: n_chunks_total = 0 n_unique_projects = 0 if project_uuids is None: _tracker = sorted(os.listdir(combine_dir)) else: _tracker = [os.path.join(combine_dir, puuid) for puuid in project_uuids] iterator = track(_tracker, description='Processing sub-indices') if current_thread() == main_thread() else _tracker for index_folder in iterator: folder_path = os.path.join(combine_dir, index_folder) if not os.path.isabs(index_folder) else index_folder index_file = os.path.join(folder_path, "malva_index.h5") if not os.path.isdir(folder_path): continue if not os.path.exists(f'{index_file}-r.h5') or not os.path.exists(f'{index_file}-m.h5'): raise FileNotFoundError(f"The malva index {folder_path} was not found") project_id = project_id_offset + n_unique_projects project_uuid = os.path.basename(folder_path) if project_uuids is None else project_uuids[project_id] with h5py.File(index_file, 'r', driver="split") as index_f: n_chunks = index_f.attrs['n_chunks'] for i in range(n_chunks): indices_dataset_name = f"index_{n_chunks_total}_indices" indptr_dataset_name = f"index_{n_chunks_total}_indptr" data_dataset_name = f"index_{n_chunks_total}_data" f[indices_dataset_name] = h5py.ExternalLink(f'{index_file}', f"index_{i}_indices") f[indptr_dataset_name] = h5py.ExternalLink(f'{index_file}', f"index_{i}_indptr") f[data_dataset_name] = h5py.ExternalLink(f'{index_file}', f"index_{i}_data") project_mapping[n_chunks_total] = (project_id, project_uuid) n_chunks_total += 1 if 'spatial_coord' in index_f and 'spatial_coord' not in f: # TODO: fix so we store the spatial coordinates from various projects. Only works for single cell now! f.create_dataset('spatial_coord', data=index_f['spatial_coord']) if 'kmer_size' not in f.attrs and 'kmer_size' in index_f.attrs: f.attrs['kmer_size'] = index_f.attrs['kmer_size'] if 'coord_lims' not in f.attrs and 'coord_lims' in index_f.attrs: f.attrs['coord_lims'] = index_f.attrs['coord_lims'] if 'n_spatial' not in f.attrs and 'n_spatial' in index_f.attrs: f.attrs['n_spatial'] = index_f.attrs['n_spatial'] n_unique_projects += 1 f.attrs['project_mapping'] = json.dumps(project_mapping) f.attrs['n_chunks'] = n_chunks_total logging.info(f"Created combined index with {n_chunks_total} chunks from {n_unique_projects} projects") return project_mapping, n_chunks_total
[docs] def process_merge_chunks(index_dir, merge_projects=False): mindex = MalvaIndex(index_dir) mindex.verbose = True if mindex.n_chunks > 1: logging.info(f"Merging {mindex.n_chunks} chunks in index at {index_dir}") merged_file = f"{mindex.index_file}.merged" if merge_projects: logging.info("Merging projects with distinct project IDs") mindex.merge_chunks(merged_file, merge_projects=merge_projects) os.remove(f'{mindex.index_file}-r.h5') os.remove(f'{mindex.index_file}-m.h5') shutil.move(f'{merged_file}-r.h5', f'{mindex.index_file}-r.h5') shutil.move(f'{merged_file}-m.h5', f'{mindex.index_file}-m.h5')
[docs] def process_group(group_idx, group, base_dir, project_uuids, merge_chunks, merge_projects, project_id_offset=0): # Add project_id_offset parameter with default value of 0 parent_dir = os.path.dirname(os.path.abspath(base_dir)) temp_group_dir = tempfile.mkdtemp(prefix="malva_group_", dir=parent_dir) sorted_group = sorted(group) for folder in sorted_group: src = os.path.join(base_dir, folder) dst = os.path.join(temp_group_dir, folder) try: os.symlink(src, dst) except Exception as e: logging.warning(f"Symlink failed for {src} with error {e}, using copytree instead.") shutil.copytree(src, dst) # Pass the project_id_offset to combine_indices mapping, _ = combine_indices(temp_group_dir, None, project_id_offset=project_id_offset) if merge_chunks: process_merge_chunks(temp_group_dir, merge_projects) return group_idx, mapping, temp_group_dir
[docs] def hierarchical_combine(base_dir, project_uuids, merge_chunks=False, merge_projects=False, group_size=16, threads=1): all_indices = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d not in ["malva_index.h5", f"malva_index.h5-r.h5", f"malva_index.h5-m.h5"]] all_indices.sort() if len(all_indices) <= group_size: logging.info("Number of indices is within direct merge limit. Calling combine_indices directly.") combine_indices(base_dir, project_uuids) if merge_chunks: process_merge_chunks(base_dir, merge_projects) return logging.info(f"Performing hierarchical merge on {len(all_indices)} indices in groups of {group_size}") groups = [] num_groups = 0 current_project_offset = 0 for i in range(0, len(all_indices), group_size): group = all_indices[i:i + group_size] if group: groups.append((num_groups, group, current_project_offset)) current_project_offset += len(group) num_groups += 1 results = [] with ProcessPoolExecutor(max_workers=threads) as executor: futures = {executor.submit( process_group, group_idx, group, base_dir, project_uuids, merge_chunks, merge_projects, project_id_offset ): group_idx for group_idx, group, project_id_offset in groups} for future in as_completed(futures): try: result = future.result() results.append(result) logging.info(f"Completed processing group {futures[future]}") except Exception as e: logging.error(f"Error processing group {futures[future]}: {e}") # Sort results by group index to maintain correct order results.sort(key=lambda x: x[0]) group_mappings = [] intermediate_dirs = [] # We now have proper project IDs from process_group, no need to update them for group_idx, mapping, temp_group_dir in results: group_mappings.append(mapping) intermediate_dirs.append(temp_group_dir) # merge the intermediate merged indices into a final combined index parent_dir = os.path.dirname(os.path.abspath(base_dir)) final_temp_dir = tempfile.mkdtemp(prefix="malva_final_", dir=parent_dir) logging.info(f"Merging {len(intermediate_dirs)} intermediate indices into final index") for idx, inter_dir in enumerate(intermediate_dirs): link_name = os.path.join(final_temp_dir, f"intermediate_{idx}") try: os.symlink(inter_dir, link_name) except Exception as e: logging.warning(f"Symlink failed for {inter_dir} with error {e}, using copytree instead.") shutil.copytree(inter_dir, link_name) combine_indices(final_temp_dir, project_uuids=None) if merge_chunks: process_merge_chunks(final_temp_dir, merge_projects=False) # project mappings from intermediate indices (most reliable way...) final_mapping = {} global_chunk_idx = 0 logging.info("Reading project mappings from intermediate indices...") for idx, inter_dir in enumerate(intermediate_dirs): inter_index = os.path.join(inter_dir, "malva_index.h5") try: with h5py.File(inter_index, 'r', driver='split') as f: if 'project_mapping' in f.attrs: sub_mapping = json.loads(f.attrs['project_mapping']) for key in sorted(sub_mapping.keys(), key=lambda x: int(x)): # Original project ID and UUID from the subindex if isinstance(sub_mapping[key], list): project_id, project_uuid = sub_mapping[key] else: project_id, project_uuid = sub_mapping[key] final_mapping[str(global_chunk_idx)] = [project_id, project_uuid] global_chunk_idx += 1 except Exception as e: logging.error(f"Error reading project mapping from {inter_index}: {e}") final_index = os.path.join(final_temp_dir, "malva_index.h5") with h5py.File(final_index, 'r+', driver='split') as f: f.attrs['project_mapping'] = json.dumps(final_mapping) logging.info(f"Updated final index with project mapping for {len(final_mapping)} chunks") # Ensure file exists before moving for suffix in ["", "-r.h5", "-m.h5"]: source_file = f"{final_index}{suffix}" target_file = os.path.join(base_dir, f"malva_index.h5{suffix}") if os.path.exists(source_file): logging.info(f"Moving {source_file} to {target_file}") shutil.copy2(source_file, target_file) try: os.remove(source_file) except Exception as e: logging.warning(f"Could not remove source file {source_file}: {e}") else: logging.warning(f"Source file {source_file} does not exist, skipping move operation") # Clean up with error handling try: for temp_dir in intermediate_dirs: if os.path.exists(temp_dir): shutil.rmtree(temp_dir) if os.path.exists(final_temp_dir): shutil.rmtree(final_temp_dir) except Exception as e: logging.error(f"Error during cleanup: {e}") logging.info("Hierarchical merge complete and temporary files cleaned up.")
def _run_combine(args): if not check_directory_exists(args.index_in): logging.error("Base directory does not exist") return project_uuids = None if args.merge_projects and check_file_exists(args.uuid): with open(args.uuid) as file: project_uuids = [line.rstrip() for line in file] index_out = os.path.join(args.index_in, "malva_index.h5") # Check if the combined index already exists. if check_file_exists(index_out + "-r.h5") and check_file_exists(index_out + "-m.h5"): logging.warning(f"The combined (non-merged) index already exists at {index_out}") else: # Identify sub-indices in the base directory. index_dirs = [d for d in os.listdir(args.index_in) if os.path.isdir(os.path.join(args.index_in, d)) and d not in ["malva_index.h5", f"malva_index.h5-r.h5", f"malva_index.h5-m.h5"]] if len(index_dirs) <= 16: logging.info("Number of indices is 16 or less, merging directly.") combine_indices(args.index_in, project_uuids) else: logging.info("Large number of indices detected, performing hierarchical merging.") hierarchical_combine(args.index_in, project_uuids, merge_chunks=args.merge_chunks, merge_projects=args.merge_projects, threads=args.threads) # we re-configure the number of spatial merged coordinates mindex = MalvaIndex(args.index_in, verbose=True) mindex.open("r+") mindex.index.attrs['n_spatial'] = 1_000_000_000 mindex.close() # After combining, if merge_chunks is enabled and there is more than one chunk, # perform a final chunk merge on the combined index. if args.merge_chunks: # Create a MalvaIndex instance for the final merged index. mindex = MalvaIndex(args.index_in, verbose=True) if mindex.n_chunks > 1: logging.info(f"Now, {mindex.n_chunks} chunks will be merged at final level") merged_file = f"{mindex.index_file}.merged" if args.merge_projects: logging.info("Merging projects with distinct project IDs at final level") mindex.merge_chunks(merged_file, merge_projects=args.merge_projects) os.remove(f'{mindex.index_file}-r.h5') os.remove(f'{mindex.index_file}-m.h5') shutil.move(f'{merged_file}-r.h5', f'{mindex.index_file}-r.h5') shutil.move(f'{merged_file}-m.h5', f'{mindex.index_file}-m.h5') logging.info("SUCCESS!") if __name__ == "__main__": from malva.cli import get_combine_parser args = get_combine_parser().parse_args() _run_combine()