1
+ use std:: collections:: HashMap ;
2
+ use std:: collections:: HashSet ;
1
3
use std:: fmt;
2
4
use std:: fs;
3
5
use std:: io:: { Read , Write , BufWriter } ;
6
+ use dashmap:: DashMap ;
4
7
use rand:: prelude:: * ;
5
8
use indicatif:: { ProgressBar , ProgressStyle , ProgressIterator } ;
6
- use crossbeam:: scope;
7
9
use rayon:: prelude:: * ;
8
10
9
11
#[ derive( Debug , Clone ) ]
@@ -23,6 +25,8 @@ pub fn process(name: &str) -> Result<(), Box<dyn std::error::Error>> {
23
25
let mut file = String :: new ( ) ;
24
26
fs:: File :: open ( name) ?. read_to_string ( & mut file) ?;
25
27
println ! ( "Reading file..." ) ;
28
+ let bar = ProgressBar :: new_spinner ( ) ;
29
+ bar. enable_steady_tick ( 100 ) ;
26
30
for line in file. lines ( ) {
27
31
let mut parts = line. split_whitespace ( ) ;
28
32
let head = parts. next ( ) . ok_or ( "Wrong format!" ) ?;
@@ -34,68 +38,75 @@ pub fn process(name: &str) -> Result<(), Box<dyn std::error::Error>> {
34
38
relation : relation,
35
39
} ) ;
36
40
}
41
+ bar. finish ( ) ;
37
42
println ! ( "{} triples" , triples. len( ) ) ;
38
43
println ! ( "Building maps..." ) ;
39
- let bar = ProgressBar :: new ( triples. len ( ) as u64 ) ;
40
- bar. set_style (
41
- ProgressStyle :: default_bar ( )
42
- . template ( "[{elapsed} / {eta}]({per_sec}) {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}" )
43
- . progress_chars ( "##-" ) ,
44
- ) ;
45
- bar. set_draw_delta ( 100000 ) ;
46
- let num_cores = num_cpus:: get ( ) ;
47
- let mut maps = vec ! [ ] ;
48
- scope ( |s| {
49
- let mut threads = vec ! [ ] ;
50
- for chunk in triples. chunks ( triples. len ( ) / num_cores + 1 ) {
51
- threads. push ( s. spawn ( |_| {
52
- let mut map = std:: collections:: HashMap :: < & str , Vec < & Triple > > :: with_capacity ( chunk. len ( ) ) ;
53
- chunk. iter ( ) . for_each ( |triple| {
54
- // bar.inc(1);
55
- map. entry ( triple. head ) . or_insert ( Vec :: with_capacity ( 4 ) ) . push ( triple) ;
56
- } ) ;
57
- map
58
- } ) ) ;
59
- }
60
- threads. into_iter ( ) . for_each ( |thread| {
61
- let map = thread. join ( ) . unwrap ( ) ;
62
- maps. push ( map) ;
63
- bar. inc ( 1 ) ;
64
- } ) ;
65
- } ) . unwrap ( ) ;
66
- println ! ( "{} maps built with total size {}" , maps. len( ) , maps. iter( ) . map( |m| m. len( ) ) . sum:: <usize >( ) ) ;
44
+ let bar = ProgressBar :: new_spinner ( ) ;
45
+ bar. enable_steady_tick ( 100 ) ;
46
+ // let head_map = DashMap::new();
47
+ let relation_map: DashMap < & str , std:: collections:: HashMap < & str , Vec < & Triple > > > = DashMap :: new ( ) ;
48
+ triples. par_iter ( ) . for_each ( |triple| {
49
+ // bar.inc(1);
50
+ // head_map.entry(triple.head).or_insert(Vec::with_capacity(4)).push(triple);
51
+ relation_map. entry ( triple. head ) . or_default ( ) . entry ( triple. relation ) . or_default ( ) . push ( triple) ;
52
+ } ) ;
53
+ bar. finish ( ) ;
54
+ println ! ( "Map built with total size {}" , relation_map. len( ) ) ;
67
55
println ! ( "Finding neighbor..." ) ;
68
- let bar = ProgressBar :: new ( triples. len ( ) as u64 ) ;
69
- bar. set_style (
70
- ProgressStyle :: default_bar ( )
71
- . template ( "[{elapsed} / {eta}]({per_sec}) {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}" )
72
- . progress_chars ( "##-" ) ,
73
- ) ;
74
- bar. set_draw_delta ( 100000 ) ;
56
+ let bar = ProgressBar :: new_spinner ( ) ;
57
+ bar. enable_steady_tick ( 100 ) ;
75
58
let mut writer = BufWriter :: new ( fs:: File :: create ( String :: from ( name) + "_ptranse" ) ?) ;
76
59
77
- triples. par_iter ( ) . map ( |triple| -> ( & Triple , Option < & Triple > ) {
60
+ triples. par_iter ( ) . map ( |triple| -> ( & Triple , Option < ( & Triple , f64 ) > ) {
78
61
// bar.inc(1);
79
- let initial = thread_rng ( ) . gen_range ( 0 ..maps. len ( ) ) ;
80
- // iterator maps from id:
81
- for id in 0 ..num_cores {
82
- if let Some ( neighbors) = maps[ ( initial + id) % num_cores] . get ( triple. tail ) {
83
- return ( triple, Some ( neighbors. choose ( & mut thread_rng ( ) ) . unwrap ( ) ) ) ;
62
+ if let Some ( neighbors) = relation_map. get ( triple. tail ) {
63
+ let total_length = 0 ;
64
+ let mut map: HashMap < & str , HashMap < & str , HashSet < & Triple > > > = std:: collections:: HashMap :: new ( ) ;
65
+ relation_map. get ( triple. head ) . map ( |h| {
66
+ h. iter ( ) . for_each ( |r1| {
67
+ r1. 1 . iter ( ) . for_each ( |t1| {
68
+ relation_map. get ( t1. tail ) . map ( |t2| {
69
+ t2. iter ( ) . for_each ( |r2| {
70
+ let t = map. entry ( r1. 0 ) . or_default ( ) . entry ( r2. 0 ) . or_default ( ) ;
71
+ // t.extend(r2.1.iter());
72
+ } ) ;
73
+ } ) ;
74
+ } ) ;
75
+ } ) ;
76
+ // total_length += set.len();
77
+ } ) ;
78
+ let max = map. iter ( ) . filter_map ( |m| {
79
+ // find max in m.1
80
+ let mm = m. 1 . iter ( ) . max_by_key ( |d| d. 1 . len ( ) ) . map ( |d| ( d. 0 , d. 1 . len ( ) ) ) ;
81
+ if let Some ( ( r, l) ) = mm {
82
+ Some ( ( m. 0 , r, l) )
83
+ } else {
84
+ None
85
+ }
86
+ } ) . max_by_key ( |d| d. 2 ) ;
87
+ if let Some ( r1) = max {
88
+ // (triple, None)
89
+ let t = map. get ( r1. 0 ) . unwrap ( ) . get ( r1. 1 ) . unwrap ( ) ;
90
+ ( triple, Some ( t. iter ( ) . take ( 1 ) . next ( ) . unwrap ( ) , r1. 2 as f64 / total_length as f64 ) )
91
+ } else {
92
+ ( triple, None )
84
93
}
94
+ } else {
95
+ ( triple, None )
85
96
}
86
- ( triple, None )
87
97
// Some(map.get(triple.tail)?.choose(&mut rand::thread_rng())?.clone())
88
98
} ) . collect :: < Vec < _ > > ( ) . iter ( ) . for_each ( |triple| {
89
99
match triple {
90
100
( triple, Some ( neighbor) ) => {
91
- writer. write ( format ! ( "{}\t {}\t {}\n " , triple, neighbor. relation, neighbor. tail) . as_bytes ( ) ) . unwrap ( ) ;
101
+ writer. write ( format ! ( "{}\t {}\t {}\t {} \ n " , triple, neighbor. 0 . relation, neighbor. 0 . tail, neighbor . 1 ) . as_bytes ( ) ) . unwrap ( ) ;
92
102
} ,
93
103
( triple, None ) => {
94
104
writer. write ( format ! ( "{}\n " , triple) . as_bytes ( ) ) . unwrap ( ) ;
95
105
} ,
96
106
}
97
107
} ) ;
98
108
writer. flush ( ) ?;
109
+ bar. finish ( ) ;
99
110
println ! ( "Found neighbor!" ) ;
100
111
Ok ( ( ) )
101
112
}
0 commit comments