Created
January 25, 2017 21:30
-
-
Save piiswrong/deeb345ccc5c83426d2249ba86105cfc to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#if defined(__MACH__) | |
#include <mach/clock.h> | |
#include <mach/mach.h> | |
#endif | |
#if !defined(__WIN32__) | |
#include <sys/stat.h> | |
#include <sys/types.h> | |
#if !defined(__ANDROID__) && (!defined(MSHADOW_USE_SSE) || MSHADOW_USE_SSE == 1) | |
#include <emmintrin.h> | |
#endif | |
#endif | |
#include <algorithm> | |
#include <array> | |
#include <assert.h> | |
#include <atomic> | |
#include <cblas.h> | |
#include <cctype> | |
#include <cfloat> | |
#include <chrono> | |
#include <climits> | |
#include <cmath> | |
#include <condition_variable> | |
#include <cstddef> | |
#include <cstdint> | |
#include <cstdio> | |
#include <cstdlib> | |
#include <cstring> | |
#include <ctime> | |
#include <deque> | |
#include <dirent.h> | |
#include <emmintrin.h> | |
#include <errno.h> | |
#include <execinfo.h> | |
#include <fstream> | |
#include <functional> | |
#include <inttypes.h> | |
#include <iostream> | |
#include <istream> | |
#include <limits> | |
#include <list> | |
#include <map> | |
#include <memory> | |
#include <mutex> | |
#include <new> | |
#include <ostream> | |
#include <queue> | |
#include <random> | |
#include <regex> | |
#include <sched.h> | |
#include <set> | |
#include <sstream> | |
#include <stdbool.h> | |
#include <stddef.h> | |
#include <stdexcept> | |
#include <stdint.h> | |
#include <stdlib.h> | |
#include <streambuf> | |
#include <string> | |
#include <thread> | |
#include <time.h> | |
#include <tuple> | |
#include <type_traits> | |
#include <typeindex> | |
#include <typeinfo> | |
#include <unordered_map> | |
#include <unordered_set> | |
#include <utility> | |
#include <vector> | |
//===== EXPANDING: dmlc-minimum0.cc ===== | |
/*! | |
* Copyright 2015 by Contributors. | |
* \brief Mininum DMLC library Amalgamation, used for easy plugin of dmlc lib. | |
* Normally this is not needed. | |
*/ | |
//===== EXPANDING: ../dmlc-core/src/io/line_split.cc ===== | |
// Copyright by Contributors | |
//===== EXPANDING: ../dmlc-core/include/dmlc/io.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file io.h | |
* \brief defines serializable interface of dmlc | |
*/ | |
#ifndef DMLC_IO_H_ | |
#define DMLC_IO_H_ | |
//===== EXPANDING: ../dmlc-core/include/dmlc/logging.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file logging.h | |
* \brief defines logging macros of dmlc | |
* allows use of GLOG, fall back to internal | |
* implementation when disabled | |
*/ | |
#ifndef DMLC_LOGGING_H_ | |
#define DMLC_LOGGING_H_ | |
//===== EXPANDING: ../dmlc-core/include/dmlc/base.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file base.h | |
* \brief defines configuration macros | |
*/ | |
#ifndef DMLC_BASE_H_ | |
#define DMLC_BASE_H_ | |
/*! \brief whether use glog for logging */ | |
#ifndef DMLC_USE_GLOG | |
#define DMLC_USE_GLOG 0 | |
#endif | |
/*! | |
* \brief whether throw dmlc::Error instead of | |
* directly calling abort when FATAL error occured | |
* NOTE: this may still not be perfect. | |
* do not use FATAL and CHECK in destructors | |
*/ | |
#ifndef DMLC_LOG_FATAL_THROW | |
#define DMLC_LOG_FATAL_THROW 1 | |
#endif | |
/*! | |
* \brief whether always log a message before throw | |
* This can help identify the error that cannot be catched. | |
*/ | |
#ifndef DMLC_LOG_BEFORE_THROW | |
#define DMLC_LOG_BEFORE_THROW 1 | |
#endif | |
/*! | |
* \brief Whether to use customized logger, | |
* whose output can be decided by other libraries. | |
*/ | |
#ifndef DMLC_LOG_CUSTOMIZE | |
#define DMLC_LOG_CUSTOMIZE 0 | |
#endif | |
/*! | |
* \brief Wheter to print stack trace for fatal error, | |
* enabled on linux when using gcc. | |
*/ | |
#if (!defined(DMLC_LOG_STACK_TRACE) && defined(__GNUC__)) | |
#define DMLC_LOG_STACK_TRACE 1 | |
#endif | |
/*! \brief whether compile with hdfs support */ | |
#ifndef DMLC_USE_HDFS | |
#define DMLC_USE_HDFS 0 | |
#endif | |
/*! \brief whether compile with s3 support */ | |
#ifndef DMLC_USE_S3 | |
#define DMLC_USE_S3 0 | |
#endif | |
/*! \brief whether or not use parameter server */ | |
#ifndef DMLC_USE_PS | |
#define DMLC_USE_PS 0 | |
#endif | |
/*! \brief whether or not use c++11 support */ | |
#ifndef DMLC_USE_CXX11 | |
#define DMLC_USE_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\ | |
__cplusplus >= 201103L || defined(_MSC_VER)) | |
#endif | |
/*! \brief strict CXX11 support */ | |
#ifndef DMLC_STRICT_CXX11 | |
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER)) | |
#endif | |
/*! \brief whether RTTI is enabled */ | |
#ifndef DMLC_ENABLE_RTTI | |
#define DMLC_ENABLE_RTTI 1 | |
#endif | |
/// check if g++ is before 4.6 | |
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) | |
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6 | |
#pragma message("Will need g++-4.6 or higher to compile all" \ | |
"the features in dmlc-core, " \ | |
"compile without c++0x, some features may be disabled") | |
#undef DMLC_USE_CXX11 | |
#define DMLC_USE_CXX11 0 | |
#endif | |
#endif | |
/*! | |
* \brief Enable std::thread related modules, | |
* Used to disable some module in mingw compile. | |
*/ | |
#ifndef DMLC_ENABLE_STD_THREAD | |
#define DMLC_ENABLE_STD_THREAD DMLC_USE_CXX11 | |
#endif | |
/*! \brief whether enable regex support, actually need g++-4.9 or higher*/ | |
#ifndef DMLC_USE_REGEX | |
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER)) | |
#endif | |
/*! \brief helper macro to supress unused warning */ | |
#if defined(__GNUC__) | |
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused)) | |
#else | |
#define DMLC_ATTRIBUTE_UNUSED | |
#endif | |
/*! \brief helper macro to generate string concat */ | |
#define DMLC_STR_CONCAT_(__x, __y) __x##__y | |
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y) | |
/*! | |
* \brief Disable copy constructor and assignment operator. | |
* | |
* If C++11 is supported, both copy and move constructors and | |
* assignment operators are deleted explicitly. Otherwise, they are | |
* only declared but not implemented. Place this macro in private | |
* section if C++11 is not available. | |
*/ | |
#ifndef DISALLOW_COPY_AND_ASSIGN | |
# if DMLC_USE_CXX11 | |
# define DISALLOW_COPY_AND_ASSIGN(T) \ | |
T(T const&) = delete; \ | |
T(T&&) = delete; \ | |
T& operator=(T const&) = delete; \ | |
T& operator=(T&&) = delete | |
# else | |
# define DISALLOW_COPY_AND_ASSIGN(T) \ | |
T(T const&); \ | |
T& operator=(T const&) | |
# endif | |
#endif | |
/// | |
/// code block to handle optionally loading | |
/// | |
#if !defined(__GNUC__) | |
#define fopen64 std::fopen | |
#endif | |
#if (defined __MINGW32__) && !(defined __MINGW64__) | |
#define fopen64 std::fopen | |
#endif | |
#ifdef _MSC_VER | |
#if _MSC_VER < 1900 | |
// NOTE: sprintf_s is not equivalent to snprintf, | |
// they are equivalent when success, which is sufficient for our case | |
#define snprintf sprintf_s | |
#define vsnprintf vsprintf_s | |
#endif | |
#else | |
#ifdef _FILE_OFFSET_BITS | |
#if _FILE_OFFSET_BITS == 32 | |
#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit") | |
#endif | |
#endif | |
#ifdef __APPLE__ | |
#define off64_t off_t | |
#define fopen64 std::fopen | |
#endif | |
extern "C" { | |
} | |
#endif | |
#ifdef _MSC_VER | |
//! \cond Doxygen_Suppress | |
typedef signed char int8_t; | |
typedef __int16 int16_t; | |
typedef __int32 int32_t; | |
typedef __int64 int64_t; | |
typedef unsigned char uint8_t; | |
typedef unsigned __int16 uint16_t; | |
typedef unsigned __int32 uint32_t; | |
typedef unsigned __int64 uint64_t; | |
//! \endcond | |
#else | |
#endif | |
#if defined(_MSC_VER) && _MSC_VER < 1900 | |
#define noexcept_true throw () | |
#define noexcept_false | |
#define noexcept(a) noexcept_##a | |
#endif | |
#if DMLC_USE_CXX11 | |
#define DMLC_THROW_EXCEPTION noexcept(false) | |
#define DMLC_NO_EXCEPTION noexcept(true) | |
#else | |
#define DMLC_THROW_EXCEPTION | |
#define DMLC_NO_EXCEPTION | |
#endif | |
/*! \brief namespace for dmlc */ | |
namespace dmlc { | |
/*! | |
* \brief safely get the beginning address of a vector | |
* \param vec input vector | |
* \return beginning address of a vector | |
*/ | |
template<typename T> | |
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*) | |
if (vec.size() == 0) { | |
return NULL; | |
} else { | |
return &vec[0]; | |
} | |
} | |
/*! | |
* \brief get the beginning address of a const vector | |
* \param vec input vector | |
* \return beginning address of a vector | |
*/ | |
template<typename T> | |
inline const T *BeginPtr(const std::vector<T> &vec) { | |
if (vec.size() == 0) { | |
return NULL; | |
} else { | |
return &vec[0]; | |
} | |
} | |
/*! | |
* \brief get the beginning address of a string | |
* \param str input string | |
* \return beginning address of a string | |
*/ | |
inline char* BeginPtr(std::string &str) { // NOLINT(*) | |
if (str.length() == 0) return NULL; | |
return &str[0]; | |
} | |
/*! | |
* \brief get the beginning address of a const string | |
* \param str input string | |
* \return beginning address of a string | |
*/ | |
inline const char* BeginPtr(const std::string &str) { | |
if (str.length() == 0) return NULL; | |
return &str[0]; | |
} | |
} // namespace dmlc | |
#if defined(_MSC_VER) && _MSC_VER < 1900 | |
#define constexpr const | |
#define alignof __alignof | |
#endif | |
#endif // DMLC_BASE_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/base.h ===== | |
#if DMLC_LOG_STACK_TRACE | |
#endif | |
namespace dmlc { | |
/*! | |
* \brief exception class that will be thrown by | |
* default logger if DMLC_LOG_FATAL_THROW == 1 | |
*/ | |
struct Error : public std::runtime_error { | |
/*! | |
* \brief constructor | |
* \param s the error message | |
*/ | |
explicit Error(const std::string &s) : std::runtime_error(s) {} | |
}; | |
} // namespace dmlc | |
#if DMLC_USE_GLOG | |
namespace dmlc { | |
/*! | |
* \brief optionally redirect to google's init log | |
* \param argv0 The arguments. | |
*/ | |
inline void InitLogging(const char* argv0) { | |
google::InitGoogleLogging(argv0); | |
} | |
} // namespace dmlc | |
#else | |
// use a light version of glog | |
#if defined(_MSC_VER) | |
#pragma warning(disable : 4722) | |
#endif | |
namespace dmlc { | |
inline void InitLogging(const char* argv0) { | |
// DO NOTHING | |
} | |
class LogCheckError { | |
public: | |
LogCheckError() : str(nullptr) {} | |
explicit LogCheckError(const std::string& str_) : str(new std::string(str_)) {} | |
~LogCheckError() { if (str != nullptr) delete str; } | |
operator bool() {return str != nullptr; } | |
std::string* str; | |
}; | |
#define DEFINE_CHECK_FUNC(name, op) \ | |
template <typename X, typename Y> \ | |
inline LogCheckError LogCheck##name(const X& x, const Y& y) { \ | |
if (x op y) return LogCheckError(); \ | |
std::ostringstream os; \ | |
os << " (" << x << " vs. " << y << ") "; /* CHECK_XX(x, y) requires x and y can be serialized to string. Use CHECK(x OP y) otherwise. NOLINT(*) */ \ | |
return LogCheckError(os.str()); \ | |
} \ | |
inline LogCheckError LogCheck##name(int x, int y) { \ | |
return LogCheck##name<int, int>(x, y); \ | |
} | |
#define CHECK_BINARY_OP(name, op, x, y) \ | |
if (dmlc::LogCheckError _check_err = dmlc::LogCheck##name(x, y)) \ | |
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ | |
<< "Check failed: " << #x " " #op " " #y << *(_check_err.str) | |
DEFINE_CHECK_FUNC(_LT, <) | |
DEFINE_CHECK_FUNC(_GT, >) | |
DEFINE_CHECK_FUNC(_LE, <=) | |
DEFINE_CHECK_FUNC(_GE, >=) | |
DEFINE_CHECK_FUNC(_EQ, ==) | |
DEFINE_CHECK_FUNC(_NE, !=) | |
// Always-on checking | |
#define CHECK(x) \ | |
if (!(x)) \ | |
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ | |
<< "Check failed: " #x << ' ' | |
#define CHECK_LT(x, y) CHECK_BINARY_OP(_LT, <, x, y) | |
#define CHECK_GT(x, y) CHECK_BINARY_OP(_GT, >, x, y) | |
#define CHECK_LE(x, y) CHECK_BINARY_OP(_LE, <=, x, y) | |
#define CHECK_GE(x, y) CHECK_BINARY_OP(_GE, >=, x, y) | |
#define CHECK_EQ(x, y) CHECK_BINARY_OP(_EQ, ==, x, y) | |
#define CHECK_NE(x, y) CHECK_BINARY_OP(_NE, !=, x, y) | |
#define CHECK_NOTNULL(x) \ | |
((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*) | |
// Debug-only checking. | |
#ifdef NDEBUG | |
#define DCHECK(x) \ | |
while (false) CHECK(x) | |
#define DCHECK_LT(x, y) \ | |
while (false) CHECK((x) < (y)) | |
#define DCHECK_GT(x, y) \ | |
while (false) CHECK((x) > (y)) | |
#define DCHECK_LE(x, y) \ | |
while (false) CHECK((x) <= (y)) | |
#define DCHECK_GE(x, y) \ | |
while (false) CHECK((x) >= (y)) | |
#define DCHECK_EQ(x, y) \ | |
while (false) CHECK((x) == (y)) | |
#define DCHECK_NE(x, y) \ | |
while (false) CHECK((x) != (y)) | |
#else | |
#define DCHECK(x) CHECK(x) | |
#define DCHECK_LT(x, y) CHECK((x) < (y)) | |
#define DCHECK_GT(x, y) CHECK((x) > (y)) | |
#define DCHECK_LE(x, y) CHECK((x) <= (y)) | |
#define DCHECK_GE(x, y) CHECK((x) >= (y)) | |
#define DCHECK_EQ(x, y) CHECK((x) == (y)) | |
#define DCHECK_NE(x, y) CHECK((x) != (y)) | |
#endif // NDEBUG | |
#if DMLC_LOG_CUSTOMIZE | |
#define LOG_INFO dmlc::CustomLogMessage(__FILE__, __LINE__) | |
#else | |
#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__) | |
#endif | |
#define LOG_ERROR LOG_INFO | |
#define LOG_WARNING LOG_INFO | |
#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__) | |
#define LOG_QFATAL LOG_FATAL | |
// Poor man version of VLOG | |
#define VLOG(x) LOG_INFO.stream() | |
#define LOG(severity) LOG_##severity.stream() | |
#define LG LOG_INFO.stream() | |
#define LOG_IF(severity, condition) \ | |
!(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) | |
#ifdef NDEBUG | |
#define LOG_DFATAL LOG_ERROR | |
#define DFATAL ERROR | |
#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) | |
#define DLOG_IF(severity, condition) \ | |
(true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) | |
#else | |
#define LOG_DFATAL LOG_FATAL | |
#define DFATAL FATAL | |
#define DLOG(severity) LOG(severity) | |
#define DLOG_IF(severity, condition) LOG_IF(severity, condition) | |
#endif | |
// Poor man version of LOG_EVERY_N | |
#define LOG_EVERY_N(severity, n) LOG(severity) | |
class DateLogger { | |
public: | |
DateLogger() { | |
#if defined(_MSC_VER) | |
_tzset(); | |
#endif | |
} | |
const char* HumanDate() { | |
#if defined(_MSC_VER) | |
_strtime_s(buffer_, sizeof(buffer_)); | |
#else | |
time_t time_value = time(NULL); | |
struct tm *pnow; | |
#if !defined(_WIN32) | |
struct tm now; | |
pnow = localtime_r(&time_value, &now); | |
#else | |
pnow = localtime(&time_value); // NOLINT(*) | |
#endif | |
snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", | |
pnow->tm_hour, pnow->tm_min, pnow->tm_sec); | |
#endif | |
return buffer_; | |
} | |
private: | |
char buffer_[9]; | |
}; | |
class LogMessage { | |
public: | |
LogMessage(const char* file, int line) | |
: | |
#ifdef __ANDROID__ | |
log_stream_(std::cout) | |
#else | |
log_stream_(std::cerr) | |
#endif | |
{ | |
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" | |
<< line << ": "; | |
} | |
~LogMessage() { log_stream_ << '\n'; } | |
std::ostream& stream() { return log_stream_; } | |
protected: | |
std::ostream& log_stream_; | |
private: | |
DateLogger pretty_date_; | |
LogMessage(const LogMessage&); | |
void operator=(const LogMessage&); | |
}; | |
// customized logger that can allow user to define where to log the message. | |
class CustomLogMessage { | |
public: | |
CustomLogMessage(const char* file, int line) { | |
log_stream_ << "[" << DateLogger().HumanDate() << "] " << file << ":" | |
<< line << ": "; | |
} | |
~CustomLogMessage() { | |
Log(log_stream_.str()); | |
} | |
std::ostream& stream() { return log_stream_; } | |
/*! | |
* \brief customized logging of the message. | |
* This function won't be implemented by libdmlc | |
* \param msg The message to be logged. | |
*/ | |
static void Log(const std::string& msg); | |
private: | |
std::ostringstream log_stream_; | |
}; | |
#if DMLC_LOG_FATAL_THROW == 0 | |
class LogMessageFatal : public LogMessage { | |
public: | |
LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} | |
~LogMessageFatal() { | |
#if DMLC_LOG_STACK_TRACE | |
const int MAX_STACK_SIZE = 256; | |
void *stack[MAX_STACK_SIZE]; | |
int nframes = backtrace(stack, MAX_STACK_SIZE); | |
log_stream_ << "\n\n" << "Stack trace returned " << nframes << " entries:\n"; | |
char **msgs = backtrace_symbols(stack, nframes); | |
if (msgs != nullptr) { | |
for (int i = 0; i < nframes; ++i) { | |
log_stream_ << "[bt] (" << i << ") " << msgs[i] << "\n"; | |
} | |
} | |
#endif | |
log_stream_ << "\n"; | |
abort(); | |
} | |
private: | |
LogMessageFatal(const LogMessageFatal&); | |
void operator=(const LogMessageFatal&); | |
}; | |
#else | |
class LogMessageFatal { | |
public: | |
LogMessageFatal(const char* file, int line) { | |
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" | |
<< line << ": "; | |
} | |
std::ostringstream &stream() { return log_stream_; } | |
~LogMessageFatal() DMLC_THROW_EXCEPTION { | |
#if DMLC_LOG_STACK_TRACE | |
const int MAX_STACK_SIZE = 256; | |
void *stack[MAX_STACK_SIZE]; | |
int nframes = backtrace(stack, MAX_STACK_SIZE); | |
log_stream_ << "\n\n" << "Stack trace returned " << nframes << " entries:\n"; | |
char **msgs = backtrace_symbols(stack, nframes); | |
if (msgs != nullptr) { | |
for (int i = 0; i < nframes; ++i) { | |
log_stream_ << "[bt] (" << i << ") " << msgs[i] << "\n"; | |
} | |
} | |
#endif | |
// throwing out of destructor is evil | |
// hopefully we can do it here | |
// also log the message before throw | |
#if DMLC_LOG_BEFORE_THROW | |
LOG(ERROR) << log_stream_.str(); | |
#endif | |
throw Error(log_stream_.str()); | |
} | |
private: | |
std::ostringstream log_stream_; | |
DateLogger pretty_date_; | |
LogMessageFatal(const LogMessageFatal&); | |
void operator=(const LogMessageFatal&); | |
}; | |
#endif | |
// This class is used to explicitly ignore values in the conditional | |
// logging macros. This avoids compiler warnings like "value computed | |
// is not used" and "statement has no effect". | |
class LogMessageVoidify { | |
public: | |
LogMessageVoidify() {} | |
// This has to be an operator with a precedence lower than << but | |
// higher than "?:". See its usage. | |
void operator&(std::ostream&) {} | |
}; | |
} // namespace dmlc | |
#endif | |
#endif // DMLC_LOGGING_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/logging.h ===== | |
// include uint64_t only to make io standalone | |
#ifdef _MSC_VER | |
/*! \brief uint64 */ | |
typedef unsigned __int64 uint64_t; | |
#else | |
#endif | |
/*! \brief namespace for dmlc */ | |
namespace dmlc { | |
/*! | |
* \brief interface of stream I/O for serialization | |
*/ | |
class Stream { // NOLINT(*) | |
public: | |
/*! | |
* \brief reads data from a stream | |
* \param ptr pointer to a memory buffer | |
* \param size block size | |
* \return the size of data read | |
*/ | |
virtual size_t Read(void *ptr, size_t size) = 0; | |
/*! | |
* \brief writes data to a stream | |
* \param ptr pointer to a memory buffer | |
* \param size block size | |
*/ | |
virtual void Write(const void *ptr, size_t size) = 0; | |
/*! \brief virtual destructor */ | |
virtual ~Stream(void) {} | |
/*! | |
* \brief generic factory function | |
* create an stream, the stream will close the underlying files upon deletion | |
* | |
* \param uri the uri of the input currently we support | |
* hdfs://, s3://, and file:// by default file:// will be used | |
* \param flag can be "w", "r", "a" | |
* \param allow_null whether NULL can be returned, or directly report error | |
* \return the created stream, can be NULL when allow_null == true and file do not exist | |
*/ | |
static Stream *Create(const char *uri, | |
const char* const flag, | |
bool allow_null = false); | |
// helper functions to write/read different data structures | |
/*! | |
* \brief writes a data to stream | |
* | |
* dmlc::Stream support Write/Read of most STL | |
* composites and base types. | |
* If the data type is not supported, a compile time error will | |
* be issued. | |
* | |
* \param data data to be written | |
* \tparam T the data type to be written | |
*/ | |
template<typename T> | |
inline void Write(const T &data); | |
/*! | |
* \brief loads a data from stream. | |
* | |
* dmlc::Stream support Write/Read of most STL | |
* composites and base types. | |
* If the data type is not supported, a compile time error will | |
* be issued. | |
* | |
* \param out_data place holder of data to be deserialized | |
* \return whether the load was successful | |
*/ | |
template<typename T> | |
inline bool Read(T *out_data); | |
}; | |
/*! \brief interface of i/o stream that support seek */ | |
class SeekStream: public Stream { | |
public: | |
// virtual destructor | |
virtual ~SeekStream(void) {} | |
/*! \brief seek to certain position of the file */ | |
virtual void Seek(size_t pos) = 0; | |
/*! \brief tell the position of the stream */ | |
virtual size_t Tell(void) = 0; | |
/*! | |
* \brief generic factory function | |
* create an SeekStream for read only, | |
* the stream will close the underlying files upon deletion | |
* error will be reported and the system will exit when create failed | |
* \param uri the uri of the input currently we support | |
* hdfs://, s3://, and file:// by default file:// will be used | |
* \param allow_null whether NULL can be returned, or directly report error | |
* \return the created stream, can be NULL when allow_null == true and file do not exist | |
*/ | |
static SeekStream *CreateForRead(const char *uri, | |
bool allow_null = false); | |
}; | |
/*! \brief interface for serializable objects */ | |
class Serializable { | |
public: | |
/*! \brief virtual destructor */ | |
virtual ~Serializable() {} | |
/*! | |
* \brief load the model from a stream | |
* \param fi stream where to load the model from | |
*/ | |
virtual void Load(Stream *fi) = 0; | |
/*! | |
* \brief saves the model to a stream | |
* \param fo stream where to save the model to | |
*/ | |
virtual void Save(Stream *fo) const = 0; | |
}; | |
/*! | |
* \brief input split creates that allows reading | |
* of records from split of data, | |
* independent part that covers all the dataset | |
* | |
* see InputSplit::Create for definition of record | |
*/ | |
class InputSplit { | |
public: | |
/*! \brief a blob of memory region */ | |
struct Blob { | |
/*! \brief points to start of the memory region */ | |
void *dptr; | |
/*! \brief size of the memory region */ | |
size_t size; | |
}; | |
/*! | |
* \brief hint the inputsplit how large the chunk size | |
* it should return when implementing NextChunk | |
* this is a hint so may not be enforced, | |
* but InputSplit will try adjust its internal buffer | |
* size to the hinted value | |
* \param chunk_size the chunk size | |
*/ | |
virtual void HintChunkSize(size_t chunk_size) {} | |
/*! \brief get the total size of the InputSplit */ | |
virtual size_t GetTotalSize(void) = 0; | |
/*! \brief reset the position of InputSplit to beginning */ | |
virtual void BeforeFirst(void) = 0; | |
/*! | |
* \brief get the next record, the returning value | |
* is valid until next call to NextRecord or NextChunk | |
* caller can modify the memory content of out_rec | |
* | |
* For text, out_rec contains a single line | |
* For recordio, out_rec contains one record content(with header striped) | |
* | |
* \param out_rec used to store the result | |
* \return true if we can successfully get next record | |
* false if we reached end of split | |
* \sa InputSplit::Create for definition of record | |
*/ | |
virtual bool NextRecord(Blob *out_rec) = 0; | |
/*! | |
* \brief get a chunk of memory that can contain multiple records, | |
* the caller needs to parse the content of the resulting chunk, | |
* for text file, out_chunk can contain data of multiple lines | |
* for recordio, out_chunk can contain multiple records(including headers) | |
* | |
* This function ensures there won't be partial record in the chunk | |
* caller can modify the memory content of out_chunk, | |
* the memory is valid until next call to NextRecord or NextChunk | |
* | |
* Usually NextRecord is sufficient, NextChunk can be used by some | |
* multi-threaded parsers to parse the input content | |
* | |
* \param out_chunk used to store the result | |
* \return true if we can successfully get next record | |
* false if we reached end of split | |
* \sa InputSplit::Create for definition of record | |
* \sa RecordIOChunkReader to parse recordio content from out_chunk | |
*/ | |
virtual bool NextChunk(Blob *out_chunk) = 0; | |
/*! \brief destructor*/ | |
virtual ~InputSplit(void) {} | |
/*! | |
* \brief reset the Input split to a certain part id, | |
* The InputSplit will be pointed to the head of the new specified segment. | |
* This feature may not be supported by every implementation of InputSplit. | |
* \param part_index The part id of the new input. | |
* \param num_parts The total number of parts. | |
*/ | |
virtual void ResetPartition(unsigned part_index, unsigned num_parts) = 0; | |
/*! | |
* \brief factory function: | |
* create input split given a uri | |
* \param uri the uri of the input, can contain hdfs prefix | |
* \param part_index the part id of current input | |
* \param num_parts total number of splits | |
* \param type type of record | |
* List of possible types: "text", "recordio" | |
* - "text": | |
* text file, each line is treated as a record | |
* input split will split on '\\n' or '\\r' | |
* - "recordio": | |
* binary recordio file, see recordio.h | |
* \return a new input split | |
* \sa InputSplit::Type | |
*/ | |
static InputSplit* Create(const char *uri, | |
unsigned part_index, | |
unsigned num_parts, | |
const char *type); | |
}; | |
/*! | |
* \brief a std::ostream class that can can wrap Stream objects, | |
* can use ostream with that output to underlying Stream | |
* | |
* Usage example: | |
* \code | |
* | |
* Stream *fs = Stream::Create("hdfs:///test.txt", "w"); | |
* dmlc::ostream os(fs); | |
* os << "hello world" << std::endl; | |
* delete fs; | |
* \endcode | |
*/ | |
class ostream : public std::basic_ostream<char> { | |
public: | |
/*! | |
* \brief construct std::ostream type | |
* \param stream the Stream output to be used | |
* \param buffer_size internal streambuf size | |
*/ | |
explicit ostream(Stream *stream, | |
size_t buffer_size = (1 << 10)) | |
: std::basic_ostream<char>(NULL), buf_(buffer_size) { | |
this->set_stream(stream); | |
} | |
// explictly synchronize the buffer | |
virtual ~ostream() DMLC_NO_EXCEPTION { | |
buf_.pubsync(); | |
} | |
/*! | |
* \brief set internal stream to be stream, reset states | |
* \param stream new stream as output | |
*/ | |
inline void set_stream(Stream *stream) { | |
buf_.set_stream(stream); | |
this->rdbuf(&buf_); | |
} | |
/*! \return how many bytes we written so far */ | |
inline size_t bytes_written(void) const { | |
return buf_.bytes_out(); | |
} | |
private: | |
// internal streambuf | |
class OutBuf : public std::streambuf { | |
public: | |
explicit OutBuf(size_t buffer_size) | |
: stream_(NULL), buffer_(buffer_size), bytes_out_(0) { | |
if (buffer_size == 0) buffer_.resize(2); | |
} | |
// set stream to the buffer | |
inline void set_stream(Stream *stream); | |
inline size_t bytes_out() const { return bytes_out_; } | |
private: | |
/*! \brief internal stream by StreamBuf */ | |
Stream *stream_; | |
/*! \brief internal buffer */ | |
std::vector<char> buffer_; | |
/*! \brief number of bytes written so far */ | |
size_t bytes_out_; | |
// override sync | |
inline int_type sync(void); | |
// override overflow | |
inline int_type overflow(int c); | |
}; | |
/*! \brief buffer of the stream */ | |
OutBuf buf_; | |
}; | |
/*! | |
* \brief a std::istream class that can can wrap Stream objects, | |
* can use istream with that output to underlying Stream | |
* | |
* Usage example: | |
* \code | |
* | |
* Stream *fs = Stream::Create("hdfs:///test.txt", "r"); | |
* dmlc::istream is(fs); | |
* is >> mydata; | |
* delete fs; | |
* \endcode | |
*/ | |
class istream : public std::basic_istream<char> { | |
public: | |
/*! | |
* \brief construct std::ostream type | |
* \param stream the Stream output to be used | |
* \param buffer_size internal buffer size | |
*/ | |
explicit istream(Stream *stream, | |
size_t buffer_size = (1 << 10)) | |
: std::basic_istream<char>(NULL), buf_(buffer_size) { | |
this->set_stream(stream); | |
} | |
virtual ~istream() DMLC_NO_EXCEPTION {} | |
/*! | |
* \brief set internal stream to be stream, reset states | |
* \param stream new stream as output | |
*/ | |
inline void set_stream(Stream *stream) { | |
buf_.set_stream(stream); | |
this->rdbuf(&buf_); | |
} | |
/*! \return how many bytes we read so far */ | |
inline size_t bytes_read(void) const { | |
return buf_.bytes_read(); | |
} | |
private: | |
// internal streambuf | |
class InBuf : public std::streambuf { | |
public: | |
explicit InBuf(size_t buffer_size) | |
: stream_(NULL), bytes_read_(0), | |
buffer_(buffer_size) { | |
if (buffer_size == 0) buffer_.resize(2); | |
} | |
// set stream to the buffer | |
inline void set_stream(Stream *stream); | |
// return how many bytes read so far | |
inline size_t bytes_read(void) const { | |
return bytes_read_; | |
} | |
private: | |
/*! \brief internal stream by StreamBuf */ | |
Stream *stream_; | |
/*! \brief how many bytes we read so far */ | |
size_t bytes_read_; | |
/*! \brief internal buffer */ | |
std::vector<char> buffer_; | |
// override underflow | |
inline int_type underflow(); | |
}; | |
/*! \brief input buffer */ | |
InBuf buf_; | |
}; | |
} // namespace dmlc | |
//===== EXPANDING: ../dmlc-core/include/dmlc/serializer.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file serializer.h | |
* \brief serializer template class that helps serialization. | |
* This file do not need to be directly used by most user. | |
*/ | |
#ifndef DMLC_SERIALIZER_H_ | |
#define DMLC_SERIALIZER_H_ | |
//===== EXPANDING: ../dmlc-core/include/dmlc/type_traits.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file type_traits.h | |
* \brief type traits information header | |
*/ | |
#ifndef DMLC_TYPE_TRAITS_H_ | |
#define DMLC_TYPE_TRAITS_H_ | |
#if DMLC_USE_CXX11 | |
#endif | |
namespace dmlc { | |
/*! | |
* \brief whether a type is pod type | |
* \tparam T the type to query | |
*/ | |
template<typename T> | |
struct is_pod { | |
#if DMLC_USE_CXX11 | |
/*! \brief the value of the traits */ | |
static const bool value = std::is_pod<T>::value; | |
#else | |
/*! \brief the value of the traits */ | |
static const bool value = false; | |
#endif | |
}; | |
/*! | |
* \brief whether a type is integer type | |
* \tparam T the type to query | |
*/ | |
template<typename T> | |
struct is_integral { | |
#if DMLC_USE_CXX11 | |
/*! \brief the value of the traits */ | |
static const bool value = std::is_integral<T>::value; | |
#else | |
/*! \brief the value of the traits */ | |
static const bool value = false; | |
#endif | |
}; | |
/*! | |
* \brief whether a type is floating point type | |
* \tparam T the type to query | |
*/ | |
template<typename T> | |
struct is_floating_point { | |
#if DMLC_USE_CXX11 | |
/*! \brief the value of the traits */ | |
static const bool value = std::is_floating_point<T>::value; | |
#else | |
/*! \brief the value of the traits */ | |
static const bool value = false; | |
#endif | |
}; | |
/*! | |
* \brief whether a type is arithemetic type | |
* \tparam T the type to query | |
*/ | |
template<typename T> | |
struct is_arithmetic { | |
#if DMLC_USE_CXX11 | |
/*! \brief the value of the traits */ | |
static const bool value = std::is_arithmetic<T>::value; | |
#else | |
/*! \brief the value of the traits */ | |
static const bool value = (dmlc::is_integral<T>::value || | |
dmlc::is_floating_point<T>::value); | |
#endif | |
}; | |
/*! | |
* \brief the string representation of type name | |
* \tparam T the type to query | |
* \return a const string of typename. | |
*/ | |
template<typename T> | |
inline const char* type_name() { | |
return ""; | |
} | |
/*! | |
* \brief whether a type have save/load function | |
* \tparam T the type to query | |
*/ | |
template<typename T> | |
struct has_saveload { | |
/*! \brief the value of the traits */ | |
static const bool value = false; | |
}; | |
/*! | |
* \brief template to select type based on condition | |
* For example, IfThenElseType<true, int, float>::Type will give int | |
* \tparam cond the condition | |
* \tparam Then the typename to be returned if cond is true | |
* \tparam The typename to be returned if cond is false | |
*/ | |
template<bool cond, typename Then, typename Else> | |
struct IfThenElseType; | |
/*! \brief macro to quickly declare traits information */ | |
#define DMLC_DECLARE_TRAITS(Trait, Type, Value) \ | |
template<> \ | |
struct Trait<Type> { \ | |
static const bool value = Value; \ | |
} | |
/*! \brief macro to quickly declare traits information */ | |
#define DMLC_DECLARE_TYPE_NAME(Type, Name) \ | |
template<> \ | |
inline const char* type_name<Type>() { \ | |
return Name; \ | |
} | |
//! \cond Doxygen_Suppress | |
// declare special traits when C++11 is not available | |
#if DMLC_USE_CXX11 == 0 | |
DMLC_DECLARE_TRAITS(is_pod, char, true); | |
DMLC_DECLARE_TRAITS(is_pod, int8_t, true); | |
DMLC_DECLARE_TRAITS(is_pod, int16_t, true); | |
DMLC_DECLARE_TRAITS(is_pod, int32_t, true); | |
DMLC_DECLARE_TRAITS(is_pod, int64_t, true); | |
DMLC_DECLARE_TRAITS(is_pod, uint8_t, true); | |
DMLC_DECLARE_TRAITS(is_pod, uint16_t, true); | |
DMLC_DECLARE_TRAITS(is_pod, uint32_t, true); | |
DMLC_DECLARE_TRAITS(is_pod, uint64_t, true); | |
DMLC_DECLARE_TRAITS(is_pod, float, true); | |
DMLC_DECLARE_TRAITS(is_pod, double, true); | |
DMLC_DECLARE_TRAITS(is_integral, char, true); | |
DMLC_DECLARE_TRAITS(is_integral, int8_t, true); | |
DMLC_DECLARE_TRAITS(is_integral, int16_t, true); | |
DMLC_DECLARE_TRAITS(is_integral, int32_t, true); | |
DMLC_DECLARE_TRAITS(is_integral, int64_t, true); | |
DMLC_DECLARE_TRAITS(is_integral, uint8_t, true); | |
DMLC_DECLARE_TRAITS(is_integral, uint16_t, true); | |
DMLC_DECLARE_TRAITS(is_integral, uint32_t, true); | |
DMLC_DECLARE_TRAITS(is_integral, uint64_t, true); | |
DMLC_DECLARE_TRAITS(is_floating_point, float, true); | |
DMLC_DECLARE_TRAITS(is_floating_point, double, true); | |
#endif | |
DMLC_DECLARE_TYPE_NAME(float, "float"); | |
DMLC_DECLARE_TYPE_NAME(double, "double"); | |
DMLC_DECLARE_TYPE_NAME(int, "int"); | |
DMLC_DECLARE_TYPE_NAME(uint32_t, "int (non-negative)"); | |
DMLC_DECLARE_TYPE_NAME(uint64_t, "long (non-negative)"); | |
DMLC_DECLARE_TYPE_NAME(std::string, "string"); | |
DMLC_DECLARE_TYPE_NAME(bool, "boolean"); | |
template<typename Then, typename Else> | |
struct IfThenElseType<true, Then, Else> { | |
typedef Then Type; | |
}; | |
template<typename Then, typename Else> | |
struct IfThenElseType<false, Then, Else> { | |
typedef Else Type; | |
}; | |
//! \endcond | |
} // namespace dmlc | |
#endif // DMLC_TYPE_TRAITS_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/type_traits.h ===== | |
#if DMLC_USE_CXX11 | |
#endif | |
namespace dmlc { | |
/*! \brief internal namespace for serializers */ | |
namespace serializer { | |
/*! | |
* \brief generic serialization handler | |
* \tparam T the type to be serialized | |
*/ | |
template<typename T> | |
struct Handler; | |
//! \cond Doxygen_Suppress | |
/*! | |
* \brief Serializer that redirect calls by condition | |
* \tparam cond the condition | |
* \tparam Then the serializer used for then condition | |
* \tparam Else the serializer used for else condition | |
* \tparam Return the type of data the serializer handles | |
*/ | |
template<bool cond, typename Then, typename Else, typename Return> | |
struct IfThenElse; | |
template<typename Then, typename Else, typename T> | |
struct IfThenElse<true, Then, Else, T> { | |
inline static void Write(Stream *strm, const T &data) { | |
Then::Write(strm, data); | |
} | |
inline static bool Read(Stream *strm, T *data) { | |
return Then::Read(strm, data); | |
} | |
}; | |
template<typename Then, typename Else, typename T> | |
struct IfThenElse<false, Then, Else, T> { | |
inline static void Write(Stream *strm, const T &data) { | |
Else::Write(strm, data); | |
} | |
inline static bool Read(Stream *strm, T *data) { | |
return Else::Read(strm, data); | |
} | |
}; | |
/*! \brief Serializer for POD(plain-old-data) data */ | |
template<typename T> | |
struct PODHandler { | |
inline static void Write(Stream *strm, const T &data) { | |
strm->Write(&data, sizeof(T)); | |
} | |
inline static bool Read(Stream *strm, T *dptr) { | |
return strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*) | |
} | |
}; | |
// serializer for class that have save/load function | |
template<typename T> | |
struct SaveLoadClassHandler { | |
inline static void Write(Stream *strm, const T &data) { | |
data.Save(strm); | |
} | |
inline static bool Read(Stream *strm, T *data) { | |
return data->Load(strm); | |
} | |
}; | |
/*! | |
* \brief dummy class for undefined serialization. | |
* This is used to generate error message when user tries to | |
* serialize something that is not supported. | |
* \tparam T the type to be serialized | |
*/ | |
template<typename T> | |
struct UndefinedSerializerFor { | |
}; | |
/*! | |
* \brief Serializer handler for std::vector<T> where T is POD type. | |
* \tparam T element type | |
*/ | |
template<typename T> | |
struct PODVectorHandler { | |
inline static void Write(Stream *strm, const std::vector<T> &vec) { | |
uint64_t sz = static_cast<uint64_t>(vec.size()); | |
strm->Write(&sz, sizeof(sz)); | |
if (sz != 0) { | |
strm->Write(&vec[0], sizeof(T) * vec.size()); | |
} | |
} | |
inline static bool Read(Stream *strm, std::vector<T> *out_vec) { | |
uint64_t sz; | |
if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false; | |
size_t size = static_cast<size_t>(sz); | |
out_vec->resize(size); | |
if (sz != 0) { | |
size_t nbytes = sizeof(T) * size; | |
return strm->Read(&(*out_vec)[0], nbytes) == nbytes; | |
} | |
return true; | |
} | |
}; | |
/*! | |
* \brief Serializer handler for std::vector<T> where T can be composed type | |
* \tparam T element type | |
*/ | |
template<typename T> | |
struct ComposeVectorHandler { | |
inline static void Write(Stream *strm, const std::vector<T> &vec) { | |
uint64_t sz = static_cast<uint64_t>(vec.size()); | |
strm->Write(&sz, sizeof(sz)); | |
for (size_t i = 0; i < vec.size(); ++i) { | |
Handler<T>::Write(strm, vec[i]); | |
} | |
} | |
inline static bool Read(Stream *strm, std::vector<T> *out_vec) { | |
uint64_t sz; | |
if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false; | |
size_t size = static_cast<size_t>(sz); | |
out_vec->resize(size); | |
for (size_t i = 0; i < size; ++i) { | |
if (!Handler<T>::Read(strm, &(*out_vec)[i])) return false; | |
} | |
return true; | |
} | |
}; | |
/*! | |
* \brief Serializer handler for std::basic_string<T> where T is POD type. | |
* \tparam T element type | |
*/ | |
template<typename T> | |
struct PODStringHandler { | |
inline static void Write(Stream *strm, const std::basic_string<T> &vec) { | |
uint64_t sz = static_cast<uint64_t>(vec.length()); | |
strm->Write(&sz, sizeof(sz)); | |
if (sz != 0) { | |
strm->Write(&vec[0], sizeof(T) * vec.length()); | |
} | |
} | |
inline static bool Read(Stream *strm, std::basic_string<T> *out_vec) { | |
uint64_t sz; | |
if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false; | |
size_t size = static_cast<size_t>(sz); | |
out_vec->resize(size); | |
if (sz != 0) { | |
size_t nbytes = sizeof(T) * size; | |
return strm->Read(&(*out_vec)[0], nbytes) == nbytes; | |
} | |
return true; | |
} | |
}; | |
/*! \brief Serializer for std::pair */ | |
template<typename TA, typename TB> | |
struct PairHandler { | |
inline static void Write(Stream *strm, const std::pair<TA, TB> &data) { | |
Handler<TA>::Write(strm, data.first); | |
Handler<TB>::Write(strm, data.second); | |
} | |
inline static bool Read(Stream *strm, std::pair<TA, TB> *data) { | |
return Handler<TA>::Read(strm, &(data->first)) && | |
Handler<TB>::Read(strm, &(data->second)); | |
} | |
}; | |
// set type handler that can handle most collection type case | |
template<typename ContainerType> | |
struct CollectionHandler { | |
inline static void Write(Stream *strm, const ContainerType &data) { | |
typedef typename ContainerType::value_type ElemType; | |
// dump data to vector | |
std::vector<ElemType> vdata(data.begin(), data.end()); | |
// serialize the vector | |
Handler<std::vector<ElemType> >::Write(strm, vdata); | |
} | |
inline static bool Read(Stream *strm, ContainerType *data) { | |
typedef typename ContainerType::value_type ElemType; | |
std::vector<ElemType> vdata; | |
if (!Handler<std::vector<ElemType> >::Read(strm, &vdata)) return false; | |
data->clear(); | |
data->insert(vdata.begin(), vdata.end()); | |
return true; | |
} | |
}; | |
// handler that can handle most list type case | |
// this type insert function takes additional iterator | |
template<typename ListType> | |
struct ListHandler { | |
inline static void Write(Stream *strm, const ListType &data) { | |
typedef typename ListType::value_type ElemType; | |
// dump data to vector | |
std::vector<ElemType> vdata(data.begin(), data.end()); | |
// serialize the vector | |
Handler<std::vector<ElemType> >::Write(strm, vdata); | |
} | |
inline static bool Read(Stream *strm, ListType *data) { | |
typedef typename ListType::value_type ElemType; | |
std::vector<ElemType> vdata; | |
if (!Handler<std::vector<ElemType> >::Read(strm, &vdata)) return false; | |
data->clear(); | |
data->insert(data->begin(), vdata.begin(), vdata.end()); | |
return true; | |
} | |
}; | |
//! \endcond | |
/*! | |
* \brief generic serialization handler for type T | |
* | |
* User can define specialization of this class to support | |
* composite serialization of their own class. | |
* | |
* \tparam T the type to be serialized | |
*/ | |
template<typename T> | |
struct Handler { | |
/*! | |
* \brief write data to stream | |
* \param strm the stream we write the data. | |
* \param data the data obeject to be serialized | |
*/ | |
inline static void Write(Stream *strm, const T &data) { | |
IfThenElse<dmlc::is_pod<T>::value, | |
PODHandler<T>, | |
IfThenElse<dmlc::has_saveload<T>::value, | |
SaveLoadClassHandler<T>, | |
UndefinedSerializerFor<T>, T>, | |
T> | |
::Write(strm, data); | |
} | |
/*! | |
* \brief read data to stream | |
* \param strm the stream to read the data. | |
* \param data the pointer to the data obeject to read | |
* \return whether the read is successful | |
*/ | |
inline static bool Read(Stream *strm, T *data) { | |
return IfThenElse<dmlc::is_pod<T>::value, | |
PODHandler<T>, | |
IfThenElse<dmlc::has_saveload<T>::value, | |
SaveLoadClassHandler<T>, | |
UndefinedSerializerFor<T>, T>, | |
T> | |
::Read(strm, data); | |
} | |
}; | |
//! \cond Doxygen_Suppress | |
template<typename T> | |
struct Handler<std::vector<T> > { | |
inline static void Write(Stream *strm, const std::vector<T> &data) { | |
IfThenElse<dmlc::is_pod<T>::value, | |
PODVectorHandler<T>, | |
ComposeVectorHandler<T>, std::vector<T> > | |
::Write(strm, data); | |
} | |
inline static bool Read(Stream *strm, std::vector<T> *data) { | |
return IfThenElse<dmlc::is_pod<T>::value, | |
PODVectorHandler<T>, | |
ComposeVectorHandler<T>, | |
std::vector<T> > | |
::Read(strm, data); | |
} | |
}; | |
template<typename T> | |
struct Handler<std::basic_string<T> > { | |
inline static void Write(Stream *strm, const std::basic_string<T> &data) { | |
IfThenElse<dmlc::is_pod<T>::value, | |
PODStringHandler<T>, | |
UndefinedSerializerFor<T>, | |
std::basic_string<T> > | |
::Write(strm, data); | |
} | |
inline static bool Read(Stream *strm, std::basic_string<T> *data) { | |
return IfThenElse<dmlc::is_pod<T>::value, | |
PODStringHandler<T>, | |
UndefinedSerializerFor<T>, | |
std::basic_string<T> > | |
::Read(strm, data); | |
} | |
}; | |
template<typename TA, typename TB> | |
struct Handler<std::pair<TA, TB> > { | |
inline static void Write(Stream *strm, const std::pair<TA, TB> &data) { | |
IfThenElse<dmlc::is_pod<TA>::value && dmlc::is_pod<TB>::value, | |
PODHandler<std::pair<TA, TB> >, | |
PairHandler<TA, TB>, | |
std::pair<TA, TB> > | |
::Write(strm, data); | |
} | |
inline static bool Read(Stream *strm, std::pair<TA, TB> *data) { | |
return IfThenElse<dmlc::is_pod<TA>::value && dmlc::is_pod<TB>::value, | |
PODHandler<std::pair<TA, TB> >, | |
PairHandler<TA, TB>, | |
std::pair<TA, TB> > | |
::Read(strm, data); | |
} | |
}; | |
template<typename K, typename V> | |
struct Handler<std::map<K, V> > | |
: public CollectionHandler<std::map<K, V> > { | |
}; | |
template<typename K, typename V> | |
struct Handler<std::multimap<K, V> > | |
: public CollectionHandler<std::multimap<K, V> > { | |
}; | |
template<typename T> | |
struct Handler<std::set<T> > | |
: public CollectionHandler<std::set<T> > { | |
}; | |
template<typename T> | |
struct Handler<std::multiset<T> > | |
: public CollectionHandler<std::multiset<T> > { | |
}; | |
template<typename T> | |
struct Handler<std::list<T> > | |
: public ListHandler<std::list<T> > { | |
}; | |
template<typename T> | |
struct Handler<std::deque<T> > | |
: public ListHandler<std::deque<T> > { | |
}; | |
#if DMLC_USE_CXX11 | |
template<typename K, typename V> | |
struct Handler<std::unordered_map<K, V> > | |
: public CollectionHandler<std::unordered_map<K, V> > { | |
}; | |
template<typename K, typename V> | |
struct Handler<std::unordered_multimap<K, V> > | |
: public CollectionHandler<std::unordered_multimap<K, V> > { | |
}; | |
template<typename T> | |
struct Handler<std::unordered_set<T> > | |
: public CollectionHandler<std::unordered_set<T> > { | |
}; | |
template<typename T> | |
struct Handler<std::unordered_multiset<T> > | |
: public CollectionHandler<std::unordered_multiset<T> > { | |
}; | |
#endif | |
//! \endcond | |
} // namespace serializer | |
} // namespace dmlc | |
#endif // DMLC_SERIALIZER_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/serializer.h ===== | |
namespace dmlc { | |
// implementations of inline functions | |
template<typename T> | |
inline void Stream::Write(const T &data) { | |
serializer::Handler<T>::Write(this, data); | |
} | |
template<typename T> | |
inline bool Stream::Read(T *out_data) { | |
return serializer::Handler<T>::Read(this, out_data); | |
} | |
// implementations for ostream | |
inline void ostream::OutBuf::set_stream(Stream *stream) { | |
if (stream_ != NULL) this->pubsync(); | |
this->stream_ = stream; | |
this->setp(&buffer_[0], &buffer_[0] + buffer_.size() - 1); | |
} | |
inline int ostream::OutBuf::sync(void) { | |
if (stream_ == NULL) return -1; | |
std::ptrdiff_t n = pptr() - pbase(); | |
stream_->Write(pbase(), n); | |
this->pbump(-static_cast<int>(n)); | |
bytes_out_ += n; | |
return 0; | |
} | |
inline int ostream::OutBuf::overflow(int c) { | |
*(this->pptr()) = c; | |
std::ptrdiff_t n = pptr() - pbase(); | |
this->pbump(-static_cast<int>(n)); | |
if (c == EOF) { | |
stream_->Write(pbase(), n); | |
bytes_out_ += n; | |
} else { | |
stream_->Write(pbase(), n + 1); | |
bytes_out_ += n + 1; | |
} | |
return c; | |
} | |
// implementations for istream | |
inline void istream::InBuf::set_stream(Stream *stream) { | |
stream_ = stream; | |
this->setg(&buffer_[0], &buffer_[0], &buffer_[0]); | |
} | |
inline int istream::InBuf::underflow() { | |
char *bhead = &buffer_[0]; | |
if (this->gptr() == this->egptr()) { | |
size_t sz = stream_->Read(bhead, buffer_.size()); | |
this->setg(bhead, bhead, bhead + sz); | |
bytes_read_ += sz; | |
} | |
if (this->gptr() == this->egptr()) { | |
return traits_type::eof(); | |
} else { | |
return traits_type::to_int_type(*gptr()); | |
} | |
} | |
} // namespace dmlc | |
#endif // DMLC_IO_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/io.h ===== | |
//===== EXPANDING: ../dmlc-core/src/io/line_split.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file line_split.h | |
* \brief base class implementation of input splitter | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_IO_LINE_SPLIT_H_ | |
#define DMLC_IO_LINE_SPLIT_H_ | |
//===== EXPANDING: ../dmlc-core/src/io/input_split_base.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file input_split_base.h | |
* \brief base class to construct input split from multiple files | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_IO_INPUT_SPLIT_BASE_H_ | |
#define DMLC_IO_INPUT_SPLIT_BASE_H_ | |
//===== EXPANDING: ../dmlc-core/src/io/filesys.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file filesystem.h | |
* \brief general file system io interface | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_IO_FILESYS_H_ | |
#define DMLC_IO_FILESYS_H_ | |
namespace dmlc { | |
namespace io { | |
/*! \brief common data structure for URI */ | |
struct URI { | |
/*! \brief protocol */ | |
std::string protocol; | |
/*! | |
* \brief host name, namenode for HDFS, bucket name for s3 | |
*/ | |
std::string host; | |
/*! \brief name of the path */ | |
std::string name; | |
/*! \brief enable default constructor */ | |
URI(void) {} | |
/*! | |
* \brief construct from URI string | |
*/ | |
explicit URI(const char *uri) { | |
const char *p = std::strstr(uri, "://"); | |
if (p == NULL) { | |
name = uri; | |
} else { | |
protocol = std::string(uri, p - uri + 3); | |
uri = p + 3; | |
p = std::strchr(uri, '/'); | |
if (p == NULL) { | |
host = uri; name = '/'; | |
} else { | |
host = std::string(uri, p - uri); | |
name = p; | |
} | |
} | |
} | |
/*! \brief string representation */ | |
inline std::string str(void) const { | |
return protocol + host + name; | |
} | |
}; | |
/*! \brief type of file */ | |
enum FileType { | |
/*! \brief the file is file */ | |
kFile, | |
/*! \brief the file is directory */ | |
kDirectory | |
}; | |
/*! \brief use to store file information */ | |
struct FileInfo { | |
/*! \brief full path to the file */ | |
URI path; | |
/*! \brief the size of the file */ | |
size_t size; | |
/*! \brief the type of the file */ | |
FileType type; | |
/*! \brief default constructor */ | |
FileInfo() : size(0), type(kFile) {} | |
}; | |
/*! \brief file system system interface */ | |
class FileSystem { | |
public: | |
/*! | |
* \brief get singleton of filesystem instance according to URI | |
* \param path can be s3://..., hdfs://..., file://..., | |
* empty string(will return local) | |
* \return a corresponding filesystem, report error if | |
* we cannot find a matching system | |
*/ | |
static FileSystem *GetInstance(const URI &path); | |
/*! \brief virtual destructor */ | |
virtual ~FileSystem() {} | |
/*! | |
* \brief get information about a path | |
* \param path the path to the file | |
* \return the information about the file | |
*/ | |
virtual FileInfo GetPathInfo(const URI &path) = 0; | |
/*! | |
* \brief list files in a directory | |
* \param path to the file | |
* \param out_list the output information about the files | |
*/ | |
virtual void ListDirectory(const URI &path, std::vector<FileInfo> *out_list) = 0; | |
/*! | |
* \brief open a stream | |
* \param path path to file | |
* \param uri the uri of the input, can contain hdfs prefix | |
* \param flag can be "w", "r", "a | |
* \param allow_null whether NULL can be returned, or directly report error | |
* \return the created stream, can be NULL when allow_null == true and file do not exist | |
*/ | |
virtual Stream *Open(const URI &path, | |
const char* const flag, | |
bool allow_null = false) = 0; | |
/*! | |
* \brief open a seekable stream for read | |
* \param path the path to the file | |
* \param allow_null whether NULL can be returned, or directly report error | |
* \return the created stream, can be NULL when allow_null == true and file do not exist | |
*/ | |
virtual SeekStream *OpenForRead(const URI &path, | |
bool allow_null = false) = 0; | |
}; | |
} // namespace io | |
} // namespace dmlc | |
#endif // DMLC_IO_FILESYS_H_ | |
//===== EXPANDED: ../dmlc-core/src/io/filesys.h ===== | |
namespace dmlc { | |
namespace io { | |
/*! \brief class to construct input split from multiple files */ | |
class InputSplitBase : public InputSplit { | |
public: | |
/*! | |
* \brief helper struct to hold chunk data | |
* with internal pointer to move along the record | |
*/ | |
struct Chunk { | |
char *begin; | |
char *end; | |
std::vector<size_t> data; | |
explicit Chunk(size_t buffer_size) | |
: begin(NULL), end(NULL), | |
data(buffer_size + 1) {} | |
// load chunk from split | |
bool Load(InputSplitBase *split, size_t buffer_size); | |
}; | |
// 16 MB | |
static const size_t kBufferSize = 2UL << 20UL; | |
// destructor | |
virtual ~InputSplitBase(void); | |
// implement BeforeFirst | |
virtual void BeforeFirst(void); | |
virtual void HintChunkSize(size_t chunk_size) { | |
buffer_size_ = std::max(chunk_size / sizeof(size_t), buffer_size_); | |
} | |
virtual size_t GetTotalSize(void) { | |
return file_offset_.back(); | |
} | |
// implement next record | |
virtual bool NextRecord(Blob *out_rec) { | |
while (!ExtractNextRecord(out_rec, &tmp_chunk_)) { | |
if (!tmp_chunk_.Load(this, buffer_size_)) return false; | |
} | |
return true; | |
} | |
// implement next chunk | |
virtual bool NextChunk(Blob *out_chunk) { | |
while (!ExtractNextChunk(out_chunk, &tmp_chunk_)) { | |
if (!tmp_chunk_.Load(this, buffer_size_)) return false; | |
} | |
return true; | |
} | |
// implement ResetPartition. | |
virtual void ResetPartition(unsigned rank, unsigned nsplit); | |
/*! | |
* \brief read a chunk of data into buf | |
* the data can span multiple records, | |
* but cannot contain partial records | |
* | |
* \param buf the memory region of the buffer, | |
* should be properly aligned to 64 bits | |
* \param size the maximum size of memory, | |
* after the function returns, it stores the size of the chunk | |
* \return whether end of file was reached | |
*/ | |
bool ReadChunk(void *buf, size_t *size); | |
/*! | |
* \brief extract next chunk from the chunk | |
* \param out_chunk the output record | |
* \param chunk the chunk information | |
* \return true if non-empty record is extracted | |
* false if the chunk is already finishes its life | |
*/ | |
bool ExtractNextChunk(Blob *out_rchunk, Chunk *chunk); | |
/*! | |
* \brief extract next record from the chunk | |
* \param out_rec the output record | |
* \param chunk the chunk information | |
* \return true if non-empty record is extracted | |
* false if the chunk is already finishes its life | |
*/ | |
virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk) = 0; | |
protected: | |
// constructor | |
InputSplitBase() | |
: fs_(NULL), | |
align_bytes_(8), | |
tmp_chunk_(kBufferSize), | |
buffer_size_(kBufferSize) {} | |
/*! | |
* \brief intialize the base before doing anything | |
* \param fs the filesystem ptr | |
* \param uri the uri of the files | |
* \param rank the rank of the split | |
* \param nsplit number of splits | |
* \param align_bytes the head split must be multiple of align_bytes | |
* this also checks if file size are multiple of align_bytes | |
*/ | |
void Init(FileSystem *fs, | |
const char *uri, | |
size_t align_bytes); | |
// to be implemented by child class | |
/*! | |
* \brief seek to the beginning of the first record | |
* in current file pointer | |
* \return how many bytes we read past | |
*/ | |
virtual size_t SeekRecordBegin(Stream *fi) = 0; | |
/*! | |
* \brief find the last occurance of record header | |
* \param begin beginning of the buffer | |
* \param end end of the buffer | |
* \return the pointer between [begin, end] indicating the | |
* last record head | |
*/ | |
virtual const char* | |
FindLastRecordBegin(const char *begin, const char *end) = 0; | |
private: | |
/*! \brief FileSystem */ | |
FileSystem *filesys_; | |
/*! \brief information about files */ | |
std::vector<FileInfo> files_; | |
/*! \brief current input stream */ | |
SeekStream *fs_; | |
/*! \brief bytes to be aligned */ | |
size_t align_bytes_; | |
/*! \brief file pointer of which file to read on */ | |
size_t file_ptr_; | |
/*! \brief file pointer where the end of file lies */ | |
size_t file_ptr_end_; | |
/*! \brief get the current offset */ | |
size_t offset_curr_; | |
/*! \brief beginning of offset */ | |
size_t offset_begin_; | |
/*! \brief end of the offset */ | |
size_t offset_end_; | |
/*! \brief temporal chunk */ | |
Chunk tmp_chunk_; | |
/*! \brief buffer size */ | |
size_t buffer_size_; | |
/*! \brief byte-offset of each file */ | |
std::vector<size_t> file_offset_; | |
/*! \brief internal overflow buffer */ | |
std::string overflow_; | |
/*! \brief initialize information in files */ | |
void InitInputFileInfo(const std::string& uri); | |
/*! \brief strip continous chars in the end of str */ | |
std::string StripEnd(std::string str, char ch); | |
/*! \brief same as stream.Read */ | |
size_t Read(void *ptr, size_t size); | |
}; | |
} // namespace io | |
} // namespace dmlc | |
#endif // DMLC_IO_INPUT_SPLIT_BASE_H_ | |
//===== EXPANDED: ../dmlc-core/src/io/input_split_base.h ===== | |
namespace dmlc { | |
namespace io { | |
/*! \brief class that split the files by line */ | |
class LineSplitter : public InputSplitBase { | |
public: | |
LineSplitter(FileSystem *fs, | |
const char *uri, | |
unsigned rank, | |
unsigned nsplit) { | |
this->Init(fs, uri, 1); | |
this->ResetPartition(rank, nsplit); | |
} | |
virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk); | |
protected: | |
virtual size_t SeekRecordBegin(Stream *fi); | |
virtual const char* | |
FindLastRecordBegin(const char *begin, const char *end); | |
}; | |
} // namespace io | |
} // namespace dmlc | |
#endif // DMLC_IO_LINE_SPLIT_H_ | |
//===== EXPANDED: ../dmlc-core/src/io/line_split.h ===== | |
namespace dmlc { | |
namespace io { | |
size_t LineSplitter::SeekRecordBegin(Stream *fi) { | |
char c = '\0'; | |
size_t nstep = 0; | |
// search till fist end-of-line | |
while (true) { | |
if (fi->Read(&c, sizeof(c)) == 0) return nstep; | |
nstep += 1; | |
if (c == '\n' || c == '\r') break; | |
} | |
// search until first non-endofline | |
while (true) { | |
if (fi->Read(&c, sizeof(c)) == 0) return nstep; | |
if (c != '\n' && c != '\r') break; | |
// non-end-of-line should not count | |
nstep += 1; | |
} | |
return nstep; | |
} | |
const char* LineSplitter::FindLastRecordBegin(const char *begin, | |
const char *end) { | |
CHECK(begin != end); | |
for (const char *p = end - 1; p != begin; --p) { | |
if (*p == '\n' || *p == '\r') return p + 1; | |
} | |
return begin; | |
} | |
bool LineSplitter::ExtractNextRecord(Blob *out_rec, Chunk *chunk) { | |
if (chunk->begin == chunk->end) return false; | |
char *p; | |
for (p = chunk->begin; p != chunk->end; ++p) { | |
if (*p == '\n' || *p == '\r') break; | |
} | |
for (; p != chunk->end; ++p) { | |
if (*p != '\n' && *p != '\r') break; | |
} | |
// set the string end sign for safety | |
if (p == chunk->end) { | |
*p = '\0'; | |
} else { | |
*(p - 1) = '\0'; | |
} | |
out_rec->dptr = chunk->begin; | |
out_rec->size = p - chunk->begin; | |
chunk->begin = p; | |
return true; | |
} | |
} // namespace io | |
} // namespace dmlc | |
//===== EXPANDED: ../dmlc-core/src/io/line_split.cc ===== | |
//===== EXPANDING: ../dmlc-core/src/io/recordio_split.cc ===== | |
// Copyright by Contributors | |
//===== EXPANDING: ../dmlc-core/include/dmlc/recordio.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file recordio.h | |
* \brief recordio that is able to pack binary data into a splittable | |
* format, useful to exchange data in binary serialization, | |
* such as binary raw data or protobuf | |
*/ | |
#ifndef DMLC_RECORDIO_H_ | |
#define DMLC_RECORDIO_H_ | |
namespace dmlc { | |
/*! | |
* \brief writer of binary recordio | |
* binary format for recordio | |
* recordio format: magic lrecord data pad | |
* | |
* - magic is magic number | |
* - pad is simply a padding space to make record align to 4 bytes | |
* - lrecord encodes length and continue bit | |
* - data.length() = (lrecord & (1U<<29U - 1)); | |
* - cflag == (lrecord >> 29U) & 7; | |
* | |
* cflag was used to handle (rare) special case when magic number | |
* occured in the data sequence. | |
* | |
* In such case, the data is splitted into multiple records by | |
* the cells of magic number | |
* | |
* (1) cflag == 0: this is a complete record; | |
* (2) cflag == 1: start of a multiple-rec; | |
* cflag == 2: middle of multiple-rec; | |
* cflag == 3: end of multiple-rec | |
*/ | |
class RecordIOWriter { | |
public: | |
/*! | |
* \brief magic number of recordio | |
* note: (kMagic >> 29U) & 7 > 3 | |
* this ensures lrec will not be kMagic | |
*/ | |
static const uint32_t kMagic = 0xced7230a; | |
/*! | |
* \brief encode the lrecord | |
* \param cflag cflag part of the lrecord | |
* \param length length part of lrecord | |
* \return the encoded data | |
*/ | |
inline static uint32_t EncodeLRec(uint32_t cflag, uint32_t length) { | |
return (cflag << 29U) | length; | |
} | |
/*! | |
* \brief decode the flag part of lrecord | |
* \param rec the lrecord | |
* \return the flag | |
*/ | |
inline static uint32_t DecodeFlag(uint32_t rec) { | |
return (rec >> 29U) & 7U; | |
} | |
/*! | |
* \brief decode the length part of lrecord | |
* \param rec the lrecord | |
* \return the length | |
*/ | |
inline static uint32_t DecodeLength(uint32_t rec) { | |
return rec & ((1U << 29U) - 1U); | |
} | |
/*! | |
* \brief constructor | |
* \param stream the stream to be constructed | |
*/ | |
explicit RecordIOWriter(Stream *stream) | |
: stream_(stream), seek_stream_(dynamic_cast<SeekStream*>(stream)), | |
except_counter_(0) { | |
CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes"; | |
} | |
/*! | |
* \brief write record to the stream | |
* \param buf the buffer of memory region | |
* \param size the size of record to write out | |
*/ | |
void WriteRecord(const void *buf, size_t size); | |
/*! | |
* \brief write record to the stream | |
* \param data the data to write out | |
*/ | |
inline void WriteRecord(const std::string &data) { | |
this->WriteRecord(data.c_str(), data.length()); | |
} | |
/*! | |
* \return number of exceptions(occurance of magic number) | |
* during the writing process | |
*/ | |
inline size_t except_counter(void) const { | |
return except_counter_; | |
} | |
/*! \brief tell the current position of the input stream */ | |
inline size_t Tell(void) { | |
CHECK(seek_stream_ != NULL) << "The input stream is not seekable"; | |
return seek_stream_->Tell(); | |
} | |
private: | |
/*! \brief output stream */ | |
Stream *stream_; | |
/*! \brief seekable stream */ | |
SeekStream *seek_stream_; | |
/*! \brief counts the number of exceptions */ | |
size_t except_counter_; | |
}; | |
/*! | |
* \brief reader of binary recordio to reads in record from stream | |
* \sa RecordIOWriter | |
*/ | |
class RecordIOReader { | |
public: | |
/*! | |
* \brief constructor | |
* \param stream the stream to be constructed | |
*/ | |
explicit RecordIOReader(Stream *stream) | |
: stream_(stream), seek_stream_(dynamic_cast<SeekStream*>(stream)), | |
end_of_stream_(false) { | |
CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes"; | |
} | |
/*! | |
* \brief read next complete record from stream | |
* \param out_rec used to store output record in string | |
* \return true of read was successful, false if end of stream was reached | |
*/ | |
bool NextRecord(std::string *out_rec); | |
/*! \brief seek to certain position of the input stream */ | |
inline void Seek(size_t pos) { | |
CHECK(seek_stream_ != NULL) << "The input stream is not seekable"; | |
seek_stream_->Seek(pos); | |
} | |
private: | |
/*! \brief output stream */ | |
Stream *stream_; | |
SeekStream *seek_stream_; | |
/*! \brief whether we are at end of stream */ | |
bool end_of_stream_; | |
}; | |
/*! | |
* \brief reader of binary recordio from Blob returned by InputSplit | |
* This class divides the blob into several independent parts specified by caller, | |
* and read from one segment. | |
* The part reading can be used together with InputSplit::NextChunk for | |
* multi-threaded parsing(each thread take a RecordIOChunkReader) | |
* | |
* \sa RecordIOWriter, InputSplit | |
*/ | |
class RecordIOChunkReader { | |
public: | |
/*! | |
* \brief constructor | |
* \param chunk source data returned by InputSplit | |
* \param part_index which part we want to reado | |
* \param num_parts number of total segments | |
*/ | |
explicit RecordIOChunkReader(InputSplit::Blob chunk, | |
unsigned part_index = 0, | |
unsigned num_parts = 1); | |
/*! | |
* \brief read next complete record from stream | |
* the blob contains the memory content | |
* NOTE: this function is not threadsafe, use one | |
* RecordIOChunkReader per thread | |
* \param out_rec used to store output blob, the header is already | |
* removed and out_rec only contains the memory content | |
* \return true of read was successful, false if end was reached | |
*/ | |
bool NextRecord(InputSplit::Blob *out_rec); | |
private: | |
/*! \brief internal temporal data */ | |
std::string temp_; | |
/*! \brief internal data pointer */ | |
char *pbegin_, *pend_; | |
}; | |
} // namespace dmlc | |
#endif // DMLC_RECORDIO_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/recordio.h ===== | |
//===== EXPANDING: ../dmlc-core/src/io/recordio_split.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file recordio_split.h | |
* \brief input split that splits recordio files | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_IO_RECORDIO_SPLIT_H_ | |
#define DMLC_IO_RECORDIO_SPLIT_H_ | |
namespace dmlc { | |
namespace io { | |
/*! \brief class that split the files by line */ | |
class RecordIOSplitter : public InputSplitBase { | |
public: | |
RecordIOSplitter(FileSystem *fs, | |
const char *uri, | |
unsigned rank, | |
unsigned nsplit) { | |
this->Init(fs, uri, 4); | |
this->ResetPartition(rank, nsplit); | |
} | |
virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk); | |
protected: | |
virtual size_t SeekRecordBegin(Stream *fi); | |
virtual const char* | |
FindLastRecordBegin(const char *begin, const char *end); | |
}; | |
} // namespace io | |
} // namespace dmlc | |
#endif // DMLC_IO_RECORDIO_SPLIT_H_ | |
//===== EXPANDED: ../dmlc-core/src/io/recordio_split.h ===== | |
namespace dmlc { | |
namespace io { | |
size_t RecordIOSplitter::SeekRecordBegin(Stream *fi) { | |
size_t nstep = 0; | |
uint32_t v, lrec; | |
while (true) { | |
if (fi->Read(&v, sizeof(v)) == 0) return nstep; | |
nstep += sizeof(v); | |
if (v == RecordIOWriter::kMagic) { | |
CHECK(fi->Read(&lrec, sizeof(lrec)) != 0) | |
<< "invalid record io format"; | |
nstep += sizeof(lrec); | |
uint32_t cflag = RecordIOWriter::DecodeFlag(lrec); | |
if (cflag == 0 || cflag == 1) break; | |
} | |
} | |
// should point at head of record | |
return nstep - 2 * sizeof(uint32_t); | |
} | |
const char* RecordIOSplitter::FindLastRecordBegin(const char *begin, | |
const char *end) { | |
CHECK_EQ((reinterpret_cast<size_t>(begin) & 3UL), 0U); | |
CHECK_EQ((reinterpret_cast<size_t>(end) & 3UL), 0U); | |
const uint32_t *pbegin = reinterpret_cast<const uint32_t *>(begin); | |
const uint32_t *p = reinterpret_cast<const uint32_t *>(end); | |
CHECK(p >= pbegin + 2); | |
for (p = p - 2; p != pbegin; --p) { | |
if (p[0] == RecordIOWriter::kMagic) { | |
uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]); | |
if (cflag == 0 || cflag == 1) { | |
return reinterpret_cast<const char*>(p); | |
} | |
} | |
} | |
return begin; | |
} | |
bool RecordIOSplitter::ExtractNextRecord(Blob *out_rec, Chunk *chunk) { | |
if (chunk->begin == chunk->end) return false; | |
CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end) | |
<< "Invalid RecordIO Format"; | |
CHECK_EQ((reinterpret_cast<size_t>(chunk->begin) & 3UL), 0U); | |
CHECK_EQ((reinterpret_cast<size_t>(chunk->end) & 3UL), 0U); | |
uint32_t *p = reinterpret_cast<uint32_t *>(chunk->begin); | |
uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]); | |
uint32_t clen = RecordIOWriter::DecodeLength(p[1]); | |
// skip header | |
out_rec->dptr = chunk->begin + 2 * sizeof(uint32_t); | |
// move pbegin | |
chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U); | |
CHECK(chunk->begin <= chunk->end) << "Invalid RecordIO Format"; | |
out_rec->size = clen; | |
if (cflag == 0) return true; | |
const uint32_t kMagic = RecordIOWriter::kMagic; | |
// abnormal path, move data around to make a full part | |
CHECK(cflag == 1U) << "Invalid RecordIO Format"; | |
while (cflag != 3U) { | |
CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end); | |
p = reinterpret_cast<uint32_t *>(chunk->begin); | |
CHECK(p[0] == RecordIOWriter::kMagic); | |
cflag = RecordIOWriter::DecodeFlag(p[1]); | |
clen = RecordIOWriter::DecodeLength(p[1]); | |
// pad kmagic in between | |
std::memcpy(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size, | |
&kMagic, sizeof(kMagic)); | |
out_rec->size += sizeof(kMagic); | |
// move the rest of the blobs | |
if (clen != 0) { | |
std::memmove(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size, | |
chunk->begin + 2 * sizeof(uint32_t), clen); | |
out_rec->size += clen; | |
} | |
chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U); | |
} | |
return true; | |
} | |
} // namespace io | |
} // namespace dmlc | |
//===== EXPANDED: ../dmlc-core/src/io/recordio_split.cc ===== | |
//===== EXPANDING: ../dmlc-core/src/io/input_split_base.cc ===== | |
// Copyright by Contributors | |
//===== EXPANDING: ../dmlc-core/include/dmlc/common.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file common.h | |
* \brief defines some common utility function. | |
*/ | |
#ifndef DMLC_COMMON_H_ | |
#define DMLC_COMMON_H_ | |
namespace dmlc { | |
/*! | |
* \brief Split a string by delimiter | |
* \param s String to be splitted. | |
* \param delim The delimiter. | |
* \return a splitted vector of strings. | |
*/ | |
inline std::vector<std::string> Split(const std::string& s, char delim) { | |
std::string item; | |
std::istringstream is(s); | |
std::vector<std::string> ret; | |
while (std::getline(is, item, delim)) { | |
ret.push_back(item); | |
} | |
return ret; | |
} | |
} // namespace dmlc | |
#endif // DMLC_COMMON_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/common.h ===== | |
#if DMLC_USE_REGEX | |
#endif | |
namespace dmlc { | |
namespace io { | |
void InputSplitBase::Init(FileSystem *filesys, | |
const char *uri, | |
size_t align_bytes) { | |
this->filesys_ = filesys; | |
// initialize the path | |
this->InitInputFileInfo(uri); | |
file_offset_.resize(files_.size() + 1); | |
file_offset_[0] = 0; | |
for (size_t i = 0; i < files_.size(); ++i) { | |
file_offset_[i + 1] = file_offset_[i] + files_[i].size; | |
CHECK(files_[i].size % align_bytes == 0) | |
<< "file do not align by " << align_bytes << " bytes"; | |
} | |
this->align_bytes_ = align_bytes; | |
} | |
void InputSplitBase::ResetPartition(unsigned rank, | |
unsigned nsplit) { | |
size_t ntotal = file_offset_.back(); | |
size_t nstep = (ntotal + nsplit - 1) / nsplit; | |
// align the nstep to 4 bytes | |
nstep = ((nstep + align_bytes_ - 1) / align_bytes_) * align_bytes_; | |
offset_begin_ = std::min(nstep * rank, ntotal); | |
offset_end_ = std::min(nstep * (rank + 1), ntotal); | |
offset_curr_ = offset_begin_; | |
if (offset_begin_ == offset_end_) return; | |
file_ptr_ = std::upper_bound(file_offset_.begin(), | |
file_offset_.end(), | |
offset_begin_) - file_offset_.begin() - 1; | |
file_ptr_end_ = std::upper_bound(file_offset_.begin(), | |
file_offset_.end(), | |
offset_end_) - file_offset_.begin() - 1; | |
if (fs_ != NULL) { | |
delete fs_; fs_ = NULL; | |
} | |
// find the exact ending position | |
if (offset_end_ != file_offset_[file_ptr_end_]) { | |
CHECK(offset_end_ >file_offset_[file_ptr_end_]); | |
CHECK(file_ptr_end_ < files_.size()); | |
fs_ = filesys_->OpenForRead(files_[file_ptr_end_].path); | |
fs_->Seek(offset_end_ - file_offset_[file_ptr_end_]); | |
offset_end_ += SeekRecordBegin(fs_); | |
delete fs_; | |
} | |
fs_ = filesys_->OpenForRead(files_[file_ptr_].path); | |
if (offset_begin_ != file_offset_[file_ptr_]) { | |
fs_->Seek(offset_begin_ - file_offset_[file_ptr_]); | |
offset_begin_ += SeekRecordBegin(fs_); | |
} | |
this->BeforeFirst(); | |
} | |
void InputSplitBase::BeforeFirst(void) { | |
if (offset_begin_ >= offset_end_) return; | |
size_t fp = std::upper_bound(file_offset_.begin(), | |
file_offset_.end(), | |
offset_begin_) - file_offset_.begin() - 1; | |
if (file_ptr_ != fp) { | |
delete fs_; | |
file_ptr_ = fp; | |
fs_ = filesys_->OpenForRead(files_[file_ptr_].path); | |
} | |
// seek to beginning of stream | |
fs_->Seek(offset_begin_ - file_offset_[file_ptr_]); | |
offset_curr_ = offset_begin_; | |
tmp_chunk_.begin = tmp_chunk_.end = NULL; | |
// clear overflow buffer | |
overflow_.clear(); | |
} | |
InputSplitBase::~InputSplitBase(void) { | |
delete fs_; | |
// no need to delete filesystem, it was singleton | |
} | |
std::string InputSplitBase::StripEnd(std::string str, char ch) { | |
while (str.length() != 0 && str[str.length() - 1] == ch) { | |
str.resize(str.length() - 1); | |
} | |
return str; | |
} | |
void InputSplitBase::InitInputFileInfo(const std::string& uri) { | |
// split by : | |
const char dlm = ';'; | |
std::vector<std::string> file_list = Split(uri, dlm); | |
std::vector<URI> expanded_list; | |
// expand by match regex pattern. | |
for (size_t i = 0; i < file_list.size(); ++i) { | |
URI path(file_list[i].c_str()); | |
size_t pos = path.name.rfind('/'); | |
if (pos == std::string::npos || pos + 1 == path.name.length()) { | |
expanded_list.push_back(path); | |
} else { | |
URI dir = path; | |
dir.name = path.name.substr(0, pos); | |
std::vector<FileInfo> dfiles; | |
filesys_->ListDirectory(dir, &dfiles); | |
bool exact_match = false; | |
for (size_t i = 0; i < dfiles.size(); ++i) { | |
if (StripEnd(dfiles[i].path.name, '/') == StripEnd(path.name, '/')) { | |
expanded_list.push_back(dfiles[i].path); | |
exact_match = true; | |
break; | |
} | |
} | |
#if DMLC_USE_REGEX | |
if (!exact_match) { | |
std::string spattern = path.name; | |
try { | |
std::regex pattern(spattern); | |
for (size_t i = 0; i < dfiles.size(); ++i) { | |
if (dfiles[i].type != kFile || dfiles[i].size == 0) continue; | |
std::string stripped = StripEnd(dfiles[i].path.name, '/'); | |
std::smatch base_match; | |
if (std::regex_match(stripped, base_match, pattern)) { | |
for (size_t j = 0; j < base_match.size(); ++j) { | |
if (base_match[j].str() == stripped) { | |
expanded_list.push_back(dfiles[i].path); break; | |
} | |
} | |
} | |
} | |
} catch (std::regex_error& e) { | |
LOG(FATAL) << e.what() << " bad regex " << spattern | |
<< "This could due to compiler version, g++-4.9 is needed"; | |
} | |
} | |
#endif // DMLC_USE_REGEX | |
} | |
} | |
for (size_t i = 0; i < expanded_list.size(); ++i) { | |
const URI& path = expanded_list[i]; | |
FileInfo info = filesys_->GetPathInfo(path); | |
if (info.type == kDirectory) { | |
std::vector<FileInfo> dfiles; | |
filesys_->ListDirectory(info.path, &dfiles); | |
for (size_t i = 0; i < dfiles.size(); ++i) { | |
if (dfiles[i].size != 0 && dfiles[i].type == kFile) { | |
files_.push_back(dfiles[i]); | |
} | |
} | |
} else { | |
if (info.size != 0) { | |
files_.push_back(info); | |
} | |
} | |
} | |
CHECK_NE(files_.size(), 0U) | |
<< "Cannot find any files that matches the URI patternz " << uri; | |
} | |
size_t InputSplitBase::Read(void *ptr, size_t size) { | |
if (offset_begin_ >= offset_end_) return 0; | |
if (offset_curr_ + size > offset_end_) { | |
size = offset_end_ - offset_curr_; | |
} | |
if (size == 0) return 0; | |
size_t nleft = size; | |
char *buf = reinterpret_cast<char*>(ptr); | |
while (true) { | |
size_t n = fs_->Read(buf, nleft); | |
nleft -= n; buf += n; | |
offset_curr_ += n; | |
if (nleft == 0) break; | |
if (n == 0) { | |
if (offset_curr_ != file_offset_[file_ptr_ + 1]) { | |
LOG(ERROR) << "curr=" << offset_curr_ | |
<< ",begin=" << offset_begin_ | |
<< ",end=" << offset_end_ | |
<< ",fileptr=" << file_ptr_ | |
<< ",fileoffset=" << file_offset_[file_ptr_ + 1]; | |
for (size_t i = 0; i < file_ptr_; ++i) { | |
LOG(ERROR) << "offset[" << i << "]=" << file_offset_[i]; | |
} | |
LOG(FATAL) << "file offset not calculated correctly"; | |
} | |
if (file_ptr_ + 1 >= files_.size()) break; | |
file_ptr_ += 1; | |
delete fs_; | |
fs_ = filesys_->OpenForRead(files_[file_ptr_].path); | |
} | |
} | |
return size - nleft; | |
} | |
bool InputSplitBase::ReadChunk(void *buf, size_t *size) { | |
size_t max_size = *size; | |
if (max_size <= overflow_.length()) { | |
*size = 0; return true; | |
} | |
if (overflow_.length() != 0) { | |
std::memcpy(buf, BeginPtr(overflow_), overflow_.length()); | |
} | |
size_t olen = overflow_.length(); | |
overflow_.resize(0); | |
size_t nread = this->Read(reinterpret_cast<char*>(buf) + olen, | |
max_size - olen); | |
nread += olen; | |
if (nread == 0) return false; | |
if (nread != max_size) { | |
*size = nread; | |
return true; | |
} else { | |
const char *bptr = reinterpret_cast<const char*>(buf); | |
// return the last position where a record starts | |
const char *bend = this->FindLastRecordBegin(bptr, bptr + max_size); | |
*size = bend - bptr; | |
overflow_.resize(max_size - *size); | |
if (overflow_.length() != 0) { | |
std::memcpy(BeginPtr(overflow_), bend, overflow_.length()); | |
} | |
return true; | |
} | |
} | |
bool InputSplitBase::Chunk::Load(InputSplitBase *split, size_t buffer_size) { | |
if (buffer_size + 1 > data.size()) { | |
data.resize(buffer_size + 1); | |
} | |
while (true) { | |
// leave one tail chunk | |
size_t size = (data.size() - 1) * sizeof(size_t); | |
// set back to 0 for string safety | |
data.back() = 0; | |
if (!split->ReadChunk(BeginPtr(data), &size)) return false; | |
if (size == 0) { | |
data.resize(data.size() * 2); | |
} else { | |
begin = reinterpret_cast<char *>(BeginPtr(data)); | |
end = begin + size; | |
break; | |
} | |
} | |
return true; | |
} | |
bool InputSplitBase::ExtractNextChunk(Blob *out_chunk, Chunk *chunk) { | |
if (chunk->begin == chunk->end) return false; | |
out_chunk->dptr = chunk->begin; | |
out_chunk->size = chunk->end - chunk->begin; | |
chunk->begin = chunk->end; | |
return true; | |
} | |
} // namespace io | |
} // namespace dmlc | |
//===== EXPANDED: ../dmlc-core/src/io/input_split_base.cc ===== | |
//===== EXPANDING: ../dmlc-core/src/io/local_filesys.cc ===== | |
// Copyright by Contributors | |
extern "C" { | |
} | |
#ifndef _MSC_VER | |
extern "C" { | |
} | |
#else | |
#define stat _stat64 | |
#endif | |
//===== EXPANDING: ../dmlc-core/src/io/local_filesys.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file local_filesys.h | |
* \brief local access module | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_IO_LOCAL_FILESYS_H_ | |
#define DMLC_IO_LOCAL_FILESYS_H_ | |
namespace dmlc { | |
namespace io { | |
/*! \brief local file system */ | |
class LocalFileSystem : public FileSystem { | |
public: | |
/*! \brief destructor */ | |
virtual ~LocalFileSystem() {} | |
/*! | |
* \brief get information about a path | |
* \param path the path to the file | |
* \return the information about the file | |
*/ | |
virtual FileInfo GetPathInfo(const URI &path); | |
/*! | |
* \brief list files in a directory | |
* \param path to the file | |
* \param out_list the output information about the files | |
*/ | |
virtual void ListDirectory(const URI &path, std::vector<FileInfo> *out_list); | |
/*! | |
* \brief open a stream, will report error and exit if bad thing happens | |
* NOTE: the IStream can continue to work even when filesystem was destructed | |
* \param path path to file | |
* \param uri the uri of the input | |
* \param allow_null whether NULL can be returned, or directly report error | |
* \return the created stream, can be NULL when allow_null == true and file do not exist | |
*/ | |
virtual SeekStream *Open(const URI &path, | |
const char* const flag, | |
bool allow_null); | |
/*! | |
* \brief open a seekable stream for read | |
* \param path the path to the file | |
* \param allow_null whether NULL can be returned, or directly report error | |
* \return the created stream, can be NULL when allow_null == true and file do not exist | |
*/ | |
virtual SeekStream *OpenForRead(const URI &path, bool allow_null); | |
/*! | |
* \brief get a singleton of LocalFileSystem when needed | |
* \return a singleton instance | |
*/ | |
inline static LocalFileSystem *GetInstance(void) { | |
static LocalFileSystem instance; | |
return &instance; | |
} | |
private: | |
LocalFileSystem() {} | |
}; | |
} // namespace io | |
} // namespace dmlc | |
#endif // DMLC_IO_LOCAL_FILESYS_H_ | |
//===== EXPANDED: ../dmlc-core/src/io/local_filesys.h ===== | |
#if defined(__FreeBSD__) | |
#define fopen64 std::fopen | |
#endif | |
namespace dmlc { | |
namespace io { | |
/*! \brief implementation of file i/o stream */ | |
class FileStream : public SeekStream { | |
public: | |
explicit FileStream(FILE *fp, bool use_stdio) | |
: fp_(fp), use_stdio_(use_stdio) {} | |
virtual ~FileStream(void) { | |
this->Close(); | |
} | |
virtual size_t Read(void *ptr, size_t size) { | |
return std::fread(ptr, 1, size, fp_); | |
} | |
virtual void Write(const void *ptr, size_t size) { | |
CHECK(std::fwrite(ptr, 1, size, fp_) == size) | |
<< "FileStream.Write incomplete"; | |
} | |
virtual void Seek(size_t pos) { | |
CHECK(!std::fseek(fp_, static_cast<long>(pos), SEEK_SET)); // NOLINT(*) | |
} | |
virtual size_t Tell(void) { | |
return std::ftell(fp_); | |
} | |
virtual bool AtEnd(void) const { | |
return std::feof(fp_) != 0; | |
} | |
inline void Close(void) { | |
if (fp_ != NULL && !use_stdio_) { | |
std::fclose(fp_); fp_ = NULL; | |
} | |
} | |
private: | |
std::FILE *fp_; | |
bool use_stdio_; | |
}; | |
FileInfo LocalFileSystem::GetPathInfo(const URI &path) { | |
struct stat sb; | |
if (stat(path.name.c_str(), &sb) == -1) { | |
int errsv = errno; | |
LOG(FATAL) << "LocalFileSystem.GetPathInfo " << path.name | |
<< " Error:" << strerror(errsv); | |
} | |
FileInfo ret; | |
ret.path = path; | |
ret.size = sb.st_size; | |
if ((sb.st_mode & S_IFMT) == S_IFDIR) { | |
ret.type = kDirectory; | |
} else { | |
ret.type = kFile; | |
} | |
return ret; | |
} | |
void LocalFileSystem::ListDirectory(const URI &path, std::vector<FileInfo> *out_list) { | |
#ifndef _MSC_VER | |
DIR *dir = opendir(path.name.c_str()); | |
if (dir == NULL) { | |
int errsv = errno; | |
LOG(FATAL) << "LocalFileSystem.ListDirectory " << path.str() | |
<<" error: " << strerror(errsv); | |
} | |
out_list->clear(); | |
struct dirent *ent; | |
/* print all the files and directories within directory */ | |
while ((ent = readdir(dir)) != NULL) { | |
if (!strcmp(ent->d_name, ".")) continue; | |
if (!strcmp(ent->d_name, "..")) continue; | |
URI pp = path; | |
if (pp.name[pp.name.length() - 1] != '/') { | |
pp.name += '/'; | |
} | |
pp.name += ent->d_name; | |
out_list->push_back(GetPathInfo(pp)); | |
} | |
closedir(dir); | |
#else | |
WIN32_FIND_DATA fd; | |
std::string pattern = path.name + "/*"; | |
HANDLE handle = FindFirstFile(pattern.c_str(), &fd); | |
if (handle == INVALID_HANDLE_VALUE) { | |
int errsv = GetLastError(); | |
LOG(FATAL) << "LocalFileSystem.ListDirectory " << path.str() | |
<< " error: " << strerror(errsv); | |
} | |
do { | |
if (strcmp(fd.cFileName, ".") && strcmp(fd.cFileName, "..")) { | |
URI pp = path; | |
char clast = pp.name[pp.name.length() - 1]; | |
if (pp.name == ".") { | |
pp.name = fd.cFileName; | |
} else if (clast != '/' && clast != '\\') { | |
pp.name += '/'; | |
pp.name += fd.cFileName; | |
} | |
out_list->push_back(GetPathInfo(pp)); | |
} | |
} while (FindNextFile(handle, &fd)); | |
FindClose(handle); | |
#endif | |
} | |
SeekStream *LocalFileSystem::Open(const URI &path, | |
const char* const mode, | |
bool allow_null) { | |
bool use_stdio = false; | |
FILE *fp = NULL; | |
const char *fname = path.name.c_str(); | |
using namespace std; | |
#ifndef DMLC_DISABLE_STDIN | |
if (!strcmp(fname, "stdin")) { | |
use_stdio = true; fp = stdin; | |
} | |
if (!strcmp(fname, "stdout")) { | |
use_stdio = true; fp = stdout; | |
} | |
#endif | |
if (!strncmp(fname, "file://", 7)) fname += 7; | |
if (!use_stdio) { | |
std::string flag = mode; | |
if (flag == "w") flag = "wb"; | |
if (flag == "r") flag = "rb"; | |
fp = fopen64(fname, flag.c_str()); | |
} | |
if (fp != NULL) { | |
return new FileStream(fp, use_stdio); | |
} else { | |
CHECK(allow_null) << " LocalFileSystem: fail to open \"" << path.str() << '\"'; | |
return NULL; | |
} | |
} | |
SeekStream *LocalFileSystem::OpenForRead(const URI &path, bool allow_null) { | |
return Open(path, "r", allow_null); | |
} | |
} // namespace io | |
} // namespace dmlc | |
//===== EXPANDED: ../dmlc-core/src/io/local_filesys.cc ===== | |
//===== EXPANDING: ../dmlc-core/src/data.cc ===== | |
// Copyright by Contributors | |
//===== EXPANDING: ../dmlc-core/include/dmlc/data.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file data.h | |
* \brief defines common input data structure, | |
* and interface for handling the input data | |
*/ | |
#ifndef DMLC_DATA_H_ | |
#define DMLC_DATA_H_ | |
//===== EXPANDING: ../dmlc-core/include/dmlc/registry.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file registry.h | |
* \brief Registry utility that helps to build registry singletons. | |
*/ | |
#ifndef DMLC_REGISTRY_H_ | |
#define DMLC_REGISTRY_H_ | |
//===== EXPANDING: ../dmlc-core/include/dmlc/parameter.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file parameter.h | |
* \brief Provide lightweight util to do parameter setup and checking. | |
*/ | |
#ifndef DMLC_PARAMETER_H_ | |
#define DMLC_PARAMETER_H_ | |
//===== EXPANDING: ../dmlc-core/include/dmlc/json.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file json.h | |
* \brief Lightweight JSON Reader/Writer that read save into C++ data structs. | |
* This includes STL composites and structures. | |
*/ | |
#ifndef DMLC_JSON_H_ | |
#define DMLC_JSON_H_ | |
// This code requires C++11 to compile | |
#if DMLC_USE_CXX11 | |
#if DMLC_STRICT_CXX11 && DMLC_ENABLE_RTTI | |
//===== EXPANDING: ../dmlc-core/include/dmlc/any.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file any.h | |
* \brief Container to hold any data type. | |
*/ | |
#ifndef DMLC_ANY_H_ | |
#define DMLC_ANY_H_ | |
// This code need c++11 to compile | |
namespace dmlc { | |
// forward declare any; | |
class any; | |
/*! | |
* Get a reference to content stored in the any as type T. | |
* This will cause an error if | |
* T does not match the type stored. | |
* This function is not part of std::any standard. | |
* | |
* \param src The source source any container. | |
* \return The reference of content | |
* \tparam T The type of the value to be fetched. | |
*/ | |
template<typename T> | |
inline T& get(any& src); // NOLINT(*) | |
/*! | |
* Get the const reference content stored in the any as type T. | |
* This will cause an error if | |
* T does not match the type stored. | |
* This function is not part of std::any standard. | |
* | |
* \param src The source source any container. | |
* \return The reference of content | |
* \tparam T The type of the value to be fetched. | |
*/ | |
template<typename T> | |
inline const T& get(const any& src); | |
/*! | |
* \brief An any class that is compatible to std::any in c++17. | |
* | |
* \code | |
* dmlc::any a = std::string("mydear"), b = 1; | |
* // get reference out and add it | |
* dmlc::get<int>(b) += 1; | |
* // a is now string | |
* LOG(INFO) << dmlc::get<std::string>(a); | |
* // a is now 2, the string stored will be properly destructed | |
* a = std::move(b); | |
* LOG(INFO) << dmlc::get<int>(a); | |
* \endcode | |
* \sa get | |
*/ | |
class any { | |
public: | |
/*! \brief default constructor */ | |
inline any() = default; | |
/*! | |
* \brief move constructor from another any | |
* \param other The other any to be moved | |
*/ | |
inline any(any&& other); // NOLINT(*) | |
/*! | |
* \brief copy constructor | |
* \param other The other any to be copied | |
*/ | |
inline any(const any& other); // NOLINT(*) | |
/*! | |
* \brief constructor from any types | |
* \param other The other types to be constructed into any. | |
* \tparam T The value type of other. | |
*/ | |
template<typename T> | |
inline any(T&& other); // NOLINT(*) | |
/*! \brief destructor */ | |
inline ~any(); | |
/*! | |
* \brief assign operator from other | |
* \param other The other any to be copy or moved. | |
* \return self | |
*/ | |
inline any& operator=(any&& other); | |
/*! | |
* \brief assign operator from other | |
* \param other The other any to be copy or moved. | |
* \return self | |
*/ | |
inline any& operator=(const any& other); | |
/*! | |
* \brief assign operator from any type. | |
* \param other The other any to be copy or moved. | |
* \tparam T The value type of other. | |
* \return self | |
*/ | |
template<typename T> | |
inline any& operator=(T&& other); | |
/*! | |
* \return whether the container is empty. | |
*/ | |
inline bool empty() const; | |
/*! | |
* \return clear the content of container | |
*/ | |
inline void clear(); | |
/*! | |
* swap current content with other | |
* \param other The other data to be swapped. | |
*/ | |
inline void swap(any& other); // NOLINT(*) | |
/*! | |
* \return The type_info about the stored type. | |
*/ | |
inline const std::type_info& type() const; | |
private: | |
//! \cond Doxygen_Suppress | |
// declare of helper class | |
template<typename T> | |
class TypeOnHeap; | |
template<typename T> | |
class TypeOnStack; | |
template<typename T> | |
class TypeInfo; | |
// size of stack space, it takes 32 bytes for one any type. | |
static const size_t kStack = sizeof(void*) * 3; | |
static const size_t kAlign = sizeof(void*); | |
// container use dynamic storage only when space runs lager | |
union Data { | |
// stack space | |
std::aligned_storage<kStack, kAlign>::type stack; | |
// pointer to heap space | |
void* pheap; | |
}; | |
// type specific information | |
struct Type { | |
// destructor function | |
void (*destroy)(Data* data); | |
// copy constructor | |
void (*create_from_data)(Data* dst, const Data& src); | |
// the type info function | |
const std::type_info* ptype_info; | |
}; | |
// constant to check if data can be stored on heap. | |
template<typename T> | |
struct data_on_stack { | |
static const bool value = alignof(T) <= kAlign && sizeof(T) <= kStack; | |
}; | |
// declare friend with | |
template<typename T> | |
friend T& get(any& src); // NOLINT(*) | |
template<typename T> | |
friend const T& get(const any& src); | |
// internal construct function | |
inline void construct(any&& other); | |
// internal construct function | |
inline void construct(const any& other); | |
// internal function to check if type is correct. | |
template<typename T> | |
inline void check_type() const; | |
// internal type specific information | |
const Type* type_{nullptr}; | |
// internal data | |
Data data_; | |
}; | |
template<typename T> | |
inline any::any(T&& other) { | |
typedef typename std::decay<T>::type DT; | |
if (std::is_same<DT, any>::value) { | |
this->construct(std::forward<T>(other)); | |
} else { | |
static_assert(std::is_copy_constructible<DT>::value, | |
"Any can only hold value that is copy constructable"); | |
type_ = TypeInfo<DT>::get_type(); | |
if (data_on_stack<DT>::value) { | |
new (&(data_.stack)) DT(std::forward<T>(other)); | |
} else { | |
data_.pheap = new DT(std::forward<T>(other)); | |
} | |
} | |
} | |
inline any::any(any&& other) { | |
this->construct(std::move(other)); | |
} | |
inline any::any(const any& other) { | |
this->construct(other); | |
} | |
inline void any::construct(any&& other) { | |
type_ = other.type_; | |
data_ = other.data_; | |
other.type_ = nullptr; | |
} | |
inline void any::construct(const any& other) { | |
type_ = other.type_; | |
if (type_ != nullptr) { | |
type_->create_from_data(&data_, other.data_); | |
} | |
} | |
inline any::~any() { | |
this->clear(); | |
} | |
inline any& any::operator=(any&& other) { | |
any(std::move(other)).swap(*this); | |
return *this; | |
} | |
inline any& any::operator=(const any& other) { | |
any(other).swap(*this); | |
return *this; | |
} | |
template<typename T> | |
inline any& any::operator=(T&& other) { | |
any(std::forward<T>(other)).swap(*this); | |
return *this; | |
} | |
inline void any::swap(any& other) { // NOLINT(*) | |
std::swap(type_, other.type_); | |
std::swap(data_, other.data_); | |
} | |
inline void any::clear() { | |
if (type_ != nullptr) { | |
if (type_->destroy != nullptr) { | |
type_->destroy(&data_); | |
} | |
type_ = nullptr; | |
} | |
} | |
inline bool any::empty() const { | |
return type_ == nullptr; | |
} | |
inline const std::type_info& any::type() const { | |
if (type_ != nullptr) { | |
return *(type_->ptype_info); | |
} else { | |
return typeid(void); | |
} | |
} | |
template<typename T> | |
inline void any::check_type() const { | |
CHECK(type_ != nullptr) | |
<< "The any container is empty" | |
<< " requested=" << typeid(T).name(); | |
CHECK(type_->ptype_info == &typeid(T)) | |
<< "The stored type mismatch" | |
<< " stored=" << type_->ptype_info->name() | |
<< " requested=" << typeid(T).name(); | |
} | |
template<typename T> | |
inline const T& get(const any& src) { | |
src.check_type<T>(); | |
return *any::TypeInfo<T>::get_ptr(&(src.data_)); | |
} | |
template<typename T> | |
inline T& get(any& src) { // NOLINT(*) | |
src.check_type<T>(); | |
return *any::TypeInfo<T>::get_ptr(&(src.data_)); | |
} | |
template<typename T> | |
class any::TypeOnHeap { | |
public: | |
inline static T* get_ptr(any::Data* data) { | |
return static_cast<T*>(data->pheap); | |
} | |
inline static const T* get_ptr(const any::Data* data) { | |
return static_cast<const T*>(data->pheap); | |
} | |
inline static void create_from_data(any::Data* dst, const any::Data& data) { | |
dst->pheap = new T(*get_ptr(&data)); | |
} | |
inline static void destroy(Data* data) { | |
delete static_cast<T*>(data->pheap); | |
} | |
}; | |
template<typename T> | |
class any::TypeOnStack { | |
public: | |
inline static T* get_ptr(any::Data* data) { | |
return reinterpret_cast<T*>(&(data->stack)); | |
} | |
inline static const T* get_ptr(const any::Data* data) { | |
return reinterpret_cast<const T*>(&(data->stack)); | |
} | |
inline static void create_from_data(any::Data* dst, const any::Data& data) { | |
new (&(dst->stack)) T(*get_ptr(&data)); | |
} | |
inline static void destroy(Data* data) { | |
T* dptr = reinterpret_cast<T*>(&(data->stack)); | |
dptr->~T(); | |
} | |
}; | |
template<typename T> | |
class any::TypeInfo | |
: public std::conditional<any::data_on_stack<T>::value, | |
any::TypeOnStack<T>, | |
any::TypeOnHeap<T> >::type { | |
public: | |
inline static const Type* get_type() { | |
static TypeInfo<T> tp; | |
return &(tp.type_); | |
} | |
private: | |
// local type | |
Type type_; | |
// constructor | |
TypeInfo() { | |
if (std::is_pod<T>::value) { | |
type_.destroy = nullptr; | |
} else { | |
type_.destroy = TypeInfo<T>::destroy; | |
} | |
type_.create_from_data = TypeInfo<T>::create_from_data; | |
type_.ptype_info = &typeid(T); | |
} | |
}; | |
//! \endcond | |
} // namespace dmlc | |
#endif // DMLC_ANY_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/any.h ===== | |
#endif // DMLC_STRICT_CXX11 | |
#endif // DMLC_USE_CXX11 | |
namespace dmlc { | |
/*! | |
* \brief Lightweight JSON Reader to read any STL compositions and structs. | |
* The user need to know the schema of the | |
* | |
*/ | |
class JSONReader { | |
public: | |
/*! | |
* \brief Constructor. | |
* \param is the input stream. | |
*/ | |
explicit JSONReader(std::istream *is) | |
: is_(is), | |
line_count_r_(0), | |
line_count_n_(0) {} | |
/*! | |
* \brief Parse next JSON string. | |
* \param out_str the output string. | |
* \throw dmlc::Error when next token is not string | |
*/ | |
inline void ReadString(std::string *out_str); | |
/*! | |
* \brief Read Number. | |
* \param out_value output value; | |
* \throw dmlc::Error when next token is not number of ValueType. | |
* \tparam ValueType type of the number | |
*/ | |
template<typename ValueType> | |
inline void ReadNumber(ValueType *out_value); | |
/*! | |
* \brief Begin parsing an object. | |
* \code | |
* std::string key; | |
* // value can be any type that is json serializable. | |
* std::string value; | |
* reader->BeginObject(); | |
* while (reader->NextObjectItem(&key)) { | |
* // do somthing to key value | |
* reader->Read(&value); | |
* } | |
* \endcode | |
*/ | |
inline void BeginObject(); | |
/*! | |
* \brief Begin parsing an array. | |
* \code | |
* // value can be any type that is json serializable. | |
* std::string value; | |
* reader->BeginArray(); | |
* while (reader->NextObjectArrayItem(&value)) { | |
* // do somthing to value | |
* } | |
* \endcode | |
*/ | |
inline void BeginArray(); | |
/*! | |
* \brief Try to move to next object item. | |
* If this call is successful, user can proceed to call | |
* reader->Read to read in the value. | |
* \param out_key the key to the next object. | |
* \return true if the read is successful, false if we are at end of the object. | |
*/ | |
inline bool NextObjectItem(std::string *out_key); | |
/*! | |
* \brief Try to read the next element in the array. | |
* If this call is successful, user can proceed to call | |
* reader->Read to read in the value. | |
* \return true if the read is successful, false if we are at end of the array. | |
*/ | |
inline bool NextArrayItem(); | |
/*! | |
* \brief Read next ValueType. | |
* \param out_value any STL or json readable type to be read | |
* \throw dmlc::Error when the read of ValueType is not successful. | |
* \tparam ValueType the data type to be read. | |
*/ | |
template<typename ValueType> | |
inline void Read(ValueType *out_value); | |
/*! \return current line count */ | |
inline std::string line_info() const { | |
char temp[64]; | |
std::ostringstream os; | |
os << " Line " << std::max(line_count_r_, line_count_n_); | |
is_->getline(temp, 64); | |
os << ", around ^`" << temp << "`"; | |
return os.str(); | |
} | |
private: | |
/*! \brief internal reader stream */ | |
std::istream *is_; | |
/*! \brief "\\r" counter */ | |
size_t line_count_r_; | |
/*! \brief "\\n" counter */ | |
size_t line_count_n_; | |
/*! | |
* \brief record how many element processed in | |
* current array/object scope. | |
*/ | |
std::vector<size_t> scope_counter_; | |
/*! | |
* \brief Read next nonspace character. | |
* \return the next nonspace character. | |
*/ | |
inline int NextNonSpace(); | |
/*! | |
* \brief Read just before next nonspace but not read that. | |
* \return the next nonspace character. | |
*/ | |
inline int PeekNextNonSpace(); | |
}; | |
/*! | |
* \brief Lightweight json to write any STL compositions. | |
*/ | |
class JSONWriter { | |
public: | |
/*! | |
* \brief Constructor. | |
* \param os the output stream. | |
*/ | |
explicit JSONWriter(std::ostream *os) | |
: os_(os) {} | |
/*! | |
* \brief Write a string that do not contain escape characters. | |
* \param s the string to be written. | |
*/ | |
inline void WriteNoEscape(const std::string &s); | |
/*! | |
* \brief Write a string that can contain escape characters. | |
* \param s the string to be written. | |
*/ | |
inline void WriteString(const std::string &s); | |
/*! | |
* \brief Write a string that can contain escape characters. | |
* \param v the value to be written. | |
* \tparam ValueType The value type to be written. | |
*/ | |
template<typename ValueType> | |
inline void WriteNumber(const ValueType &v); | |
/*! | |
* \brief Start beginning of array. | |
* \param multi_line whether to start an multi_line array. | |
* \code | |
* writer->BeginArray(); | |
* for (auto& v : vdata) { | |
* writer->WriteArrayItem(v); | |
* } | |
* writer->EndArray(); | |
* \endcode | |
*/ | |
inline void BeginArray(bool multi_line = true); | |
/*! \brief Finish writing an array. */ | |
inline void EndArray(); | |
/*! | |
* \brief Start beginning of array. | |
* \param multi_line whether to start an multi_line array. | |
* \code | |
* writer->BeginObject(); | |
* for (auto& kv : vmap) { | |
* writer->WriteObjectKeyValue(kv.first, kv.second); | |
* } | |
* writer->EndObject(); | |
* \endcode | |
*/ | |
inline void BeginObject(bool multi_line = true); | |
/*! \brief Finish writing object. */ | |
inline void EndObject(); | |
/*! | |
* \brief Write key value pair in the object. | |
* \param key the key of the object. | |
* \param value the value of to be written. | |
* \tparam ValueType The value type to be written. | |
*/ | |
template<typename ValueType> | |
inline void WriteObjectKeyValue(const std::string &key, | |
const ValueType &value); | |
/*! | |
* \brief Write seperator of array, before writing next element. | |
* User can proceed to call writer->Write to write next item | |
*/ | |
inline void WriteArraySeperator(); | |
/*! | |
* \brief Write value into array. | |
* \param value The value of to be written. | |
* \tparam ValueType The value type to be written. | |
*/ | |
template<typename ValueType> | |
inline void WriteArrayItem(const ValueType &value); | |
/*! | |
* \brief Write value to json. | |
* \param value any STL or json readable that can be written. | |
* \tparam ValueType the data type to be write. | |
*/ | |
template<typename ValueType> | |
inline void Write(const ValueType &value); | |
private: | |
/*! \brief Output stream */ | |
std::ostream *os_; | |
/*! | |
* \brief record how many element processed in | |
* current array/object scope. | |
*/ | |
std::vector<size_t> scope_counter_; | |
/*! \brief Record whether current is a multiline scope */ | |
std::vector<bool> scope_multi_line_; | |
/*! | |
* \brief Write seperating space and newlines | |
*/ | |
inline void WriteSeperator(); | |
}; | |
/*! | |
* \brief Helper class to read JSON into a class or struct object. | |
* \code | |
* struct Param { | |
* std::string name; | |
* int value; | |
* // define load function from JSON | |
* inline void Load(dmlc::JSONReader *reader) { | |
* dmlc::JSONStructReadHelper helper; | |
* helper.DeclareField("name", &name); | |
* helper.DeclareField("value", &value); | |
* helper.ReadAllFields(reader); | |
* } | |
* }; | |
* \endcode | |
*/ | |
class JSONObjectReadHelper { | |
public: | |
/*! | |
* \brief Declare field of type T | |
* \param key the key of the of field. | |
* \param addr address of the data type. | |
* \tparam T the data type to be read, must be STL composition of JSON serializable. | |
*/ | |
template<typename T> | |
inline void DeclareField(const std::string &key, T *addr) { | |
DeclareFieldInternal(key, addr, false); | |
} | |
/*! | |
* \brief Declare optional field of type T | |
* \param key the key of the of field. | |
* \param addr address of the data type. | |
* \tparam T the data type to be read, must be STL composition of JSON serializable. | |
*/ | |
template<typename T> | |
inline void DeclareOptionalField(const std::string &key, T *addr) { | |
DeclareFieldInternal(key, addr, true); | |
} | |
/*! | |
* \brief Read in all the declared fields. | |
* \param reader the JSONReader to read the json. | |
*/ | |
inline void ReadAllFields(JSONReader *reader); | |
private: | |
/*! | |
* \brief Internal function to declare field. | |
* \param key the key of the of field. | |
* \param addr address of the data type. | |
* \param optional if set to true, no error will be reported if the key is not presented. | |
* \tparam T the data type to be read, must be STL composition of JSON serializable. | |
*/ | |
template<typename T> | |
inline void DeclareFieldInternal(const std::string &key, T *addr, bool optional); | |
/*! | |
* \brief The internal reader function. | |
* \param reader The reader to read. | |
* \param addr The memory address to read. | |
*/ | |
template<typename T> | |
inline static void ReaderFunction(JSONReader *reader, void *addr); | |
/*! \brief callback type to reader function */ | |
typedef void (*ReadFunction)(JSONReader *reader, void *addr); | |
/*! \brief internal data entry */ | |
struct Entry { | |
/*! \brief the reader function */ | |
ReadFunction func; | |
/*! \brief the address to read */ | |
void *addr; | |
/*! \brief whether it is optional */ | |
bool optional; | |
}; | |
/*! \brief the internal map of reader callbacks */ | |
std::map<std::string, Entry> map_; | |
}; | |
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \ | |
static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \ | |
__make_AnyJSONType ## _ ## KeyName ## __ | |
/*! | |
* \def DMLC_JSON_ENABLE_ANY | |
* \brief Macro to enable save/load JSON of dmlc:: whose actual type is Type. | |
* Any type will be saved as json array [KeyName, content] | |
* | |
* \param Type The type to be registered. | |
* \param KeyName The Type key assigned to the type, must be same during load. | |
*/ | |
#define DMLC_JSON_ENABLE_ANY(Type, KeyName) \ | |
DMLC_STR_CONCAT(DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName), __COUNTER__) = \ | |
::dmlc::json::AnyJSONManager::Global()->EnableType<Type>(#KeyName) \ | |
//! \cond Doxygen_Suppress | |
namespace json { | |
/*! | |
* \brief generic serialization handler | |
* \tparam T the type to be serialized | |
*/ | |
template<typename T> | |
struct Handler; | |
template<typename ValueType> | |
struct NumericHandler { | |
inline static void Write(JSONWriter *writer, const ValueType &value) { | |
writer->WriteNumber<ValueType>(value); | |
} | |
inline static void Read(JSONReader *reader, ValueType *value) { | |
reader->ReadNumber<ValueType>(value); | |
} | |
}; | |
template<typename ContainerType> | |
struct ArrayHandler { | |
inline static void Write(JSONWriter *writer, const ContainerType &array) { | |
typedef typename ContainerType::value_type ElemType; | |
writer->BeginArray(array.size() > 10 || !dmlc::is_pod<ElemType>::value); | |
for (typename ContainerType::const_iterator it = array.begin(); | |
it != array.end(); ++it) { | |
writer->WriteArrayItem(*it); | |
} | |
writer->EndArray(); | |
} | |
inline static void Read(JSONReader *reader, ContainerType *array) { | |
typedef typename ContainerType::value_type ElemType; | |
array->clear(); | |
reader->BeginArray(); | |
while (reader->NextArrayItem()) { | |
ElemType value; | |
Handler<ElemType>::Read(reader, &value); | |
array->insert(array->end(), value); | |
} | |
} | |
}; | |
template<typename ContainerType> | |
struct MapHandler{ | |
inline static void Write(JSONWriter *writer, const ContainerType &map) { | |
writer->BeginObject(map.size() > 1); | |
for (typename ContainerType::const_iterator it = map.begin(); it != map.end(); ++it) { | |
writer->WriteObjectKeyValue(it->first, it->second); | |
} | |
writer->EndObject(); | |
} | |
inline static void Read(JSONReader *reader, ContainerType *map) { | |
typedef typename ContainerType::mapped_type ElemType; | |
map->clear(); | |
reader->BeginObject(); | |
std::string key; | |
while (reader->NextObjectItem(&key)) { | |
ElemType value; | |
reader->Read(&value); | |
(*map)[key] = value; | |
} | |
} | |
}; | |
template<typename T> | |
struct CommonJSONSerializer { | |
inline static void Write(JSONWriter *writer, const T &value) { | |
value.Save(writer); | |
} | |
inline static void Read(JSONReader *reader, T *value) { | |
value->Load(reader); | |
} | |
}; | |
template<> | |
struct Handler<std::string> { | |
inline static void Write(JSONWriter *writer, const std::string &value) { | |
writer->WriteString(value); | |
} | |
inline static void Read(JSONReader *reader, std::string *str) { | |
reader->ReadString(str); | |
} | |
}; | |
template<typename T> | |
struct Handler<std::vector<T> > : public ArrayHandler<std::vector<T> > { | |
}; | |
template<typename K, typename V> | |
struct Handler<std::pair<K, V> > { | |
inline static void Write(JSONWriter *writer, const std::pair<K, V> &kv) { | |
writer->BeginArray(); | |
writer->WriteArrayItem(kv.first); | |
writer->WriteArrayItem(kv.second); | |
writer->EndArray(); | |
} | |
inline static void Read(JSONReader *reader, std::pair<K, V> *kv) { | |
reader->BeginArray(); | |
CHECK(reader->NextArrayItem()) | |
<< "Expect array of length 2"; | |
Handler<K>::Read(reader, &(kv->first)); | |
CHECK(reader->NextArrayItem()) | |
<< "Expect array of length 2"; | |
Handler<V>::Read(reader, &(kv->second)); | |
CHECK(!reader->NextArrayItem()) | |
<< "Expect array of length 2"; | |
} | |
}; | |
template<typename T> | |
struct Handler<std::list<T> > : public ArrayHandler<std::list<T> > { | |
}; | |
template<typename V> | |
struct Handler<std::map<std::string, V> > : public MapHandler<std::map<std::string, V> > { | |
}; | |
#if DMLC_USE_CXX11 | |
template<typename V> | |
struct Handler<std::unordered_map<std::string, V> > | |
: public MapHandler<std::unordered_map<std::string, V> > { | |
}; | |
#endif // DMLC_USE_CXX11 | |
template<typename T> | |
struct Handler { | |
inline static void Write(JSONWriter *writer, const T &data) { | |
typedef typename dmlc::IfThenElseType<dmlc::is_arithmetic<T>::value, | |
NumericHandler<T>, | |
CommonJSONSerializer<T> >::Type THandler; | |
THandler::Write(writer, data); | |
} | |
inline static void Read(JSONReader *reader, T *data) { | |
typedef typename dmlc::IfThenElseType<dmlc::is_arithmetic<T>::value, | |
NumericHandler<T>, | |
CommonJSONSerializer<T> >::Type THandler; | |
THandler::Read(reader, data); | |
} | |
}; | |
#if DMLC_STRICT_CXX11 && DMLC_ENABLE_RTTI | |
// Manager to store json serialization strategy. | |
class AnyJSONManager { | |
public: | |
template<typename T> | |
inline AnyJSONManager& EnableType(const std::string& type_name) { // NOLINT(*) | |
std::type_index tp = std::type_index(typeid(T)); | |
if (type_name_.count(tp) != 0) { | |
CHECK(type_name_.at(tp) == type_name) | |
<< "Type has already been registered as another typename " << type_name_.at(tp); | |
return *this; | |
} | |
CHECK(type_map_.count(type_name) == 0) | |
<< "Type name " << type_name << " already registered in registry"; | |
Entry e; | |
e.read = ReadAny<T>; | |
e.write = WriteAny<T>; | |
type_name_[tp] = type_name; | |
type_map_[type_name] = e; | |
return *this; | |
} | |
// return global singleton | |
inline static AnyJSONManager* Global() { | |
static AnyJSONManager inst; | |
return &inst; | |
} | |
private: | |
AnyJSONManager() {} | |
template<typename T> | |
inline static void WriteAny(JSONWriter *writer, const any &data) { | |
writer->Write(dmlc::get<T>(data)); | |
} | |
template<typename T> | |
inline static void ReadAny(JSONReader *reader, any* data) { | |
T temp; | |
reader->Read(&temp); | |
*data = std::move(temp); | |
} | |
// data entry to store vtable for any type | |
struct Entry { | |
void (*read)(JSONReader* reader, any *data); | |
void (*write)(JSONWriter* reader, const any& data); | |
}; | |
template<typename T> | |
friend struct Handler; | |
std::unordered_map<std::type_index, std::string> type_name_; | |
std::unordered_map<std::string, Entry> type_map_; | |
}; | |
template<> | |
struct Handler<any> { | |
inline static void Write(JSONWriter *writer, const any &data) { | |
std::unordered_map<std::type_index, std::string>& | |
nmap = AnyJSONManager::Global()->type_name_; | |
std::type_index id = std::type_index(data.type()); | |
auto it = nmap.find(id); | |
CHECK(it != nmap.end() && it->first == id) | |
<< "Type " << id.name() << " has not been registered via DMLC_JSON_ENABLE_ANY"; | |
std::string type_name = it->second; | |
AnyJSONManager::Entry e = AnyJSONManager::Global()->type_map_.at(type_name); | |
writer->BeginArray(false); | |
writer->WriteArrayItem(type_name); | |
writer->WriteArraySeperator(); | |
e.write(writer, data); | |
writer->EndArray(); | |
} | |
inline static void Read(JSONReader *reader, any *data) { | |
std::string type_name; | |
reader->BeginArray(); | |
CHECK(reader->NextArrayItem()) << "invalid any json format"; | |
Handler<std::string>::Read(reader, &type_name); | |
std::unordered_map<std::string, AnyJSONManager::Entry>& | |
tmap = AnyJSONManager::Global()->type_map_; | |
auto it = tmap.find(type_name); | |
CHECK(it != tmap.end() && it->first == type_name) | |
<< "Typename " << type_name << " has not been registered via DMLC_JSON_ENABLE_ANY"; | |
AnyJSONManager::Entry e = it->second; | |
CHECK(reader->NextArrayItem()) << "invalid any json format"; | |
e.read(reader, data); | |
CHECK(!reader->NextArrayItem()) << "invalid any json format"; | |
} | |
}; | |
#endif // DMLC_STRICT_CXX11 | |
} // namespace json | |
// implementations of JSONReader/Writer | |
inline int JSONReader::NextNonSpace() { | |
int ch; | |
do { | |
ch = is_->get(); | |
if (ch == '\n') ++line_count_n_; | |
if (ch == '\r') ++line_count_r_; | |
} while (isspace(ch)); | |
return ch; | |
} | |
inline int JSONReader::PeekNextNonSpace() { | |
int ch; | |
while (true) { | |
ch = is_->peek(); | |
if (ch == '\n') ++line_count_n_; | |
if (ch == '\r') ++line_count_r_; | |
if (!isspace(ch)) break; | |
is_->get(); | |
} | |
return ch; | |
} | |
inline void JSONReader::ReadString(std::string *out_str) { | |
int ch = NextNonSpace(); | |
CHECK_EQ(ch, '\"') | |
<< "Error at" << line_info() | |
<< ", Expect \'\"\' but get \'" << static_cast<char>(ch) << '\''; | |
std::ostringstream os; | |
while (true) { | |
ch = is_->get(); | |
if (ch == '\\') { | |
char sch = static_cast<char>(is_->get()); | |
switch (sch) { | |
case 'r': os << "\r"; break; | |
case 'n': os << "\n"; break; | |
case '\\': os << "\\"; break; | |
case '\t': os << "\t"; break; | |
case '\"': os << "\""; break; | |
default: LOG(FATAL) << "unknown string escape \\" << sch; | |
} | |
} else { | |
if (ch == '\"') break; | |
os << static_cast<char>(ch); | |
} | |
if (ch == EOF || ch == '\r' || ch == '\n') { | |
LOG(FATAL) | |
<< "Error at" << line_info() | |
<< ", Expect \'\"\' but reach end of line "; | |
} | |
} | |
*out_str = os.str(); | |
} | |
template<typename ValueType> | |
inline void JSONReader::ReadNumber(ValueType *out_value) { | |
*is_ >> *out_value; | |
CHECK(!is_->fail()) | |
<< "Error at" << line_info() | |
<< ", Expect number"; | |
} | |
inline void JSONReader::BeginObject() { | |
int ch = NextNonSpace(); | |
CHECK_EQ(ch, '{') | |
<< "Error at" << line_info() | |
<< ", Expect \'{\' but get \'" << static_cast<char>(ch) << '\''; | |
scope_counter_.push_back(0); | |
} | |
inline void JSONReader::BeginArray() { | |
int ch = NextNonSpace(); | |
CHECK_EQ(ch, '[') | |
<< "Error at" << line_info() | |
<< ", Expect \'{\' but get \'" << static_cast<char>(ch) << '\''; | |
scope_counter_.push_back(0); | |
} | |
inline bool JSONReader::NextObjectItem(std::string *out_key) { | |
bool next = true; | |
if (scope_counter_.back() != 0) { | |
int ch = NextNonSpace(); | |
if (ch == EOF) { | |
next = false; | |
} else if (ch == '}') { | |
next = false; | |
} else { | |
CHECK_EQ(ch, ',') | |
<< "Error at" << line_info() | |
<< ", JSON object expect \'}\' or \',\' \'" << static_cast<char>(ch) << '\''; | |
} | |
} else { | |
int ch = PeekNextNonSpace(); | |
if (ch == '}') { | |
is_->get(); | |
next = false; | |
} | |
} | |
if (!next) { | |
scope_counter_.pop_back(); | |
return false; | |
} else { | |
scope_counter_.back() += 1; | |
ReadString(out_key); | |
int ch = NextNonSpace(); | |
CHECK_EQ(ch, ':') | |
<< "Error at" << line_info() | |
<< ", Expect \':\' but get \'" << static_cast<char>(ch) << '\''; | |
return true; | |
} | |
} | |
inline bool JSONReader::NextArrayItem() { | |
bool next = true; | |
if (scope_counter_.back() != 0) { | |
int ch = NextNonSpace(); | |
if (ch == EOF) { | |
next = false; | |
} else if (ch == ']') { | |
next = false; | |
} else { | |
CHECK_EQ(ch, ',') | |
<< "Error at" << line_info() | |
<< ", JSON array expect \']\' or \',\'. Get \'" << static_cast<char>(ch) << "\' instead"; | |
} | |
} else { | |
int ch = PeekNextNonSpace(); | |
if (ch == ']') { | |
is_->get(); | |
next = false; | |
} | |
} | |
if (!next) { | |
scope_counter_.pop_back(); | |
return false; | |
} else { | |
scope_counter_.back() += 1; | |
return true; | |
} | |
} | |
template<typename ValueType> | |
inline void JSONReader::Read(ValueType *out_value) { | |
json::Handler<ValueType>::Read(this, out_value); | |
} | |
inline void JSONWriter::WriteNoEscape(const std::string &s) { | |
*os_ << '\"' << s << '\"'; | |
} | |
inline void JSONWriter::WriteString(const std::string &s) { | |
std::ostream &os = *os_; | |
os << '\"'; | |
for (size_t i = 0; i < s.length(); ++i) { | |
char ch = s[i]; | |
switch (ch) { | |
case '\r': os << "\\r"; break; | |
case '\n': os << "\\n"; break; | |
case '\\': os << "\\\\"; break; | |
case '\t': os << "\\t"; break; | |
case '\"': os << "\\\""; break; | |
default: os << ch; | |
} | |
} | |
os << '\"'; | |
} | |
template<typename ValueType> | |
inline void JSONWriter::WriteNumber(const ValueType &v) { | |
*os_ << v; | |
} | |
inline void JSONWriter::BeginArray(bool multi_line) { | |
*os_ << '['; | |
scope_multi_line_.push_back(multi_line); | |
scope_counter_.push_back(0); | |
} | |
inline void JSONWriter::EndArray() { | |
CHECK_NE(scope_multi_line_.size(), 0U); | |
CHECK_NE(scope_counter_.size(), 0U); | |
bool newline = scope_multi_line_.back(); | |
size_t nelem = scope_counter_.back(); | |
scope_multi_line_.pop_back(); | |
scope_counter_.pop_back(); | |
if (newline && nelem != 0) WriteSeperator(); | |
*os_ << ']'; | |
} | |
inline void JSONWriter::BeginObject(bool multi_line) { | |
*os_ << "{"; | |
scope_multi_line_.push_back(multi_line); | |
scope_counter_.push_back(0); | |
} | |
inline void JSONWriter::EndObject() { | |
CHECK_NE(scope_multi_line_.size(), 0U); | |
CHECK_NE(scope_counter_.size(), 0U); | |
bool newline = scope_multi_line_.back(); | |
size_t nelem = scope_counter_.back(); | |
scope_multi_line_.pop_back(); | |
scope_counter_.pop_back(); | |
if (newline && nelem != 0) WriteSeperator(); | |
*os_ << '}'; | |
} | |
template<typename ValueType> | |
inline void JSONWriter::WriteObjectKeyValue(const std::string &key, | |
const ValueType &value) { | |
std::ostream &os = *os_; | |
if (scope_counter_.back() == 0) { | |
WriteSeperator(); | |
os << '\"' << key << "\": "; | |
} else { | |
os << ", "; | |
WriteSeperator(); | |
os << '\"' << key << "\": "; | |
} | |
scope_counter_.back() += 1; | |
json::Handler<ValueType>::Write(this, value); | |
} | |
inline void JSONWriter::WriteArraySeperator() { | |
std::ostream &os = *os_; | |
if (scope_counter_.back() != 0) { | |
os << ", "; | |
} | |
scope_counter_.back() += 1; | |
WriteSeperator(); | |
} | |
template<typename ValueType> | |
inline void JSONWriter::WriteArrayItem(const ValueType &value) { | |
this->WriteArraySeperator(); | |
json::Handler<ValueType>::Write(this, value); | |
} | |
template<typename ValueType> | |
inline void JSONWriter::Write(const ValueType &value) { | |
size_t nscope = scope_multi_line_.size(); | |
json::Handler<ValueType>::Write(this, value); | |
CHECK_EQ(nscope, scope_multi_line_.size()) | |
<< "Uneven scope, did you call EndArray/EndObject after each BeginObject/Array?"; | |
} | |
inline void JSONWriter::WriteSeperator() { | |
if (scope_multi_line_.size() == 0 || scope_multi_line_.back()) { | |
*os_ << '\n' << std::string(scope_multi_line_.size() * 2, ' '); | |
} | |
} | |
inline void JSONObjectReadHelper::ReadAllFields(JSONReader *reader) { | |
reader->BeginObject(); | |
std::map<std::string, int> visited; | |
std::string key; | |
while (reader->NextObjectItem(&key)) { | |
if (map_.count(key) != 0) { | |
Entry e = map_[key]; | |
(*e.func)(reader, e.addr); | |
visited[key] = 0; | |
} else { | |
std::ostringstream os; | |
os << "JSONReader: Unknown field " << key << ", candidates are: \n"; | |
for (std::map<std::string, Entry>::iterator | |
it = map_.begin(); it != map_.end(); ++it) { | |
os << '\"' <<it->first << "\"\n"; | |
} | |
LOG(FATAL) << os.str(); | |
} | |
} | |
if (visited.size() != map_.size()) { | |
for (std::map<std::string, Entry>::iterator | |
it = map_.begin(); it != map_.end(); ++it) { | |
if (it->second.optional) continue; | |
CHECK_NE(visited.count(it->first), 0U) | |
<< "JSONReader: Missing field \"" << it->first << "\"\n At " | |
<< reader->line_info(); | |
} | |
} | |
} | |
template<typename T> | |
inline void JSONObjectReadHelper::ReaderFunction(JSONReader *reader, void *addr) { | |
json::Handler<T>::Read(reader, static_cast<T*>(addr)); | |
} | |
template<typename T> | |
inline void JSONObjectReadHelper:: | |
DeclareFieldInternal(const std::string &key, T *addr, bool optional) { | |
CHECK_EQ(map_.count(key), 0U) | |
<< "Adding duplicate field " << key; | |
Entry e; | |
e.func = ReaderFunction<T>; | |
e.addr = static_cast<void*>(addr); | |
e.optional = optional; | |
map_[key] = e; | |
} | |
//! \endcond | |
} // namespace dmlc | |
#endif // DMLC_JSON_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/json.h ===== | |
//===== EXPANDING: ../dmlc-core/include/dmlc/optional.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file optional.h | |
* \brief Container to hold optional data. | |
*/ | |
#ifndef DMLC_OPTIONAL_H_ | |
#define DMLC_OPTIONAL_H_ | |
namespace dmlc { | |
/*! \brief dummy type for assign null to optional */ | |
struct nullopt_t { | |
#if defined(_MSC_VER) && _MSC_VER < 1900 | |
/*! \brief dummy constructor */ | |
explicit nullopt_t(int a) {} | |
#else | |
/*! \brief dummy constructor */ | |
constexpr nullopt_t(int a) {} | |
#endif | |
}; | |
/*! Assign null to optional: optional<T> x = nullopt; */ | |
constexpr const nullopt_t nullopt = nullopt_t(0); | |
/*! | |
* \brief c++17 compatible optional class. | |
* | |
* At any time an optional<T> instance either | |
* hold no value (string representation "None") | |
* or hold a value of type T. | |
*/ | |
template<typename T> | |
class optional { | |
public: | |
/*! \brief construct an optional object that contains no value */ | |
optional() : is_none(true) {} | |
/*! \brief construct an optional object with value */ | |
explicit optional(const T& value) { | |
is_none = false; | |
new (&val) T(value); | |
} | |
/*! \brief construct an optional object with another optional object */ | |
optional(const optional<T>& other) { | |
is_none = other.is_none; | |
if (!is_none) { | |
new (&val) T(other.value()); | |
} | |
} | |
/*! \brief deconstructor */ | |
~optional() { | |
if (!is_none) { | |
reinterpret_cast<T*>(&val)->~T(); | |
} | |
} | |
/*! \brief swap two optional */ | |
void swap(optional<T>& other) { | |
std::swap(val, other.val); | |
std::swap(is_none, other.is_none); | |
} | |
/*! \brief set this object to hold value | |
* \param value the value to hold | |
* \return return self to support chain assignment | |
*/ | |
optional<T>& operator=(const T& value) { | |
(optional<T>(value)).swap(*this); | |
return *this; | |
} | |
/*! \brief set this object to hold the same value with other | |
* \param other the other object | |
* \return return self to support chain assignment | |
*/ | |
optional<T>& operator=(const optional<T> &other) { | |
(optional<T>(other)).swap(*this); | |
return *this; | |
} | |
/*! \brief clear the value this object is holding. | |
* optional<T> x = nullopt; | |
*/ | |
optional<T>& operator=(nullopt_t) { | |
(optional<T>()).swap(*this); | |
return *this; | |
} | |
/*! \brief non-const dereference operator */ | |
T& operator*() { // NOLINT(*) | |
return *reinterpret_cast<T*>(&val); | |
} | |
/*! \brief const dereference operator */ | |
const T& operator*() const { | |
return *reinterpret_cast<const T*>(&val); | |
} | |
/*! \brief return the holded value. | |
* throws std::logic_error if holding no value | |
*/ | |
const T& value() const { | |
if (is_none) { | |
throw std::logic_error("bad optional access"); | |
} | |
return *reinterpret_cast<const T*>(&val); | |
} | |
/*! \brief whether this object is holding a value */ | |
explicit operator bool() const { return !is_none; } | |
private: | |
// whether this is none | |
bool is_none; | |
// on stack storage of value | |
typename std::aligned_storage<sizeof(T), alignof(T)>::type val; | |
}; | |
/*! \brief serialize an optional object to string. | |
* | |
* \code | |
* dmlc::optional<int> x; | |
* std::cout << x; // None | |
* x = 0; | |
* std::cout << x; // 0 | |
* \endcode | |
* | |
* \param os output stream | |
* \param t source optional<T> object | |
* \return output stream | |
*/ | |
template<typename T> | |
std::ostream &operator<<(std::ostream &os, const optional<T> &t) { | |
if (t) { | |
os << *t; | |
} else { | |
os << "None"; | |
} | |
return os; | |
} | |
/*! \brief parse a string object into optional<T> | |
* | |
* \code | |
* dmlc::optional<int> x; | |
* std::string s1 = "1"; | |
* std::istringstream is1(s1); | |
* s1 >> x; // x == optional<int>(1) | |
* | |
* std::string s2 = "None"; | |
* std::istringstream is2(s2); | |
* s2 >> x; // x == optional<int>() | |
* \endcode | |
* | |
* \param is input stream | |
* \param t target optional<T> object | |
* \return input stream | |
*/ | |
template<typename T> | |
std::istream &operator>>(std::istream &is, optional<T> &t) { | |
char buf[4]; | |
std::streampos origin = is.tellg(); | |
is.read(buf, 4); | |
if (is.fail() || buf[0] != 'N' || buf[1] != 'o' || | |
buf[2] != 'n' || buf[3] != 'e') { | |
is.clear(); | |
is.seekg(origin); | |
T x; | |
is >> x; | |
t = x; | |
} else { | |
t = nullopt; | |
} | |
return is; | |
} | |
/*! \brief description for optional int */ | |
DMLC_DECLARE_TYPE_NAME(optional<int>, "int or None"); | |
} // namespace dmlc | |
#endif // DMLC_OPTIONAL_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/optional.h ===== | |
namespace dmlc { | |
// this file is backward compatible with non-c++11 | |
/*! \brief Error throwed by parameter checking */ | |
struct ParamError : public dmlc::Error { | |
/*! | |
* \brief constructor | |
* \param msg error message | |
*/ | |
explicit ParamError(const std::string &msg) | |
: dmlc::Error(msg) {} | |
}; | |
/*! | |
* \brief Get environment variable with default. | |
* \param key the name of environment variable. | |
* \param default_value the default value of environment vriable. | |
* \return The value received | |
*/ | |
template<typename ValueType> | |
inline ValueType GetEnv(const char *key, | |
ValueType default_value); | |
/*! \brief internal namespace for parameter manangement */ | |
namespace parameter { | |
// forward declare ParamManager | |
class ParamManager; | |
// forward declare FieldAccessEntry | |
class FieldAccessEntry; | |
// forward declare FieldEntry | |
template<typename DType> | |
class FieldEntry; | |
// forward declare ParamManagerSingleton | |
template<typename PType> | |
struct ParamManagerSingleton; | |
/*! \brief option in parameter initialization */ | |
enum ParamInitOption { | |
/*! \brief allow unknown parameters */ | |
kAllowUnknown, | |
/*! \brief need to match exact parameters */ | |
kAllMatch, | |
/*! \brief allow unmatched hidden field with format __*__ */ | |
kAllowHidden | |
}; | |
} // namespace parameter | |
/*! | |
* \brief Information about a parameter field in string representations. | |
*/ | |
struct ParamFieldInfo { | |
/*! \brief name of the field */ | |
std::string name; | |
/*! \brief type of the field in string format */ | |
std::string type; | |
/*! | |
* \brief detailed type information string | |
* This include the default value, enum constran and typename. | |
*/ | |
std::string type_info_str; | |
/*! \brief detailed description of the type */ | |
std::string description; | |
}; | |
/*! | |
* \brief Parameter is the base type every parameter struct should inheritate from | |
* The following code is a complete example to setup parameters. | |
* \code | |
* struct Param : public dmlc::Parameter<Param> { | |
* float learning_rate; | |
* int num_hidden; | |
* std::string name; | |
* // declare parameters in header file | |
* DMLC_DECLARE_PARAMETER(Param) { | |
* DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000); | |
* DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f); | |
* DMLC_DECLARE_FIELD(name).set_default("hello"); | |
* } | |
* }; | |
* // register it in cc file | |
* DMLC_REGISTER_PARAMETER(Param); | |
* \endcode | |
* | |
* After that, the Param struct will get all the functions defined in Parameter. | |
* \tparam PType the type of parameter struct | |
* | |
* \sa DMLC_DECLARE_FIELD, DMLC_REGISTER_PARAMETER, DMLC_DECLARE_PARAMETER | |
*/ | |
template<typename PType> | |
struct Parameter { | |
public: | |
/*! | |
* \brief initialize the parameter by keyword arguments. | |
* This function will initialize the parameter struct, check consistency | |
* and throw error if something wrong happens. | |
* | |
* \param kwargs map of keyword arguments, or vector of pairs | |
* \parma option The option on initialization. | |
* \tparam Container container type | |
* \throw ParamError when something go wrong. | |
*/ | |
template<typename Container> | |
inline void Init(const Container &kwargs, | |
parameter::ParamInitOption option = parameter::kAllowHidden) { | |
PType::__MANAGER__()->RunInit(static_cast<PType*>(this), | |
kwargs.begin(), kwargs.end(), | |
NULL, | |
option); | |
} | |
/*! | |
* \brief initialize the parameter by keyword arguments. | |
* This is same as Init, but allow unknown arguments. | |
* | |
* \param kwargs map of keyword arguments, or vector of pairs | |
* \tparam Container container type | |
* \throw ParamError when something go wrong. | |
* \return vector of pairs of unknown arguments. | |
*/ | |
template<typename Container> | |
inline std::vector<std::pair<std::string, std::string> > | |
InitAllowUnknown(const Container &kwargs) { | |
std::vector<std::pair<std::string, std::string> > unknown; | |
PType::__MANAGER__()->RunInit(static_cast<PType*>(this), | |
kwargs.begin(), kwargs.end(), | |
&unknown, parameter::kAllowUnknown); | |
return unknown; | |
} | |
/*! | |
* \brief Return a dictionary representation of the parameters | |
* \return A dictionary that maps key -> value | |
*/ | |
inline std::map<std::string, std::string> __DICT__() const { | |
std::vector<std::pair<std::string, std::string> > vec | |
= PType::__MANAGER__()->GetDict(this->head()); | |
return std::map<std::string, std::string>(vec.begin(), vec.end()); | |
} | |
/*! | |
* \brief Write the parameters in JSON format. | |
* \param writer JSONWriter used for writing. | |
*/ | |
inline void Save(dmlc::JSONWriter *writer) const { | |
writer->Write(this->__DICT__()); | |
} | |
/*! | |
* \brief Load the parameters from JSON. | |
* \param reader JSONReader used for loading. | |
* \throw ParamError when something go wrong. | |
*/ | |
inline void Load(dmlc::JSONReader *reader) { | |
std::map<std::string, std::string> kwargs; | |
reader->Read(&kwargs); | |
this->Init(kwargs); | |
} | |
/*! | |
* \brief Get the fields of the parameters. | |
* \return List of ParamFieldInfo of each field. | |
*/ | |
inline static std::vector<ParamFieldInfo> __FIELDS__() { | |
return PType::__MANAGER__()->GetFieldInfo(); | |
} | |
/*! | |
* \brief Print docstring of the parameter | |
* \return the printed docstring | |
*/ | |
inline static std::string __DOC__() { | |
std::ostringstream os; | |
PType::__MANAGER__()->PrintDocString(os); | |
return os.str(); | |
} | |
protected: | |
/*! | |
* \brief internal function to allow declare of a parameter memember | |
* \param manager the parameter manager | |
* \param key the key name of the parameter | |
* \param ref the reference to the parameter in the struct. | |
*/ | |
template<typename DType> | |
inline parameter::FieldEntry<DType>& DECLARE( | |
parameter::ParamManagerSingleton<PType> *manager, | |
const std::string &key, DType &ref) { // NOLINT(*) | |
parameter::FieldEntry<DType> *e = | |
new parameter::FieldEntry<DType>(); | |
e->Init(key, this->head(), ref); | |
manager->manager.AddEntry(key, e); | |
return *e; | |
} | |
private: | |
/*! \return Get head pointer of child structure */ | |
inline PType *head() const { | |
return static_cast<PType*>(const_cast<Parameter<PType>*>(this)); | |
} | |
}; | |
//! \cond Doxygen_Suppress | |
/*! | |
* \brief macro used to declare parameter | |
* | |
* Example: | |
* \code | |
* struct Param : public dmlc::Parameter<Param> { | |
* // declare parameters in header file | |
* DMLC_DECLARE_PARAMETER(Param) { | |
* // details of declarations | |
* } | |
* }; | |
* \endcode | |
* | |
* This macro need to be put in a source file so that registeration only happens once. | |
* Refer to example code in Parameter for details | |
* | |
* \param PType the name of parameter struct. | |
* \sa Parameter | |
*/ | |
#define DMLC_DECLARE_PARAMETER(PType) \ | |
static ::dmlc::parameter::ParamManager *__MANAGER__(); \ | |
inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \ | |
/*! | |
* \brief macro to declare fields | |
* \param FieldName the name of the field. | |
*/ | |
#define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName) | |
/*! | |
* \brief macro to declare alias of a fields | |
* \param FieldName the name of the field. | |
* \param AliasName the name of the alias, must be declared after the field is declared. | |
*/ | |
#define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName) | |
/*! | |
* \brief Macro used to register parameter. | |
* | |
* This macro need to be put in a source file so that registeration only happens once. | |
* Refer to example code in Parameter for details | |
* \param PType the type of parameter struct. | |
* \sa Parameter | |
*/ | |
#define DMLC_REGISTER_PARAMETER(PType) \ | |
::dmlc::parameter::ParamManager *PType::__MANAGER__() { \ | |
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \ | |
return &inst.manager; \ | |
} \ | |
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \ | |
__make__ ## PType ## ParamManager__ = \ | |
(*PType::__MANAGER__()) \ | |
//! \endcond | |
/*! | |
* \brief internal namespace for parameter manangement | |
* There is no need to use it directly in normal case | |
*/ | |
namespace parameter { | |
/*! | |
* \brief FieldAccessEntry interface to help manage the parameters | |
* Each entry can be used to access one parameter in the Parameter struct. | |
* | |
* This is an internal interface used that is used to manage parameters | |
*/ | |
class FieldAccessEntry { | |
public: | |
FieldAccessEntry() | |
: has_default_(false) {} | |
/*! \brief destructor */ | |
virtual ~FieldAccessEntry() {} | |
/*! | |
* \brief set the default value. | |
* \param head the pointer to the head of the struct | |
* \throw error if no default is presented | |
*/ | |
virtual void SetDefault(void *head) const = 0; | |
/*! | |
* \brief set the parameter by string value | |
* \param head the pointer to the head of the struct | |
* \param value the value to be set | |
*/ | |
virtual void Set(void *head, const std::string &value) const = 0; | |
// check if value is OK | |
virtual void Check(void *head) const {} | |
/*! | |
* \brief get the string representation of value. | |
* \param head the pointer to the head of the struct | |
*/ | |
virtual std::string GetStringValue(void *head) const = 0; | |
/*! | |
* \brief Get field information | |
* \return the corresponding field information | |
*/ | |
virtual ParamFieldInfo GetFieldInfo() const = 0; | |
protected: | |
/*! \brief whether this parameter have default value */ | |
bool has_default_; | |
/*! \brief positional index of parameter in struct */ | |
size_t index_; | |
/*! \brief parameter key name */ | |
std::string key_; | |
/*! \brief parameter type */ | |
std::string type_; | |
/*! \brief description of the parameter */ | |
std::string description_; | |
/*! | |
* \brief print string representation of default value | |
* \parma os the stream to print the docstring to. | |
*/ | |
virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*) | |
// allow ParamManager to modify self | |
friend class ParamManager; | |
}; | |
/*! | |
* \brief manager class to handle parameter structure for each type | |
* An manager will be created for each parameter structure. | |
*/ | |
class ParamManager { | |
public: | |
/*! \brief destructor */ | |
~ParamManager() { | |
for (size_t i = 0; i < entry_.size(); ++i) { | |
delete entry_[i]; | |
} | |
} | |
/*! | |
* \brief find the access entry by parameter key | |
* \param key the key of the parameter. | |
* \return pointer to FieldAccessEntry, NULL if nothing is found. | |
*/ | |
inline FieldAccessEntry *Find(const std::string &key) const { | |
std::map<std::string, FieldAccessEntry*>::const_iterator it = | |
entry_map_.find(key); | |
if (it == entry_map_.end()) return NULL; | |
return it->second; | |
} | |
/*! | |
* \brief set parameter by keyword arguments. | |
* \param head head to the parameter field. | |
* \param begin begin iterator of original kwargs | |
* \param end end iterator of original kwargs | |
* \param unknown_args optional, used to hold unknown arguments | |
* When it is specified, unknown arguments will be stored into here, instead of raise an error | |
* \tparam RandomAccessIterator iterator type | |
* \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing. | |
*/ | |
template<typename RandomAccessIterator> | |
inline void RunInit(void *head, | |
RandomAccessIterator begin, | |
RandomAccessIterator end, | |
std::vector<std::pair<std::string, std::string> > *unknown_args, | |
parameter::ParamInitOption option) const { | |
std::set<FieldAccessEntry*> selected_args; | |
for (RandomAccessIterator it = begin; it != end; ++it) { | |
FieldAccessEntry *e = Find(it->first); | |
if (e != NULL) { | |
e->Set(head, it->second); | |
e->Check(head); | |
selected_args.insert(e); | |
} else { | |
if (unknown_args != NULL) { | |
unknown_args->push_back(*it); | |
} else { | |
if (option != parameter::kAllowUnknown) { | |
if (option == parameter::kAllowHidden && | |
it->first.length() > 4 && | |
it->first.find("__") == 0 && | |
it->first.rfind("__") == it->first.length()-2) { | |
continue; | |
} | |
std::ostringstream os; | |
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n"; | |
os << "----------------\n"; | |
PrintDocString(os); | |
throw dmlc::ParamError(os.str()); | |
} | |
} | |
} | |
} | |
for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin(); | |
it != entry_map_.end(); ++it) { | |
if (selected_args.count(it->second) == 0) { | |
it->second->SetDefault(head); | |
} | |
} | |
} | |
/*! | |
* \brief internal function to add entry to manager, | |
* The manager will take ownership of the entry. | |
* \param key the key to the parameters | |
* \param e the pointer to the new entry. | |
*/ | |
inline void AddEntry(const std::string &key, FieldAccessEntry *e) { | |
e->index_ = entry_.size(); | |
// TODO(bing) better error message | |
if (entry_map_.count(key) != 0) { | |
LOG(FATAL) << "key " << key << " has already been registered in " << name_; | |
} | |
entry_.push_back(e); | |
entry_map_[key] = e; | |
} | |
/*! | |
* \brief internal function to add entry to manager, | |
* The manager will take ownership of the entry. | |
* \param key the key to the parameters | |
* \param e the pointer to the new entry. | |
*/ | |
inline void AddAlias(const std::string& field, const std::string& alias) { | |
if (entry_map_.count(field) == 0) { | |
LOG(FATAL) << "key " << field << " has not been registered in " << name_; | |
} | |
if (entry_map_.count(alias) != 0) { | |
LOG(FATAL) << "Alias " << alias << " has already been registered in " << name_; | |
} | |
entry_map_[alias] = entry_map_[field]; | |
} | |
/*! | |
* \brief set the name of parameter manager | |
* \param name the name to set | |
*/ | |
inline void set_name(const std::string &name) { | |
name_ = name; | |
} | |
/*! | |
* \brief get field information of each field. | |
* \return field information | |
*/ | |
inline std::vector<ParamFieldInfo> GetFieldInfo() const { | |
std::vector<ParamFieldInfo> ret(entry_.size()); | |
for (size_t i = 0; i < entry_.size(); ++i) { | |
ret[i] = entry_[i]->GetFieldInfo(); | |
} | |
return ret; | |
} | |
/*! | |
* \brief Print readible docstring to ostream, add newline. | |
* \parma os the stream to print the docstring to. | |
*/ | |
inline void PrintDocString(std::ostream &os) const { // NOLINT(*) | |
for (size_t i = 0; i < entry_.size(); ++i) { | |
ParamFieldInfo info = entry_[i]->GetFieldInfo(); | |
os << info.name << " : " << info.type_info_str << '\n'; | |
if (info.description.length() != 0) { | |
os << " " << info.description << '\n'; | |
} | |
} | |
} | |
/*! | |
* \brief Get internal parameters in vector of pairs. | |
* \param head the head of the struct. | |
* \param skip_default skip the values that equals default value. | |
* \return the parameter dictionary. | |
*/ | |
inline std::vector<std::pair<std::string, std::string> > GetDict(void * head) const { | |
std::vector<std::pair<std::string, std::string> > ret; | |
for (std::map<std::string, FieldAccessEntry*>::const_iterator | |
it = entry_map_.begin(); it != entry_map_.end(); ++it) { | |
ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head))); | |
} | |
return ret; | |
} | |
private: | |
/*! \brief parameter struct name */ | |
std::string name_; | |
/*! \brief positional list of entries */ | |
std::vector<FieldAccessEntry*> entry_; | |
/*! \brief map from key to entry */ | |
std::map<std::string, FieldAccessEntry*> entry_map_; | |
}; | |
//! \cond Doxygen_Suppress | |
// The following piece of code will be template heavy and less documented | |
// singleton parameter manager for certain type, used for initialization | |
template<typename PType> | |
struct ParamManagerSingleton { | |
ParamManager manager; | |
explicit ParamManagerSingleton(const std::string ¶m_name) { | |
PType param; | |
param.__DECLARE__(this); | |
manager.set_name(param_name); | |
} | |
}; | |
// Base class of FieldEntry | |
// implement set_default | |
template<typename TEntry, typename DType> | |
class FieldEntryBase : public FieldAccessEntry { | |
public: | |
// entry type | |
typedef TEntry EntryType; | |
// implement set value | |
virtual void Set(void *head, const std::string &value) const { | |
std::istringstream is(value); | |
is >> this->Get(head); | |
if (!is.fail()) { | |
while (!is.eof()) { | |
int ch = is.get(); | |
if (ch == EOF) { | |
is.clear(); break; | |
} | |
if (!isspace(ch)) { | |
is.setstate(std::ios::failbit); break; | |
} | |
} | |
} | |
if (is.fail()) { | |
std::ostringstream os; | |
os << "Invalid Parameter format for " << key_ | |
<< " expect " << type_ << " but value=\'" << value<< '\''; | |
throw dmlc::ParamError(os.str()); | |
} | |
} | |
virtual std::string GetStringValue(void *head) const { | |
std::ostringstream os; | |
PrintValue(os, this->Get(head)); | |
return os.str(); | |
} | |
virtual ParamFieldInfo GetFieldInfo() const { | |
ParamFieldInfo info; | |
std::ostringstream os; | |
info.name = key_; | |
info.type = type_; | |
os << type_; | |
if (has_default_) { | |
os << ',' << " optional, default="; | |
PrintDefaultValueString(os); | |
} else { | |
os << ", required"; | |
} | |
info.type_info_str = os.str(); | |
info.description = description_; | |
return info; | |
} | |
// implement set head to default value | |
virtual void SetDefault(void *head) const { | |
if (!has_default_) { | |
std::ostringstream os; | |
os << "Required parameter " << key_ | |
<< " of " << type_ << " is not presented"; | |
throw dmlc::ParamError(os.str()); | |
} else { | |
this->Get(head) = default_value_; | |
} | |
} | |
// return reference of self as derived type | |
inline TEntry &self() { | |
return *(static_cast<TEntry*>(this)); | |
} | |
// implement set_default | |
inline TEntry &set_default(const DType &default_value) { | |
default_value_ = default_value; | |
has_default_ = true; | |
// return self to allow chaining | |
return this->self(); | |
} | |
// implement describe | |
inline TEntry &describe(const std::string &description) { | |
description_ = description; | |
// return self to allow chaining | |
return this->self(); | |
} | |
// initialization function | |
inline void Init(const std::string &key, | |
void *head, DType &ref) { // NOLINT(*) | |
this->key_ = key; | |
if (this->type_.length() == 0) { | |
this->type_ = dmlc::type_name<DType>(); | |
} | |
this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*) | |
} | |
protected: | |
// print the value | |
virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*) | |
os << value; | |
} | |
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) | |
PrintValue(os, default_value_); | |
} | |
// get the internal representation of parameter | |
// for example if this entry corresponds field param.learning_rate | |
// then Get(¶m) will return reference to param.learning_rate | |
inline DType &Get(void *head) const { | |
return *(DType*)((char*)(head) + offset_); // NOLINT(*) | |
} | |
// internal offset of the field | |
ptrdiff_t offset_; | |
// default value of field | |
DType default_value_; | |
}; | |
// parameter base for numeric types that have range | |
template<typename TEntry, typename DType> | |
class FieldEntryNumeric | |
: public FieldEntryBase<TEntry, DType> { | |
public: | |
FieldEntryNumeric() | |
: has_begin_(false), has_end_(false) {} | |
// implement set_range | |
virtual TEntry &set_range(DType begin, DType end) { | |
begin_ = begin; end_ = end; | |
has_begin_ = true; has_end_ = true; | |
return this->self(); | |
} | |
// implement set_range | |
virtual TEntry &set_lower_bound(DType begin) { | |
begin_ = begin; has_begin_ = true; | |
return this->self(); | |
} | |
// consistency check for numeric ranges | |
virtual void Check(void *head) const { | |
FieldEntryBase<TEntry, DType>::Check(head); | |
DType v = this->Get(head); | |
if (has_begin_ && has_end_) { | |
if (v < begin_ || v > end_) { | |
std::ostringstream os; | |
os << "value " << v << " for Parameter " << this->key_ | |
<< " exceed bound [" << begin_ << ',' << end_ <<']'; | |
throw dmlc::ParamError(os.str()); | |
} | |
} else if (has_begin_ && v < begin_) { | |
std::ostringstream os; | |
os << "value " << v << " for Parameter " << this->key_ | |
<< " should be greater equal to " << begin_; | |
throw dmlc::ParamError(os.str()); | |
} else if (has_end_ && v > end_) { | |
std::ostringstream os; | |
os << "value " << v << " for Parameter " << this->key_ | |
<< " should be smaller equal to " << end_; | |
throw dmlc::ParamError(os.str()); | |
} | |
} | |
protected: | |
// whether it have begin and end range | |
bool has_begin_, has_end_; | |
// data bound | |
DType begin_, end_; | |
}; | |
/*! | |
* \brief FieldEntry defines parsing and checking behavior of DType. | |
* This class can be specialized to implement specific behavior of more settings. | |
* \tparam DType the data type of the entry. | |
*/ | |
template<typename DType> | |
class FieldEntry : | |
public IfThenElseType<dmlc::is_arithmetic<DType>::value, | |
FieldEntryNumeric<FieldEntry<DType>, DType>, | |
FieldEntryBase<FieldEntry<DType>, DType> >::Type { | |
}; | |
// specialize define for int(enum) | |
template<> | |
class FieldEntry<int> | |
: public FieldEntryNumeric<FieldEntry<int>, int> { | |
public: | |
// construct | |
FieldEntry<int>() : is_enum_(false) {} | |
// parent | |
typedef FieldEntryNumeric<FieldEntry<int>, int> Parent; | |
// override set | |
virtual void Set(void *head, const std::string &value) const { | |
if (is_enum_) { | |
std::map<std::string, int>::const_iterator it = enum_map_.find(value); | |
std::ostringstream os; | |
if (it == enum_map_.end()) { | |
os << "Invalid Input: \'" << value; | |
os << "\', valid values are: "; | |
PrintEnums(os); | |
throw dmlc::ParamError(os.str()); | |
} else { | |
os << it->second; | |
Parent::Set(head, os.str()); | |
} | |
} else { | |
Parent::Set(head, value); | |
} | |
} | |
virtual ParamFieldInfo GetFieldInfo() const { | |
if (is_enum_) { | |
ParamFieldInfo info; | |
std::ostringstream os; | |
info.name = key_; | |
info.type = type_; | |
PrintEnums(os); | |
if (has_default_) { | |
os << ',' << "optional, default="; | |
PrintDefaultValueString(os); | |
} else { | |
os << ", required"; | |
} | |
info.type_info_str = os.str(); | |
info.description = description_; | |
return info; | |
} else { | |
return Parent::GetFieldInfo(); | |
} | |
} | |
// add enum | |
inline FieldEntry<int> &add_enum(const std::string &key, int value) { | |
if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \ | |
enum_back_map_.count(value) != 0) { | |
std::ostringstream os; | |
os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n"; | |
os << "Enums: "; | |
for (std::map<std::string, int>::const_iterator it = enum_map_.begin(); | |
it != enum_map_.end(); ++it) { | |
os << "(" << it->first << ": " << it->second << "), "; | |
} | |
throw dmlc::ParamError(os.str()); | |
} | |
enum_map_[key] = value; | |
enum_back_map_[value] = key; | |
is_enum_ = true; | |
return this->self(); | |
} | |
protected: | |
// enum flag | |
bool is_enum_; | |
// enum map | |
std::map<std::string, int> enum_map_; | |
// enum map | |
std::map<int, std::string> enum_back_map_; | |
// override print behavior | |
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) | |
os << '\''; | |
PrintValue(os, default_value_); | |
os << '\''; | |
} | |
// override print default | |
virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*) | |
if (is_enum_) { | |
CHECK_NE(enum_back_map_.count(value), 0U) | |
<< "Value not found in enum declared"; | |
os << enum_back_map_.at(value); | |
} else { | |
os << value; | |
} | |
} | |
private: | |
inline void PrintEnums(std::ostream &os) const { // NOLINT(*) | |
os << '{'; | |
for (std::map<std::string, int>::const_iterator | |
it = enum_map_.begin(); it != enum_map_.end(); ++it) { | |
if (it != enum_map_.begin()) { | |
os << ", "; | |
} | |
os << "\'" << it->first << '\''; | |
} | |
os << '}'; | |
} | |
}; | |
// specialize define for optional<int>(enum) | |
template<> | |
class FieldEntry<optional<int> > | |
: public FieldEntryBase<FieldEntry<optional<int> >, optional<int> > { | |
public: | |
// construct | |
FieldEntry<optional<int> >() : is_enum_(false) {} | |
// parent | |
typedef FieldEntryBase<FieldEntry<optional<int> >, optional<int> > Parent; | |
// override set | |
virtual void Set(void *head, const std::string &value) const { | |
if (is_enum_ && value != "None") { | |
std::map<std::string, int>::const_iterator it = enum_map_.find(value); | |
std::ostringstream os; | |
if (it == enum_map_.end()) { | |
os << "Invalid Input: \'" << value; | |
os << "\', valid values are: "; | |
PrintEnums(os); | |
throw dmlc::ParamError(os.str()); | |
} else { | |
os << it->second; | |
Parent::Set(head, os.str()); | |
} | |
} else { | |
Parent::Set(head, value); | |
} | |
} | |
virtual ParamFieldInfo GetFieldInfo() const { | |
if (is_enum_) { | |
ParamFieldInfo info; | |
std::ostringstream os; | |
info.name = key_; | |
info.type = type_; | |
PrintEnums(os); | |
if (has_default_) { | |
os << ',' << "optional, default="; | |
PrintDefaultValueString(os); | |
} else { | |
os << ", required"; | |
} | |
info.type_info_str = os.str(); | |
info.description = description_; | |
return info; | |
} else { | |
return Parent::GetFieldInfo(); | |
} | |
} | |
// add enum | |
inline FieldEntry<optional<int> > &add_enum(const std::string &key, int value) { | |
CHECK_NE(key, "None") << "None is reserved for empty optional<int>"; | |
if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \ | |
enum_back_map_.count(value) != 0) { | |
std::ostringstream os; | |
os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n"; | |
os << "Enums: "; | |
for (std::map<std::string, int>::const_iterator it = enum_map_.begin(); | |
it != enum_map_.end(); ++it) { | |
os << "(" << it->first << ": " << it->second << "), "; | |
} | |
throw dmlc::ParamError(os.str()); | |
} | |
enum_map_[key] = value; | |
enum_back_map_[value] = key; | |
is_enum_ = true; | |
return this->self(); | |
} | |
protected: | |
// enum flag | |
bool is_enum_; | |
// enum map | |
std::map<std::string, int> enum_map_; | |
// enum map | |
std::map<int, std::string> enum_back_map_; | |
// override print behavior | |
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) | |
os << '\''; | |
PrintValue(os, default_value_); | |
os << '\''; | |
} | |
// override print default | |
virtual void PrintValue(std::ostream &os, optional<int> value) const { // NOLINT(*) | |
if (is_enum_) { | |
if (!value) { | |
os << "None"; | |
} else { | |
CHECK_NE(enum_back_map_.count(value.value()), 0U) | |
<< "Value not found in enum declared"; | |
os << enum_back_map_.at(value.value()); | |
} | |
} else { | |
os << value; | |
} | |
} | |
private: | |
inline void PrintEnums(std::ostream &os) const { // NOLINT(*) | |
os << "{None"; | |
for (std::map<std::string, int>::const_iterator | |
it = enum_map_.begin(); it != enum_map_.end(); ++it) { | |
os << ", "; | |
os << "\'" << it->first << '\''; | |
} | |
os << '}'; | |
} | |
}; | |
// specialize define for string | |
template<> | |
class FieldEntry<std::string> | |
: public FieldEntryBase<FieldEntry<std::string>, std::string> { | |
public: | |
// parent class | |
typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent; | |
// override set | |
virtual void Set(void *head, const std::string &value) const { | |
this->Get(head) = value; | |
} | |
// override print default | |
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) | |
os << '\'' << default_value_ << '\''; | |
} | |
}; | |
// specialize define for bool | |
template<> | |
class FieldEntry<bool> | |
: public FieldEntryBase<FieldEntry<bool>, bool> { | |
public: | |
// parent class | |
typedef FieldEntryBase<FieldEntry<bool>, bool> Parent; | |
// override set | |
virtual void Set(void *head, const std::string &value) const { | |
std::string lower_case; lower_case.resize(value.length()); | |
std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower); | |
bool &ref = this->Get(head); | |
if (lower_case == "true") { | |
ref = true; | |
} else if (lower_case == "false") { | |
ref = false; | |
} else if (lower_case == "1") { | |
ref = true; | |
} else if (lower_case == "0") { | |
ref = false; | |
} else { | |
std::ostringstream os; | |
os << "Invalid Parameter format for " << key_ | |
<< " expect " << type_ << " but value=\'" << value<< '\''; | |
throw dmlc::ParamError(os.str()); | |
} | |
} | |
protected: | |
// print default string | |
virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*) | |
if (value) { | |
os << "True"; | |
} else { | |
os << "False"; | |
} | |
} | |
}; | |
} // namespace parameter | |
//! \endcond | |
// implement GetEnv | |
template<typename ValueType> | |
inline ValueType GetEnv(const char *key, | |
ValueType default_value) { | |
const char *val = getenv(key); | |
if (val == NULL) return default_value; | |
ValueType ret; | |
parameter::FieldEntry<ValueType> e; | |
e.Init(key, &ret, ret); | |
e.Set(&ret, val); | |
return ret; | |
} | |
} // namespace dmlc | |
#endif // DMLC_PARAMETER_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/parameter.h ===== | |
namespace dmlc { | |
/*! | |
* \brief Registry class. | |
* Registry can be used to register global singletons. | |
* The most commonly use case are factory functions. | |
* | |
* \tparam EntryType Type of Registry entries, | |
* EntryType need to name a name field. | |
*/ | |
template<typename EntryType> | |
class Registry { | |
public: | |
/*! \return list of entries in the registry(excluding alias) */ | |
inline static const std::vector<const EntryType*>& List() { | |
return Get()->const_list_; | |
} | |
/*! \return list all names registered in the registry, including alias */ | |
inline static std::vector<std::string> ListAllNames() { | |
const std::map<std::string, EntryType*> &fmap = Get()->fmap_; | |
typename std::map<std::string, EntryType*>::const_iterator p; | |
std::vector<std::string> names; | |
for (p = fmap.begin(); p !=fmap.end(); ++p) { | |
names.push_back(p->first); | |
} | |
return names; | |
} | |
/*! | |
* \brief Find the entry with corresponding name. | |
* \param name name of the function | |
* \return the corresponding function, can be NULL | |
*/ | |
inline static const EntryType *Find(const std::string &name) { | |
const std::map<std::string, EntryType*> &fmap = Get()->fmap_; | |
typename std::map<std::string, EntryType*>::const_iterator p = fmap.find(name); | |
if (p != fmap.end()) { | |
return p->second; | |
} else { | |
return NULL; | |
} | |
} | |
/*! | |
* \brief Add alias to the key_name | |
* \param key_name The original entry key | |
* \param alias The alias key. | |
*/ | |
inline void AddAlias(const std::string& key_name, | |
const std::string& alias) { | |
EntryType* e = fmap_.at(key_name); | |
if (fmap_.count(alias)) { | |
CHECK_EQ(e, fmap_.at(alias)) | |
<< "Entry " << e->name << " already registered under different entry"; | |
} else { | |
fmap_[alias] = e; | |
} | |
} | |
/*! | |
* \brief Internal function to register a name function under name. | |
* \param name name of the function | |
* \return ref to the registered entry, used to set properties | |
*/ | |
inline EntryType &__REGISTER__(const std::string& name) { | |
CHECK_EQ(fmap_.count(name), 0U) | |
<< name << " already registered"; | |
EntryType *e = new EntryType(); | |
e->name = name; | |
fmap_[name] = e; | |
const_list_.push_back(e); | |
entry_list_.push_back(e); | |
return *e; | |
} | |
/*! | |
* \brief Internal function to either register or get registered entry | |
* \param name name of the function | |
* \return ref to the registered entry, used to set properties | |
*/ | |
inline EntryType &__REGISTER_OR_GET__(const std::string& name) { | |
if (fmap_.count(name) == 0) { | |
return __REGISTER__(name); | |
} else { | |
return *fmap_.at(name); | |
} | |
} | |
/*! | |
* \brief get a singleton of the Registry. | |
* This function can be defined by DMLC_ENABLE_REGISTRY. | |
* \return get a singleton | |
*/ | |
static Registry *Get(); | |
private: | |
/*! \brief list of entry types */ | |
std::vector<EntryType*> entry_list_; | |
/*! \brief list of entry types */ | |
std::vector<const EntryType*> const_list_; | |
/*! \brief map of name->function */ | |
std::map<std::string, EntryType*> fmap_; | |
/*! \brief constructor */ | |
Registry() {} | |
/*! \brief destructor */ | |
~Registry() { | |
for (size_t i = 0; i < entry_list_.size(); ++i) { | |
delete entry_list_[i]; | |
} | |
} | |
}; | |
/*! | |
* \brief Common base class for function registry. | |
* | |
* \code | |
* // This example demonstrates how to use Registry to create a factory of trees. | |
* struct TreeFactory : | |
* public FunctionRegEntryBase<TreeFactory, std::function<Tree*()> > { | |
* }; | |
* | |
* // in a independent cc file | |
* namespace dmlc { | |
* DMLC_REGISTRY_ENABLE(TreeFactory); | |
* } | |
* // register binary tree constructor into the registry. | |
* DMLC_REGISTRY_REGISTER(TreeFactory, TreeFactory, BinaryTree) | |
* .describe("Constructor of BinaryTree") | |
* .set_body([]() { return new BinaryTree(); }); | |
* \endcode | |
* | |
* \tparam EntryType The type of subclass that inheritate the base. | |
* \tparam FunctionType The function type this registry is registerd. | |
*/ | |
template<typename EntryType, typename FunctionType> | |
class FunctionRegEntryBase { | |
public: | |
/*! \brief name of the entry */ | |
std::string name; | |
/*! \brief description of the entry */ | |
std::string description; | |
/*! \brief additional arguments to the factory function */ | |
std::vector<ParamFieldInfo> arguments; | |
/*! \brief Function body to create ProductType */ | |
FunctionType body; | |
/*! \brief Return type of the function */ | |
std::string return_type; | |
/*! | |
* \brief Set the function body. | |
* \param body Function body to set. | |
* \return reference to self. | |
*/ | |
inline EntryType &set_body(FunctionType body) { | |
this->body = body; | |
return this->self(); | |
} | |
/*! | |
* \brief Describe the function. | |
* \param description The description of the factory function. | |
* \return reference to self. | |
*/ | |
inline EntryType &describe(const std::string &description) { | |
this->description = description; | |
return this->self(); | |
} | |
/*! | |
* \brief Add argument information to the function. | |
* \param name Name of the argument. | |
* \param type Type of the argument. | |
* \param description Description of the argument. | |
* \return reference to self. | |
*/ | |
inline EntryType &add_argument(const std::string &name, | |
const std::string &type, | |
const std::string &description) { | |
ParamFieldInfo info; | |
info.name = name; | |
info.type = type; | |
info.type_info_str = info.type; | |
info.description = description; | |
arguments.push_back(info); | |
return this->self(); | |
} | |
/*! | |
* \brief Append list if arguments to the end. | |
* \param args Additional list of arguments. | |
* \return reference to self. | |
*/ | |
inline EntryType &add_arguments(const std::vector<ParamFieldInfo> &args) { | |
arguments.insert(arguments.end(), args.begin(), args.end()); | |
return this->self(); | |
} | |
/*! | |
* \brief Set the return type. | |
* \param type Return type of the function, could be Symbol or Symbol[] | |
* \return reference to self. | |
*/ | |
inline EntryType &set_return_type(const std::string &type) { | |
return_type = type; | |
return this->self(); | |
} | |
protected: | |
/*! | |
* \return reference of self as derived type | |
*/ | |
inline EntryType &self() { | |
return *(static_cast<EntryType*>(this)); | |
} | |
}; | |
/*! | |
* \def DMLC_REGISTRY_ENABLE | |
* \brief Macro to enable the registry of EntryType. | |
* This macro must be used under namespace dmlc, and only used once in cc file. | |
* \param EntryType Type of registry entry | |
*/ | |
#define DMLC_REGISTRY_ENABLE(EntryType) \ | |
template<> \ | |
Registry<EntryType > *Registry<EntryType >::Get() { \ | |
static Registry<EntryType > inst; \ | |
return &inst; \ | |
} \ | |
/*! | |
* \brief Generic macro to register an EntryType | |
* There is a complete example in FactoryRegistryEntryBase. | |
* | |
* \param EntryType The type of registry entry. | |
* \param EntryTypeName The typename of EntryType, must do not contain namespace :: . | |
* \param Name The name to be registered. | |
* \sa FactoryRegistryEntryBase | |
*/ | |
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \ | |
static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ | |
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \ | |
/*! | |
* \brief (Optional) Declare a file tag to current file that contains object registrations. | |
* | |
* This will declare a dummy function that will be called by register file to | |
* incur a link dependency. | |
* | |
* \param UniqueTag The unique tag used to represent. | |
* \sa DMLC_REGISTRY_LINK_TAG | |
*/ | |
#define DMLC_REGISTRY_FILE_TAG(UniqueTag) \ | |
int __dmlc_registry_file_tag_ ## UniqueTag ## __() { return 0; } | |
/*! | |
* \brief (Optional) Force link to all the objects registered in file tag. | |
* | |
* This macro must be used in the same file as DMLC_REGISTRY_ENABLE and | |
* in the same namespace as DMLC_REGISTRY_FILE_TAG | |
* | |
* DMLC_REGISTRY_FILE_TAG and DMLC_REGISTRY_LINK_TAG are optional macros for registration. | |
* They are used to encforce link of certain file into during static linking. | |
* | |
* This is mainly used to solve problem during statically link a library which contains backward registration. | |
* Specifically, this avoids the objects in these file tags to be ignored by compiler. | |
* | |
* For dynamic linking, this problem won't occur as everything is loaded by default. | |
* | |
* Use of this is optional as it will create an error when a file tag do not exist. | |
* An alternative solution is always ask user to enable --whole-archieve during static link. | |
* | |
* \begincode | |
* // in file objective_registry.cc | |
* DMLC_REGISTRY_ENABLE(MyObjective); | |
* DMLC_REGISTRY_LINK_TAG(regression_op); | |
* DMLC_REGISTRY_LINK_TAG(rank_op); | |
* | |
* // in file regression_op.cc | |
* // declare tag of this file. | |
* DMLC_REGISTRY_FILE_TAG(regression_op); | |
* DMLC_REGISTRY_REGISTER(MyObjective, logistic_reg, logistic_reg); | |
* // ... | |
* | |
* // in file rank_op.cc | |
* // declare tag of this file. | |
* DMLC_REGISTRY_FILE_TAG(rank_op); | |
* DMLC_REGISTRY_REGISTER(MyObjective, pairwiserank, pairwiserank); | |
* | |
* \endcode | |
* | |
* \param UniqueTag The unique tag used to represent. | |
* \sa DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_FILE_TAG | |
*/ | |
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \ | |
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \ | |
static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \ | |
__dmlc_registry_file_tag_ ## UniqueTag ## __(); | |
} // namespace dmlc | |
#endif // DMLC_REGISTRY_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/registry.h ===== | |
namespace dmlc { | |
/*! | |
* \brief this defines the float point | |
* that will be used to store feature values | |
*/ | |
typedef float real_t; | |
/*! | |
* \brief this defines the unsigned integer type | |
* that can normally be used to store feature index | |
*/ | |
typedef unsigned index_t; | |
// This file describes common data structure that can be used | |
// for large-scale machine learning, this may not be a complete list | |
// But we will keep the most common and useful ones, and keep adding new ones | |
/*! | |
* \brief data iterator interface | |
* this is not a C++ style iterator, but nice for data pulling:) | |
* This interface is used to pull in the data | |
* The system can do some useful tricks for you like pre-fetching | |
* from disk and pre-computation. | |
* | |
* Usage example: | |
* \code | |
* | |
* itr->BeforeFirst(); | |
* while (itr->Next()) { | |
* const DType &batch = itr->Value(); | |
* // some computations | |
* } | |
* \endcode | |
* \tparam DType the data type | |
*/ | |
template<typename DType> | |
class DataIter { | |
public: | |
/*! \brief destructor */ | |
virtual ~DataIter(void) {} | |
/*! \brief set before first of the item */ | |
virtual void BeforeFirst(void) = 0; | |
/*! \brief move to next item */ | |
virtual bool Next(void) = 0; | |
/*! \brief get current data */ | |
virtual const DType &Value(void) const = 0; | |
}; | |
/*! | |
* \brief one row of training instance | |
* \tparam IndexType type of index | |
*/ | |
template<typename IndexType> | |
class Row { | |
public: | |
/*! \brief label of the instance */ | |
real_t label; | |
/*! \brief weight of the instance */ | |
real_t weight; | |
/*! \brief length of the sparse vector */ | |
size_t length; | |
/*! | |
* \brief index of each instance | |
*/ | |
const IndexType *index; | |
/*! | |
* \brief array value of each instance, this can be NULL | |
* indicating every value is set to be 1 | |
*/ | |
const real_t *value; | |
/*! | |
* \param i the input index | |
* \return i-th feature | |
*/ | |
inline IndexType get_index(size_t i) const { | |
return index[i]; | |
} | |
/*! | |
* \param i the input index | |
* \return i-th feature value, this function is always | |
* safe even when value == NULL | |
*/ | |
inline real_t get_value(size_t i) const { | |
return value == NULL ? 1.0f : value[i]; | |
} | |
/*! | |
* \brief helper function to compute dot product of current | |
* \param weight the dense array of weight we want to product | |
* \param size the size of the weight vector | |
* \tparam V type of the weight vector | |
* \return the result of dot product | |
*/ | |
template<typename V> | |
inline V SDot(const V *weight, size_t size) const { | |
V sum = static_cast<V>(0); | |
if (value == NULL) { | |
for (size_t i = 0; i < length; ++i) { | |
CHECK(index[i] < size) << "feature index exceed bound"; | |
sum += weight[index[i]]; | |
} | |
} else { | |
for (size_t i = 0; i < length; ++i) { | |
CHECK(index[i] < size) << "feature index exceed bound"; | |
sum += weight[index[i]] * value[i]; | |
} | |
} | |
return sum; | |
} | |
}; | |
/*! | |
* \brief a block of data, containing several rows in sparse matrix | |
* This is useful for (streaming-sxtyle) algorithms that scans through rows of data | |
* examples include: SGD, GD, L-BFGS, kmeans | |
* | |
* The size of batch is usually large enough so that parallelizing over the rows | |
* can give significant speedup | |
* \tparam IndexType type to store the index used in row batch | |
*/ | |
template<typename IndexType> | |
struct RowBlock { | |
/*! \brief batch size */ | |
size_t size; | |
/*! \brief array[size+1], row pointer to beginning of each rows */ | |
const size_t *offset; | |
/*! \brief array[size] label of each instance */ | |
const real_t *label; | |
/*! \brief With weight: array[size] label of each instance, otherwise nullptr */ | |
const real_t *weight; | |
/*! \brief feature index */ | |
const IndexType *index; | |
/*! \brief feature value, can be NULL, indicating all values are 1 */ | |
const real_t *value; | |
/*! | |
* \brief get specific rows in the batch | |
* \param rowid the rowid in that row | |
* \return the instance corresponding to the row | |
*/ | |
inline Row<IndexType> operator[](size_t rowid) const; | |
/*! \return memory cost of the block in bytes */ | |
inline size_t MemCostBytes(void) const { | |
size_t cost = size * (sizeof(size_t) + sizeof(real_t)); | |
if (weight != NULL) cost += size * sizeof(real_t); | |
size_t ndata = offset[size] - offset[0]; | |
if (index != NULL) cost += ndata * sizeof(IndexType); | |
if (value != NULL) cost += ndata * sizeof(real_t); | |
return cost; | |
} | |
/*! | |
* \brief slice a RowBlock to get rows in [begin, end) | |
* \param begin the begin row index | |
* \param end the end row index | |
* \return the sliced RowBlock | |
*/ | |
inline RowBlock Slice(size_t begin, size_t end) const { | |
CHECK(begin <= end && end <= size); | |
RowBlock ret; | |
ret.size = end - begin; | |
ret.label = label + begin; | |
if (weight != NULL) { | |
ret.weight = weight + begin; | |
} else { | |
ret.weight = NULL; | |
} | |
ret.offset = offset + begin; | |
ret.index = index; | |
ret.value = value; | |
return ret; | |
} | |
}; | |
/*! | |
* \brief Data structure that holds the data | |
* Row block iterator interface that gets RowBlocks | |
* Difference between RowBlockIter and Parser: | |
* RowBlockIter caches the data internally that can be used | |
* to iterate the dataset multiple times, | |
* Parser holds very limited internal state and was usually | |
* used to read data only once | |
* | |
* \sa Parser | |
* \tparam IndexType type of index in RowBlock | |
* Create function was only implemented for IndexType uint64_t and uint32_t | |
*/ | |
template<typename IndexType> | |
class RowBlockIter : public DataIter<RowBlock<IndexType> > { | |
public: | |
/*! | |
* \brief create a new instance of iterator that returns rowbatch | |
* by default, a in-memory based iterator will be returned | |
* | |
* \param uri the uri of the input, can contain hdfs prefix | |
* \param part_index the part id of current input | |
* \param num_parts total number of splits | |
* \param type type of dataset can be: "libsvm", ... | |
* | |
* \return the created data iterator | |
*/ | |
static RowBlockIter<IndexType> * | |
Create(const char *uri, | |
unsigned part_index, | |
unsigned num_parts, | |
const char *type); | |
/*! \return maximum feature dimension in the dataset */ | |
virtual size_t NumCol() const = 0; | |
}; | |
/*! | |
* \brief parser interface that parses input data | |
* used to load dmlc data format into your own data format | |
* Difference between RowBlockIter and Parser: | |
* RowBlockIter caches the data internally that can be used | |
* to iterate the dataset multiple times, | |
* Parser holds very limited internal state and was usually | |
* used to read data only once | |
* | |
* | |
* \sa RowBlockIter | |
* \tparam IndexType type of index in RowBlock | |
* Create function was only implemented for IndexType uint64_t and uint32_t | |
*/ | |
template <typename IndexType> | |
class Parser : public DataIter<RowBlock<IndexType> > { | |
public: | |
/*! | |
* \brief create a new instance of parser based on the "type" | |
* | |
* \param uri_ the uri of the input, can contain hdfs prefix | |
* \param part_index the part id of current input | |
* \param num_parts total number of splits | |
* \param type type of dataset can be: "libsvm", "auto", ... | |
* | |
* When "auto" is passed, the type is decided by format argument string in URI. | |
* | |
* \return the created parser | |
*/ | |
static Parser<IndexType> * | |
Create(const char *uri_, | |
unsigned part_index, | |
unsigned num_parts, | |
const char *type); | |
/*! \return size of bytes read so far */ | |
virtual size_t BytesRead(void) const = 0; | |
/*! \brief Factory type of the parser*/ | |
typedef Parser<IndexType>* (*Factory) | |
(const std::string& path, | |
const std::map<std::string, std::string>& args, | |
unsigned part_index, | |
unsigned num_parts); | |
}; | |
/*! | |
* \brief registry entry of parser factory | |
* \tparam IndexType The type of index | |
*/ | |
template<typename IndexType> | |
struct ParserFactoryReg | |
: public FunctionRegEntryBase<ParserFactoryReg<IndexType>, | |
typename Parser<IndexType>::Factory> {}; | |
/*! | |
* \brief Register a new distributed parser to dmlc-core. | |
* | |
* \param IndexType The type of Batch index, can be uint32_t or uint64_t | |
* \param TypeName The typename of of the data. | |
* \param FactoryFunction The factory function that creates the parser. | |
* | |
* \begincode | |
* | |
* // defin the factory function | |
* template<typename IndexType> | |
* Parser<IndexType>* | |
* CreateLibSVMParser(const char* uri, unsigned part_index, unsigned num_parts) { | |
* return new LibSVMParser(uri, part_index, num_parts); | |
* } | |
* | |
* // Register it to DMLC | |
* // Then we can use Parser<uint32_t>::Create(uri, part_index, num_parts, "libsvm"); | |
* // to create the parser | |
* | |
* DMLC_REGISTER_DATA_PARSER(uint32_t, libsvm, CreateLibSVMParser<uint32_t>); | |
* DMLC_REGISTER_DATA_PARSER(uint64_t, libsvm, CreateLibSVMParser<uint64_t>); | |
* | |
* \endcode | |
*/ | |
#define DMLC_REGISTER_DATA_PARSER(IndexType, TypeName, FactoryFunction) \ | |
DMLC_REGISTRY_REGISTER(::dmlc::ParserFactoryReg<IndexType>, \ | |
ParserFactoryReg ## _ ## IndexType, TypeName) \ | |
.set_body(FactoryFunction) | |
// implementation of operator[] | |
template<typename IndexType> | |
inline Row<IndexType> | |
RowBlock<IndexType>::operator[](size_t rowid) const { | |
CHECK(rowid < size); | |
Row<IndexType> inst; | |
inst.label = label[rowid]; | |
if (weight != NULL) { | |
inst.weight = weight[rowid]; | |
} else { | |
inst.weight = 1.0f; | |
} | |
inst.length = offset[rowid + 1] - offset[rowid]; | |
inst.index = index + offset[rowid]; | |
if (value == NULL) { | |
inst.value = NULL; | |
} else { | |
inst.value = value + offset[rowid]; | |
} | |
return inst; | |
} | |
} // namespace dmlc | |
#endif // DMLC_DATA_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/data.h ===== | |
//===== EXPANDING: ../dmlc-core/src/io/uri_spec.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file uri_spec.h | |
* \brief common specification of sugars in URI | |
* string passed to dmlc Create functions | |
* such as local file cache | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_IO_URI_SPEC_H_ | |
#define DMLC_IO_URI_SPEC_H_ | |
namespace dmlc { | |
namespace io { | |
/*! | |
* \brief some super set of URI | |
* that allows sugars to be passed around | |
* Example: | |
* | |
* hdfs:///mylibsvm/?format=libsvm&clabel=0#mycache-file. | |
*/ | |
class URISpec { | |
public: | |
/*! \brief the real URI */ | |
std::string uri; | |
/*! \brief arguments in the URL */ | |
std::map<std::string, std::string> args; | |
/*! \brief the path to cache file */ | |
std::string cache_file; | |
/*! | |
* \brief constructor. | |
* \param uri The raw uri string. | |
* \param part_index The parition index of the part. | |
* \param num_parts total number of parts. | |
*/ | |
explicit URISpec(const std::string& uri, | |
unsigned part_index, | |
unsigned num_parts) { | |
std::vector<std::string> name_cache = Split(uri, '#'); | |
if (name_cache.size() == 2) { | |
std::ostringstream os; | |
os << name_cache[1]; | |
if (num_parts != 1) { | |
os << ".split" << num_parts << ".part" << part_index; | |
} | |
this->cache_file = os.str(); | |
} else { | |
CHECK_EQ(name_cache.size(), 1U) | |
<< "only one `#` is allowed in file path for cachefile specification"; | |
} | |
std::vector<std::string> name_args = Split(name_cache[0], '?'); | |
if (name_args.size() == 2) { | |
std::vector<std::string> arg_list = Split(name_args[1], '&'); | |
for (size_t i = 0; i < arg_list.size(); ++i) { | |
std::istringstream is(arg_list[i]); | |
std::pair<std::string, std::string> kv; | |
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format" | |
<< " for key in arg " << i + 1; | |
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format" | |
<< " for value in arg " << i + 1; | |
this->args.insert(kv); | |
} | |
} else { | |
CHECK_EQ(name_args.size(), 1U) | |
<< "only one `#` is allowed in file path for cachefile specification"; | |
} | |
this->uri = name_args[0]; | |
} | |
}; | |
} // namespace io | |
} // namespace dmlc | |
#endif // DMLC_IO_URI_SPEC_H_ | |
//===== EXPANDED: ../dmlc-core/src/io/uri_spec.h ===== | |
//===== EXPANDING: ../dmlc-core/src/data/parser.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file libsvm_parser.h | |
* \brief iterator parser to parse libsvm format | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_DATA_PARSER_H_ | |
#define DMLC_DATA_PARSER_H_ | |
//===== EXPANDING: ../dmlc-core/include/dmlc/threadediter.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file threadediter.h | |
* \brief thread backed iterator that can be used to implement | |
* general thread-based pipeline such as prefetch and pre-computation | |
* To use the functions in this header, C++11 is required | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_THREADEDITER_H_ | |
#define DMLC_THREADEDITER_H_ | |
// defines DMLC_USE_CXX11 | |
// this code depends on c++11 | |
#if DMLC_ENABLE_STD_THREAD | |
namespace dmlc { | |
/*! | |
* \brief a iterator that was backed by a thread | |
* to pull data eagerly from a single producer into a bounded buffer | |
* the consumer can pull the data at its own rate | |
* | |
* NOTE: thread concurrency cost time, make sure to store big blob of data in DType | |
* | |
* Usage example: | |
* \code | |
* ThreadedIter<DType> iter; | |
* iter.Init(&producer); | |
* // the following code can be in parallel | |
* DType *dptr; | |
* while (iter.Next(&dptr)) { | |
* // do something on dptr | |
* // recycle the space | |
* iter.Recycle(&dptr); | |
* } | |
* \endcode | |
* \tparam DType the type of data blob we support | |
*/ | |
template<typename DType> | |
class ThreadedIter : public DataIter<DType> { | |
public: | |
/*! | |
* \brief producer class interface | |
* that threaditer used as source to | |
* preduce the content | |
*/ | |
class Producer { | |
public: | |
// virtual destructor | |
virtual ~Producer() {} | |
/*! \brief reset the producer to beginning */ | |
virtual void BeforeFirst(void) { | |
NotImplemented(); | |
} | |
/*! | |
* \brief load the data content into DType, | |
* the caller can pass in NULL or an existing address | |
* when inout_dptr is NULL: | |
* producer need to allocate a DType and fill the content | |
* when inout_dptr is specified | |
* producer takes need to fill the content into address | |
* specified inout_dptr, or delete the one and create a new one | |
* | |
* \param inout_dptr used to pass in the data holder cell | |
* and return the address of the cell filled | |
* \return true if there is next record, false if we reach the end | |
*/ | |
virtual bool Next(DType **inout_dptr) = 0; | |
}; | |
/*! | |
* \brief constructor | |
* \param max_capacity maximum capacity of the queue | |
*/ | |
explicit ThreadedIter(size_t max_capacity = 8) | |
: producer_owned_(NULL), | |
producer_thread_(NULL), | |
max_capacity_(max_capacity), | |
nwait_consumer_(0), | |
nwait_producer_(0), | |
out_data_(NULL) {} | |
/*! \brief destructor */ | |
virtual ~ThreadedIter(void) { | |
this->Destroy(); | |
} | |
/*! | |
* \brief destroy all the related resources | |
* this is equivalent to destructor, can be used | |
* to destroy the threaditer when user think it is | |
* appropriate, it is safe to call this multiple times | |
*/ | |
inline void Destroy(void); | |
/*! | |
* \brief set maximum capacity of the queue | |
* \param max_capacity maximum capacity of the queue | |
*/ | |
inline void set_max_capacity(size_t max_capacity) { | |
max_capacity_ = max_capacity; | |
} | |
/*! | |
* \brief initialize the producer and start the thread | |
* can only be called once | |
* \param producer pointer to the producer | |
* \param pass_ownership whether pass the ownership to the iter | |
* if this is true, the threaditer will delete the producer | |
* when destructed | |
*/ | |
inline void Init(Producer *producer, bool pass_ownership = false); | |
/*! | |
* \brief initialize the producer and start the thread | |
* pass in two function(closure) of producer to represent the producer | |
* the beforefirst function is optional, and defaults to not implemented | |
* NOTE: the closure must remain valid until the ThreadedIter destructs | |
* \param next the function called to get next element, see Producer.Next | |
* \param beforefirst the function to call to reset the producer, see Producer.BeforeFirst | |
*/ | |
inline void Init(std::function<bool(DType **)> next, | |
std::function<void()> beforefirst = NotImplemented); | |
/*! | |
* \brief get the next data, this function is threadsafe | |
* \param out_dptr used to hold the pointer to the record | |
* after the function call, the caller takes ownership of the pointer | |
* the caller can call recycle to return ownership back to the threaditer | |
* so that the pointer can be re-used | |
* \return true if there is next record, false if we reach the end | |
* \sa Recycle | |
*/ | |
inline bool Next(DType **out_dptr); | |
/*! | |
* \brief recycle the data cell, this function is threadsafe | |
* the threaditer can reuse the data cell for future data loading | |
* \param inout_dptr pointer to the dptr to recycle, after the function call | |
* the content of inout_dptr will be set to NULL | |
*/ | |
inline void Recycle(DType **inout_dptr); | |
/*! | |
* \brief adapt the iterator interface's Next | |
* NOTE: the call to this function is not threadsafe | |
* use the other Next instead | |
* \return true if there is next record, false if we reach the end | |
*/ | |
virtual bool Next(void) { | |
if (out_data_ != NULL) { | |
this->Recycle(&out_data_); | |
} | |
if (Next(&out_data_)) { | |
return true; | |
} else { | |
return false; | |
} | |
} | |
/*! | |
* \brief adapt the iterator interface's Value | |
* NOTE: the call to this function is not threadsafe | |
* use the other Next instead | |
*/ | |
virtual const DType &Value(void) const { | |
CHECK(out_data_ != NULL) << "Calling Value at beginning or end?"; | |
return *out_data_; | |
} | |
/*! \brief set the iterator before first location */ | |
virtual void BeforeFirst(void) { | |
std::unique_lock<std::mutex> lock(mutex_); | |
if (out_data_ != NULL) { | |
free_cells_.push(out_data_); | |
out_data_ = NULL; | |
} | |
if (producer_sig_ == kDestroy) return; | |
producer_sig_ = kBeforeFirst; | |
CHECK(!producer_sig_processed_); | |
if (nwait_producer_ != 0) { | |
producer_cond_.notify_one(); | |
} | |
CHECK(!producer_sig_processed_); | |
// wait until the request has been processed | |
consumer_cond_.wait(lock, [this]() { | |
return producer_sig_processed_; | |
}); | |
producer_sig_processed_ = false; | |
bool notify = nwait_producer_ != 0 && !produce_end_; | |
lock.unlock(); | |
// notify producer, in case they are waiting for the condition. | |
if (notify) producer_cond_.notify_one(); | |
} | |
private: | |
/*! \brief not support BeforeFirst */ | |
inline static void NotImplemented(void) { | |
LOG(FATAL) << "BeforeFirst is not supported"; | |
} | |
/*! \brief signals send to producer */ | |
enum Signal { | |
kProduce, | |
kBeforeFirst, | |
kDestroy | |
}; | |
/*! \brief producer class */ | |
Producer *producer_owned_; | |
/*! \brief signal to producer */ | |
Signal producer_sig_; | |
/*! \brief whether the special signal other than kProduce is procssed */ | |
bool producer_sig_processed_; | |
/*! \brief thread that runs the producer */ | |
std::thread *producer_thread_; | |
/*! \brief whether produce ends */ | |
bool produce_end_; | |
/*! \brief maximum queue size */ | |
size_t max_capacity_; | |
/*! \brief internal mutex */ | |
std::mutex mutex_; | |
/*! \brief number of consumer waiting */ | |
unsigned nwait_consumer_; | |
/*! \brief number of consumer waiting */ | |
unsigned nwait_producer_; | |
/*! \brief conditional variable for producer thread */ | |
std::condition_variable producer_cond_; | |
/*! \brief conditional variable for consumer threads */ | |
std::condition_variable consumer_cond_; | |
/*! \brief the current output cell */ | |
DType *out_data_; | |
/*! \brief internal queue of producer */ | |
std::queue<DType*> queue_; | |
/*! \brief free cells that can be used */ | |
std::queue<DType*> free_cells_; | |
}; | |
// implementation of functions | |
template<typename DType> | |
inline void ThreadedIter<DType>::Destroy(void) { | |
if (producer_thread_ != NULL) { | |
{ | |
// lock the mutex | |
std::lock_guard<std::mutex> lock(mutex_); | |
// send destroy signal | |
producer_sig_ = kDestroy; | |
if (nwait_producer_ != 0) { | |
producer_cond_.notify_one(); | |
} | |
} | |
producer_thread_->join(); | |
delete producer_thread_; | |
producer_thread_ = NULL; | |
} | |
// end of critical region | |
// now the slave thread should exit | |
while (free_cells_.size() != 0) { | |
delete free_cells_.front(); | |
free_cells_.pop(); | |
} | |
while (queue_.size() != 0) { | |
delete queue_.front(); | |
queue_.pop(); | |
} | |
if (producer_owned_ != NULL) { | |
delete producer_owned_; | |
} | |
if (out_data_ != NULL) { | |
delete out_data_; out_data_ = NULL; | |
} | |
} | |
template<typename DType> | |
inline void ThreadedIter<DType>:: | |
Init(Producer *producer, bool pass_ownership) { | |
CHECK(producer_owned_ == NULL) << "can only call Init once"; | |
if (pass_ownership) producer_owned_ = producer; | |
auto next = [producer](DType **dptr) { | |
return producer->Next(dptr); | |
}; | |
auto beforefirst = [producer]() { | |
producer->BeforeFirst(); | |
}; | |
this->Init(next, beforefirst); | |
} | |
template<typename DType> | |
inline void ThreadedIter<DType>:: | |
Init(std::function<bool(DType **)> next, | |
std::function<void()> beforefirst) { | |
producer_sig_ = kProduce; | |
producer_sig_processed_ = false; | |
produce_end_ = false; | |
// procedure running in prodcuer | |
// run producer thread | |
auto producer_fun = [this, next, beforefirst] () { | |
while (true) { | |
DType *cell = NULL; | |
{ | |
// lockscope | |
std::unique_lock<std::mutex> lock(mutex_); | |
++this->nwait_producer_; | |
producer_cond_.wait(lock, [this]() { | |
if (producer_sig_ == kProduce) { | |
bool ret = !produce_end_ && | |
(queue_.size() < max_capacity_ || free_cells_.size() != 0); | |
return ret; | |
} else { | |
return true; | |
} | |
}); | |
--this->nwait_producer_; | |
if (producer_sig_ == kProduce) { | |
if (free_cells_.size() != 0) { | |
cell = free_cells_.front(); | |
free_cells_.pop(); | |
} | |
} else if (producer_sig_ == kBeforeFirst) { | |
// reset the producer | |
beforefirst(); | |
// cleanup the queue | |
while (queue_.size() != 0) { | |
free_cells_.push(queue_.front()); | |
queue_.pop(); | |
} | |
// reset the state | |
produce_end_ = false; | |
producer_sig_processed_ = true; | |
producer_sig_ = kProduce; | |
// notify consumer that all the process as been done. | |
lock.unlock(); | |
consumer_cond_.notify_all(); | |
continue; | |
} else { | |
// destroy the thread | |
CHECK(producer_sig_ == kDestroy); | |
producer_sig_processed_ = true; | |
produce_end_ = true; | |
consumer_cond_.notify_all(); | |
return; | |
} | |
} // end of lock scope | |
// now without lock | |
produce_end_ = !next(&cell); | |
CHECK(cell != NULL || produce_end_); | |
bool notify; | |
{ | |
// lockscope | |
std::lock_guard<std::mutex> lock(mutex_); | |
if (!produce_end_) { | |
queue_.push(cell); | |
} else { | |
if (cell != NULL) free_cells_.push(cell); | |
} | |
// put things into queue | |
notify = nwait_consumer_ != 0; | |
} | |
if (notify) consumer_cond_.notify_all(); | |
} | |
}; | |
producer_thread_ = new std::thread(producer_fun); | |
} | |
template<typename DType> | |
inline bool ThreadedIter<DType>:: | |
Next(DType **out_dptr) { | |
if (producer_sig_ == kDestroy) return false; | |
std::unique_lock<std::mutex> lock(mutex_); | |
CHECK(producer_sig_ == kProduce) | |
<< "Make sure you call BeforeFirst not inconcurrent with Next!"; | |
++nwait_consumer_; | |
consumer_cond_.wait(lock, [this]() { | |
return queue_.size() != 0 || produce_end_; | |
}); | |
--nwait_consumer_; | |
if (queue_.size() != 0) { | |
*out_dptr = queue_.front(); | |
queue_.pop(); | |
bool notify = nwait_producer_ != 0 && !produce_end_; | |
lock.unlock(); | |
if (notify) producer_cond_.notify_one(); | |
return true; | |
} else { | |
CHECK(produce_end_); | |
return false; | |
} | |
} | |
template<typename DType> | |
inline void ThreadedIter<DType>::Recycle(DType **inout_dptr) { | |
bool notify; | |
{ | |
std::lock_guard<std::mutex> lock(mutex_); | |
free_cells_.push(*inout_dptr); | |
*inout_dptr = NULL; | |
notify = nwait_producer_ != 0 && !produce_end_; | |
} | |
if (notify) producer_cond_.notify_one(); | |
} | |
} // namespace dmlc | |
#endif // DMLC_USE_CXX11 | |
#endif // DMLC_THREADEDITER_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/threadediter.h ===== | |
//===== EXPANDING: ../dmlc-core/src/data/row_block.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file row_block.h | |
* \brief additional data structure to support | |
* RowBlock data structure | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_DATA_ROW_BLOCK_H_ | |
#define DMLC_DATA_ROW_BLOCK_H_ | |
namespace dmlc { | |
namespace data { | |
/*! | |
* \brief dynamic data structure that holds | |
* a row block of data | |
* \tparam IndexType the type of index we are using | |
*/ | |
template<typename IndexType> | |
struct RowBlockContainer { | |
/*! \brief array[size+1], row pointer to beginning of each rows */ | |
std::vector<size_t> offset; | |
/*! \brief array[size] label of each instance */ | |
std::vector<real_t> label; | |
/*! \brief array[size] weight of each instance */ | |
std::vector<real_t> weight; | |
/*! \brief feature index */ | |
std::vector<IndexType> index; | |
/*! \brief feature value */ | |
std::vector<real_t> value; | |
/*! \brief maximum value of index */ | |
IndexType max_index; | |
// constructor | |
RowBlockContainer(void) { | |
this->Clear(); | |
} | |
/*! \brief convert to a row block */ | |
inline RowBlock<IndexType> GetBlock(void) const; | |
/*! | |
* \brief write the row block to a binary stream | |
* \param fo output stream | |
*/ | |
inline void Save(Stream *fo) const; | |
/*! | |
* \brief load row block from a binary stream | |
* \param fi output stream | |
* \return false if at end of file | |
*/ | |
inline bool Load(Stream *fi); | |
/*! \brief clear the container */ | |
inline void Clear(void) { | |
offset.clear(); offset.push_back(0); | |
label.clear(); index.clear(); value.clear(); weight.clear(); | |
max_index = 0; | |
} | |
/*! \brief size of the data */ | |
inline size_t Size(void) const { | |
return offset.size() - 1; | |
} | |
/*! \return estimation of memory cost of this container */ | |
inline size_t MemCostBytes(void) const { | |
return offset.size() * sizeof(size_t) + | |
label.size() * sizeof(real_t) + | |
weight.size() * sizeof(real_t) + | |
index.size() * sizeof(IndexType) + | |
value.size() * sizeof(real_t); | |
} | |
/*! | |
* \brief push the row into container | |
* \param row the row to push back | |
* \tparam I the index type of the row | |
*/ | |
template<typename I> | |
inline void Push(Row<I> row) { | |
label.push_back(row.label); | |
weight.push_back(row.weight); | |
for (size_t i = 0; i < row.length; ++i) { | |
CHECK_LE(row.index[i], std::numeric_limits<IndexType>::max()) | |
<< "index exceed numeric bound of current type"; | |
IndexType findex = static_cast<IndexType>(row.index[i]); | |
index.push_back(findex); | |
max_index = std::max(max_index, findex); | |
} | |
if (row.value != NULL) { | |
for (size_t i = 0; i < row.length; ++i) { | |
value.push_back(row.value[i]); | |
} | |
} | |
offset.push_back(index.size()); | |
} | |
/*! | |
* \brief push the row block into container | |
* \param row the row to push back | |
* \tparam I the index type of the row | |
*/ | |
template<typename I> | |
inline void Push(RowBlock<I> batch) { | |
size_t size = label.size(); | |
label.resize(label.size() + batch.size); | |
std::memcpy(BeginPtr(label) + size, batch.label, | |
batch.size * sizeof(real_t)); | |
if (batch.weight != NULL) { | |
weight.insert(weight.end(), batch.weight, batch.weight + batch.size); | |
} | |
size_t ndata = batch.offset[batch.size] - batch.offset[0]; | |
index.resize(index.size() + ndata); | |
IndexType *ihead = BeginPtr(index) + offset.back(); | |
for (size_t i = 0; i < ndata; ++i) { | |
CHECK_LE(batch.index[i], std::numeric_limits<IndexType>::max()) | |
<< "index exceed numeric bound of current type"; | |
IndexType findex = static_cast<IndexType>(batch.index[i]); | |
ihead[i] = findex; | |
max_index = std::max(max_index, findex); | |
} | |
if (batch.value != NULL) { | |
value.resize(value.size() + ndata); | |
std::memcpy(BeginPtr(value) + value.size() - ndata, batch.value, | |
ndata * sizeof(real_t)); | |
} | |
size_t shift = offset[size]; | |
offset.resize(offset.size() + batch.size); | |
size_t *ohead = BeginPtr(offset) + size + 1; | |
for (size_t i = 0; i < batch.size; ++i) { | |
ohead[i] = shift + batch.offset[i + 1] - batch.offset[0]; | |
} | |
} | |
}; | |
template<typename IndexType> | |
inline RowBlock<IndexType> | |
RowBlockContainer<IndexType>::GetBlock(void) const { | |
// consistency check | |
if (label.size()) { | |
CHECK_EQ(label.size() + 1, offset.size()); | |
} | |
CHECK_EQ(offset.back(), index.size()); | |
CHECK(offset.back() == value.size() || value.size() == 0); | |
RowBlock<IndexType> data; | |
data.size = offset.size() - 1; | |
data.offset = BeginPtr(offset); | |
data.label = BeginPtr(label); | |
data.weight = BeginPtr(weight); | |
data.index = BeginPtr(index); | |
data.value = BeginPtr(value); | |
return data; | |
} | |
template<typename IndexType> | |
inline void | |
RowBlockContainer<IndexType>::Save(Stream *fo) const { | |
fo->Write(offset); | |
fo->Write(label); | |
fo->Write(weight); | |
fo->Write(index); | |
fo->Write(value); | |
fo->Write(&max_index, sizeof(IndexType)); | |
} | |
template<typename IndexType> | |
inline bool | |
RowBlockContainer<IndexType>::Load(Stream *fi) { | |
if (!fi->Read(&offset)) return false; | |
CHECK(fi->Read(&label)) << "Bad RowBlock format"; | |
CHECK(fi->Read(&weight)) << "Bad RowBlock format"; | |
CHECK(fi->Read(&index)) << "Bad RowBlock format"; | |
CHECK(fi->Read(&value)) << "Bad RowBlock format"; | |
CHECK(fi->Read(&max_index, sizeof(IndexType))) << "Bad RowBlock format"; | |
return true; | |
} | |
} // namespace data | |
} // namespace dmlc | |
#endif // DMLC_DATA_ROW_BLOCK_H_ | |
//===== EXPANDED: ../dmlc-core/src/data/row_block.h ===== | |
namespace dmlc { | |
namespace data { | |
/*! \brief declare thread class */ | |
template <typename IndexType> | |
class ThreadedParser; | |
/*! \brief base class for parser to parse data */ | |
template <typename IndexType> | |
class ParserImpl : public Parser<IndexType> { | |
public: | |
ParserImpl() : data_ptr_(0), data_end_(0) {} | |
// virtual destructor | |
virtual ~ParserImpl() {} | |
/*! \brief implement next */ | |
virtual bool Next(void) { | |
while (true) { | |
while (data_ptr_ < data_end_) { | |
data_ptr_ += 1; | |
if (data_[data_ptr_ - 1].Size() != 0) { | |
block_ = data_[data_ptr_ - 1].GetBlock(); | |
return true; | |
} | |
} | |
if (!ParseNext(&data_)) break; | |
data_ptr_ = 0; | |
data_end_ = static_cast<IndexType>(data_.size()); | |
} | |
return false; | |
} | |
virtual const RowBlock<IndexType> &Value(void) const { | |
return block_; | |
} | |
/*! \return size of bytes read so far */ | |
virtual size_t BytesRead(void) const = 0; | |
protected: | |
// allow ThreadedParser to see ParseNext | |
friend class ThreadedParser<IndexType>; | |
/*! | |
* \brief read in next several blocks of data | |
* \param data vector of data to be returned | |
* \return true if the data is loaded, false if reach end | |
*/ | |
virtual bool ParseNext(std::vector<RowBlockContainer<IndexType> > *data) = 0; | |
/*! \brief pointer to begin and end of data */ | |
IndexType data_ptr_, data_end_; | |
/*! \brief internal data */ | |
std::vector<RowBlockContainer<IndexType> > data_; | |
/*! \brief internal row block */ | |
RowBlock<IndexType> block_; | |
}; | |
#if DMLC_ENABLE_STD_THREAD | |
template <typename IndexType> | |
class ThreadedParser : public ParserImpl<IndexType> { | |
public: | |
explicit ThreadedParser(ParserImpl<IndexType> *base) | |
: base_(base), tmp_(NULL) { | |
iter_.set_max_capacity(8); | |
iter_.Init([base](std::vector<RowBlockContainer<IndexType> > **dptr) { | |
if (*dptr == NULL) { | |
*dptr = new std::vector<RowBlockContainer<IndexType> >(); | |
} | |
return base->ParseNext(*dptr); | |
}, [base]() {base->BeforeFirst();}); | |
} | |
virtual ~ThreadedParser(void) { | |
// stop things before base is deleted | |
iter_.Destroy(); | |
delete base_; | |
delete tmp_; | |
} | |
virtual void BeforeFirst() { | |
iter_.BeforeFirst(); | |
} | |
/*! \brief implement next */ | |
using ParserImpl<IndexType>::data_ptr_; | |
using ParserImpl<IndexType>::data_end_; | |
virtual bool Next(void) { | |
while (true) { | |
while (data_ptr_ < data_end_) { | |
data_ptr_ += 1; | |
if ((*tmp_)[data_ptr_ - 1].Size() != 0) { | |
this->block_ = (*tmp_)[data_ptr_ - 1].GetBlock(); | |
return true; | |
} | |
} | |
if (tmp_ != NULL) iter_.Recycle(&tmp_); | |
if (!iter_.Next(&tmp_)) break; | |
data_ptr_ = 0; data_end_ = tmp_->size(); | |
} | |
return false; | |
} | |
virtual size_t BytesRead(void) const { | |
return base_->BytesRead(); | |
} | |
protected: | |
virtual bool ParseNext(std::vector<RowBlockContainer<IndexType> > *data) { | |
LOG(FATAL) << "cannot call ParseNext"; return false; | |
} | |
private: | |
/*! \brief the place where we get the data */ | |
Parser<IndexType> *base_; | |
/*! \brief backend threaded iterator */ | |
ThreadedIter<std::vector<RowBlockContainer<IndexType> > > iter_; | |
/*! \brief current chunk of data */ | |
std::vector<RowBlockContainer<IndexType> > *tmp_; | |
}; | |
#endif // DMLC_USE_CXX11 | |
} // namespace data | |
} // namespace dmlc | |
#endif // DMLC_DATA_PARSER_H_ | |
//===== EXPANDED: ../dmlc-core/src/data/parser.h ===== | |
//===== EXPANDING: ../dmlc-core/src/data/basic_row_iter.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file basic_row_iter.h | |
* \brief row based iterator that | |
* loads in everything into memory and returns | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_DATA_BASIC_ROW_ITER_H_ | |
#define DMLC_DATA_BASIC_ROW_ITER_H_ | |
//===== EXPANDING: ../dmlc-core/include/dmlc/timer.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file timer.h | |
* \brief cross platform timer for timing | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_TIMER_H_ | |
#define DMLC_TIMER_H_ | |
#if DMLC_USE_CXX11 | |
#endif | |
#ifdef __MACH__ | |
#endif | |
namespace dmlc { | |
/*! | |
* \brief return time in seconds | |
*/ | |
inline double GetTime(void) { | |
#if DMLC_USE_CXX11 | |
return std::chrono::duration<double>( | |
std::chrono::high_resolution_clock::now().time_since_epoch()).count(); | |
#elif defined __MACH__ | |
clock_serv_t cclock; | |
mach_timespec_t mts; | |
host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock); | |
CHECK(clock_get_time(cclock, &mts) == 0) << "failed to get time"; | |
mach_port_deallocate(mach_task_self(), cclock); | |
return static_cast<double>(mts.tv_sec) + static_cast<double>(mts.tv_nsec) * 1e-9; | |
#else | |
#if defined(__unix__) || defined(__linux__) | |
timespec ts; | |
CHECK(clock_gettime(CLOCK_REALTIME, &ts) == 0) << "failed to get time"; | |
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9; | |
#else | |
return static_cast<double>(time(NULL)); | |
#endif | |
#endif | |
} | |
} // namespace dmlc | |
#endif // DMLC_TIMER_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/timer.h ===== | |
namespace dmlc { | |
namespace data { | |
/*! | |
* \brief basic set of row iterators that provides | |
* \tparam IndexType the type of index we are using | |
*/ | |
template<typename IndexType> | |
class BasicRowIter: public RowBlockIter<IndexType> { | |
public: | |
explicit BasicRowIter(Parser<IndexType> *parser) | |
: at_head_(true) { | |
this->Init(parser); | |
delete parser; | |
} | |
virtual ~BasicRowIter() {} | |
virtual void BeforeFirst(void) { | |
at_head_ = true; | |
} | |
virtual bool Next(void) { | |
if (at_head_) { | |
at_head_ = false; | |
return true; | |
} else { | |
return false; | |
} | |
} | |
virtual const RowBlock<IndexType> &Value(void) const { | |
return row_; | |
} | |
virtual size_t NumCol(void) const { | |
return static_cast<size_t>(data_.max_index) + 1; | |
} | |
private: | |
// at head | |
bool at_head_; | |
// row block to store | |
RowBlock<IndexType> row_; | |
// back end data | |
RowBlockContainer<IndexType> data_; | |
// initialize | |
inline void Init(Parser<IndexType> *parser); | |
}; | |
template<typename IndexType> | |
inline void BasicRowIter<IndexType>::Init(Parser<IndexType> *parser) { | |
data_.Clear(); | |
double tstart = GetTime(); | |
size_t bytes_expect = 10UL << 20UL; | |
while (parser->Next()) { | |
data_.Push(parser->Value()); | |
double tdiff = GetTime() - tstart; | |
size_t bytes_read = parser->BytesRead(); | |
if (bytes_read >= bytes_expect) { | |
bytes_read = bytes_read >> 20UL; | |
LOG(INFO) << bytes_read << "MB read," | |
<< bytes_read / tdiff << " MB/sec"; | |
bytes_expect += 10UL << 20UL; | |
} | |
} | |
row_ = data_.GetBlock(); | |
double tdiff = GetTime() - tstart; | |
LOG(INFO) << "finish reading at " | |
<< (parser->BytesRead() >> 20UL) / tdiff | |
<< " MB/sec"; | |
} | |
} // namespace data | |
} // namespace dmlc | |
#endif // DMLC_DATA_BASIC_ROW_ITER_H__ | |
//===== EXPANDED: ../dmlc-core/src/data/basic_row_iter.h ===== | |
//===== EXPANDING: ../dmlc-core/src/data/disk_row_iter.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file basic_row_iter.h | |
* \brief row based iterator that | |
* caches things into disk and then load segments | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_DATA_DISK_ROW_ITER_H_ | |
#define DMLC_DATA_DISK_ROW_ITER_H_ | |
//===== EXPANDING: ../dmlc-core/src/data/libsvm_parser.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file libsvm_parser.h | |
* \brief iterator parser to parse libsvm format | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_DATA_LIBSVM_PARSER_H_ | |
#define DMLC_DATA_LIBSVM_PARSER_H_ | |
//===== EXPANDING: ../dmlc-core/src/data/text_parser.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file text_parser.h | |
* \brief iterator parser to parse text format | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_DATA_TEXT_PARSER_H_ | |
#define DMLC_DATA_TEXT_PARSER_H_ | |
//===== EXPANDING: ../dmlc-core/include/dmlc/omp.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file omp.h | |
* \brief header to handle OpenMP compatibility issues | |
*/ | |
#ifndef DMLC_OMP_H_ | |
#define DMLC_OMP_H_ | |
#if defined(_OPENMP) | |
#else | |
#ifndef DISABLE_OPENMP | |
// use pragma message instead of warning | |
#pragma message("Warning: OpenMP is not available, " \ | |
"project will be compiled into single-thread code. " \ | |
"Use OpenMP-enabled compiler to get benefit of multi-threading.") | |
#endif | |
//! \cond Doxygen_Suppress | |
inline int omp_get_thread_num() { return 0; } | |
inline int omp_get_num_threads() { return 1; } | |
inline int omp_get_max_threads() { return 1; } | |
inline int omp_get_num_procs() { return 1; } | |
inline void omp_set_num_threads(int nthread) {} | |
#endif | |
// loop variable used in openmp | |
namespace dmlc { | |
#ifdef _MSC_VER | |
typedef int omp_uint; | |
typedef long omp_ulong; // NOLINT(*) | |
#else | |
typedef unsigned omp_uint; | |
typedef unsigned long omp_ulong; // NOLINT(*) | |
#endif | |
//! \endcond | |
} // namespace dmlc | |
#endif // DMLC_OMP_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/omp.h ===== | |
namespace dmlc { | |
namespace data { | |
/*! | |
* \brief Text parser that parses the input lines | |
* and returns rows in input data | |
*/ | |
template <typename IndexType> | |
class TextParserBase : public ParserImpl<IndexType> { | |
public: | |
explicit TextParserBase(InputSplit *source, | |
int nthread) | |
: bytes_read_(0), source_(source) { | |
int maxthread; | |
#pragma omp parallel | |
{ | |
maxthread = std::max(omp_get_num_procs() / 2 - 4, 1); | |
} | |
nthread_ = std::min(maxthread, nthread); | |
} | |
virtual ~TextParserBase() { | |
delete source_; | |
} | |
virtual void BeforeFirst(void) { | |
source_->BeforeFirst(); | |
} | |
virtual size_t BytesRead(void) const { | |
return bytes_read_; | |
} | |
virtual bool ParseNext(std::vector<RowBlockContainer<IndexType> > *data) { | |
return FillData(data); | |
} | |
protected: | |
/*! | |
* \brief parse data into out | |
* \param begin beginning of buffer | |
* \param end end of buffer | |
*/ | |
virtual void ParseBlock(char *begin, | |
char *end, | |
RowBlockContainer<IndexType> *out) = 0; | |
/*! | |
* \brief read in next several blocks of data | |
* \param data vector of data to be returned | |
* \return true if the data is loaded, false if reach end | |
*/ | |
inline bool FillData(std::vector<RowBlockContainer<IndexType> > *data); | |
/*! | |
* \brief start from bptr, go backward and find first endof line | |
* \param bptr end position to go backward | |
* \param begin the beginning position of buffer | |
* \return position of first endof line going backward | |
*/ | |
inline char* BackFindEndLine(char *bptr, | |
char *begin) { | |
for (; bptr != begin; --bptr) { | |
if (*bptr == '\n' || *bptr == '\r') return bptr; | |
} | |
return begin; | |
} | |
private: | |
// nthread | |
int nthread_; | |
// number of bytes readed | |
size_t bytes_read_; | |
// source split that provides the data | |
InputSplit *source_; | |
}; | |
// implementation | |
template <typename IndexType> | |
inline bool TextParserBase<IndexType>:: | |
FillData(std::vector<RowBlockContainer<IndexType> > *data) { | |
InputSplit::Blob chunk; | |
if (!source_->NextChunk(&chunk)) return false; | |
const int nthread = omp_get_max_threads(); | |
// reserve space for data | |
data->resize(nthread); | |
bytes_read_ += chunk.size; | |
CHECK_NE(chunk.size, 0U); | |
char *head = reinterpret_cast<char*>(chunk.dptr); | |
#pragma omp parallel num_threads(nthread) | |
{ | |
// threadid | |
int tid = omp_get_thread_num(); | |
size_t nstep = (chunk.size + nthread - 1) / nthread; | |
size_t sbegin = std::min(tid * nstep, chunk.size); | |
size_t send = std::min((tid + 1) * nstep, chunk.size); | |
char *pbegin = BackFindEndLine(head + sbegin, head); | |
char *pend; | |
if (tid + 1 == nthread) { | |
pend = head + send; | |
} else { | |
pend = BackFindEndLine(head + send, head); | |
} | |
ParseBlock(pbegin, pend, &(*data)[tid]); | |
} | |
this->data_ptr_ = 0; | |
return true; | |
} | |
} // namespace data | |
} // namespace dmlc | |
#endif // DMLC_DATA_TEXT_PARSER_H_ | |
//===== EXPANDED: ../dmlc-core/src/data/text_parser.h ===== | |
//===== EXPANDING: ../dmlc-core/src/data/strtonum.h ===== | |
/*! | |
*x Copyright (c) 2015 by Contributors | |
* \file strtonum.h | |
* \brief A faster implementation of strtod, ... | |
*/ | |
#ifndef DMLC_DATA_STRTONUM_H_ | |
#define DMLC_DATA_STRTONUM_H_ | |
namespace dmlc { | |
namespace data { | |
inline bool isspace(char c) { | |
return (c == ' ' || c == '\t' || c == '\r' || c == '\n' || c == '\f'); | |
} | |
inline bool isblank(char c) { | |
return (c == ' ' || c == '\t'); | |
} | |
inline bool isdigit(char c) { | |
return (c >= '0' && c <= '9'); | |
} | |
inline bool isdigitchars(char c) { | |
return (c >= '0' && c <= '9') | |
|| c == '+' || c == '-' | |
|| c == '.' | |
|| c == 'e' || c == 'E'; | |
} | |
/*! | |
* \brief A faster version of strtof | |
* TODO the current version does not support INF, NAN, and hex number | |
*/ | |
inline float strtof(const char *nptr, char **endptr) { | |
const char *p = nptr; | |
// Skip leading white space, if any. Not necessary | |
while (isspace(*p) ) ++p; | |
// Get sign, if any. | |
bool sign = true; | |
if (*p == '-') { | |
sign = false; ++p; | |
} else if (*p == '+') { | |
++p; | |
} | |
// Get digits before decimal point or exponent, if any. | |
float value; | |
for (value = 0; isdigit(*p); ++p) { | |
value = value * 10.0f + (*p - '0'); | |
} | |
// Get digits after decimal point, if any. | |
if (*p == '.') { | |
uint64_t pow10 = 1; | |
uint64_t val2 = 0; | |
++p; | |
while (isdigit(*p)) { | |
val2 = val2 * 10 + (*p - '0'); | |
pow10 *= 10; | |
++p; | |
} | |
value += static_cast<float>( | |
static_cast<double>(val2) / static_cast<double>(pow10)); | |
} | |
// Handle exponent, if any. | |
if ((*p == 'e') || (*p == 'E')) { | |
++p; | |
bool frac = false; | |
float scale = 1.0; | |
unsigned expon; | |
// Get sign of exponent, if any. | |
if (*p == '-') { | |
frac = true; | |
++p; | |
} else if (*p == '+') { | |
++p; | |
} | |
// Get digits of exponent, if any. | |
for (expon = 0; isdigit(*p); p += 1) { | |
expon = expon * 10 + (*p - '0'); | |
} | |
if (expon > 38) expon = 38; | |
// Calculate scaling factor. | |
while (expon >= 8) { scale *= 1E8; expon -= 8; } | |
while (expon > 0) { scale *= 10.0; expon -= 1; } | |
// Return signed and scaled floating point result. | |
value = frac ? (value / scale) : (value * scale); | |
} | |
if (endptr) *endptr = (char*)p; // NOLINT(*) | |
return sign ? value : - value; | |
} | |
/** | |
* \brief A faster string to integer convertor | |
* TODO only support base <=10 | |
*/ | |
template <typename V> | |
inline V strtoint(const char* nptr, char **endptr, int base) { | |
const char *p = nptr; | |
// Skip leading white space, if any. Not necessary | |
while (isspace(*p) ) ++p; | |
// Get sign if any | |
bool sign = true; | |
if (*p == '-') { | |
sign = false; ++p; | |
} else if (*p == '+') { | |
++p; | |
} | |
V value; | |
for (value = 0; isdigit(*p); ++p) { | |
value = value * base + (*p - '0'); | |
} | |
if (endptr) *endptr = (char*)p; // NOLINT(*) | |
return sign ? value : - value; | |
} | |
template <typename V> | |
inline V strtouint(const char* nptr, char **endptr, int base) { | |
const char *p = nptr; | |
// Skip leading white space, if any. Not necessary | |
while (isspace(*p)) ++p; | |
// Get sign if any | |
bool sign = true; | |
if (*p == '-') { | |
sign = false; ++p; | |
} else if (*p == '+') { | |
++p; | |
} | |
// we are parsing unsigned, so no minus sign should be found | |
CHECK_EQ(sign, true); | |
V value; | |
for (value = 0; isdigit(*p); ++p) { | |
value = value * base + (*p - '0'); | |
} | |
if (endptr) *endptr = (char*)p; // NOLINT(*) | |
return value; | |
} | |
inline uint64_t | |
strtoull(const char* nptr, char **endptr, int base) { | |
return strtouint<uint64_t>(nptr, endptr, base); | |
} | |
inline long atol(const char* p) { // NOLINT(*) | |
return strtoint<long>(p, 0, 10); // NOLINT(*) | |
} | |
inline float atof(const char *nptr) { | |
return strtof(nptr, 0); | |
} | |
template<typename T> | |
class Str2T { | |
public: | |
static inline T get(const char * begin, const char * end); | |
}; | |
template<typename T> | |
inline T Str2Type(const char * begin, const char * end) { | |
return Str2T<T>::get(begin, end); | |
} | |
template<> | |
class Str2T<int32_t> { | |
public: | |
static inline int32_t get(const char * begin, const char * end) { | |
return strtoint<int>(begin, NULL, 10); | |
} | |
}; | |
template<> | |
class Str2T<uint32_t> { | |
public: | |
static inline uint32_t get(const char * begin, const char * end) { | |
return strtouint<int>(begin, NULL, 10); | |
} | |
}; | |
template<> | |
class Str2T<int64_t> { | |
public: | |
static inline int64_t get(const char * begin, const char * end) { | |
return strtoint<int64_t>(begin, NULL, 10); | |
} | |
}; | |
template<> | |
class Str2T<uint64_t> { | |
public: | |
static inline uint64_t get(const char * begin, const char * end) { | |
return strtouint<uint64_t>(begin, NULL, 10); | |
} | |
}; | |
template<> | |
class Str2T<float> { | |
public: | |
static inline float get(const char * begin, const char * end) { | |
return atof(begin); | |
} | |
}; | |
/** | |
* \brief Parse colon seperated pair v1[:v2] | |
* \param begin: pointer to string | |
* \param end: one past end of string | |
* \param parseEnd: end string of parsed string | |
* \param v1: first value in the pair | |
* \param v2: second value in the pair | |
* \output number of values parsed | |
*/ | |
template<typename T1, typename T2> | |
inline int ParsePair(const char * begin, const char * end, | |
const char ** endptr, T1 &v1, T2 &v2) { // NOLINT(*) | |
const char * p = begin; | |
while (p != end && !isdigitchars(*p)) ++p; | |
if (p == end) { | |
*endptr = end; | |
return 0; | |
} | |
const char * q = p; | |
while (q != end && isdigitchars(*q)) ++q; | |
v1 = Str2Type<T1>(p, q); | |
p = q; | |
while (p != end && isblank(*p)) ++p; | |
if (p == end || *p != ':') { | |
// only v1 | |
*endptr = p; | |
return 1; | |
} | |
p++; | |
while (p != end && !isdigitchars(*p)) ++p; | |
q = p; | |
while (q != end && isdigitchars(*q)) ++q; | |
*endptr = q; | |
v2 = Str2Type<T2>(p, q); | |
return 2; | |
} | |
} // namespace data | |
} // namespace dmlc | |
#endif // DMLC_DATA_STRTONUM_H_ | |
//===== EXPANDED: ../dmlc-core/src/data/strtonum.h ===== | |
namespace dmlc { | |
namespace data { | |
/*! | |
* \brief Text parser that parses the input lines | |
* and returns rows in input data | |
*/ | |
template <typename IndexType> | |
class LibSVMParser : public TextParserBase<IndexType> { | |
public: | |
explicit LibSVMParser(InputSplit *source, | |
int nthread) | |
: TextParserBase<IndexType>(source, nthread) {} | |
protected: | |
virtual void ParseBlock(char *begin, | |
char *end, | |
RowBlockContainer<IndexType> *out); | |
}; | |
template <typename IndexType> | |
void LibSVMParser<IndexType>:: | |
ParseBlock(char *begin, | |
char *end, | |
RowBlockContainer<IndexType> *out) { | |
out->Clear(); | |
char * lbegin = begin; | |
char * lend = lbegin; | |
while (lbegin != end) { | |
// get line end | |
lend = lbegin + 1; | |
while (lend != end && *lend != '\n' && *lend != '\r') ++lend; | |
// parse label[:weight] | |
const char * p = lbegin; | |
const char * q = NULL; | |
real_t label; | |
real_t weight; | |
int r = ParsePair<real_t, real_t>(p, lend, &q, label, weight); | |
if (r < 1) { | |
// empty line | |
lbegin = lend; | |
continue; | |
} | |
if (r == 2) { | |
// has weight | |
out->weight.push_back(weight); | |
} | |
if (out->label.size() != 0) { | |
out->offset.push_back(out->index.size()); | |
} | |
out->label.push_back(label); | |
// parse feature[:value] | |
p = q; | |
while (p != lend) { | |
IndexType featureId; | |
real_t value; | |
int r = ParsePair<IndexType, real_t>(p, lend, &q, featureId, value); | |
if (r < 1) { | |
p = q; | |
continue; | |
} | |
out->index.push_back(featureId); | |
if (r == 2) { | |
// has value | |
out->value.push_back(value); | |
} | |
p = q; | |
} | |
// next line | |
lbegin = lend; | |
} | |
if (out->label.size() != 0) { | |
out->offset.push_back(out->index.size()); | |
} | |
CHECK(out->label.size() + 1 == out->offset.size()); | |
} | |
} // namespace data | |
} // namespace dmlc | |
#endif // DMLC_DATA_LIBSVM_PARSER_H_ | |
//===== EXPANDED: ../dmlc-core/src/data/libsvm_parser.h ===== | |
#if DMLC_ENABLE_STD_THREAD | |
namespace dmlc { | |
namespace data { | |
/*! | |
* \brief basic set of row iterators that provides | |
* \tparam IndexType the type of index we are using | |
*/ | |
template<typename IndexType> | |
class DiskRowIter: public RowBlockIter<IndexType> { | |
public: | |
// page size 64MB | |
static const size_t kPageSize = 64UL << 20UL; | |
/*! | |
* \brief disk row iterator constructor | |
* \param parser parser used to generate this | |
*/ | |
explicit DiskRowIter(Parser<IndexType> *parser, | |
const char *cache_file, | |
bool reuse_cache) | |
: cache_file_(cache_file), fi_(NULL) { | |
if (reuse_cache) { | |
if (!TryLoadCache()) { | |
this->BuildCache(parser); | |
CHECK(TryLoadCache()) | |
<< "failed to build cache file " << cache_file; | |
} | |
} else { | |
this->BuildCache(parser); | |
CHECK(TryLoadCache()) | |
<< "failed to build cache file " << cache_file; | |
} | |
delete parser; | |
} | |
virtual ~DiskRowIter(void) { | |
iter_.Destroy(); | |
delete fi_; | |
} | |
virtual void BeforeFirst(void) { | |
iter_.BeforeFirst(); | |
} | |
virtual bool Next(void) { | |
if (iter_.Next()) { | |
row_ = iter_.Value().GetBlock(); | |
return true; | |
} else { | |
return false; | |
} | |
} | |
virtual const RowBlock<IndexType> &Value(void) const { | |
return row_; | |
} | |
virtual size_t NumCol(void) const { | |
return num_col_; | |
} | |
private: | |
// file place | |
std::string cache_file_; | |
// input stream | |
SeekStream *fi_; | |
// maximum feature dimension | |
size_t num_col_; | |
// row block to store | |
RowBlock<IndexType> row_; | |
// iterator | |
ThreadedIter<RowBlockContainer<IndexType> > iter_; | |
// load disk cache file | |
inline bool TryLoadCache(void); | |
// build disk cache | |
inline void BuildCache(Parser<IndexType> *parser); | |
}; | |
// build disk cache | |
template<typename IndexType> | |
inline bool DiskRowIter<IndexType>::TryLoadCache(void) { | |
SeekStream *fi = SeekStream::CreateForRead(cache_file_.c_str(), true); | |
if (fi == NULL) return false; | |
this->fi_ = fi; | |
iter_.Init([fi](RowBlockContainer<IndexType> **dptr) { | |
if (*dptr ==NULL) { | |
*dptr = new RowBlockContainer<IndexType>(); | |
} | |
return (*dptr)->Load(fi); | |
}, | |
[fi]() { fi->Seek(0); }); | |
return true; | |
} | |
template<typename IndexType> | |
inline void DiskRowIter<IndexType>:: | |
BuildCache(Parser<IndexType> *parser) { | |
Stream *fo = Stream::Create(cache_file_.c_str(), "w"); | |
// back end data | |
RowBlockContainer<IndexType> data; | |
num_col_ = 0; | |
double tstart = GetTime(); | |
while (parser->Next()) { | |
data.Push(parser->Value()); | |
double tdiff = GetTime() - tstart; | |
if (data.MemCostBytes() >= kPageSize) { | |
size_t bytes_read = parser->BytesRead(); | |
bytes_read = bytes_read >> 20UL; | |
LOG(INFO) << bytes_read << "MB read," | |
<< bytes_read / tdiff << " MB/sec"; | |
data.Save(fo); | |
data.Clear(); | |
num_col_ = std::max(num_col_, | |
static_cast<size_t>(data.max_index) + 1); | |
} | |
} | |
if (data.Size() != 0) { | |
data.Save(fo); | |
} | |
delete fo; | |
double tdiff = GetTime() - tstart; | |
LOG(INFO) << "finish reading at %g MB/sec" | |
<< (parser->BytesRead() >> 20UL) / tdiff; | |
} | |
} // namespace data | |
} // namespace dmlc | |
#endif // DMLC_USE_CXX11 | |
#endif // DMLC_DATA_DISK_ROW_ITER_H_ | |
//===== EXPANDED: ../dmlc-core/src/data/disk_row_iter.h ===== | |
//===== EXPANDING: ../dmlc-core/src/data/csv_parser.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file csv_parser.h | |
* \brief iterator parser to parse csv format | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_DATA_CSV_PARSER_H_ | |
#define DMLC_DATA_CSV_PARSER_H_ | |
namespace dmlc { | |
namespace data { | |
struct CSVParserParam : public Parameter<CSVParserParam> { | |
std::string format; | |
int label_column; | |
// declare parameters | |
DMLC_DECLARE_PARAMETER(CSVParserParam) { | |
DMLC_DECLARE_FIELD(format).set_default("csv") | |
.describe("File format."); | |
DMLC_DECLARE_FIELD(label_column).set_default(-1) | |
.describe("Column index that will put into label."); | |
} | |
}; | |
/*! | |
* \brief CSVParser, parses a dense csv format. | |
* Currently is a dummy implementation, when label column is not specified. | |
* All columns are treated as real dense data. | |
* label will be assigned to 0. | |
* | |
* This should be extended in future to accept arguments of column types. | |
*/ | |
template <typename IndexType> | |
class CSVParser : public TextParserBase<IndexType> { | |
public: | |
explicit CSVParser(InputSplit *source, | |
const std::map<std::string, std::string>& args, | |
int nthread) | |
: TextParserBase<IndexType>(source, nthread) { | |
param_.Init(args); | |
CHECK_EQ(param_.format, "csv"); | |
} | |
protected: | |
virtual void ParseBlock(char *begin, | |
char *end, | |
RowBlockContainer<IndexType> *out); | |
private: | |
CSVParserParam param_; | |
}; | |
template <typename IndexType> | |
void CSVParser<IndexType>:: | |
ParseBlock(char *begin, | |
char *end, | |
RowBlockContainer<IndexType> *out) { | |
out->Clear(); | |
char * lbegin = begin; | |
char * lend = lbegin; | |
while (lbegin != end) { | |
// get line end | |
lend = lbegin + 1; | |
while (lend != end && *lend != '\n' && *lend != '\r') ++lend; | |
char* p = lbegin; | |
int column_index = 0; | |
IndexType idx = 0; | |
float label = 0.0f; | |
while (p != lend) { | |
char *endptr; | |
float v = strtof(p, &endptr); | |
p = endptr; | |
if (column_index == param_.label_column) { | |
label = v; | |
} else { | |
out->value.push_back(v); | |
out->index.push_back(idx++); | |
} | |
++column_index; | |
while (*p != ',' && p != lend) ++p; | |
if (p != lend) ++p; | |
} | |
// skip empty line | |
while ((*lend == '\n' || *lend == '\r') && lend != end) ++lend; | |
lbegin = lend; | |
out->label.push_back(label); | |
out->offset.push_back(out->index.size()); | |
} | |
CHECK(out->label.size() + 1 == out->offset.size()); | |
} | |
} // namespace data | |
} // namespace dmlc | |
#endif // DMLC_DATA_CSV_PARSER_H_ | |
//===== EXPANDED: ../dmlc-core/src/data/csv_parser.h ===== | |
namespace dmlc { | |
/*! \brief namespace for useful input data structure */ | |
namespace data { | |
template<typename IndexType> | |
Parser<IndexType> * | |
CreateLibSVMParser(const std::string& path, | |
const std::map<std::string, std::string>& args, | |
unsigned part_index, | |
unsigned num_parts) { | |
InputSplit* source = InputSplit::Create( | |
path.c_str(), part_index, num_parts, "text"); | |
ParserImpl<IndexType> *parser = new LibSVMParser<IndexType>(source, 2); | |
#if DMLC_ENABLE_STD_THREAD | |
parser = new ThreadedParser<IndexType>(parser); | |
#endif | |
return parser; | |
} | |
template<typename IndexType> | |
Parser<IndexType> * | |
CreateCSVParser(const std::string& path, | |
const std::map<std::string, std::string>& args, | |
unsigned part_index, | |
unsigned num_parts) { | |
InputSplit* source = InputSplit::Create( | |
path.c_str(), part_index, num_parts, "text"); | |
return new CSVParser<IndexType>(source, args, 2); | |
} | |
template<typename IndexType> | |
inline Parser<IndexType> * | |
CreateParser_(const char *uri_, | |
unsigned part_index, | |
unsigned num_parts, | |
const char *type) { | |
std::string ptype = type; | |
io::URISpec spec(uri_, part_index, num_parts); | |
if (ptype == "auto") { | |
if (spec.args.count("format") != 0) { | |
ptype = spec.args.at("format"); | |
} else { | |
ptype = "libsvm"; | |
} | |
} | |
const ParserFactoryReg<IndexType>* e = | |
Registry<ParserFactoryReg<IndexType> >::Get()->Find(ptype); | |
if (e == NULL) { | |
LOG(FATAL) << "Unknown data type " << ptype; | |
} | |
// create parser | |
return (*e->body)(spec.uri, spec.args, part_index, num_parts); | |
} | |
template<typename IndexType> | |
inline RowBlockIter<IndexType> * | |
CreateIter_(const char *uri_, | |
unsigned part_index, | |
unsigned num_parts, | |
const char *type) { | |
using namespace std; | |
io::URISpec spec(uri_, part_index, num_parts); | |
Parser<IndexType> *parser = CreateParser_<IndexType> | |
(spec.uri.c_str(), part_index, num_parts, type); | |
if (spec.cache_file.length() != 0) { | |
#if DMLC_ENABLE_STD_THREAD | |
return new DiskRowIter<IndexType>(parser, spec.cache_file.c_str(), true); | |
#else | |
LOG(FATAL) << "compile with c++0x or c++11 to enable cache file"; | |
return NULL; | |
#endif | |
} else { | |
return new BasicRowIter<IndexType>(parser); | |
} | |
} | |
DMLC_REGISTER_PARAMETER(CSVParserParam); | |
} // namespace data | |
// template specialization | |
template<> | |
RowBlockIter<uint32_t> * | |
RowBlockIter<uint32_t>::Create(const char *uri, | |
unsigned part_index, | |
unsigned num_parts, | |
const char *type) { | |
return data::CreateIter_<uint32_t>(uri, part_index, num_parts, type); | |
} | |
template<> | |
RowBlockIter<uint64_t> * | |
RowBlockIter<uint64_t>::Create(const char *uri, | |
unsigned part_index, | |
unsigned num_parts, | |
const char *type) { | |
return data::CreateIter_<uint64_t>(uri, part_index, num_parts, type); | |
} | |
template<> | |
Parser<uint32_t> * | |
Parser<uint32_t>::Create(const char *uri_, | |
unsigned part_index, | |
unsigned num_parts, | |
const char *type) { | |
return data::CreateParser_<uint32_t>(uri_, part_index, num_parts, type); | |
} | |
template<> | |
Parser<uint64_t> * | |
Parser<uint64_t>::Create(const char *uri_, | |
unsigned part_index, | |
unsigned num_parts, | |
const char *type) { | |
return data::CreateParser_<uint64_t>(uri_, part_index, num_parts, type); | |
} | |
// registry | |
DMLC_REGISTRY_ENABLE(ParserFactoryReg<uint32_t>); | |
DMLC_REGISTRY_ENABLE(ParserFactoryReg<uint64_t>); | |
DMLC_REGISTER_DATA_PARSER(uint32_t, libsvm, data::CreateLibSVMParser<uint32_t>); | |
DMLC_REGISTER_DATA_PARSER(uint64_t, libsvm, data::CreateLibSVMParser<uint64_t>); | |
DMLC_REGISTER_DATA_PARSER(uint32_t, csv, data::CreateCSVParser<uint32_t>); | |
} // namespace dmlc | |
//===== EXPANDED: ../dmlc-core/src/data.cc ===== | |
//===== EXPANDING: ../dmlc-core/src/io.cc ===== | |
// Copyright by Contributors | |
//===== EXPANDING: ../dmlc-core/src/io/single_file_split.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file single_file_split.h | |
* \brief base implementation of line-spliter | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_IO_SINGLE_FILE_SPLIT_H_ | |
#define DMLC_IO_SINGLE_FILE_SPLIT_H_ | |
#if defined(__FreeBSD__) | |
#define fopen64 std::fopen | |
#endif | |
namespace dmlc { | |
namespace io { | |
/*! | |
* \brief line split implementation from single FILE | |
* simply returns lines of files, used for stdin | |
*/ | |
class SingleFileSplit : public InputSplit { | |
public: | |
explicit SingleFileSplit(const char *fname) | |
: use_stdin_(false), buffer_size_(kBufferSize), | |
chunk_begin_(NULL), chunk_end_(NULL) { | |
if (!std::strcmp(fname, "stdin")) { | |
#ifndef DMLC_STRICT_CXX98_ | |
use_stdin_ = true; fp_ = stdin; | |
#endif | |
} | |
if (!use_stdin_) { | |
fp_ = fopen64(fname, "rb"); | |
CHECK(fp_ != NULL) << "SingleFileSplit: fail to open " << fname; | |
} | |
buffer_.resize(kBufferSize); | |
} | |
virtual ~SingleFileSplit(void) { | |
if (!use_stdin_) std::fclose(fp_); | |
} | |
virtual void BeforeFirst(void) { | |
fseek(fp_, 0, SEEK_SET); | |
} | |
virtual void HintChunkSize(size_t chunk_size) { | |
buffer_size_ = std::max(chunk_size, buffer_size_); | |
} | |
virtual size_t GetTotalSize(void) { | |
struct stat buf; | |
fstat(fileno(fp_), &buf); | |
return buf.st_size; | |
} | |
virtual size_t Read(void *ptr, size_t size) { | |
return std::fread(ptr, 1, size, fp_); | |
} | |
virtual void ResetPartition(unsigned part_index, unsigned num_parts) { | |
CHECK(part_index == 0 && num_parts == 1); | |
this->BeforeFirst(); | |
} | |
virtual void Write(const void *ptr, size_t size) { | |
LOG(FATAL) << "InputSplit do not support write"; | |
} | |
virtual bool NextRecord(Blob *out_rec) { | |
if (chunk_begin_ == chunk_end_) { | |
if (!LoadChunk()) return false; | |
} | |
char *next = FindNextRecord(chunk_begin_, | |
chunk_end_); | |
out_rec->dptr = chunk_begin_; | |
out_rec->size = next - chunk_begin_; | |
chunk_begin_ = next; | |
return true; | |
} | |
virtual bool NextChunk(Blob *out_chunk) { | |
if (chunk_begin_ == chunk_end_) { | |
if (!LoadChunk()) return false; | |
} | |
out_chunk->dptr = chunk_begin_; | |
out_chunk->size = chunk_end_ - chunk_begin_; | |
chunk_begin_ = chunk_end_; | |
return true; | |
} | |
inline bool ReadChunk(void *buf, size_t *size) { | |
size_t max_size = *size; | |
if (max_size <= overflow_.length()) { | |
*size = 0; return true; | |
} | |
if (overflow_.length() != 0) { | |
std::memcpy(buf, BeginPtr(overflow_), overflow_.length()); | |
} | |
size_t olen = overflow_.length(); | |
overflow_.resize(0); | |
size_t nread = this->Read(reinterpret_cast<char*>(buf) + olen, | |
max_size - olen); | |
nread += olen; | |
if (nread == 0) return false; | |
if (nread != max_size) { | |
*size = nread; | |
return true; | |
} else { | |
const char *bptr = reinterpret_cast<const char*>(buf); | |
// return the last position where a record starts | |
const char *bend = this->FindLastRecordBegin(bptr, bptr + max_size); | |
*size = bend - bptr; | |
overflow_.resize(max_size - *size); | |
if (overflow_.length() != 0) { | |
std::memcpy(BeginPtr(overflow_), bend, overflow_.length()); | |
} | |
return true; | |
} | |
} | |
protected: | |
inline const char* FindLastRecordBegin(const char *begin, | |
const char *end) { | |
if (begin == end) return begin; | |
for (const char *p = end - 1; p != begin; --p) { | |
if (*p == '\n' || *p == '\r') return p + 1; | |
} | |
return begin; | |
} | |
inline char* FindNextRecord(char *begin, char *end) { | |
char *p; | |
for (p = begin; p != end; ++p) { | |
if (*p == '\n' || *p == '\r') break; | |
} | |
for (; p != end; ++p) { | |
if (*p != '\n' && *p != '\r') return p; | |
} | |
return end; | |
} | |
inline bool LoadChunk(void) { | |
if (buffer_.length() < buffer_size_) { | |
buffer_.resize(buffer_size_); | |
} | |
while (true) { | |
size_t size = buffer_.length(); | |
if (!ReadChunk(BeginPtr(buffer_), &size)) return false; | |
if (size == 0) { | |
buffer_.resize(buffer_.length() * 2); | |
} else { | |
chunk_begin_ = reinterpret_cast<char *>(BeginPtr(buffer_)); | |
chunk_end_ = chunk_begin_ + size; | |
break; | |
} | |
} | |
return true; | |
} | |
private: | |
// buffer size | |
static const size_t kBufferSize = 1 << 18UL; | |
// file | |
std::FILE *fp_; | |
bool use_stdin_; | |
// internal overflow | |
std::string overflow_; | |
// internal buffer | |
std::string buffer_; | |
// internal buffer size | |
size_t buffer_size_; | |
// beginning of chunk | |
char *chunk_begin_; | |
// end of chunk | |
char *chunk_end_; | |
}; | |
} // namespace io | |
} // namespace dmlc | |
#endif // DMLC_IO_SINGLE_FILE_SPLIT_H_ | |
//===== EXPANDED: ../dmlc-core/src/io/single_file_split.h ===== | |
//===== EXPANDING: ../dmlc-core/src/io/cached_input_split.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file cached_input_split.h | |
* \brief InputSplit that reads from an existing InputSplit | |
* and cache the data into local disk, the second iteration | |
* will be reading from the local cached data | |
* \author Tianqi Chen | |
*/ | |
#ifndef DMLC_IO_CACHED_INPUT_SPLIT_H_ | |
#define DMLC_IO_CACHED_INPUT_SPLIT_H_ | |
// this code depends on c++11 | |
#if DMLC_ENABLE_STD_THREAD | |
namespace dmlc { | |
namespace io { | |
/*! | |
* \brief InputSplit that reads from an existing InputSplit | |
* and cache the data into local disk, the second iteration | |
* will be reading from the local cached data | |
*/ | |
class CachedInputSplit : public InputSplit { | |
public: | |
/*! | |
* \brief constructor | |
* \param base source input split | |
* \param cache_file the path to cache file | |
* \param reuse_exist_cache whether reuse existing cache file, if any | |
*/ | |
CachedInputSplit(InputSplitBase *base, | |
const char *cache_file, | |
bool reuse_exist_cache = true) | |
: buffer_size_(InputSplitBase::kBufferSize), | |
cache_file_(cache_file), | |
fo_(NULL), fi_(NULL), | |
base_(base), tmp_chunk_(NULL), | |
iter_preproc_(NULL) { | |
if (reuse_exist_cache) { | |
if (!this->InitCachedIter()) { | |
this->InitPreprocIter(); | |
} | |
} else { | |
this->InitPreprocIter(); | |
} | |
} | |
// destructor | |
virtual ~CachedInputSplit(void) { | |
// NOTE delete can handle NULL ptr | |
// deletion order matters | |
delete iter_preproc_; | |
delete fo_; | |
iter_cached_.Destroy(); | |
delete tmp_chunk_; | |
delete base_; | |
delete fi_; | |
} | |
virtual void BeforeFirst(void) { | |
// if preprocessing did not end | |
// pull data from preprocessing module | |
if (iter_preproc_ != NULL) { | |
if (tmp_chunk_ != NULL) { | |
iter_preproc_->Recycle(&tmp_chunk_); | |
} | |
while (iter_preproc_->Next(&tmp_chunk_)) { | |
iter_preproc_->Recycle(&tmp_chunk_); | |
} | |
// finalize the push out process | |
delete iter_preproc_; | |
delete fo_; | |
iter_preproc_ = NULL; | |
fo_ = NULL; | |
CHECK(this->InitCachedIter()) | |
<< "Failed to initialize CachedIter"; | |
} else { | |
iter_cached_.BeforeFirst(); | |
} | |
if (tmp_chunk_ != NULL) { | |
iter_cached_.Recycle(&tmp_chunk_); | |
} | |
} | |
virtual void ResetPartition(unsigned part_index, unsigned num_parts) { | |
LOG(FATAL) << "ResetPartition is not supported in CachedInputSplit"; | |
} | |
virtual void HintChunkSize(size_t chunk_size) { | |
buffer_size_ = std::max(chunk_size / sizeof(size_t), buffer_size_); | |
} | |
virtual size_t GetTotalSize(void) { | |
return base_->GetTotalSize(); | |
} | |
// implement next record | |
virtual bool NextRecord(Blob *out_rec) { | |
auto *iter = iter_preproc_ != NULL ? iter_preproc_ : &iter_cached_; | |
if (tmp_chunk_ == NULL) { | |
if (!iter->Next(&tmp_chunk_)) return false; | |
} | |
while (!base_->ExtractNextRecord(out_rec, tmp_chunk_)) { | |
iter->Recycle(&tmp_chunk_); | |
if (!iter->Next(&tmp_chunk_)) return false; | |
} | |
return true; | |
} | |
// implement next chunk | |
virtual bool NextChunk(Blob *out_chunk) { | |
auto *iter = iter_preproc_ != NULL ? iter_preproc_ : &iter_cached_; | |
if (tmp_chunk_ == NULL) { | |
if (!iter->Next(&tmp_chunk_)) return false; | |
} | |
while (!base_->ExtractNextChunk(out_chunk, tmp_chunk_)) { | |
iter->Recycle(&tmp_chunk_); | |
if (!iter->Next(&tmp_chunk_)) return false; | |
} | |
return true; | |
} | |
private: | |
/*! \brief internal buffer size */ | |
size_t buffer_size_; | |
/*! \brief cache file path */ | |
std::string cache_file_; | |
/*! \brief output stream to cache file*/ | |
dmlc::Stream *fo_; | |
/*! \brief input stream from cache file */ | |
dmlc::SeekStream *fi_; | |
/*! \brief the place where we get the data */ | |
InputSplitBase *base_; | |
/*! \brief current chunk of data */ | |
InputSplitBase::Chunk *tmp_chunk_; | |
/*! \brief backend thread iterator for preprocessing */ | |
ThreadedIter<InputSplitBase::Chunk> *iter_preproc_; | |
/*! \brief backend thread iterator for cache */ | |
ThreadedIter<InputSplitBase::Chunk> iter_cached_; | |
/*! \brief initialize the cached iterator */ | |
inline void InitPreprocIter(void); | |
/*! | |
* \brief initialize the cached iterator | |
* \return wheher the file exist and | |
* initialization is successful | |
*/ | |
inline bool InitCachedIter(void); | |
}; | |
inline void CachedInputSplit:: InitPreprocIter(void) { | |
fo_ = dmlc::Stream::Create(cache_file_.c_str(), "w"); | |
iter_preproc_ = new ThreadedIter<InputSplitBase::Chunk>(); | |
iter_preproc_->set_max_capacity(16); | |
iter_preproc_->Init([this](InputSplitBase::Chunk **dptr) { | |
if (*dptr == NULL) { | |
*dptr = new InputSplitBase::Chunk(buffer_size_); | |
} | |
auto *p = *dptr; | |
if (!p->Load(base_, buffer_size_)) return false; | |
// after loading, save to disk | |
size_t size = p->end - p->begin; | |
fo_->Write(&size, sizeof(size)); | |
fo_->Write(p->begin, size); | |
return true; | |
}); | |
} | |
inline bool CachedInputSplit::InitCachedIter(void) { | |
fi_ = dmlc::SeekStream::CreateForRead(cache_file_.c_str(), true); | |
if (fi_ == NULL) return false; | |
iter_cached_.Init([this](InputSplitBase::Chunk **dptr) { | |
if (*dptr == NULL) { | |
*dptr = new InputSplitBase::Chunk(buffer_size_); | |
} | |
auto *p = *dptr; | |
// read data from cache file | |
size_t size; | |
size_t nread = fi_->Read(&size, sizeof(size)); | |
if (nread == 0) return false; | |
CHECK(nread == sizeof(size)) | |
<< cache_file_ << " has invalid cache file format"; | |
p->data.resize(size / sizeof(size_t) + 1); | |
p->begin = reinterpret_cast<char*>(BeginPtr(p->data)); | |
p->end = p->begin + size; | |
CHECK(fi_->Read(p->begin, size) == size) | |
<< cache_file_ << " has invalid cache file format"; | |
return true; | |
}, | |
[this]() { fi_->Seek(0); }); | |
return true; | |
} | |
} // namespace io | |
} // namespace dmlc | |
#endif // DMLC_USE_CXX11 | |
#endif // DMLC_IO_CACHED_INPUT_SPLIT_H_ | |
//===== EXPANDED: ../dmlc-core/src/io/cached_input_split.h ===== | |
#if DMLC_USE_HDFS | |
#endif | |
#if DMLC_USE_S3 | |
#endif | |
#if DMLC_USE_AZURE | |
#endif | |
namespace dmlc { | |
namespace io { | |
FileSystem *FileSystem::GetInstance(const URI &path) { | |
if (path.protocol == "file://" || path.protocol.length() == 0) { | |
return LocalFileSystem::GetInstance(); | |
} | |
if (path.protocol == "hdfs://") { | |
#if DMLC_USE_HDFS | |
return HDFSFileSystem::GetInstance(path.host); | |
#else | |
LOG(FATAL) << "Please compile with DMLC_USE_HDFS=1 to use hdfs"; | |
#endif | |
} | |
if (path.protocol == "s3://" || path.protocol == "http://" || path.protocol == "https://") { | |
#if DMLC_USE_S3 | |
return S3FileSystem::GetInstance(); | |
#else | |
LOG(FATAL) << "Please compile with DMLC_USE_S3=1 to use S3"; | |
#endif | |
} | |
if (path.protocol == "azure://") { | |
#if DMLC_USE_AZURE | |
return AzureFileSystem::GetInstance(); | |
#else | |
LOG(FATAL) << "Please compile with DMLC_USE_AZURE=1 to use Azure"; | |
#endif | |
} | |
LOG(FATAL) << "unknown filesystem protocol " + path.protocol; | |
return NULL; | |
} | |
} // namespace io | |
InputSplit* InputSplit::Create(const char *uri_, | |
unsigned part, | |
unsigned nsplit, | |
const char *type) { | |
using namespace std; | |
using namespace dmlc::io; | |
// allow cachefile in format path#cachefile | |
io::URISpec spec(uri_, part, nsplit); | |
if (!strcmp(spec.uri.c_str(), "stdin")) { | |
return new SingleFileSplit(spec.uri.c_str()); | |
} | |
CHECK(part < nsplit) << "invalid input parameter for InputSplit::Create"; | |
URI path(spec.uri.c_str()); | |
InputSplitBase *split = NULL; | |
if (!strcmp(type, "text")) { | |
split = new LineSplitter(FileSystem::GetInstance(path), | |
spec.uri.c_str(), part, nsplit); | |
} else if (!strcmp(type, "recordio")) { | |
split = new RecordIOSplitter(FileSystem::GetInstance(path), | |
spec.uri.c_str(), part, nsplit); | |
} else { | |
LOG(FATAL) << "unknown input split type " << type; | |
} | |
#if DMLC_ENABLE_STD_THREAD | |
if (spec.cache_file.length() == 0) { | |
return split; | |
} else { | |
return new CachedInputSplit(split, spec.cache_file.c_str()); | |
} | |
#else | |
CHECK(spec.cache_file.length() == 0) | |
<< "to enable cached file, compile with c++11"; | |
return split; | |
#endif | |
} | |
Stream *Stream::Create(const char *uri, | |
const char * const flag, | |
bool try_create) { | |
io::URI path(uri); | |
return io::FileSystem:: | |
GetInstance(path)->Open(path, flag, try_create); | |
} | |
SeekStream *SeekStream::CreateForRead(const char *uri, bool try_create) { | |
io::URI path(uri); | |
return io::FileSystem:: | |
GetInstance(path)->OpenForRead(path, try_create); | |
} | |
} // namespace dmlc | |
//===== EXPANDED: ../dmlc-core/src/io.cc ===== | |
//===== EXPANDING: ../dmlc-core/src/recordio.cc ===== | |
// Copyright by Contributors | |
namespace dmlc { | |
// implemmentation | |
void RecordIOWriter::WriteRecord(const void *buf, size_t size) { | |
CHECK(size < (1 << 29U)) | |
<< "RecordIO only accept record less than 2^29 bytes"; | |
const uint32_t umagic = kMagic; | |
// initialize the magic number, in stack | |
const char *magic = reinterpret_cast<const char*>(&umagic); | |
const char *bhead = reinterpret_cast<const char*>(buf); | |
uint32_t len = static_cast<uint32_t>(size); | |
uint32_t lower_align = (len >> 2U) << 2U; | |
uint32_t upper_align = ((len + 3U) >> 2U) << 2U; | |
uint32_t dptr = 0; | |
for (uint32_t i = 0; i < lower_align ; i += 4) { | |
// use char check for alignment safety reason | |
if (bhead[i] == magic[0] && | |
bhead[i + 1] == magic[1] && | |
bhead[i + 2] == magic[2] && | |
bhead[i + 3] == magic[3]) { | |
uint32_t lrec = EncodeLRec(dptr == 0 ? 1U : 2U, | |
i - dptr); | |
stream_->Write(magic, 4); | |
stream_->Write(&lrec, sizeof(lrec)); | |
if (i != dptr) { | |
stream_->Write(bhead + dptr, i - dptr); | |
} | |
dptr = i + 4; | |
except_counter_ += 1; | |
} | |
} | |
uint32_t lrec = EncodeLRec(dptr != 0 ? 3U : 0U, | |
len - dptr); | |
stream_->Write(magic, 4); | |
stream_->Write(&lrec, sizeof(lrec)); | |
if (len != dptr) { | |
stream_->Write(bhead + dptr, len - dptr); | |
} | |
// write padded bytes | |
uint32_t zero = 0; | |
if (upper_align != len) { | |
stream_->Write(&zero, upper_align - len); | |
} | |
} | |
bool RecordIOReader::NextRecord(std::string *out_rec) { | |
if (end_of_stream_) return false; | |
const uint32_t kMagic = RecordIOWriter::kMagic; | |
out_rec->clear(); | |
size_t size = 0; | |
while (true) { | |
uint32_t header[2]; | |
size_t nread = stream_->Read(header, sizeof(header)); | |
if (nread == 0) { | |
end_of_stream_ = true; return false; | |
} | |
CHECK(nread == sizeof(header)) << "Inavlid RecordIO File"; | |
CHECK(header[0] == RecordIOWriter::kMagic) << "Invalid RecordIO File"; | |
uint32_t cflag = RecordIOWriter::DecodeFlag(header[1]); | |
uint32_t len = RecordIOWriter::DecodeLength(header[1]); | |
uint32_t upper_align = ((len + 3U) >> 2U) << 2U; | |
out_rec->resize(size + upper_align); | |
if (upper_align != 0) { | |
CHECK(stream_->Read(BeginPtr(*out_rec) + size, upper_align) == upper_align) | |
<< "Invalid RecordIO File upper_align=" << upper_align; | |
} | |
// squeeze back | |
size += len; out_rec->resize(size); | |
if (cflag == 0U || cflag == 3U) break; | |
out_rec->resize(size + sizeof(kMagic)); | |
std::memcpy(BeginPtr(*out_rec) + size, &kMagic, sizeof(kMagic)); | |
size += sizeof(kMagic); | |
} | |
return true; | |
} | |
// helper function to find next recordio head | |
inline char *FindNextRecordIOHead(char *begin, char *end) { | |
CHECK_EQ((reinterpret_cast<size_t>(begin) & 3UL), 0U); | |
CHECK_EQ((reinterpret_cast<size_t>(end) & 3UL), 0U); | |
uint32_t *p = reinterpret_cast<uint32_t *>(begin); | |
uint32_t *pend = reinterpret_cast<uint32_t *>(end); | |
for (; p + 1 < pend; ++p) { | |
if (p[0] == RecordIOWriter::kMagic) { | |
uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]); | |
if (cflag == 0 || cflag == 1) { | |
return reinterpret_cast<char*>(p); | |
} | |
} | |
} | |
return end; | |
} | |
RecordIOChunkReader::RecordIOChunkReader(InputSplit::Blob chunk, | |
unsigned part_index, | |
unsigned num_parts) { | |
size_t nstep = (chunk.size + num_parts - 1) / num_parts; | |
// align | |
nstep = ((nstep + 3UL) >> 2UL) << 2UL; | |
size_t begin = std::min(chunk.size, nstep * part_index); | |
size_t end = std::min(chunk.size, nstep * (part_index + 1)); | |
char *head = reinterpret_cast<char*>(chunk.dptr); | |
pbegin_ = FindNextRecordIOHead(head + begin, head + chunk.size); | |
pend_ = FindNextRecordIOHead(head + end, head + chunk.size); | |
} | |
bool RecordIOChunkReader::NextRecord(InputSplit::Blob *out_rec) { | |
if (pbegin_ >= pend_) return false; | |
uint32_t *p = reinterpret_cast<uint32_t *>(pbegin_); | |
CHECK(p[0] == RecordIOWriter::kMagic); | |
uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]); | |
uint32_t clen = RecordIOWriter::DecodeLength(p[1]); | |
if (cflag == 0) { | |
// skip header | |
out_rec->dptr = pbegin_ + 2 * sizeof(uint32_t); | |
// move pbegin | |
pbegin_ += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U); | |
CHECK(pbegin_ <= pend_) << "Invalid RecordIO Format"; | |
out_rec->size = clen; | |
return true; | |
} else { | |
const uint32_t kMagic = RecordIOWriter::kMagic; | |
// abnormal path, read into string | |
CHECK(cflag == 1U) << "Invalid RecordIO Format"; | |
temp_.resize(0); | |
while (true) { | |
CHECK(pbegin_ + 2 * sizeof(uint32_t) <= pend_); | |
p = reinterpret_cast<uint32_t *>(pbegin_); | |
CHECK(p[0] == RecordIOWriter::kMagic); | |
cflag = RecordIOWriter::DecodeFlag(p[1]); | |
clen = RecordIOWriter::DecodeLength(p[1]); | |
size_t tsize = temp_.length(); | |
temp_.resize(tsize + clen); | |
if (clen != 0) { | |
std::memcpy(BeginPtr(temp_) + tsize, | |
pbegin_ + 2 * sizeof(uint32_t), | |
clen); | |
tsize += clen; | |
} | |
pbegin_ += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U); | |
if (cflag == 3U) break; | |
temp_.resize(tsize + sizeof(kMagic)); | |
std::memcpy(BeginPtr(temp_) + tsize, &kMagic, sizeof(kMagic)); | |
} | |
out_rec->dptr = BeginPtr(temp_); | |
out_rec->size = temp_.length(); | |
return true; | |
} | |
} | |
} // namespace dmlc | |
//===== EXPANDED: ../dmlc-core/src/recordio.cc ===== | |
//===== EXPANDED: dmlc-minimum0.cc ===== | |
//===== EXPANDING: nnvm.cc ===== | |
//===== EXPANDING: ../mshadow/mshadow/tensor.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file tensor.h | |
* \brief header file of tensor data structure and functions | |
* This lib requires explicit memory allocation and de-allocation | |
* all the data structure Tensor<cpu,1>, Tensor<gpu,1> are like handles(pointers), | |
* no memory allocation is happening during calculation | |
* | |
* For STL style tensor, see tensor_container.h | |
* \author Bing Xu, Tianqi Chen | |
*/ | |
#ifndef MSHADOW_TENSOR_H_ | |
#define MSHADOW_TENSOR_H_ | |
//===== EXPANDING: ../mshadow/mshadow/base.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file base.h | |
* \brief definitions of base types, operators, macros functions | |
* | |
* \author Bing Xu, Tianqi Chen | |
*/ | |
#ifndef MSHADOW_BASE_H_ | |
#define MSHADOW_BASE_H_ | |
#ifdef _MSC_VER | |
#ifndef _CRT_SECURE_NO_WARNINGS | |
#define _CRT_SECURE_NO_WARNINGS | |
#endif | |
#ifndef _CRT_SECURE_NO_DEPRECATE | |
#define _CRT_SECURE_NO_DEPRECATE | |
#endif | |
#define NOMINMAX | |
#endif | |
#ifdef _MSC_VER | |
//! \cond Doxygen_Suppress | |
typedef signed char int8_t; | |
typedef __int16 int16_t; | |
typedef __int32 int32_t; | |
typedef __int64 int64_t; | |
typedef unsigned char uint8_t; | |
typedef unsigned __int16 uint16_t; | |
typedef unsigned __int32 uint32_t; | |
typedef unsigned __int64 uint64_t; | |
//! \endcond | |
#else | |
#endif | |
// macro defintiions | |
/*! | |
* \brief if this macro is define to be 1, | |
* mshadow should compile without any of other libs | |
*/ | |
#ifndef MSHADOW_STAND_ALONE | |
#define MSHADOW_STAND_ALONE 0 | |
#endif | |
/*! \brief whether do padding during allocation */ | |
#ifndef MSHADOW_ALLOC_PAD | |
#define MSHADOW_ALLOC_PAD true | |
#endif | |
/*! | |
* \brief | |
* x dimension of data must be bigger pad_size * ratio to be alloced padded memory, | |
* otherwise use tide allocation | |
* for example, if pad_ratio=2, GPU memory alignement size is 32, | |
* then we will only allocate padded memory if x dimension > 64 | |
* set it to 0 then we will always allocate padded memory | |
*/ | |
#ifndef MSHADOW_MIN_PAD_RATIO | |
#define MSHADOW_MIN_PAD_RATIO 2 | |
#endif | |
#if MSHADOW_STAND_ALONE | |
#define MSHADOW_USE_CBLAS 0 | |
#define MSHADOW_USE_MKL 0 | |
#define MSHADOW_USE_CUDA 0 | |
#endif | |
/*! | |
* \brief force user to use GPU stream during computation | |
* error will be shot when default stream NULL is used | |
*/ | |
#ifndef MSHADOW_FORCE_STREAM | |
#define MSHADOW_FORCE_STREAM 1 | |
#endif | |
/*! \brief use CBLAS for CBLAS */ | |
#ifndef MSHADOW_USE_CBLAS | |
#define MSHADOW_USE_CBLAS 0 | |
#endif | |
/*! \brief use MKL for BLAS */ | |
#ifndef MSHADOW_USE_MKL | |
#define MSHADOW_USE_MKL 1 | |
#endif | |
/*! | |
* \brief use CUDA support, must ensure that the cuda include path is correct, | |
* or directly compile using nvcc | |
*/ | |
#ifndef MSHADOW_USE_CUDA | |
#define MSHADOW_USE_CUDA 1 | |
#endif | |
/*! | |
* \brief use CUDNN support, must ensure that the cudnn include path is correct | |
*/ | |
#ifndef MSHADOW_USE_CUDNN | |
#define MSHADOW_USE_CUDNN 0 | |
#endif | |
/*! | |
* \brief seems CUDAARCH is deprecated in future NVCC | |
* set this to 1 if you want to use CUDA version smaller than 2.0 | |
*/ | |
#ifndef MSHADOW_OLD_CUDA | |
#define MSHADOW_OLD_CUDA 0 | |
#endif | |
/*! | |
* \brief macro to decide existence of c++11 compiler | |
*/ | |
#ifndef MSHADOW_IN_CXX11 | |
#define MSHADOW_IN_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\ | |
__cplusplus >= 201103L || defined(_MSC_VER)) | |
#endif | |
/*! \brief whether use SSE */ | |
#ifndef MSHADOW_USE_SSE | |
#define MSHADOW_USE_SSE 1 | |
#endif | |
/*! \brief whether use NVML to get dynamic info */ | |
#ifndef MSHADOW_USE_NVML | |
#define MSHADOW_USE_NVML 0 | |
#endif | |
// SSE is conflict with cudacc | |
#ifdef __CUDACC__ | |
#undef MSHADOW_USE_SSE | |
#define MSHADOW_USE_SSE 0 | |
#endif | |
#if MSHADOW_USE_CBLAS | |
extern "C" { | |
} | |
#elif MSHADOW_USE_MKL | |
#endif | |
#if MSHADOW_USE_CUDA | |
#endif | |
#if MSHADOW_USE_CUDNN == 1 | |
#endif | |
#if MSHADOW_USE_NVML | |
#endif | |
// -------------------------------- | |
// MSHADOW_XINLINE is used for inlining template code for both CUDA and CPU code | |
#ifdef MSHADOW_XINLINE | |
#error "MSHADOW_XINLINE must not be defined" | |
#endif | |
#ifdef _MSC_VER | |
#define MSHADOW_FORCE_INLINE __forceinline | |
#pragma warning(disable : 4068) | |
#else | |
#define MSHADOW_FORCE_INLINE inline __attribute__((always_inline)) | |
#endif | |
#ifdef __CUDACC__ | |
#define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__ | |
#else | |
#define MSHADOW_XINLINE MSHADOW_FORCE_INLINE | |
#endif | |
/*! \brief cpu force inline */ | |
#define MSHADOW_CINLINE MSHADOW_FORCE_INLINE | |
#if defined(__GXX_EXPERIMENTAL_CXX0X) ||\ | |
defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |
#define MSHADOW_CONSTEXPR constexpr | |
#else | |
#define MSHADOW_CONSTEXPR const | |
#endif | |
/*! | |
* \brief default data type for tensor string | |
* in code release, change it to default_real_t | |
* during development, change it to empty string so that missing | |
* template arguments can be detected | |
*/ | |
#ifndef MSHADOW_DEFAULT_DTYPE | |
#define MSHADOW_DEFAULT_DTYPE = default_real_t | |
#endif | |
/*! | |
* \brief DMLC marco for logging | |
*/ | |
#ifndef MSHADOW_USE_GLOG | |
#define MSHADOW_USE_GLOG DMLC_USE_GLOG | |
#endif // MSHADOW_USE_GLOG | |
#if DMLC_USE_CXX11 | |
#define MSHADOW_THROW_EXCEPTION noexcept(false) | |
#define MSHADOW_NO_EXCEPTION noexcept(true) | |
#else | |
#define MSHADOW_THROW_EXCEPTION | |
#define MSHADOW_NO_EXCEPTION | |
#endif | |
/*! | |
* \brief Protected cuda call in mshadow | |
* \param func Expression to call. | |
* It checks for CUDA errors after invocation of the expression. | |
*/ | |
#define MSHADOW_CUDA_CALL(func) \ | |
{ \ | |
cudaError_t e = (func); \ | |
if (e == cudaErrorCudartUnloading) { \ | |
throw dmlc::Error(cudaGetErrorString(e)); \ | |
} \ | |
CHECK(e == cudaSuccess) \ | |
<< "CUDA: " << cudaGetErrorString(e); \ | |
} | |
/*! | |
* \brief Run function and catch error, log unknown error. | |
* \param func Expression to call. | |
*/ | |
#define MSHADOW_CATCH_ERROR(func) \ | |
{ \ | |
try { \ | |
(func); \ | |
} catch (const dmlc::Error &e) { \ | |
std::string what = e.what(); \ | |
if (what.find("driver shutting down") == std::string::npos) { \ | |
LOG(ERROR) << "Ignore CUDA Error " << what; \ | |
} \ | |
} \ | |
} | |
//===== EXPANDING: ../mshadow/mshadow/half.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file half.h | |
* \brief definition of half (float16) type. | |
* | |
* \author Junyuan Xie | |
*/ | |
#ifndef MSHADOW_HALF_H_ | |
#define MSHADOW_HALF_H_ | |
#if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) | |
#define MSHADOW_CUDA_HALF 1 | |
#if defined(__CUDA_ARCH__) | |
/*! \brief __half2float_warp */ | |
__host__ __device__ float __half2float_warp(const volatile __half& h) { /* NOLINT(*) */ | |
__half val; | |
val.x = h.x; | |
return __half2float(val); | |
} | |
#endif | |
#else | |
#define MSHADOW_CUDA_HALF 0 | |
#endif | |
/*! \brief namespace for mshadow */ | |
namespace mshadow { | |
/* \brief name space for host/device portable half-precision floats */ | |
namespace half { | |
#define MSHADOW_HALF_OPERATOR(RTYPE, OP) \ | |
MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \ | |
return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ | |
} \ | |
template<typename T> \ | |
MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \ | |
return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ | |
} \ | |
template<typename T> \ | |
MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \ | |
return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ | |
} | |
#define MSHADOW_HALF_ASSIGNOP(AOP, OP) \ | |
template<typename T> \ | |
MSHADOW_XINLINE half_t operator AOP (const T& a) { \ | |
return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \ | |
} \ | |
template<typename T> \ | |
MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \ | |
return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \ | |
} | |
#if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) | |
#define MSHADOW_HALF_CONVERSIONOP(T) \ | |
MSHADOW_XINLINE operator T() const { \ | |
return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \ | |
} \ | |
MSHADOW_XINLINE operator T() const volatile { \ | |
return T(__half2float_warp(cuhalf_)); /* NOLINT(*)*/ \ | |
} | |
#else | |
#define MSHADOW_HALF_CONVERSIONOP(T) \ | |
MSHADOW_XINLINE operator T() const { \ | |
return T(half2float(half_)); /* NOLINT(*)*/ \ | |
} \ | |
MSHADOW_XINLINE operator T() const volatile { \ | |
return T(half2float(half_)); /* NOLINT(*)*/ \ | |
} | |
#endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) | |
class half_t { | |
public: | |
union { | |
uint16_t half_; | |
#if MSHADOW_CUDA_HALF | |
__half cuhalf_; | |
#endif // MSHADOW_CUDA_HALF | |
}; | |
static MSHADOW_XINLINE half_t Binary(uint16_t value) { | |
half_t res; | |
res.half_ = value; | |
return res; | |
} | |
MSHADOW_XINLINE half_t() {} | |
#if MSHADOW_CUDA_HALF | |
MSHADOW_XINLINE explicit half_t(const __half& value) { | |
cuhalf_ = value; | |
} | |
#endif // MSHADOW_CUDA_HALF | |
MSHADOW_XINLINE half_t(const float& value) { constructor(value); } | |
MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); } | |
MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); } | |
MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); } | |
MSHADOW_XINLINE explicit half_t(const uint32_t& value) { constructor(value); } | |
MSHADOW_XINLINE explicit half_t(const int64_t& value) { constructor(value); } | |
MSHADOW_XINLINE explicit half_t(const uint64_t& value) { constructor(value); } | |
MSHADOW_HALF_CONVERSIONOP(float) | |
MSHADOW_HALF_ASSIGNOP(+=, +) | |
MSHADOW_HALF_ASSIGNOP(-=, -) | |
MSHADOW_HALF_ASSIGNOP(*=, *) | |
MSHADOW_HALF_ASSIGNOP(/=, /) | |
MSHADOW_XINLINE half_t operator+() { | |
return *this; | |
} | |
MSHADOW_XINLINE half_t operator-() { | |
return half_t(-float(*this)); // NOLINT(*) | |
} | |
MSHADOW_XINLINE half_t operator=(const half_t& a) { | |
half_ = a.half_; | |
return a; | |
} | |
template<typename T> | |
MSHADOW_XINLINE half_t operator=(const T& a) { | |
return *this = half_t(a); /* NOLINT(*)*/ | |
} | |
MSHADOW_XINLINE half_t operator=(const half_t& a) volatile { | |
half_ = a.half_; | |
return a; | |
} | |
template<typename T> | |
MSHADOW_XINLINE half_t operator=(const T& a) volatile { | |
return *this = half_t(a); /* NOLINT(*)*/ | |
} | |
private: | |
union Bits { | |
float f; | |
int32_t si; | |
uint32_t ui; | |
}; | |
static int const shift = 13; | |
static int const shiftSign = 16; | |
static int32_t const infN = 0x7F800000; // flt32 infinity | |
static int32_t const maxN = 0x477FE000; // max flt16 normal as a flt32 | |
static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 | |
static int32_t const signN = 0x80000000; // flt32 sign bit | |
static int32_t const infC = infN >> shift; | |
static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 | |
static int32_t const maxC = maxN >> shift; | |
static int32_t const minC = minN >> shift; | |
static int32_t const signC = signN >> shiftSign; // flt16 sign bit | |
static int32_t const mulN = 0x52000000; // (1 << 23) / minN | |
static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) | |
static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted | |
static int32_t const norC = 0x00400; // min flt32 normal down shifted | |
static int32_t const maxD = infC - maxC - 1; | |
static int32_t const minD = minC - subC - 1; | |
MSHADOW_XINLINE uint16_t float2half(const float& value) const { | |
Bits v, s; | |
v.f = value; | |
uint32_t sign = v.si & signN; | |
v.si ^= sign; | |
sign >>= shiftSign; // logical shift | |
s.si = mulN; | |
s.si = s.f * v.f; // correct subnormals | |
v.si ^= (s.si ^ v.si) & -(minN > v.si); | |
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN)); | |
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN)); | |
v.ui >>= shift; // logical shift | |
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC); | |
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC); | |
return v.ui | sign; | |
} | |
MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*) | |
Bits v, s; | |
v.f = value; | |
uint32_t sign = v.si & signN; | |
v.si ^= sign; | |
sign >>= shiftSign; // logical shift | |
s.si = mulN; | |
s.si = s.f * v.f; // correct subnormals | |
v.si ^= (s.si ^ v.si) & -(minN > v.si); | |
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN)); | |
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN)); | |
v.ui >>= shift; // logical shift | |
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC); | |
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC); | |
return v.ui | sign; | |
} | |
MSHADOW_XINLINE float half2float(const uint16_t& value) const { | |
Bits v; | |
v.ui = value; | |
int32_t sign = v.si & signC; | |
v.si ^= sign; | |
sign <<= shiftSign; | |
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); | |
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); | |
Bits s; | |
s.si = mulC; | |
s.f *= v.si; | |
int32_t mask = -(norC > v.si); | |
v.si <<= shift; | |
v.si ^= (s.si ^ v.si) & mask; | |
v.si |= sign; | |
return v.f; | |
} | |
MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile { // NOLINT(*) | |
Bits v; | |
v.ui = value; | |
int32_t sign = v.si & signC; | |
v.si ^= sign; | |
sign <<= shiftSign; | |
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); | |
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); | |
Bits s; | |
s.si = mulC; | |
s.f *= v.si; | |
int32_t mask = -(norC > v.si); | |
v.si <<= shift; | |
v.si ^= (s.si ^ v.si) & mask; | |
v.si |= sign; | |
return v.f; | |
} | |
template<typename T> | |
MSHADOW_XINLINE void constructor(const T& value) { | |
#if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) | |
cuhalf_ = __float2half(float(value)); // NOLINT(*) | |
#else | |
half_ = float2half(float(value)); // NOLINT(*) | |
#endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) | |
} | |
}; | |
/*! \brief overloaded + operator for half_t */ | |
MSHADOW_HALF_OPERATOR(half_t, +) | |
/*! \brief overloaded - operator for half_t */ | |
MSHADOW_HALF_OPERATOR(half_t, -) | |
/*! \brief overloaded * operator for half_t */ | |
MSHADOW_HALF_OPERATOR(half_t, *) | |
/*! \brief overloaded / operator for half_t */ | |
MSHADOW_HALF_OPERATOR(half_t, /) | |
/*! \brief overloaded > operator for half_t */ | |
MSHADOW_HALF_OPERATOR(bool, >) | |
/*! \brief overloaded < operator for half_t */ | |
MSHADOW_HALF_OPERATOR(bool, <) | |
/*! \brief overloaded >= operator for half_t */ | |
MSHADOW_HALF_OPERATOR(bool, >=) | |
/*! \brief overloaded <= operator for half_t */ | |
MSHADOW_HALF_OPERATOR(bool, <=) | |
#define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0x0400); | |
} // namespace half | |
} // namespace mshadow | |
#endif // MSHADOW_HALF_H_ | |
//===== EXPANDED: ../mshadow/mshadow/half.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/logging.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file logging.h | |
* \brief defines logging macros of dmlc | |
* allows use of GLOG, fall back to internal | |
* implementation when disabled | |
*/ | |
#ifndef MSHADOW_LOGGING_H_ | |
#define MSHADOW_LOGGING_H_ | |
#ifndef DMLC_LOGGING_H_ | |
#define DMLC_LOGGING_H_ | |
namespace dmlc { | |
/*! \brief taken from DMLC directly */ | |
/*! | |
* \brief exception class that will be thrown by | |
* default logger if DMLC_LOG_FATAL_THROW == 1 | |
*/ | |
struct Error : public std::runtime_error { | |
/*! | |
* \brief constructor | |
* \param s the error message | |
*/ | |
explicit Error(const std::string &s) : std::runtime_error(s) {} | |
}; | |
} // namespace dmlc | |
#if defined(_MSC_VER) && _MSC_VER < 1900 | |
#define noexcept(a) | |
#endif | |
#if DMLC_USE_GLOG | |
namespace dmlc { | |
/*! \brief taken from DMLC directly */ | |
inline void InitLogging(const char* argv0) { | |
google::InitGoogleLogging(argv0); | |
} | |
} // namespace dmlc | |
#else | |
// use a light version of glog | |
#if defined(_MSC_VER) | |
#pragma warning(disable : 4722) | |
#endif | |
namespace dmlc { | |
inline void InitLogging(const char* argv0) { | |
// DO NOTHING | |
} | |
// Always-on checking | |
#define CHECK(x) \ | |
if (!(x)) \ | |
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \ | |
"failed: " #x << ' ' | |
#define CHECK_LT(x, y) CHECK((x) < (y)) | |
#define CHECK_GT(x, y) CHECK((x) > (y)) | |
#define CHECK_LE(x, y) CHECK((x) <= (y)) | |
#define CHECK_GE(x, y) CHECK((x) >= (y)) | |
#define CHECK_EQ(x, y) CHECK((x) == (y)) | |
#define CHECK_NE(x, y) CHECK((x) != (y)) | |
#define CHECK_NOTNULL(x) \ | |
((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*) | |
// Debug-only checking. | |
#ifdef NDEBUG | |
#define DCHECK(x) \ | |
while (false) CHECK(x) | |
#define DCHECK_LT(x, y) \ | |
while (false) CHECK((x) < (y)) | |
#define DCHECK_GT(x, y) \ | |
while (false) CHECK((x) > (y)) | |
#define DCHECK_LE(x, y) \ | |
while (false) CHECK((x) <= (y)) | |
#define DCHECK_GE(x, y) \ | |
while (false) CHECK((x) >= (y)) | |
#define DCHECK_EQ(x, y) \ | |
while (false) CHECK((x) == (y)) | |
#define DCHECK_NE(x, y) \ | |
while (false) CHECK((x) != (y)) | |
#else | |
#define DCHECK(x) CHECK(x) | |
#define DCHECK_LT(x, y) CHECK((x) < (y)) | |
#define DCHECK_GT(x, y) CHECK((x) > (y)) | |
#define DCHECK_LE(x, y) CHECK((x) <= (y)) | |
#define DCHECK_GE(x, y) CHECK((x) >= (y)) | |
#define DCHECK_EQ(x, y) CHECK((x) == (y)) | |
#define DCHECK_NE(x, y) CHECK((x) != (y)) | |
#endif // NDEBUG | |
#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__) | |
#define LOG_ERROR LOG_INFO | |
#define LOG_WARNING LOG_INFO | |
#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__) | |
#define LOG_QFATAL LOG_FATAL | |
// Poor man version of VLOG | |
#define VLOG(x) LOG_INFO.stream() | |
#define LOG(severity) LOG_##severity.stream() | |
#define LG LOG_INFO.stream() | |
#define LOG_IF(severity, condition) \ | |
!(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) | |
#ifdef NDEBUG | |
#define LOG_DFATAL LOG_ERROR | |
#define DFATAL ERROR | |
#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) | |
#define DLOG_IF(severity, condition) \ | |
(true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) | |
#else | |
#define LOG_DFATAL LOG_FATAL | |
#define DFATAL FATAL | |
#define DLOG(severity) LOG(severity) | |
#define DLOG_IF(severity, condition) LOG_IF(severity, condition) | |
#endif | |
// Poor man version of LOG_EVERY_N | |
#define LOG_EVERY_N(severity, n) LOG(severity) | |
class DateLogger { | |
public: | |
DateLogger() { | |
#if defined(_MSC_VER) | |
_tzset(); | |
#endif | |
} | |
const char* HumanDate() { | |
#if defined(_MSC_VER) | |
_strtime_s(buffer_, sizeof(buffer_)); | |
#else | |
time_t time_value = time(NULL); | |
struct tm now; | |
localtime_r(&time_value, &now); | |
snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", now.tm_hour, | |
now.tm_min, now.tm_sec); | |
#endif | |
return buffer_; | |
} | |
private: | |
char buffer_[9]; | |
}; | |
class LogMessage { | |
public: | |
LogMessage(const char* file, int line) | |
: | |
#ifdef __ANDROID__ | |
log_stream_(std::cout) | |
#else | |
log_stream_(std::cerr) | |
#endif | |
{ | |
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" | |
<< line << ": "; | |
} | |
~LogMessage() { log_stream_ << "\n"; } | |
std::ostream& stream() { return log_stream_; } | |
protected: | |
std::ostream& log_stream_; | |
private: | |
DateLogger pretty_date_; | |
LogMessage(const LogMessage&); | |
void operator=(const LogMessage&); | |
}; | |
#if DMLC_LOG_FATAL_THROW == 0 | |
class LogMessageFatal : public LogMessage { | |
public: | |
LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} | |
~LogMessageFatal() { | |
log_stream_ << "\n"; | |
abort(); | |
} | |
private: | |
LogMessageFatal(const LogMessageFatal&); | |
void operator=(const LogMessageFatal&); | |
}; | |
#else | |
class LogMessageFatal { | |
public: | |
LogMessageFatal(const char* file, int line) { | |
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" | |
<< line << ": "; | |
} | |
std::ostringstream &stream() { return log_stream_; } | |
~LogMessageFatal() DMLC_THROW_EXCEPTION { | |
// throwing out of destructor is evil | |
// hopefully we can do it here | |
throw Error(log_stream_.str()); | |
} | |
private: | |
std::ostringstream log_stream_; | |
DateLogger pretty_date_; | |
LogMessageFatal(const LogMessageFatal&); | |
void operator=(const LogMessageFatal&); | |
}; | |
#endif | |
// This class is used to explicitly ignore values in the conditional | |
// logging macros. This avoids compiler warnings like "value computed | |
// is not used" and "statement has no effect". | |
class LogMessageVoidify { | |
public: | |
LogMessageVoidify() {} | |
// This has to be an operator with a precedence lower than << but | |
// higher than "?:". See its usage. | |
void operator&(std::ostream&) {} | |
}; | |
} // namespace dmlc | |
#endif | |
#endif // DMLC_LOGGING_H_ | |
#endif // MSHADOW_LOGGING_H_ | |
//===== EXPANDED: ../mshadow/mshadow/logging.h ===== | |
/*! \brief namespace for mshadow */ | |
namespace mshadow { | |
/*! \brief buffer size for each random number generator */ | |
const unsigned kRandBufferSize = 1000000; | |
/*! \brief pi */ | |
const float kPi = 3.1415926f; | |
/*! \brief type that will be used for index */ | |
typedef unsigned index_t; | |
#ifdef _WIN32 | |
/*! \brief openmp index for windows */ | |
typedef int64_t openmp_index_t; | |
#else | |
/*! \brief openmp index for linux */ | |
typedef index_t openmp_index_t; | |
#endif | |
/*! \brief float point type that will be used in default by mshadow */ | |
typedef float default_real_t; | |
/*! \brief data type flag */ | |
enum TypeFlag { | |
kFloat32, | |
kFloat64, | |
kFloat16, | |
kUint8, | |
kInt32 | |
}; | |
template<typename DType> | |
struct DataType; | |
template<> | |
struct DataType<float> { | |
static const int kFlag = kFloat32; | |
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1) | |
static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT; | |
typedef float ScaleType; | |
#endif | |
}; | |
template<> | |
struct DataType<double> { | |
static const int kFlag = kFloat64; | |
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1) | |
static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE; | |
typedef double ScaleType; | |
#endif | |
}; | |
template<> | |
struct DataType<half::half_t> { | |
static const int kFlag = kFloat16; | |
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1) | |
static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF; | |
typedef float ScaleType; | |
#endif | |
}; | |
template<> | |
struct DataType<uint8_t> { | |
static const int kFlag = kUint8; | |
}; | |
template<> | |
struct DataType<int32_t> { | |
static const int kFlag = kInt32; | |
}; | |
/*! \brief type enum value for default real type */ | |
const int default_type_flag = DataType<default_real_t>::kFlag; | |
/*! layout flag */ | |
enum LayoutFlag { | |
kNCHW = 0, | |
kNHWC, | |
kCHWN, | |
kNCDHW = 1 << 5, | |
kNDHWC, | |
kCDHWN | |
}; | |
template<int layout> | |
struct LayoutType; | |
template<> | |
struct LayoutType<kNCHW> { | |
static const index_t kNdim = 4; | |
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) | |
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW; | |
#else | |
static const int kCudnnFlag = -1; | |
#endif | |
}; | |
template<> | |
struct LayoutType<kNHWC> { | |
static const index_t kNdim = 4; | |
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) | |
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC; | |
#else | |
static const int kCudnnFlag = -1; | |
#endif | |
}; | |
/*! \brief default layout for 4d tensor */ | |
const int default_layout = kNCHW; | |
template<> | |
struct LayoutType<kNCDHW> { | |
static const index_t kNdim = 5; | |
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) | |
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW; | |
#else | |
static const int kCudnnFlag = -1; | |
#endif | |
}; | |
template<> | |
struct LayoutType<kNDHWC> { | |
static const index_t kNdim = 5; | |
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) | |
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC; | |
#else | |
static const int kCudnnFlag = -1; | |
#endif | |
}; | |
/*! \brief default layout for 5d tensor */ | |
const int default_layout_5d = kNCDHW; | |
/*! \brief namespace for operators */ | |
namespace op { | |
// binary operator | |
/*! \brief mul operator */ | |
struct mul{ | |
/*! \brief map a, b to result using defined operation */ | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a * b; | |
} | |
}; | |
/*! \brief plus operator */ | |
struct plus { | |
/*! \brief map a, b to result using defined operation */ | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a + b; | |
} | |
}; | |
/*! \brief minus operator */ | |
struct minus { | |
/*! \brief map a, b to result using defined operation */ | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a - b; | |
} | |
}; | |
/*! \brief divide operator */ | |
struct div { | |
/*! \brief map a, b to result using defined operation */ | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a / b; | |
} | |
}; | |
/*! \brief get rhs */ | |
struct right { | |
/*! \brief map a, b to result using defined operation */ | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return b; | |
} | |
}; | |
// unary operator/ function: example | |
// these operators can be defined by user, | |
// in the same style as binary and unary operator | |
// to use, simply write F<op::identity>( src ) | |
/*! \brief identity function that maps a real number to it self */ | |
struct identity{ | |
/*! \brief map a to result using defined operation */ | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return a; | |
} | |
}; | |
} // namespace op | |
/*! \brief namespace for savers */ | |
namespace sv { | |
/*! \brief save to saver: = */ | |
struct saveto { | |
/*! \brief save b to a using save method */ | |
template<typename DType> | |
MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) | |
a = b; | |
} | |
/*! \brief helper constant to use BLAS, alpha */ | |
inline static default_real_t AlphaBLAS(void) { return 1.0f; } | |
/*! \brief helper constant to use BLAS, beta */ | |
inline static default_real_t BetaBLAS(void) { return 0.0f; } | |
/*! \brief corresponding binary operator type */ | |
typedef op::right OPType; | |
}; | |
/*! \brief save to saver: += */ | |
struct plusto { | |
/*! \brief save b to a using save method */ | |
template<typename DType> | |
MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) | |
a += b; | |
} | |
/*! \brief helper constant to use BLAS, alpha */ | |
inline static default_real_t AlphaBLAS(void) { return 1.0f; } | |
/*! \brief helper constant to use BLAS, beta */ | |
inline static default_real_t BetaBLAS(void) { return 1.0f; } | |
/*! \brief corresponding binary operator type */ | |
typedef op::plus OPType; | |
}; | |
/*! \brief minus to saver: -= */ | |
struct minusto { | |
/*! \brief save b to a using save method */ | |
template<typename DType> | |
MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) | |
a -= b; | |
} | |
/*! \brief helper constant to use BLAS, alpha */ | |
inline static default_real_t AlphaBLAS(void) { return -1.0f; } | |
/*! \brief helper constant to use BLAS, beta */ | |
inline static default_real_t BetaBLAS(void) { return 1.0f; } | |
/*! \brief corresponding binary operator type */ | |
typedef op::minus OPType; | |
}; | |
/*! \brief multiply to saver: *= */ | |
struct multo { | |
/*! \brief save b to a using save method */ | |
template<typename DType> | |
MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) | |
a *= b; | |
} | |
/*! \brief corresponding binary operator type */ | |
typedef op::mul OPType; | |
}; | |
/*! \brief divide to saver: /= */ | |
struct divto { | |
/*! \brief save b to a using save method */ | |
template<typename DType> | |
MSHADOW_XINLINE static void Save(DType& a, DType b) { // NOLINT(*) | |
a /= b; | |
} | |
/*! \brief corresponding binary operator type */ | |
typedef op::div OPType; | |
}; | |
} // namespace sv | |
/*! \brief namespace for potential reducer operations */ | |
namespace red { | |
namespace limits { | |
/*! | |
* \brief minimum value of certain types | |
* \tparam DType data type | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE DType MinValue(void); | |
/*! \brief minimum value of float */ | |
template<> | |
MSHADOW_XINLINE float MinValue<float>(void) { | |
return -FLT_MAX; | |
} | |
/*! \brief minimum value of double */ | |
template<> | |
MSHADOW_XINLINE double MinValue<double>(void) { | |
return -DBL_MAX; | |
} | |
/*! \brief minimum value of half */ | |
template<> | |
MSHADOW_XINLINE half::half_t MinValue<half::half_t>(void) { | |
return MSHADOW_HALF_MIN; | |
} | |
/*! \brief minimum value of int */ | |
template<> | |
MSHADOW_XINLINE int MinValue<int>(void) { | |
return INT_MIN; | |
} | |
/*! \brief minimum value of int */ | |
template<> | |
MSHADOW_XINLINE uint8_t MinValue<uint8_t>(void) { | |
return 0; | |
} | |
} // namespace limits | |
/*! \brief sum reducer */ | |
struct sum { | |
/*! \brief do reduction into dst */ | |
template<typename DType> | |
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) | |
dst += src; | |
} | |
/*! | |
*\brief calculate gradient of redres with respect to redsrc, | |
* redres: reduced result, redsrc: one of reduction element | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { | |
return 1; | |
} | |
/*! | |
*\brief set the initial value during reduction | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) | |
initv = 0; | |
} | |
}; | |
/*! \brief maximum reducer */ | |
struct maximum { | |
/*! \brief do reduction into dst */ | |
template<typename DType> | |
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) | |
using namespace std; | |
#ifdef __CUDACC__ | |
dst = ::max(dst, src); | |
#else | |
dst = max(dst, src); | |
#endif // __CUDACC__ | |
} | |
/*! | |
* \brief calculate gradient of redres with respect to redsrc, | |
* redres: reduced result, redsrc: one of reduction element | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { | |
return redres == redsrc ? 1: 0; | |
} | |
/*! | |
*\brief set the initial value during reduction | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) | |
initv = limits::MinValue<DType>(); | |
} | |
}; | |
/*! \brief minimum reducer */ | |
struct minimum { | |
/*! \brief do reduction into dst */ | |
template<typename DType> | |
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) | |
using namespace std; | |
#ifdef __CUDACC__ | |
dst = ::min(dst, src); | |
#else | |
dst = min(dst, src); | |
#endif // __CUDACC__ | |
} | |
/*! | |
* \brief calculate gradient of redres with respect to redsrc, | |
* redres: reduced result, redsrc: one of reduction element | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { | |
return redres == redsrc ? 1: 0; | |
} | |
/*! | |
*\brief set the initial value during reduction | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) | |
initv = -limits::MinValue<DType>(); | |
} | |
}; | |
} // namespace red | |
#define MSHADOW_TYPE_SWITCH(type, DType, ...) \ | |
switch (type) { \ | |
case mshadow::kFloat32: \ | |
{ \ | |
typedef float DType; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
case mshadow::kFloat64: \ | |
{ \ | |
typedef double DType; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
case mshadow::kFloat16: \ | |
{ \ | |
typedef mshadow::half::half_t DType; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
case mshadow::kUint8: \ | |
{ \ | |
typedef uint8_t DType; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
case mshadow::kInt32: \ | |
{ \ | |
typedef int32_t DType; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
default: \ | |
LOG(FATAL) << "Unknown type enum " << type; \ | |
} | |
#define MSHADOW_REAL_TYPE_SWITCH(type, DType, ...) \ | |
switch (type) { \ | |
case mshadow::kFloat32: \ | |
{ \ | |
typedef float DType; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
case mshadow::kFloat64: \ | |
{ \ | |
typedef double DType; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
case mshadow::kFloat16: \ | |
{ \ | |
typedef mshadow::half::half_t DType; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
case mshadow::kUint8: \ | |
LOG(FATAL) << "This operation only support " \ | |
"floating point types not uint8"; \ | |
break; \ | |
case mshadow::kInt32: \ | |
LOG(FATAL) << "This operation only support " \ | |
"floating point types, not int32"; \ | |
break; \ | |
default: \ | |
LOG(FATAL) << "Unknown type enum " << type; \ | |
} | |
#define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \ | |
switch (layout) { \ | |
case mshadow::kNCHW: \ | |
{ \ | |
const int Layout = kNCHW; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
case mshadow::kNHWC: \ | |
{ \ | |
const int Layout = kNHWC; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
case mshadow::kNCDHW: \ | |
{ \ | |
const int Layout = kNCDHW; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
case mshadow::kNDHWC: \ | |
{ \ | |
const int Layout = kNDHWC; \ | |
{__VA_ARGS__} \ | |
} \ | |
break; \ | |
default: \ | |
LOG(FATAL) << "Unknown layout enum " << layout; \ | |
} | |
/*! \brief get data type size from type enum */ | |
inline size_t mshadow_sizeof(int type) { | |
int size = 0; | |
MSHADOW_TYPE_SWITCH(type, DType, size = sizeof(DType);); | |
return size; | |
} | |
} // namespace mshadow | |
#endif // MSHADOW_BASE_H_ | |
//===== EXPANDED: ../mshadow/mshadow/base.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/expression.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file expression.h | |
* \brief definitions of abstract expressions and expressions template | |
* \author Tianqi Chen, Bing Xu | |
*/ | |
#ifndef MSHADOW_EXPRESSION_H_ | |
#define MSHADOW_EXPRESSION_H_ | |
namespace mshadow { | |
/*! | |
* \brief namespace for abstract expressions and expressions template, | |
* have no dependecy on tensor.h, | |
* These data structure takes no charge in computations, | |
* they are only used to define operations and represent expression in a symbolic way | |
*/ | |
namespace expr { | |
/*! \brief type of expressions */ | |
namespace type { | |
// type expression type are defined as bitmask | |
// subtype relationshop kRValue < kMapper < kPull < kComplex | |
/*! | |
* \brief this expression directly correspnds to a data class, | |
* can be used to assign data | |
*/ | |
const int kRValue = 0; | |
/*! | |
* \brief expression contains element-wise tensor operations, | |
* map a expression to same shape | |
*/ | |
const int kMapper = 1; | |
/*! | |
* \brief expression that can be chained with other expressiones | |
* Usually it have function Eval(i,j) defined, which pulls the result (i, j) from input | |
* expression and output the result at certain position. | |
*/ | |
const int kChainer = 3; | |
/*! \brief othercase: e.g dot product */ | |
const int kComplex = 7; | |
} // namespace type | |
/*! | |
* \brief expression engine that actually interprets these expressions | |
* this is a function template that needed to be implemented for specific expressions | |
* \tparam Saver the save method | |
* \tparam RValue the type of RValue to be saved | |
* \sa namespace sv | |
*/ | |
template<typename Saver, typename RValue, typename DType> | |
struct ExpEngine; | |
/*! \brief defines how expression exp can be evaluated and stored into dst */ | |
// template<typename EType> | |
// inline static void Eval(RValue *dst, const EType &exp); | |
/*! | |
* \brief base class for expression | |
* \tparam SubType inheritated class must put their type into this parameter | |
* \tparam DType the data type of each element in the expression | |
* \tparam exp_type expression type, see namespace type | |
*/ | |
template<typename SubType, typename DType, int exp_type> | |
struct Exp { | |
public: | |
/*! \return subtype instance of current class */ | |
inline const SubType& self(void) const { | |
return *static_cast<const SubType*>(this); | |
} | |
/*! \return reference of subtype instance of current class */ | |
inline SubType* ptrself(void) { | |
return static_cast<SubType*>(this); | |
} | |
}; | |
/*! | |
* \brief scalar expression | |
* \tparam DType the data type of the scalar | |
*/ | |
template<typename DType> | |
struct ScalarExp: public Exp<ScalarExp<DType>, DType, type::kMapper> { | |
/*! \brief scalar value */ | |
DType scalar_; | |
/*! \brief implicit constructor, MUST NOT BE explicit */ | |
ScalarExp(DType scalar) : scalar_(scalar) {} // NOLINT(*) | |
}; | |
/*! \brief create an scalar expression */ | |
template<typename DType> | |
inline ScalarExp<DType> scalar(DType s) { | |
return ScalarExp<DType>(s); | |
} | |
/*! | |
* \brief typecast expression, cast the type of elements | |
* \tparam DstDType the target type we want to cast into | |
* \tparam SrcDType the target type we want to cast from | |
* \tparam EType the type of the source expression | |
* \tparam etype the type of expression after cast | |
*/ | |
template<typename DstDType, typename SrcDType, typename EType, int etype> | |
struct TypecastExp: | |
public Exp<TypecastExp<DstDType, SrcDType, EType, etype>, | |
DstDType, etype> { | |
/*! \brief expression to be typecasted */ | |
const EType &exp; | |
/*! \brief constructor */ | |
explicit TypecastExp(const EType &e) : exp(e) {} | |
}; | |
/*! \brief create an scalar expression */ | |
template<typename DstDType, typename SrcDType, | |
typename EType, int etype> | |
inline TypecastExp<DstDType, SrcDType, EType, (etype|type::kMapper)> | |
tcast(const Exp<EType, SrcDType, etype> &exp) { | |
return TypecastExp<DstDType, SrcDType, EType, (etype|type::kMapper)>(exp.self()); | |
} | |
/*! \brief represent a transpose expression of a container */ | |
template<typename EType, typename DType> | |
struct TransposeExp: public Exp<TransposeExp<EType, DType>, | |
DType, type::kChainer> { | |
/*! \brief expression to be transposed */ | |
const EType &exp; | |
/*! \brief constructor */ | |
explicit TransposeExp(const EType &e) : exp(e) {} | |
/*! \brief transpose expression */ | |
inline const EType &T(void) const { | |
return exp; | |
} | |
}; | |
/*! | |
* \brief base class of all rvalues | |
* \tparam Container the actually class of data container, e.g. Tensor1D | |
* \tparam DataType the element data type of each element in the container | |
*/ | |
template<typename Container, typename DType> | |
class RValueExp: public Exp<Container, DType, type::kRValue> { | |
public: | |
/*! | |
*\brief transpose of a matrix | |
*\return transpose of current expression | |
*/ | |
inline const TransposeExp<Container, DType> T(void) const { | |
return TransposeExp<Container, DType>(this->self()); | |
} | |
/*! \brief operator overload */ | |
inline Container &operator+=(DType s) { | |
ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s)); | |
return *(this->ptrself()); | |
} | |
/*! \brief operator overload */ | |
inline Container &operator-=(DType s) { | |
ExpEngine<sv::minusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s)); | |
return *(this->ptrself()); | |
} | |
/*! \brief operator overload */ | |
inline Container &operator*=(DType s) { | |
ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), scalar<DType>(s)); | |
return *(this->ptrself()); | |
} | |
/*! \brief operator overload */ | |
inline Container &operator/=(DType s) { | |
ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s)); | |
return *(this->ptrself()); | |
} | |
/*! \brief operator overload */ | |
inline Container &__assign(DType s) { | |
ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s)); | |
return *(this->ptrself()); | |
} | |
/*! \brief we can not define container = container */ | |
template<typename E, int etype> | |
inline Container &__assign(const Exp<E, DType, etype> &exp) { | |
ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), exp.self()); | |
return *(this->ptrself()); | |
} | |
/*! \brief operator overload, assign */ | |
inline Container &__assign(const Exp<Container, DType, type::kRValue> &exp); | |
/*! \brief implementation of operator+= */ | |
template<typename E, int etype> | |
inline Container &operator+=(const Exp<E, DType, etype> &exp) { | |
ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), exp.self()); | |
return *(this->ptrself()); | |
} | |
/*! \brief implementation of operator-= */ | |
template<typename E, int etype> | |
inline Container &operator-=(const Exp<E, DType, etype> &exp) { | |
ExpEngine<sv::minusto, Container, DType>::Eval(this->ptrself(), exp.self()); | |
return *(this->ptrself()); | |
} | |
/*! \brief implementation of operator*= */ | |
template<typename E, int etype> | |
inline Container &operator*=(const Exp<E, DType, etype> &exp) { | |
ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), exp.self()); | |
return *(this->ptrself()); | |
} | |
/*! \brief implementation of operator/= */ | |
template<typename E, int etype> | |
inline Container &operator/=(const Exp<E, DType, etype> &exp) { | |
ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), exp.self()); | |
return *(this->ptrself()); | |
} | |
}; | |
/*! | |
* \brief matrix multiplication expression dot(lhs[.T], rhs[.T]) | |
* \tparam TA type of lhs | |
* \tparam TB type of rhs | |
* \tparam ltrans whether lhs is transposed | |
* \tparam rtrans whether rhs is transposed | |
* \tparam DType the data type of the scalar | |
*/ | |
template<typename TA, typename TB, bool ltrans, bool rtrans, typename DType> | |
struct DotExp: public Exp<DotExp<TA, TB, ltrans, rtrans, DType>, | |
DType, type::kComplex> { | |
/*! \brief left operand */ | |
const TA &lhs_; | |
/*! \brief right operand */ | |
const TB &rhs_; | |
/*! \brief scale over result */ | |
DType scale_; | |
/*! \brief constructor */ | |
explicit DotExp(const TA &lhs, const TB &rhs, DType scale) | |
: lhs_(lhs), rhs_(rhs), scale_(scale) {} | |
}; | |
// definition of dot expression | |
/*! \brief dot operator def */ | |
template<typename TA, typename TB, typename DType> | |
inline DotExp<TA, TB, false, false, DType> | |
dot(const RValueExp<TA, DType> &lhs, const RValueExp<TB, DType> &rhs) { | |
return DotExp<TA, TB, false, false, DType>(lhs.self(), rhs.self(), DType(1.0f)); | |
} | |
/*! \brief dot operator def */ | |
template<typename TA, typename TB, typename DType> | |
inline DotExp<TA, TB, true, false, DType> | |
dot(const TransposeExp<TA, DType> &lhs, const RValueExp<TB, DType> &rhs) { | |
return DotExp<TA, TB, true, false, DType>(lhs.exp, rhs.self(), DType(1.0f)); | |
} | |
/*! \brief dot operator def */ | |
template<typename TA, typename TB, typename DType> | |
inline DotExp<TA, TB, false, true, DType> | |
dot(const RValueExp<TA, DType> &lhs, const TransposeExp<TB, DType> &rhs) { | |
return DotExp<TA, TB, false, true, DType>(lhs.self(), rhs.exp, DType(1.0f)); | |
} | |
/*! \brief dot operator def */ | |
template<typename TA, typename TB, typename DType> | |
inline DotExp<TA, TB, true, true, DType> | |
dot(const TransposeExp<TA, DType> &lhs, const TransposeExp<TB, DType> &rhs) { | |
return DotExp<TA, TB, true, true, DType>(lhs.exp, rhs.exp, DType(1.0f)); | |
} | |
/*! \brief batch_dot operator def */ | |
template<bool transpose_left, bool transpose_right, typename TA, typename TB, typename DType> | |
inline DotExp<TA, TB, transpose_left, transpose_right, DType> | |
batch_dot(const RValueExp<TA, DType> &lhs, const RValueExp<TB, DType> &rhs) { | |
return DotExp<TA, TB, transpose_left, transpose_right, DType>( | |
lhs.self(), rhs.self(), DType(1.0f)); | |
} | |
//--------------- | |
// TernaryMapExp | |
// -------------- | |
/*! | |
* \brief ternary map expression | |
* \tparam OP operator | |
* \tparam TA type of item1 | |
* \tparam TB type of item2 | |
* \tparam etype expression type, sa namespace::type | |
*/ | |
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype> | |
struct TernaryMapExp: public Exp<TernaryMapExp<OP, TA, TB, TC, DType, etype>, | |
DType, etype> { | |
/*! \brief first operand */ | |
const TA &item1_; | |
/*! \brief second operand */ | |
const TB &item2_; | |
/*! \brief third operand */ | |
const TC &item3_; | |
/*! \brief constructor */ | |
explicit TernaryMapExp(const TA &item1, const TB &item2, const TC &item3) | |
:item1_(item1), item2_(item2), item3_(item3) {} | |
}; | |
/*! \brief make expression */ | |
template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc> | |
inline TernaryMapExp<OP, TA, TB, TC, DType, (ta|tb|tc|type::kMapper)> | |
MakeExp(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2, | |
const Exp<TC, DType, tc> &item3) { | |
return TernaryMapExp<OP, TA, TB, TC, DType, | |
(ta|tb|tc|type::kMapper)>(item1.self(), item2.self(), item3.self()); | |
} | |
/*! | |
* \brief short hand for MakeExp, usage F<op>(item1,item2,item3). create a ternary operation expression | |
* \param item1 first operand | |
* \param item2 second operand | |
* \param item3 third operand | |
* \return the result expression | |
* \tparam ternary operator | |
* \tparam TA item1 expression | |
* \tparam ta item1 expression type | |
* \tparam TB item2 expression | |
* \tparam tb item2 expression type | |
* \tparam TC item3 expression | |
* \tparam tc item3 expression type | |
* \sa mshadow::op | |
*/ | |
// Ternary | |
template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc> | |
inline TernaryMapExp<OP, TA, TB, TC, DType, (ta|tb|tc|type::kMapper)> | |
F(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2, | |
const Exp<TC, DType, tc> &item3) { | |
return MakeExp<OP>(item1, item2, item3); | |
} | |
//--------------- | |
// BinaryMapExp | |
// -------------- | |
/*! | |
* \brief binary map expression lhs [op] rhs | |
* \tparam OP operator | |
* \tparam TA type of lhs | |
* \tparam TB type of rhs | |
* \tparam etype expression type, sa namespace::type | |
*/ | |
template<typename OP, typename TA, typename TB, typename DType, int etype> | |
struct BinaryMapExp: public Exp<BinaryMapExp<OP, TA, TB, DType, etype>, | |
DType, etype> { | |
/*! \brief left operand */ | |
const TA &lhs_; | |
/*! \brief right operand */ | |
const TB &rhs_; | |
/*! \brief constructor */ | |
explicit BinaryMapExp(const TA &lhs, const TB &rhs) | |
:lhs_(lhs), rhs_(rhs) {} | |
}; | |
/*! \brief make expression */ | |
template<typename OP, typename TA, typename TB, typename DType, int ta, int tb> | |
inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)> | |
MakeExp(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return BinaryMapExp<OP, TA, TB, DType, | |
(ta|tb|type::kMapper)>(lhs.self(), rhs.self()); | |
} | |
/*! | |
* \brief short hand for MakeExp, usage F<op>(lhs, rhs). create a binary operation expression | |
* \param lhs left operand | |
* \param rhs right operand | |
* \return the result expression | |
* \tparam binary operator | |
* \tparam TA lhs expression | |
* \tparam ta lhs expression type | |
* \tparam TB rhs expression | |
* \tparam tb rhs expression type | |
* \sa mshadow::op | |
*/ | |
template<typename OP, typename TA, typename TB, typename DType, int ta, int tb> | |
inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)> | |
F(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return MakeExp<OP>(lhs, rhs); | |
} | |
// operator rules | |
/*! \brief operator overload */ | |
template<typename TA, typename TB, typename DType, int ta, int tb> | |
inline BinaryMapExp<op::plus, TA, TB, DType, (ta|tb|type::kMapper)> | |
operator+(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return MakeExp<op::plus>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
template<typename TA, typename TB, typename DType, int ta, int tb> | |
inline BinaryMapExp<op::minus, TA, TB, DType, (ta|tb|type::kMapper)> | |
operator-(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return MakeExp<op::minus>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
template<typename TA, typename TB, typename DType, int ta, int tb> | |
inline BinaryMapExp<op::mul, TA, TB, DType, (ta|tb|type::kMapper)> | |
operator*(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return MakeExp<op::mul>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
template<typename TA, typename TB, typename DType, int ta, int tb> | |
inline BinaryMapExp<op::div, TA, TB, DType, (ta|tb|type::kMapper)> | |
operator/(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return MakeExp<op::div>(lhs, rhs); | |
} | |
//--------------- | |
// UnaryMapExp | |
// -------------- | |
/*! | |
* \brief unary map expression op(src) | |
* \tparam OP operator | |
* \tparam TA type of src | |
* \tparam etype expression type, sa namespace::type | |
*/ | |
template<typename OP, typename TA, typename DType, int etype> | |
struct UnaryMapExp: public Exp<UnaryMapExp<OP, TA, DType, etype>, | |
DType, etype> { | |
/*! \brief source expression */ | |
const TA &src_; | |
/*! \brief constructor */ | |
explicit UnaryMapExp(const TA &src) : src_(src) {} | |
}; | |
/*! \brief make expression */ | |
template<typename OP, typename TA, typename DType, int ta> | |
inline UnaryMapExp<OP, TA, DType, (ta|type::kMapper)> | |
MakeExp(const Exp<TA, DType, ta> &src) { | |
return UnaryMapExp<OP, TA, DType, (ta|type::kMapper)>(src.self()); | |
} | |
/*! | |
* \brief short hand for MakeExp, usage F<op>(src), create a unary operation expression | |
* \param src source expression | |
* \return the result expression | |
* \tparam operator | |
* \tparam TA source expression | |
* \tparam ta source expression type | |
* \sa mshadow::op | |
*/ | |
template<typename OP, typename TA, typename DType, int ta> | |
inline UnaryMapExp<OP, TA, DType, (ta|type::kMapper)> | |
F(const Exp<TA, DType, ta> &src) { | |
return MakeExp<OP>(src); | |
} | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXPRESSION_H_ | |
//===== EXPANDED: ../mshadow/mshadow/expression.h ===== | |
namespace mshadow { | |
/*! \brief device name CPU */ | |
struct cpu { | |
/*! \brief whether this device is CPU or not */ | |
static const bool kDevCPU = true; | |
/*! \brief device flag number, identifies this device */ | |
static const int kDevMask = 1 << 0; | |
}; | |
/*! \brief device name GPU */ | |
struct gpu { | |
/*! \brief whether this device is CPU or not */ | |
static const bool kDevCPU = false; | |
/*! \brief device flag number, identifies this device */ | |
static const int kDevMask = 1 << 1; | |
}; | |
template<int ndim> | |
struct Shape; | |
/*! | |
* \brief allow string printing of the shape | |
* \param os the output stream | |
* \param shape the shape | |
* \return the ostream | |
*/ | |
template<int ndim> | |
inline std::ostream &operator<<(std::ostream &os, const Shape<ndim> &shape); // NOLINT(*) | |
/*! | |
* \brief shape of a tensor | |
* \tparam dimension dimension of tensor | |
*/ | |
template<int dimension> | |
struct Shape { | |
/*! \brief dimension of current shape */ | |
static const int kDimension = dimension; | |
/*! \brief dimension of current shape minus one */ | |
static const int kSubdim = dimension - 1; | |
/*! \brief storing the dimension information */ | |
index_t shape_[kDimension]; | |
/*! \brief default constructor, do nothing */ | |
MSHADOW_XINLINE Shape(void) {} | |
/*! \brief constuctor */ | |
MSHADOW_XINLINE Shape(const Shape<kDimension> &s) { | |
#pragma unroll | |
for (int i = 0; i < kDimension; ++i) { | |
this->shape_[i] = s[i]; | |
} | |
} | |
/*! | |
* \brief get corresponding index | |
* \param idx dimension index | |
* \return the corresponding dimension size | |
*/ | |
MSHADOW_XINLINE index_t &operator[](index_t idx) { | |
return shape_[idx]; | |
} | |
/*! | |
* \brief get corresponding index | |
* \param idx dimension index | |
* \return the corresponding dimension size | |
*/ | |
MSHADOW_XINLINE const index_t &operator[](index_t idx) const { | |
return shape_[idx]; | |
} | |
/*! | |
* \return whether two shape equals | |
* \param s the shape to compare against | |
*/ | |
MSHADOW_XINLINE bool operator==(const Shape<kDimension> &s) const { | |
#pragma unroll | |
for (int i = 0; i < kDimension; ++i) { | |
if (s.shape_[i] != this->shape_[i]) return false; | |
} | |
return true; | |
} | |
/*! | |
* \return whether two shape not equal | |
* \param s the shape to compare against | |
*/ | |
MSHADOW_XINLINE bool operator!=(const Shape<kDimension> &s) const { | |
return !(*this == s); | |
} | |
/*! | |
* flatten the tensor, return a 1D shape | |
* \return the flat 1d shape | |
*/ | |
MSHADOW_XINLINE Shape<1> FlatTo1D(void) const { | |
Shape<1> s; | |
s[0] = this->Size(); | |
return s; | |
} | |
/*! | |
* flatten the higher dimension to second dimension, return a 2D shape | |
* \return the flat 2d shape | |
*/ | |
MSHADOW_XINLINE Shape<2> FlatTo2D(void) const { | |
Shape<2> s; | |
s.shape_[1] = this->shape_[kDimension - 1]; | |
index_t ymax = 1; | |
#pragma unroll | |
for (int i = 0; i < kDimension - 1; ++i) { | |
ymax *= this->shape_[i]; | |
} | |
s.shape_[0] = ymax; | |
return s; | |
} | |
/*! \return number of valid elements */ | |
MSHADOW_XINLINE size_t Size(void) const { | |
size_t size = this->shape_[0]; | |
#pragma unroll | |
for (int i = 1; i < kDimension; ++i) { | |
size *= this->shape_[i]; | |
} | |
return size; | |
} | |
/*! | |
* \return product shape in [dimstart,dimend) | |
* \param dimstart start dimension | |
* \param dimend end dimension | |
*/ | |
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const { | |
index_t num = 1; | |
#pragma unroll | |
for (int i = dimstart; i < dimend; ++i) { | |
num *= this->shape_[i]; | |
} | |
return num; | |
} | |
/*! | |
* \brief get subshape that takes off largest dimension | |
v * \return subshape | |
*/ | |
MSHADOW_XINLINE Shape<kSubdim> SubShape(void) const { | |
Shape<kSubdim> s; | |
// for cuda | |
#pragma unroll | |
for (int i = 0; i < kSubdim; ++i) { | |
s.shape_[i] = this->shape_[i + 1]; | |
} | |
return s; | |
} | |
/*! | |
* \brief slice the shape from start to end | |
* \tparam dimstart start dimension | |
* \tparam dimend end dimension | |
* \return the sliced shape | |
*/ | |
template<int dimstart, int dimend> | |
MSHADOW_XINLINE Shape<dimend - dimstart> Slice(void) const { | |
Shape<dimend - dimstart> s; | |
#pragma unroll | |
for (int i = dimstart; i < dimend; ++i) { | |
s[i - dimstart] = this->shape_[i]; | |
} | |
return s; | |
} | |
//! \cond Doxygen_Suppress | |
template<int dim> | |
friend std::ostream &operator<<(std::ostream &os, const Shape<dim> &shape); // NOLINT(*) | |
//! \endcond | |
}; // Shape | |
//------------------------------------------------ | |
// useful construction functions to generate shape | |
//------------------------------------------------- | |
/*! | |
* \brief construct a one dimension shape, stride will equal s0 | |
* \param s0 size of dimension 0 | |
* \return the shape construction | |
*/ | |
MSHADOW_XINLINE Shape<1> Shape1(index_t s0) { | |
Shape<1> s; s[0] = s0; | |
return s; | |
} | |
/*! | |
* \brief construct a two dimension shape, stride will equal s0 | |
* \param s0 size of dimension 0 | |
* \param s1 size of dimension 1 | |
* \return the shape construction | |
*/ | |
MSHADOW_XINLINE Shape<2> Shape2(index_t s0, index_t s1) { | |
Shape<2> s; s[0] = s0; s[1] = s1; | |
return s; | |
} | |
/*! | |
* \brief construct a three dimension shape, stride will equal s0 | |
* \param s0 size of dimension 0 | |
* \param s1 size of dimension 1 | |
* \param s2 size of dimension 2 | |
* \return the shape construction | |
*/ | |
MSHADOW_XINLINE Shape<3> Shape3(index_t s0, index_t s1, index_t s2) { | |
Shape<3> s; | |
s[0] = s0; s[1] = s1; s[2] = s2; | |
return s; | |
} | |
/*! | |
* \brief construct a four dimension shape, stride will equal s0 | |
* \param s0 size of dimension 0 | |
* \param s1 size of dimension 1 | |
* \param s2 size of dimension 2 | |
* \param s3 size of dimension 3 | |
* \return the shape construction | |
*/ | |
MSHADOW_XINLINE Shape<4> Shape4(index_t s0, index_t s1, | |
index_t s2, index_t s3) { | |
Shape<4> s; | |
s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; | |
return s; | |
} | |
/*! | |
* \brief construct a five dimension shape, stride will equal s0 | |
* \param s0 size of dimension 0 | |
* \param s1 size of dimension 1 | |
* \param s2 size of dimension 2 | |
* \param s3 size of dimension 3 | |
* \param s4 size of dimension 4 | |
* \return the shape construction | |
*/ | |
MSHADOW_XINLINE Shape<5> Shape5(index_t s0, index_t s1, index_t s2, | |
index_t s3, index_t s4) { | |
Shape<5> s; | |
s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; s[4] = s4; | |
return s; | |
} | |
/*! | |
* \brief Convert shape in src_layout to shape in dst_layout | |
* \param src original shape | |
* \param src_layout layout of original shape | |
* \param dst_layout target layout | |
* \return shape in target layout | |
*/ | |
inline Shape<4> ConvertLayout(const Shape<4>& src, int src_layout, int dst_layout) { | |
Shape<4> dst; | |
switch (src_layout) { | |
case kNCHW: | |
dst = src; | |
break; | |
case kNHWC: | |
dst[0] = src[0]; | |
dst[2] = src[1]; | |
dst[3] = src[2]; | |
dst[1] = src[3]; | |
break; | |
default: | |
LOG(FATAL) << "Invalid layout for 4d shape " << src_layout; | |
} | |
Shape<4> dst2; | |
switch (dst_layout) { | |
case kNCHW: | |
return dst; | |
case kNHWC: | |
dst2[0] = dst[0]; | |
dst2[1] = dst[2]; | |
dst2[2] = dst[3]; | |
dst2[3] = dst[1]; | |
break; | |
default: | |
LOG(FATAL) << "Invalid layout for 4d shape " << src_layout; | |
} | |
return dst2; | |
} | |
/*! | |
* \brief Convert shape in src_layout to shape in dst_layout | |
* \param src original shape | |
* \param src_layout layout of original shape | |
* \param dst_layout target layout | |
* \return shape in target layout | |
*/ | |
inline Shape<5> ConvertLayout(const Shape<5>& src, int src_layout, int dst_layout) { | |
Shape<5> dst; | |
switch (src_layout) { | |
case kNCDHW: | |
dst = src; | |
break; | |
case kNDHWC: | |
dst[0] = src[0]; | |
dst[2] = src[1]; | |
dst[3] = src[2]; | |
dst[4] = src[3]; | |
dst[1] = src[4]; | |
break; | |
default: | |
LOG(FATAL) << "Invalid layout for 5d shape " << src_layout; | |
} | |
Shape<5> dst2; | |
switch (dst_layout) { | |
case kNCDHW: | |
return dst; | |
case kNDHWC: | |
dst2[0] = dst[0]; | |
dst2[1] = dst[2]; | |
dst2[2] = dst[3]; | |
dst2[3] = dst[4]; | |
dst2[4] = dst[1]; | |
break; | |
default: | |
LOG(FATAL) << "Invalid layout for 5d shape " << src_layout; | |
} | |
return dst2; | |
} | |
/*! | |
* \brief computaion stream structure, used for asynchronize computation | |
*/ | |
template<typename Device> | |
struct Stream { | |
// this is only a dummy implementation for CPU | |
// for GPU, the actual implementation will be specialized in tensor_gpu-inl.h | |
/*! | |
* \brief wait for all the computation associated | |
* with this stream to complete | |
*/ | |
inline void Wait(void) {} | |
/*! | |
* \brief query whether the the stream is idle | |
* \return true if the stream is idle and all the job have been completed | |
*/ | |
inline bool CheckIdle(void) { | |
return true; | |
} | |
/*! \brief create a blas handle */ | |
inline void CreateBlasHandle() {} | |
}; | |
/*! | |
* \brief Tensor RValue, this is the super type of all kinds of possible tensors | |
* \tparam Container the tensor type | |
* \tparam Device which device the tensor is on | |
* \tparam dimension dimension of the tensor | |
* \tparam DType the type of elements in the tensor | |
*/ | |
template<typename Container, typename Device, int dimension, typename DType> | |
struct TRValue: public expr::RValueExp<Container, DType> { | |
}; | |
// more compact template | |
/*! | |
* \brief general tensor | |
* \tparam Device which device the tensor is on | |
* \tparam dimension dimension of the tensor | |
* \tparam DType the type of elements in the tensor | |
*/ | |
template<typename Device, int dimension, | |
typename DType MSHADOW_DEFAULT_DTYPE> | |
struct Tensor: public TRValue<Tensor<Device, dimension, DType>, | |
Device, dimension, DType> { | |
public: | |
//-------------------------------- | |
// struct memembers | |
//-------------------------------- | |
/*! \brief whether current type lies in cpu */ | |
static const bool kDevCPU = Device::kDevCPU; | |
/*! \brief dimension of subtype */ | |
static const int kSubdim = dimension - 1; | |
//-------------------------------- | |
// struct memembers | |
//-------------------------------- | |
/*! \brief pointer to the data */ | |
DType *dptr_; | |
/*! \brief shape of the tensor */ | |
Shape<dimension> shape_; | |
/*! | |
* \brief storing the stride information in x dimension | |
* this is used to deal with pitch allocation in gpu or sse(align x dimension to 64bit) for efficiency | |
*/ | |
index_t stride_; | |
/*! | |
* \brief stream where the computation lies | |
* stream is a device dependency concept where each computation | |
*/ | |
Stream<Device> *stream_; | |
//-------------------------------- | |
// functions | |
//-------------------------------- | |
/*! \brief default constructor */ | |
MSHADOW_XINLINE Tensor(void) : stream_(NULL) {} | |
/*! \brief constructor from shape */ | |
MSHADOW_XINLINE Tensor(const Shape<dimension> &shape) | |
: shape_(shape), stream_(NULL) {} | |
/*! \brief constructor from data pointer and shape, without stride */ | |
MSHADOW_XINLINE Tensor(DType *dptr, const Shape<dimension> &shape) | |
: dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(NULL) {} | |
/*! \brief constructor from data pointer and shape, without stride */ | |
MSHADOW_XINLINE Tensor(DType *dptr, const Shape<dimension> &shape, | |
Stream<Device> *stream) | |
: dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(stream) {} | |
/*! \brief constructor from data pointer and shape */ | |
MSHADOW_XINLINE Tensor(DType *dptr, | |
const Shape<dimension> &shape, | |
index_t stride, Stream<Device> *stream) | |
: dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {} | |
/*! | |
* \brief set the stream to do computation of current tensor | |
* \param stream the computation stream | |
*/ | |
inline void set_stream(Stream<Device> *stream) { | |
this->stream_ = stream; | |
} | |
/*! | |
* \return memory cost of the tensor, including the aligned x dimension | |
* \tparam startdim the starting dimension | |
*/ | |
template<int startdim> | |
MSHADOW_XINLINE size_t MemSize(void) const { | |
size_t memsz = this->stride_; | |
#pragma unroll | |
for (int i = startdim; i < kSubdim; ++i) { | |
memsz *= this->shape_[i]; | |
} | |
return memsz; | |
} | |
/*! | |
* \return whether the tensor's memory is continuous | |
* x dimension same as stride | |
*/ | |
MSHADOW_XINLINE bool CheckContiguous(void) const { | |
return this->shape_[dimension - 1] == stride_; | |
} | |
/*! | |
* \return memory cost of the tensor, including the aligned x dimension | |
*/ | |
MSHADOW_XINLINE size_t MSize(void) const { | |
return this->MemSize<0>(); | |
} | |
/*! | |
* \brief return size of i-th dimension, start counting from highest dimension | |
* \param idx the dimension count from the highest dimensin | |
* \return the size | |
*/ | |
MSHADOW_XINLINE index_t size(index_t idx) const { | |
return shape_[idx]; | |
} | |
/*! | |
* \brief flatten the tensor to 1 dimension | |
* \return tensor after flatten | |
*/ | |
MSHADOW_XINLINE Tensor<Device, 1, DType> FlatTo1D(void) const { | |
return Tensor<Device, 1, DType>(dptr_, shape_.FlatTo1D(), stride_, stream_); | |
} | |
/*! | |
* \brief flatten the tensor to 2 dimension, collapse the higher dimensions together | |
* \return tensor after flatten | |
*/ | |
MSHADOW_XINLINE Tensor<Device, 2, DType> FlatTo2D(void) const { | |
return Tensor<Device, 2, DType>(dptr_, shape_.FlatTo2D(), stride_, stream_); | |
} | |
/*! | |
* \brief get a element of dimension - 1 | |
* \param idx index | |
* \return the result tensor | |
*/ | |
MSHADOW_XINLINE Tensor<Device, kSubdim, DType> operator[](index_t idx) const { | |
return Tensor<Device, kSubdim, DType>(dptr_ + this->MemSize<1>() * idx, | |
shape_.SubShape(), stride_, stream_); | |
} | |
/*! | |
* \brief slice the tensor in highest dimension [begin,end) | |
* \param begin begin position of slice | |
* \param end end position of slice | |
* \return tensor after slice | |
*/ | |
MSHADOW_XINLINE Tensor<Device, dimension, DType> | |
Slice(index_t begin, index_t end) const { | |
Shape<dimension> s = this->shape_; | |
s[0] = end - begin; | |
return Tensor<Device, dimension, DType>(dptr_ + this->MemSize<1>() * begin, | |
s, stride_, stream_); | |
} | |
/*!\brief implement the assignment of same type */ | |
inline Tensor<Device, dimension, DType> & | |
operator=(const Tensor<Device, dimension, DType> &exp) { | |
dptr_ = exp.dptr_; | |
shape_ = exp.shape_; | |
stride_ = exp.stride_; | |
stream_ = exp.stream_; | |
return *this; | |
} | |
/*!\brief functions to fit expression template */ | |
template<typename E, int etype> | |
inline Tensor<Device, dimension, DType> & | |
operator=(const expr::Exp<E, DType, etype> &exp) { | |
return this->__assign(exp); | |
} | |
/*!\brief functions to fit expression template */ | |
inline Tensor<Device, dimension, DType> &operator=(const DType &exp) { | |
return this->__assign(exp); | |
} | |
}; | |
/* | |
* respecialized class Tensor1D, thei is due to different implementation in operator[] | |
*/ | |
template<typename Device, typename DType> | |
struct Tensor<Device, 1, DType>: | |
public TRValue<Tensor<Device, 1, DType>, Device, 1, DType> { | |
public: | |
DType *dptr_; | |
Shape<1> shape_; | |
index_t stride_; | |
Stream<Device> *stream_; | |
// constructor | |
MSHADOW_XINLINE Tensor(void) : stream_(NULL) {} | |
MSHADOW_XINLINE Tensor(const Shape<1> &shape) | |
: shape_(shape), stream_(NULL) {} | |
MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape) | |
: dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(NULL) {} | |
MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape, Stream<Device> *stream) | |
: dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(stream) {} | |
MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape, | |
index_t stride, Stream<Device> *stream) | |
: dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {} | |
inline void set_stream(Stream<Device> *stream) { | |
this->stream_ = stream; | |
} | |
MSHADOW_XINLINE Tensor<Device, 1, DType> FlatTo1D(void) const { | |
return *this; | |
} | |
MSHADOW_XINLINE Tensor<Device, 2, DType> FlatTo2D(void) const { | |
return Tensor<Device, 2, DType>(dptr_, shape_.FlatTo2D(), stride_, stream_); | |
} | |
MSHADOW_XINLINE Tensor<Device, 1, DType> Slice(index_t begin, index_t end) const { | |
Shape<1> s; | |
s[0] = end - begin; | |
return Tensor<Device, 1, DType>(dptr_ + begin, s, s[0], stream_); | |
} | |
MSHADOW_XINLINE bool CheckContiguous(void) const { | |
return true; | |
} | |
MSHADOW_XINLINE size_t MSize(void) const { | |
return shape_[0]; | |
} | |
MSHADOW_XINLINE index_t size(index_t i) const { | |
return shape_[0]; | |
} | |
MSHADOW_XINLINE DType &operator[](index_t idx) { | |
return dptr_[idx]; | |
} | |
MSHADOW_XINLINE const DType &operator[](index_t idx) const { | |
return dptr_[idx]; | |
} | |
/*!\brief implement the assignment of same type */ | |
inline Tensor<Device, 1, DType> & | |
operator=(const Tensor<Device, 1, DType> &exp) { | |
dptr_ = exp.dptr_; | |
shape_ = exp.shape_; | |
stride_ = exp.stride_; | |
stream_ = exp.stream_; | |
return *this; | |
} | |
template<typename E, int etype> | |
inline Tensor<Device, 1, DType> & | |
operator=(const expr::Exp<E, DType, etype> &exp) { | |
return this->__assign(exp); | |
} | |
inline Tensor<Device, 1, DType> &operator=(const DType &exp) { | |
return this->__assign(exp); | |
} | |
}; | |
//------------------------ | |
// Function Declarations | |
//----------------------- | |
/*! | |
* \brief initialize tensor engine, used to call intialization functions of dependent libs | |
* this function should be called before all GPU tensor operations, | |
* for using tensors in CPU, this call is actually not needed | |
* \param device_id GPU device id to be choosed | |
* \tparam Device the device type | |
*/ | |
template<typename Device> | |
inline void InitTensorEngine(int device_id = 0); | |
/*! | |
* \brief Shutdown tensor engine on current device | |
* this function should be called after all GPU tensor operations, | |
* for using tensors in CPU, this call is actually not needed | |
* \tparam Device the device type | |
*/ | |
template<typename Device> | |
inline void ShutdownTensorEngine(void); | |
/*! | |
* \brief set the device of current thread to work on | |
* \param devid the device id | |
* \tparam Device the device type | |
*/ | |
template<typename Device> | |
inline void SetDevice(int devid); | |
/*! | |
* \brief create a new stream from system | |
* \param create_blas_handle whether create blas handle in stream | |
* \param create_dnn_handle whether create cudnn handle in stream | |
* \return a pointer to the created stream | |
* \tparam Device the device type | |
*/ | |
template<typename Device> | |
inline Stream<Device> *NewStream(bool create_blas_handle, | |
bool create_dnn_handle); | |
/*! \brief default behavior: create cublas handle */ | |
template<typename Device> | |
inline Stream<Device> *NewStream() { | |
return NewStream<Device>(true, false); | |
} | |
/*! | |
* \brief delete the computing stream | |
* \param stream the stream parameter to be deleted | |
*/ | |
template<typename Device> | |
inline void DeleteStream(Stream<Device> *stream); | |
/*! | |
* \brief CPU/CPU: allocate space for CTensor, according to the shape in the obj | |
* this function is responsible to set the stride_ in each obj.shape | |
* \param obj the tensor object, with shape specified | |
* \param pad whether padding dimension 0, to make last dimension aligned, | |
* padding may help improve efficiency of matrix multiplications | |
* if true, will allocate space with stride_ that may not equals shape[0] | |
* if false, will allocate continuous space | |
* \tparam dim specify the dim of tensor | |
* \tparam DType type of element in tensor | |
*/ | |
template<int dim, typename DType> | |
inline void AllocSpace(Tensor<cpu, dim, DType> *obj, | |
bool pad = MSHADOW_ALLOC_PAD); | |
/*! | |
* \brief CPU/CPU: allocate space for CTensor, according to the shape in the obj | |
* this function is responsible to set the stride_ in each obj.shape | |
* \param obj the tensor object, with shape specified | |
* \param pad whether padding dimension 0, to make last dimension aligned, | |
* padding may help improve efficiency of matrix multiplications | |
* if true, will allocate space with stride_ that may not equals shape[0] | |
* if false, will allocate continuous space | |
* \tparam dim specify the dim of tensor | |
* \tparam DType type of element in tensor | |
*/ | |
template<int dim, typename DType> | |
inline void AllocSpace(Tensor<gpu, dim, DType> *obj, | |
bool pad = MSHADOW_ALLOC_PAD); | |
/*! | |
* \brief CPU/GPU: free the space of tensor, will set obj.dptr to NULL | |
* \param obj the tensor object | |
* \tparam dim specify the dim of tensor | |
* \tparam DType type of element in tensor | |
*/ | |
template<int dim, typename DType> | |
inline void FreeSpace(Tensor<cpu, dim, DType> *obj); | |
/*! | |
* \brief CPU/GPU: free the space of tensor, will set obj.dptr to NULL | |
* \param obj the tensor object | |
* \tparam dim specify the dim of tensor | |
* \tparam DType type of element in tensor | |
*/ | |
template<int dim, typename DType> | |
inline void FreeSpace(Tensor<gpu, dim, DType> *obj); | |
/*! | |
* \brief CPU/GPU: short cut to allocate and initialize a Tensor | |
* \param shape: shape of tensor | |
* \param initv: initialization value | |
* \param pad : padding option | |
* \param stream : stream of tensor | |
* \tparam Device device of tensor | |
* \tparam DType type of element in tensor | |
* \tparam dim dimention of tensor | |
* \return a new allocated tensor | |
* \sa AllocSpace | |
*/ | |
template<typename Device, typename DType, int dim> | |
inline Tensor<Device, dim, DType> NewTensor(const Shape<dim> &shape, | |
DType initv, | |
bool pad = MSHADOW_ALLOC_PAD, | |
Stream<Device> *stream = NULL); | |
/*! | |
* \brief copy data from one tensor to another, with same shape | |
* \param dst target tensor | |
* \param src source tensor | |
* \param stream the stream, when specified, the copy can exhibit asynchronize behavior | |
* \tparam dim specify the dim of tensor | |
* \tparam DType type of element in tensor | |
*/ | |
template<int dim, typename DType> | |
inline void Copy(Tensor<cpu, dim, DType> dst, | |
const Tensor<cpu, dim, DType> &src, | |
Stream<cpu> *stream = NULL); | |
/*! | |
* \brief copy data from one tensor to another, with same shape | |
* \param dst target tensor | |
* \param src source tensor | |
* \param stream the stream, when specified, the copy can exhibit asynchronize behavior | |
* \tparam dim specify the dim of tensor | |
* \tparam DType type of element in tensor | |
*/ | |
template<int dim, typename DType> | |
inline void Copy(Tensor<cpu, dim, DType> dst, | |
const Tensor<gpu, dim, DType> &src, | |
Stream<gpu> *stream = NULL); | |
/*! | |
* \brief copy data from one tensor to another, with same shape | |
* \param dst target tensor | |
* \param src source tensor | |
* \param stream the stream, when specified, the copy can exhibit asynchronize behavior | |
* \tparam dim specify the dim of tensor | |
* \tparam DType type of element in tensor | |
*/ | |
template<int dim, typename DType> | |
inline void Copy(Tensor<gpu, dim, DType> dst, | |
const Tensor<cpu, dim, DType> &src, | |
Stream<gpu> *stream = NULL); | |
/*! | |
* \brief copy data from one tensor to another, with same shape | |
* \param dst target tensor | |
* \param src source tensor | |
* \param stream the stream, when specified, the copy can exhibit asynchronize behavior | |
* \tparam dim specify the dim of tensor | |
* \tparam DType type of element in tensor | |
*/ | |
template<int dim, typename DType> | |
inline void Copy(Tensor<gpu, dim, DType> dst, | |
const Tensor<gpu, dim, DType> &src, | |
Stream<gpu> *stream = NULL); | |
/*! | |
* \brief CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) | |
* \param dst destination | |
* \param energy input energy | |
*/ | |
template<typename DType> | |
inline void Softmax(Tensor<cpu, 2, DType> dst, const Tensor<cpu, 2, DType> &energy); | |
/*! | |
* \brief CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) | |
* \param dst destination | |
* \param energy input energy | |
*/ | |
template<typename DType> | |
inline void Softmax(Tensor<gpu, 2, DType> dst, const Tensor<gpu, 2, DType> &energy); | |
/*! | |
* \brief CPU/GPU: softmax gradient | |
* \param dst destination | |
* \param src source output | |
* \param label label info | |
*/ | |
template<typename DType> | |
inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst, | |
const Tensor<cpu, 2, DType> &src, | |
const Tensor<cpu, 1, DType> &label); | |
/*! | |
* \brief CPU/GPU: softmax gradient | |
* \param dst destination | |
* \param src source output | |
* \param label label info | |
*/ | |
template<typename DType> | |
inline void SoftmaxGrad(Tensor<gpu, 2, DType> dst, | |
const Tensor<gpu, 2, DType> &src, | |
const Tensor<gpu, 1, DType> &label); | |
/*! | |
* \brief CPU/GPU: Gradient accumulate of embedding matrix. | |
dst[index[i]] += src[i] | |
Called when the featuredim of src is much larger than the batchsize | |
* \param dst destination | |
* \param index index to take | |
* \param src source output | |
*/ | |
template<typename IndexType, typename DType> | |
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst, | |
const Tensor<cpu, 1, IndexType>& index, | |
const Tensor<cpu, 2, DType> &src); | |
/*! | |
* \brief CPU/GPU: Gradient accumulate of embedding matrix. | |
dst[index[i]] += src[i] | |
Called when the featuredim of src is much larger than the batchsize | |
* \param dst destination | |
* \param index index to take | |
* \param src source output | |
*/ | |
template<typename IndexType, typename DType> | |
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst, | |
const Tensor<gpu, 1, IndexType>& index, | |
const Tensor<gpu, 2, DType> &src); | |
/*! | |
* \brief CPU/GPU: Gradient accumulate of embedding matrix. | |
dst[sorted[i]] += src[index[i]] | |
Called when the batchsize of src is larger than the featuredim | |
* \param dst destination | |
* \param sorted the sorted indices | |
* \param index original index of the sorted indices | |
* \param src source output | |
*/ | |
template<typename IndexType, typename DType> | |
inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst, | |
const Tensor<gpu, 1, IndexType>& sorted, | |
const Tensor<cpu, 1, IndexType>& index, | |
const Tensor<cpu, 2, DType> &src); | |
/*! | |
* \brief CPU/GPU: Gradient accumulate of embedding matrix. | |
dst[sorted[i]] += src[index[i]] | |
Called when the batchsize of src is larger than the featuredim | |
* \param dst destination | |
* \param sorted the sorted indices | |
* \param index original index of the sorted indices | |
* \param src source output | |
*/ | |
template<typename IndexType, typename DType> | |
inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst, | |
const Tensor<gpu, 1, IndexType>& sorted, | |
const Tensor<gpu, 1, IndexType>& index, | |
const Tensor<gpu, 2, DType> &src); | |
/*! | |
* \brief CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix. | |
dst[index[i]] = src[i] | |
Will use atomicAdd in the inner implementation and the result may not be deterministic. | |
* \param dst destination | |
* \param index the index to accumulate value | |
* \param src source output | |
*/ | |
template<typename IndexType, typename DType> | |
inline void IndexFill(Tensor<cpu, 2, DType> dst, | |
const Tensor<cpu, 1, IndexType>& index, | |
const Tensor<cpu, 2, DType> &src); | |
/*! | |
* \brief CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix. | |
dst[index[i]] = src[i] | |
Will use atomicAdd in the inner implementation and the result may not be deterministic. | |
* \param dst destination | |
* \param index the index to accumulate value | |
* \param src source output | |
*/ | |
template<typename IndexType, typename DType> | |
inline void IndexFill(Tensor<gpu, 2, DType> dst, | |
const Tensor<gpu, 1, IndexType>& index, | |
const Tensor<gpu, 2, DType> &src); | |
/*! | |
* \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) | |
* \param keys the keys to sort | |
* \param values the values that sorts w.r.t the key | |
* \param is_ascend whether to sort key in ascending order | |
*/ | |
template<typename KDType, typename VDType> | |
inline void SortByKey(Tensor<cpu, 1, KDType> keys, Tensor<cpu, 1, VDType> values, | |
bool is_ascend = true); | |
/*! | |
* \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) | |
* \param keys the keys to sort | |
* \param values the values that sorts w.r.t the key | |
* \param is_ascend whether to sort key in ascending order | |
*/ | |
template<typename KDType, typename VDType> | |
inline void SortByKey(Tensor<gpu, 1, KDType> keys, Tensor<gpu, 1, VDType> values, | |
bool is_ascend = true); | |
/*! | |
* \brief CPU/GPU: Sort the keys within each segment. (Stable sort is performed!) | |
Segments is defined as an ascending ordered vector like [0, 0, 0, 1, 1, 2, 3, 3, 3,...] | |
We sort separately the keys labeled by 0 and 1, 2, 3, etc. | |
Currently only supports sorting in ascending order !! | |
* \param values the data to sort | |
* \param segments segment indicator | |
*/ | |
template<typename Device, typename VDType, typename SDType> | |
inline void VectorizedSort(Tensor<Device, 1, VDType> values, Tensor<Device, 1, SDType> segments); | |
// function declarations to support expression, no need to understand them | |
// these functions do not need to be directly used | |
/*! | |
* \brief CPU/GPU: map a expression to a tensor, this function calls MapPlan | |
* \tparam Saver specify storage method | |
* \tparam R specifies the storage type of the tensor | |
* \tparam dim dim of the tensor, during usage, there is no need to specify this parameter | |
* \tparam DType the type of elements in the tensor | |
* \tparam E specifies the expression type, not need to specify this parameter during usage | |
* \tparam etype expression type | |
* \param dst destination | |
* \param exp expression | |
* \sa namespace mshadow:sv, mshadow::op, mshadow::expr | |
*/ | |
template<typename Saver, typename R, int dim, | |
typename DType, typename E, int etype> | |
inline void MapExp(TRValue<R, cpu, dim, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp); | |
/*! | |
* \brief CPU/GPU: map a expression to a tensor, this function calls MapPlan | |
* \tparam Saver specify storage method | |
* \tparam R specifies the storage type of the tensor | |
* \tparam dim dim of the tensor, during usage, there is no need to specify this parameter | |
* \tparam DType the type of elements in the tensor | |
* \tparam E specifies the expression type, not need to specify this parameter during usage | |
* \tparam etype expression type | |
* \param dst destination | |
* \param exp expression | |
* \sa namespace mshadow:sv, mshadow::op, mshadow::expr | |
*/ | |
template<typename Saver, typename R, int dim, | |
typename DType, typename E, int etype> | |
inline void MapExp(TRValue<R, gpu, dim, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp); | |
/*! | |
* \brief CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) | |
* \tparam Saver specify storage method | |
* \tparam Reducer specify a reducer method | |
* \tparam R specifies the storage type of the tensor | |
* \tparam DType the type of elements in the tensor | |
* \tparam E specifies the expression type, not need to specify this parameter during usage | |
* \tparam etype expression type | |
* \param dst destination | |
* \param exp expression | |
* \param scale scale the result before save | |
* \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr | |
*/ | |
template<typename Saver, typename Reducer, | |
typename R, typename DType, typename E, int etype> | |
inline void MapReduceKeepLowest(TRValue<R, cpu, 1, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp, | |
DType scale = 1); | |
/*! | |
* \brief CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) | |
* \tparam Saver specify storage method | |
* \tparam Reducer specify a reducer method | |
* \tparam R specifies the storage type of the tensor | |
* \tparam DType the type of elements in the tensor | |
* \tparam E specifies the expression type, not need to specify this parameter during usage | |
* \tparam etype expression type | |
* \param dst destination | |
* \param exp expression | |
* \param scale scale the result before save | |
* \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr | |
*/ | |
template<typename Saver, typename Reducer, typename R, | |
typename DType, typename E, int etype> | |
inline void MapReduceKeepLowest(TRValue<R, gpu, 1, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp, | |
DType scale = 1); | |
/*! | |
* \brief CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) | |
* \tparam Saver specify storage method | |
* \tparam Reducer specify a reducer method | |
* \tparam R specifies the storage type of the tensor | |
* \tparam DType the type of elements in the tensor | |
* \tparam dimkeep the target dimension to be kept, should be larger than 0, for 0, use MapReduceKeepLowest | |
* \tparam E specifies the expression type, not need to specify this parameter during usage | |
* \tparam etype expression type | |
* \param dst destination | |
* \param exp expression | |
* \param scale scale the result before save | |
* \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr | |
*/ | |
template<typename Saver, typename Reducer, int dimkeep, | |
typename R, typename DType, typename E, int etype> | |
inline void MapReduceKeepHighDim(TRValue<R, cpu, 1, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp, | |
DType scale = 1); | |
/*! | |
* \brief CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) | |
* \tparam Saver specify storage method | |
* \tparam Reducer specify a reducer method | |
* \tparam R specifies the storage type of the tensor | |
* \tparam DType the type of elements in the tensor | |
* \tparam dimkeep the target dimension to be kept, should be larger than 0, for 0, use MapReduceKeepLowest | |
* \tparam E specifies the expression type, not need to specify this parameter during usage | |
* \tparam etype expression type | |
* \param dst destination | |
* \param exp expression | |
* \param scale scale the result before save | |
* \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr | |
*/ | |
template<typename Saver, typename Reducer, int dimkeep, | |
typename R, typename DType, typename E, int etype> | |
inline void MapReduceKeepHighDim(TRValue<R, gpu, 1, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp, | |
DType scale = 1); | |
/*! | |
* \brief CPU/GPU: 1 dimension vector dot | |
* \param dst Length 1 vector, used to hold the result. | |
* \param lhs Left operand vector | |
* \param rhs Right operand vector | |
*/ | |
template<typename Device, typename DType> | |
inline void VectorDot(Tensor<Device, 1, DType> dst, | |
const Tensor<Device, 1, DType> &lhs, | |
const Tensor<Device, 1, DType> &rhs); | |
/*! | |
* \brief CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst | |
* \param dst Length 3 tensor, used to hold the result | |
* \param lhs Left operand vector | |
* \param rhs Right operand vector | |
* \param alpha multiplier of op(lhs)op(rhs) | |
* \param beta multiplier of dst | |
* \param workspace Workspace for casting DType* to DType** (batched-view), must have size >= 3 * batch_size | |
*/ | |
template<bool transpose_left, bool transpose_right, typename Device, typename DType> | |
inline void BatchGEMM(Tensor<Device, 3, DType> dst, | |
const Tensor<Device, 3, DType> &lhs, | |
const Tensor<Device, 3, DType> &rhs, | |
DType alpha, | |
DType beta, | |
Tensor<Device, 1, DType*> workspace); | |
} // namespace mshadow | |
// include headers | |
//===== EXPANDING: ../mshadow/mshadow/stream_gpu-inl.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file stream_gpu-inl.h | |
* \brief implementation of GPU code | |
* \author Bing Xu, Tianqi Chen | |
*/ | |
#ifndef MSHADOW_STREAM_GPU_INL_H_ | |
#define MSHADOW_STREAM_GPU_INL_H_ | |
namespace mshadow { | |
#if MSHADOW_USE_CUDA == 1 | |
// Stream alocation | |
// actual implementation of GPU stream in CUDA | |
template<> | |
struct Stream<gpu> { | |
/*! \brief handle state */ | |
enum HandleState { | |
NoHandle = 0, | |
OwnHandle = 1, | |
}; | |
/*! \brief cudaStream */ | |
cudaStream_t stream_; | |
/*! \brief cublas handle */ | |
cublasHandle_t blas_handle_; | |
/*! \brief cudnn handle */ | |
#if MSHADOW_USE_CUDNN == 1 | |
cudnnHandle_t dnn_handle_; | |
#endif | |
/*! \brief cublas handle ownership */ | |
HandleState blas_handle_ownership_; | |
/*! \brief cudnn handle ownership */ | |
HandleState dnn_handle_ownership_; | |
Stream(void) : stream_(0), | |
blas_handle_ownership_(NoHandle), | |
dnn_handle_ownership_(NoHandle) {} | |
/*! | |
* \brief wait for all the computation associated | |
* with this stream to complete | |
*/ | |
inline void Wait(void) { | |
MSHADOW_CUDA_CALL(cudaStreamSynchronize(stream_)); | |
} | |
/*! | |
* \brief query whether the the stream is idle | |
* \return true if the stream is idle and all the job have been completed | |
*/ | |
inline bool CheckIdle(void) { | |
cudaError_t err = cudaStreamQuery(stream_); | |
if (err == cudaSuccess) return true; | |
if (err == cudaErrorNotReady) return false; | |
LOG(FATAL) << cudaGetErrorString(err); | |
return false; | |
} | |
/*! | |
* \brief returns actual cudaStream_t given an input GPU stream pointer | |
* \param stream pointer to GPU stream | |
*/ | |
inline static cudaStream_t GetStream(Stream<gpu> *stream) { | |
if (stream == NULL) { | |
#if MSHADOW_FORCE_STREAM | |
LOG(FATAL) << "Default GPU stream was used when MSHADOW_FORCE_STREAM was on"; | |
#endif | |
return 0; | |
} else { | |
return stream->stream_; | |
} | |
} | |
/*! | |
* \brief return actual cublasHandle | |
* \param pointer to GPU stream | |
*/ | |
inline static cublasHandle_t GetBlasHandle(Stream<gpu> *stream) { | |
if (stream == NULL) { | |
return 0; | |
} else { | |
CHECK_NE(stream->blas_handle_ownership_, NoHandle) | |
<< "No handle exist in source stream"; | |
return stream->blas_handle_; | |
} | |
} | |
/*! \brief Destory cublas handle if own it */ | |
inline void DestoryBlasHandle() { | |
if (blas_handle_ownership_ == OwnHandle) { | |
cublasStatus_t err = cublasDestroy(blas_handle_); | |
blas_handle_ownership_ = NoHandle; | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Destory cublas handle failed"; | |
} | |
} | |
/*! \brief Destory original blas handle and create a new one */ | |
inline void CreateBlasHandle() { | |
this->DestoryBlasHandle(); | |
cublasStatus_t err = cublasCreate(&blas_handle_); | |
blas_handle_ownership_ = OwnHandle; | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Create cublas handle failed"; | |
} | |
// #if MSHADOW_USE_CUDNN && defined(__CUDACC__) | |
#if MSHADOW_USE_CUDNN == 1 | |
inline static cudnnHandle_t GetDnnHandle(Stream<gpu> *stream) { | |
if (stream == NULL) { | |
return 0; | |
} else { | |
CHECK_NE(stream->dnn_handle_ownership_, NoHandle) << "No handle exist in source stream"; | |
return stream->dnn_handle_; | |
} | |
} | |
#endif | |
inline void DestroyDnnHandle() { | |
// #if MSHADOW_USE_CUDNN && defined(__CUDACC__) | |
#if MSHADOW_USE_CUDNN == 1 | |
if (dnn_handle_ownership_ == OwnHandle) { | |
cudnnStatus_t err = cudnnDestroy(dnn_handle_); | |
CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); | |
} | |
#endif | |
} | |
inline void CreateDnnHandle() { | |
// #if MSHADOW_USE_CUDNN == 1 && defined(__CUDACC__) | |
#if MSHADOW_USE_CUDNN == 1 | |
this->DestroyDnnHandle(); | |
cudnnStatus_t err = cudnnCreate(&dnn_handle_); | |
CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); | |
err = cudnnSetStream(dnn_handle_, stream_); | |
CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); | |
this->dnn_handle_ownership_ = OwnHandle; | |
#endif | |
} | |
}; | |
template<> | |
inline Stream<gpu> *NewStream<gpu>(bool create_blas_handle, | |
bool create_dnn_handle) { | |
Stream<gpu> *st = new Stream<gpu>(); | |
MSHADOW_CUDA_CALL(cudaStreamCreate(&st->stream_)); | |
if (create_blas_handle) { | |
st->CreateBlasHandle(); | |
} | |
if (create_dnn_handle) { | |
st->CreateDnnHandle(); | |
} | |
return st; | |
} | |
template<> | |
inline void DeleteStream<gpu>(Stream<gpu> *stream) { | |
MSHADOW_CUDA_CALL(cudaStreamDestroy(stream->stream_)); | |
stream->DestoryBlasHandle(); | |
stream->DestroyDnnHandle(); | |
delete stream; | |
} | |
#endif | |
} // namespace mshadow | |
#endif // MSHADOW_STREAM_GPU_INL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/stream_gpu-inl.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension.h ===== | |
/*! | |
* Copyright by Contributors | |
* \file extension.h | |
* \brief some extension of expressions, | |
* used to support something beyond elementwise op | |
* \author Tianqi Chen, Bing Xu | |
*/ | |
#ifndef MSHADOW_EXTENSION_H_ | |
#define MSHADOW_EXTENSION_H_ | |
//===== EXPANDING: ../mshadow/mshadow/expr_engine-inl.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file expr_engine-inl.h | |
* \brief definitions of how expressions should be evaluated | |
* \author Tianqi Chen, Bing Xu | |
*/ | |
#ifndef MSHADOW_EXPR_ENGINE_INL_H_ | |
#define MSHADOW_EXPR_ENGINE_INL_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief a general class that allows extension that makes tensors of some shape | |
* \tparam SubType type of subclass | |
* \tparam SrcExp source expression of the MakeTensorExp, the source of operation | |
* \tparam dim dimension of the expression | |
* \tparam DType the type of elements | |
*/ | |
template<typename SubType, typename SrcExp, int dim, typename DType> | |
struct MakeTensorExp | |
: public Exp<MakeTensorExp<SubType, SrcExp, dim, DType>, | |
DType, type::kChainer> { | |
/*! \brief the shape of this expression */ | |
Shape<dim> shape_; | |
/*! \brief true self of subtype */ | |
inline const SubType& real_self(void) const{ | |
return *static_cast<const SubType*>(this); | |
} | |
}; | |
//---------------------------------------------------------------------- | |
// This part of code gives plan that can be used to carry out execution | |
//--------------------------------------------------------------------- | |
// Declarations of plans | |
template<typename ExpType, typename DType> | |
class Plan { | |
public: | |
/*! | |
* \brief evaluate the expression at index [y][x] | |
* to be implemented by SubType, for RValue, the return type will be DType & | |
*/ | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const; | |
}; | |
// tensor plan | |
template <typename Device, int dim, typename DType> | |
class Plan<Tensor<Device, dim, DType>, DType> { | |
public: | |
explicit Plan(const Tensor<Device, dim, DType> &t) | |
: dptr_(t.dptr_), stride_(t.stride_) {} | |
// for RValue, the return type should be reference | |
MSHADOW_XINLINE DType &REval(index_t y, index_t x) { | |
return dptr_[y * stride_ + x]; | |
} | |
// const evaluation | |
MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const { | |
return dptr_[y * stride_ + x]; | |
} | |
private: | |
DType *dptr_; | |
index_t stride_; | |
}; | |
// special evaluation case for 1d tensor, no stride | |
template <typename Device, typename DType> | |
class Plan<Tensor<Device, 1, DType>, DType> { | |
public: | |
explicit Plan(const Tensor<Device, 1, DType> &t) : dptr_(t.dptr_) {} | |
MSHADOW_XINLINE DType &REval(index_t y, index_t x) { | |
return dptr_[x]; | |
} | |
MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const { | |
return dptr_[x]; | |
} | |
private: | |
DType *dptr_; | |
}; | |
// scalar | |
template<typename DType> | |
class Plan<ScalarExp<DType>, DType> { | |
public: | |
explicit Plan(DType scalar) : scalar_(scalar) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return scalar_; | |
} | |
private: | |
DType scalar_; | |
}; | |
// unary expression | |
template<typename DstDType, typename SrcDType, | |
typename EType, int etype> | |
class Plan<TypecastExp<DstDType, SrcDType, EType, etype>, DstDType> { | |
public: | |
explicit Plan(const Plan<EType, SrcDType> &src) : src_(src) {} | |
MSHADOW_XINLINE DstDType Eval(index_t y, index_t x) const { | |
return DstDType(src_.Eval(y, x)); // NOLINT(*) | |
} | |
private: | |
Plan<EType, SrcDType> src_; | |
}; | |
// ternary expression | |
template<typename OP, typename TA, typename TB, typename TC, int etype, typename DType> | |
class Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType> { | |
public: | |
explicit Plan(const Plan<TA, DType> &item1, const Plan<TB, DType> &item2, | |
const Plan<TC, DType> &item3) | |
: item1_(item1), item2_(item2), item3_(item3) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return OP::Map(item1_.Eval(y, x), item2_.Eval(y, x), item3_.Eval(y, x)); | |
} | |
private: | |
Plan<TA, DType> item1_; | |
Plan<TB, DType> item2_; | |
Plan<TC, DType> item3_; | |
}; | |
// binary expression | |
template<typename OP, typename TA, typename TB, int etype, typename DType> | |
class Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType> { | |
public: | |
explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs) | |
: lhs_(lhs), rhs_(rhs) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x)); | |
} | |
private: | |
Plan<TA, DType> lhs_; | |
Plan<TB, DType> rhs_; | |
}; | |
// unary expression | |
template<typename OP, typename TA, int etype, typename DType> | |
class Plan<UnaryMapExp<OP, TA, DType, etype>, DType> { | |
public: | |
explicit Plan(const Plan<TA, DType> &src) : src_(src) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return OP::Map(src_.Eval(y, x)); | |
} | |
private: | |
Plan<TA, DType> src_; | |
}; | |
// remaps map tensor expression to subtype's plan | |
template<typename SubType, typename SrcExp, int dim, typename DType> | |
struct Plan<MakeTensorExp<SubType, SrcExp, dim, DType>, DType> { | |
public: | |
Plan(const Plan<SubType, DType> &src) : src_(src) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return src_.Eval(y, x); | |
} | |
private: | |
Plan<SubType, DType> src_; | |
}; | |
// tranpsoe | |
template<typename EType, typename DType> | |
class Plan<TransposeExp<EType, DType>, DType> { | |
public: | |
explicit Plan(const Plan<EType, DType> &src) : src_(src) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return src_.Eval(x, y); | |
} | |
private: | |
Plan<EType, DType> src_; | |
}; | |
//---------------------------------------------------------------------- | |
// Mappings from expression to plans | |
//--------------------------------------------------------------------- | |
template<typename OP, typename TA, typename TB, typename DType, int etype> | |
inline Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType> | |
MakePlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e); | |
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype> | |
inline Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType> | |
MakePlan(const TernaryMapExp<OP, TA, TB, TC, DType, etype> &e); | |
template<typename DType> | |
inline Plan<ScalarExp<DType>, DType> MakePlan(const ScalarExp<DType> &e) { | |
return Plan<ScalarExp<DType>, DType>(e.scalar_); | |
} | |
template<typename DstDType, typename SrcDType, typename EType, int etype> | |
inline Plan<TypecastExp<DstDType, SrcDType, EType, etype>, DstDType> | |
MakePlan(const TypecastExp<DstDType, SrcDType, EType, etype> &e) { | |
return Plan<TypecastExp<DstDType, SrcDType, EType, etype>, DstDType>(MakePlan(e.exp)); | |
} | |
template<typename T, typename DType> | |
inline Plan<T, DType> MakePlan(const RValueExp<T, DType> &e) { | |
return Plan<T, DType>(e.self()); | |
} | |
template<typename T, typename DType> | |
inline Plan<TransposeExp<T, DType>, DType> | |
MakePlan(const TransposeExp<T, DType> &e) { | |
return Plan<TransposeExp<T, DType>, DType>(MakePlan(e.exp)); | |
} | |
template<typename T, typename SrcExp, int dim, typename DType> | |
inline Plan<T, DType> | |
MakePlan(const MakeTensorExp<T, SrcExp, dim, DType> &e) { | |
return Plan<T, DType>(e.real_self()); | |
} | |
template<typename OP, typename TA, typename DType, int etype> | |
inline Plan<UnaryMapExp<OP, TA, DType, etype>, DType> | |
MakePlan(const UnaryMapExp<OP, TA, DType, etype> &e) { | |
return Plan<UnaryMapExp<OP, TA, DType, etype>, DType>(MakePlan(e.src_)); | |
} | |
template<typename OP, typename TA, typename TB, typename DType, int etype> | |
inline Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType> | |
MakePlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e) { | |
return Plan<BinaryMapExp<OP, TA, TB, DType, etype>, | |
DType>(MakePlan(e.lhs_), MakePlan(e.rhs_)); | |
} | |
// Ternary | |
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype> | |
inline Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType> | |
MakePlan(const TernaryMapExp<OP, TA, TB, TC, DType, etype> &e) { | |
return Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, | |
DType>(MakePlan(e.item1_), MakePlan(e.item2_), MakePlan(e.item3_)); | |
} | |
//---------------------------------------------------------------- | |
// Static Type inference and Type Checking | |
//---------------------------------------------------------------- | |
/*! | |
* \brief static type inference template, | |
* used to get the dimension of each expression, | |
* if ExpInfo<E>::kDim == -1, this means here are mismatch in expression | |
* if (ExpInfo<E>::kDevMask & cpu::kDevMask) != 0, this means this expression can be assigned to cpu | |
* \tparam E expression | |
*/ | |
template<typename E> | |
struct ExpInfo { | |
static const int kDim = -1; | |
static const int kDevMask = 0; | |
}; | |
template<typename DType> | |
struct ExpInfo< ScalarExp<DType> > { | |
static const int kDim = 0; | |
static const int kDevMask = 0xffff; | |
}; | |
template<typename E, typename DType> | |
struct ExpInfo<TransposeExp<E, DType> > { | |
static const int kDim = ExpInfo<E>::kDim; | |
static const int kDevMask = ExpInfo<E>::kDevMask; | |
}; | |
template<typename DstDType, typename SrcDType, typename EType, int etype> | |
struct ExpInfo<TypecastExp<DstDType, SrcDType, EType, etype> > { | |
static const int kDim = ExpInfo<EType>::kDim; | |
static const int kDevMask = ExpInfo<EType>::kDevMask; | |
}; | |
template<typename Device, int dim, typename DType> | |
struct ExpInfo<Tensor<Device, dim, DType> > { | |
static const int kDim = dim; | |
static const int kDevMask = Device::kDevMask; | |
}; | |
template<typename T, typename SrcExp, int dim, typename DType> | |
struct ExpInfo<MakeTensorExp<T, SrcExp, dim, DType> > { | |
static const int kDimSrc = ExpInfo<SrcExp>::kDim; | |
static const int kDim = kDimSrc >= 0 ? dim : -1; | |
static const int kDevMask = ExpInfo<SrcExp>::kDevMask; | |
}; | |
template<typename OP, typename TA, typename DType, int etype> | |
struct ExpInfo<UnaryMapExp<OP, TA, DType, etype> > { | |
static const int kDim = ExpInfo<TA>::kDim; | |
static const int kDevMask = ExpInfo<TA>::kDevMask; | |
}; | |
template<typename OP, typename TA, typename TB, typename DType, int etype> | |
struct ExpInfo<BinaryMapExp<OP, TA, TB, DType, etype> > { | |
static const int kDimLhs = ExpInfo<TA>::kDim; | |
static const int kDimRhs = ExpInfo<TB>::kDim; | |
static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\ | |
(kDimLhs == 0 ?\ | |
kDimRhs :\ | |
((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1; | |
static const int kDevMask = ExpInfo<TA>::kDevMask & ExpInfo<TB>::kDevMask; | |
}; | |
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype> | |
struct ExpInfo<TernaryMapExp<OP, TA, TB, TC, DType, etype> > { | |
static const int kDimItem1 = ExpInfo<TA>::kDim; | |
static const int kDimItem2 = ExpInfo<TB>::kDim; | |
static const int kDimItem3 = ExpInfo<TC>::kDim; | |
static const int kDim = kDimItem1; | |
static const int kDevMask = ExpInfo<TA>::kDevMask & ExpInfo<TB>::kDevMask & ExpInfo<TC>::kDevMask; | |
}; | |
/*! \brief template to do type check */ | |
template<typename Device, int dim, typename DType, typename E> | |
struct TypeCheck { | |
/*! \brief dimension of expression*/ | |
static const int kExpDim = ExpInfo<E>::kDim; | |
/*! \brief whether the expression device type matches */ | |
static const bool kDevPass = (ExpInfo<E>::kDevMask & Device::kDevMask) != 0; | |
/*! \brief whether the expression can be mapped to expression of dim */ | |
static const bool kMapPass = (kExpDim == 0 || kExpDim == dim) && kDevPass; | |
/*! \brief whether the expression can be reduced to expression of dim */ | |
static const bool kRedPass = (kExpDim > dim) && kDevPass; | |
}; | |
/*! \brief used to help static type check*/ | |
template<bool kPass> | |
struct TypeCheckPass; | |
// Todo : add static assert using C++11 | |
template<> | |
struct TypeCheckPass<false> {}; | |
template<> | |
struct TypeCheckPass<true> { | |
inline static void Error_All_Tensor_in_Exp_Must_Have_Same_Type(void) {} | |
inline static void Error_TypeCheck_Not_Pass_For_Reduce_Exp(void) {} | |
inline static void Error_Expression_Does_Not_Meet_Dimension_Req(void) {} | |
}; | |
//---------------------------------------------------------------- | |
// Runtime Stream Getting | |
//---------------------------------------------------------------- | |
template<typename Device, typename E> | |
struct StreamInfo { | |
inline static Stream<Device> *Get(const E &t); | |
}; | |
template<int dim, typename Device, typename DType> | |
struct StreamInfo<Device, Tensor<Device, dim, DType> > { | |
inline static Stream<Device> *Get(const Tensor<Device, dim, DType> &t) { | |
return t.stream_; | |
} | |
}; | |
//---------------------------------------------------------------- | |
// Runtime Shape Checking | |
//---------------------------------------------------------------- | |
/*! | |
* \brief runtime shape checking template | |
* get the shape of an expression, report error if shape mismatch | |
* \tparam dim the dimension of the shape | |
* \tparam E expression | |
*/ | |
template<int dim, typename E> | |
struct ShapeCheck { | |
inline static Shape<dim> Check(const E &t); | |
}; | |
template<int dim, typename DType> | |
struct ShapeCheck<dim, ScalarExp<DType> > { | |
inline static Shape<dim> Check(const ScalarExp<DType> &exp) { | |
// use lowest dimension to mark scalar exp | |
Shape<dim> shape; | |
for (int i = 0; i < dim; ++i) { | |
shape[i] = 0; | |
} | |
return shape; | |
} | |
}; | |
template<int dim, typename DstDType, typename SrcDType, typename EType, int etype> | |
struct ShapeCheck<dim, TypecastExp<DstDType, SrcDType, EType, etype> > { | |
inline static Shape<dim> | |
Check(const TypecastExp<DstDType, SrcDType, EType, etype> &exp) { | |
return ShapeCheck<dim, EType>::Check(exp.exp); | |
} | |
}; | |
template<int dim, typename E, typename DType> | |
struct ShapeCheck<dim, TransposeExp<E, DType> > { | |
inline static Shape<dim> Check(const TransposeExp<E, DType> &e) { | |
// swap the lowest two dimensions | |
Shape<dim> s = ShapeCheck<dim, E>::Check(e.exp); | |
std::swap(s[0], s[1]); | |
return s; | |
} | |
}; | |
template<int dim, typename Device, typename DType> | |
struct ShapeCheck<dim, Tensor<Device, dim, DType> > { | |
inline static Shape<dim> Check(const Tensor<Device, dim, DType> &t) { | |
return t.shape_; | |
} | |
}; | |
template<int dim, typename SrcExp, typename T, typename DType> | |
struct ShapeCheck<dim, MakeTensorExp<T, SrcExp, dim, DType> > { | |
inline static Shape<dim> | |
Check(const MakeTensorExp<T, SrcExp, dim, DType> &t) { | |
return t.shape_; | |
} | |
}; | |
template<int dim, typename OP, typename TA, typename DType, int etype> | |
struct ShapeCheck<dim, UnaryMapExp<OP, TA, DType, etype> > { | |
inline static Shape<dim> Check(const UnaryMapExp<OP, TA, DType, etype> &t) { | |
Shape<dim> s = ShapeCheck<dim, TA>::Check(t.src_); | |
return s; | |
} | |
}; | |
template<int dim, typename OP, typename TA, typename TB, | |
typename DType, int etype> | |
struct ShapeCheck<dim, BinaryMapExp<OP, TA, TB, DType, etype> > { | |
inline static Shape<dim> | |
Check(const BinaryMapExp<OP, TA, TB, DType, etype> &t) { | |
Shape<dim> shape1 = ShapeCheck<dim, TA>::Check(t.lhs_); | |
Shape<dim> shape2 = ShapeCheck<dim, TB>::Check(t.rhs_); | |
if (shape1[0] == 0) return shape2; | |
if (shape2[0] == 0) return shape1; | |
CHECK_EQ(shape1, shape2) << "BinaryMapExp: Shapes of operands are not the same, " << | |
"Shape1=" << shape1 << ", Shape2=" << shape2; | |
return shape1; | |
} | |
}; | |
template<int dim, typename OP, typename TA, typename TB, typename TC, | |
typename DType, int etype> | |
struct ShapeCheck<dim, TernaryMapExp<OP, TA, TB, TC, DType, etype> > { | |
inline static Shape<dim> | |
Check(const TernaryMapExp<OP, TA, TB, TC, DType, etype> &t) { | |
Shape<dim> shape1 = ShapeCheck<dim, TA>::Check(t.item1_); | |
Shape<dim> shape2 = ShapeCheck<dim, TB>::Check(t.item2_); | |
Shape<dim> shape3 = ShapeCheck<dim, TC>::Check(t.item3_); | |
bool same = (shape1 == shape2) && (shape2 == shape3); | |
CHECK(same) << "TernaryMapExp: Shapes of operands are not the same, " << | |
"Shape1=" << shape1 << ", Shape2=" << shape2 << ", Shape3=" << shape3; | |
return shape1; | |
} | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
// include definition of dot engine | |
//===== EXPANDING: ../mshadow/mshadow/dot_engine-inl.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file dot_engine-inl.h | |
* \brief definitions of how Matrix Multiplications can be evaluated | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_DOT_ENGINE_INL_H_ | |
#define MSHADOW_DOT_ENGINE_INL_H_ | |
//===== EXPANDING: ../mshadow/mshadow/extension/implicit_gemm.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file implicit_gemm.h | |
* \brief support for implicit GEMM operation | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ | |
#define MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ | |
//===== EXPANDING: ../mshadow/mshadow/packet-inl.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file packet-inl.h | |
* \brief Generic packet vectorization code | |
*/ | |
#ifndef MSHADOW_PACKET_INL_H_ | |
#define MSHADOW_PACKET_INL_H_ | |
#ifdef __APPLE__ | |
#else | |
#endif | |
namespace mshadow { | |
/*! \brief namespace of packet math*/ | |
namespace packet { | |
enum PacketArch { | |
kPlain, | |
kSSE2, | |
}; | |
#if MSHADOW_USE_SSE | |
#define MSHADOW_DEFAULT_PACKET ::mshadow::packet::kSSE2 | |
#else | |
#define MSHADOW_DEFAULT_PACKET ::mshadow::packet::kPlain | |
#endif | |
// whether packet operator is enabled. | |
/*! | |
* \brief Generic packet type | |
* \tparam DType The data type of the packet. | |
* \tparam Arch the Arch of the packet. | |
*/ | |
template<typename DType, PacketArch Arch = MSHADOW_DEFAULT_PACKET> | |
struct Packet; | |
template<PacketArch Arch> | |
struct AlignBytes { | |
static const index_t value = 4; | |
}; | |
} // namespace packet | |
} // namespace mshadow | |
namespace mshadow { | |
namespace packet { | |
/*! | |
* \brief analog to cudaMallocPitch, allocate a aligned space with num_line * lspace cells | |
* \param out_pitch output parameter, the actuall space allocated for each line | |
* \param lspace number of cells required for each line | |
* \param num_line number of lines to be allocated | |
*/ | |
inline void* AlignedMallocPitch(size_t *out_pitch, | |
size_t lspace, | |
size_t num_line) { | |
const index_t bits = AlignBytes<MSHADOW_DEFAULT_PACKET>::value; | |
const index_t mask = (1 << bits) - 1; | |
size_t pitch = ((lspace + mask) >> bits) << bits; | |
*out_pitch = pitch; | |
#ifdef _MSC_VER | |
void *res = _aligned_malloc(pitch * num_line, 1 << bits); | |
#else | |
void *res; | |
int ret = posix_memalign(&res, 1 << bits, pitch * num_line); | |
CHECK_EQ(ret, 0) << "AlignedMallocPitch failed"; | |
#endif | |
if (res == NULL) { | |
LOG(FATAL) << "AlignedMallocPitch failed"; | |
} | |
return res; | |
} | |
/*! | |
* \brief free aligned space | |
* \param ptr pointer to space to be freed | |
*/ | |
inline void AlignedFree(void *ptr) { | |
#ifdef _MSC_VER | |
_aligned_free(ptr); | |
#else | |
free(ptr); | |
#endif | |
} | |
/*! \brief check if a pointer is aligned */ | |
template<PacketArch Arch> | |
inline bool CheckAlign(size_t pitch) { | |
const index_t bits = AlignBytes<Arch>::value; | |
return !(pitch & ((1 << bits) - 1)); | |
} | |
/*! \brief check if a pointer is aligned */ | |
template<PacketArch Arch> | |
inline bool CheckAlign(void *ptr) { | |
return CheckAlign<Arch>(reinterpret_cast<size_t>(ptr)); | |
} | |
/*! | |
* \brief get upper bound of aligned index of size | |
* \param size size of the array | |
* \param fsize size of float | |
*/ | |
template<typename DType, PacketArch Arch> | |
inline index_t UpperAlign(index_t size) { | |
const index_t bits = AlignBytes<MSHADOW_DEFAULT_PACKET>::value; | |
const index_t mask = (1 << bits) - 1; | |
const index_t fsize = sizeof(DType); | |
return (((size * fsize + mask) >> bits) << bits) / fsize; | |
} | |
/*! | |
* \brief get lower bound of aligned index of size | |
* \param size size of the array | |
* \param fsize size of float | |
*/ | |
template<typename DType, PacketArch Arch> | |
inline index_t LowerAlign(index_t size) { | |
const index_t bits = AlignBytes<MSHADOW_DEFAULT_PACKET>::value; | |
const index_t fsize = sizeof(DType); | |
return (((size * fsize) >> bits) << bits) / fsize; | |
} | |
/*! | |
* \brief generic Packet operator | |
* \tparam OP The operator | |
* \tparam DType The data type | |
* \tparam Arch The architecture. | |
*/ | |
template<typename OP, typename DType, PacketArch Arch> | |
struct PacketOp { | |
static const bool kEnabled = false; | |
}; | |
// specialization of operators | |
template<typename DType, PacketArch Arch> | |
struct PacketOp<op::plus, DType, Arch> { | |
static const bool kEnabled = true; | |
MSHADOW_CINLINE static Packet<DType, Arch> Map(const Packet<DType, Arch>& lhs, | |
const Packet<DType, Arch>& rhs) { | |
return lhs + rhs; | |
} | |
}; | |
template<typename DType, PacketArch Arch> | |
struct PacketOp<op::minus, DType, Arch> { | |
static const bool kEnabled = true; | |
MSHADOW_CINLINE static Packet<DType, Arch> Map(const Packet<DType, Arch>& lhs, | |
const Packet<DType, Arch>& rhs) { | |
return lhs - rhs; | |
} | |
}; | |
template<typename DType, PacketArch Arch> | |
struct PacketOp<op::mul, DType, Arch> { | |
static const bool kEnabled = true; | |
MSHADOW_CINLINE static Packet<DType, Arch> Map(const Packet<DType, Arch>& lhs, | |
const Packet<DType, Arch>& rhs) { | |
return lhs * rhs; | |
} | |
}; | |
template<typename DType, PacketArch Arch> | |
struct PacketOp<op::div, DType, Arch> { | |
static const bool kEnabled = true; | |
MSHADOW_CINLINE static Packet<DType, Arch> Map(const Packet<DType, Arch>& lhs, | |
const Packet<DType, Arch>& rhs) { | |
return lhs / rhs; | |
} | |
}; | |
template<typename DType, PacketArch Arch> | |
struct PacketOp<op::identity, DType, Arch> { | |
static const bool kEnabled = true; | |
MSHADOW_CINLINE static Packet<DType, Arch> Map(const Packet<DType, Arch>& src) { | |
return src; | |
} | |
}; | |
// savers to do storage | |
template<typename SV, typename TFloat, PacketArch Arch> | |
struct Saver{ | |
MSHADOW_CINLINE static void Save(TFloat *dst, const Packet<TFloat, Arch>& src) { | |
Packet<TFloat, Arch> lhs = Packet<TFloat, Arch>::Load(dst); | |
Packet<TFloat, Arch> ans = PacketOp<typename SV::OPType, TFloat, Arch>::Map(lhs, src); | |
ans.Store(dst); | |
} | |
}; | |
template<typename TFloat, PacketArch Arch> | |
struct Saver<sv::saveto, TFloat, Arch> { | |
MSHADOW_CINLINE static void Save(TFloat *dst, const Packet<TFloat, Arch>& src) { | |
src.Store(dst); | |
} | |
}; | |
} // namespace packet | |
} // namespace mshadow | |
//===== EXPANDING: ../mshadow/mshadow/packet/plain-inl.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file plain-inl.h | |
* \brief support of plain packet that use the plain datatype. | |
*/ | |
#ifndef MSHADOW_PACKET_PLAIN_INL_H_ | |
#define MSHADOW_PACKET_PLAIN_INL_H_ | |
namespace mshadow { | |
namespace packet { | |
template<typename DType> | |
struct Packet<DType, kPlain> { | |
public: | |
/*! \brief number of float in vector */ | |
static const index_t kSize = 1; | |
/*! \brief The internal data */ | |
DType data_; | |
// enable default copy constructor | |
Packet(void) {} | |
// constructor from the intrinsic type | |
explicit Packet(DType data) : data_(data) {} | |
// create a fill with the target value s | |
MSHADOW_CINLINE static Packet<DType, kPlain> Fill(DType s) { | |
return Packet<DType, kPlain>(s); | |
} | |
// load from address | |
MSHADOW_CINLINE static Packet<DType, kPlain> Load(const DType* src) { | |
return Packet<DType, kPlain>(*src); | |
} | |
// load from address | |
MSHADOW_CINLINE static Packet<DType, kPlain> LoadUnAligned(const DType* src) { | |
return Packet<DType, kPlain>(*src); | |
} | |
// fill it with value s | |
MSHADOW_CINLINE Packet<DType, kPlain>& operator=(DType s) { | |
data_ = s; | |
return *this; | |
} | |
// store data into dst | |
MSHADOW_CINLINE void Store(DType* dst) const { | |
*dst = data_; | |
} | |
// get the sum of all contents | |
MSHADOW_CINLINE DType Sum() const { | |
return data_; | |
} | |
}; | |
template<typename DType> | |
MSHADOW_CINLINE Packet<DType, kPlain> operator+(const Packet<DType, kPlain>& lhs, | |
const Packet<DType, kPlain>& rhs) { | |
return Packet<DType, kPlain>(lhs.data_ + rhs.data_); | |
} | |
template<typename DType> | |
MSHADOW_CINLINE Packet<DType, kPlain> operator-(const Packet<DType, kPlain>& lhs, | |
const Packet<DType, kPlain>& rhs) { | |
return Packet<DType, kPlain>(lhs.data_ - rhs.data_); | |
} | |
template<typename DType> | |
MSHADOW_CINLINE Packet<DType, kPlain> operator*(const Packet<DType, kPlain>& lhs, | |
const Packet<DType, kPlain>& rhs) { | |
return Packet<DType, kPlain>(lhs.data_ * rhs.data_); | |
} | |
template<typename DType> | |
MSHADOW_CINLINE Packet<DType, kPlain> operator/(const Packet<DType, kPlain>& lhs, | |
const Packet<DType, kPlain>& rhs) { | |
return Packet<DType, kPlain>(lhs.data_ / rhs.data_); | |
} | |
} // namespace packet | |
} // namespace mshadow | |
#endif // MSHADOW_PACKET_PLAIN_INL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/packet/plain-inl.h ===== | |
#if MSHADOW_USE_SSE && !defined(__CUDACC__) | |
//===== EXPANDING: ../mshadow/mshadow/packet/sse-inl.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file sse-inl.h | |
* \brief support of sse2 packet optimization of some operations | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_PACKET_SSE_INL_H_ | |
#define MSHADOW_PACKET_SSE_INL_H_ | |
namespace mshadow { | |
namespace packet { | |
template<> | |
struct Packet<float, kSSE2> { | |
public: | |
/*! \brief number of float in vector */ | |
static const index_t kSize = 4; | |
/*! \brief The internal data */ | |
__m128 data_; | |
// enable default copy constructor | |
Packet(void) {} | |
// constructor from the intrinsic type | |
explicit Packet(__m128 data) : data_(data) {} | |
// create a fill with the target value s | |
MSHADOW_CINLINE static Packet<float, kSSE2> Fill(float s) { | |
return Packet<float, kSSE2>(_mm_set1_ps(s)); | |
} | |
// load from address | |
MSHADOW_CINLINE static Packet<float, kSSE2> Load(const float* src) { | |
return Packet<float, kSSE2>(_mm_load_ps(src)); | |
} | |
// load from address | |
MSHADOW_CINLINE static Packet<float, kSSE2> LoadUnAligned(const float* src) { | |
return Packet<float, kSSE2>(_mm_loadu_ps(src)); | |
} | |
// fill it with value s | |
MSHADOW_CINLINE Packet<float, kSSE2>& operator=(float s) { | |
data_ = _mm_set1_ps(s); | |
return *this; | |
} | |
// store data into dst | |
MSHADOW_CINLINE void Store(float* dst) const { | |
_mm_store_ps(dst, data_); | |
} | |
// get the sum of all contents | |
MSHADOW_CINLINE float Sum() const { | |
__m128 ans = _mm_add_ps(data_, _mm_movehl_ps(data_, data_)); | |
__m128 rst = _mm_add_ss(ans, _mm_shuffle_ps(ans, ans, 1)); | |
#if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64) | |
return rst.m128_f32[0]; | |
#else | |
float rr = _mm_cvtss_f32(rst); | |
return rr; | |
#endif | |
} | |
}; | |
/*! \brief vector real type for float */ | |
template<> | |
struct Packet<double, kSSE2> { | |
/*! \brief number of float in vector */ | |
static const index_t kSize = 2; | |
// internal data | |
__m128d data_; | |
// constructor | |
Packet(void) {} | |
explicit Packet(__m128d data) : data_(data) {} | |
// create a fill with the target value s | |
MSHADOW_CINLINE static Packet<double, kSSE2> Fill(double s) { | |
return Packet<double, kSSE2>(_mm_set1_pd(s)); | |
} | |
// load from address | |
MSHADOW_CINLINE static Packet<double, kSSE2> Load(const double* src) { | |
return Packet<double, kSSE2>(_mm_load_pd(src)); | |
} | |
MSHADOW_CINLINE static Packet<double, kSSE2> LoadUnAligned(const double* src) { | |
return Packet<double, kSSE2>(_mm_loadu_pd(src)); | |
} | |
// fill it with value s | |
MSHADOW_CINLINE Packet<double, kSSE2>& operator=(double s) { | |
data_ = _mm_set1_pd(s); | |
return *this; | |
} | |
// store data into dst | |
MSHADOW_CINLINE void Store(double* dst) const { | |
_mm_store_pd(dst, data_); | |
} | |
// get sum of all content | |
inline double Sum(void) const { | |
__m128d tmp = _mm_add_sd(data_, _mm_unpackhi_pd(data_, data_)); | |
#if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64) | |
return tmp.m128d_f64[0]; | |
#else | |
double ans = _mm_cvtsd_f64(tmp); | |
return ans; | |
#endif | |
} | |
}; | |
MSHADOW_CINLINE Packet<float, kSSE2> operator+(const Packet<float, kSSE2>& lhs, | |
const Packet<float, kSSE2>& rhs) { | |
return Packet<float, kSSE2>(_mm_add_ps(lhs.data_, rhs.data_)); | |
} | |
MSHADOW_CINLINE Packet<double, kSSE2> operator+(const Packet<double, kSSE2>& lhs, | |
const Packet<double, kSSE2>& rhs) { | |
return Packet<double, kSSE2>(_mm_add_pd(lhs.data_, rhs.data_)); | |
} | |
MSHADOW_CINLINE Packet<float, kSSE2> operator-(const Packet<float, kSSE2>& lhs, | |
const Packet<float, kSSE2>& rhs) { | |
return Packet<float, kSSE2>(_mm_sub_ps(lhs.data_, rhs.data_)); | |
} | |
MSHADOW_CINLINE Packet<double, kSSE2> operator-(const Packet<double, kSSE2>& lhs, | |
const Packet<double, kSSE2>& rhs) { | |
return Packet<double, kSSE2>(_mm_sub_pd(lhs.data_, rhs.data_)); | |
} | |
MSHADOW_CINLINE Packet<float, kSSE2> operator*(const Packet<float, kSSE2>& lhs, | |
const Packet<float, kSSE2>& rhs) { | |
return Packet<float, kSSE2>(_mm_mul_ps(lhs.data_, rhs.data_)); | |
} | |
MSHADOW_CINLINE Packet<double, kSSE2> operator*(const Packet<double, kSSE2>& lhs, | |
const Packet<double, kSSE2>& rhs) { | |
return Packet<double, kSSE2>(_mm_mul_pd(lhs.data_, rhs.data_)); | |
} | |
MSHADOW_CINLINE Packet<float, kSSE2> operator/(const Packet<float, kSSE2>& lhs, | |
const Packet<float, kSSE2>& rhs) { | |
return Packet<float, kSSE2>(_mm_div_ps(lhs.data_, rhs.data_)); | |
} | |
MSHADOW_CINLINE Packet<double, kSSE2> operator/(const Packet<double, kSSE2>& lhs, | |
const Packet<double, kSSE2>& rhs) { | |
return Packet<double, kSSE2>(_mm_div_pd(lhs.data_, rhs.data_)); | |
} | |
} // namespace packet | |
} // namespace mshadow | |
#endif // MSHADOW_PACKET_SSE_INL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/packet/sse-inl.h ===== | |
#endif | |
namespace mshadow { | |
namespace expr { | |
typedef packet::PacketArch PacketArch; | |
// same as plan, but use packet | |
template<typename ExpType, typename DType, PacketArch Arch> | |
class PacketPlan { | |
public: | |
/*! | |
* \brief evaluate the expression at index [y][x], | |
* x will be aligned to Packet<DType, Arch>::kSize | |
*/ | |
MSHADOW_CINLINE packet::Packet<DType, Arch> EvalPacket(index_t y, index_t x) const; | |
MSHADOW_CINLINE DType Eval(index_t y, index_t x) const; | |
}; | |
template <typename Device, int dim, typename DType, PacketArch Arch> | |
class PacketPlan<Tensor<Device, dim, DType>, DType, Arch> { | |
public: | |
explicit PacketPlan(const Tensor<Device, dim, DType> &t) | |
:dptr_(t.dptr_), stride_(t.stride_) {} | |
MSHADOW_CINLINE packet::Packet<DType, Arch> EvalPacket(index_t y, index_t x) const { | |
return packet::Packet<DType, Arch>::Load(&dptr_[y * stride_ + x]); | |
} | |
MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { | |
return dptr_[y * stride_ + x]; | |
} | |
private: | |
const DType *dptr_; | |
index_t stride_; | |
}; | |
template<typename DType, PacketArch Arch> | |
class PacketPlan<ScalarExp<DType>, DType, Arch> { | |
public: | |
explicit PacketPlan(DType scalar) : scalar_(scalar) {} | |
MSHADOW_CINLINE packet::Packet<DType, Arch> EvalPacket(index_t y, index_t x) const { | |
return packet::Packet<DType, Arch>::Fill(scalar_); | |
} | |
MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { | |
return scalar_; | |
} | |
private: | |
DType scalar_; | |
}; | |
template<typename OP, typename TA, typename TB, int etype, typename DType, PacketArch Arch> | |
class PacketPlan<BinaryMapExp<OP, TA, TB, DType, etype>, DType, Arch> { | |
public: | |
PacketPlan(const PacketPlan<TA, DType, Arch> &lhs, const PacketPlan<TB, DType, Arch> &rhs) | |
: lhs_(lhs), rhs_(rhs) {} | |
MSHADOW_CINLINE packet::Packet<DType, Arch> EvalPacket(index_t y, index_t x) const { | |
return packet::PacketOp<OP, DType, Arch>::Map(lhs_.EvalPacket(y, x), rhs_.EvalPacket(y, x)); | |
} | |
MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { | |
return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x)); | |
} | |
private: | |
PacketPlan<TA, DType, Arch> lhs_; | |
PacketPlan<TB, DType, Arch> rhs_; | |
}; | |
template<typename OP, typename TA, int etype, typename DType, PacketArch Arch> | |
class PacketPlan<UnaryMapExp<OP, TA, DType, etype>, DType, Arch> { | |
public: | |
PacketPlan(const PacketPlan<TA, DType, Arch> &src) : src_(src) {} | |
MSHADOW_CINLINE packet::Packet<DType> EvalPacket(index_t y, index_t x) const { | |
return packet::PacketOp<OP, DType, Arch>::Map(src_.EvalPacket(y, x)); | |
} | |
MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { | |
return OP::Map(src_.Eval(y, x)); | |
} | |
private: | |
PacketPlan<TA, DType, Arch> src_; | |
}; | |
template<PacketArch Arch, typename OP, typename TA, typename TB, typename DType, int etype> | |
inline PacketPlan<BinaryMapExp<OP, TA, TB, DType, etype>, DType, Arch> | |
MakePacketPlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e); | |
template<PacketArch Arch, typename DType> | |
inline PacketPlan<ScalarExp<DType>, DType, Arch> MakePacketPlan(const ScalarExp<DType> &e) { | |
return PacketPlan<ScalarExp<DType>, DType, Arch>(e.scalar_); | |
} | |
template<PacketArch Arch, typename T, typename DType> | |
inline PacketPlan<T, DType, Arch> MakePacketPlan(const RValueExp<T, DType> &e) { | |
return PacketPlan<T, DType, Arch>(e.self()); | |
} | |
template<PacketArch Arch, typename T, int dim, typename DType> | |
inline PacketPlan<T, DType, Arch> | |
MakePacketPlan(const MakeTensorExp<T, cpu, dim, DType> &e) { | |
return PacketPlan<T, DType, Arch>(e.real_self()); | |
} | |
template<PacketArch Arch, typename OP, typename TA, typename DType, int etype> | |
inline PacketPlan<UnaryMapExp<OP, TA, DType, etype>, DType, Arch> | |
MakePacketPlan(const UnaryMapExp<OP, TA, DType, etype> &e) { | |
return PacketPlan<UnaryMapExp<OP, TA, DType, etype>, DType, Arch>(MakePacketPlan<Arch>(e.src_)); | |
} | |
template<PacketArch Arch, typename OP, typename TA, typename TB, typename DType, int etype> | |
inline PacketPlan<BinaryMapExp<OP, TA, TB, DType, etype>, DType, Arch> | |
MakePacketPlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e) { | |
return PacketPlan<BinaryMapExp<OP, TA, TB, DType, etype>, | |
DType, Arch>(MakePacketPlan<Arch>(e.lhs_), MakePacketPlan<Arch>(e.rhs_)); | |
} | |
/*! | |
* \brief static check packet enable | |
* | |
* \tparam Device the type of Device | |
* \tparam dim dimension of the tensor | |
* \tparam E expression | |
*/ | |
template<typename E, PacketArch Arch> | |
struct PacketCheck{ | |
static const bool kPass = false; | |
}; | |
template<PacketArch Arch> | |
struct PacketCheck<float, Arch> { | |
static const bool kPass = true; | |
}; | |
template<PacketArch Arch> | |
struct PacketCheck<double, Arch> { | |
static const bool kPass = true; | |
}; | |
template<typename DType, PacketArch Arch> | |
struct PacketCheck<ScalarExp<DType>, Arch> { | |
static const bool kPass = PacketCheck<DType, Arch>::kPass; | |
}; | |
template<int dim, typename DType, PacketArch Arch> | |
struct PacketCheck<Tensor<cpu, dim, DType>, Arch> { | |
static const bool kPass = PacketCheck<DType, Arch>::kPass; | |
}; | |
template<typename OP, typename TA, typename DType, int etype, PacketArch Arch> | |
struct PacketCheck<UnaryMapExp<OP, TA, DType, etype>, Arch> { | |
static const bool kPass = PacketCheck<TA, Arch>::kPass && | |
packet::PacketOp<OP, DType, Arch>::kEnabled; | |
}; | |
template<typename OP, typename TA, typename TB, typename DType, int etype, PacketArch Arch> | |
struct PacketCheck< BinaryMapExp<OP, TA, TB, DType, etype>, Arch> { | |
static const bool kPass = packet::PacketOp<OP, DType, Arch>::kEnabled && | |
PacketCheck<TA, Arch>::kPass && PacketCheck<TB, Arch>::kPass; | |
}; | |
//---------------------------------------------------- | |
// Check if data is aligned and allow packet operation | |
//---------------------------------------------------- | |
template<int dim, typename E, PacketArch Arch> | |
struct PacketAlignCheck { | |
inline static bool Check(const E &exp) { | |
return false; | |
} | |
}; | |
template<int dim, typename DType, PacketArch Arch> | |
struct PacketAlignCheck<dim, ScalarExp<DType>, Arch> { | |
inline static bool Check(const ScalarExp<DType> &exp) { | |
return true; | |
} | |
}; | |
template<int dim, typename DType, PacketArch Arch> | |
struct PacketAlignCheck<dim, Tensor<cpu, dim, DType>, Arch> { | |
inline static bool Check(const Tensor<cpu, dim, DType> &t) { | |
return packet::CheckAlign<Arch>(t.dptr_) && | |
packet::CheckAlign<Arch>(t.stride_ * sizeof(DType)); | |
} | |
}; | |
template<int dim, typename OP, typename TA, typename DType, int etype, PacketArch Arch> | |
struct PacketAlignCheck<dim, UnaryMapExp<OP, TA, DType, etype>, Arch> { | |
inline static bool Check(const UnaryMapExp<OP, TA, DType, etype> &t) { | |
return PacketAlignCheck<dim, TA, Arch>::Check(t.src_); | |
} | |
}; | |
template<int dim, typename OP, typename TA, typename TB, | |
typename DType, int etype, PacketArch Arch> | |
struct PacketAlignCheck<dim, BinaryMapExp<OP, TA, TB, DType, etype>, Arch> { | |
inline static bool Check(const BinaryMapExp<OP, TA, TB, DType, etype> &t) { | |
return PacketAlignCheck<dim, TA, Arch>::Check(t.lhs_) && | |
PacketAlignCheck<dim, TB, Arch>::Check(t.rhs_); | |
} | |
}; | |
/*! | |
* \brief use PacketPlan to compute result | |
*/ | |
template<typename SV, typename E, int dim, typename DType, PacketArch Arch> | |
inline void MapPacketPlan(Tensor<cpu, dim, DType> _dst, | |
const expr::PacketPlan<E, DType, Arch>& plan) { | |
Tensor<cpu, 2, DType> dst = _dst.FlatTo2D(); | |
const index_t xlen = packet::LowerAlign<DType, Arch>(dst.size(1)); | |
#if (MSHADOW_USE_CUDA == 0) | |
#pragma omp parallel for | |
#endif | |
for (openmp_index_t y = 0; y < dst.size(0); ++y) { | |
for (index_t x = 0; x < xlen; x += packet::Packet<DType, Arch>::kSize) { | |
packet::Saver<SV, DType, Arch>::Save(&dst[y][x], plan.EvalPacket(y, x)); | |
} | |
for (index_t x = xlen; x < dst.size(1); ++x) { | |
SV::Save(dst[y][x], plan.Eval(y, x)); | |
} | |
} | |
} | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_PACKET_INL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/packet-inl.h ===== | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief Matrix multiplication. | |
* \tparam LhsExp type of lhs expression | |
* \tparam LhsExp type of rhs expression | |
* \tparam DType the type of elements | |
*/ | |
template<typename LhsExp, typename RhsExp, typename DType> | |
struct ImplicitGEMMExp: | |
public Exp<ImplicitGEMMExp<LhsExp, RhsExp, DType>, | |
DType, type::kChainer> { | |
/*! \brief lhs operand */ | |
const LhsExp &lhs_; | |
/*! \brief rhs operand */ | |
const RhsExp &rhs_; | |
/*! \brief internal production size*/ | |
index_t prod_size_; | |
/*! \brief the shape of this expression */ | |
Shape<2> shape_; | |
/*! \brief constructor */ | |
ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs) | |
: lhs_(lhs), rhs_(rhs) { | |
Shape<2> slhs = ShapeCheck<2, LhsExp>::Check(lhs_); | |
Shape<2> srhs = ShapeCheck<2, RhsExp>::Check(rhs_); | |
this->shape_ = mshadow::Shape2(slhs[0], srhs[1]); | |
prod_size_ = slhs[1]; | |
} | |
}; | |
template<typename LhsExp, typename RhsExp, typename DType, int e1, int e2> | |
inline ImplicitGEMMExp<LhsExp, RhsExp, DType> | |
implicit_dot(const Exp<LhsExp, DType, e1> &lhs, | |
const Exp<RhsExp, DType, e2> &rhs) { | |
TypeCheckPass<ExpInfo<LhsExp>::kDim == 2 && ExpInfo<RhsExp>::kDim == 2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return ImplicitGEMMExp<LhsExp, RhsExp, DType>(lhs.self(), rhs.self()); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename LhsExp, typename RhsExp, typename DType> | |
struct Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType> { | |
public: | |
explicit Plan(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &e) | |
: lhs_(MakePlan(e.lhs_)), | |
rhs_(MakePlan(e.rhs_)), | |
prod_size_(e.prod_size_), | |
prod_size_lower_align_(packet::LowerAlign<DType, MSHADOW_DEFAULT_PACKET>(e.prod_size_)) { | |
} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
typedef packet::Packet<DType> Packet; | |
Packet sum = Packet::Fill(0); | |
DType lhs_temp[Packet::kSize], rhs_temp[Packet::kSize]; | |
for (index_t i = 0; i < prod_size_lower_align_; i += packet::Packet<DType>::kSize) { | |
// unroll | |
for (index_t j = 0; j < Packet::kSize; ++j) { | |
lhs_temp[j] = lhs_.Eval(y, i + j); | |
} | |
for (index_t j = 0; j < Packet::kSize; ++j) { | |
rhs_temp[j] = rhs_.Eval(i + j, x); | |
} | |
sum = sum + Packet::LoadUnAligned(lhs_temp) * Packet::LoadUnAligned(rhs_temp); | |
} | |
DType ret_result = sum.Sum(); | |
for (index_t i = prod_size_lower_align_; i < prod_size_; ++i) { | |
ret_result += lhs_.Eval(y, i) * rhs_.Eval(i, x); | |
} | |
return ret_result; | |
} | |
private: | |
expr::Plan<LhsExp, DType> lhs_; | |
expr::Plan<RhsExp, DType> rhs_; | |
const index_t prod_size_; | |
const index_t prod_size_lower_align_; | |
}; | |
template<typename LhsExp, typename RhsExp, typename DType> | |
inline Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType> | |
MakePlan(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &exp) { | |
return Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType>(exp); | |
} | |
template<int dim, typename LhsExp, typename RhsExp, typename DType> | |
struct ShapeCheck<dim, ImplicitGEMMExp<LhsExp, RhsExp, DType> > { | |
inline static Shape<dim> | |
Check(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &t) { | |
CHECK(dim == 2) | |
<< "ImplicitGEMMExp only support 2 dimension"; | |
Shape<dim> shape1 = ShapeCheck<dim, LhsExp>::Check(t.lhs_); | |
Shape<dim> shape2 = ShapeCheck<dim, RhsExp>::Check(t.rhs_); | |
CHECK_EQ(shape1[1], shape2[0]) | |
<< "implicit_dot The matrix shape do not match"; | |
return t.shape_; | |
} | |
}; | |
template<typename LhsExp, typename RhsExp, typename DType> | |
struct ExpInfo<ImplicitGEMMExp<LhsExp, RhsExp, DType> > { | |
static const int kDim = 2; | |
static const int kDevMask = ExpInfo<LhsExp>::kDevMask & ExpInfo<RhsExp>::kDevMask; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/implicit_gemm.h ===== | |
#ifdef __CUDACC__ | |
#endif // #ifdef __CUDACC__ | |
namespace mshadow { | |
/*! | |
* \brief CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride | |
* \param dst 2D pointer | |
* \param src 1D pointer | |
* \param num number of batches | |
* \param stride size of each batch | |
* \param stream | |
*/ | |
template<typename Device, typename DType> | |
inline void GetBatchedView(DType **dst, DType *src, int num, int stride, | |
Stream<Device> *stream); | |
template<typename DType> | |
inline void GetBatchedView(DType **dst, DType *src, int num, int stride, | |
Stream<cpu> *stream) { | |
for (int i = 0; i < num; i++) { | |
dst[i] = src + i * stride; | |
} | |
} | |
#ifdef __CUDACC__ | |
template<typename DType> | |
inline void GetBatchedView(DType **dst, DType *src, int num, int stride, | |
Stream<gpu> *stream) { | |
cuda::GetBatchedView(dst, src, num, stride, stream); | |
} | |
#endif // #ifdef __CUDACC__ | |
namespace expr { | |
//--------------------------------------------------------------------- | |
// Matrix Multiplications, depends on BLAS Engine | |
//--------------------------------------------------------------------- | |
template<typename SV, typename Device, int ddim, int ldim, | |
int rdim, bool ltrans, bool rtrans, typename DType> | |
struct DotEngine { | |
inline static void Eval(Tensor<Device, ddim, DType> *p_dst, | |
const Tensor<Device, ldim, DType> &lhs, | |
const Tensor<Device, rdim, DType> &rhs, | |
DType scale); | |
}; | |
// handles the dot, use CblasColMajor | |
template<typename Device, typename DType = default_real_t> | |
struct BLASEngine { | |
inline static bool GetT(bool t) { | |
return t ? true : false; | |
} | |
inline static void SetStream(Stream<Device> *stream) { | |
} | |
inline static void gemm(Stream<Device> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, DType alpha, | |
const DType *A, int lda, const DType *B, int ldb, | |
DType beta, DType *C, int ldc) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void batched_gemm(Stream<Device> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, DType alpha, | |
const DType *A, int lda, const DType *B, int ldb, | |
DType beta, DType *C, int ldc, int batch_count, | |
DType **workspace) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void gemv(Stream<Device> *stream, | |
bool trans, int m, int n, | |
DType alpha, const DType *A, int lda, | |
const DType *X, int incX, | |
DType beta, DType *Y, int incY) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void batched_gemv(Stream<Device> *stream, | |
bool trans, int m, int n, | |
DType alpha, const DType *A, int lda, | |
const DType *X, int incX, | |
DType beta, DType *Y, int incY, int batch_count) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void ger(Stream<Device> *stream, | |
int m, int n, DType alpha, | |
const DType *X, int incX, | |
const DType *Y, int incY, DType *A, int lda) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void batched_ger(Stream<Device> *stream, | |
int m, int n, DType alpha, | |
const DType *X, int incX, | |
const DType *Y, int incY, DType *A, int lda, int batch_count) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void dot(Stream<Device> *stream, | |
int n, | |
const DType* X, int incX, | |
const DType* Y, int incY, | |
DType* ret) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
}; | |
#if MSHADOW_STAND_ALONE | |
template<> | |
struct BLASEngine<cpu, float> { | |
inline static bool GetT(bool t) { | |
return t ? true : false; | |
} | |
inline static void SetStream(Stream<cpu> *stream) { | |
} | |
inline static void gemm(Stream<cpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, float alpha, | |
const float *A, int lda, const float *B, int ldb, | |
float beta, float *C, int ldc) { | |
if (alpha == 1.0f && beta == 0.0f) { | |
bool transpose_left = transb; | |
bool transpose_right = transa; | |
Tensor<cpu, 2, float> lhs((float*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*) | |
Tensor<cpu, 2, float> rhs((float*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*) | |
Tensor<cpu, 2, float> dst(C, Shape2(m, n)); | |
if (!transpose_left && !transpose_right) { | |
dst = expr::implicit_dot(lhs, rhs); return; | |
} else if (!transpose_left && transpose_right) { | |
dst = expr::implicit_dot(lhs, rhs.T()); return; | |
} else if (transpose_left && !transpose_right) { | |
dst = expr::implicit_dot(lhs.T(), rhs); return; | |
} else { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
} else { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
} | |
inline static void batched_gemm(Stream<cpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, float alpha, | |
const float *A, int lda, const float *B, int ldb, | |
float beta, float *C, int ldc, int batch_count, | |
float **workspace) { | |
for (int i = 0; i < batch_count; ++i) { | |
gemm(stream, transa, transb, m, n, k, alpha, | |
A + i * m * k, lda, B + i * k * n, ldb, | |
beta, C + i * m * n, ldc); | |
} | |
} | |
inline static void gemv(Stream<cpu> *stream, | |
bool trans, int m, int n, | |
float alpha, const float *A, int lda, | |
const float *X, int incX, | |
float beta, float *Y, int incY) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void batched_gemv(Stream<cpu> *stream, | |
bool trans, int m, int n, | |
float alpha, const float *A, int lda, | |
const float *X, int incX, | |
float beta, float *Y, int incY, int batch_count) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void ger(Stream<cpu> *stream, | |
int m, int n, float alpha, | |
const float *X, int incX, | |
const float *Y, int incY, float *A, int lda) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void batched_ger(Stream<cpu> *stream, | |
int m, int n, float alpha, | |
const float *X, int incX, | |
const float *Y, int incY, float *A, int lda, int batch_count) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void dot(Stream<cpu> *stream, | |
int n, | |
const float* X, int incX, | |
const float* Y, int incY, | |
float* ret) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
}; | |
template<> | |
struct BLASEngine<cpu, double> { | |
inline static bool GetT(bool t) { | |
return t ? true : false; | |
} | |
inline static void SetStream(Stream<cpu> *stream) { | |
} | |
inline static void gemm(Stream<cpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, double alpha, | |
const double *A, int lda, const double *B, int ldb, | |
double beta, double *C, int ldc) { | |
if (alpha == 1.0f && beta == 0.0f) { | |
bool transpose_left = transb; | |
bool transpose_right = transa; | |
Tensor<cpu, 2, double> lhs((double*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*) | |
Tensor<cpu, 2, double> rhs((double*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*) | |
Tensor<cpu, 2, double> dst(C, Shape2(m, n)); | |
if (!transpose_left && !transpose_right) { | |
dst = expr::implicit_dot(lhs, rhs); return; | |
} else if (!transpose_left && transpose_right) { | |
dst = expr::implicit_dot(lhs, rhs.T()); return; | |
} else if (transpose_left && !transpose_right) { | |
dst = expr::implicit_dot(lhs.T(), rhs); return; | |
} else { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
} else { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
} | |
inline static void batched_gemm(Stream<cpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, double alpha, | |
const double *A, int lda, const double *B, int ldb, | |
double beta, double *C, int ldc, int batch_count, | |
double **workspace) { | |
for (int i = 0; i < batch_count; ++i) { | |
gemm(stream, transa, transb, m, n, k, alpha, | |
A + i * m * k, lda, B + i * k * n, ldb, | |
beta, C + i * m * n, ldc); | |
} | |
} | |
inline static void gemv(Stream<cpu> *stream, | |
bool trans, int m, int n, | |
double alpha, const double *A, int lda, | |
const double *X, int incX, | |
double beta, double *Y, int incY) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void batched_gemv(Stream<cpu> *stream, | |
bool trans, int m, int n, | |
double alpha, const double *A, int lda, | |
const double *X, int incX, | |
double beta, double *Y, int incY, int batch_count) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void ger(Stream<cpu> *stream, | |
int m, int n, double alpha, | |
const double *X, int incX, | |
const double *Y, int incY, double *A, int lda) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void batched_ger(Stream<cpu> *stream, | |
int m, int n, double alpha, | |
const double *X, int incX, | |
const double *Y, int incY, double *A, int lda, int batch_count) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void dot(Stream<cpu> *stream, | |
int n, | |
const double* X, int incX, | |
const double* Y, int incY, | |
double* ret) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
}; | |
#elif (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) // NOLINT(*) | |
template<> | |
struct BLASEngine<cpu, float> { | |
inline static CBLAS_TRANSPOSE GetT(bool t) { | |
return t ? CblasTrans : CblasNoTrans; | |
} | |
inline static void SetStream(Stream<cpu> *stream) { | |
} | |
inline static void gemm(Stream<cpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, float alpha, | |
const float *A, int lda, const float *B, int ldb, | |
float beta, float *C, int ldc) { | |
cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb), | |
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); | |
} | |
inline static void batched_gemm(Stream<cpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, float alpha, | |
const float *A, int lda, const float *B, int ldb, | |
float beta, float *C, int ldc, int batch_count, | |
float **workspace) { | |
for (int i = 0; i < batch_count; ++i) { | |
gemm(stream, transa, transb, m, n, k, alpha, | |
A + i * m * k, lda, B + i * k * n, ldb, | |
beta, C + i * m * n, ldc); | |
} | |
} | |
inline static void gemv(Stream<cpu> *stream, | |
bool trans, int m, int n, | |
float alpha, const float *A, int lda, | |
const float *X, int incX, | |
float beta, float *Y, int incY) { | |
cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha, | |
A, lda, X, incX, beta, Y, incY); | |
} | |
inline static void batched_gemv(Stream<cpu> *stream, | |
bool trans, int m, int n, | |
float alpha, const float *A, int lda, | |
const float *X, int incX, | |
float beta, float *Y, int incY, int batch_count) { | |
for (int i = 0; i < batch_count; ++i) { | |
gemv(stream, trans, m, n, alpha, A + i * m * n, lda, | |
X + i * (trans ? m : n) * incX, incX, | |
beta, Y + i * (trans ? n : m) * incY, incY); | |
} | |
} | |
inline static void ger(Stream<cpu> *stream, | |
int m, int n, float alpha, | |
const float *X, int incX, | |
const float *Y, int incY, float *A, int lda) { | |
cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); | |
} | |
inline static void batched_ger(Stream<cpu> *stream, | |
int m, int n, float alpha, | |
const float *X, int incX, | |
const float *Y, int incY, float *A, int lda, int batch_count) { | |
for (int i = 0; i < batch_count; ++i) { | |
ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, | |
A + i * lda * n, lda); | |
} | |
} | |
inline static void dot(Stream<cpu> *stream, | |
int n, | |
const float* X, int incX, | |
const float* Y, int incY, | |
float* ret) { | |
*ret = cblas_sdot(n, X, incX, Y, incY); | |
} | |
}; | |
template<> | |
struct BLASEngine<cpu, double> { | |
inline static CBLAS_TRANSPOSE GetT(bool t) { | |
return t ? CblasTrans : CblasNoTrans; | |
} | |
inline static void SetStream(Stream<cpu> *stream) { | |
} | |
inline static void gemm(Stream<cpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, double alpha, | |
const double *A, int lda, const double *B, int ldb, | |
double beta, double *C, int ldc) { | |
cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb), | |
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); | |
} | |
inline static void batched_gemm(Stream<cpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, double alpha, | |
const double *A, int lda, const double *B, int ldb, | |
double beta, double *C, int ldc, int batch_count, | |
double **workspace) { | |
for (int i = 0; i < batch_count; ++i) { | |
gemm(stream, transa, transb, m, n, k, alpha, | |
A + i * m * k, lda, B + i * k * n, ldb, | |
beta, C + i * m * n, ldc); | |
} | |
} | |
inline static void gemv(Stream<cpu> *stream, | |
bool trans, int m, int n, double alpha, | |
const double *A, int lda, | |
const double *X, int incX, | |
double beta, double *Y, int incY) { | |
cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha, | |
A, lda, X, incX, beta, Y, incY); | |
} | |
inline static void batched_gemv(Stream<cpu> *stream, | |
bool trans, int m, int n, | |
double alpha, const double *A, int lda, | |
const double *X, int incX, | |
double beta, double *Y, int incY, int batch_count) { | |
for (int i = 0; i < batch_count; ++i) { | |
gemv(stream, trans, m, n, alpha, A + i * m * n, lda, | |
X + i * (trans ? m : n) * incX, incX, | |
beta, Y + i * (trans ? n : m) * incY, incY); | |
} | |
} | |
inline static void ger(Stream<cpu> *stream, | |
int m, int n, double alpha, | |
const double *X, int incX, | |
const double *Y, int incY, double *A, int lda) { | |
cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); | |
} | |
inline static void batched_ger(Stream<cpu> *stream, | |
int m, int n, double alpha, | |
const double *X, int incX, | |
const double *Y, int incY, double *A, int lda, int batch_count) { | |
for (int i = 0; i < batch_count; ++i) { | |
ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, | |
A + i * lda * n, lda); | |
} | |
} | |
inline static void dot(Stream<cpu> *stream, | |
int n, | |
const double* X, int incX, | |
const double* Y, int incY, | |
double* ret) { | |
*ret = cblas_ddot(n, X, incX, Y, incY); | |
} | |
}; | |
#endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE | |
// CuBLAS redirect code | |
#if MSHADOW_USE_CUDA | |
// All CuBLAS goes to here, use legacy API: not threadsafe | |
template<> | |
struct BLASEngine<gpu, half::half_t> { | |
inline static cublasOperation_t GetT(bool t) { | |
return t ? CUBLAS_OP_T : CUBLAS_OP_N; | |
} | |
inline static void SetStream(Stream<gpu> *stream) { | |
cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream), | |
Stream<gpu>::GetStream(stream)); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail"; | |
} | |
inline static void gemm(Stream<gpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, half::half_t alpha, | |
const half::half_t *A, int lda, | |
const half::half_t *B, int ldb, half::half_t beta, | |
half::half_t *C, int ldc) { | |
#if defined(CUDA_VERSION) && CUDA_VERSION >= 7050 | |
#if MSHADOW_USE_PASCAL == 1 | |
cublasStatus_t err = cublasHgemm(Stream<gpu>::GetBlasHandle(stream), | |
GetT(transa), GetT(transb), m, n, k, &alpha.cuhalf_, | |
&A->cuhalf_, lda, &B->cuhalf_, ldb, &beta.cuhalf_, &C->cuhalf_, ldc); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas Hgemm fail"; | |
#else | |
float alpha_f = float(alpha); // NOLINT(*) | |
float beta_f = float(beta); // NOLINT(*) | |
#if CUDA_VERSION >= 8000 | |
cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream), | |
GetT(transa), GetT(transb), m, n, k, &alpha_f, | |
A, CUDA_R_16F, lda, B, CUDA_R_16F, | |
ldb, &beta_f, C, CUDA_R_16F, ldc); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail"; | |
#else | |
cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream), | |
GetT(transa), GetT(transb), m, n, k, &alpha_f, | |
A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF, | |
ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail"; | |
#endif // CUDA_VERSION >= 8000 | |
#endif // MSHADOW_USE_PASCAL == 1 | |
#else | |
LOG(FATAL) << "Require CUDA version >= 7.5!"; | |
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050 | |
} | |
inline static void batched_gemm(Stream<gpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, half::half_t alpha, | |
const half::half_t *A, int lda, const half::half_t *B, int ldb, | |
half::half_t beta, half::half_t *C, int ldc, int batch_count, | |
half::half_t **workspace) { | |
for (int i = 0; i < batch_count; ++i) { | |
gemm(stream, transa, transb, m, n, k, alpha, | |
A + i * m * k, lda, B + i * k * n, ldb, | |
beta, C + i * m * n, ldc); | |
} | |
} | |
inline static void gemv(Stream<gpu> *stream, | |
bool trans, int m, int n, half::half_t alpha, | |
const half::half_t *A, int lda, | |
const half::half_t *X, int incX, half::half_t beta, | |
half::half_t *Y, int incY) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void batched_gemv(Stream<gpu> *stream, | |
bool trans, int m, int n, | |
half::half_t alpha, const half::half_t *A, int lda, | |
const half::half_t *X, int incX, | |
half::half_t beta, half::half_t *Y, int incY, int batch_count) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void ger(Stream<gpu> *stream, | |
int m, int n, half::half_t alpha, | |
const half::half_t *X, int incX, | |
const half::half_t *Y, int incY, half::half_t *A, int lda) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void batched_ger(Stream<gpu> *stream, | |
int m, int n, half::half_t alpha, | |
const half::half_t *X, int incX, const half::half_t *Y, int incY, | |
half::half_t *A, int lda, int batch_count) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
inline static void dot(Stream<gpu> *stream, | |
int n, | |
const half::half_t* X, int incX, | |
const half::half_t* Y, int incY, | |
half::half_t *ret) { | |
LOG(FATAL) << "Not implmented!"; | |
} | |
}; | |
template<> | |
struct BLASEngine<gpu, float> { | |
inline static cublasOperation_t GetT(bool t) { | |
return t ? CUBLAS_OP_T : CUBLAS_OP_N; | |
} | |
inline static void SetStream(Stream<gpu> *stream) { | |
cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream), | |
Stream<gpu>::GetStream(stream)); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail"; | |
} | |
inline static void gemm(Stream<gpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, float alpha, | |
const float *A, int lda, | |
const float *B, int ldb, float beta, | |
float *C, int ldc) { | |
cublasStatus_t err = cublasSgemm(Stream<gpu>::GetBlasHandle(stream), | |
GetT(transa), GetT(transb), m, n, k, &alpha, | |
A, lda, B, ldb, &beta, C, ldc); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemm fail"; | |
} | |
inline static void batched_gemm(Stream<gpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, float alpha, | |
const float *A, int lda, const float *B, int ldb, | |
float beta, float *C, int ldc, int batch_count, | |
float **workspace) { | |
#if defined(__CUDACC__) && CUDA_VERSION >= 4010 | |
// Cast DType* to DType** using workspace as a buffer | |
bool alloc_workspace = false; | |
if (workspace == NULL) { | |
// Allocate the workspace if it's NULL. | |
// TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe. | |
cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count * sizeof(float*)); | |
alloc_workspace = true; | |
} | |
GetBatchedView(workspace, const_cast<float*>(A), batch_count, m * k, stream); | |
GetBatchedView(workspace + batch_count, | |
const_cast<float*>(B), batch_count, k * n, stream); | |
GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream); | |
cublasStatus_t err = cublasSgemmBatched(Stream<gpu>::GetBlasHandle(stream), | |
GetT(transa), GetT(transb), m, n, k, &alpha, | |
(const float**)workspace, lda, | |
(const float**)(workspace + batch_count), ldb, | |
&beta, workspace + 2 * batch_count, ldc, batch_count); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmBatched fail"; | |
if (alloc_workspace) { | |
cudaFree(workspace); | |
} | |
#else | |
for (int i = 0; i < batch_count; ++i) { | |
gemm(stream, transa, transb, m, n, k, alpha, | |
A + i * m * k, lda, B + i * k * n, ldb, | |
beta, C + i * m * n, ldc); | |
} | |
#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 | |
} | |
inline static void gemv(Stream<gpu> *stream, | |
bool trans, int m, int n, float alpha, | |
const float *A, int lda, | |
const float *X, int incX, float beta, | |
float *Y, int incY) { | |
cublasStatus_t err = cublasSgemv(Stream<gpu>::GetBlasHandle(stream), | |
GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemv fail"; | |
} | |
inline static void batched_gemv(Stream<gpu> *stream, | |
bool trans, int m, int n, | |
float alpha, const float *A, int lda, | |
const float *X, int incX, | |
float beta, float *Y, int incY, int batch_count) { | |
for (int i = 0; i < batch_count; ++i) { | |
gemv(stream, trans, m, n, alpha, A + i * m * n, lda, | |
X + i * (trans ? m : n) * incX, incX, | |
beta, Y + i * (trans ? n : m) * incY, incY); | |
} | |
} | |
inline static void ger(Stream<gpu> *stream, | |
int m, int n, float alpha, | |
const float *X, int incX, | |
const float *Y, int incY, float *A, int lda) { | |
cublasStatus_t err = cublasSger(Stream<gpu>::GetBlasHandle(stream), | |
m, n, &alpha, X, incX, Y, incY, A, lda); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sger fail"; | |
} | |
inline static void batched_ger(Stream<gpu> *stream, | |
int m, int n, float alpha, | |
const float *X, int incX, | |
const float *Y, int incY, float *A, int lda, int batch_count) { | |
for (int i = 0; i < batch_count; ++i) { | |
ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, | |
A + i * lda * n, lda); | |
} | |
} | |
inline static void dot(Stream<gpu> *stream, | |
int n, | |
const float* X, int incX, | |
const float* Y, int incY, | |
float *ret) { | |
cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream), | |
CUBLAS_POINTER_MODE_DEVICE); | |
cublasStatus_t err = cublasSdot(Stream<gpu>::GetBlasHandle(stream), | |
n, X, incX, Y, incY, ret); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail"; | |
cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream), | |
CUBLAS_POINTER_MODE_HOST); | |
} | |
}; | |
template<> | |
struct BLASEngine<gpu, double> { | |
inline static cublasOperation_t GetT(bool t) { | |
return t ? CUBLAS_OP_T : CUBLAS_OP_N; | |
} | |
inline static void SetStream(Stream<gpu> *stream) { | |
cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream), | |
Stream<gpu>::GetStream(stream)); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail"; | |
} | |
inline static void gemm(Stream<gpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, double alpha, | |
const double *A, int lda, | |
const double *B, int ldb, | |
double beta, double *C, int ldc) { | |
cublasStatus_t err = cublasDgemm(Stream<gpu>::GetBlasHandle(stream), | |
GetT(transa), GetT(transb), m, n, k, &alpha, | |
A, lda, B, ldb, &beta, C, ldc); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemm fail"; | |
} | |
inline static void batched_gemm(Stream<gpu> *stream, | |
bool transa, bool transb, | |
int m, int n, int k, double alpha, | |
const double *A, int lda, const double *B, int ldb, | |
double beta, double *C, int ldc, int batch_count, | |
double **workspace) { | |
#if defined(__CUDACC__) && CUDA_VERSION >= 4010 | |
// Cast DType* to DType** using workspace as a buffer | |
bool alloc_workspace = false; | |
if (workspace == NULL) { | |
// Allocate the workspace if it's NULL. | |
// TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe. | |
cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count * sizeof(double*)); | |
alloc_workspace = true; | |
} | |
GetBatchedView(workspace, const_cast<double*>(A), batch_count, m * k, stream); | |
GetBatchedView(workspace + batch_count, | |
const_cast<double*>(B), batch_count, k * n, stream); | |
GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream); | |
cublasStatus_t err = cublasDgemmBatched(Stream<gpu>::GetBlasHandle(stream), | |
GetT(transa), GetT(transb), m, n, k, &alpha, | |
(const double**)workspace, lda, | |
(const double**)(workspace + batch_count), ldb, | |
&beta, workspace + 2 * batch_count, ldc, batch_count); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmBatched fail"; | |
if (alloc_workspace) { | |
cudaFree(workspace); | |
} | |
#else | |
for (int i = 0; i < batch_count; ++i) { | |
gemm(stream, transa, transb, m, n, k, alpha, | |
A + i * m * k, lda, B + i * k * n, ldb, | |
beta, C + i * m * n, ldc); | |
} | |
#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 | |
} | |
inline static void gemv(Stream<gpu> *stream, | |
bool trans, int m, int n, double alpha, | |
const double *A, int lda, | |
const double *X, int incX, | |
double beta, double *Y, int incY) { | |
cublasStatus_t err = cublasDgemv(Stream<gpu>::GetBlasHandle(stream), | |
GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemv fail"; | |
} | |
inline static void batched_gemv(Stream<gpu> *stream, | |
bool trans, int m, int n, | |
double alpha, const double *A, int lda, | |
const double *X, int incX, | |
double beta, double *Y, int incY, int batch_count) { | |
for (int i = 0; i < batch_count; ++i) { | |
gemv(stream, trans, m, n, alpha, A + i * m * n, lda, | |
X + i * (trans ? m : n) * incX, incX, | |
beta, Y + i * (trans ? n : m) * incY, incY); | |
} | |
} | |
inline static void ger(Stream<gpu> *stream, | |
int m, int n, double alpha, | |
const double *X, int incX, | |
const double *Y, int incY, double *A, int lda) { | |
cublasStatus_t err = cublasDger(Stream<gpu>::GetBlasHandle(stream), | |
m, n, &alpha, X, incX, Y, incY, A, lda); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dger fail"; | |
} | |
inline static void batched_ger(Stream<gpu> *stream, | |
int m, int n, double alpha, | |
const double *X, int incX, | |
const double *Y, int incY, double *A, int lda, int batch_count) { | |
for (int i = 0; i < batch_count; ++i) { | |
ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, | |
A + i * lda * n, lda); | |
} | |
} | |
inline static void dot(Stream<gpu> *stream, | |
int n, | |
const double* X, int incX, | |
const double* Y, int incY, | |
double *ret) { | |
cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream), | |
CUBLAS_POINTER_MODE_DEVICE); | |
cublasStatus_t err = cublasDdot(Stream<gpu>::GetBlasHandle(stream), | |
n, X, incX, Y, incY, ret); | |
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail"; | |
cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream), | |
CUBLAS_POINTER_MODE_HOST); | |
} | |
}; | |
#endif // MSHADOW_USE_CUDA | |
// helper function to decide which shape we are in | |
inline Shape<2> GetShape(const Shape<2> &shape, bool transpose) { | |
return transpose ? Shape2(shape[1], shape[0]) : shape; | |
} | |
// dst = dot(lhs[.T], rhs[.T]) | |
template<typename SV, typename xpu, | |
bool transpose_left, bool transpose_right, typename DType> | |
struct DotEngine<SV, xpu, 2, 2, 2, transpose_left, transpose_right, DType> { | |
inline static void Eval(Tensor<xpu, 2, DType> *p_dst, | |
const Tensor<xpu, 2, DType> &lhs, | |
const Tensor<xpu, 2, DType> &rhs, | |
DType scale) { | |
Tensor<xpu, 2, DType> &dst = *p_dst; | |
#if MSHADOW_STAND_ALONE | |
if (xpu::kDevMask == cpu::kDevMask && scale == 1.0f) { | |
if (!transpose_left && !transpose_right) { | |
dst = expr::implicit_dot(lhs, rhs); return; | |
} else if (!transpose_left && transpose_right) { | |
dst = expr::implicit_dot(lhs, rhs.T()); return; | |
} else if (transpose_left && !transpose_right) { | |
dst = expr::implicit_dot(lhs.T(), rhs); return; | |
} | |
} | |
#endif | |
// set kernel stream | |
// if there is no stream, crush | |
BLASEngine<xpu, DType>::SetStream(dst.stream_); | |
Shape<2> sleft = GetShape(lhs.shape_, transpose_left); | |
Shape<2> sright = GetShape(rhs.shape_, transpose_right); | |
CHECK(dst.size(0) == sleft[0] && dst.size(1) == sright[1] && sleft[1] == sright[0]) | |
<< "dot-gemm: matrix shape mismatch"; | |
// use column major argument to compatible with most BLAS | |
BLASEngine<xpu, DType>::gemm | |
(dst.stream_, | |
transpose_right , transpose_left, | |
transpose_right ? rhs.size(0) : rhs.size(1), | |
transpose_left ? lhs.size(1) : lhs.size(0), | |
transpose_right ? rhs.size(1) : rhs.size(0), | |
DType(scale * SV::AlphaBLAS()), | |
rhs.dptr_, rhs.stride_, | |
lhs.dptr_, lhs.stride_, | |
DType(SV::BetaBLAS()), | |
dst.dptr_, dst.stride_); | |
} | |
}; | |
template<typename SV, typename xpu, bool transpose_right, typename DType> | |
struct DotEngine<SV, xpu, 1, 1, 2, false, transpose_right, DType> { | |
inline static void Eval(Tensor<xpu, 1, DType> *p_dst, | |
const Tensor<xpu, 1, DType> &lhs, | |
const Tensor<xpu, 2, DType> &rhs, | |
DType scale) { | |
Tensor<xpu, 1, DType> &dst = *p_dst; | |
// set kernel stream | |
// if there is no stream, crush | |
BLASEngine<xpu, DType>::SetStream(dst.stream_); | |
Shape<2> sright = GetShape(rhs.shape_, transpose_right); | |
CHECK(dst.size(0) == sright[1] && lhs.size(0) == sright[0]) | |
<< "dot-gemv: matrix shape mismatch" | |
<< "dst: " << dst.shape_ << "\n" | |
<< "lhs: " << lhs.shape_ << "\n" | |
<< "rhs: " << sright << "\n"; | |
BLASEngine<xpu, DType>::gemv | |
(dst.stream_, | |
transpose_right, | |
rhs.size(1), rhs.size(0), scale * SV::AlphaBLAS(), | |
rhs.dptr_, rhs.stride_, | |
lhs.dptr_, 1, SV::BetaBLAS(), | |
dst.dptr_, 1); | |
} | |
}; | |
template<typename SV, typename xpu, typename DType> | |
struct DotEngine<SV, xpu, 2, 1, 1, true, false, DType> { | |
inline static void Eval(Tensor<xpu, 2, DType> *p_dst, | |
const Tensor<xpu, 1, DType> &lhs, | |
const Tensor<xpu, 1, DType> &rhs, | |
DType scale) { | |
Tensor<xpu, 2, DType> &dst = *p_dst; | |
// set kernel stream | |
// if there is no stream, crush | |
BLASEngine<xpu, DType>::SetStream(dst.stream_); | |
CHECK(dst.size(0) == lhs.size(0) && dst.size(1) == rhs.size(0)) | |
<< "dot-ger: matrix shape mismatch" | |
<< "dst: " << dst.shape_ << "\n" | |
<< "lhs: " << lhs.shape_ << "\n" | |
<< "rhs: " << rhs.shape_; | |
if (SV::BetaBLAS() == 0.0f) { | |
BLASEngine<xpu, DType>::ger | |
(dst.stream_, rhs.size(0), lhs.size(0), scale * SV::AlphaBLAS(), | |
rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_); | |
} else { | |
DotEngine<SV, xpu, 2, 2, 2, true, false, | |
DType>::Eval(p_dst, lhs.FlatTo2D(), rhs.FlatTo2D(), scale); | |
} | |
} | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_DOT_ENGINE_INL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/dot_engine-inl.h ===== | |
namespace mshadow { | |
namespace expr { | |
/*! \brief some engine that evaluate complex expression */ | |
template<typename SV, typename RV, typename E, typename DType> | |
struct ExpComplexEngine { | |
inline static void Eval(RV *dst, const E &exp); | |
}; | |
/*! \brief the engine that dispatches simple operations*/ | |
template<typename SV, typename RV, typename DType> | |
struct ExpEngine { | |
template<typename E> | |
inline static void Eval(RV *dst, | |
const Exp<E, DType, type::kMapper> &exp) { | |
MapExp<SV>(dst, exp); | |
} | |
template<typename E> | |
inline static void Eval(RV *dst, | |
const Exp<E, DType, type::kChainer> &exp) { | |
MapExp<SV>(dst, exp); | |
} | |
template<typename E> | |
inline static void Eval(RV *dst, | |
const Exp<E, DType, type::kRValue> &exp) { | |
MapExp<SV>(dst, exp); | |
} | |
template<typename E> | |
inline static void Eval(RV *dst, | |
const Exp<E, DType, type::kComplex> &exp) { | |
ExpComplexEngine<SV, RV, E, DType>::Eval(dst->ptrself(), exp.self()); | |
} | |
}; | |
template<typename SV, typename Device, int dim, int ldim, | |
int rdim, bool ltrans, bool rtrans, typename DType> | |
struct ExpComplexEngine<SV, | |
Tensor<Device, dim, DType>, | |
DotExp<Tensor<Device, ldim, DType>, | |
Tensor<Device, rdim, DType>, | |
ltrans, rtrans, DType>, | |
DType> { | |
inline static void Eval(Tensor<Device, dim, DType> *dst, | |
const DotExp<Tensor<Device, ldim, DType>, | |
Tensor<Device, rdim, DType>, | |
ltrans, rtrans, DType> &exp) { | |
DotEngine<SV, Device, dim, ldim, rdim, | |
ltrans, rtrans, DType>::Eval(dst, exp.lhs_, exp.rhs_, exp.scale_); | |
} | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXPR_ENGINE_INL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/expr_engine-inl.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/broadcast.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file broadcast.h | |
* \brief support for broadcast and repmat | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_BROADCAST_H_ | |
#define MSHADOW_EXTENSION_BROADCAST_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief broadcast Tensor1D into a higher dimension Tensor | |
* input: Tensor<Device,1>: ishape[0] | |
* output: Tensor<Device,dimdst> : oshape[dimcast] = ishape[0] | |
* \tparam SrcExp type of input expression | |
* \tparam DType the type of elements | |
* \tparam dimdst target tensor dimension | |
* \tparam dimcast_m_dst dimdst - dimcast | |
*/ | |
template<typename SrcExp, typename DType, int dimdst, int dimdst_m_cast> | |
struct Broadcast1DExp: | |
public MakeTensorExp<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>, | |
SrcExp, dimdst, DType> { | |
/*! \brief source operand */ | |
const SrcExp &src_; | |
/*! \brief constructor */ | |
Broadcast1DExp(const SrcExp &src, Shape<dimdst> shape) | |
: src_(src) { | |
this->shape_ = shape; | |
} | |
}; | |
/*! | |
* \brief broadcast scalar into a higher dimension Tensor | |
* input: Tensor<Device,1>: ishape = {1} | |
* output: Tensor<Device, dimdst> : oshape[dimcast] = ishape[0] | |
* \tparam SrcExp type of input expression | |
* \tparam DType the type of elements | |
* \tparam dimdst target tensor dimension | |
*/ | |
template<typename SrcExp, typename DType, int dimdst> | |
struct BroadcastScalarExp: | |
public MakeTensorExp<BroadcastScalarExp<SrcExp, DType, dimdst>, | |
SrcExp, dimdst, DType> { | |
/*! \brief source operand */ | |
const SrcExp &src_; | |
/*! \brief constructor */ | |
BroadcastScalarExp(const SrcExp &src, Shape<dimdst> shape) | |
: src_(src) { | |
this->shape_ = shape; | |
} | |
}; | |
/*! | |
* \brief a expression that replicate a 1 dimension tensor in dimension dimcast | |
* \param src Tensor<Device,1>: shape[0] | |
* \param shape shape of output | |
* \return a expresion with type Tensor<Device,dimdst> | |
* \tparam dimcast target dimension where the 1D tensor will be broadcasted | |
* \tparam SrcExp type of input expression | |
* \tparam DType the type of elements | |
* \tparam dimdst dimension of destination tensor | |
* \tparam dimcast_lowest the dimension we want to cast the data into | |
*/ | |
template<int dimcast, typename SrcExp, typename DType, | |
int etype, int dimdst> | |
inline Broadcast1DExp<SrcExp, DType, dimdst, dimdst - dimcast> | |
broadcast(const expr::Exp<SrcExp, DType, etype> &src, Shape<dimdst> shape) { | |
TypeCheckPass<dimcast < dimdst && ExpInfo<SrcExp>::kDim == 1> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp; | |
CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], shape[dimcast]) | |
<< "broadcast, shape mismatch"; | |
return Broadcast1DExp<SrcExp, DType, dimdst, | |
dimdst - dimcast>(src.self(), shape); | |
} | |
/*! | |
* \brief a expression that replicate a scalar tensor to target dimension. | |
* \param src Tensor<Device,1>: shape[0] == 1 | |
* \param shape shape of output | |
* \return a expresion with type Tensor<Device, dimdst> | |
* \tparam dimcast target dimension where the 1D tensor will be broadcasted | |
* \tparam SrcExp type of input expression | |
* \tparam DType the type of elements | |
* \tparam dimdst dimension of destination tensor | |
*/ | |
template<typename SrcExp, typename DType, int etype, int dimdst> | |
inline BroadcastScalarExp<SrcExp, DType, dimdst> | |
broadcast_scalar(const expr::Exp<SrcExp, DType, etype> &src, Shape<dimdst> shape) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim == 1> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp; | |
CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], 1) | |
<< "broadcast_scalar, source need to be scalar expression"; | |
return BroadcastScalarExp<SrcExp, DType, dimdst>(src.self(), shape); | |
} | |
// short cut functions | |
/*! | |
* \brief a expression that replicate a 1 dimension tensor for nrow times | |
* \param src Tensor<Device,1>: shape[0] | |
* \param nrow number of rows to replicate | |
* \return a expresion with type Tensor<Device,2> size(1), size(0) = nrow | |
* \tparam Device which device it lies | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline Broadcast1DExp<SrcExp, DType, 2, 1> | |
repmat(const expr::Exp<SrcExp, DType, etype> &src, index_t nrow) { | |
return broadcast<1> | |
(src, Shape2(nrow, ShapeCheck<1, SrcExp>::Check(src.self())[0])); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename SrcExp, typename DType, int dimdst, int dimdst_m_cast> | |
struct Plan<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>, DType> { | |
public: | |
static const int dimcast = dimdst - dimdst_m_cast; | |
explicit Plan(const Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast> &e) | |
: src_(MakePlan(e.src_)), | |
ystride_(e.shape_.ProdShape(dimcast + 1, dimdst - 1)), | |
length_(e.shape_[dimcast]) { | |
TypeCheckPass<dimcast != dimdst - 1> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return src_.Eval(0, (y / ystride_) % length_); | |
} | |
private: | |
expr::Plan<SrcExp, DType> src_; | |
const index_t ystride_, length_; | |
}; | |
/*! \brief execution plan of Broadcast1DExp */ | |
template<typename SrcExp, typename DType, int dimdst> | |
struct Plan<Broadcast1DExp<SrcExp, DType, dimdst, 1>, DType>{ | |
public: | |
explicit Plan(const Broadcast1DExp<SrcExp, DType, dimdst, 1> &e) | |
: src_(MakePlan(e.src_)) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return src_.Eval(0, x); | |
} | |
private: | |
expr::Plan<SrcExp, DType> src_; | |
}; | |
/*! \brief execution plan of Broadcast1DExp */ | |
template<typename SrcExp, typename DType, int dimdst> | |
struct Plan<BroadcastScalarExp<SrcExp, DType, dimdst>, DType>{ | |
public: | |
explicit Plan(const BroadcastScalarExp<SrcExp, DType, dimdst> &e) | |
: src_(MakePlan(e.src_)) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return src_.Eval(0, 0); | |
} | |
private: | |
expr::Plan<SrcExp, DType> src_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_BROADCAST_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/broadcast.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/unpack_patch2col.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file unpack_patch2col.h | |
* \brief support for unpack | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ | |
#define MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief unpack local (overlap) patches of image to column of mat, | |
* can be used to implement convolution, this expression allow unpack of a batch | |
* this is a version support unpacking multiple images | |
* after getting unpacked mat, we can use: output = dot(weight, mat) to get covolved results, the relations: | |
* \tparam SrcExp source expression | |
* \tparam dstdim destination dimension | |
*/ | |
template<typename SrcExp, typename DType, int srcdim> | |
struct UnpackPatchToColXExp: | |
public MakeTensorExp<UnpackPatchToColXExp<SrcExp, DType, srcdim>, | |
SrcExp, 2, DType>{ | |
/*! \brief source operand */ | |
const SrcExp &img_; | |
/*! \brief patch height */ | |
index_t psize_y_; | |
/*! \brief patch width */ | |
index_t psize_x_; | |
/*! \brief patch stride */ | |
index_t pstride_y_; | |
index_t pstride_x_; | |
/*! \brief patch dilate */ | |
index_t pdilate_y_; | |
index_t pdilate_x_; | |
/*! \brief number of input channel */ | |
index_t i_channel_; | |
/*! \brief height of img */ | |
index_t i_height_; | |
/*! \brief width of img */ | |
index_t i_width_; | |
/*! \brief constructor */ | |
UnpackPatchToColXExp(const SrcExp &img, | |
index_t psize_y, | |
index_t psize_x, | |
index_t pstride_y, | |
index_t pstride_x, | |
index_t pdilate_y, | |
index_t pdilate_x) | |
: img_(img), psize_y_(psize_y), psize_x_(psize_x), | |
pstride_y_(pstride_y), pstride_x_(pstride_x), | |
pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){ | |
Shape<srcdim> imshape = ShapeCheck<srcdim, SrcExp>::Check(img_); | |
CHECK(imshape[srcdim - 1] >= psize_x && imshape[srcdim - 2] >= psize_y) | |
<< "UnpackPatchToCol:image shape smaller than patch size"; | |
this->i_channel_ = imshape[srcdim - 3]; | |
this->i_height_ = imshape[srcdim - 2]; | |
this->i_width_ = imshape[srcdim - 1]; | |
// calculate number of batches | |
const index_t num = imshape.ProdShape(0, srcdim - 3); | |
const index_t o_height = (i_height_ - | |
(pdilate_y * (psize_y - 1) + 1)) / pstride_y + 1; | |
const index_t o_width = (i_width_ - | |
(pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1; | |
this->shape_[1] = o_height * o_width * num; | |
this->shape_[0] = psize_y * psize_x * i_channel_; | |
} | |
}; | |
/*! | |
* \brief unpack local (overlap) patches of image to column of mat, can be used to implement convolution | |
* after getting unpacked mat, we can use: output = dot(weight, mat) to get covolved results, the relations: | |
* | |
* weight; shape[0]: out_channel, shape[1]: ichannel * psize_y * psize_x | |
* output; shape[0]: out_channel, shape[1]: out_height * out_width * num_of_images | |
* out_height = (in_height - psize_y) / pstride + 1, this means we pad inperfect patch with 0 | |
* out_width = (in_width - psize_x) / pstride + 1 | |
* | |
* \return mat target matrix; shape[0]: in_channel*psize_y*psize_x shape[1]: out_height*out_width * num_of_images | |
* \param img source image; shape[-3]: in_channels, shape[-2]: in_height, shape[-1]: in_width, can be 3D or 4D tensor(multiple images) | |
* \param psize_y height of each patch | |
* \param psize_x width of each patch | |
* \param pstride stride of each patch | |
* \param pdilate dilate of each patch | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype type of expression | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
unpack_patch2col(const Exp<SrcExp, DType, etype> &img, | |
index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 3> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
(img.self(), psize_y, psize_x, pstride, pstride, pdilate, pdilate); | |
} | |
/*! | |
*if you want to specify stride_x and stride_y | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
unpack_patch2col(const Exp<SrcExp, DType, etype> &img, | |
index_t psize_y, index_t psize_x, index_t pstride_y_, index_t pstride_x_, | |
index_t pdilate_y_, index_t pdilate_x_) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 3> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
(img.self(), psize_y, psize_x, pstride_y_, pstride_x_, pdilate_y_, pdilate_x_); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename SrcExp, typename DType, int srcdim> | |
struct Plan<UnpackPatchToColXExp<SrcExp, DType, srcdim>, DType> { | |
public: | |
explicit Plan(const UnpackPatchToColXExp<SrcExp, DType, srcdim> &e) | |
:src_(MakePlan(e.img_)), | |
psize_y_(e.psize_y_), psize_x_(e.psize_x_), | |
pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_), | |
i_channel_(e.i_channel_), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_), | |
i_height_(e.i_height_), i_width_(e.i_width_), | |
o_height_((i_height_ - (pdilate_y_ * (psize_y_ - 1) + 1)) / pstride_y_ + 1), | |
o_width_((i_width_ - (pdilate_x_ * (psize_x_ - 1) + 1)) / pstride_x_ + 1) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
const index_t x_offset = i % psize_x_ * pdilate_x_; | |
const index_t idivp = i / psize_x_; | |
const index_t y_offset = idivp % psize_y_ * pdilate_y_; | |
const index_t c = idivp / psize_y_; | |
const index_t x = (j % o_width_) * pstride_x_ + x_offset; | |
const index_t jdivw = j / o_width_; | |
const index_t y = (jdivw % o_height_) * pstride_y_ + y_offset; | |
const index_t n = jdivw / o_height_; | |
if (x < i_width_ && y < i_height_) { | |
return src_.Eval((n * i_channel_ + c) * i_height_ + y, x); | |
} else { | |
return DType(0.0f); | |
} | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_; | |
const index_t pdilate_y_, pdilate_x_; | |
const index_t i_height_, i_width_, o_height_, o_width_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/unpack_patch2col.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/pack_col2patch.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file pack_col2patch.h | |
* \brief support for pack | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_ | |
#define MSHADOW_EXTENSION_PACK_COL2PATCH_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief reverse operation of UnpackPatchToCol, | |
* used to backprop gradient back | |
* this is a version supporting multiple images | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam dstdim destination dimension | |
*/ | |
template<typename SrcExp, typename DType, int dstdim> | |
struct PackColToPatchXExp: | |
public MakeTensorExp<PackColToPatchXExp<SrcExp, DType, dstdim>, | |
SrcExp, dstdim, DType> { | |
/*! \brief source operand */ | |
const SrcExp &src_; | |
/*! \brief patch height */ | |
index_t psize_y_; | |
/*! \brief patch height */ | |
index_t psize_x_; | |
/*! \brief patch stride */ | |
index_t pstride_y_; | |
index_t pstride_x_; | |
/*! \brief patch dilate */ | |
index_t pdilate_y_; | |
index_t pdilate_x_; | |
/*! \brief constructor */ | |
PackColToPatchXExp(const SrcExp &src, Shape<dstdim> imshape, | |
index_t psize_y, index_t psize_x, | |
index_t pstride_y, index_t pstride_x, | |
index_t pdilate_y, index_t pdilate_x) | |
:src_(src), psize_y_(psize_y), psize_x_(psize_x), | |
pstride_y_(pstride_y), pstride_x_(pstride_x), | |
pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){ | |
this->shape_ = imshape; | |
const index_t o_height = (imshape[dstdim - 2] - | |
(pdilate_y * (psize_y - 1)+ 1))/pstride_y + 1; | |
const index_t o_width = (imshape[dstdim - 1] - | |
(pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1; | |
Shape<2> sshape = ShapeCheck<2, SrcExp>::Check(src_); | |
CHECK_EQ(sshape[1], o_height * o_width * imshape.ProdShape(0, dstdim - 3)) | |
<< "PackColToPatchExp: src.size(1) mismatch"; | |
CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3]) | |
<< "PackColToPatchExp: src.size(0) mismatch"; | |
} | |
}; | |
/*! | |
* \brief reverse operation of pack_col2patch, can be used to implement deconvolution | |
* \return packed img expression | |
* \param mat source matrix | |
* \param imshape shape of target img | |
* \param psize_y height of each patch | |
* \param psize_x height of each patch | |
* \param pstride stride of each patch | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam dstdim destination dimension | |
* \tparam etype type of expression | |
*/ | |
template<typename SrcExp, typename DType, int dstdim, int etype> | |
inline PackColToPatchXExp<SrcExp, DType, dstdim> | |
pack_col2patch(const expr::Exp<SrcExp, DType, etype> &src, | |
Shape<dstdim> imshape, index_t psize_y, | |
index_t psize_x, index_t pstride, index_t pdilate) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim == 2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y) | |
<< "PackColToPatch:image shape smaller than patch size"; | |
return PackColToPatchXExp<SrcExp, DType, dstdim>(src.self(), imshape, | |
psize_y, psize_x, pstride, pstride, | |
pdilate, pdilate); | |
} | |
/*! | |
*if you want to specify kstride_y and kstride_x | |
*/ | |
template<typename SrcExp, typename DType, int dstdim, int etype> | |
inline PackColToPatchXExp<SrcExp, DType, dstdim> | |
pack_col2patch(const expr::Exp<SrcExp, DType, etype> &src, | |
Shape<dstdim> imshape, index_t psize_y, | |
index_t psize_x, index_t pstride_y, index_t pstride_x, | |
index_t pdilate_y, index_t pdilate_x) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim == 2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y) | |
<< "PackColToPatch:image shape smaller than patch size"; | |
return PackColToPatchXExp<SrcExp, DType, dstdim>(src.self(), imshape, | |
psize_y, psize_x, pstride_y, pstride_x, | |
pdilate_y, pdilate_x); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename SrcExp, typename DType, int dstdim> | |
struct Plan<PackColToPatchXExp<SrcExp, DType, dstdim>, DType> { | |
public: | |
explicit Plan(const PackColToPatchXExp<SrcExp, DType, dstdim> &e) | |
:src_(MakePlan(e.src_)), psize_y_(e.psize_y_), | |
psize_x_(e.psize_x_), pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_), | |
i_channel_(e.shape_[dstdim - 3]), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_), | |
i_height_(e.shape_[dstdim - 2]), | |
o_height_((e.shape_[dstdim - 2] - (pdilate_y_ * (psize_y_ - 1) + 1)) / | |
pstride_y_ + 1), | |
o_width_((e.shape_[dstdim - 1] - (pdilate_x_ * (psize_x_ - 1) + 1)) / | |
pstride_x_ + 1) { | |
// note: i/o convention are same as unpack | |
} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
using namespace std; | |
const index_t y = i % i_height_; | |
const index_t idivh = i / i_height_; | |
const index_t c = idivh % i_channel_; | |
const index_t n = idivh / i_channel_; | |
const index_t x = j; | |
const index_t psize_y_dilate = (pdilate_y_ * (psize_y_ - 1) + 1); | |
const index_t psize_x_dilate = (pdilate_x_ * (psize_x_ - 1) + 1); | |
const index_t py_min = | |
y < psize_y_dilate ? y % pdilate_y_ : (y-psize_y_dilate + pstride_y_) / pstride_y_; | |
const index_t px_min = | |
x < psize_x_dilate ? x % pdilate_x_ : (x-psize_x_dilate + pstride_x_) / pstride_x_; | |
const index_t py_max = min((y + pstride_y_) / pstride_y_, o_height_); | |
const index_t px_max = min((x + pstride_x_) / pstride_x_, o_width_); | |
DType res = static_cast<DType>(0); | |
for (index_t py = py_min; py < py_max; py += pdilate_y_) { | |
for (index_t px = px_min; px < px_max; px += pdilate_x_) { | |
res += src_.Eval(((c * psize_y_ + (y - py*pstride_y_) / pdilate_y_) * psize_x_ + | |
(x - px * pstride_x_) / pdilate_x_), | |
(n * o_height_ + py) * o_width_ + px); | |
} | |
} | |
return res; | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_; | |
const index_t pdilate_y_, pdilate_x_; | |
const index_t i_height_, o_height_, o_width_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/pack_col2patch.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/reshape.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file reshape.h | |
* \brief support for reshape | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_RESHAPE_H_ | |
#define MSHADOW_EXTENSION_RESHAPE_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief reshape the content to another shape | |
* input: Tensor<Device,dimsrc>: ishape | |
* output: Tensor<Device,dimdst> ishape.Size() == oshape.Size() | |
* \tparam SrcExp source expression | |
* \tparam dimdst target dimension | |
* \tparam dimsrc source dimension | |
*/ | |
template<typename SrcExp, typename DType, int dimdst, int dimsrc> | |
struct ReshapeExp: | |
public MakeTensorExp<ReshapeExp<SrcExp, DType, dimdst, dimsrc>, | |
SrcExp, dimdst, DType> { | |
/*! \brief source expression */ | |
const SrcExp &src_; | |
/*! \brief smallest dimension of input */ | |
index_t ishapex_; | |
/*! \brief constructor */ | |
ReshapeExp(const SrcExp &src, Shape<dimdst> shape) | |
: src_(src) { | |
Shape<dimsrc> ishape = ShapeCheck<dimsrc, SrcExp>::Check(src_); | |
CHECK_EQ(ishape.Size(), shape.Size()) << "reshape size must match"; | |
ishapex_ = ishape[dimsrc - 1]; | |
this->shape_ = shape; | |
} | |
}; | |
/*! | |
* \brief a expression that reshapes a tensor to another shape | |
* \param src Tensor<Device,dimsrc>: | |
* \param oshape target shape | |
* \return a expresion with type Tensor<Device,dimdst> | |
* \tparam SrcExp source expression | |
* \tparam etype source expression type | |
* \tparam dimdst target dimension | |
*/ | |
template<typename SrcExp, typename DType, int etype, int dimdst> | |
inline ReshapeExp<SrcExp, DType, dimdst, ExpInfo<SrcExp>::kDim> | |
reshape(const Exp<SrcExp, DType, etype> &src, Shape<dimdst> oshape) { | |
return ReshapeExp<SrcExp, DType, dimdst, ExpInfo<SrcExp>::kDim> | |
(src.self(), oshape); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename SrcExp, typename DType, int dimdst, int dimsrc> | |
struct Plan<ReshapeExp<SrcExp, DType, dimdst, dimsrc>, DType> { | |
public: | |
explicit Plan(const ReshapeExp<SrcExp, DType, dimdst, dimsrc> &e) | |
: src_(MakePlan(e.src_)), | |
oshapex_(e.shape_[dimdst - 1]), ishapex_(e.ishapex_) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
const index_t idx = y * oshapex_ + x; | |
return src_.Eval(idx / ishapex_, idx % ishapex_); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t oshapex_, ishapex_; | |
}; | |
// special work plan for 1 dimensional data | |
template<typename SrcExp, typename DType, int dimdst> | |
struct Plan<ReshapeExp<SrcExp, DType, dimdst, 1>, DType> { | |
public: | |
explicit Plan(const ReshapeExp<SrcExp, DType, dimdst, 1> &e) | |
: src_(MakePlan(e.src_)), oshapex_(e.shape_[dimdst - 1]) { | |
} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return src_.Eval(0, y * oshapex_ + x); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t oshapex_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_RESHAPE_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/reshape.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/swapaxis.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file swapaxis.h | |
* \brief support for swapaxis | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_SWAPAXIS_H_ | |
#define MSHADOW_EXTENSION_SWAPAXIS_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief swap two axis of a tensor | |
* input: Tensor<Device,dim>: ishape | |
* output: Tensor<Device,dimdst> oshape[a1],oshape[a2] = ishape[a2],oshape[a1] | |
* | |
* \tparam SrcExp type of source expression | |
* \tparam DType the type of elements | |
* \tparam dimsrc source dimension, assert a1 > a2 | |
* \tparam m_a1 one dimension to be swapped, encoded by dimsrc - a1 | |
* \tparam a2 second dimension to be swapped, encoded by a2 | |
*/ | |
template<typename SrcExp, typename DType, int dimsrc, int m_a1, int a2> | |
struct SwapAxisExp: | |
public MakeTensorExp<SwapAxisExp<SrcExp, DType, dimsrc, m_a1, a2>, | |
SrcExp, dimsrc, DType> { | |
// decode the a1, a2 | |
static const int a1 = dimsrc - m_a1; | |
/*! \brief source expression */ | |
const SrcExp &src_; | |
/*! \brief constructor */ | |
explicit SwapAxisExp(const SrcExp &src) : src_(src) { | |
this->shape_ = ShapeCheck<dimsrc, SrcExp>::Check(src); | |
std::swap(this->shape_[a1], this->shape_[a2]); | |
} | |
}; | |
/*! | |
* \brief a expression that reshapes a tensor to another shape | |
* \param src Tensor<Device,dimsrc>: | |
* \return a expresion with type Tensor<Device,dimdst> | |
* \tparam a1 higher dimension to be swapped, assert a1 > a2 | |
* \tparam a2 lower dimension to be swapped | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype source expression type | |
*/ | |
template<int a1, int a2, typename SrcExp, typename DType, int etype> | |
inline SwapAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim, | |
ExpInfo<SrcExp>::kDim - a1, a2> | |
swapaxis(const Exp<SrcExp, DType, etype> &src) { | |
typedef ExpInfo<SrcExp> Info; | |
TypeCheckPass<Info::kDim >= a1 + 1 && Info::kDim >= a2 + 1 && | |
a2 < a1>::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return SwapAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim, | |
ExpInfo<SrcExp>::kDim - a1, a2>(src.self()); | |
} | |
template<typename SrcExp, typename DType, int dimsrc, int m_a1, int a2> | |
struct Plan<SwapAxisExp<SrcExp, DType, dimsrc, m_a1, a2>, DType> { | |
public: | |
// decode the a1 | |
static const int a1 = dimsrc - m_a1; | |
explicit Plan(const SwapAxisExp<SrcExp, DType, dimsrc, m_a1, a2> &e) | |
: src_(MakePlan(e.src_)), | |
shapey_(e.shape_.ProdShape(a1 + 1, dimsrc - 1)), | |
shapez_(e.shape_[a1]), | |
shapec_(e.shape_.ProdShape(a2 + 1, a1)), | |
shapen_(e.shape_[a2]) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
const index_t y = i % shapey_; | |
i /= shapey_; | |
const index_t z = i % shapez_; | |
i /= shapez_; | |
const index_t c = i % shapec_; | |
i /= shapec_; | |
const index_t n = i % shapen_; | |
// swap z and n | |
return src_.Eval(((((i / shapen_) * shapez_ + z) * shapec_ + | |
c) * shapen_ + n) * shapey_ + y, j); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t shapey_, shapez_, shapec_, shapen_; | |
}; | |
template<typename SrcExp, typename DType, int dimsrc, int a2> | |
struct Plan<SwapAxisExp<SrcExp, DType, dimsrc, 1, a2>, DType> { | |
public: | |
explicit Plan(const SwapAxisExp<SrcExp, DType, dimsrc, 1, a2> &e) | |
: src_(MakePlan(e.src_)), | |
shapex_(e.shape_[dimsrc - 1]), | |
shapey_(e.shape_.ProdShape(a2 + 1, dimsrc - 1)), | |
shapez_(e.shape_[a2]) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t x) const { | |
// swap x and z | |
const index_t y = i % shapey_; | |
i /= shapey_; | |
const index_t z = i % shapez_; | |
const index_t n = i / shapez_; | |
return src_.Eval((n * shapex_ + x) * shapey_ + y , z); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t shapex_, shapey_, shapez_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_SWAPAXIS_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/swapaxis.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/reduceto1d.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file reduceto1d.h | |
* \brief support for sum_rows and sumall_except_dim | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_REDUCETO1D_H_ | |
#define MSHADOW_EXTENSION_REDUCETO1D_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief reduction to 1 dimension tensor | |
* input: Tensor<Device,k>: ishape | |
* output: Tensor<Device,1> shape[0] = ishape[dimkeep]; | |
* | |
* \tparam SrcExp type of expression to be reduced | |
* \tparam DType the data type of the scalar | |
* \tparam Reducer which reducer to use | |
* \tparam m_dimkeep which dimension to be kept, encoded with dimsrc - dimkeep | |
*/ | |
template<typename SrcExp, typename DType, typename Reducer, int m_dimkeep> | |
struct ReduceTo1DExp: | |
public Exp<ReduceTo1DExp<SrcExp, DType, Reducer, m_dimkeep>, | |
DType, type::kComplex> { | |
/*! \brief source operand */ | |
const SrcExp &src_; | |
/*! \brief source operand, scale of the */ | |
DType scale_; | |
/*! \brief construct a repmat expression from src and nrow */ | |
ReduceTo1DExp(const SrcExp& src, DType scale) : src_(src), scale_(scale) {} | |
}; | |
/*! | |
* \brief a sum over all dimensions, except dimkeep | |
* \param exp input expression that must be a matrix Tensor<?,2> | |
* \return a expresion with type Tensor<Device,1> | |
* \tparam dimkeep the dimension that will be kept | |
* \tparam SrcExp expression | |
* \tparam etype type of expression | |
*/ | |
template<int dimkeep, typename SrcExp, typename DType, int etype> | |
inline ReduceTo1DExp<SrcExp, DType, red::sum, | |
ExpInfo<SrcExp>::kDim - dimkeep> | |
sumall_except_dim(const Exp<SrcExp, DType, etype> &exp) { | |
return ReduceTo1DExp<SrcExp, DType, red::sum, | |
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1)); | |
} | |
/*! | |
* \brief reduce over all dimensions, except dimkeep | |
* \param exp input expression that must be a matrix Tensor<?,2> | |
* \return a expresion with type Tensor<Device,1> | |
* \tparam dimkeep the dimension that will be kept | |
* \tparam SrcExp expression | |
* \tparam etype type of expression | |
*/ | |
template<int dimkeep, typename Reducer, typename SrcExp, typename DType, int etype> | |
inline ReduceTo1DExp<SrcExp, DType, Reducer, | |
ExpInfo<SrcExp>::kDim - dimkeep> | |
reduce_except_dim(const Exp<SrcExp, DType, etype> &exp) { | |
return ReduceTo1DExp<SrcExp, DType, Reducer, | |
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1)); | |
} | |
/*! | |
* \brief a expression that sum over rows of a matrix | |
* \param exp input expression that must be a matrix Tensor<?, 2> | |
* \return a expresion with type Tensor<Device, 1> | |
* \tparam SrcExp expression | |
* \tparam etype type of expression | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline ReduceTo1DExp<SrcExp, DType, red::sum, 1> | |
sum_rows(const Exp<SrcExp, DType, etype> &exp) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim ==2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return sumall_except_dim<1>(exp); | |
} | |
template<typename SV, typename Device, typename DType, | |
typename SrcExp, typename Reducer, int m_dimkeep> | |
struct ExpComplexEngine<SV, | |
Tensor<Device, 1, DType>, | |
ReduceTo1DExp<SrcExp, DType, Reducer, m_dimkeep>, | |
DType> { | |
static const int dimkeep = ExpInfo<SrcExp>::kDim - m_dimkeep; | |
inline static void Eval(Tensor<Device, 1, DType> *dst, | |
const ReduceTo1DExp<SrcExp, DType, | |
Reducer, m_dimkeep> &exp) { | |
TypeCheckPass<m_dimkeep != 1> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
MapReduceKeepHighDim<SV, Reducer, dimkeep>(dst, exp.src_, exp.scale_); | |
} | |
}; | |
template<typename SV, typename Device, typename DType, | |
typename SrcExp, typename Reducer> | |
struct ExpComplexEngine<SV, | |
Tensor<Device, 1, DType>, | |
ReduceTo1DExp<SrcExp, DType, Reducer, 1>, DType> { | |
inline static void Eval(Tensor<Device, 1, DType> *dst, | |
const ReduceTo1DExp<SrcExp, DType, Reducer, 1> &exp) { | |
MapReduceKeepLowest<SV, Reducer>(dst, exp.src_, exp.scale_); | |
} | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_REDUCETO1D_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/reduceto1d.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/spatial_pool.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file spatial_pool.h | |
* \brief support for spatial pooling | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_SPATIAL_POOL_H_ | |
#define MSHADOW_EXTENSION_SPATIAL_POOL_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief pooling expression, do reduction over local patches of a image | |
* \tparam Reducer reduction method during pooling | |
* \tparam SrcExp source expression to be pooled from | |
* \tparam DType the content data type | |
* \tparam srcdim dimension of src | |
*/ | |
template<typename Reducer, typename SrcExp, typename DType, int srcdim> | |
struct PoolingExp: | |
public MakeTensorExp<PoolingExp<Reducer, SrcExp, DType, srcdim>, | |
SrcExp, srcdim, DType> { | |
/*! \brief source operand */ | |
const SrcExp &src_; | |
/*! \brief kernel size in height */ | |
index_t ksize_y_; | |
/*! \brief kernel size in width */ | |
index_t ksize_x_; | |
/*! \brief kernel stride in y directory */ | |
index_t kstride_y_; | |
/*! \brief kernel stride in x directory */ | |
index_t kstride_x_; | |
/*! \brief source height shape[1] */ | |
index_t src_height_; | |
/*! \brief source width shape[0] */ | |
index_t src_width_; | |
/*! \brief constructor */ | |
PoolingExp(const SrcExp &src, | |
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) | |
: src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x), | |
kstride_y_(kstride_y), kstride_x_(kstride_x) { | |
Shape<srcdim> sshape = ShapeCheck<srcdim, SrcExp>::Check(src_); | |
CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y) | |
<< "PoolingExp: kernel must be smaller than image"; | |
this->src_height_ = sshape[srcdim - 2]; | |
this->src_width_ = sshape[srcdim - 1]; | |
this->shape_ = sshape; | |
this->shape_[srcdim - 2] = (src_height_ - ksize_y) / kstride_y + 1; | |
this->shape_[srcdim - 1] = (src_width_ - ksize_x) / kstride_x + 1; | |
} | |
/*! \brief constructor, specify shape */ | |
PoolingExp(const SrcExp &src, Shape<2> pshape, | |
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) | |
: src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x), | |
kstride_y_(kstride_y), kstride_x_(kstride_x) { | |
Shape<srcdim> sshape = ShapeCheck<srcdim, SrcExp>::Check(src_); | |
CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y) | |
<< "PoolingExp: kernel must be smaller than image"; | |
this->src_height_ = sshape[srcdim - 2]; | |
this->src_width_ = sshape[srcdim - 1]; | |
this->shape_ = sshape; | |
this->shape_[srcdim - 2] = pshape[0]; | |
this->shape_[srcdim - 1] = pshape[1]; | |
} | |
}; | |
/*! | |
* \brief pooling subregion results together | |
* \param src source image, shape: (batch, channel, height, width) | |
* \param ksize_y kernel size in height | |
* \param ksize_x kernel size in width | |
* \param kstride_y stride in y directory | |
* \param kstride_x stride in x directory | |
* \return expression of pooled result | |
* \tparam Reducer reducer type | |
* \tparam SrcExp source expression | |
* \tparam DType the content data type | |
* \tparam etype type of expression | |
*/ | |
template<typename Reducer, typename SrcExp, typename DType, int etype> | |
inline PoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
pool(const Exp<SrcExp, DType, etype> &src, | |
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return PoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
(src.self(), ksize_y, ksize_x, kstride_y, kstride_x); | |
} | |
/*! | |
* \brief same as pool, except the output shape is specified by pshape | |
* \param src source image | |
* \param pshape ouput shape | |
* \param ksize_y kernel size in y | |
* \param ksize_x kernel size in x | |
* \param kstride_y stride in y directory | |
* \param kstride_x stride in x directory | |
* \return expression of pooled result | |
* \tparam Reducer reducer type | |
* \tparam SrcExp source expression | |
* \tparam DType the content data type | |
* \tparam etype type of expression | |
*/ | |
template<typename Reducer, typename SrcExp, | |
typename DType, int etype> | |
inline PoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
pool(const Exp<SrcExp, DType, etype> &src, Shape<2> pshape, | |
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return PoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
(src.self(), pshape, ksize_y, ksize_x, kstride_y, kstride_x); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename Reducer, typename SrcExp, typename DType, int srcdim> | |
struct Plan<PoolingExp< Reducer, SrcExp, DType, srcdim>, DType> { | |
public: | |
explicit Plan(const PoolingExp<Reducer, SrcExp, DType, srcdim> &e) | |
: src_(MakePlan(e.src_)), | |
ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_), | |
kstride_y_(e.kstride_y_), kstride_x_(e.kstride_x_), | |
src_height_(e.src_height_), src_width_(e.src_width_), | |
new_height_(e.shape_[srcdim - 2]) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
using namespace std; | |
const index_t py = i % new_height_; | |
const index_t y_start = py * kstride_y_; | |
const index_t y_end = min(y_start + ksize_y_, src_height_); | |
const index_t px = j; | |
const index_t x_start = px * kstride_x_; | |
const index_t x_end = min(x_start + ksize_x_, src_width_); | |
const index_t c = i / new_height_; | |
DType res; Reducer::SetInitValue(res); | |
for (index_t y = y_start; y < y_end; ++y) { | |
for (index_t x = x_start; x < x_end; ++x) { | |
Reducer::Reduce(res, src_.Eval(c * src_height_ + y, x)); | |
} | |
} | |
return res; | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t ksize_y_, ksize_x_, kstride_y_, kstride_x_; | |
const index_t src_height_, src_width_; | |
const index_t new_height_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_SPATIAL_POOL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/spatial_pool.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/spatial_unpool.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file spatial_unpool.h | |
* \brief support for unpool | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ | |
#define MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief unpooling expr reverse operation of pooling, used to pass gradient back | |
* \tparam Reducer reduction method during pooling | |
* \tparam SrcExp source expression to be pooled from | |
* \tparam DType the content data type | |
* \tparam srcdim dimension of src | |
*/ | |
template<typename Reducer, typename SrcExp, typename DType, int srcdim> | |
struct UnPoolingExp: | |
public MakeTensorExp<UnPoolingExp<Reducer, SrcExp, DType, srcdim>, | |
SrcExp, srcdim, DType> { | |
/*! \brief source input, corresponds to src in pooling */ | |
const SrcExp &data_src_; | |
/*! \brief result of pooled data, corresponds to result of pooling */ | |
const SrcExp &data_pooled_; | |
/*! \brief gradient data of pooled part, to be propgate down */ | |
const SrcExp &grad_pooled_; | |
/*! \brief shape of pooled expression */ | |
index_t pshape_y_; | |
/*! \brief shape of pooled expression */ | |
index_t pshape_x_; | |
/*! \brief kernel size in height */ | |
index_t ksize_y_; | |
/*! \brief kernel size in width */ | |
index_t ksize_x_; | |
/*! \brief kernel stride in y directory */ | |
index_t kstride_y_; | |
/*! \brief kernel stride in x directory */ | |
index_t kstride_x_; | |
/*! \brief constructor */ | |
UnPoolingExp(const SrcExp &data_src, | |
const SrcExp &data_pooled, | |
const SrcExp &grad_pooled, | |
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) | |
: data_src_(data_src), data_pooled_(data_pooled), | |
grad_pooled_(grad_pooled), | |
ksize_y_(ksize_y), ksize_x_(ksize_x), | |
kstride_y_(kstride_y), kstride_x_(kstride_x) { | |
Shape<srcdim> pshape = ShapeCheck<srcdim, SrcExp>::Check(grad_pooled); | |
typedef ShapeCheck<srcdim, SrcExp> ShapeCheckSrcDimSrcExp; | |
CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled)) | |
<< "UnPoolingExp: pooled shape mismatch"; | |
Shape<srcdim> sshape = ShapeCheck<srcdim, SrcExp>::Check(data_src); | |
for (int k = 0; k < srcdim - 2; ++k) { | |
CHECK_EQ(pshape[k], sshape[k]) << "UnPoolingExp: pool and src shape mismatch"; | |
} | |
pshape_x_ = pshape[srcdim - 1]; | |
pshape_y_ = pshape[srcdim - 2]; | |
this->shape_ = sshape; | |
} | |
}; | |
/*! | |
* \brief unpooling gradient for 4D, backprop gradient value back, revserse operation of pooling, | |
* same as unpooling, but allows unequal size of kernel | |
* \param data_src source input, corresponds to src in pooling | |
* \param data_pooled result of pooled data, corresponds to result of pooling | |
* \param grad_pooled gradient data of pooled part, to be propgate down | |
* \param ksize_y kernel height | |
* \param ksize_x kernel width | |
* \param kstride_y stride in y directory | |
* \param kstride_x stride in x directory | |
* \return expression corresponding to unpooled 4D Tensor, storing backproped gradient | |
* \tparam Reducer reducer type | |
* \tparam SrcExp source expression | |
* \tparam DType the content data type | |
* \tparam etype type of expression | |
*/ | |
template<typename Reducer, typename SrcExp, typename DType, int etype> | |
inline UnPoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
unpool(const Exp<SrcExp, DType, etype> &data_src, | |
const Exp<SrcExp, DType, etype> &data_pooled, | |
const Exp<SrcExp, DType, etype> &grad_pooled, | |
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) { | |
return UnPoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
(data_src.self(), data_pooled.self(), grad_pooled.self(), | |
ksize_y, ksize_x, kstride_y, kstride_x); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename Reducer, typename SrcExp, typename DType, int srcdim> | |
struct Plan<UnPoolingExp<Reducer, SrcExp, DType, srcdim>, DType> { | |
public: | |
explicit Plan(const UnPoolingExp<Reducer, SrcExp, DType, srcdim> &e) | |
: data_src_(MakePlan(e.data_src_)), data_pooled_(MakePlan(e.data_pooled_)), | |
grad_pooled_(MakePlan(e.grad_pooled_)), sshape_y_(e.shape_[srcdim - 2]), | |
pshape_y_(e.pshape_y_), pshape_x_(e.pshape_x_), | |
ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_), | |
kstride_y_(e.kstride_y_), kstride_x_(e.kstride_x_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
using namespace std; | |
const index_t x = j; | |
const index_t y = i % sshape_y_; | |
const index_t c = i / sshape_y_; | |
const DType vsrc = data_src_.Eval(i, j); | |
const index_t py_min = | |
y < ksize_y_ ? 0 : (y - ksize_y_ + kstride_y_) / kstride_y_; | |
const index_t px_min = | |
x < ksize_x_ ? 0 : (x - ksize_x_ + kstride_x_) / kstride_x_; | |
const index_t py_max = min((y + kstride_y_) / kstride_y_, pshape_y_); | |
const index_t px_max = min((x + kstride_x_) / kstride_x_, pshape_x_); | |
DType val = static_cast<DType>(0); | |
for (index_t py = py_min; py < py_max; ++py) { | |
for (index_t px = px_min; px < px_max; ++px) { | |
val += Reducer::PartialGrad(vsrc, | |
data_pooled_.Eval(c * pshape_y_ + py, px)) * | |
grad_pooled_.Eval(c * pshape_y_ + py, px); | |
} | |
} | |
return val; | |
} | |
private: | |
Plan<SrcExp, DType> data_src_, data_pooled_, grad_pooled_; | |
const index_t sshape_y_, pshape_y_, pshape_x_; | |
const index_t ksize_y_, ksize_x_; | |
const index_t kstride_y_, kstride_x_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/spatial_unpool.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/channel_pool.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file channel_pool.h | |
* \brief support for chpool | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_CHANNEL_POOL_H_ | |
#define MSHADOW_EXTENSION_CHANNEL_POOL_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief channel pooling expression, do reduction over (local nearby) channels, | |
* used to implement local response normalization | |
* \tparam Reducer reduction method during pooling | |
* \tparam SrcExp source expression to be pooled from | |
* \tparam DType the type of elements | |
* \tparam srcdim dimension of src | |
*/ | |
template<typename Reducer, typename SrcExp, typename DType, int srcdim> | |
struct ChannelPoolingExp: | |
public MakeTensorExp<ChannelPoolingExp<Reducer, SrcExp, DType, srcdim>, | |
SrcExp, srcdim, DType> { | |
/*! \brief source operand */ | |
const SrcExp &src_; | |
/*! \brief neighbor size */ | |
index_t nsize_; | |
/*! \brief stride of pooling */ | |
index_t stride_; | |
/*! \brief pad of pooling of each side */ | |
index_t pad_; | |
index_t src_channel_; | |
/*! \brief constructor */ | |
ChannelPoolingExp(const SrcExp &src, index_t nsize, index_t stride, index_t pad) | |
: src_(src), nsize_(nsize), stride_(stride), pad_(pad) { | |
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_); | |
this->src_channel_ = this->shape_[srcdim - 3]; | |
CHECK_GE(this->shape_[srcdim - 3], nsize_) | |
<< "chpool: local size must be smaller than nchannels"; | |
this->shape_[srcdim - 3] = (this->src_channel_ - nsize + pad * 2 + 1) / stride; | |
} | |
}; | |
/*! | |
* \brief channel pooling, do reduction over (local nearby) channels, | |
* used to implement local response normalization | |
* \param src source data | |
* \param nsize neighbor size | |
* \return expression of pooled result | |
* \tparam Reducer reducer type | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype type of expression | |
*/ | |
template<typename Reducer, typename SrcExp, typename DType, int etype> | |
inline ChannelPoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
chpool(const Exp<SrcExp, DType, etype> &src, index_t nsize) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 3> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
CHECK_EQ(nsize % 2, 1) << "chpool: if no pad is specified, local size must be odd"; | |
return ChannelPoolingExp<Reducer, SrcExp, | |
DType, ExpInfo<SrcExp>::kDim>(src.self(), nsize, 1, nsize / 2); | |
} | |
template<typename Reducer, typename SrcExp, typename DType, int etype> | |
inline ChannelPoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
chpool(const Exp<SrcExp, DType, etype> &src, index_t nsize, index_t stride, index_t pad) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 3> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return ChannelPoolingExp<Reducer, SrcExp, | |
DType, ExpInfo<SrcExp>::kDim>(src.self(), nsize, stride, pad); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename Reducer, typename SrcExp, typename DType, int srcdim> | |
struct Plan<ChannelPoolingExp<Reducer, SrcExp, DType, srcdim>, DType> { | |
public: | |
explicit Plan(const ChannelPoolingExp<Reducer, SrcExp, DType, srcdim> &e) | |
: src_(MakePlan(e.src_)), channel_(e.shape_[srcdim - 3]), | |
height_(e.shape_[srcdim - 2]), width_(e.shape_[srcdim - 1]), | |
hnsize_(e.nsize_), stride_(e.stride_), pad_(e.pad_), | |
src_channel_(e.src_channel_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
using namespace std; | |
const index_t y = i % height_; | |
i /= height_; | |
const index_t c = i % channel_; | |
const index_t n = i / channel_; | |
const index_t x = j; | |
const index_t cstart = c * stride_ < pad_ ? 0 : c * stride_ - pad_; | |
const index_t cend = min(cstart + hnsize_, channel_); | |
DType res; Reducer::SetInitValue(res); | |
for (index_t cc = cstart; cc < cend; ++cc) { | |
Reducer::Reduce(res, src_.Eval((n * src_channel_ + cc) * height_ + y, x)); | |
} | |
return res; | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t channel_, height_, width_, hnsize_, stride_, pad_, src_channel_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_CHANNEL_POOL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/channel_pool.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/channel_unpool.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file channel_pool.h | |
* \brief support for chpool | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ | |
#define MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief channel pooling expression, do reduction over (local nearby) channels, | |
* used to implement local response normalization | |
* \tparam Reducer reduction method during pooling | |
* \tparam SrcExp source expression to be pooled from | |
* \tparam DType the type of elements | |
* \tparam srcdim dimension of src | |
*/ | |
template<typename Reducer, typename SrcExp, typename DType, int srcdim> | |
struct ChannelUnpoolingExp: | |
public MakeTensorExp<ChannelUnpoolingExp<Reducer, SrcExp, DType, srcdim>, | |
SrcExp, srcdim, DType> { | |
/*! \brief source input, corresponds to src in pooling */ | |
const SrcExp &data_src_; | |
/*! \brief result of pooled data, corresponds to result of pooling */ | |
const SrcExp &data_pooled_; | |
/*! \brief gradient data of pooled part, to be propgate down */ | |
const SrcExp &grad_pooled_; | |
/*! \brief channel of pooled expression */ | |
index_t pchannel_; | |
/*! \brief kernel size in height */ | |
index_t nsize_; | |
/*! \brief kernel size in width */ | |
index_t kstride_; | |
/*! \brief pad */ | |
index_t pad_; | |
/*! \brief constructor */ | |
ChannelUnpoolingExp(const SrcExp &data_src, | |
const SrcExp &data_pooled, | |
const SrcExp &grad_pooled, | |
index_t nsize, index_t kstride, index_t pad) | |
: data_src_(data_src), data_pooled_(data_pooled), | |
grad_pooled_(grad_pooled), | |
nsize_(nsize), kstride_(kstride), pad_(pad) { | |
Shape<srcdim> pshape = ShapeCheck<srcdim, SrcExp>::Check(grad_pooled); | |
typedef ShapeCheck<srcdim, SrcExp> ShapeCheckSrcDimSrcExp; | |
CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled)) | |
<< "ChannelUnPoolingExp: data and grad shape mismatch"; | |
Shape<srcdim> sshape = ShapeCheck<srcdim, SrcExp>::Check(data_src); | |
for (int k = 0; k < srcdim; ++k) { | |
if (k == 1) { | |
continue; | |
} | |
CHECK_EQ(pshape[k], sshape[k]) | |
<< "ChannelUnPoolingExp: pooled tensor and src tensor shape mismatch" | |
<< pshape[k] | |
<< " vs " | |
<< sshape[k]; | |
} | |
pchannel_ = pshape[1]; | |
this->shape_ = sshape; | |
} | |
}; | |
/*! | |
* \brief channel unpooling, do unroll over (local nearby) channels | |
* \param src source data | |
* \param nsize neighbor size | |
* \param stride stride of the pooling | |
* \param pad number of padding at each side | |
* \return expression of pooled result | |
* \tparam Reducer reducer type | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype type of expression | |
*/ | |
template<typename Reducer, typename SrcExp, typename DType, int etype> | |
inline ChannelUnpoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
ch_unpool(const Exp<SrcExp, DType, etype> &data_src, | |
const Exp<SrcExp, DType, etype> &data_pooled, | |
const Exp<SrcExp, DType, etype> &grad_pooled, | |
index_t nsize, index_t stride, index_t pad) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 3> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return ChannelUnpoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
(data_src.self(), data_pooled.self(), grad_pooled.self(), nsize, stride, pad); | |
} | |
template<typename Reducer, typename SrcExp, typename DType, int etype> | |
inline ChannelUnpoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
ch_unpool(const Exp<SrcExp, DType, etype> &data_src, | |
const Exp<SrcExp, DType, etype> &data_pooled, | |
const Exp<SrcExp, DType, etype> &grad_pooled, index_t nsize) { | |
return ch_unpool(data_src, data_pooled, grad_pooled, nsize, 1, nsize / 2); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename Reducer, typename SrcExp, typename DType, int srcdim> | |
struct Plan<ChannelUnpoolingExp<Reducer, SrcExp, DType, srcdim>, DType> { | |
public: | |
explicit Plan(const ChannelUnpoolingExp<Reducer, SrcExp, DType, srcdim> &e) | |
: data_src_(e.data_src_), data_pooled_(e.data_pooled_), | |
grad_pooled_(e.grad_pooled_), channel_(e.shape_[srcdim - 3]), | |
height_(e.shape_[srcdim - 2]), pchannel_(e.pchannel_), | |
hnsize_(e.nsize_), stride_(e.kstride_), pad_(e.pad_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
using namespace std; | |
const DType vsrc = data_src_.Eval(i, j); | |
const index_t y = i % height_; | |
i /= height_; | |
const index_t c = i % channel_; | |
const index_t n = i / channel_; | |
const index_t x = j; | |
const index_t cstart = c < hnsize_ - pad_ ? 0 | |
: (c - (hnsize_ - pad_) + stride_) / stride_; | |
const index_t cend = min((c + pad_ + stride_) / stride_, channel_); | |
DType val = static_cast<DType>(0); | |
for (index_t cc = cstart; cc < cend; ++cc) { | |
val += Reducer::PartialGrad(vsrc, | |
data_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x)) * | |
grad_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x); | |
} | |
return val; | |
} | |
private: | |
Plan<SrcExp, DType> data_src_, data_pooled_, grad_pooled_; | |
const index_t channel_, height_, pchannel_, hnsize_, stride_, pad_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/channel_unpool.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/pad.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file pad.h | |
* \brief support for pad | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_PAD_H_ | |
#define MSHADOW_EXTENSION_PAD_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief padding expression, pad a image with zeros | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam srcdim dimension of src | |
*/ | |
template<typename SrcExp, typename DType, int srcdim> | |
struct PaddingExp: | |
public MakeTensorExp<PaddingExp<SrcExp, DType, srcdim>, | |
SrcExp, srcdim, DType> { | |
/*! \brief source operand */ | |
const SrcExp &src_; | |
/*! \brief pad size in y */ | |
index_t pad_y_; | |
/*! \brief pad size in x */ | |
index_t pad_x_; | |
/*! \brief source tensor height */ | |
index_t src_height_; | |
/*! \brief source tensor width */ | |
index_t src_width_; | |
/*! \brief constructor */ | |
PaddingExp(const SrcExp &src, index_t pad_y, index_t pad_x) | |
: src_(src), pad_y_(pad_y), pad_x_(pad_x) { | |
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_); | |
src_height_ = this->shape_[srcdim - 2]; | |
src_width_ = this->shape_[srcdim - 1]; | |
this->shape_[srcdim - 2] += pad_y * 2; // height | |
this->shape_[srcdim - 1] += pad_x * 2; // width | |
} | |
}; | |
/*! | |
* \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1] | |
* \param src original image batches | |
* \param pad padding size | |
* \return expression corresponding to padded result | |
* \tparam SrcExp source expression | |
* \tparam DType the content data type | |
* \tparam etype type of expression | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
pad(const Exp<SrcExp, DType, etype> &src, index_t pad) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), pad, pad); | |
} | |
/*! | |
* \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1] | |
* \param src original image batches | |
* \param pad_y padding size in y | |
* \param pad_x padding size in x | |
* \return expression corresponding to padded result | |
* \tparam SrcExp source expression | |
* \tparam DType the content data type | |
* \tparam etype type of expression | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
pad(const Exp<SrcExp, DType, etype> &src, index_t pad_y, index_t pad_x) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
(src.self(), pad_y, pad_x); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename SrcExp, typename DType, int srcdim> | |
struct Plan<PaddingExp<SrcExp, DType, srcdim>, DType> { | |
public: | |
explicit Plan(const PaddingExp<SrcExp, DType, srcdim> &e) | |
: src_(MakePlan(e.src_)), | |
pad_y_(e.pad_y_), pad_x_(e.pad_x_), | |
new_height_(e.shape_[srcdim - 2]), | |
src_height_(e.src_height_), src_width_(e.src_width_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
const index_t x = j; | |
const index_t y = i % new_height_; | |
const index_t c = i / new_height_; | |
if (y < pad_y_ || x < pad_x_) return static_cast<DType>(0); | |
const index_t h = y - pad_y_; | |
const index_t w = x - pad_x_; | |
if (h < src_height_ && w < src_width_) { | |
return src_.Eval(c * src_height_ + h, w); | |
} else { | |
return static_cast<DType>(0); | |
} | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t pad_y_; | |
const index_t pad_x_; | |
const index_t new_height_; | |
const index_t src_height_; | |
const index_t src_width_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_PAD_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/pad.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/crop.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file crop.h | |
* \brief support for crop | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_CROP_H_ | |
#define MSHADOW_EXTENSION_CROP_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief crop expression, cut off the boundary region, reverse operation of padding | |
* \tparam SrcExp source expression to be pooled from | |
* \tparam DType the type of elements | |
* \tparam srcdim dimension of src | |
*/ | |
template<typename SrcExp, typename DType, int srcdim> | |
struct CroppingExp: | |
public MakeTensorExp<CroppingExp<SrcExp, DType, srcdim>, | |
SrcExp, srcdim, DType> { | |
/*! \brief source operand */ | |
const SrcExp &src_; | |
/*! \brief pad height */ | |
index_t pad_height_; | |
/*! \brief pad height */ | |
index_t pad_width_; | |
/*! \brief src height */ | |
index_t src_height_; | |
/*! \brief constructor */ | |
explicit CroppingExp(const SrcExp &src, Shape<2> cshape) | |
: src_(src) { | |
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_); | |
CHECK_GE(this->shape_[srcdim - 2], cshape[0]) << "CroppingExp: height requirement not met"; | |
CHECK_GE(this->shape_[srcdim - 1], cshape[1]) << "CroppingExp: width requirement not met"; | |
pad_height_ = (this->shape_[srcdim - 2] - cshape[0]) / 2; | |
pad_width_ = (this->shape_[srcdim - 1] - cshape[1]) / 2; | |
src_height_ = this->shape_[srcdim - 2]; | |
this->shape_[srcdim - 2] = cshape[0]; // height | |
this->shape_[srcdim - 1] = cshape[1]; // width | |
} | |
/*! \brief constructor */ | |
explicit CroppingExp(const SrcExp &src, Shape<2> cshape, | |
index_t start_height, index_t start_width) | |
: src_(src), pad_height_(start_height), pad_width_(start_width) { | |
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_); | |
CHECK_GE(this->shape_[srcdim - 2], cshape[0] + start_height) | |
<< "CroppingExp: height requirement not met"; | |
CHECK_GE(this->shape_[srcdim - 1], cshape[1] + start_width) | |
<< "CroppingExp: width requirement not met"; | |
src_height_ = this->shape_[srcdim - 2]; | |
this->shape_[srcdim - 2] = cshape[0]; // height | |
this->shape_[srcdim - 1] = cshape[1]; // width | |
} | |
}; // struct CroppingExp | |
/*! | |
* \brief revserse operationg of padding, cut off boundaries, | |
* crop output from center of input | |
* \param src original image batches | |
* \param oshape output shape to be cropped | |
* \return expression corresponding to padded result | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype type of expression | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline CroppingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
crop(const Exp<SrcExp, DType, etype> &src, Shape<2> oshape) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return CroppingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), oshape); | |
} | |
/*! | |
* \brief same as crop, but can specify starting position to do cropping | |
* \param src original image batches | |
* \param oshape output shape to be cropped | |
* \param start_height start height position to do cropping | |
* \param start_width start width position to do cropping | |
* \return expression corresponding to padded result | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype type of expression | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline CroppingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
crop(const Exp<SrcExp, DType, etype> &src, Shape<2> oshape, | |
index_t start_height, index_t start_width) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return CroppingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
(src.self(), oshape, start_height, start_width); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename SrcExp, typename DType, int srcdim> | |
struct Plan<CroppingExp<SrcExp, DType, srcdim>, DType> { | |
public: | |
explicit Plan(const CroppingExp<SrcExp, DType, srcdim> &e) | |
: src_(MakePlan(e.src_)), | |
pad_height_(e.pad_height_), pad_width_(e.pad_width_), | |
new_height_(e.shape_[srcdim - 2]), src_height_(e.src_height_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
const index_t x = j; | |
const index_t y = i % new_height_; | |
const index_t c = i / new_height_; | |
const index_t h = y + pad_height_; | |
const index_t w = x + pad_width_; | |
return src_.Eval(c * src_height_ + h, w); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t pad_height_, pad_width_; | |
const index_t new_height_; | |
const index_t src_height_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_CROP_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/crop.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/mirror.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file mirror.h | |
* \brief support for mirror | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_MIRROR_H_ | |
#define MSHADOW_EXTENSION_MIRROR_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief mirror expression, mirror a image in width | |
* \tparam SrcExp source expression to be mirrored | |
* \tparam DType the type of elements | |
* \tparam srcdim dimension of src | |
*/ | |
template<typename SrcExp, typename DType, int srcdim> | |
struct MirroringExp: | |
public MakeTensorExp<MirroringExp<SrcExp, DType, srcdim>, | |
SrcExp, srcdim, DType> { | |
/*! \brief source operand */ | |
const SrcExp &src_; | |
/*! \brief constructor */ | |
explicit MirroringExp(const SrcExp &src) : src_(src) { | |
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_); | |
} | |
}; | |
/*! | |
* \brief mirroring expression, mirror images in width | |
* \param src original image batches | |
* \return expression corresponding to mirrored result | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype type of expression | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline MirroringExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
mirror(const Exp<SrcExp, DType, etype> &src) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return MirroringExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self()); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename SrcExp, typename DType, int srcdim> | |
struct Plan<MirroringExp<SrcExp, DType, srcdim>, DType> { | |
public: | |
explicit Plan(const MirroringExp<SrcExp, DType, srcdim> &e) | |
: src_(MakePlan(e.src_)), width_(e.shape_[srcdim - 1]) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
return src_.Eval(i, width_ - j - 1); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t width_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_MIRROR_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/mirror.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/concat.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file concat.h | |
* \brief support for concatenation | |
*/ | |
#ifndef MSHADOW_EXTENSION_CONCAT_H_ | |
#define MSHADOW_EXTENSION_CONCAT_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief concat expression, concat two tensor's channel | |
* \tparam LhsExp left expression | |
* \tparam RhsExp right expression | |
* \tparam DType the type of elements | |
* \tparam srcdim dimension of src | |
* \tparam dimsrc_m_cat dimsrc - dimcat | |
*/ | |
template<typename LhsExp, typename RhsExp, | |
typename Device, typename DType, | |
int srcdim, int dimsrc_m_cat> | |
struct ConcatExp : public TRValue<ConcatExp<LhsExp, RhsExp, | |
Device, DType, | |
srcdim, dimsrc_m_cat>, | |
Device, srcdim, DType> { | |
static const int dimcat = srcdim - dimsrc_m_cat; | |
const LhsExp &src1_; | |
const RhsExp &src2_; | |
index_t dcat_src1_; | |
index_t dcat_src2_; | |
Shape<4> shape_; | |
ConcatExp(const LhsExp &src1, const RhsExp &src2) : src1_(src1), src2_(src2) { | |
Shape<srcdim> sshape1 = ShapeCheck<srcdim, LhsExp>::Check(src1_); | |
Shape<srcdim> sshape2 = ShapeCheck<srcdim, RhsExp>::Check(src2_); | |
#pragma unroll | |
for (int i = 0; i < srcdim; ++i) { | |
if (i != dimcat) { | |
CHECK_EQ(sshape1[i], sshape2[i]) << "ConcatExp: shape mismatch"; | |
} | |
} | |
this->shape_ = sshape1; | |
this->shape_[dimcat] = sshape1[dimcat] + sshape2[dimcat]; | |
this->dcat_src1_ = sshape1[dimcat]; | |
this->dcat_src2_ = sshape2[dimcat]; | |
} | |
template<typename E, int etype> | |
inline void | |
operator=(const expr::Exp<E, DType, etype> &exp) { | |
this->__assign(exp); | |
} | |
inline void | |
operator=(const DType &exp) { | |
this->__assign(exp); | |
} | |
}; // struct ConcatExp | |
/*! | |
* \brief concat two 4D tensor | |
* \param src1 source tensor1 | |
* \param src2 source tensor2 | |
* \return concated 4D tensor | |
* \tparam cdim the dimension to concatnate on | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype type of expression | |
*/ | |
template<int cdim, typename LhsExp, typename RhsExp, | |
typename Device, typename DType, int srcdim> | |
inline ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, srcdim - cdim> | |
concat(const TRValue<LhsExp, Device, srcdim, DType> &src1, | |
const TRValue<RhsExp, Device, srcdim, DType> &src2) { | |
TypeCheckPass<ExpInfo<LhsExp>::kDim == ExpInfo<RhsExp>::kDim> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
TypeCheckPass<cdim < srcdim && ExpInfo<LhsExp>::kDim == srcdim> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, srcdim - cdim> | |
(src1.self(), src2.self()); | |
} | |
//------------------------ | |
// engine plugin | |
//------------------------ | |
// runtime shapecheck | |
template<typename LhsExp, typename RhsExp, | |
typename Device, typename DType, | |
int srcdim, int dimsrc_m_cat> | |
struct ShapeCheck<srcdim, ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat> >{ | |
inline static Shape<srcdim> Check(const ConcatExp<LhsExp, RhsExp, | |
Device, DType, srcdim, dimsrc_m_cat> &t) { | |
return t.shape_; | |
} | |
}; | |
template<typename LhsExp, typename RhsExp, | |
typename Device, typename DType, | |
int srcdim, int dimsrc_m_cat> | |
struct StreamInfo<Device, ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat> >{ | |
inline static Stream<Device> * | |
Get(const ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat> &t) { | |
Stream<Device> *lhs = StreamInfo<Device, LhsExp>::Get(t.src1_); | |
Stream<Device> *rhs = StreamInfo<Device, RhsExp>::Get(t.src2_); | |
if (lhs != rhs) return NULL; | |
return lhs; | |
} | |
}; | |
// static typecheck | |
template<typename LhsExp, typename RhsExp, | |
typename Device, typename DType, | |
int srcdim, int dimsrc_m_cat> | |
struct ExpInfo<ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat> >{ | |
static const int kDimLhs = ExpInfo<LhsExp>::kDim; | |
static const int kDimRhs = ExpInfo<RhsExp>::kDim; | |
// copy from binarymap | |
static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\ | |
(kDimLhs == 0 ?\ | |
kDimRhs :\ | |
((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1; | |
static const int kDevMask = ExpInfo<LhsExp>::kDevMask & ExpInfo<RhsExp>::kDevMask; | |
}; | |
//---------------------- | |
// Execution plan | |
//--------------------- | |
template<typename LhsExp, typename RhsExp, | |
typename Device, typename DType, | |
int srcdim, int dimsrc_m_cat> | |
struct Plan<ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat>, DType> { | |
public: | |
static const int dimcat = srcdim - dimsrc_m_cat; | |
explicit Plan(const ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat> &e) | |
: src1_(MakePlan(e.src1_)), src2_(MakePlan(e.src2_)), | |
height_(e.shape_.ProdShape(dimcat + 1, srcdim - 1)), | |
ch_src1_(e.dcat_src1_), ch_src2_(e.dcat_src2_), ch_(e.shape_[dimcat]) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
const index_t y = i % height_; | |
i /= height_; | |
const index_t c = i % ch_; | |
const index_t b = i / ch_; | |
const index_t x = j; | |
if (c < ch_src1_) { | |
return src1_.Eval((b * ch_src1_ + c) * height_ + y, x); | |
} else { | |
return src2_.Eval((b * ch_src2_ + c - ch_src1_) * height_ + y, x); | |
} | |
} | |
MSHADOW_XINLINE DType &REval(index_t i, index_t j) { | |
const index_t y = i % height_; | |
i /= height_; | |
const index_t c = i % ch_; | |
const index_t b = i / ch_; | |
const index_t x = j; | |
if (c < ch_src1_) { | |
return src1_.REval((b * ch_src1_ + c) * height_ + y, x); | |
} else { | |
return src2_.REval((b * ch_src2_ + c - ch_src1_) * height_ + y, x); | |
} | |
} | |
private: | |
Plan<LhsExp, DType> src1_; | |
Plan<RhsExp, DType> src2_; | |
const index_t height_, ch_src1_, ch_src2_, ch_; | |
}; // struct Plan | |
// specialize for concat in x | |
template<typename LhsExp, typename RhsExp, | |
typename Device, typename DType, | |
int srcdim> | |
struct Plan<ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, 1>, DType> { | |
public: | |
explicit Plan(const ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, 1> &e) | |
: src1_(MakePlan(e.src1_)), src2_(MakePlan(e.src2_)), | |
width_src1_(e.dcat_src1_) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
if (x < width_src1_) { | |
return src1_.Eval(y, x); | |
} else { | |
return src2_.Eval(y, x - width_src1_); | |
} | |
} | |
MSHADOW_XINLINE DType &REval(index_t y, index_t x) { | |
if (x < width_src1_) { | |
return src1_.REval(y, x); | |
} else { | |
return src2_.REval(y, x - width_src1_); | |
} | |
} | |
private: | |
Plan<LhsExp, DType> src1_; | |
Plan<RhsExp, DType> src2_; | |
const index_t width_src1_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_CONCAT_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/concat.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/choose.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file choose.h | |
* \brief support for implicit array selection operation | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_CHOOSE_H_ | |
#define MSHADOW_EXTENSION_CHOOSE_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief Make a choice of index in the lowest changing dimension. | |
* \tparam SrcExp type of lhs expression | |
* \tparam IndexExp type of index expression | |
* \tparam DType the type of elements | |
*/ | |
template<typename SrcExp, typename IndexExp, typename DType> | |
struct MatChooseRowElementExp: | |
public Exp<MatChooseRowElementExp<SrcExp, IndexExp, DType>, | |
DType, type::kChainer> { | |
/*! \brief source operand */ | |
const SrcExp &src_; | |
/*! \brief index operand */ | |
const IndexExp &index_; | |
/*! \brief constructor */ | |
MatChooseRowElementExp(const SrcExp &src, const IndexExp &index) | |
: src_(src), index_(index) {} | |
}; | |
template<typename SrcExp, typename IndexExp, | |
typename DType, typename IDType, int e1, int e2> | |
inline MatChooseRowElementExp<SrcExp, IndexExp, DType> | |
mat_choose_row_element(const Exp<SrcExp, DType, e1> &src, | |
const Exp<IndexExp, IDType, e2> &index) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim == 2 && ExpInfo<IndexExp>::kDim == 1> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return MatChooseRowElementExp<SrcExp, IndexExp, DType>(src.self(), index.self()); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename SrcExp, typename IndexExp, typename DType> | |
struct Plan<MatChooseRowElementExp<SrcExp, IndexExp, DType>, DType> { | |
public: | |
explicit Plan(const MatChooseRowElementExp<SrcExp, IndexExp, DType> &e) | |
: src_(MakePlan(e.src_)), | |
index_(MakePlan(e.index_)) { | |
} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
index_t idx = static_cast<index_t>(index_.Eval(0, x)); | |
return src_.Eval(x, idx); | |
} | |
private: | |
expr::Plan<SrcExp, DType> src_; | |
expr::Plan<IndexExp, DType> index_; | |
}; | |
template<typename SrcExp, typename IndexExp, typename DType> | |
inline Plan<MatChooseRowElementExp<SrcExp, IndexExp, DType>, DType> | |
MakePlan(const MatChooseRowElementExp<SrcExp, IndexExp, DType> &exp) { | |
return Plan<MatChooseRowElementExp<SrcExp, IndexExp, DType>, DType>(exp); | |
} | |
template<int dim, typename SrcExp, typename IndexExp, typename DType> | |
struct ShapeCheck<dim, MatChooseRowElementExp<SrcExp, IndexExp, DType> > { | |
inline static Shape<dim> | |
Check(const MatChooseRowElementExp<SrcExp, IndexExp, DType> &t) { | |
CHECK(dim == 1) | |
<< "MatChooseRowElementExp only support 1 dimension output"; | |
Shape<2> shape1 = ShapeCheck<2, SrcExp>::Check(t.src_); | |
Shape<dim> shape2 = ShapeCheck<dim, IndexExp>::Check(t.index_); | |
CHECK_EQ(shape1[0], shape2[0]) | |
<< "mat_choose_row_element index length and number of rows in matrix"; | |
return shape2; | |
} | |
}; | |
template<typename SrcExp, typename IndexExp, typename DType> | |
struct ExpInfo<MatChooseRowElementExp<SrcExp, IndexExp, DType> > { | |
static const int kDim = 1; | |
static const int kDevMask = ExpInfo<SrcExp>::kDevMask & ExpInfo<IndexExp>::kDevMask; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_CHOOSE_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/choose.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/fill.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file fill.h | |
* \brief support for implicit array filling operation | |
* \author Xingjian Shi | |
*/ | |
#ifndef MSHADOW_EXTENSION_FILL_H_ | |
#define MSHADOW_EXTENSION_FILL_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief Set value of a specific element in each line of the data matrix. | |
* \tparam SrcExp type of src expression | |
* \tparam ValExp type of val expression | |
* \tparam IndexExp type of index expression | |
* \tparam DType the type of ret expression | |
*/ | |
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType> | |
struct MatFillRowElementExp: | |
public Exp<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, | |
DType, type::kChainer> { | |
/*! \brief src operand */ | |
const SrcExp &src_; | |
const ValExp &val_; | |
/*! \brief index operand */ | |
const IndexExp &index_; | |
/*! \brief constructor */ | |
MatFillRowElementExp(const SrcExp &src, const ValExp &val, const IndexExp &index) | |
: src_(src), val_(val), index_(index) {} | |
}; | |
template<typename SrcExp, typename ValExp, typename IndexExp, | |
typename SDType, typename VDType, typename IDType, int e1, int e2, int e3> | |
inline MatFillRowElementExp<SrcExp, ValExp, IndexExp, SDType> | |
mat_fill_row_element(const Exp<SrcExp, SDType, e1> &src, | |
const Exp<ValExp, VDType, e2> &val, | |
const Exp<IndexExp, IDType, e3> &index) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim == 2 && ExpInfo<ValExp>::kDim == 1 | |
&& ExpInfo<IndexExp>::kDim == 1>::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return MatFillRowElementExp<SrcExp, ValExp, IndexExp, SDType>(src.self(), | |
val.self(), index.self()); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType> | |
struct Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType> { | |
public: | |
explicit Plan(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &e) | |
: src_(MakePlan(e.src_)), | |
val_(MakePlan(e.val_)), | |
index_(MakePlan(e.index_)) { | |
} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
index_t idx = static_cast<index_t>(index_.Eval(0, y)); | |
if (idx == x) { | |
return static_cast<DType>(val_.Eval(0, y)); | |
} else { | |
return static_cast<DType>(src_.Eval(y, x)); | |
} | |
} | |
private: | |
expr::Plan<SrcExp, DType> src_; | |
expr::Plan<ValExp, DType> val_; | |
expr::Plan<IndexExp, DType> index_; | |
}; | |
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType> | |
inline Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType> | |
MakePlan(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &exp) { | |
return Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType>(exp); | |
} | |
template<int dim, typename SrcExp, typename ValExp, typename IndexExp, typename DType> | |
struct ShapeCheck<dim, MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> > { | |
inline static Shape<dim> | |
Check(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &t) { | |
CHECK(dim == 2) | |
<< "MatFillRowElementExp only support 2 dimension output"; | |
Shape<2> shape_src = ShapeCheck<2, SrcExp>::Check(t.src_); | |
Shape<1> shape_val = ShapeCheck<1, ValExp>::Check(t.val_); | |
Shape<1> shape_index = ShapeCheck<1, IndexExp>::Check(t.index_); | |
CHECK((shape_src[0] == shape_index[0]) && (shape_index[0] == shape_val[0])) | |
<< "mat_fill_row_element index length, val length and number of rows in matrix"; | |
return shape_src; | |
} | |
}; | |
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType> | |
struct ExpInfo<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> > { | |
static const int kDim = 2; | |
static const int kDevMask = | |
ExpInfo<SrcExp>::kDevMask & ExpInfo<ValExp>::kDevMask & ExpInfo<IndexExp>::kDevMask; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_FILL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/fill.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/one_hot.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file one_hot.h | |
* \brief Create one-hot indicator array based on the index. | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_EXTENSION_ONE_HOT_H_ | |
#define MSHADOW_EXTENSION_ONE_HOT_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief Create a one-hot indicator array. | |
* \tparam IndexExp type of index expression | |
* \tparam DType the type of elements | |
*/ | |
template<typename IndexExp, typename DType> | |
struct OneHotEncodeExp: | |
public Exp<OneHotEncodeExp<IndexExp, DType>, | |
DType, type::kChainer> { | |
/*! \brief index operand */ | |
const IndexExp &index_; | |
/*! \brief number of choices we can have. */ | |
index_t num_choices_; | |
/*! \brief constructor */ | |
OneHotEncodeExp(const IndexExp &index, index_t num_choices) | |
: index_(index), num_choices_(num_choices) {} | |
}; | |
template<typename IndexExp, | |
typename IDType, int e1> | |
inline OneHotEncodeExp<IndexExp, default_real_t> | |
one_hot_encode(const Exp<IndexExp, IDType, e1> &index, index_t num_choices) { | |
TypeCheckPass<ExpInfo<IndexExp>::kDim == 1> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return OneHotEncodeExp<IndexExp, default_real_t>(index.self(), num_choices); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename IndexExp, typename DType> | |
struct Plan<OneHotEncodeExp<IndexExp, DType>, DType> { | |
public: | |
explicit Plan(const OneHotEncodeExp<IndexExp, DType> &e) | |
: index_(MakePlan(e.index_)) { | |
} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
index_t idx = static_cast<index_t>(index_.Eval(0, y)); | |
return static_cast<DType>(x == idx); | |
} | |
private: | |
expr::Plan<IndexExp, DType> index_; | |
}; | |
template<typename IndexExp, typename DType> | |
inline Plan<OneHotEncodeExp<IndexExp, DType>, DType> | |
MakePlan(const OneHotEncodeExp<IndexExp, DType> &exp) { | |
return Plan<OneHotEncodeExp<IndexExp, DType>, DType>(exp); | |
} | |
template<int dim, typename IndexExp, typename DType> | |
struct ShapeCheck<dim, OneHotEncodeExp<IndexExp, DType> > { | |
inline static Shape<dim> | |
Check(const OneHotEncodeExp<IndexExp, DType> &t) { | |
CHECK(dim == 2) | |
<< "OneHotEncodeExp only support 2 dimension output"; | |
Shape<1> shape = ShapeCheck<1, IndexExp>::Check(t.index_); | |
Shape<dim> ret; | |
ret[0] = shape[0]; | |
ret[1] = t.num_choices_; | |
return ret; | |
} | |
}; | |
template<typename IndexExp, typename DType> | |
struct ExpInfo<OneHotEncodeExp<IndexExp, DType> > { | |
static const int kDim = 2; | |
static const int kDevMask = ExpInfo<IndexExp>::kDevMask; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_ONE_HOT_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/one_hot.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/slice.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file slice.h | |
* \brief support for slice a certain dimension. | |
*/ | |
#ifndef MSHADOW_EXTENSION_SLICE_H_ | |
#define MSHADOW_EXTENSION_SLICE_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief slice expression, slice a tensor's channel | |
* \tparam SrcExp left expression | |
* \tparam DType the type of elements | |
* \tparam srcdim dimension of src | |
* \tparam dimsrc_m_cat dimsrc - dimcat | |
*/ | |
template<typename SrcExp, | |
typename Device, typename DType, | |
int srcdim, int dimsrc_m_slice> | |
struct SliceExp : public TRValue<SliceExp<SrcExp, | |
Device, DType, | |
srcdim, dimsrc_m_slice>, | |
Device, srcdim, DType> { | |
static const int dimslice = srcdim - dimsrc_m_slice; | |
const SrcExp &src_; | |
index_t ch_begin_; | |
index_t ch_old_; | |
Shape<srcdim> shape_; | |
SliceExp(const SrcExp &src, index_t begin, index_t end) | |
: src_(src), ch_begin_(begin) { | |
shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_); | |
ch_old_ = shape_[dimslice]; | |
CHECK(begin < shape_[dimslice] && end <= shape_[dimslice]) | |
<< "The slice went out of range"; | |
shape_[dimslice] = end - begin; | |
} | |
template<typename E, int etype> | |
inline void | |
operator=(const expr::Exp<E, DType, etype> &exp) { | |
this->__assign(exp); | |
} | |
inline void | |
operator=(const DType &exp) { | |
this->__assign(exp); | |
} | |
}; // struct Slice | |
/*! | |
* \brief Slice a Tensor | |
* \param src source tensor | |
* \param begin The beginning slice. | |
* \param end The end slice. | |
* \return sliced tensor | |
* \tparam sdim the dimension to slice on | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype type of expression | |
*/ | |
template<int sdim, typename SrcExp, | |
typename Device, typename DType, int srcdim> | |
inline SliceExp<SrcExp, Device, DType, srcdim, srcdim - sdim> | |
slice(const TRValue<SrcExp, Device, srcdim, DType> &src, index_t begin, index_t end) { | |
TypeCheckPass<sdim < srcdim && ExpInfo<SrcExp>::kDim == srcdim> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return SliceExp<SrcExp, Device, DType, srcdim, srcdim - sdim>(src.self(), begin, end); | |
} | |
//------------------------ | |
// engine plugin | |
//------------------------ | |
// runtime shapecheck | |
template<typename SrcExp, | |
typename Device, typename DType, | |
int srcdim, int dimsrc_m_slice> | |
struct ShapeCheck<srcdim, SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{ | |
inline static Shape<srcdim> Check(const SliceExp<SrcExp, | |
Device, DType, srcdim, dimsrc_m_slice> &t) { | |
return t.shape_; | |
} | |
}; | |
template<typename SrcExp, | |
typename Device, typename DType, | |
int srcdim, int dimsrc_m_slice> | |
struct StreamInfo<Device, SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{ | |
inline static Stream<Device> * | |
Get(const SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> &t) { | |
return StreamInfo<Device, SrcExp>::Get(t.src_); | |
} | |
}; | |
// static typecheck | |
template<typename SrcExp, | |
typename Device, typename DType, | |
int srcdim, int dimsrc_m_slice> | |
struct ExpInfo<SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{ | |
static const int kDim = ExpInfo<SrcExp>::kDim; | |
static const int kDevMask = ExpInfo<SrcExp>::kDevMask; | |
}; | |
//---------------------- | |
// Execution plan | |
//--------------------- | |
template<typename SrcExp, | |
typename Device, typename DType, | |
int srcdim, int dimsrc_m_slice> | |
struct Plan<SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice>, DType> { | |
public: | |
static const int dimslice = srcdim - dimsrc_m_slice; | |
explicit Plan(const SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> &e) | |
: src_(MakePlan(e.src_)), | |
height_(e.shape_.ProdShape(dimslice + 1, srcdim - 1)), | |
ch_begin_(e.ch_begin_), ch_old_(e.ch_old_), ch_(e.shape_[dimslice]) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
const index_t y = i % height_; | |
i /= height_; | |
const index_t c = i % ch_ + ch_begin_; | |
const index_t b = i / ch_; | |
const index_t x = j; | |
return src_.Eval((b * ch_old_ + c) * height_ + y, x); | |
} | |
MSHADOW_XINLINE DType &REval(index_t i, index_t j) { | |
const index_t y = i % height_; | |
i /= height_; | |
const index_t c = i % ch_ + ch_begin_; | |
const index_t b = i / ch_; | |
const index_t x = j; | |
return src_.REval((b * ch_old_ + c) * height_ + y, x); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t height_, ch_begin_, ch_old_, ch_; | |
}; // struct Plan | |
template<typename SrcExp, | |
typename Device, typename DType, | |
int srcdim> | |
struct Plan<SliceExp<SrcExp, Device, DType, srcdim, 1>, DType> { | |
public: | |
explicit Plan(const SliceExp<SrcExp, Device, DType, srcdim, 1> &e) | |
: src_(MakePlan(e.src_)), | |
ch_begin_(e.ch_begin_) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return src_.Eval(y, x + ch_begin_); | |
} | |
MSHADOW_XINLINE DType &REval(index_t y, index_t x) { | |
return src_.REval(y, x + ch_begin_); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t ch_begin_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_SLICE_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/slice.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/slice_ex.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file slice.h | |
* \brief support for slice a certain dimension. | |
*/ | |
#ifndef MSHADOW_EXTENSION_SLICE_EX_H_ | |
#define MSHADOW_EXTENSION_SLICE_EX_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief slice expression, slice a tensor's channel | |
* \tparam SrcExp left expression | |
* \tparam DType the type of elements | |
* \tparam srcdim dimension of src | |
* \tparam dimsrc_m_cat dimsrc - dimcat | |
*/ | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
struct SliceExExp : public TRValue<SliceExExp<SrcExp, | |
Device, DType, | |
srcdim>, | |
Device, srcdim, DType> { | |
const SrcExp &src_; | |
Shape<srcdim> src_shape_; | |
Shape<srcdim> shape_; | |
const Shape<srcdim> begin_; | |
const Shape<srcdim> end_; | |
SliceExExp(const SrcExp &src, Shape<srcdim> begin, Shape<srcdim> end) | |
: src_(src), begin_(begin), end_(end) { | |
src_shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_); | |
for (int i = 0; i < srcdim; ++i) { | |
shape_[i] = end_[i] - begin_[i]; | |
} | |
} | |
template<typename E, int etype> | |
inline void | |
operator=(const expr::Exp<E, DType, etype> &exp) { | |
this->__assign(exp); | |
} | |
inline void | |
operator=(const DType &exp) { | |
this->__assign(exp); | |
} | |
}; // struct SliceEx | |
/*! | |
* \brief SliceEx a Tensor | |
* \param src source tensor | |
* \param begin The beginning slice. | |
* \param end The end slice. | |
* \return sliced tensor | |
* \tparam sdim the dimension to slice on | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype type of expression | |
*/ | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
inline SliceExExp<SrcExp, Device, DType, srcdim> | |
slice(const TRValue<SrcExp, Device, srcdim, DType> &src, Shape<srcdim> begin, Shape<srcdim> end) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim == srcdim> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return SliceExExp<SrcExp, Device, DType, srcdim>(src.self(), begin, end); | |
} | |
//------------------------ | |
// engine plugin | |
//------------------------ | |
// runtime shapecheck | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
struct ShapeCheck<srcdim, SliceExExp<SrcExp, Device, DType, srcdim> >{ | |
inline static Shape<srcdim> Check(const SliceExExp<SrcExp, | |
Device, DType, srcdim> &t) { | |
return t.shape_; | |
} | |
}; | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
struct StreamInfo<Device, SliceExExp<SrcExp, Device, DType, srcdim> >{ | |
inline static Stream<Device> * | |
Get(const SliceExExp<SrcExp, Device, DType, srcdim> &t) { | |
return StreamInfo<Device, SrcExp>::Get(t.src_); | |
} | |
}; | |
// static typecheck | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
struct ExpInfo<SliceExExp<SrcExp, Device, DType, srcdim> >{ | |
static const int kDim = ExpInfo<SrcExp>::kDim; | |
static const int kDevMask = ExpInfo<SrcExp>::kDevMask; | |
}; | |
//---------------------- | |
// Execution plan | |
//--------------------- | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
struct Plan<SliceExExp<SrcExp, Device, DType, srcdim>, DType> { | |
public: | |
explicit Plan(const SliceExExp<SrcExp, Device, DType, srcdim> &e) | |
: src_(MakePlan(e.src_)), begin_(e.begin_), | |
src_shape_(e.src_shape_), shape_(e.shape_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
index_t idx = 0; | |
index_t stride = 1; | |
#pragma unroll | |
for (int k = srcdim-2; k >= 0; --k) { | |
idx += stride * (i%shape_[k] + begin_[k]); | |
i /= shape_[k]; | |
stride *= src_shape_[k]; | |
} | |
return src_.Eval(idx, j + begin_[srcdim-1]); | |
} | |
MSHADOW_XINLINE DType &REval(index_t i, index_t j) { | |
index_t idx = 0; | |
index_t stride = 1; | |
#pragma unroll | |
for (int k = srcdim-2; k >= 0; --k) { | |
idx += stride * (i%shape_[k] + begin_[k]); | |
i /= shape_[k]; | |
stride *= src_shape_[k]; | |
} | |
return src_.REval(idx, j + begin_[srcdim-1]); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const Shape<srcdim> begin_, src_shape_, shape_; | |
}; // struct Plan | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_SLICE_EX_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/slice_ex.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/take.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file take.h | |
* \brief | |
* \author Bing Xu | |
*/ | |
#ifndef MSHADOW_EXTENSION_TAKE_H_ | |
#define MSHADOW_EXTENSION_TAKE_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! \brief Take a column from a matrix | |
* \tparam IndexExp type of index expression | |
* \tparam SrcExp type of src expression | |
* \tparam DType data type | |
*/ | |
template<typename IndexExp, typename SrcExp, typename DType> | |
struct TakeExp: public Exp<TakeExp<IndexExp, SrcExp, DType>, | |
DType, type::kChainer> { | |
/*! \brief index oprand */ | |
const IndexExp &index_; | |
/*! \brief embediing oprand */ | |
const SrcExp &src_; | |
/*! constructor */ | |
TakeExp(const IndexExp &index, const SrcExp &src) | |
: index_(index), src_(src) {} | |
}; // struct TakeExp | |
template<typename IndexExp, | |
typename SrcExp, | |
typename DType, | |
int e1, int e2> | |
inline TakeExp<IndexExp, SrcExp, DType> | |
take(const Exp<IndexExp, DType, e1> &index, | |
const Exp<SrcExp, DType, e2> &src) { | |
return TakeExp<IndexExp, SrcExp, DType>(index.self(), src.self()); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename IndexExp, typename SrcExp, typename DType> | |
struct Plan<TakeExp<IndexExp, SrcExp, DType>, DType> { | |
public: | |
explicit Plan(const TakeExp<IndexExp, SrcExp, DType> &e) | |
: index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) { | |
} | |
// TODO(xx): discuss W shape: in * out or out * in | |
// Now I use in * out | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
index_t idx = static_cast<index_t>(index_.Eval(0, y)); | |
return static_cast<DType>(src_.Eval(idx, x)); | |
} | |
private: | |
expr::Plan<IndexExp, DType> index_; | |
expr::Plan<SrcExp, DType> src_; | |
}; // struct Plan | |
template<typename IndexExp, typename SrcExp, typename DType> | |
inline Plan<TakeExp<IndexExp, SrcExp, DType>, DType> | |
MakePlan(const TakeExp<IndexExp, SrcExp, DType> &exp) { | |
return Plan<TakeExp<IndexExp, SrcExp, DType>, DType>(exp); | |
} | |
template<int dim, typename IndexExp, typename SrcExp, typename DType> | |
struct ShapeCheck<dim, TakeExp<IndexExp, SrcExp, DType> > { | |
inline static Shape<dim> | |
Check(const TakeExp<IndexExp, SrcExp, DType> &t) { | |
CHECK(dim == 2) | |
<< "TakeExp only support 2D output"; | |
Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); | |
Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_); | |
Shape<dim> ret; | |
ret[0] = dshape[0]; | |
ret[1] = wshape[1]; | |
return ret; | |
} | |
}; | |
template<typename IndexExp, typename SrcExp, typename DType> | |
struct ExpInfo<TakeExp<IndexExp, SrcExp, DType> > { | |
static const int kDim = 2; | |
static const int kDevMask = ExpInfo<IndexExp>::kDevMask; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_TAKE_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/take.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/take_grad.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file take_grad.h | |
* \brief | |
* \author Bing Xu | |
*/ | |
#ifndef MSHADOW_EXTENSION_TAKE_GRAD_H_ | |
#define MSHADOW_EXTENSION_TAKE_GRAD_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! \brief Calculate embedding gradient | |
* \tparam IndexExp type of index expression | |
* \tparam SrcExp type of src expression | |
* \tparam DType data type | |
*/ | |
template<typename IndexExp, typename SrcExp, typename DType> | |
struct TakeGradExp : public Exp<TakeGradExp<IndexExp, SrcExp, DType>, | |
DType, type::kChainer> { | |
/*! \brief index oprand */ | |
const IndexExp &index_; | |
/*! \brief out gradient oprand */ | |
const SrcExp &src_; | |
/*! \brief batch size */ | |
const index_t input_dim_; | |
/*! \brief constructor */ | |
TakeGradExp(const IndexExp &index, const SrcExp &src, const index_t input_dim) | |
: index_(index), src_(src), input_dim_(input_dim) {} | |
}; // struct TakeGradExp | |
template<typename IndexExp, | |
typename SrcExp, | |
typename DType, | |
int e1, int e2> | |
inline TakeGradExp<IndexExp, SrcExp, DType> | |
take_grad(const Exp<IndexExp, DType, e1> &index, | |
const Exp<SrcExp, DType, e2> &src, | |
const index_t input_dim) { | |
return TakeGradExp<IndexExp, SrcExp, DType>(index.self(), | |
src.self(), | |
input_dim); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename IndexExp, typename SrcExp, typename DType> | |
struct Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType> { | |
public: | |
explicit Plan(const TakeGradExp<IndexExp, SrcExp, DType> &e) | |
: index_(MakePlan(e.index_)), | |
src_(MakePlan(e.src_)), | |
batch_size_(ShapeCheck<1, IndexExp>::Check(e.index_)[0]) { | |
} | |
// now return shape: in * out | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
DType ret = 0.f; | |
for (index_t i = 0; i < batch_size_; ++i) { | |
index_t idx = static_cast<index_t>(index_.Eval(0, i)); | |
if (idx == y) { | |
ret += static_cast<DType>(src_.Eval(i, x)); | |
} | |
} | |
return ret; | |
} | |
private: | |
expr::Plan<IndexExp, DType> index_; | |
expr::Plan<SrcExp, DType> src_; | |
const index_t batch_size_; | |
}; // struct Plan | |
template<typename IndexExp, typename SrcExp, typename DType> | |
inline Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType> | |
MakePlan(const TakeGradExp<IndexExp, SrcExp, DType> &exp) { | |
return Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType>(exp); | |
} | |
template<int dim, typename IndexExp, typename SrcExp, typename DType> | |
struct ShapeCheck<dim, TakeGradExp<IndexExp, SrcExp, DType> > { | |
inline static Shape<dim> | |
Check(const TakeGradExp<IndexExp, SrcExp, DType> &t) { | |
CHECK(dim == 2) | |
<< "TakeGradExp only support 2D output"; | |
// Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); | |
Shape<2> gshape = ShapeCheck<2, SrcExp>::Check(t.src_); | |
Shape<dim> ret; | |
ret[0] = t.input_dim_; | |
ret[1] = gshape[1]; | |
return ret; | |
} | |
}; // struct ShapeCheck | |
template<typename IndexExp, typename SrcExp, typename DType> | |
struct ExpInfo<TakeGradExp<IndexExp, SrcExp, DType> > { | |
static const int kDim = 2; | |
static const int kDevMask = ExpInfo<IndexExp>::kDevMask; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_TAKE_GRAD_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/take_grad.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/reduce_with_axis.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file reduce_with_axis.h | |
* \brief | |
* \author Junyuan Xie | |
*/ | |
#ifndef MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ | |
#define MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! \brief reduce out the dimension of src labeled by axis. | |
* \tparam Reducer type of reducer | |
* \tparam SrcExp type of source expression | |
* \tparam DType data type | |
*/ | |
template<typename Reducer, typename SrcExp, typename DType, int dimsrc, bool mask, int dimdst> | |
struct ReduceWithAxisExp: | |
public MakeTensorExp<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>, | |
SrcExp, dimdst, DType> { | |
/*! \brief source oprand */ | |
const SrcExp &src_; | |
/*! \brief size of last destination dimension */ | |
index_t last_dst_dim_; | |
/*! \brief size of trailing dimensions */ | |
index_t trailing_; | |
/*! \brief size of axis dimension */ | |
index_t size_; | |
/*! \brief size of last src dimension */ | |
index_t last_; | |
/*! constructor */ | |
explicit ReduceWithAxisExp(const SrcExp &src, int axis) | |
: src_(src) { | |
bool keepdim = (dimsrc == dimdst); | |
CHECK(dimsrc > axis) << "reduce axis out of bound"; | |
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_); | |
for (int i = 0; i < axis; ++i) { | |
this->shape_[i] = src_shape[i]; | |
} | |
this->size_ = src_shape[axis]; | |
this->trailing_ = 1; | |
if (!keepdim) { | |
for (int i = axis + 1; i < dimsrc; ++i) { | |
this->trailing_ *= src_shape[i]; | |
this->shape_[i - 1] = src_shape[i]; | |
} | |
} else { | |
this->shape_[axis] = 1; | |
for (index_t i = axis + 1; i < dimsrc; ++i) { | |
this->trailing_ *= src_shape[i]; | |
this->shape_[i] = src_shape[i]; | |
} | |
} | |
this->last_ = src_shape[dimsrc - 1]; | |
this->last_dst_dim_ = this->shape_[dimdst - 1]; | |
} | |
}; // struct ReduceWithAxisExp | |
/*! | |
* \brief reduce out the dimension of src labeled by axis. | |
* \param Reducer type of the reducing operation | |
* \param mask whether to output the unmask indices | |
* \tparam SrcExp source expression | |
* \tparam DType data type | |
* \tparam etype type of the expression | |
*/ | |
template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype> | |
inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask, | |
ExpInfo<SrcExp>::kDim - 1> | |
reduce_with_axis(const Exp<SrcExp, DType, etype> &src, int axis) { | |
return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask, | |
ExpInfo<SrcExp>::kDim- 1>(src.self(), axis); | |
} | |
/*! | |
* \brief reduce out the dimension of src labeled by axis, keepdim turned on. | |
* \param Reducer type of the reducing operation | |
* \param mask whether to output the unmask indices | |
* \tparam SrcExp source expression | |
* \tparam DType data type | |
* \tparam etype type of the expression | |
*/ | |
template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype> | |
inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask, | |
ExpInfo<SrcExp>::kDim> | |
reduce_keepdim(const Exp<SrcExp, DType, etype> &src, int axis) { | |
return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask, | |
ExpInfo<SrcExp>::kDim>(src.self(), axis); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename Reducer, typename SrcExp, typename DType, int dimsrc, bool mask, int dimdst> | |
struct Plan<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>, DType> { | |
public: | |
explicit Plan(const ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst> &e) | |
: src_(MakePlan(e.src_)), last_dst_dim_(e.last_dst_dim_), trailing_(e.trailing_), | |
size_(e.size_), last_(e.last_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
index_t x = (i*last_dst_dim_ + j)/trailing_; | |
index_t y = (i*last_dst_dim_ + j)%trailing_; | |
if (mask) { | |
index_t idx = 0; | |
DType res; Reducer::SetInitValue(res); | |
for (index_t k = 0; k < size_; ++k) { | |
index_t z = (x*size_+k)*trailing_+y; | |
DType tmp = res; | |
Reducer::Reduce(res, src_.Eval(z/last_, z%last_)); | |
if (tmp != res) { | |
idx = k; | |
} | |
} | |
return static_cast<DType>(static_cast<int>(idx)); | |
} else { | |
DType res; Reducer::SetInitValue(res); | |
for (index_t k = 0; k < size_; ++k) { | |
index_t z = (x*size_+k)*trailing_+y; | |
Reducer::Reduce(res, src_.Eval(z/last_, z%last_)); | |
} | |
return res; | |
} | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t last_dst_dim_, trailing_, size_, last_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/reduce_with_axis.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/broadcast_with_axis.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file broadcast_with_axis.h | |
* \brief | |
* \author Junyuan Xie, Xingjian Shi | |
*/ | |
#ifndef MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ | |
#define MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis. Otherwise broadcasting axis. | |
* \tparam SrcExp source expression | |
* \tparam DType data type | |
* \tparam dimsrc source dimension | |
* \tparam dimdst destination dimension | |
*/ | |
template<typename SrcExp, typename DType, int dimsrc, int dimdst> | |
struct BroadcastWithAxisExp: | |
public MakeTensorExp<BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst>, | |
SrcExp, dimdst, DType> { | |
/*! \brief data oprand */ | |
const SrcExp &src_; | |
/*! \brief size of the last dimension of dst */ | |
index_t dst_last_; | |
/*! \brief product of the dimensions after the broadcasting axis */ | |
index_t trailing_; | |
/*! \brief new dimension of the broadcasting axis*/ | |
index_t size_; | |
/*! \brief size of the last dimension of src*/ | |
index_t last_; | |
/*! constructor */ | |
BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size) | |
: src_(src), size_(size) { | |
bool keepdim = (dimsrc == dimdst); | |
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_); | |
this->trailing_ = 1; | |
if (!keepdim) { | |
CHECK(dimsrc > axis && axis >= -1) << "broadcast axis (no keepdim) out of bound, " << | |
"axis must be between -1 and" << dimsrc - 1 << ", given=" << axis << "."; | |
for (int i = 0; i <= axis; ++i) { | |
this->shape_[i] = src_shape[i]; | |
} | |
this->shape_[axis + 1] = size_; | |
for (int i = axis + 1; i < dimsrc; ++i) { | |
this->trailing_ *= src_shape[i]; | |
this->shape_[i + 1] = src_shape[i]; | |
} | |
} else { | |
CHECK(dimdst > axis && axis >= 0) << "broadcast axis (keepdim) out of bound, " << | |
"axis must be between 0 and" << dimdst - 1 << ", given=" << axis << "."; | |
CHECK_EQ(src_shape[axis], 1) << "Size of the dimension of the broadcasting axis must be 1" << | |
" when keepdim is on, src_shape[" << axis << "]=" << src_shape[axis] << "."; | |
for (int i = 0; i <= axis - 1; ++i) { | |
this->shape_[i] = src_shape[i]; | |
} | |
this->shape_[axis] = size_; | |
for (int i = axis + 1; i < dimdst; ++i) { | |
this->trailing_ *= src_shape[i]; | |
this->shape_[i] = src_shape[i]; | |
} | |
} | |
this->last_ = src_shape[dimsrc - 1]; | |
this->dst_last_ = this->shape_[dimdst - 1]; | |
} | |
}; // struct BroadcastWithAxisExp | |
/*! | |
* \brief Broadcasting the tensor after given axis. | |
* \tparam SrcExp source expression | |
* \tparam DType data type | |
* \tparam etype type of the expression | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim, | |
ExpInfo<SrcExp>::kDim + 1> | |
broadcast_with_axis(const Exp<SrcExp, DType, etype> &src, const int axis, const index_t size) { | |
return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim, | |
ExpInfo<SrcExp>::kDim + 1>(src.self(), axis, size); | |
} | |
/*! | |
* \brief Broadcasting the tensor in the given axis (keepdim turned on) | |
* \tparam SrcExp source expression | |
* \tparam DType data type | |
* \tparam etype type of the expression | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim, | |
ExpInfo<SrcExp>::kDim> | |
broadcast_keepdim(const Exp<SrcExp, DType, etype> &src, const int axis, const index_t size) { | |
return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim, | |
ExpInfo<SrcExp>::kDim>(src.self(), axis, size); | |
} | |
/*! | |
* \brief Broadcasting the tensor in multiple axes. The dimension of the source tensor | |
in the given axes must be 1. | |
* \tparam SrcExp source expression | |
* \tparam DType data type | |
* \tparam dimsrc source dimension | |
* \tparam axesnum number of broadcasting dimensions | |
*/ | |
template<typename SrcExp, typename DType, int dimsrc> | |
struct BroadcastWithMultiAxesExp : | |
public MakeTensorExp<BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc>, | |
SrcExp, dimsrc, DType> { | |
/*! \brief data oprand */ | |
const SrcExp &src_; | |
/*! \brief size of the last dimension of dst */ | |
index_t dst_last_; | |
/*! \brief number of broadcasting axes*/ | |
index_t axesnum_; | |
/*! \brief product of the dimensions after the broadcasting axses */ | |
Shape<dimsrc> trailings_; | |
/*! \brief new dimension of the broadcasting axes*/ | |
Shape<dimsrc> sizes_; | |
/*! \brief size of the last dimension of src*/ | |
index_t last_; | |
/*! constructor */ | |
template<typename TShape> | |
BroadcastWithMultiAxesExp(const SrcExp &src, const TShape& axes, const TShape& sizes) | |
: src_(src) { | |
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_); | |
CHECK(axes.ndim() == sizes.ndim()) << "ndim of axes and sizes must be equal."; | |
this->axesnum_ = axes.ndim(); | |
CHECK(this->axesnum_ <= dimsrc) << "Number of broadcasting axes must be smaller than" | |
"the source ndim, number of axes=" << this->axesnum_ << " dimsrc=" << dimsrc; | |
for (index_t i = 0; i < this->axesnum_; i++) { | |
CHECK(dimsrc > axes[i]) << "broadcast axis (keepdim) out of bound, " << | |
"all axes must be between 0 and" << dimsrc - 1 << ", given axes[" << i << "] = " << axes[i] | |
<< "."; | |
CHECK_EQ(src_shape[axes[i]], 1) << "Size of the dimension of the broadcasting axis must be 1" | |
<< ", src_shape[" << axes[i] << "]=" << src_shape[axes[i]] << "."; | |
if (i < this->axesnum_ - 1) { | |
CHECK(axes[i] < axes[i + 1]) << "The given axes must be in increasing order."; | |
} | |
} | |
for (index_t i = 0; i < dimsrc; i++) { | |
this->shape_[i] = src_shape[i]; | |
this->sizes_[i] = 1; | |
this->trailings_[i] = 1; | |
} | |
for (index_t i = 0; i < this->axesnum_; i++) { | |
this->shape_[axes[i]] = sizes[i]; | |
this->sizes_[i] = sizes[i]; | |
} | |
for (index_t i = 0; i < this->axesnum_; i++) { | |
this->trailings_[i] = 1; | |
for (index_t j = axes[i] + 1; j < dimsrc; ++j) { | |
this->trailings_[i] *= this->shape_[j]; | |
} | |
} | |
this->last_ = src_shape[dimsrc - 1]; | |
this->dst_last_ = this->shape_[dimsrc - 1]; | |
} | |
}; // struct BroadcastWithMultiAxesExp | |
/*! | |
* \brief Broadcasting the tensor in the given axis (keepdim turned on) | |
* \param src source | |
* \param axes broadcasting axes | |
* \param sizes sizes of the broadcasting axes | |
* \tparam SrcExp source expression | |
* \tparam DType data type | |
* \tparam etype type of the expression | |
* \tparam TShape the flexible shape type | |
*/ | |
template<typename SrcExp, typename DType, int etype, typename TShape> | |
inline BroadcastWithMultiAxesExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
broadcast_multi_axes(const Exp<SrcExp, DType, etype> &src, | |
const TShape &axes, const TShape &sizes) { | |
return BroadcastWithMultiAxesExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), axes, sizes); | |
} | |
/*! | |
* \brief Broadcasting the tensor to the target shape, | |
dimension of different sizes must be 1 in the original tensor. | |
* \param src source | |
* \param target_shape shape of the target broadcasting tensor | |
* \tparam SrcExp source expression | |
* \tparam DType data type | |
* \tparam etype type of the expression | |
* \tparam TShape the flexible shape type | |
*/ | |
template<typename SrcExp, typename DType, int etype, typename TShape> | |
inline BroadcastWithMultiAxesExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
broadcast_to(const Exp<SrcExp, DType, etype> &src, const TShape &target_shape) { | |
static const int dimsrc = ExpInfo<SrcExp>::kDim; | |
CHECK_EQ(target_shape.ndim(), dimsrc); | |
std::vector<index_t> axes_vec, sizes_vec; | |
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src.self()); | |
for (int i = 0; i < dimsrc; ++i) { | |
if (src_shape[i] != target_shape[i]) { | |
CHECK_EQ(src_shape[i], 1) << "broadcasting axis must have size 1, received shape=" | |
<< src_shape << " target_shape=" << target_shape; | |
axes_vec.push_back(i); | |
sizes_vec.push_back(target_shape[i]); | |
} | |
} | |
TShape axes = TShape(axes_vec.begin(), axes_vec.end()); | |
TShape sizes = TShape(sizes_vec.begin(), sizes_vec.end()); | |
return BroadcastWithMultiAxesExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), axes, sizes); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename SrcExp, typename DType, int dimsrc, int dimdst> | |
struct Plan<BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst>, DType> { | |
public: | |
explicit Plan(const BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst> &e) | |
: src_(MakePlan(e.src_)), dst_last_(e.dst_last_), | |
trailing_(e.trailing_), size_(e.size_), last_(e.last_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
index_t x = (i * dst_last_ + j) / trailing_ / size_; | |
index_t y = (i * dst_last_ + j) % trailing_; | |
index_t z = x * trailing_ + y; | |
return src_.Eval(z / last_, z % last_); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t dst_last_, trailing_, size_, last_; | |
}; | |
template<typename SrcExp, typename DType, int dimsrc> | |
struct Plan<BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc>, DType> { | |
public: | |
explicit Plan(const BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc> &e) | |
: src_(MakePlan(e.src_)), dst_last_(e.dst_last_), last_(e.last_), axesnum_(e.axesnum_), | |
trailings_(e.trailings_), sizes_(e.sizes_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
index_t indx = i * dst_last_ + j; | |
for (index_t p = 0; p < dimsrc; ++p) { | |
if (p >= axesnum_) { | |
break; | |
} | |
indx = (indx / trailings_[p] / sizes_[p]) * trailings_[p] + (indx % trailings_[p]); | |
} | |
return src_.Eval(indx / last_, indx % last_); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t dst_last_, last_, axesnum_; | |
const Shape<dimsrc> trailings_, sizes_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/broadcast_with_axis.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/spatial_upsampling_nearest.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file spatial_upsampling.h | |
* \brief | |
* \author Bing Xu | |
*/ | |
#ifndef MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_ | |
#define MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! \brief nearest neighboor upsampling | |
* out(x, y) = in(int(x / scale_x), int(y / scale_y)) | |
* \tparam SrcExp source expression | |
* \tparam DType data type | |
* \tparam srcdim source dimension | |
*/ | |
template<typename SrcExp, typename DType, int srcdim> | |
struct UpSamplingNearestExp : | |
public MakeTensorExp<UpSamplingNearestExp<SrcExp, DType, srcdim>, | |
SrcExp, srcdim, DType> { | |
/*! \brief source oprand */ | |
const SrcExp &src_; | |
/*! \brief up sampling scale */ | |
index_t scale_; | |
/*! \brief constructor */ | |
UpSamplingNearestExp(const SrcExp &src, index_t scale) | |
: src_(src), scale_(scale) { | |
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_); | |
this->shape_[srcdim - 2] *= scale_; | |
this->shape_[srcdim - 1] *= scale_; | |
} | |
}; | |
template<typename SrcExp, typename DType, int etype> | |
inline UpSamplingNearestExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
upsampling_nearest(const Exp<SrcExp, DType, etype> &src, index_t scale) { | |
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2> | |
::Error_Expression_Does_Not_Meet_Dimension_Req(); | |
return UpSamplingNearestExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), scale); | |
} | |
template<typename SrcExp, typename DType, int srcdim> | |
struct Plan<UpSamplingNearestExp<SrcExp, DType, srcdim>, DType> { | |
public: | |
explicit Plan(const UpSamplingNearestExp<SrcExp, DType, srcdim> &e) | |
: src_(MakePlan(e.src_)), | |
scale_(e.scale_), | |
new_height_(e.shape_[srcdim - 2]), | |
src_height_(static_cast<index_t>(e.shape_[srcdim - 2] / e.scale_)) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
const index_t x = j; | |
const index_t y = i % new_height_; | |
const index_t c = i / new_height_; | |
const index_t h = static_cast<index_t>(y / scale_); | |
const index_t w = static_cast<index_t>(x / scale_); | |
return src_.Eval(c * src_height_ + h, w); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t scale_; | |
const index_t new_height_; | |
const index_t src_height_; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/spatial_upsampling_nearest.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/transpose.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file transpose.h | |
* \brief support for transpose | |
* \author Junyuan Xie | |
*/ | |
#ifndef MSHADOW_EXTENSION_TRANSPOSE_H_ | |
#define MSHADOW_EXTENSION_TRANSPOSE_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief transpose axes of a tensor | |
* input: Tensor<Device,dim>: ishape | |
* output: Tensor<Device,dimdst> oshape[a1],oshape[a2] = ishape[a2],oshape[a1] | |
* | |
* \tparam SrcExp type of source expression | |
* \tparam DType the type of elements | |
* \tparam dimsrc source dimension, assert a1 > a2 | |
* \tparam m_a1 one dimension to be swapped, encoded by dimsrc - a1 | |
* \tparam a2 second dimension to be swapped, encoded by a2 | |
*/ | |
template<typename SrcExp, typename DType, int dimsrc> | |
struct TransposeExExp: | |
public MakeTensorExp<TransposeExExp<SrcExp, DType, dimsrc>, | |
SrcExp, dimsrc, DType> { | |
/*! \brief source expression */ | |
const SrcExp &src_; | |
const Shape<dimsrc> axes_; | |
Shape<dimsrc> dst_in_src_stride_; // Holds the corresponding stride of the dst axes in src | |
index_t src_stride_; | |
/*! \brief constructor */ | |
explicit TransposeExExp(const SrcExp &src, Shape<dimsrc> axes) : src_(src), axes_(axes) { | |
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src); | |
src_stride_ = src_shape[dimsrc - 1]; | |
Shape<dimsrc> src_stride; | |
src_stride[dimsrc-1] = 1; | |
for (int i = dimsrc-2; i >= 0; --i) src_stride[i] = src_shape[i+1]*src_stride[i+1]; | |
for (int i = 0; i < dimsrc; ++i) { | |
dst_in_src_stride_[i] = src_stride[axes[i]]; | |
this->shape_[i] = src_shape[axes[i]]; | |
} | |
} | |
}; | |
/*! | |
* \brief a expression that reshapes a tensor to another shape | |
* \param src Tensor<Device,dimsrc>: | |
* \return a expresion with type Tensor<Device,dimdst> | |
* \tparam a1 higher dimension to be swapped, assert a1 > a2 | |
* \tparam a2 lower dimension to be swapped | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype source expression type | |
*/ | |
template<typename SrcExp, typename DType, int etype> | |
inline TransposeExExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> | |
transpose(const Exp<SrcExp, DType, etype> &src, Shape<ExpInfo<SrcExp>::kDim> axes) { | |
return TransposeExExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), axes); | |
} | |
template<typename SrcExp, typename DType, int dimsrc> | |
struct Plan<TransposeExExp<SrcExp, DType, dimsrc>, DType> { | |
public: | |
explicit Plan(const TransposeExExp<SrcExp, DType, dimsrc> &e) | |
: src_(MakePlan(e.src_)), | |
src_stride_(e.src_stride_), | |
dst_in_src_stride_(e.dst_in_src_stride_), | |
dst_shape_(e.shape_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
index_t idx = j * dst_in_src_stride_[dimsrc - 1]; | |
#pragma unroll | |
for (int k = dimsrc-2; k >= 0; --k) { | |
idx += (i % dst_shape_[k]) * dst_in_src_stride_[k]; | |
i /= dst_shape_[k]; | |
} | |
return src_.Eval(idx/src_stride_, idx%src_stride_); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t src_stride_; | |
const Shape<dimsrc> dst_in_src_stride_, dst_shape_; | |
}; | |
/*! | |
* \brief transform contiguous indices of the source tensor to indices of the transposed tensor. | |
* input: Tensor<Device, k>: ishape | |
* output: Tensor<Device, k>: oshape = ishape | |
* | |
* \tparam SrcExp type of source expression | |
* \tparam DType the type of elements | |
* \tparam dimsrc source dimension | |
* \tparam etype source type | |
*/ | |
template<typename SrcExp, typename DType, int dimsrc, int etype> | |
struct TransposeIndicesExp: | |
public Exp<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType, etype> { | |
/*! \brief source expression */ | |
const SrcExp &src_indices_; // Expression of the source indices | |
Shape<dimsrc> src_shape_; // Holds the corresponding stride of the source axes in dst | |
const Shape<dimsrc> axes_; // The transpose axes | |
Shape<dimsrc> src_in_dst_stride_; // Holds the corresponding stride of the source axes in dst | |
/*! \brief constructor */ | |
explicit TransposeIndicesExp(const SrcExp &src_indices, | |
Shape<dimsrc> src_shape, | |
Shape<dimsrc> axes) : src_indices_(src_indices), | |
src_shape_(src_shape), axes_(axes) { | |
Shape<dimsrc> dst_shape_; | |
Shape<dimsrc> dst_stride_; | |
bool axes_checking_flag[dimsrc] = { 0 }; | |
for (int i = 0; i < dimsrc; ++i) { | |
CHECK_LT(axes[i], dimsrc) | |
<< "Invalid axes input! All elements of axes must be between 0 and " << dimsrc | |
<< ", find axes=" << axes; | |
dst_shape_[i] = src_shape[axes[i]]; | |
axes_checking_flag[axes[i]] = true; | |
} | |
// check if the input axes is valid | |
for (int i = 0; i < dimsrc; ++i) { | |
CHECK_EQ(axes_checking_flag[i], true) | |
<< "Invalid axes input! All elements of axes must be between 0 and " << dimsrc | |
<< ", find axes=" << axes; | |
} | |
dst_stride_[dimsrc - 1] = 1; | |
for (int i = dimsrc - 2; i >= 0; --i) dst_stride_[i] = dst_shape_[i+1] * dst_stride_[i+1]; | |
for (int i = 0; i < dimsrc; ++i) { | |
src_in_dst_stride_[axes[i]] = dst_stride_[i]; | |
} | |
} | |
}; | |
/*! | |
* \brief a expression that reshapes a tensor to another shape | |
* \param src Tensor<Device,dimsrc>: | |
* \return a expresion with type Tensor<Device,dimdst> | |
* \tparam a1 higher dimension to be swapped, assert a1 > a2 | |
* \tparam a2 lower dimension to be swapped | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype source expression type | |
*/ | |
template<typename SrcExp, typename DType, int dimsrc, int etype> | |
inline TransposeIndicesExp<SrcExp, DType, dimsrc, etype> | |
transpose_indices(const Exp<SrcExp, DType, etype> &src_indices, | |
Shape<dimsrc> src_shape, | |
Shape<dimsrc> axes) { | |
return TransposeIndicesExp<SrcExp, DType, dimsrc, etype>(src_indices.self(), src_shape, axes); | |
} | |
template<typename SrcExp, typename DType, int dimsrc, int etype> | |
struct Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType> { | |
public: | |
explicit Plan(const TransposeIndicesExp<SrcExp, DType, dimsrc, etype> &e) | |
: src_indices_(MakePlan(e.src_indices_)), | |
src_in_dst_stride_(e.src_in_dst_stride_), | |
src_shape_(e.src_shape_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
index_t src_idx = static_cast<index_t>(src_indices_.Eval(i, j)); | |
index_t dst_idx = 0; | |
#pragma unroll | |
for (int k = dimsrc - 1; k >= 0; --k) { | |
dst_idx += (src_idx % src_shape_[k]) * src_in_dst_stride_[k]; | |
src_idx /= src_shape_[k]; | |
} | |
return static_cast<DType>(dst_idx); | |
} | |
private: | |
Plan<SrcExp, DType> src_indices_; | |
const Shape<dimsrc> src_in_dst_stride_, src_shape_; | |
}; | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
/*! \brief make expression */ | |
template<typename SrcExp, typename DType, int dimsrc, int etype> | |
inline Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType> | |
MakePlan(const TransposeIndicesExp<SrcExp, DType, dimsrc, etype> &e) { | |
return Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType>(e); | |
} | |
template<int dim, typename SrcExp, typename DType, int dimsrc, int etype> | |
struct ShapeCheck<dim, TransposeIndicesExp<SrcExp, DType, dimsrc, etype> > { | |
inline static Shape<dim> | |
Check(const TransposeIndicesExp<SrcExp, DType, dimsrc, etype> &t) { | |
Shape<dim> s = ShapeCheck<dim, SrcExp>::Check(t.src_indices_); | |
return s; | |
} | |
}; | |
template<typename SrcExp, typename DType, int dimsrc, int etype> | |
struct ExpInfo<TransposeIndicesExp<SrcExp, DType, dimsrc, etype> > { | |
static const int kDim = ExpInfo<SrcExp>::kDim; | |
static const int kDevMask = ExpInfo<SrcExp>::kDevMask; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_TRANSPOSE_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/transpose.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/flip.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file flip.h | |
* \brief support for flip a certain dimension. | |
* \author Junyuan Xie | |
*/ | |
#ifndef MSHADOW_EXTENSION_FLIP_H_ | |
#define MSHADOW_EXTENSION_FLIP_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief slice expression, slice a tensor's channel | |
* \tparam SrcExp left expression | |
* \tparam DType the type of elements | |
* \tparam srcdim dimension of src | |
* \tparam dimsrc_m_cat dimsrc - dimcat | |
*/ | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
struct FlipExp : public TRValue<FlipExp<SrcExp, | |
Device, DType, | |
srcdim>, | |
Device, srcdim, DType> { | |
const SrcExp &src_; | |
index_t trailing_; | |
index_t stride_; | |
index_t stride_j_; | |
Shape<srcdim> shape_; | |
FlipExp(const SrcExp &src, int dim) | |
: src_(src) { | |
shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_); | |
stride_ = shape_[dim]; | |
stride_j_ = shape_[srcdim-1]; | |
trailing_ = 1; | |
for (int i = dim + 1; i < srcdim; ++i) { | |
trailing_ *= shape_[i]; | |
} | |
} | |
template<typename E, int etype> | |
inline void | |
operator=(const expr::Exp<E, DType, etype> &exp) { | |
this->__assign(exp); | |
} | |
inline void | |
operator=(const DType &exp) { | |
this->__assign(exp); | |
} | |
}; // struct Flip | |
/*! | |
* \brief Flip a Tensor | |
* \param src source tensor | |
* \param begin The beginning slice. | |
* \param end The end slice. | |
* \return sliced tensor | |
* \tparam sdim the dimension to slice on | |
* \tparam SrcExp source expression | |
* \tparam DType the type of elements | |
* \tparam etype type of expression | |
*/ | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
inline FlipExp<SrcExp, Device, DType, srcdim> | |
flip(const TRValue<SrcExp, Device, srcdim, DType> &src, int dim) { | |
return FlipExp<SrcExp, Device, DType, srcdim>(src.self(), dim); | |
} | |
//------------------------ | |
// engine plugin | |
//------------------------ | |
// runtime shapecheck | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
struct ShapeCheck<srcdim, FlipExp<SrcExp, Device, DType, srcdim> >{ | |
inline static Shape<srcdim> Check(const FlipExp<SrcExp, | |
Device, DType, srcdim> &t) { | |
return t.shape_; | |
} | |
}; | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
struct StreamInfo<Device, FlipExp<SrcExp, Device, DType, srcdim> >{ | |
inline static Stream<Device> * | |
Get(const FlipExp<SrcExp, Device, DType, srcdim> &t) { | |
return StreamInfo<Device, SrcExp>::Get(t.src_); | |
} | |
}; | |
// static typecheck | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
struct ExpInfo<FlipExp<SrcExp, Device, DType, srcdim> >{ | |
static const int kDim = ExpInfo<SrcExp>::kDim; | |
static const int kDevMask = ExpInfo<SrcExp>::kDevMask; | |
}; | |
//---------------------- | |
// Execution plan | |
//--------------------- | |
template<typename SrcExp, typename Device, | |
typename DType, int srcdim> | |
struct Plan<FlipExp<SrcExp, Device, DType, srcdim>, DType> { | |
public: | |
explicit Plan(const FlipExp<SrcExp, Device, DType, srcdim> &e) | |
: src_(MakePlan(e.src_)), stride_j_(e.stride_j_), | |
trailing_(e.trailing_), stride_(e.stride_) {} | |
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { | |
index_t idx = i*stride_j_+j; | |
const index_t low = idx%trailing_; | |
index_t high = idx/trailing_; | |
const index_t x = high%stride_; | |
high /= stride_; | |
idx = (high*stride_+stride_-1-x)*trailing_+low; | |
return src_.Eval(idx/stride_j_, idx%stride_j_); | |
} | |
MSHADOW_XINLINE DType &REval(index_t i, index_t j) const { | |
index_t idx = i*stride_j_+j; | |
const index_t low = idx%trailing_; | |
index_t high = idx/trailing_; | |
const index_t x = high%stride_; | |
high /= stride_; | |
idx = (high*stride_+stride_-1-x)*trailing_+low; | |
return src_.REval(idx/stride_j_, idx%stride_j_); | |
} | |
private: | |
Plan<SrcExp, DType> src_; | |
const index_t stride_j_, trailing_, stride_; | |
}; // struct Plan | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_FLIP_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/flip.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/complex.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file complex.h | |
* \brief support for complex operations | |
* \author Xingjian Shi | |
*/ | |
#ifndef MSHADOW_EXTENSION_COMPLEX_H_ | |
#define MSHADOW_EXTENSION_COMPLEX_H_ | |
namespace mshadow { | |
namespace op { | |
namespace complex { | |
enum BinaryCalculationType { kBinaryCC, kBinaryCR, kBinaryRC}; | |
enum UnitaryCalculationType { kUnitaryC2R, kUnitaryC2C }; | |
struct mul { | |
/*! \brief map a_real, a_imag, b_real, b_imag to result using defined operation */ | |
template<typename DType> | |
MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag, | |
DType b_real, DType b_imag) { | |
return a_real * b_real - a_imag * b_imag; | |
} | |
template<typename DType> | |
MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag, | |
DType b_real, DType b_imag) { | |
return a_real * b_imag + b_real * a_imag; | |
} | |
}; | |
struct div { | |
/*! \brief map a_real, a_imag, b_real, b_imag to result using defined operation */ | |
template<typename DType> | |
MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag, | |
DType b_real, DType b_imag) { | |
return (a_real * b_real + a_imag * b_imag) / (b_real * b_real + b_imag * b_imag); | |
} | |
template<typename DType> | |
MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag, | |
DType b_real, DType b_imag) { | |
return (b_real * a_imag - a_real * b_imag) / (b_real * b_real + b_imag * b_imag); | |
} | |
}; | |
struct conjugate { | |
template<typename TA, typename DType> | |
MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_, | |
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { | |
return src_.Eval(real_i, real_j); | |
} | |
template<typename TA, typename DType> | |
MSHADOW_XINLINE static DType ImagMap(const expr::Plan<TA, DType> &src_, | |
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { | |
return -src_.Eval(imag_i, imag_j); | |
} | |
}; | |
struct exchange { | |
template<typename TA, typename DType> | |
MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_, | |
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { | |
return src_.Eval(imag_i, imag_j); | |
} | |
template<typename TA, typename DType> | |
MSHADOW_XINLINE static DType ImagMap(const expr::Plan<TA, DType> &src_, | |
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { | |
return src_.Eval(real_i, real_j); | |
} | |
}; | |
struct abs_square { | |
template<typename TA, typename DType> | |
MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_, | |
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { | |
DType real_val = src_.Eval(real_i, real_j); | |
DType image_val = src_.Eval(imag_i, imag_j); | |
return real_val * real_val + image_val * image_val; | |
} | |
}; | |
struct sum_real_imag { | |
template<typename TA, typename DType> | |
MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_, | |
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { | |
DType real_val = src_.Eval(real_i, real_j); | |
DType image_val = src_.Eval(imag_i, imag_j); | |
return real_val + image_val; | |
} | |
}; | |
} // namespace complex | |
} // namespace op | |
namespace expr { | |
//-------------------- | |
// ComplexBinaryMapExp | |
//-------------------- | |
/*! | |
* \brief binary map expression lhs [op] rhs where lhs and rhs are complex tensors | |
* \tparam OP operator | |
* \tparam calctype type of the calculation | |
* \tparam TA type of lhs | |
* \tparam TB type of rhs | |
* \tparam etype expression type, sa namespace::type | |
*/ | |
template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype> | |
struct ComplexBinaryMapExp : public Exp<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>, | |
DType, etype> { | |
/*! \brief left operand */ | |
const TA &lhs_; | |
/*! \brief right operand */ | |
const TB &rhs_; | |
/*! \brief constructor */ | |
explicit ComplexBinaryMapExp(const TA &lhs, const TB &rhs) | |
:lhs_(lhs), rhs_(rhs) {} | |
}; | |
//------------------- | |
// ComplexConjExp | |
//------------------- | |
/*! | |
* \brief compute conj(src) where src is a complex tensor | |
* \tparam TA type of src | |
* \tparam etype expression type, sa namespace::type | |
*/ | |
template<int calctype, typename OP, typename TA, typename DType, int etype> | |
struct ComplexUnitaryExp : public Exp<ComplexUnitaryExp<calctype, OP, TA, DType, etype>, | |
DType, etype> { | |
/*! \brief source expression */ | |
const TA &src_; | |
/*! \brief constructor */ | |
explicit ComplexUnitaryExp(const TA &src) : src_(src) {} | |
}; | |
template<int calctype, typename OP, typename TA, typename TB, typename DType, int ta, int tb> | |
inline ComplexBinaryMapExp<calctype, OP, TA, TB, DType, (ta | tb | type::kMapper)> | |
ComplexF(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return ComplexBinaryMapExp<calctype, OP, TA, TB, DType, | |
(ta | tb | type::kMapper)>(lhs.self(), rhs.self()); | |
} | |
/*! | |
* \brief conj Negation the imaginary part of A where A is a complex tensor | |
* \param src source tensor | |
* \tparam e1 type of source expression | |
*/ | |
template<int calctype, typename OP, typename SrcExp, typename DType, int e1> | |
inline ComplexUnitaryExp<calctype, OP, SrcExp, DType, (e1 | type::kMapper)> | |
ComplexF(const Exp<SrcExp, DType, e1> &src) { | |
return ComplexUnitaryExp<calctype, OP, SrcExp, DType, (e1 | type::kMapper)>(src.self()); | |
} | |
/*! | |
* \brief complex_mul_cc Complex multipilication two complex tensors, A * B | |
*/ | |
template<typename TA, typename TB, typename DType, int ta, int tb> | |
inline ComplexBinaryMapExp<op::complex::kBinaryCC, op::complex::mul, | |
TA, TB, DType, (ta | tb | type::kMapper)> | |
complex_mul_cc(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return ComplexF<op::complex::kBinaryCC, op::complex::mul>(lhs, rhs); | |
} | |
/*! | |
* \brief complex_mul_cr Complex multipilication a complex tensor A and a real tensor B | |
*/ | |
template<typename TA, typename TB, typename DType, int ta, int tb> | |
inline ComplexBinaryMapExp<op::complex::kBinaryCR, op::complex::mul, | |
TA, TB, DType, (ta | tb | type::kMapper)> | |
complex_mul_cr(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return ComplexF<op::complex::kBinaryCR, op::complex::mul>(lhs, rhs); | |
} | |
/*! | |
* \brief complex_mul_rc Complex multipilication of a real tensor B and a complex tensor A | |
*/ | |
template<typename TA, typename TB, typename DType, int ta, int tb> | |
inline ComplexBinaryMapExp<op::complex::kBinaryRC, op::complex::mul, | |
TA, TB, DType, (ta | tb | type::kMapper)> | |
complex_mul_rc(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return ComplexF<op::complex::kBinaryRC, op::complex::mul>(lhs, rhs); | |
} | |
/*! | |
* \brief complex_mul_cc Complex multipilication two complex tensors, A * B | |
*/ | |
template<typename TA, typename TB, typename DType, int ta, int tb> | |
inline ComplexBinaryMapExp<op::complex::kBinaryCC, op::complex::div, | |
TA, TB, DType, (ta | tb | type::kMapper)> | |
complex_div_cc(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return ComplexF<op::complex::kBinaryCC, op::complex::div>(lhs, rhs); | |
} | |
/*! | |
* \brief complex_mul_cr Complex multipilication a complex tensor A and a real tensor B | |
*/ | |
template<typename TA, typename TB, typename DType, int ta, int tb> | |
inline ComplexBinaryMapExp<op::complex::kBinaryCR, op::complex::div, | |
TA, TB, DType, (ta | tb | type::kMapper)> | |
complex_div_cr(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return ComplexF<op::complex::kBinaryCR, op::complex::div>(lhs, rhs); | |
} | |
/*! | |
* \brief complex_mul_rc Complex multipilication of a real tensor A and a complex tensor B | |
*/ | |
template<typename TA, typename TB, typename DType, int ta, int tb> | |
inline ComplexBinaryMapExp<op::complex::kBinaryRC, op::complex::div, | |
TA, TB, DType, (ta | tb | type::kMapper)> | |
complex_div_rc(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) { | |
return ComplexF<op::complex::kBinaryRC, op::complex::div>(lhs, rhs); | |
} | |
/*! | |
* \brief conj Negation the imaginary part of A where A is a complex tensor | |
* \param src source tensor | |
* \tparam e1 type of source expression | |
*/ | |
template<typename SrcExp, typename DType, int e1> | |
inline ComplexUnitaryExp<op::complex::kUnitaryC2C, op::complex::conjugate, | |
SrcExp, DType, (e1|type::kMapper)> | |
conj(const Exp<SrcExp, DType, e1> &src) { | |
return ComplexF<op::complex::kUnitaryC2C, op::complex::conjugate>(src); | |
} | |
/*! | |
* \brief complex_exchange Exchange the real and imaginary part of A where A is a complex tensor | |
* \param src source tensor | |
* \tparam e1 type of source expression | |
*/ | |
template<typename SrcExp, typename DType, int e1> | |
inline ComplexUnitaryExp<op::complex::kUnitaryC2C, op::complex::exchange, | |
SrcExp, DType, (e1|type::kMapper)> | |
complex_exchange(const Exp<SrcExp, DType, e1> &src) { | |
return ComplexF<op::complex::kUnitaryC2C, op::complex::exchange>(src); | |
} | |
/*! | |
* \brief complex_abs_square calculate the square of the modulus of A where A is a complex tensor | |
* \param src source tensor | |
* \tparam e1 type of source expression | |
*/ | |
template<typename SrcExp, typename DType, int e1> | |
inline ComplexUnitaryExp<op::complex::kUnitaryC2R, op::complex::abs_square, | |
SrcExp, DType, (e1 | type::kMapper)> | |
complex_abs_square(const Exp<SrcExp, DType, e1> &src) { | |
return ComplexF<op::complex::kUnitaryC2R, op::complex::abs_square>(src); | |
} | |
template<typename SrcExp, typename DType, int e1> | |
inline ComplexUnitaryExp<op::complex::kUnitaryC2R, op::complex::sum_real_imag, | |
SrcExp, DType, (e1 | type::kMapper)> | |
complex_sum_real_imag(const Exp<SrcExp, DType, e1> &src) { | |
return ComplexF<op::complex::kUnitaryC2R, op::complex::sum_real_imag>(src); | |
} | |
template<int dim, int calctype, typename OP, typename TA, typename TB, | |
typename DType, int etype> | |
struct ShapeCheck<dim, ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> > { | |
inline static Shape<dim> | |
Check(const ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> &t) { | |
Shape<dim> shape1 = ShapeCheck<dim, TA>::Check(t.lhs_); | |
Shape<dim> shape2 = ShapeCheck<dim, TB>::Check(t.rhs_); | |
if (shape1[0] == 0) return shape2; | |
if (shape2[0] == 0) return shape1; | |
if (calctype == op::complex::kBinaryCC) { | |
CHECK_EQ(shape1, shape2) << "ComplexBinaryMapExp (CC): Shapes of operands are not the same."; | |
CHECK_EQ(shape1[dim - 1] % 2, 0) << | |
"ComplexBinaryMapExp (CC): Shape of the last dimension is not even. " | |
"We must have real part + imaginary part."; | |
return shape1; | |
} else if (calctype == op::complex::kBinaryCR) { | |
for (int i = 0; i < dim - 1; ++i) { | |
CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) << | |
"ComplexBinaryMapExp (CR): Shapes of operands are not the same."; | |
} | |
CHECK_EQ(shape1[dim - 1], shape2[dim - 1] * 2) << | |
"ComplexBinaryMapExp (CR): Shapes of operands do not match."; | |
return shape1; | |
} else if (calctype == op::complex::kBinaryRC) { | |
for (int i = 0; i < dim - 1; ++i) { | |
CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) << | |
"ComplexBinaryMapExp (RC): Shapes of operands are not the same."; | |
} | |
CHECK_EQ(shape2[dim - 1], shape1[dim - 1] * 2) << | |
"ComplexBinaryMapExp (RC): Shapes of operands do not match."; | |
return shape2; | |
} else { | |
LOG(FATAL) << "ComplexBinaryMapExp: Unexpected Calculation Type!"; | |
return shape1; | |
} | |
} | |
}; | |
template<int dim, int calctype, typename OP, typename TA, typename DType, int etype> | |
struct ShapeCheck<dim, ComplexUnitaryExp<calctype, OP, TA, DType, etype> > { | |
inline static Shape<dim> Check(const ComplexUnitaryExp<calctype, OP, TA, DType, etype> &t) { | |
Shape<dim> s = ShapeCheck<dim, TA>::Check(t.src_); | |
CHECK_EQ(s[dim - 1] % 2, 0) << "ComplexUnitaryExp: Shape of the last dimension is not even. " | |
"We must have real + imaginary."; | |
if (calctype == op::complex::kUnitaryC2C) { | |
return s; | |
} else if (calctype == op::complex::kUnitaryC2R) { | |
Shape<dim> s_ret = s; | |
s_ret[dim - 1] /= 2; | |
return s_ret; | |
} else { | |
LOG(FATAL) << "ComplexUnitaryExp: Unexpected Calculation Type!"; | |
return s; | |
} | |
} | |
}; | |
// complex binary expression (cc) | |
template<typename OP, typename TA, typename TB, int etype, typename DType> | |
class Plan<ComplexBinaryMapExp<op::complex::kBinaryCC, OP, TA, TB, DType, etype>, DType> { | |
public: | |
explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs) | |
: lhs_(lhs), rhs_(rhs) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
const index_t base_x = static_cast<index_t>(x / 2) * 2; | |
if (x % 2 == 0) { | |
return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), | |
rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); | |
} else { | |
return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), | |
rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); | |
} | |
} | |
private: | |
Plan<TA, DType> lhs_; | |
Plan<TB, DType> rhs_; | |
}; | |
// complex binary expression (cr) | |
template<typename OP, typename TA, typename TB, int etype, typename DType> | |
class Plan<ComplexBinaryMapExp<op::complex::kBinaryCR, OP, TA, TB, DType, etype>, DType> { | |
public: | |
explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs) | |
: lhs_(lhs), rhs_(rhs) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
const index_t base_x = static_cast<index_t>(x / 2) * 2; | |
if (x % 2 == 0) { | |
return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), | |
rhs_.Eval(y, base_x / 2), static_cast<DType>(0)); | |
} else { | |
return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), | |
rhs_.Eval(y, base_x / 2), static_cast<DType>(0)); | |
} | |
} | |
private: | |
Plan<TA, DType> lhs_; | |
Plan<TB, DType> rhs_; | |
}; | |
// complex binary expression (rc) | |
template<typename OP, typename TA, typename TB, int etype, typename DType> | |
class Plan<ComplexBinaryMapExp<op::complex::kBinaryRC, OP, TA, TB, DType, etype>, DType> { | |
public: | |
explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs) | |
: lhs_(lhs), rhs_(rhs) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
const index_t base_x = static_cast<index_t>(x / 2) * 2; | |
if (x % 2 == 0) { | |
return OP::RealMap(lhs_.Eval(y, base_x / 2), static_cast<DType>(0), | |
rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); | |
} else { | |
return OP::ImagMap(lhs_.Eval(y, base_x / 2), static_cast<DType>(0), | |
rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); | |
} | |
} | |
private: | |
Plan<TA, DType> lhs_; | |
Plan<TB, DType> rhs_; | |
}; | |
// complex unitary expression (c2c) | |
template<typename OP, typename TA, int etype, typename DType> | |
class Plan<ComplexUnitaryExp<op::complex::kUnitaryC2C, OP, TA, DType, etype>, DType> { | |
public: | |
explicit Plan(const Plan<TA, DType> &src) : src_(src) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
const index_t base_x = static_cast<index_t>(x / 2) * 2; | |
if (0 == x % 2) { | |
return OP::RealMap(src_, y, base_x, y, base_x + 1); | |
} else { | |
return OP::ImagMap(src_, y, base_x, y, base_x + 1); | |
} | |
} | |
private: | |
Plan<TA, DType> src_; | |
}; | |
// complex unitary expression (c2r) | |
template<typename OP, typename TA, int etype, typename DType> | |
class Plan<ComplexUnitaryExp<op::complex::kUnitaryC2R, OP, TA, DType, etype>, DType> { | |
public: | |
explicit Plan(const Plan<TA, DType> &src) : src_(src) {} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return OP::RealMap(src_, y, x * 2, y, x * 2 + 1); | |
} | |
private: | |
Plan<TA, DType> src_; | |
}; | |
template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype> | |
inline Plan<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>, DType> | |
MakePlan(const ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> &e) { | |
return Plan<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>, | |
DType>(MakePlan(e.lhs_), MakePlan(e.rhs_)); | |
} | |
template<int calctype, typename OP, typename TA, typename DType, int etype> | |
inline Plan<ComplexUnitaryExp<calctype, OP, TA, DType, etype>, DType> | |
MakePlan(const ComplexUnitaryExp<calctype, OP, TA, DType, etype> &e) { | |
return Plan<ComplexUnitaryExp<calctype, OP, TA, DType, etype>, | |
DType>(MakePlan(e.src_)); | |
} | |
template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype> | |
struct ExpInfo<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> > { | |
static const int kDimLhs = ExpInfo<TA>::kDim; | |
static const int kDimRhs = ExpInfo<TB>::kDim; | |
static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ? \ | |
(kDimLhs == 0 ? \ | |
kDimRhs : \ | |
((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1; | |
static const int kDevMask = ExpInfo<TA>::kDevMask & ExpInfo<TB>::kDevMask; | |
}; | |
template<int calctype, typename OP, typename TA, typename DType, int etype> | |
struct ExpInfo<ComplexUnitaryExp<calctype, OP, TA, DType, etype> > { | |
static const int kDim = ExpInfo<TA>::kDim; | |
static const int kDevMask = ExpInfo<TA>::kDevMask; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_COMPLEX_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/complex.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/range.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file range.h | |
* \brief support generating a range vector | |
* \author Xingjian Shi | |
*/ | |
#ifndef MSHADOW_EXTENSION_RANGE_H_ | |
#define MSHADOW_EXTENSION_RANGE_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! | |
* \brief Generate a range vector similar to python: range(start, stop[, step][, repeat]). | |
If step is positive, the last element is the largest start + i * step less than stop | |
If step is negative, the last element is the smallest start + i * step greater than stop. | |
All elements are repeated for `repeat` times, e.g range(0, 4, 2, 3) --> 0, 0, 0, 2, 2, 2 | |
* \tparam SrcExp type of lhs expression | |
* \tparam IndexExp type of index expression | |
* \tparam DType the type of elements | |
*/ | |
template<typename DType> | |
struct RangeExp: | |
public Exp<RangeExp<DType>, DType, type::kMapper> { | |
const float start_; | |
const float stop_; | |
const float step_; | |
const int repeat_; | |
/*! \brief constructor */ | |
RangeExp(float start, float stop, float step, int repeat) | |
: start_(start), stop_(stop), step_(step), repeat_(repeat) {} | |
}; | |
template<typename DType> | |
inline RangeExp<DType> | |
range(float start, float stop, float step = 1, int repeat = 1) { | |
return RangeExp<DType>(start, stop, step, repeat); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename DType> | |
struct Plan<RangeExp<DType>, DType> { | |
public: | |
explicit Plan(const RangeExp<DType> &e) | |
: start_(e.start_), | |
stop_(e.stop_), | |
step_(e.step_), | |
repeat_(e.repeat_) { | |
} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return static_cast<DType>(start_ + | |
static_cast<float>((static_cast<int>(x) / repeat_)) * step_); | |
} | |
private: | |
const float start_; | |
const float stop_; | |
const float step_; | |
const int repeat_; | |
}; | |
template<typename DType> | |
inline Plan<RangeExp<DType>, DType> | |
MakePlan(const RangeExp<DType> &exp) { | |
return Plan<RangeExp<DType>, DType>(exp); | |
} | |
template<int dim, typename DType> | |
struct ShapeCheck<dim, RangeExp<DType> > { | |
inline static Shape<dim> | |
Check(const RangeExp<DType> &t) { | |
CHECK(dim == 1) | |
<< "RangeExp only support 1 dimension output, received " << dim; | |
CHECK(t.step_ != 0) | |
<< "RangeExp does not support step=0, received " << t.step_; | |
CHECK(t.repeat_ > 0) | |
<< "RangeExp only supports repeat > 0, received " << t.repeat_; | |
if (t.step_ > 0) { | |
CHECK(t.start_ < t.stop_) << "RangeExp does not support (start, stop, step) = " | |
<< "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")"; | |
return Shape1(t.repeat_ * ceil((t.stop_ - t.start_) / t.step_)); | |
} else { | |
CHECK(t.start_ > t.stop_) << "RangeExp does not support (start, stop, step)= " | |
<< "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")"; | |
return Shape1(t.repeat_ * ceil((t.stop_ - t.start_) / t.step_)); | |
} | |
} | |
}; | |
template<typename DType> | |
struct ExpInfo<RangeExp<DType> > { | |
static const int kDim = 1; | |
static const int kDevMask = 0xffff; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_RANGE_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/range.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/extension/mask.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file mask.h | |
* \brief | |
* \author Bing Xu | |
*/ | |
#ifndef MSHADOW_EXTENSION_MASK_H_ | |
#define MSHADOW_EXTENSION_MASK_H_ | |
namespace mshadow { | |
namespace expr { | |
/*! \brief Broadcast a mask and do element-wise multiplication | |
* \tparam IndexExp type of index expression | |
* \tparam SrcExp type of src expression | |
* \tparam DType data type | |
*/ | |
template<typename IndexExp, typename SrcExp, typename DType> | |
struct MaskExp: public Exp<MaskExp<IndexExp, SrcExp, DType>, | |
DType, type::kChainer> { | |
/*! \brief index oprand */ | |
const IndexExp &index_; | |
/*! \brief matrix oprand */ | |
const SrcExp &src_; | |
/*! constructor */ | |
MaskExp(const IndexExp &index, const SrcExp &src) | |
: index_(index), src_(src) {} | |
}; // struct MaskExp | |
template<typename IndexExp, | |
typename SrcExp, | |
typename DType, | |
int e1, int e2> | |
inline MaskExp<IndexExp, SrcExp, DType> | |
mask(const Exp<IndexExp, DType, e1> &index, | |
const Exp<SrcExp, DType, e2> &src) { | |
return MaskExp<IndexExp, SrcExp, DType>(index.self(), src.self()); | |
} | |
//---------------------- | |
// Execution plan | |
//---------------------- | |
template<typename IndexExp, typename SrcExp, typename DType> | |
struct Plan<MaskExp<IndexExp, SrcExp, DType>, DType> { | |
public: | |
explicit Plan(const MaskExp<IndexExp, SrcExp, DType> &e) | |
: index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) { | |
} | |
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | |
return static_cast<DType>(src_.Eval(y, x) * index_.Eval(0, y)); | |
} | |
private: | |
expr::Plan<IndexExp, DType> index_; | |
expr::Plan<SrcExp, DType> src_; | |
}; // struct Plan | |
template<typename IndexExp, typename SrcExp, typename DType> | |
inline Plan<MaskExp<IndexExp, SrcExp, DType>, DType> | |
MakePlan(const MaskExp<IndexExp, SrcExp, DType> &exp) { | |
return Plan<MaskExp<IndexExp, SrcExp, DType>, DType>(exp); | |
} | |
template<int dim, typename IndexExp, typename SrcExp, typename DType> | |
struct ShapeCheck<dim, MaskExp<IndexExp, SrcExp, DType> > { | |
inline static Shape<dim> | |
Check(const MaskExp<IndexExp, SrcExp, DType> &t) { | |
CHECK(dim == 2) | |
<< "MaskExp only support 2D output"; | |
Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); | |
Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_); | |
CHECK_EQ(dshape[0], wshape[0]) << "MaskExp require inputs in same first dimention"; | |
Shape<dim> ret; | |
ret[0] = wshape[0]; | |
ret[1] = wshape[1]; | |
return ret; | |
} | |
}; | |
template<typename IndexExp, typename SrcExp, typename DType> | |
struct ExpInfo<MaskExp<IndexExp, SrcExp, DType> > { | |
static const int kDim = 2; | |
static const int kDevMask = ExpInfo<IndexExp>::kDevMask; | |
}; | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXTENSION_MASK_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension/mask.h ===== | |
#endif // MSHADOW_EXTENSION_H_ | |
//===== EXPANDED: ../mshadow/mshadow/extension.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/tensor_cpu-inl.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file tensor_cpu-inl.h | |
* \brief implementation of CPU host code | |
* \author Bing Xu, Tianqi Chen | |
*/ | |
#ifndef MSHADOW_TENSOR_CPU_INL_H_ | |
#define MSHADOW_TENSOR_CPU_INL_H_ | |
namespace mshadow { | |
template<> | |
inline void InitTensorEngine<cpu>(int dev_id) { | |
} | |
template<> | |
inline void ShutdownTensorEngine<cpu>(void) { | |
} | |
template<> | |
inline void SetDevice<cpu>(int devid) { | |
} | |
template<> | |
inline Stream<cpu> *NewStream<cpu>(bool create_blas_handle, | |
bool create_dnn_handle) { | |
return new Stream<cpu>(); | |
} | |
template<> | |
inline void DeleteStream<cpu>(Stream<cpu> *stream) { | |
delete stream; | |
} | |
template<int ndim> | |
inline std::ostream &operator<<(std::ostream &os, const Shape<ndim> &shape) { // NOLINT(*) | |
os << '('; | |
for (int i = 0; i < ndim; ++i) { | |
if (i != 0) os << ','; | |
os << shape[i]; | |
} | |
// python style tuple | |
if (ndim == 1) os << ','; | |
os << ')'; | |
return os; | |
} | |
template<typename xpu> | |
inline void *AllocHost_(size_t size); | |
template<typename xpu> | |
inline void FreeHost_(void * dptr); | |
#ifdef __CUDACC__ | |
template<> | |
inline void *AllocHost_<gpu>(size_t size) { | |
void *dptr; | |
MSHADOW_CUDA_CALL(cudaMallocHost(&dptr, size, cudaHostAllocPortable)); | |
return dptr; | |
} | |
template<> | |
inline void FreeHost_<gpu>(void *dptr) { | |
MSHADOW_CUDA_CALL(cudaFreeHost(dptr)); | |
} | |
#endif | |
template<> | |
inline void *AllocHost_<cpu>(size_t size) { | |
size_t pitch; | |
return packet::AlignedMallocPitch(&pitch, size, 1); | |
} | |
template<> | |
inline void FreeHost_<cpu>(void *dptr) { | |
packet::AlignedFree(dptr); | |
} | |
template<typename xpu, int dim, typename DType> | |
inline void AllocHost(Tensor<cpu, dim, DType> *obj) { | |
obj->stride_ = obj->size(dim - 1); | |
CHECK_EQ(obj->CheckContiguous(), true) << "AllocHost"; | |
void *dptr = AllocHost_<xpu>(obj->MSize() * sizeof(DType)); | |
obj->dptr_ = reinterpret_cast<DType*>(dptr); | |
} | |
template<typename xpu, int dim, typename DType> | |
inline void FreeHost(Tensor<cpu, dim, DType> *obj) { | |
if (obj->dptr_ == NULL) { | |
LOG(FATAL) << "FreeHost:: double free"; | |
} | |
FreeHost_<xpu>(obj->dptr_); | |
obj->dptr_ = NULL; | |
} | |
template<int dim, typename DType> | |
inline void AllocSpace(Tensor<cpu, dim, DType> *obj, bool pad) { | |
size_t pitch; | |
void *dptr; | |
if (pad) { | |
dptr = packet::AlignedMallocPitch | |
(&pitch, obj->size(dim - 1) * sizeof(DType), obj->shape_.FlatTo2D()[0]); | |
obj->stride_ = static_cast<index_t>(pitch / sizeof(DType)); | |
} else { | |
obj->stride_ = obj->size(dim - 1); | |
dptr = packet::AlignedMallocPitch | |
(&pitch, obj->shape_.Size() * sizeof(DType), 1); | |
} | |
obj->dptr_ = reinterpret_cast<DType*>(dptr); | |
} | |
template<typename Device, typename DType, int dim> | |
inline Tensor<Device, dim, DType> | |
NewTensor(const Shape<dim> &shape, DType initv, bool pad, Stream<Device> *stream_) { | |
Tensor<Device, dim, DType> obj(shape); | |
obj.stream_ = stream_; | |
AllocSpace(&obj, pad); | |
MapExp<sv::saveto>(&obj, expr::ScalarExp<DType>(initv)); | |
return obj; | |
} | |
template<int dim, typename DType> | |
inline void FreeSpace(Tensor<cpu, dim, DType> *obj) { | |
packet::AlignedFree(obj->dptr_); | |
obj->dptr_ = NULL; | |
} | |
template<int dim, typename DType> | |
inline void Copy(Tensor<cpu, dim, DType> _dst, | |
const Tensor<cpu, dim, DType> &_src, | |
Stream<cpu> *stream) { | |
CHECK_EQ(_dst.shape_, _src.shape_) | |
<< "Copy:shape mismatch:" << _dst.shape_ << " vs " << _src.shape_; | |
if (_dst.CheckContiguous() && _src.CheckContiguous()) { | |
memcpy(_dst.dptr_, _src.dptr_, sizeof(DType) * _dst.shape_.Size()); | |
} else { | |
Tensor<cpu, 2, DType> dst = _dst.FlatTo2D(); | |
Tensor<cpu, 2, DType> src = _src.FlatTo2D(); | |
for (index_t y = 0; y < dst.size(0); ++y) { | |
memcpy(dst[y].dptr_, src[y].dptr_, sizeof(DType) * dst.size(1)); | |
} | |
} | |
} | |
template<typename Saver, typename R, int dim, | |
typename DType, typename E> | |
inline void MapPlan(TRValue<R, cpu, dim, DType> *dst, | |
const expr::Plan<E, DType> &plan) { | |
Shape<2> shape = expr::ShapeCheck<dim, R>::Check(dst->self()).FlatTo2D(); | |
expr::Plan<R, DType> dplan = expr::MakePlan(dst->self()); | |
#if (MSHADOW_USE_CUDA == 0) | |
#pragma omp parallel for | |
#endif | |
// temp remove openmp, as default setting throttles CPU | |
for (openmp_index_t y = 0; y < shape[0]; ++y) { | |
for (index_t x = 0; x < shape[1]; ++x) { | |
// trust your compiler! -_- they will optimize it | |
Saver::template Save<DType>(dplan.REval(y, x), plan.Eval(y, x)); | |
} | |
} | |
} | |
// code to handle SSE optimization | |
template<bool pass_check, typename Saver, | |
typename R, int dim, | |
typename DType, typename E, int etype> | |
struct MapExpCPUEngine { | |
inline static void Map(TRValue<R, cpu, dim, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp) { | |
MapPlan<Saver>(dst, MakePlan(exp.self())); | |
} | |
}; | |
template<typename SV, int dim, typename DType, typename E, int etype> | |
struct MapExpCPUEngine<true, SV, Tensor<cpu, dim, DType>, | |
dim, DType, E, etype> { | |
inline static void Map(Tensor<cpu, dim, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp) { | |
if (expr::PacketAlignCheck<dim, E, MSHADOW_DEFAULT_PACKET>::Check(exp.self()) && | |
expr::PacketAlignCheck<dim, Tensor<cpu, dim, DType>, MSHADOW_DEFAULT_PACKET>::Check(*dst)) { | |
expr::MapPacketPlan<SV>(dst->self(), | |
expr::MakePacketPlan<MSHADOW_DEFAULT_PACKET>(exp.self())); | |
} else { | |
MapPlan<SV>(dst, MakePlan(exp.self())); | |
} | |
} | |
}; | |
template<typename Saver, typename R, int dim, | |
typename DType, typename E, int etype> | |
inline void MapExp(TRValue<R, cpu, dim, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp) { | |
expr::TypeCheckPass<expr::TypeCheck<cpu, dim, DType, E>::kMapPass> | |
::Error_All_Tensor_in_Exp_Must_Have_Same_Type(); | |
Shape<dim> eshape = expr::ShapeCheck<dim, E>::Check(exp.self()); | |
Shape<dim> dshape = expr::ShapeCheck<dim, R>::Check(dst->self()); | |
CHECK(eshape[0] == 0 || eshape == dshape) | |
<< "Assignment: Shape of Tensors are not consistent with target, " | |
<< "eshape: " << eshape << " dshape:" << dshape; | |
MapExpCPUEngine<expr::PacketCheck<E, MSHADOW_DEFAULT_PACKET>::kPass, | |
Saver, R, dim, DType, E, etype> | |
::Map(dst->ptrself(), exp); | |
} | |
template<typename Saver, typename Reducer, | |
typename R, typename DType, typename E, int etype> | |
inline void MapReduceKeepLowest(TRValue<R, cpu, 1, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp, | |
DType scale) { | |
expr::TypeCheckPass<expr::TypeCheck<cpu, 1, DType, E>::kRedPass> | |
::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); | |
Shape<2> eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E> | |
::Check(exp.self()).FlatTo2D(); | |
Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); | |
CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match"; | |
CHECK_NE(eshape[0], 0) << "can not reduce over empty tensor"; | |
// execution | |
expr::Plan<R, DType> dplan = MakePlan(dst->self()); | |
expr::Plan<E, DType> splan = MakePlan(exp.self()); | |
#if (MSHADOW_USE_CUDA == 0) | |
#pragma omp parallel for | |
#endif | |
for (openmp_index_t x = 0; x < eshape[1]; ++x) { | |
DType res = splan.Eval(0, x); | |
for (index_t y = 1; y < eshape[0]; ++y) { | |
Reducer::Reduce(res, splan.Eval(y, x)); | |
} | |
Saver::template Save<DType>(dplan.REval(0, x), res * scale); | |
} | |
} | |
template<typename Saver, typename Reducer, int dimkeep, | |
typename R, typename DType, typename E, int etype> | |
inline void MapReduceKeepHighDim(TRValue<R, cpu, 1, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp, | |
DType scale) { | |
expr::TypeCheckPass<expr::TypeCheck<cpu, dimkeep, DType, E>::kRedPass> | |
::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); | |
typedef Shape<expr::ExpInfo<E>::kDim> EShape; | |
EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E> | |
::Check(exp.self()); | |
Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); | |
CHECK_EQ(eshape[dimkeep], dshape[0]) | |
<< "MapReduceKeepHighDim::reduction dimension do not match"; | |
// use equvalent form | |
Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep), | |
eshape[dimkeep], | |
eshape.ProdShape(dimkeep + 1, EShape::kSubdim), | |
eshape[EShape::kSubdim]); | |
// execution | |
expr::Plan<R, DType> dplan = MakePlan(dst->self()); | |
expr::Plan<E, DType> splan = MakePlan(exp.self()); | |
#if (MSHADOW_USE_CUDA == 0) | |
#pragma omp parallel for | |
#endif | |
for (openmp_index_t c = 0; c < pshape[1]; ++c) { | |
DType res; Reducer::SetInitValue(res); | |
for (index_t n = 0; n < pshape[0]; ++n) { | |
DType tres; Reducer::SetInitValue(tres); | |
for (index_t y = 0; y < pshape[2]; ++y) { | |
for (index_t x = 0; x < pshape[3]; ++x) { | |
Reducer::Reduce(tres, | |
splan.Eval((n * pshape[1] + c) * pshape[2] + y, x)); | |
} | |
} | |
Reducer::Reduce(res, tres); | |
} | |
Saver::template Save<DType>(dplan.REval(0, c), DType(res * scale)); | |
} | |
} | |
template<typename DType> | |
inline void Softmax(Tensor<cpu, 1, DType> dst, | |
const Tensor<cpu, 1, DType> &energy) { | |
DType mmax = energy[0]; | |
for (index_t x = 1; x < dst.size(0); ++x) { | |
if (mmax < energy[x]) mmax = energy[x]; | |
} | |
DType sum = DType(0.0f); | |
for (index_t x = 0; x < dst.size(0); ++x) { | |
dst[x] = std::exp(energy[x] - mmax); | |
sum += dst[x]; | |
} | |
for (index_t x = 0; x < dst.size(0); ++x) { | |
dst[x] /= sum; | |
} | |
} | |
template<typename DType> | |
inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst, | |
const Tensor<cpu, 2, DType> &src, | |
const Tensor<cpu, 1, DType> &label) { | |
#pragma omp parallel for | |
for (openmp_index_t y = 0; y < dst.size(0); ++y) { | |
const index_t k = static_cast<int>(label[y]); | |
for (index_t x = 0; x < dst.size(1); ++x) { | |
if (x == k) { | |
dst[y][k] = src[y][k] - 1.0f; | |
} else { | |
dst[y][x] = src[y][x]; | |
} | |
} | |
} | |
} | |
template<typename DType> | |
inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst, | |
const Tensor<cpu, 2, DType> &src, | |
const Tensor<cpu, 1, DType> &label, | |
const DType &ignore_label) { | |
#pragma omp parallel for | |
for (openmp_index_t y = 0; y < dst.size(0); ++y) { | |
const index_t k = static_cast<int>(label[y]); | |
for (index_t x = 0; x < dst.size(1); ++x) { | |
if (static_cast<int>(ignore_label) == k) { | |
dst[y][x] = 0.0f; | |
} else { | |
if (x == k) { | |
dst[y][k] = src[y][k] - 1.0f; | |
} else { | |
dst[y][x] = src[y][x]; | |
} | |
} | |
} | |
} | |
} | |
template<typename DType> | |
inline void SoftmaxGrad(Tensor<cpu, 3, DType> dst, | |
const Tensor<cpu, 3, DType> &src, | |
const Tensor<cpu, 2, DType> &label) { | |
#pragma omp parallel for | |
for (openmp_index_t n = 0; n < dst.size(2); ++n) { | |
for (index_t y = 0; y < dst.size(0); ++y) { | |
const index_t k = static_cast<int>(label[y][n]); | |
for (index_t x = 0; x < dst.size(1); ++x) { | |
if (x == k) { | |
dst[y][k][n] = src[y][k][n] - 1.0f; | |
} else { | |
dst[y][x][n] = src[y][x][n]; | |
} | |
} | |
} | |
} | |
} | |
template<typename DType> | |
inline void SoftmaxGrad(Tensor<cpu, 3, DType> dst, | |
const Tensor<cpu, 3, DType> &src, | |
const Tensor<cpu, 2, DType> &label, | |
const DType &ignore_label) { | |
#pragma omp parallel for | |
for (openmp_index_t n = 0; n < dst.size(2); ++n) { | |
for (index_t y = 0; y < dst.size(0); ++y) { | |
const index_t k = static_cast<int>(label[y][n]); | |
if (k == static_cast<int>(ignore_label)) { | |
for (index_t x = 0; x < dst.size(1); ++x) { | |
dst[y][x][n] = DType(0.0f); | |
} | |
} else { | |
for (index_t x = 0; x < dst.size(1); ++x) { | |
if (x == k) { | |
dst[y][k][n] = src[y][k][n] - 1.0f; | |
} else { | |
dst[y][x][n] = src[y][x][n]; | |
} | |
} | |
} | |
} | |
} | |
} | |
template<typename DType> | |
inline void Softmax(Tensor<cpu, 2, DType> dst, | |
const Tensor<cpu, 2, DType> &energy) { | |
CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch"; | |
#pragma omp parallel for | |
for (openmp_index_t y = 0; y < dst.size(0); ++y) { | |
Softmax(dst[y], energy[y]); | |
} | |
} | |
template<typename DType> | |
inline void Softmax(Tensor<cpu, 3, DType> dst, | |
const Tensor<cpu, 3, DType> &energy) { | |
CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch"; | |
#pragma omp parallel for | |
for (openmp_index_t y = 0; y < dst.size(0); ++y) { | |
for (index_t n = 0; n < dst.size(2); ++n) { | |
DType mmax = energy[y][0][n]; | |
for (index_t x = 1; x < dst.size(1); ++x) { | |
if (mmax < energy[y][x][n]) mmax = energy[y][x][n]; | |
} | |
DType sum = DType(0.0f); | |
for (index_t x = 0; x < dst.size(1); ++x) { | |
dst[y][x][n] = std::exp(energy[y][x][n] - mmax); | |
sum += dst[y][x][n]; | |
} | |
for (index_t x = 0; x < dst.size(1); ++x) { | |
dst[y][x][n] /= sum; | |
} | |
} | |
} | |
} | |
template<typename IndexType, typename DType> | |
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst, | |
const Tensor<cpu, 1, IndexType>& index, | |
const Tensor<cpu, 2, DType> &src) { | |
for (index_t y = 0; y < index.size(0); ++y) { | |
dst[index[y]] += src[y]; | |
} | |
} | |
template<typename IndexType, typename DType> | |
inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst, | |
const Tensor<cpu, 1, IndexType>& sorted, | |
const Tensor<cpu, 1, IndexType>& index, | |
const Tensor<cpu, 2, DType> &src) { | |
for (index_t y = 0; y < sorted.size(0); ++y) { | |
dst[sorted[y]] += src[index[y]]; | |
} | |
} | |
template<typename IndexType, typename DType> | |
inline void IndexFill(Tensor<cpu, 2, DType> dst, | |
const Tensor<cpu, 1, IndexType>& index, | |
const Tensor<cpu, 2, DType> &src) { | |
for (index_t y = 0; y < index.size(0); ++y) { | |
for (index_t j = 0; j < src.size(1); j++) { | |
dst[index[y]][j] = src[y][j]; | |
} | |
} | |
} | |
template<typename KDType, typename VDType> | |
inline void SortByKey(Tensor<cpu, 1, KDType> keys, Tensor<cpu, 1, VDType> values, | |
bool is_ascend) { | |
CHECK_EQ(keys.CheckContiguous(), true); | |
CHECK_EQ(values.CheckContiguous(), true); | |
CHECK_EQ(keys.size(0), values.size(0)) | |
<< "The sizes of key/value are not equal! keys_size: " << keys.size(0) | |
<< "values_size: " << values.size(0); | |
std::vector<size_t> idx(keys.size(0)); | |
std::vector<KDType> keys_vec(keys.size(0)); | |
std::vector<VDType> values_vec(values.size(0)); | |
for (int i = 0; i < keys.size(0); i++) { | |
idx[i] = i; | |
keys_vec[i] = keys[i]; | |
values_vec[i] = values[i]; | |
} | |
if (is_ascend) { | |
std::stable_sort(idx.begin(), idx.end(), | |
[&keys_vec](size_t i1, size_t i2) | |
{return keys_vec[i1] < keys_vec[i2]; }); | |
} else { | |
std::stable_sort(idx.begin(), idx.end(), | |
[&keys_vec](size_t i1, size_t i2) | |
{return keys_vec[i1] > keys_vec[i2]; }); | |
} | |
for (index_t i = 0; i < values.size(0); i++) { | |
keys[i] = keys_vec[idx[i]]; | |
values[i] = values_vec[idx[i]]; | |
} | |
} | |
template<typename Device, typename VDType, typename SDType> | |
inline void VectorizedSort(Tensor<Device, 1, VDType> values, Tensor<Device, 1, SDType> segments) { | |
// We can sort each segments using two stable sorts | |
SortByKey(values, segments, true); | |
SortByKey(segments, values, true); | |
} | |
// blas related | |
template<typename Device, typename DType> | |
inline void VectorDot(Tensor<Device, 1, DType> dst, | |
const Tensor<Device, 1, DType> &lhs, | |
const Tensor<Device, 1, DType> &rhs) { | |
CHECK_EQ(lhs.size(0), rhs.size(0)) | |
<< "VectorDot: Shape mismatch"; | |
CHECK_EQ(dst.size(0), 1) | |
<< "VectorDot: expect dst to be scalar"; | |
expr::BLASEngine<Device, DType>::SetStream(lhs.stream_); | |
mshadow::expr::BLASEngine<Device, DType>::dot( | |
lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_); | |
} | |
template<bool transpose_left, bool transpose_right, typename Device, typename DType> | |
inline void BatchGEMM(Tensor<Device, 3, DType> dst, | |
const Tensor<Device, 3, DType> &lhs, | |
const Tensor<Device, 3, DType> &rhs, | |
DType alpha, | |
DType beta, | |
Tensor<Device, 1, DType*> workspace) { | |
index_t batch_size = dst.shape_[0]; | |
expr::BLASEngine<Device, DType>::SetStream(dst.stream_); | |
Shape<3> sleft = transpose_left ? Shape3(lhs.shape_[0], lhs.shape_[2], lhs.shape_[1]) | |
: lhs.shape_; | |
Shape<3> sright = transpose_right ? Shape3(rhs.shape_[0], rhs.shape_[2], rhs.shape_[1]) | |
: rhs.shape_; | |
CHECK_EQ(dst.CheckContiguous(), true); | |
CHECK_EQ(lhs.CheckContiguous(), true); | |
CHECK_EQ(rhs.CheckContiguous(), true); | |
CHECK(sleft[0] == batch_size && sright[0] == batch_size) | |
<< "BatchGEMM: batchsize must be equal." | |
<< "dst: " << dst.shape_ << "\n" | |
<< "lhs: " << sleft << "\n" | |
<< "rhs: " << sright << "\n"; | |
CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1]) | |
<< "BatchGEMM: matrix shape mismatch" | |
<< "dst: " << dst.shape_ << "\n" | |
<< "lhs: " << sleft << "\n" | |
<< "rhs: " << sright << "\n"; | |
CHECK(workspace.size(0) >= 3 * batch_size) | |
<< "Workspace Size must be bigger than " << 3 * batch_size; | |
CHECK_EQ(workspace.CheckContiguous(), true); | |
// use column major argument to compatible with most BLAS | |
expr::BLASEngine<Device, DType>::batched_gemm | |
(dst.stream_, | |
transpose_right, transpose_left, | |
transpose_right ? rhs.size(1) : rhs.size(2), | |
transpose_left ? lhs.size(2) : lhs.size(1), | |
transpose_right ? rhs.size(2) : rhs.size(1), | |
alpha, | |
rhs.dptr_, rhs.stride_, | |
lhs.dptr_, lhs.stride_, | |
beta, | |
dst.dptr_, dst.stride_, batch_size, | |
workspace.dptr_); | |
} | |
} // namespace mshadow | |
#endif // MSHADOW_TENSOR_CPU_INL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/tensor_cpu-inl.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/tensor_gpu-inl.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file tensor_gpu-inl.h | |
* \brief implementation of GPU host code | |
* \author Bing Xu, Tianqi Chen | |
*/ | |
#ifndef MSHADOW_TENSOR_GPU_INL_H_ | |
#define MSHADOW_TENSOR_GPU_INL_H_ | |
namespace mshadow { | |
#if MSHADOW_USE_CUDA | |
template<> | |
inline void InitTensorEngine<gpu>(int dev_id) { | |
cudaDeviceProp prop; | |
int device_id = 0; | |
int device_count = 0; | |
cudaGetDeviceCount(&device_count); | |
CHECK_GT(device_count, 0) << "Cannot find CUDA device. Please check CUDA-Configuration"; | |
if (dev_id < 0) { | |
device_id = 0; | |
} else { | |
device_id = dev_id; | |
} | |
CHECK_LT(device_id, device_count) << "Incorrect Device ID"; | |
MSHADOW_CUDA_CALL(cudaSetDevice(device_id)); | |
MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&prop, device_id)); | |
} | |
template<> | |
inline void ShutdownTensorEngine<gpu>(void) { | |
} | |
template<> | |
inline void SetDevice<gpu>(int devid) { | |
MSHADOW_CUDA_CALL(cudaSetDevice(devid)); | |
} | |
template<int dim, typename DType> | |
inline void AllocSpace(Tensor<gpu, dim, DType> *obj, bool pad) { | |
size_t pitch; | |
// common choice for cuda mem align unit is 32 | |
if (pad && obj->size(dim - 1) >= MSHADOW_MIN_PAD_RATIO * 32) { | |
MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch, | |
obj->size(dim - 1) * sizeof(DType), | |
obj->shape_.FlatTo2D()[0])); | |
obj->stride_ = static_cast<index_t>(pitch / sizeof(DType)); | |
} else { | |
obj->stride_ = obj->size(dim - 1); | |
MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch, | |
obj->shape_.Size() * sizeof(DType), 1)); | |
} | |
} | |
template<int dim, typename DType> | |
inline void FreeSpace(Tensor<gpu, dim, DType> *obj) { | |
MSHADOW_CUDA_CALL(cudaFree(obj->dptr_)); | |
obj->dptr_ = NULL; | |
} | |
template<typename A, typename B, int dim, typename DType> | |
inline void Copy(Tensor<A, dim, DType> _dst, | |
Tensor<B, dim, DType> _src, | |
cudaMemcpyKind kind, | |
Stream<gpu> *stream) { | |
CHECK_EQ(_dst.shape_, _src.shape_) << "Copy:shape mismatch"; | |
Tensor<A, 2, DType> dst = _dst.FlatTo2D(); | |
Tensor<B, 2, DType> src = _src.FlatTo2D(); | |
MSHADOW_CUDA_CALL(cudaMemcpy2DAsync(dst.dptr_, dst.stride_ * sizeof(DType), | |
src.dptr_, src.stride_ * sizeof(DType), | |
dst.size(1) * sizeof(DType), | |
dst.size(0), kind, | |
Stream<gpu>::GetStream(stream))); | |
// use synchronize call behavior for zero stream | |
if (stream == NULL) { | |
MSHADOW_CUDA_CALL(cudaStreamSynchronize(0)); | |
} | |
} | |
template<int dim, typename DType> | |
inline void Copy(Tensor<cpu, dim, DType> dst, | |
const Tensor<gpu, dim, DType> &src, | |
Stream<gpu> *stream) { | |
Copy(dst, src, cudaMemcpyDeviceToHost, stream); | |
} | |
template<int dim, typename DType> | |
inline void Copy(Tensor<gpu, dim, DType> dst, | |
const Tensor<gpu, dim, DType> &src, | |
Stream<gpu> *stream) { | |
Copy(dst, src, cudaMemcpyDeviceToDevice, stream); | |
} | |
template<int dim, typename DType> | |
inline void Copy(Tensor<gpu, dim, DType> dst, | |
const Tensor<cpu, dim, DType> &src, | |
Stream<gpu> *stream) { | |
Copy(dst, src, cudaMemcpyHostToDevice, stream); | |
} | |
#endif // MSHADOW_USE_CUDA | |
} // namespace mshadow | |
// the following part is included only if compiler is nvcc | |
#ifdef __CUDACC__ | |
namespace mshadow { | |
template<typename Saver, typename R, int dim, | |
typename DType, typename E, int etype> | |
inline void MapExp(TRValue<R, gpu, dim, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp) { | |
expr::TypeCheckPass<expr::TypeCheck<gpu, dim, DType, E>::kMapPass> | |
::Error_All_Tensor_in_Exp_Must_Have_Same_Type(); | |
Shape<dim> eshape = expr::ShapeCheck<dim, E>::Check(exp.self()); | |
Shape<dim> dshape = expr::ShapeCheck<dim, R>::Check(dst->self()); | |
CHECK(eshape[0] == 0 || eshape == dshape) | |
<< "Assignment: Shape of Tensors are not consistent with target, " | |
<< "eshape: " << eshape << " dshape:" << dshape; | |
cuda::MapPlan<Saver>(MakePlan(dst->self()), | |
MakePlan(exp.self()), | |
dshape.FlatTo2D(), | |
Stream<gpu>::GetStream(expr::StreamInfo<gpu, R>::Get(dst->self()))); | |
} | |
template<typename Saver, typename Reducer, | |
typename R, typename DType, typename E, int etype> | |
inline void MapReduceKeepLowest(TRValue<R, gpu, 1, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp, | |
DType scale) { | |
expr::TypeCheckPass<expr::TypeCheck<gpu, 1, DType, E>::kRedPass> | |
::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); | |
Shape<2> eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E> | |
::Check(exp.self()).FlatTo2D(); | |
Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); | |
CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match"; | |
CHECK_NE(eshape[0], 0) << "can not reduce over empty tensor"; | |
cuda::MapReduceKeepLowest<Saver, Reducer> | |
(MakePlan(dst->self()), MakePlan(exp.self()), scale, eshape, | |
Stream<gpu>::GetStream(expr::StreamInfo<gpu, R>::Get(dst->self()))); | |
} | |
template<typename Saver, typename Reducer, int dimkeep, | |
typename R, typename DType, typename E, int etype> | |
inline void MapReduceKeepHighDim(TRValue<R, gpu, 1, DType> *dst, | |
const expr::Exp<E, DType, etype> &exp, | |
DType scale) { | |
expr::TypeCheckPass<expr::TypeCheck<gpu, dimkeep, DType, E>::kRedPass> | |
::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); | |
typedef Shape<expr::ExpInfo<E>::kDim> EShape; | |
EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E> | |
::Check(exp.self()); | |
Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); | |
CHECK_EQ(eshape[dimkeep], dshape[0]) << "MapReduceKeepHighDim::reduction dimension do not match"; | |
// use equvalent form | |
Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep), | |
eshape[dimkeep], | |
eshape.ProdShape(dimkeep + 1, EShape::kSubdim), | |
eshape[EShape::kSubdim]); | |
// call equavalent map red dim 2 | |
cuda::MapReduceKeepDim1<Saver, Reducer> | |
(MakePlan(dst->self()), MakePlan(exp.self()), scale, pshape, | |
Stream<gpu>::GetStream(expr::StreamInfo<gpu, R>::Get(dst->self()))); | |
} | |
template<typename DType> | |
inline void Softmax(Tensor<gpu, 2, DType> dst, | |
const Tensor<gpu, 2, DType>& src) { | |
cuda::Softmax(dst, src); | |
} | |
template<typename DType> | |
inline void Softmax(Tensor<gpu, 3, DType> dst, | |
const Tensor<gpu, 3, DType>& src) { | |
cuda::Softmax(dst, src); | |
} | |
template<typename DType> | |
inline void SoftmaxGrad(Tensor<gpu, 2, DType> dst, | |
const Tensor<gpu, 2, DType> &src, | |
const Tensor<gpu, 1, DType> &label) { | |
cuda::SoftmaxGrad(dst, src, label); | |
} | |
template<typename DType> | |
inline void SoftmaxGrad(Tensor<gpu, 2, DType> dst, | |
const Tensor<gpu, 2, DType> &src, | |
const Tensor<gpu, 1, DType> &label, | |
const DType &ignore_label) { | |
cuda::SoftmaxGrad(dst, src, label, ignore_label); | |
} | |
template<typename DType> | |
inline void SoftmaxGrad(Tensor<gpu, 3, DType> dst, | |
const Tensor<gpu, 3, DType> &src, | |
const Tensor<gpu, 2, DType> &label) { | |
cuda::SoftmaxGrad(dst, src, label); | |
} | |
template<typename DType> | |
inline void SoftmaxGrad(Tensor<gpu, 3, DType> dst, | |
const Tensor<gpu, 3, DType> &src, | |
const Tensor<gpu, 2, DType> &label, | |
const DType &ignore_label) { | |
cuda::SoftmaxGrad(dst, src, label, ignore_label); | |
} | |
template<typename IndexType, typename DType> | |
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst, | |
const Tensor<gpu, 1, IndexType>& index, | |
const Tensor<gpu, 2, DType> &src) { | |
cuda::AddTakeGrad(dst, index, src); | |
} | |
template<typename IndexType, typename DType> | |
inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst, | |
const Tensor<gpu, 1, IndexType>& sorted, | |
const Tensor<gpu, 1, IndexType>& index, | |
const Tensor<gpu, 2, DType> &src) { | |
cuda::AddTakeGradLargeBatch(dst, sorted, index, src); | |
} | |
template<typename KDType, typename VDType> | |
inline void SortByKey(Tensor<gpu, 1, KDType> keys, Tensor<gpu, 1, VDType> values, | |
bool is_ascend) { | |
cuda::SortByKey(keys, values, is_ascend); | |
} | |
template<typename IndexType, typename DType> | |
inline void IndexFill(Tensor<gpu, 2, DType> dst, | |
const Tensor<gpu, 1, IndexType>& index, | |
const Tensor<gpu, 2, DType> &src) { | |
cuda::IndexFill(dst, index, src); | |
} | |
} // namespace mshadow | |
#endif // __CUDACC__ | |
#endif // MSHADOW_TENSOR_GPU_INL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/tensor_gpu-inl.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/io.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file io.h | |
* \brief definitions of I/O functions for mshadow tensor | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_IO_H_ | |
#define MSHADOW_IO_H_ | |
namespace mshadow { | |
namespace utils { | |
/*! | |
* \brief interface of stream I/O, used to serialize data, | |
* mshadow does not restricted to only this interface in SaveBinary/LoadBinary | |
* mshadow accept all class that implements Read and Write | |
*/ | |
class IStream { | |
public: | |
/*! | |
* \brief read data from stream | |
* \param ptr pointer to memory buffer | |
* \param size size of block | |
* \return usually is the size of data readed | |
*/ | |
virtual size_t Read(void *ptr, size_t size) = 0; | |
/*! | |
* \brief write data to stream | |
* \param ptr pointer to memory buffer | |
* \param size size of block | |
*/ | |
virtual void Write(const void *ptr, size_t size) = 0; | |
/*! \brief virtual destructor */ | |
virtual ~IStream(void) {} | |
}; | |
} // namespace utils | |
/*! | |
* \brief CPU/GPU: save a tensor by binary format, for GPU version, a temp Tensor<cpu,dim> storage will be allocated | |
* \param fo output binary stream | |
* \param src source data file | |
* \tparam dim dimension of tensor | |
* \tparam DType type of element in tensor | |
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. | |
*/ | |
template<int dim, typename DType, typename TStream> | |
inline void SaveBinary(TStream &fo, const Tensor<cpu, dim, DType> &src); // NOLINT(*) | |
/*! | |
* \brief CPU/GPU: save a tensor by binary format, for GPU version, a temp Tensor<cpu,dim> storage will be allocated | |
* \param fo output binary stream | |
* \param src source data file | |
* \tparam dim dimension of tensor | |
* \tparam DType type of element in tensor | |
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. | |
*/ | |
template<int dim, typename DType, typename TStream> | |
inline void SaveBinary(TStream &fo, const Tensor<gpu, dim, DType> &src); // NOLINT(*) | |
/*! | |
* \brief CPU/GPU: load a tensor by binary format, for GPU version, a temp Tensor<cpu,dim> storage will be allocated | |
* if pre_alloc is true , then space in dst is preallocated, and must have same shape of the tensor loaded | |
* if pre_alloc is false, then dst originally does not have space allocated, LoadBinary will allocate space for dst | |
* \param fi output binary stream | |
* \param dst destination file | |
* \param pre_alloc whether space is pre-allocated, if false, space allocation will happen | |
* \tparam dim dimension of tensor | |
* \tparam DType type of element in tensor | |
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. | |
*/ | |
template<int dim, typename DType, typename TStream> | |
inline void LoadBinary(TStream &fi, // NOLINT(*) | |
Tensor<cpu, dim, DType> *dst, bool pre_alloc); | |
/*! | |
* \brief CPU/GPU: load a tensor by binary format, for GPU version, a temp Tensor<cpu,dim> storage will be allocated | |
* if pre_alloc is true , then space in dst is preallocated, and must have same shape of the tensor loaded | |
* if pre_alloc is false, then dst originally does not have space allocated, LoadBinary will allocate space for dst | |
* \param fi output binary stream | |
* \param dst destination file | |
* \param pre_alloc whether space is pre-allocated, if false, space allocation will happen | |
* \tparam dim dimension of tensor | |
* \tparam DType type of element in tensor | |
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. | |
*/ | |
template<int dim, typename DType, typename TStream> | |
inline void LoadBinary(TStream &fi, // NOLINT(*) | |
Tensor<gpu, dim, DType> *dst, bool pre_alloc); | |
// implementations | |
template<int dim, typename DType, typename TStream> | |
inline void SaveBinary(TStream &fo, const Tensor<cpu, dim, DType> &src_) { // NOLINT(*) | |
fo.Write(&src_.shape_, sizeof(src_.shape_)); | |
Tensor<cpu, 2, DType> src = src_.FlatTo2D(); | |
for (index_t i = 0; i < src.size(0); ++i) { | |
fo.Write(src[i].dptr_, sizeof(DType) * src.size(1)); | |
} | |
} | |
template<int dim, typename DType, typename TStream> | |
inline void SaveBinary(TStream &fo, const Tensor<gpu, dim, DType> &src) { // NOLINT(*) | |
// copy to CPU, then save | |
Tensor<cpu, dim, DType> tmp(src.shape_); | |
AllocSpace(&tmp); | |
Stream<gpu> stream; | |
Copy(tmp, src, &stream); | |
SaveBinary(fo, tmp); | |
FreeSpace(&tmp); | |
} | |
template<int dim, typename DType, typename TStream> | |
inline void LoadBinary(TStream &fi, // NOLINT(*) | |
Tensor<cpu, dim, DType> *dst_, bool pre_alloc) { | |
Shape<dim> shape; | |
CHECK_NE(fi.Read(&shape, sizeof(shape)), 0) << "mshadow::LoadBinary"; | |
if (pre_alloc) { | |
CHECK_EQ(shape, dst_->shape_) << "LoadBinary, shape do not match pre-allocated shape"; | |
} else { | |
dst_->shape_ = shape; AllocSpace(dst_); | |
} | |
Tensor<cpu, 2, DType> dst = dst_->FlatTo2D(); | |
if (dst.size(0) == 0) return; | |
for (index_t i = 0; i < dst.size(0); ++i) { | |
CHECK_NE(fi.Read(dst[i].dptr_, sizeof(DType) * dst.size(1)), 0) << "mshadow::LoadBinary"; | |
} | |
} | |
template<int dim, typename DType, typename TStream> | |
inline void LoadBinary(TStream &fi, // NOLINT(*) | |
Tensor<gpu, dim, DType> *dst, bool pre_alloc) { | |
Tensor<cpu, dim, DType> tmp; | |
LoadBinary(fi, &tmp, false); | |
if (pre_alloc) { | |
CHECK_EQ(tmp.shape, dst->shape_) << "LoadBinary, shape do not match pre-allocated shape"; | |
} else { | |
dst->shape = tmp.shape; AllocSpace(dst); | |
} | |
Stream<gpu> stream; | |
Copy(*dst, tmp, &stream); | |
FreeSpace(&tmp); | |
} | |
} // namespace mshadow | |
#endif // MSHADOW_IO_H_ | |
//===== EXPANDED: ../mshadow/mshadow/io.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/tensor_container.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file tensor_container.h | |
* \brief tensor container that does memory allocation and resize like STL | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_TENSOR_CONTAINER_H_ | |
#define MSHADOW_TENSOR_CONTAINER_H_ | |
namespace mshadow { | |
/*! | |
* \brief tensor container that does memory allocation and resize like STL, | |
* use it to save the lines of FreeSpace in class. | |
* Do not abuse it, efficiency can come from pre-allocation and no re-allocation | |
* | |
* \tparam Device which device the tensor is on | |
* \tparam dimension dimension of the tensor | |
*/ | |
template<typename Device, int dimension, typename DType = default_real_t> | |
class TensorContainer: public Tensor<Device, dimension, DType> { | |
public: | |
/*! | |
* \brief constructor | |
* \param pad whether use padding alignment in space allocation | |
*/ | |
explicit TensorContainer(bool pad = MSHADOW_ALLOC_PAD) { | |
this->pad_ = pad; | |
this->dptr_ = data_.dptr_ = NULL; | |
this->shape_[0] = 0; | |
this->stride_ = 0; | |
this->data_.stride_ = 0; | |
this->data_.shape_[0] = 0; | |
} | |
/*! | |
* \brief constructor | |
* \param shape intial shape | |
*/ | |
explicit TensorContainer(const Shape<dimension> &shape) { | |
this->pad_ = MSHADOW_ALLOC_PAD; | |
data_.dptr_ = NULL; | |
this->AllocByShape(shape); | |
} | |
/*! | |
* \brief constructor | |
* \param shape intial shape | |
* \param initv intial value | |
*/ | |
explicit TensorContainer(const Shape<dimension> &shape, DType initv) { | |
this->pad_ = MSHADOW_ALLOC_PAD; | |
data_.dptr_ = NULL; | |
this->AllocByShape(shape); | |
(*this) = initv; | |
} | |
/*! | |
* \brief copy constructor | |
* \param src source value | |
*/ | |
TensorContainer | |
(const TensorContainer<Device, dimension, DType> &src) | |
: pad_(src.pad_) { | |
this->dptr_ = data_.dptr_ = NULL; | |
this->shape_[0] = 0; | |
this->stride_ = 0; | |
this->data_.stride_ = 0; | |
this->data_.shape_[0] = 0; | |
this->stream_ = src.stream_; | |
if (src.dptr_ != NULL) { | |
this->AllocByShape(src.shape_); | |
mshadow::Copy(*this, src, this->stream_); | |
} | |
} | |
~TensorContainer(void) { | |
this->Release(); | |
} | |
/*! | |
* \brief resize the container to given shape, content is NOT preserved | |
* \param shape target shape | |
*/ | |
inline void Resize(const Shape<dimension> &shape) { | |
Shape<2> s2 = shape.FlatTo2D(); | |
if (s2.shape_[1] > data_.stride_ || s2.shape_[0] > data_.size(0)) { | |
this->AllocByShape(shape); | |
} else { | |
this->shape_ = shape; | |
if (this->pad_) { | |
this->stride_ = data_.stride_; | |
} else { | |
this->stride_ = s2.shape_[1]; | |
} | |
} | |
} | |
/*! | |
* \brief resize the container to given shape, and initialize, content is NOT preserved | |
* \param shape target shape | |
* \param initv initialization value | |
*/ | |
inline void Resize(const Shape<dimension> &shape, DType initv) { | |
this->Resize(shape); | |
(*this) = initv; | |
} | |
/*! \brief set whether padding is allowed in tensor */ | |
inline void set_pad(bool pad) { | |
this->pad_ = pad; | |
} | |
/*! | |
* \brief save by binary format | |
* \param fo output binary stream | |
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. | |
*/ | |
template<typename TStream> | |
inline void SaveBinary(TStream &fo) const { // NOLINT(*) | |
mshadow::SaveBinary(fo, *this); | |
} | |
/*! | |
* \brief load by binary format, a temp Tensor<cpu,dim> storage will be allocated | |
* \param fi input binary stream | |
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. | |
*/ | |
template<typename TStream> | |
inline void LoadBinary(TStream &fi) { // NOLINT(*) | |
Tensor<cpu, dimension, DType> tmp; | |
mshadow::LoadBinary(fi, &tmp, false); | |
this->Resize(tmp.shape_); | |
Stream<Device> stream; | |
Copy(*this, tmp, &stream); | |
mshadow::FreeSpace(&tmp); | |
} | |
/*! | |
* \brief assign operator from TensorContainer | |
* \param src source value | |
* \return reference of self | |
*/ | |
inline TensorContainer &operator= | |
(const TensorContainer<Device, dimension, DType> &src) { | |
this->pad_ = src.pad_; | |
this->stream_ = src.stream_; | |
if (src.dptr_ != NULL) { | |
this->Resize(src.shape_); | |
mshadow::Copy(*this, src, this->stream_); | |
} | |
return *this; | |
} | |
/*!\brief functions to fit expression template */ | |
inline Tensor<Device, dimension, DType> &operator=(DType s) { | |
return this->__assign(s); | |
} | |
/*!\brief functions to fit expression template */ | |
template<typename E> | |
inline Tensor<Device, dimension, DType> & | |
operator=(const expr::Exp<E, DType, expr::type::kMapper> &exp) { | |
return this->__assign(exp); | |
} | |
/*!\brief functions to fit expression template */ | |
template<typename E> | |
inline Tensor<Device, dimension, DType> & | |
operator=(const expr::Exp<E, DType, expr::type::kChainer> &exp) { | |
return this->__assign(exp); | |
} | |
/*!\brief functions to fit expression template */ | |
template<typename E> | |
inline Tensor<Device, dimension, DType> & | |
operator=(const expr::Exp<E, DType, expr::type::kComplex> &exp) { | |
return this->__assign(exp); | |
} | |
/*! | |
* \brief Release the llocated space, | |
* The TensorContainer is still functionable, | |
* but will restart allocating space when Resize is called. | |
*/ | |
inline void Release(void) { | |
if (data_.dptr_ != NULL) { | |
this->shape_[0] = 0; | |
this->stride_ = 0; | |
this->data_.stride_ = 0; | |
this->data_.shape_[0] = 0; | |
try { | |
mshadow::FreeSpace(&data_); | |
} catch (const dmlc::Error &e) { | |
this->dptr_ = data_.dptr_ = NULL; | |
throw e; | |
} | |
this->dptr_ = data_.dptr_ = NULL; | |
} | |
} | |
private: | |
/*! \brief whether we do padding in the space */ | |
bool pad_; | |
/*! \brief the shape of data_ is actually current data space */ | |
Tensor<Device, 2, DType> data_; | |
inline void AllocByShape(const Shape<dimension>& shape) { | |
if (data_.dptr_ != NULL) this->Release(); | |
data_.shape_ = shape.FlatTo2D(); | |
mshadow::AllocSpace(&data_, pad_); | |
this->dptr_ = data_.dptr_; | |
this->shape_ = shape; | |
if (this->pad_) { | |
this->stride_ = data_.stride_; | |
} else { | |
this->stride_ = data_.size(1); | |
} | |
} | |
}; | |
} // namespace mshadow | |
#endif // MSHADOW_TENSOR_CONTAINER_H_ | |
//===== EXPANDED: ../mshadow/mshadow/tensor_container.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/tensor_blob.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file tensor_blob.h | |
* \brief TBlob class that holds common representation of | |
* arbitrary dimension tensor, can be used to transformed | |
* to normal fixed dimension tensor | |
* \author Tianqi Chen | |
*/ | |
#ifndef MSHADOW_TENSOR_BLOB_H_ | |
#define MSHADOW_TENSOR_BLOB_H_ | |
namespace mshadow { | |
/*! | |
* \brief dynamic shape class that can hold shape | |
* of arbitrary dimension | |
*/ | |
struct TShape { | |
public: | |
/*! \brief constructor */ | |
TShape() | |
: ndim_(0), | |
num_heap_allocated_(0), | |
data_heap_(NULL) {} | |
/*! | |
* \brief construct an "all-one" TShape with given dimension | |
* \param ndim the number of dimension of the shape | |
*/ | |
explicit TShape(index_t ndim) | |
: ndim_(ndim) { | |
if (ndim_ <= kStackCache) { | |
data_heap_ = NULL; | |
num_heap_allocated_ = 0; | |
std::fill_n(data_stack_, ndim_, 1); | |
} else { | |
data_heap_ = new index_t[ndim_]; | |
num_heap_allocated_ = ndim_; | |
std::fill_n(data_heap_, ndim_, 1); | |
} | |
} | |
/*! | |
* \brief constructor from TShape | |
* \param s the source shape | |
*/ | |
TShape(const TShape &s) | |
: ndim_(s.ndim_) { | |
if (ndim_ <= kStackCache) { | |
data_heap_ = NULL; | |
num_heap_allocated_ = 0; | |
std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_); | |
} else { | |
data_heap_ = new index_t[ndim_]; | |
num_heap_allocated_ = ndim_; | |
std::copy(s.data_heap_, s.data_heap_ + ndim_, data_heap_); | |
} | |
} | |
/*! | |
* \brief construct the TShape from content of iterator | |
* \param begin the beginning of iterator | |
* \param end end the end of the iterator | |
* \tparam RandomAccessIterator iterator type | |
*/ | |
template<typename RandomAccessIterator> | |
TShape(RandomAccessIterator begin, | |
RandomAccessIterator end) | |
: ndim_(0), | |
num_heap_allocated_(0), | |
data_heap_(NULL) { | |
this->CopyFrom(begin, end); | |
} | |
#if MSHADOW_IN_CXX11 | |
/*! | |
* \brief move constructor from TShape | |
* \param s the source shape | |
*/ | |
TShape(TShape &&s) | |
: ndim_(s.ndim_), | |
num_heap_allocated_(s.num_heap_allocated_), | |
data_heap_(s.data_heap_) { | |
if (ndim_ <= kStackCache) { | |
std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_); | |
} | |
// remove data heap space from s | |
s.data_heap_ = NULL; | |
} | |
/*! | |
* \brief move constructor from Shape | |
* \param s the source shape | |
*/ | |
template<int dim> | |
TShape(Shape<dim> &&s) // NOLINT(*) | |
: ndim_(0), | |
num_heap_allocated_(0), | |
data_heap_(NULL) { | |
this->CopyFrom(s.shape_, s.shape_ + dim); | |
} | |
#endif | |
/*! \brief destructor */ | |
~TShape() { | |
// data_heap_ can be NULL | |
delete [] data_heap_; | |
} | |
/*! | |
* \brief copy shape from content betwen two iterators | |
* \param begin the beginning of iterator | |
* \param end the end of the iterator | |
* \tparam RandomAccessIterator iterator type | |
*/ | |
template<typename RandomAccessIterator> | |
inline void CopyFrom(RandomAccessIterator begin, | |
RandomAccessIterator end) { | |
this->SetDim(end - begin); | |
std::copy(begin, end, data()); | |
} | |
/*! | |
* \brief assignment from shape | |
* \param shape source shape | |
* \return reference of self | |
*/ | |
inline TShape &operator=(const TShape &shape) { | |
this->SetDim(shape.ndim_); | |
const index_t *src = shape.data(); | |
std::copy(src, src + ndim_, data()); | |
return *this; | |
} | |
/*! | |
* \brief assignment from vector | |
* \param shape source shape | |
* \return reference of self | |
*/ | |
inline TShape &operator=(const std::vector<index_t> &shape) { | |
this->CopyFrom(shape.begin(), shape.end()); | |
return *this; | |
} | |
/*! | |
* \brief assignment from shape | |
* \param shape source shape | |
* \tparam dim shape dimension | |
* \return reference of self | |
*/ | |
template<int dim> | |
inline TShape &operator=(const Shape<dim> &shape) { | |
this->SetDim(dim); | |
index_t *d = dim <= kStackCache ? data_stack_ : data_heap_; | |
for (int i = 0; i < dim; ++i) { | |
d[i] = shape[i]; | |
} | |
return *this; | |
} | |
/*! \return the data content of the shape */ | |
inline const index_t *data() const { | |
return ndim_ <= kStackCache ? data_stack_ : data_heap_; | |
} | |
/*! \return the data content of the shape */ | |
inline index_t *data() { | |
return ndim_ <= kStackCache ? data_stack_ : data_heap_; | |
} | |
/*! \brief return number of dimension of the tensor inside */ | |
inline index_t ndim(void) const { | |
return ndim_; | |
} | |
/*! | |
* \brief get corresponding index | |
* \param i dimension index | |
* \return the corresponding dimension size | |
*/ | |
inline index_t &operator[](index_t i) { | |
return data()[i]; | |
} | |
/*! | |
* \brief get corresponding index | |
* \param i dimension index | |
* \return the corresponding dimension size | |
*/ | |
inline const index_t &operator[](index_t i) const { | |
return data()[i]; | |
} | |
/*! \brief total number of elements in the tensor */ | |
inline size_t Size(void) const { | |
size_t size = 1; | |
const index_t *d = this->data(); | |
for (index_t i = 0; i < ndim_; ++i) { | |
size *= d[i]; | |
} | |
return size; | |
} | |
/*! | |
* flatten the higher dimension to second dimension, return a 2D shape | |
* \return the flat 2d shape | |
*/ | |
inline Shape<2> FlatTo2D(void) const { | |
Shape<2> s; | |
if (ndim_ == 0) return Shape2(0, 0); | |
const index_t *d = this->data(); | |
s.shape_[1] = d[ndim_ - 1]; | |
index_t ymax = 1; | |
for (index_t i = 1; i < ndim_; ++i) { | |
ymax *= d[i - 1]; | |
} | |
s.shape_[0] = ymax; | |
return s; | |
} | |
/*! | |
* flatten the shape into three parts: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim) | |
* \param axis_begin The beginning axis specified. | |
* \param axis_end The ending axis specified. | |
* \return the flat 3d shape | |
*/ | |
inline Shape<3> FlatTo3D(index_t axis_begin, index_t axis_end) const { | |
CHECK(axis_end >= axis_begin); | |
Shape<3> s; | |
if (ndim_ == 0) return Shape3(0, 0, 0); | |
const index_t *d = this->data(); | |
s.shape_[0] = 1; | |
s.shape_[1] = 1; | |
s.shape_[2] = 1; | |
for (index_t i = 0; i < axis_begin; ++i) { | |
s.shape_[0] *= d[i]; | |
} | |
for (index_t i = axis_begin; i <= axis_end; ++i) { | |
s.shape_[1] *= d[i]; | |
} | |
for (index_t i = axis_end + 1; i < ndim_; ++i) { | |
s.shape_[2] *= d[i]; | |
} | |
return s; | |
} | |
/*! | |
* flatten the axis before and after the specified axis, so it becomes 3D tensor | |
* \param axis The axis specified. | |
* \return the flat 3d shape | |
*/ | |
inline Shape<3> FlatTo3D(index_t axis) const { | |
return FlatTo3D(axis, axis); | |
} | |
/*! | |
* \return product shape in [dimstart,dimend) | |
* \param dimstart start dimension | |
* \param dimend end dimension | |
*/ | |
inline index_t ProdShape(int dimstart, int dimend) const { | |
index_t num = 1; | |
const index_t *d = this->data(); | |
for (int i = dimstart; i < dimend; ++i) { | |
num *= d[i]; | |
} | |
return num; | |
} | |
/*! | |
* \brief get the shape of tensor specifying dim | |
* \return the shape requested | |
* \tparam dim dimension of the tensor | |
*/ | |
template<int dim> | |
inline Shape<dim> get(void) const { | |
CHECK_EQ(dim, ndim_) << "dimension do not match target dimension " << dim << " vs " << ndim_; | |
const index_t *d = this->data(); | |
Shape<dim> s; | |
for (int i = 0; i < dim; ++i) { | |
s[i] = d[i]; | |
} | |
return s; | |
} | |
/*! | |
* \return whether two shape equals | |
* \param s the shape to compare against | |
*/ | |
inline bool operator==(const TShape &s) const { | |
if (ndim_ != s.ndim_) return false; | |
if (ndim_ <= kStackCache) { | |
for (index_t i = 0; i < ndim_; ++i) { | |
if (data_stack_[i] != s.data_stack_[i]) return false; | |
} | |
} else { | |
for (index_t i = 0; i < ndim_; ++i) { | |
if (data_heap_[i] != s.data_heap_[i]) return false; | |
} | |
} | |
return true; | |
} | |
/*! | |
* \return whether two shape not equals | |
* \param s the shape to compare against | |
*/ | |
inline bool operator!=(const TShape &s) const { | |
return !(*this == s); | |
} | |
/*! | |
* \return whether two shape equals | |
* \param s the shape to compare against | |
* \tparam dim dimension of the shape | |
*/ | |
template<int dim> | |
inline bool operator==(const Shape<dim> &s) const { | |
if (ndim_ != dim) return false; | |
const index_t *d = dim <= kStackCache ? data_stack_ : data_heap_; | |
for (index_t i = 0; i < dim; ++i) { | |
if (d[i] != s.shape_[i]) return false; | |
} | |
return true; | |
} | |
/*! | |
* \return whether two shape not equals | |
* \param s the shape to compare against | |
* \tparam dim dimension of the shape | |
*/ | |
template<int dim> | |
inline bool operator!=(const Shape<dim> &s) const { | |
return !(*this == s); | |
} | |
/*! | |
* \brief save the content into binary stream | |
* \param strm the output stream | |
* \tparam TStream any stream type that have write | |
*/ | |
template<typename TStream> | |
inline void Save(TStream *strm) const { | |
strm->Write(&ndim_, sizeof(ndim_)); | |
strm->Write(data(), sizeof(index_t) * ndim_); | |
} | |
/*! | |
* \brief load the content from binary stream | |
* \param strm the output stream | |
* \tparam TStream any stream type that have write | |
* \return whether the load is successful | |
*/ | |
template<typename TStream> | |
inline bool Load(TStream *strm) { | |
if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false; | |
this->SetDim(ndim_); | |
size_t nread = sizeof(index_t) * ndim_; | |
if (strm->Read(data(), nread) != nread) return false; | |
return true; | |
} | |
friend std::ostream &operator<<(std::ostream &os, const TShape &shape); | |
friend std::istream &operator>>(std::istream &is, TShape &shape); | |
private: | |
// the shape will be stored in data_stack_ | |
// when dimension is smaller than kStackCache | |
// when it is bigger, it will be stored in data_heap_; | |
/*! \brief size of in stack space */ | |
static const index_t kStackCache = 4; | |
/*! \brief number of dimension of the shape */ | |
index_t ndim_; | |
/*! \brief number of cells allocated in data_heap_ */ | |
index_t num_heap_allocated_; | |
/*! \brief in stack space used to store shape when it is small */ | |
index_t data_stack_[kStackCache]; | |
/*! \brief space to store shape when dimension is big*/ | |
index_t *data_heap_; | |
/*! | |
* \brief internal function to set the dimension | |
* \param dim the dimension of the shape | |
*/ | |
inline void SetDim(index_t dim) { | |
if (dim > kStackCache && | |
dim > num_heap_allocated_) { | |
// data_heap_ can be NULL | |
delete [] data_heap_; | |
data_heap_ = new index_t[dim]; | |
num_heap_allocated_ = dim; | |
} | |
ndim_ = dim; | |
} | |
}; | |
/*! | |
* \brief allow string printing of the shape | |
* \param os the output stream | |
* \param shape the shape | |
* \return the ostream | |
*/ | |
inline std::ostream &operator<<(std::ostream &os, const TShape &shape) { | |
os << '('; | |
for (index_t i = 0; i < shape.ndim(); ++i) { | |
if (i != 0) os << ','; | |
os << shape[i]; | |
} | |
// python style tuple | |
if (shape.ndim() == 1) os << ','; | |
os << ')'; | |
return os; | |
} | |
/*! | |
* \brief read shape from the istream | |
* \param is the input stream | |
* \param shape the shape | |
* \return the istream | |
*/ | |
inline std::istream &operator>>(std::istream &is, TShape &shape) { | |
// get ( | |
while (true) { | |
char ch = is.peek(); | |
if (isdigit(ch)) { | |
index_t idx; | |
if (is >> idx) { | |
shape.CopyFrom(&idx, &idx + 1); | |
} | |
return is; | |
} | |
is.get(); | |
if (ch == '(') break; | |
if (!isspace(ch)) { | |
is.setstate(std::ios::failbit); | |
return is; | |
} | |
} | |
index_t idx; | |
std::vector<index_t> tmp; | |
while (is >> idx) { | |
tmp.push_back(idx); | |
char ch; | |
do { | |
ch = is.get(); | |
} while (isspace(ch)); | |
if (ch == 'L') { | |
ch = is.get(); | |
} | |
if (ch == ',') { | |
while (true) { | |
ch = is.peek(); | |
if (isspace(ch)) { | |
is.get(); continue; | |
} | |
if (ch == ')') { | |
is.get(); break; | |
} | |
break; | |
} | |
if (ch == ')') break; | |
} else if (ch == ')') { | |
break; | |
} else { | |
is.setstate(std::ios::failbit); | |
return is; | |
} | |
} | |
shape.CopyFrom(tmp.begin(), tmp.end()); | |
return is; | |
} | |
/*! | |
* \brief tensor blob class that can be used to hold tensor of any dimension, | |
* any device and any data type, | |
* This is a weak type that can be used to transfer data through interface | |
* TBlob itself do not involve any arithmetic operations, | |
* but it can be converted to tensor of fixed dimension for further operations | |
* | |
* Like tensor, this data structure is like a pointer class and do not | |
* implicit allocated, de-allocate space. | |
* This data structure can be helpful to hold tensors of different dimensions | |
* and wait for further processing | |
*/ | |
class TBlob { | |
public: | |
/*! \brief pointer to the data */ | |
void *dptr_; | |
/*! \brief shape of the tensor */ | |
TShape shape_; | |
/*! | |
* \brief storing the stride information in x dimension | |
*/ | |
index_t stride_; | |
/*! \brief device mask of the corresponding device */ | |
int dev_mask_; | |
/*! \brief type flag of the tensor blob */ | |
int type_flag_; | |
/*! \brief default constructor, default copy assign will work */ | |
TBlob(void) | |
: dptr_(NULL), dev_mask_(cpu::kDevMask), | |
type_flag_(DataType<default_real_t>::kFlag) {} | |
/*! | |
* \brief constructor that construct TBlob from contiguous memory | |
* \param dptr the pointer to the memory | |
* \param shape the shape of the data | |
* \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask | |
*/ | |
template<typename DType> | |
TBlob(DType *dptr, | |
const TShape &shape, | |
int dev_mask) | |
: dptr_(dptr), shape_(shape), | |
stride_(shape[shape.ndim() - 1]), | |
dev_mask_(dev_mask), | |
type_flag_(DataType<DType>::kFlag) {} | |
/*! | |
* \brief constructor that construct TBlob from contiguous memory | |
* \param dptr the pointer to the memory | |
* \param shape the shape of the data | |
* \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask | |
* \param type_flag the type flag. Can be one of enum mshadow::dtype | |
*/ | |
TBlob(void *dptr, | |
const TShape &shape, | |
int dev_mask, | |
int type_flag) | |
: dptr_(dptr), shape_(shape), | |
stride_(shape[shape.ndim() - 1]), | |
dev_mask_(dev_mask), | |
type_flag_(type_flag) {} | |
/*! | |
* \brief constructor from tensor | |
* \param src source tensor | |
* \tparam Device which device the tensor is on | |
* \tparam dim tensor dimension | |
* \tparam DType the type of elements in the tensor | |
*/ | |
template<typename Device, int dim, typename DType> | |
TBlob(const Tensor<Device, dim, DType> &src) { // NOLINT(*) | |
*this = src; | |
} | |
/*! | |
* \brief assignment from tensor | |
* \param src source tensor | |
* \tparam Device which device the tensor is on | |
* \tparam dim tensor dimension | |
* \tparam DType the type of elements in the tensor | |
* \return reference of self | |
*/ | |
template<typename Device, int dim, typename DType> | |
inline TBlob | |
&operator=(const Tensor<Device, dim, DType> &src) { | |
dptr_ = src.dptr_; | |
shape_ = src.shape_; | |
stride_ = src.stride_; | |
dev_mask_ = Device::kDevMask; | |
type_flag_ = DataType<DType>::kFlag; | |
return *this; | |
} | |
/*! | |
* \return whether the tensor's memory is continuous | |
*/ | |
inline bool CheckContiguous(void) const { | |
return shape_[shape_.ndim() - 1] == stride_; | |
} | |
/*! | |
* \brief flatten the tensor to 2 dimension, collapse the higher dimensions together | |
* \param stream the possible stream target tensor should reside on | |
* \tparam Device which device the tensor is on | |
* \tparam DType the type of elements in the tensor | |
* \return tensor after flatten | |
*/ | |
template<typename Device, typename DType> | |
inline Tensor<Device, 2, DType> FlatTo2D(Stream<Device> *stream = NULL) const { | |
CHECK(Device::kDevMask == dev_mask_) | |
<< "TBlob.get: device type do not match specified type"; | |
CHECK(DataType<DType>::kFlag == type_flag_) | |
<< "TBlob.get_with_shape: data type do not match specified type." | |
<< "Expected: " << type_flag_ << " v.s. given " << DataType<DType>::kFlag; | |
return Tensor<Device, 2, DType>(static_cast<DType*>(dptr_), | |
shape_.FlatTo2D(), stride_, stream); | |
} | |
/*! \brief return number of dimension of the tensor inside */ | |
inline int ndim(void) const { | |
return shape_.ndim(); | |
} | |
/*! | |
* \brief return size of i-th dimension, start counting from highest dimension | |
* \param idx the dimension count from the highest dimension | |
* \return the size | |
*/ | |
inline index_t size(index_t idx) const { | |
return shape_[idx]; | |
} | |
/*! \brief total number of elements in the tensor */ | |
inline index_t Size(void) const { | |
return shape_.Size(); | |
} | |
/*! | |
* \brief fetch the tensor, with respect to specific dimension | |
* if dim do not match the stored dimension, an error will be issued | |
* \return the tensor requested | |
* \param stream the possible stream target tensor should reside on | |
* \tparam Device which device the tensor is on | |
* \tparam dim dimension of the tensor | |
* \tparam DType the type of elements in the tensor | |
*/ | |
template<typename Device, int dim, typename DType> | |
inline Tensor<Device, dim, DType> get(Stream<Device> *stream = NULL) const { | |
CHECK(Device::kDevMask == dev_mask_) | |
<< "TBlob.get: device type do not match specified type"; | |
CHECK(DataType<DType>::kFlag == type_flag_) | |
<< "TBlob.get_with_shape: data type do not match specified type." | |
<< "Expected: " << type_flag_ << " v.s. given " << DataType<DType>::kFlag; | |
return Tensor<Device, dim, DType>(static_cast<DType*>(dptr_), | |
shape_.get<dim>(), | |
stride_, stream); | |
} | |
/*! | |
* \brief fetch a tensor in given shape | |
* If size do not match the stored size, an error will be issued | |
* \return the tensor requested | |
* \param shape the shape required | |
* \param stream the possible stream target tensor should reside on | |
* \tparam Device which device the tensor is on | |
* \tparam dim dimension of the tensor | |
* \tparam DType the type of elements in the tensor | |
*/ | |
template<typename Device, int dim, typename DType> | |
inline Tensor<Device, dim, DType> get_with_shape(const Shape<dim> &shape, | |
Stream<Device> *stream = NULL) const { | |
CHECK(Device::kDevMask == dev_mask_) | |
<< "TBlob.get: device type do not match specified type"; | |
CHECK(DataType<DType>::kFlag == type_flag_) | |
<< "TBlob.get_with_shape: data type do not match specified type." | |
<< "Expected: " << type_flag_ << " v.s. given " << DataType<DType>::kFlag; | |
CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be contiguous"; | |
CHECK_EQ(this->shape_.Size(), shape.Size()) | |
<< "TBlob.get_with_shape: new and old shape do not match total elements"; | |
return Tensor<Device, dim, DType>(static_cast<DType*>(dptr_), | |
shape, | |
shape[dim - 1], | |
stream); | |
} | |
/*! | |
* \brief flatten the tensor to 3 dimension, | |
* collapse the dimension before and after specified axis. | |
* \param axis The axis specified. | |
* \param stream the possible stream target tensor should reside on | |
* \tparam Device which device the tensor is on | |
* \tparam DType the type of elements in the tensor | |
* \return tensor after flatten | |
*/ | |
template<typename Device, typename DType> | |
inline Tensor<Device, 3, DType> FlatTo3D(int axis, Stream<Device> *stream = NULL) const { | |
return this->get_with_shape<Device, 3, DType>( | |
this->shape_.FlatTo3D(axis), stream); | |
} | |
/*! | |
* \brief flatten the tensor to 3 dimension, | |
* collapse the dimension: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim). | |
* \param axis_begin The beginning axis specified. | |
* \param axis_end The ending axis specified. | |
* \param stream the possible stream target tensor should reside on | |
* \tparam Device which device the tensor is on | |
* \tparam DType the type of elements in the tensor | |
* \return tensor after flatten | |
*/ | |
template<typename Device, typename DType> | |
inline Tensor<Device, 3, DType> FlatTo3D(int axis_begin, int axis_end, | |
Stream<Device> *stream = NULL) const { | |
return this->get_with_shape<Device, 3, DType>( | |
this->shape_.FlatTo3D(axis_begin, axis_end), stream); | |
} | |
}; | |
} // namespace mshadow | |
#endif // MSHADOW_TENSOR_BLOB_H_ | |
//===== EXPANDED: ../mshadow/mshadow/tensor_blob.h ===== | |
//===== EXPANDING: ../mshadow/mshadow/random.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file random.h | |
* \brief Random inline functions for tensor. | |
* \author Bing Xu, Tianqi Chen | |
* Based on curand|MKL|stdlib | |
*/ | |
#ifndef MSHADOW_RANDOM_H_ | |
#define MSHADOW_RANDOM_H_ | |
#if MSHADOW_IN_CXX11 | |
#endif | |
#if _MSC_VER | |
#define rand_r(x) rand() | |
#endif | |
namespace mshadow { | |
/*! | |
* \brief random number generator | |
* \tparam Device the device of random number generator | |
* \tparam DType the target data type of random number can be float for double | |
*/ | |
template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE> | |
class Random {}; | |
/*! \brief CPU random number generator */ | |
template<typename DType> | |
class Random<cpu, DType> { | |
public: | |
/*! | |
* \brief constructor of random engine | |
* \param seed random number seed | |
*/ | |
explicit Random(int seed) { | |
this->Seed(seed); | |
buffer_.Resize(Shape1(kRandBufferSize)); | |
} | |
~Random(void) { | |
} | |
/*! | |
* \brief seed random number generator using this seed | |
* \param seed seed of prng | |
*/ | |
inline void Seed(int seed) { | |
#if MSHADOW_IN_CXX11 | |
rnd_engine_.seed(seed); | |
#endif | |
this->rseed_ = static_cast<unsigned>(seed); | |
} | |
/*! | |
* \brief get random seed used in random generator | |
* \return seed in unsigned | |
*/ | |
inline unsigned GetSeed() const { | |
return rseed_; | |
} | |
/*! | |
* \brief set the stream of computation | |
* \param stream computation stream | |
*/ | |
inline void set_stream(Stream<cpu> *stream) { | |
} | |
/*! | |
* \brief generate data from uniform [a,b) | |
* \param dst destination | |
* \param a lower bound of uniform | |
* \param b upper bound of uniform | |
* \tparam dim dimension of tensor | |
*/ | |
template<int dim> | |
inline void SampleUniform(Tensor<cpu, dim, DType> *dst, | |
DType a = 0.0f, DType b = 1.0f) { | |
if (dst->CheckContiguous()) { | |
this->GenUniform(dst->dptr_, dst->shape_.Size(), a, b); | |
} else { | |
Tensor<cpu, 2, DType> mat = dst->FlatTo2D(); | |
for (index_t i = 0; i < mat.size(0); ++i) { | |
this->GenUniform(mat[i].dptr_, mat.size(1), a, b); | |
} | |
} | |
} | |
/*! | |
* \brief generate data from standard gaussian | |
* \param dst destination | |
* \param mu mean variable | |
* \param sigma standard deviation | |
* \tparam dim dimension of tensor | |
*/ | |
template<int dim> | |
inline void SampleGaussian(Tensor<cpu, dim, DType> *dst, | |
DType mu = 0.0f, DType sigma = 1.0f) { | |
if (sigma <= 0.0f) { | |
*dst = mu; return; | |
} | |
if (dst->CheckContiguous()) { | |
this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma); | |
} else { | |
Tensor<cpu, 2, DType> mat = dst->FlatTo2D(); | |
for (index_t i = 0; i < mat.size(0); ++i) { | |
this->GenGaussian(mat[i].dptr_, mat.size(1), mu, sigma); | |
} | |
} | |
} | |
/*! | |
* \brief return a temporal expression storing standard gaussian random variables | |
* the temporal tensor is only valid before next call of gaussian or uniform | |
* can be used as part of expression | |
* Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, | |
* since second call of gaussian(s2) makes gaussian(s1) invalid | |
* A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression | |
* \param shape shape of the tensor | |
* \return a temporal expression storing standard gaussian random variables | |
* \tparam dim dimension of tensor | |
*/ | |
template<int dim> | |
inline expr::ReshapeExp<Tensor<cpu, 1, DType>, DType, dim, 1> | |
gaussian(Shape<dim> shape) { | |
buffer_.Resize(Shape1(shape.Size())); | |
this->SampleGaussian(&buffer_, 0.0f, 1.0f); | |
return expr::reshape(buffer_, shape); | |
} | |
/*! | |
* \brief return a temporal expression storing standard uniform [0,1) | |
* the temporal tensor is only valid before next call of gaussian or uniform | |
* can be used as part of expression | |
* Caution: this means expression such as A = uniform(s1) * uniform(s2) will give invalid result, | |
* since second call of gaussian(s2) makes gaussian(s1) invalid | |
* A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression | |
* \param shape shape of the tensor | |
* \return a temporal expression storing standard uniform [0,1) | |
* \tparam dim dimension of tensor | |
*/ | |
template<int dim> | |
inline expr::ReshapeExp<Tensor<cpu, 1, DType>, DType, dim, 1> | |
uniform(Shape<dim> shape) { | |
buffer_.Resize(Shape1(shape.Size())); | |
this->SampleUniform(&buffer_, 0.0f, 1.0f); | |
return expr::reshape(buffer_, shape); | |
} | |
private: | |
#if MSHADOW_IN_CXX11 | |
/*! \brief use c++11 random engine. */ | |
std::mt19937 rnd_engine_; | |
/*! \brief random number seed used in random engine */ | |
unsigned rseed_; | |
// implementing generators. | |
inline void GenUniform(DType *dptr, index_t size, DType a, DType b) { | |
std::uniform_real_distribution<DType> dist_uniform(a, b); | |
for (size_t i = 0; i < size; ++i) { | |
dptr[i] = dist_uniform(rnd_engine_); | |
} | |
} | |
inline void GenGaussian(DType *dptr, index_t size, DType mu, DType sigma) { | |
std::normal_distribution<DType> dist_normal(mu, sigma); | |
for (size_t i = 0; i < size; ++i) { | |
dptr[i] = dist_normal(rnd_engine_); | |
} | |
} | |
#else | |
/*! \brief random number seed used by PRNG */ | |
unsigned rseed_; | |
// functions | |
inline void GenUniform(float *dptr, index_t size, float a, float b) { | |
for (index_t j = 0; j < size; ++j) { | |
dptr[j] = static_cast<float>(RandNext()) * (b - a) + a; | |
} | |
} | |
inline void GenUniform(double *dptr, index_t size, double a, double b) { | |
for (index_t j = 0; j < size; ++j) { | |
dptr[j] = static_cast<double>(RandNext()) * (b - a) + a; | |
} | |
} | |
inline void GenGaussian(float *dptr, index_t size, float mu, float sigma) { | |
this->GenGaussianX(dptr, size, mu, sigma); | |
} | |
inline void GenGaussian(double *dptr, index_t size, double mu, double sigma) { | |
this->GenGaussianX(dptr, size, mu, sigma); | |
} | |
inline void GenGaussianX(DType *dptr, index_t size, DType mu, DType sigma) { | |
DType g1 = 0.0f, g2 = 0.0f; | |
for (index_t j = 0; j < size; ++j) { | |
if ((j & 1) == 0) { | |
this->SampleNormal2D(&g1, &g2); | |
dptr[j] = mu + g1 * sigma; | |
} else { | |
dptr[j] = mu + g2 * sigma; | |
} | |
} | |
} | |
/*! \brief get next random number from rand */ | |
inline DType RandNext(void) { | |
return static_cast<DType>(rand_r(&rseed_)) / | |
(static_cast<DType>(RAND_MAX) + 1.0f); | |
} | |
/*! \brief return a real numer uniform in (0,1) */ | |
inline DType RandNext2(void) { | |
return (static_cast<DType>(rand_r(&rseed_)) + 1.0f) / | |
(static_cast<DType>(RAND_MAX) + 2.0f); | |
} | |
/*! | |
* \brief sample iid xx,yy ~N(0,1) | |
* \param xx first gaussian output | |
* \param yy second gaussian output | |
*/ | |
inline void SampleNormal2D(DType *xx_, DType *yy_) { | |
DType &xx = *xx_, &yy = *yy_; | |
DType x, y, s; | |
do { | |
x = 2.0f * RandNext2() - 1.0f; | |
y = 2.0f * RandNext2() - 1.0f; | |
s = x * x + y * y; | |
} while (s >= 1.0f || s == 0.0f); | |
DType t = std::sqrt(-2.0f * std::log(s) / s); | |
xx = x * t; yy = y * t; | |
} | |
#endif | |
/*! \brief temporal space used to store random numbers */ | |
TensorContainer<cpu, 1, DType> buffer_; | |
}; // class Random<cpu, DType> | |
// only allow GPU PRNG when cuda is enabled | |
#if MSHADOW_USE_CUDA | |
/*! \brief GPU random number generator */ | |
template<typename DType> | |
class Random<gpu, DType> { | |
public: | |
/*! | |
* \brief constructor of random engine | |
* \param seed random number seed | |
*/ | |
explicit Random(int seed) { | |
curandStatus_t status; | |
status = curandCreateGenerator(&gen_, CURAND_RNG_PSEUDO_DEFAULT); | |
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Can not create CURAND Generator"; | |
this->Seed(seed); | |
buffer_.Resize(Shape1(kRandBufferSize)); | |
} | |
~Random(void) MSHADOW_THROW_EXCEPTION { | |
curandStatus_t status; | |
status = curandDestroyGenerator(gen_); | |
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Destory CURAND Gen failed"; | |
} | |
/*! | |
* \brief set the stream of computation | |
* \param stream computation stream | |
*/ | |
inline void set_stream(Stream<gpu> *stream) { | |
curandStatus_t status; | |
status = curandSetStream(gen_, Stream<gpu>::GetStream(stream)); | |
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "set_stream CURAND failed"; | |
} | |
/*! | |
* \brief seed random number generator using this seed | |
* \param seed seed of prng | |
*/ | |
inline void Seed(int seed) { | |
curandStatus_t status; | |
status = curandSetPseudoRandomGeneratorSeed(gen_, seed); | |
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Set CURAND seed failed."; | |
} | |
/*! | |
* \brief generate data from uniform [a,b) | |
* \param dst destination | |
* \param a lower bound of uniform | |
* \param b upper bound of uniform | |
* \tparam dim dimension of tensor | |
*/ | |
template<int dim> | |
inline void SampleUniform(Tensor<gpu, dim, DType> *dst, | |
DType a = 0.0f, DType b = 1.0f); | |
/*! | |
* \brief generate data from standard gaussian | |
* \param dst destination | |
* \param mu mean variable | |
* \param sigma standard deviation | |
* \tparam dim dimension of tensor | |
*/ | |
template<int dim> | |
inline void SampleGaussian(Tensor<gpu, dim, DType> *dst, | |
DType mu = 0.0f, DType sigma = 1.0f); | |
/*! | |
* \brief return a temporal expression storing standard gaussian random variables | |
* the temporal tensor is only valid before next call of gaussian or uniform | |
* can be used as part of expression | |
* Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, | |
* since second call of gaussian(s2) makes gaussian(s1) invalid | |
* A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression | |
* \param shape shape of the tensor | |
* \param mu mean | |
* \param sigma variance | |
* \return a temporal expression storing standard gaussian random variables | |
* \tparam dim dimension of tensor | |
*/ | |
template<int dim> | |
inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1> | |
gaussian(Shape<dim> shape, DType mu = 0.0f, DType sigma = 1.0f); | |
/*! | |
* \brief return a temporal expression storing standard uniform [0,1) | |
* the temporal tensor is only valid before next call of gaussian or uniform | |
* can be used as part of expression | |
* Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, | |
* since second call of gaussian(s2) makes gaussian(s1) invalid | |
* A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression | |
* \param shape shape of the tensor | |
* \return a temporal expression storing standard uniform [0,1) | |
* \tparam dim dimension of tensor | |
*/ | |
template<int dim> | |
inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1> | |
uniform(Shape<dim> shape); | |
private: | |
inline void GenGaussian(float *dptr, size_t size, float mu, float sigma) { | |
curandStatus_t status; | |
status = curandGenerateNormal(gen_, dptr, size, mu, sigma); | |
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal float failed." | |
<< " size = " << size | |
<< ",mu = " << mu | |
<< ",sigma = " << sigma; | |
} | |
inline void GenGaussian(double *dptr, size_t size, double mu, double sigma) { | |
curandStatus_t status; | |
status = curandGenerateNormalDouble(gen_, dptr, size, mu, sigma); | |
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal double failed." | |
<< " size = " << size | |
<< ",mu = " << mu | |
<< ",sigma = " << sigma; | |
} | |
inline void GenUniform(float *dptr, size_t size) { | |
curandStatus_t status; | |
status = curandGenerateUniform(gen_, dptr, size); | |
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform float failed." | |
<< " size = " << size; | |
} | |
inline void GenUniform(double *dptr, size_t size) { | |
curandStatus_t status; | |
status = curandGenerateUniformDouble(gen_, dptr, size); | |
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform double failed." | |
<< " size = " << size; | |
} | |
/*! \brief random numbeer generator */ | |
curandGenerator_t gen_; | |
/*! \brief templ buffer */ | |
TensorContainer<gpu, 1, DType> buffer_; | |
}; // class Random<gpu, DType> | |
#endif // MSHADOW_USE_CUDA | |
#ifdef __CUDACC__ | |
// implementations that depends on cuda kernels | |
template<typename DType> | |
template<int dim> | |
inline void Random<gpu, DType>::SampleUniform( | |
Tensor<gpu, dim, DType> *dst, DType a, DType b) { | |
if (a == 0.0f && b == 1.0f) { | |
if (dst->CheckContiguous()) { | |
this->GenUniform(dst->dptr_, dst->shape_.Size()); | |
} else { | |
*dst = this->uniform(dst->shape_); | |
} | |
} else { | |
*dst = this->uniform(dst->shape_) * (b - a) + a; | |
} | |
} | |
template<typename DType> | |
template<int dim> | |
inline void Random<gpu, DType>::SampleGaussian( | |
Tensor<gpu, dim, DType> *dst, DType mu, DType sigma) { | |
// We need to check whether the shape size is even since CuRand supports only normal distribution | |
// generation of even number of elements. | |
if (dst->CheckContiguous() && (dst->shape_.Size() % 2 == 0)) { | |
this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma); | |
} else { | |
*dst = this->gaussian(dst->shape_, mu, sigma); | |
} | |
} | |
template<typename DType> | |
template<int dim> | |
inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1> | |
Random<gpu, DType>::gaussian(Shape<dim> shape, DType mu, DType sigma) { | |
size_t aligned_sz = ((shape.Size() + 1UL) >> 1) << 1; | |
// allocate alligned size | |
buffer_.Resize(Shape1(aligned_sz)); | |
buffer_.Resize(Shape1(shape.Size())); | |
this->GenGaussian(buffer_.dptr_, aligned_sz, mu, sigma); | |
return expr::reshape(buffer_, shape); | |
} | |
template<typename DType> | |
template<int dim> | |
inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1> | |
Random<gpu, DType>::uniform(Shape<dim> shape) { | |
buffer_.Resize(Shape1(shape.Size())); | |
this->GenUniform(buffer_.dptr_, buffer_.size(0)); | |
return expr::reshape(buffer_, shape); | |
} | |
#endif // __CUDACC__ | |
} // namespace mshadow | |
#endif // MSHADOW_RANDOM_H_ | |
//===== EXPANDED: ../mshadow/mshadow/random.h ===== | |
// add definition of scalar related operators | |
#ifdef MSHADOW_SCALAR_ | |
#error "MSHADOW_SCALAR_ must not be defined" | |
#endif | |
// enumerate all the scalar data type we aim to be good at | |
#define MSHADOW_SCALAR_ float | |
//===== EXPANDING: ../mshadow/mshadow/expr_scalar-inl.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file expr_scalar-inl.h | |
* \brief definitions of operators in expression with respect to scalar | |
* this file will be included several times, each time with MACRO MSHADOW_SCALAR_ to be different types | |
* | |
* DO NOT add pragma once or macro guard | |
* \author Tianqi Chen, Bing Xu | |
*/ | |
// macro guard is harmful, used to pass the cpplint | |
#ifndef MSHADOW_EXPR_SCALAR_INL_H_ | |
#define MSHADOW_EXPR_SCALAR_INL_H_ | |
// undef the guard so it can be included multiple times | |
#undef MSHADOW_EXPR_SCALAR_INL_H_ | |
namespace mshadow { | |
namespace expr { | |
// DotExp | |
/*! \brief dot operator def */ | |
template<typename TA, typename TB, bool ltrans, bool rtrans> | |
inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_> | |
operator*(const DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_> &lhs, | |
MSHADOW_SCALAR_ rhs) { | |
return DotExp<TA, TB, ltrans, rtrans, | |
MSHADOW_SCALAR_>(lhs.lhs_, lhs.rhs_, lhs.scale_ * rhs); | |
} | |
/*! \brief scale of dot operation */ | |
template<typename TA, typename TB, bool ltrans, bool rtrans> | |
inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_> | |
operator*(MSHADOW_SCALAR_ lhs, | |
const DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_> &rhs) { | |
return DotExp<TA, TB, ltrans, rtrans, | |
MSHADOW_SCALAR_>(rhs.lhs_, rhs.rhs_, rhs.scale_ * lhs); | |
} | |
/*! \brief operator overload */ | |
template<typename E, typename DType, typename R, int d> | |
inline ReduceTo1DExp<E, DType, R, d> | |
operator*(const ReduceTo1DExp<E, DType, R, d> &e, MSHADOW_SCALAR_ scale) { | |
return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale); | |
} | |
/*! \brief operator overload */ | |
template<typename E, typename DType, typename R, int d> | |
inline ReduceTo1DExp<E, DType, R, d> | |
operator*(MSHADOW_SCALAR_ scale, const ReduceTo1DExp<E, DType, R, d> &e) { | |
return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale); | |
} | |
/*! \brief operator overload for const */ | |
template<typename OP, typename TA, int ta> | |
inline BinaryMapExp<OP, TA, ScalarExp<MSHADOW_SCALAR_>, | |
MSHADOW_SCALAR_, (ta|type::kMapper)> | |
F(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) { | |
return MakeExp<OP>(lhs, rhs); | |
} | |
/*! \brief operator overload for const */ | |
template<typename OP, typename TB, int tb> | |
inline BinaryMapExp<OP, ScalarExp<MSHADOW_SCALAR_>, TB, | |
MSHADOW_SCALAR_, (tb|type::kMapper)> | |
F(const ScalarExp<MSHADOW_SCALAR_> &lhs, const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) { | |
return MakeExp<OP>(lhs, rhs); | |
} | |
/*! \brief operator overload for const */ | |
template<typename OP> | |
inline BinaryMapExp<OP, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>, | |
MSHADOW_SCALAR_, (1|type::kMapper)> | |
F(const ScalarExp<MSHADOW_SCALAR_> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) { | |
return MakeExp<OP>(lhs, rhs); | |
} | |
// constant operators | |
/*! \brief operator overload */ | |
template<typename TA, int ta> | |
inline BinaryMapExp<op::plus, TA, ScalarExp<MSHADOW_SCALAR_>, | |
MSHADOW_SCALAR_, (ta|type::kMapper)> | |
operator+(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs, | |
const ScalarExp<MSHADOW_SCALAR_> &rhs) { | |
return MakeExp<op::plus>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
template<typename TA, int ta> | |
inline BinaryMapExp<op::minus, TA, ScalarExp<MSHADOW_SCALAR_>, | |
MSHADOW_SCALAR_, (ta|type::kMapper)> | |
operator-(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs, | |
const ScalarExp<MSHADOW_SCALAR_> &rhs) { | |
return MakeExp<op::minus>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
template<typename TA, int ta> | |
inline BinaryMapExp<op::mul, TA, ScalarExp<MSHADOW_SCALAR_>, | |
MSHADOW_SCALAR_, (ta|type::kMapper)> | |
operator*(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs, | |
const ScalarExp<MSHADOW_SCALAR_> &rhs) { | |
return MakeExp<op::mul>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
template<typename TA, int ta> | |
inline BinaryMapExp<op::div, TA, ScalarExp<MSHADOW_SCALAR_>, | |
MSHADOW_SCALAR_, (ta|type::kMapper)> | |
operator/(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs, | |
const ScalarExp<MSHADOW_SCALAR_> &rhs) { | |
return MakeExp<op::div>(lhs, rhs); | |
} | |
// constant operators 2 | |
/*! \brief operator overload */ | |
template<typename TB, int tb> | |
inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, TB, | |
MSHADOW_SCALAR_, (tb|type::kMapper)> | |
operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs, | |
const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) { | |
return MakeExp<op::plus>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
template<typename TB, int tb> | |
inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, TB, | |
MSHADOW_SCALAR_, (tb|type::kMapper)> | |
operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs, | |
const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) { | |
return MakeExp<op::minus>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
template<typename TB, int tb> | |
inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, TB, | |
MSHADOW_SCALAR_, (tb|type::kMapper)> | |
operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs, | |
const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) { | |
return MakeExp<op::mul>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
template<typename TB, int tb> | |
inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, TB, | |
MSHADOW_SCALAR_, (tb|type::kMapper)> | |
operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) { | |
return MakeExp<op::div>(lhs, rhs); | |
} | |
// constant operators 3 | |
/*! \brief operator overload */ | |
inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>, | |
MSHADOW_SCALAR_, (1|type::kMapper)> | |
operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs, | |
const ScalarExp<MSHADOW_SCALAR_> &rhs) { | |
return MakeExp<op::plus>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>, | |
MSHADOW_SCALAR_, (1|type::kMapper)> | |
operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs, | |
const ScalarExp<MSHADOW_SCALAR_> &rhs) { | |
return MakeExp<op::minus>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>, | |
MSHADOW_SCALAR_, (1|type::kMapper)> | |
operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs, | |
const ScalarExp<MSHADOW_SCALAR_> &rhs) { | |
return MakeExp<op::mul>(lhs, rhs); | |
} | |
/*! \brief operator overload */ | |
inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>, | |
MSHADOW_SCALAR_, (1|type::kMapper)> | |
operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) { | |
return MakeExp<op::div>(lhs, rhs); | |
} | |
} // namespace expr | |
} // namespace mshadow | |
#endif // MSHADOW_EXPR_SCALAR_INL_H_ | |
//===== EXPANDED: ../mshadow/mshadow/expr_scalar-inl.h ===== | |
#undef MSHADOW_SCALAR_ | |
#define MSHADOW_SCALAR_ double | |
#undef MSHADOW_SCALAR_ | |
#define MSHADOW_SCALAR_ int | |
#undef MSHADOW_SCALAR_ | |
#define MSHADOW_SCALAR_ mshadow::half::half_t | |
#undef MSHADOW_SCALAR_ | |
#endif // MSHADOW_TENSOR_H_ | |
//===== EXPANDED: ../mshadow/mshadow/tensor.h ===== | |
//===== EXPANDING: ../include/mxnet/base.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file base.h | |
* \brief configuation of mxnet as well as basic data structure. | |
*/ | |
#ifndef MXNET_BASE_H_ | |
#define MXNET_BASE_H_ | |
// nnvm headers for symbolic construction. | |
//===== EXPANDING: ../nnvm/include/nnvm/op.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file op.h | |
* \brief Operator information structor. | |
*/ | |
#ifndef NNVM_OP_H_ | |
#define NNVM_OP_H_ | |
//===== EXPANDING: ../nnvm/include/nnvm/base.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file base.h | |
* \brief Configuation of nnvm as well as basic data structure. | |
*/ | |
#ifndef NNVM_BASE_H_ | |
#define NNVM_BASE_H_ | |
//===== EXPANDING: ../dmlc-core/include/dmlc/memory.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file memory.h | |
* \brief Additional memory hanlding utilities. | |
*/ | |
#ifndef DMLC_MEMORY_H_ | |
#define DMLC_MEMORY_H_ | |
//===== EXPANDING: ../dmlc-core/include/dmlc/thread_local.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file thread_local.h | |
* \brief Portable thread local storage. | |
*/ | |
#ifndef DMLC_THREAD_LOCAL_H_ | |
#define DMLC_THREAD_LOCAL_H_ | |
namespace dmlc { | |
// macro hanlding for threadlocal variables | |
#ifdef __GNUC__ | |
#define MX_TREAD_LOCAL __thread | |
#elif __STDC_VERSION__ >= 201112L | |
#define MX_TREAD_LOCAL _Thread_local | |
#elif defined(_MSC_VER) | |
#define MX_TREAD_LOCAL __declspec(thread) | |
#endif | |
#ifndef MX_TREAD_LOCAL | |
#message("Warning: Threadlocal is not enabled"); | |
#endif | |
/*! | |
* \brief A threadlocal store to store threadlocal variables. | |
* Will return a thread local singleton of type T | |
* \tparam T the type we like to store | |
*/ | |
template<typename T> | |
class ThreadLocalStore { | |
public: | |
/*! \return get a thread local singleton */ | |
static T* Get() { | |
static MX_TREAD_LOCAL T* ptr = nullptr; | |
if (ptr == nullptr) { | |
ptr = new T(); | |
Singleton()->RegisterDelete(ptr); | |
} | |
return ptr; | |
} | |
private: | |
/*! \brief constructor */ | |
ThreadLocalStore() {} | |
/*! \brief destructor */ | |
~ThreadLocalStore() { | |
for (size_t i = 0; i < data_.size(); ++i) { | |
delete data_[i]; | |
} | |
} | |
/*! \return singleton of the store */ | |
static ThreadLocalStore<T> *Singleton() { | |
static ThreadLocalStore<T> inst; | |
return &inst; | |
} | |
/*! | |
* \brief register str for internal deletion | |
* \param str the string pointer | |
*/ | |
void RegisterDelete(T *str) { | |
std::unique_lock<std::mutex> lock(mutex_); | |
data_.push_back(str); | |
lock.unlock(); | |
} | |
/*! \brief internal mutex */ | |
std::mutex mutex_; | |
/*!\brief internal data */ | |
std::vector<T*> data_; | |
}; | |
} // namespace dmlc | |
#endif // DMLC_THREAD_LOCAL_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/thread_local.h ===== | |
namespace dmlc { | |
/*! | |
* \brief A memory pool that allocate memory of fixed size and alignment. | |
* \tparam size The size of each piece. | |
* \tparam align The alignment requirement of the memory. | |
*/ | |
template<size_t size, size_t align> | |
class MemoryPool { | |
public: | |
/*! \brief constructor */ | |
MemoryPool() { | |
static_assert(align % alignof(LinkedList) == 0, | |
"alignment requirement failed."); | |
curr_page_.reset(new Page()); | |
} | |
/*! \brief allocate a new memory of size */ | |
inline void* allocate() { | |
if (head_ != nullptr) { | |
LinkedList* ret = head_; | |
head_ = head_->next; | |
return ret; | |
} else { | |
if (page_ptr_ < kPageSize) { | |
return &(curr_page_->data[page_ptr_++]); | |
} else { | |
allocated_.push_back(std::move(curr_page_)); | |
curr_page_.reset(new Page()); | |
page_ptr_ = 1; | |
return &(curr_page_->data[0]); | |
} | |
} | |
} | |
/*! | |
* \brief deallocate a piece of memory | |
* \param p The pointer to the memory to be de-allocated. | |
*/ | |
inline void deallocate(void* p) { | |
LinkedList* ptr = static_cast<LinkedList*>(p); | |
ptr->next = head_; | |
head_ = ptr; | |
} | |
private: | |
// page size of each member | |
static const int kPageSize = ((1 << 22) / size); | |
// page to be requested. | |
struct Page { | |
typename std::aligned_storage<size, align>::type data[kPageSize]; | |
}; | |
// internal linked list structure. | |
struct LinkedList { | |
LinkedList* next{nullptr}; | |
}; | |
// head of free list | |
LinkedList* head_{nullptr}; | |
// current free page | |
std::unique_ptr<Page> curr_page_; | |
// pointer to the current free page position. | |
size_t page_ptr_{0}; | |
// allocated pages. | |
std::vector<std::unique_ptr<Page> > allocated_; | |
}; | |
/*! | |
* \brief A thread local allocator that get memory from a threadlocal memory pool. | |
* This is suitable to allocate objects that do not cross thread. | |
* \tparam T the type of the data to be allocated. | |
*/ | |
template<typename T> | |
class ThreadlocalAllocator { | |
public: | |
/*! \brief pointer type */ | |
typedef T* pointer; | |
/*! \brief const pointer type */ | |
typedef const T* const_ptr; | |
/*! \brief value type */ | |
typedef T value_type; | |
/*! \brief default constructor */ | |
ThreadlocalAllocator() {} | |
/*! | |
* \brief constructor from another allocator | |
* \param other another allocator | |
* \tparam U another type | |
*/ | |
template<typename U> | |
ThreadlocalAllocator(const ThreadlocalAllocator<U>& other) {} | |
/*! | |
* \brief allocate memory | |
* \param n number of blocks | |
* \return an uninitialized memory of type T. | |
*/ | |
inline T* allocate(size_t n) { | |
CHECK_EQ(n, 1); | |
typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store; | |
return static_cast<T*>(Store::Get()->allocate()); | |
} | |
/*! | |
* \brief deallocate memory | |
* \param p a memory to be returned. | |
* \param n number of blocks | |
*/ | |
inline void deallocate(T* p, size_t n) { | |
CHECK_EQ(n, 1); | |
typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store; | |
Store::Get()->deallocate(p); | |
} | |
}; | |
/*! | |
* \brief a shared pointer like type that allocate object | |
* from a threadlocal object pool. This object is not thread-safe | |
* but can be faster than shared_ptr in certain usecases. | |
* \tparam T the data type. | |
*/ | |
template<typename T> | |
struct ThreadlocalSharedPtr { | |
public: | |
/*! \brief default constructor */ | |
ThreadlocalSharedPtr() : block_(nullptr) {} | |
/*! | |
* \brief constructor from nullptr | |
* \param other the nullptr type | |
*/ | |
ThreadlocalSharedPtr(std::nullptr_t other) : block_(nullptr) {} // NOLINT(*) | |
/*! | |
* \brief copy constructor | |
* \param other another pointer. | |
*/ | |
ThreadlocalSharedPtr(const ThreadlocalSharedPtr<T>& other) | |
: block_(other.block_) { | |
IncRef(block_); | |
} | |
/*! | |
* \brief move constructor | |
* \param other another pointer. | |
*/ | |
ThreadlocalSharedPtr(ThreadlocalSharedPtr<T>&& other) | |
: block_(other.block_) { | |
other.block_ = nullptr; | |
} | |
/*! | |
* \brief destructor | |
*/ | |
~ThreadlocalSharedPtr() { | |
DecRef(block_); | |
} | |
/*! | |
* \brief move assignment | |
* \param other another object to be assigned. | |
* \return self. | |
*/ | |
inline ThreadlocalSharedPtr<T>& operator=(ThreadlocalSharedPtr<T>&& other) { | |
DecRef(block_); | |
block_ = other.block_; | |
other.block_ = nullptr; | |
return *this; | |
} | |
/*! | |
* \brief copy assignment | |
* \param other another object to be assigned. | |
* \return self. | |
*/ | |
inline ThreadlocalSharedPtr<T> &operator=(const ThreadlocalSharedPtr<T>& other) { | |
DecRef(block_); | |
block_ = other.block_; | |
IncRef(block_); | |
return *this; | |
} | |
/*! \brief check if nullptr */ | |
inline bool operator==(std::nullptr_t other) const { | |
return block_ == nullptr; | |
} | |
/*! | |
* \return get the pointer content. | |
*/ | |
inline T* get() const { | |
if (block_ == nullptr) return nullptr; | |
return reinterpret_cast<T*>(&(block_->data)); | |
} | |
/*! | |
* \brief reset the pointer to nullptr. | |
*/ | |
inline void reset() { | |
DecRef(block_); | |
block_ = nullptr; | |
} | |
/*! \return if use_count == 1*/ | |
inline bool unique() const { | |
if (block_ == nullptr) return false; | |
return block_->use_count_ == 1; | |
} | |
/*! \return dereference pointer */ | |
inline T* operator*() const { | |
return reinterpret_cast<T*>(&(block_->data)); | |
} | |
/*! \return dereference pointer */ | |
inline T* operator->() const { | |
return reinterpret_cast<T*>(&(block_->data)); | |
} | |
/*! | |
* \brief create a new space from threadlocal storage and return it. | |
* \tparam Args the arguments. | |
* \param args The input argument | |
* \return the allocated pointer. | |
*/ | |
template <typename... Args> | |
inline static ThreadlocalSharedPtr<T> Create(Args&&... args) { | |
ThreadlocalAllocator<RefBlock> arena; | |
ThreadlocalSharedPtr<T> p; | |
p.block_ = arena.allocate(1); | |
p.block_->use_count_ = 1; | |
new (&(p.block_->data)) T(std::forward<Args>(args)...); | |
return p; | |
} | |
private: | |
// internal reference block | |
struct RefBlock { | |
typename std::aligned_storage<sizeof(T), alignof(T)>::type data; | |
unsigned use_count_; | |
}; | |
// decrease ref counter | |
inline static void DecRef(RefBlock* block) { | |
if (block != nullptr) { | |
if (--block->use_count_ == 0) { | |
ThreadlocalAllocator<RefBlock> arena; | |
T* dptr = reinterpret_cast<T*>(&(block->data)); | |
dptr->~T(); | |
arena.deallocate(block, 1); | |
} | |
} | |
} | |
// increase ref counter | |
inline static void IncRef(RefBlock* block) { | |
if (block != nullptr) { | |
++block->use_count_; | |
} | |
} | |
// internal block | |
RefBlock *block_; | |
}; | |
} // namespace dmlc | |
#endif // DMLC_MEMORY_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/memory.h ===== | |
//===== EXPANDING: ../dmlc-core/include/dmlc/array_view.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file array_view.h | |
* \brief Read only data structure to reference array | |
*/ | |
#ifndef DMLC_ARRAY_VIEW_H_ | |
#define DMLC_ARRAY_VIEW_H_ | |
namespace dmlc { | |
/*! | |
* \brief Read only data structure to reference continuous memory region of array. | |
* Provide unified view for vector, array and C style array. | |
* This data structure do not guarantee aliveness of referenced array. | |
* | |
* Make sure do not use array_view to record data in async function closures. | |
* Also do not use array_view to create reference to temporary data structure. | |
* | |
* \tparam ValueType The value | |
* | |
* \code | |
* std::vector<int> myvec{1,2,3}; | |
* dmlc::array_view<int> view(myvec); | |
* // indexed visit to the view. | |
* LOG(INFO) << view[0]; | |
* | |
* for (int v : view) { | |
* // visit each element in the view | |
* } | |
* \endcode | |
*/ | |
template<typename ValueType> | |
class array_view { | |
public: | |
/*! \brief default constructor */ | |
array_view() = default; | |
/*! | |
* \brief default copy constructor | |
* \param other another array view. | |
*/ | |
array_view(const array_view<ValueType> &other) = default; // NOLINT(*) | |
/*! | |
* \brief default move constructor | |
* \param other another array view. | |
*/ | |
array_view(array_view<ValueType>&& other) = default; // NOLINT(*) | |
/*! | |
* \brief default assign constructor | |
* \param other another array view. | |
* \return self. | |
*/ | |
array_view<ValueType>& operator=(const array_view<ValueType>& other) = default; // NOLINT(*) | |
/*! | |
* \brief construct array view std::vector | |
* \param other vector container | |
*/ | |
array_view(const std::vector<ValueType>& other) { // NOLINT(*) | |
if (other.size() != 0) { | |
begin_ = &other[0]; size_ = other.size(); | |
} | |
} | |
/*! | |
* \brief construct array std::array | |
* \param other another array view. | |
*/ | |
template<std::size_t size> | |
array_view(const std::array<ValueType, size>& other) { // NOLINT(*) | |
if (size != 0) { | |
begin_ = &other[0]; size_ = size; | |
} | |
} | |
/*! | |
* \brief construct array view from continuous segment | |
* \param begin beginning pointre | |
* \param end end pointer | |
*/ | |
array_view(const ValueType* begin, const ValueType* end) { | |
if (begin < end) { | |
begin_ = begin; | |
size_ = end - begin; | |
} | |
} | |
/*! \return size of the array */ | |
inline size_t size() const { | |
return size_; | |
} | |
/*! \return begin of the array */ | |
inline const ValueType* begin() const { | |
return begin_; | |
} | |
/*! \return end point of the array */ | |
inline const ValueType* end() const { | |
return begin_ + size_; | |
} | |
/*! | |
* \brief get i-th element from the view | |
* \param i The index. | |
* \return const reference to i-th element. | |
*/ | |
inline const ValueType& operator[](size_t i) const { | |
return begin_[i]; | |
} | |
private: | |
/*! \brief the begin of the view */ | |
const ValueType* begin_{nullptr}; | |
/*! \brief The size of the view */ | |
size_t size_{0}; | |
}; | |
} // namespace dmlc | |
#endif // DMLC_ARRAY_VIEW_H_ | |
//===== EXPANDED: ../dmlc-core/include/dmlc/array_view.h ===== | |
namespace nnvm { | |
/*! \brief any type */ | |
using dmlc::any; | |
/*! \brief array_veiw type */ | |
using dmlc::array_view; | |
/*!\brief getter function of any type */ | |
using dmlc::get; | |
} // namespace nnvm | |
#endif // NNVM_BASE_H_ | |
//===== EXPANDED: ../nnvm/include/nnvm/base.h ===== | |
namespace nnvm { | |
// forward declarations | |
class Node; | |
struct NodeAttrs; | |
template<typename ValueType> | |
class OpMap; | |
class OpGroup; | |
class OpRegistryEntry; | |
using dmlc::ParamFieldInfo; | |
/*! \brief constant to indicate it take any length of positional inputs */ | |
static const uint32_t kVarg = std::numeric_limits<uint32_t>::max(); | |
/*! | |
* \brief Operator structure. | |
* | |
* Besides the fields in the structure, | |
* arbitary additional information can be associated with each op. | |
* See function GetAttr for details. | |
* | |
* \code | |
* // Example usage of Op | |
* | |
* // registeration of oeprators | |
* // NOTE that the attr function can register any | |
* // additional attributes to the operator | |
* NNVM_REGISTER_OP(add) | |
* .describe("add two inputs together") | |
* .set_num_inputs(2) | |
* .set_attr<OpKernel>("OpKernel<gpu>", AddKernel) | |
* .include("ElementwiseOpAttr"); | |
* | |
* // can register attribute by group | |
* // all the ops that include the group get the attribute. | |
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) | |
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape); | |
* | |
* NNVM_REGISTER_OP(sub) | |
* .describe("substract one tensor from another") | |
* .set_num_inputs(2); | |
* | |
* // Can call regster multiple times in different files | |
* // to register different part of information | |
* NNVM_REGISTER_OP(sub) | |
* .set_attr<OpKernel>("OpKernel<gpu>", SubKernel); | |
* .include("ElementwiseOpAttr"); | |
* | |
* // get operators from registry. | |
* void my_function() { | |
* const Op* add = Op::Get("add"); | |
* const Op* sub = Op::Get("sub"); | |
* // query basic information about each operator. | |
* assert(op->name == "plus"); | |
* assert(op->num_inputs == 2); | |
* | |
* // get additional registered information, | |
* // Assume user registered a OpKernel type attribute as gpu_kernel on each operator. | |
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("OpKernel<gpu>"); | |
* // we can get the kernel functions by using operator as key. | |
* auto add_kernel = kernel[add]; | |
* auto sub_kernel = kernel[sub]; | |
* // subsequent code can make use of the queried kernel functions. | |
* } | |
* \endcode | |
*/ | |
class Op { | |
public: | |
/*! \brief name of the operator */ | |
std::string name; | |
/*! | |
* \brief detailed description of the operator | |
* This can be used to generate docstring automatically for the operator. | |
*/ | |
std::string description; | |
/* \brief description of inputs and keyword arguments*/ | |
std::vector<ParamFieldInfo> arguments; | |
/*! | |
* \brief number of inputs to the operator, | |
* -1 means it is variable length | |
* When get_num_inputs is presented, | |
* the number will be decided by get_num_inputs instead. | |
* \sa get_num_inputs | |
*/ | |
uint32_t num_inputs = 1; | |
/*! | |
* \brief number of outputs of the operator | |
* When get_num_outputs is presented. | |
* The number of outputs will be decided by | |
* get_num_outputs function | |
* \sa get_num_outputs | |
*/ | |
uint32_t num_outputs = 1; | |
/*! | |
* \brief get number of outputs given information about the node. | |
* \param attrs The attribute of the node | |
* \return number of outputs. | |
*/ | |
std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr; | |
/*! | |
* \brief get number of inputs given information about the node. | |
* \param attrs The attribute of the node | |
* \return number of inputs | |
*/ | |
std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr; | |
/*! | |
* \brief Attribute parser to parse the NodeAttrs information. | |
* | |
* This can help to get quick access to a parsed attribute | |
* object | |
* | |
* \code | |
* // Example usage of attr_parser. | |
* | |
* // Suppose we want to register operator sum. | |
* // The parameters about sum operator | |
* struct SumParam { | |
* int axis; | |
* }; | |
* // The parser function | |
* void SumAttrParser(NodeAttrs* attrs) { | |
* // This will be invoked during node construction. | |
* SumParam param; | |
* // parse axis string to integer | |
* param.axis = atoi(attrs->dict["axis"].c_str()); | |
* // set the parsed parameter | |
* attrs->parsed = std::move(param); | |
* } | |
* // The other function that can utilize the parsed result. | |
* TShape SumInferShape(const NodeAttrs& attrs, | |
* const std::vector<TShape>& ishapes) { | |
* // we can use the parsed version of param | |
* // without repeatively parsing the parameter | |
* const SumParam& param = nnvm::get<SumParam>(attrs.parsed); | |
* } | |
* \endcode | |
*/ | |
std::function<void(NodeAttrs* attrs)> attr_parser = nullptr; | |
// function fields. | |
/*! | |
* \brief setter function during registration | |
* Set the description of operator | |
* \param descr the description string. | |
* \return reference to self. | |
*/ | |
inline Op& describe(const std::string& descr); // NOLINT(*) | |
/*! | |
* \brief Add argument information to the function. | |
* \param name Name of the argument. | |
* \param type Type of the argument. | |
* \param description Description of the argument. | |
* \return reference to self. | |
*/ | |
inline Op& add_argument(const std::string &name, | |
const std::string &type, | |
const std::string &description); | |
/*! | |
* \brief Append list if arguments to the end. | |
* \param args Additional list of arguments. | |
* \return reference to self. | |
*/ | |
inline Op& add_arguments(const std::vector<ParamFieldInfo> &args); | |
/*! | |
* \brief Set the num_inputs | |
* \param n The number of inputs to be set. | |
* \return reference to self. | |
*/ | |
inline Op& set_num_inputs(uint32_t n); // NOLINT(*) | |
/*! | |
* \brief Set the get_num_outputs function. | |
* \param fn The function to be set. | |
* \return reference to self. | |
*/ | |
inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*) | |
/*! | |
* \brief Set the num_outputs | |
* \param n The number of outputs to be set. | |
* \return reference to self. | |
*/ | |
inline Op& set_num_outputs(uint32_t n); // NOLINT(*) | |
/*! | |
* \brief Set the get_num_outputs function. | |
* \param fn The function to be set. | |
* \return reference to self. | |
*/ | |
inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*) | |
/*! | |
* \brief Set the attr_parser function. | |
* \param fn The number of outputs to be set. | |
* \return reference to self. | |
*/ | |
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*) | |
/*! | |
* \brief Register additional attributes to operator. | |
* \param attr_name The name of the attribute. | |
* \param value The value to be set. | |
* \param plevel The priority level of this set, | |
* an higher priority level attribute | |
* will replace lower priority level attribute. | |
* Must be bigger than 0. | |
* | |
* Cannot set with same plevel twice in the code. | |
* | |
* \tparam ValueType The type of the value to be set. | |
*/ | |
template<typename ValueType> | |
inline Op& set_attr(const std::string& attr_name, // NOLINT(*) | |
const ValueType& value, | |
int plevel = 10); | |
/*! | |
* \brief Add another alias to this operator. | |
* The same Op can be queried with Op::Get(alias) | |
* \param alias The alias of the operator. | |
* \return reference to self. | |
*/ | |
Op& add_alias(const std::string& alias); // NOLINT(*) | |
/*! | |
* \brief Include all the attributes from an registered op group. | |
* \param group_name The name of the group. | |
* \return reference to self. | |
* | |
* \sa NNVM_REGISTER_OP_GROUP | |
*/ | |
Op& include(const std::string& group_name); | |
/*! | |
* \brief Get an Op for a given operator name. | |
* Will raise an error if the op has not been registered. | |
* \param op_name Name of the operator. | |
* \return Pointer to a Op, valid throughout program lifetime. | |
*/ | |
static const Op* Get(const std::string& op_name); | |
/*! | |
* \brief Get additional registered attribute about operators. | |
* If nothing has been registered, an empty OpMap will be returned. | |
* \param attr_name The name of the attribute. | |
* \return An OpMap of specified attr_name. | |
* \tparam ValueType The type of the attribute. | |
*/ | |
template<typename ValueType> | |
static const OpMap<ValueType>& GetAttr(const std::string& attr_name); | |
private: | |
template<typename ValueType> | |
friend class OpMap; | |
friend class OpGroup; | |
friend class dmlc::Registry<Op>; | |
// Program internal unique index of operator. | |
// Used to help index the program. | |
uint32_t index_{0}; | |
// internal constructor | |
Op(); | |
// get const reference to certain attribute | |
static const any* GetAttrMap(const std::string& key); | |
// update the attribute OpMap | |
static void UpdateAttrMap(const std::string& key, | |
std::function<void(any*)> updater); | |
// add a trigger based on tag matching on certain tag attribute | |
// This will apply trigger on all the op such that | |
// include the corresponding group. | |
// The trigger will also be applied to all future registrations | |
// that calls include | |
static void AddGroupTrigger(const std::string& group_name, | |
std::function<void(Op*)> trigger); | |
}; | |
/*! | |
* \brief A map data structure that takes Op* as key | |
* and returns ValueType | |
* \tparam ValueType The type of the value stored in map. | |
*/ | |
template<typename ValueType> | |
class OpMap { | |
public: | |
/*! | |
* \brief get the corresponding value element at op | |
* \param op The key to the map | |
* \return the const reference to the content value. | |
*/ | |
inline const ValueType& operator[](const Op* op) const; | |
/*! | |
* \brief get the corresponding value element at op with default value. | |
* \param op The key to the map | |
* \param def_value The default value when the key does not exist. | |
* \return the const reference to the content value. | |
*/ | |
inline const ValueType& get(const Op* op, const ValueType& def_value) const; | |
/*! | |
* \brief Check if the map has op as key. | |
* \param op The key to the map | |
* \return 1 if op is contained in map, 0 otherwise. | |
*/ | |
inline int count(const Op* op) const; | |
private: | |
friend class Op; | |
// internal attribute name | |
std::string attr_name_; | |
// internal data | |
std::vector<std::pair<ValueType, int> > data_; | |
OpMap() = default; | |
}; | |
/*! | |
* \brief auxiliary data structure used to | |
* set attributes to a group of operators | |
*/ | |
class OpGroup { | |
public: | |
/*! \brief the tag key to be matched */ | |
std::string group_name; | |
/*! | |
* \brief Register additional attributes to operator group. | |
* \param attr_name The name of the attribute. | |
* \param value The value to be set. | |
* \param plevel The priority level of this set, | |
* an higher priority level attribute | |
* will replace lower priority level attribute. | |
* Must be bigger than 0. | |
* | |
* Cannot set with same plevel twice in the code. | |
* | |
* \tparam ValueType The type of the value to be set. | |
*/ | |
template<typename ValueType> | |
inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*) | |
const ValueType& value, | |
int plevel = 1); | |
}; | |
// internal macros to make | |
#define NNVM_REGISTER_VAR_DEF(OpName) \ | |
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName | |
#define NNVM_REGISTER_GVAR_DEF(TagName) \ | |
static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName | |
/*! | |
* \def NNVM_REGISTER_OP | |
* \brief Register a new operator, or set attribute of the corresponding op. | |
* | |
* \param OpName The name of registry | |
* | |
* \code | |
* | |
* NNVM_REGISTER_OP(add) | |
* .describe("add two inputs together") | |
* .set_num_inputs(2) | |
* .set_attr<OpKernel>("gpu_kernel", AddKernel); | |
* | |
* \endcode | |
*/ | |
#define NNVM_REGISTER_OP(OpName) \ | |
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ | |
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) | |
/*! | |
* \def NNVM_REGISTER_OP_GROUP | |
* \brief Register attribute to a group of operators. | |
* These attributes will be registered to Op that include the group. | |
* | |
* \param GroupName The name of the group. | |
* | |
* \code | |
* | |
* NNVM_REGISTER_OP(add) | |
* .include("ElementwiseOpAttr"); | |
* | |
* // register same attributes to all the ops that include the group | |
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) | |
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape); | |
* | |
* NNVM_REGISTER_OP(mul) | |
* .include("ElementwiseOpAttr"); | |
* | |
* \endcode | |
*/ | |
#define NNVM_REGISTER_OP_GROUP(GroupName) \ | |
DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \ | |
::nnvm::OpGroup {#GroupName} | |
// implementations of template functions after this. | |
// member function of Op | |
template<typename ValueType> | |
inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) { | |
const any* ref = GetAttrMap(key); | |
if (ref == nullptr) { | |
// update the attribute map of the key by creating new empty OpMap | |
UpdateAttrMap(key, [key](any* pmap) { | |
// use callback so it is in lockscope | |
if (pmap->empty()) { | |
OpMap<ValueType> pm; | |
pm.attr_name_ = key; | |
*pmap = std::move(pm); | |
} | |
}); | |
ref = GetAttrMap(key); | |
} | |
return nnvm::get<OpMap<ValueType> >(*ref); | |
} | |
template<typename ValueType> | |
inline Op& Op::set_attr( // NOLINT(*) | |
const std::string& attr_name, | |
const ValueType& value, | |
int plevel) { | |
CHECK_GT(plevel, 0) | |
<< "plevel in set_attr must be greater than 0"; | |
// update the attribute map of the key by creating new empty if needed. | |
UpdateAttrMap(attr_name, | |
[this, attr_name, value, plevel](any* pmap) { | |
// the callback is in lockscope so is threadsafe. | |
if (pmap->empty()) { | |
OpMap<ValueType> pm; | |
pm.attr_name_ = attr_name; | |
*pmap = std::move(pm); | |
} | |
CHECK(pmap->type() == typeid(OpMap<ValueType>)) | |
<< "Attribute " << attr_name | |
<< " of operator " << this->name | |
<< " is registered as inconsistent types" | |
<< " previously " << pmap->type().name() | |
<< " current " << typeid(OpMap<ValueType>).name(); | |
std::vector<std::pair<ValueType, int> >& vec = | |
nnvm::get<OpMap<ValueType> >(*pmap).data_; | |
// resize the value type. | |
if (vec.size() <= index_) { | |
vec.resize(index_ + 1, | |
std::make_pair(ValueType(), 0)); | |
} | |
std::pair<ValueType, int>& p = vec[index_]; | |
CHECK(p.second != plevel) | |
<< "Attribute " << attr_name | |
<< " of operator " << this->name | |
<< " is already registered with same plevel=" << plevel; | |
if (p.second < plevel) { | |
vec[index_] = std::make_pair(value, plevel); | |
} | |
}); | |
return *this; | |
} | |
inline Op& Op::describe(const std::string& descr) { // NOLINT(*) | |
this->description = descr; | |
return *this; | |
} | |
inline Op& Op::add_argument(const std::string &name, | |
const std::string &type, | |
const std::string &description) { | |
arguments.push_back({name, type, type, description}); | |
return *this; | |
} | |
inline Op& Op::add_arguments(const std::vector<ParamFieldInfo> &args) { | |
this->arguments.insert(arguments.end(), args.begin(), args.end()); | |
return *this; | |
} | |
inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) | |
this->num_inputs = n; | |
return *this; | |
} | |
inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*) | |
this->get_num_inputs = fn; | |
return *this; | |
} | |
inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) | |
this->num_outputs = n; | |
return *this; | |
} | |
inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*) | |
this->get_num_outputs = fn; | |
return *this; | |
} | |
inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { // NOLINT(*) | |
this->attr_parser = fn; | |
return *this; | |
} | |
// member functions of OpMap | |
template<typename ValueType> | |
inline int OpMap<ValueType>::count(const Op* op) const { | |
if (op == nullptr) return 0; | |
const uint32_t idx = op->index_; | |
return idx < data_.size() ? (data_[idx].second != 0) : 0; | |
} | |
template<typename ValueType> | |
inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const { | |
CHECK(op != nullptr); | |
const uint32_t idx = op->index_; | |
CHECK(idx < data_.size() && data_[idx].second) | |
<< "Attribute " << attr_name_ | |
<< " has not been registered for Operator " << op->name; | |
return data_[idx].first; | |
} | |
template<typename ValueType> | |
inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const { | |
if (op == nullptr) return def_value; | |
const uint32_t idx = op->index_; | |
if (idx < data_.size() && data_[idx].second) { | |
return data_[idx].first; | |
} else { | |
return def_value; | |
} | |
} | |
template<typename ValueType> | |
inline OpGroup& OpGroup::set_attr(const std::string& attr_name, | |
const ValueType& value, | |
int plevel) { | |
auto trigger = [attr_name, value, plevel](Op* op) { | |
op->set_attr<ValueType>(attr_name, value, plevel); | |
}; | |
Op::AddGroupTrigger(group_name, trigger); | |
return *this; | |
} | |
} // namespace nnvm | |
#endif // NNVM_OP_H_ | |
//===== EXPANDED: ../nnvm/include/nnvm/op.h ===== | |
//===== EXPANDING: ../nnvm/include/nnvm/tuple.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file tuple.h | |
* \brief Data structure Tuple and TShape to store dynamic sized shapes. | |
*/ | |
#ifndef NNVM_TUPLE_H_ | |
#define NNVM_TUPLE_H_ | |
namespace nnvm { | |
/*! \brief data type to store array index */ | |
typedef uint32_t index_t; | |
/*! | |
* \brief A dynamic sized array data strcuture that is optimized for storing | |
* small number of elements with same type. | |
* | |
* Data will be stored in stack when number of elements is small. | |
* It is suitable to hold shape of Tensor. | |
* | |
* \tparam ValueType The type of data stored inside tuple. | |
* \sa TShape | |
*/ | |
template<typename ValueType> | |
class Tuple { | |
public: | |
// Tuple requires the content to be simple data type. | |
static_assert(std::is_pod<ValueType>::value, | |
"Tuple only support simple data type like int"); | |
/*! \brief default constructor */ | |
Tuple() = default; | |
/*! \brief destructor */ | |
inline ~Tuple() { | |
delete [] data_heap_; | |
} | |
/*! | |
* \brief copy constructor from another tuple | |
* \param s the source tuple | |
*/ | |
inline Tuple(const Tuple<ValueType>& s) { | |
this->assign(s.begin(), s.end()); | |
} | |
/*! | |
* \brief constructor from initializer list | |
* \param init the initializer_list | |
*/ | |
inline Tuple(std::initializer_list<ValueType> init) { | |
this->assign(init.begin(), init.end()); | |
} | |
/*! | |
* \brief move constructor from Tuple | |
* \param src the source shape | |
*/ | |
inline Tuple(Tuple<ValueType>&& src) { // NOLINT(*) | |
this->swap(src); | |
} | |
/*! | |
* \brief construct the Tuple from content of iterator | |
* \param begin the beginning of iterator | |
* \param end end the end of the iterator | |
* \tparam RandomAccessIterator iterator type | |
*/ | |
template<typename RandomAccessIterator> | |
inline Tuple(RandomAccessIterator begin, | |
RandomAccessIterator end) { | |
this->assign(begin, end); | |
} | |
/*! | |
* \brief Assign content to tuple from iterator. | |
* \param begin the beginning of iteratro | |
* \param end end the end of the iterator | |
* \tparam RandomAccessIterator iterator type | |
*/ | |
template<typename RandomAccessIterator> | |
inline void assign(RandomAccessIterator begin, | |
RandomAccessIterator end) { | |
this->SetDim(end - begin); | |
std::copy(begin, end, this->begin()); | |
} | |
/*! | |
* \brief Swap current object with other | |
* \param other another object to be swapped. | |
*/ | |
inline void swap(Tuple<ValueType>& other) { // NOLINT(*) | |
std::swap(ndim_, other.ndim_); | |
std::swap(num_heap_allocated_, other.num_heap_allocated_); | |
std::swap(data_stack_, other.data_stack_); | |
std::swap(data_heap_, other.data_heap_); | |
} | |
/*! | |
* \brief assignment from another tuple. | |
* \param src source tuple | |
* \return reference of self | |
*/ | |
inline Tuple<ValueType>& operator=(const Tuple<ValueType>& src) { | |
this->assign(src.begin(), src.end()); | |
return *this; | |
} | |
/*! | |
* \brief assignment from rvalue of another tuple. | |
* \param src source tuple | |
* \return reference of self | |
*/ | |
inline Tuple<ValueType>& operator=(Tuple<ValueType>&& src) { | |
Tuple<ValueType>(std::move(src)).swap(*this); | |
return *this; | |
} | |
/*! | |
* \brief assignment from initializer list | |
* \param init the source initializer list | |
* \return reference of self | |
*/ | |
inline Tuple<ValueType> &operator=(std::initializer_list<ValueType> init) { | |
this->assign(init.begin(), init.end()); | |
return *this; | |
} | |
/*! | |
* \return whether two tuple equals | |
* \param s the tuple to compare against | |
*/ | |
inline bool operator==(const Tuple<ValueType> &s) const { | |
if (ndim_ != s.ndim_) return false; | |
return std::equal(begin(), end(), s.begin()); | |
} | |
/*! | |
* \return whether two tuple not equal | |
* \param s the tuple to compare against | |
*/ | |
inline bool operator!=(const Tuple<ValueType> &s) const { | |
return !(*this == s); | |
} | |
/*! \return the begin data pointer to content of the tuple */ | |
inline const ValueType *begin() const { | |
return ndim_ <= kStackCache ? data_stack_ : data_heap_; | |
} | |
/*! \return the begin data pointer to content of the tuple */ | |
inline ValueType *begin() { | |
return ndim_ <= kStackCache ? data_stack_ : data_heap_; | |
} | |
/*! \return the data pointer to end of the tuple */ | |
inline const ValueType* end() const { | |
return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); | |
} | |
/*! \return the data pointer to end the tuple */ | |
inline ValueType* end() { | |
return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); | |
} | |
/*! \return number of dimension of the tuple */ | |
inline index_t ndim() const { | |
return ndim_; | |
} | |
/*! | |
* \brief get corresponding index | |
* \param i dimension index | |
* \return the corresponding dimension size | |
*/ | |
inline ValueType& operator[](index_t i) { | |
return begin()[i]; | |
} | |
/*! | |
* \brief get corresponding index | |
* \param i dimension index | |
* \return the corresponding dimension size | |
*/ | |
inline const ValueType& operator[](index_t i) const { | |
return begin()[i]; | |
} | |
/*! | |
* \brief Save Tuple to JSON. | |
* \param writer JSONWriter | |
*/ | |
inline void Save(dmlc::JSONWriter* writer) const { | |
std::vector<ValueType> tmp(begin(), end()); | |
writer->Write(tmp); | |
} | |
/*! | |
* \brief Load Tuple from JSON. | |
* \param reader JSONReader | |
*/ | |
inline void Load(dmlc::JSONReader* reader) { | |
std::vector<ValueType> tmp; | |
reader->Read(&tmp); | |
this->assign(tmp.begin(), tmp.end()); | |
} | |
/*! | |
* \brief allow output string of tuple to ostream | |
* \param os the output stream | |
* \param t the tuple | |
* \return the ostream | |
*/ | |
friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) { | |
os << '('; | |
const ValueType* begin = t.begin(); | |
const ValueType* end = t.end(); | |
for (const ValueType* it = begin; it != end; ++it) { | |
if (it != begin) os << ','; | |
os << *it; | |
} | |
// python style tuple | |
if (t.ndim() == 1) os << ','; | |
os << ')'; | |
return os; | |
} | |
/*! | |
* \brief read tuple from the istream | |
* \param is the input stream | |
* \param t The tuple | |
* \return the istream | |
*/ | |
friend std::istream &operator>>(std::istream &is, Tuple<ValueType> &t) { | |
// get ( | |
while (true) { | |
char ch = is.peek(); | |
if (isdigit(ch)) { | |
ValueType idx; | |
if (is >> idx) { | |
t.assign(&idx, &idx + 1); | |
} | |
return is; | |
} | |
is.get(); | |
if (ch == '(' || ch == '[') break; | |
if (!isspace(ch)) { | |
is.setstate(std::ios::failbit); | |
return is; | |
} | |
} | |
// Handle empty tuple | |
while (isspace(is.peek())) { | |
is.get(); | |
} | |
if (is.peek() == ')') { | |
is.get(); | |
return is; | |
} | |
// Handle non-empty tuple | |
ValueType idx; | |
std::vector<ValueType> tmp; | |
while (is >> idx) { | |
tmp.push_back(idx); | |
char ch; | |
do { | |
ch = is.get(); | |
} while (isspace(ch)); | |
if (std::is_integral<ValueType>::value && ch == 'L') { | |
ch = is.get(); | |
} | |
if (ch == ',') { | |
while (true) { | |
ch = is.peek(); | |
if (isspace(ch)) { | |
is.get(); continue; | |
} | |
if (ch == ')' || ch == ']') { | |
is.get(); break; | |
} | |
break; | |
} | |
if (ch == ')' || ch == ']') break; | |
} else if (ch == ')' || ch == ']') { | |
break; | |
} else { | |
is.setstate(std::ios::failbit); | |
return is; | |
} | |
} | |
t.assign(tmp.begin(), tmp.end()); | |
return is; | |
} | |
/*! | |
* \brief save the content into binary stream | |
* \param strm the output stream | |
* \tparam TStream any stream type that have write | |
*/ | |
template<typename TStream> | |
inline void Save(TStream *strm) const { | |
strm->Write(&ndim_, sizeof(ndim_)); | |
strm->Write(begin(), sizeof(ValueType) * ndim_); | |
} | |
/*! | |
* \brief load the content from binary stream | |
* \param strm the output stream | |
* \tparam TStream any stream type that have write | |
* \return whether the load is successful | |
*/ | |
template<typename TStream> | |
inline bool Load(TStream *strm) { | |
if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false; | |
this->SetDim(ndim_); | |
size_t nread = sizeof(ValueType) * ndim_; | |
if (strm->Read(begin(), nread) != nread) return false; | |
return true; | |
} | |
protected: | |
// stack cache size | |
static const uint32_t kStackCache = 4; | |
/*! \brief number of dimension of the tuple */ | |
index_t ndim_{0}; | |
/*! \brief number of cells allocated in data_heap_ */ | |
index_t num_heap_allocated_{0}; | |
/*! \brief in stack space used to store shape when it is small */ | |
ValueType data_stack_[kStackCache]; | |
/*! \brief space to store shape when dimension is big*/ | |
ValueType* data_heap_{nullptr}; | |
// internal function to change the dimension | |
inline void SetDim(index_t dim) { | |
if (dim > kStackCache && | |
dim > num_heap_allocated_) { | |
delete [] data_heap_; | |
data_heap_ = new ValueType[dim]; | |
num_heap_allocated_ = dim; | |
} | |
ndim_ = dim; | |
} | |
}; | |
/*! | |
* \brief A Shape class that is used to represent shape of each tensor. | |
*/ | |
class TShape : public Tuple<index_t> { | |
public: | |
/*! \brief default constructor */ | |
TShape() = default; | |
/*! | |
* constructor to construct a shape with all 1. | |
* \param ndim the number of dimension | |
*/ | |
inline TShape(index_t ndim) { // NOLINT(*) | |
this->SetDim(ndim); | |
std::fill_n(begin(), ndim, 1); | |
} | |
/*! | |
* \brief copy constructor of TShape | |
* \param s source shape. | |
*/ | |
inline TShape(const Tuple<index_t>& s) { // NOLINT(*) | |
this->assign(s.begin(), s.end()); | |
} | |
/*! | |
* \brief constructor from initializer list | |
* \param init the initializer_list | |
*/ | |
inline TShape(std::initializer_list<index_t> init) { | |
this->assign(init.begin(), init.end()); | |
} | |
/*! | |
* \brief move constructor. | |
* \param s source shape. | |
*/ | |
inline TShape(Tuple<index_t>&& s) { // NOLINT(*) | |
this->swap(s); | |
} | |
/*! | |
* \brief construct the Tuple from content of iterator | |
* \param begin the beginning of iterator | |
* \param end end the end of the iterator | |
* \tparam RandomAccessIterator iterator type | |
*/ | |
template<typename RandomAccessIterator> | |
inline TShape(RandomAccessIterator begin, | |
RandomAccessIterator end) { | |
this->assign(begin, end); | |
} | |
/*! | |
* \brief assignment function from tshape | |
* \param src source shape. | |
* \return self. | |
*/ | |
inline TShape& operator=(const Tuple<index_t>& src) { | |
this->assign(src.begin(), src.end()); | |
return *this; | |
} | |
/*! | |
* \brief move assignment function from tshape | |
* \param src source shape. | |
* \return self. | |
*/ | |
inline TShape& operator=(Tuple<index_t>&& src) { // NOLINT(*) | |
TShape(std::move(src)).swap(*this); // NOLINT(*) | |
return *this; | |
} | |
/*! \return total number of elements in the shape */ | |
inline size_t Size() const { | |
size_t size = 1; | |
const index_t* start = begin(), *fin = end(); | |
for (const index_t* it = start; it != fin; ++it) { | |
size *= *it; | |
} | |
return size; | |
} | |
/*! | |
* \return product shape in [dimstart,dimend) | |
* \param dimstart start dimension | |
* \param dimend end dimension | |
*/ | |
inline index_t ProdShape(int dimstart, int dimend) const { | |
index_t num = 1; | |
const index_t *d = this->data(); | |
for (int i = dimstart; i < dimend; ++i) { | |
num *= d[i]; | |
} | |
return num; | |
} | |
/*! \return the begin data pointer to content of the tuple */ | |
inline const index_t *data() const { | |
return begin(); | |
} | |
/*! \return the begin data pointer to content of the tuple */ | |
inline index_t *data() { | |
return begin(); | |
} | |
#ifdef MSHADOW_XINLINE | |
template<int dim> | |
inline TShape(const mshadow::Shape<dim> &s) {// NOLINT(*) | |
this->assign(s.shape_, s.shape_ + dim); | |
} | |
template<int dim> | |
inline TShape(mshadow::Shape<dim> &&s) {// NOLINT(*) | |
this->assign(s.shape_, s.shape_ + dim); | |
} | |
/*! | |
* \brief assignment from shape | |
* \param shape source shape | |
* \tparam dim shape dimension | |
* \return reference of self | |
*/ | |
template<int dim> | |
inline TShape &operator=(const mshadow::Shape<dim> &shape) { | |
this->assign(shape.shape_, shape.shape_ + dim); | |
return *this; | |
} | |
/*! | |
* \brief get the shape of tensor specifying dim | |
* \return the shape requested | |
* \tparam dim dimension of the tensor | |
*/ | |
template<int dim> | |
inline mshadow::Shape<dim> get() const { | |
CHECK_EQ(dim, ndim()) | |
<< "dimension do not match target dimension " << dim << " vs " << ndim(); | |
const index_t *d = this->data(); | |
mshadow::Shape<dim> s; | |
for (int i = 0; i < dim; ++i) { | |
s[i] = d[i]; | |
} | |
return s; | |
} | |
/*! | |
* flatten the higher dimension to second dimension, return a 2D shape | |
* \return the flat 2d shape | |
*/ | |
inline mshadow::Shape<2> FlatTo2D(void) const { | |
mshadow::Shape<2> s; | |
if (ndim() == 0) return mshadow::Shape2(0, 0); | |
const index_t *d = this->data(); | |
s.shape_[1] = d[ndim() - 1]; | |
index_t ymax = 1; | |
for (index_t i = 1; i < ndim(); ++i) { | |
ymax *= d[i - 1]; | |
} | |
s.shape_[0] = ymax; | |
return s; | |
} | |
/*! | |
* flatten the shape into three parts: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim) | |
* \param axis_begin The beginning axis specified. | |
* \param axis_end The ending axis specified. | |
* \return the flat 3d shape | |
*/ | |
inline mshadow::Shape<3> FlatTo3D(index_t axis_begin, index_t axis_end) const { | |
CHECK(axis_end >= axis_begin); | |
mshadow::Shape<3> s; | |
if (ndim() == 0) return mshadow::Shape3(0, 0, 0); | |
const index_t *d = this->data(); | |
s.shape_[0] = 1; | |
s.shape_[1] = 1; | |
s.shape_[2] = 1; | |
for (index_t i = 0; i < axis_begin; ++i) { | |
s.shape_[0] *= d[i]; | |
} | |
for (index_t i = axis_begin; i <= axis_end; ++i) { | |
s.shape_[1] *= d[i]; | |
} | |
for (index_t i = axis_end + 1; i < ndim(); ++i) { | |
s.shape_[2] *= d[i]; | |
} | |
return s; | |
} | |
/*! | |
* flatten the axis before and after the specified axis, so it becomes 3D tensor | |
* \param axis The axis specified. | |
* \return the flat 3d shape | |
*/ | |
inline mshadow::Shape<3> FlatTo3D(index_t axis) const { | |
return FlatTo3D(axis, axis); | |
} | |
inline bool operator==(const TShape &s) const { | |
if (ndim() != s.ndim()) return false; | |
return std::equal(begin(), end(), s.begin()); | |
} | |
inline bool operator!=(const TShape &s) const { | |
return !(*this == s); | |
} | |
/*! | |
* \return whether two shape equals | |
* \param s the shape to compare against | |
* \tparam dim dimension of the shape | |
*/ | |
template<int dim> | |
inline bool operator==(const mshadow::Shape<dim> &s) const { | |
if (ndim_ != dim) return false; | |
const index_t *d = dim <= kStackCache ? data_stack_ : data_heap_; | |
for (index_t i = 0; i < dim; ++i) { | |
if (d[i] != s.shape_[i]) return false; | |
} | |
return true; | |
} | |
/*! | |
* \return whether two shape not equals | |
* \param s the shape to compare against | |
* \tparam dim dimension of the shape | |
*/ | |
template<int dim> | |
inline bool operator!=(const mshadow::Shape<dim> &s) const { | |
return !(*this == s); | |
} | |
#endif | |
}; | |
} // namespace nnvm | |
#endif // NNVM_TUPLE_H_ | |
//===== EXPANDED: ../nnvm/include/nnvm/tuple.h ===== | |
//===== EXPANDING: ../nnvm/include/nnvm/symbolic.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file symbolic.h | |
* \brief Symbolic graph construction API | |
* | |
* This API is optional, but useful to allow user | |
* to construct NNVM Graph easily, and quickly create | |
* front-end host languages. | |
*/ | |
#ifndef NNVM_SYMBOLIC_H_ | |
#define NNVM_SYMBOLIC_H_ | |
//===== EXPANDING: ../nnvm/include/nnvm/node.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file node.h | |
* \brief Graph node data structure. | |
*/ | |
#ifndef NNVM_NODE_H_ | |
#define NNVM_NODE_H_ | |
namespace nnvm { | |
// Forward declare node. | |
class Node; | |
/*! | |
* \brief we always used NodePtr for a reference pointer | |
* to the node, so this alias can be changed in case. | |
* | |
* By default, NodePtr is a std::shared_ptr of node | |
*/ | |
using NodePtr = std::shared_ptr<Node>; | |
/*! \brief an entry that represents output data from a node */ | |
struct NodeEntry { | |
/*! \brief the source node of this data */ | |
NodePtr node; | |
/*! \brief index of output from the source. */ | |
uint32_t index; | |
/*! | |
* \brief version of input Variable. | |
* This field can only be nonzero when this->node is a Variable node. | |
* version is increased by one each time a Variable get composed to a mutation Op. | |
* This information can be helpful to decide order of operations when sequence of mutation happens. | |
*/ | |
uint32_t version; | |
}; | |
/*! | |
* \brief The attributes of the current operation node. | |
* Usually are additional parameters like axis, | |
*/ | |
struct NodeAttrs { | |
/*! | |
* \brief The operator this node uses. | |
* For place holder variable, op == nullptr. | |
*/ | |
const Op *op{nullptr}; | |
/*! \brief name of the node */ | |
std::string name; | |
/*! \brief Vector representation of positional attributes */ | |
std::vector<double> scalars; | |
/*! \brief The dictionary representation of attributes */ | |
std::unordered_map<std::string, std::string> dict; | |
/*! | |
* \brief A parsed version of attributes, | |
* This is generated if OpProperty.attr_parser is registered. | |
* The object can be used to quickly access attributes. | |
*/ | |
any parsed; | |
}; | |
/*! | |
* \brief Node represents an operation in a computation graph. | |
*/ | |
class Node { | |
public: | |
/*! \brief The attributes in the node. */ | |
NodeAttrs attrs; | |
/*! \brief inputs to this node */ | |
std::vector<NodeEntry> inputs; | |
/*! | |
* \brief Optional control flow dependencies | |
* Gives operation must be performed before this operation. | |
*/ | |
std::vector<NodePtr> control_deps; | |
/*! \brief destructor of node */ | |
~Node(); | |
/*! \return operator in this node */ | |
inline const Op* op() const; | |
/*! | |
* \brief return whether node is placeholder variable. | |
* This is equivalent to op == nullptr | |
* \return whether node is placeholder input variable | |
*/ | |
inline bool is_variable() const; | |
/*! \return number of outputs from this node */ | |
inline uint32_t num_outputs() const; | |
/*! \return number of inputs from this node */ | |
inline uint32_t num_inputs() const; | |
/*! | |
* \brief create a new empty shared_ptr of Node. | |
* \return a created empty node. | |
*/ | |
static NodePtr Create(); | |
}; | |
// implementation of functions. | |
inline const Op* Node::op() const { | |
return this->attrs.op; | |
} | |
inline bool Node::is_variable() const { | |
return this->op() == nullptr; | |
} | |
inline uint32_t Node::num_outputs() const { | |
if (is_variable()) return 1; | |
if (this->op()->get_num_outputs == nullptr) { | |
return this->op()->num_outputs; | |
} else { | |
return this->op()->get_num_outputs(this->attrs); | |
} | |
} | |
inline uint32_t Node::num_inputs() const { | |
if (is_variable()) return 1; | |
if (this->op()->get_num_inputs == nullptr) { | |
return this->op()->num_inputs; | |
} else { | |
return this->op()->get_num_inputs(this->attrs); | |
} | |
} | |
} // namespace nnvm | |
#endif // NNVM_NODE_H_ | |
//===== EXPANDED: ../nnvm/include/nnvm/node.h ===== | |
namespace nnvm { | |
/*! | |
* \brief Symbol is help class used to represent the operator node in Graph. | |
* | |
* Symbol acts as an interface for building graphs from different components | |
* like Variable, Functor and Group. Symbol is also exported to python front-end | |
* (while Graph is not) to enable quick test and deployment. Conceptually, | |
* symbol is the final operation of a graph and thus including all the information | |
* required (the graph) to evaluate its output value. | |
*/ | |
class Symbol { | |
public: | |
/*! \brief option passed to ListAttr */ | |
enum ListAttrOption { | |
/*! \brief recursively list all attributes */ | |
kRecursive = 0, | |
/*! \brief only list attributes in current node */ | |
kShallow = 1 | |
}; | |
/*! \brief option passed to ListInputNames */ | |
enum ListInputOption { | |
/*! \brief list all the arguments */ | |
kAll = 0, | |
/*! \brief list only read only arguments */ | |
kReadOnlyArgs = 1, | |
/*! | |
* \brief List auxiliary states that can be mutated by the graph. | |
* This excludes the ReadOnly arguments | |
*/ | |
kAuxiliaryStates = 2 | |
}; | |
/*! \brief output entries contained in the symbol */ | |
std::vector<NodeEntry> outputs; | |
/*! | |
* \brief Copy the symbol. | |
* \return A deep copy of this symbol. | |
*/ | |
Symbol Copy() const; | |
/*! | |
* \brief Print the symbol info to output stream. | |
* \param os The output stream to print to. | |
*/ | |
void Print(std::ostream &os) const; // NOLINT(*) | |
/*! | |
* \brief Get the index-th element from the returned tuple. | |
* \param index Index of multi output. | |
* \return The symbol corresponds to the indexed element. | |
*/ | |
Symbol operator[] (size_t index) const; | |
/*! | |
* \brief List the input variable nodes. | |
* | |
* The order of the returned list is the same as the order of the input list to `operator()`. | |
* | |
* \param option The options to list the arguments. | |
* \return The arguments list of this symbol, they can be either named or unnamed (empty string). | |
* \sa ListInputOption | |
*/ | |
std::vector<NodePtr> ListInputs(ListInputOption option) const; | |
/*! | |
* \brief List the input names. | |
* | |
* The order of the returned list is the same as the order of the input list to `operator()`. | |
* | |
* \param option The options to list the arguments. | |
* \return The arguments list of this symbol, they can be either named or unnamed (empty string). | |
* \sa ListInputOption | |
*/ | |
std::vector<std::string> ListInputNames(ListInputOption option) const; | |
/*! | |
* \brief List the names of outputs for this symbol. | |
* | |
* For normal operators, it is usually symbol node name + "_output". | |
* | |
* \return get the descriptions of outputs for this symbol. | |
*/ | |
std::vector<std::string> ListOutputNames() const; | |
/*! | |
* \brief Compose the symbol with arguments, this changes the current symbol. | |
* The kwargs passed in can be in-complete, | |
* | |
* The rest of the symbols will remain the same name. | |
* | |
* \param args Positional arguments. | |
* \param kwargs Keyword arguments for the symbol. | |
* \param name Name of returned symbol. | |
*/ | |
void Compose(const array_view<const Symbol*>& args, | |
const std::unordered_map<std::string, const Symbol*>& kwargs, | |
const std::string& name); | |
/*! | |
* \brief Apply the symbol as a function, compose with arguments | |
* | |
* This is equivalent to Copy then Compose. | |
* | |
* \param args Positional arguments for the symbol. | |
* \param kwargs Keyword arguments for the symbol. | |
* \param name Name of returned symbol. | |
* \return A new Symbol which is the composition of current symbol with its arguments. | |
*/ | |
Symbol operator () (const array_view<const Symbol*>& args, | |
const std::unordered_map<std::string, const Symbol*>& kwargs, | |
const std::string& name) const; | |
/*! | |
* \brief Add control flow depenencies to the operators in symbols. | |
* | |
* For grouped symbol, an error will be raised. This mutates current symbolic Node. | |
* | |
* \param src The symbols to depend on. | |
*/ | |
void AddControlDeps(const Symbol& src); | |
/* | |
* \brief Get all the internal nodes of the symbol. | |
* \return symbol A new symbol whose output contains all the outputs of the symbols | |
* including input variables and intermediate outputs. | |
*/ | |
Symbol GetInternals() const; | |
/*! | |
* \brief Set additional attributes to current node. | |
* | |
* This only works for symbol with outputs from single operators. | |
* For grouped symbol, an error will be raised. | |
* | |
* This function mutates the node's symbol and is not recommended. | |
* | |
* \param attrs The attributes to set. | |
*/ | |
void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs); | |
/*! | |
* \brief Get attributes from the symbol. | |
* | |
* This only works for symbol with outputs from single operators. | |
* For grouped symbol, an error will be raised. | |
* | |
* \param key Key of the attribute. When key == "name", it returns the name attirbute. | |
* \param out The output value of the attribute. | |
* \return true If the attribute exists, false if the attribute does not exist. | |
*/ | |
bool GetAttr(const std::string& key, std::string* out) const; | |
/*! | |
* \brief Get attribute dictionary from the symbol. | |
* | |
* For grouped symbol, an error will be raised. | |
* | |
* \param option If recursive flag is set, the attributes of all children are retrieved. | |
* The name of symbol will be pre-pended to each key. | |
* \return The created attribute. | |
*/ | |
std::unordered_map<std::string, std::string> ListAttrs(ListAttrOption option) const; | |
/*! | |
* \brief Get attribute dictionary from the symbol and all children. | |
* | |
* For grouped symbol, an error will be raised. | |
* | |
* \return The created attribute in format <operator_name, key, value>. | |
*/ | |
std::vector<std::tuple<std::string, std::string, std::string> > | |
ListAttrsRecursive() const; | |
/*! | |
* \brief Create symbolic functor(AtomicSymbol) by given operator and attributes. | |
* \param op The operator. | |
* \param attrs The additional attributes. | |
* \return Symbol that can be used to call compose further. | |
*/ | |
static Symbol CreateFunctor(const Op* op, | |
std::unordered_map<std::string, std::string> attrs); | |
/*! | |
* \brief Create symbol node representing variable. | |
* \param name Name of the variable. | |
* \return The symbol. | |
*/ | |
static Symbol CreateVariable(const std::string& name); | |
/*! | |
* \brief Create equivalence of symbol by grouping the symbols together. | |
* \param symbols A list of symbols to be grouped. | |
* \return The grouped symbol. | |
*/ | |
static Symbol CreateGroup(const std::vector<Symbol>& symbols); | |
}; | |
} // namespace nnvm | |
#endif // NNVM_SYMBOLIC_H_ | |
//===== EXPANDED: ../nnvm/include/nnvm/symbolic.h ===== | |
/*! | |
*\brief whether to use opencv support | |
*/ | |
#ifndef MXNET_USE_OPENCV | |
#define MXNET_USE_OPENCV 1 | |
#endif | |
/*! | |
*\brief whether to use cuda support | |
*/ | |
#ifndef MXNET_USE_CUDA | |
#define MXNET_USE_CUDA MSHADOW_USE_CUDA | |
#endif | |
/*! | |
*\brief whether to use cudnn library for convolution | |
*/ | |
#ifndef MXNET_USE_CUDNN | |
#define MXNET_USE_CUDNN MSHADOW_USE_CUDNN | |
#endif | |
/*! \brief Error message for using gpu when MXNET_USE_CUDA==0 */ | |
#define MXNET_GPU_NOT_ENABLED_ERROR "GPU is not enabled" | |
/*! | |
* \brief define compatible keywords in g++ | |
* Used to support g++-4.6 and g++4.7 | |
*/ | |
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) | |
#if __GNUC__ == 4 && __GNUC_MINOR__ < 8 | |
#error "Currently we need g++ 4.8 or higher to fully support c++11 features" | |
#define override | |
#define final | |
#endif | |
#endif | |
/*! | |
* \brief define dllexport for Visual Studio | |
*/ | |
#ifdef _MSC_VER | |
#ifdef MXNET_EXPORTS | |
#define MXNET_API __declspec(dllexport) | |
#else | |
#define MXNET_API __declspec(dllimport) | |
#endif | |
#else | |
#define MXNET_API | |
#endif | |
/*! | |
* \brief define prediction only | |
*/ | |
#ifndef MXNET_PREDICT_ONLY | |
#define MXNET_PREDICT_ONLY 0 | |
#endif | |
/*! | |
* \brief define operator message for profiler | |
*/ | |
#if MXNET_USE_PROFILER | |
#define PROFILER_MESSAGE(msg) msg | |
#else | |
#define PROFILER_MESSAGE(msg) nullptr | |
#endif | |
/*! \brief major version */ | |
#define MXNET_MAJOR 0 | |
/*! \brief minor version */ | |
#define MXNET_MINOR 9 | |
/*! \brief patch version */ | |
#define MXNET_PATCH 3 | |
/*! \brief mxnet version */ | |
#define MXNET_VERSION (MXNET_MAJOR*10000 + MXNET_MINOR*100 + MXNET_PATCH) | |
/*! \brief helper for making version number */ | |
#define MXNET_MAKE_VERSION(major, minor, patch) ((major)*10000 + (minor)*100 + patch) | |
/*! | |
* \brief define function name as profiler message | |
*/ | |
#define PROFILER_MESSAGE_FUNCNAME PROFILER_MESSAGE(__FUNCTION__) | |
/*! \brief namespace of mxnet */ | |
namespace mxnet { | |
/*! \brief mxnet cpu */ | |
typedef mshadow::cpu cpu; | |
/*! \brief mxnet gpu */ | |
typedef mshadow::gpu gpu; | |
/*! \brief index type usually use unsigned */ | |
typedef mshadow::index_t index_t; | |
/*! \brief data type that will be used to store ndarray */ | |
typedef mshadow::default_real_t real_t; | |
/*! \brief Shape data structure used to record shape information */ | |
using TShape = nnvm::TShape; | |
/*! \brief operator structure from NNVM */ | |
using Op = nnvm::Op; | |
/*! \brief Context information about the execution environment */ | |
struct Context { | |
/*! \brief Type of device */ | |
enum DeviceType { | |
kCPU = cpu::kDevMask, | |
kGPU = gpu::kDevMask, | |
kCPUPinned = 3 | |
}; | |
/*! \brief the device type we run the op on */ | |
DeviceType dev_type; | |
/*! \brief device id we are going to run it on */ | |
int32_t dev_id; | |
/*! \brief default constructor */ | |
Context() : dev_type(kCPU), dev_id(0) {} | |
/*! | |
* \brief Get corresponding device mask | |
* \return cpu::kDevMask or gpu::kDevMask | |
*/ | |
inline int dev_mask() const { | |
if (dev_type == kCPUPinned) return cpu::kDevMask; | |
return dev_type; | |
} | |
/*! | |
* \brief Comparator, used to enable Context as std::map key. | |
* \param b another context to compare | |
* \return compared result | |
*/ | |
inline bool operator<(const Context &b) const; | |
/*! | |
* \brief check if current context equals another one | |
* \param b another context to compare | |
* \return whether dev mask and id are same | |
*/ | |
inline bool operator==(const Context &b) const { | |
return dev_type == b.dev_type && dev_id == b.dev_id; | |
} | |
/*! | |
* \brief check if current context not equals another one | |
* \param b another context to compare | |
* \return whether they are not the same | |
*/ | |
inline bool operator!=(const Context &b) const { | |
return !(*this == b); | |
} | |
/*! | |
* \brief save the content into binary stream | |
* \param strm the output stream | |
*/ | |
inline void Save(dmlc::Stream *strm) const { | |
strm->Write(&dev_type, sizeof(dev_type)); | |
strm->Write(&dev_id, sizeof(dev_id)); | |
} | |
/*! | |
* \brief load the content from binary stream | |
* \param strm the output stream | |
* \return whether the load is successful | |
*/ | |
inline bool Load(dmlc::Stream *strm) { | |
if (strm->Read(&dev_type, sizeof(dev_type)) != sizeof(dev_type)) return false; | |
if (strm->Read(&dev_id, sizeof(int32_t)) != sizeof(int32_t)) return false; | |
return true; | |
} | |
/*! \brief the maximal device type */ | |
static const int32_t kMaxDevType = 4; | |
/*! \brief the maximal device index */ | |
static const int32_t kMaxDevID = 16; | |
/*! | |
* \brief Create a new context. | |
* \param dev_type device type. | |
* \param dev_id device id. -1 for current device. | |
*/ | |
inline static Context Create(DeviceType dev_type, int32_t dev_id = -1); | |
/*! \return CPU Context */ | |
inline static Context CPU(int32_t dev_id = 0); | |
/*! | |
* Create a GPU context. | |
* \param dev_id the device id. | |
* \return GPU Context. -1 for current GPU. | |
*/ | |
inline static Context GPU(int32_t dev_id = -1); | |
/*! | |
* Create a pinned CPU context. | |
* \param dev_id the device id for corresponding GPU. | |
* \return Pinned CPU context. -1 for current GPU. | |
*/ | |
inline static Context CPUPinned(int32_t dev_id = -1); | |
/*! | |
* Create a context from string of the format [cpu|gpu|cpu_pinned](n) | |
* \param str the string pattern | |
* \return Context | |
*/ | |
inline static Context FromString(std::string str); | |
}; | |
/*! | |
* \brief execution time context. | |
* The information needed in runtime for actual execution. | |
*/ | |
struct RunContext { | |
/*! | |
* \brief the stream of the device, can be NULL or Stream<gpu>* in GPU mode | |
*/ | |
void *stream; | |
/*! | |
* \brief get mshadow stream from Context | |
* \return the mshadow stream | |
* \tparam xpu the device type of the stream | |
*/ | |
template<typename xpu> | |
inline mshadow::Stream<xpu>* get_stream() const { | |
return static_cast<mshadow::Stream<xpu>*>(stream); | |
} | |
}; | |
} // namespace mxnet | |
//! \cond Doxygen_Suppress | |
namespace mxnet { | |
// implementing Context | |
inline bool Context::operator<(const Context &b) const { | |
if (dev_type == b.dev_type) { | |
return dev_id < b.dev_id; | |
} else { | |
return dev_type < b.dev_type; | |
} | |
} | |
inline Context Context::Create(DeviceType dev_type, int32_t dev_id) { | |
Context ctx; | |
ctx.dev_type = dev_type; | |
if (dev_id < 0) { | |
ctx.dev_id = 0; | |
#if MXNET_USE_CUDA | |
if (dev_type != kCPU) { | |
CHECK_EQ(cudaGetDevice(&ctx.dev_id), cudaSuccess); | |
} | |
#endif | |
} else { | |
ctx.dev_id = dev_id; | |
} | |
return ctx; | |
} | |
inline Context Context::CPU(int32_t dev_id) { | |
return Create(kCPU, dev_id); | |
} | |
inline Context Context::CPUPinned(int32_t dev_id) { | |
return Create(kCPUPinned, dev_id); | |
} | |
inline Context Context::GPU(int32_t dev_id) { | |
return Create(kGPU, dev_id); | |
} | |
inline Context Context::FromString(std::string str) { | |
Context ret; | |
try { | |
std::string::size_type l = str.find('('); | |
CHECK_NE(l, std::string::npos); | |
std::string::size_type r = str.find(')'); | |
CHECK_EQ(r, str.length()-1); | |
std::string type = str.substr(0, l); | |
int id = std::stoi(str.substr(l+1, r-l-1)); | |
if (type == "cpu") { | |
ret = CPU(id); | |
} else if (type == "gpu") { | |
ret = GPU(id); | |
} else if (type == "cpu_pinned") { | |
ret = CPUPinned(id); | |
} else { | |
LOG(FATAL) << "Invalid context string " << str; | |
} | |
} catch (...) { | |
LOG(FATAL) << "Invalid context string " << str; | |
} | |
return ret; | |
} | |
inline std::ostream& operator<<(std::ostream &out, const Context &ctx) { | |
if (ctx.dev_type == Context::kCPU) { | |
out << "cpu("; | |
} else if (ctx.dev_type == Context::kGPU) { | |
out << "gpu("; | |
} else if (ctx.dev_type == Context::kCPUPinned) { | |
out << "cpu_pinned("; | |
} else { | |
out << "unknown("; | |
} | |
out << ctx.dev_id << ")"; | |
return out; | |
} | |
} // namespace mxnet | |
//===== EXPANDING: ../include/mxnet/tensor_blob.h ===== | |
/*! | |
* Copyright (c) 2014 by Contributors | |
* \file tensor_blob.h | |
* \brief TBlob class that holds common representation of | |
* arbirary dimension tensor, can be used to transformed | |
* to normal fixed dimenson tensor | |
* \author Tianqi Chen | |
*/ | |
#ifndef MXNET_TENSOR_BLOB_H_ | |
#define MXNET_TENSOR_BLOB_H_ | |
#if MXNET_USE_MKL2017 == 1 | |
#endif | |
namespace mxnet { | |
/*! | |
* \brief tensor blob class that can be used to hold tensor of any dimension, | |
* any device and any data type, | |
* This is a weak type that can be used to transfer data through interface | |
* TBlob itself do not involve any arithmentic operations, | |
* but it can be converted to tensor of fixed dimension for further operations | |
* | |
* Like tensor, this data structure is like a pointer class and do not | |
* implicit allocated, de-allocate space. | |
* This data structure can be helpful to hold tensors of different dimensions | |
* and wait for further processing | |
*/ | |
class TBlob { | |
public: | |
/*! \brief pointer to the data */ | |
void *dptr_; | |
/*! \brief shape of the tensor */ | |
TShape shape_; | |
/*! | |
* \brief storing the stride information in x dimension | |
*/ | |
index_t stride_; | |
/*! \brief device mask of the corresponding device */ | |
int dev_mask_; | |
/*! \brief type flag of the tensor blob */ | |
int type_flag_; | |
/*! \brief storing mkl chunk buffer blob, use for experimental only */ | |
#if MKL_EXPERIMENTAL == 1 | |
std::shared_ptr<MKLMemHolder> Mkl_mem_; | |
#endif | |
/*! \brief default constructor, default copy assign will work */ | |
TBlob(void) | |
: dptr_(NULL), dev_mask_(cpu::kDevMask), | |
type_flag_(mshadow::DataType<real_t>::kFlag) { | |
#if MKL_EXPERIMENTAL == 1 | |
Mkl_mem_ = NULL; | |
#endif | |
} | |
/*! | |
* \brief constructor that construct TBlob from contiguous memory | |
* \param dptr the pointer to the memory | |
* \param shape the shape of the data | |
* \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask | |
*/ | |
template<typename DType> | |
TBlob(DType *dptr, | |
const TShape &shape, | |
int dev_mask) | |
: dptr_(dptr), shape_(shape), | |
stride_(shape[shape.ndim() - 1]), | |
dev_mask_(dev_mask), | |
type_flag_(mshadow::DataType<DType>::kFlag) { | |
#if MKL_EXPERIMENTAL == 1 | |
Mkl_mem_ = NULL; | |
#endif | |
} | |
/*! | |
* \brief constructor that construct TBlob from contiguous memory | |
* \param dptr the pointer to the memory | |
* \param shape the shape of the data | |
* \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask | |
* \param type_flag the type flag. Can be one of enum mshadow::dtype | |
*/ | |
TBlob(void *dptr, | |
const TShape &shape, | |
int dev_mask, | |
int type_flag) | |
: dptr_(dptr), shape_(shape), | |
stride_(shape[shape.ndim() - 1]), | |
dev_mask_(dev_mask), | |
type_flag_(type_flag) { | |
#if MKL_EXPERIMENTAL == 1 | |
Mkl_mem_ = NULL; | |
#endif | |
} | |
/*! | |
* \brief constructor from tensor | |
* \param src source tensor | |
* \tparam Device which device the tensor is on | |
* \tparam dim tensor dimension | |
* \tparam DType the type of elements in the tensor | |
*/ | |
template<typename Device, int dim, typename DType> | |
TBlob(const mshadow::Tensor<Device, dim, DType> &src) { // NOLINT(*) | |
*this = src; | |
#if MKL_EXPERIMENTAL == 1 | |
Mkl_mem_ = NULL; | |
#endif | |
} | |
/*! | |
* \brief assignment from tensor | |
* \param src source tensor | |
* \tparam Device which device the tensor is on | |
* \tparam dim tensor dimension | |
* \tparam DType the type of elements in the tensor | |
* \return reference of self | |
*/ | |
template<typename Device, int dim, typename DType> | |
inline TBlob | |
&operator=(const mshadow::Tensor<Device, dim, DType> &src) { | |
dptr_ = src.dptr_; | |
shape_ = src.shape_; | |
stride_ = src.stride_; | |
dev_mask_ = Device::kDevMask; | |
type_flag_ = mshadow::DataType<DType>::kFlag; | |
return *this; | |
} | |
/*! | |
* \return whether the tensor's memory is continuous | |
*/ | |
inline bool CheckContiguous(void) const { | |
return shape_[shape_.ndim() - 1] == stride_; | |
} | |
/*! | |
* \brief flatten the tensor to 2 dimension, collapse the higher dimensions together | |
* \param stream the possible stream target tensor should reside on | |
* \tparam Device which device the tensor is on | |
* \tparam DType the type of elements in the tensor | |
* \return tensor after flatten | |
*/ | |
template<typename Device, typename DType> | |
inline mshadow::Tensor<Device, 2, DType> FlatTo2D( | |
mshadow::Stream<Device> *stream = NULL) const { | |
CHECK(Device::kDevMask == dev_mask_) | |
<< "TBlob.get: device type do not match specified type"; | |
CHECK(mshadow::DataType<DType>::kFlag == type_flag_) | |
<< "TBlob.get_with_shape: data type do not match specified type." | |
<< "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag; | |
#if MKL_EXPERIMENTAL == 1 | |
if (Mkl_mem_ != nullptr) { | |
Mkl_mem_->check_and_prv_to_cpu(dptr_); | |
} | |
#endif | |
return mshadow::Tensor<Device, 2, DType>(static_cast<DType*>(dptr_), | |
shape_.FlatTo2D(), stride_, stream); | |
} | |
/*! | |
* \brief flatten the tensor to 1 dimension, collapse all the dimensions together. | |
* \param stream the possible stream target tensor should reside on | |
* \tparam Device which device the tensor is on | |
* \tparam DType the type of elements in the tensor | |
* \return tensor after flatten | |
*/ | |
template<typename Device, typename DType> | |
inline mshadow::Tensor<Device, 1, DType> FlatTo1D( | |
mshadow::Stream<Device> *stream = NULL) const { | |
return this->get_with_shape<Device, 1, DType>( | |
mshadow::Shape1(shape_.Size()), stream); | |
} | |
/*! \brief return number of dimension of the tensor inside */ | |
inline int ndim(void) const { | |
return shape_.ndim(); | |
} | |
/*! | |
* \brief return size of i-th dimension, start counting from highest dimension | |
* \param idx the dimension count from the highest dimensin | |
* \return the size | |
*/ | |
inline index_t size(index_t idx) const { | |
return shape_[idx]; | |
} | |
/*! \brief total number of elements in the tensor */ | |
inline index_t Size(void) const { | |
return shape_.Size(); | |
} | |
/*! \brief get pointer in dtype */ | |
template<typename DType> | |
inline DType* dptr() const { | |
CHECK(mshadow::DataType<DType>::kFlag == type_flag_) | |
<< "TBlob.dptr(): data type do not match specified type."; | |
return static_cast<DType*>(dptr_); | |
} | |
/*! | |
* \brief fetch the tensor, with respect to specific dimension | |
* if dim do not match the stored dimension, an error will be issued | |
* \return the tensor requested | |
* \param stream the possible stream target tensor should reside on | |
* \tparam Device which device the tensor is on | |
* \tparam dim dimension of the tensor | |
* \tparam DType the type of elements in the tensor | |
*/ | |
template<typename Device, int dim, typename DType> | |
inline mshadow::Tensor<Device, dim, DType> get(mshadow::Stream<Device> *stream = NULL) const { | |
CHECK(Device::kDevMask == dev_mask_) | |
<< "TBlob.get: device type do not match specified type"; | |
CHECK(mshadow::DataType<DType>::kFlag == type_flag_) | |
<< "TBlob.get_with_shape: data type do not match specified type." | |
<< "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag; | |
#if MKL_EXPERIMENTAL == 1 | |
if (Mkl_mem_ != nullptr) { | |
Mkl_mem_->check_and_prv_to_cpu(dptr_); | |
} | |
#endif | |
return mshadow::Tensor<Device, dim, DType>(static_cast<DType*>(dptr_), | |
shape_.get<dim>(), | |
stride_, stream); | |
} | |
/*! | |
* \brief fetch a tensor in given shape | |
* If size do not match the stored size, an error will be issued | |
* \return the tensor requested | |
* \param shape the shape required | |
* \param stream the possible stream target tensor should reside on | |
* \tparam Device which device the tensor is on | |
* \tparam dim dimension of the tensor | |
* \tparam DType the type of elements in the tensor | |
*/ | |
template<typename Device, int dim, typename DType> | |
inline mshadow::Tensor<Device, dim, DType> get_with_shape( | |
const mshadow::Shape<dim> &shape, | |
mshadow::Stream<Device> *stream = NULL) const { | |
CHECK(Device ::kDevMask == dev_mask_) | |
<< "TBlob.get: device type do not match specified type"; | |
CHECK(mshadow::DataType<DType>::kFlag == type_flag_) | |
<< "TBlob.get_with_shape: data type do not match specified type." | |
<< "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag; | |
CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be contiguous"; | |
CHECK_EQ(this->shape_.Size(), shape.Size()) | |
<< "TBlob.get_with_shape: new and old shape do not match total elements"; | |
#if MKL_EXPERIMENTAL == 1 | |
if (Mkl_mem_ != nullptr) { | |
Mkl_mem_->check_and_prv_to_cpu(dptr_); | |
} | |
#endif | |
return mshadow::Tensor<Device, dim, DType>(static_cast<DType*>(dptr_), | |
shape, | |
shape[dim - 1], | |
stream); | |
} | |
/*! | |
* \brief flatten the tensor to 3 dimension, | |
* collapse the dimension before and after specified axis. | |
* \param axis The axis specified. | |
* \param stream the possible stream target tensor should reside on | |
* \tparam Device which device the tensor is on | |
* \tparam DType the type of elements in the tensor | |
* \return tensor after flatten | |
*/ | |
template<typename Device, typename DType> | |
inline mshadow::Tensor<Device, 3, DType> FlatTo3D( | |
int axis, mshadow::Stream<Device> *stream = NULL) const { | |
return this->get_with_shape<Device, 3, DType>( | |
this->shape_.FlatTo3D(axis), stream); | |
} | |
/*! | |
* \brief flatten the tensor to 3 dimension, | |
* collapse the dimension: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim). | |
* \param axis_begin The beginning axis specified. | |
* \param axis_end The ending axis specified. | |
* \param stream the possible stream target tensor should reside on | |
* \tparam Device which device the tensor is on | |
* \tparam DType the type of elements in the tensor | |
* \return tensor after flatten | |
*/ | |
template<typename Device, typename DType> | |
inline mshadow::Tensor<Device, 3, DType> FlatTo3D( | |
int axis_begin, int axis_end, | |
mshadow::Stream<Device> *stream = NULL) const { | |
return this->get_with_shape<Device, 3, DType>( | |
this->shape_.FlatTo3D(axis_begin, axis_end), stream); | |
} | |
}; | |
} // namespace mxnet | |
namespace dmlc { | |
// Add a few patches to support TShape in dmlc/parameter. | |
DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)"); | |
DMLC_DECLARE_TYPE_NAME(nnvm::Tuple<int>, "Shape(tuple)"); | |
namespace parameter { | |
template<> | |
class FieldEntry<mxnet::TShape> | |
: public FieldEntryBase<FieldEntry<mxnet::TShape>, mxnet::TShape> { | |
public: | |
FieldEntry() : enforce_nonzero_(false), expect_ndim_(0) {} | |
// parent class | |
typedef FieldEntryBase<FieldEntry<mxnet::TShape>, mxnet::TShape> Parent; | |
virtual void Check(void *head) const { | |
Parent::Check(head); | |
mxnet::TShape &v = this->Get(head); | |
if (expect_ndim_ != 0 && v.ndim() != expect_ndim_) { | |
std::ostringstream os; | |
os << "value " << v << "for Parameter " << this->key_ | |
<< " has wrong dimensions, expected dimension=" << expect_ndim_; | |
throw dmlc::ParamError(os.str()); | |
} | |
if (enforce_nonzero_) { | |
for (mxnet::index_t i = 0; i < v.ndim(); ++i) { | |
if (v[i] == 0U) { | |
std::ostringstream os; | |
os << "value " << v << "for Parameter " << this->key_ | |
<< " is invalid, the input shape must be nonzero in all dimensions"; | |
throw dmlc::ParamError(os.str()); | |
} | |
} | |
} | |
} | |
inline FieldEntry<mxnet::TShape> &enforce_nonzero() { | |
this->enforce_nonzero_ = true; | |
return this->self(); | |
} | |
inline FieldEntry<mxnet::TShape> &set_expect_ndim(mxnet::index_t ndim) { | |
expect_ndim_ = ndim; | |
return this->self(); | |
} | |
private: | |
// whether all the entries need to be nonzero | |
bool enforce_nonzero_; | |
// expected number of dimension, default = 0 means no restriction. | |
mxnet::index_t expect_ndim_; | |
}; | |
} // namespace parameter | |
} // namespace dmlc | |
#endif // MXNET_TENSOR_BLOB_H_ | |
//===== EXPANDED: ../include/mxnet/tensor_blob.h ===== | |
//! \endcond | |
#endif // MXNET_BASE_H_ | |
//===== EXPANDED: ../include/mxnet/base.h ===== | |
//===== EXPANDING: ../nnvm/src/core/graph.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file graph_attr_types.cc | |
* \brief Graph node data structure. | |
*/ | |
//===== EXPANDING: ../nnvm/include/nnvm/graph.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file graph.h | |
* \brief Configuation of nnvm as well as basic data structure. | |
*/ | |
#ifndef NNVM_GRAPH_H_ | |
#define NNVM_GRAPH_H_ | |
namespace nnvm { | |
class IndexedGraph; | |
/*! | |
* \brief Symbolic computation graph. | |
* This is the intermediate representation for optimization pass. | |
*/ | |
class Graph { | |
public: | |
/*! \brief outputs of the computation graph. */ | |
std::vector<NodeEntry> outputs; | |
/*! | |
* \brief attributes of a graph | |
* Note that attribute is shared pointer and can be shared across graphs. | |
* | |
* It is highly recommended to keep each attribute immutable. | |
* It is also safe to implement an copy-on-write semnatics. | |
* | |
* Copy when shared_ptr.unique is not true, while reuse original space | |
* when shared_ptr.unique is true. | |
*/ | |
std::unordered_map<std::string, std::shared_ptr<any> > attrs; | |
/*! | |
* \brief Get the immutable attribute from attrs. | |
* \param attr_name the name of the attribute | |
* \return the reference to corresponding attribute | |
* \tparam T the type of the attribute. | |
*/ | |
template<typename T> | |
inline const T& GetAttr(const std::string& attr_name) const; | |
/*! | |
* \brief Get a move copy of the attribute, implement copy on write semantics. | |
* The content is moved if the reference counter of shared_ptr is 1. | |
* The attribute is erased from attrs after the call. | |
* | |
* \param attr_name the name of the attribute | |
* \return a new copy of the corresponding attribute. | |
* \tparam T the type of the attribute. | |
*/ | |
template<typename T> | |
inline T MoveCopyAttr(const std::string& attr_name); | |
/*! | |
* \brief get a indexed graph of current graph, if not exist, create it on demand | |
* \return The indexed graph. | |
* \sa IndexedGraph | |
*/ | |
const IndexedGraph& indexed_graph(); | |
private: | |
// internal structure of indexed graph | |
std::shared_ptr<const IndexedGraph> indexed_graph_; | |
}; | |
/*! | |
* \brief Auxililary data structure to index a graph. | |
* It maps Nodes in the graph to consecutive integers node_id. | |
* It also maps IndexedGraph::NodeEntry to consecutive integer entry_id. | |
* This allows storing properties of Node and NodeEntry into | |
* compact vector and quickly access them without resorting to hashmap. | |
* | |
* The node_id and entry_rptr are the same as the JSON graph produced by SaveJSON Pass. | |
*/ | |
class IndexedGraph { | |
public: | |
/*! \brief represents a data in the graph */ | |
struct NodeEntry { | |
/*! \brief the source node id in the computation graph */ | |
uint32_t node_id; | |
/*! \brief index of output from the source. */ | |
uint32_t index; | |
/*! \brief version of the node */ | |
uint32_t version; | |
}; | |
/*! \brief Node data structure in IndexedGraph */ | |
struct Node { | |
/*! \brief pointer to the source node */ | |
const nnvm::Node* source; | |
/*! \brief inputs to the node */ | |
array_view<NodeEntry> inputs; | |
/*! \brief control flow dependencies to the node */ | |
array_view<uint32_t> control_deps; | |
}; | |
/*! \return number of nodes in the graph */ | |
inline size_t num_nodes() const { | |
return nodes_.size(); | |
} | |
/*! \return total number of NodeEntry in the graph */ | |
inline size_t num_node_entries() const { | |
return entry_rptr_.back(); | |
} | |
/*! | |
* \brief Get a unique entry id between 0 to num_node_entries() | |
* for a given IndexedGraph::NodeEntry | |
* \param node_id The node index | |
* \param index the output index | |
* \return the unique index. | |
*/ | |
inline uint32_t entry_id(uint32_t node_id, uint32_t index) const { | |
return entry_rptr_[node_id] + index; | |
} | |
/*! | |
* \brief Get a unique entry id between 0 to num_node_entries() | |
* for a given IndexedGraph::NodeEntry | |
* \param e The entry to query for index. | |
* \return the unique index. | |
*/ | |
inline uint32_t entry_id(const NodeEntry& e) const { | |
return entry_rptr_[e.node_id] + e.index; | |
} | |
/*! | |
* \brief Get a unique entry id between 0 to num_node_entries() | |
* for a given NodeEntry. | |
* \param e The entry to query for index. | |
* \return the unique index. | |
*/ | |
inline uint32_t entry_id(const nnvm::NodeEntry& e) const { | |
return entry_rptr_[node_id(e.node.get())] + e.index; | |
} | |
/*! | |
* \brief Get the corresponding node id for a given Node in the IndexedGraph. | |
* \param node The Node to query for index. | |
* \return the node index. | |
*/ | |
inline uint32_t node_id(const nnvm::Node* node) const { | |
return node2index_.at(node); | |
} | |
/*! | |
* \brief Get the corresponding Node structure for a given node_id. | |
* \param node_id The node id | |
* \return const reference to the corresponding IndexedGraph::Node | |
*/ | |
inline const Node& operator[](uint32_t node_id) const { | |
return nodes_[node_id]; | |
} | |
/*! | |
* \brief Get the corresponding Node structure | |
* \param node The pointer to the Node structure | |
* \return const reference to the corresponding IndexedGraph::Node | |
*/ | |
inline const Node& operator[](const nnvm::Node* node) const { | |
return nodes_[node_id(node)]; | |
} | |
/*! \return list of argument nodes */ | |
inline const std::vector<uint32_t>& input_nodes() const { | |
return input_nodes_; | |
} | |
/*! \return list of mutable nodes */ | |
inline const std::unordered_set<uint32_t>& mutable_input_nodes() const { | |
return mutable_input_nodes_; | |
} | |
/*! \return list of output entries */ | |
inline const std::vector<NodeEntry>& outputs() const { | |
return outputs_; | |
} | |
// disalllow copy assign | |
IndexedGraph(const IndexedGraph&) = delete; | |
private: | |
friend class Graph; | |
/*! | |
* \brief Constructor an IndexedGraph from normal Graph | |
* \param other The source graph. | |
*/ | |
explicit IndexedGraph(const Graph& other); | |
// Node pointers in CSR structure. | |
std::vector<Node> nodes_; | |
// Index to all input nodes. | |
std::vector<uint32_t> input_nodes_; | |
// Index to all mutable input nodes. | |
std::unordered_set<uint32_t> mutable_input_nodes_; | |
// space to store the outputs entries | |
std::vector<NodeEntry> outputs_; | |
// mapping from node to index. | |
std::unordered_map<const nnvm::Node*, uint32_t> node2index_; | |
// CSR pointer of node entries | |
std::vector<size_t> entry_rptr_; | |
// space to store input entries of each | |
std::vector<NodeEntry> input_entries_; | |
// control flow dependencies | |
std::vector<uint32_t> control_deps_; | |
}; | |
/*! | |
* \brief perform a Post Order DFS visit to each node in the graph. | |
* This order is deterministic and is also topoligical sorted. | |
* \param heads The heads in the graph. | |
* \param fvisit a function of type std::function<void(const std::shared_ptr<Node>&)> | |
* \tparam FVisit The function type to perform the visit. | |
*/ | |
template<typename FVisit> | |
inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit); | |
// inline function implementations | |
template<typename T> | |
inline const T& Graph::GetAttr(const std::string& attr_name) const { | |
auto it = attrs.find(attr_name); | |
CHECK(it != attrs.end()) | |
<< "Cannot find attribute " << attr_name << " in the graph"; | |
return nnvm::get<T>(*it->second); | |
} | |
template<typename T> | |
inline T Graph::MoveCopyAttr(const std::string& attr_name) { | |
auto it = attrs.find(attr_name); | |
CHECK(it != attrs.end()) | |
<< "Cannot find attribute " << attr_name << " in the graph"; | |
std::shared_ptr<any> sptr = it->second; | |
attrs.erase(it); | |
if (sptr.unique()) { | |
return std::move(nnvm::get<T>(*sptr)); | |
} else { | |
return nnvm::get<T>(*sptr); | |
} | |
} | |
template <typename GNode, typename HashType, | |
typename FVisit, typename HashFunc, | |
typename InDegree, typename GetInput> | |
void PostOrderDFSVisit(const std::vector<GNode>& heads, | |
FVisit fvisit, | |
HashFunc hash, | |
InDegree indegree, | |
GetInput getinput) { | |
std::vector<std::pair<GNode, uint32_t> > stack; | |
std::unordered_set<HashType> visited; | |
for (auto& head : heads) { | |
HashType head_hash = hash(head); | |
if (visited.count(head_hash) == 0) { | |
stack.push_back(std::make_pair(head, 0)); | |
visited.insert(head_hash); | |
} | |
while (!stack.empty()) { | |
std::pair<GNode, uint32_t>& back = stack.back(); | |
if (back.second == indegree(back.first)) { | |
fvisit(back.first); | |
stack.pop_back(); | |
} else { | |
const GNode& input = getinput(back.first, back.second++); | |
HashType input_hash = hash(input); | |
if (visited.count(input_hash) == 0) { | |
stack.push_back(std::make_pair(input, 0)); | |
visited.insert(input_hash); | |
} | |
} | |
} | |
} | |
} | |
template<typename FVisit> | |
inline void DFSVisit(const std::vector<NodeEntry>& heads, | |
FVisit fvisit) { | |
typedef const NodePtr* GNode; | |
std::vector<GNode> head_nodes(heads.size()); | |
std::transform(heads.begin(), heads.end(), head_nodes.begin(), | |
[](const NodeEntry& e)->GNode { | |
return &e.node; | |
}); | |
PostOrderDFSVisit<GNode, Node*>( | |
head_nodes, | |
[fvisit](GNode n) { fvisit(*n); }, // FVisit | |
[](GNode n)->Node* { return n->get(); }, // HashFunc | |
[](GNode n)->uint32_t { // InDegree | |
return (*n)->inputs.size() + (*n)->control_deps.size(); | |
}, | |
[](GNode n, uint32_t index)->GNode { // GetInput | |
if (index < (*n)->inputs.size()) { | |
return &(*n)->inputs.at(index).node; | |
} else { | |
return &(*n)->control_deps.at(index - (*n)->inputs.size()); | |
} | |
}); | |
} | |
} // namespace nnvm | |
#endif // NNVM_GRAPH_H_ | |
//===== EXPANDED: ../nnvm/include/nnvm/graph.h ===== | |
//===== EXPANDING: ../nnvm/include/nnvm/op_attr_types.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file op_attr_types.h | |
* \brief Data structures that can appear in operator attributes. | |
*/ | |
#ifndef NNVM_OP_ATTR_TYPES_H_ | |
#define NNVM_OP_ATTR_TYPES_H_ | |
namespace nnvm { | |
// These types are optional attributes in each operator. | |
// Each attribute can be required by some passes. | |
/*! | |
* \brief Return list of input arguments names of each operator. | |
* | |
* \param attrs The attributes of the node. | |
* \return list of inputs | |
* \note Register under "FListInputNames", default return {"data"}. | |
* | |
* FListInputNames enables automatic variable creation for missing arguments. | |
*/ | |
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>; | |
/*! | |
* \brief Return number of visible outputs by the user. | |
* | |
* \param attrs The attributes of the node. | |
* | |
* \note Register under "FNumVisibleOutputs", default not registered. | |
* This can be used to hide certain output from the user, | |
* but the additional outputs can be used to pass information from | |
* forward to gradient pass. | |
*/ | |
using FNumVisibleOutputs = std::function<uint32_t (const NodeAttrs& attrs)>; | |
/*! | |
* \brief Return list of output arguments names of each operator. | |
* | |
* \param attrs The attributes of the node. | |
* \return list of inputs | |
* \note Register under "FListOutputNames", default return {"outputs"}. | |
* | |
* FListOutputNames customized naming for operator outputs. | |
*/ | |
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>; | |
/*! | |
* \brief Check whether operator will mutate k-th input. | |
* \param attrs The attributes of the node. | |
* \return list of input indices it mutates. | |
* | |
* \note Register under "FMutateInputs", default return false | |
* FMutateInputs enables mutation order handling correctly. | |
*/ | |
using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attrs)>; | |
/*! | |
* \brief Inference function of certain type. | |
* \tparam AttrType The type of the attribute to be infered. | |
* \return whether all attributes are inferred. | |
*/ | |
template<typename AttrType> | |
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs, | |
std::vector<AttrType> *in_attrs, | |
std::vector<AttrType> *out_attrs)>; | |
/*! | |
* \brief Shape inference function. | |
* Update the shapes given the input shape information. | |
* TShape.ndim() == 0 means the shape is still unknown. | |
* | |
* \note Register under "FInferShape", | |
* by default do not update any shapes. | |
* | |
* FInferShape is needed by shape inference | |
*/ | |
using FInferShape = FInferNodeEntryAttr<TShape>; | |
/*! | |
* \brief Type inference function. | |
* Update the type given the known type information. | |
* | |
* \note Register under "FInferType", | |
* by default set all the output types to 0. | |
*/ | |
using FInferType = FInferNodeEntryAttr<int>; | |
/*! | |
* \brief Whether this op is an explicit backward operator, | |
* If TIsBackward is true: | |
* - The first control_deps of the node points to the corresponding forward operator. | |
* | |
* \note Register under "TIsBackward" | |
* This enables easier shape/type inference for backward operators. | |
*/ | |
using TIsBackward = bool; | |
/*! | |
* \brief Get possible inplace options. | |
* This function enables optimization to reuse memory of inputs in output. | |
* \param attrs The attributes of the node | |
* \param in_data The input data. | |
* \param out_data The output data. | |
* \return list of pair of that maps input->output, | |
* indicating possible in place operations. | |
* | |
* \note Register under "FInplaceOption", by default no inplace can happen. | |
*/ | |
using FInplaceOption = std::function< | |
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>; | |
/*! | |
* \brief Get list of inputs in the op whose content are actually not used by the operator | |
* These are dummy input that can be used for example in zeros_like, ones_like. | |
* | |
* \param attrs The attributes of the node | |
* \return list input index that are not used by the operator. | |
* | |
* \note Register under "FIgnoreInputs". | |
*/ | |
using FIgnoreInputs = std::function< | |
std::vector<uint32_t> (const NodeAttrs& attrs)>; | |
/*! | |
* \brief Get the gradient node of the op node | |
* This function generates the backward graph of the node | |
* \param nodeptr The node to take gradient | |
* \param out_grads Gradient of current node's outputs | |
* \return gradients of the inputs | |
* | |
* \note Register under "FGradient" | |
*/ | |
using FGradient = std::function<std::vector<NodeEntry>( | |
const NodePtr& nodeptr, | |
const std::vector<NodeEntry>& out_grads)>; | |
/*! | |
* \brief Set the attributes of input variable. | |
* Usually used for setting initialization or weight decay. | |
* \param attrs The attributes of this node. | |
* \param var the input variable | |
* \param index index of var in all inputs | |
*/ | |
using FSetInputVarAttrOnCompose = std::function<void( | |
const NodeAttrs& attrs, | |
NodePtr var, | |
const int index)>; | |
} // namespace nnvm | |
#endif // NNVM_OP_ATTR_TYPES_H_ | |
//===== EXPANDED: ../nnvm/include/nnvm/op_attr_types.h ===== | |
namespace nnvm { | |
const IndexedGraph& Graph::indexed_graph() { | |
if (indexed_graph_ == nullptr) { | |
indexed_graph_.reset(new IndexedGraph(*this)); | |
} | |
return *indexed_graph_; | |
} | |
// implement constructor from graph | |
IndexedGraph::IndexedGraph(const Graph &g) { | |
entry_rptr_.push_back(0); | |
std::vector<size_t> inputs_rptr{0}, control_rptr{0}; | |
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr] | |
(const NodePtr& n) { | |
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max()); | |
uint32_t nid = static_cast<uint32_t>(nodes_.size()); | |
// nodes_ | |
IndexedGraph::Node new_node; | |
new_node.source = n.get(); | |
nodes_.emplace_back(std::move(new_node)); | |
// arg_nodes_ | |
if (n->is_variable()) { | |
input_nodes_.push_back(nid); | |
} | |
// node2index_ | |
node2index_[n.get()] = nid; | |
// entry rptr | |
entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs()); | |
// input entries | |
for (const auto& e : n->inputs) { | |
auto it = node2index_.find(e.node.get()); | |
CHECK(it != node2index_.end() && it->first == e.node.get()); | |
input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version}); | |
} | |
inputs_rptr.push_back(input_entries_.size()); | |
// control deps | |
for (const auto& nptr : n->control_deps) { | |
auto it = node2index_.find(nptr.get()); | |
CHECK(it != node2index_.end() && it->first == nptr.get()); | |
control_deps_.push_back(it->second); | |
} | |
control_rptr.push_back(control_deps_.size()); | |
}); | |
for (const auto& e : g.outputs) { | |
outputs_.emplace_back(NodeEntry{ | |
node2index_.at(e.node.get()), e.index, e.version}); | |
} | |
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); | |
std::unordered_set<uint32_t> mutable_inputs; | |
// setup array view | |
// input_entries_ and control_rptr must not change after this step. | |
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_); | |
for (size_t nid = 0; nid < nodes_.size(); ++nid) { | |
nodes_[nid].inputs = array_view<NodeEntry>( | |
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); | |
if (nodes_[nid].source->op() != nullptr && | |
fmutate_inputs.count(nodes_[nid].source->op())) { | |
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) { | |
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id); | |
} | |
} | |
} | |
const uint32_t* cptr = dmlc::BeginPtr(control_deps_); | |
for (size_t nid = 0; nid < nodes_.size(); ++nid) { | |
nodes_[nid].control_deps = array_view<uint32_t>( | |
cptr + control_rptr[nid], cptr + control_rptr[nid + 1]); | |
} | |
} | |
} // namespace nnvm | |
//===== EXPANDED: ../nnvm/src/core/graph.cc ===== | |
//===== EXPANDING: ../nnvm/src/core/op.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file op.cc | |
* \brief Support for operator registry. | |
*/ | |
namespace dmlc { | |
// enable registry | |
DMLC_REGISTRY_ENABLE(nnvm::Op); | |
} // namespace dmlc | |
namespace nnvm { | |
// single manager of operator information. | |
struct OpManager { | |
// mutex to avoid registration from multiple threads. | |
// recursive is needed for trigger(which calls UpdateAttrMap) | |
std::recursive_mutex mutex; | |
// global operator counter | |
std::atomic<int> op_counter{0}; | |
// storage of additional attribute table. | |
std::unordered_map<std::string, std::unique_ptr<any> > attr; | |
// storage of existing triggers | |
std::unordered_map<std::string, std::vector<std::function<void(Op*)> > > tmap; | |
// group of each operator. | |
std::vector<std::unordered_set<std::string> > op_group; | |
// get singleton of the | |
static OpManager* Global() { | |
static OpManager inst; | |
return &inst; | |
} | |
}; | |
// constructor | |
Op::Op() { | |
OpManager* mgr = OpManager::Global(); | |
index_ = mgr->op_counter++; | |
} | |
Op& Op::add_alias(const std::string& alias) { // NOLINT(*) | |
dmlc::Registry<Op>::Get()->AddAlias(this->name, alias); | |
return *this; | |
} | |
// find operator by name | |
const Op* Op::Get(const std::string& name) { | |
const Op* op = dmlc::Registry<Op>::Find(name); | |
CHECK(op != nullptr) | |
<< "Operator " << name << " is not registered"; | |
return op; | |
} | |
// Get attribute map by key | |
const any* Op::GetAttrMap(const std::string& key) { | |
auto& dict = OpManager::Global()->attr; | |
auto it = dict.find(key); | |
if (it != dict.end()) { | |
return it->second.get(); | |
} else { | |
return nullptr; | |
} | |
} | |
// update attribute map | |
void Op::UpdateAttrMap(const std::string& key, | |
std::function<void(any*)> updater) { | |
OpManager* mgr = OpManager::Global(); | |
std::lock_guard<std::recursive_mutex>(mgr->mutex); | |
std::unique_ptr<any>& value = mgr->attr[key]; | |
if (value.get() == nullptr) value.reset(new any()); | |
if (updater != nullptr) updater(value.get()); | |
} | |
void Op::AddGroupTrigger(const std::string& group_name, | |
std::function<void(Op*)> trigger) { | |
OpManager* mgr = OpManager::Global(); | |
std::lock_guard<std::recursive_mutex>(mgr->mutex); | |
auto& tvec = mgr->tmap[group_name]; | |
tvec.push_back(trigger); | |
auto& op_group = mgr->op_group; | |
for (const Op* op : dmlc::Registry<Op>::List()) { | |
if (op->index_ < op_group.size() && | |
op_group[op->index_].count(group_name) != 0) { | |
trigger((Op*)op); // NOLINT(*) | |
} | |
} | |
} | |
Op& Op::include(const std::string& group_name) { | |
OpManager* mgr = OpManager::Global(); | |
std::lock_guard<std::recursive_mutex>(mgr->mutex); | |
auto it = mgr->tmap.find(group_name); | |
if (it != mgr->tmap.end()) { | |
for (auto& trigger : it->second) { | |
trigger(this); | |
} | |
} | |
auto& op_group = mgr->op_group; | |
if (index_ >= op_group.size()) { | |
op_group.resize(index_ + 1); | |
} | |
op_group[index_].insert(group_name); | |
return *this; | |
} | |
} // namespace nnvm | |
//===== EXPANDED: ../nnvm/src/core/op.cc ===== | |
//===== EXPANDING: ../nnvm/src/core/symbolic.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file symbolic.cc | |
* \brief Symbolic graph composition API. | |
*/ | |
namespace nnvm { | |
namespace symbol_constants { | |
const char *kNamespaceSeparator = "$"; | |
} // namespace symbol_constants | |
// auxililary version attribute in variable. | |
struct VariableParam { | |
uint32_t version{0}; | |
}; | |
NodePtr CreateVariableNode(const std::string& name) { | |
NodePtr n = Node::Create(); | |
n->attrs.op = nullptr; | |
n->attrs.name = name; | |
n->attrs.parsed = VariableParam(); | |
return n; | |
} | |
// scan over a node's input, update the version to latest | |
// If the node's op mutates a certain input variable, | |
// The version of that varaible will increase | |
// version is used to implicitly order the mutation sequences | |
inline void UpdateNodeVersion(Node *n) { | |
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); | |
for (NodeEntry& e : n->inputs) { | |
if (e.node->is_variable()) { | |
e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version; | |
} | |
} | |
if (fmutate_inputs.count(n->op()) != 0) { | |
for (uint32_t i : fmutate_inputs[n->op()](n->attrs)) { | |
NodeEntry& e = n->inputs[i]; | |
CHECK(e.node->is_variable()) | |
<< "Mutation target can only be Variable"; | |
// increase the version of the variable. | |
e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version; | |
} | |
} | |
} | |
inline std::string DefaultVarName(const std::string &op_name, | |
const std::string &arg_name) { | |
if (op_name.length() == 0) { | |
return arg_name; | |
} else { | |
return op_name + '_' + arg_name; | |
} | |
} | |
inline void KeywordArgumentMismatch(const char *source, | |
const std::vector<std::string>& user_args, | |
const array_view<std::string>& args) { | |
std::unordered_set<std::string> keys(args.begin(), args.end()); | |
std::ostringstream head, msg; | |
msg << "\nCandidate arguments:\n"; | |
for (size_t i = 0; i < args.size(); ++i) { | |
msg << "\t[" << i << ']' << args[i] << '\n'; | |
} | |
for (const auto& key : user_args) { | |
if (keys.count(key) == 0) { | |
LOG(FATAL) << source | |
<< "Keyword argument name " << key << " not found." | |
<< msg.str(); | |
} | |
} | |
} | |
template<typename T> | |
inline std::vector<std::string> GetKeys( | |
const std::unordered_map<std::string, T>& kwargs) { | |
std::vector<std::string> keys(kwargs.size()); | |
std::transform(kwargs.begin(), kwargs.end(), keys.begin(), | |
[](decltype(*kwargs.begin())& kv) { return kv.first; }); | |
return keys; | |
} | |
// whether the symbol is atomic functor | |
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) { | |
return outputs[0].node->inputs.size() == 0 && | |
outputs[0].node->control_deps.size() == 0; | |
} | |
// public functions | |
Symbol Symbol::Copy() const { | |
std::unordered_map<Node*, NodePtr> old_new; | |
// use DFSVisit to copy all the nodes | |
DFSVisit(this->outputs, [&old_new](const NodePtr& node) { | |
NodePtr np = Node::Create(); | |
np->attrs = node->attrs; | |
old_new[node.get()] = std::move(np); | |
}); | |
// connect nodes of new graph | |
for (const auto &kv : old_new) { | |
for (const NodeEntry& e : kv.first->inputs) { | |
Node *ptr = e.node.get(); | |
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); | |
} | |
for (const NodePtr& p : kv.first->control_deps) { | |
kv.second->control_deps.emplace_back(old_new[p.get()]); | |
} | |
} | |
// set the head | |
Symbol ret; | |
for (const NodeEntry &e : outputs) { | |
ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version}); | |
} | |
return ret; | |
} | |
void Symbol::Print(std::ostream &os) const { | |
if (outputs.size() == 1 && | |
outputs[0].node->inputs.size() == 0 && | |
outputs[0].node->control_deps.size() == 0) { | |
if (outputs[0].node->is_variable()) { | |
os << "Variable:" << outputs[0].node->attrs.name << '\n'; | |
} else { | |
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op()->name << '\n'; | |
} | |
} else { | |
// use DFSVisit to copy all the nodes | |
os << "Symbol Outputs:\n"; | |
for (size_t i = 0; i < outputs.size(); ++i) { | |
os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name | |
<< '(' << outputs[i].index << ")\n"; | |
} | |
DFSVisit(this->outputs, [&os](const NodePtr& node) { | |
if (node->is_variable()) { | |
os << "Variable:" << node->attrs.name << '\n'; | |
} else { | |
os << "--------------------\n"; | |
os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n' | |
<< "Inputs:\n"; | |
for (size_t i = 0; i < node->inputs.size(); ++i) { | |
const NodeEntry& e = node->inputs[i]; | |
os << "\targ[" << i << "]=" << e.node->attrs.name | |
<< '(' << e.index << ")"; | |
if (e.node->is_variable()) { | |
os << " version=" << e.version << '\n'; | |
} else { | |
os << '\n'; | |
} | |
} | |
if (!node->attrs.dict.empty()) { | |
os << "Attrs:\n"; | |
// make an ordered copy because unordered_map doesn't guarantee order. | |
std::map<std::string, std::string> sorted_dict( | |
node->attrs.dict.begin(), node->attrs.dict.end()); | |
for (auto &kv : sorted_dict) { | |
os << '\t' << kv.first << '=' << kv.second << '\n'; | |
} | |
} | |
if (node->control_deps.size() != 0) { | |
os << "Control deps:\n"; | |
for (size_t i = 0; i < node->control_deps.size(); ++i) { | |
os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n'; | |
} | |
} | |
} | |
}); | |
} | |
} | |
Symbol Symbol::operator[] (size_t index) const { | |
size_t nreturn = outputs.size(); | |
CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; | |
if (nreturn == 1) { | |
return *this; | |
} else { | |
Symbol s; | |
s.outputs.push_back(outputs[index]); | |
return s; | |
} | |
} | |
std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const { | |
std::vector<NodePtr> ret; | |
if (option == kAll) { | |
DFSVisit(this->outputs, [&ret](const NodePtr &node) { | |
if (node->is_variable()) { | |
ret.push_back(node); | |
} | |
}); | |
} else { | |
std::unordered_set<Node*> mutable_set; | |
std::vector<NodePtr> vlist; | |
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); | |
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) { | |
if (node->is_variable()) { | |
vlist.push_back(node); | |
} else if (fmutate_inputs.count(node->op())) { | |
for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){ | |
mutable_set.insert(node->inputs[i].node.get()); | |
} | |
} | |
}); | |
for (const NodePtr& node : vlist) { | |
if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) || | |
(option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) { | |
ret.emplace_back(node); | |
} | |
} | |
} | |
return ret; | |
} | |
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const { | |
std::vector<NodePtr> inputs = ListInputs(option); | |
std::vector<std::string> ret(inputs.size()); | |
for (size_t i = 0; i < inputs.size(); ++i) { | |
ret[i] = inputs[i]->attrs.name; | |
} | |
return ret; | |
} | |
std::vector<std::string> Symbol::ListOutputNames() const { | |
static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames"); | |
std::vector<std::string> ret; | |
for (auto &head : outputs) { | |
if (head.node->is_variable()) { | |
ret.push_back(head.node->attrs.name); | |
} else { | |
const std::string& hname = head.node->attrs.name; | |
std::string rname; | |
FListOutputNames fn = flist_ouputs.get(head.node->op(), nullptr); | |
if (fn != nullptr) { | |
rname = fn(head.node->attrs)[head.index]; | |
} else { | |
rname = "output"; | |
if (head.node->num_outputs() != 1) { | |
std::ostringstream os; | |
os << rname << head.index; | |
rname = os.str(); | |
} | |
} | |
if (hname.length() == 0) { | |
ret.push_back(std::move(rname)); | |
} else { | |
ret.push_back(hname + '_' + rname); | |
} | |
} | |
} | |
return ret; | |
} | |
// compositional logic | |
void Symbol::Compose(const array_view<const Symbol*>& args, | |
const std::unordered_map<std::string, const Symbol*>& kwargs, | |
const std::string& name) { | |
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames"); | |
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose"); | |
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed"; | |
// parameter check. | |
for (size_t i = 0; i < args.size(); ++i) { | |
CHECK_EQ(args[i]->outputs.size(), 1U) | |
<< "Argument " << i << " is a tuple, single value is required"; | |
} | |
for (const auto& kv : kwargs) { | |
CHECK_EQ(kv.second->outputs.size(), 1U) | |
<< "Keyword Argument " << kv.first << " is a tuple, single value is required"; | |
} | |
// assign new name | |
outputs[0].node->attrs.name = name; | |
// Atomic functor composition. | |
if (IsAtomic(outputs)) { | |
Node* n = outputs[0].node.get(); | |
uint32_t n_req = n->num_inputs(); | |
if (n_req != kVarg) { | |
n->inputs.resize(n_req); | |
CHECK_LE(args.size(), n_req) | |
<< "Incorrect number of arguments, requires " << n_req | |
<< ", provided " << args.size(); | |
for (size_t i = 0; i < args.size(); ++i) { | |
n->inputs[i] = args[i]->outputs[0]; | |
} | |
// switch to keyword argument matching | |
if (args.size() != n_req) { | |
FListInputNames fn = flist_inputs.get(n->op(), nullptr); | |
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs); | |
if (arg_names.size() != n_req) { | |
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name; | |
} | |
size_t nmatched = 0; | |
for (size_t i = args.size(); i < n_req; ++i) { | |
auto it = kwargs.find(arg_names[i]); | |
if (it != kwargs.end() && it->first == arg_names[i]) { | |
n->inputs[i] = it->second->outputs[0]; | |
++nmatched; | |
} else { | |
n->inputs[i] = NodeEntry{ | |
CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0}; | |
// copy attribute of parent over automatically created variables | |
n->inputs[i].node->attrs.dict = n->attrs.dict; | |
} | |
} | |
if (nmatched != kwargs.size()) { | |
n->inputs.clear(); | |
std::vector<std::string> keys = GetKeys(kwargs); | |
array_view<std::string> view(dmlc::BeginPtr(arg_names) + args.size(), | |
dmlc::BeginPtr(arg_names) + arg_names.size()); | |
KeywordArgumentMismatch("Symbol.Compose", keys, view); | |
} | |
} | |
} else { | |
CHECK_EQ(kwargs.size(), 0U) << "Variable length function do not accept kwargs"; | |
n->inputs.reserve(args.size()); | |
for (const Symbol* s : args) { | |
n->inputs.push_back(s->outputs[0]); | |
} | |
} | |
UpdateNodeVersion(n); | |
FSetInputVarAttrOnCompose fn = fset_attrs.get(n->op(), nullptr); | |
if (fn != nullptr) { | |
for (size_t i = 0; i < n->inputs.size(); ++i) { | |
if (n->inputs[i].node->is_variable()) { | |
fn(n->attrs, n->inputs[i].node, i); | |
} | |
} | |
} | |
} else { | |
// general composition | |
CHECK_EQ(args.size(), 0U) | |
<< "General composition only support kwargs for now"; | |
size_t nmatched = 0; | |
size_t arg_counter = 0; | |
std::unordered_map<Node *, const NodeEntry*> replace_map; | |
// replace map stores the existing replacement plan for arguments node | |
auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map] | |
(const NodePtr &node) { | |
if (node->is_variable()) { | |
if (arg_counter < args.size()) { | |
replace_map[node.get()] = &(args[arg_counter]->outputs[0]); | |
++arg_counter; | |
} else { | |
// match kwargs | |
auto kit = kwargs.find(node->attrs.name); | |
if (kit != kwargs.end()) { | |
replace_map[node.get()] = &(kit->second->outputs[0]); | |
++nmatched; | |
} | |
} | |
} | |
}; | |
DFSVisit(this->outputs, find_replace_map); | |
if (nmatched == kwargs.size() && arg_counter <= args.size()) { | |
std::vector<Node*> update_nodes; | |
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan; | |
auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes] | |
(const NodePtr &node) { | |
// visit all the childs, find possible replacement | |
bool repl = false; | |
for (size_t i = 0; i < node->inputs.size(); ++i) { | |
NodeEntry *e = &(node->inputs[i]); | |
if (e->node->is_variable()) { | |
auto iter = replace_map.find(e->node.get()); | |
if (iter != replace_map.end()) { | |
replace_plan.push_back(std::make_pair(e, iter->second)); | |
repl = true; | |
} | |
} | |
} | |
if (repl) update_nodes.push_back(node.get()); | |
}; | |
DFSVisit(this->outputs, find_replace_plan); | |
for (const auto& kv : replace_plan) { | |
*(kv.first) = *(kv.second); | |
} | |
for (Node* n : update_nodes) { | |
UpdateNodeVersion(n); | |
} | |
} else { | |
std::vector<std::string> keys = GetKeys(kwargs); | |
std::vector<std::string> arg_names = ListInputNames(kAll); | |
array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_counter, | |
dmlc::BeginPtr(arg_names) + arg_names.size()); | |
KeywordArgumentMismatch("Symbol.Compose", keys, arg_names); | |
} | |
} | |
} | |
Symbol Symbol::operator () (const array_view<const Symbol*>& args, | |
const std::unordered_map<std::string, const Symbol*>& kwargs, | |
const std::string& name) const { | |
Symbol s = this->Copy(); | |
s.Compose(args, kwargs, name); | |
return s; | |
} | |
void Symbol::AddControlDeps(const Symbol& src) { | |
CHECK_EQ(outputs.size(), 1U) | |
<< "AddControlDeps only works for nongrouped symbol"; | |
Node* n = outputs[0].node.get(); | |
for (const NodeEntry& sp : src.outputs) { | |
n->control_deps.push_back(sp.node); | |
} | |
} | |
Symbol Symbol::GetInternals() const { | |
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs"); | |
Symbol ret; | |
DFSVisit(this->outputs, [&ret](const NodePtr& node) { | |
Node* n = node.get(); | |
if (n->is_variable()) { | |
// grab version from variable. | |
VariableParam& param = nnvm::get<VariableParam>(n->attrs.parsed); | |
ret.outputs.emplace_back(NodeEntry{node, 0, param.version}); | |
} else { | |
uint32_t nout = n->num_outputs(); | |
if (fnum_vis_output.count(n->op())) { | |
nout = fnum_vis_output[n->op()](n->attrs); | |
} | |
for (uint32_t i = 0; i < nout; ++i) { | |
ret.outputs.emplace_back(NodeEntry{node, i, 0}); | |
} | |
} | |
}); | |
return ret; | |
} | |
void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs) { | |
Node* node = outputs[0].node.get(); | |
for (const NodeEntry& e : outputs) { | |
CHECK(node == e.node.get()) | |
<< "Symbol.SetAttrs only works for non-grouped symbol"; | |
} | |
for (const auto& kv : attrs) { | |
if (kv.first == "name") { | |
node->attrs.name = kv.second; | |
} else { | |
node->attrs.dict[kv.first] = kv.second; | |
} | |
} | |
if (node->op() != nullptr && node->op()->attr_parser != nullptr) { | |
node->op()->attr_parser(&(node->attrs)); | |
} | |
} | |
bool Symbol::GetAttr(const std::string& key, std::string* out) const { | |
Node* node = outputs[0].node.get(); | |
for (const NodeEntry& e : outputs) { | |
if (node != e.node.get()) return false; | |
} | |
if (key == "name") { | |
*out = node->attrs.name; | |
return true; | |
} | |
auto it = node->attrs.dict.find(key); | |
if (it == node->attrs.dict.end()) return false; | |
*out = it->second; | |
return true; | |
} | |
std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption option) const { | |
if (option == kRecursive) { | |
std::unordered_map<std::string, std::string> ret; | |
DFSVisit(this->outputs, [&ret](const NodePtr& n) { | |
for (const auto& it : n->attrs.dict) { | |
ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; | |
} | |
}); | |
return ret; | |
} else { | |
return outputs[0].node->attrs.dict; | |
} | |
} | |
std::vector<std::tuple<std::string, std::string, std::string> > | |
Symbol::ListAttrsRecursive() const { | |
std::vector<std::tuple<std::string, std::string, std::string> > ret; | |
DFSVisit(this->outputs, [&ret](const NodePtr& n) { | |
for (const auto& it : n->attrs.dict) { | |
ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second)); | |
} | |
}); | |
return ret; | |
} | |
Symbol Symbol::CreateFunctor(const Op* op, | |
std::unordered_map<std::string, std::string> attrs) { | |
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs"); | |
Symbol s; | |
NodePtr n = Node::Create(); | |
n->attrs.op = op; | |
n->attrs.dict = std::move(attrs); | |
if (n->op()->attr_parser != nullptr) { | |
n->op()->attr_parser(&(n->attrs)); | |
} | |
uint32_t nout = n->num_outputs(); | |
if (fnum_vis_output.count(n->op())) { | |
nout = fnum_vis_output[n->op()](n->attrs); | |
} | |
for (uint32_t i = 0; i < nout; ++i) { | |
s.outputs.emplace_back(NodeEntry{n, i, 0}); | |
} | |
return s; | |
} | |
Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) { | |
Symbol ret; | |
for (const auto &s : symbols) { | |
ret.outputs.insert(ret.outputs.end(), s.outputs.begin(), s.outputs.end()); | |
} | |
return ret; | |
} | |
Symbol Symbol::CreateVariable(const std::string& name) { | |
Symbol s; | |
s.outputs.emplace_back(NodeEntry{CreateVariableNode(name), 0, 0}); | |
return s; | |
} | |
} // namespace nnvm | |
//===== EXPANDED: ../nnvm/src/core/symbolic.cc ===== | |
//===== EXPANDING: ../nnvm/src/core/node.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file node.cc | |
* \brief Graph node data structure. | |
*/ | |
namespace nnvm { | |
Node::~Node() { | |
if (inputs.size() != 0) { | |
// explicit deletion via DFS | |
// this is used to avoid stackoverflow caused by chain of deletions | |
std::vector<Node*> stack{this}; | |
std::vector<NodePtr> to_delete; | |
while (!stack.empty()) { | |
Node* n = stack.back(); | |
stack.pop_back(); | |
for (NodeEntry& e : n->inputs) { | |
if (e.node.unique()) { | |
stack.push_back(e.node.get()); | |
to_delete.emplace_back(std::move(e.node)); | |
} else { | |
e.node.reset(); | |
} | |
} | |
for (NodePtr& sp : n->control_deps) { | |
if (sp.unique()) { | |
stack.push_back(sp.get()); | |
} else { | |
sp.reset(); | |
} | |
} | |
n->inputs.clear(); | |
} | |
} | |
} | |
NodePtr Node::Create() { | |
return std::make_shared<Node>(); | |
} | |
} // namespace nnvm | |
//===== EXPANDED: ../nnvm/src/core/node.cc ===== | |
//===== EXPANDING: ../nnvm/src/core/pass.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file pass.cc | |
* \brief Support for pass registry. | |
*/ | |
//===== EXPANDING: ../nnvm/include/nnvm/pass.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file pass.h | |
* \brief Pass that can be applied to a graph. | |
*/ | |
#ifndef NNVM_PASS_H_ | |
#define NNVM_PASS_H_ | |
namespace nnvm { | |
/*! | |
* \brief A PassFunction is an "Operator on Graph". | |
* It takes a source graph and return a graph that may or may | |
* not be the same as the input one. | |
* | |
* A pass function can either change the graph structure (thus, | |
* generating a new Graph), or add new attributes to the graph. | |
* | |
* \param src The graph to be transformed. | |
* \return The generated graph. | |
*/ | |
typedef std::function<Graph (Graph src)> PassFunction; | |
/*! | |
* \brief Apply a series of pass transformations on the input graph. | |
* \param src The graph to be transformed. | |
* \param passes A list of pass names to be applied. | |
* \return The transformed graph | |
*/ | |
Graph ApplyPasses(Graph src, | |
const std::vector<std::string>& passes); | |
/*! | |
* \brief Apply one pass to the graph. | |
* \param src The graph to be transformed. | |
* \param pass The name of pass to be applied. | |
* \return The transformed graph. | |
*/ | |
inline Graph ApplyPass(Graph src, const std::string& pass) { | |
return ApplyPasses(src, {pass}); | |
} | |
/*! | |
* \brief Registry entry for DataIterator factory functions. | |
*/ | |
struct PassFunctionReg | |
: public dmlc::FunctionRegEntryBase<PassFunctionReg, | |
PassFunction> { | |
/*! | |
* \brief Whether the pass will change graph structure | |
* If this is false, the pass will only change attributes. | |
*/ | |
bool change_graph{false}; | |
/*! \brief dependencies on operator attributes */ | |
std::vector<std::string> op_attr_dependency; | |
/*! \brief dependencies on attributes in the graph */ | |
std::vector<std::string> graph_attr_dependency; | |
/*! \brief generated targets of graph attributes */ | |
std::vector<std::string> graph_attr_targets; | |
/*! | |
* \brief Set whether this pass will change graph structure. | |
* \param v If true, the pass will change graph structure. | |
* \return Reference to self. | |
*/ | |
PassFunctionReg& set_change_graph(bool v) { // NOLINT(*) | |
change_graph = v; | |
return *this; | |
} | |
/*! | |
* \brief Declare that this pass will generate the given graph attribute name | |
* once it is applied on the graph. | |
* \param attr_name Name of the graph attribute. | |
* \return Reference to self. | |
*/ | |
PassFunctionReg& provide_graph_attr(const std::string& attr_name) { // NOLINT(*) | |
graph_attr_targets.push_back(attr_name); | |
return *this; | |
} | |
/*! | |
* \brief Declare this pass requires the given operator attribute to be | |
* available before being applied on the graph. | |
* \param attr_name Name of the attribute. | |
* \return Reference to self. | |
*/ | |
PassFunctionReg& depend_op_attr(const std::string& attr_name) { // NOLINT(*) | |
op_attr_dependency.push_back(attr_name); | |
return *this; | |
} | |
/*! | |
* \brief Declare this pass requires the given graph attribute to be | |
* available before being applied on the graph. | |
* \param attr_name Name of the attribute. | |
* \return Reference to self. | |
*/ | |
PassFunctionReg& depend_graph_attr(const std::string& attr_name) { // NOLINT(*) | |
graph_attr_dependency.push_back(attr_name); | |
return *this; | |
} | |
}; | |
/*! | |
* \def NNVM_REGISTER_PASS | |
* \brief Macro to register pass fuctions. | |
* | |
* \code | |
* // example of registering a shape inference pass | |
* NNVM_REGISTER_PASS(InferShape) | |
* .describe("Shape Inference function, generate graph attributes") | |
* .provide_graph_attr("data_shape") | |
* .depend_graph_attr("indexed_graph") | |
* .depend_op_attr("infer_shape") | |
* .set_body([](const Graph& g) { | |
* // shape inference logic | |
* }); | |
* \endcode | |
*/ | |
#define NNVM_REGISTER_PASS(name) \ | |
DMLC_REGISTRY_REGISTER(::nnvm::PassFunctionReg, PassFunctionReg, name) | |
} // namespace nnvm | |
#endif // NNVM_PASS_H_ | |
//===== EXPANDED: ../nnvm/include/nnvm/pass.h ===== | |
namespace dmlc { | |
// enable registry | |
DMLC_REGISTRY_ENABLE(nnvm::PassFunctionReg); | |
} // namespace dmlc | |
namespace nnvm { | |
const PassFunctionReg* FindPassDep(const std::string&attr_name) { | |
for (auto* r : dmlc::Registry<PassFunctionReg>::List()) { | |
for (auto& s : r->graph_attr_targets) { | |
if (s == attr_name) return r; | |
} | |
} | |
return nullptr; | |
} | |
Graph ApplyPasses(Graph g, | |
const std::vector<std::string>& pass) { | |
std::vector<const PassFunctionReg*> fpass; | |
for (auto& name : pass) { | |
auto* reg = dmlc::Registry<PassFunctionReg>::Find(name); | |
CHECK(reg != nullptr) | |
<< "Cannot find pass " << name << " in the registry"; | |
fpass.push_back(reg); | |
} | |
for (auto r : fpass) { | |
for (auto& dep : r->graph_attr_dependency) { | |
if (g.attrs.count(dep) == 0) { | |
auto* pass_dep = FindPassDep(dep); | |
std::string msg; | |
if (pass_dep != nullptr) { | |
msg = " The attribute is provided by pass " + pass_dep->name; | |
} | |
LOG(FATAL) << "Graph attr dependency " << dep | |
<< " is required by pass " << r->name | |
<< " but is not available " | |
<< msg; | |
} | |
} | |
g = r->body(std::move(g)); | |
} | |
return g; | |
} | |
} // namespace nnvm | |
//===== EXPANDED: ../nnvm/src/core/pass.cc ===== | |
//===== EXPANDING: ../nnvm/src/pass/gradient.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file gradients.cc | |
* \brief Passes that takes gradient of the graph | |
* This code code was modified based on mxnet codebase by Min Lin | |
*/ | |
namespace nnvm { | |
namespace pass { | |
namespace { | |
// default aggregate gradient function | |
// require operator __zero__ and __ewise_sum__ to be presented. | |
NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) { | |
if (v.size() == 1) { | |
return std::move(v[0]); | |
} else if (v.size() == 0) { | |
NodePtr zero_node = Node::Create(); | |
zero_node->attrs.op = Op::Get("__zero__"); | |
return NodeEntry{zero_node, 0, 0}; | |
} else { | |
NodePtr sum_node = Node::Create(); | |
sum_node->attrs.op = Op::Get("__ewise_sum__"); | |
sum_node->inputs = std::move(v); | |
return NodeEntry{sum_node, 0, 0}; | |
} | |
} | |
// helper entry | |
struct GradEntry { | |
#ifdef _MSC_VER | |
NodeEntry sum = NodeEntry{nullptr, 0, 0}; | |
#else | |
NodeEntry sum{nullptr, 0, 0}; | |
#endif | |
std::vector<NodeEntry> grads; | |
bool need_attr_hint{true}; | |
}; | |
Graph Gradient(Graph src) { | |
using nnvm::FGradient; | |
using MirrorFun = std::function<int (const Node& node)>; | |
using AttrHintFun = std::function<NodeEntry (const NodeEntry& src, const NodeEntry &like)>; | |
CHECK_NE(src.attrs.count("grad_ys"), 0U) | |
<< "Gradient require grad_ys to be presented."; | |
CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U) | |
<< "Gradient require grad_ys_out_grad to be presented."; | |
CHECK_NE(src.attrs.count("grad_xs"), 0U) | |
<< "Gradient require grad_xs to be presented."; | |
const std::vector<NodeEntry>& ys = | |
src.GetAttr<std::vector<NodeEntry> >("grad_ys"); | |
const std::vector<NodeEntry>& ys_out_grad = | |
src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad"); | |
const std::vector<NodeEntry>& xs = | |
src.GetAttr<std::vector<NodeEntry> >("grad_xs"); | |
using AggFun = std::function<NodeEntry (std::vector<NodeEntry>&& inputs)>; | |
AggFun agg_fun = DefaultAggregateGradient; | |
if (src.attrs.count("grad_aggregate_fun") != 0) { | |
agg_fun = src.GetAttr<AggFun>("grad_aggregate_fun"); | |
} | |
MirrorFun mirror_fun = nullptr; | |
if (src.attrs.count("grad_mirror_fun") != 0) { | |
mirror_fun = src.GetAttr<MirrorFun>("grad_mirror_fun"); | |
} | |
AttrHintFun attr_hint_fun = nullptr; | |
if (src.attrs.count("attr_hint_fun") != 0) { | |
attr_hint_fun = src.GetAttr<AttrHintFun>("attr_hint_fun"); | |
} | |
// topo sort | |
std::vector<NodePtr> topo_order; | |
std::unordered_map<Node*, std::vector<GradEntry> > output_grads; | |
DFSVisit(ys, [&](const NodePtr& node) { | |
if (output_grads.count(node.get()) == 0) { | |
output_grads[node.get()].resize(node->num_outputs()); | |
} | |
topo_order.push_back(node); | |
}); | |
CHECK_EQ(ys.size(), ys_out_grad.size()); | |
for (size_t i = 0; i < ys.size(); ++i) { | |
NodeEntry ograd = ys_out_grad[i]; | |
output_grads[ys[i].node.get()][ys[i].index].grads = { ograd }; | |
} | |
// construct mirror reduece memory strategy if needed | |
std::unordered_map<Node*, NodePtr> mirror_map; | |
if (mirror_fun != nullptr) { | |
for (const NodePtr& n : topo_order) { | |
if (mirror_fun(*n)) { | |
NodePtr new_node = Node::Create(); | |
*new_node = *n; | |
new_node->attrs.name += "_mirror"; | |
for (auto& e : new_node->inputs) { | |
e.node = mirror_map.at(e.node.get()); | |
} | |
for (auto& n : new_node->control_deps) { | |
n = mirror_map.at(n.get()); | |
} | |
mirror_map[n.get()] = std::move(new_node); | |
} else { | |
mirror_map[n.get()] = n; | |
} | |
} | |
} | |
// traverse backward | |
static auto& grad_fun_map = Op::GetAttr<FGradient>("FGradient"); | |
static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape"); | |
std::vector<NodeEntry> out_agg_grads; | |
for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { | |
const NodePtr& ptr = *rit; | |
if (ptr->is_variable()) continue; | |
out_agg_grads.clear(); | |
auto& out_grad_vec = output_grads.at(ptr.get()); | |
for (uint32_t i = 0; i < out_grad_vec.size(); ++i) { | |
GradEntry& e = out_grad_vec[i]; | |
e.sum = agg_fun(std::move(e.grads)); | |
if (e.need_attr_hint && attr_hint_fun != nullptr) { | |
e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i}); | |
} | |
out_agg_grads.push_back(e.sum); | |
} | |
if ((*rit)->inputs.size() != 0) { | |
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); | |
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]( | |
fwd_node, out_agg_grads); | |
CHECK_EQ((*rit)->inputs.size(), input_grads.size()) | |
<< "Gradient function not returning enough gradient"; | |
auto git = input_grads.begin(); | |
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { | |
auto& ge = output_grads[it->node.get()][it->index]; | |
// if any of the backward op can do shape inference, the hint is not necessary. | |
if (finfer_shape.count(git->node->op())) { | |
ge.need_attr_hint = false; | |
} | |
ge.grads.emplace_back(std::move(*git)); | |
} | |
} | |
} | |
// take out the xs' grads | |
Graph ret; | |
ret.outputs.reserve(xs.size()); | |
for (const NodeEntry& e : xs) { | |
GradEntry& entry = output_grads[e.node.get()][e.index]; | |
// aggregate sum if there haven't been | |
if (entry.sum.node.get() == nullptr) { | |
entry.sum = agg_fun(std::move(entry.grads)); | |
if (entry.need_attr_hint && attr_hint_fun != nullptr) { | |
entry.sum = attr_hint_fun(entry.sum, e); | |
} | |
} | |
ret.outputs.emplace_back(std::move(entry.sum)); | |
} | |
return ret; | |
} | |
// register pass | |
NNVM_REGISTER_PASS(Gradient) | |
.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]") | |
.set_body(Gradient) | |
.set_change_graph(true) | |
.depend_graph_attr("grad_ys") | |
.depend_graph_attr("grad_xs") | |
.depend_graph_attr("grad_ys_out_grad"); | |
} // namespace | |
} // namespace pass | |
} // namespace nnvm | |
//===== EXPANDED: ../nnvm/src/pass/gradient.cc ===== | |
//===== EXPANDING: ../nnvm/src/pass/order_mutation.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file order_mutation.cc | |
* \brief Add control flow dependencies between nodes | |
* To correctly order mutation and read to resolve | |
* write after read problem and read after write problems. | |
*/ | |
namespace nnvm { | |
namespace pass { | |
namespace { | |
template<typename T> | |
inline T get_with_default(const std::unordered_map<Node*, T> &map, | |
Node* key, | |
const T& def) { | |
auto it = map.find(key); | |
if (it != map.end()) return it->second; | |
return def; | |
} | |
inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) { | |
return std::binary_search(mutate_inputs.begin(), mutate_inputs.end(), i); | |
} | |
Graph OrderMutation(const Graph& src) { | |
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist; | |
DFSVisit(src.outputs, [&version_hist](const NodePtr& n) { | |
for (const NodeEntry& e : n->inputs) { | |
if (e.node->is_variable()) { | |
if (e.version != 0 && version_hist.count(e.node.get()) == 0) { | |
version_hist[e.node.get()] = std::vector<NodeEntry>{}; | |
} | |
} | |
} | |
}); | |
// no mutation happens, everything if fine. | |
if (version_hist.size() == 0) return src; | |
// start preparing for remapping the nodes. | |
std::unordered_map<Node*, NodePtr> old_new; | |
auto prepare = [&version_hist, &old_new] (const NodePtr& n) { | |
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); | |
std::vector<uint32_t> mutate_inputs; | |
if (!n->is_variable() && fmutate_inputs.count(n->op())) { | |
mutate_inputs = fmutate_inputs[n->op()](n->attrs); | |
} | |
std::sort(mutate_inputs.begin(), mutate_inputs.end()); | |
bool need_repl = false; | |
for (size_t i = 0; i < n->inputs.size(); ++i) { | |
const NodeEntry& e = n->inputs[i]; | |
if (e.node->is_variable()) { | |
if (e.version != 0) need_repl = true; | |
auto it = version_hist.find(e.node.get()); | |
if (it != version_hist.end()) { | |
std::vector<NodeEntry>& vec = it->second; | |
vec.emplace_back(NodeEntry{n, IsMutate(mutate_inputs, i), e.version}); | |
} | |
} else { | |
if (old_new.count(e.node.get()) != 0) need_repl = true; | |
} | |
} | |
for (const NodePtr& p : n->control_deps) { | |
if (old_new.count(p.get()) != 0) need_repl = true; | |
} | |
if (need_repl) { | |
NodePtr np = Node::Create(); | |
np->attrs = n->attrs; | |
old_new[n.get()] = std::move(np); | |
} | |
}; | |
DFSVisit(src.outputs, prepare); | |
// comparator of history entry | |
auto comparator = [](const NodeEntry& a, const NodeEntry &b) { | |
if (a.version < b.version) return true; | |
if (a.version > b.version) return false; | |
return a.index > b.index; | |
}; | |
for (auto &kv : version_hist) { | |
std::sort(kv.second.begin(), kv.second.end(), comparator); | |
} | |
// copy the nodes, as well as add control deps | |
for (auto &kv : old_new) { | |
// copy the nodes | |
for (const NodeEntry& e : kv.first->inputs) { | |
auto it = old_new.find(e.node.get()); | |
if (it != old_new.end()) { | |
kv.second->inputs.emplace_back(NodeEntry{it->second, e.index, e.version}); | |
} else { | |
kv.second->inputs.push_back(e); | |
} | |
} | |
for (const NodePtr& p : kv.first->control_deps) { | |
kv.second->control_deps.emplace_back( | |
get_with_default(old_new, p.get(), p)); | |
} | |
// add control deps | |
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); | |
std::vector<uint32_t> mutate_inputs; | |
if (fmutate_inputs.count(kv.first->op())) { | |
mutate_inputs = fmutate_inputs[kv.first->op()](kv.first->attrs); | |
} | |
std::sort(mutate_inputs.begin(), mutate_inputs.end()); | |
for (size_t i = 0; i < kv.first->inputs.size(); ++i) { | |
const NodeEntry& e = kv.first->inputs[i]; | |
if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) { | |
std::vector<NodeEntry>& vec = version_hist.at(e.node.get()); | |
auto it = std::lower_bound(vec.begin(), vec.end(), | |
NodeEntry{nullptr, 1, e.version}, | |
comparator); | |
if (IsMutate(mutate_inputs, i)) { | |
int read_dep = 0; | |
while (it != vec.begin()) { | |
--it; | |
if (it->index != 0) break; | |
++read_dep; | |
// depend on previous read | |
kv.second->control_deps.push_back( | |
get_with_default(old_new, it->node.get(), it->node)); | |
} | |
if (read_dep == 0 && it->index != 0) { | |
// depend on last write | |
kv.second->control_deps.push_back( | |
get_with_default(old_new, it->node.get(), it->node)); | |
} | |
} else { | |
// depend on last write | |
if (it->index != 0) { | |
kv.second->control_deps.push_back( | |
get_with_default(old_new, it->node.get(), it->node)); | |
} | |
} | |
} | |
} | |
} | |
Graph ret; | |
for (const NodeEntry &e : src.outputs) { | |
ret.outputs.emplace_back(NodeEntry{ | |
get_with_default(old_new, e.node.get(), e.node), e.index, e.version}); | |
} | |
return ret; | |
} | |
NNVM_REGISTER_PASS(OrderMutation) | |
.describe("Return a new graph that adds control dependencies, "\ | |
"to order the mutation and reads if mutation exists.") | |
.set_body(OrderMutation) | |
.set_change_graph(true); | |
} // namespace | |
} // namespace pass | |
} // namespace nnvm | |
//===== EXPANDED: ../nnvm/src/pass/order_mutation.cc ===== | |
//===== EXPANDING: ../nnvm/src/pass/plan_memory.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file plan_memory.cc | |
* \brief Assign memory tag to each of the data entries. | |
*/ | |
//===== EXPANDING: ../nnvm/include/nnvm/graph_attr_types.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file graph_attr_types.h | |
* \brief Data structures that can appear in graph attributes. | |
*/ | |
#ifndef NNVM_GRAPH_ATTR_TYPES_H_ | |
#define NNVM_GRAPH_ATTR_TYPES_H_ | |
namespace nnvm { | |
/*! | |
* \brief The result holder of JSON serializer | |
* | |
* \note Stored under ret.attrs["json"], provided by Pass "SaveJSON" | |
* \code | |
* Graph ret = ApplyPass(src_graph, "SaveJSON"); | |
* const JSONString& json = ret.GetAttr<JSONString>("shape"); | |
* \endcode | |
*/ | |
using JSONString = std::string; | |
/*! | |
* \brief The result holder of shape of each NodeEntry in the graph. | |
* \note Stored under graph.attrs["shape"], provided by Pass "InferShape" | |
* | |
* \code | |
* Graph g = ApplyPass(src_graph, "InferShape"); | |
* const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape"); | |
* // get shape by entry id | |
* TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)]; | |
* \endcode | |
* | |
* \sa FInferShape | |
*/ | |
using ShapeVector = std::vector<TShape>; | |
/*! | |
* \brief The result holder of type of each NodeEntry in the graph. | |
* \note Stored under graph.attrs["dtype"], provided by Pass "InferType" | |
* | |
* \code | |
* Graph g = ApplyPass(src_graph, "InferType"); | |
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype"); | |
* // get shape by entry id | |
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)]; | |
* \endcode | |
* | |
* \sa FInferType | |
*/ | |
using DTypeVector = std::vector<int>; | |
/*! | |
* \brief The result holder of device of each operator in the graph. | |
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice" | |
* | |
* \code | |
* Graph g = ApplyPass(src_graph, "PlaceDevice"); | |
* const &device = g.GetAttr<DeviceVector>("device"); | |
* // get device by node_id | |
* int device_type = device[g.indexed_graph().node_id(my_node)]; | |
* \endcode | |
*/ | |
using DeviceVector = std::vector<int>; | |
/*! | |
* \brief The result holder of device of each operator in the graph. | |
* | |
* \note Stored under graph.attrs["device_assign_map"], needed by Pass "PlaceDevice" | |
* -1 means unknown device | |
*/ | |
using DeviceAssignMap = std::unordered_map<std::string, int>; | |
/*! | |
* \brief The result holder of storage id of each NodeEntry in the graph. | |
* | |
* \note Stored under graph.attrs["storage"], provided by Pass "PlanMemory" | |
* Storage id is a continuous integer. | |
* If the storage id is -1 then the storage is not assigned. | |
* | |
* \code | |
* Graph g = ApplyPass(src_graph, "PlanMemory"); | |
* const &storage = g.GetAttr<StorageVector>("storage"); | |
* // get storage id by entry | |
* int storage_id = storage[g.indexed_graph().entry_id(my_entry)]; | |
* \endcode | |
*/ | |
using StorageVector = std::vector<int>; | |
} // namespace nnvm | |
#endif // NNVM_GRAPH_ATTR_TYPES_H_ | |
//===== EXPANDED: ../nnvm/include/nnvm/graph_attr_types.h ===== | |
//===== EXPANDING: ../nnvm/src/pass/graph_algorithm.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file graph_algorithm.h | |
* \brief This header contains graph algorithms on StaticGraph. | |
* It is used compute informations such as whether two | |
* operations can run in parallel, and helps allocation. | |
*/ | |
#ifndef NNVM_PASS_GRAPH_ALGORITHM_H_ | |
#define NNVM_PASS_GRAPH_ALGORITHM_H_ | |
namespace nnvm { | |
namespace pass { | |
/*! | |
* \brief Find best path in the DAG, with reward defined | |
* by sum of reward of each node along the path. | |
* \param graph the original static graph. | |
* \param topo_order topo order of the nodes in the graph. | |
* \param node_reward the reward of each node. | |
* \param path the output path of nodes. | |
* \return the total reward of best path. | |
*/ | |
inline uint32_t FindBestPath( | |
const IndexedGraph& graph, | |
const std::vector<uint32_t>& node_reward, | |
std::vector<uint32_t>* path) { | |
const uint32_t num_nodes = static_cast<uint32_t>(graph.num_nodes()); | |
CHECK_EQ(num_nodes, node_reward.size()); | |
std::vector<uint32_t> best_reward(node_reward.size(), 0); | |
std::vector<uint32_t> next_node(node_reward.size(), num_nodes); | |
uint32_t best_solution = 0, best_start_node = 0; | |
// traverse in reverse topo order | |
for (uint32_t i = static_cast<uint32_t>(graph.num_nodes()); i != 0; --i) { | |
const uint32_t nid = i - 1; | |
best_reward[nid] += node_reward[nid]; | |
if (best_reward[nid] > best_solution) { | |
best_solution = best_reward[nid]; | |
best_start_node = nid; | |
} | |
for (const auto& e : graph[nid].inputs) { | |
const uint32_t prev = e.node_id; | |
if (best_reward[nid] > best_reward[prev]) { | |
best_reward[prev] = best_reward[nid]; | |
next_node[prev] = nid; | |
} | |
} | |
} | |
path->clear(); | |
uint32_t reward = 0; | |
for (uint32_t nid = best_start_node; nid < num_nodes; nid = next_node[nid]) { | |
path->push_back(nid); reward += node_reward[nid]; | |
} | |
CHECK_EQ(reward, best_solution); | |
return best_solution; | |
} | |
/*! | |
* \brief Color the nodes in the graph into index. | |
* The coloring algorithm tries to assign node group | |
* such that node in the same group cannot run in parallel. | |
* | |
* \param graph the original indexed graph. | |
* \param node_importance The importance of the node | |
* \param max_ncolor maximum number of colors allowed. | |
* \param color the color index of each of the node. | |
* \return the total number of colors. | |
*/ | |
inline uint32_t ColorNodeGroup( | |
const IndexedGraph &graph, | |
std::vector<uint32_t> node_importance, | |
uint32_t max_ncolor, | |
std::vector<uint32_t> *color) { | |
CHECK_NE(max_ncolor, 0U); | |
CHECK_EQ(graph.num_nodes(), node_importance.size()); | |
color->clear(); | |
color->resize(graph.num_nodes(), max_ncolor); | |
uint32_t cindex; | |
// greedy algorithm, every time | |
// find a path with best reward and assign a new color | |
// All the nodes in the path cannot run in parallel. | |
for (cindex = 0; cindex < max_ncolor - 1; ++cindex) { | |
std::vector<uint32_t> path; | |
uint32_t reward = FindBestPath(graph, node_importance, &path); | |
if (reward == 0) break; | |
for (uint32_t nid : path) { | |
if (node_importance[nid] != 0) { | |
CHECK_EQ(color->at(nid), max_ncolor); | |
color->at(nid) = cindex; | |
// make the importance 0 after color is decided. | |
node_importance[nid] = 0; | |
} | |
} | |
} | |
// assign i for rest of the node | |
for (uint32_t i = 0; i < graph.num_nodes(); ++i) { | |
if (color->at(i) == max_ncolor) { | |
color->at(i) = cindex; | |
} | |
} | |
return cindex + 1; | |
} | |
} // namespace pass | |
} // namespace nnvm | |
#endif // NNVM_PASS_GRAPH_ALGORITHM_H_ | |
//===== EXPANDED: ../nnvm/src/pass/graph_algorithm.h ===== | |
namespace nnvm { | |
namespace pass { | |
namespace { | |
// simple graph based allocator. | |
class GraphAllocator { | |
public: | |
// storage id equals integer. | |
using StorageID = int; | |
// bad storage id | |
static const StorageID kBadStorageID = -1; | |
// external storage id | |
static const StorageID kExternalStorageID = -2; | |
// request a free storage | |
StorageID Request(int dev_id, int dtype, TShape shape, uint32_t node_id) { | |
if (shape.ndim() == 0) return kBadStorageID; | |
// search memory block in [size / match_range_, size * match_range_) | |
// TODO(tqchen) add size of the dtype, assume 4 bytes for now | |
size_t size = shape.Size() * 4; | |
if (match_range_ == 0) return this->Alloc(dev_id, size); | |
auto begin = free_.lower_bound(size / match_range_); | |
auto mid = free_.lower_bound(size); | |
auto end = free_.upper_bound(size * match_range_); | |
// search for memory blocks larger than requested | |
for (auto it = mid; it != end; ++it) { | |
StorageEntry *e = it->second; | |
if (e->device_id != dev_id) continue; | |
if (node_color_.size() != 0 && | |
node_color_[e->released_by_node] != node_color_[node_id]) continue; | |
// Use exect matching strategy | |
e->max_bytes = std::max(size, e->max_bytes); | |
// find a exact match, erase from map and return | |
free_.erase(it); | |
return e->id; | |
} | |
// then search for memory blocks smaller than requested space | |
for (auto it = mid; it != begin;) { | |
--it; | |
StorageEntry *e = it->second; | |
if (e->device_id != dev_id) continue; | |
if (node_color_.size() != 0 && | |
node_color_[e->released_by_node] != node_color_[node_id]) continue; | |
// Use exect matching strategy | |
e->max_bytes = std::max(size, e->max_bytes); | |
// find a exact match, erase from map and return | |
free_.erase(it); | |
return e->id; | |
} | |
// cannot find anything return a new one. | |
return this->Alloc(dev_id, size); | |
} | |
// release a memory space. | |
void Release(StorageID id, uint32_t node_id) { | |
CHECK_NE(id, kBadStorageID); | |
if (id == kExternalStorageID) return; | |
StorageEntry *e = data_[id].get(); | |
e->released_by_node = node_id; | |
free_.insert({e->max_bytes, e}); | |
} | |
// totoal number of bytes allocated | |
size_t TotalAllocBytes() const { | |
size_t total = 0; | |
for (auto &p : data_) { | |
total += p->max_bytes; | |
} | |
return total; | |
} | |
// constructor | |
explicit GraphAllocator(const IndexedGraph* idx) : idx_(idx) { | |
this->Init(dmlc::GetEnv("NNVM_EXEC_MATCH_RANGE", 16), | |
dmlc::GetEnv("NNVM_EXEC_NUM_TEMP", 1)); | |
} | |
private: | |
// initialize the graph allocator | |
void Init(size_t match_range, uint32_t num_match_color) { | |
match_range_ = match_range; | |
num_match_color_ = num_match_color; | |
if (num_match_color_ > 1) { | |
std::vector<uint32_t> importance(idx_->num_nodes(), 0); | |
for (uint32_t nid = 0; nid < idx_->num_nodes(); ++nid) { | |
if ((*idx_)[nid].source->is_variable()) continue; | |
importance[nid] = 1; | |
} | |
num_match_color_ = pass::ColorNodeGroup( | |
*idx_, importance, num_match_color_, &node_color_); | |
} | |
} | |
StorageID Alloc(int dev_id, size_t size) { | |
StorageID id = static_cast<StorageID>(data_.size()); | |
std::unique_ptr<StorageEntry> ptr(new StorageEntry()); | |
ptr->id = id; | |
ptr->device_id = dev_id; | |
ptr->max_bytes = size; | |
data_.emplace_back(std::move(ptr)); | |
return id; | |
} | |
// internal storage entry | |
struct StorageEntry { | |
// the id of the entry. | |
StorageID id; | |
// the device id of the storage. | |
int device_id; | |
// maximum size of storage requested. | |
size_t max_bytes{0}; | |
// node index that released it last time | |
uint32_t released_by_node{0}; | |
}; | |
// scale used for rough match | |
size_t match_range_; | |
// whether use color based match algorithm | |
uint32_t num_match_color_{1}; | |
// the size of each dtype | |
std::vector<size_t> dtype_size_dict_; | |
// free list of storage entry | |
std::multimap<size_t, StorageEntry*> free_; | |
// all the storage resources available | |
std::vector<std::unique_ptr<StorageEntry> > data_; | |
// color of nodes in the graph, used for auxiliary policy making. | |
std::vector<uint32_t> node_color_; | |
// internal indexed graph | |
const IndexedGraph* idx_; | |
}; | |
// function to plan memory | |
Graph PlanMemory(Graph ret) { | |
// setup ref counter | |
const IndexedGraph& idx = ret.indexed_graph(); | |
static auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs"); | |
// reference counter of each node | |
std::vector<uint32_t> ref_count(idx.num_node_entries(), 0); | |
// step 1: initialize reference count | |
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { | |
const auto& inode = idx[nid]; | |
if (inode.source->is_variable()) continue; | |
for (const auto& e : inode.inputs) { | |
++ref_count[idx.entry_id(e)]; | |
} | |
// no dataflow dependency is needed for those are ignored. | |
// revoke the dependency counter. | |
if (fignore_inputs.count(inode.source->op()) != 0) { | |
auto ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs); | |
for (uint32_t i : ignore_inputs) { | |
--ref_count[idx.entry_id(inode.inputs[i])]; | |
} | |
} | |
} | |
for (const auto& e : idx.outputs()) { | |
++ref_count[idx.entry_id(e)]; | |
} | |
// step 2: allocate memory. | |
StorageVector storage; | |
if (ret.attrs.count("storage") != 0) { | |
storage = ret.MoveCopyAttr<StorageVector>("storage"); | |
} else { | |
storage.resize(idx.num_node_entries(), -1); | |
} | |
std::vector<int> storage_inplace_index(idx.num_node_entries(), -1); | |
const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape"); | |
const DTypeVector& dtype_vec = ret.GetAttr<DTypeVector>("dtype"); | |
const DeviceVector* device_vec = nullptr; | |
static auto& finplace_option = Op::GetAttr<FInplaceOption>("FInplaceOption"); | |
if (ret.attrs.count("device") != 0) { | |
device_vec = &(ret.GetAttr<DeviceVector>("device")); | |
} | |
// the allocator. | |
GraphAllocator allocator(&idx); | |
// number of entries that are not statically allocated. | |
size_t num_not_allocated = 0; | |
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { | |
const auto& inode = idx[nid]; | |
if (inode.source->is_variable()) continue; | |
// check inplace option | |
if (finplace_option.count(inode.source->op()) != 0) { | |
auto inplace_pairs = finplace_option[inode.source->op()](inode.source->attrs); | |
for (auto& kv : inplace_pairs) { | |
uint32_t eid_out = idx.entry_id(nid, kv.second); | |
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); | |
if (ref_count[eid_in] == 1 && | |
ref_count[eid_out] != 0 && | |
storage[eid_out] == GraphAllocator::kBadStorageID && | |
storage[eid_in] != GraphAllocator::kBadStorageID && | |
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && | |
dtype_vec[eid_out] == dtype_vec[eid_in]) { | |
// inplace optimization | |
storage[eid_out] = storage[eid_in]; | |
ref_count[eid_in] = 0; | |
storage_inplace_index[eid_out] = kv.first; | |
} | |
} | |
} | |
// normal allocation | |
const int dev_id = (device_vec != nullptr) ? device_vec->at(nid) : 0; | |
// allocate output | |
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { | |
uint32_t eid = idx.entry_id(nid, index); | |
if (storage[eid] == GraphAllocator::kBadStorageID) { | |
storage[eid] = allocator.Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); | |
} | |
} | |
// check if certain inputs is ignored. | |
std::vector<uint32_t> ignore_inputs; | |
if (fignore_inputs.count(inode.source->op()) != 0) { | |
ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs); | |
std::sort(ignore_inputs.begin(), ignore_inputs.end()); | |
} | |
// then free inputs | |
for (size_t i = 0; i < inode.inputs.size(); ++i) { | |
// ref counter of ignored input is already decreased. | |
if (std::binary_search(ignore_inputs.begin(), ignore_inputs.end(), i)) continue; | |
const auto& e = inode.inputs[i]; | |
uint32_t eid = idx.entry_id(e); | |
// temp_ref_count == 0 means it is taken by inplace op | |
if (ref_count[eid] == 0) continue; | |
// if we decrease it to zero, means we are ready to relase | |
--ref_count[eid]; | |
if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) { | |
allocator.Release(storage[eid], nid); | |
} | |
} | |
// check if there are outputs that can be freeded immediately | |
// these output are not referenced by any operator. | |
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { | |
uint32_t eid = idx.entry_id(nid, index); | |
if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) { | |
allocator.Release(storage[eid], nid); | |
// use -2 to indicate that the node was never touched. | |
storage_inplace_index[eid] = -2; | |
} | |
if (storage[eid] == GraphAllocator::kBadStorageID) { | |
++num_not_allocated; | |
} | |
} | |
} | |
ret.attrs["storage_id"] = std::make_shared<any>(std::move(storage)); | |
ret.attrs["storage_inplace_index"] = std::make_shared<any>(std::move(storage_inplace_index)); | |
ret.attrs["storage_allocated_bytes"] = std::make_shared<any>(allocator.TotalAllocBytes()); | |
ret.attrs["storage_num_not_allocated"] = std::make_shared<any>(num_not_allocated); | |
return ret; | |
} | |
NNVM_REGISTER_PASS(PlanMemory) | |
.describe("Plan the memory allocation of each node entries.") | |
.set_body(PlanMemory) | |
.set_change_graph(false) | |
.depend_graph_attr("dtype") | |
.depend_graph_attr("shape") | |
.provide_graph_attr("storage_id") | |
.provide_graph_attr("storage_inplace_index"); | |
} // namespace | |
} // namespace pass | |
} // namespace nnvm | |
//===== EXPANDED: ../nnvm/src/pass/plan_memory.cc ===== | |
//===== EXPANDING: ../nnvm/src/pass/infer_shape_type.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file infer_shape.cc | |
* \brief Inference the shapes given existin information. | |
*/ | |
namespace nnvm { | |
namespace pass { | |
namespace { | |
template<typename AttrType, typename IsNone, typename FDefault> | |
Graph InferAttr(Graph &&ret, | |
const AttrType empty_val, | |
const char* infer_name, | |
const char* input_name, | |
const char* attr_key_name, | |
const char* attr_name, | |
const char* unknown_name, | |
IsNone fis_none, | |
FDefault fdefault) { | |
using AttrVector = std::vector<AttrType>; | |
const IndexedGraph& idx = ret.indexed_graph(); | |
static auto& finfer_shape = | |
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name); | |
static auto& is_backward = | |
Op::GetAttr<TIsBackward>("TIsBackward"); | |
// gradient function, used to get node correspondence. | |
static auto& fgrad = | |
Op::GetAttr<FGradient>("FGradient"); | |
// reshape shape vector | |
AttrVector rshape; | |
if (ret.attrs.count(attr_name) != 0) { | |
rshape = ret.MoveCopyAttr<AttrVector>(attr_name); | |
} else { | |
rshape.resize(idx.num_node_entries(), empty_val); | |
} | |
if (ret.attrs.count(input_name) != 0) { | |
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name); | |
CHECK_LE(shape_args.size(), idx.input_nodes().size()) | |
<< "More provided shapes than number of arguments."; | |
for (size_t i = 0; i < shape_args.size(); ++i) { | |
rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i]; | |
} | |
// erase the provided arguments | |
ret.attrs.erase(input_name); | |
} | |
std::string shape_attr_key; | |
if (ret.attrs.count(attr_key_name) != 0) { | |
shape_attr_key = ret.GetAttr<std::string>(attr_key_name); | |
// erase the provided arguments | |
ret.attrs.erase(attr_key_name); | |
} | |
// Temp space for shape inference. | |
std::vector<AttrType> ishape, oshape; | |
// inference step function for nid | |
auto infer_step = [&](uint32_t nid, bool last_iter) { | |
const auto& inode = idx[nid]; | |
const uint32_t num_inputs = inode.inputs.size(); | |
const uint32_t num_outputs = inode.source->num_outputs(); | |
if (inode.source->is_variable()) { | |
// Variable node. No operator. Only one output entry. | |
CHECK(inode.source->op() == nullptr); | |
CHECK_EQ(num_outputs, 1U); | |
const uint32_t out_ent_id = idx.entry_id(nid, 0); | |
if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) { | |
auto it = inode.source->attrs.dict.find(shape_attr_key); | |
if (it != inode.source->attrs.dict.end()) { | |
std::istringstream is(it->second); | |
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; | |
} | |
} | |
} else if (is_backward.get(inode.source->op(), false)) { | |
CHECK_GE(inode.control_deps.size(), 1U) | |
<< "BackwardOp need to have control_deps to its forward op"; | |
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; | |
NodePtr fwd_ptr = inode.source->control_deps[0]; | |
// use gradient function to find out the correspondence. | |
std::vector<NodeEntry> ograd(fwd_ptr->num_outputs()); | |
for (size_t i = 0; i < ograd.size(); ++i) { | |
ograd[i].index = static_cast<uint32_t>(i); | |
} | |
// input gradient list | |
auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); | |
const Op* backward_op = inode.source->op(); | |
const Node* igrad_node = nullptr; | |
// Input gradient assignement | |
for (size_t i = 0; i < igrad.size(); ++i) { | |
if (igrad[i].node->op() == backward_op) { | |
uint32_t eid = idx.entry_id(nid, igrad[i].index); | |
if (fis_none(rshape[eid])) { | |
rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; | |
} else { | |
CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) | |
<< "Backward shape inconsistent with the forward shape"; | |
} | |
if (igrad_node == nullptr) { | |
igrad_node = igrad[i].node.get(); | |
} else { | |
CHECK(igrad_node == igrad[i].node.get()); | |
} | |
} | |
} | |
// out grad entries | |
for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { | |
const NodeEntry& e = igrad_node->inputs[i]; | |
if (e.node == nullptr) { | |
uint32_t eid = idx.entry_id(inode.inputs[i]); | |
if (fis_none(rshape[eid])) { | |
rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)]; | |
} | |
} | |
} | |
} else { | |
bool forward_known = true; | |
// Forward operator inference. | |
ishape.resize(num_inputs, empty_val); | |
for (uint32_t i = 0; i < ishape.size(); ++i) { | |
ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; | |
if (fis_none(ishape[i])) forward_known = false; | |
} | |
oshape.resize(num_outputs, empty_val); | |
for (uint32_t i = 0; i < oshape.size(); ++i) { | |
oshape[i] = rshape[idx.entry_id(nid, i)]; | |
if (fis_none(oshape[i])) forward_known = false; | |
} | |
auto finfer = finfer_shape.get(inode.source->op(), fdefault); | |
if (!forward_known) { | |
if (finfer != nullptr) { | |
// Call inference function of the operator. | |
try { | |
forward_known = finfer(inode.source->attrs, &ishape, &oshape); | |
} catch (const std::exception& e) { | |
throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); | |
} | |
} else { | |
CHECK(!last_iter) | |
<< "Attribute " << infer_name | |
<< " is not registed by op " << inode.source->op()->name | |
<< " we are not able to complete the inference because of this"; | |
} | |
} | |
// Save to the result map. | |
for (uint32_t i = 0; i < num_inputs; ++i) { | |
rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; | |
} | |
for (uint32_t i = 0; i < num_outputs; ++i) { | |
rshape[idx.entry_id(nid, i)] = oshape[i]; | |
} | |
} | |
}; | |
size_t last_num_unknown; | |
size_t num_unknown = rshape.size(); | |
int i = 0; | |
do { | |
if (i % 2 == 0) { | |
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { | |
infer_step(nid, false); | |
} | |
} else { | |
// backward inference | |
for (uint32_t i = idx.num_nodes(); i != 0; --i) { | |
infer_step(i - 1, false); | |
} | |
} | |
last_num_unknown = num_unknown; | |
num_unknown = 0; | |
for (size_t j = 0; j < idx.num_node_entries(); ++j) { | |
if (fis_none(rshape[j])) { | |
++num_unknown; | |
} | |
} | |
++i; | |
} while (num_unknown > 0 && last_num_unknown > num_unknown); | |
// set the shapes | |
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape)); | |
// number of nodes who knows the shape. | |
ret.attrs[unknown_name] = std::make_shared<any>(num_unknown); | |
return ret; | |
} | |
NNVM_REGISTER_PASS(InferShape) | |
.describe("Infer the shape of each node entries.") | |
.set_body([](Graph ret) { | |
return InferAttr<TShape>( | |
std::move(ret), TShape(), | |
"FInferShape", "shape_inputs", "shape_attr_key", | |
"shape", "shape_num_unknown_nodes", | |
[](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, | |
nullptr); | |
}) | |
.set_change_graph(false) | |
.provide_graph_attr("shape"); | |
// inference fucntion for same type | |
inline bool SameType(const NodeAttrs& attrs, | |
std::vector<int> *iattr, | |
std::vector<int> *oattr) { | |
int def_v = -1; | |
for (int v : *oattr) { | |
if (v != -1) { | |
def_v = v; break; | |
} | |
} | |
if (def_v == -1) { | |
for (int v : *iattr) { | |
if (v != -1) { | |
def_v = v; break; | |
} | |
} | |
} | |
if (def_v == -1) return false; | |
for (int& v : *oattr) { | |
v = def_v; | |
} | |
for (int& v : *iattr) { | |
v = def_v; | |
} | |
return true; | |
} | |
NNVM_REGISTER_PASS(InferType) | |
.describe("Infer the dtype of each node entries.") | |
.set_body([](Graph ret) { | |
return InferAttr<int>( | |
std::move(ret), -1, | |
"FInferType", "dtype_inputs", "dtype_attr_key", | |
"dtype", "dtype_num_unknown_nodes", | |
[](const int t) { return t == -1; }, | |
SameType); | |
}) | |
.set_change_graph(false) | |
.provide_graph_attr("dtype"); | |
DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape); | |
DMLC_JSON_ENABLE_ANY(DTypeVector, list_int); | |
DMLC_JSON_ENABLE_ANY(size_t, size_t); | |
} // namespace | |
} // namespace pass | |
} // namespace nnvm | |
//===== EXPANDED: ../nnvm/src/pass/infer_shape_type.cc ===== | |
//===== EXPANDING: ../nnvm/src/pass/place_device.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file place_device.cc | |
* \brief Inference the device of each operator given known information. | |
* Insert a copy node automatically when there is a cross device. | |
*/ | |
namespace nnvm { | |
namespace pass { | |
namespace { | |
// simply logic to place device according to device_group hint | |
// insert copy node when there is | |
Graph PlaceDevice(Graph src) { | |
CHECK(src.attrs.count("device_group_attr_key")) | |
<< "Need graph attribute \"device_group_attr_key\" in PlaceDevice"; | |
CHECK(src.attrs.count("device_assign_map")) | |
<< "Need graph attribute \"device_assign_map\" in PlaceDevice"; | |
CHECK(src.attrs.count("device_copy_op")) | |
<< "Need graph attribute \"device_copy_op\" in PlaceDevice"; | |
std::string device_group_attr_key = src.GetAttr<std::string>("device_group_attr_key"); | |
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op")); | |
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map"); | |
const IndexedGraph& idx = src.indexed_graph(); | |
static auto& is_backward = | |
Op::GetAttr<TIsBackward>("TIsBackward"); | |
DeviceVector device; | |
// copy on write semanatics | |
if (src.attrs.count("device") != 0) { | |
device = src.MoveCopyAttr<DeviceVector>("device"); | |
CHECK_EQ(device.size(), idx.num_nodes()); | |
} else { | |
device.resize(idx.num_nodes(), -1); | |
} | |
// forward pass | |
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { | |
const auto& inode = idx[nid]; | |
auto it = inode.source->attrs.dict.find(device_group_attr_key); | |
if (it != inode.source->attrs.dict.end()) { | |
const std::string& device_group = it->second; | |
auto dit = device_assign_map.find(device_group); | |
CHECK(dit != device_assign_map.end()) | |
<< "The device assignment not found for group " << device_group; | |
device[nid] = dit->second; | |
} else { | |
if (!inode.source->is_variable() && | |
is_backward.get(inode.source->op(), false)) { | |
if (device[inode.control_deps[0]] != -1) { | |
device[nid] = device[inode.control_deps[0]]; | |
} | |
} else { | |
for (const IndexedGraph::NodeEntry& e : inode.inputs) { | |
if (device[e.node_id] != -1) { | |
device[nid] = device[e.node_id]; break; | |
} | |
} | |
} | |
} | |
} | |
// backward pass | |
for (uint32_t i = idx.num_nodes(); i != 0; --i) { | |
uint32_t nid = i - 1; | |
const auto& inode = idx[nid]; | |
if (device[nid] == -1) continue; | |
for (const IndexedGraph::NodeEntry& e : inode.inputs) { | |
if (device[e.node_id] == -1) device[e.node_id] = device[nid]; | |
} | |
} | |
int num_dev = 1, other_dev_id = -1; | |
for (int& dev : device) { | |
if (dev == -1) dev = 0; | |
if (dev != other_dev_id) { | |
if (other_dev_id != -1) ++num_dev; | |
other_dev_id = dev; | |
} | |
} | |
if (num_dev == 1) { | |
src.attrs.erase("device_group_attr_key"); | |
src.attrs.erase("device_assign_map"); | |
src.attrs.erase("device_copy_op"); | |
src.attrs["device"] = std::make_shared<any>(std::move(device)); | |
return src; | |
} | |
std::map<std::tuple<uint32_t, uint32_t, int>, NodePtr> copy_map; | |
std::vector<NodePtr> new_node_map(idx.num_nodes(), nullptr); | |
std::unordered_map<const Node*, int> new_device_map; | |
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); | |
// insert copy node | |
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { | |
int dev_id = device[nid]; | |
const auto& inode = idx[nid]; | |
// check if mutation is needed | |
bool need_mutate = false; | |
if (!inode.source->is_variable() && fmutate_inputs.count(inode.source->op())) { | |
for (uint32_t index : fmutate_inputs[inode.source->op()](inode.source->attrs)) { | |
auto e = inode.inputs[index]; | |
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) { | |
LOG(FATAL) << " mutable state cannot go across device" | |
<< " op=" << inode.source->op()->name | |
<< " input_state_index=" << index; | |
} | |
} | |
} | |
for (const IndexedGraph::NodeEntry& e : inode.inputs) { | |
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) { | |
need_mutate = true; break; | |
} | |
} | |
if (!need_mutate) { | |
for (const uint32_t cid : inode.control_deps) { | |
if (new_node_map[cid] != nullptr) { | |
need_mutate = true; break; | |
} | |
} | |
} | |
if (inode.source->is_variable()) { | |
CHECK(!need_mutate) << "consistency check"; | |
} | |
if (need_mutate) { | |
NodePtr new_node = Node::Create(); | |
new_node->attrs = inode.source->attrs; | |
new_node->inputs.reserve(inode.inputs.size()); | |
for (size_t i = 0; i < inode.inputs.size(); ++i) { | |
const IndexedGraph::NodeEntry& e = inode.inputs[i]; | |
if (dev_id != device[e.node_id]) { | |
auto copy_key = std::make_tuple(e.node_id, e.index, dev_id); | |
auto it = copy_map.find(copy_key); | |
if (it != copy_map.end() && it->first == copy_key) { | |
new_node->inputs.emplace_back( | |
NodeEntry{it->second, 0, 0}); | |
} else { | |
NodePtr copy_node = Node::Create(); | |
std::ostringstream os; | |
os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy"; | |
copy_node->attrs.op = copy_op; | |
copy_node->attrs.name = os.str(); | |
if (new_node_map[e.node_id] != nullptr) { | |
copy_node->inputs.emplace_back( | |
NodeEntry{new_node_map[e.node_id], e.index, 0}); | |
} else { | |
copy_node->inputs.push_back(inode.source->inputs[i]); | |
} | |
if (copy_node->attrs.op->attr_parser != nullptr) { | |
copy_node->attrs.op->attr_parser(&(copy_node->attrs)); | |
} | |
copy_map[copy_key] = copy_node; | |
new_device_map[copy_node.get()] = dev_id; | |
new_node->inputs.emplace_back( | |
NodeEntry{std::move(copy_node), 0, 0}); | |
} | |
} else { | |
if (new_node_map[e.node_id] != nullptr) { | |
new_node->inputs.emplace_back( | |
NodeEntry{new_node_map[e.node_id], e.index, 0}); | |
} else { | |
new_node->inputs.push_back(inode.source->inputs[i]); | |
} | |
} | |
} | |
new_node->control_deps.reserve(inode.control_deps.size()); | |
for (size_t i = 0; i < inode.control_deps.size(); ++i) { | |
uint32_t cid = inode.control_deps[i]; | |
if (new_node_map[cid] != nullptr) { | |
new_node->control_deps.push_back(new_node_map[cid]); | |
} else { | |
new_node->control_deps.push_back(inode.source->control_deps[i]); | |
} | |
} | |
new_device_map[new_node.get()] = dev_id; | |
new_node_map[nid] = std::move(new_node); | |
} else { | |
new_device_map[inode.source] = dev_id; | |
} | |
} | |
// make the new graph | |
Graph ret; | |
for (const NodeEntry& e : src.outputs) { | |
if (new_node_map[idx.node_id(e.node.get())] != nullptr) { | |
ret.outputs.emplace_back( | |
NodeEntry{new_node_map[idx.node_id(e.node.get())], e.index, e.version}); | |
} else { | |
ret.outputs.emplace_back(e); | |
} | |
} | |
DeviceVector new_device_vec(ret.indexed_graph().num_nodes()); | |
for (uint32_t nid = 0; nid < ret.indexed_graph().num_nodes(); ++nid) { | |
auto source = ret.indexed_graph()[nid].source; | |
if (new_device_map.count(source) == 0) { | |
LOG(FATAL) << "canot find " << source; | |
} | |
new_device_vec[nid] = new_device_map.at(source); | |
} | |
ret.attrs["device"] = std::make_shared<any>(std::move(new_device_vec)); | |
return ret; | |
} | |
NNVM_REGISTER_PASS(PlaceDevice) | |
.describe("Infer the device type of each operator."\ | |
"Insert a copy node when there is cross device copy") | |
.set_body(PlaceDevice) | |
.set_change_graph(true) | |
.provide_graph_attr("device") | |
.depend_graph_attr("device_group_attr_key") | |
.depend_graph_attr("device_assign_map") | |
.depend_graph_attr("device_copy_op"); | |
DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int); | |
} // namespace | |
} // namespace pass | |
} // namespace nnvm | |
//===== EXPANDED: ../nnvm/src/pass/place_device.cc ===== | |
//===== EXPANDING: ../nnvm/src/pass/saveload_json.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file saveload_json.cc | |
* \brief Save and load graph to/from JSON file. | |
*/ | |
//===== EXPANDING: ../nnvm/include/nnvm/pass_functions.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file pass_functions.h | |
* \brief Pass functions that simply redirect the calls to ApplyPass | |
* | |
* This file serves as documentation on how to use functions implemented in "src/pass". | |
* It is totally optional to add these functions when you add a new pass, since | |
* ApplyPass can be directly called. | |
*/ | |
#ifndef NNVM_PASS_FUNCTIONS_H_ | |
#define NNVM_PASS_FUNCTIONS_H_ | |
namespace nnvm { | |
namespace pass { | |
/*! | |
* \brief Load a graph from JSON string, redirects to "LoadJSON" pass. | |
* \param json_str The json string. | |
* \return Loaded graph. | |
*/ | |
inline Graph LoadJSON(const std::string& json_str) { | |
Graph ret; | |
ret.attrs["json"] = std::make_shared<any>(json_str); | |
return ApplyPass(ret, "LoadJSON"); | |
} | |
/*! | |
* \brief Save a graph to json, redirects to "SaveJSON" pass. | |
* \param graph The graph to be saved as json format. | |
* \return The json string. | |
*/ | |
inline std::string SaveJSON(Graph graph) { | |
Graph ret = ApplyPass(std::move(graph), "SaveJSON"); | |
return ret.GetAttr<std::string>("json"); | |
} | |
/*! | |
* \brief Add control flow dependencies between nodes. | |
* | |
* This function will enforce the correct order between | |
* write (mutable operators) and read (immutable operators) | |
* to sovle write-after-read and read-after-write problems. | |
* | |
* \param src The input graph. | |
* \return A graph with proper control flow dependencies added. | |
*/ | |
inline Graph OrderMutation(Graph src) { | |
return ApplyPass(std::move(src), "OrderMutation"); | |
} | |
/*! | |
* \brief Infer shapes in the graph given the information. | |
* \param graph The input graph. | |
* \param shape_inputs The shapes of input symbols to the graph. | |
* \param shape_attr_key The key to the node attribute that can indicate shape. This is | |
* the place where manual hint for shapes could be injected. | |
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. | |
* The index of ShapeVector is given by graph.indexed_graph().entry_id. | |
*/ | |
inline Graph InferShape(Graph graph, | |
ShapeVector shape_inputs, | |
std::string shape_attr_key = "") { | |
if (shape_inputs.size() != 0) { | |
graph.attrs["shape_inputs"] = std::make_shared<any>(std::move(shape_inputs)); | |
} | |
if (shape_attr_key.length() != 0) { | |
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key)); | |
} | |
return ApplyPass(std::move(graph), "InferShape"); | |
} | |
/*! | |
* \brief Infer types in the graph given the information. | |
* \param graph The input graph. | |
* \param dtype_inputs The types of input symbols to the graph. | |
* \param dtype_attr_key The key to the node attribute that can indicate types. This is | |
* the place where manual hint for types could be injected. | |
* \return A graph with new attribute "dtype" containing inferred type of each NodeEntry. | |
* The index of ShapeVector is given by graph.indexed_graph().entry_id. | |
*/ | |
inline Graph InferType(Graph graph, | |
DTypeVector dtype_inputs, | |
std::string dtype_attr_key = "") { | |
if (dtype_inputs.size() != 0) { | |
graph.attrs["dtype_inputs"] = std::make_shared<any>(std::move(dtype_inputs)); | |
} | |
if (dtype_attr_key.length() != 0) { | |
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key)); | |
} | |
return ApplyPass(std::move(graph), "InferType"); | |
} | |
/*! | |
* \brief Place the devices for each operator in the graph. | |
* | |
* Current device placement is quite simple. Each operator is assigned to a "group" (stored | |
* in `device_group_attr_key` attribute). Each group is assigned to a device (stored in | |
* `device_assign_map` attribute). Operators will be placed to the device assigned to its | |
* group. Copy operators will be injected if cross device reference happens. | |
* | |
* \param graph The input graph. | |
* \param device_group_attr_key The attribute name for hints of device group. | |
* \param device_assign_map The assignment map of device. | |
* \param device_copy_op The name of copy op to be inserted when cross device copy happened. | |
* \return A graph with new attribute "device", cotaining device information of each node. | |
*/ | |
inline Graph PlaceDevice(Graph graph, | |
std::string device_group_attr_key, | |
DeviceAssignMap device_assign_map, | |
std::string device_copy_op) { | |
graph.attrs["device_group_attr_key"] = std::make_shared<any>(std::move(device_group_attr_key)); | |
graph.attrs["device_assign_map"] = std::make_shared<any>(std::move(device_assign_map)); | |
graph.attrs["device_copy_op"] = std::make_shared<any>(std::move(device_copy_op)); | |
return ApplyPass(std::move(graph), "PlaceDevice"); | |
} | |
/*! | |
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys. | |
* \param graph The input graph. | |
* \param ys The entries we want to take gradient from. | |
* \param xs The input to take gradient with respect to. | |
* \param ys_out_grad The symbol for additional gradient to be propagate back to y. | |
* \param aggregate_fun Aggregation function applied to aggregate the inputs. | |
* \param mirror_fun Optional mirror function to do mirror optimization and save memory. | |
* \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like. | |
* \return A new graph, whose outputs correspond to inputs of xs. | |
*/ | |
inline Graph Gradient( | |
Graph graph, | |
std::vector<NodeEntry> ys, | |
std::vector<NodeEntry> xs, | |
std::vector<NodeEntry> ys_out_grad, | |
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr, | |
std::function<int(const Node& node)> mirror_fun = nullptr, | |
std::function<NodeEntry(const NodeEntry& src, const NodeEntry &like)> | |
attr_hint_fun = nullptr) { | |
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys)); | |
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs)); | |
graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad)); | |
if (aggregate_fun != nullptr) { | |
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun); | |
} | |
if (mirror_fun != nullptr) { | |
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun); | |
} | |
if (attr_hint_fun != nullptr) { | |
graph.attrs["attr_hint_fun"] = std::make_shared<any>(attr_hint_fun); | |
} | |
return ApplyPass(std::move(graph), "Gradient"); | |
} | |
} // namespace pass | |
} // namespace nnvm | |
#endif // NNVM_PASS_FUNCTIONS_H_ | |
//===== EXPANDED: ../nnvm/include/nnvm/pass_functions.h ===== | |
namespace dmlc { | |
namespace json { | |
// overload handler for shared ptr | |
template<> | |
struct Handler<std::shared_ptr<any> > { | |
inline static void Write(JSONWriter *writer, const std::shared_ptr<any> &data) { | |
writer->Write(*data); | |
} | |
inline static void Read(JSONReader *reader, std::shared_ptr<any> *data) { | |
any v; | |
reader->Read(&v); | |
*data = std::make_shared<any>(std::move(v)); | |
} | |
}; | |
} // namespace json | |
} // namespace dmlc | |
namespace nnvm { | |
namespace pass { | |
namespace { | |
// auxiliary node structure for serialization. | |
struct JSONNode { | |
// the node entry structure in serialized format | |
struct Entry { | |
uint32_t node_id; | |
uint32_t index; | |
uint32_t version; | |
void Save(dmlc::JSONWriter *writer) const { | |
writer->BeginArray(false); | |
writer->WriteArrayItem(node_id); | |
writer->WriteArrayItem(index); | |
writer->WriteArrayItem(version); | |
writer->EndArray(); | |
} | |
void Load(dmlc::JSONReader *reader) { | |
reader->BeginArray(); | |
CHECK(reader->NextArrayItem()) << "invalid json format"; | |
reader->Read(&node_id); | |
CHECK(reader->NextArrayItem()) << "invalid json format"; | |
reader->Read(&index); | |
if (reader->NextArrayItem()) { | |
reader->Read(&version); | |
CHECK(!reader->NextArrayItem()) << "invalid json format"; | |
} else { | |
version = 0; | |
} | |
} | |
}; | |
// pointer to the graph node | |
NodePtr node; | |
// inputs | |
std::vector<Entry> inputs; | |
// control flow dependencies | |
std::vector<uint32_t> control_deps; | |
// function to save JSON node. | |
void Save(dmlc::JSONWriter *writer) const { | |
writer->BeginObject(); | |
if (node->op() != nullptr) { | |
writer->WriteObjectKeyValue("op", node->op()->name); | |
} else { | |
std::string json_null = "null"; | |
writer->WriteObjectKeyValue("op", json_null); | |
} | |
writer->WriteObjectKeyValue("name", node->attrs.name); | |
if (node->attrs.dict.size() != 0) { | |
// write attributes in order; | |
std::map<std::string, std::string> dict( | |
node->attrs.dict.begin(), node->attrs.dict.end()); | |
writer->WriteObjectKeyValue("attr", dict); | |
} | |
writer->WriteObjectKeyValue("inputs", inputs); | |
if (control_deps.size() != 0) { | |
writer->WriteObjectKeyValue("control_deps", control_deps); | |
} | |
writer->EndObject(); | |
} | |
void Load(dmlc::JSONReader *reader) { | |
node = Node::Create(); | |
control_deps.clear(); | |
dmlc::JSONObjectReadHelper helper; | |
std::string op_type_str; | |
helper.DeclareField("op", &op_type_str); | |
helper.DeclareField("name", &(node->attrs.name)); | |
helper.DeclareField("inputs", &inputs); | |
helper.DeclareOptionalField("attr", &(node->attrs.dict)); | |
helper.DeclareOptionalField("control_deps", &control_deps); | |
// backward compatible code with mxnet graph. | |
int backward_source_id; | |
std::unordered_map<std::string, std::string> param; | |
helper.DeclareOptionalField("param", ¶m); | |
helper.DeclareOptionalField("backward_source_id", &backward_source_id); | |
helper.ReadAllFields(reader); | |
node->attrs.dict.insert(param.begin(), param.end()); | |
if (op_type_str != "null") { | |
try { | |
node->attrs.op = Op::Get(op_type_str); | |
} catch (const dmlc::Error &err) { | |
std::ostringstream os; | |
os << "Failed loading Op " << node->attrs.name | |
<< " of type " << op_type_str << ": " << err.what(); | |
throw dmlc::Error(os.str()); | |
} | |
} else { | |
node->attrs.op = nullptr; | |
} | |
} | |
}; | |
// graph structure to help read/save JSON. | |
struct JSONGraph { | |
std::vector<JSONNode> nodes; | |
std::vector<uint32_t> arg_nodes; | |
std::vector<uint32_t> node_row_ptr; | |
std::vector<JSONNode::Entry> heads; | |
std::unordered_map<std::string, std::shared_ptr<any> > attrs; | |
void Save(dmlc::JSONWriter *writer) const { | |
writer->BeginObject(); | |
writer->WriteObjectKeyValue("nodes", nodes); | |
writer->WriteObjectKeyValue("arg_nodes", arg_nodes); | |
writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr); | |
writer->WriteObjectKeyValue("heads", heads); | |
if (attrs.size() != 0) { | |
writer->WriteObjectKeyValue("attrs", attrs); | |
} | |
writer->EndObject(); | |
} | |
void Load(dmlc::JSONReader *reader) { | |
attrs.clear(); | |
dmlc::JSONObjectReadHelper helper; | |
helper.DeclareField("nodes", &nodes); | |
helper.DeclareField("arg_nodes", &arg_nodes); | |
helper.DeclareField("heads", &heads); | |
helper.DeclareOptionalField("node_row_ptr", &node_row_ptr); | |
helper.DeclareOptionalField("attrs", &attrs); | |
helper.ReadAllFields(reader); | |
} | |
}; | |
// Load a graph from JSON file. | |
Graph LoadJSON(Graph src) { | |
CHECK_NE(src.attrs.count("json"), 0U) | |
<< "Load JSON require json to be presented."; | |
const std::string &json_str = | |
nnvm::get<std::string>(*src.attrs.at("json")); | |
bool no_parse = false; | |
if (src.attrs.count("load_json_no_parse")) { | |
no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse")); | |
} | |
std::istringstream is(json_str); | |
dmlc::JSONReader reader(&is); | |
JSONGraph jgraph; | |
// load in json graph. | |
jgraph.Load(&reader); | |
// connects the nodes | |
for (JSONNode &n : jgraph.nodes) { | |
n.node->inputs.reserve(n.inputs.size()); | |
for (const JSONNode::Entry &e : n.inputs) { | |
n.node->inputs.emplace_back( | |
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); | |
} | |
n.node->control_deps.reserve(n.control_deps.size()); | |
for (uint32_t nid : n.control_deps) { | |
n.node->control_deps.push_back(jgraph.nodes[nid].node); | |
} | |
// rebuild attribute parser | |
if (!no_parse && n.node->op() != nullptr && | |
n.node->op()->attr_parser != nullptr) { | |
n.node->op()->attr_parser(&(n.node->attrs)); | |
} | |
} | |
// consistent check | |
for (uint32_t nid : jgraph.arg_nodes) { | |
CHECK(jgraph.nodes[nid].node->is_variable()); | |
} | |
// return the graph | |
Graph ret; | |
ret.attrs = std::move(jgraph.attrs); | |
ret.outputs.reserve(jgraph.heads.size()); | |
for (const JSONNode::Entry &e : jgraph.heads) { | |
ret.outputs.emplace_back( | |
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); | |
} | |
return ret; | |
} | |
// save a graph to json | |
Graph SaveJSON(Graph src) { | |
JSONGraph jgraph; | |
jgraph.attrs = src.attrs; | |
std::unordered_map<Node*, uint32_t> node2index; | |
jgraph.node_row_ptr.push_back(0); | |
DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) { | |
uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size()); | |
node2index[n.get()] = nid; | |
if (n->is_variable()) { | |
jgraph.arg_nodes.push_back(nid); | |
} | |
JSONNode jnode; | |
jnode.node = n; | |
jnode.inputs.reserve(n->inputs.size()); | |
for (const NodeEntry& e : n->inputs) { | |
jnode.inputs.emplace_back( | |
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version}); | |
} | |
for (const NodePtr& c : n->control_deps) { | |
jnode.control_deps.push_back(node2index.at(c.get())); | |
} | |
jgraph.node_row_ptr.push_back( | |
jgraph.node_row_ptr.back() + n->num_outputs()); | |
jgraph.nodes.emplace_back(std::move(jnode)); | |
}); | |
for (const NodeEntry& e : src.outputs) { | |
jgraph.heads.push_back( | |
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version}); | |
} | |
std::ostringstream os; | |
dmlc::JSONWriter writer(&os); | |
jgraph.Save(&writer); | |
Graph ret; | |
ret.attrs["json"] = std::make_shared<any>(os.str()); | |
return ret; | |
} | |
// register pass | |
NNVM_REGISTER_PASS(LoadJSON) | |
.describe("Return a new Graph, loaded from src.attrs[\"json\"]") | |
.set_body(LoadJSON) | |
.set_change_graph(true) | |
.depend_graph_attr("json"); | |
NNVM_REGISTER_PASS(SaveJSON) | |
.describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]") | |
.set_body(SaveJSON) | |
.set_change_graph(true) | |
.provide_graph_attr("json"); | |
DMLC_JSON_ENABLE_ANY(std::string, str); | |
DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int); | |
} // namespace | |
} // namespace pass | |
} // namespace nnvm | |
//===== EXPANDED: ../nnvm/src/pass/saveload_json.cc ===== | |
//===== EXPANDING: ../nnvm/src/c_api/c_api_error.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file c_api_error.cc | |
* \brief C error handling | |
*/ | |
//===== EXPANDING: ../nnvm/src/c_api/c_api_common.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file c_api_error.h | |
* \brief Common fields of all C APIs | |
*/ | |
#ifndef NNVM_C_API_C_API_COMMON_H_ | |
#define NNVM_C_API_C_API_COMMON_H_ | |
//===== EXPANDING: ../nnvm/include/nnvm/c_api.h ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file c_api.h | |
* \brief C API of NNVM symbolic construction and pass. | |
* Enables construction and transformation of Graph | |
* in any other host languages. | |
*/ | |
#ifndef NNVM_C_API_H_ | |
#define NNVM_C_API_H_ | |
#ifdef __cplusplus | |
#define NNVM_EXTERN_C extern "C" | |
#else | |
#define NNVM_EXTERN_C | |
#endif | |
/*! \brief NNVM_DLL prefix for windows */ | |
#ifdef _WIN32 | |
#ifdef NNVM_EXPORTS | |
#define NNVM_DLL NNVM_EXTERN_C __declspec(dllexport) | |
#else | |
#define NNVM_DLL NNVM_EXTERN_C __declspec(dllimport) | |
#endif | |
#else | |
#define NNVM_DLL NNVM_EXTERN_C | |
#endif | |
/*! \brief manually define unsigned int */ | |
typedef unsigned int nn_uint; | |
/*! \brief handle to a function that takes param and creates symbol */ | |
typedef void *OpHandle; | |
/*! \brief handle to a symbol that can be bind as operator */ | |
typedef void *SymbolHandle; | |
/*! \brief handle to Graph */ | |
typedef void *GraphHandle; | |
/*! | |
* \brief Set the last error message needed by C API | |
* \param msg The error message to set. | |
*/ | |
NNVM_DLL void NNAPISetLastError(const char* msg); | |
/*! | |
* \brief return str message of the last error | |
* all function in this file will return 0 when success | |
* and -1 when an error occured, | |
* NNGetLastError can be called to retrieve the error | |
* | |
* this function is threadsafe and can be called by different thread | |
* \return error info | |
*/ | |
NNVM_DLL const char *NNGetLastError(void); | |
/*! | |
* \brief list all the available operator names, include entries. | |
* \param out_size the size of returned array | |
* \param out_array the output operator name array. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNListAllOpNames(nn_uint *out_size, | |
const char*** out_array); | |
/*! | |
* \brief Get operator handle given name. | |
* \param op_name The name of the operator. | |
* \param op_out The returnning op handle. | |
*/ | |
NNVM_DLL int NNGetOpHandle(const char* op_name, | |
OpHandle* op_out); | |
/*! | |
* \brief list all the available operators. | |
* This won't include the alias, use ListAllNames | |
* instead to get all alias names. | |
* | |
* \param out_size the size of returned array | |
* \param out_array the output AtomicSymbolCreator array | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNListUniqueOps(nn_uint *out_size, | |
OpHandle **out_array); | |
/*! | |
* \brief Get the detailed information about atomic symbol. | |
* \param op The operator handle. | |
* \param real_name The returned name of the creator. | |
* This name is not the alias name of the atomic symbol. | |
* \param description The returned description of the symbol. | |
* \param num_doc_args Number of arguments that contain documents. | |
* \param arg_names Name of the arguments of doc args | |
* \param arg_type_infos Type informations about the arguments. | |
* \param arg_descriptions Description information about the arguments. | |
* \param return_type Return type of the function, if any. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNGetOpInfo(OpHandle op, | |
const char **real_name, | |
const char **description, | |
nn_uint *num_doc_args, | |
const char ***arg_names, | |
const char ***arg_type_infos, | |
const char ***arg_descriptions, | |
const char **return_type); | |
/*! | |
* \brief Create an AtomicSymbol functor. | |
* \param op The operator handle | |
* \param num_param the number of parameters | |
* \param keys the keys to the params | |
* \param vals the vals of the params | |
* \param out pointer to the created symbol handle | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op, | |
nn_uint num_param, | |
const char **keys, | |
const char **vals, | |
SymbolHandle *out); | |
/*! | |
* \brief Create a Variable Symbol. | |
* \param name name of the variable | |
* \param out pointer to the created symbol handle | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out); | |
/*! | |
* \brief Create a Symbol by grouping list of symbols together | |
* \param num_symbols number of symbols to be grouped | |
* \param symbols array of symbol handles | |
* \param out pointer to the created symbol handle | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, | |
SymbolHandle *symbols, | |
SymbolHandle *out); | |
/*! | |
* \brief Add src_dep to the handle as control dep. | |
* \param handle The symbol to add dependency edges on. | |
* \param src_dep the source handles. | |
*/ | |
NNVM_DLL int NNAddControlDeps(SymbolHandle handle, | |
SymbolHandle src_dep); | |
/*! | |
* \brief Free the symbol handle. | |
* \param symbol the symbol | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolFree(SymbolHandle symbol); | |
/*! | |
* \brief Copy the symbol to another handle | |
* \param symbol the source symbol | |
* \param out used to hold the result of copy | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out); | |
/*! | |
* \brief Print the content of symbol, used for debug. | |
* \param symbol the symbol | |
* \param out_str pointer to hold the output string of the printing. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str); | |
/*! | |
* \brief Get string attribute from symbol | |
* \param symbol the source symbol | |
* \param key The key of the symbol. | |
* \param out The result attribute, can be NULL if the attribute do not exist. | |
* \param success Whether the result is contained in out. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, | |
const char* key, | |
const char** out, | |
int *success); | |
/*! | |
* \brief Set string attribute from symbol. | |
* NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph. | |
* | |
* Safe recommendaton: use immutable graph | |
* - Only allow set attributes during creation of new symbol as optional parameter | |
* | |
* Mutable graph (be careful about the semantics): | |
* - Allow set attr at any point. | |
* - Mutating an attribute of some common node of two graphs can cause confusion from user. | |
* | |
* \param symbol the source symbol | |
* \param num_param Number of parameters to set. | |
* \param keys The keys of the attribute | |
* \param values The value to be set | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, | |
nn_uint num_param, | |
const char** keys, | |
const char** values); | |
/*! | |
* \brief Get all attributes from symbol, including all descendents. | |
* \param symbol the source symbol | |
* \param recursive_option 0 for recursive, 1 for shallow. | |
* \param out_size The number of output attributes | |
* \param out 2*out_size strings representing key value pairs. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, | |
int recursive_option, | |
nn_uint *out_size, | |
const char*** out); | |
/*! | |
* \brief List inputs variables in the symbol. | |
* \param symbol the symbol | |
* \param option The option to list the inputs | |
* option=0 means list all arguments. | |
* option=1 means list arguments that are readed only by the graph. | |
* option=2 means list arguments that are mutated by the graph. | |
* \param out_size output size | |
* \param out_sym_array the output array. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, | |
int option, | |
nn_uint *out_size, | |
SymbolHandle** out_sym_array); | |
/*! | |
* \brief List input names in the symbol. | |
* \param symbol the symbol | |
* \param option The option to list the inputs | |
* option=0 means list all arguments. | |
* option=1 means list arguments that are readed only by the graph. | |
* option=2 means list arguments that are mutated by the graph. | |
* \param out_size output size | |
* \param out_str_array pointer to hold the output string array | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, | |
int option, | |
nn_uint *out_size, | |
const char ***out_str_array); | |
/*! | |
* \brief List returns names in the symbol. | |
* \param symbol the symbol | |
* \param out_size output size | |
* \param out_str_array pointer to hold the output string array | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, | |
nn_uint *out_size, | |
const char ***out_str_array); | |
/*! | |
* \brief Get a symbol that contains all the internals. | |
* \param symbol The symbol | |
* \param out The output symbol whose outputs are all the internals. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, | |
SymbolHandle *out); | |
/*! | |
* \brief Get index-th outputs of the symbol. | |
* \param symbol The symbol | |
* \param index the Index of the output. | |
* \param out The output symbol whose outputs are the index-th symbol. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, | |
nn_uint index, | |
SymbolHandle *out); | |
/*! | |
* \brief Compose the symbol on other symbols. | |
* | |
* This function will change the sym hanlde. | |
* To achieve function apply behavior, copy the symbol first | |
* before apply. | |
* | |
* \param sym the symbol to apply | |
* \param name the name of symbol | |
* \param num_args number of arguments | |
* \param keys the key of keyword args (optional) | |
* \param args arguments to sym | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNSymbolCompose(SymbolHandle sym, | |
const char* name, | |
nn_uint num_args, | |
const char** keys, | |
SymbolHandle* args); | |
// Graph IR API | |
/*! | |
* \brief create a graph handle from symbol | |
* \param symbol The symbol representing the graph. | |
* \param graph The graph handle created. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph); | |
/*! | |
* \brief free the graph handle | |
* \param handle The handle to be freed. | |
*/ | |
NNVM_DLL int NNGraphFree(GraphHandle handle); | |
/*! | |
* \brief Get a new symbol from the graph. | |
* \param graph The graph handle. | |
* \param symbol The corresponding symbol | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); | |
/*! | |
* \brief Get Set a attribute in json format. | |
* This feature allows pass graph attributes back and forth in reasonable speed. | |
* | |
* \param handle The graph handle. | |
* \param key The key to the attribute. | |
* \param json_value The value need to be in format [type_name, value], | |
* Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, | |
const char* key, | |
const char* json_value); | |
/*! | |
* \brief Get a serialized attrirbute from graph. | |
* This feature allows pass graph attributes back and forth in reasonable speed. | |
* | |
* \param handle The graph handle. | |
* \param key The key to the attribute. | |
* \param json_out The result attribute, can be NULL if the attribute do not exist. | |
* The json_out is an array of [type_name, value]. | |
* Where the type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY. | |
* \param success Whether the result is contained in out. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle, | |
const char* key, | |
const char** json_out, | |
int *success); | |
/*! | |
* \brief Set a attribute whose type is std::vector<NodeEntry> in c++ | |
* This feature allows pass List of symbolic variables for gradient request. | |
* | |
* \note This is beta feature only used for test purpos | |
* | |
* \param handle The graph handle. | |
* \param key The key to the attribute. | |
* \param list The symbol whose outputs represents the list of NodeEntry to be passed. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, | |
const char* key, | |
SymbolHandle list); | |
/*! | |
* \brief Apply passes on the src graph. | |
* \param src The source graph handle. | |
* \param num_pass The number of pass to be applied. | |
* \param pass_names The names of the pass. | |
* \param dst The result graph. | |
* \return 0 when success, -1 when failure happens | |
*/ | |
NNVM_DLL int NNGraphApplyPasses(GraphHandle src, | |
nn_uint num_pass, | |
const char** pass_names, | |
GraphHandle *dst); | |
#endif // NNVM_C_API_H_ | |
//===== EXPANDED: ../nnvm/include/nnvm/c_api.h ===== | |
/*! \brief macro to guard beginning and end section of all functions */ | |
#define API_BEGIN() try { | |
/*! \brief every function starts with API_BEGIN(); | |
and finishes with API_END() or API_END_HANDLE_ERROR */ | |
#define API_END() } catch(dmlc::Error &_except_) { return NNAPIHandleException(_except_); } return 0; // NOLINT(*) | |
/*! | |
* \brief every function starts with API_BEGIN(); | |
* and finishes with API_END() or API_END_HANDLE_ERROR | |
* The finally clause contains procedure to cleanup states when an error happens. | |
*/ | |
#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return NNAPIHandleException(_except_); } return 0; // NOLINT(*) | |
/*! \brief entry to to easily hold returning information */ | |
struct NNAPIThreadLocalEntry { | |
/*! \brief result holder for returning string */ | |
std::string ret_str; | |
/*! \brief result holder for returning strings */ | |
std::vector<std::string> ret_vec_str; | |
/*! \brief result holder for returning string pointers */ | |
std::vector<const char *> ret_vec_charp; | |
/*! \brief result holder for returning handles */ | |
std::vector<void *> ret_handles; | |
/*! \brief argument holder to hold symbol */ | |
std::unordered_map<std::string, const nnvm::Symbol*> kwarg_symbol; | |
}; | |
/*! \brief Thread local store that can be used to hold return values. */ | |
typedef dmlc::ThreadLocalStore<NNAPIThreadLocalEntry> NNAPIThreadLocalStore; | |
/*! | |
* \brief handle exception throwed out | |
* \param e the exception | |
* \return the return value of API after exception is handled | |
*/ | |
inline int NNAPIHandleException(const dmlc::Error &e) { | |
NNAPISetLastError(e.what()); | |
return -1; | |
} | |
#endif // NNVM_C_API_C_API_COMMON_H_ | |
//===== EXPANDED: ../nnvm/src/c_api/c_api_common.h ===== | |
struct ErrorEntry { | |
std::string last_error; | |
}; | |
typedef dmlc::ThreadLocalStore<ErrorEntry> NNAPIErrorStore; | |
const char *NNGetLastError() { | |
return NNAPIErrorStore::Get()->last_error.c_str(); | |
} | |
void NNAPISetLastError(const char* msg) { | |
NNAPIErrorStore::Get()->last_error = msg; | |
} | |
//===== EXPANDED: ../nnvm/src/c_api/c_api_error.cc ===== | |
//===== EXPANDING: ../nnvm/src/c_api/c_api_graph.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file c_api_graph.cc | |
* \brief C API related to Graph IR. | |
*/ | |
using namespace nnvm; | |
int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph) { | |
Graph* g = new Graph(); | |
API_BEGIN(); | |
g->outputs = static_cast<Symbol*>(symbol)->outputs; | |
*graph = g; | |
API_END_HANDLE_ERROR(delete g); | |
} | |
int NNGraphFree(GraphHandle handle) { | |
API_BEGIN(); | |
delete static_cast<Graph*>(handle); | |
API_END(); | |
} | |
int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) { | |
Symbol* s = new Symbol(); | |
API_BEGIN(); | |
s->outputs = static_cast<Graph*>(graph)->outputs; | |
*symbol = s; | |
API_END_HANDLE_ERROR(delete s); | |
} | |
int NNGraphSetNodeEntryListAttr_(GraphHandle handle, | |
const char* key, | |
SymbolHandle list) { | |
API_BEGIN(); | |
Symbol* s = static_cast<Symbol*>(list); | |
Graph* g = static_cast<Graph*>(handle); | |
g->attrs[std::string(key)] | |
= std::make_shared<any>(s->outputs); | |
API_END(); | |
} | |
int NNGraphSetJSONAttr(GraphHandle handle, | |
const char* key, | |
const char* json_value) { | |
API_BEGIN(); | |
Graph* g = static_cast<Graph*>(handle); | |
std::string temp(json_value); | |
std::istringstream is(temp); | |
dmlc::JSONReader reader(&is); | |
nnvm::any value; | |
reader.Read(&value); | |
g->attrs[std::string(key)] = std::make_shared<any>(std::move(value)); | |
API_END(); | |
} | |
int NNGraphGetJSONAttr(GraphHandle handle, | |
const char* key, | |
const char** json_out, | |
int *success) { | |
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); | |
API_BEGIN(); | |
Graph* g = static_cast<Graph*>(handle); | |
std::string skey(key); | |
auto it = g->attrs.find(skey); | |
if (it != g->attrs.end()) { | |
std::ostringstream os; | |
dmlc::JSONWriter writer(&os); | |
writer.Write(*it->second.get()); | |
ret->ret_str = os.str(); | |
*json_out = (ret->ret_str).c_str(); | |
*success = 1; | |
} else { | |
*success = 0; | |
} | |
API_END(); | |
} | |
int NNGraphApplyPasses(GraphHandle src, | |
nn_uint num_pass, | |
const char** pass_names, | |
GraphHandle *dst) { | |
Graph* g = new Graph(); | |
API_BEGIN(); | |
std::vector<std::string> vpass; | |
for (nn_uint i = 0; i < num_pass; ++i) { | |
vpass.emplace_back(std::string(pass_names[i])); | |
} | |
*g = ApplyPasses(*static_cast<Graph*>(src), vpass); | |
*dst = g; | |
API_END_HANDLE_ERROR(delete g); | |
} | |
//===== EXPANDED: ../nnvm/src/c_api/c_api_graph.cc ===== | |
//===== EXPANDING: ../nnvm/src/c_api/c_api_symbolic.cc ===== | |
/*! | |
* Copyright (c) 2016 by Contributors | |
* \file c_api_symbolic.cc | |
* \brief C API related to symbolic graph compsition. | |
*/ | |
using namespace nnvm; | |
int NNListAllOpNames(nn_uint *out_size, | |
const char*** out_array) { | |
API_BEGIN(); | |
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); | |
ret->ret_vec_str = dmlc::Registry<Op>::ListAllNames(); | |
ret->ret_vec_charp.clear(); | |
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { | |
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); | |
} | |
*out_array = dmlc::BeginPtr(ret->ret_vec_charp); | |
*out_size = static_cast<nn_uint>(ret->ret_vec_str.size()); | |
API_END(); | |
} | |
int NNGetOpHandle(const char* op_name, | |
OpHandle* op_out) { | |
API_BEGIN(); | |
*op_out = (OpHandle)Op::Get(op_name); // NOLINT(*) | |
API_END(); | |
} | |
int NNListUniqueOps(nn_uint *out_size, | |
OpHandle **out_array) { | |
API_BEGIN(); | |
auto &vec = dmlc::Registry<Op>::List(); | |
*out_size = static_cast<nn_uint>(vec.size()); | |
*out_array = (OpHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) | |
API_END(); | |
} | |
int NNAddControlDeps(SymbolHandle handle, | |
SymbolHandle src_dep) { | |
API_BEGIN(); | |
static_cast<Symbol*>(handle)->AddControlDeps( | |
*static_cast<Symbol*>(src_dep)); | |
API_END(); | |
} | |
int NNGetOpInfo(OpHandle handle, | |
const char **name, | |
const char **description, | |
nn_uint *num_doc_args, | |
const char ***arg_names, | |
const char ***arg_type_infos, | |
const char ***arg_descriptions, | |
const char **return_type) { | |
const Op *op = static_cast<const Op *>(handle); | |
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); | |
API_BEGIN(); | |
*name = op->name.c_str(); | |
*description = op->description.c_str(); | |
*num_doc_args = static_cast<nn_uint>(op->arguments.size()); | |
if (return_type) *return_type = nullptr; | |
ret->ret_vec_charp.clear(); | |
for (size_t i = 0; i < op->arguments.size(); ++i) { | |
ret->ret_vec_charp.push_back(op->arguments[i].name.c_str()); | |
} | |
for (size_t i = 0; i < op->arguments.size(); ++i) { | |
ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str()); | |
} | |
for (size_t i = 0; i < op->arguments.size(); ++i) { | |
ret->ret_vec_charp.push_back(op->arguments[i].description.c_str()); | |
} | |
*arg_names = dmlc::BeginPtr(ret->ret_vec_charp); | |
*arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size(); | |
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2); | |
API_END(); | |
} | |
int NNSymbolCreateAtomicSymbol(OpHandle creator, | |
nn_uint num_param, | |
const char **keys, | |
const char **vals, | |
SymbolHandle *out) { | |
Symbol *s = new Symbol(); | |
API_BEGIN(); | |
const Op* op = static_cast<const Op*>(creator); | |
std::unordered_map<std::string, std::string> kwargs; | |
for (nn_uint i = 0; i < num_param; ++i) { | |
kwargs.insert({std::string(keys[i]), std::string(vals[i])}); | |
} | |
*s = Symbol::CreateFunctor(op, std::move(kwargs)); | |
*out = s; | |
API_END_HANDLE_ERROR(delete s;); | |
} | |
int NNSymbolCreateVariable(const char *name, SymbolHandle *out) { | |
Symbol *s = new Symbol(); | |
API_BEGIN(); | |
*s = Symbol::CreateVariable(name); | |
*out = s; | |
API_END_HANDLE_ERROR(delete s); | |
} | |
int NNSymbolCreateGroup(nn_uint num_symbols, | |
SymbolHandle *symbols, | |
SymbolHandle *out) { | |
Symbol *s = new Symbol(); | |
Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*) | |
API_BEGIN(); | |
std::vector<Symbol> syms; | |
for (nn_uint i = 0; i < num_symbols; ++i) { | |
syms.push_back(*sym_arr[i]); | |
} | |
*s = Symbol::CreateGroup(syms); | |
*out = s; | |
API_END_HANDLE_ERROR(delete s); | |
} | |
int NNSymbolGetOutput(SymbolHandle symbol, | |
nn_uint index, | |
SymbolHandle *out) { | |
Symbol *s = new Symbol(); | |
API_BEGIN(); | |
*s = (*static_cast<Symbol*>(symbol))[index]; | |
*out = s; | |
API_END_HANDLE_ERROR(delete s); | |
} | |
int NNSymbolGetInternals(SymbolHandle symbol, | |
SymbolHandle *out) { | |
Symbol *s = new Symbol(); | |
API_BEGIN(); | |
*s = static_cast<Symbol*>(symbol)->GetInternals(); | |
*out = s; | |
API_END_HANDLE_ERROR(delete s); | |
} | |
int NNSymbolFree(SymbolHandle symbol) { | |
API_BEGIN(); | |
delete static_cast<Symbol*>(symbol); | |
API_END(); | |
} | |
int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out) { | |
Symbol *s = new Symbol(); | |
API_BEGIN(); | |
*s = static_cast<const Symbol*>(symbol)->Copy(); | |
*out = s; | |
API_END_HANDLE_ERROR(delete s); | |
} | |
int NNSymbolPrint(SymbolHandle symbol, const char **out_str) { | |
Symbol *s = static_cast<Symbol*>(symbol); | |
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); | |
API_BEGIN(); | |
std::ostringstream os; | |
s->Print(os); | |
ret->ret_str = os.str(); | |
*out_str = (ret->ret_str).c_str(); | |
API_END(); | |
} | |
int NNSymbolGetAttr(SymbolHandle symbol, | |
const char* key, | |
const char** out, | |
int* success) { | |
Symbol *s = static_cast<Symbol*>(symbol); | |
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); | |
API_BEGIN(); | |
if (s->GetAttr(key, &(ret->ret_str))) { | |
*out = (ret->ret_str).c_str(); | |
*success = 1; | |
} else { | |
*out = nullptr; | |
*success = 0; | |
} | |
API_END(); | |
} | |
int NNSymbolSetAttrs(SymbolHandle symbol, | |
nn_uint num_param, | |
const char** keys, | |
const char** vals) { | |
Symbol *s = static_cast<Symbol*>(symbol); | |
API_BEGIN(); | |
std::vector<std::pair<std::string, std::string> > kwargs; | |
for (nn_uint i = 0; i < num_param; ++i) { | |
kwargs.emplace_back( | |
std::make_pair(std::string(keys[i]), std::string(vals[i]))); | |
} | |
s->SetAttrs(kwargs); | |
API_END(); | |
} | |
int NNSymbolListAttrs(SymbolHandle symbol, | |
int option, | |
nn_uint *out_size, | |
const char*** out) { | |
Symbol *s = static_cast<Symbol*>(symbol); | |
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); | |
API_BEGIN(); | |
std::unordered_map<std::string, std::string> attr = | |
s->ListAttrs(static_cast<Symbol::ListAttrOption>(option)); // NOLINT(*) | |
std::vector<std::string>& attr_list = ret->ret_vec_str; | |
attr_list.clear(); | |
for (const auto& kv : attr) { | |
attr_list.push_back(kv.first); | |
attr_list.push_back(kv.second); | |
} | |
*out_size = attr.size(); | |
ret->ret_vec_charp.clear(); | |
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { | |
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); | |
} | |
*out = dmlc::BeginPtr(ret->ret_vec_charp); | |
API_END(); | |
} | |
int NNSymbolListInputVariables(SymbolHandle symbol, | |
int option, | |
nn_uint *out_size, | |
SymbolHandle** out_sym_array) { | |
Symbol *s = static_cast<Symbol*>(symbol); | |
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); | |
API_BEGIN(); | |
std::vector<NodePtr> vs = s->ListInputs(Symbol::ListInputOption(option)); | |
ret->ret_handles.clear(); | |
for (size_t i = 0; i < vs.size(); ++i) { | |
nnvm::Symbol* rs = new nnvm::Symbol(); | |
rs->outputs.push_back(NodeEntry{vs[i], 0, 0}); | |
ret->ret_handles.push_back(rs); | |
} | |
*out_size = static_cast<nn_uint>(vs.size()); | |
*out_sym_array = dmlc::BeginPtr(ret->ret_handles); | |
API_END(); | |
} | |
int NNSymbolListInputNames(SymbolHandle symbol, | |
int option, | |
nn_uint *out_size, | |
const char ***out_str_array) { | |
Symbol *s = static_cast<Symbol*>(symbol); | |
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); | |
API_BEGIN(); | |
ret->ret_vec_str = | |
s->ListInputNames(Symbol::ListInputOption(option)); | |
ret->ret_vec_charp.clear(); | |
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { | |
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); | |
} | |
*out_size = static_cast<nn_uint>(ret->ret_vec_charp.size()); | |
*out_str_array = dmlc::BeginPtr(ret->ret_vec_charp); | |
API_END(); | |
} | |
int NNSymbolListOutputNames(SymbolHandle symbol, | |
nn_uint *out_size, | |
const char ***out_str_array) { | |
Symbol *s = static_cast<Symbol*>(symbol); | |
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); | |
API_BEGIN(); | |
ret->ret_vec_str = s->ListOutputNames(); | |
ret->ret_vec_charp.clear(); | |
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { | |
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); | |
} | |
*out_size = static_cast<nn_uint>(ret->ret_vec_charp.size()); | |
*out_str_array = dmlc::BeginPtr(ret->ret_vec_charp); | |
API_END(); | |
} | |
int NNSymbolCompose(SymbolHandle sym, | |
const char *name, | |
nn_uint num_args, | |
const char** keys, | |
SymbolHandle* args) { | |
API_BEGIN(); | |
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); | |
std::string& s_name = ret->ret_str; | |
std::unordered_map<std::string, const Symbol*>& kwargs | |
= ret->kwarg_symbol; | |
kwargs.clear(); | |
if (name != nullptr) { | |
s_name = name; | |
} else { | |
s_name.clear(); | |
} | |
Symbol* s = static_cast<Symbol*>(sym); | |
if (keys == nullptr && num_args != 0) { | |
kwargs.clear(); | |
array_view<const Symbol*> parg( | |
(Symbol**)args, (Symbol**)args + num_args); // NOLINT(*) | |
s->Compose(parg, kwargs, s_name); | |
} else { | |
for (nn_uint i = 0; i < num_args; ++i) { | |
kwargs[keys[i]] = (Symbol*)args[i]; // NOLINT(*) | |
} | |
s->Compose(array_view<const Symbol*>(), kwargs, s_name); | |
} | |
API_END(); | |
} | |
//===== EXPANDED: ../nnvm/src/c_api/c_api_symbolic.cc ===== | |
//===== EXPANDED: nnvm.cc ===== | |
//===== EXPANDING: mxnet_predict0.cc ===== | |
// mexnet.cc | |
#if defined(__ANDROID__) || defined(__MXNET_JS__) | |
#define MSHADOW_USE_SSE 0 | |
#endif | |
//===== EXPANDING: ../src/ndarray/ndarray_function.cc ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file ndarray_function_cpu.cc | |
* \brief CPU Implementation of ndarray function. | |
*/ | |
// this will be invoked by gcc and compile CPU version | |
//===== EXPANDING: ../src/ndarray/ndarray_function.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file ndarray_op.h | |
* \brief the real execution functions of ndarray operations | |
*/ | |
#ifndef MXNET_NDARRAY_NDARRAY_FUNCTION_H_ | |
#define MXNET_NDARRAY_NDARRAY_FUNCTION_H_ | |
//===== EXPANDING: ../include/mxnet/resource.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file resource.h | |
* \brief Global resource allocation handling. | |
*/ | |
#ifndef MXNET_RESOURCE_H_ | |
#define MXNET_RESOURCE_H_ | |
//===== EXPANDING: ../include/mxnet/engine.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file engine.h | |
* \brief Engine that schedules all the operations according to dependency. | |
*/ | |
#ifndef MXNET_ENGINE_H_ | |
#define MXNET_ENGINE_H_ | |
#if DMLC_USE_CXX11 | |
#endif | |
namespace mxnet { | |
// forward declare engine | |
class Engine; | |
/*! \brief namespace of engine internal types. */ | |
namespace engine { | |
/*! \brief Internal representation of variable. */ | |
struct Var; | |
/*! \brief Internal representation of operator. */ | |
struct Opr; | |
/*! \brief Variable pointer type, usually hold by user used to specify dependencies. */ | |
typedef Var* VarHandle; | |
/*! \brief Operator pointer type, usually hold by user.*/ | |
typedef Opr* OprHandle; | |
/*! | |
* \brief OnComplete Callback to the engine, | |
* called by AsyncFn when action completes | |
*/ | |
class CallbackOnComplete { | |
public: | |
// use implicit copy and assign | |
/*! \brief involve the callback */ | |
inline void operator()() const { | |
(*callback_)(engine_, param_); | |
} | |
private: | |
/*! \brief engine can see content of callback */ | |
friend class ::mxnet::Engine; | |
/*! \brief the real callback */ | |
void (*callback_)(Engine *, void *); | |
/*! \brief the engine class passed to callback */ | |
Engine* engine_; | |
/*! \brief the parameter set on callback */ | |
void* param_; | |
}; | |
} // namespace engine | |
#if DMLC_USE_CXX11 | |
/*! \brief Function property, used to hint what action is pushed to engine. */ | |
enum class FnProperty { | |
/*! \brief Normal operation */ | |
kNormal, | |
/*! \brief Copy operation from GPU to other devices */ | |
kCopyFromGPU, | |
/*! \brief Copy operation from CPU to other devices */ | |
kCopyToGPU, | |
/*! \brief Prioritized sync operation on CPU */ | |
kCPUPrioritized, | |
/*! \brief Asynchronous function call */ | |
kAsync | |
}; // enum class FnProperty | |
/*! | |
* \brief Dependency engine that schedules operations. | |
*/ | |
class MXNET_API Engine { | |
public: | |
/*! \brief callback on complete*/ | |
typedef engine::CallbackOnComplete CallbackOnComplete; | |
/*! \brief Synchronous operation to pass to engine. */ | |
typedef std::function<void(RunContext)> SyncFn; | |
/*! \brief Asynchronous operation to pass to engine. */ | |
typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn; | |
/*! \brief Variable pointer */ | |
typedef engine::VarHandle VarHandle; | |
/*! \brief Operator pointer */ | |
typedef engine::OprHandle OprHandle; | |
/*! | |
* \brief Notify the engine about a shutdown, | |
* This can help engine to print less messages into display. | |
* | |
* User do not have to call this function. | |
* \return 0 when success, -1 when failure happens. | |
*/ | |
virtual void NotifyShutdown() = 0; | |
/*! | |
* \brief Allocate a new variable, the variable can then | |
* be used to schedule the operation concurrently via dependency | |
* patterns. | |
* \return The new variable allocated. | |
*/ | |
virtual VarHandle NewVariable() = 0; | |
/*! | |
* \brief Create a new operator. The returned operator could be saved | |
* externally so that it could be resued for scheduling. | |
* \param fn The execution function. | |
* \param const_vars The variables that current operation will use but not | |
* mutate. | |
* \param mutable_vars The variables that current operation will mutate. | |
* \param prop Property of the function. | |
* \param opr_name The operator name. | |
* \return The new operator allocated. | |
*/ | |
virtual OprHandle NewOperator(AsyncFn fn, | |
std::vector<VarHandle> const& const_vars, | |
std::vector<VarHandle> const& mutable_vars, | |
FnProperty prop = FnProperty::kNormal, | |
const char* opr_name = nullptr) = 0; | |
/*! | |
* \brief Delete the given operator. | |
* \param op The operator to delete. | |
* | |
* The delete will not happen immediately, but will wait until all the | |
* operations using this operator are completed. | |
*/ | |
virtual void DeleteOperator(OprHandle op) = 0; | |
/*! | |
* \brief Push an operator to the engine. | |
* \param op The operator to push. | |
* \param exec_ctx Execution context. | |
* \param priority Priority of the action, as hint to the engine. | |
* \param profiling The variable indicate whether to profile this operator. | |
*/ | |
virtual void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) = 0; | |
/*! | |
* \brief Push an asynchronous operation to the engine. | |
* \param exec_fun Execution function, this function takes a parameter | |
* on_complete that must be called when the execution | |
* completes. | |
* \param exec_ctx Execution context. | |
* \param const_vars The variables that current operation will use but not | |
* mutate. | |
* \param mutable_vars The variables that current operation will mutate. | |
* \param prop Property of the function. | |
* \param priority Priority of the action, as hint to the engine. | |
* \param opr_name The operator name. | |
*/ | |
virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, | |
std::vector<VarHandle> const& const_vars, | |
std::vector<VarHandle> const& mutable_vars, | |
FnProperty prop = FnProperty::kNormal, | |
int priority = 0, | |
const char* opr_name = nullptr) = 0; | |
/*! | |
* \brief Schedule the deletion of a variable. | |
* | |
* The delete will not happen immediately, but will wait until all the | |
* operations depending on var are completed. | |
* | |
* \param delete_fn A function that will be called after the variable is | |
* deleted. | |
* \param exec_ctx Execution context. | |
* \param var The variable to be deleted. | |
*/ | |
virtual void DeleteVariable(SyncFn delete_fn, | |
Context exec_ctx, | |
VarHandle var) = 0; | |
/*! | |
* \brief Wait for a variable. | |
* \param var The variable we should wait for. This function returns when the | |
* variable is ready. | |
*/ | |
virtual void WaitForVar(VarHandle var) = 0; | |
/*! | |
* \brief Wait until all the activity of engine finishes. | |
*/ | |
virtual void WaitForAll() = 0; | |
/*!\brief virtual destructor */ | |
virtual ~Engine() noexcept(false) {} | |
/*! | |
* \return Engine singleton. | |
*/ | |
static Engine* Get(); | |
/*! | |
* \brief Get shared pointer reference to engine singleton. | |
* Most user should not call this function. | |
* This function is called by another singleton X who requires | |
* engine to be destructed after X. | |
* | |
* \return A shared pointer to Engine singleton. | |
*/ | |
static std::shared_ptr<Engine> _GetSharedRef(); | |
/*! | |
* \brief Push an synchronous operation to the engine. | |
* \param exec_fn Execution function that executes the operation. | |
* \param exec_ctx Execution context. | |
* \param const_vars The variables that current operation will use but not | |
* mutate. | |
* \param mutable_vars The variables that current operation will mutate. | |
* \param prop Property of the function. | |
* \param priority Priority of the action, as hint to the engine. | |
* \param opr_name The operator name. | |
* \tparam SyncFn the synchronous function to be pushed. | |
*/ | |
inline void PushSync(SyncFn exec_fn, Context exec_ctx, | |
std::vector<VarHandle> const& const_vars, | |
std::vector<VarHandle> const& mutable_vars, | |
FnProperty prop = FnProperty::kNormal, | |
int priority = 0, | |
const char* opr_name = nullptr) { | |
this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) { | |
exec_fn(ctx); | |
on_complete(); | |
}, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name); | |
} | |
/*! | |
* \brief factory function to create OnComplete callback. | |
* \param callback th static callback function. | |
* \param param the paramter passed to callback. | |
*/ | |
inline CallbackOnComplete CreateCallback( | |
void (*callback)(Engine *, void *), void *param) { | |
CallbackOnComplete ret; | |
ret.callback_ = callback; | |
ret.engine_ = this; | |
ret.param_ = param; | |
return ret; | |
} | |
}; // class Engine | |
#endif // DMLC_USE_CXX11 | |
} // namespace mxnet | |
#endif // MXNET_ENGINE_H_ | |
//===== EXPANDED: ../include/mxnet/engine.h ===== | |
namespace mxnet { | |
/*! | |
* \brief The resources that can be requested by Operator | |
*/ | |
struct ResourceRequest { | |
/*! \brief Resource type, indicating what the pointer type is */ | |
enum Type { | |
/*! \brief mshadow::Random<xpu> object */ | |
kRandom, | |
/*! \brief A dynamic temp space that can be arbitrary size */ | |
kTempSpace | |
}; | |
/*! \brief type of resources */ | |
Type type; | |
/*! \brief default constructor */ | |
ResourceRequest() {} | |
/*! | |
* \brief constructor, allow implicit conversion | |
* \param type type of resources | |
*/ | |
ResourceRequest(Type type) // NOLINT(*) | |
: type(type) {} | |
}; | |
/*! | |
* \brief Resources used by mxnet operations. | |
* A resource is something special other than NDArray, | |
* but will still participate | |
*/ | |
struct Resource { | |
/*! \brief The original request */ | |
ResourceRequest req; | |
/*! \brief engine variable */ | |
engine::VarHandle var; | |
/*! \brief identifier of id information, used for debug purpose */ | |
int32_t id; | |
/*! | |
* \brief pointer to the resource, do not use directly, | |
* access using member functions | |
*/ | |
void *ptr_; | |
/*! \brief default constructor */ | |
Resource() : id(0) {} | |
/*! | |
* \brief Get random number generator. | |
* \param stream The stream to use in the random number generator. | |
* \return the mshadow random number generator requested. | |
* \tparam xpu the device type of random number generator. | |
*/ | |
template<typename xpu, typename DType> | |
inline mshadow::Random<xpu, DType>* get_random( | |
mshadow::Stream<xpu> *stream) const { | |
CHECK_EQ(req.type, ResourceRequest::kRandom); | |
mshadow::Random<xpu, DType> *ret = | |
static_cast<mshadow::Random<xpu, DType>*>(ptr_); | |
ret->set_stream(stream); | |
return ret; | |
} | |
/*! | |
* \brief Get space requested as mshadow Tensor. | |
* The caller can request arbitrary size. | |
* | |
* This space can be shared with other calls to this->get_space. | |
* So the caller need to serialize the calls when using the conflicted space. | |
* The old space can get freed, however, this will incur a synchronization, | |
* when running on device, so the launched kernels that depend on the temp space | |
* can finish correctly. | |
* | |
* \param shape the Shape of returning tensor. | |
* \param stream the stream of retruning tensor. | |
* \return the mshadow tensor requested. | |
* \tparam xpu the device type of random number generator. | |
* \tparam ndim the number of dimension of the tensor requested. | |
*/ | |
template<typename xpu, int ndim> | |
inline mshadow::Tensor<xpu, ndim, real_t> get_space( | |
mshadow::Shape<ndim> shape, mshadow::Stream<xpu> *stream) const { | |
return get_space_typed<xpu, ndim, real_t>(shape, stream); | |
} | |
/*! | |
* \brief Get cpu space requested as mshadow Tensor. | |
* The caller can request arbitrary size. | |
* | |
* \param shape the Shape of returning tensor. | |
* \return the mshadow tensor requested. | |
* \tparam ndim the number of dimension of the tensor requested. | |
*/ | |
template<int ndim> | |
inline mshadow::Tensor<cpu, ndim, real_t> get_host_space( | |
mshadow::Shape<ndim> shape) const { | |
return get_host_space_typed<cpu, ndim, real_t>(shape); | |
} | |
/*! | |
* \brief Get space requested as mshadow Tensor in specified type. | |
* The caller can request arbitrary size. | |
* | |
* \param shape the Shape of returning tensor. | |
* \param stream the stream of retruning tensor. | |
* \return the mshadow tensor requested. | |
* \tparam xpu the device type of random number generator. | |
* \tparam ndim the number of dimension of the tensor requested. | |
*/ | |
template<typename xpu, int ndim, typename DType> | |
inline mshadow::Tensor<xpu, ndim, DType> get_space_typed( | |
mshadow::Shape<ndim> shape, mshadow::Stream<xpu> *stream) const { | |
CHECK_EQ(req.type, ResourceRequest::kTempSpace); | |
return mshadow::Tensor<xpu, ndim, DType>( | |
reinterpret_cast<DType*>(get_space_internal(shape.Size() * sizeof(DType))), | |
shape, shape[ndim - 1], stream); | |
} | |
/*! | |
* \brief Get CPU space as mshadow Tensor in specified type. | |
* The caller can request arbitrary size. | |
* | |
* \param shape the Shape of returning tensor | |
* \return the mshadow tensor requested | |
* \tparam ndim the number of dimnesion of tensor requested | |
* \tparam DType request data type | |
*/ | |
template<int ndim, typename DType> | |
inline mshadow::Tensor<cpu, ndim, DType> get_host_space_typed( | |
mshadow::Shape<ndim> shape) const { | |
return mshadow::Tensor<cpu, ndim, DType>( | |
reinterpret_cast<DType*>(get_host_space_internal(shape.Size() * sizeof(DType))), | |
shape, shape[ndim - 1], NULL); | |
} | |
/*! | |
* \brief internal function to get space from resources. | |
* \param size The size of the space. | |
* \return The allocated space. | |
*/ | |
void* get_space_internal(size_t size) const; | |
/*! | |
* \brief internal function to get cpu space from resources. | |
* \param size The size of space. | |
* \return The allocated space | |
*/ | |
void *get_host_space_internal(size_t size) const; | |
}; | |
/*! \brief Global resource manager */ | |
class ResourceManager { | |
public: | |
/*! | |
* \brief Get resource of requested type. | |
* \param ctx the context of the request. | |
* \param req the resource request. | |
* \return the requested resource. | |
* \note The returned resource's ownership is | |
* still hold by the manager singleton. | |
*/ | |
virtual Resource Request(Context ctx, const ResourceRequest &req) = 0; | |
/*! | |
* \brief Seed all the allocated random numbers. | |
* \param seed the seed to the random number generators on all devices. | |
*/ | |
virtual void SeedRandom(uint32_t seed) = 0; | |
/*! \brief virtual destructor */ | |
virtual ~ResourceManager() DMLC_THROW_EXCEPTION {} | |
/*! | |
* \return Resource manager singleton. | |
*/ | |
static ResourceManager *Get(); | |
}; | |
} // namespace mxnet | |
#endif // MXNET_RESOURCE_H_ | |
//===== EXPANDED: ../include/mxnet/resource.h ===== | |
//===== EXPANDING: ../src/operator/mshadow_op.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file mshadow_op.h | |
* \brief | |
* \author Bing Xu | |
*/ | |
#ifndef MXNET_OPERATOR_MSHADOW_OP_H_ | |
#define MXNET_OPERATOR_MSHADOW_OP_H_ | |
//===== EXPANDING: ../src/operator/special_functions-inl.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file special_functions-inl.h | |
* \brief | |
* \author Valentin Flunkert | |
*/ | |
#ifndef MXNET_OPERATOR_SPECIAL_FUNCTIONS_INL_H_ | |
#define MXNET_OPERATOR_SPECIAL_FUNCTIONS_INL_H_ | |
namespace mxnet { | |
namespace op { | |
namespace special_functions { | |
template<typename DType> | |
struct helper_numeric_limits { | |
MSHADOW_XINLINE static DType max(); | |
}; | |
template<> | |
struct helper_numeric_limits<double> { | |
MSHADOW_XINLINE static double max() { | |
return DBL_MAX; | |
} | |
}; | |
template<> | |
struct helper_numeric_limits<float> { | |
MSHADOW_XINLINE static double max() { | |
return FLT_MAX; | |
} | |
}; | |
// This code is based on the Cephes Library availible at http://www.netlib.org/cephes | |
// The original author, Stephen Moshier, has kindly given permission to use this code | |
// in mxnet. (See email below). | |
// | |
// Date: Tue, 13 Sep 2016 09:28:20 -0400 | |
// From: Stephen Moshier | |
// To: Flunkert, Valentin | |
// Subject: Re: cephes code in mxnet | |
// | |
// Hello Valentin, | |
// | |
// Thank you for writing. You are welcome to use and modify the Cephes code | |
// and distribute it under the Apache license. | |
// | |
// Good luck with your project, | |
// Steve Moshier | |
// | |
// Cephes Math Library Release 2.2: June, 1992 | |
// Copyright 1984, 1987, 1992 by Stephen L. Moshier | |
// Direct inquiries to 30 Frost Street, Cambridge, MA 02140 | |
// | |
struct cephes { | |
/* | |
* Helper to evaluate a polynomial given an array of coefficients. | |
*/ | |
template <typename DType> | |
MSHADOW_XINLINE static DType polevl(DType x, const DType coef[], int N) { | |
DType ans; | |
DType const *p; | |
int i; | |
p = coef; | |
ans = *p++; | |
i = N; | |
do { | |
ans = ans * x + *p++; | |
} while ( --i ); | |
return( ans ); | |
} | |
/* | |
* Helper function for psi that handles double/float specific differences | |
* in the algorithm. | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static DType psi_helper(DType s); | |
/* | |
* | |
* Psi (digamma) function | |
* | |
* | |
* SYNOPSIS: | |
* | |
* float x, y, psif(); | |
* | |
* y = psif( x ); | |
* | |
* | |
* DESCRIPTION: | |
* | |
* d - | |
* psi(x) = -- ln | (x) | |
* dx | |
* | |
* is the logarithmic derivative of the gamma function. | |
* For integer x, | |
* n-1 | |
* - | |
* psi(n) = -EUL + > 1/k. | |
* - | |
* k=1 | |
* | |
* This formula is used for 0 < n <= 10. If x is negative, it | |
* is transformed to a positive argument by the reflection | |
* formula psi(1-x) = psi(x) + pi cot(pi x). | |
* For general positive x, the argument is made greater than 10 | |
* using the recurrence psi(x+1) = psi(x) + 1/x. | |
* Then the following asymptotic expansion is applied: | |
* | |
* inf. B | |
* - 2k | |
* psi(x) = log(x) - 1/2x - > ------- | |
* - 2k | |
* k=1 2k x | |
* | |
* where the B2k are Bernoulli numbers. | |
* | |
* ACCURACY: | |
* Absolute error, relative when |psi| > 1 : | |
* arithmetic domain # trials peak rms | |
* IEEE -33,0 30000 8.2e-7 1.2e-7 | |
* IEEE 0,33 100000 7.3e-7 7.7e-8 | |
* | |
* ERROR MESSAGES: | |
* message condition value returned | |
* psi singularity x integer <=0 MAXNUMF | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static DType psi(DType x) { | |
DType p, q, nz, s, w, y; | |
int i, n, negative; | |
DType EUL(0.57721566490153286061); | |
DType PI(3.14159265358979323846); | |
negative = 0; | |
nz = 0.0; | |
if ( x <= 0.0 ) { | |
negative = 1; | |
q = x; | |
p = std::floor(q); | |
if ( p == q ) { | |
return helper_numeric_limits<double>::max(); | |
} | |
/* Remove the zeros of tan(PI x) | |
* by subtracting the nearest integer from x | |
*/ | |
nz = q - p; | |
if ( nz != 0.5 ) { | |
if ( nz > 0.5 ) { | |
p += 1.0; | |
nz = q - p; | |
} | |
nz = PI/std::tan(PI*nz); | |
} else { | |
nz = 0.0; | |
} | |
x = 1.0 - x; | |
} | |
/* check for positive integer up to 10 */ | |
if ( (x <= 10.0) && (x == std::floor(x)) ) { | |
y = 0.0; | |
n = x; | |
for ( i = 1; i < n; i++ ) { | |
w = i; | |
y += 1.0/w; | |
} | |
y -= EUL; | |
goto done; | |
} | |
s = x; | |
w = 0.0; | |
while ( s < 10.0 ) { | |
w += 1.0/s; | |
s += 1.0; | |
} | |
y = psi_helper(s); | |
y = logf(s) - (0.5/s) - y - w; | |
done: | |
if ( negative ) { | |
y -= nz; | |
} | |
return(y); | |
} | |
}; | |
template<> | |
MSHADOW_XINLINE double cephes::psi_helper<double>(double s) { | |
double z; | |
const double A[] = { | |
8.33333333333333333333E-2, | |
-2.10927960927960927961E-2, | |
7.57575757575757575758E-3, | |
-4.16666666666666666667E-3, | |
3.96825396825396825397E-3, | |
-8.33333333333333333333E-3, | |
8.33333333333333333333E-2 | |
}; | |
if ( s < 1.0e17 ) { | |
z = 1.0/(s * s); | |
return z * cephes::polevl<double>(z, A, 6); | |
} else { | |
return 0.0; | |
} | |
} | |
template<> | |
MSHADOW_XINLINE float cephes::psi_helper<float>(float s) { | |
float z; | |
const float A[] = { | |
-4.16666666666666666667E-3f, | |
3.96825396825396825397E-3f, | |
-8.33333333333333333333E-3f, | |
8.33333333333333333333E-2f | |
}; | |
if ( s < 1.0e8 ) { | |
z = 1.0/(s * s); | |
return z * cephes::polevl<float>(z, A, 3); | |
} else { | |
return 0.0; | |
} | |
} | |
} // namespace special_functions | |
} // namespace op | |
} // namespace mxnet | |
#endif // MXNET_OPERATOR_SPECIAL_FUNCTIONS_INL_H_ | |
//===== EXPANDED: ../src/operator/special_functions-inl.h ===== | |
namespace mxnet { | |
namespace op { | |
namespace mshadow_op { | |
#ifdef __CUDA_ARCH__ | |
__constant__ const float PI = 3.14159265358979323846; | |
#else | |
const float PI = 3.14159265358979323846; | |
using std::isnan; | |
#endif | |
/*! \brief identity Operation */ | |
struct identity { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(a); | |
} | |
}; | |
struct identity_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(DType(1.0f)); | |
} | |
}; | |
struct left { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a; | |
} | |
}; | |
struct right { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return b; | |
} | |
}; | |
struct negation { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(-a); | |
} | |
}; | |
/*! \brief sigmoid unit */ | |
struct sigmoid { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(DType(1.0f) / (DType(1.0f) + expf(-a))); | |
} | |
}; | |
struct sigmoid_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(a * (DType(1.0f) - a)); | |
} | |
}; | |
/*! \brief Rectified Linear Operation */ | |
struct relu { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(a > DType(0.0f) ? a : DType(0.0f)); | |
} | |
}; | |
struct relu_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(a > DType(0.0f) ? DType(1.0f) : DType(0.0f)); | |
} | |
}; | |
/*! \brief Leaky ReLU Operation */ | |
struct xelu { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(a > DType(0.0f) ? a : a * b); | |
} | |
}; | |
struct xelu_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(a > DType(0.0f) ? DType(1.0f) : b); | |
} | |
}; | |
/*! \brief Exponential Linear Unit */ | |
struct elu { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType x, DType a) { | |
return DType(x > DType(0.0f) ? x : a * (expf(x) - DType(1.0f))); | |
} | |
}; | |
struct elu_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType x, DType a) { | |
return DType(x > DType(0.0f) ? DType(1.0f) : a + x); | |
} | |
}; | |
struct tanh { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(tanhf( a )); | |
} | |
}; | |
struct tanh_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(DType(1.0f) - a * a); | |
} | |
}; | |
/*! \brief SoftReLU, also known as softplus activation. */ | |
struct softrelu { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(log1pf(expf(a))); | |
} | |
}; | |
struct softrelu_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(DType(1.0f) - expf(-a)); | |
} | |
}; | |
struct exp { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(expf(a)); | |
} | |
}; | |
struct expm1 { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(expm1f(a)); | |
} | |
}; | |
struct log { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(logf(a)); | |
} | |
}; | |
struct log10 { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(log10f(a)); | |
} | |
}; | |
struct log2 { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(log2f(a)); | |
} | |
}; | |
struct log_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(DType(1.0f) / a); | |
} | |
}; | |
struct sin { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(sinf(a)); | |
} | |
}; | |
struct sin_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(cosf(a)); | |
} | |
}; | |
struct log1p { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(log1pf(a)); | |
} | |
}; | |
struct log1p_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(DType(1.0f) / (DType(1.0f) + a)); | |
} | |
}; | |
struct cos { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(cosf(a)); | |
} | |
}; | |
struct cos_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(-sinf(a)); | |
} | |
}; | |
struct tan { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(tanf(a)); | |
} | |
}; | |
struct tan_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(a * a + 1); | |
} | |
}; | |
struct arcsin { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(asinf(a)); | |
} | |
}; | |
struct arcsin_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(1.0 / (sqrtf(1 - a*a))); | |
} | |
}; | |
struct arccos { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(acosf(a)); | |
} | |
}; | |
struct arccos_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(-1.0 / (sqrtf(1 - a*a))); | |
} | |
}; | |
struct arctan { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(atanf(a)); | |
} | |
}; | |
struct arctan_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(1 / (a*a + 1)); | |
} | |
}; | |
struct hypot { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(sqrtf(a * a + b * b)); | |
} | |
}; | |
struct hypot_grad_left { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(a/sqrtf(a * a + b * b)); | |
} | |
}; | |
struct hypot_grad_right { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(b/sqrtf(a * a + b * b)); | |
} | |
}; | |
struct degrees { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(180. / PI * a); | |
} | |
}; | |
struct degrees_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(180. / PI); | |
} | |
}; | |
struct radians { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(PI /180. * a); | |
} | |
}; | |
struct radians_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(PI / 180.); | |
} | |
}; | |
struct sinh { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(sinhf(a)); | |
} | |
}; | |
struct sinh_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(coshf(a)); | |
} | |
}; | |
struct cosh { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(coshf(a)); | |
} | |
}; | |
struct cosh_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(sinhf(a)); | |
} | |
}; | |
struct arcsinh { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(asinhf(a)); | |
} | |
}; | |
struct arcsinh_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(1.0 / (sqrtf(1 + a*a))); | |
} | |
}; | |
struct arccosh { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(acoshf(a)); | |
} | |
}; | |
struct arccosh_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(1.0 / (sqrtf(a*a - 1.0))); | |
} | |
}; | |
struct arctanh { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(atanhf(a)); | |
} | |
}; | |
struct arctanh_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(-1.0 / (a*a - 1.0)); | |
} | |
}; | |
struct square { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(a * a); | |
} | |
}; | |
struct square_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(DType(2.0f) * a); | |
} | |
}; | |
/*! \brief used for generate Bernoulli mask */ | |
struct threshold { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(a < b ? DType(1.0f) : DType(0.0f)); | |
} | |
}; | |
/*! \brief used for generate element of abs */ | |
struct abs { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(fabsf(float(a))); // NOLINT(*) | |
} | |
}; | |
/*! \brief used for generate element of sign */ | |
struct sign { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
if (a < 0.0f) return DType(-DType(1.0f)); | |
if (a > 0.0f) return DType(DType(1.0f)); | |
return DType(DType(0.0f)); | |
} | |
}; | |
struct sign_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(DType(0.0f)); | |
} | |
}; | |
/*! \brief used for generate element of power */ | |
struct power { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(powf( a, b )); | |
} | |
}; | |
struct power_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(powf( a, b - 1 )*b); | |
} | |
}; | |
struct power_rgrad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(powf( a, b )*logf(a)); | |
} | |
}; | |
struct rpower { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(powf( b, a )); | |
} | |
}; | |
struct rpower_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(a*logf(b)); | |
} | |
}; | |
/*! \brief used for generate element of maximum */ | |
struct maximum { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a > b ? a : b; | |
} | |
}; | |
/*! \brief used for generate element of minimum */ | |
struct minimum { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a < b ? a : b; | |
} | |
}; | |
struct ge { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a >= b ? DType(1) : DType(0); | |
} | |
}; | |
struct gt { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a > b ? DType(1) : DType(0); | |
} | |
}; | |
struct lt { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a < b ? DType(1) : DType(0); | |
} | |
}; | |
struct le { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a <= b ? DType(1) : DType(0); | |
} | |
}; | |
struct eq { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a == b ? DType(1) : DType(0); | |
} | |
}; | |
struct ne { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return a != b ? DType(1) : DType(0); | |
} | |
}; | |
/*!\ \brief used for generate element sqrt */ | |
struct square_root { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(sqrtf(a)); | |
} | |
}; | |
struct square_root_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(DType(0.5f) / a); | |
} | |
}; | |
/*!\ \brief used for generate element sqrt */ | |
struct reciprocal_square_root { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(DType(1.0f)/sqrtf(a)); | |
} | |
}; | |
struct reciprocal_square_root_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(-(DType(1.0f) / (DType(2.0f) * a * sqrtf(a)))); | |
} | |
}; | |
/*! \brief used for generate element of round */ | |
struct round { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(roundf(a)); | |
} | |
}; | |
/*! \brief used for generate element of ceil */ | |
struct ceil { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(ceilf(a)); | |
} | |
}; | |
/*! \brief used for generate element of floor */ | |
struct floor { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
return DType(floorf(a)); | |
} | |
}; | |
/*! \brief used to round number to nearest integer */ | |
struct rint { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
float floor = floorf(a); | |
float ceil = ceilf(a); | |
return DType((floor - a) < (ceil - a) ? floor : ceil); | |
} | |
}; | |
/*! \brief used to round number to integer nearest to 0 */ | |
struct fix { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
float floor = floorf(a); | |
float ceil = ceilf(a); | |
return DType((floor - 0) < (ceil - 0) ? floor : ceil); | |
} | |
}; | |
/*! \brief used for generate gradient of MAE loss*/ | |
struct minus_sign { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(a-b > DType(0.0f) ? DType(1.0f) : -DType(1.0f)); | |
} | |
}; | |
struct rminus { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(b-a); | |
} | |
}; | |
struct div_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(DType(1)/b); | |
} | |
}; | |
struct div_rgrad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(-a/(b*b)); | |
} | |
}; | |
struct rdiv { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(b/a); | |
} | |
}; | |
struct rdiv_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return DType(-b/(a*a)); | |
} | |
}; | |
struct clip { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType x, DType bound) { | |
if (x > bound) { | |
return bound; | |
} else if (x < -bound) { | |
return -bound; | |
} else { | |
return x; | |
} | |
} | |
}; | |
/***** gamma ******/ | |
struct gamma { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
// default implementation using floating precision | |
return DType(tgammaf(a)); | |
} | |
}; | |
template<> | |
MSHADOW_XINLINE double gamma::Map<double>(double a) { | |
return tgamma(a); | |
} | |
struct gamma_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
// default implementation using floating precision | |
return DType(tgammaf(a) * special_functions::cephes::psi<float>(a)); | |
} | |
}; | |
template<> | |
MSHADOW_XINLINE double gamma_grad::Map<double>(double a) { | |
return tgamma(a) * special_functions::cephes::psi<double>(a); | |
} | |
/***** gammaln ******/ | |
struct gammaln { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
// default implementation using floating precision | |
return DType(lgammaf(a)); | |
} | |
}; | |
template<> | |
MSHADOW_XINLINE double gammaln::Map<double>(double a) { | |
return lgamma(a); | |
} | |
struct gammaln_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a) { | |
// default implementation using floating precision | |
return DType(special_functions::cephes::psi<float>(a)); | |
} | |
}; | |
template<> | |
MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) { | |
return special_functions::cephes::psi<double>(a); | |
} | |
/* Smooth L1 Loss is a loss specific for R-CNN franchise training | |
* Smooth L1 Loss function | |
* f(x) = 0.5 * (sigma * x) ^ 2, x < 1 / sigma^2 | |
* = |x| - 0.5 / sigma / sigma, otherwise | |
* When sigma = 1, it is equivalent to Huber Loss evaluated at | |
* delta = 1. | |
* smooth_l1_loss = w_out * f(w_in * x) | |
* with w_in, w_out provided by input_data. | |
*/ | |
struct smooth_l1_loss { | |
// a is x, b is sigma2 | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
b *= b; | |
if (a > 1.0f / b) { | |
return a - 0.5f / b; | |
} else if (a < -1.0f / b) { | |
return -a - 0.5f / b; | |
} else { | |
return 0.5f * a * a * b; | |
} | |
} | |
}; // struct smooth_l1_loss | |
/* The derivative of smooth l1 loss is | |
* f'(x) = sigma^2 * x, x < 1 / sigma^2 | |
* = sign(x), otherwise | |
*/ | |
struct smooth_l1_gradient { | |
// a is x, b is sigma2 | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
b *= b; | |
if (a > 1.0f / b) { | |
return 1.0f; | |
} else if (a < -1.0f / b) { | |
return DType(-1); | |
} else { | |
return b * a; | |
} | |
} | |
}; // struct smooth_l1_derivative | |
/*! \brief product reducer */ | |
struct product { | |
/*! \brief do reduction into dst */ | |
template<typename DType> | |
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) | |
dst *= src; | |
} | |
/*! | |
*\brief calculate gradient of redres with respect to redsrc, | |
* redres: reduced result, redsrc: one of reduction element | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { | |
return redres / redsrc; | |
} | |
/*! | |
*\brief set the initial value during reduction | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) | |
initv = 1; | |
} | |
}; | |
namespace isnan_typed { | |
template<typename DType> | |
MSHADOW_XINLINE bool IsNan(volatile DType val) { | |
return false; | |
} | |
template<> | |
MSHADOW_XINLINE bool IsNan(volatile float val) { | |
return isnan(val); | |
} | |
template<> | |
MSHADOW_XINLINE bool IsNan(volatile double val) { | |
return isnan(val); | |
} | |
template<> | |
MSHADOW_XINLINE bool IsNan(volatile long double val) { | |
return isnan(val); | |
} | |
template<> | |
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) { | |
return (val.half_ & 0x7fff) > 0x7c00; | |
} | |
}; // namespace isnan_typed | |
/*! \brief sum reducer that ignores NaN values in the input */ | |
struct nansum { | |
/*! \brief do reduction into dst */ | |
template<typename DType> | |
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) | |
if (isnan_typed::IsNan(dst)) { | |
if (isnan_typed::IsNan(src)) { | |
dst = DType(0); | |
} else { | |
dst = src; | |
} | |
} else { | |
if (isnan_typed::IsNan(src)) { | |
dst = dst; | |
} else { | |
dst += src; | |
} | |
} | |
} | |
/*! | |
*\brief set the initial value during reduction | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*) | |
initv = 0; | |
} | |
}; | |
struct nansum_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return isnan_typed::IsNan(a) ? DType(0) : DType(1); | |
} | |
}; | |
/*! \brief product reducer that ignores NaN values in the input */ | |
struct nanprod { | |
/*! \brief do reduction into dst */ | |
template<typename DType> | |
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) | |
if (isnan_typed::IsNan(dst)) { | |
if (isnan_typed::IsNan(src)) { | |
dst = DType(1); | |
} else { | |
dst = src; | |
} | |
} else { | |
if (isnan_typed::IsNan(src)) { | |
dst = dst; | |
} else { | |
dst *= src; | |
} | |
} | |
} | |
/*! | |
*\brief set the initial value during reduction | |
*/ | |
template<typename DType> | |
MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*) | |
initv = 1; | |
} | |
}; | |
struct nanprod_grad { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
return isnan_typed::IsNan(a) ? DType(0) : b / a; | |
} | |
}; | |
} // namespace mshadow_op | |
} // namespace op | |
} // namespace mxnet | |
#endif // MXNET_OPERATOR_MSHADOW_OP_H_ | |
//===== EXPANDED: ../src/operator/mshadow_op.h ===== | |
namespace mxnet { | |
/*! \brief namespace to support all possible Ndarray operator */ | |
namespace ndarray { | |
struct BinaryBase { | |
inline static TShape GetShape(const TShape &lshape, const TShape &rshape) { | |
CHECK(lshape == rshape) << "operands shape mismatch"; | |
CHECK(lshape.ndim() != 0) << "source operand have zero dimension shape"; | |
return lshape; | |
} | |
}; | |
// operators | |
struct Plus : public BinaryBase { | |
typedef mshadow::op::plus mshadow_op; | |
}; | |
struct Minus : public BinaryBase { | |
typedef mshadow::op::minus mshadow_op; | |
}; | |
struct Mul : public BinaryBase { | |
typedef mshadow::op::mul mshadow_op; | |
}; | |
struct Div : public BinaryBase { | |
typedef mshadow::op::div mshadow_op; | |
}; | |
struct ClipMin : public BinaryBase { | |
struct mshadow_op { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
if (a < b) { | |
return b; | |
} else { | |
return a; | |
} | |
} | |
}; | |
}; | |
struct ClipMax : public BinaryBase { | |
struct mshadow_op { | |
template<typename DType> | |
MSHADOW_XINLINE static DType Map(DType a, DType b) { | |
if (a > b) { | |
return b; | |
} else { | |
return a; | |
} | |
} | |
}; | |
}; | |
struct OneHotEncode { | |
inline static TShape GetShape(const TShape &index, const TShape &proptype) { | |
CHECK(index.ndim() == 1 && proptype.ndim() == 2) << "OneHotEncode only support 1d index."; | |
CHECK_EQ(index[0], proptype[0]) << "OneHotEncode shape inconsistent"; | |
return proptype; | |
} | |
}; | |
struct MatChooseRowElem { | |
inline static TShape GetShape(const TShape &lshape, const TShape &rshape) { | |
CHECK(lshape.ndim() == 2 && rshape.ndim() == 1) | |
<< "choose_row_element only support 2D Matrix and 1D index"; | |
CHECK_EQ(lshape[0], rshape[0]) << "choose_row_element index and matrix shape mismatch"; | |
return rshape; | |
} | |
}; | |
struct MatFillRowElem { | |
inline static TShape GetShape(const TShape &lshape, const TShape &mshape, const TShape &rshape) { | |
CHECK(lshape.ndim() == 2 && mshape.ndim() == 1 && rshape.ndim() == 1) | |
<< "fill_row_element only support 2D Matrix, 1D value and 1D index"; | |
CHECK((lshape[0] == mshape[0]) && (mshape[0] == rshape[0])) | |
<< "choose_row_element index vector, value vector and matrix shape mismatch"; | |
return lshape; | |
} | |
}; | |
// type holder for random number generators | |
struct UniformDistribution {}; | |
struct GaussianDistribution {}; | |
template<typename Device> | |
void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max, | |
TBlob *ret, RunContext ctx); | |
template<typename Device, typename OP> | |
void Eval(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, TBlob *ret, RunContext ctx); | |
template<typename Device, typename OP> | |
void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx); | |
template<typename Device, typename OP> | |
void Eval(const TBlob &src, TBlob *ret, RunContext ctx); | |
template<typename Device, typename OP, bool reverse> | |
void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx); | |
template<typename Device> | |
void Eval(const real_t &rhs, TBlob *ret, RunContext ctx); | |
template<typename Device, typename Distribution> | |
void EvalRandom(const real_t &a, | |
const real_t &b, | |
const Resource &resource, | |
TBlob *ret, RunContext ctx); | |
// copy function when only cpu is involved | |
template<typename DeviceFrom, typename DeviceTo> | |
void Copy(const TBlob &from, TBlob *to, | |
Context from_ctx, Context to_ctx, | |
RunContext ctx); | |
template<typename Device> | |
void ElementwiseSum(const std::vector<TBlob> source, | |
TBlob *out, | |
RunContext ctx); | |
// broadcasting | |
template <typename Device> | |
void EvalBroadcast(TBlob const& src, TBlob* ret, int size, RunContext ctx); | |
} // namespace ndarray | |
} // namespace mxnet | |
#endif // MXNET_NDARRAY_NDARRAY_FUNCTION_H_ | |
//===== EXPANDED: ../src/ndarray/ndarray_function.h ===== | |
//===== EXPANDING: ../src/ndarray/ndarray_function-inl.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file ndarray_function-inl.h | |
* \brief The real implementation of NDArray functions. | |
*/ | |
#ifndef MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_ | |
#define MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_ | |
// this file will be included twice by CPU and GPU | |
// macro to help specialize evaluation function | |
#ifndef DECL_TERNARY | |
#define DECL_TERNARY(XPU, OP, FUN) \ | |
template<> \ | |
void Eval<XPU, OP>(const TBlob &lhs, const TBlob &mhs, \ | |
const TBlob &rhs, TBlob *ret, RunContext ctx) { \ | |
FUN<XPU, OP>(lhs, mhs, rhs, ret, ctx); \ | |
} | |
#endif | |
#ifndef DECL_BINARY | |
#define DECL_BINARY(XPU, OP, FUN) \ | |
template<> \ | |
void Eval<XPU, OP>(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { \ | |
FUN<XPU, OP>(lhs, rhs, ret, ctx); \ | |
} | |
#endif | |
#ifndef DECL_SCALAR | |
#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ | |
template<> \ | |
void Eval<XPU, OP, REVERSE>(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { \ | |
FUN<XPU, OP, REVERSE>(lhs, rhs, ret, ctx); \ | |
} | |
#endif | |
#if defined(__CUDACC__) | |
#define DEVICE gpu | |
#else | |
#define DEVICE cpu | |
#endif | |
namespace mxnet { | |
namespace ndarray { | |
// true implementation | |
template<typename xpu, typename OP> | |
inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, | |
TBlob *ret, RunContext ctx) { | |
using namespace mshadow::expr; | |
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | |
CHECK_EQ(ret->type_flag_, lhs.type_flag_) | |
<< "Only support input/output with the same data type"; | |
CHECK_EQ(ret->type_flag_, rhs.type_flag_) | |
<< "Only support input/output with the same data type"; | |
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { | |
ret->FlatTo2D<xpu, DType>(s) | |
= F<typename OP::mshadow_op>(lhs.FlatTo2D<xpu, DType>(s), | |
rhs.FlatTo2D<xpu, DType>(s)); | |
}); | |
} | |
template<typename xpu, typename OP> | |
inline void EvalOneHot_(const TBlob &index, const TBlob &rhs, | |
TBlob *ret, RunContext ctx) { | |
using namespace mshadow::expr; | |
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | |
// TODO(eric): support mixed type encoding, i.e. int index and float rhs. | |
CHECK_EQ(ret->type_flag_, mshadow::default_type_flag) | |
<< "one_hot_encode only support float32 as input/output"; | |
CHECK_EQ(rhs.type_flag_, mshadow::default_type_flag) | |
<< "one_hot_encode only support float32 as input/output"; | |
CHECK_EQ(index.type_flag_, mshadow::default_type_flag) | |
<< "one_hot_encode only support float32 as input/output"; | |
ret->get<xpu, 2, real_t>(s) = | |
one_hot_encode(index.get<xpu, 1, real_t>(s), | |
rhs.shape_[1]); | |
} | |
template<typename xpu, typename OP> | |
inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, | |
TBlob *ret, RunContext ctx) { | |
using namespace mshadow::expr; | |
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | |
// TODO(eric): support mixed type choose, i.e. int index and float rhs. | |
CHECK_EQ(ret->type_flag_, mshadow::default_type_flag) | |
<< "mat_choose_row_element only support float32 as input/output"; | |
CHECK_EQ(rhs.type_flag_, mshadow::default_type_flag) | |
<< "mat_choose_row_element only support float32 as input/output"; | |
CHECK_EQ(lhs.type_flag_, mshadow::default_type_flag) | |
<< "mat_choose_row_element only support float32 as input/output"; | |
ret->get<xpu, 1, real_t>(s) | |
= mat_choose_row_element(lhs.get<xpu, 2, real_t>(s), | |
rhs.get<xpu, 1, real_t>(s)); | |
} | |
template<typename xpu, typename OP> | |
inline void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, | |
TBlob *ret, RunContext ctx) { | |
using namespace mshadow::expr; | |
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | |
ret->get<xpu, 2, real_t>(s) | |
= mat_fill_row_element(lhs.get<xpu, 2, real_t>(s), | |
mhs.get<xpu, 1, real_t>(s), | |
rhs.get<xpu, 1, real_t>(s)); | |
} | |
template<typename xpu, typename OP, bool reverse> | |
inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, | |
TBlob *ret, RunContext ctx) { | |
using namespace mshadow::expr; | |
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | |
CHECK_EQ(ret->type_flag_, lhs.type_flag_) | |
<< "Only support input/output with the same data type"; | |
if (reverse) { | |
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { | |
ret->FlatTo2D<xpu, DType>(s) | |
= F<typename OP::mshadow_op>(scalar(DType(rhs)), lhs.FlatTo2D<xpu, DType>(s)); | |
}); | |
} else { | |
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { | |
ret->FlatTo2D<xpu, DType>(s) | |
= F<typename OP::mshadow_op>(lhs.FlatTo2D<xpu, DType>(s), scalar(DType(rhs))); | |
}); | |
} | |
} | |
template<> | |
void EvalClip<DEVICE>(const TBlob &src, const real_t &a_min, const real_t &a_max, | |
TBlob *ret, RunContext ctx) { | |
typedef DEVICE xpu; | |
using namespace mshadow::expr; | |
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | |
CHECK_EQ(ret->type_flag_, src.type_flag_) | |
<< "Only support input/output with the same data type"; | |
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { | |
ret->FlatTo2D<xpu, DType>(s) | |
= F<ClipMax::mshadow_op>( | |
F<ClipMin::mshadow_op>(src.FlatTo2D<xpu, DType>(s), scalar(DType(a_min))), | |
scalar(DType(a_max))); | |
}); | |
} | |
template<> | |
void EvalRandom<DEVICE, UniformDistribution>( | |
const real_t &a, | |
const real_t &b, | |
const Resource &resource, | |
TBlob *ret, | |
RunContext ctx) { | |
typedef DEVICE xpu; | |
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | |
switch (ret->type_flag_) { | |
case mshadow::kFloat32: | |
{ | |
mshadow::Random<xpu, float> *prnd = resource.get_random<xpu, float>(s); | |
mshadow::Tensor<xpu, 2, float> tmp = ret->FlatTo2D<xpu, float>(s); | |
prnd->SampleUniform(&tmp, float(a), float(b)); // NOLINT(*) | |
break; | |
} | |
case mshadow::kFloat64: | |
{ | |
mshadow::Random<xpu, double> *prnd = resource.get_random<xpu, double>(s); | |
mshadow::Tensor<xpu, 2, double> tmp = ret->FlatTo2D<xpu, double>(s); | |
prnd->SampleUniform(&tmp, double(a), double(b)); // NOLINT(*) | |
break; | |
} | |
default: | |
LOG(FATAL) << "Random only support float32 and float64"; | |
} | |
} | |
template<> | |
void EvalRandom<DEVICE, GaussianDistribution>( | |
const real_t &mu, | |
const real_t &sigma, | |
const Resource &resource, | |
TBlob *ret, | |
RunContext ctx) { | |
typedef DEVICE xpu; | |
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | |
switch (ret->type_flag_) { | |
case mshadow::kFloat32: | |
{ | |
mshadow::Random<xpu, float> *prnd = resource.get_random<xpu, float>(s); | |
mshadow::Tensor<xpu, 2, float> tmp = ret->FlatTo2D<xpu, float>(s); | |
prnd->SampleGaussian(&tmp, float(mu), float(sigma)); // NOLINT(*) | |
break; | |
} | |
case mshadow::kFloat64: | |
{ | |
mshadow::Random<xpu, double> *prnd = resource.get_random<xpu, double>(s); | |
mshadow::Tensor<xpu, 2, double> tmp = ret->FlatTo2D<xpu, double>(s); | |
prnd->SampleGaussian(&tmp, double(mu), double(sigma)); // NOLINT(*) | |
break; | |
} | |
default: | |
LOG(FATAL) << "Random only support float32 and float64"; | |
} | |
} | |
template<> | |
void Eval<DEVICE>(const real_t &rhs, TBlob *ret, RunContext ctx) { | |
mshadow::Stream<DEVICE> *s = ctx.get_stream<DEVICE>(); | |
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { | |
ret->FlatTo2D<DEVICE, DType>(s) = DType(rhs); | |
}); | |
} | |
template<> | |
void ElementwiseSum<DEVICE>(const std::vector<TBlob> source, | |
TBlob *dst, | |
RunContext ctx) { | |
typedef DEVICE xpu; | |
using namespace mshadow; | |
using namespace mshadow::expr; | |
Stream<xpu> *s = ctx.get_stream<xpu>(); | |
for (size_t i = 1; i < source.size(); ++i) { | |
CHECK_EQ(source[i].type_flag_, dst->type_flag_) | |
<< "Only support input/output with the same data type"; | |
} | |
MSHADOW_TYPE_SWITCH(dst->type_flag_, DType, { | |
Tensor<xpu, 2, DType> out = dst->FlatTo2D<xpu, DType>(s); | |
switch (source.size()) { | |
case 2: { | |
Tensor<xpu, 2, DType> in_0 = source[0].FlatTo2D<xpu, DType>(s); | |
Tensor<xpu, 2, DType> in_1 = source[1].FlatTo2D<xpu, DType>(s); | |
out = in_0 + in_1; | |
break; | |
} | |
case 3: { | |
Tensor<xpu, 2, DType> in_0 = source[0].FlatTo2D<xpu, DType>(s); | |
Tensor<xpu, 2, DType> in_1 = source[1].FlatTo2D<xpu, DType>(s); | |
Tensor<xpu, 2, DType> in_2 = source[2].FlatTo2D<xpu, DType>(s); | |
out = in_0 + in_1 + in_2; | |
break; | |
} | |
case 4: { | |
Tensor<xpu, 2, DType> in_0 = source[0].FlatTo2D<xpu, DType>(s); | |
Tensor<xpu, 2, DType> in_1 = source[1].FlatTo2D<xpu, DType>(s); | |
Tensor<xpu, 2, DType> in_2 = source[2].FlatTo2D<xpu, DType>(s); | |
Tensor<xpu, 2, DType> in_3 = source[3].FlatTo2D<xpu, DType>(s); | |
out = in_0 + in_1 + in_2 + in_3; | |
break; | |
} | |
default: { | |
Tensor<xpu, 2, DType> in_0 = source[0].FlatTo2D<xpu, DType>(s); | |
out = F<mshadow::op::identity>(in_0); | |
for (size_t i = 1; i < source.size(); ++i) { | |
out += source[i].FlatTo2D<xpu, DType>(s); | |
} | |
break; | |
} | |
} | |
}); | |
} | |
template <> | |
void EvalBroadcast<DEVICE>(TBlob const& src, TBlob* ret, int size, RunContext ctx) { | |
typedef DEVICE xpu; | |
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); | |
mshadow::Tensor<xpu, 3> out = ret->get<xpu, 3, real_t>(s); | |
mshadow::Tensor<xpu, 2> in = src.get<xpu, 2, real_t>(s); | |
out = mshadow::expr::broadcast_with_axis(in, 0, size); | |
} | |
// declarations | |
DECL_BINARY(DEVICE, MatChooseRowElem, EvalMatChooseRowElem_) | |
DECL_TERNARY(DEVICE, MatFillRowElem, EvalMatFillRowElem_) | |
DECL_BINARY(DEVICE, OneHotEncode, EvalOneHot_) | |
DECL_BINARY(DEVICE, Plus, EvalBinary_) | |
DECL_BINARY(DEVICE, Minus, EvalBinary_) | |
DECL_BINARY(DEVICE, Mul, EvalBinary_) | |
DECL_BINARY(DEVICE, Div, EvalBinary_) | |
DECL_SCALAR(DEVICE, Plus, EvalScalar_, true) | |
DECL_SCALAR(DEVICE, Minus, EvalScalar_, true) | |
DECL_SCALAR(DEVICE, Mul, EvalScalar_, true) | |
DECL_SCALAR(DEVICE, Div, EvalScalar_, true) | |
// for reverse seq | |
DECL_SCALAR(DEVICE, Plus, EvalScalar_, false) | |
DECL_SCALAR(DEVICE, Minus, EvalScalar_, false) | |
DECL_SCALAR(DEVICE, Mul, EvalScalar_, false) | |
DECL_SCALAR(DEVICE, Div, EvalScalar_, false) | |
} // namespace ndarray | |
} // namespace mxnet | |
#endif // MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_ | |
//===== EXPANDED: ../src/ndarray/ndarray_function-inl.h ===== | |
namespace mxnet { | |
namespace ndarray { | |
template<> | |
void Copy<cpu, cpu>(const TBlob &from, TBlob *to, | |
Context from_ctx, Context to_ctx, | |
RunContext ctx) { | |
MSHADOW_TYPE_SWITCH(to->type_flag_, DType, { | |
if (to->type_flag_ == from.type_flag_) { | |
mshadow::Copy(to->FlatTo1D<cpu, DType>(), | |
from.FlatTo1D<cpu, DType>()); | |
} else { | |
MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, { | |
to->FlatTo1D<cpu, DType>() = | |
mshadow::expr::tcast<DType>(from.FlatTo1D<cpu, SrcDType>()); | |
}) | |
} | |
}) | |
} | |
} // namespace ndarray | |
} // namespace mxnet | |
//===== EXPANDED: ../src/ndarray/ndarray_function.cc ===== | |
//===== EXPANDING: ../src/ndarray/ndarray.cc ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file ndarray.cc | |
* \brief ndarry module of mxnet | |
*/ | |
//===== EXPANDING: ../include/mxnet/ndarray.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file ndarray.h | |
* \brief NDArray interface that handles array arithematics. | |
*/ | |
#ifndef MXNET_NDARRAY_H_ | |
#define MXNET_NDARRAY_H_ | |
//===== EXPANDING: ../include/mxnet/storage.h ===== | |
/*! | |
* Copyright (c) 2015 by Contributors | |
* \file storage.h | |
* \brief Storage manager across multiple devices. | |
*/ | |
#ifndef MXNET_STORAGE_H_ | |
#define MXNET_STORAGE_H_ | |
namespace mxnet { | |
/*! | |
* \brief Storage manager across multiple devices. | |
*/ | |
class Storage { | |
public: | |
/*! | |
* \brief Storage handle. | |
*/ | |
struct Handle { | |
/*! | |
* \brief Pointer to the data. | |
*/ | |
void* dptr; | |
/*! | |
* \brief Size of the storage. | |
*/ | |
size_t size; | |
/*! | |
* \brief Context information about device and ID. | |
*/ | |
Context ctx; | |
}; | |
/*! | |
* \brief Allocate a new contiguous memory for a given size. | |
* \param size Total size of memory in bytes. | |
* \param ctx Context information about the device and ID. | |
* \return Handle struct. | |
*/ | |
virtual Handle Alloc(size_t size, Context ctx) = 0; | |
/*! | |
* \brief Free storage. | |
* \param handle Handle struect. | |
*/ | |
virtual void Free(Handle handle) = 0; | |
/*! | |
* \brief Free storage directly, without putting it into memory pool. | |
* This can synchronization of all previous runned device functions. | |
* | |
* This function is suitable for conatiner structure with requirement on upsizing | |
* in the beginning phase of the iteration. | |
* | |
* \param handle Handle struct. | |
*/ | |
virtual void DirectFree(Handle handle) = 0; | |
/*! | |
* \brief Destructor. | |
*/ | |
virtual ~Storage() {} | |
/*! | |
* \return Storage singleton. | |
*/ | |
static Storage* Get(); | |
/*! | |
* \brief Get shared pointer reference to engine singleton. | |
* Most user should not call this function. | |
* This function is called by another singleton X who requires | |
* Storage to be destructed after X. | |
* | |
* \return A shared pointer to Storage singleton. | |
*/ | |
static std::shared_ptr<Storage> _GetSharedRef(); | |
}; // class Storage | |
} // namespace mxnet | |
#endif // MXNET_STORAGE_H_ | |
//===== EXPANDED: ../include/mxnet/storage.h ===== | |
#if MKL_EXPERIMENTAL == 1 | |
#endif | |
// check c++11 | |
#if DMLC_USE_CXX11 == 0 | |
#error "cxx11 was required for ndarray module" | |
#endif | |
namespace mxnet { | |
/*! | |
* \brief ndarray interface | |
*/ | |
class NDArray { | |
public: | |
/*! \brief default cosntructor */ | |
NDArray() { | |
#if MKL_EXPERIMENTAL == 1 | |
Mkl_mem_ = MKLMemHolder::create(); | |
#endif | |
} | |
/*! | |
* \brief constructing a new dynamic NDArray | |
* \param shape the shape of array | |
* \param ctx context of NDArray | |
* \param delay_alloc whether delay the allocation | |
* \param dtype data type of this ndarray | |
*/ | |
NDArray(const TShape &shape, Context ctx, | |
bool delay_alloc = false, int dtype = mshadow::default_type_flag) | |
: ptr_(std::make_shared<Chunk>(shape.Size(), ctx, delay_alloc, dtype)), | |
shape_(shape), offset_(0), dtype_(dtype) { | |
#if MKL_EXPERIMENTAL == 1 | |
Mkl_mem_ = std::make_shared<MKLMemHolder>(); | |
#endif | |
} | |
/*! | |
* \brief constructing a static NDArray that shares data with TBlob | |
* Use with caution: allocate ONLY ONE NDArray for each TBlob, | |
* make sure the memory region is available through out the life of NDArray | |
* \param data the memory content of static data | |
* \param dev_id the device id this tensor sits at | |
*/ | |
NDArray(const TBlob &data, int dev_id) | |
: ptr_(std::make_shared<Chunk>(data, dev_id)), shape_(data.shape_), offset_(0), | |
dtype_(data.type_flag_) { | |
#if MKL_EXPERIMENTAL == 1 | |
Mkl_mem_ = std::make_shared<MKLMemHolder>(); | |
#endif | |
} | |
/*! | |
* \return the shape of current NDArray | |
*/ | |
inline const TShape &shape() const { | |
return shape_; | |
} | |
/*! | |
* \return the data TBlob | |
*/ | |
inline TBlob data() const { | |
TBlob res; | |
MSHADOW_TYPE_SWITCH(dtype_, DType, { | |
res = TBlob(static_cast<DType*>(ptr_->shandle.dptr) | |
+ offset_, shape_, ptr_->shandle.ctx.dev_mask()); | |
}); | |
#if MKL_EXPERIMENTAL == 1 | |
res.Mkl_mem_ = Mkl_mem_; | |
#endif | |
return res; | |
} | |
/*! | |
* \return a chunk of raw data in TBlob | |
*/ | |
inline TBlob raw_data(index_t offset, index_t length) const { | |
TBlob res; | |
TShape raw_shape(1); | |
raw_shape[0] = length; | |
MSHADOW_TYPE_SWITCH(dtype_, DType, { | |
res = TBlob(static_cast<DType*>(ptr_->shandle.dptr) | |
+ offset_ + offset, raw_shape, ptr_->shandle.ctx.dev_mask()); | |
}); | |
#if MKL_EXPERIMENTAL == 1 | |
res.Mkl_mem_ = Mkl_mem_; | |
#endif | |
return res; | |
} | |
/*! | |
* \return the context of NDArray, this function is only valid when the NDArray is not empty | |
*/ | |
inline Context ctx() const { | |
return ptr_->shandle.ctx; | |
} | |
/*! | |
* \return the data type of NDArray, this function is only valid when the NDArray is not empty | |
*/ | |
inline int dtype() const { | |
return dtype_; | |
} | |
/*! \return whether this ndarray is not initialized */ | |
inline bool is_none() const { | |
return ptr_.get() == nullptr; | |
} | |
/*! | |
* \brief Block until all the pending write operations with respect | |
* to current NDArray are finished, and read can be performed. | |
*/ | |
inline void WaitToRead() const { | |
if (is_none()) return; | |
Engine::Get()->WaitForVar(ptr_->var); | |
} | |
/*! | |
* \brief Block until all the pending read/write operations with respect | |
* to current NDArray are finished, and write can be performed. | |
*/ | |
inline void WaitToWrite() const { | |
if (is_none()) return; | |
/*! | |
* Push an empty mutable function to flush all preceding reads to the | |
* variable. | |
*/ | |
Engine::Get()->PushSync([](RunContext) {}, Context{}, {}, {ptr_->var}); | |
Engine::Get()->WaitForVar(ptr_->var); | |
} | |
/*! \return the associated variable of the ndarray.*/ | |
inline Engine::VarHandle var() const { | |
return ptr_->var; | |
} | |
/*! | |
* \brief save the content into binary stream | |
* \param strm the output stream | |
*/ | |
void Save(dmlc::Stream *strm) const; | |
/*! | |
* \brief load the content from binary stream | |
* \param strm the output stream | |
* \return whether the load is successful | |
*/ | |
bool Load(dmlc::Stream *strm); | |
/*! | |
* \brief set all the elements in ndarray to be scalar | |
* \param scalar the scalar to set | |
* \return reference of self | |
*/ | |
NDArray &operator=(real_t scalar); | |
/*! | |
* \brief elementwise add to current space | |
* this mutate the current NDArray | |
* \param src the data to add | |
* \return reference of self | |
*/ | |
NDArray &operator+=(const NDArray &src); | |
/*! | |
* \brief elementwise add to current space | |
* this mutate the current NDArray | |
* \param src the data to add | |
* \return reference of self | |
*/ | |
NDArray &operator+=(const real_t &src); | |
/*! | |
* \brief elementwise subtract from current ndarray | |
* this mutate the current NDArray | |
* \param src the data to subtract | |
* \return reference of self | |
*/ | |
NDArray &operator-=(const NDArray &src); | |
/*! | |
* \brief elementwise subtract from current ndarray | |
* this mutate the current NDArray | |
* \param src the data to subtract | |
* \return reference of self | |
*/ | |
NDArray &operator-=(const real_t &src); | |
/*! | |
* \brief elementwise multiplication to current ndarray | |
* this mutate the current NDArray | |
* \param src the data to subtract | |
* \return reference of self | |
*/ | |
NDArray &operator*=(const NDArray &src); | |
/*! | |
* \brief elementwise multiplication to current ndarray | |
* this mutate the current NDArray | |
* \param src the data to subtract | |
* \return reference of self | |
*/ | |
NDArray &operator*=(const real_t &src); | |
/*! | |
* \brief elementwise division from current ndarray | |
* this mutate the current NDArray | |
* \param src the data to subtract | |
* \return reference of self | |
*/ | |
NDArray &operator/=(const NDArray &src); | |
/*! | |
* \brief elementwise division from current ndarray | |
* this mutate the current NDArray | |
* \param src the data to subtract | |
* \return reference of self | |
*/ | |
NDArray &operator/=(const real_t &src); | |
/*! | |
* \brief return transpose of current NDArray | |
* \return a new transposed NDArray | |
*/ | |
NDArray T() const; | |
/*! | |
* \brief return a new copy this NDArray | |
* \param ctx the new context of this NDArray | |
* \return the new copy | |
*/ | |
NDArray Copy(Context ctx) const; | |
/*! | |
* \brief Do a synchronize copy from a continugous CPU memory region. | |
* | |
* This function will call WaitToWrite before the copy is performed. | |
* This is useful to copy data from existing memory region that are | |
* not wrapped by NDArray(thus dependency not being tracked). | |
* | |
* \param data the data source to copy from. | |
* \param size the size of the source array, in sizeof(DType) not raw btyes. | |
*/ | |
void SyncCopyFromCPU(const void *data, size_t size) const; | |
/*! | |
* \brief Do a synchronize copy to a continugous CPU memory region. | |
* | |
* This function will call WaitToRead before the copy is performed. | |
* This is useful to copy data from existing memory region that are | |
* not wrapped by NDArray(thus dependency not being tracked). | |
* | |
* \param data the data source to copyinto. | |
* \param size the memory size we want to copy into, in sizeof(DType) not raw btyes. | |
*/ | |
void SyncCopyToCPU(void *data, size_t size) const; | |
/*! | |
* \brief Slice a NDArray | |
* \param begin begin index in first dim | |
* \param end end index in first dim | |
* \return sliced NDArray | |
*/ | |
inline NDArray Slice(index_t begin, index_t end) const { | |
NDArray ret = *this; | |
CHECK(!is_none()) << "NDArray is not initialized"; | |
CHECK_GE(shape_[0], end) << "Slice end index out of range"; | |
size_t length = shape_.ProdShape(1, shape_.ndim()); | |
ret.offset_ += begin * length; | |
ret.shape_[0] = end - begin; | |
return ret; | |
} | |
/*! | |
* \brief Index a NDArray | |
* \param idx the index | |
* \return idx-th sub array NDArray | |
*/ | |
inline NDArray At(index_t idx) const { | |
NDArray ret = *this; | |
CHECK(!is_none()) << "NDArray is not initialized"; | |
CHECK_GT(shape_[0], idx) << "index out of range"; | |
size_t length = shape_.ProdShape(1, shape_.ndim()); | |
ret.offset_ += idx * length; | |
if (shape_.ndim() > 1) { | |
ret.shape_ = TShape(shape_.data()+1, shape_.data()+shape_.ndim()); | |
} else { | |
ret.shape_ = mshadow::Shape1(1); | |
} | |
return ret; | |
} | |
/*! | |
* \brief Create a NDArray that shares memory with current one | |
* The new array must have smaller memory size than the current array. | |
* \param shape new shape | |
* \param dtype The data type. | |
* \return NDArray in new shape and type. | |
*/ | |
inline NDArray AsArray(const TShape &shape, int dtype) const { | |
CHECK_GE(shape_.Size() * mshadow::mshadow_sizeof(dtype_), | |
shape.Size() * mshadow::mshadow_sizeof(dtype)) | |
<< "NDArray.AsArray: target memory size is bigger"; | |
#if MKL_EXPERIMENTAL == 1 | |
if (Mkl_mem_ != nullptr) { | |
// convert prv to cpu | |
Mkl_mem_->check_and_prv_to_cpu(ptr_->shandle.dptr); | |
} | |
#endif | |
NDArray ret = *this; | |
ret.shape_ = shape; | |
ret.dtype_ = dtype; | |
return ret; | |
} | |
/*! | |
* \brief Get an reshaped NDArray | |
* \param shape new shape | |
* \return NDArray in new shape | |
*/ | |
inline NDArray Reshape(const TShape &shape) const { | |
CHECK_GE(shape_.Size(), shape.Size()) | |
<< "NDArray.Reshape: target shape size is different from current shape"; | |
NDArray ret = *this; | |
ret.shape_ = shape; | |
return ret; | |
} | |
/*! | |
* \brief Allocate the space if it is delayed allocated. | |
* This is an internal function used by system that normal user should not use | |
*/ | |
inline void CheckAndAlloc() const { | |
ptr_->CheckAndAlloc(); | |
} | |
/*! | |
* \brief Save list of narray into the Stream.x | |
* \param fo The stream of output. | |
* \param data the NDArrays to be saved. | |
* \param names the name of the NDArray, optional, can be zero length. | |
*/ | |
static void Save(dmlc::Stream* fo, | |
const std::vector<NDArray>& data, | |
const std::vector<std::string>& names); | |
/*! | |
* \brief Load list of narray into from th |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment