Source code for talon.cli.commands.filter

#!/usr/bin/env python
import logging
import pickle
from typing import List

import nibabel as nib
import numpy as np
import scipy.sparse as sp

import talon
import talon.cli

DESCRIPTION = ("Use TALON to filter a tractogram with the Volume Fraction "
               "forward model.")


def add_parser(subparsers):
    parser = subparsers.add_parser(
        'filter',
        description=DESCRIPTION,
        help='Filter a tractogram using TALON.')

    parser.add_argument(
        'in_tracks',
        type=str,
        help='Input tractogram file in RAS+ and mm space. '
             'The streamline coordinate (0,0,0) refers to the center of the '
             'voxel. '
             'Must be in NiBabel-readable format (.trk or .tck).'
    )

    parser.add_argument(
        'in_data',
        type=str,
        help='Input data to be fitted. Serves also as reference space for '
             'tractogram. '
             'Must be in NiBabel-readable format (.nii or .nii.gz).'
    )

    parser.add_argument(
        'out_weights',
        type=str,
        help='Output text file containing the streamline weights.'
    )

    parser.add_argument(
        '--operator-type',
        type=str,
        choices=['reference', 'fast', 'opencl'],
        default='fast',
        help='Type of operator to use. Default: `fast`.'
    )
    talon.cli.utils.add_ndir_to_input(parser)

    solver_options = parser.add_argument_group('Solver options')
    solver_options.add_argument(
        '--allow-negative-x',
        action='store_true',
        help='Disables the non negativity constraint.'
    )

    solver_options.add_argument(
        '--sigma',
        type=float,
        default=0.0,
        metavar='value',
        help='Sets the regularization scale parameter as in (Frigo, 2021). '
             'The final value of lambda is `sigma*max(||At*data||/gwei)`, '
             'where sigma is the passed parameter, `||At*data||` is the 2-norm '
             'of the product between the transposed linear operator and the '
             'data, and `gwei` is the vector of the weights associated to each '
             'group of streamlines. Default: %(default)s.'
    )

    solver_options.add_argument(
        '--streamline-assignment',
        type=str,
        metavar='file',
        help='Activates the group sparsity regularization by specifying the '
             'node assignments of each streamline to some parcellation. '
             'Typically, this file is produced by the Mrtrix3 command '
             '`tck2connectome` with the option `-out_assignment`. '
             'The file is expected to be in text format with one row per '
             'streamline. E.g., if the first row is [5, 14], the first '
             'streamline will be bundled together with all the streamlines '
             'corresponding rows having [5, 14] or [14, 5].'
    )

    solver_options.add_argument(
        '--connectome',
        type=str,
        metavar='file',
        help='Activates the FIT regularization by specifying the connectivity '
             'matrix. '
             'Each streamline bundle is associated to the entry in the '
             'connectivity matrix corresponding to the region lables that it '
             'connects. '
             'E.g., the bundle connecting regions 5 and 14 is associated to '
             'the entry [5, 14] of the connectivity matrix. '
             'Notice that the first row and column correspond to the zero '
             'label. '
             'Must be used together with `--streamline-assignment`.'
    )

    solver_options.add_argument(
        '--objective-relative-tolerance',
        type=float,
        default=1e-6,
        metavar='value',
        help='Sets relative tolerance on cost function. Default: %(default)s.'
    )

    solver_options.add_argument(
        '--x-absolute-tolerance',
        type=float,
        default=1e-6,
        metavar='value',
        help='Sets absolute tolerance on variable. Default: %(default)s.'
    )

    solver_options.add_argument(
        '--maxiter',
        type=int,
        default=1000,
        metavar='count',
        help='Sets maximum number of iterations. Default: %(default)s.'
    )

    parser.add_argument(
        '--precomputed-indices-weights',
        type=str,
        nargs=2,
        metavar=('file_idx', 'file_wei'),
        help='Uses the indices and weights passed as input to build the '
             'linear operator. '
             'E.g. `--precomputed-indices-weights <indices>.npz '
             '<weights>.npz`. '
             'The two matrices must be defined on the same number of '
             'directions as the ones that are used at the call of this script.'
    )

    operator_format = parser.add_mutually_exclusive_group(required=False)
    operator_format.add_argument(
        '--save-generators-indices-weights',
        type=str,
        nargs=3,
        metavar=('file_gen', 'file_idx', 'file_wei'),
        help='Saves the linear operator as three separate files '
             '`<generators>.npy <indices>.npz <weights>.npz`. '
             'All types of operator can be saved in this format.'
    )
    operator_format.add_argument(
        '--save-operator-pickle',
        type=str,
        metavar='file',
        help='Saves the linear operator with pickle. Only available when '
             '--operator-type is set to `reference` or `fast`.'
    )

    talon.cli.utils.add_verbosity_and_force_to_parser(parser)

    parser.set_defaults(func=run)
    return parser


def load_operator_and_data(indices: sp.coo_matrix, weights: sp.coo_matrix,
                           fdata: str, operator_type: str, ndir: int) -> (
                           talon.core.LinearOperator, np.ndarray):
    data = nib.load(fdata).get_fdata().ravel()
    logging.info('Loaded data')

    # mask data
    g = np.ones((ndir, 1), dtype=np.float64)
    if not data.size == weights.shape[0]:
        raise ValueError('Data volume and weights matrix are not compatible.')
    op = talon.operator(g, indices, weights, operator_type='reference')
    data = talon.utils.mask_data(data, op)
    logging.info('Preprocessed data')
    del op
    del g

    # define operator
    if operator_type == 'opencl':
        generators = np.zeros((ndir, 4), dtype=talon.core.DATATYPE)
        generators[:, 0] = 1.0
        data = np.kron(data[:, None], np.array([[1., 0, 0, 0]]).T).squeeze()
    else:
        generators = np.ones((ndir, 1), dtype=talon.core.DATATYPE)

    logging.info(f'Loading operator of type: {operator_type}')
    op = talon.operator(generators, indices, weights,
                        operator_type=operator_type)
    logging.info('Loaded operator')

    return op, data


def compute_indices_and_generators(in_tracks: str, in_data: str, ndir: int):
    logging.info('Voxelizing tractogram')
    img = nib.load(in_data)
    affine = np.linalg.inv(img.affine)
    image_shape = img.shape
    vertices = talon.utils.directions(number_of_points=ndir)
    tracto = nib.streamlines.load(in_tracks)
    logging.info('Changing streamlines coordinates to voxel space')
    tracto.tractogram.apply_affine(affine)
    streamlines = tracto.streamlines
    return talon.voxelize(streamlines, vertices, image_shape)


