Boost C++ Libraries

...one of the most highly regarded and expertly designed C++ library projects in the world. Herb Sutter and Andrei Alexandrescu, C++ Coding Standards

This is the documentation for an old version of Boost. Click here to view this page for the latest version.

boost/numeric/ublas/tensor/extents.hpp

//
//  Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
//
//  Distributed under the Boost Software License, Version 1.0. (See
//  accompanying file LICENSE_1_0.txt or copy at
//  http://www.boost.org/LICENSE_1_0.txt)
//
//  The authors gratefully acknowledge the support of
//  Fraunhofer IOSB, Ettlingen, Germany
//


#ifndef BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP
#define BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP

#include <algorithm>
#include <initializer_list>
#include <limits>
#include <numeric>
#include <stdexcept>
#include <vector>

#include <cassert>

namespace boost {
namespace numeric {
namespace ublas {


/** @brief Template class for storing tensor extents with runtime variable size.
 *
 * Proxy template class of std::vector<int_type>.
 *
 */
template<class int_type>
class basic_extents
{
	static_assert( std::numeric_limits<typename std::vector<int_type>::value_type>::is_integer, "Static error in basic_layout: type must be of type integer.");
	static_assert(!std::numeric_limits<typename std::vector<int_type>::value_type>::is_signed,  "Static error in basic_layout: type must be of type unsigned integer.");

public:
	using base_type = std::vector<int_type>;
	using value_type = typename base_type::value_type;
	using const_reference = typename base_type::const_reference;
	using reference = typename base_type::reference;
	using size_type = typename base_type::size_type;
	using const_pointer = typename base_type::const_pointer;
	using const_iterator = typename base_type::const_iterator;


	/** @brief Default constructs basic_extents
	 *
	 * @code auto ex = basic_extents<unsigned>{};
	 */
	constexpr explicit basic_extents()
	  : _base{}
	{
	}

	/** @brief Copy constructs basic_extents from a one-dimensional container
	 *
	 * @code auto ex = basic_extents<unsigned>(  std::vector<unsigned>(3u,3u) );
	 *
	 * @note checks if size > 1 and all elements > 0
	 *
	 * @param b one-dimensional std::vector<int_type> container
	 */
	explicit basic_extents(base_type const& b)
	  : _base(b)
	{
		if (!this->valid()){
			throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
		}
	}

	/** @brief Move constructs basic_extents from a one-dimensional container
	 *
	 * @code auto ex = basic_extents<unsigned>(  std::vector<unsigned>(3u,3u) );
	 *
	 * @note checks if size > 1 and all elements > 0
	 *
	 * @param b one-dimensional container of type std::vector<int_type>
	 */
	explicit basic_extents(base_type && b)
	  : _base(std::move(b))
	{
		if (!this->valid()){
			throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
		}
	}

	/** @brief Constructs basic_extents from an initializer list
	 *
	 * @code auto ex = basic_extents<unsigned>{3,2,4};
	 *
	 * @note checks if size > 1 and all elements > 0
	 *
	 * @param l one-dimensional list of type std::initializer<int_type>
	 */
	basic_extents(std::initializer_list<value_type> l)
	  : basic_extents( base_type(std::move(l)) )
	{
	}

	/** @brief Constructs basic_extents from a range specified by two iterators
	 *
	 * @code auto ex = basic_extents<unsigned>(a.begin(), a.end());
	 *
	 * @note checks if size > 1 and all elements > 0
	 *
	 * @param first iterator pointing to the first element
	 * @param last iterator pointing to the next position after the last element
	 */
	basic_extents(const_iterator first, const_iterator last)
	  : basic_extents ( base_type( first,last ) )
	{
	}

	/** @brief Copy constructs basic_extents */
	basic_extents(basic_extents const& l )
	  : _base(l._base)
	{
	}

	/** @brief Move constructs basic_extents */
	basic_extents(basic_extents && l ) noexcept
	  : _base(std::move(l._base))
	{
	}

	~basic_extents() = default;

	basic_extents& operator=(basic_extents other) noexcept
	{
		swap (*this, other);
		return *this;
	}

	friend void swap(basic_extents& lhs, basic_extents& rhs) {
		std::swap(lhs._base   , rhs._base   );
	}



	/** @brief Returns true if this has a scalar shape
	 *
	 * @returns true if (1,1,[1,...,1])
	*/
	bool is_scalar() const
	{
		return
		    _base.size() != 0 &&
		    std::all_of(_base.begin(), _base.end(),
		                [](const_reference a){ return a == 1;});
	}

	/** @brief Returns true if this has a vector shape
	 *
	 * @returns true if (1,n,[1,...,1]) or (n,1,[1,...,1]) with n > 1
	*/
	bool is_vector() const
	{
		if(_base.size() == 0){
			return false;
		}

		if(_base.size() == 1){
			return _base.at(0) > 1;
		}

		auto greater_one = [](const_reference a){ return a >  1;};
		auto equal_one   = [](const_reference a){ return a == 1;};

		return
		    std::any_of(_base.begin(),   _base.begin()+2, greater_one) &&
		    std::any_of(_base.begin(),   _base.begin()+2, equal_one  ) &&
		    std::all_of(_base.begin()+2, _base.end(),     equal_one);
	}

	/** @brief Returns true if this has a matrix shape
	 *
	 * @returns true if (m,n,[1,...,1]) with m > 1 and n > 1
	*/
	bool is_matrix() const
	{
		if(_base.size() < 2){
			return false;
		}

		auto greater_one = [](const_reference a){ return a >  1;};
		auto equal_one   = [](const_reference a){ return a == 1;};

		return
		    std::all_of(_base.begin(),   _base.begin()+2, greater_one) &&
		    std::all_of(_base.begin()+2, _base.end(),     equal_one  );
	}

	/** @brief Returns true if this is has a tensor shape
	 *
	 * @returns true if !empty() && !is_scalar() && !is_vector() && !is_matrix()
	*/
	bool is_tensor() const
	{
		if(_base.size() < 3){
			return false;
		}

		auto greater_one = [](const_reference a){ return a > 1;};

		return std::any_of(_base.begin()+2, _base.end(), greater_one);
	}

	const_pointer data() const
	{
		return this->_base.data();
	}

	const_reference operator[] (size_type p) const
	{
		return this->_base[p];
	}

	const_reference at (size_type p) const
	{
		return this->_base.at(p);
	}

	reference operator[] (size_type p)
	{
		return this->_base[p];
	}

	reference at (size_type p)
	{
		return this->_base.at(p);
	}


	bool empty() const
	{
		return this->_base.empty();
	}

	size_type size() const
	{
		return this->_base.size();
	}

	/** @brief Returns true if size > 1 and all elements > 0 */
	bool valid() const
	{
		return
		    this->size() > 1 &&
		    std::none_of(_base.begin(), _base.end(),
		                 [](const_reference a){ return a == value_type(0); });
	}

	/** @brief Returns the number of elements a tensor holds with this */
	size_type product() const
	{
		if(_base.empty()){
			return 0;
		}

		return std::accumulate(_base.begin(), _base.end(), 1ul, std::multiplies<>());

	}


	/** @brief Eliminates singleton dimensions when size > 2
	 *
	 * squeeze {  1,1} -> {  1,1}
	 * squeeze {  2,1} -> {  2,1}
	 * squeeze {  1,2} -> {  1,2}
	 *
	 * squeeze {1,2,3} -> {  2,3}
	 * squeeze {2,1,3} -> {  2,3}
	 * squeeze {1,3,1} -> {  3,1}
	 *
	*/
	basic_extents squeeze() const
	{
		if(this->size() <= 2){
			return *this;
		}

		auto new_extent = basic_extents{};
		auto insert_iter = std::back_insert_iterator<typename basic_extents::base_type>(new_extent._base);
		std::remove_copy(this->_base.begin(), this->_base.end(), insert_iter ,value_type{1});
		return new_extent;

	}

	void clear()
	{
		this->_base.clear();
	}

	bool operator == (basic_extents const& b) const
	{
		return _base == b._base;
	}

	bool operator != (basic_extents const& b) const
	{
		return !( _base == b._base );
	}

	const_iterator
	begin() const
	{
		return _base.begin();
	}

	const_iterator
	end() const
	{
		return _base.end();
	}

	base_type const& base() const { return _base; }

private:

	base_type _base;

};

using shape = basic_extents<std::size_t>;

} // namespace ublas
} // namespace numeric
} // namespace boost

#endif