max_shared_ptr.hh

00001 // This file is a part of Aurelia.
00002 // Copyright (C) 2010  Valentin David
00003 // Copyright (C) 2010  University of Bergen
00004 //
00005 // This program is free software: you can redistribute it and/or modify
00006 // it under the terms of the GNU General Public License as published by
00007 // the Free Software Foundation, either version 3 of the License, or
00008 // (at your option) any later version.
00009 //
00010 // This program is distributed in the hope that it will be useful,
00011 // but WITHOUT ANY WARRANTY; without even the implied warranty of
00012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00013 // GNU General Public License for more details.
00014 //
00015 // You should have received a copy of the GNU General Public License
00016 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
00017 
00018 #ifndef __MAX_SHARED_PTR_HH
00019 # define __MAX_SHARED_PTR_HH
00020 
00021 # include "../lockfree/unordered_set.hh"
00022 # include <unordered_set>
00023 # include "../memory/pool_allocator.hh"
00024 # include <memory>
00025 # include <atomic>
00026 
00027 # include "../type_traits/is_same_ellipsis.hh"
00028 # include <type_traits/is_constructible.hh>
00029 # include "deref.hh"
00030 # include "../hash/mhash.hh"
00031 # include <wrappers/members.hh>
00032 
00033 namespace aurelia {
00034 
00035   struct virt_ptr {
00036   public:
00037     std::atomic<long>& count() const {
00038       return const_cast<virt_ptr*>(this)->_count;
00039     }
00040 
00041   private:
00042     std::atomic<long> _count;
00043 
00044   public:
00045     virt_ptr() {
00046       _count.store(0);
00047     }
00048 
00049     ~virt_ptr() {
00050     }
00051 
00052   protected:
00053     virt_ptr(const virt_ptr& other): _count(0) {
00054       _count.store(other._count.load());
00055     }
00056 
00057   public:
00058     virt_ptr(virt_ptr&& other): _count(0) {
00059       _count.store(other._count.load());
00060     }
00061 
00062     virt_ptr& operator=(virt_ptr&& other) {
00063       long tmp = _count.load();
00064       _count.store(other._count.load());
00065       other._count.store(tmp);
00066       return *this;
00067     }
00068   };
00069 
00070   DEF_MEMBER_WRAPPER(collect);
00071 
00072   template <typename T,
00073             bool = IF(catsfoot::is_callable<member_collect(T&)>)>
00074   struct collector {
00075     static bool collect(T&) {
00076       return false;
00077     }
00078   };
00079 
00080   template <typename T>
00081   struct collector<T, true> {
00082     static bool collect(T& t) {
00083       return t.collect();
00084     }
00085   };
00086 
00090   template <typename T, typename Hash = mhash<T>,
00091             typename EqnFunc = std::equal_to<T>,
00092             typename Alloc = pool_allocator<T>,
00093             bool lockfree = true>
00094   struct max_shared_ptr {
00095   protected:
00096     struct ptr;
00097     typedef typename
00098     std::conditional<lockfree,
00099                      lockfree::unordered_set<ptr, deref_op<Hash>,
00100                                              deref_op<EqnFunc>,
00101                                              Alloc>,
00102                      std::unordered_set<ptr, deref_op<Hash>,
00103                                         deref_op<EqnFunc>,
00104                                         Alloc> >::type set_t;
00105     struct ptr: public virt_ptr {
00106     private:
00107       typedef virt_ptr super;
00108       T core;
00109 
00110       explicit ptr(T&& core): core(std::move(core)) {
00111       }
00112 
00113     public:
00114       static void init(set_t& _set) {
00115         //Ensure all allocators have been used before call to atexit
00116         //so that collect will be called before any call to static
00117         //cleanup of allocators.
00118         _set.insert(ptr(T()));
00119         _set.clear();
00120         while (collector<set_t>::collect(_set)) ;
00121         atexit(collect);
00122       }
00123 
00124       const T& operator*() const {
00125         return core;
00126       }
00127 
00128       const T* get() const {
00129         return &core;
00130       }
00131 
00132       ptr(const ptr& other): core(other.core) {}
00133 
00134       ptr(ptr&& other): super(std::move(other)),
00135                         core(std::move(other.core))
00136       {}
00137 
00138       ptr& operator=(ptr&& other) {
00139         std::swap(*(super*)this, (super&)other);
00140         std::swap(core, other.core);
00141         return *this;
00142       }
00143 
00144       static const typename set_t::iterator create(T&& core) {
00145         ptr tmp(std::move(core));
00146         while (true) {
00147           std::pair<typename set_t::iterator, bool> r =
00148             set().insert(std::move(tmp));
00149           if (r.second) {
00150             (*r.first).count()++;
00151             return r.first;
00152           }
00153           long count = (*r.first).count().load();
00154           if (count == 0)
00155             //Either being built, or destroyed.
00156             //Note in this case that ~ptr has been already called, but
00157             //not the deallocator. For this reason ~ptr should never
00158             //alter _count.
00159             continue ;
00160           if (!(*r.first).count().compare_exchange_strong(count, count+1))
00161             continue ;
00162 
00163           return r.first;
00164         }
00165       }
00166     };
00167   private:
00168 
00169     typename set_t::iterator _ptr;
00170 
00171   public:
00172     static void collect() {
00173       while (collector<set_t>::collect(set())) ;
00174     }
00175 
00176   private:
00177     struct init_set {
00178       set_t _set;
00179       init_set() {
00180         ptr::init(_set);
00181       }
00182     };
00183 
00184   protected:
00185     static set_t& set() {
00186       static init_set _set;
00187       return _set._set;
00188     }
00189 
00190   protected:
00191     explicit max_shared_ptr(const typename set_t::iterator& other)
00192       : _ptr(other) {
00193       ++((*_ptr).count());
00194     }
00195 
00196   public:
00197     ~max_shared_ptr() {
00198       if (_ptr == set().end())
00199         return ;
00200       if (0 == --((*_ptr).count())) {
00201         set().erase(_ptr);
00202       }
00203       _ptr = set().end();
00204     }
00205 
00206     max_shared_ptr(T&& core): _ptr(ptr::create(std::move(core))) {
00207     }
00208 
00209     max_shared_ptr(const max_shared_ptr& other): _ptr(other._ptr) {
00210       if (_ptr != set().end())
00211         (*_ptr).count()++;
00212     }
00213 
00214     max_shared_ptr(max_shared_ptr&& other): _ptr(other._ptr) {
00215       other._ptr = set().end();
00216     }
00217 
00218     max_shared_ptr& operator=(max_shared_ptr&& other) {
00219       std::swap(_ptr, other._ptr);
00220       return *this;
00221     }
00222 
00223     max_shared_ptr& operator=(const max_shared_ptr& other) {
00224       return *this = max_shared_ptr(other);
00225     }
00226 
00227     template <typename... Args,
00228               ENABLE_IF_NOT(type_traits::
00229                             is_same_ellipsis<max_shared_ptr,
00230                                              Args...>),
00231               ENABLE_IF(catsfoot::is_constructible<T(Args...)>)>
00232     max_shared_ptr(Args... args...):
00233       _ptr(ptr::create(T(std::forward<Args>(args)...))) {
00234     }
00235 
00236     max_shared_ptr(): _ptr(set().end()) {
00237     }
00238 
00239     bool operator==(const max_shared_ptr& other) const {
00240       return _ptr == other._ptr;
00241     }
00242 
00243     bool operator!=(const max_shared_ptr& other) const {
00244       return _ptr != other._ptr;
00245     }
00246 
00247     const T* get() const {
00248       return (*_ptr).get();
00249     }
00250 
00251     const T* operator->() const {
00252       return get();
00253     }
00254 
00255     const T& operator*() const {
00256       return *get();
00257     }
00258   };
00259 
00260 }
00261 
00262 
00263 
00264 #endif