[docs]def run(in_tracks: str, in_data: str, out_weights: str, force: bool, ndir: int, precomputed_indices_weights: List[str], save_generators_indices_weights: List[str], save_operator_pickle: str, operator_type: str, streamline_assignment: str, connectome: str, sigma: float, allow_negative_x: bool, maxiter: int, objective_relative_tolerance: float, x_absolute_tolerance: float, **kwargs): """ Args: in_tracks: str Input tractogram file in RAS+ and mm space. The streamline coordinate (0,0,0) refers to the center of the voxel. Must be in NiBabel-readable format (.trk or .tck). in_data: str Input data to be fitted. Serves also as reference space for tractogram. Must be in NiBabel-readable format (.nii or .nii.gz). out_weights: str Output text file containing the streamline weights. force: bool True if the file can be overwritten, False otherwise. ndir: int Number of directions for the voxelization. precomputed_indices_weights: List Uses the indices and weights passed as input to build the linear operator. The two matrices must be defined on the same number of directions (ndir) as the ones that are used at the call of this script. save_generators_indices_weights: List Saves the linear operator as three separate files. save_operator_pickle: str Saves the linear operator with pickle. Only available when `operator_type` is 'fast' or 'reference'. operator_type: str Type of operator to use. Default: `fast`. Choiches: 'reference', 'fast', 'opencl'. streamline_assignment: str Path to the file whose rows contain the assignment of each streamline. E.g., if the n-th row is '5 17', the n-th streamline is assigned to regions 5 and 17. The region labels must be integer values and separated by a blank space. Lines starting with # are skipped. connectome: str Path to the connectivity matrix to be employed in txt format. The first row and column correspond to the zero label. sigma: float Sets the regularization scale parameter as in (Frigo, 2021). The final value of lambda is `sigma*max(||At*data||/gwei)`, where sigma is the passed parameter, `||At*data||` is the 2-norm of the product between the transposed linear operator and the data, and `gwei` is the vector of the weights associated to each group of streamlines. allow_negative_x: bool Disables the non negativity constraint. maxiter: int Sets maximum number of iterations. Default: 1000. objective_relative_tolerance: float Sets relative tolerance on cost function. Default: 1e-6. x_absolute_tolerance: float Sets absolute tolerance on variable. Default: 1e-6. """ files_to_check = [out_weights] if save_generators_indices_weights is not None: files_to_check.extend(save_generators_indices_weights) if save_operator_pickle is not None: files_to_check.append(save_operator_pickle) for f in files_to_check: talon.cli.utils.check_can_write_file(f, force=force) if precomputed_indices_weights is not None: i = sp.load_npz(precomputed_indices_weights[0]) w = sp.load_npz(precomputed_indices_weights[1]) else: # Voxelize i, w = compute_indices_and_generators(in_tracks, in_data, ndir=ndir) # Go back to filter op, data = load_operator_and_data(i, w, in_data, operator_type, ndir) llambda = 0.0 groups = None weights = None if streamline_assignment is not None: mapping = talon.cli.utils.assignment_to_mapping(streamline_assignment) c = None if connectome is not None: c = np.loadtxt(connectome) groups, weights = talon.cli.utils.mapping_to_groups_weights(mapping, c) logging.info(f'Number of streamline bundles: {len(groups)}') llambda = np.linalg.norm(op.T @ data, 2) llambda *= np.max(1 / weights) llambda *= sigma logging.info(f'Regularization parameter: {llambda}') regterm = talon.regularization( non_negativity=np.logical_not(allow_negative_x), regularization_parameter=llambda, groups=groups, weights=weights ) logging.info(f'Non-negativity: {regterm.non_negativity}') x0 = np.zeros(op.shape[1], dtype=talon.core.DATATYPE) converged = False total_nit = 0 sol = None while not converged: max_nit = np.min([maxiter - total_nit, 25]) sol = talon.solve(op, data, reg_term=regterm, cost_reltol=objective_relative_tolerance, x_abstol=x_absolute_tolerance, max_nit=max_nit, x0=x0, verbose='NONE') total_nit += sol.nit logging.info(f'Iteration: {total_nit} ~ ' f'1/2||Ax-y||2: {sol.fun[0]} ~ ' f'Omega: {sol.fun[1]} ~ ' f'Total: {np.sum(sol.fun)}') x0 = sol.x if sol.success: converged = True # save result logging.info(f'Saving result in {out_weights}') np.savetxt(out_weights, sol.x) # save dictionary if save_generators_indices_weights is not None: logging.info( 'Saving linear operator in {} {} {}.'.format( *save_generators_indices_weights)) np.save(save_generators_indices_weights[0], op.generators) sp.save_npz(save_generators_indices_weights[1], op._indices_of_generators) sp.save_npz(save_generators_indices_weights[2], op._weights) elif save_operator_pickle: if operator_type in ['reference', 'fast']: logging.info(f'Saving linear operator in {save_operator_pickle}.') with open(save_operator_pickle, 'wb') as f: pickle.dump(op, f) else: raise ValueError('Operator type does not allow to save in pickle ' 'format.') if precomputed_indices_weights is None: logging.info('Cleaning temporary files.')