unordered_set.hh

00001 // This file is a part of Aurelia.
00002 // Copyright (C) 2010  Valentin David
00003 // Copyright (C) 2010  University of Bergen
00004 //
00005 // This program is free software: you can redistribute it and/or modify
00006 // it under the terms of the GNU General Public License as published by
00007 // the Free Software Foundation, either version 3 of the License, or
00008 // (at your option) any later version.
00009 //
00010 // This program is distributed in the hope that it will be useful,
00011 // but WITHOUT ANY WARRANTY; without even the implied warranty of
00012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00013 // GNU General Public License for more details.
00014 //
00015 // You should have received a copy of the GNU General Public License
00016 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
00017 
00018 #ifndef __UNORDERED_SET_HH
00019 # define __UNORDERED_SET_HH
00020 
00021 # include "sorted_forward_list.hh"
00022 # include "vector.hh"
00023 # include "../fast/reset_msb.hh"
00024 # include "../fast/invert_bits.hh"
00025 
00026 namespace lockfree {
00027   template <typename T,
00028             typename Hash = std::hash<T>,
00029             typename Eq = std::equal_to<T>,
00030             typename Alloc = std::allocator<T>,
00031             typename Multi = std::false_type>
00032   class unordered_set {
00033     struct elts {
00034     private:
00035       size_t _invhash;
00036       typename std::aligned_storage<sizeof(T),
00037                                     std::alignment_of<T>::value>::type _elt;
00038 
00039       T& elt() {
00040         return *reinterpret_cast<T*>(&_elt);
00041       }
00042       const T& elt() const {
00043         return *reinterpret_cast<const T*>(&_elt);
00044       }
00045 
00046     public:
00047       size_t invhash() const {
00048         return _invhash;
00049       }
00050 
00051       const T& operator*() const {
00052         return elt();
00053       }
00054       elts(size_t _invhash, T&& elt): _invhash(_invhash) {
00055         assert(_invhash & 0x1);
00056         new (reinterpret_cast<T*>(&_elt)) T(std::move(elt));
00057       }
00058       elts(size_t _invhash): _invhash(_invhash) {
00059         assert(!(_invhash & 0x1));
00060       }
00061 
00062       elts() = delete;
00063       elts(const elts&) = delete;
00064       elts(elts&& other): _invhash(other._invhash) {
00065         if (other._invhash & 0x1) {
00066           new (reinterpret_cast<T*>(&_elt)) T(std::move(other.elt()));
00067         }
00068       }
00069 
00070       elts& operator=(const elts&) = delete;
00071       elts& operator=(elts&& other) {
00072         if (_invhash & 0x1) {
00073           if (other._invhash & 0x1)
00074             std::swap(*reinterpret_cast<T*>(&_elt), *reinterpret_cast<T*>(&other._elt));
00075           else
00076             new (reinterpret_cast<T*>(&other._elt)) T(std::move(elt()));
00077         }
00078         else {
00079           if (!(other._invhash & 0x1))
00080             new (reinterpret_cast<T*>(&_elt)) T(std::move(other.elt()));
00081         }
00082         std::swap(_invhash, other._invhash);
00083         return *this;
00084       }
00085 
00086       ~elts() {
00087         if (_invhash & 0x1)
00088           elt().~T();
00089       }
00090     };
00091 
00092     struct compare_elts {
00093       bool operator()(const elts& a, const elts& b) const {
00094         return a.invhash() < b.invhash();
00095       }
00096     };
00097 
00098     struct eq_elts {
00099       Eq eq;
00100       eq_elts(Eq eq): eq(eq) {}
00101 
00102       bool operator()(const elts& a, const elts& b) const {
00103         if (a.invhash() != b.invhash())
00104           return false;
00105         if (!(a.invhash()&0x1))
00106           return false;
00107         return eq(*a, *b);
00108       }
00109     };
00110 
00111     typedef typename Alloc::template rebind<elts>::other elt_alloc_t;
00112     typedef sorted_forward_list<elts,
00113                                 eq_elts, compare_elts,
00114                                 elt_alloc_t, Multi> list_t;
00115 
00116     typedef typename Alloc::template rebind<typename list_t::const_iterator>
00117     ::other it_alloc_t;
00118 
00119     typedef vector<typename list_t::const_iterator,
00120                    1024u, it_alloc_t> entries_t;
00121 
00122     double max_load_factor;
00123     Eq eq;
00124     Hash hash;
00125 
00126     list_t list;
00127     entries_t entries;
00128 
00129     std::atomic<size_t> _size;
00130 
00131   public:
00132     unordered_set(Hash hash = Hash(), Eq eq = Eq(), Alloc alloc = Alloc()):
00133       max_load_factor(.3),
00134       eq(eq), hash(hash), list(eq_elts(eq), compare_elts(),
00135                                alloc), entries(list.end(), alloc) {
00136       entries.resize(1);
00137       entries[0x0] = list.insert(elts(0x0)).first;
00138       _size.store(0);
00139     }
00140 
00141   private:
00142     typename list_t::const_iterator _get_bucket(size_t b) const {
00143       typename list_t::const_iterator i = entries[b];
00144       if (i == list.end()) {
00145         typename list_t::const_iterator parent =
00146           _get_bucket(fast::reset_msb(b));
00147         i = const_cast<unordered_set*>(this)->list
00148           .insert_after(elts(get_invert(b)&~0x1), parent, true).first;
00149         const_cast<unordered_set*>(this)->entries[b] = i;
00150       }
00151       return i;
00152     }
00153 
00154     size_t get_invert(size_t h) const {
00155       fast::invert_bits(h);
00156       return h | 0x1;
00157     }
00158 
00159     typename list_t::const_iterator get_bucket(size_t h) const {
00160       return _get_bucket(h%entries.size());
00161     }
00162 
00163   public:
00164     ~unordered_set() {
00165       clear();
00166       while (collect()) ;
00167     }
00168 
00169     struct iterator {
00170     private:
00171       typename list_t::const_iterator i;
00172 
00173     public:
00174       iterator(const typename list_t::const_iterator& i): i(i) {
00175         while ((this->i.get_node() == NULL)?false
00176                :(!((*(this->i)).invhash()&0x1)))
00177           ++(this->i);
00178       }
00179       iterator(const iterator&) = default;
00180       iterator() = delete;
00181 
00182       const T& operator*() const {
00183         return **i;
00184       }
00185 
00186       operator const typename list_t::const_iterator&() const {
00187         return i;
00188       }
00189 
00190       iterator& operator++() {
00191         do
00192           ++i;
00193         while ((this->i.get_node() == NULL)?false
00194                :(!((*(this->i)).invhash()&0x1)));
00195         return *this;
00196       }
00197 
00198       iterator operator++(int) {
00199         iterator tmp(*this);
00200         ++(*this);
00201         return tmp;
00202       }
00203 
00204       bool operator==(const iterator& other) const {
00205         return i == other.i;
00206       }
00207       bool operator!=(const iterator& other) const {
00208         return i != other.i;
00209       }
00210     };
00211 
00212     iterator end() const {
00213       return list.end();
00214     }
00215 
00216     iterator begin() const {
00217       return list.begin();
00218     }
00219 
00220     void clear() {
00221       for (iterator i = begin(); i != end();) {
00222         iterator tmp = i++;
00223         erase(tmp);
00224       }
00225     }
00226 
00227     iterator find(const T& elt) const {
00228       size_t h = hash(elt);
00229       typename list_t::const_iterator i = get_bucket(h);
00230       size_t inv = get_invert(h);
00231       size_t f = 0;
00232       for (;
00233            i != list.end(); ++i) {
00234         if (inv < (*i).invhash())
00235           return end();
00236         if (inv == (*i).invhash())
00237           if (eq(**i, elt))
00238             return iterator(i);
00239         ++f;
00240       }
00241       return end();
00242     }
00243 
00244     std::pair<iterator,bool> insert(T&& elt) {
00245       size_t h = hash(elt);
00246       typename list_t::const_iterator i = get_bucket(h);
00247       elts toinsert(get_invert(h), std::move(elt));
00248       std::pair<typename list_t::const_iterator,bool> ret =
00249         list.insert_after(std::move(toinsert), i);
00250       if (ret.second) {
00251         size_t nelts = ++_size;
00252         size_t size = entries.size();
00253         if ((((double)nelts) / size) > max_load_factor) {
00254           entries.resize(2*size);
00255         }
00256       }
00257       else {
00258         std::swap(*const_cast<T*>(&*toinsert), elt);
00259       }
00260       return std::pair<iterator,bool>(iterator(ret.first),ret.second);
00261     }
00262 
00263     size_t size() const {
00264       return _size.load();
00265     }
00266 
00267     bool collect() {
00268       entries.collect();
00269       return list.collect();
00270     }
00271 
00272     void erase(const T& t) {
00273       static_assert(!Multi::value, "Cannot erase elements on multiset");
00274       erase(find(t));
00275     }
00276 
00277     void erase(const iterator& i) {
00278       if (i == end())
00279         return ;
00280       typename list_t::const_iterator j = i;
00281       size_t h = (*j).invhash();
00282       fast::invert_bits(h);
00283       typename list_t::const_iterator b = get_bucket(h);
00284       if (list.erase(j, b))
00285         --_size;
00286     }
00287   };
00288 }
00289 
00290 #endif