Source code for malva.serve.serve

import io
import logging
import numpy as np
import pandas as pd
from flask import Flask, render_template, send_file, request, session, jsonify, Blueprint, g, url_for, make_response
from flask_session import Session
from flask_cors import CORS
from PIL import Image
import uuid
import os
from dataclasses import dataclass
from typing import Optional, Tuple, List, Union
import datashader as ds
from skimage.filters import gaussian
from scipy.spatial import cKDTree
import threading
from pathlib import Path
import tempfile
import re

# for proxy functionality
from urllib.parse import urljoin
from werkzeug.middleware.proxy_fix import ProxyFix

from malva.index import MalvaIndex
from malva.dbutils import handle_sequence
from malva.serve.reportgen import HTMLReportGenerator
# from malva.utils import check_file_exists
# from malva.serve.modeling import handle_natural_query, setup_model
# from malva.serve.templates.strings import HINT_SEQUENCE_QUERY

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Errors
[docs] class SequenceValidationError(ValueError): """Custom error for sequence validation issues"""
[docs] def __init__(self, message: str, help_text: Optional[str] = None): super().__init__(message) self.help_text = help_text
# Constants MAX_LOAD_ALL = 10_000_000 TILE_SIZE = 256 MAX_SEQUENCE_LENGTH = 1000 EMPTY_TILE = np.zeros((TILE_SIZE, TILE_SIZE, 4), dtype=np.uint8) EMPTY_TILE_BYTES = io.BytesIO() Image.fromarray(EMPTY_TILE).save(EMPTY_TILE_BYTES, format='PNG') EMPTY_TILE_BYTES.seek(0)
[docs] def interactive_query_standard( sequence: str, sliding_size: int = 128, pct_threshold: float = 0.65, low_complexity_filter: bool = True, countmaxkmer: int = 100_000, countminkmer: int = 10, ) -> Tuple[np.ndarray, np.ndarray, List]: """Process standard sequence query""" logger.info(f"Querying sequence '{sequence}'") result = global_state.kmer_index.where( sequence, sliding_size=sliding_size, pct_threshold=pct_threshold, count_at_most=int(countmaxkmer), count_at_least=int(countminkmer), use_background_model=False, show_coverage=False, force_reload=False, max_mem="1M" ) locs, ints, where_abundant = result[0] return locs, ints, where_abundant
[docs] class MalvaProxyFix: """Custom middleware for handling proxy prefixes"""
[docs] def __init__(self, app, uuid): self.app = app self.prefix = f"/api/malva/view/{uuid}"
def __call__(self, environ, start_response): # Ensure PATH_INFO contains the full path if environ['PATH_INFO'].startswith('/view/'): # Reconstruct the full path if it's not present environ['PATH_INFO'] = f"/api/malva{environ['PATH_INFO']}" # Now handle the complete path if environ['PATH_INFO'].startswith(self.prefix): script_name = self.prefix path_info = environ['PATH_INFO'][len(self.prefix):] # Ensure path_info starts with a slash if path_info and not path_info.startswith('/'): path_info = '/' + path_info environ['SCRIPT_NAME'] = script_name environ['PATH_INFO'] = path_info return self.app(environ, start_response)
[docs] @dataclass class SpatialData: """Class to hold spatial data and related methods""" coordinates: np.ndarray # Nx2 array of x,y coordinates values: np.ndarray # N-length array of values bounds: Tuple[float, float, float, float] # xmin, xmax, ymin, ymax @property def xmin(self) -> float: return self.bounds[0] @property def xmax(self) -> float: return self.bounds[1] @property def ymin(self) -> float: return self.bounds[2] @property def ymax(self) -> float: return self.bounds[3]
[docs] class GlobalState: """Singleton class to manage global state and data""" _instance = None _lock = threading.Lock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance
[docs] def __init__(self): if not hasattr(self, 'initialized'): self.coordinates = None self.values = None self.bounds = None self.initialized = False
[docs] def initialize_from_index(self, index_path: str, max_mem: float = 32): """Initialize global state from a malva index file""" with self._lock: try: logger.info("Loading malva index and metadata") self.kmer_index = MalvaIndex(index_path, verbose=True) self.kmer_index.open() # Load the initial data points if len(self.kmer_index.index[f"index_0_data"]) >= 2*MAX_LOAD_ALL: self._loc_all, self._abu_all = np.unique( self.kmer_index.index[f"index_0_data"][MAX_LOAD_ALL:(2*MAX_LOAD_ALL)], return_counts=True, ) else: self._loc_all, self._abu_all = np.unique( self.kmer_index.index[f"index_0_data"][0:MAX_LOAD_ALL], return_counts=True, ) logger.info("Loading the pointers into memory. Might take a while.") self.kmer_index.where( "A" * self.kmer_index.kmer_size + "T", max_mem=max_mem, use_background_model=False ) # Set up the coordinate bounds self.bounds = ( self.kmer_index.coord_lims[0], # xmin self.kmer_index.coord_lims[1] + 1, # xmax self.kmer_index.coord_lims[2], # ymin self.kmer_index.coord_lims[3] + 1 # ymax ) # Store the spatial coordinates self.xy = self.kmer_index.spatial_coord[:] self.initialized = True except Exception as e: logger.error(f"Error initializing from index: {str(e)}") raise
[docs] def initialize(self, coordinates: np.ndarray, values: np.ndarray, bounds: Tuple): """Initialize with data""" with self._lock: self.coordinates = coordinates self.values = values self.bounds = bounds self.initialized = True
[docs] def get_bounds(self): """Get the coordinate bounds""" if not self.initialized: raise RuntimeError("GlobalState not initialized") return self.bounds
[docs] def get_coordinates(self): """Get the spatial coordinates""" if not self.initialized: raise RuntimeError("GlobalState not initialized") return self.xy if self.xy is not None else self.coordinates
[docs] class UserSession: """Class to manage per-user session data"""
[docs] def __init__(self, session_id: str): self.session_id = session_id self.data_path = Path(tempfile.gettempdir()) / f"user_{session_id}.nc" self.lock = threading.Lock() self.data = None self.query_results = None self.max_values = [] self.background_data = None # Store global intensity ranges for consistent scaling self.background_intensity_range = None self.query_intensity_range = None # Store trees for fast spatial lookup self.background_tree = None self.query_tree = None
[docs] def initialize_background(self): """Initialize session with background data from global state""" if global_state.initialized and global_state._loc_all is not None: with self.lock: logger.info("Initializing background data for session") background_coords = global_state.xy[global_state._loc_all] background_values = global_state._abu_all # Calculate global intensity range for background self.background_intensity_range = self._calculate_intensity_range(background_values) self.background_data = { 'locations': background_coords, 'intensities': background_values } self.background_tree = cKDTree(background_coords) logger.info(f"Background data initialized with intensity range: {self.background_intensity_range}")
def _calculate_intensity_range(self, values: np.ndarray) -> Tuple[float, float]: """Calculate robust intensity range for values""" if len(values) == 0: return (0, 1) # Use percentiles for more robust range estimation min_val = np.percentile(values, 1) # 1st percentile instead of min max_val = np.percentile(values, 99) # 99th percentile instead of max # Ensure we don't have zero range if min_val == max_val: max_val = min_val + 1 return (min_val, max_val)
[docs] def add_query_result(self, locations: np.ndarray, intensities: np.ndarray): """Add query results to user session""" with self.lock: # Calculate global intensity range for query self.query_intensity_range = self._calculate_intensity_range(intensities) self.query_results = { 'locations': locations, 'intensities': intensities } self.query_tree = cKDTree(locations) logger.info(f"Query results added with intensity range: {self.query_intensity_range}")
[docs] def add_background_data(self, coordinates: np.ndarray, values: np.ndarray): """Store background data""" with self.lock: self.background_data = { 'locations': coordinates, 'intensities': values } self.background_tree = cKDTree(coordinates)
def _rasterize_points(self, points: np.ndarray, values: np.ndarray, tree: cKDTree, x_range: Tuple[float, float], y_range: Tuple[float, float], value_range: Optional[Tuple[float, float]] = None, sigma: float = 1.0) -> np.ndarray: """Convert points to raster image preserving aspect ratio""" if points is None or len(points) == 0: return np.zeros((TILE_SIZE, TILE_SIZE), dtype=np.uint8) # Find the largest range to maintain aspect ratio x_size = x_range[1] - x_range[0] y_size = y_range[1] - y_range[0] max_range = max(x_size, y_size) half_range = max_range * 0.5 x_center = (x_range[0] + x_range[1]) * 0.5 y_center = (y_range[0] + y_range[1]) * 0.5 # Query point for the radius search query_point = np.array([x_center, y_center]) radius = np.sqrt(2) * half_range x_adjusted = (x_center - half_range, x_center + half_range) y_adjusted = (y_center - half_range, y_center + half_range) # Query points using spatial index indices = tree.query_ball_point(query_point, radius) if not indices: return np.zeros(TILE_SIZE * TILE_SIZE, dtype=np.uint8).reshape(TILE_SIZE, TILE_SIZE) # Use queried points filtered_points = points[indices] filtered_values = values[indices] # Additional filter for exact box bounds x_adjusted = (x_center - half_range, x_center + half_range) y_adjusted = (y_center - half_range, y_center + half_range) mask = ((filtered_points[:, 0] >= x_adjusted[0]) & (filtered_points[:, 0] < x_adjusted[1]) & (filtered_points[:, 1] >= y_adjusted[0]) & (filtered_points[:, 1] < y_adjusted[1])) # Use datashader with adjusted ranges canvas = ds.Canvas(plot_width=TILE_SIZE, plot_height=TILE_SIZE, x_range=x_adjusted, y_range=y_adjusted) df = pd.DataFrame({ 'x': filtered_points[:, 0], 'y': filtered_points[:, 1], 'val': filtered_values }) agg = canvas.points(df, 'x', 'y', agg=ds.mean('val')) img = np.nan_to_num(agg.values, copy=False) # Use global value range for normalization if provided if value_range is not None and value_range[1] > value_range[0]: img = np.clip(img, value_range[0], value_range[1]) img = (img - value_range[0]) / (value_range[1] - value_range[0]) * 255 else: if img.max() > 0: img = img / img.max() * 255 img = gaussian(img, sigma=sigma, preserve_range=True) return img.astype(np.uint8)
[docs] def get_tile_data(self, x: int, y: int, zoom: int) -> np.ndarray: """Get tile data for this user's view""" if not global_state.initialized: raise RuntimeError("Global state not initialized") # Get global coordinate limits coord_lims = global_state.kmer_index.coord_lims x_total = coord_lims[1] - coord_lims[0] y_total = coord_lims[3] - coord_lims[2] max_dim = max(x_total, y_total) # Calculate tile bounds based on the maximum dimension n = 2.0 ** zoom tile_size = max_dim / n x_bounds = ( coord_lims[0] + x * tile_size, coord_lims[0] + (x + 1) * tile_size ) y_bounds = ( coord_lims[2] + y * tile_size, coord_lims[2] + (y + 1) * tile_size ) _sigma = 1 if zoom <= 1 else 1.2 # Initialize empty RGBA image img_all = np.zeros((TILE_SIZE, TILE_SIZE, 4), dtype=np.uint8) # Add background data to first channel (red) if hasattr(self, 'background_data') and self.background_data is not None: img_all[:, :, 0] = self._rasterize_points( self.background_data['locations'], self.background_data['intensities'], self.background_tree, x_bounds, y_bounds, value_range=self.background_intensity_range, sigma=_sigma ) # Add query results to second channel (green) if self.query_results is not None and len(self.query_results.get('locations', [])) > 0: img_all[:, :, 1] = self._rasterize_points( self.query_results['locations'], self.query_results['intensities'], self.query_tree, x_bounds, y_bounds, value_range=self.query_intensity_range, sigma=_sigma ) # Set alpha channel only where we have data img_all[:, :, 3] = np.maximum(img_all[:, :, 0], img_all[:, :, 1]) > 0 img_all[:, :, 3] *= 255 return img_all
[docs] def cleanup(self): """Clean up session data""" with self.lock: if self.data_path.exists(): self.data_path.unlink()
[docs] def create_app(init_state=True, _uuid=None): """Application factory function""" app = Flask(__name__) # Configure app app.config.update( SECRET_KEY=os.environ.get("SECRET_KEY", "dev_key"), SESSION_TYPE="filesystem", SESSION_PERMANENT=True, ) if _uuid: # Set application root for proper URL generation app.config['APPLICATION_ROOT'] = f"/api/malva/view/{_uuid}/" # Apply proxy fixes in correct order app.wsgi_app = MalvaProxyFix(app.wsgi_app, _uuid) app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) Session(app) # Session management user_sessions = {} session_lock = threading.Lock() def get_user_session(): """Get or create user session""" if 'user_id' not in session: session['user_id'] = str(uuid.uuid4()) user_id = session['user_id'] with session_lock: if user_id not in user_sessions: new_session = UserSession(user_id) if init_state and global_state.initialized: new_session.initialize_background() user_sessions[user_id] = new_session return user_sessions[user_id] @app.before_request def before_request(): """Set up user session before each request Note: we should never get these via the /api/malva/view or other """ if not request.path.startswith("/health"): try: g.user_session = get_user_session() except Exception as e: logger.error(f"Error in before_request: {str(e)}") return jsonify({"error": "Internal server error"}), 500 @app.route("/parse_queried", methods=["POST"]) def parse_queried(): """Process and return query results""" try: if "where_abundant" not in session: return jsonify({ "success": False, "query_term": "", "sequence": "", "scores": [] }) scores = [0] if len(session["where_abundant"]) > 2: data = np.array(session["where_abundant"]) df = ( pd.DataFrame({"pos": data[:, 0], "val": data[:, 1]}) .groupby("pos") .mean() .rolling(window=24) .mean() ) # Convert numpy values to native Python types for proper JSON serialization scores = df["val"].fillna(0).tolist() if scores: # Normalize only if we have scores min_val = min(scores) max_val = max(scores) if max_val > min_val: scores = [(x - min_val) / (max_val - min_val) for x in scores] response_data = { "success": True, "query_term": session.get("query_term", ""), "sequence": session.get("query_seq", ""), "scores": scores } # Set correct content type header return jsonify(response_data) except Exception as e: logger.error(f"Error in parse_queried: {str(e)}") return jsonify({ "success": False, "error": str(e) }), 500 @app.route("/tiles/<zoom>/<int:x>/<int:y>.png") def tile(zoom: str, x: int, y: int): """Generate and serve tile image""" try: zoom = int(float(zoom)) coord_lims = global_state.kmer_index.coord_lims # Calculate tile bounds n = 2.0 ** zoom x_tile_min = x / n * (coord_lims[1] - coord_lims[0]) + coord_lims[0] x_tile_max = (x + 1) / n * (coord_lims[1] - coord_lims[0]) + coord_lims[0] y_tile_min = y / n * (coord_lims[3] - coord_lims[2]) + coord_lims[2] y_tile_max = (y + 1) / n * (coord_lims[3] - coord_lims[2]) + coord_lims[2] # Check if tile is completely outside data bounds if (x_tile_min > coord_lims[1] or x_tile_max < coord_lims[0] or y_tile_min > coord_lims[3] or y_tile_max < coord_lims[2]): # Return pre-generated transparent tile return send_file( EMPTY_TILE_BYTES, mimetype='image/png', ) img_array = g.user_session.get_tile_data(x, y, zoom) # Convert numpy array to PIL Image and serve img_io = io.BytesIO() Image.fromarray(img_array).save( img_io, format='PNG', optimize=True, compress_level=1 # Faster compression ) img_io.seek(0) return send_file(img_io, mimetype='image/png') except Exception as e: logger.error(f"Error generating tile: {str(e)}") return str(e), 400 @app.route("/health") def health(): """Health check endpoint""" return jsonify({ "status": "healthy" if global_state.initialized else "initializing", "message": "Service is ready" if global_state.initialized else "Service is initializing" }) def parse_gene_format(query: str) -> str: """Parse various gene format inputs into standard format""" if query.startswith('gene:') or query.startswith('ensembl:'): return query if ';' in query and not query.startswith('gene:'): return f"gene:{query}" return f"gene:{query};type:cdna" def validate_window_size(sequences: Union[List[str], List[List[str]]], window_size: int, kmer_size: int) -> Tuple[int, Optional[str]]: """ Validate window size against sequence lengths and k-mer size. Handles both flat lists of sequences and lists of isoform lists. Args: sequences: List of sequences or list of isoform lists window_size: Requested window size kmer_size: Size of k-mers used in index Returns: Tuple of (validated_size, warning_message) """ if window_size < kmer_size: raise SequenceValidationError( f"Window size ({window_size}) cannot be smaller than k-mer size ({kmer_size})", "Please increase the window size parameter" ) # Flatten sequences if needed to find shortest length flat_sequences = [] for seq in sequences: if isinstance(seq, list): # This is a list of isoforms flat_sequences.extend(seq) else: # This is a single sequence flat_sequences.append(seq) if not flat_sequences: raise SequenceValidationError( "No valid sequences to process", "No sequences were found for the provided input" ) min_seq_length = min(len(seq) for seq in flat_sequences) if min_seq_length < kmer_size: raise SequenceValidationError( f"Shortest sequence length ({min_seq_length}) cannot be smaller than k-mer size ({kmer_size})", "One or more sequences are too short for analysis" ) if window_size > min_seq_length: new_size = min_seq_length return new_size, f"Window size adjusted to {new_size} to match shortest sequence/isoform" return window_size, None def process_sequence_input(query: str) -> List[str]: """ Process input query into list of sequences. Handles raw sequences, FASTA format, and gene IDs. Supports multiple separators: commas, spaces, newlines """ query = query.strip() if not query: return [] # Check if this is a FASTA sequence if query.startswith('>'): sequences = [] current_seq = [] for line in query.split('\n'): line = line.strip() if not line: continue if line.startswith('>'): if current_seq: sequences.append('\n'.join(current_seq)) current_seq = [line] else: current_seq.append(line) if current_seq: sequences.append('\n'.join(current_seq)) return sequences # Not FASTA - handle as raw sequence(s) or gene IDs parts = [] for line in query.split('\n'): line_parts = re.split(r'[,\s]+', line.strip()) parts.extend(part for part in line_parts if part) sequences = [] for part in parts: cleaned = part.strip() if not cleaned: continue if cleaned.startswith('gene:') or cleaned.startswith('ensembl:'): if ';type:' not in cleaned: cleaned = f"{cleaned};type:cdna" sequences.append(cleaned) elif re.match(r'^[ACGTNUacgtnu]+$', cleaned): sequences.append(cleaned.upper()) else: sequences.append(f"gene:{cleaned};type:cdna") return sequences @app.route("/search", methods=["POST"]) def search(): try: query = request.args.get("selectsequence", "").strip() if not query: raise SequenceValidationError("No sequence provided", "Please enter a sequence or gene name") logger.info("Starting sequence processing...") # Get parameters sliding_size = int(request.args.get("sliding_size", 128)) pct_threshold = float(request.args.get("pct_threshold", 0.65)) low_complexity_filter = request.args.get("low_complexity_filter", "").lower() in ["true", "1"] countmaxkmer = int(float(request.args.get("countmaxkmer", 5))) countminkmer = int(float(request.args.get("countminkmer", 1))) # Process input processed_sequences = process_sequence_input(query) logger.info(f"Initial processing returned {len(processed_sequences)} sequences") if not processed_sequences: raise SequenceValidationError( "No valid sequences found in input", "Check the format of your input. Examples are shown below." ) # Handle gene IDs and get actual sequences final_sequences = [] warnings = [] errors = [] for seq in processed_sequences: try: if seq.startswith('>'): # FASTA lines = seq.split('\n') sequence = ''.join(lines[1:]).upper() if not re.match(r'^[ACGTNU]+$', sequence): errors.append(f"Invalid FASTA sequence: {lines[0]}") continue final_sequences.append(sequence) elif seq.startswith('gene:') or seq.startswith('ensembl:') or ';' in seq: # Handle gene IDs result = handle_sequence(seq) if result is None: errors.append(f"Gene not found: {seq}") continue if isinstance(result, list): final_sequences.extend(result) else: final_sequences.append(result) else: # Must be raw sequence or simple gene if re.match(r'^[ACGTNUacgtnu]+$', seq): # It's a raw sequence final_sequences.append(seq.upper()) else: # Try as simple gene name result = handle_sequence(f"gene:{seq}") if result is None: errors.append(f"Gene not found: {seq}") continue if isinstance(result, list): final_sequences.extend(result) else: final_sequences.append(result) except Exception as e: errors.append(f"Error processing {seq[:20]}: {str(e)}") continue logger.info(f"Final processing yielded {len(final_sequences)} sequences") if not final_sequences: error_msg = "No valid sequences could be processed." if errors: error_msg += f" Errors: {'; '.join(errors)}" raise SequenceValidationError(error_msg, "Please check your input and try again") # Validate window size kmer_size = global_state.kmer_index.kmer_size min_seq_length = min(len(seq) for seq in final_sequences) if min_seq_length < kmer_size: raise SequenceValidationError( f"Sequence length ({min_seq_length}) cannot be smaller than k-mer size ({kmer_size})", "One or more sequences are too short for analysis" ) if sliding_size > min_seq_length: sliding_size = min_seq_length warnings.append(f"Window size adjusted to {sliding_size} to match shortest sequence") # Query the index try: locs, ints, where_abundant = interactive_query_standard( final_sequences, sliding_size=sliding_size, pct_threshold=pct_threshold, low_complexity_filter=low_complexity_filter, countmaxkmer=10**countmaxkmer, countminkmer=10**countminkmer ) except Exception as e: logger.error(f"Query error: {str(e)}") raise SequenceValidationError( "Error searching sequences", "Try adjusting the search parameters" ) # Check results if len(locs) == 0: raise SequenceValidationError( "No results found for any sequences", "Try adjusting the search parameters" ) # Store results g.user_session.add_query_result( global_state.xy[locs], ints ) # Store first sequence for display session["query_seq"] = final_sequences[0] session["query_term"] = query session["where_abundant"] = where_abundant.tolist() if isinstance(where_abundant, np.ndarray) else where_abundant response = {"success": True} if warnings: response["warnings"] = warnings if errors: response["errors"] = errors logger.info("Search completed successfully") return jsonify(response) except SequenceValidationError as e: return jsonify({ "error": str(e), "help": e.help_text }), 400 except Exception as e: logger.error(f"Unexpected error in search: {str(e)}") return jsonify({ "error": "An unexpected error occurred", "help": "Please try again or contact support if the problem persists" }), 500 # Add route to handle save request @app.route("/save_report", methods=["POST"]) def save_report(): try: # Create report generator = HTMLReportGenerator(session) zip_data = generator.create_report_zip() # Create response response = make_response(zip_data) response.headers['Content-Type'] = 'application/zip' response.headers['Content-Disposition'] = 'attachment; filename=malva_report.zip' return response except Exception as e: logger.error(f"Error generating report: {str(e)}") return jsonify({ "error": "Failed to generate report", "help": "Please try again or contact support if the problem persists" }), 500 @app.route("/", methods=["GET"]) def index(): """Handle root path requests""" try: _xmax = global_state.kmer_index.coord_lims[1] + 1 _ymax = global_state.kmer_index.coord_lims[3] + 1 return render_template("index.html", xmax=_xmax, ymax=_ymax) except Exception as e: return jsonify({"error": str(e)}), 500 return app
def _run_serve(args): """Run the server with the given arguments""" global MAX_SEQUENCE_LENGTH, global_state MAX_SEQUENCE_LENGTH = args.max_len try: # Initialize global state first global_state = GlobalState() global_state.initialize_from_index(args.index_in, max_mem=args.max_mem) logger.info("Initialized global state from index") # Create app only once after global state is initialized app = create_app(init_state=True, _uuid=args.uuid) CORS(app, resources={ r"/*": { "origins": "*", "allow_headers": ["Content-Type", "Authorization"], "supports_credentials": True } }) app.run( debug=False, host=args.address, port=args.port, use_reloader=False ) except Exception as e: logger.error(f"Error running serve: {str(e)}") raise app = None if __name__ == "__main__": from malva.cli import get_serve_parser args = get_serve_parser().parse_args() _run_serve(args)